Per-Pixel Classification is Not All You Need for Semantic Segmentation

1 minute read

Facebook AI 논문, MaskFormer (2021) 를 읽고 작성한 글입니다.

Semantic segmentation 은 픽셀 단위로 classification 을 하는 것이라고 보통 설명을 시작하는데요, 이 논문의 제목만 봐도 반대되어 매우 흥미롭게 읽기 시작한 논문입니다. 방법 위주로 설명드리도록 하겠습니다.

우선, 이 논문은 semantic 과 instance-level segmentation 이 모두 가능하면서 성능도 좋은 transformer-based 방법을 제안하고 있습니다. 논문을 읽다보면 작년에 나온 논문 DETR 가 많이 떠오릅니다. 기본적인 transformer 구조를 가져오기도 했고, segmentation output을 ground truth와 매칭할 때 bipartite matching 방법을 쓰고 있습니다. 그래서 읽기 전 DETR 논문에 대한 간단한 이해를 하고 계시다면 더 이해하기 쉬울 것 같습니다.

스크린샷 2021-07-25 오후 10.20.18

위 그림은 semantic segmentation task에서 기존의 per-pixel classification (좌) 방법과 이 논문에서 제안하는 mask classification (우) 방법을 비교한 것입니다. 기존 방법은 각 위치마다 같은 classification loss를 적용하는데 반해, 논문에서 제안한 방법은 binary mask들과 해당 mask의 class들을 예측해서 mask loss와 classification loss를 각각 적용합니다.

위 방법을 통해 기존 per-pixel classification 방법만큼의 성능도 내면서 instance segmentation으로 확장될 수도 있도록 하였습니다. 그리고 기존 instance segmentation 방법들은 bounding box를 필요로 하거나 여러 auxiliary loss들을 필요로 한다는 단점이 있습니다. 이 논문에서는 오로지 mask loss와 classification loss만 사용하고 있습니다.

MaskFormer

MaskFormer 에서는 mask classification을 하는데, 우선 image를 N개의 영역으로 구분하여 최종적으로 N개의 binary mask를 output으로 합니다. 다음, 각 N개의 segment들을 K개의 class들 중 하나로 분류합니다. 그렇다면 output은 \(z = \{(p_i, m_i)\}_{i=1}^N\) 가 될 것입니다. p는 class probability, m은 binary mask 입니다. 이제 ground truth와 매칭하여 classification loss, mask loss를 구해주어야 하는데, 이때 bipartite matching-based 방법을 사용하여 gt와 매칭했을 때 가장 cost가 적은 매칭으로 설정하여 loss를 구해줍니다. 참고로 이때 매칭이 되지 않는 output들도 있을 수 있기 때문에 “no object” 라는 클래스도 필요로 합니다.

스크린샷 2021-07-25 오후 10.41.34

구조는 위 그림과 같습니다. 크게 세 부분, pixel-level module, transformer module, segmentation module로 구성되어 있습니다.

pixel-level module은 일반적인 segmentation network (encoder-decoder)를 사용해서 image feature와 per-pixel embeddings를 뽑아줍니다. transformer module은 DETR 논문과 굉장히 비슷한데, image feature를 벡터 형태로 만들어 중간에 넣어주고, positional embedding 역할을 하는 N개의 query들을 input으로 합니다. 다음 decoder를 통과해 N개의 per-segment embedding을 뽑아줍니다. 여기서 N query들은 learnable parameter 입니다.

마지막으로 segmentation module에서는 N개의 per-segment embdding들을 한쪽은 linear classifier와 softmax를 통과시켜 class probability prediction을 구하고, 다른 한쪽은 MLP를 통과시켜서 per-pixel embedding과 dot product를 하여 N개의 mask prediction을 구합니다. 이것들이 합해저 최종적인 segmentation mask 결과가 나오게 됩니다. 이때 최종 mask를 구하기 위해 argmax를 사용해 구해주는데, class probability 와 mask probability 모두 높아야 선택될 수 있는 방법을 쓰고 있습니다.


Reference:
[1] MaskFormer

Leave a comment