ビジネスパーソン・ガジェット置場 empty lot for business

営業や仕事、それに伴う生活を便利に楽にするツール、ガジェットを作ります。既にあるツールも自分用にカスタマイズ。

ディープラーニング: フレームワークDeZeroで画像認識をする方法②

初めてのディープラーニングで画像認識を行っています。最初の検証結果は訓練データに関しては良い結果となるも、テストデータでその半分程度の正解率となってしまい、完全に過学習が起きてしまいました。今回は、その過学習の抑制を目指し画像を増やしてみた結果の記事です。

ディープラーニングK-POPのiveのメンバーの画像を認識させる企画の2回目です。前回の記事はこちらになります。

 

ディープラーニング: フレームワークDeZeroで画像認識をする方法① - ビジネスパーソン・ガジェット置場 empty lot for business

 

過学習の抑制

前回の実証結果では下記のように、テストデータに対する損失率と正解率が悪く、結果としてiveのメンバーが正確に認識できていないという結果になりました。原因としては過学習が起きているということが考えられます。

 

過学習が起きる原因

過学習が起きる原因として考えられるのが、下記などが挙げられます。

1. 訓練データが少ないこと

2. モデルの表現力が高すぎること

(ゼロから作るDeep Learning③より)

 

今回はその内の1. 訓練データが少ないことについて焦点を当てて過学習の抑制ができるか実験していきます。

 

訓練データ数を増やす

前回の実装では、訓練データはIveのメンバー1人につき画像を50枚ずつ用意しました。メンバーが6人なので50x6で300枚の画像になります。

 

今回は各メンバーの訓練画像を30枚ずつ増やし、80枚にしてみました。

 

増やし方は、新しい画像ではなく、既存の画像をPILで変形させたものになります。

  • 角度変更 +10枚
  • ミラー +10枚
  • 白黒 + 10枚

※画像を増やすコードは記事下部に記載いたします。

 

今回の実装結果

訓練データの枚数を増やした以外、モデルやコードは前回と同じです。画像を増やした分今回の学習時間は15分程度かかりました。

 

そして、その結果はというと。。

 

 

 

結論で言うとやはり訓練データに比べてテストデータの正解率や損失率は変わらず悪い結果になりました。

可視化すると以下になります。

 

 

実際に画像認識してみます。

こちらの画像が誰なのか認識できるか(答えはユジンです)

認識できました!!

 

しかし、前回のリベンジリズが認識できるのか

だめでした。。。。

 

やはり、過学習を抑えて正解率をもっと上げていかないとだめそう。。

次回に続きます。

 

 

今回のコード(画像を増やす)

from PIL import Image, ImageOps
import glob
import numpy as np

d_name = ['イソ', 'ウォニョン', 'ガウル', 'ユジン', 'リズ', 'レイ']

for name in d_name:
    files = glob.glob(f'data/train/{name}/*')
    for i in range(10):
        img = Image.open(files[i])
        img = img.convert('RGB')
        new_img = img.copy().rotate(np.random.randint(-60, 60))
        new_img.save(f'data/train/{name}/{name}_r_{i}.jpg', quality=95)
    for i in range(10, 20):
        img = Image.open(files[i])
        img = img.convert('RGB')
        new_img = img.copy()
        new_img = ImageOps.mirror(new_img)
        new_img.save(f'data/train/{name}/{name}_m_{i}.jpg', quality=95)
    for i in range(20, 30):
        img = Image.open(files[i])
        img = img.convert('RGB')
        new_img = img.copy().convert('L')
        new_img.save(f'data/train/{name}/{name}_l_{i}.jpg', quality=95)
                                    

png画像が含まれているので'RGB'に一旦コンバートしています。