1 概述
what:给定一句话,或一些要求,按要求生成需要的图像。
本篇总结主要包含反卷积和GAN(generative adversial network, GAN)
2 反卷积与图像生成
what:反卷积可以看成卷积的反操作,但不完全一样,不是把卷积反过来就是反卷积。即给定特征,反向生成输入。但反卷积运算的卷积核与卷积运算的不同
效果:卷积是大图像越来越小,反卷积可以图像越来越大
2.1 反卷积运算
卷积核不同:卷积卷积核旋转180度可得到反卷积运算的卷积核
padding:如果希望反卷积运算后,图像大小保持不变,需要计算padding并给输入图像补padding
2.2 反池化运算
反池化有很多方法,有一种卷积运算方法可以近似省略池化(因为效果相近),即给卷积运算加步伐。即每一个卷积核在原图像运算完,朝下一个运算窗口移动的步数。默认步数是1.步数大于1的效果很接近卷积+池化运算效果。这样的卷积运算,可以看成步数为1的卷积运算+池化运算,即省略了池化运算
步伐>2的卷积效果:卷积得到的图像比步伐小的图像更小。因此反卷积时,也需要处理此种情况
2.3 反卷积和分数步伐
步伐>2的卷积,可以通过分数步伐的反卷积恢复。即对输入图像每个像素点之间补充空白点,卷积步长越大,反卷积补的像素间空白点就越多
2.4 批正则化技术
概念:是每一层神经网络层和非线性运算层之间加入的一个线性运算层,逻辑为y=ax+b。a,b为要学习的参数,x为一批里归一化处理后的输入:(x-mean(x))/std
3 图像生成-最小均方差模型
3.1 思路
输入是一个数字,输出是一个数字的手写图像。通过反卷积网络实现这样的输入与输出
3.2 代码实现
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fimport torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utilimport matplotlib.pyplot as pyplot
import numpy as np
import osoutput_img_size = 28
input_dim = 100
channel_num = 1
features_num = 64
batch_size = 64print(f'prepare datasets begin')
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
itype = torch.cuda.LongTensor if use_cuda else torch.LongTensortrain_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
index_verify = range(len(test_dataset))[:5000]
index_test = range(len(test_dataset))[5000:]sampler_verify = torch.utils.data.sampler.SubsetRandomSampler(index_verify)
sampler_test = torch.utils.data.sampler.SubsetRandomSampler(index_test)verify_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, sampler=sampler_verify)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, sampler=sampler_test)class AntiCNN(nn.Module):def __init__(self):super(AntiCNN, self).__init__()self.model = nn.Sequential()self.model.add_module('deconv1', nn.ConvTranspose2d(input_dim, features_num * 2, 5, 2, 0, bias=False))self.model.add_module('batch_norm1', nn.BatchNorm2d(features_num * 2))self.model.add_module('relu1', nn.ReLU(True))self.model.add_module('deconv2', nn.ConvTranspose2d(features_num * 2, features_num, 5, 2, 0, bias=False))self.model.add_module('batch_norm2', nn.BatchNorm2d(features_num))self.model.add_module('relu2', nn.ReLU(True))self.model.add_module('deconv3', nn.ConvTranspose2d(features_num, channel_num, 4, 2, 0, bias=False))self.model.add_module('sigmoid', nn.Sigmoid())def forward(self, input):output = inputfor _, module in self.model.named_children():output = module(output)return outputdef weight_init(module):class_name = module.__class__.__name__if class_name.find('conv') != -1:module.weight.data.normal_(0, 0.02) # convey mean and stdif class_name.find('norm') != -1:module.weight.data.normal_(1, 0.02)def resize_to_img(img):return img.data.expand(batch_size, 3, output_img_size, output_img_size)def imgshow(input, title=None):if input.size()[0] > 1:input = input.numpy().transpose((1, 2, 0))else:input = input[0].numpy()min_val, max_val = np.amin(input), np.amax(input)if max_val > min_val:input = (input - min_val) / (max_val - min_val)pyplot.imshow(input)if title:pyplot.title(title)pyplot.pause(0.001)def main():net = AntiCNN()net = net.cuda() if use_cuda else netcriterion = nn.MSELoss()optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)samples = np.random.choice(10, batch_size)samples = torch.from_numpy(samples).type(dtype)step = 0num_epoch = 2record = []print('train begin')for epoch in range(num_epoch):print(f'the no.{epoch} epoch')train_loss = []for batch_index, (data, target) in enumerate(train_loader):target, data = data.clone().detach().requires_grad_(True), target.clone().detach()#target, data = target.cuda(), data.cuda() if use_cuda else target, dataif use_cuda:target, data = target.cuda(), data.cuda()data = data.type(dtype)data = data.resize(data.size()[0], 1, 1, 1)data = data.expand(data.size()[0], input_dim, 1, 1)net.train()output = net(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()step += 1loss = loss.cpu() if use_cuda else losstrain_loss.append(loss.data.numpy())if batch_index % 300 == 0:net.eval()verify_loss = []index = 0for data, target in verify_loader:target, data = data.clone().detach().requires_grad_(True), target.clone().detach()index += 1# target, data = target.cuda(), data.cuda() if use_cuda else target, dataif use_cuda:target, data = target.cuda(), data.cuda()data = data.type(dtype)data = data.resize(data.size()[0], 1, 1, 1)data = data.expand(data.size()[0], input_dim, 1, 1)output = net(data)loss = criterion(output, target)loss = loss.cpu() if use_cuda else lossverify_loss.append(loss.data.numpy())print(f'now no.{batch_index} batch. train loss:{np.mean(train_loss):.4f}, verify loss:{np.mean(verify_loss):.4f}')record.append([np.mean(train_loss), np.mean(verify_loss)])with torch.no_grad():samples.resize_(batch_size, 1, 1, 1)samples = samples.data.expand(batch_size, input_dim, 1, 1)# samples = samples.cuda() if use_cuda else samplesif use_cuda:samples = samples.cuda()fake_u = net(samples)# fake_u = fake_u.cuda() if use_cuda else fake_uif use_cuda:fake_u = fake_u.cuda()img = resize_to_img(fake_u)os.makedirs(os.path.realpath('./pytorch/jizhi/image_generate/temp1'), exist_ok=True)util.save_image(img, os.path.realpath(f'./pytorch/jizhi/image_generate/temp1/fake{epoch}.png'))pyplot.show()if __name__ == '__main__':main()
发现图片很模糊,可能是均方误差算的是所有手写数字的平均值,且每个图像没有明显模式,倒是平均值就是很模糊。咋整呢?可以尝试用之前的手写数字图像识别器帮助矫正MSE