[PyTorch] Freeze Network: no_grad, requires_grad 차이
모델 Freeze 하는 방법에 대해 정리한 글입니다.
때때로 모델 일부는 freeze 시키고 모델 일부에 대해서만 parameter update를 하고 싶은 경우가 있습니다. 대표적으로 transfer learning 나 generative adversarial network 의 경우가 있는데요. 이럴 경우 어떻게 파이토치에서 코드를 작성해야 할지 여러 방법들을 정리해보았습니다.
첫번째 경우, 위 그림과 같이 모델 A, B가 있는데, 모델 A는 freeze 시키고 모델 B만 학습시키는 방법입니다. 쉽게 생각하면 B에서 A로 가는 gradient를 막아버리면 되겠죠.
# 첫번째 경우_1
z = A(x)
y = B(z.detach())
위와 같이 detach 함수를 사용하면 z를 static tensor로 만들어 update 할때, z 이전 파라미터에는 gradient가 흐르지 않습니다.
# 첫번째 경우_2
with torch.no_grad():
z = A(x)
y = B(z)
위와 같이 no_grad 함수를 쓰는 방법도 있습니다. no_grad 사용하면 내부에 있는 operation들의 gradient는 계산이 되지 않게 됩니다. gradient가 계산되지 않기 때문에 당연히 모델 A의 파라미터 업데이트가 될 수 없겠죠.
# 첫번째 경우_3
for p in A.parameter():
p.requires_grad = False
z = A(x)
y = B(z)
마지막 방법은 requires_grad를 쓰는 방법입니다. A의 파라미터를 하나씩 불러와서 gradient를 꺼주는 것입니다. 이렇게 하면 A의 parameter들을 상수 취급해주어 업데이트도 되지 않습니다. 이후에 다시 A의 파라미터를 업데이트 해주고 싶다면 requires_grad=True 로 gradient를 켜주면 됩니다.
두번째 경우는, 위 그림처럼 A만 update 시키고 B를 freeze 하는 것입니다. 위에 경우랑 똑같은거 아닌가 싶을 수 있지만 주의할 사항이 있습니다. 위 상황처럼 B로 가는 gradient를 끊어버리면 안된다는 것입니다. 우선 detach 함수는 이 경우 쓰기 힘들고 2, 3번째 경우를 비교해 보도록 하겠습니다.
# 두번째 경우_1
z = A(x)
with torch.no_grad():
y = B(z)
# 두번째 경우_2
for p in B.parameter():
p.requires_grad = False
z = A(x)
y = B(z)
두 방법 모두 똑같아 보이지만, 이 중 하나는 작동하고 하나는 작동하지 않습니다. 첫번째 경우는 오류가 발생하는데요, no_grad 와 requires_grad 차이에 있습니다.
no_grad는 아예 gradient 자체를 계산하지 않겠다는 것입니다. 따라서 B 모델의 gradient 자체를 계산하지 않는다면 자연스럽게 A는 당연히 gradient가 계산될 수 없겠죠. 반면에 requires_grad는 B의 상수 취급을 해서 B 모델의 파라미터가 업데이트 되지는 않지만, 여전히 gradient는 흐를 수 있는 상태입니다. 따라서 두번째 경우는 A 모델이 업데이트 될 수 있습니다.
이 외에도 여러가지 방법이 있을 수 있지만 이번 글에서 no_grad와 requires_grad의 차이를 설명드리고 싶어서 두 방법 위주로 설명했습니다. 다른 방법으로는 zero_grad() 또는 optimizer를 다르게 쓰는 방법 등이 있을 것 같습니다.
Leave a comment