Post

[논문 리뷰] Scalable Pre-training of Large Autoregressive Image Models (AIM)



Apple arXiv 2024

Introduction


Pre-training task agnostic model은 최근 NLP의 표준이 되었다. 이러한 모델은 복잡한 추론 작업을 해결하고 ChatGPT와 같이 AI assistant로 널리 사용되고 있다. 성공의 핵심요소로는 capacity(parameter 수), pre-training data의 증가에 따라 향상되는 능력으로 볼 수 있다.

이러한 모델의 확장은 두 가지 측면에서 중요하다.

  1. 모델은 과거를 고려하여 문장의 다음 단어를 예측하는 간단한 목표로 훈련되었지만 긴 context에 걸쳐 복잡한 패턴을 학습할 수 있다.
  2. Autoregressive objective의 scalability(확장성)는 특정 아키텍처, 특히 Transformer와 함께 사용될 때 주로 관찰되며 시너지를 낸다.

이러한 요소는 language modeling에만 국한되는 것이 아니다. 또한 최근 ViT(Vision Transfomer)의 성공은 Transformer architecture가 computer vision에서도 성공적으로 사용될 수 있음을 보여준다. 따라서 본 논문에서는 LLM의 결과를 일반화하기 위한 첫 번째 단계로 autoregressive objective를 사용하여 ViT 모델을 학습하여 경쟁력 있는 성능을 얻을 수 있는지 살펴본다.


AIM_1.png

본 논문에서는 visual feature를 위한 large-scale pre-training을 위해 autoregressive 접근 방법을 사용한 Autoregressive Image Models (AIM)을 제안했다. Vision transformer, large-scale web data, LLM pre-training과 같은 방법을 사용하여 기존의 iGPT와 같은 방법을 재검토했다. 또한 autoregressive pre-training을 visual feature에 적용하기 위한 두 가지 architecture 수정사항을 도입했다.

저자는 선별되지 않은(uncurated) 2B 이미지를 사용하여 600M - 7B의 parameter를 가지는 모델에 대해 연구했다. AIM은 이러한 이미지 대규모 모델에 대해 saturation 없이 지속적인 성능 향상을 이끌어냈다. 전반적인 결과는 large language model의 scaling 연구와 일치한다.





Pre-training


Data filtering networks📄 논문에서 소개한 DFN 데이터셋에 대해 pre-train을 진행했다. 데이터셋은 Common Crawl에서 필터링된 12.8B image-text pair로 구성되며 부적절한 콘텐츠 제거, 얼굴 blur, 중복 제거 등의 pre-process를 진행했다. Data filtering network에서 image와 caption 간의 alignment score를 측정하여 샘플 순위를 매긴 후, 12.8B 데이터내에서 상위 15% 샘플을 선정하여 DFN-2B(subset) 데이터셋이 추출된다. (Privacy 및 safety filter 외에 이미지 content를 기반으로 한 curation은 포함되지 않았다.)

Pre-train 중에 LLM에서 사용하는 기존 방법을 따라 p = 0.8의 확률로 DFN-2B에서 이미지를 샘플링하고 p = 0.2의 확률로 ImageNet-1k에서 이미지를 샘플링했다. 이러한 데이터셋을 DFN-2B+라고 한다.





Approach


1. Training Objective

논문의 training objective는 표준 autoregressive model을 따른다. 이미지 $x$가 주어지면 $K$개의 non-overlapping patch grid $x_k , k \in [1, K]$로 분할되어 token을 형성한다. 이 때 sequence 순서는 모든 이미지에서 고정되어있다고 가정하고 일반적으로 raster ordering(위→아래, 왼→오)을 사용한다. 이 때 이미지 하나의 확률은 다음과 같다.

\[P(x) = \prod^K_{k=1} P(x_k∣x_{<k})\]

여기서 $x_{<k}$는 $k-1$까지의 patch 집합을 나타내며 $k$번째 패치를 예측하는 데 사용되는 context이다(현재 순서까지의 patch를 사용해 다음 patch를 예측하는 것을 말한다). 그 다음 이미지 전체 $\mathcal{X}$에 대한 training loss은 negative log-likelihood(NLL)로 정의된다.

