第十六周:机器学习笔记

第十六周周报

  • 摘要
  • Abstratc
  • 一、机器学习
    • 1. Pointer Network(指针网络)
    • 2. 生成式对抗网络(Generative Adversarial Networks | GAN)——(上)
      • 2.1 Generator(生成器)
      • 2.2 Discriminator(判别器)
      • 2.3 Generator与 Discriminator 的训练过程
      • 2.4 GAN 的应用案例
  • 二、Pytorch学习
    • 1. 完整的模型验证套路
    • 2. Github开源项目——pytorch-CycleGAN-and-pix2pix
  • 总结

摘要

本周周报在机器学习的理论内容中,简单的介绍了Pointer Network 的运作原理,此外还详细描述了GAN,剖析了其内部结构,然后对其中的Generator 和 Discriminator的运作原理以及它们的互动方式进行简要的描述,最后还介绍了GAN的应用案例。在Pytorch的代码实践中,周报详细讲述了如何对模型进行验证的方法。此外,周报在最后还介绍了如何在GitHub上通读项目。

Abstratc

The working principle of the Pointer Network was briefly introduced in the theoretical content of machine learning within this week’s report. Additionally, a detailed description of Generative Adversarial Networks (GANs) was provided, including an analysis of their internal structure. The operating principles of the Generator and Discriminator, as well as their interactive mechanisms, were also briefly explained. Application cases of GANs were presented. Methods for model validation were detailed in the PyTorch code practice section. Furthermore, instructions on how to read projects on GitHub were introduced at the end of the report.

一、机器学习

1. Pointer Network(指针网络)

