Post

Generative model 기초 1. GAN 정리


Machine Learning Task는 크게 두 종류로 분류해볼 수 있다.

GAN_1 이미지 출처:https://www.turing.com/kb/generative-models-vs-discriminative-models-for-deep-learning

  • Discriminative model: 데이터 $X$가 주어졌을 때 라벨 $Y$가 나타날 조건부 확률 $p(Y|X$)를 directly 캡처하는 모델, Decision Boundary를 학습 → Logistic Regression, SVM, classification, detection
  • Generative model: 데이터 $X$가 주어졌을 때 확률분포 $p(Y)$와 $p(X|Y)$를 모델링하고, Bayes’ rule을 사용해 $p(Y|X)$를 indiretly 캡처하는 모델. 데이터의 probability distribution을 학습 → Bayesian Network, Autoregressive model, GAN


이 중 Generative model을 살펴보려고 한다. Ian Goodfellow는 NIPS 2016 Tutorial에서 Generative model을 아래와 같이 분류했다.

GAN_2

해당 그래프를 보면 알 수 있듯 GAN은 Implicit density, 즉 데이터 분포(likelihood)를 직접 모델링하지 않고 적대적 학습이라는 부차적인 방법을 사용하여 이미지를생성한다. 그럼 GAN에 대해 정리해보자.



GAN(Generative Adversarial Network)


GAN은 Generative Adversarial Network의 약자로, 말 그대로 적대적 신경망이다.

Generator가 어떤 데이터 샘플(가짜 데이터)을 생성하면, Discriminator는 주어진 샘플이 가짜인지, 실제 샘플인지를 검사한다. Generator는 완벽한 가짜 데이터 샘플을 생성하여 Discriminator를 속이고, Discriminator는 진짜와 가짜 데이터를 더 잘 구분하도록 학습이 진행된다. 이 방법을 적대적 학습 방법이라고 부르는 것이다.

원본 논문에서 GAN을 이해할 때 사용했던 지폐위조범과 경찰 이야기로 이해해보자.

GAN_3

여기서 Generator는 지폐 위조범이다. 지폐위조범(G)는 가짜 화폐(G(z))를 최대한 진짜와 비슷하게 만들어 경찰을 속이는 것이 목적이다. 반면 Discriminator, 경찰(D)은 화폐를 보고 지폐위조범이 만든 가짜 화폐(G(z))인지, 아니면 진짜 화폐(x)인지 구분해내야한다.

경찰이 진짜와 가짜 화폐를 구분하는 능력을 키워가면, 지폐위조범도 경찰을 속이기 위해 가짜 화폐를 더 진짜와 비슷하게 만들도록 위조 실력을 키워간다. 그럼 다시 경찰은 또 그에 맞도록 판별 능력을 키워가고, 이렇게 번갈아가며 학습을 진행하는 것이 적대적 학습이다.



GAN Architecture


GAN의 컨셉을 이해했으니 다음으로는 GAN의 구조를 살펴보자.

GAN_4

  • Generator($G_{\theta}$)
    • Random noise vector($z$)를 입력으로 받아 실제 데이터 분포와 유사한 이미지를 생성한다.
    • 실제 데이터셋 분포를 모델링하여 비슷하게 만드는 것이 학습의 목적
  • Discriminator($D_{\phi}$)
    • Generator가 생성한 가짜 이미지$(G_{\theta}(z))$, 실제 이미지($x$)를 입력으로 받는다.
    • 입력으로 받은 이미지가 실제 이미지(1)인지, 가짜 이미지(0)인지를 예측하는 확률을 출력한다.



GAN Loss Function


GAN_5

GAN의 Loss Function은 위와 같다. Generator는 loss를 최소화하는 것을 목적으로 하고, Discriminator는 loss를 최대화하는 것을 목표로 하는 미묘한 min-max problem이다.

Discriminator Loss

GAN_6

Discriminator는 $D(x)$는 최대한 1(Real)에 가깝게, $D(G(z))$는 최대한 0(Fake)에 가까워지게 학습해야한다. Real or Fake를 판별하는 문제이므로 Binary Cross Entropy Loss를 사용한다.

