静かなる名辞

pythonとプログラミングのこと



【python】numbaを使ってライフゲームを書いてみた

概要

 ライフゲームを書きました。

 素のpythonだと何をやっても激遅だったので、numbaで高速化しました。

方針

 まず実装の方針を決めます。主要な関数としては以下のものがあればできると思いました。

  • update_cell

 1セルの状態を更新する

  • update_field

 フィールド全体を更新する

  • main

 メインループ、描画など

 最初からnumbaを使ってみるつもりでしたが、numbaは割と制約が多いので、基本的にpython的なコードにするとJITコンパイルに失敗します。それを意識してコーディングしました。
(nopython=Trueオプションを付けてコンパイルできる状態でないと、まったく速くなりません。みなさんも注意してください)

実装の説明

 実装の詳細について説明します。

グローバル変数

 グローバル変数として以下の2つを定義しました。

field_w = 200
field_h = 200

 フィールドのサイズはグローバル変数で書いておいた方が楽だろう、という判断です。なお、とりあえず200*200を指定していますが、私のマシンでは600*600くらいまでは1ステップ1秒未満で計算できます。見てて楽しいのはもっと小さいフィールドですが。

get_ijlst関数

 ライフゲームを書こうとしたとき、誰もが思うのは「周囲8セルの座標を出すのが面倒くさい」ということでしょう。(i, j+1), (i, j-1), (i+1, j+1),...みたいにやっていけば良いことはわかりますが、フィールドからのはみ出しなどを考慮すると大変そうです。

 そこで、その部分を簡略化するべく関数を1つ作りました。

@nb.jit(nopython=True)
def get_ijlst(x, limit):
    ret = []
    if 0 < x:
        ret.append(x-1)
    if x < limit-1:
        ret.append(x+1)
    ret.append(x)
    return ret

 基本的には[x-1, x+1, x]のlistを返しますが、0 <= x < limitの範囲に収まらない要素は返り値のlistの中に含めないような処理をするための関数です。なお、これは次に説明するupdate_cell関数から呼ぶため、jitコンパイルしています。

update_cell関数

 先にコードを示します。

@nb.jit(nopython=True)
def cell_update(i, j, field, out):
    i_lst = get_ijlst(i, field_h)
    j_lst = get_ijlst(j, field_w)

    s = 0
    for ni in i_lst:
        for nj in j_lst:
            s += field[ni, nj]
    s -= field[i,j]

    if s < 2:
        out[i,j] = 0
    elif s == 2:
        out[i,j] = field[i,j]
    elif s == 3:
        out[i,j] = 1
    elif s >= 4:
        out[i,j] = 0
    else:
        raise Exception

 座標値のi,jとnumpy配列のfield, outを受け取り、fieldに従って計算した次の状態をoutに書き込みます。

 上のforループのあたりのコードは周囲8マスの総和の計算ですが、実は中心の(i,j)の値もループ対象にして総和を格納する変数sに加算し、後から中心の値をsから引いています。ループの中にifなどを入れて判定するより処理速度的に安上がりだろうという判断です。

 その下にあるif文はライフゲームのルールを実装しています。周囲8マスの総和をsとおくと、

  • sが2未満なら死(過疎)
  • sが2なら元と同じ値
  • sが3なら誕生する。元の生死にかかわらず1
  • sが4以上なら死(過密)

 と表せます。なお、これ以外のパターンはルール上ありえないので、万が一へんな値が来たときに備えてelse節で例外を投げています(限りなくデバッグ用に近い)。

update_field関数

 こちらはシンプルです。

def update_field(pair_lst):
    for i in range(field_h):
        for j in range(field_w):
            cell_update(i, j, pair_lst[0], pair_lst[1])
    pair_lst.append(pair_lst.pop(0))

 工夫したのはpair_lstでしょうか。これは同じサイズ(shape=(field_h, field_w))の2つのnumpy配列を要素に持つlistを受け取ることを想定しています。このlistは呼び出し元(main)で定義します。

 最後の行が何をしているのか、初見では理解できないと思いますが、