在十四周的周报中(链接如下:https://blog.csdn.net/Zcymatics/article/details/142406494?spm=1001.2014.3001.5501)
在Seq2Seq的tips的Copy Mechanism,我们了解到了pointer network可以实现copy这个功能,因此这一小节我们来了解一下pointer network。
在这里插入图片描述
pointer network最早的最早的呢是被用在想要解一系列演算法的问题

‌‌演算法是解决特定问题的一种有限步骤程序。‌ 
演算法是对特定问题求解方法和步骤的一种描述,它是指令的有限序列,其中每个指令表示一个或多个操作。
演化算法(Evolutionary Algorithms)是另一种与演算法相关的概念。
演化算法是一种基于群体的元启发式最优化算法,通过模拟生物演化机制如繁衍、变异、重组和选择来寻找最优解。

如下图所示
有如下data point。
我们要用这个自动的方法找出的data point来通通连在一起以后,他可以把其余的点包括进去
input十个都是2维的vector到neural network里面去,然后输出是 4 2 7 6 5 3(正好能把所有点囊括进去)
训练这个network 也是需要喂大量的数据,训练出来。
在这里插入图片描述
那这个point network要怎么解呢?
我们可以尝试以下用Seq 2 Sqe 去解决
总共有四个data point input,就是input为四个data point的坐标。
每一个output的选择是这四个point ,为1到4。
过程如下图所示:
encoder把这四个data point读进来,变成一个红色的vector。然后把红色的vector丢到decoder里面,
decoder就得到一个distribution,然后根据这个distribution做sample(比如使用argmax())
在这里插入图片描述
根据argmax 第一个输出token 1,第二个输出token 4,然后再下一个时间点输出token 2,直到输出end的时候就结束
看看硬train下去行不行得通
在这里插入图片描述
结果是行不通的
因为encoder是个RNN,所以今天input sequence的长短不一样的时候,encoder可以处理。
比如在训训练的时候都只input 50个点,但是在testing的时候input 100个点,encoder也可以处理的。
但是decoder就不一样
假设原来设定好的output就是50个点,那意味着说output这个vector就是1到50加上end。如果testing的时候变成input 100个点,那它就只有1到50加上end的,它永远选不了其他51到编号100的点了
在这里插入图片描述
所以Seq2Seq的方法是行不通的。
那要怎么办呢?
对attention的机制做了一下改造以后,让network可以动态的决定输出的set
将sequence这个第一个点到第四个点读进来。再加一个特别的符号x0和y0代表end
在这里插入图片描述
让我们来复习一下attention based model
在这里插入图片描述
attention其实就是一个当前的输入与输出的匹配度。
在上图中,即为h1和z0的匹配度
h1 为当前时刻RNN的hidden layer输出向量,而不是原始输入的词向量,z0为初始化向量,如rnn中的initial memory
其中的match为计算这两个向量的匹配度的模块,出来的α01即为由match算出来的相似度。

代入到下图
先产生一个key,这个key是z0
我们用这个z0对input去做attention,对每一个input就会产生一个attention的weight。
算出来之后,我们不需要加权求和。直接用soft-max之后的的结果作为distribution,使用argmax选择最大值,作为输出。
如下图所示
在这里插入图片描述
然后h1作为输入,选出h4,然后再用h4作为输入,以此类推…直到end出现
end出现代表end的attention的最大。整个process就结束了
在这里插入图片描述point network最经典是用在summarization上面(总结全文)
输入一个文章的文件,输出对这篇文章的总结
很多summary里面放的词汇,其实都是人名地名。而且我们的词汇长度有限,没可能把这些人名地名都包含在内。如果用Seq2Seq那么那些人名地名(即没有出现在词汇表上)被output的几率几乎为0
summary跟input document的关系,其实就是把input document取一些重要的词汇出来,拼接起来就是很好的summary了,因此我们需要使用point network。
summary加上point network如下图所示:
加上point network概念,会直接从attention distribution里面直接去做sample,即用这个distribution来决定summary里面应该产生什么样的word
然后再learn另外一个weight (这个weight为:P generation,这个P generation决定了要走右边的Pgen,还是走左边的(1-Pgen)。

即我们要用多少的传统的方法 加上 point network的output。
举例:那如果是pointer network这个部分,Argentina的几率是最大的
假设dataset里面也有Argentina,那就把它们的几率加起来。
这两个distribution的几率加起来后,再从总的distribution去决定summarization的decoder要产生什么样的word
在这里插入图片描述
其实point network的核心就是,不要weighted sum,直接拿weight score经过 soft-max 后作为distribution 最后直接在这个distribution上sample出几率最大的值作为output
这样我们就可以做到复制的效果,不必担心要分类的内容中(词汇表)是否包含输入的内容(人名、地名)

2. 生成式对抗网络(Generative Adversarial Networks | GAN)——(上)

到目前为止,我们学习到的网络本质上都是一个函数,即提供一个输入 x,网络就可以输出一个结果 y并可以应对不同类型的输入 x 和输出 y。
例如,当输入 x 是一张图片时,可以使用例如卷积神经网络等模型进行处理
当输入 x 是序列数据时,可以使用基于循环神经网络架构的模型进行处理,其中输出 y 既可以是数值、类别,也可以是一个序列。

接下来我们将学习一个新的知识——生成模型

2.1 Generator(生成器)

生成模型中的核心就是生成器,所谓生成器就是Network。
在模型输入时会将一个随机变量 z 与原始输入 x 一并输入到模型中,这个变量z是从随机分布中得到的。
那么 x 与 z要怎么同时作为input 输入到 network中呢?
有两种方法
①:向量拼接的方式将 x 和 z 一并输入。
②:或在 x、z 长度一样时,将二者的加和作为输入。

我们在这里可以看到跟前面所学的模型上的区别,在input上多了个 z ,但是这个 z 是特别的
z 特别之处在于其非固定性,即每一次我们使用网络时都会从一个随机分布中采样得到一个新的 z。
对于该随机分布的要求是其足够简单,可以较为容易地进行采样,或者可以直接写出该随机分布的函数,例如高斯分布(Gaussian distribution)、均匀分布(uniform distribution)等等。
每次我们输入一个 x 的同时,就会在随机分布中采样到 z ,然后经过Network得到最终的结果y。
同理,对于网络来说,其输出也不再固定,而变成了一个复杂的分布。
我们也将这种可以输出一个复杂分布的网络称为生成器。
在这里插入图片描述
为什么我们要输入一个 概率分布采样出来的 z 呢?
下面介绍一个视频预测的例子,即给模型一段的视频短片,然后让它预测接下来发生的事情。
视频环境是小精灵游戏,预测下一帧的游戏画面。
当然在实践中,我们为了保证高效训练,我们会将每一帧画面分割为很多块作为输入,并行分别进行预测。
在这里插入图片描述
在实践中,我们为了保证高效训练,我们会将每一帧画面分割为很多块作为输入,并行分别进行预测。
为了简化,假设网络是一次性输入的整个画面。
如果我们使用前几章介绍的基于监督学习的训练方法,我们得到的结果可能会是的十分模糊的甚至游戏中的角色消失、出现残影的。
如下图所示:

监督学习是根据已知输入和输出训练模型,用于预测或分类。
无监督学习没有标签的训练数据集,目标是探索数据中的结构和关系,例如聚类、降维和异常检测等。

在这里插入图片描述
为什么会出现这种情况呢?
因为我们提供数据的时候“喂”给generator的图片中面对转向问题,同时具备了向左或者向右移动的图片,所以在训练的时候,我们为了提升正确率会与原来的图片进行对比,不断逼近真实的图片,减少误差。
当我们在训练的时候,对于一条向左转的训练数据,网络得到的指示就是要学会游戏角色向左转的输出。同理,对于一条向右转的训练数据,网络得到的指示就是学会角色向右转的输出。
但是实际上这两种数据可能会被同时训练,所以当这个输出同时距离向左转和向右转最近,网络就会得到一个错误的结果————向左转是对的,向右转也是对的。 因此会产生残影或者消失等结果。

所以我们 input 的 z 这个概率分布就是为了防止这种错误输出的产生
其通过产生distribution,对不同的情况赋予不同的概率,从而让其通过generator(生成器)后产生一个复杂的distribution (这个distribution其实可以理解为输出下一帧所有可能的情况,例如下一帧预测向左移动,或者下一帧预测向右移动等等)。
举例来说,假设我们选择的 z 服从一个二项分布,即就只有 0 和 1 并且各占 50%。那么我们的网络就可以学到 z 采样到 1 的时候就向左转,采样到 0 的时候就向右转,这样就可以解决了。
在这里插入图片描述

其实说白了input 的 z 就是让答案是让网络有概率的输出一切可能的结果,或者说输出一个概率的分布,而不是原来的单一的输出。

为什么需要generator呢?
我们之所以需要generator,就是因为我们需要一些工具来帮我们创造一些东西。
可以比作让很多人一起处理一个问题,大家的回答都是有自己的想法的,回答的答案都是正确的,只是思路不同,就像头脑风暴一样。
所以生成模型也可以被理解为让模型自己拥有了创造的能力。
举个具体的例子,比如我们对机器人说,你知道有哪些童话故事吗?聊天机器人会回答会有很多,比如:安徒生童话、格林童话。这些回答都是机器人发挥想象力得到的,且都是正确的,但是就是没有一个标准的答案,所以生成模型的任务就是需要能够输出一个分布,或者说多个答案。

在生成模型中,最著名的就是生成式对抗网络(generative adversarial network),我们通常缩写为 GAN。
我们通过让机器生成动漫人物的脸来形象地介绍 GAN
GAN又分为以下两种:
1、无限制生成(un-conditional generation),也就是我们不需要原始输入 x。
2、条件型生成(conditional generation),即需要原始输入 x 。

对于无限制的 GAN
它的唯一输入就是 z,这里假设 z 为正态分布采样出的向量。其通常是一个低维的向量,例如 50、100 的维度
在这里插入图片描述
我们首先从正态分布中采样得到一个向量 z,并输入到生成器中,生成器会给我们一个对应的输出——一个动漫人物的脸。
我们分析一下生成器输出一个动漫人物面部的过程。
一张图片就是一个高维的向量,所以生成器实际上做的事情就是输出一个高维的向量,比如是一个 64×64 的图片(如果是彩色图片那么输出就是 64×64×3)。
当输入的向量 z 不同的时候,生成器的输出就会跟着改变,所以我们从正态分布中采样出不同的 z,得到的输出 y 也就会不同,动漫人脸照片也不同。当然,我们也可以选择其他的分布
但是其实不同分布之间的差距可能不会非常大,所以我们选择最常见的分布——正态分布(高斯分布)

2.2 Discriminator(判别器)

在GAN中,除了生成器外,还有一个非常重要的模块。叫做判别器即Discriminator,其实判别器是一个神经网络。
Discriminator会输入一张图片,输出一个值(scalar),其数值越大就证明这张图片约接近动漫人物。
如下图所示:
这里假设 1 是最大的值,画得很好的动漫图像输出就是 1,画的差的就输出0.5,再差一些就输出 0.1 等等。
判别器从本质来说与生成器一样也是神经网络,可以用卷积神经网络,也可以用 Transformer,只要能够产生出我们要的输入输出即可。
当然对于这个例子,因为输入是一张图片,所以选择卷积神经网络,因为其在处理图像上有非常大的优势。
在这里插入图片描述
生成器生成动漫图片的过程,如下图所示:
首先,对生成器进行初始化,因为初始化的生成器(记为版本v1)参数都是随机的,因此其画出来的效果其实跟马赛克差不多。
然后些图片就会给判别器识别,判别器学习的目标是成功分辨生成器输出的动漫图片。
在图中Generator为V1生成的图片对Discriminator V1的识别来说很容易,它只要看图片中是否有两个黑黑的眼睛即可。
接下来生成器就要通过训练调整里面的参数来骗过判别器。
假设判别器判断一张图片GAN中的Discriminator判别是不是真实图片的依据是看图片有没有眼睛。那Generator V2就需要输出有眼睛的图片可以“骗”过Discriminator V1
但是Discriminator也是会进化和升级的
Discriminator会试图分辨新的生成器与真实图片之间的差异。例如, Discriminator V2通过有没有嘴巴来识别真假。
所以Generator V3就会想办法去骗 Discriminator V2 ,比如把嘴巴加上去。
这个过程实际上就是Generator 与 Discriminator 的反复“博弈”,来“逼迫” Generator 产生出来的图片越来越像动漫的人物。
所以生成器和判别器彼此之间是一直的互动、促进关系,和我们所说的“内卷”一样。最终,生成器会学会画出动漫人物的脸,而判别器也会学会分辨真假图片,这就是 GAN 的训练过程。

在这里插入图片描述

2.3 Generator与 Discriminator 的训练过程

下面,我们从算法角度来解释Generator与 Discriminator是如何进行train的
Generator与 Discriminator 实际上是两个网络
其训练步骤如下:
Step 1 :初始化 Generator 与 Discriminator 的参数

Step 2:固定 Generator,只训练 Discriminator
因为Generator的初始参数是随机初始化的,所以它什么都没有学习到,输入一系列采样得到的向量给它,它的输出肯定都是些随机、混乱的图片(跟马赛克差不多),与真实的动漫头像完全不同。
同时,我们**会有一个很多动漫人物头像的图像Database(**可以通过爬虫等方法得到)。
我们会从这个Database中采样一些动漫人物头像图片出来,来与Generator产生出来的结果对比从而训练 Discriminator。

Discriminator的训练目标是要分辨真正的动漫人物与 Generator 产生出来的动漫人物间的差异。
假如把真正的图片都标 1, Generator 产生出来的图片都标 0。
对于 Discriminator 来说,这就是一个分类或回归的问题。

1、如果是分类的问题,我们就把真正的人脸当作类别 1, Generator生成出来的图片当作类别 2
然后训练一个分类器,将其分类出来。

2、如果当作回归的问题, Discriminator 看到真实图片就要输出 1,Generator生成的图片就要输出 0
然后进行 0-1 之间的打分。
在这里插入图片描述

Step 3:固定 Discriminator ,只训练 Generator
训练 Generator 的目的就是让生成器想办法去骗过 Discriminator因为在 Step 2 中 Discriminator 已经学会分辨真图和假图间的差异。
Generator 如果可以骗过 Discriminator ,那 Generator 产生出来的图片可能就可以以假乱真了。
具体的操作如下:
首先 Generator 输入一个向量(来源于之前讲到的Normal distribution(正态分布)中采样数据),并产生一个图片。
接着我们将这个图片输入到 Discriminator 中, Discriminator 会给这个图片一个打分。
此时 Discriminator 是固定的,它只需要给更“真”的图片更高的分数即可, Generator 训练的目标就是让图片更加真实,即骗取高分数。
在这里插入图片描述
真实场景中 Generator 和 Discriminator 都是有很多层的神经网络
我们通常将将其结合在一起,当作一个大的网络来看待
但是不会调整 Discriminator 部分的模型参数

因为假设要输出的分数越大越好,那完全可以直接调整最后的输出层,改变一下偏差值设为很大的值,那输出的得分就会很高,但是完全达不到我们想要的效果。
所以我们只能训练生成的部分,训练方法与前几章介绍的网络训练方法基本一致,只是我们希望优化目标越大越好,这个与之前我们希望损失越小越好不同。
修改的方式如下:
可以直接在优化目标前加“负号”,就当作损失看待也可以,这样就变为了让损失变小。
另一种方法,我们可以使用梯度上升进行优化,而取代之前的梯度下降优化算法。

总结一下,GAN 算法的三个步骤。
步骤一:初始化参数
步骤二,固定 Generator 训练 Discriminator
步骤三,固定 Discriminator 训练 Generator

接下来就是重复以上的训练,训练完 Discriminator 固定判别器训练 Generator 。然后训练完 Generator 以后再用 Generator 去产生更多的新图片再给 Discriminator 做训练。训练完 Discriminator 后再训练 Generator
反覆地去执行,当其中一个进行训练的时候,另外一个就固定住,期待它们都可以在自己的目标处达到最优。
在这里插入图片描述

2.4 GAN 的应用案例

GAN 生成动画人物人脸
如下图
这些分别是训练 100 轮、1000 轮、2000 轮、5000 轮、10000 轮、20000 轮和 50000 轮的结果。
可以看到训练 100 轮时,生成的图片还是马赛克;
训练到 1000 轮的时候,产生了眼睛;
训练到 2000 轮的时候,产生了嘴巴;
训练到 5000 轮的时候,已经开始有一点人脸的轮廓了,并且有动漫人物大眼睛的特点;
训练到 10000 轮以后,还有些模糊
训练 20000 轮后生成的图片完全跟真的很接近
训练 50000 轮后生成的图片已经跟真的没什么区别了
在这里插入图片描述
除此之外,GAN也可以产生真实的人脸。
如图所示:
产生高清人脸的技术,叫做渐进式 GAN(progressive GAN),上下两排都是由机器产生的人脸。
同样,我们可以用 GAN 产生我们从没有看过的人脸
在这里插入图片描述
先前介绍的 GAN 中的Generator,就是输入一个向量,输出一张图片。
但是,我们还可以把输入的向量做内差,在输出部分我们就会看到两张图片之间连续的变化。

如图所示:
在这里插入图片描述
比如我们输入一个向量通过 GAN 产生一个看起来非常严肃的男人,同时输入另一个向量通过 GAN 产生一个微笑着的女人。
那我们输入这两个向量中间的数值向量,就可以看到这个男人逐渐地笑了起来。
在这里插入图片描述
另一个例子,输入一个向量产生一个往左看的人,同时输入一个向量产生一个往右看的人,我们在之间做内差,机器并不会傻傻地将两张图片叠在一起,而是生成一张正面的脸。
神奇的是,我们在训练的时候其实并没有真的输入正面的人脸,但机器可以自己学到把这两张左右脸做内差,应该会得到一个往正面看的人脸。
在这里插入图片描述
不过如果我们不加约束,GAN 会产生一些很奇怪的图片
如图所示
比如我们使用 BigGAN 算法,会产生一个左右不对称的玻璃杯子,甚至产生一个网球狗。
在这里插入图片描述

二、Pytorch学习

1. 完整的模型验证套路

当我们训练好模型之后,需要对模型进行验证来判断它的好坏
验证模型的核心:利用已经训练好的模型,给它提供输入。
我们需要预检验的图片是一只狗。在这里插入图片描述
接下来,直接上代码演示。

import torch
import torchvision.transforms
from PIL import Image
from torch import nn
from torchvision import transforms
# 图片路径
image_path = "./images/dog.jpg"
# 用PIL打开图片,因为transform是基于PIL的基础上处理的
image = Image.open(image_path)
print(image)# 有些图片格式可能不是3通道,用这个不管是几通道都可以转换成3通道,保险起见还是加上
image = image.convert('RGB')
# 把图片转化为 32*32大小,然后再toTensor
transform = transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()])
image = transform(image)
# 查看图片的大小是否正确
print(image.shape)# 网络模型,因为是自己编写的,所以需要在这里编写出来,否则load会报错
class MCifar(torch.nn.Module):def __init__(self):super(MCifar, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2, 2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2, 2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2, 2),nn.Flatten(),nn.Linear(1024, 64),nn.Linear(64, 10),)def forward(self, x):x = self.model(x)return x# 导入训练好的模型参数。
# 注意:如果是在GPU上训练,在加载模型时,要切换为cpu
model = torch.load("cifar10_99.pth", map_location=torch.device('cpu'))
print(model)
# batch_size = 1(单张图片);通道数为3(RGB);高度宽度为32*32
image = torch.reshape(image, (1, 3, 32, 32))
# 表面为验证模块
model.eval()
# 梯度始终为0,减少不必要的运算消耗资源
with torch.no_grad():output = model(image)
# 输出预测结果下标
print(output.argmax(1))

