[논문 리뷰] MaskGIT: Masked Generative Image Transformer
2022년 Google Research에서 발표한 Image Generation 논문이다.
Google Research;
CVPR 2022;
- Paper: https://arxiv.org/abs/2202.04200
- Git: https://github.com/google-research/maskgit
- Project Page: https://masked-generative-image-transformer.github.io
Introduction
Image synthesize task에서 GAN이 좋은 성능을 보여줬지만, training 불안정성 및 mode collapse와 같은 문제로 인해 sample diversity가 부족한 문제를 겪고 있다. GAN에서 사용되는 미묘한 min-max optimization과 달리 Generative transformer model은 Maximum likelihood estimation에 의해 학습되며 안정적인 training 및 향상된 distribution coverage와 diversity를 제공한다는 점에서 GAN보다 장점이 있다.
Generative transformer 접근 방식은 이미지를 일련의 개별 토큰(or visual words)로 양자화한 뒤, Autoregressive model(예: transformer)이 이전에 생성된 결과(즉 autoregressive decoding)를 기반으로 이미지 토큰을 순차적으로 생성하도록 학습된다. 기존 작업은 대부분 1단계, 즉 information loss를 최소화하도록 이미지를 quantization하는 방법에 초점을 맞추고, NLP에서 차용한 동일한 2단계를 사용하므로 SOTA Generative transformer 조차도 이미지를 단순하게 시퀀스로 취급하고 raster-scan oder(즉 라인 by 라인, 왼→오)에 따라 1D 토큰 시퀀스로 평면화한다. 본 논문에서는 텍스트와 달리 이미지는 순차적이지 않으므로 이러한 방법이 이미지에 효율적이지 않다고 주장한다. 이미지를 Flat 시퀀스로 취급한다는 것은 Autoregressive sequence length가 증가하여 자연어 문장보다 훨씬 긴 매우 긴 시퀀스를 형성한다는 것을 의미하고, Long-term correlation 모델링 뿐 아니라 디코딩을 다루기 어렵게 만든다.
본 논문에서는 MaskGIT이라고 하는 이미지 합성을 위한 새로운 Bidirectional transformer를 소개하고 있다. MaskGIT은 training 중에 BERT의 mask prediction과 유사한 proxy task에 의해 훈련되며, inference에서는 새로운 non-autoregressive decoding method을 채택하여 일정한 단계 수로 이미지를 합성하는 방법을 사용한다. 이미지를 생성할때 기존처럼 256단계가 아닌 8단계만 거치고, 각 단계 내의 예측이 병렬화 가능하기 떄문에 Autoregresive decoding보다 훨씬 빠르다. 또한, 래스터 스캔 순서대로 이전 토큰에 대해서만 조건을 지정하는 대신 양방향 셀프 어텐션을 통해 모델이 생성된 토큰에서 모든 방향으로 새 토큰을 생성할 수 있다. 또한 본 논문의 마스크 스케줄링 방식이 생성 품질에 상당한 영향을 미친다는 것을 발견했다.
Method
본 논문의 목표는 병렬 Decoding, Bi-directional Generation을 활용한 새로운 이미지 합성 패러다임 설계하는 것이다.
이전의 연구와 동일하게 2-stage 방법을 사용했고, 특히 두번째 stage를 개선하는 것이 목표이므로 첫번째 stage는 VQGAN과 동일한 구조를 사용한다.
두번째 단계에서는 MVTM(Masked Visual Token Modeling)을 통한 Bi-directional transformer 학습을 제안한다. 기존에 Visual token을 flatten해서 raster-scan order으로 진행하던 방식과 달리, 전체를 Bi-directional transformer에 넣어줌으로서 이미지의 전체를 볼 수 있도록 했다. 기존에 BERT에서 사용했던 방식과 유사하게 이미지 토큰을 마스킹처리하고, 마스킹 된 위치에 어떤게 있었어야 했는지를 예측한다.
< 2-stage design >
- 기존의 Generative Transformer 구조
- Quantize(or Tokenization) image: 시퀀스 길이 축소, high-resolution image를 저차원 discretized된 code로 압축시킴.
- Auto-regressive model: 압축된 코드를 flatten해서 시퀀스와 같은 형태로 만들어준다. Auto-regressive 모델에 넣어서, 시퀀스를 하나씩 prediction, 이전 token을 보고 다음 토큰 예측(Next token prediction)
VQGAN: Vector Quantization
이미지가 인코더에 들어가게 되면 압축된 feature가 생성된다. 이 압축된 feature를 그대로 사용하는게 아니라, 우리가 가지고 있는Codebook이라는 벡터의 집합에서 이 feature의 HW의 각각의 위치 position wise, 하나의 벡터에 가장 비슷한 코드북 내의 벡터를 찾는 것이 목적.
코드북의 구성을 살펴보면, 코드(코드 인덱스)와 그 코드에 해당하는 코드 임베딩으로 이루어져 있다. 코드북 내에는 코드와 코드 임베딩 벡터 페어가 총 K개 만큼 존재한다. Continuous한 벡터 z를 코드북 내의 가장 가까운 코드(위 그림의 파란색 벡터)로 대체하는 방식으로 descitized를 진행한다. (단, 이 과정에서 continous 한 코드를 가지고 있는 정보만으로 표현하기 때문에 정보 손실이 발생할 수 있다.)
한마디로 정리하면, Continuous한 Embeded vector를 codebook 안에 가지고 있는 code embedding으로 estimation 하는 과정이다.
1. MVTM in Training
학습 과정 중, VQ-encoder를 통해 얻은 token의 일부를 [MASK] token으로 대체한다. 이 샘플링 과정은 Mask scheduling function $\gamma(r) \in (0,1]$ 에 의해 파라미터화된다. 일단 0-1사이의 ratio를 샘플링 한 뒤, Y에서 마스크로 바꾸기 위한 token을 균일하게 $[\gamma(r) \cdot N]$개 선정하여 배치한다. 여기서 Mask shceduling은 이미지 생성 품질에 큰 영향을 미친다.
마스킹된 Y$_M$을 Multi-layer bi-directional transformer에 공급하여 각 마스킹된 토큰에 대한 확률을 예측한다. 여기서 Negetive log-likelihood는 ground-truth와 one-hot token 간의 Cross entropy로 계산된다.
2. Iterative Decoding
Auto-regressive decoding에서 토큰은 이전에 생성된 출력을 기반으로 순차적으로 생성된다. 이 프로세스는 병렬화할 수 없으며 이미지 토큰 길이 때문에 매우 느리다. 본 논문에서는 Bi-directional transformer를 활용해 이미지 토큰을 병렬도 동시에 생성하는 decoding 방식을 제안한다.
이론상으로는 모든 토큰을 동시에 생성하는 것이 가능하지만, 학습 과정과의 불일치로 인해 동시에 생성하는 것은 어렵다. Inference time에서 이미지를 생성하기 위해서는 모든 토큰이 마스킹 된 blank 이미지에서 시작하여 아래와 같은 과정을 반복한다.
하나의 Step에서 Decoding 단계는 4단계로 구성된다.
- Predict: 현재 iteration에서 마스킹된 토큰이 주어지면 모델은 모든 마스킹된 위치에 대한 확률을 병렬로 예측한다.
- Sample: 모든 마스킹된 위치 $i$에서, Code book의 가능한 모든 토큰에 대한 예측 확률 $p_i^{(t)}$를 기반으로 토큰을 샘플링한다(이 때, 확률이 가장 높은 코드를 고르는 것이 아니라 multinomial sampling을 통해 코드를 선택. 이 과정으로 인해 GAN등의 다른 방법보다 Diversity를 높일 수 있었던 것 같다). 토큰이 샘플링 된 후 해당 예측 점수는 이 예측에 대한 모델의 믿음을 나타내는 “Confidence” score로 사용된다. 마스킹되지 않은 위치는 confidence score를 1.0으로 설정한다.
- Mask Schedule: Mask scheduling function $\gamma$ 에 따라 마스크할 토큰의 개수를 계산한다.
- Mask: $Y_M^{(t)}$에서 n개의 토큰을 마스킹하여 $Y_M^{(t+1)}$을 얻는다. Iteration t+1에 대한 마스크 $M^{(t+1)}$은 다음과 같이 계산된다.
Decoding 과정은 T step으로 이미지를 생성하며, 모델은 매 iteration마다 모든 토큰에 대해 동시에 예측하지만 가장 확실한 토큰만 유지한다. 남은 토큰들은 다시 마스킹되며 다음 iteration에서 재 예측된다. 모든 토큰이 생성될 때 까지 마스크 비율이감소한다. 본 논문에서는 T=8 iteration으로 이미지를 생성했다.
3. Masking Design
Masking design에 따라 생성된 이미지의 품질은 크게 영향을 받는다. 논문에선 주어진 latent 토큰에 대한 마스크 비율을 계산하는 Mask scheduling fuction $\gamma(\cdot)$ 를 모델링한다. $\gamma$ 는 training, inference 과정에 모두 사용된다.
- $\gamma(\cdot)$ → training: randomly sample a ratio $r$ in $(0, 1]$ (다양한 디코딩 시나리오를 시뮬레이션하기 위함) → inference: $0/T, 1/T, … , (T-1)/T$ 의 입력을 받아 디코딩 진행 상황을 나타냄
본 논문에서는 $\gamma$로 아래와 같은 간단한 변환을 수행했다.
- Linear function: 매번 동일한 양의 마스크를 생성
- Concave function: 이미지 생성이 less>more 정보 흐름을 따른다는 직관을 따른다. 처음에는 대부분의 토큰이 마스킹되어 모델이 확신할 수 있는 몇 개의 정확한 예측만 하면 된다. 끝으로 갈수록 마스크 비율이 급격히 떨어지므로 모델이 훨씬 정확한 예측을 해야한다. 끝에서는 많은 예측된 토큰이 있으므로 효과적인 정보가 증가했다. (ex. cosine, square, cubic, exponential)
- Convex function: 모델은 처음 몇 번의 반복 내에서 대다수의 토큰을 예측해야한다. (ex. square root, logarithmic)
본 논문에서는 실험적으로 위의 함수들을 비교하고 cosine 함수가 가장 좋은 성능을 보인다고 한다.
Experiments
1. Experimental Setup
- Single autoencoder, decoder 학습
- codebook 1024 token (256X256 image), 이미지는 16배 압축
- Model: 24개 layer, 8개 Attention heads, 768 embedding dimensions 및 3072개 hidden dimensions
- label smoothing=0.1, dropout rate=0.1, Adam optimizer
- Data Augmentation: RandomResizeAndCrop
2. Class-conditional Image Synthesis
ImageNet 256X256, 512X512 이미지에 대한 class-conditional image synthesis 성능에 대한 평가이다. 이미지 quality(FID), 속도(step), 다양성(CAS, Precision/Recall) 측면에서 MaskGIT이 뛰어난 성능을 보여주고 있다.
3. Image Editing Applications
세가지 작업에 MaskGIT 적용: class-conditional image editing, image inpainting, outpainting
Class-conditional image editing
Image inpainting, outpainting
그림 7의 예에서 볼 수 있듯이 MaskGIT는 동일한 입력에 대해 서로 다른 시드를 사용하여 다양한 결과를 합성할 수 있다. MaskGIT은 특히 객체와 전역 구조를 특히 잘 완성하는 것을 볼 수 있는데, 이것이 Transformer의 Global attention을 통해 유용한 표현을 학습하기 때문이라고 한다.
4. Ablation Study
Mask Scheduling
MaskGIT의 학습과 Iterative decoding에 사용되는 Mask scheduling function. 일반적으로 오목 함수가 선형보다 더 나은 score를 얻는다.
- Concave Function: 처음에는 대부분의 토큰이 마스킹되어 모델이 확신할 수 있는 몇 가지 정확한 예측만 진행한다. 학습이 진행되고, 효과적인 정보가 증가하면, 많은 수의 토큰을 결정한다. 직관적으로 이해하기에도 concave function이 효율적임을 알 수 있다.
- Convex Function: 모델은 초기 iteration에서 대다수의 토큰을 정해야한다.
Iteration Number
그림 8을 통해 알 수 있듯이, 동일한 설정에서 더 많은 iteration이 반드시 좋은 것은 아니다. 대부분 최적 지점 T가 존재한다. 저자는 너무 많은 iteration으로 인해 모델이 confidence가 낮은 예측을 유지하지 못하게 되어 토큰 다양성이 악화되기 때문에 최적 지점이 존재한다고 한다.