カラー画像の分類

モノクロ画像とカラー画像の違い

これまで扱ってきた MNIST の画像は白黒のモノクロ画像でした。しかし世の中の画像の多くはカラー画像です。ここでは、カラー画像を実際に分類してみようと思います。モノクロ画像は白黒なので、各画素がどれだけ白いか(黒いか)の1つの情報しか持っていません。カラー画像は、聞いたことのある方もいらっしゃるかもしれませんが、R (赤)、G (緑)、B(青) の3色の強さで色を表しています。つまり、カラー画像はモノクロ画像の3倍の情報をもっています。この1色や3色のことをカラーチャネル(もしくは単にチャネル)と呼びます。

MNISTの画像は縦28px, 横28pxのモノクロ画像でしたので、そのデータは (28,28,1)の配列で表現されていました。カラー画像の場合は、(H, W, 3) の形をとります。Hは縦のピクセル数、Wは横のピクセル数です。

カラー画像データセット cifar10

ここでは cifar10 というカラー画像のデータセットを扱います。scikit-learn でデータをロードしたように、keras の関数で簡単にロードできます。

(X_train_cl,y_train_cl),(X_test_cl,y_test_cl) = tf.keras.datasets.cifar10.load_data()

実際に最初の画像を表示してみましょう。

plt.imshow(X_train_cl[0])

解像度が低くてよくわからないですね。y_train_cl[0] を表示すると、この画像のラベル(インデックスの数字)がわかります。 cifar10のページ で、オブジェクトが airplane, automobile, … と並んでいるので、これを 0, 1, 2…と数えていって、このインデックスに相当するものを探すと、何が写っているか明らかになります。

畳み込みニューラルネットワークの実装

さきほど利用した畳込みニューラルネットワークを利用してカラー画像を分類してみましょう。ただ、メモリの関係で、今回は全ての画像を読み込むのが難しいかもしれません。学習データを最初の10,000枚に絞り込みましょう。あわせて画像データが整数になっているので、keras で扱えるよう floatに変換します。

X_train_cl = X_train_cl[:10000].astype(float)
y_train_cl = y_train_cl[:10000]

それでは畳込みニューラルネットワークを実装しましょう。注意点は MNIST のデータ形式が (28,28,1) だったのに対して、cifar10は異なる形式となる点です。まずはその形式を shape で調べてみましょう。

X_train_cl[0].shape

Practice 14

shape がわかったら、以下の MNIST 用の畳込みニューラルネットワークのコードを変えて実行してみましょう。変更点は以下のとおりです。

  • input_shape を cifar10用にします。
  • fit に与えるデータ (X_train_cl と y_train_cl に cifar10のデータを読み込んでいます)
import tensorflow as tf

input_shape =(28,28,1)
model = tf.keras.models.Sequential([
    tf.keras.layers.Reshape(input_shape),
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10)
    ])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

model.fit(X_train, y_train, epochs=5)

Practice 15

学習が終わったら、model.evaluate を利用して、テストデータで評価してみましょう。MNISTのデータでは以下のコードを利用したので、これを cifar10用に書き換えましょう。

model.evaluate(X_test,  y_test, verbose=2)

テストデータに対する精度と、学習時に表示されていた精度を見比べるとどうでしょうか? データが多様になればなるほど、テストデータに未知な情報が含まれて、学習時の精度を達成することが難しくなります。 これを克服するためには、より高度なアルゴリズムを利用したり、多くのデータを利用したりする方法があります。