>>> lst = [0,1]
>>> lst.append(lst.pop(0))
>>> lst
[1, 0]

 このように値を入れ替えられるというアイデアです。つまり、2つの配列を最初に作り、ずっと同じ2つを新旧を入れ替えながら使うということです。これによりオーバーヘッドの削減を狙っています。

main

 必要な配列を定義し、更新・描画のループを回しているだけです。手抜きによりmatplotlibでアニメーション描画しています。

def main():
    field = (np.random.random(size=(field_h, field_w)) > 0.9).astype(np.int16)
    out = np.zeros(shape=(field_h, field_w)).astype(np.int16)
    pair_lst = [field, out]

    img = plt.imshow(field)
    for i in range(1000):
        update_field(pair_lst)
        img.set_data(pair_lst[0])
        plt.pause(0.001)

コード全文

 コードの全体を以下に示します。

import numpy as np
import numba as nb
import matplotlib.pyplot as plt

field_w = 40
field_h = 60

@nb.jit(nopython=True)
def get_ijlst(x, limit):
    ret = []
    if 0 < x:
        ret.append(x-1)
    if x < limit-1:
        ret.append(x+1)
    ret.append(x)
    return ret

@nb.jit(nopython=True)
def update_cell(i, j, field, out):
    i_lst = get_ijlst(i, field_h)
    j_lst = get_ijlst(j, field_w)

    s = 0
    for ni in i_lst:
        for nj in j_lst:
            s += field[ni, nj]
    s -= field[i,j]

    if s < 2:
        out[i,j] = 0
    elif s == 2:
        out[i,j] = field[i,j]
    elif s == 3:
        out[i,j] = 1
    elif s >= 4:
        out[i,j] = 0
    else:
        raise Exception

def update_field(pair_lst):
    for i in range(field_h):
        for j in range(field_w):
            update_cell(i, j, pair_lst[0], pair_lst[1])
    pair_lst.append(pair_lst.pop(0))

def main():
    field = (np.random.random(size=(field_h, field_w)) > 0.9).astype(np.int16)
    out = np.zeros(shape=(field_h, field_w)).astype(np.int16)
    pair_lst = [field, out]

    img = plt.imshow(field)
    for i in range(1000):
        update_field(pair_lst)
        img.set_data(pair_lst[0])
        plt.pause(0.001)

if __name__ == "__main__":
    main()

計測

 描画処理をコメントアウトし、JITコンパイルを付けたときと外したときで200*200のフィールドを20ステップ進めるのにかかる時間を計測してみました。

  • JITコンパイルなし

 8.8秒

  • JITコンパイルあり

 1.4秒

 6倍強の高速化が達成されました。・・・ってちょっと微妙ですね。威張るほどでもない。

 numbaの型指定をしていないからかもしれないし、そもそもこんなものという可能性もあります。

画像

 50*50のフィールドで、グライダーが生まれたタイミングを見計らって一枚スクショしてみました。

結果
結果

 色合いが変なのはcmapをデフォルトのまま変えていないからです。

 動いているのが見たい方は、コードをコピペして手元環境で実行してください。

まとめ

 案外シンプルに書けたし、numbaでの高速化を試す良い機会にもなったと思います。

【python】sklearnのOneClassSVMを使って外れ値検知してみる

はじめに

 OneClassSVMというものがあると知ったので使ってみます。

 「1クラスSVM?」と思われると思いますが、要するに異常検知・外れ値検出などで使う手法です。信頼区間を出すのに似ていますが、複雑な分布だったりそもそも分布が想定できないようなデータでも計算してくれるので、シチュエーションによっては役に立ちそうです。

 なお、わかりやすい記事があったので先に紹介しておきます。

qiita.com

