MENU

【Python 機械学習】GANでMNISTデータセットを使い数字を学習

目次

概要

GAN(Generative Adversarial Network、生成的対向ネットワーク)は、2014年にイアン・グッドフェローらによって提案された、データを生成するための機械学習モデルの一種です。GANは「生成モデル」と「識別モデル」という2つのモデルが互いに競争しながら学習する仕組みで、以下のように動作します。

生成モデル(ジェネレーター)
ジェネレーターは、ランダムなノイズから「偽のデータ」を生成します。この偽データは、例えば画像生成のGANでは本物の画像のように見えることを目指して作成されます。

識別モデル(ディスクリミネーター)
ディスクリミネーターは、本物のデータとジェネレーターが作った偽のデータを見分ける役割を持ちます。ディスクリミネーターは「本物」と「偽物」を区別できるように学習します。

本記事では実際に手書き数字のMNISTデータセットを学習させてどのようになるか確認します。

コード

# 必要なライブラリのインポート
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os  # フォルダ作成用

# デバイスの設定 (GPUが利用可能な場合はGPU、そうでない場合はCPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# データ変換の定義
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# MNISTデータセットのダウンロードとデータローダーの設定
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Generatorの定義
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Discriminatorの定義
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# ハイパーパラメータの設定
z_dim = 100  # ノイズの次元数
img_dim = 28 * 28  # MNIST画像の次元数(28x28=784)
learning_rate = 0.0002

# GeneratorとDiscriminatorのインスタンス化、デバイスへの移動
generator = Generator(input_dim=z_dim, output_dim=img_dim).to(device)
discriminator = Discriminator(input_dim=img_dim).to(device)

# 最適化アルゴリズム
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

# 損失関数
criterion = nn.BCELoss()

# 画像保存関数(各エポックごとに10枚保存)
def save_generated_images(epoch, num_images=10, folder='generated_images'):
    os.makedirs(folder, exist_ok=True)  # フォルダが存在しない場合は作成
    z = torch.randn(num_images, z_dim).to(device)  # num_images個のノイズベクトルを生成
    fake_imgs = generator(z).view(-1, 1, 28, 28).to("cpu")  # 生成画像を整形し、CPUに戻す
    fake_imgs = (fake_imgs + 1) / 2  # ピクセル値を[0, 1]にスケール
    
    grid = torchvision.utils.make_grid(fake_imgs, nrow=5)  # 5x2のグリッドで表示
    file_path = os.path.join(folder, f"epoch_{epoch+1}.png")  # ファイルパスを設定
    plt.figure(figsize=(10, 4))
    plt.imshow(grid.permute(1, 2, 0).detach().numpy())
    plt.axis('off')
    plt.savefig(file_path)  # 画像をファイルに保存
    plt.close()  # プロットを閉じる

# トレーニングの設定
num_epochs = 50

for epoch in range(num_epochs):
    for real_imgs, _ in train_loader:
        # データをデバイスに移動
        real_imgs = real_imgs.view(real_imgs.size(0), -1).to(device)
        real_labels = torch.ones(real_imgs.size(0), 1).to(device)
        fake_labels = torch.zeros(real_imgs.size(0), 1).to(device)

        # Discriminatorのトレーニング(本物と偽物を識別)
        optimizer_D.zero_grad()
        
        # 本物の画像の損失
        outputs = discriminator(real_imgs)
        d_loss_real = criterion(outputs, real_labels)
        
        # 偽物の画像の損失
        z = torch.randn(real_imgs.size(0), z_dim).to(device)  # ノイズもデバイスに移動
        fake_imgs = generator(z)
        outputs = discriminator(fake_imgs.detach())  # detachでGeneratorの勾配は更新しない
        d_loss_fake = criterion(outputs, fake_labels)
        
        # Discriminatorの合計損失と更新
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()
        
        # Generatorのトレーニング(偽物を本物と誤認させる)
        optimizer_G.zero_grad()
        outputs = discriminator(fake_imgs)
        g_loss = criterion(outputs, real_labels)
        
        # Generatorの損失と更新
        g_loss.backward()
        optimizer_G.step()
        
    print(f"Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}")
    
    # 各エポックごとに10枚の生成画像を保存
    save_generated_images(epoch, num_images=10, folder='generated_images')

結果

上記のとおりepochを50まで回した結果です。