\[\sum_{x \in \mathcal{X}} \sum^K_{k=1} -\log P(x_k∣x_{<k})\]



Prediction loss

위의 training objective는 분포 $P(x_k ∣ x_{<k})$를 정의함에 따라 다양하게 변형된다. 논문에서는 기본적으로 Masked autoencoders are scalable vision learners📄 논문과 유사하게 normalized pixel-level regression loss를 사용한다.

이 Loss는 $P(x_k ∣ x_{<k})$를 일정한 variance를 갖는 Gaussian 분포로 가정한다.
즉, $\hat{x}_k(\theta)$가 $\theta$로 매개변수화된 네트워크의 $k$번째 patch prediction이고 $x_k$가 해당 ground-truth value인 경우, 목표는 prediction과 ground-truth 사이의 sum $\ell_2$ squared distance를 minimize하는 것이다.

\[\min_\theta \frac{1}{K} \sum^K_{k=1} \Vert \hat{x}_k(\theta) - x_k \Vert^2_2.\]






2. Architecture

AIM_2.png

Backbone으로 ViT(Vision Transformer)를 사용했다. Model capacity scaling을 위해 language modeling의 일반적인 방법을 따르고 depth 확장보다는 width 확장을 우선시한다. AIM의 design parameter에 대한 overview는 아래 표와 같다.

AIM_3.png



Pre-training 중 이전 patch가 주어지면, Self-attention layer에 causal mask를 적용하여 다음 patch의 확률을 모델링한다. Patch $i$에 대한 embedding은 아래와 같이 계산된다.

\[y_i = \sum^K_{k=1} a_{ik}v_i\]

여기서 $a_{ik}$는 attention weight이고 $v_k$는 value embedding이다. $k > i$에 대해 $a_{ik} = 0, \sum ^K_{k=1} a_{ik} = 1$로 설정하여 이전 sequence의 patch만을 보도록 하는 casual mask를 적용하였다(이후의 patch는 보지 않음). 즉, training 중 이미지는 single forward pass로 처리된다.



Prefix Transformer.

AIM_4.png

Pre-training 중 self-attention에는 causal mask를 사용했지만, 표준 ViT 모델의 down-stream task에서는 bidirectional self-attention을 필요로 한다. 이러한 불일치는 성능 저하로 이어지게 되므로 이 문제를 해결하기 위해 PrefixLM📄 논문과 같은 방법을 사용한다.

초기 일부분의 patch를 prefix로 간주하고, 나머지 patch에서 볼 수 있도록(나머지 patch를 예측하기 위한 context로 활용) casual mask를 제거한다(방법: prefix length size $S ∈ [1, K − 1],\; k < S$에 대해 $a_{i,k} > 0$). 이 방법을 통해 causal masking 없이도 모델이 동작할 수 있고 downstream task를 위한 추가작업 없이 성능을 올릴 수 있다.



MLP prediction heads.

Network가 pre-training objective에 특화되는 것을 방지하기 위해 일반적으로 pre-training 중에 특정한 prediction head를 추가한다. 본 논문에서는 transformer 위에 N block의 MLP를 사용하여 각 patch를 독립적으로 처리했다.



Straightforward implementation.

  • AIM에서는 LayerScale, stochastic depth, QK-Norm, freezing과 같은 optimization 안정성 유도 작업이 필요하지 않음
  • Transformer 앞, MLP head 앞에 sinusoidal positional embedding을 추가함
  • Transformer, Head에 사용되는 MLP는 expansion ratio 4를 사용함
  • 기존의 ViT와 달리 입력에 classification token을 사용하지 않음



Downstream adaptation.

본 논문에서는 down-stream task를 위해 model weight을 fix하고 classification head에 대해서만 training을 진행했다.

Pre-training 중 loss는 각 patch에 대해 독립적으로 계산되었고 Image-level token은 존재하지 않았다. Image-level prediction을 위한 global한 정보를 얻기 위해 일반적으로 patch feature에서 global average pooling을 사용한다. 하지만 AIM은 linear classifier 앞에 attention pooling operation를 사용했다.

patch features 집합 $P = \lbrace p_i ∣ 1 ≤ i ≤ K\rbrace$가 주어질 때, multi-head attention pooling을 사용한 global descriptor $\hat{p}$ 는 다음과 같이 정의된다.