実験

 異常検知・外れ値検出系で使える手法なので、センサデータの処理とか、為替や株価のアルゴリズム取引用の処理なんかをやると適当だと思いますが、私はそんなカッコいいデータは持っていません。

 なので、例によって例のごとく、irisをPCAで二次元に落としたデータを使います。

 使い方は簡単で、nuに異常値の割合を指定すれば良いようです。なんかドキュメントには意味深なことが書いてありますが、この理解で良さそうです。

 ちなみにデフォルトはnu=0.5なので、データの半数が異常値扱いになります。最初は、一体何事かと思いました。あと、predictすると正常値=1,異常値=-1という予測になります。

 ドキュメント

 sklearn.svm.OneClassSVM — scikit-learn 0.20.1 documentation

 コードは以下のとおりです。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.svm import OneClassSVM
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA

def main():
    iris = load_iris()
    pca = PCA(n_components=2)
    data = pca.fit_transform(iris.data)

    x = np.linspace(-5, 5, 500)
    y = np.linspace(-1.5, 1.5, 250)    
    X, Y = np.meshgrid(x, y)
    
    ocsvm = OneClassSVM(nu=0.1, gamma="auto")
    ocsvm.fit(data)
    df = ocsvm.decision_function(
        np.array([X.ravel(), Y.ravel()]).T).reshape(X.shape)
    preds = ocsvm.predict(data)

    plt.scatter(data[:,0], data[:,1], c=preds,
                cmap=plt.cm.RdBu, alpha=0.8)
    r = max([abs(df.min()), abs(df.max())])
    plt.contourf(X, Y, df, 10, vmin=-r, vmax=r,
                 cmap=plt.cm.RdBu, alpha=.5)
    plt.savefig("result.png")

if __name__ == "__main__":
    main()

 予測と決定関数を見るだけという手抜き。雰囲気はこれでわかると思うので、勘弁してください。

 なんかcontourfあたりでごちゃごちゃやっていますが、決定境界がcmapの中心と一致するように配慮しています。こうすることで、白色のあたり(というか青と赤の境界)が決定境界になります。

 余談ですが、このコードのためにlevelsをキーワード引数で指定しようとしたら、matplotlibのバグを踏みました。ひどい。

plt.contour levels parameter don't work as intended if receive a single int · Issue #11913 · matplotlib/matplotlib · GitHub

結果

 プロットされる図を示します。

result.png
result.png

 このように、お手軽に良さげな結果が得られます。分布の形状が複雑でもうまく推定できる訳です。良いですね。

まとめ

 SVMなので使いやすくて、うまく動くようです。手軽に良好な異常検知ができる手法としては、かなり便利だと思います。

scipy.interpolate.griddataの内挿方法による違いを比較

はじめに

 以前、3次元のサンプルデータを内挿してmatplotlibでうまくプロットする方法について記事にしました。

www.haya-programming.com

 この記事では内挿のアルゴリズムをデフォルトのlinearにして使いましたが、他の方法ではどうなるのか気になったので実験してみました。

使えるアルゴリズム

 選択肢は3つだけです。

method : {‘linear’, ‘nearest’, ‘cubic’}, optional
Method of interpolation. One of

nearest
return the value at the data point closest to the point of interpolation. See NearestNDInterpolator for more details.

linear
tessellate the input point set to n-dimensional simplices, and interpolate linearly on each simplex. See LinearNDInterpolator for more details.

cubic (1-D)
return the value determined from a cubic spline.

cubic (2-D)
return the value determined from a piecewise cubic, continuously differentiable (C1), and approximately curvature-minimizing polynomial surface. See CloughTocher2DInterpolator for more details.

