rueki

Generative Adversarial Nets (GAN) 논문, 코드 리뷰 본문

DL

Generative Adversarial Nets (GAN) 논문, 코드 리뷰

륵기 2020. 12. 3. 22:39
728x90
반응형

14년에 이안 굿펠로우가 처음 제안한 개념으로써 한글로 흔히 생성 및 적대적 모델이라고 얘기를 한다.

여기서는 크게 두 가지 모델을 제시하는데 Generative model과 Discriminator model 이다.

생성 모델과 판별모델인데 이는 아래의 그림으로 설명을 하고자 한다.

논문에서는 경찰과 위조 지폐범으로 개념을 소개하였다.

 

지폐 위조범은 지폐를 위조해서 세상에 내놓으려고 하는데, 지폐위조범에게 가장 좋은 상황은 무엇일까?

이는 실제 지폐와 위조지폐가 사람들이 구분을 못하는 것이다. 즉 실제 data와 유사한 data를 생성하는 것이 Generator의 목적인 것이다. 이는 역전파가 성능에 큰 역할을 하였다고 한다.

경찰은 여기서 지폐를 판별하는 Discriminator의 역할을 하는데 얘의 목적은 지폐를 정확히 구분하고자 하는 것이다.

 

정규분포 형태를 띄는 것이 실제 데이터고 초록색 실선이 생성 함수 분포, 점선이 Discriminator에 대한 분포인데 A에서 D로 가면서 생성함수와 실제 데이터 분포 함수가 일치하게 되는 것이 학습의 목적이며 판별 함수는 d에서 보면 일직선으로, 즉 1/2로 수렴한 형태인데 이는 둘중에 뭐가 맞는지 모르는 상황, 50 대 50의 확률을 나타낸다.

 

Gan의 Value function

logD(x)에서 D(x)는 1로 maximize하고, log(1-D(G(z)))에서 D(G(z))는 0으로 minimize해야한다.

+를 중심으로 앞에는 real data의 분포, 뒷부분은 생성 data의 분포에 대한 내용인데, log(1-D(G(z)))에서 D(G(z)),

즉 gaussian 분포 z로 생성한 데이터를 판별하는 것이 0이되게 되면 log(1)이 나오는데 이는 결과값이 결국에 0이 된다.

D의 입장에서 보면 결과적으로 D(x) = 1이 되기에 잘 판별한다고 얘기할 수가 있다.

 

Discriminator는 손실함수를 통해 minimize하려하고, Generator는 maximize하려고 하는데, 처음에는 discriminator가 잘 판별을 못 하기에 loss가 크게 나오지만, 학습할 수록 0에 가까워지게 된다고 한다.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)


#hyper parameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4 #learning rate
z_dim = 64
image_dim = 28 * 28 * 1
batch_size = 32
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307,),(0.3081,))]
)

dataset = datasets.MNIST(root='dataset/', transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#optimizer, 손실함수
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

#tensorboard
writer_fake = SummaryWriter(f'runs/GAN_MNIST/fake')
writer_real = SummaryWriter(f'runs/GAN_MNIST/real')

step = 0

#훈련
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1,784).to(device)
        batch_size = real.shape[0]

        ### train discriminator : max log(D(real)) + log(1-D(G(z)))
        #latent vector 선언
        noise = torch.randn(batch_size, z_dim).to(device)
        #가짜 생성
        fake = gen(noise)

        #real , real의 손실
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        #fake, fake 손실
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        #전체 손실
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        #train generator min log(1- D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx ==0:
            print(f'epoch : {epoch}/{num_epochs} // Loss D : {lossD : .4f}, loss G : {lossG :.4f}')


            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1,1,28,28)
                data = real.reshape(-1,1,28,28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake images", img_grid_fake, global_step=step
                )

                writer_real.add_image(
                    "Mnist Fake images", img_grid_real, global_step= step
                )
                step += 1
728x90
반응형
Comments