이안 굿펠로우(Ian Goodfellow)의 GANs(Generative Adversarial Networks) 페이퍼를 파이토치로 구현한 블로그가 있어서 재현해 보았습니다.(제목과는 달리 전체 코드는 100줄이 넘습니다) 이 블로그에 있는 파이토치 소스는 랜덤하게 발생시킨 균등 분포를 정규 분포 데이터로 만들어 주는 생성기(G, Generator)와 생성기가 만든 데이터와 진짜 정규 분포를 감별하는 분류기(D, Discriminator)를 학습시키는 것입니다.
분류기에는 감별하려는 분포의 데이터 한 벌이 모두 들어가서 참, 거짓의 이진 판별을 내 놓습니다. 생성기는 한번에 하나의 데이터 포인트가 들어가서 출력값도 하나입니다. 원하는 분포를 얻기 위해서 생성기에 필요한 만큼 데이터를 계속 주입하면 됩니다. 행렬로 생각하면 분류기에는 하나의 행이 들어가고, 생성기에는 배치 효과를 위해 하나의 열 벡터가 들어갑니다.
그런데 기대한 것처럼 정규분포와 비슷한 생성기를 학습시키기가 어려웠습니다. 하이퍼파라미터 세팅에 많이 민감하여 종종 에러가 폭주하기도 했습니다. 그리고 학습이 되어도 정규분포의 모양을 가진다기 보다 평균과 표준편차만 같은 분포를 만들어 내었습니다. 즉 평균과 표준편차를 학습한 것 같은 모습니다(물론 이 코드는 결과보다 예시를 위한 것이라고 합니다만..). 그래서 생성기의 출력을 한번에 전체 분포를 얻을 수 있도록 마지막 레이어의 노드 수를 늘렸습니다. 이렇게 했더니 손쉽게 정규 분포와 비슷한 결과물을 얻었습니다만 생성기의 목적을 제대로 수행한 것인지는 확신이 서지 않네요. 🙂
전체 노트북은 깃허브에 있습니다. 코드의 중요 부분은 아래와 같습니다.
학습시키고자 하는 정규 분포는 평균이 4, 표준 편차가 1.25 입니다. 분류기는 입력 유닛이 100개이고 히든 유닛의 수는 50, 출력 노드는 하나입니다. 생성기는 100개의 입출력 유닛을 가지고 히든 유닛 수는 50개 입니다. 이 예제의 분류기, 생성기는 모두 각각 두 개의 히든 레이어를 가지고 있는데 두 개의 히든 레이어의 노드 수는 동일합니다.
data_mean = 4 data_stddev = 1.25 d_input_size = 100 d_hidden_size = 50 d_output_size = 1 g_input_size = 100 g_hidden_size = 50 g_output_size = 100
분류기와 생성기의 클래스 정의입니다.
class Discriminator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Discriminator, self).__init__() self.map1 = nn.Linear(input_size, hidden_size) self.map2 = nn.Linear(hidden_size, hidden_size) self.map3 = nn.Linear(hidden_size, output_size) def forward(self, x): x = F.relu(self.map1(x)) x = F.relu(self.map2(x)) return F.sigmoid(self.map3(x)) class Generator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Generator, self).__init__() self.map1 = nn.Linear(input_size, hidden_size) self.map2 = nn.Linear(hidden_size, hidden_size) self.map3 = nn.Linear(hidden_size, output_size) def forward(self, x): x = F.relu(self.map1(x)) x = F.relu(self.map2(x)) return self.map3(x)
렐루(ReLU)를 사용해서 레이어를 감싸고 분류기의 출력은 시그모이드로, 생성기의 출력은 마지막 레이어의 값을 그대로 뽑았습니다. 다음은 분류기와 생성기가 서로 학습하는 과정입니다.
for epoch in range(num_epochs): # 분류기와 생성기의 학습 횟수가 서로 다를 수 있습니다. for d_index in range(d_steps): D.zero_grad() # 그래디언트를 비웁니다. real_data = Variable(real_sampler(d_input_size)) # 정규분포 real_decision = D(real_data) # 정규분포 판별 real_error = loss(real_decision, Variable(torch.ones(1))) # 무조건 참이어야 합니다 real_error.backward() # 역전파 fake_input = Variable(fake_sampler(g_input_size)) # 균등분포 fake_data = G(fake_input) # 생성기가 정규분포로 변조시킵니다 fake_decision = D(fake_data) # 가짜인지 구분해야 합니다 fake_error = loss(fake_decision, Variable(torch.zeros(1))) # 무조건 거짓이어야 합니다 fake_error.backward() # 역전파 d_optimizer.step() # 분류기 파라미터 업데이트 # 생성기 학습 차례 for g_index in range(g_steps): G.zero_grad() # 그래디언트를 비웁니다. fake_input = Variable(fake_sampler(g_input_size)) # 균등분포 fake_data = G(fake_input) # 생성기가 정규분포로 변조합니다 fake_decision = D(fake_data) # 가짜인지 구분합니다. fake_error = loss(fake_decision, Variable(torch.ones(1))) # 생성기 입장에서는 분류기가 참으로 속아야 합니다. fake_error.backward() # 역전파 g_optimizer.step() # 생성기 파라미터 업데이트
이렇게 학습시키게 되면 초기에는 평균과 표준편차가 잠시 널뛰지만 반복이 어느 시점을 넘어서면 거의 변동이 없는 상태가 됩니다. 아래 그림에서 파란색이 평균, 붉은 색이 표준편차 입니다.
에러는 총 세가지로 분류기가 학습할 때 계산하는 정답 분포에 대한 에러(파란)와 가짜 분포(녹색)에 대한 에러, 그리고 생성기가 학습할 때 계산하는 가짜 분포에 대한 에러(적색) 입니다. 아래 그래프에서 볼 수 있듯이 세 에러가 거의 변동이 없이 일치하고 있습니다.
학습이 끝난 후 만들어진 생성기로 랜덤한 균등 분포를 입력해서 정규 분포로 바꾸어 보았습니다. 눈으로 볼 때에도 구분할 수 없을 만큼 비슷했습니다. 깃허브에 있는 노트북 코드는 원본 코드와는 조금 다르며 코드를 간소화하기 위해 람다 함수 등을 수정했습니다.