scipy.interpolate.griddata — SciPy v1.1.0 Reference Guide

 なんとなくcubicには1-Dと2-Dの2つがあって「1次キュービック補間と2次キュービック補間? そんなのあったっけ」と思いがちですが、データが1次元か2次元かで使い分けられるだけで、ユーザが指定できるのは{‘linear’, ‘nearest’, ‘cubic’}のいずれかです。

 それぞれ

  • 線形補間
  • 最近傍補間
  • キュービック補間

 です。詳しい中身は知らなくても、いずれも名前くらいは聞いたことがあると思います。

実験

 二次元正規分布でサンプル数=128,512とし、それぞれの補間アルゴリズムで内挿します。結果をプロットして確認します。

 また、回帰とみなしてRMSEを出してみました。

 コードを以下に示します。

import numpy as np
import matplotlib.pyplot as plt

from scipy import stats
from scipy import interpolate
from sklearn.metrics import mean_squared_error

def rmse(true, pred):
    return mean_squared_error(true.ravel(), pred.ravel())**(1/2)

def main():
    norm = stats.multivariate_normal(mean=[2.0, 3.0], cov=[[4, 2],[2,4]])

    # samples
    xy128 = np.random.uniform(low=-10, high=10, size=(128, 2))
    z128 = norm.pdf(xy128)
    xy512 = np.random.uniform(low=-10, high=10, size=(512, 2))
    z512 = norm.pdf(xy512)

    # xy meshgrid
    x = y = np.linspace(-10, 10, 500)
    X, Y = np.meshgrid(x, y)
    Z = norm.pdf(np.vstack([X.ravel(), Y.ravel()]).T).reshape(X.shape)

    # plot
    fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(10,5))
    plt.subplots_adjust(hspace=0.6, wspace=0.4)

    axes[0,0].pcolormesh(X, Y, Z, cmap="jet")
    axes[0,0].set_title("true data")
    axes[1,0].pcolormesh(X, Y, Z, cmap="jet")
    axes[1,0].set_title("true data")

    for i, (n_samples, xy, z) in enumerate(
            zip([128, 512], [xy128, xy512], [z128, z512])):
        axes[i,1].scatter(xy[:,0], xy[:,1], c=z, cmap="jet")
        axes[i,1].set_title("samples {}".format(n_samples))

        for j, i_method in enumerate(["nearest", "linear", "cubic"]):
            i_Z = interpolate.griddata(xy, z, (X, Y), method=i_method, 
                                       fill_value=0.0)
            axes[i,j+2].pcolormesh(X, Y, i_Z, cmap="jet")
            axes[i,j+2].set_title("{} {}\nrmse={:.5f}".format(
                i_method, str(n_samples), rmse(Z, i_Z)))
        
    plt.savefig("result.png")

if __name__ == "__main__":
    main()

 なお、RMSEを計算する都合上、fill_value=0.0としています。デフォルトはnanですが、それだと計算できないので……。一応実際にnanの状態でも確認し、nanになるのはグラフの端(このデータではほぼ0.0)の領域だけであることを確認して以上の判断をしました。

結果と考察

 プロットされた結果を示します。

result.png
result.png

 見ての通り、ダメダメな最近傍補間、まあまあな線形補間、群を抜いて良いキュービック補間という関係です。cubicにしておけば良いのでは?

 ただ、今回は真値そのものが補間が効きやすいなめらかなデータですが、実データはもう少しノイズが乗ったりして暴れることがあると思います。キュービック補間はオーバーシュート・アンダーシュートがあるらしいので、そういう場合でも対応できるように保険としてデフォルトがlinearになっているのかもしれません。まあ、無難なのはそっちでしょう。

 実用的には、両方やってみて大丈夫そうな方を選ぶことになるでしょう。

まとめ

 cubicがよかったです。

【python】rangeではin演算子が使える。速度は微妙かも

はじめに

 今日コードを書いていて、rangeでもinが使えることに気づきました。

>>> 10 in range(20)
True

 ドキュメントを見るとシーケンス型としての機能は一通り備えているようです。