\[\hat{p_h} = \sum^K_{i=1}\frac{\exp(q^T_h W^k_hp_i)}{\sum^K_{j=1}\exp(q^T_h W^k_hp_j)}W^v_h p_i\]



각 attention head $h = \lbrace 1, …, H \rbrace$에 대해 $W^h_k, W^h_v ∈ R^{d_h \times d}$ 는 각각 key, value weight을 나타낸다. $q_h$ 는 learnable한 query vector이다.

위 식의 결과로 linear classifier 입력이 되는 pooled feature $\hat{p} = [p_1 , …, p_H ], \hat{p} \in R^d$ 를 얻을 수 있다. 이러한 attention pooling을 사용하면 전체 작업이 엄격하게 linear하게 되지 않으므로 이를 "Attentive Probe"라고 부른다. 그럼에도 불구하고 linear probing의 장점(예: 적은 추가 parsmeter 수 및 overfitting 위험 감소)은 이 probe에서도 그대로 유지된다.





Results


1. Impact of scaling

저자는 model parameter 및 training data 측면에서 scaling의 영향에 대해 조사했다. 실험은 IN-1k 데이터셋의 validation에 대해 수행되었다.

AIM_7.png

각 모델에 대해 training iteration에 따른 pre-training loss 값과 validation set의 classification accuracy를 측정하였다. 결과는 위의 그래프와 같이 전체 training 동안 개선되는 것을 볼 수 있다. 또한 모델의 capacity를 scaling함에 따라 down-stream 작업의 loss 값과 accuracy가 향상되는 것을 볼 수있다.



AIM_8.png

위의 그래프는 1M 이미지의 작은 curated 데이터셋(예: IN-1k) 또는 큰 2B 이미지 세트(예: DFN-2B+)를 pre-train할 때 validation loss의 진행을 나타낸다. IN-1k에 대한 학습은 빠르게 validation loss가 줄어들지만 overfitting을 보이고 있다. 반면 uncurated DFN-2B 데이터셋은 validation loss가 빨리 줄어들지는 않지만 overfitting이 일어나지 않았다.

위에서 언급한 대로 동일한 데이터셋이 규모가 작은 IN-1k 데이터로 augment되면 결국 IN-1k에 대한 pre-train을 진행한 것 보다 좋은 성능을 보이는 것을 알 수 있다.

AIM_9.png




2. Architecture and Design

AIM_10.png

  • Targets and objective (a): target patch에 대한 다양한 representation에 대해 조사했다.
  • Autoregression pattern (b): Autoregressive pre-training을 위한 patch 순서를 정하는 방식에 대해 조사했다. (연관 1.) 아래의 Figure 7.에서 각각의 방법에 대한 patch prediction difficulty를 조사했다.
  • Cropping scale (c): cropping scale의 하한을 조정하여 각 patch의 정보 content가 미치는 영향에 대해 조사했다.
  • Causal vs. Prefix Attention (d): 표준 causal attention를 사용하는 것과 반대로 pre-train중에 prefix attention를 통합하는 것의 영향을 측정했다.
  • Head design (e): pixel level prediction을 위해 backbone의 top에 있는 다양한 유형의 head를 고려했다.
  • Deeper vs. Wider architecture (f): Depth가 Width보다 더 빠르게 확장되는 ViT의 원래 디자인과 달리 Llama와 유사한 scaling 전략을 채택했다. 위의 표 3f에서 wide한 아키텍처의 효율성을 검증했다. 작은 규모의 AIM-0.6B 모델의 경우에도 wide한 아키텍처가 좋은 성능을 제공할 뿐만 아니라 훈련 안정성도 향상시키는 것으로 나타났다.

AIM_11.png




3. Pre-training objective

AIM_12.png

Autoregressive vs. Masking: autoregressive objective와 masking objective로 학습된 모델에 대해 조사했다. 위의 표는 AIM이 masking objective보다 autoregressive objective에서 더 나은 성능을 발휘한다는 것을 보여준다.




4. Comparison with other methods

15가지 benchmark에 대해 SOTA 방법들과 비교했다. 결과는 아래와 같다.

AIM_13.png




This post is licensed under CC BY 4.0 by the author.