[논문 리뷰] Scalable Pre-training of Large Autoregressive Image Models (AIM)
Apple
arXiv 2024
Introduction
Pre-training task agnostic model은 최근 NLP의 표준이 되었다. 이러한 모델은 복잡한 추론 작업을 해결하고 ChatGPT와 같이 AI assistant로 널리 사용되고 있다. 성공의 핵심요소로는 capacity(parameter 수), pre-training data의 증가에 따라 향상되는 능력으로 볼 수 있다.
이러한 모델의 확장은 두 가지 측면에서 중요하다.
- 모델은 과거를 고려하여 문장의 다음 단어를 예측하는 간단한 목표로 훈련되었지만 긴 context에 걸쳐 복잡한 패턴을 학습할 수 있다.
- Autoregressive objective의 scalability(확장성)는 특정 아키텍처, 특히 Transformer와 함께 사용될 때 주로 관찰되며 시너지를 낸다.
이러한 요소는 language modeling에만 국한되는 것이 아니다. 또한 최근 ViT(Vision Transfomer)의 성공은 Transformer architecture가 computer vision에서도 성공적으로 사용될 수 있음을 보여준다. 따라서 본 논문에서는 LLM의 결과를 일반화하기 위한 첫 번째 단계로 autoregressive objective를 사용하여 ViT 모델을 학습하여 경쟁력 있는 성능을 얻을 수 있는지 살펴본다.
본 논문에서는 visual feature를 위한 large-scale pre-training을 위해 autoregressive 접근 방법을 사용한 Autoregressive Image Models (AIM)을 제안했다. Vision transformer, large-scale web data, LLM pre-training과 같은 방법을 사용하여 기존의 iGPT와 같은 방법을 재검토했다. 또한 autoregressive pre-training을 visual feature에 적용하기 위한 두 가지 architecture 수정사항을 도입했다.
저자는 선별되지 않은(uncurated) 2B 이미지를 사용하여 600M - 7B의 parameter를 가지는 모델에 대해 연구했다. AIM은 이러한 이미지 대규모 모델에 대해 saturation 없이 지속적인 성능 향상을 이끌어냈다. 전반적인 결과는 large language model의 scaling 연구와 일치한다.
Pre-training
Data filtering networks📄 논문에서 소개한 DFN 데이터셋에 대해 pre-train을 진행했다. 데이터셋은 Common Crawl에서 필터링된 12.8B image-text pair로 구성되며 부적절한 콘텐츠 제거, 얼굴 blur, 중복 제거 등의 pre-process를 진행했다. Data filtering network에서 image와 caption 간의 alignment score를 측정하여 샘플 순위를 매긴 후, 12.8B 데이터내에서 상위 15% 샘플을 선정하여 DFN-2B(subset) 데이터셋이 추출된다. (Privacy 및 safety filter 외에 이미지 content를 기반으로 한 curation은 포함되지 않았다.)
Pre-train 중에 LLM에서 사용하는 기존 방법을 따라 p = 0.8의 확률로 DFN-2B에서 이미지를 샘플링하고 p = 0.2의 확률로 ImageNet-1k에서 이미지를 샘플링했다. 이러한 데이터셋을 DFN-2B+라고 한다.
Approach
1. Training Objective
논문의 training objective는 표준 autoregressive model을 따른다. 이미지 $x$가 주어지면 $K$개의 non-overlapping patch grid $x_k , k \in [1, K]$로 분할되어 token을 형성한다. 이 때 sequence 순서는 모든 이미지에서 고정되어있다고 가정하고 일반적으로 raster ordering(위→아래, 왼→오)을 사용한다. 이 때 이미지 하나의 확률은 다음과 같다.
\[P(x) = \prod^K_{k=1} P(x_k∣x_{<k})\]여기서 $x_{<k}$는 $k-1$까지의 patch 집합을 나타내며 $k$번째 패치를 예측하는 데 사용되는 context이다(현재 순서까지의 patch를 사용해 다음 patch를 예측하는 것을 말한다). 그 다음 이미지 전체 $\mathcal{X}$에 대한 training loss은 negative log-likelihood(NLL)로 정의된다.
\[\sum_{x \in \mathcal{X}} \sum^K_{k=1} -\log P(x_k∣x_{<k})\]Prediction loss
위의 training objective는 분포 $P(x_k ∣ x_{<k})$를 정의함에 따라 다양하게 변형된다. 논문에서는 기본적으로 Masked autoencoders are scalable vision learners📄 논문과 유사하게 normalized pixel-level regression loss를 사용한다.
이 Loss는 $P(x_k ∣ x_{<k})$를 일정한 variance를 갖는 Gaussian 분포로 가정한다.
즉, $\hat{x}_k(\theta)$가 $\theta$로 매개변수화된 네트워크의 $k$번째 patch prediction이고 $x_k$가 해당 ground-truth value인 경우, 목표는 prediction과 ground-truth 사이의 sum $\ell_2$ squared distance를 minimize하는 것이다.
2. Architecture
Backbone으로 ViT(Vision Transformer)를 사용했다. Model capacity scaling을 위해 language modeling의 일반적인 방법을 따르고 depth 확장보다는 width 확장을 우선시한다. AIM의 design parameter에 대한 overview는 아래 표와 같다.
Pre-training 중 이전 patch가 주어지면, Self-attention layer에 causal mask를 적용하여 다음 patch의 확률을 모델링한다. Patch $i$에 대한 embedding은 아래와 같이 계산된다.
\[y_i = \sum^K_{k=1} a_{ik}v_i\]여기서 $a_{ik}$는 attention weight이고 $v_k$는 value embedding이다. $k > i$에 대해 $a_{ik} = 0, \sum ^K_{k=1} a_{ik} = 1$로 설정하여 이전 sequence의 patch만을 보도록 하는 casual mask를 적용하였다(이후의 patch는 보지 않음). 즉, training 중 이미지는 single forward pass로 처리된다.
Prefix Transformer.
Pre-training 중 self-attention에는 causal mask를 사용했지만, 표준 ViT 모델의 down-stream task에서는 bidirectional self-attention을 필요로 한다. 이러한 불일치는 성능 저하로 이어지게 되므로 이 문제를 해결하기 위해 PrefixLM📄 논문과 같은 방법을 사용한다.
초기 일부분의 patch를 prefix로 간주하고, 나머지 patch에서 볼 수 있도록(나머지 patch를 예측하기 위한 context로 활용) casual mask를 제거한다(방법: prefix length size $S ∈ [1, K − 1],\; k < S$에 대해 $a_{i,k} > 0$). 이 방법을 통해 causal masking 없이도 모델이 동작할 수 있고 downstream task를 위한 추가작업 없이 성능을 올릴 수 있다.
MLP prediction heads.
Network가 pre-training objective에 특화되는 것을 방지하기 위해 일반적으로 pre-training 중에 특정한 prediction head를 추가한다. 본 논문에서는 transformer 위에 N block의 MLP를 사용하여 각 patch를 독립적으로 처리했다.
Straightforward implementation.
- AIM에서는 LayerScale, stochastic depth, QK-Norm, freezing과 같은 optimization 안정성 유도 작업이 필요하지 않음
- Transformer 앞, MLP head 앞에 sinusoidal positional embedding을 추가함
- Transformer, Head에 사용되는 MLP는 expansion ratio 4를 사용함
- 기존의 ViT와 달리 입력에 classification token을 사용하지 않음
Downstream adaptation.
본 논문에서는 down-stream task를 위해 model weight을 fix하고 classification head에 대해서만 training을 진행했다.
Pre-training 중 loss는 각 patch에 대해 독립적으로 계산되었고 Image-level token은 존재하지 않았다. Image-level prediction을 위한 global한 정보를 얻기 위해 일반적으로 patch feature에서 global average pooling을 사용한다. 하지만 AIM은 linear classifier 앞에 attention pooling operation를 사용했다.
patch features 집합 $P = \lbrace p_i ∣ 1 ≤ i ≤ K\rbrace$가 주어질 때, multi-head attention pooling을 사용한 global descriptor $\hat{p}$ 는 다음과 같이 정의된다.
각 attention head $h = \lbrace 1, …, H \rbrace$에 대해 $W^h_k, W^h_v ∈ R^{d_h \times d}$ 는 각각 key, value weight을 나타낸다. $q_h$ 는 learnable한 query vector이다.
위 식의 결과로 linear classifier 입력이 되는 pooled feature $\hat{p} = [p_1 , …, p_H ], \hat{p} \in R^d$ 를 얻을 수 있다. 이러한 attention pooling을 사용하면 전체 작업이 엄격하게 linear하게 되지 않으므로 이를 "Attentive Probe"라고 부른다. 그럼에도 불구하고 linear probing의 장점(예: 적은 추가 parsmeter 수 및 overfitting 위험 감소)은 이 probe에서도 그대로 유지된다.
Results
1. Impact of scaling
저자는 model parameter 및 training data 측면에서 scaling의 영향에 대해 조사했다. 실험은 IN-1k 데이터셋의 validation에 대해 수행되었다.
각 모델에 대해 training iteration에 따른 pre-training loss 값과 validation set의 classification accuracy를 측정하였다. 결과는 위의 그래프와 같이 전체 training 동안 개선되는 것을 볼 수 있다. 또한 모델의 capacity를 scaling함에 따라 down-stream 작업의 loss 값과 accuracy가 향상되는 것을 볼 수있다.
위의 그래프는 1M 이미지의 작은 curated 데이터셋(예: IN-1k) 또는 큰 2B 이미지 세트(예: DFN-2B+)를 pre-train할 때 validation loss의 진행을 나타낸다. IN-1k에 대한 학습은 빠르게 validation loss가 줄어들지만 overfitting을 보이고 있다. 반면 uncurated DFN-2B 데이터셋은 validation loss가 빨리 줄어들지는 않지만 overfitting이 일어나지 않았다.
위에서 언급한 대로 동일한 데이터셋이 규모가 작은 IN-1k 데이터로 augment되면 결국 IN-1k에 대한 pre-train을 진행한 것 보다 좋은 성능을 보이는 것을 알 수 있다.
2. Architecture and Design
- Targets and objective (a): target patch에 대한 다양한 representation에 대해 조사했다.
- Autoregression pattern (b): Autoregressive pre-training을 위한 patch 순서를 정하는 방식에 대해 조사했다. (연관 1.) 아래의 Figure 7.에서 각각의 방법에 대한 patch prediction difficulty를 조사했다.
- Cropping scale (c): cropping scale의 하한을 조정하여 각 patch의 정보 content가 미치는 영향에 대해 조사했다.
- Causal vs. Prefix Attention (d): 표준 causal attention를 사용하는 것과 반대로 pre-train중에 prefix attention를 통합하는 것의 영향을 측정했다.
- Head design (e): pixel level prediction을 위해 backbone의 top에 있는 다양한 유형의 head를 고려했다.
- Deeper vs. Wider architecture (f): Depth가 Width보다 더 빠르게 확장되는 ViT의 원래 디자인과 달리 Llama와 유사한 scaling 전략을 채택했다. 위의 표 3f에서 wide한 아키텍처의 효율성을 검증했다. 작은 규모의 AIM-0.6B 모델의 경우에도 wide한 아키텍처가 좋은 성능을 제공할 뿐만 아니라 훈련 안정성도 향상시키는 것으로 나타났다.
3. Pre-training objective
Autoregressive vs. Masking: autoregressive objective와 masking objective로 학습된 모델에 대해 조사했다. 위의 표는 AIM이 masking objective보다 autoregressive objective에서 더 나은 성능을 발휘한다는 것을 보여준다.
4. Comparison with other methods
15가지 benchmark에 대해 SOTA 방법들과 비교했다. 결과는 아래와 같다.