Model Architecture in Medical Image Segmentation

3 minute read

Medical image segmentation 문제를 풀기 위해 제안된 여러 model architecture에 대해 정리한 글입니다.

전에 이 논문에 참여하면서 Medical Image (MI) Segmentation의 여러 모델들을 실험해본 경험이 있는데, 그 모델들을 정리할 겸 최근에 나온 모델 연구들도 같이 소개해드리겠습니다.

Base Architecture

Natural image에서는 Deeplab v3+ 가 많은 segmentation task의 default 모델로 쓰이는 것처럼 MI segmentation task에서 가장 default로 많이 쓰이는 모델들은 V-Net3D U-Net 입니다.

스크린샷 2021-07-03 오전 11.07.06

UNet 과 같이 encoder-decoder (혹은 contracting-expanding path) 구조에 가운데 skip connection으로 구성되어 있습니다. 대신 medical image에서 쓰는 데이터는 대부분 3D 볼륨데이터이기 때문에 이 모델들에서는 2D convolution 대신 3D convolution을 사용합니다.

V-Net과 3D U-Net의 구조는 위에서 보는 것과 같이 거의 비슷합니다. 사소한 차이점은 down/up-sampling에 convolution을 썼는지, residual connection을 사용했는지, skip connection 에서 합쳐주는 방식 등입니다.

이 모델들을 base model로 해서 attention module을 새롭게 제안하는 연구도 있었고, 이 base model에 contrastive loss 를 추가해 pre-training 하는 방법을 제안한 연구도 있었습니다. 그만큼 여러 연구에 base model로 굉장히 많이 쓰이는 모델들입니다.

Hybrid Architecture

위에 base network들은 모두 3D convolution으로 이루어져있다고 했는데, 3D convolution은 2D convolution에 비해서 무조건 high computation complexity 문제가 있을 수 밖에 없습니다.

아래는 이 문제를 해결하기 위해 제안된 hybrid 모델 구조들입니다. 2D convolution과 3D convolution을 섞어서 사용하자! 가 핵심입니다.

H-DenseUNet

Framework to explore hybrid (intra-slice and inter-slice) features for segmentation

스크린샷 2021-07-03 오전 11.01.36

논문에서는 2D/3D DenseUNet을 활용하여 inter-slice 와 intra-slice의 feature를 추출하고, HFF layer를 통해 2D/3D feature fusion을 해줍니다. 여기서 DenseUNet은 기존 UNet 구조에 DenseNet의 layer 연결구조를 사용한 것입니다.

이 hybrid feature learning 구조를 통해서 volumetric context를 고려할 수 없는 2D network와 computational cost가 큰 3D network의 문제들을 서로 보완할 수 있다고 합니다. 그리고 기존 2D network들은 pre-trained 모델이 많기 때문에 그것을 잘 활용할 수 있다는 장점이 있습니다.

LW-HCN

논문에서도 3D convolution의 computational complexity를 줄이기 위해 Depthwise Spatial Temporal Separate (DSTS) convolution 을 제안하고 있습니다.

스크린샷 2021-07-03 오전 11.02.12

H-DenseUNet 같은 경우에는 2D convolution을 사용했다고 하더라도 3D DenseUNet의 encoder-decoder 형태를 완전히 사용하였기 때문에 computational cost가 줄어든게 맞나? 라는 생각을 했고 실제로 실험을 했을때도 굉장히 training 시간이 오래결렸습니다. LW-HCN 같은 경우에는 앞부분의 feature를 추출할 때는 2D depthwise conv를 뒷부분에는 3D DSTS conv를 조화롭게 사용하여 확실히 cost를 줄였습니다.

앞부분에는 2D convolution을 사용하고, 중간에 Z라는 transformation layer로 합쳐주고 그 다음 3D DSTS Convolution을 통과합니다. 이때 DSTS 구조를 보시면, spatial branch와 temporal branch로 구분을 하여 한 slice안에서의 특징 그리고 cross-slice간의 특징을 모두 잘 잡을 수 있도록 하였습니다.

맨 뒷단에 ASPP를 사용하고 중간 feature들을 upsampling하여 합쳐주는 구조를 보면 Deeplab v3+와 매우 유사합니다.

VA Mask-RCNN

마지막 논문은 전에 attention 관련 글에서도 마지막에 소개해드린 적이 있는 논문입니다.

자세한 그림은 이전 글을 참고해주시고, 간단히 설명하자면 2.5D network로 z 방향으로 context information을 잘 고려할 수 있도록 하는 Mask-RCNN base의 네트워크를 제안하였습니다.

2.5D image는 3개의 인접한 슬라이스들로 구성을 하였고, Mask-RCNN의 피라미드 구조로 각 feature를 뽑는데, 중간에 Volumetric Attention (VA) module을 통해 전체 볼륨에 대한 feature를 고려한 attention map을 만들어 adaptive feature refinement를 가능하게 해줬습니다.

Transformer-based Architecture

최근 글에서 vision transformer 흐름에 대해 다룬 적이 있었는데, medical image segmentation 쪽에서도 올해 초에 활발하게 transformer 적용한 network들이 연구되었습니다.

TransUNet

Framework to preserve the advantage of Transformers and also benefit medical image segmentation

비젼 분야에 transformer를 적용하는 대표적인 이유로 long-range relation을 정의하기 하기 좋다는 점이 있습니다. 이 논문에서도 transformer를 활용하여 global context를 모델링 하였고, large-scale pre-training으로 downstream task에서도 효과적이라는 것을 보여주었습니다.

스크린샷 2021-07-03 오후 2.19.14

위 그림과 같이 convolution으로 feature를 뽑는데, decoder로 가기 전에 transformer를 활용하여 feature 사이의 관계를 학습할 수 있도록 합니다.

CoTr

스크린샷 2021-07-03 오후 2.26.35

다음 논문은 TransUNet 후에 나온 논문 CoTr 입니다. 위 그림을 보시면 SETR 모델은 단순히 encoder 부분에 transformer를 적용하였고, TransUNet에서는 encoder에서 뽑은 feature를 transformer의 input으로 사용했습니다. CoTr 논문에서는 encoder에서 multi-scale and high-resolution feature를 뽑아서 transformer의 input으로 넣습니다.

단순히 그냥 multi-layer에서 뽑아서 넣으면 되는거 아니야? 라고 생각하실 수 있는데, transformer 특성상 computational/spatial complexity가 크기 때문에 이 논문에서는 효율적인 DeTrans Layer를 새롭게 제안합니다.

스크린샷 2021-07-03 오후 2.29.45

여기서는 key를 전체를 넣는 것이 아니라, reference point를 기준으로 key sampling을 해서 MS-DMSA layer가 작은 key set에만 집중할 수 있도록 합니다.


Reference:
[1] 3D U-Net
[2] V-Net
[3] Attention U-Net
[4] Contrastive learning
[5] H-DenseUNet
[6] LW-HCN
[7] VA Mask-RCNN
[8] TransUNet
[9] CoTr

Leave a comment