[논문읽기] A Comprehensive Overhaul of Feature Distillation

2 minute read

이 글은 2019 ICCV 논문, A Comprehensive Overhaul of Feature Distillationclova-ai blog를 참고하여 작성하였습니다.

knowledge distillation 분야에 관심있어서 막 공부를 시작하는 사람들이 읽었으면 좋겠다고 생각하여 소개하게 된 논문입니다.

저도 knowledge distillation에 대한 대략적인 개념만 알고 최근 어떤 연구들이 있는지 찾아보던 중이었는데, 이 논문에서 지금까지 이루어진 distillation 연구들을 잘 정리해놓아서 공부하는데 큰 도움이 되었습니다.

기존 knowledge distillation 연구들과 이 논문에서 제안한 컨셉 위주로 논문을 리뷰해보도록 하겠습니다:>

Background

스마트폰과 같은 가벼운 기기에서 딥러닝 모델이 돌아갈 수 있도록 하기 위해서는 모델 사이즈가 작아야합니다. 이를 위해 모델을 더 작고 비용효율적으로 만들어야 하는데, 이를 위한 연구가 Knowledge Distillation (KD) 입니다.

KD에서는 smaller network (student) 와 larger network (teacher)를 이용해 teacher network의 knowledge를 student network에 잘 옮겨서 비슷한 성능이 나오게 합니다.

이때, “어떠한 방법을 이용하는가” 이것이 매우 중요합니다. 이 논문에서는 기존 연구들에서 사용했던 “teacher와 student network의 차이를 줄이는 방법”을 매우 잘 정리해놓았습니다.

teacher network와 student network의 output의 차이를 좁히는 방법 만으로는 부족하여 “feature distillation”이 있습니다. hidden feature들도 비슷하게끔 하는거죠.

structure

general training scheme

위에 그림처럼 중간중간 hidden feature 사이의 distance를 구하여 학습하는 방법이 일반적입니다. 이때 각 T와 d를 무엇을 쓰는지에 따라 성능이 달라집니다. 식으로 나타내면 다음과 같습니다.

\[L_{distill} = d(T_t(F_t), T_s(F_s))\]

이 loss를 최소화하는 것이 목표입니다.

difference

기존 논문과 이 논문에서 \(T_t, T_s\)와 distance 등을 구하기 위해 어떠한 방법을 썼는지 정리해놓은 표입니다. 이 논문에서는 Teacher transform, disillation feature position, distance 등에 다른 방법 썼습니다.

Proposed Method

“Transfer all feature information without missing any import information from the teacher”

이 논문에서는 최대한 정보손실이 일어나지 않는 방향으로 distillation loss를 설계했습니다. 논문에서 제안한 방법은 다음 그림과 같이 정리할 수 있습니다.

enter image description here

propsed method

- Margin ReLU

margin ReLU를 사용함으로써 도움이 되는 정보들(positive information)은 변형 없이 사용되고, 도움이 되지 않는 정보들(negative information)은 사용되지 않습니다. 이에 따라 beneficial information 중에 빠뜨린 것 없이 모두 distillation 될 수 있게 됩니다.

\[\sigma_m(x) = max(x, m).\]

따라서 다음과 같은 식을 사용하는데요, 이때 m은 음수로, positive value들은 student가 teacher과 같은 값을 가질 수 있도록 하고, negative value에 대해서는 0보다 작은 값을 나오게 하여 student value도 음수로 만들게 됩니다.

\[m_c = E[F_t^i\mid F_t^i<0, i \in C].\]

이때, margin m을 구하는 방법은 위 식과 같습니다. m은 channel-by-channel 음수 값들의 기대값입니다.

- Partial L2

\[d_p(T,S) = \sum_i^{WHC} \begin{cases} 0 & {if \ \ } {S_i\le T_i\le 0 }\\ (T_i-S_i)^2 & {otherwise} \end{cases}\]

student가 teacher 보다 작고, 음수인 경우를 제외하고 L2를 구해줍니다. 이는, teacher의 negative response에 대해서는 student value가 target value 보다 크면 반드시 감소해야하고, 그렇지 않은 경우에는 그럴 필요가 없기 때문입니다.

- Distillation Loss

따라서 위를 적용한 최종적인 distillation loss 식은 다음과 같습니다.

\[L_{distill} = d_p(\sigma_{m_c}(F_t), r(F_s))\]

오늘은 기존 distillation 연구들과 이 논문에서 제안한 방법 위주로 간단하게 글을 작성해보았습니다. 매우 짧게 끝난 것 같지만, 요즘 시간이 없어서ㅠㅠ 짧게나마 리뷰해봅니다.

Reference:
[1] paper

Leave a comment