[논문읽기] Distillation-Based Training for Multi-Exit Architectures

1 minute read

이 글은 2019 ICCV Oral 논문, Distillation-Based Training for Multi-Exit Architectures를 참고하여 작성하였습니다.

이번 논문도 이전 포스팅한 글과 마찬가지로 ICCV에서 발표를 듣고 감명받아 읽게 된 논문입니다. 일반적으로 knowledge distillation에서 많이 쓰이는 Distillation loss를 Multi-exit architecture 라는 다른 task 에 적용한 논문입니다.

그럼 가볍게 컨셉 위주로 논문 리뷰 해보도록 하겠습니다:)

Multi-exit architectures

저는 처음에 이 Multi-exit architecture 용어 자체가 매우 낯설었는데요, 이것은 convolutional network를 이용한 image classification task에서 anytime prediction을 하기 위해 사용되는 것입니다.

anytime prediction 은 test time에 단일 모델과 예제별로 accuracy와 computation을 trade off 할 수 있는 기능을 말합니다. 가능한 computation budget 안에서 최적의 accuracy를 낼 수 있는 모델을 구현하는 것이 중요하죠.

이렇게 나온것이 multi-exit architecture 입니다. 연속적인 feature layer (convolutional) 에서 각각 다른 깊이의 early exits 들로 확장되는 것이죠. 전형적으로는 multi-task와 같은 방법으로 training을 했습니다. (각각의 exit마다 loss를 구하고 이를 평균내어 적용하는 형태)

하지만, 이러한 방법은 prior knowledge에 대한 것을 무시하는 것이라고 지적하며 knowledge distillation based method를 제안합니다. 이 방법으로는 exits 사이에 information을 공유할 수 있습니다.

Method

Distillation Training Objective

adads

이 논문에서 제안한 방법을 위 그림과 같이 나타내고 있습니다. 각각의 exits 들에서 output이 나오고 이를 distillation loss와 classification loss를 이용해 training 시키는 것이죠.

사용되는 loss 식은 다음과 같습니다.

\[{1 \over N} \sum_{n=1}^{N}{[L_{cls}(x_n, y_n) + L_{dist}(x_n)]}\]

여기서 \(L_{dist}\)는

\[L_{dist}(x_n) = {1 \over M}\sum_{m=1}^M{1 \over | T(m) |} \sum_{t \in T(m)} \mathfrak{l}^\tau(p_t(x_n),p_m(x_n))\]

입니다. M은 exits 들을 나타내고 \(T(m)\)은 해당 exit의 teacher exit을 말합니다. 여기서 \(\tau\)는 temperature scale 입니다. 이 parameter가 커질수록 teacher의 prediction은 더 soften 해집니다. (논문 참고)

Optimization

f

논문에서 제시한 알고리즘은 다음과 같은데요, 여기서 두 가지를 강조합니다.

일단 detach로 꼭 teacher 로부터 student 가 학습되도록 하기위해 teacher prediction을 상수처럼 취급한다고 합니다.

또 다른 하나는 알고리즘 그림 맨 밑에 있는 \(\tau\) 업데이트 부분인데요, 여기서 \(\tau\) 값을 조정하여(증가시켜) 모델이 계속해서 confident 하게 학습되는것을 막는다고 합니다.

Results

각 exit 마다 imageNet을 갖고 test 한 결과 입니다.

asdf

ImageNet(100) 부분은 각 클래스 당 100개의 이미지 씩만 학습시켰다는 의미인데요, 그림에서 볼 수 있듯이 ImageNet(100)의 Exit1 결과가 Exit-wise loss(기존 연구) Exit5 만큼의 성능을 보이고 있는 것을 알 수 있습니다. 하지만, ImageNet (full) 에서는 성능차이가 많지 않은 것으로 봐서 한정된 데이터셋에서 학습 시킬 경우에 논문에서 제시한 방법이 효과적이라는 것을 알 수 있습니다.


글에 이 논문의 내용 전체를 담진 못했지만, 간략하게 컨셉위주로 리뷰해보았습니다. 그 외에도 budget mode, anytime mode, temperature scale 관련 내용이 자세히 논문에 나와있습니다.

Reference:
[1] 논문

Leave a comment