Post

[논문 리뷰] Segment Anything (SAM)



Meta AI Research, FAIR arXiv 2023

Introduction


NLP에서는 web-scale 데이터셋에 대해 학습된 Large language model이 강력한 zero-shot/few-shot generalization 성능을 보여주고 있다. 이러한 Foundation model은 학습에 사용되는 task/data를 넘어서는 generalization이 가능하다.

Computer vision에서도 Foundation model에 대한 연구가 진행되었다. Text-image pair를 alignment하는 방법을 사용한 CLIP 및 ALIGN과 같은 연구는 contrastive learning을 사용하여 text 및 image encoder를 학습시켰다. 이러한 encoder는 Image Generation과 같은 down-stream task를 위해 다른 모듈과 함께 사용된다. Vision, language encoder에 대한 많은 진전이 있었지만 computer vision에는 이를 넘어서는 광범위한 문제가 포함되어 있으며, 풍부한 데이터셋 또한 존재하지 않는다.

SAM_1.png

본 논문의 목표는 Image Segmentation을 위한 Foundation model을 구축하는 것이다. 저자는 강력한 generalization을 위한 promptable model을 개발하고 광범위한 데이터셋에 대해 이 모델을 pre-train한다. 이 모델을 통해 prompt engineering을 사용하여 학습하지 않은 새로운 데이터 분포에 대한 문제(segmentation)를 해결하는 것을 목표로 한다.

이를 위해 저자는 문제를 세 가지 구성 요소로 구분하였다. Image Segmentation에 대한 다음 질문을 해결하는 것을 목표로 한다.

  1. What task will enable zero-shot generalization?
  2. What is the corresponding model architecture?
  3. What data can power this task and model?


이러한 질문은 서로 연관되어있으므로 포괄적인 솔루션이 필요하다. 강력한 pre-training objective를 제공하고 넓은 범위의 downstream application이 가능하게 하는 promptable segmentation task를 정의하는 것 부터 시작한다. 이를 위해서는 flexible prompting을 지원하고 prompt를 사용하여 interactive하게 real-time으로 sementation output을 출력할 수 있는 model이 필요하다. 이 모델을 학습시키려면 다양하고 대규모의 data source가 필요하다.

하지만 segmentaion을 위한 web-scale data source는 존재하지 않는다. 따라서 저자는 효율적인 모델을 사용하여 데이터 수집을 지원하고 새로 수집된 데이터를 사용하여 모델을 개선하는 “data engine”을 구축했다.





Segment Anything Task


Task.

먼저 Task를 정의하기 위해 Prompt를 NLP에서 Segmetation에 맞게 변환이 필요하다. 저자는 points(foreground / background point 집합), boxes, mask, text등을 prompt로 사용했다. Promptable segmentation task는 이러한 다양한 형식의 prompt에 대해 유효한 segmentation mask를 생성하는 것을 목표로 한다.

SAM_2.png

“유효한” mask란 promt가 모호하고 여러 객체를 의미할 수 있는 경우에도 출력이 합리적이야함을 의미한다. 예를 들어, 위의 그림에서 point는 하얀 벽면 자체를 의미할 수 도 있고, ‘ZURICH’라는 단어를 의미할 수도 있다. 혹은 ‘Z’ 알파벳 하나를 의미할 수도 있다. 이렇게 다양한 객체를 의미할 수 있는 경우에도 출력이 합리적이어야 한다는 것이다.




Pre-training.

Training sample에 대한 prompt(e.g. points, boxes, masks) sequence가 주어지면 이를 사용하여 mask를 예측하고, ground truth와 비교한다. 본 논문에서는 interactive segmentation 방법을 사용하였다.

차이점은 Interactive segmentation 방법과 달리 prompt가 모호한 경우에도 항상 모든 prompt에 대해 유효한 mask를 예측하도록한다.





Segment Anything Model


SAM_3.png