Real 이미지를 정답으로 잘 예측했다면 $D(x) = 1$, $log(D(x)) = 0$ 이되고, 예측에 실패하여 Real을 Fake로 예측했다면 $D(x) =$ 0, $log(D(x)) = - ∞$ 가 된다.

반대로 Fake 이미지를 정답으로 잘 예측했다면 $log(1 - D(G(z))) = 0$ 이 되고, 잘못 예측하여 Fake를 Real로 예측했다면 $log(1 - D(G(z))) = - ∞$ 가 된다.


Generator Loss

GAN_7

Genertor는 전체 loss의 첫 번째 term과 관련이 없으므로 위와 같이 나타낼 수 있다. Generator는 $D(G(z))$를 1과 가깝게 만들어 loss를 최소화하는 것이 목표이다.

학습은 G와 D가 이렇게 한번씩 번갈아가면서 update 하는 형식으로 진행된다.



Limitation: Mode Collapse


Mode collapse는 GAN에서 가장 두드러지게 나타나는 한계점이다. Mode collapse란, Generator가 생성가능한 전체 범위의 다양한 이미지를 생성하지 못하고 제한된 출력을 생성하는 것을 말한다. Generator는 Discriminator를 속이기만 하면 되기 때문에 일부 샘플만 생성하여 Discriminator를 속이도록 학습하는 문제이다.

mode collapse가 발생하는 주된 원인 다양하지만, 가장 큰 요인은 학습의 불균형에 있다. Discriminator와 Generator라는 독립된 두 모델을 학습시켜야 하기 때문에 두 모델 중 한쪽만 먼저 학습이 잘 되는 경우에 특히 이 현상이 발생한다.

Generator 입장에서는 이미지 생성이라는 비교적 어려운 Task를 수행해야하고, Discriminator의 입장에서는 이미지의 Fake/Real만 판단하면 되기 때문에 Discriminator overfitting이 쉽게 일어난다. 이렇다 보니 학습 도중 학습 불균형이 발생한다.(물론 학습에 따라 반대의 경우도 발생한다)

추가적으로 Gradient vanishing의 원인이 있다. Generator를 업데이트하는 데 사용되는 gradient가 매우 작거나 0이 되어 효과적인 학습을 방해하는 경우, Generator가 전체 데이터 분포를 탐색하는 것을 막고 mode collapse가 발생한다.

예시를 통해 좀 더 자세히 살펴보자.
Mode란 통계학적으로 최빈값(가장 빈도가 높은 값), 관측치가 가장 높은 구역을 의미한다. 예를 들어 MNIST가 0-3까지의 class가 있다고 가정했을 때, 아래 그림과 같이 데이터의 분포를 표현 할 수 있다.

GAN_8 이미지 출처:http://dl-ai.blogspot.com/2017/08/gan-problems.html

위의 그림처럼 대부분의 데이터는 multi-modal형식(여러개의 mode가 존재)으로 존재하게 되는데, 이렇게 mode가 여러개인 분포에서 네개의 분포 중 하나의 분포로 치우쳐서 변환하는 경우 mode collapse가 발생한다.

GAN_9 이미지 출처:http://dl-ai.blogspot.com/2017/08/gan-problems.html

Generator가 Discriminator를 속일 수 있는 하나의 모드만을 계속해서 생성해내는 문제가 발생한다. 실제로 GAN을 이용하여 MNIST를 학습시킬 때 하나의 class(하나의 숫자)만을 생성하는 현상이 일어난다. Generator는 하나의 class만을 올바르게 생성해도 Discriminator의 관점에서 잘못된 이미지라고 말할 수 없다. 이렇듯 Discriminator만 속이면 되는 적대적 학습 방법은 mode collapse에 취약한 한계점을 드러낸다.

GAN_10 이미지 출처:http://dl-ai.blogspot.com/2017/08/gan-problems.html






Reference

[1] Generative Adversarial Net(paper): https://arxiv.org/abs/1406.2661
[2] NIPS 2016 Tutorial: https://arxiv.org/pdf/1701.00160.pdf
[3] Lean.AI: https://dl-ai.blogspot.com/2017/08/gan-problems.html

This post is licensed under CC BY 4.0 by the author.