range オブジェクトは collections.abc.Sequence ABC を実装し、包含判定、要素インデックス検索、スライシングのような機能を提供し、負のインデックスをサポートします (シーケンス型 — list, tuple, range を参照):

4. 組み込み型 — Python 3.6.5 ドキュメント

 ちなみにfloatでもなんとなくTrueになりましたが、あくまでも離散値の包含で比較される雰囲気です。

>>> 10.0 in range(20)
True
>>> 10.1 in range(20)
False

測ってみる

 こういうものがあると、速度が気になります。

>>> r = range(1000)
>>> l = list(range(1000))
>>> s = set(l)
>>> import timeit
>>> timeit.timeit(lambda : 500 in r)
0.2108146829996258
>>> timeit.timeit(lambda : 500 in l)
8.18245309899794
>>> timeit.timeit(lambda : 500 in s)
0.13419233300010092
>>> timeit.timeit(lambda : 0 <= 500 < 1000)
0.12016064700219431

 同じ長さのlistよりは圧倒的に速いものの、setに負けるという結果に。恐らくシーケンスに展開して線形探索をする訳ではないものの、内部処理がそこそこ複雑なのでしょう。また、単に値の区間だけ確認したいのなら、不等式による比較が最速のようです(setが異様に速いと言うべきか・・・)。

使いどころ

 rangeということはstepも入れられるので、複雑な条件のときにはmodで書くより可読性が良いかもしれません。

>>> [x in r for x in range(10)]
[True, False, False, True, False, False, True, False, False, True]

 他に積極的に使う理由は思いつきません。

xyzの点データを内挿してmeshgridにしmatplotlibでプロットする

はじめに

 pythonでmatplotlibを使って作図するとき、三次元のデータでpcolormeshとかcontourでやるような等高線プロットを作りたいんだけど、手持ちのデータはxyzが紐付いた点のバラバラな離散データだけ……ということがままあります。

 散布図ならそれでも良いのですが、等高線などをプロットしようとするとmeshgrid的な形式のデータが要ると思います。困ります。

 なので、内挿してなんとかしてみます。

方針

 matplotlibのドキュメントを読んでいると、それらしいものを見つけました。

mlab — Matplotlib 3.0.2 documentation

matplotlib.mlab.griddata(x, y, z, xi, yi, interp='nn')[source]
Deprecated since version 2.2: The griddata function was deprecated in Matplotlib 2.2 and will be removed in 3.1. Use scipy.interpolate.griddata instead.

 scipyのscipy.interpolate.griddataを使え、ということらしいです。

scipy.interpolate.griddata — SciPy v0.18.1 Reference Guide

 ちょっと使い方がわかりづらいですが、仮に

  • xy:shape=(n_samples, 2)のxy座標のデータ
  • z:shape=(n_samples,)のz座標のデータ
  • X, Y:xy空間のmeshgrid

 という変数を置くとすると

Z = griddata(xy, z, (X, Y))

 でZ(zのmeshgrid)が得られる、ということのようです。詳細はドキュメントのサンプルを見て確認してみてください。

実験

 二次元の正規分布を作り、真値と内挿で得られた値をpcolormesh, contourでプロットしました。

 コードを以下に示します。