다음으로는 promptable segmentation을 위한 Segment Anything Model (SAM)에 대해 설명한다. SAM은 Image Encoder, Prompt Encoder, Mask Decoder 세 가지 요소로 구성된다.


Image encoder.

High-resolution 이미지를 처리하기 위해 Masked AutoEncoder(MAE)로 pre-train된 ViT(Vision Transfomer)를 사용한다.


Prompt encoder.

Prompt encoder는 flexible하게 다양한 형태의 입력에 대해 처리한다. 크게 sparse(points, boxes, text)와 dense(masks) 두 가지 prompt를 처리한다. 입력은 각각 아래의 방법으로 256차원 vectorial embedding으로 변환된다.

  • points: point의 위치를 나타내는 positional encoding학습된 embdding의 합으로 표현된다. 여기서 positional encoding은 coordinate space에서 frequency space로embedding을 진행하는 positional encoding을 의미한다(논문: 📄). 학습된 embedding은 point가 foreground 인지, background 인지를 나타내는 지에 대해 학습된 embedding이다.
  • boxes: embedding pair로 표현됨. (1) box의 top-left corner에 해당하는 positional encoding + “top-left”를 의미하는 학습된 embedding(point와 유사하다고 보면 됨), (2) box의 bottom-left corner에 해당하는 positional encoding + “bottom-left”를 의미하는 학습된 embedding
  • free-form text: CLIP의 text encoder를 사용하여 embedding 된다.
  • masks: mask는 이미지와 spatially 동일. convolution을 사용하여 임베딩되고 image 임베딩과 element-wise로 더해진다.


Mask decoder.

Mask decoder는 image 임베딩과 prompt 임베딩을 효율적으로 매핑하여 output mask를 생성한다. Decoder를 적용하기 전에, ViT의 [cls] token과 유사하게 learned output token embedding을 prompt embedding에 추가한다(합쳐서 token).

Decoder 구조는 표준 Transformer decoder를 수정했으며, 4개의 단계로 구성된다.

SAM_4.png

  1. token에 대해 self-attention
  2. Token을 query로, Image embedding과 cross-attention
  3. point-wise MLP 적용
  4. Image embedding을 query로, Token과 cross-attention

2개의 decoder layer를 적용하는데 각 decoder block 전에는 image embedding에 positional encoding을 더해준다. 다음 decoder layer에는 업데이트 된 image embedding과 token(정확히는 업데이트된 token + original prompt token)을 사용한다.

Decoder 실행 후 2개의 Transposed convolutional layer를 사용하여 image embedding을 4배로 upsample한다. 추가로 업데이트 된 token에 대해, token을 query로 image embedding과 cross-attention을 한번 수행한다. 그리고 3개의 MLP layer를 거쳐서 upscale된 image embedding과 spatially point-wise product를 수행하여 mask를 예측한다.


Resolving ambiguity.

모호한 prompt가 주어지면 모델은 여러 개의 유효한 mask의 평균을 출력한다. 이러한 모호함을 해결하기 위하여 하나의 prompt에 대해 3개의 출력 마스크를 생성하도록 했다. 학습 중에는 각각의 mask에 대해 minimum loss에 대해서만 backprop한다. Mask에 대한 순위를 매기기 위해 모델은 각 mask에 대한 confidence score(estimated IoU)를 계산한다.


Losses and training.

Focal loss와 Dice loss를 20:1 로 조합하여 mask prediction을 수행한다.





Segment Anything Data Engine


Segmentation mask data는 인터넷에 충분히 존재하지 않기 때문에, 데이터 엔진을 구축하여 1.1B 규모의 mask dataset SA-1B를 수집했다. 데이터 엔진은 크게 4 단계로 구성된다.


1. Assisted-manual stage.

첫 번째 단계로, 기존 interactive segmentation과 유사하게 model-assisted annotation을 수행한다. 전문 annotator들은 interactive segmentation tool을 사용하여 foreground / background를 지정하고, SAM model이 대략적인 mask를 제공한다.

