[논문 리뷰] DiT, Scalable Diffusion Models with Transformers
ICCV 2023
- Paper: https://arxiv.org/abs/2212.09748
- Git: https://github.com/facebookresearch/DiT
- Page: https://www.wpeebles.com/DiT
Introduction
다양한 도메인에서 Neural architecture는 Transformer 기반으로 대체되었다. Image Generation의 경우, Autoregressive model에서는 transformer가 퍼져있지만 다른 Generative modeling framework에서는 많이 사용되고 있지 않다. 특히 Diffusion model은 최근 이미지 생성에 있어 선두에 있다고 할 수 있지만, backbone으로는 Convolutional U-Net 구조를 사용하고 있다.
다양한 U-Net 기반 diffusion backbone 연구를 통해 architecture 선택의 중요성을 깨닫고, 저자는 generative modeling 연구에 대한 baseline을 제공하는 것을 목표로 한다. Diffusion model에서 U-Net의 inductive bias가 성능에 중요하지 않으며, transformer 설계로 대체될 수 있음을 증명한다.
논문에서는 Transformer 기반 Diffusion model을 제안하고, 이를 Diffusion Transformer(DiT)라고 한다. ViT의 구조를 응용했으며, network complexity와 sample quality 측면에서 transformer의 scaling에 대해 연구한다. 또한 VAE Latent space에서 학습을 진행하는 LDM(Latent Diffusion Model) framework에서 DiT를 설계함으로써 U-Net backbone을 transformer로 성공적으로 대체할 수 있음을 증명했다.
Method
1. Preliminaries
Diffusion formulation.
논문에서는 Diffusion Model(DDPM)에 관한 기본 지식을 설명하고 있다. 이 부분은 아래 그림으로 대체한다. 자세한 내용은 블로그 내 [Generative model 기초 3. Diffusion 정리] 를 참조바란다.
Classifier-free guidance.
Conditional diffusion model은 class label을 입력으로 사용한다. 이 경우 Reverse process는 $p_\theta(x_{t-1}∣x_t, c)$와 같이 수정되고, 노이즈 $\epsilon_{\theta}$와 $\Sigma_\theta$는 c에 따라 condition이 지정된다. 이 때 classifier-free guidance를 사용하여 $\log(c∣x)$가 높은 $x$를 찾도록 장려할 수 있다. classifier-free guidance는 일반적인 sampling 기술에 비해 크게 향상된 sample을 생성하는 것으로 알려져 있으며, DiT 모델에서도 이러한 방법을 적용한다.
Latent diffusion models.
High-resolution Pixel space에서 Diffusion model을 직접 학습시키는 것은 계산적으로 쉽지 않다. LDM에서는 2-stage 방식으로 이러한 문제를 해결했다.
(1) 학습된 autoencoder를 사용하여 image를 더 작은 차원의 representation으로 변환
(2) encoder에서 압축된 representation을 디코딩하여 새로운 이미지를 생성하는 diffusion model 학습
본 논문에 대한 자세한 내용은 [Stable Diffusion 리뷰] 게시물을 참고하길 바란다.
2. Diffusion Transformer Design Space
본 논문에서 중점은 DDPM을 학습하는 것이므로 transformer 구조는 ViT(Vision Transformer) architecture를 기반으로 한다.
Patchify.
DiT의 입력은 spatial representation(VAE에서 나온 noised latent) $z$이다(image: 256 × 256 × 3 → $z$: 32 × 32 × 4). DiT의 첫번째 layer에서 spatial input을 patch로 나누고, patch들은 각각 linearly embedding되어 $T$개 token sequence로 patchify 된다. Patchify 이후 ViT의 frequency-based positional embedding(sine-cosine version)을 모든 입력 token에 적용한다(image가 latent feature로 변경된 점만 제외하고 여기까지 ViT와 동일하다).
Patchify에 의해 생성된 token $T$의 수는 patch size 하이퍼파라미터 $p$에 의해 결정된다(patch size가 작아지면 그만큼 token의 개수가 늘어난다). $p$가 작아지면 $T$가 커지고 이에따라 GFlops가 증가한다.
DiT block design.
Patchify 단계를 지나 입력 token들은 transfomer 입력으로 사용된다. 하지만 Diffusion model은 noised 이미지 input 외에도 noise timestep $t$, class label $c$, natural language 등과 같은 추가 condition정보를 입력으로 처리하는 경우가 있다. 저자는 conditional input을 다르게 처리하는 transformer block의 네가지 변형을 제시한다.
1) In-context conditioning
단순히 $t$와 $c$의 vector 임베딩을 입력 시퀀스에 두 개의 추가 token으로 추가하여 이미지 token과 다르지 않게 처리한다(ViT의 cls token과 유사). 이 방법은 ViT block을 별도로 수정할 필요 없다. 마지막 block 이후에 output 시퀀스에서 conditioning token을 제거한다.
2) Cross-attention block
$t$와 $c$ 임베딩을 이미지 token과 별도로 분리한다. $t$와 $c$는 concat하여 길이가 2인 시퀀스로 만든다. Transformer block은 multi-head self-attention block 이후 multi-head cross-attention layer가 추가된다. 이 방법은 class label을 condition으로 주기 위해 LDM에서 사용하는 것과 유사한 방식이다.
3) Adaptive layer norm (adaLN) block
GAN과 U-Net backbone을 가진 Diffusion model에서 많이 사용되는 Adaptive normalization layer를 적용하기 위해서, Transformer에서 사용하는 standard layer norm을 adaptive layer norm (adaLN)으로 대체한다. dimension-wise scale/shift parameter $\gamma$와 $\beta$를 직접 학습하는 대신 $t$와 $c$의 임베딩 vector를 이용해서 regression하게 된다(자세한 방법 아래 참조). 위의 방법들 중 adaLN은 Gflops가 가장 적으므로 계산 효율성이 가장 높다. 또한 모든 token에 동일한 feature를 적용하는 conditioning 방법이다.
Layer Normalization
\[\text{LN}(x)\; = \; \gamma\bigg(\frac{x-\mu(x)}{\sigma(x)}\bigg) \; + \; \beta\]
FiLM(FiLM: Visual Reasoning with a General Conditioning Layer📄)
FiLM 논문은 conditioning을 위한 방법을 제시한 논문이다. 해당 논문에서는 입력 이미지와 관련해서 conditioning할 정보를 인코딩한 뒤, 네트워크에서 이미지 feature map에 affine transformation 해줌으로서 adaptively 적용하는 방식이다.
방법은 다음과 같다. Condition input(ex. caption)을 임의의 function(neural network로 구현됨)을 이용해 scale vector $\gamma_{i,c}$, shift vetor $\beta_{i,c}$로 인코딩한다. 그 뒤에 이 vector를 이용해 이미지 feature에 affine transform을 진행한다.
본 논문에서 수행한 adaLN은 이러한 affine transform의 형태를 Layer Noramlization에 적용한 형태라고 할 수 있다.
먼저 기존의 Layer Normalization과 같이 Data sample 단위로 평균과 분산을 구한다. 여기서 learnable scale/shift parameter $\gamma, \beta$를 학습하는 대신 timestep $t$와 class label $c$를 shift, scale 으로 사용한다.
4) adaLN-Zero block
ResNet에서 Residual block은 일반적으로 identity function로 초기화된다. Diffusion U-Net 모델에서도 이와 유사한 초기화 전략을 사용하여 residual connection 이전 마지막 conv layer에 zero-initializing을 적용하는 것이다. 유사한 작업을 위해 저자는 adaLN을 위한 $\gamma, \beta$ 이외에 DiT block 내의 residual connection 전에 적용되는 dimension-wise scaling parameters $\alpha$를 도입했다.
모든 $\alpha$에 대해 Zero-vector를 출력하기 위해 MLP를 초기화하였으며, 이를 통해 DiT block을 identity function으로 초기화한다.
Model size.
각각의 hidden dimension size가 d개인 N개의 DiT block을 적용했다. ViT와 같이 표준 transformer 구조를 사용했다. 모델 구성에 대한 세부 정보는 아래 표와 같다.
Transformer decoder.
DiT block 이후 나온 이미지 token sequence를 output noise prediction()과 output diagonal covariance prediction으로 디코딩해야한다. 이를 위해 linear decoder를 사용한다. 마지막 layer norm을 적용하고 각 token을 $p\times p \times C → p\times p \times 2C$ tensor로 디코딩하고 reshape한다.
3. Experimental Setup
Training.
- ImageNet dataset의 256 × 256, 512 × 512 resolution 이미지 사용
- Data augmentation: horizontal flips
- ViT와 달리 learning rate warmup, regularization 없이도 안정적인 학습이 가능함
- EMA model(exponential moving average): decay 0.9999
Diffusion.
- Encoder: Stable Diffusion의 pretrained VAE 사용
- VAE 입력 RGB 이미지는 256×256×3, $z = E(x)$는 32 × 32 × 4
- Diffusion 모델에서 새로운 latent를 생성한 후 VAE Decoder$(x=D(z))$를 사용하여 pixel 이미지로 디코딩
Evaluation metrics.
- FID를 사용하여 scaling performance를 측정함
- Inception Score, sFID 및 Precision/Recall을 보조 측정항목으로 사용함
Experiments
DiT block design.
다양한 DiT block design에 대해 FID 성능을 비교했다. 결과는 아래 그림과 같다.
Scaling model size and patch size.
다양한 모델 config(S, B, L, XL)와 patch size(8, 4, 2)에 대한 FID 비교 결과이다.
다음은 DiT 모델의 다양한 전략에 대한 실험에 대한 결과이다.
5.1. State-of-the-Art Diffusion Models
다음으로 다양한 SOTA class-conditional generative model과 비교했다. Bubble area는 diffusion 모델의 flops를 나타낸다고 한다.
256×256 ImageNet, 512×512 ImageNet에서 SOTA class-conditional generative model과 수치적으로 비교했다.