k-近傍法

k-近傍法 (k-nearest neighbor) とは

k-近傍法は、分類したいデータに似ている既存のデータをk個見て、そのk個のなかで最も多いラベルを多数決的に採用する方法です。似ている似ていないの判断基準が必要ですが、単純なものとしてはデータ間のユークリッド距離を使います。既存のデータが多い場合、最も近いk個を選ぶのに探索が必要で、データ分類に時間がかかります。

以下の図では、? のデータを分類するために、k=3個のデータを見ています。△のラベルが2個、□のラベルが1個なので、? のデータを△と推定します。

k-近傍法の実装

scikit-learn を利用してk-近傍法を実装しましょう。k-近傍法の使い方は scikit-learn のページの Examples に書かれています。ノートブックで使えるように転記したものが以下になります。

X = [[0], [1], [2], [3]]
y = [0, 0, 1, 1]
from sklearn.neighbors import KNeighborsClassifier
neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(X, y)
print(neigh.predict([[1.1]]))
print(neigh.predict_proba([[0.9]]))

主に以下の4ステップで構成されています。

  1. この例では以下が既存のデータとして与えられていて、未知のデータのラベルの推定に利用します。
    • データ X: 0, 1, 2, 3
    • ラベル y: 0, 0, 1, 1
  2. 機械学習のアルゴリズムのうち、k-近傍法 KNeighborsClassifier を呼び出します。n_neighbors=3 は、近い3個のデータでラベルを判断する設定です。
  3. fit を呼び出すと与えられたデータに対して学習を行います。k-近傍法の場合、近いデータを探してラベルを決めるモデルが出来上がります。
  4. 未知データを推定します。ラベルを推定する関数は predict です。predict_proba はラベルごとに推定の確率を返します。確率が大きいものをラベルとして推定することもできます。

Practice 4

  • 上記のコードを Jupyter Notebook で実行してみましょう。
  • 0.9 に対して、関数 predict_proba を実行したとき、どのような確率が返ってきたか確認しましょう。
  • X のなかで、0.9 に近い3つの数字はどれでしょう。そのうち、0のラベルと1のラベルは何個含まれているでしょうか。
  • k=3 に含まれるラベルの数の割合と、確率の関係はどういう関係でしょうか。

k-近傍法の MNIST への適用

上の例を少し書き換えて、前の章で準備した MNIST のデータに k-近傍法を適用してみましょう。 データが変わっているだけなので、学習に使う X, y や推論したい 0.9 や 1.1 を MNISTのものに変えます。 具体的には以下の通りに置き換えます。

  • [[0], [1], [2], [3]] ← X_train
  • [0, 0, 1, 1] ← y_train
  • 0.9, 1.1 ← X_test[0]

※ 0.9 や 1.1 の代わりに、テストデータ X_testの最初の画像 X_test[0] を使います。

書き換えたコードは以下の通りです。

X = X_train
y = y_train
from sklearn.neighbors import KNeighborsClassifier
neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(X, y)
print(neigh.predict([X_test[0]]))
print(neigh.predict_proba([X_test[0]]))

Practice 5

上のコードでは、X_test[0]、つまりテストデータの最初の画像に対する数字を予測しました。これが実際に正しいかを確認するために、最初の画像に対するラベルを表示して比較しましょう。

ヒント: テストデータの最初の画像のラベルは y_test[0] でアクセス可能です。

推論時間の測定

冒頭の説明で述べたようにk-近傍法は膨大なデータを探索するので予測に時間がかかります。ここでは100枚の画像に対する予測の時間を計測してみましょう。 セルの最初に %%time といれるとそのセルの実行時間を計測することができます。以下では1枚目の画像のみを予測して、その時間を測ります。 Wall time が実行時間に相当します。CPU times は CPU が利用された時間です。

%%time
neigh.predict(X_test[0:100])