2021-4-30

numpyのchoice関数を読む

javascriptは標準で重み付きのchoice関数が無いのでライブラリを探すか自分で作る必要がある。折角なのでpythonの勉強がてらnumpyのchoice関数を読んだ。 該当箇所のソースコードを引用しつつ説明を加えていく。

概観

ざっくり2つの引数(replace, p)によって挙動が変わっている。replaceは復元抽出か非復元抽出かを表していてTrueなら復元抽出になる。 pは各要素の選択確率を表していて指定されなかった場合には等確率で選択される。 「replaceがTrueかFalseか」×「pが指定されたかされてないか」の4パターンで異なるロジックが用意されていた。

復元抽出の場合

python
Copied!
if replace:
    if p is not None:
        cdf = p.cumsum()
        cdf /= cdf[-1]
        uniform_samples = self.random(shape)
        idx = cdf.searchsorted(uniform_samples, side='right')
        # searchsorted returns a scalar
        idx = np.array(idx, copy=False, dtype=np.int64)
    else:
        idx = self.integers(0, pop_size, size=shape, dtype=np.int64)

まずは復元抽出の場合pは各要素の抽出確率を表す1次元配列。pが指定された場合=各要素の抽出確率が指定された場合には逆関数法的な発想で抽出を行っている。 cdfが累積確率、uniform_samplesが一様分布から生成された乱数を表している。2分探索(searchsorted)を使って生成したuniform_samplesに対応する要素のインデックスをcdfから探しているだけのシンプルなロジック。 連続の場合なら逆関数法そのものだけど離散的な場合にも逆関数法と呼ぶのだろうか?少し調べるとmultinominal resampling(多項リサンプリング)という単語がヒットしたが一般的かどうかわからない。

pが指定されなかった場合には何も考えずにランダムな整数の配列を生成して、それをインデックスとして利用することで全ての要素を等しい確率で抽出する。

非復元抽出の場合

python
Copied!
if p is not None:
    if np.count_nonzero(p > 0) < size:
        raise ValueError("Fewer non-zero entries in p than size")
    n_uniq = 0
    p = p.copy()
    found = np.zeros(shape, dtype=np.int64)
    flat_found = found.ravel()
    while n_uniq < size:
        x = self.random((size - n_uniq,))
        if n_uniq > 0:
            p[flat_found[0:n_uniq]] = 0
        cdf = np.cumsum(p)
        cdf /= cdf[-1]
        new = cdf.searchsorted(x, side='right')
        _, unique_indices = np.unique(new, return_index=True)
        unique_indices.sort()
        new = new.take(unique_indices)
        flat_found[n_uniq:n_uniq + new.size] = new
        n_uniq += new.size
    idx = found

次に非復元抽出の場合。多次元のインプットに対応しているせいでかなりわかりにくい。詳しくは分からないがベースのアイディアは復元抽出と同じみたい。確率の累積和を計算⇒一様分布から生成された乱数を2分探索とベースは同じ。違うのは重複した結果が含まれていた場合。newのユニークな要素を抽出して結果に足し込む+次回以降サンプルされないように当該要素がサンプルされる確率を0にする、という処理を結果が指定された要素数になるまで繰り返している様子。

非復元抽出×pが指定されなかった場合の処理は読み解けなかった。コメント欄にFloyd's Algorithmというメモがあったが検索かけてもワーシャルフロイド法しか見つからない。関係があるのだろうか。

その他

issueを調べると「もっと効率いい手法あるのでは?」というコメントが入っている。それに対して「何回も言ってるけど効率の良い手法は殆ど入力が一次元の場合にしか対応してないだろ!numpyは多次元の入力にも対応する必要があるんじゃ!」とレスが付いてて、なるほどとなった。効率が良い手法を知りたいのであれば別のソースを読むべきだったかもしれない。