역전파, 직접 짜봐야 하나요?

OpenAI의 안드레이 카패시(Andrej Karpathy)가 얼마전 ‘Yes you should understood backprop‘란 글을 미디엄 사이트에 올렸습니다. 안드레이는 OpenAI에 오기 전에 스탠포드 대학에서 PhD 학생으로 근무했고 CS231n 강의를 진행했습니다. 이 강의는 영상과 강의 노트 모두 인터넷에 공개되어 있어 인공지능에 관심있는 연구자나 학생들에게 인기가 많습니다. 그런데 학생들이 가끔 CS231n 의 숙제에 불평을 하는 경우가 있는 모양입니다. 텐서플로우 같은 라이브러리가 역전파 알고리즘을 모두 자동으로 처리해 주는 데 굳이 numpy 로 정방향(forward pass), 역방향(backward pass, backpropagation) 코드를 직접 구현하는 숙제를 할 필요가 있느냐고 말이죠.

역전파 코드를 직접 만들어 봐야할 이유가 지적 호기심이나 더 나은 역전파 알고리즘을 만들기 위해서가 아니라 역전파를 알아야 뉴럴 네트워크를 만들 때 오류를 범하지 않고 디버깅을 하는데 도움을 주기 때문입니다. 안드레이는 역전파가 불완전한 추상화(leaky abstraction)라고 말하고 있습니다. 불완전한 추상화 또는 누수 추상화는 조엘 스폴스키의 블로그를 통해서 알려졌었습니다. 한마디로 구멍이 많다는거죠!

시그모이드 함수의 예에서 역전파가 전혀 이루어지지 않고 그래디언트가 사라져 버리는 경우를 설명하고 있습니다. 흔히 이를 배니싱 그래디언트(vanishing gradient)라고 부릅니다. 즉 그래디언트가 사라져 버려서 역전파가 안되고 파라미터가 학습이 되지 않는다는 뜻입니다. 시그모이드 함수의 식과 도함수는 아래와 같습니다(편의상 바이어스는 입력 벡터에 포함되어 있다고 생각합니다):

z = \dfrac{1}{1 + e^{-t}}  ,  \dfrac{\partial z}{\partial t} = z(1 - z)  ,  t = W \cdot x

따라서 안드레이가 제시한 코드와 같이 정방향, 역방향 계산은 간단한 numpy 코드로 만들 수 있습니다.

z = 1/(1 + np.exp(-np.dot(W, x))) # 정방향, z는 W의 행 개수와 동일한 크기의 벡터
dx = np.dot(W.T, z*(1-z)) # 역방향: x를 위한 로컬 그래디언트
dW = np.outer(z*(1-z), x) # 역방향: W를 위한 로컬 그래디언트, outer로 원래 W 행렬 사이즈 복원

가중치 파라미터가 큰 값으로 초기화되면 np.dot(W, x) 값이 커지고 큰 양수나 큰 음수가 될 것입니다. 이는 시그모드이 함수의 출력(z)를 1 또는 0 에 가깝게 만듭니다. 그런데 이 때, 즉 z = 0 or 1, 시그모이드 도함수는 z*(1-z) 이므로 두 경우 모두 쉽게 0 의 값을 가질 수 있게 됩니다. 이 시그모이드 활성화 함수의 그래디언트가 0 이 되면 체인룰에 의해 이후 네트워크로 전달되는 모든 그래디언트도 자동으로 0 이 됩니다. 즉 W 파라미터가 학습되지 않을 것입니다. 이런 상황을 포화되었다고(saturated) 말하고 있습니다. 그래프로 보시면 조금 더 이해가 빠릅니다.

1-gkxi7lywygplu5dn6jb6bg

출처: 안드레이 카패시 미디엄 블로그

시그모이드 함수는 S 자 곡선을 그리며 0 과 1 로 수렴합니다. 시그모이드의 도함수는 t = 0 일 때 z = 0.5 가 되며 이 때 도함수 z(1 - z) 가 가장 큰 값 0.25가 됩니다. z = 0 or 1 일 땐 도 함수 값은 모두 0 으로 수렴합니다. 반대로 생각하면 시그모이드 함수를 쓰는 한 출력 z 가 아무리 큰 값을 가지더라도 역전파될 때 그래디언트는 최소한 1/4 만큼 줄어든다는 뜻입니다. 따라서 기본적인 SGD 방법에서 뒤쪽 레이어가 앞쪽 레이어보다 학습이 느리게 될 수 밖에 없습니다.

하이퍼볼릭 탄젠트(tanh) 활성화 함수의 경우 -1 ~ 1 사이의 결과 값을 가집니다. 이 경우에도 마찬가지로 쌍곡선함수의 도함수가 1 - z^2 이므로 모두 그래디언트가 0 이 됩니다. 그래서 가중치의 초기화, 규제(regularization) 또 입력 데이터의 정규화가 중요할 수 밖에 없습니다. CS231n 의 강좌를 참고로 알려 주었습니다.

렐루(ReLU) 활성화 함수는 음수 값은 버리고 양수 값은 그대로 바이패스하는 간단한 함수 입니다. 코드로 나타내면 아래와 같습니다.

