import torch as t
from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import CIFAR10
import numpy as np
from torch import optim
import torchvision.utils as vutil
class Config:
lr=0.0002
nz=100# 噪声维度
image_size=64
image_size2=64
nc=3# 图片三通道
ngf=64 #生成图片
ndf=64 #判别图片
gpuids=None
beta1=0.5
batch_size=32
max_epoch=1# =1 when debug
workers=2
opt=Config()
# 数据加载和预处理
dataset=CIFAR10(root='cifar10/',\
transform=transforms.Compose(\
[transforms.Scale(opt.image_size) ,
transforms.ToTensor(),
transforms.Normalize([0.5]*3,[0.5]*3)
]))
# 什么惰性加载,预加载,多线程,乱序 全都解决
dataloader=t.utils.data.DataLoader(dataset,opt.batch_size,True,num_workers=opt.workers)
#模型定义
class ModelG(nn.Module):
def __init__(self,ngpu):
super(ModelG,self).__init__()
self.ngpu=ngpu
self.model=nn.Sequential()
self.model.add_module('deconv1',nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False))
self.model.add_module('bnorm1',nn.BatchNorm2d(opt.ngf*8))
self.model.add_module('relu1',nn.ReLU(True))
self.model.add_module('deconv2',nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False))
self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ngf*4))
self.model.add_module('relu2',nn.ReLU(True))
self.model.add_module('deconv3',nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False))
self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ngf*2))
self.model.add_module('relu3',nn.ReLU(True))
self.model.add_module('deconv4',nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False))
self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ngf))
self.model.add_module('relu4',nn.ReLU(True))
self.model.add_module('deconv5',nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False))
self.model.add_module('tanh',nn.Tanh())
def forward(self,input):
gpuids=None
if self.ngpu:
gpuids=range(gpuids)
return nn.parallel.data_parallel(self.model,input, device_ids=gpuids)
def weight_init(m):
#模型参数初始化. 可以优化成为xavier 初始化
class_name=m.__class__.__name__
if class_name.find('conv')!=-1:
m.weight.data.normal_(0,0.02)
if class_name.find('norm')!=-1:
m.weight.data.normal_(1.0,0.02)
class ModelD(nn.Module):
def __init__(self,ngpu):
super(ModelD,self).__init__()
self.ngpu=ngpu
self.model=nn.Sequential()
self.model.add_module('conv1',nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False))
self.model.add_module('relu1',nn.LeakyReLU(0.2,inplace=True))
self.model.add_module('conv2',nn.Conv2d(opt.ndf,opt.ndf*2,4,2,1,bias=False))
self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ndf*2))
self.model.add_module('relu2',nn.LeakyReLU(0.2,inplace=True))
self.model.add_module('conv3',nn.Conv2d(opt.ndf*2,opt.ndf*4,4,2,1,bias=False))
self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ndf*4))
self.model.add_module('relu3',nn.LeakyReLU(0.2,inplace=True))
self.model.add_module('conv4',nn.Conv2d(opt.ndf*4,opt.ndf*8,4,2,1,bias=False))
self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ndf*8))
self.model.add_module('relu4',nn.LeakyReLU(0.2,inplace=True))
self.model.add_module('conv5',nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False))
self.model.add_module('sigmoid',nn.Sigmoid())
def forward(self,input):
gpuids=None
if self.ngpu:
gpuids=range(gpuids)
return nn.parallel.data_parallel(self.model,input, device_ids=gpuids).view(-1,1)
netg=ModelG(opt.gpuids)
netg.apply(weight_init)
netd=ModelD(opt.gpuids)
netd.apply(weight_init)
ModelD ( (model): Sequential ( (conv1): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (relu1): LeakyReLU (0.2, inplace) (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) (relu2): LeakyReLU (0.2, inplace) (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (bnorm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) (relu3): LeakyReLU (0.2, inplace) (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (bnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) (relu4): LeakyReLU (0.2, inplace) (conv5): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False) (sigmoid): Sigmoid () ) )
# 优化器
optimizerD=optim.Adam(netd.parameters(),lr=opt.lr,betas=(opt.beta1,0.999))
optimizerG=optim.Adam(netg.parameters(),lr=opt.lr,betas=(opt.beta1,0.999))
# 模型的输入输出
input=Variable(t.FloatTensor(opt.batch_size,opt.nc,opt.image_size,opt.image_size2))
label=Variable(t.FloatTensor(opt.batch_size))
noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1))
fixed_noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1))
real_label=1
fake_label=0
# 训练
criterion=nn.BCELoss()
for epoch in xrange(6):
for ii, data in enumerate(dataloader,0):
#训练 D 网
netd.zero_grad()
#真实图片
real,_=data
input.data.resize_(real.size()).copy_(real)
label.data.resize_(input.size()[0]).fill_(real_label)
output=netd(input)
error_real=criterion(output,label)
error_real.backward()
D_x=output.data.mean()
# 假图片
noise.data.resize_(input.size()[0],opt.nz,1,1 ).normal_(0,1)
fake_pic=netg(noise).detach()
output2=netd(fake_pic)
label.data.fill_(fake_label)
error_fake=criterion(output2,label)
error_fake.backward()
D_x2=output2.data.mean()
error_D=error_real+error_fake
optimizerD.step()
# 训练 G网 G网和D网训练次数1:2
if t.rand(1)[0]>0.5:
netg.zero_grad()
label.data.fill_(real_label)
noise.data.normal_(0,1)
fake_pic=netg(noise)
output=netd(fake_pic)
error_G=criterion(output,label)
error_G.backward()
optimizerG.step()
D_G_z2=output.data.mean()
print ('{ii}/{epoch} lossD:{error_D},lossG:{error_G},{D_x2},{D_G_z2},{D_x}'.format(ii=ii,epoch=epoch,\
error_D=error_D.data[0],error_G=error_G.data[0],\
D_x2=D_x2,D_G_z2=D_G_z2,D_x=D_x))
if ii%100==0 and ii>0:
fake_u=netg(fixed_noise)
vutil.save_image(fake_u.data,'fake%s.png'%ii)
vutil.save_image(real,'real%s.png' %ii)
0/0 lossD:1.36829459667,lossG:0.894766509533,0.486646070145,0.411444487981,0.504189566709 1/0 lossD:1.27741086483,lossG:0.946315646172,0.443847945891,0.392864650115,0.51032573171 2/0 lossD:1.3073836565,lossG:0.946315646172,0.4495063629,0.392864650115,0.49758704938 3/0 lossD:1.13484323025,lossG:0.978585541248,0.415155397728,0.380839592777,0.554169878364 4/0 lossD:1.16920399666,lossG:0.955828249454,0.436097631231,0.391813380178,0.556802288629 5/0 lossD:1.1938097477,lossG:0.9689848423,0.471667434089,0.387160736136,0.583109146915 6/0 lossD:1.21595573425,lossG:0.9689848423,0.451340062544,0.387160736136,0.552521592937 7/0 lossD:1.10636794567,lossG:1.00885415077,0.392565631308,0.368035835214,0.551878349856 8/0 lossD:1.1512260437,lossG:1.00885415077,0.468347774819,0.368035835214,0.605279957876 9/0 lossD:1.02875673771,lossG:1.00885415077,0.363269736525,0.368035835214,0.569313969463 10/0 lossD:0.985520362854,lossG:1.17760324478,0.292472077999,0.314572186675,0.534999700263 11/0 lossD:1.20287775993,lossG:1.0366050005,0.521362030879,0.360163018573,0.63543565385 12/0 lossD:1.22502338886,lossG:1.03986740112,0.467908551916,0.357476376463,0.562215656973 13/0 lossD:1.33360874653,lossG:1.02025258541,0.459647334181,0.369051158894,0.496054288
t.save(netd.state_dict(),'1epoch_netd.pth')
t.save(netg.state_dict(),'1epoch_netg.pth')
import torch as t
from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import CIFAR10
import numpy as np
from torch import optim
import torchvision.utils as vutil
#from tensorboard_logger import Logger
'''
https://zhuanlan.zhihu.com/p/25071913
WGAN 相比于DCGAN 的修改:
1. 判别器最后一层去掉sigmoid # 回归问题,而不是二分类概率
2. 生成器和判别器的loss不取log # Wasserstein 距离
3. 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c #W距离->L连续->数值稳定
4. 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行 #->玄学
GAN 两大问题的解释:
collapse mode ->KL 散度不对称
数值不稳定 -> KL散度和JS散度优化方向不一样
'''
class Config:
lr=0.0002
nz=100# 噪声维度
image_size=64
image_size2=64
nc=3# 图片三通道
ngf=64 #生成图片
ndf=64 #判别图片
gpuids=None
beta1=0.5
batch_size=32
max_epoch=12# =1 when debug
workers=2
clamp_num=0.01# WGAN 截断大小
opt=Config()
# 加载数据
dataset=CIFAR10(root='cifar10//',\
transform=transforms.Compose(\
[transforms.Scale(opt.image_size) ,
transforms.ToTensor(),
transforms.Normalize([0.5]*3,[0.5]*3)
]))
# 什么惰性加载,预加载,多线程,乱序 全都解决
dataloader=t.utils.data.DataLoader(dataset,opt.batch_size,True,num_workers=opt.workers)
# 网络结构
class ModelG(nn.Module):
def __init__(self,ngpu):
super(ModelG,self).__init__()
self.ngpu=ngpu
self.model=nn.Sequential()
self.model.add_module('deconv1',nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False))
self.model.add_module('bnorm1',nn.BatchNorm2d(opt.ngf*8))
self.model.add_module('relu1',nn.ReLU(True))
self.model.add_module('deconv2',nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False))
self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ngf*4))
self.model.add_module('relu2',nn.ReLU(True))
self.model.add_module('deconv3',nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False))
self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ngf*2))
self.model.add_module('relu3',nn.ReLU(True))
self.model.add_module('deconv4',nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False))
self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ngf))
self.model.add_module('relu4',nn.ReLU(True))
self.model.add_module('deconv5',nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False))
self.model.add_module('tanh',nn.Tanh())
def forward(self,input):
gpuids=None
if self.ngpu:
gpuids=range(gpuids)
return nn.parallel.data_parallel(self.model,input, device_ids=gpuids)
def weight_init(m):
# 参数初始化。 可以改成xavier初始化方法
class_name=m.__class__.__name__
if class_name.find('conv')!=-1:
m.weight.data.normal_(0,0.02)
if class_name.find('norm')!=-1:
m.weight.data.normal_(1.0,0.02)
class ModelD(nn.Module):
def __init__(self,ngpu):
super(ModelD,self).__init__()
self.ngpu=ngpu
self.model=nn.Sequential()
self.model.add_module('conv1',nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False))
self.model.add_module('relu1',nn.LeakyReLU(0.2,inplace=True))
self.model.add_module('conv2',nn.Conv2d(opt.ndf,opt.ndf*2,4,2,1,bias=False))
self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ndf*2))
self.model.add_module('relu2',nn.LeakyReLU(0.2,inplace=True))
self.model.add_module('conv3',nn.Conv2d(opt.ndf*2,opt.ndf*4,4,2,1,bias=False))
self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ndf*4))
self.model.add_module('relu3',nn.LeakyReLU(0.2,inplace=True))
self.model.add_module('conv4',nn.Conv2d(opt.ndf*4,opt.ndf*8,4,2,1,bias=False))
self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ndf*8))
self.model.add_module('relu4',nn.LeakyReLU(0.2,inplace=True))
self.model.add_module('conv5',nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False))
# modify: remove sigmoid
#self.model.add_module('sigmoid',nn.Sigmoid())
def forward(self,input):
gpuids=None
if self.ngpu:
gpuids=range(gpuids)
return nn.parallel.data_parallel(self.model,input, device_ids=gpuids).view(-1,1).mean(0).view(1)#
## no loss but score
netg=ModelG(opt.gpuids)
netg.apply(weight_init)
netd=ModelD(opt.gpuids)
netd.apply(weight_init)
ModelD ( (model): Sequential ( (conv1): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (relu1): LeakyReLU (0.2, inplace) (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) (relu2): LeakyReLU (0.2, inplace) (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (bnorm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) (relu3): LeakyReLU (0.2, inplace) (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (bnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) (relu4): LeakyReLU (0.2, inplace) (conv5): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False) ) )
# 定义优化器
optimizerD=optim.RMSprop(netd.parameters(),lr=opt.lr ) #modify : 不要采用基于动量的优化方法 如Adam
optimizerG=optim.RMSprop(netg.parameters(),lr=opt.lr ) #
# 定义 D网和G网的输入
input=Variable(t.FloatTensor(opt.batch_size,opt.nc,opt.image_size,opt.image_size2))
label=Variable(t.FloatTensor(opt.batch_size))
noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1))
fixed_noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1))
real_label=1
fake_label=0
#criterion=nn.BCELoss() # WGAN 不需要log(交叉熵)
one=t.FloatTensor([1])
mone=-1*one
#开始训练
for epoch in xrange(opt.max_epoch):
for ii, data in enumerate(dataloader,0):
#### 训练D网 ####
netd.zero_grad() #有必要
real,_=data
input.data.resize_(real.size()).copy_(real)
label.data.resize_(input.size()[0]).fill_(real_label)
output=netd(input)
output.backward(one)#######for wgan
D_x=output.data.mean()
noise.data.resize_(input.size()[0],opt.nz,1,1 ).normal_(0,1)
fake_pic=netg(noise).detach()
output2=netd(fake_pic)
label.data.fill_(fake_label)
output2.backward(mone) #for wgan
D_x2=output2.data.mean()
optimizerD.step()
for parm in netd.parameters():parm.data.clamp_(-opt.clamp_num,opt.clamp_num) ### 只有判别器需要 截断参数
#### 训练G网 ########
if t.rand(1)[0]>0.8:
# d网和g网的训练次数不一样, 这里d网和g网的训练比例大概是: 5:1
netg.zero_grad()
label.data.fill_(real_label)
noise.data.normal_(0,1)
fake_pic=netg(noise)
output=netd(fake_pic)
output.backward(one)
optimizerG.step()
#for parm in netg.parameters():parm.data.clamp_(-opt.clamp_num,opt.clamp_num)## 只有判别器需要 生成器不需要
D_G_z2=output.data.mean()
if ii%100==0 and ii>0:
fake_u=netg(fixed_noise)
vutil.save_image(fake_u.data,'wgan/fake%s_%s.png'%(epoch,ii))
vutil.save_image(real,'wgan/real%s_%s.png'%(epoch,ii))
t.save(netd.state_dict(),'epoch_netd.pth')
t.save(netg.state_dict(),'epoch_netg.pth')