- はじめに -
半教師あり学習のアルゴリズムの1種であるlabel propagationをRustで実装し、クレートとして公開した。
本記事は、label propationの実装と検証を行った際のメモである。
- label propagationとは -
label propagationは、transductive learningの枠組みの1つでもあり、グラフ構造を利用した機械学習アルゴリズムである。
ラベルがあるデータ、ラベルのないデータ、それらを繋ぐエッジがある状態で、ラベルのないデータに付くラベルを推定する事が解きたいタスクとなる。
最もシンプルな実タスクとして例示すると「文書データ等で一部のデータにはラベルがあるが一部欠損している所を推定したい」「ユーザとアイテム、それらを繋ぐPV等のエッジがあり、アイテムにのみラベルがある状態でユーザにもラベル付けを行いたい」といった状況が想定できる。
近年ではCVPR 2019でEmbeddingによる距離をノードとして画像ラベルを推定して利用する手法*1が採択されるなどしており、汎用的なアルゴリズムの1つである。近いワードとしては、tag recommendationなどがあり、PageRankアルゴリズムを利用した手法*2やCollaborative filteringを拡張する手法*3が提案されている他、Content baseな方法もまた考えられる。
実際エムスリーではtag propagationを利用したtag伝搬を用いてユーザのタグ付けを行い、様々な配信のセグメント分けや分析に利用している*4。ハイパーパラメータが少なく、グラフ生成部及び内部の行列計算手前までをオンライン化する事ができ、汎用性が高く安定した結果を得られる所が良いところである。
label propagationの問題設計は、 をラベル付きデータにのC個のラベルが付与されていた時、そこから観測できないのデータに紐付いたを推定する事にある。データ間の重みは、古典的にユークリッド距離とハイパーパラメータを用いて簡素に以下のように表現される。
これは最も簡素な例で、距離に関しても時に離散的な距離であったりDNNのEmbeddingから得られる距離であったりする。を作るためには、の確率遷移行列を作ってやればよい。
行列の最適化のためのアプローチは、いくつか方法があるが、概ね以下が詳しい。
- scikit-learn準拠で Label propagation とか実装した - でかいチーズをベーグルする
- Community detection with the Label Propagation Algorithm: A survey - ScienceDirect
ベースは、グラフ上で隣接するノードは同じラベルを持つ可能性が高い、という所に基づいて設計した目的関数を最小化することでweight行列を最適化する。「隣接ノードが同じラベル」の閾値をパラメータや推論によってコントロールする拡張が主である。
label propagationは、Pythonではsklearn内にも実装されており、簡易に呼び出す事ができる。
sklearn.semi_supervised.LabelPropagation — scikit-learn 0.24.2 documentation
よりグラフィカルな解説は以下が参考になる。オススメ。
- Label Propagation Demystified. A simple introduction to graph-based… | by Vijini Mallawaarachchi | Towards Data Science
- How to get away with few Labels: Label Propagation | by Dr. Robert Kübler | Towards Data Science
- Rustによる実装 -
先に示した通り、確率遷移行列を作って最小化できれば良いので、行列演算を行う事になる。
今回はndarrayを利用して実装している。rust/ndarrayのドキュメント内にnumpyからの移行のススメがあるので、基本的にはここを参照すると良い。
numpyにはadvanced-indexingという機能がある。
Indexing — NumPy v1.21 Manual
こういうやつ
x = np.array([0, 1]) y = np.array([[0, 0], [0, 0], [0, 0]]) y[x] = 1 # array([[1, 1], # [1, 1], # [0, 0]])
rustのndarrayでは、現状実装されていないのでslice_mutで指定インデックスごとにスライスを作ってfillterによる代入を行う必要がある。
for i in x { y.slice_mut(s![*i, ..]).fill(1); }
機械学習で行列を扱う時は大体スパースな事が多く、実装としてsparse matrixを使う事が多い。現状ndarrayにはsparse matrixに類似するものは実装されていなさそう。同じく行列演算を趣旨としたnalgebraにはnalgebra_sparse::csr::CsrMatrixがあるが、こちらはdot積などが実装されていない。
なのでArrayBaseで押し切る実装になってしまった。メモリに優しくない。linfaなど、一部ライブラリで独自にsparse matrixを実装しているものもあるが、クレート依存が激しい。
以下のクレートを試してみてはという助言を貰ったので検証中ではある。
github.com
この辺何か良い方法があるんだろうか。知っている人居れば教えて欲しい。
上記以外はdot積と行列変換が扱えれば良いのでndarrayで十分実装できる。
- 検証 -
irisデータセットを利用して、一部のラベルを欠損、各データのユークリッド距離をエッジと考えて、label propagationにより欠損ラベルを推論する。
公開したlabel-propagation-rsには、label propagationの派生アルゴリズムとして、LGCとCAMLPを実装しており、検証にはCAMLPを利用した。
Rustにおけるsklearnのような立ち位置になるライブラリであるsmartcoreよりirisデータセットを読み込んで行列を作る。閾値としてユークリッド距離の逆数が0.5以下になっている場合はエッジを繋がないものとする。
... let iris = iris::load_dataset(); let node = (0..iris.num_samples).collect::<Array<usize, _>>(); let mut label = Array::from_shape_vec(iris.num_samples, iris.target.iter().map(|x| *x as usize).collect())?; let mut graph = Array::<f32, _>::zeros((iris.num_samples, iris.num_samples)); let data = Array::from_shape_vec((iris.num_samples, iris.num_features), iris.data)?; for i in 0..iris.num_samples { for j in 0..iris.num_samples { if i != j { let weight = 1. / (*&data.slice(s![i, ..]).sq_l2_dist(&data.slice(s![j, ..]))? + 1.); // reciprocal if weight > 0.5 { graph[[i, j]] = weight; } } } } ...
ざっくり10個ターゲットを選んで、ノードに付与されたラベルを0にする。irisのラベルは0,1,2のどれかなので「ランダムにあるラベルが0になってしまった」という状況になる。
... let target_num = 10; let mut rng = thread_rng(); let target = (0..iris.num_samples).choose_multiple(&mut rng, target_num).iter().map(|x| *x).collect::<Array<usize, _>>(); for i in &target { label[*i] = 0; } ...
モデルを学習させて、上記で0にしたターゲットのlabelを推定する。
... let mut model = CAMLP::new(graph).iter(100).beta(0.1); model.fit(&node, &label)?; let result = model.predict_proba(&target); for (i, x) in target.iter().enumerate() { println!("node: {:?}, label: {:?}, result: {:?}", *x, iris.target[*x], result.slice(s![i, ..]).argmax()?); } ...
結果は以下のようになった。
node: 0, label: 0.0, result: 0 node: 14, label: 0.0, result: 0 node: 67, label: 1.0, result: 0 node: 118, label: 2.0, result: 2 node: 43, label: 0.0, result: 0 node: 144, label: 2.0, result: 2 node: 91, label: 1.0, result: 1 node: 137, label: 2.0, result: 2 node: 49, label: 0.0, result: 0 node: 62, label: 1.0, result: 1
node 67のみ、真のラベルが1に対して推論ラベルが0となってしまっているが、それ以外は正解している。良い感じ。実際どういったデータで各metricでどの程度の精度が出るかはこれから検証していく。
上記の検証コードはexample内にある。
label-propagation-rs/examples at main · vaaaaanquish/label-propagation-rs · GitHub
*1: A. Iscen, G. Tolias, Y. Avrithis, O. Chum. "Label Propagation for Deep Semi-supervised Learning", CVPR 2019 https://openaccess.thecvf.com/content_CVPR_2019/papers/Iscen_Label_Propagation_for_Deep_Semi-Supervised_Learning_CVPR_2019_paper.pdf, github: https://github.com/ahmetius/LP-DeepSSL
*2:Heung-Nam Kim and Abdulmotaleb El Saddik. 2011. Personalized PageRank vectors for tag recommendations: inside FolkRank. In Proceedings of the fifth ACM conference on Recommender systems (RecSys '11). Association for Computing Machinery, New York, NY, USA, 45–52. DOI:https://doi.org/10.1145/2043932.2043945
*3:Kim, Heung-Nam, et al. "Collaborative filtering based on collaborative tagging for enhancing the quality of recommendation." Electronic Commerce Research and Applications 9.1 (2010): 73-83. https://www.sciencedirect.com/science/article/pii/S1567422309000544 等