静かなる名辞

pythonとプログラミングのこと



【python】numpyで多次元配列のargsortと値の取り出し

はじめに

 numpy配列のargsort()メソッドは値をソートした結果のインデックスの配列を返します。

>>> import numpy as np
>>> a = np.array([2,0,1,8,1,1,0,7])  # 適当な配列を定義
>>> idx = a.argsort()  # argsort
>>> idx  # こんな配列になる
array([1, 6, 2, 4, 5, 0, 7, 3])
>>> a[idx]  # 配列をインデックスとして使う
array([0, 0, 1, 1, 1, 2, 7, 8])

numpy.ndarray.argsort — NumPy v1.15 Manual

 同様のことを多次元配列でも行うことができます。その方法について説明します。

多次元配列でargsort()

 argsort()メソッド自体は多次元配列でも問題なく呼べます。が、ソートなので、axisを指定する必要があります。

>>> a = np.array([[2,0,1,8],[1,1,0,7]])
>>> a
array([[2, 0, 1, 8],
       [1, 1, 0, 7]])
>>> a.argsort()
array([[1, 2, 0, 3],
       [2, 0, 1, 3]])
>>> a.argsort(axis=0)  # axis=0では縦にソート
array([[1, 0, 1, 1],
       [0, 1, 0, 0]])
>>> a.argsort(axis=1)  # axis=1で横にソート。デフォルトと同じ動作
array([[1, 2, 0, 3],
       [2, 0, 1, 3]])

 ドキュメントによると、デフォルトのaxisは-1のようです。ちょっとめずらしい気がしますね。

値の取り出し

 残念ながら2次元配列なので、値を取り出そうとするとなかなか思い通りにはいきません。

>>> a[a.argsort()]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: index 2 is out of bounds for axis 0 with size 2

 こんなコードを書いた人も多いのでは。

sorted_array = a.sort()
sorted_index = a.argsort()

 どう考えても二回ソートするのは二度手間です。

 そこで、numpy.take_along_axis()関数を使います。ちなみに、これはnumpy 1.15から入った新しい関数です。

numpy.take_along_axis — NumPy v1.15 Manual

 以下のように使えます。

>>> np.take_along_axis(a, a.argsort(axis=1), axis=1)
array([[0, 1, 2, 8],
       [0, 1, 1, 7]])

 便利ですね。