Post

[논문 리뷰] Med-PaLM M, Towards Generalist Biomedical AI



Google Research Google DeepMind NEJM AI 2024

📖 핵심 훑어보기 !!

  • MultiMedBench: 의료 영상, 임상 텍스트, 유전체학을 포함한 여러 modality에 걸쳐 있는 multimodal biomedical 벤치마크를 큐레이션. 다양한 modality 뿐 아니라 14가지 다양한 task를 포함.
  • Med-PaLM M: LLM 기반 multimodal 모델 PaLM-E를 활용, Instruction Tuning을 통해 모델 tuning. Generalist biomedical AI system 구축.



Introduction


의학은 multimodal 학문이다. Text, imaging, genomic 등 다양한 data modality를 포함하고 있다. 하지만 Biomedical AI의 상당한 발전에도 불구하고 대부분의 모델은 unimodal single task 시스템이다. 이러한 single-task, unimodal AI 시스템은 real-world application에서 매우 제한적이다.

본 논문에서는 여러 biomedical data modality를 해석하고 동일한 모델 weight set으로 다양한 downstream task를 처리할 수 있는 통합 모델(foundation model)을 개발하는 것을 목적으로 한다. 먼저 이러한 연구를 가능하게 하기 위한 포괄적인 multimodal 의료 벤치마크 MultiMedBench를 큐레이션 헀다. MultiMedBench는 언어, 의료 영상 및 유전체학(genomics) modality를 아우르는 오픈 소스 multimodal 벤치마크로서 질문 답변, 시각적 질문 답변, 의료 영상 분류, 방사선 보고서 생성 및 요약, 유전체 변이 호출을 포함한 14가지 다양한 biomedical 작업이 포함된다.

Med-PaLM_1.png

MultiMedBench를 활용하여 최근의 LLM, multimodal foundation model을 기반으로 하는 large-scale generalist biomedical AI 시스템, Med-PaLM Multimodal (Med-PaLM M)을 개발하였다. Med-PaLM M은 다양한 유형의 multimodal biomedical 정보를 통합하하는 유현한 sequence-to-sequence 구조이다. 또한 modality-agnostic한 language decoder를 사용하여, 통합된 학습 전략으로 다양한 biomedical task를 수행한다.

요약하면, 논문의 주요 기여는 다음과 같다.

  • Curation of MultiMedBench: 의료 영상, 임상 텍스트, 유전체학을 포함한 여러 modality에 걸쳐 있는 multimodal biomedical 벤치마크를 생성. 일반 biomedical AI 시스템을 학습하고 평가하기 위한 14가지 다양한 task를 포함.
  • Med-PaLM M, the first demonstration of a generalist biomedical AI system: 동일한 weight set으로 다양한 task를 수행하는 multitask, multimodal biomedical AI 시스템. Task-specific customization 없이 여러 task에서 SOTA 모델과 비슷하거나 능가하는 성능 달성.
  • Evidence of novel emergent capabilities in Med-PaLM M: 정량적인 수치적 평가를 넘어, zero-shot 의학 추론, 새로운 의학적 concept에 대한 일반화, task transfer에 대한 긍정적인 증거를 관찰함.
  • Human evaluation of Med-PaLM M outputs: 모델에 대한 정량적 metric을 넘어 radiologist 평가를 진행.





MultiMedBench: A Benchmark for Generalist Biomedical AI


논문에서는 multi-task, multimodal 데이터셋 MultiMedBench를 큐레이션 했다. MultiMedBench는 12개의 익명화된 오픈 소스 데이터 세트와 14개의 개별 task로 구성되어있다.

  • Task type: question answering, report generation / summarization, visual question answering, medical image classification, genomic variant calling.
  • Modality: text, radiology(CT, MRI, and X-ray), pathology(병리학), dermatology(피부과), mammography(유방조영술), genomics(유전체학).
  • Output format: classification을 포함한 모든 task에 대한 open-ended generation.

Med-PaLM_2.png

  • Language-only task: medical question answering, radiology report summarization
  • Multimodal task: medical visual question answering (VQA), medical image classification, chest X-ray report generation, genomic variant calling