Epoch [1/50], d_loss: 0.0566, g_loss: 10.4132
Epoch [2/50], d_loss: 0.0672, g_loss: 7.2213
Epoch [3/50], d_loss: 0.2609, g_loss: 3.6743
Epoch [4/50], d_loss: 0.3130, g_loss: 10.1963
Epoch [5/50], d_loss: 0.0203, g_loss: 8.1622
Epoch [6/50], d_loss: 0.1389, g_loss: 6.9700
Epoch [7/50], d_loss: 0.0043, g_loss: 10.8207
Epoch [8/50], d_loss: 0.0278, g_loss: 6.0416
Epoch [9/50], d_loss: 0.1329, g_loss: 7.9730
Epoch [10/50], d_loss: 0.2260, g_loss: 4.7434
Epoch [11/50], d_loss: 0.2017, g_loss: 6.4528
Epoch [12/50], d_loss: 0.1051, g_loss: 3.3526
Epoch [13/50], d_loss: 0.1734, g_loss: 5.9209
Epoch [14/50], d_loss: 0.3976, g_loss: 5.4625
Epoch [15/50], d_loss: 0.1002, g_loss: 6.7051
Epoch [16/50], d_loss: 0.1129, g_loss: 3.9182
Epoch [17/50], d_loss: 0.3493, g_loss: 3.2061
Epoch [18/50], d_loss: 0.1807, g_loss: 3.5808
Epoch [19/50], d_loss: 0.2168, g_loss: 3.3772
Epoch [20/50], d_loss: 0.5591, g_loss: 2.6476
Epoch [21/50], d_loss: 0.3900, g_loss: 3.8372
Epoch [22/50], d_loss: 0.3020, g_loss: 3.0853
Epoch [23/50], d_loss: 0.4162, g_loss: 3.6674
Epoch [24/50], d_loss: 0.1737, g_loss: 2.6148
Epoch [25/50], d_loss: 0.5323, g_loss: 2.9985
Epoch [26/50], d_loss: 0.6900, g_loss: 3.8382
Epoch [27/50], d_loss: 0.6474, g_loss: 2.1668
Epoch [28/50], d_loss: 0.3230, g_loss: 2.4905
Epoch [29/50], d_loss: 0.9136, g_loss: 1.9325
Epoch [30/50], d_loss: 0.2749, g_loss: 3.3031
Epoch [31/50], d_loss: 0.8904, g_loss: 2.5821
Epoch [32/50], d_loss: 0.5742, g_loss: 1.7155
Epoch [33/50], d_loss: 0.7564, g_loss: 2.0241
Epoch [34/50], d_loss: 0.5066, g_loss: 2.5349
Epoch [35/50], d_loss: 0.6034, g_loss: 1.8526
Epoch [36/50], d_loss: 0.8636, g_loss: 2.2440
Epoch [37/50], d_loss: 0.5999, g_loss: 2.2265
Epoch [38/50], d_loss: 0.7234, g_loss: 2.2666
Epoch [39/50], d_loss: 0.5331, g_loss: 1.8005
Epoch [40/50], d_loss: 1.0194, g_loss: 2.7950
Epoch [41/50], d_loss: 0.7151, g_loss: 1.6714
Epoch [42/50], d_loss: 0.6679, g_loss: 1.5930
Epoch [43/50], d_loss: 1.0183, g_loss: 2.0514
Epoch [44/50], d_loss: 1.1098, g_loss: 1.1370
Epoch [45/50], d_loss: 0.9919, g_loss: 1.5936
Epoch [46/50], d_loss: 0.9348, g_loss: 1.6760
Epoch [47/50], d_loss: 0.7383, g_loss: 1.7330
Epoch [48/50], d_loss: 0.8078, g_loss: 1.9343
Epoch [49/50], d_loss: 0.7770, g_loss: 1.9625
Epoch [50/50], d_loss: 0.5050, g_loss: 2.1323

epochがすすむごとに数字がだんだんと認識できるようになり、無事手書き数字を生成できました。

また、今回はテンソルボードを起動させていないです。そのためわかりづらいですがGANのlossは生成モデルと識別モデルが拮抗する関係で変動します。

まとめ

今回はGANについてご紹介させていただきました。

GANは非常にリアルな画像生成や音声生成に成功し、現在ではアートの自動生成やフェイク画像・動画の作成、医療や自動運転など多くの分野で利用されています。

画像、テキスト、音声など汎用性が高い学習方法のため、引き続き理解していき覚えたいです。

よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!

この記事を書いた人

プログラミングをそれとなく続けてきて歴だけは10年。
コーディングは基本的な命令文とクラスの概念は理解。
あとはライブラリなどを使ってそれとなく。
最近はAI関連を触ってます。

目次