结果如下:
预测为dog,结果正确。
在这里插入图片描述

2. Github开源项目——pytorch-CycleGAN-and-pix2pix

链接如下:
pytorch-CycleGAN-and-pix2pix
在GitHub上看开源项目的时候,我们要看 README ,看看这个项目具体是干什么的。
从图片上,我们很明显可以看出来GAN就是负责在原来图片基础上进行“创造”
例如:
给出的是一匹普通的马,经过CycleGAN-and-pix2pix后就变成了一批斑马
在这里插入图片描述
往下面滑动,README中会叫我们如何安装项目
例如:克隆存储库、下载 CycleGAN 数据集(例如地图)、测试模型等等
所以在github找项目首先看README是很重要的
在这里插入图片描述
然后我们点击进入项目中的 train.py 查看具体情况。
在这里插入图片描述
以下是train.py的具体内容
其实就和我们训练的套路差不多
先获取数据集
然后导入模型
然后设置epoch(训练轮数),开始训练,然后训练多少次打印结果、训练多少次保存模型,到最后训练多少轮保存我们最新的模型

import time
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizerif __name__ == '__main__':opt = TrainOptions().parse()   # 获取训练选项dataset = create_dataset(opt)  # 根据 opt.dataset_mode 和其他选项创建数据集dataset_size = len(dataset)    # 获取数据集中的图像数量。print('训练图像的数量 = %d' % dataset_size)model = create_model(opt)      # 根据 opt.model 和其他选项创建模型model.setup(opt)               # 常规设置:加载并打印网络;创建调度器visualizer = Visualizer(opt)   # 创建一个可视化工具,用于显示/保存图像和图表total_iters = 0                # 训练迭代的总次数for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):    # 每一轮用于不同的周期;我们通过 <epoch_count>, <epoch_count>+<save_latest_freq> 保存模型epoch_start_time = time.time()  # 训练一轮的计时器iter_data_time = time.time()    # 每一轮迭代数据加载的计时器epoch_iter = 0                  # 当前轮数的训练次数,每一轮重置为0visualizer.reset()              # 重置可视化工具:确保它至少每一轮保存结果到HTML一次model.update_learning_rate()    # 在每一轮的开始更新学习率。for i, data in enumerate(dataset):  # 开始一轮中的每一次的内部循环iter_start_time = time.time()  # 每次计算的计时器if total_iters % opt.print_freq == 0:t_data = iter_start_time - iter_data_timetotal_iters += opt.batch_sizeepoch_iter += opt.batch_sizemodel.set_input(data)         # 从数据集解包数据并应用预处理model.optimize_parameters()   # 计算损失函数,获取梯度,更新网络权重if total_iters % opt.display_freq == 0:   # 每训练<opt.display_freq>次在 visdom 上显示图像并将图像保存到 HTML 文件save_result = total_iters % opt.update_html_freq == 0model.compute_visuals()visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)if total_iters % opt.print_freq == 0:    # 每训练<opt.print_freq>次打印训练损失并将日志信息保存到磁盘losses = model.get_current_losses()t_comp = (time.time() - iter_start_time) / opt.batch_sizevisualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)if opt.display_id > 0:visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)if total_iters % opt.save_latest_freq == 0:   # 每训练<save_latest_freq> 次迭代保存我们的最新模型print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'model.save_networks(save_suffix)iter_data_time = time.time()if epoch % opt.save_epoch_freq == 0:              # 每训练<save_epoch_freq> 轮保存我们的模型print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))model.save_networks('latest')model.save_networks(epoch)print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))