이 때 초기의 SAM은 공개된 segmentation 데이터셋을 사용하여 학습된 상태이다. Annotation이 어느정도 진행되어 데이터가 추가되면, 추가된 mask 데이터를 사용하여 SAM을 다시 학습한다. 더 많은 mask 데이터가 수집됨에 따라 이미지 인코더를 ViT-B에서 ViT-H로 확장하였으며, 이와 같은 방법으로 모델을 총 6번 재학습했다. 전체적으로 이 단계에서 120K 이미지에서 4.3M mask를 생성하였다.


2. Semi-automatic stage.

다음 단계로는 “Segment Anything” 능력을 향상시키기 위해 mask의 다양성을 높이는 것을 목표로 했다. 이를 위해 모델에서 Confident가 높은 mask를 제외하고, confident가 낮은 눈에 잘 띄지 않는 mask object에 대해 annotation을 지시했다.

Confident mask를 detect하기 위해서 첫 번째 단계에서 수집한 mask를 사용하여 bounding box detector를 학습했다. 이 단계에서 저자는 180K 이미지에 대해 5.9M mask를 추가로 수집했다(총 10.2M mask). 첫 번째 단계와 마찬가지로 5회에 걸쳐 모델을 재학습했다.


3. Fully automatic stage.

앞서 2단계를 통해 충분한 mask를 수집했고, 모호한 경우에 대해서도 valid한 mask를 예측할 수 있게 되었으므로(ambiguity-aware model) 마지막 단계에서는 완전 자동화된 방식으로 annotation을 생성했다. 아래 과정을 통해 11M 이미지에 대해 1.1B high-quality mask를 생성했다.

  1. 32×32 regular grid에 대한 point 집합을 prompt로 입력하고, 각각 point에 대해 valid object에 대한 mask를 예측한다.
  2. ambiguity-aware model을 사용하면 point에 대해 object의 subpart, part, whole object를 반환한다.
  3. 모델의 IoU prediction module은 confident mask를 선택하는데 사용된다.
  4. 이렇게 선택된 mask중에서 stable mask만 선택하였는데, 해당 mask의 probability map에서 0.5 − δ, 0.5 + δ로 threshold를 설정했을 때도 mask의 모양이 유사하면 stable하고 판단하였다고 한다.
  5. 마지막으로 NMS(Non-Maximal Suppression)를 적용하여 중복 항목을 filtering했다.

++ 추가로, 더 작은 mask의 quality를 향상시키기 위해 이미지 crop을 확대하여 multiple overlapping을 수행하기도 했다.





Segment Anything Dataset


데이터 엔진을 통해 생성한 데이터셋 SA-1B는 11M개의 다양한 이미지, 1.1B의 high-quality segmentation mask로 구성되어있다.

사진작가와 협력하는 업체로부터 새로운 11M개의 이미지에 대한 라이센스를 받았으며, 고해상도에 이미지에 대해 downsampling을 수행하였지만 기존의 데이터셋보다는 높은 해상도를 가지고 있다. 데이터 엔진으로 생성된 마스크 품질은 전체 mask 데이터의 94%가 IoU 90%이상을 달성할 만큼 우수한 품질을 가지고 있다(전문 annotator들이 mask 품질을 개선한 후 IoU 비교 결과).


SAM_5.png

위의 그림은 SA-1B의 Mask 속성에 대한 조사를 위해 mask에서 object center의 spatial distribution을 기존의 데이터셋과 비교한 것이다. SA-1B는 다른 데이터셋과 비교하여 비슷하거나 더 넓게 분포되어있는 것을 알 수 있다.


SAM_6.png

위의 그래프는 데이터셋 크기별로 비교를 수행한 결과이다. 각각 mask-image 분포, mage-relative mask size(mask 영역을 이미지 영역으로 나눈 제곱근), shape complexity 분석을 위한 mask concavity([1 - mask area]를 mask의 convex hull 면적으로 나눈 값)에 대한 그래프이다.

