概要
ライフゲームを書きました。
素のpythonだと何をやっても激遅だったので、numbaで高速化しました。
方針
まず実装の方針を決めます。主要な関数としては以下のものがあればできると思いました。
- update_cell
1セルの状態を更新する
- update_field
フィールド全体を更新する
- main
メインループ、描画など
最初からnumbaを使ってみるつもりでしたが、numbaは割と制約が多いので、基本的にpython的なコードにするとJITコンパイルに失敗します。それを意識してコーディングしました。
(nopython=Trueオプションを付けてコンパイルできる状態でないと、まったく速くなりません。みなさんも注意してください)
実装の説明
実装の詳細について説明します。
グローバル変数
グローバル変数として以下の2つを定義しました。
field_w = 200 field_h = 200
フィールドのサイズはグローバル変数で書いておいた方が楽だろう、という判断です。なお、とりあえず200*200を指定していますが、私のマシンでは600*600くらいまでは1ステップ1秒未満で計算できます。見てて楽しいのはもっと小さいフィールドですが。
get_ijlst関数
ライフゲームを書こうとしたとき、誰もが思うのは「周囲8セルの座標を出すのが面倒くさい」ということでしょう。(i, j+1), (i, j-1), (i+1, j+1),...みたいにやっていけば良いことはわかりますが、フィールドからのはみ出しなどを考慮すると大変そうです。
そこで、その部分を簡略化するべく関数を1つ作りました。
@nb.jit(nopython=True) def get_ijlst(x, limit): ret = [] if 0 < x: ret.append(x-1) if x < limit-1: ret.append(x+1) ret.append(x) return ret
基本的には[x-1, x+1, x]のlistを返しますが、0 <= x < limitの範囲に収まらない要素は返り値のlistの中に含めないような処理をするための関数です。なお、これは次に説明するupdate_cell関数から呼ぶため、jitコンパイルしています。
update_cell関数
先にコードを示します。
@nb.jit(nopython=True) def cell_update(i, j, field, out): i_lst = get_ijlst(i, field_h) j_lst = get_ijlst(j, field_w) s = 0 for ni in i_lst: for nj in j_lst: s += field[ni, nj] s -= field[i,j] if s < 2: out[i,j] = 0 elif s == 2: out[i,j] = field[i,j] elif s == 3: out[i,j] = 1 elif s >= 4: out[i,j] = 0 else: raise Exception
座標値のi,jとnumpy配列のfield, outを受け取り、fieldに従って計算した次の状態をoutに書き込みます。
上のforループのあたりのコードは周囲8マスの総和の計算ですが、実は中心の(i,j)の値もループ対象にして総和を格納する変数sに加算し、後から中心の値をsから引いています。ループの中にifなどを入れて判定するより処理速度的に安上がりだろうという判断です。
その下にあるif文はライフゲームのルールを実装しています。周囲8マスの総和をsとおくと、
- sが2未満なら死(過疎)
- sが2なら元と同じ値
- sが3なら誕生する。元の生死にかかわらず1
- sが4以上なら死(過密)
と表せます。なお、これ以外のパターンはルール上ありえないので、万が一へんな値が来たときに備えてelse節で例外を投げています(限りなくデバッグ用に近い)。
update_field関数
こちらはシンプルです。
def update_field(pair_lst): for i in range(field_h): for j in range(field_w): cell_update(i, j, pair_lst[0], pair_lst[1]) pair_lst.append(pair_lst.pop(0))
工夫したのはpair_lstでしょうか。これは同じサイズ(shape=(field_h, field_w))の2つのnumpy配列を要素に持つlistを受け取ることを想定しています。このlistは呼び出し元(main)で定義します。
最後の行が何をしているのか、初見では理解できないと思いますが、
>>> lst = [0,1] >>> lst.append(lst.pop(0)) >>> lst [1, 0]
このように値を入れ替えられるというアイデアです。つまり、2つの配列を最初に作り、ずっと同じ2つを新旧を入れ替えながら使うということです。これによりオーバーヘッドの削減を狙っています。
main
必要な配列を定義し、更新・描画のループを回しているだけです。手抜きによりmatplotlibでアニメーション描画しています。
def main(): field = (np.random.random(size=(field_h, field_w)) > 0.9).astype(np.int16) out = np.zeros(shape=(field_h, field_w)).astype(np.int16) pair_lst = [field, out] img = plt.imshow(field) for i in range(1000): update_field(pair_lst) img.set_data(pair_lst[0]) plt.pause(0.001)
スポンサーリンク
コード全文
コードの全体を以下に示します。
import numpy as np import numba as nb import matplotlib.pyplot as plt field_w = 40 field_h = 60 @nb.jit(nopython=True) def get_ijlst(x, limit): ret = [] if 0 < x: ret.append(x-1) if x < limit-1: ret.append(x+1) ret.append(x) return ret @nb.jit(nopython=True) def update_cell(i, j, field, out): i_lst = get_ijlst(i, field_h) j_lst = get_ijlst(j, field_w) s = 0 for ni in i_lst: for nj in j_lst: s += field[ni, nj] s -= field[i,j] if s < 2: out[i,j] = 0 elif s == 2: out[i,j] = field[i,j] elif s == 3: out[i,j] = 1 elif s >= 4: out[i,j] = 0 else: raise Exception def update_field(pair_lst): for i in range(field_h): for j in range(field_w): update_cell(i, j, pair_lst[0], pair_lst[1]) pair_lst.append(pair_lst.pop(0)) def main(): field = (np.random.random(size=(field_h, field_w)) > 0.9).astype(np.int16) out = np.zeros(shape=(field_h, field_w)).astype(np.int16) pair_lst = [field, out] img = plt.imshow(field) for i in range(1000): update_field(pair_lst) img.set_data(pair_lst[0]) plt.pause(0.001) if __name__ == "__main__": main()
計測
描画処理をコメントアウトし、JITコンパイルを付けたときと外したときで200*200のフィールドを20ステップ進めるのにかかる時間を計測してみました。
- JITコンパイルなし
8.8秒
- JITコンパイルあり
1.4秒
6倍強の高速化が達成されました。・・・ってちょっと微妙ですね。威張るほどでもない。
numbaの型指定をしていないからかもしれないし、そもそもこんなものという可能性もあります。
画像
50*50のフィールドで、グライダーが生まれたタイミングを見計らって一枚スクショしてみました。
色合いが変なのはcmapをデフォルトのまま変えていないからです。
動いているのが見たい方は、コードをコピペして手元環境で実行してください。
まとめ
案外シンプルに書けたし、numbaでの高速化を試す良い機会にもなったと思います。
追記2
最近numbaの正しい使い方を知りました。
この知見を活かして型の情報を書いてみました。2つに分けていた関数は大した処理ではないのでまとめました。
@nb.jit("void(i8, i8, i2[:, :], i2[:, :])", nopython=True) def update_cell(i, j, field, out): i_lst = [i] j_lst = [j] if 0 < i: i_lst.append(i-1) if 0 < j: j_lst.append(j-1) if i < field_h: i_lst.append(i+1) if j < field_w: j_lst.append(j+1) s = 0 for ni in i_lst: for nj in j_lst: s += field[ni, nj] s -= field[i,j] if s < 2: out[i,j] = 0 elif s == 2: out[i,j] = field[i,j] elif s == 3: out[i,j] = 1 elif s >= 4: out[i,j] = 0 else: raise Exception
元のコードでは1.4秒(ライブラリのアップデートにも関わらずほとんど変わらなかった)かかっていたものが、0.9秒に高速化されました。
こうなるともう少し速くならないかと思うのが人情で、こっちもJITコンパイルすることにします。
@nb.jit("void(i2[:, :], i2[:, :])", nopython=True) def update_field(a, b): for i in range(field_h): for j in range(field_w): update_cell(i, j, a, b)
呼び出し側はこうします。
def main(): field = (np.random.random(size=(field_h, field_w)) > 0.9).astype(np.int16) out = np.zeros(shape=(field_h, field_w)).astype(np.int16) pair_lst = [field, out] t1 = time.time() # img = plt.imshow(field) for i in range(20): update_field(*pair_lst) pair_lst.append(pair_lst.pop(0)) # img.set_data(pair_lst[0]) # plt.pause(0.001) t2 = time.time() print(t2- t1)
0.4秒に高速化されました。こんな感じで速くなるので、なかなか大したものだと思いました。