引入项目的过程中调参是我们最麻烦的一个问题,下面我们就来了解一下如何调参
train.py 有一个TrainOptions()
我们点击进入TrainOptions()就可以查看参数的设置
我们可以多观察后面的注释来了解具体的参数是干什么的
在这里插入图片描述
然后在README中,我们可以看到官网给了我么如何训练和测试的方法
就是通过控制台上输入代码测试,但是往往在这个过程的参数调整也是我们的一大困扰,下面就让我们来解读一下如何修改参数
在这里插入图片描述
首先我们要找到dataroot这个单词
就在base_options.py (是TrainOptions的父类)
注意事项如图所示:
在这里插入图片描述
从上图中我们可以看到,
1、有些属性只有 require 没有 default 的参数;
(这类属性,说明我们必须要赋予其一个值)

2、有些属性只有default 没有 require 参数;
(这类属性,如果我们不赋值,它就使用默认值,即default里面的内容)

一般我们运行项目的时候,先看看这个文件中有没有 require = True 的,有的话就把他们替换为default就直接运行即可。(一般数据集正确或者路径无错误,都可以正常运行)
在这里插入图片描述

总结

本周的需要处理的事情很多,所进度比较缓慢,希望下一周加快进度。
本周在机器学习中,我学习了Pointer Network(指针网络),指针网络是一种序列到序列的学习模型,它通过指针机制来预测序列中元素的位置。这种网络特别适合于处理具有明显顺序性的任务,如文本摘要、机器翻译等。此外还学习了生成式对抗网络(Generative Adversarial Networks | GAN),GAN由两部分组成:生成器(Generator)和判别器(Discriminator)。生成器的目标是生成尽可能真实的数据,而判别器则尝试区分真实数据和生成器生成的假数据。 Generator(生成器)是一个神经网络,它接收随机分布 z 作为输入,并生成新的随机分布。这些样本旨在模仿训练数据的分布。Discriminator(判别器),也是一个神经网络,它的任务是判断输入的数据是真实的还是由生成器生成的。它通过输出一个概率值来表示其判断。在训练GAN时,生成器和判别器会交替进行训练,生成器试图生成越来越难以被判别器识别的数据,而判别器则不断学习以更好地区分真假数据。这个过程可以类比为两名对手之间的对抗游戏。最后还学习了GAN的应用案例, GAN在图像生成、风格迁移、数据增强等领域有着广泛的应用,例如,它可以用于生成逼真的人脸图像,或者将一张图片的风格应用到另一张图片上。
在PyTorch学习中,我学习了完整的模型验证套路,模型验证是评估模型性能的重要步骤。我让模型在未见过的数据(一张狗的图片)进行测试。最后还学习了如何阅读Github开源项目——CycleGAN-and-pix2pix,这是基于PyTorch的深度学习模型,用于图像到图像的转换任务。CycleGAN能够实现不同图像域之间的转换,而不需要成对的训练数据。pix2pix则是一种条件GAN,它需要成对的输入和输出图像来训练模型。
因为最近临近期末需要复习,所以学习速度会有所减缓,数学拓展模块展示搁置,后面考完试再重启。下一周计划继续学习GAN,然后可以手撕一些深度学习中的基础模型,例如,感知机,深度理解一下激活函数的作用以及运作原理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/882930.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ssm企业库存管理微信小程序-计算机毕业设计源码82704