import numpy as np
from scipy import stats
from scipy import interpolate
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def main():
    # data (samples)
    xy = np.random.uniform(low=-10, high=10, size=(500, 2))
    norm = stats.multivariate_normal(mean=[2.0, 3.0], cov=[[3, 1],[1,3]])
    z = norm.pdf(xy)
    
    # 3d scatter of samples
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.scatter(xy[:,0], xy[:,1], z)
    plt.savefig("scatter_3d.png")

    # mesh x-y
    x = y = np.linspace(-10, 10, 500)
    X, Y = np.meshgrid(x, y)

    # interpolated Z
    i_Z = interpolate.griddata(xy, z, (X, Y))

    # true Z
    t_Z = norm.pdf(np.vstack([X.ravel(), Y.ravel()]).T).reshape(X.shape)

    # plot
    fig, axes = plt.subplots(ncols=2)
    plt.subplots_adjust(wspace=0.4)

    # plot true Z
    im = axes[0].pcolormesh(X, Y, t_Z, cmap="jet")
    plt.colorbar(im, ax=axes[0])
    axes[0].contour(X, Y, t_Z, colors=["black"])
    axes[0].set_title("true")

    # plot interpolation Z
    im = axes[1].pcolormesh(X, Y, i_Z, cmap="jet")
    plt.colorbar(im, ax=axes[1])
    axes[1].contour(X, Y, i_Z, colors=["black"])
    axes[1].set_title("interpolation")

    plt.savefig("result.png")

if __name__ == "__main__":
    main()

 stats.multivariate_normalについてはこちらの記事もご覧くだい。

www.haya-programming.com

 それほど難しいことはやっていないので、詳細はコードを読んでいただければわかると思います。出力される画像を以下に示します。

 まず、参考のために出力したサンプルの三次元散布図です。

scatter_3d.png
scatter_3d.png

 けっこうスカスカな印象を受けます。

 次に、真値と内挿で得られた値をプロットした結果です。

result.png
result.png

 完璧とは言えませんが、そこそこ良さそうな結果が得られています。なお、内挿のグラフの外側の白い部分は内挿できなくてnanになっている領域です。griddataのfill_valueオプションなどで挙動を変えられるので、検討してみてください。

まとめ

 このように内挿してプロットすることができます。いまいちなデータしか手持ちにないとき威力を発揮します。

 ただし、元のデータの性質を歪めている側面があるので、ある程度注意が必要になります。それさえ気をつければ十分使えると思います。

追記

 内挿方法による比較を行いました。よろしければこちらも御覧ください。

www.haya-programming.com

【python】sklearnのFeatureAgglomerationを使ってみる

はじめに

 FeatureAgglomerationは階層的クラスタリングを用いた教師なし次元削減のモデルです。特徴量に対して階層的クラスタリングを行い(つまり通常のサンプルに対するクラスタリングと縦横の向きが入れ替わる)、似ている特徴量同士をマージします。マージの方法はデフォルトでは平均のようです。

 使用例をあまり見かけませんが、直感的な次元削減方法なので何かしらの役に立つかもしれないと思って使ってみました。

sklearn.cluster.FeatureAgglomeration — scikit-learn 0.20.1 documentation

使い方

 パラメータは以下の通り。

class sklearn.cluster.FeatureAgglomeration(
    n_clusters=2, affinity=’euclidean’, memory=None, connectivity=None, 
    compute_full_tree=’auto’, linkage=’ward’, pooling_func=<function mean>)

 色々いじれるように見えますが、主要パラメータは2つだけです。

  • n_clusters

 PCAでいうところのn_componentsです。変換先の次元数を表します。

  • pooling_func

 似ている特徴量をマージする方法。callableが渡せます。何もしなければ平均が使われるので、平均より気の利いた方法を思いつく人以外はそのままで大丈夫です。

 あとは階層的クラスタリングのオプションが色々あります。それはそれで大切なものだと思いますが、今回は無視することにします。

実験

 もう何番煎じかわかりませんが、irisの2次元写像で試します。

import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import FeatureAgglomeration

def main():
    iris = load_iris()

    pca = PCA(n_components=4)
    ss = StandardScaler()
    agg = FeatureAgglomeration(n_clusters=2)

    pca_X = pca.fit_transform(iris.data)
    agg_X = agg.fit_transform(
        ss.fit_transform(iris.data))

    print(pca.components_)
    print(agg.labels_)

    fig, axes = plt.subplots(nrows=1, ncols=2)
    axes[0].scatter(pca_X[:,0], pca_X[:,1], c=iris.target)
    axes[0].set_title("PCA")
    axes[1].scatter(agg_X[:,0], agg_X[:,1], c=iris.target)
    axes[1].set_title("FeatureAgglomeration\n{}".format(agg.labels_))
    plt.savefig("result.png")