SA-1B는 다른 데이터셋과 비교하여 많은 이미지와 mask를 포함하고 있으며, 상대적으로 작은 크기의 mask를 많이 포함하고 있는 것을 살펴볼 수 있다. 또한 mask의 concavity 또한 다른 데이터셋과 유사함을 알 수 있다.





Segment Anything RAI Analysis


SA-1B 데이터셋과 SAM에 대해 공정성과 편견에 대해 조사하는 RAI(Responsible AI) 분석을 수행하였다.

SAM_7.png

SAM_8.png

위 그림은 국가별 이미지 수에 대한 시각화이다. 상위 3개 국가가 서로 다른 대륙이라는 점에서 다양한 국가에서 이미지가 수집되었음을 알 수 있다. 또한 표1, 2를 통해 SAM이 국가, 지역 및 소득, 성별, 연령 및 인종에 대해 공정하고 일관됨을 나타내고 있다.





Experiments (Zero-Shot Transfer Experiments)


SAM을 사용한 zero-shot transfer 실험을 진행헀다. 5가지의 down stream task에 대해 비교했다. mIoU 평가를 위해 23개의 데이터셋을 사용된 데이터셋 예시이다.

SAM_9.png



1. Zero-Shot Single Point Valid Mask Evaluation

먼저 single foreground point로 object를 segmentation하는 task에 대해 평가를 진행헀다. 저자는 표준 mIoU metric(예: 예측 마스크와 실제 마스크 사이의 모든 IoU의 평균) 뿐만 아니라 annotator가 mask quality를 1(nonsense)에서 10(pixel-perfect)까지 평가하는 human study를 추가적으로 수행했다. 저자는 interactive segmentation 모델인 RITM과 비교하였다. 결과는 아래와 같다.

SAM_10.png




2. Zero-Shot Edge Detection

BSDS500 데이터셋을 사용하여 edge detection 작업을 평가했다. 16×16 regular grid의 foreground point로 구성된 prompt에서, 768 predicted mask를 생성한다(point 당 3개). 중복 mask는 NMS에 의해 제거되고, mask probability map에 Sobel filtering과 edge NMS등의 postprocessing을 적용하여 edge map을 생성했다.

대표 edge map에 대한 시각화는 아래와 같다. 저자는 SAM이 edge detection을 위해 학습되지 않았음에도 불구하고 합리적인 edge map을 생성한다는 것을 확인했다. 정량적인 비교는 아래 표와 같다.

SAM_11.png




3. Zero-Shot Object Proposals

다음으로는 object proposal generation의 중간 수준 task에서 SAM을 평가했다. object proposal을 생성하기 위해 automatic mask generation pipeline의 수정된 버전을 사용하여 proposal을 mask output으로 출력했다. 카테고리가 많은 LVIS v1 데이터셋을 사용하여 average recall (AR) metric에 대해 계산하였고 ViTDet detector와 비교하였다. 결과는 아래와 같다.

SAM_12.png




4. Zero-Shot Instance Segmentation

더 높은 수준으로 SAM을 instance segmenter의 segmentation module로 사용하여 평가를 수행했다. COCO 및 LVIS에 대해 SAM 및 ViTDet이 예측한 mask를 비교했다. 결과 이미지 시각화에서 SAM 마스크가 경계가 더 명확하고 ViTDet의 마스크보다 질적으로 더 나은 경우가 많다는 것을 관찰했다.

SAM_13.png




5. Zero-Shot Text-to-Mask

더 높은 수준의 task로 free-form text에 대해 object segmentation을 수행했다. 아래 그림은 정성적 결과이다. SAM은 “a wheel”과 같은 간단한 text prompt와 “beaver tooth grille”과 같은 문구를 기반으로 객체를 분할할 수 있다.

SAM_14.png






This post is licensed under CC BY 4.0 by the author.