摘 要 本文基于SSM框架&#xff0c;设计与实现了一个企业库存管理微信小程序。该小程序主要包括用户登录、库存查询、入库操作、出库操作等功能模块。在设计过程中&#xff0c;采用了前后端分离的架构&#xff0c;前端使用了微信小程序原生开发工具进行开发&#xff0c;后端使用…

【C++篇】探索STL之美:熟悉使用String类

CSDN 文章目录 前言 &#x1f4ac; 欢迎讨论&#xff1a;如果你在学习过程中有任何问题或想法&#xff0c;欢迎在评论区留言&#xff0c;我们一起交流学习。你的支持是我继续创作的动力&#xff01; &#x1f44d; 点赞、收藏与分享&#xff1a;觉得这篇文章对你有帮助吗&…

ApacheShiro反序列化 550 721漏洞

Apache Shiro是一个强大且易用的Java安全框架,执行身份验证、授权、密码和会话管理个漏洞被称为 Shiro550 是因为在Apache Shiro的GitHub问题跟踪器中&#xff0c;该漏洞最初被标记为第550个问题,721漏洞名称也是由此而来 Shiro-550 CVE-2016-4437 Shiro反序列化Docker复现 …

Android GPU Inspector分析帧数据快速入门

使用 谷歌官方工具Android GPU Inspector (AGI) 可以对Android 应用进行深入和全面的系统性能分析和帧性能分析 。AGI 是一个非常强大的分析工具&#xff0c;尤其是在需要诊断 GPU 性能问题和优化应用时&#xff0c;可以帮助你精准找到性能瓶颈。本文介绍如何使用该工具对帧数据…