if __name__ == "__main__":
    main()

 動作原理、目的と用途を考えると、事前にスケーリングしておいた方が恐らく無難です。

 printされた出力。

[[ 0.36138659 -0.08452251  0.85667061  0.3582892 ]
 [ 0.65658877  0.73016143 -0.17337266 -0.07548102]]
[0 1 0 0]

 FeatureAgglomerationは圧倒的に結果の解釈性が良いことがわかります。写像先の0次元目は元の0,2,3次元目の平均で*1、写像先の1次元目は元の1次元目ですね。こういうのはシチュエーション次第ですが、ちょっと嬉しいかもしれません。

 出力される画像。

プロットの結果
プロットの結果

 概ねPCAと同等に使えています。うまく言葉で表現はできませんが、FeatureAgglomerationの方はなんとなくギザギザ感?みたいなものがあります。平均するとそうなる、というのがなんとなくわかる気もするし、わからない気もする。

考察

 結果の解釈性が良いのと、まがりなりにすべての特徴量の情報が結果に反映されるので、PCAより使いやすいシチュエーションはあると思います。分類前の次元削減とかで使ったときの性能とかは今回検討していませんが、たぶんそんなに良いということはないはず。

 あとドキュメントをあさっていたら、こんなページがあったので、

Feature agglomeration — scikit-learn 0.20.1 documentation

 真似してPCAでも同じものを出してみたら(コードはほとんど書き換えていないので省略。agglo = の行で代入するモデルをコメントアウトで切り替えて、あとlabels_の出力を外しただけです)、やっぱりFeatureAgglomerationはヘボかった(低次元で元の情報を保持することに関しては性能が低かった)です。

 10次元に落として元の情報をどこまで復元できるかという実験。

PCA
PCA

FeatureAgglomeration
FeatureAgglomeration

 まあ、これは仕方ないか。

まとめ

 とにかく結果の解釈性の良さを活かしたい、とか、なにか特別な理由があって使う分には良いと思います。

*1:厳密にはどれか2つが先に平均されて、更に残りと平均されるはず。つまり3つの比重が違う順番はチェックしていないのでわかりませんが、children_属性をちゃんと読み取ればわかると思います

【python】複数の条件を総なめするときの簡略化

 たとえば、こういうものを書きたいとする。

def f(a, b):
    if a == "0" and b == "0":
        print("a:0, b:0")
    elif a == "0" and b == "1":
        print("a:0, b:1")
    elif a == "1" and b == "0":
        print("a:1, b:0")
    elif a == "1" and b == "1":
        print("a:1, b:1")

 条件式がいかにも冗長。
 (ifをネストすれば良い、printも冗長、フォーマット文字列使えば良い、あるいは辞書に入れて結果の表示を分岐させれば良いというツッコミは受け付けないことにする。実際には処理内容の分岐が入ることを想定する。)

 こういうときは、シーケンスにして値比較するとスマート。文字列だと最も簡単なのはこれ。

def f(a, b):
    c = a + b
    if c == "00":
        print("a:0, b:0")
    elif c == "01":
        print("a:0, b:1")
    elif c == "10":
        print("a:1, b:0")
    elif c == "11":
        print("a:1, b:1")
||< 

 もう少し汎用的な型を想定するなら、listかtupleに入れると良い。

>|python|
def f(a, b):
    c = [a, b]
    if c == ["0", "0"]:
        print("a:0, b:0")
    elif c == ["0", "1"]:
        print("a:0, b:1")
    elif c == ["1", "0"]:
        print("a:1, b:0")
    elif c == ["1", "1"]:
        print("a:1, b:1")

 コードがシンプルになり、可読性も向上した。