[논문 리뷰] Rich Human Feedback for Text-to-Image Generation
Google Research
CVPR 2024
Best paper
- Paper: https://arxiv.org/abs/2312.10240
- Git: https://github.com/google-research/google-research/tree/master/richhf_18k
📖 핵심 훑어보기 !!
- T2I Generation 모델에서 artifact, misalignment등의 문제를 해결하기 위한 데이터셋과 human annotation을 예측하는 모델 제안
- RichHF-18K: 18K 이미지로 구성된 Rich Human Feedback 데이터셋 수집, 이미지의 비현실성/결함 및 text-image 정렬 오류를 강조하는 annotation 포함
- RAHF 모델: 멀티모달 트랜스포머 모델을 사용하여 이미지와 관련된 text prompt에 대한 human annotation을 예측
Introduction
Text-to-Image(T2I) Generation 모델은 최근 상당한 발전을 이루었고, 다양한 도메인에서 사용되고 있다. 하지만 여전히 artifacts/implausibility(비현실성), text와의 misalignment(정렬 오류)와 같은 문제를 겪고 있다. 실제로 Pick-a-Pic 데이터셋에서 생성된 이미지의 약 10%만이 artifact와 implausibility가 없는 것으로 나타났다고 한다.
그렇다면 기존에 사용한 방법들은 어땠을까?
- 기존의 IS, FID와 같은 metric → 이미지의 전체 분포에 따라 계산되며 개별 이미지의 문제를 반영하기 쉽지않음
- 인간 선호도/평가 반영 → 이미지의 품질을 단일 숫자 점수로 요약함
- CLIPScore와 같은 prompt-image alignment → 이미지의 misalignment 영역을 localize하지 못함, 모델이 복잡함
빨간 점: artifact/implausibility region, 파란 점: misaligned region
본 논문에서는 다면적 평가가 가능하고 해석가능한 모델과 데이터셋을 제안한다.
- RichHF-18K: 18K 이미지로 구성된 Rich Human Feedback 데이터셋 수집
- 이미지 내의 implausibility/artifact 및 text-image misalignment 영역을 강조하는 point annotation 포함
- 생성된 이미지에서 누락되거나 잘못 표현된 개념을 지칭하는 부분(단어)에 대해 prompt에 라벨링 됨
- 이미지 타당성, text-image alignment, aesthetics(미학) 및 전반적인 평가에 대한 4가지 유형의 세분화된 점수 포함
- Rich Automatic Human Feedback (RAHF): multimodal transformer 모델, 생성된 이미지와 관련 text prompt에 대한 human annotation을 예측하는 방법을 학습.
논문의 저자는 RAHF에 의한 human feedback이 이미지 생성을 개선하는 데 유용하다는 것을 보여준다. 예측된 heatmap을 마스크로 사용하여 문제가 있는 이미지 영역을 다시 칠하거나, 예측된 점수를 사용하여 생성 모델을 fine-tuning하는데 도움이 된다고 한다.
Collecting rich human feedback
1. Data collection process
먼저, 데이터셋에는 다음과 같은 내용이 포함된다.
- 2개의 heatmap: artifact/implausibility, misalignment
- 4개의 세분화된 점수: plausibility(타당성), alignment(정렬), aesthetics(미학), 전체 점수
- 1개의 text sequence: misaligned keywords
데이터셋 수집 과정은 다음과 같다.
- 각 생성된 이미지에 대해 먼저 annotator는 이미지를 검사하고 이를 생성하는데 사용된 prompt를 읽는다.
- Text prompt와 관련하여 implausibility/artifact 또는 misalignment의 위치를 나타내기 위해 이미지에 포인트를 표시한다.
- 표시된 각 포인트에는 “effective radius”(유효 반경, 이미지 H의 1/20)가 있어서, 포인트를 중심으로 특정 공간(디스크)만큼 유효함. 이런식으로 비교적 적은 포인트를 사용하여 결함이 있는 이미지 영역을 커버할 수 있음.
- 마지막으로 annotator는 misaligned keyword와 4가지 점수(plausibility, image-text alignment, aesthetic, 전체적인quality) 대해 5-point Likert scale로 평가를 진행
2. Human feedback consolidation
Human feedback의 신뢰성을 위해 각 image-text pair에 대해 3명의 annotator가 주석을 달았다. 3명의 annotation을 아래와 같은 방법으로 통합했다.
- 점수는 평균하여 최종 점수로 사용
- misaligned keyword annotation의 경우 다수결 투표를 통해 aligned/misaligned를 정함
- point annotation의 경우, 각 포인트에 대해 heatmap 디스크 영역으로 변환한 다음, 평균 heatmap을 계산
(명백히 비현실적인 영역은 모든 annotator가 표시할 가능성이 높고, 최종 heatmap에서 높은 값을 가짐)
3. RichHF-18K: a dataset of rich human feedback
Pick-a-Pic 데이터셋에서 data annotation을 위해 image-text pair의 subset을 추출했다. 기준은 아래와 같다. Pick-a-Pic에서 총 18,000개의 image-text pair에 대한 rich human feedback을 수집했다.
- 광범위한 응용 프로그램을 위해 대부분을 사실적인 이미지(photo-realistic)로 선택
- balanced category를 위해 PaLI visual question answering (VQA) 모델을 사용하여 basic feature를 추출
- 각 image-text pair에 대해 1) 이미지가 사실적인지, 2) 이미지의 category가 무엇인지(‘인간’, ‘동물’, ‘사물’, ‘실내 장면’, ‘실외 장면’ 중 하나)를 추출
- 각 image-text pair에 대해 1) 이미지가 사실적인지, 2) 이미지의 category가 무엇인지(‘인간’, ‘동물’, ‘사물’, ‘실내 장면’, ‘실외 장면’ 중 하나)를 추출
4. Data statistics of RichHF-18K
데이터셋의 점수에 대한 통계조사를 수행했다. 점수의 histogram은 위와 같다. 점수는 Gaussian 분포와 유사하다.
추가로 image-text pair에 대한 annotator들의 평가 일치를 분석하기 위해 점수 간의 최대 차이를 계산했다. 위의 그림은 차이에 대한 histogram이다. 약 25%의 샘플이 완벽한 일치를 보이고 약 85%의 샘플이 양호한 일치를 보이는 것을 볼 수 있다.
Predicting rich human feedback
1. Model
Architecture.
ViT와 T5X 모델을 기반으로, Spotlight 모델 아키텍처에서 영감을 받아 수정한 형태의 구조를 사용했다고 한다.
- ViT
- 생성된 이미지를 입력으로 사용하여 high-level representation 출력 token 생성
- 생성된 이미지를 입력으로 사용하여 high-level representation 출력 token 생성
- Text Embed
- text prompt token을 입력으로 받아 dense vector로 embed
- text prompt token을 입력으로 받아 dense vector로 embed
- Self-attention module (T5X self-attention encoder)
- 이미지 token과 text token을 concat하여 self-attention module에 입력으로 사용
- 이미지 token은 score prediction 및 heatmap prediction에 주로 사용되는데(그림의 구조, output 참고), 이 때 self attention module을 통해 text information을 전달받아 사용함. 반대로, Text token은 text misalignment keyword에 주로 사용되는데, self-attention module을 통해 vision information을 전달받아 vision-aware text encoding이 가능하게 된다.
- self-attention module을 통해 인코딩된 fused text and image token은 세 종류의 predictor를 사용하여 각각 다른 output(Heatmap, Score, Text)을 예측
- Predictor
- Heatmap prediction: 이미지 token은 reshape되고 convolution, deconvolution layer, sigmoid를 거쳐 implausibility(비현실성) 및 heatmap을 출력함
- Score prediction: convolution, linear layer, sigmoid를 거쳐 세분화된 점수를 출력
- Keyword misalignment sequence(Text) prediction: 이미지 생성을 위해 사용된 prompt를 모델에 대한 입력으로 사용한다. 수정된 prompt는 T5X Decoder의 prediction target이 된다. 수정된 prompt에서 misaligned token에 대해 특수 접미사 (’_0’)를 붙여 표시한다.
Model variants.
Heatmap, Score prediction head에 대한 두 가지 model variant에 대해 조사했다.
- Multi-head: 여러 개의 heatmap과 score를 예측하기 위해, multiple prediction head를 사용할 수 있다. 각 score와 heatmap type에 대해 각각의 head를 사용하여야 하므로 총 7개의 head가 필요하다.
- Augmented prompt: 다른 방식으로는 각 prediction 유형에 대해 하나의 head를 사용하는 것이다. 즉, heatmap, score, misalignment sequence에 대해 하나씩, 총 3개를 사용하는 것이다. 세부적인 type 분류를 위해서는 prompt에 각 task string을 추가하는 방식으로 augmentation을 수행한다.(e.g. ‘implausibility heatmap’)
Model optimization
- Heatmap prediction: pixel-wise mean squared error (MSE) loss
- Score prediction: MSE loss
- Keyword misalignment sequence(Text) prediction: teacher-forcing cross-entropy loss
각 prediction module에 대한 loss function은 위와 같다. 전체 loss는 위의 세 가지 loss function에 대한 weighted combination이다.
2. Experiments
Evaluation metrics
- Score prediction task: Score prediction을 위한 일반적인 평가 지표인 Pearson linear correlation coefficient (PLCC)와 Spearman rank correlation coefficient (SRCC)사용
- Heatmap prediction task: Empty ground truth가 있을 수 있으므로 모든 샘플에 대해 MSE를 측정하고, non-empty ground truth 샘플에 대해서는 NSS/KLD/AUC-Judd/SIM/CC 지표를 사용
- Misaligned keyword sequence prediction task: Token-level precision, recall, and F1-score
Prediction result on RichHF-18K test set
RichHF-18K 테스트셋에서 네 가지 세분화된 Score와 implausibility(비현실적) heatmap, misalignment heatmap, and misalignment keyword sequence에 대한 예측 실험 결과는 위 표와 같다.
그림 5,6의 결과는 implausibility(비현실적) heatmap prediction 예제를 나타낸다. 여기서 논문의 모델은 artifact/implausibility가 있는 영역을 식별하고, misalignment heatmap에 대한 모델에서 prediction 예시를 보여준다. 또한 prompt에 해당하지 않는 객체를 식별한다.
그림 7은 몇 가지 이미지 예와 실제 score 및 prediction score를 보여준다.
Learning from rich human feedback
다음으로는 human feedback을 사용하여 이미지 생성을 개선할 수 있는지 조사했다. RAHF 모델로 생성된 데이터셋에 대해 fine-tuning 하기 위해 masked transformer 구조 기반 Muse 모델을 대상 모델로 사용했다.
Finetuning generative models with predicted scores
먼저 pretrain된 Muse 모델을 사용하여 12,564개의 prompt에 대해 각각 8개의 이미지를 생성했다. 각 이미지에 대한 RAHF score를 예측하고 각 prompt의 이미지가 threshold를 넘으면 finetuning 데이터셋으로 선정한다. 그 다음 이 데이터셋으로 Muse를 finetuning한다.
위 그림에서 예측된 plausibility score로 Muse를 finetuning하는 예시를 볼 수 있다. Muse finetuning의 이득을 정량화하기 위해 100개의 새로운 prompt를 사용하여 이미지를 생성하고 6명의 annotator에게 원래 Muse와 finetuning된 Muse의 두 이미지를 나란히 비교(for plausibility, 타당성)하도록 요청했다. Annotator는 이미지 A/B를 생성하는 데 사용된 모델을 알지 못한 채, 다섯 가지 응답(이미지 A는 이미지 B보다 상당히/약간 더 좋음, 거의 동일, 이미지 B는 이미지 A보다 약간/상당히 좋음) 중에서 선택한다.
표 5의 결과는 RAHF plausibility score가 있는 finetuning된 Muse가 원래 Muse보다 훨씬 적은 artifacts/implausibility을 가지고 있음을 보여준다.
또한 그림 8(c)-(d)에서 RAHF aesthetic score를 Latent Diffusion 모델에 대한 Classifier Guidance으로 사용하는 예를 보여준다. 이는 각 세분화된 score가 생성 모델/결과의 다양한 측면을 개선할 수 있음을 보여준다.
Region inpainting with predicted heatmaps and scores
다음으로는 모델의 예측된 heatmap과 score를 사용하여 생성된 이미지의 품질을 개선하기 위한 region inpainting을 수행할 수 있음을 보여준다. 각각의 이미지에 대해 먼저 implausibility(비현실적) heatmap을 예측한 다음, threshold 설정 및 dilating을 사용하여 heatmap을 처리하여 mask를 만든다. Mask된 영역 내에서 Muse inpating을 적용하여 text prompt와 일치하는 새 이미지를 생성한다. 여러 이미지가 생성되고, 최종 이미지는 RAHF에서 예측된 가장 높은 plausibility score로 선택된다.
위의 그림에서 이러한 inpainting 결과를 확인할 수 있다. 그림을 보면 알 수 있듯, inpainting 작업 후 artifact가 적은 그럴듯한 이미지가 생성된다. 이는 다시 말해, RAHF가 RAHF를 학습하는 데 사용된 이미지와는 매우 다른 생성 모델의 이미지에 잘 일반화됨을 보여준다.