HTTP Proxy环境下部署Microsoft Entra Connect和Health Agents

在企业环境中&#xff0c;时常需要通过使用HTTP Proxy访问Internet&#xff0c;在使用HTTP Proxy访问Internet的环境中部署Microsoft Entra Connect和Microsoft Entra Connect Health Agents可能会遇到一些额外的配置步骤&#xff0c;以便这些服务能够正常连接到Internet。 一…

Windows系统PyCharm右键运行.sh文件

在参考了Windows系统下pycharm运行.sh文件&#xff0c;执行shell命令_shell在pycharm-CSDN博客 和深度学习&#xff1a;PyCharm中运行Bash脚本_pycharm bash-CSDN博客 配置了右键执行.sh文件之后&#xff0c;发现在Windows的PyCharm中直接右键运行sh文件&#xff0c;存在如下…

【MyBatis】MyBatis-config标签详解

目录 MyBatis配置文件标签详解configuration标签properties标签typeAliases标签environments标签environment标签transactionManager标签dataSource标签mappers标签 MyBatis配置文件标签详解 我们在使用MyBatis框架的时候需要一个配置文件——MyBatis-config.xml来告诉MyBatis…

Android按钮Button

Button是程序用于和用户进行交互的一个重要控件。Button也是继承自TextView&#xff0c;既可以显示文本&#xff0c;又可以显示图片&#xff0c;二者在UI上的区别主要是 Button 控件有个按钮外观&#xff0c;提示用户单击。 图1 Button示意图 Button最主要的功能是通过单击来执…