Med-PaLM M: A Proof of Concept for Generalist Biomedical AI


1. Model preliminaries

Pathways Language Model (PaLM)

PaLM은 2022년 Google이 발표한 autoregressive transformer 구조의 LLM으로, Pathways라고 하는 새로운 ML 시스템을 이용해 두 대의 TPU Pods에서 효율적으로 학습을 진행했다. 대표적인 특징은 아래와 같다.

  • 540B parameter, pipeline-free training (cf. GPT3:175B / GPT4: 1.76T)
  • 780B token으로 구성: 웹 페이지, 위키피디아 기사, 소스 코드, 소셜 미디어 대화, 뉴스 기사, 책에서 학습 데이터 수집
  • GPT와 동일하게 transformer decoder 구조
  • 3가지 PaLM 모델 variant


Vision Transformer (ViT)

ViT는 2021 Google이 발표한 논문으로 Transformer 구조를 Image에 도입하여 현재까지도 많은 논문들의 baseline 구조로 사용되고 있는 모델이다.

본 논문에서는 4B paremeter, 22B paremeter로 pre-train된 두 가지 ViT 모델을 vision encoder로 사용한다. 두 모델 모두 4B 이미지 classification 데이터셋(JFT-300M 등)에 대해 supervised learning으로 pre-train 되었다.


PaLM-E

PaLM-E는 2023 Google에서 발표된 논문으로, text, vision, sensor signal 등의 multimodal input sequence를 처리할 수 있는 multimodal language 모델이다. 기본 PaLM-E 모델은 pre-train된 PaLM과 ViT를 사용하며, 단일 prompt에서 이미지, text 및 sensor 신호를 섞어 넣을 수 있는 유연성을 제공하여 모델이 완전히 multimodal context에서 예측을 수행할 수 있도록 한다.

자세한 논문 내용은 PaLM-E 논문 리뷰를 참조하자.




2. Putting it all together: Med-PaLM M

Dataset and preprocessing

  • MultiMedBench의 모든 image는 224×224×3으로 resize 됨. W,H 비율 유지를 위해 padding 사용.
  • Gray-scale 이미지는 channel 차원으로 동일한 이미지를 3-channel로 쌓아서 사용.
  • Task-specific prepossessing 적용(class balancing, image data augmentation 등)



Instruction task prompting and one-shot exemplar

Med-PaLM_3.png

Multimodal 입력에 대해 다양한 task를 수행할 수 있는 generalist biomedical AI 모델을 학습시키기 위해 instruction tuning*을 사용했다. 모델이 다양한 유형의 task를 수행하도록 task-specific한 instruction을 입력으로 사용했다.

구체적으로 task prompt는 instruction, 관련 context information, question으로 구성된다. 예시는 위와 같다. 첫번째 chest X-ray report generation task에서는 이미지 방향 정보와 연구의 목적을, 두 번째 피부 질환 classification task에 대해서는 환자의 병력을 추가 context information으로 포함했다. Classification task의 경우는 답변 옵션이 제공되는 객관식 질문으로 공식화했다.

모델이 instruction을 더 잘 따를 수 있도록 task prompt에 one-shot exemplar를 추가했다. One-shot exempler는 모델이 원하는 형식의 응답을 생성하도록 prompt하는 데 효과적이다. 이 때 multimodal task의 경우 실제 이미지를 넣어주지 않고 text placeholder(text 문자열 <img>)로 대체했다. 실제 이미지를 예시로 넣지 않아 계산 효율성을 높였다.


In-Context Learning vs Fine-tuning
In-Context Learning 이란, Meta Learning의 일종으로, 별도의 모델을 학습을 거치지 않고(weight 변경 없음), inference 단계에서 prompt를 잘 생성하여줌으로서 맥락적인 의미를 모델이 파악하게 하여 답변을 생성하게 하는 것을 의미한다. 이 접근 방식은 모델의 기존 지식과 generalization 능력을 활용하여 주어진 맥락적 단서를 기반으로 특정 task을 이해하고 수행한다.
반면 Fine-tuning은 pre-train된 모델을 특정 task나 domain에 대해 추가로 학습을 진행하는 것을 말한다. 모델의 weight을 데이터에 더 잘 맞게 조정하여 특정 task 또는 domain에 대한 모델의 성능을 향상시킨다.
두 방법에 대한 비교는 아래와 같다.

