静かなる名辞

pythonとプログラミングのこと

2019/03/22:TechAcademyがteratailの質問・回答を盗用していた件
2019/03/26:TechAcademy盗用事件 公式発表と深まる疑念


【python】numbaを使ってライフゲームを書いてみた

概要

 ライフゲームを書きました。

 素のpythonだと何をやっても激遅だったので、numbaで高速化しました。

方針

 まず実装の方針を決めます。主要な関数としては以下のものがあればできると思いました。

  • update_cell

 1セルの状態を更新する

  • update_field

 フィールド全体を更新する

  • main

 メインループ、描画など

 最初からnumbaを使ってみるつもりでしたが、numbaは割と制約が多いので、基本的にpython的なコードにするとJITコンパイルに失敗します。それを意識してコーディングしました。
(nopython=Trueオプションを付けてコンパイルできる状態でないと、まったく速くなりません。みなさんも注意してください)

実装の説明

 実装の詳細について説明します。

グローバル変数

 グローバル変数として以下の2つを定義しました。

field_w = 200
field_h = 200

 フィールドのサイズはグローバル変数で書いておいた方が楽だろう、という判断です。なお、とりあえず200*200を指定していますが、私のマシンでは600*600くらいまでは1ステップ1秒未満で計算できます。見てて楽しいのはもっと小さいフィールドですが。

get_ijlst関数

 ライフゲームを書こうとしたとき、誰もが思うのは「周囲8セルの座標を出すのが面倒くさい」ということでしょう。(i, j+1), (i, j-1), (i+1, j+1),...みたいにやっていけば良いことはわかりますが、フィールドからのはみ出しなどを考慮すると大変そうです。

 そこで、その部分を簡略化するべく関数を1つ作りました。

@nb.jit(nopython=True)
def get_ijlst(x, limit):
    ret = []
    if 0 < x:
        ret.append(x-1)
    if x < limit-1:
        ret.append(x+1)
    ret.append(x)
    return ret

 基本的には[x-1, x+1, x]のlistを返しますが、0 <= x < limitの範囲に収まらない要素は返り値のlistの中に含めないような処理をするための関数です。なお、これは次に説明するupdate_cell関数から呼ぶため、jitコンパイルしています。

update_cell関数

 先にコードを示します。

@nb.jit(nopython=True)
def cell_update(i, j, field, out):
    i_lst = get_ijlst(i, field_h)
    j_lst = get_ijlst(j, field_w)

    s = 0
    for ni in i_lst:
        for nj in j_lst:
            s += field[ni, nj]
    s -= field[i,j]

    if s < 2:
        out[i,j] = 0
    elif s == 2:
        out[i,j] = field[i,j]
    elif s == 3:
        out[i,j] = 1
    elif s >= 4:
        out[i,j] = 0
    else:
        raise Exception

 座標値のi,jとnumpy配列のfield, outを受け取り、fieldに従って計算した次の状態をoutに書き込みます。

 上のforループのあたりのコードは周囲8マスの総和の計算ですが、実は中心の(i,j)の値もループ対象にして総和を格納する変数sに加算し、後から中心の値をsから引いています。ループの中にifなどを入れて判定するより処理速度的に安上がりだろうという判断です。

 その下にあるif文はライフゲームのルールを実装しています。周囲8マスの総和をsとおくと、

  • sが2未満なら死(過疎)
  • sが2なら元と同じ値
  • sが3なら誕生する。元の生死にかかわらず1
  • sが4以上なら死(過密)

 と表せます。なお、これ以外のパターンはルール上ありえないので、万が一へんな値が来たときに備えてelse節で例外を投げています(限りなくデバッグ用に近い)。

update_field関数

 こちらはシンプルです。

def update_field(pair_lst):
    for i in range(field_h):
        for j in range(field_w):
            cell_update(i, j, pair_lst[0], pair_lst[1])
    pair_lst.append(pair_lst.pop(0))

 工夫したのはpair_lstでしょうか。これは同じサイズ(shape=(field_h, field_w))の2つのnumpy配列を要素に持つlistを受け取ることを想定しています。このlistは呼び出し元(main)で定義します。

 最後の行が何をしているのか、初見では理解できないと思いますが、

>>> lst = [0,1]
>>> lst.append(lst.pop(0))
>>> lst
[1, 0]

 このように値を入れ替えられるというアイデアです。つまり、2つの配列を最初に作り、ずっと同じ2つを新旧を入れ替えながら使うということです。これによりオーバーヘッドの削減を狙っています。

main

 必要な配列を定義し、更新・描画のループを回しているだけです。手抜きによりmatplotlibでアニメーション描画しています。

def main():
    field = (np.random.random(size=(field_h, field_w)) > 0.9).astype(np.int16)
    out = np.zeros(shape=(field_h, field_w)).astype(np.int16)
    pair_lst = [field, out]

    img = plt.imshow(field)
    for i in range(1000):
        update_field(pair_lst)
        img.set_data(pair_lst[0])
        plt.pause(0.001)

スポンサーリンク



コード全文

 コードの全体を以下に示します。

import numpy as np
import numba as nb
import matplotlib.pyplot as plt

field_w = 40
field_h = 60

@nb.jit(nopython=True)
def get_ijlst(x, limit):
    ret = []
    if 0 < x:
        ret.append(x-1)
    if x < limit-1:
        ret.append(x+1)
    ret.append(x)
    return ret

@nb.jit(nopython=True)
def update_cell(i, j, field, out):
    i_lst = get_ijlst(i, field_h)
    j_lst = get_ijlst(j, field_w)

    s = 0
    for ni in i_lst:
        for nj in j_lst:
            s += field[ni, nj]
    s -= field[i,j]

    if s < 2:
        out[i,j] = 0
    elif s == 2:
        out[i,j] = field[i,j]
    elif s == 3:
        out[i,j] = 1
    elif s >= 4:
        out[i,j] = 0
    else:
        raise Exception

def update_field(pair_lst):
    for i in range(field_h):
        for j in range(field_w):
            update_cell(i, j, pair_lst[0], pair_lst[1])
    pair_lst.append(pair_lst.pop(0))

def main():
    field = (np.random.random(size=(field_h, field_w)) > 0.9).astype(np.int16)
    out = np.zeros(shape=(field_h, field_w)).astype(np.int16)
    pair_lst = [field, out]

    img = plt.imshow(field)
    for i in range(1000):
        update_field(pair_lst)
        img.set_data(pair_lst[0])
        plt.pause(0.001)

if __name__ == "__main__":
    main()

計測

 描画処理をコメントアウトし、JITコンパイルを付けたときと外したときで200*200のフィールドを20ステップ進めるのにかかる時間を計測してみました。

  • JITコンパイルなし

 8.8秒

  • JITコンパイルあり

 1.4秒

 6倍強の高速化が達成されました。・・・ってちょっと微妙ですね。威張るほどでもない。

 numbaの型指定をしていないからかもしれないし、そもそもこんなものという可能性もあります。

画像

 50*50のフィールドで、グライダーが生まれたタイミングを見計らって一枚スクショしてみました。

結果
結果

 色合いが変なのはcmapをデフォルトのまま変えていないからです。

 動いているのが見たい方は、コードをコピペして手元環境で実行してください。

まとめ

 案外シンプルに書けたし、numbaでの高速化を試す良い機会にもなったと思います。

追記

 CUIでも実行できるようにしました。

【python】ターミナル上でCUIでライフゲーム - 静かなる名辞