#!/usr/bin/env python # coding: utf-8 # # **2장 사전 훈련된 신경망** # ## 2 사이클 Gen # - (`horse2zebra_0.4.0.pth`) 1,068 개의 말, 1,335 개의 얼룩말 사진 훈련셋 데이터 # In[1]: import torch import torch.nn as nn class ResNetBlock(nn.Module): # <1> def __init__(self, dim): super(ResNetBlock, self).__init__() self.conv_block = self.build_conv_block(dim) def build_conv_block(self, dim): conv_block = [] conv_block += [nn.ReflectionPad2d(1)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True), nn.InstanceNorm2d(dim), nn.ReLU(True)] conv_block += [nn.ReflectionPad2d(1)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True), nn.InstanceNorm2d(dim)] return nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) # <2> return out # In[2]: class ResNetGenerator(nn.Module): def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3> assert(n_blocks >= 0) super(ResNetGenerator, self).__init__() self.input_nc = input_nc self.output_nc = output_nc self.ngf = ngf model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True), nn.InstanceNorm2d(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): mult = 2**i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=True), nn.InstanceNorm2d(ngf * mult * 2), nn.ReLU(True)] mult = 2**n_downsampling for i in range(n_blocks): model += [ResNetBlock(ngf * mult)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), nn.InstanceNorm2d(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): # <3> return self.model(input) # In[3]: # 학습된 데이터를 호출 netG = ResNetGenerator() model_path = 'data/ch02/horse2zebra_0.4.0.pth' model_data = torch.load(model_path) netG.load_state_dict(model_data) netG.eval() # eval 모드에서 실행 # In[8]: from PIL import Image img = Image.open("data/ch02/horse.jpg") img.resize((300,200)) # In[5]: from torchvision import transforms preprocess = transforms.Compose([ transforms.Resize(256), transforms.ToTensor()]) img_t = preprocess(img) # 사진 이미지를 전처리 작업 batch_t = torch.unsqueeze(img_t, 0) batch_out = netG(batch_t) out_t = (batch_out.data.squeeze() + 1.0) / 2.0 out_img = transforms.ToPILImage()(out_t) # out_img.save('data/ch02/zebra.jpg') out_img