K折交叉验证代码实现——详细注释版

正常方法 #---------------------------------Torch Modules -------------------------------------------------------- from __future__ import print_function import numpy as np import pandas as pd import torch.nn as nn import math import torch.nn.functional as …

基于潜空间搜索的策略自适应组合优化(NeurIPS2023)(未完)

文章目录 Abstract1 Introduction2 Related work3 Methods3.1 预备知识3.2 COMPASS4 Experiments4.1 TSP、CVRP和JSSP的标准基准测试4.2 对泛化的鲁棒性:解决变异实例4.3 搜索策略分析5 ConclusionAbstract 组合优化是许多现实应用的基础,但设计高效算法以解决这些复杂的、通…

MongoDB Shell 基本命令(三)生成学生脚本信息和简单查询

一、生成学生信息脚本 利用该脚本可以生成任意个学生信息&#xff0c;包括学号、姓名、班级、年级、专业、课程名称、课程成绩等信息&#xff0c;此处生成2万名学生&#xff0c;学生所有信息都是给定范围后随机生成。 生成学生信息后&#xff0c;再来对学生信息进行简单查询。…

关于武汉芯景科技有限公司的限流开关芯片XJ6241开发指南(兼容LTC4411)

一、芯片引脚介绍 1.芯片引脚 二、系统结构图 三、功能描述 1.CTL引脚控制VIN和VOUT的通断 2.CTL引脚控制STAT引脚的状态 3.输出电压高于输入电压加上–VRTO的值&#xff0c;芯片处于关断状态

