[논문리뷰] Collaborative Learning of Semi-Supervised Segmentation and Classification for Medical Images (CVPR 2019)
1. Introduction
Medical imaging 분야에서 disease grading과 lesion segmentation은 두 가지 근본적인 문제입니다. 저는 정확한 lesion segmentation은 질병이 있는 부분은 주의 깊게 보고 나머지 부분은 무시할 수 있기 때문에 disease grading의 성능을 높일 수 있고, 명확한 disease grading class 별 정보는 정확한 classification을 수행하는데 큰 영향을 끼친 feature와 그 영역을 주의 깊게 볼 수 있기 때문에 lesion segmentation의 성능을 높일 수 있다고 생각합니다. 그러나 medical data는 labeled image data를 얻기 쉽지 않은데, 특히 pixel-level annotated data는 얻는 과정에서 전문적인 지식과 많은 시간이 소요됩니다. 이러한 이유로 unlabeled data는 labeled data보다 얻기 쉽기 때문에 unsupervised learning을 활용한 상황이 많이 존재하는데, 이 경우 모델의 정확도가 굉장히 낮을 수 있습니다. 또한 일반적인 이미지 데이터 학습 방법인 supervised learning을 사용하는 것도 labeled data의 양이 매우 적은 medical imaging 상황에서는 적합하지 않다고 생각합니다. 따라서 이를 해결하기 위해 본 논문은 아주 적은 양의 labeled data(pixel-level annotated data)와 방대한 양의 unlabeled data(image-level annotated data)를 함께 사용하는 semi-supervised learning를 제안합니다.
이 논문은 Diabetic retinopathy(이하 DR)라는 당뇨망막병증 이미지를 실험 데이터로 사용하고 해당 질병은 5개의 grade로 나뉘어집니다.
2. Proposed Methods
- Overall process
그림 1은 이 논문이 제안하는 collaborative learning method를 나타내는 전체 process입니다. 먼저 아주 소량의 pixel-level annotated data(labeled segmentation data)로 segmentation model을 supervised learning 방법으로 pre-train 합니다. 그리고 방대한 양의 disease grade label만을 가지는 image-level annotated data(unlabeled data)를 pre-train된 segmentation model에 input image로 넣어서 weak mask(성능이 좋지 않아 정확하지 않은 mask)를 만듭니다. 이 때 segmentation model은 GAN method를 통한 adversarial learning을 통해 discriminator와 generator의 계속되는 경쟁으로 더 정확하고 새로운 weak mask를 만들어냅니다. 그 후, weak mask와 원본 이미지를 lesion attention classification model에 input image로 넣어서 attention map(weak mask보다 더 강력하고 정확한 pseudo label)를 output으로 만듭니다. 해당 pseudo label은 disease grading을 하는 classification 작업의 정확성과 효율을 올려주고, classification의 output으로 나오는 더 정확한 disease grade는 segmentation의 성능을 올려주면서 서로 상호보완적으로 성능 향상에 도움을 주게 됩니다.
- Contributions
1) GAN의 adversarial loss를 이용한 함수 활용
2) 극소량의 labeled data와 방대한 unlabeled data를 이용한 semi-supervised learning
3) End to end model에서 최적화 되는 method
- Problem Formulation
각각 X^P는 pixel-level annotated data, X^I는 image-level annotated data, G(x)는 lesion segmentation model, C(x)는 disease grading model을 의미합니다. 식 (1)은 pre-train된 segmentation model을 통해 ground-truth mask(s_l^P)와 예측된 mask(s ̃_l^I) 간의 차이를 최소화하여 model을 학습시키는 것을 의미합니다. 또한 식 (2)에서의 y는 image-level annotated data의 질병의 중증도인 classification label이고, att(x)는 lesion attention model을 의미하며 식(1)의 s ̃_l^I와 식(2)의 att(G(X^I))는 같습니다.
- Adversarial Multi-Lesion Masks Generator
본 논문은 아주 적은 양의 labeled data와 방대한 양의 unlabeled data를 함께 사용하는 semi-supervised learning을 제안합니다. 이 때, U-shape GAN method를 이용해 generator를 통하여 pseudo mask를 만들어내고(fake image) discriminator가 이를 ground truth(real image)와 비교하면서 generator가 더 좋은 mask를 만들어내기 위해 학습됩니다. 저는 이 방법이 극소량의 labeled data는 supervised learning으로, 나머지 unlabeled data는 unsupervised learning으로 학습하는 방법보다 훨씬 정확한 mask를 만들어내도록 모델을 학습시킬 수 있다고 생각합니다. GAN을 이용하는 해당 알고리즘 외에 labeled data가 적은 medical imaging에 적용할 수 있는 interpolation 혹은 extrapolation이 함께 적용되면 더 좋은 결과를 낼 수 있을 것 같습니다. 현재 generator는 labeled data가 너무 적기 때문에 flip, rotation와 같은 augmentation 방법을 사용하는데 이는 아주 기본적인 기법으로 더 높은 성능 향상을 기대하기 힘듭니다. 이 때, feature들끼리 pairing을 통한 extrapolation으로 새로운 labeled data를 생성한다면 weak mask를 생성하는 generator의 성능을 더 높일 수 있을 것 같습니다.
그림2은 adversarial training loss를 활용하는 전체 pipeline입니다. 첫 번째, 극소량의 데이터 X^P를 input image로 넣어 multi-lesion masks generator를 supervised learning으로 pre-train 시킵니다. 그 다음 pre-train된 generator에 많은 양의 X^I를 통과시켜 fake image(완전하지 않은 weak mask)를 생성합니다. Discriminator는 이 때 생성된 정답이 없는 fake image와 정답이 있는 real image를 각각 0과 1로 학습합니다. 두 번째, 초기에 예측된 lesion map과 X^I를 multi-lesion attentive model을 학습하는 데이터로 활용합니다. 이 때, 모두 lesion grading의 정답을 가진 데이터를 이용해 attentive model의 성능을 향상시킵니다. 이 후 attentive model로 생성한 attention map은 위의 generator가 생성한 weak mask보다 더욱 강력하고 정확한 pseudo mask로 사용되며 이는 segmentation model을 더 잘 fine-tuning 시킬 수 있습니다.
L_Adv는 일반적인 GAN에서 사용되는 adversarial loss이고 L_CE는 cross entropy 즉, classification loss입니다. L_Seg는 segmentation model을 최적화하기 위한 loss이고 식 (3)과 같이 discriminator에서 발생되는 adversarial loss와 classification loss의 weight 합으로 이루어집니다.
- Lesion Attentive Disease Grading
수동으로 질병의 중증도를 grading 하는 작업은 전문적인 지식이 굉장히 많이 필요하고 시간 또한 매우 많이 소요됩니다. 이 논문은 총 5단계로 나뉘는 질병인 DR을 데이터로 사용하며 grading 작업을 더욱 효율적이게 하기 위해 앞서 말했던 모델들과 attentive model의 결합을 제안합니다. Attentive model을 사용하면 중요한 영역만을 주의 깊게 보고 관련이 없는 부분은 무시할 수 있습니다. 이 때문에 output으로 나오는 attention map을 pseudo mask로 사용하여 segmentation 성능을 더 향상시킬 수 있습니다.
그림 3은 본 논문에서 제안하는 lesion attentive disease 의 구조이며 아래의 수식에 따라 output이 결정됩니다.
그림 3에서 가장 위의 파란색 convolutional layer를 통과한 원본 이미지와 segmentation generator를 통과하여 추출된 4개의 lesion들을 각각 하나씩 concat 하여 f^(low_att)을 얻을 수 있습니다. a_l 또한 같은 방법으로 sigmoid 연산을 통해 얻을 수 있는 attention map입니다. 우측 상단의 multi-lesion attentive features는 원본 이미지와 attention map인 pseudo mask들이 합쳐진 데이터입니다. 해당 데이터를 동일한 가중치를 가지는 classification model이 concat한 후 fc layer를 통해 질병의 중증도를 grading합니다. 이 결과는 segmentation에 활용할 pseudo labeling을 위해 활용되며 이는 더 강력하고 정확한 mask label을 만듦으로써 성능을 향상시킬 수 있습니다.
Attention 메커니즘을 pseudo mask를 생성하는 작업 외에 UNet에 적용하는 attention UNet으로 generator를 구성하는 시도를 해 볼 수 있을 것 같습니다. 현재, UNet의 skip-connection에 attention을 추가한 attention UNet이 제안된 논문이 있고 UNet보다 높은 성능을 보여주고 있습니다. 따라서 이 네트워크를 generator에 적용시켜도 좋을 것 같습니다.
3. Experimental Results
그림 4는 multi-lesion에 대한 segmentation 결과를 나타냅니다. 극소량의 labeled data로 pre-train된 generator로 생성한 mask는 semi-supervised 작업을 거친 후의 mask보다 훨씬 낮은 성능을 볼 수 있습니다. 우측 하단 Pre-train model 결과의 빨간색 box는 전혀 잘못된 detection의 결과를 보여줍니다.
위의 표들은 각각 IDRID, EyePACS, Messidor dataset에 대한 결과를 나타냅니다. 본 논문에서 제안하는 method를 모두 사용했을 때 accuracy, kappa, ROC와 PR에 대한 AUC 등의 지표들이 더 좋은 결과를 나타냄을 알 수 있습니다.