z = np.maximum(0, np.dot(W, x)) # 정방향
dW = np.outer(z > 0, x) # 역방향: W의 로컬 그래디언트
# 렐루의 출력(z)가 음수일 땐 x에 0이 곱해지므로 가중치 파라미터에 업데이트 되는 값이 없습니다.

이 함수의 도함수는 z > 0 일 땐 1 이 되고 z = 0 일 땐 0 이 됩니다.

1-g0yxlk8kebw8ua1f82xqda

출처: 안드레이 카패시 미디엄 블로그

렐루 함수의 그래디언트가  0 이라는 것은 앞서 말한 대로 가중치 파라미터 W 가 업데이트 되지 않는다는 뜻입니다. 만약 배치 그래디언트 디센트에서 전체 데이터를 주입한 후 일부 뉴런의 그래디언트가 이와 같이 0 이 된다면 이 뉴런은 훈련 세트가 변경되지 않는 한 계속 죽은 상태 즉 학습이 되지 않는 상태로 남게 될 것입니다. 학습 속도(learning rate)가 너무 커서 파라미터를 급격히 변경시킨 것이 오히려 상황을 악화시켜 경우 이런 상황에 빠질 수도 있습니다. 어떤 경우에라도 네트워크의 뉴런 일부가 학습되지 않는 멍텅구리가 되는 건 누구라도 바라는 상황은 아닙니다. 관련된 CS231n 강의 영상도 함께 제시하였습니다.

RNN 의 경우 히든 상태를 저장하고 있는 hs[t] 를 이전 히든 상태 hs[t-1] 로 부터 만들어 냅니다. 안드레이가 제시한 코드에서는 입력 값 x 를 고려하지 않았습니다.

for t in xrange(T):
    ss[t] = np.dot(Whh, hs[t-1])
    hs[t] = np.maximum(0, ss[t])
...
for t in reversed(xrange(T)):
    dss[t] = (hs[t] > 0) * dhs[t]
    dhs[t-1] = np.dot(Whh.T, dss[t])

이 예에서는 렐루 활성화 함수를 사용하여 hs[t] 를 만들었고 뉴런의 출력 값은 ss[t] 입니다. 역전파가 이루어지게 되면 각 시간 스텝에서 hs[t] 의 값이 0 보다 컸는지를 확인하고 네트워크의 앞에서 전달 받은 그래디언트 dhs[t] 를 다음 시간 스텝으로 전달할지 결정합니다. 그래디언트가 전달될 수 있는 상황이라고 가정하면 이전 스텝의 히든 상태는 시그모이드의 예와 마찬가지로 가중치 매트릭스와의 곱이 됩니다. RNN 에서 히든 상태에 대한 가중치 매트릭스는 하나이므로 시간의 역순으로 그래디언트를 계속 전달할 때에도 곱해지는 가중치 매트릭스트는 변하지 않고 계속 앞의 그래디언트에 반복적으로 곱해집니다. 그러므로 가중치가 1 보다 큰 값을 가지고 있을 때 그래디언트가 폭주(exploding)할 수 있습니다. RNN 에서 그래디언트 클리핑(clipping)을 신경써야 하는 이유입니다. 역시 CS231n 강의를 참고로 소개하고 있습니다.

그런데 안드레이가 이 글을 쓴 이유는 한 코드를 우연히 발견해서입니다. 딥 큐 러닝(Deep Q Learning)을 구현한 한 텐서플로우 코드에서 아래와 같이 델타(delta)를 클리핑하고 있습니다.

self.clipped_delta = tf.clip_by_value(self.delta, -1, 1, name='clipped_delta')
...
self.loss = tf.reduce_mean(tf.square(self.clipped_delta), name='loss')

-1 ~ 1 사이로 델타 값을 클리핑할 경우 이 범위 밖에서는 렐루 함수의 경우와 마찬가지로 로컬 그래디언트가  0 이 되기 때문에 손실 함수의 학습시킬 수 없게 됩니다. 안드레이는 후버(Huber) 손실 함수를 대신 사용하도록 권고하였습니다. 정방향에서는 델타의 범위가 -1 ~ 1 을 벗어날 경우 델타의 절대값에서 0.5 를 빼게 됩니다. 역방향 계산에서는 L1 손실의 도함수처럼 x 의 부호에 따라 1 또는 -1 이 됩니다.

self.loss = tf.reduce_mean(clipped_error(self.delta), name='loss')
...
def clipped_error(x):
    return tf.select(tf.abs(x) < 1.0, 0.5 * tf.square(x), tf.abs(x) - 0.5)

그러니 뉴럴 네트워크를 잘 만들려면 직접 정방향, 역방향 코드를 만들어 보라고 권합니다. 또 CS231n의 역전파 강의도 참고하라고 홍보하네요.

사실 이 코드는 데브시스터즈의 깃허브에 있는 것으로 강화학습과 텐서플로우에 정통하신 김태훈님이 만들었습니다. 그리고 안드레이의 이슈는 즉각 수정되었습니다. 역시 고수는 고수를 알아보는 것이죠! 🙂

답글 남기기

아래 항목을 채우거나 오른쪽 아이콘 중 하나를 클릭하여 로그 인 하세요:

WordPress.com 로고

WordPress.com의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

Twitter 사진

Twitter의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

Facebook 사진

Facebook의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

Google+ photo

Google+의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

%s에 연결하는 중