Artistic Oil Paint 艺术油画着色器插件

只需轻轻一点&#xff0c;即可将您的视频游戏转化为艺术品&#xff01;&#xff08;也许更多…&#xff09;。 ✓ 整个商店中最可配置的选项。 ✓ 六种先进算法。 ✓ 细节增强算法。 ✓ 完整的源代码&#xff08;脚本和着色器&#xff09;。 ✓ 包含在“艺术包”中。 &#x1f…

【数组知识的扩展①】

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” ArrayList在Java数组中的使用技巧 这篇博客灵感来源于某一天Aileen(&#x1f92b;)遇到了一道数组合并的题&…

python 文件防感染扫描

一、安装 首先&#xff0c;你需要安装 secplugs-python-client 库。你可以通过 pip 命令来安装&#xff1a; pip install secplugs-python-client确保你的 Python 环境已经正确设置&#xff0c;并且网络连接畅通&#xff0c;以便能够顺利安装。 二、基本用法 1. 初始化客户…

【记录】Windows|Windows 修改字体大全(Windows 桌面、VSCode、浏览器)

【记录】Windows&#xff5c;Windows 修改字体大全&#xff08;Windows 桌面、VSCode、浏览器&#xff09; 前言 最近从学长那里发现了一款非常美观的衡水体字体——Maple Mono SC NF。您可以通过以下链接下载该字体&#xff1a;https://github.com/subframe7536/maple-font/…

TiDB替换Starrocks:业务综合宽表迁移的性能评估与降本增效决策

作者&#xff1a; 我是人间不清醒 原文来源&#xff1a; https://tidb.net/blog/6638f594 1、 场景 业务综合宽表是报表生成、大屏幕展示和数据计算处理的核心数据结构。目前&#xff0c;这些宽表存储在Starrocks系统中&#xff0c;但该系统存在显著的性能瓶颈。例如&#…

Vue组件开发的属性

组件开发的属性&#xff1a; 1.ref属性&#xff1a; 如果在vue里&#xff0c;想要获取DOM对象&#xff0c;并且不想使用JS的原生语法&#xff0c;那么就可以使用ref属性 ref属性的用法&#xff1a; 1&#xff09;在HTML元素的开始标记中&#xff0c;或者在Vue子组件中的开始…

JVM、字节码文件介绍

目录 初识JVM 什么是JVM JVM的三大核心功能 JVM的组成 字节码文件的组成 基础信息 Magic魔数 主副版本号 其它基础信息 常量池 字段 方法 属性 字节码常用工具 javap jclasslib插件 阿里Arthas 初识JVM 什么是JVM JVM的三大核心功能 1. 解释和运行虚拟机指…

我的世界之合成

合成&#xff08;Crafting&#xff09;是一种在Minecraft中获得多种方块、工具和其他资源的方法。合成时&#xff0c;玩家必须先把物品从物品栏移入合成方格中。22的简易合成方格可以直接在物品栏中找到&#xff0c;而33的合成方格需要使用工作台或合成器来打开。 目录 1合成系…