论文:Glow: Generative Flow with Invertible 1x1 Convolutions
代码:pytorch版本:rosinality/glow-pytorch: PyTorch implementation of Glow (github.com)
正版是TensorFlow版本 openai的
参考csdn文章:Glow-pytorch复现github项目_pytorch glow-CSDN博客
(pytorch进阶之路)NormalizingFlow标准流_normalizing flow-CSDN博客
需要先看一下b站的Flow的讲解Flow-based Generative Model_哔哩哔哩_bilibili P59
本csdn文的目标:跑通代码+理解原理(不包含论文结果部分解读)
目录
1 引言
2 背景:
3 Generative Flow
Glow模块的整体代码:
Block模块:
Flow模块:
3.1 Actnorm: scale and bias layer with data dependent initialization
3.2 Invertible 1 1 convolution 可逆1*1卷积
3.3 Affine Coupling Layers 仿射耦合层
train部分
1 引言
基于flow模型改进,提出Glow
2 背景:
之前是基于flow的生成模型,我们的目标是从z(一个普通的分布)拟合到x(真实的分布),理解为从图A变为图B,而且要求这个过程是可逆的。
模型为G(x),目标最大化极大似然(最大似然理解为当参数为变量时,X=x的概率最大化):
也就是最后的这个。即最小化:
其中,flow的意思就是多个G连起来:
最终最大化下面这个,即:
其中,z的分布的选取一般为正态分布,均值为0函数G为双摄可逆函数,,可逆回去。
在计算方面,最后可以等于雅可比行列式的对角线。
3 Generative Flow
flow的每一步都由actnorm(3.1)、一个可逆的1x1卷积(3.2)和一个耦合层(3.3)组成。flow的深度为K,层数为L,下图。
Glow模块的整体代码:
class Glow(nn.Module):def __init__(self, in_channel, n_flow, n_block, affine=True, conv_lu=True): #n_flow为K,n_block为Lsuper().__init__()self.blocks = nn.ModuleList() #blocks层为图b的堆叠n_channel = in_channelfor i in range(n_block - 1):self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))n_channel *= 2 #最后一个Block通道*2self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))def forward(self, input):log_p_sum = 0logdet = 0out = inputz_outs = [] #中间zfor block in self.blocks:out, det, log_p, z_new = block(out) #循环 outz_outs.append(z_new)logdet = logdet + det #logdet求和if log_p is not None:log_p_sum = log_p_sum + log_p #log_p求和return log_p_sum, logdet, z_outs # 输出log_p和logdet,以及最后的z序列def reverse(self, z_list, reconstruct=False):for i, block in enumerate(self.blocks[::-1]):#最后一个block去掉if i == 0:input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)else:input = block.reverse(input, z_list[-(i + 1)], reconstruct=reconstruct)return input
Block模块:
class Block(nn.Module):def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True):super().__init__()squeeze_dim = in_channel * 4 #扩大4倍self.flows = nn.ModuleList()for i in range(n_flow): #内部Flow块,一共n_flow块self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))self.split = splitif split:self.prior = ZeroConv2d(in_channel * 2, in_channel * 4)else:self.prior = ZeroConv2d(in_channel * 4, in_channel * 8)def forward(self, input):b_size, n_channel, height, width = input.shapesqueezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2) #尺寸变小squeezed = squeezed.permute(0, 1, 3, 5, 2, 4) #[b,c,h,2,w,2]变成[b,c,2,2,h,w]out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2) #深拷贝重新创建outlogdet = 0for flow in self.flows:out, det = flow(out)logdet = logdet + detif self.split:out, z_new = out.chunk(2, 1) #分块,dim=1分2块mean, log_sd = self.prior(out).chunk(2, 1)log_p = gaussian_log_p(z_new, mean, log_sd)log_p = log_p.view(b_size, -1).sum(1)else:zero = torch.zeros_like(out)mean, log_sd = self.prior(zero).chunk(2, 1)log_p = gaussian_log_p(out, mean, log_sd)log_p = log_p.view(b_size, -1).sum(1)z_new = outreturn out, logdet, log_p, z_newdef reverse(self, output, eps=None, reconstruct=False):input = outputif reconstruct:if self.split:input = torch.cat([output, eps], 1)else:input = epselse:if self.split:mean, log_sd = self.prior(input).chunk(2, 1)z = gaussian_sample(eps, mean, log_sd)input = torch.cat([output, z], 1)else:zero = torch.zeros_like(input)# zero = F.pad(zero, [1, 1, 1, 1], value=1)mean, log_sd = self.prior(zero).chunk(2, 1)z = gaussian_sample(eps, mean, log_sd)input = zfor flow in self.flows[::-1]:input = flow.reverse(input)b_size, n_channel, height, width = input.shapeunsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)unsqueezed = unsqueezed.contiguous().view(b_size, n_channel // 4, height * 2, width * 2)return unsqueezed
Flow模块:
class Flow(nn.Module):def __init__(self, in_channel, affine=True, conv_lu=True):super().__init__()self.actnorm = ActNorm(in_channel)if conv_lu:self.invconv = InvConv2dLU(in_channel)else:self.invconv = InvConv2d(in_channel)self.coupling = AffineCoupling(in_channel, affine=affine)def forward(self, input):out, logdet = self.actnorm(input)out, det1 = self.invconv(out)out, det2 = self.coupling(out)logdet = logdet + det1if det2 is not None:logdet = logdet + det2return out, logdetdef reverse(self, output):input = self.coupling.reverse(output)input = self.invconv.reverse(input)input = self.actnorm.reverse(input)return input
3.1 Actnorm: scale and bias layer with data dependent initialization
之前提出批归一化来缓解训练深度模型时遇到的问题。然而,由于批处理归一化(batch normalization)所增加的激活噪声的方差与每个GPU或其他处理单元(PU)的小批(minibatch)大小成反比,因此已知每个PU的小批大小会降低性能。因此,minibatch=1. 我们提出了一个actnorm层(用于激活归一化),它使用每个通道的尺度和偏置参数执行激活的仿射变换,类似于批量归一化。这些参数被初始化,使得每个通道的事后激活具有零均值和给定初始小批量数据的单位方差。这是数据依赖初始化的一种形式(Salimans and Kingma 2016)。初始化后,尺度和偏差被视为独立于数据的常规可训练参数。(没怎么懂,看代码吧)
在Flow模块中的第一层就是ActNorm。这一步其实就是一个标准化,对于input(经过squeezed)【batch,12,32,32 】进行每个通道的标准化,用每个通道,例如3通道计算batch*h*w的均值,【1,12,1,1】,标准差也同样,然后进行(x-均值)/(标准差+1e-6) 标准化。因为要可逆,需要计算det,为系数的log求和,其实就是1/(标准差+1e-6)的log求和。
class ActNorm(nn.Module):def __init__(self, in_channel, logdet=True):super().__init__()self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1)) #每个通道有一个值 初始为全0self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1)) #初始scale为全1self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) #不被更新的参数self.logdet = logdet #是否计算logdetdef initialize(self, input): #改变scalewith torch.no_grad():flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)#深度拷贝,[12, 64*32*32]mean = (flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3))#上面把input分为12通道,每个通道包含64张的图像的一个通道数据,求均值,并转化为[1,12,1,1]std = (flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3))#类似的,求std标准差,并转化为[1,12,1,1]self.loc.data.copy_(-mean)# loc为负的平均值self.scale.data.copy_(1 / (std + 1e-6)) #scale为1 / (std + 1e-6)def forward(self, input):#64, 12, 32, 32_, _, height, width = input.shapeif self.initialized.item() == 0: #没操作,为0self.initialize(input) #initialized为一个操作,根据input,对loc和scale的赋值self.initialized.fill_(1) #操作完了,为1;哈哈哈哈要是我写的话,就是直接创建一个哨兵log_abs = logabs(self.scale)#均值的绝对值的loglogdet = height * width * torch.sum(log_abs)#均值的logabs求和后乘以h*w det为系数log求和,一共h*w个点if self.logdet:return self.scale * (input + self.loc), logdet #对input每个点使用通道标准化 det为系数log求和else:return self.scale * (input + self.loc)def reverse(self, output):return output / self.scale - self.loc
3.2 Invertible 1 1 convolution 可逆1*1卷积
在Flow模块中的第二层,根据是否LU,选择是否带LU操作的1*1可逆卷积:
if conv_lu:self.invconv = InvConv2dLU(in_channel)else:self.invconv = InvConv2d(in_channel)
class InvConv2dLU(nn.Module):def __init__(self, in_channel):super().__init__()weight = np.random.randn(in_channel, in_channel)#[12,12]q, _ = la.qr(weight) #qr分解,q为正交矩阵,r为上三角矩阵w_p, w_l, w_u = la.lu(q.astype(np.float32))#对于正交矩阵q进行LU分解,p为置换矩阵,l为下三角,u为上三角,PA=LU,P就是把最大元素放在第一行w_s = np.diag(w_u)#对角线w_u = np.triu(w_u, 1) #去掉对角线,只保留上三角u_mask = np.triu(np.ones_like(w_u), 1) #上三角单位阵,不包含对角线l_mask = u_mask.T #下三角 不包含对角线w_p = torch.from_numpy(w_p) #q置换矩阵pw_l = torch.from_numpy(w_l) #q的下三角lw_s = torch.from_numpy(w_s.copy()) #q的上三角u的对角线w_u = torch.from_numpy(w_u) #q的上三角u的上三角self.register_buffer("w_p", w_p)#p不更新self.register_buffer("u_mask", torch.from_numpy(u_mask))self.register_buffer("l_mask", torch.from_numpy(l_mask))self.register_buffer("s_sign", torch.sign(w_s))self.register_buffer("l_eye", torch.eye(l_mask.shape[0])) #对角线全1,其余全0self.w_l = nn.Parameter(w_l) #更新的self.w_s = nn.Parameter(logabs(w_s))self.w_u = nn.Parameter(w_u)def forward(self, input):_, _, height, width = input.shapeweight = self.calc_weight()#[12,12,1,1] 这里就是1*1卷积了,12种12通道 对应下面的卷积操作out = F.conv2d(input, weight) #输出通道数为卷积种类为12logdet = height * width * torch.sum(self.w_s)return out, logdetdef calc_weight(self):weight = (self.w_p@ (self.w_l * self.l_mask + self.l_eye) #@为矩阵乘法@ ((self.w_u * self.u_mask) + torch.diag(self.s_sign * torch.exp(self.w_s))))return weight.unsqueeze(2).unsqueeze(3)def reverse(self, output):weight = self.calc_weight()#weight跟上面的weight是同一个 需要先训练上面的那个weightreturn F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))
自己定义的权重W,(cxc),与输入的tensor h (h x w x c)之间进行卷积计算,因此,log_det的计算为:
但是,detW的计算复杂,为了简化计算复杂度,提出使用LU分解,把W参数化:
P为置换矩阵(不参与更新),L为下三角矩阵(对角线为0),U为上三角矩阵(对角线为0),diag(s)为分解时候的上三角矩阵plu的u的对角线,U仅仅只是u的对角线变为0,这样才符合plu分解,即,W=p*l*u。这样,log_det可以简化为:
对于较大的通道数c,可以大大节省。并且,除了P不参与更新外,L、U、s都参与更新。
也提供了不进行PLU分解的版本:
class InvConv2d(nn.Module):def __init__(self, in_channel):super().__init__()weight = torch.randn(in_channel, in_channel)q, _ = torch.qr(weight)weight = q.unsqueeze(2).unsqueeze(3)self.weight = nn.Parameter(weight)def forward(self, input):_, _, height, width = input.shapeout = F.conv2d(input, self.weight)logdet = (height * width * torch.slogdet(self.weight.squeeze().double())[1].float())return out, logdetdef reverse(self, output):return F.conv2d(output, self.weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))
3.3 Affine Coupling Layers 仿射耦合层
这一层在flow模块中的第三层
仿射耦合层是一种强大的可逆变换,其中正向函数、逆函数和对数行列式的计算效率很高。加性耦合层是s=1和log_det=0的特殊情况。
还是看代码吧。
Zero initialization:零初始化最后一个卷积。这样每个仿射耦合层最初执行一个恒等函数,这有助于训练非常深的网络。也就是说,网络一开始输入等于输出,因为F为0和H接近。
class ZeroConv2d(nn.Module):def __init__(self, in_channel, out_channel, padding=1):super().__init__()self.conv = nn.Conv2d(in_channel, out_channel, 3, padding=0)self.conv.weight.data.zero_()self.conv.bias.data.zero_()self.scale = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) # scale变成可以训练的 [1,12,1,1]def forward(self, input):out = F.pad(input, [1, 1, 1, 1], value=1) # 填充数值为1 从[64,512,32,32]变为[64,512,34,34]out = self.conv(out) #通道数变回 从512变回12 初始输出全为0,因为权重为0out = out * torch.exp(self.scale * 3) #0乘以1还是0return out
class AffineCoupling(nn.Module):def __init__(self, in_channel, filter_size=512, affine=True):super().__init__()self.affine = affineself.net = nn.Sequential(nn.Conv2d(in_channel // 2, filter_size, 3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(filter_size, filter_size, 1),nn.ReLU(inplace=True),ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2),#如果仿射,输出通道数为12,否则为6)self.net[0].weight.data.normal_(0, 0.05)#初始化权重,对于第一个Conv2dself.net[0].bias.data.zero_()self.net[2].weight.data.normal_(0, 0.05)#初始化权重,对于第二个Conv2dself.net[2].bias.data.zero_()def forward(self, input):in_a, in_b = input.chunk(2, 1)#分块,对于dim=1,通道分为2块,这应该就是上下两块 [6,6]if self.affine:log_s, t = self.net(in_a).chunk(2, 1)#6 6 输出初始都为0# s = torch.exp(log_s)s = torch.sigmoid(log_s + 2) #图中的F# out_a = s * in_a + tout_b = (in_b + t) * s #生成下面那个 t为图中的H,有所不同的是计算顺序logdet = torch.sum(torch.log(s).view(input.shape[0], -1), 1)else: #不生成Fnet_out = self.net(in_a) #图中的H 通道数为6out_b = in_b + net_out #直接生成下面 通道数相同logdet = Nonereturn torch.cat([in_a, out_b], 1), logdet #上面的那块不变,def reverse(self, output):out_a, out_b = output.chunk(2, 1) #上面的out拆分,第一个其实没有变if self.affine:log_s, t = self.net(out_a).chunk(2, 1) #由于第一个没有变,生成的这两个块与上面是一样的# s = torch.exp(log_s)s = torch.sigmoid(log_s + 2)#生成的F与上面也是一样的# in_a = (out_a - t) / sin_b = out_b / s - t #先除以F后减telse:net_out = self.net(out_a) #由于第一个没有变 生成的F没有变in_b = out_b - net_out #直接减掉就好return torch.cat([out_a, in_b], 1)
代码实现部分,关于s的生成注释掉的部分与视频中讲解的一致,属于标准形式,后面用sigmoid生成openai代码中也是如此。
至此,Flow模块已经完成。论文方法部分也结束了。
在Flow模块外部还有squeezed操作,把图像切分为4块后,拼起来,通道变为12后再送入Flow块。后面还有一个split操作。这形成一个Block块。如果需要split操作,则输出的一半作为z,另一半作为out送到下游。看代码
首先,高斯分布的概率密度函数:
对于此概率密度函数取对数log,以e为底:注意下面的输入,log_sd是对标准差取对数,其中mean和log_sd都是可以训练的。
def gaussian_log_p(x, mean, log_sd):return -0.5 * log(2 * pi) - log_sd - 0.5 * (x - mean) ** 2 / torch.exp(2 * log_sd)
class Block(nn.Module):def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True):super().__init__()squeeze_dim = in_channel * 4 #扩大4倍self.flows = nn.ModuleList()for i in range(n_flow): #内部Flow块,一共n_flow块self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))self.split = splitif split:#对于split,输入,输出的通道数不同self.prior = ZeroConv2d(in_channel * 2, in_channel * 4)else:self.prior = ZeroConv2d(in_channel * 4, in_channel * 8)def forward(self, input):b_size, n_channel, height, width = input.shapesqueezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2) #尺寸变小squeezed = squeezed.permute(0, 1, 3, 5, 2, 4) #[b,c,h,2,w,2]变成[b,c,2,2,h,w]out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2) #深拷贝重新创建out [b, c*4, h//2, w//2]logdet = 0for flow in self.flows:out, det = flow(out)logdet = logdet + detif self.split:#如果split的话,flow的out一半是z,另一半用来生成log_p指标,out, z_new = out.chunk(2, 1) #通道分块,dim=1分2块 6,6mean, log_sd = self.prior(out).chunk(2, 1) #6,6 mean, log_sd都是可学习的log_p = gaussian_log_p(z_new, mean, log_sd) #z_new的分布为高斯分布的概率的log 这就是z是高斯分布的关键log_p = log_p.view(b_size, -1).sum(1)#求和else:zero = torch.zeros_like(out)mean, log_sd = self.prior(zero).chunk(2, 1)log_p = gaussian_log_p(out, mean, log_sd)#out的分布为高斯分布的概率的loglog_p = log_p.view(b_size, -1).sum(1)z_new = outreturn out, logdet, log_p, z_newdef reverse(self, output, eps=None, reconstruct=False): #reverse的输入,如果是最后一层,output和eps都是z_list,其他层的话output为out,eps为zinput = outputif reconstruct: #是否重建if self.split:input = torch.cat([output, eps], 1) #如果split了,【out,z】else:input = eps #zelse: #如果不需要重建if self.split:mean, log_sd = self.prior(input).chunk(2, 1)z = gaussian_sample(eps, mean, log_sd)input = torch.cat([output, z], 1)else:zero = torch.zeros_like(input)# zero = F.pad(zero, [1, 1, 1, 1], value=1)mean, log_sd = self.prior(zero).chunk(2, 1)z = gaussian_sample(eps, mean, log_sd)input = zfor flow in self.flows[::-1]:input = flow.reverse(input)b_size, n_channel, height, width = input.shapeunsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)unsqueezed = unsqueezed.contiguous().view(b_size, n_channel // 4, height * 2, width * 2)return unsqueezed
最后Glow模型:
class Glow(nn.Module):def __init__(self, in_channel, n_flow, n_block, affine=True, conv_lu=True): #n_flow为K,n_block为Lsuper().__init__()self.blocks = nn.ModuleList() #blocks层为图b的堆叠n_channel = in_channelfor i in range(n_block - 1):self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))n_channel *= 2 #最后一个Block通道*2self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))def forward(self, input):log_p_sum = 0logdet = 0out = inputz_outs = [] #中间zfor block in self.blocks:out, det, log_p, z_new = block(out) #循环 outz_outs.append(z_new)logdet = logdet + det #logdet求和if log_p is not None:log_p_sum = log_p_sum + log_p #log_p求和return log_p_sum, logdet, z_outs # 输出log_p和logdet,以及最后的z序列def reverse(self, z_list, reconstruct=False):for i, block in enumerate(self.blocks[::-1]):#最后一个block去掉if i == 0:input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)else:input = block.reverse(input, z_list[-(i + 1)], reconstruct=reconstruct)return input
train部分
from tqdm import tqdm
import numpy as np
from PIL import Image
from math import log, sqrt, piimport argparseimport torch
from torch import nn, optim
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader
import torch.utils.data
from torchvision.datasets import CIFAR10
from torchvision import datasets, transforms, utilsfrom model import Glowdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")parser = argparse.ArgumentParser(description="Glow trainer")
parser.add_argument("--iter1", default=200000, type=int, help="maximum iterations") # 迭代周期
parser.add_argument("--n_flow", default=32, type=int, help="number of flows in each block")
parser.add_argument("--n_block", default=4, type=int, help="number of blocks")
parser.add_argument("--no_lu", action="store_true", help="use plain convolution instead of LU decomposed version")
parser.add_argument("--affine", action="store_true", help="use affine coupling instead of additive")
parser.add_argument("--n_bits", default=5, type=int, help="number of bits")
parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
parser.add_argument("--temp", default=0.7, type=float, help="temperature of sampling")
parser.add_argument("--n_sample", default=20, type=int, help="number of samples")def data_tr_1(x):x = x.resize((64, 64))x = np.array(x, dtype='float32') / 255x = (x - 0.5) / 0.5x = x.transpose((2, 0, 1))x = torch.from_numpy(x)return xdef sample_data():transform = transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.RandomHorizontalFlip(),transforms.ToTensor(),])dataset = CIFAR10('./data', train=True, transform=transform, download=True)loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)#test_set = CIFAR10('./data', train=False, transform=transform, download=True)#test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)#dataset = datasets.ImageFolder(path, transform=transform)#loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4)loader = iter(loader)while True:try:yield next(loader)except StopIteration:loader = DataLoader(dataset, shuffle=True, batch_size=64, num_workers=4)loader = iter(loader)yield next(loader)def calc_z_shapes(n_channel, input_size, n_block):'''每一个block之后输出的z_shapeinput:(3,64,64)[(6, 32, 32), (12, 16, 16), (48, 8, 8)]'''z_shapes = []for i in range(n_block - 1):input_size //= 2 #size 两倍变小n_channel *= 2 # 通道两倍变大z_shapes.append((n_channel, input_size, input_size))input_size //= 2z_shapes.append((n_channel * 4, input_size, input_size))return z_shapesdef calc_loss(log_p, logdet, image_size, n_bins):# log_p = calc_log_p([z_list])n_pixel = image_size * image_size * 3loss = -log(n_bins) * n_pixelloss = loss + logdet + log_preturn ((-loss / (log(2) * n_pixel)).mean(),(log_p / (log(2) * n_pixel)).mean(),(logdet / (log(2) * n_pixel)).mean(),)def train(args, model, optimizer):dataset = iter(sample_data())n_bins = 2.0 ** args.n_bits # 10bitz_sample = [] #中间初始值z?z_shapes = calc_z_shapes(3, image_size, n_block)for z in z_shapes:z_new = torch.randn(n_sample, *z) * temp # n_sample为batchz_sample.append(z_new.to(device)) #[-2, 3]左右with tqdm(range(iter1)) as pbar:for i in pbar:image, _ = next(dataset)image = image.to(device)image = image * 255 # [0, 255]if args.n_bits < 8: #5image = torch.floor(image / 2 ** (8 - args.n_bits)) #[0,31]image = image / n_bins - 0.5 #[-0.5, 2.6]if i == 0:with torch.no_grad():log_p, logdet, _ = model.module(image + torch.rand_like(image) / n_bins)continueelse:log_p, logdet, _ = model(image + torch.rand_like(image) / n_bins) #加噪声logdet = logdet.mean()loss, log_p, log_det = calc_loss(log_p, logdet, image_size, n_bins)model.zero_grad()loss.backward()# warmup_lr = args.lr * min(1, i * batch_size / (50000 * 10))warmup_lr = args.lroptimizer.param_groups[0]["lr"] = warmup_lroptimizer.step()pbar.set_description(f"Loss: {loss.item():.5f}; logP: {log_p.item():.5f}; logdet: {log_det.item():.5f}; lr: {warmup_lr:.7f}")if i % 100 == 0:with torch.no_grad():utils.save_image(model_single.reverse(z_sample).cpu().data,f"sample/{str(i + 1).zfill(6)}.png",normalize=True,nrow=10,range=(-0.5, 0.5),)if i % 10000 == 0:torch.save(model.state_dict(), f"checkpoint/model_{str(i + 1).zfill(6)}.pt")torch.save(optimizer.state_dict(), f"checkpoint/optim_{str(i + 1).zfill(6)}.pt")if __name__ == "__main__":args = parser.parse_args()print(args)image_size = 64n_flow = args.n_flown_block = args.n_blockn_sample = args.n_sampletemp = args.tempiter1 = args.iter1model_single = Glow(3, n_flow, n_block, affine=args.affine, conv_lu=not args.no_lu)model = nn.DataParallel(model_single)# model = model_singlemodel = model.to(device)optimizer = optim.Adam(model.parameters(), lr=1e-4)train(args, model, optimizer)
数据集我选用的是cifar10.batch size设置为64,其余都是原本的默认值。loss为log_p与logdet相加后取负,也就是目标为最大化log_p, 使输出逐渐为高斯分布,logdet使得可逆后回去Image。
最后生成的结果
由于我的数据并没有分类存放,导致学习到的特征比较混乱,而且我也只是跑通代码理解原理而已。