Med-PaLM_4.png

Instruction Tuning (Instruction Finetuning)*
Instruction Tuning은 위의 두 방법을 결합하여 모델의 유연성을 키우고 정확도를 향상시키기 위한 방법이다. Fine-tuning 처럼 추가적인 학습을 진행하는데 이 때 학습하는 데이터셋이 사용자의 구체적인 지시(instruction)와 이에 대한 응답(output)으로 이루어져 있다. 이러한 pair dataset을 통해 모델은 질문에 대해 더 유연하고 정확한 답변을 도출하도록 학습된다.
Dataset에서 지시 내용에 추가적인 설명이 필요하다면 In-Context learning에서 few-shot 예시를 주는 것 처럼 Instruction에 덧붙여 줄 수도 있다.
참고: https://devocean.sk.com/blog/techBoardDetail.do?ID=165806&boardType=techBlog





Model training

  • Pretrain된 12B, 84B, 562B parameter PaLM-E에 대해 fine-tuning 됨
  • 전체 MultiMedBench task에 대해 아래의 표와 같은 비율로 혼합하여 학습됨
  • 전체 모델에 대해 End-to-End training 진행
  • ViT는 JFT-300M에 대해 supervised 방식으로 pre-train 됨

Med-PaLM_5.png





Experiments


1. Med-PaLM M performs near or exceeding SOTA on all MultiMedBench tasks

두 가지 방법으로 비교 평가를 수행했다.

  • MultiMedBench task에 대해 SOTA 모델과 비교(단일 task)
  • MultiMedBench data에 대해 fine-tuning을 진행하지 않은 generalist 모델 (PaLM-E)

결과는 아래 표와 같다. Med-PaLM M은 SOTA 모델과 비슷하거나 더 나은 결과를 보여주고 있다.

Med-PaLM_6.png



2. Med-PaLM M demonstrates zero-shot generalization to novel medical tasks and concepts

저자는 새로운 concept에 대한 generalization 성능 평가를 위해 Montgomery County(MC) 데이터셋의 chest X-ray 이미지에서 결핵(TB) 이상 감지에 대해 평가를 진행했다. 아래 표를 확인해보면, unseen 데이터에 대한 zero-shot generalization 기능이 해당 데이터셋에 최적화된 SOTA와 비교하여 경쟁력있는 성능을 보임을 확인할 수 있다.

Med-PaLM_7.png



3. Med-PaLM M performs encouragingly on radiology report generation across model scales

4명의 임상의 평가자가 보고서의 품질을 평가하기 위해, MIMIC-CXR 데이터 세트에서 방사선과 의사가 제공한 reference report를 다른 Med-PaLM M 모델 scale(12B, 84B 및 562B)에서 생성된 보고서와 비교했다.

그림 4a는 각 평가자가 3개의 Med-PaLM M variant 중 하나에서 생성된 보고서 또는 reference report 중에서 가장 좋다고 평가한 빈도를 요약한 것이다. 또한 각 Med-PaLM M 모델에서 생성된 보고서를 방사선과 의사가 제공한 reference report와 직접 비교할 수 있도록 1:1 비교를 수행했다. 결과는 그림 4b와 같다.

Med-PaLM_8.png


또한 Human evaluation(방사선과 의사 확인)을 통해 누락 및 오류율에 대해 조사했다. 그림 5는 모델 variant(12B, 84B, 562B)에 따른 결과이다. 주목할 점은 이 오류율이 이전 연구에서 MIMIC-CXR 데이터 세트에 대한 인간 방사선과 의사 baseline으로 보고된 오류율과 비슷하다는 점이라고 한다.

Med-PaLM_9.png





This post is licensed under CC BY 4.0 by the author.