Pytorch上手使用
近期学习了另一个深度学习框架库Pytorch,对学习进行一些总结,方便自己回顾。
Pytorch是torch的python版本,是由Facebook开源的神经网络框架。与Tensorflow的静态计算图不同,pytorch的计算图是动态的,可以根据计算需要实时改变计算图。
1 安装
如果已经安装了cuda8,则使用pip来安装pytorch会十分简单。若使用其他版本的cuda,则需要下载官方释放出来对应的安装包。具体安装地址参见官网的首页。
目前最新稳定版本为0.4.0。上个版本0.3.0的文档有中文版,见中文文档。
pip install torch torchvision # for python2.7
pip3 install torch torchvision # for python3
2 概述
理解pytorch的基础主要从以下三个方面
- Numpy风格的Tensor操作。pytorch中tensor提供的API参考了Numpy的设计,因此熟悉Numpy的用户基本上可以无缝理解,并创建和操作tensor,同时torch中的数组和Numpy数组对象可以无缝的对接。
- 变量自动求导。在一序列计算过程形成的计算图中,参与的变量可以方便的计算自己对目标函数的梯度。这样就可以方便的实现神经网络的后向传播过程。
- 神经网络层与损失函数优化等高层封装。网络层的封装存在于torch.nn模块,损失函数由torch.nn.functional模块提供,优化函数由torch.optim模块提供。
因此下面的内容也主要围绕这三个方面来介绍。第3节介绍张量的操作,第4节介绍自动求导,第5节介绍神经网络层等的封装,第6,7节简单介绍损失函数与优化方法。这三部分相对重要。后续的第8节介绍介绍数据集及torchvision,第9节介绍训练过程可视的工具,第10节通过相对完整的示例代码展示pytorch中如何解决MNIST与CIFAR10的分类。
3 Tensor(张量)
Tensor是神经网络框架中重要的基础数据类型,可以简单理解为N维数组的容器对象。tensor之间的通过运算进行连接,从而形成计算图。
3.1 Tensor类型
Torch 定义了七种 CPU tensor 类型和八种 GPU tensor 类型:
Data type | CPU tensor | GPU tensor |
---|---|---|
32-bit floating point | torch.FloatTensor | torch.cuda.FloatTensor |
64-bit floating point | torch.DoubleTensor | torch.cuda.DoubleTensor |
16-bit floating point | torch.HalfTensor | torch.cuda.HalfTensor |
8-bit integer (unsigned) | torch.ByteTensor | torch.cuda.ByteTensor |
8-bit integer (signed) | torch.CharTensor | torch.cuda.CharTensor |
16-bit integer (signed) | torch.ShortTensor | torch.cuda.ShortTensor |
32-bit integer (signed) | torch.IntTensor | torch.cuda.IntTensor |
64-bit integer (signed) | torch.LongTensor | torch.cuda.LongTensor |
通常情况下使用Tensor类的构造函数返回的是FloatTensor类型对象,可通过在对象上调用cuda()返回一个新的cuda.FloatTensor类型的对象。
torch模块内提供了操作tensor的接口,而Tensor类型的对象上也设计了对应了接口。例如torch.add()与tensor.add()等价。需要注意的是这些接口都采用创建一个新对象返回的形式。如果想就地修改一个tensor对象,需要使用加后缀下划线的方法。例如x.add_(y),将修改x。Tensor类的构建函数支持从列表或ndarray等类型进行构建。默认tensor为FloatTensor。
下面的几节简单的描述重要的操作tensor的方法。
3.1 tensor的常见创建接口
方法名 | 说明 |
---|---|
Tensor() | 直接从参数构造一个的张量,参数支持list,numpy数组 |
eye(row, column) | 创建指定行数,列数的二维单位tensor |
linspace(start,end,count) | 在区间[s,e]上创建c个tensor |
logspace(s,e,c) | 在区间[10^s, 10^e]上创建c个tensor |
ones(*size) | 返回指定shape的张量,元素初始为1 |
zeros(*size) | 返回指定shape的张量,元素初始为0 |
ones_like(t) | 返回与t的shape相同的张量,且元素初始为1 |
zeros_like(t) | 返回与t的shape相同的张量,且元素初始为0 |
arange(s,e,sep) | 在区间[s,e)上以间隔sep生成一个序列张量 |
3.2 随机采样
方法名 | 说明 |
---|---|
rand(*size) | 在区间[0,1)返回一个均匀分布的随机数张量 |
uniform(s,e) | 在指定区间[s,e]上生成一个均匀分布的张量 |
randn(*size) | 返回正态分布N(0,1)取样的随机数张量 |
normal(means, std) | 返回一个正态分布N(means, std) |
3.3 序列化
方法名 | 说明 |
---|---|
save(obj, path) | 张量对象的保存,通过pickle进行 |
load(path) | 从文件中反序列化一个张量对象 |
3.4 数学操作
这些方法均为逐元素处理方法
方法名 | 说明 |
---|---|
abs | 绝对值 |
add | 加法 |
addcdiv(t, v, t1, t2) | t1与t2的按元素除后,乘v加t |
addcmul(t, v, t1, t2) | t1与t2的按元素乘后,乘v加t |
ceil | 向上取整 |
floor | 向下取整 |
clamp(t, min, max) | 将张量元素限制在指定区间 |
exp | 指数 |
log | 对数 |
pow | 幂 |
mul | 逐元素乘法 |
neg | 取反 |
sigmoid | |
sign | 取符号 |
sqrt | 开根号 |
tanh |
注:这些操作均创建新的tensor,如果需要就地操作,可以使用这些方法的下划线版本,例如abs_。
3.5 归约方法
方法名 | 说明 |
---|---|
cumprod(t, axis) | 在指定维度对t进行累积 |
cumsum | 在指定维度对t进行累加 |
dist(a,b,p=2) | 返回a,b之间的p阶范数 |
mean | 均值 |
median | 中位数 |
std | 标准差 |
var | 方差 |
norm(t,p=2) | 返回t的p阶范数 |
prod(t) | 返回t所有元素的积 |
sum(t) | 返回t所有元素的和 |
3.6 比较方法
方法名 | 说明 |
---|---|
eq | 比较tensor是否相等,支持broadcast |
equal | 比较tensor是否有相同的shape与值 |
ge/le | 大于/小于比较 |
gt/lt | 大于等于/小于等于比较 |
max/min(t,axis) | 返回最值,若指定axis,则额外返回下标 |
topk(t,k,axis) | 在指定的axis维上取最高的K个值 |
3.7 其他操作
方法名 | 说明 |
---|---|
cat(iterable, axis) | 在指定的维度上拼接序列 |
chunk(tensor, c, axis) | 在指定的维度上分割tensor |
squeeze(input,dim) | 将张量维度为1的dim进行压缩,不指定dim则压缩所有维度为1的维 |
unsqueeze(dim) | squeeze操作的逆操作 |
transpose(t) | 计算矩阵的转置换 |
cross(a, b, axis) | 在指定维度上计算向量积 |
diag | 返回对角线元素 |
hist(t, bins) | 计算直方图 |
trace | 返回迹 |
3.8 矩阵操作
方法名 | 说明 |
---|---|
dot(t1, t2) | 计算张量的内积 |
mm(t1, t2) | 计算矩阵乘法 |
mv(t1, v1) | 计算矩阵与向量乘法 |
qr(t) | 计算t的QR分解 |
svd(t) | 计算t的SVD分解 |
3.9 tensor对象的方法
方法名 | 作用 |
---|---|
size() | 返回张量的shape属性值 |
numel(input) | 计算tensor的元素个数 |
view(*shape) | 修改tensor的shape,与np.reshape类似,view返回的对象共享内存 |
resize | 类似于view,但在size超出时会重新分配内存空间 |
item | 若为单元素tensor,则返回pyton的scalar |
from_numpy | 从numpy数据填充 |
numpy | 返回ndarray类型 |
3.10 tensor内部
tensor对象由两部分组成,tensor的信息与存储,storage封装了真正的data,可以由多个tensor共享。大多数操作只是修改tensor的信息,而不修改storage部分。这样达到效率与性能的提升。
3.11 使用pytorch进行线性回归
import torch
import torch.optim as optim
import matplotlib.pyplot as pltdef get_fake_data(batch_size=32):''' y=x*2+3 '''x = torch.randn(batch_size, 1) * 20y = x * 2 + 3 + torch.randn(batch_size, 1)return x, yx, y = get_fake_data()class LinerRegress(torch.nn.Module):def __init__(self):super(LinerRegress, self).__init__()self.fc1 = torch.nn.Linear(1, 1)def forward(self, x):return self.fc1(x)net = LinerRegress()
loss_func = torch.nn.MSELoss()
optimzer = optim.SGD(net.parameters())for i in range(40000):optimzer.zero_grad()out = net(x)loss = loss_func(out, y)loss.backward()optimzer.step()w, b = [param.item() for param in net.parameters()]
print w, b # 2.01146, 3.184525# 显示原始点与拟合直线
plt.scatter(x.squeeze().numpy(), y.squeeze().numpy())
plt.plot(x.squeeze().numpy(), (x*w + b).squeeze().numpy())
plt.show()
从这里的代码可以发现,pytorch需要我们自己实现各轮更新,并手动调用反向传播以及更新参数,此外也没有提供评估及预测功能。相对于Keras这种高层的封装,pytorch需要我们了解更多的低层细节。
4 自动求导
tensor对象通过一系列的运算可以组成动态图,对于每个tensor对象,有下面几个变量控制求导的属性。
变量 | 作用 |
---|---|
requirs_grad | 默认为False,表示变量是否需要计算导数 |
grad_fn | 变量的梯度函数 |
grad | 变量对应的梯度 |
在0.3.0版本中,自动求导还需要借助于Variable类来完成,在0.4.0版本中,Variable已经被废除了,tensor自身即可完成这一过程。
import torchx = torch.randn((4,4), requires_grad=True)
y = 2*x
z = y.sum()print z.requires_grad # Truez.backward()print x.grad
'''
tensor([[ 2., 2., 2., 2.],[ 2., 2., 2., 2.],[ 2., 2., 2., 2.],[ 2., 2., 2., 2.]])
'''
5 创建神经网络
5.1 神经网络层
torch.nn模块提供了创建神经网络的基础构件,这些层都继承自Module类。下面我们简单看下如何实现Liner层。
class Liner(torch.nn.Module):def __init__(self,in_features, out_features, bias=True):super(Liner, self).__init__()self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))if bias:self.bias = torch.nn.Parameter(torch.randn(out_features))def forward(self, x):x = x.mm(self.weight)if self.bias:x = x + self.bias.expand_as(x)return x
下面表格中列出了比较重要的神经网络层组件。对应的在nn.functional模块中,提供这些层对应的函数实现。通常对于可训练参数的层使用module,而对于不需要训练参数的层如softmax这些,可以使用functional中的函数。
Layer对应的类 | 功能说明 |
---|---|
Linear(in_dim, out_dim, bias=True) | 提供了进行线性变换操作的功能 |
Dropout(p) | Dropout层,有2D,3D的类型 |
Conv2d(in_c, out_c, filter_size, stride, padding) | 二维卷积层,类似的有Conv1d,Conv3d |
ConvTranspose2d() | |
MaxPool2d(filter_size, stride, padding) | 二维最大池化层 |
MaxUnpool2d(filter, stride, padding) | 逆过程 |
AvgPool2d(filter_size, stride, padding) | 二维平均池化层 |
FractionalMaxPool2d | 分数最大池化 |
AdaptiveMaxPool2d([h,w]) | 自适应最大池化 |
AdaptiveAvgPool2d([h,w]) | 自自应平均池化 |
ZeroPad2d(padding_size) | 零填充边界 |
ConstantPad2d(padding_size,const) | 常量填充边界 |
ReplicationPad2d(ps) | 复制填充边界 |
BatchNorm1d() | 对2维或3维小批量数据进行标准化操作 |
RNN(in_dim, hidden_dim, num_layers, activation, dropout, bidi, bias) | 构建RNN层 |
RNNCell(in_dim, hidden_dim, bias, activation) | RNN单元 |
LSTM(in_dim, hidden_dim, num_layers, activation, dropout, bidi, bias) | 构建LSTM层 |
LSTMCell(in_dim, hidden_dim, bias, activation) | LSTM单元 |
GRU(in_dim, hidden_dim, num_layers, activation, dropout, bidi, bias) | 构建GRU层 |
GRUCell(in_dim, hidden_dim, bias, activation) | GRU单元 |
5.2 非线性激活层
激活层类名 | 作用 |
---|---|
ReLU(inplace=False) | Relu激活层 |
Sigmoid | Sigmoid激活层 |
Tanh | Tanh激活层 |
Softmax | Softmax激活层 |
Softmax2d | |
LogSoftmax | LogSoftmax激活层 |
5.3 容器类型
容器类型 | 功能 |
---|---|
Module | 神经网络模块的基类 |
Sequential | 序列模型,类似于keras,用于构建序列型神经网络 |
ModuleList | 用于存储层,不接受输入 |
Parameters(t) | 模块的属性,用于保存其训练参数 |
ParameterList | 参数列表 |
下面的代码演示了使用容器型模块的方式。
# 方法1
model1 = nn.Sequential()
model.add_module('fc1', nn.Linear(3,4))
model.add_module('fc2', nn.Linear(4,2))
model.add_module('output', nn.Softmax(2))# 方法2
model2 = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())
# 方法3
model3 = nn.ModuleList([nn.Linear(3,4), nn.ReLU(), nn.Linear(4,2)])
5.4 其他层
容器类型 | 功能 |
---|---|
Embedding(vocab_size, feature_dim) | 词向量层 |
Embeddingbag |
5.5 模型的保存
前面我们知道tensor可以通过save与load方法实现序列化与反序列化。由tensor组成的网络同样也可以方便的保存。不过通常没有必要完全保存网络模块对象,只需要保存各层的权重数据即可,这些数据保存在模块的state_dict字典中,因此只需要序列化这个词典。
# 模型的保存
torch.save(model.state_dict, 'path')
# 模型的加载
model.load_state_dict('path)
5.6 实现LeNet神经网络
torch.nn.Module提供了神经网络的基类,当实现神经网络时需要继承自此模块,并在初始化函数中创建网络需要包含的层,并实现forward函数完成前向计算,网络的反向计算会由自动求导机制处理。
下面的示例代码创建了LeNet的卷积神经网络。通常将需要训练的层写在init函数中,将参数不需要训练的层在forward方法里调用对应的函数来实现相应的层。
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), 2)x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
6 损失函数与优化方法
6.1 损失函数
torch.nn模块中提供了许多损失函数类,这里列出几种相对常见的。
类名 | 功能 |
---|---|
MSELoss | 均方差损失 |
CrossEntropyLoss | 交叉熵损失 |
NLLLoss | 负对数似然损失 |
PoissonNLLLoss | 带泊松分布的负对数似然损失 |
6.2 优化方法
由torch.optim模块提供支持
类名 | 功能 |
---|---|
SGD(params, lr=0.1, momentum=0, dampening=0, weight_decay=0, nesterov=False) | 随机梯度下降法 |
Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) | Adam |
RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False) | RMSprop |
Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0) | Adadelta |
Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0) | Adagrad |
lr_scheduler.ReduceLROnPlateau(optimizer, mode=’min’, factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode=’rel’, cooldown=0, min_lr=0, eps=1e-08) | 学习率的控制 |
在神经网络的性能调优中,常见的作法是对不对层的网络设置不同的学习率。
class model(nn.Module):def __init__():super(model,self).__init__()self.base = Sequencial()# code for base sub moduleself.classifier = Sequencial()# code for classifier sub moduleoptim.SGD([{'params': model.base.parameters()},{'params': model.classifier.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9)
6.3 参数初始化
良好的初始化可以让模型快速收敛,有时甚至可以决定模型是否能训练成功。Pytorch中的参数通常有默认的初始化策略,不需要我们自己指定,但框架仍然留有相应的接口供我们来调整初始化方法。
初始化方法 | 说明 |
---|---|
xavier_uniform_ | |
xavier_normal_ | |
kaiming_uniform_ |
from torch.nn import init# net的类定义
...# 初始化各层权重
for name, params in net.named_parameters():init.xavier_normal(param[0])init.xavier_normal(param[1])
7 数据集与数据加载器
7.1 DataSet与DataLoader
torch.util.data模块提供了DataSet类用于描述一个数据集。定义自己的数据集需要继承自DataSet类,且实现__getitem__()与__len__()方法。__getitem__方法返回指定索引处的tensor与其对应的label。
为了支持数据的批量及随机化操作,可以使用data模块下的DataLoader类型来返回一个加载器:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0)
7.2 torchvision简介
torchvision是配合pytorch的独立计算机视觉数据集的工具库,下面介绍其中常用的数据集类型。
torchvision.datasets.ImageFolder(dir, transform, label_map,loader)
提供了从一个目录初始化出来一个图片数据集的便捷方法。
要求目录下的图片分类存放,每一类的图片存储在以类名为目录名的目录下,方法会将每个类名映射到唯一的数字上,如果你对数字有要求,可以用label_map来定义目录名到数字的映射。
torchvision.datasets.DatasetFolder(dir,transform, label_map, loader, extensions)
提供了从一个目录初始化一般数据集的便捷方法。目录下的数据分类存放,每类数据存储在class_xxx命名的目录下。
此外torchvision.datasets下实现了常用的数据集,如CIFAR-10/100, ImageNet, COCO, MNIST, LSUN等。
除了数据集,torchvision的model模块提供了常见的模型实现,如Alex-Net, VGG,Inception, Resnet等。
7.3 torchvision提供的图像变换工具
torchvision的transforms模块提供了对PIL.Image对象和Tensor对象的常见操作。如果需要连续应用多个变换,可以使用Compose对象组装多个变换。
转换操作 | 说明 |
---|---|
Scale | PIL图片进行缩放 |
CenterCrop | PIL图片从中心位置剪切 |
Pad | PIL图片填充 |
ToTensor | PIL图片转换为Tensor且归一化到[0,1] |
Normalize | Tensor标准化 |
ToPILImage | 将Tensor转为PIL表示 |
import torchvision.tranforms as Trans
tranform = Trans.Compose([T.Scale(28*28), T.ToTensor(), T.Normalize([0.5],[0.5])])
8 训练过程可视化
8.1 使用Tensorboard
通过使用第三方库tensorboard_logger,将训练过程中的数据保存为日志,然后便可以通过Tensorboard来查看这些数据了。其功能支持相对有限,这里不做过多介绍。
8.2 使用visdom
visdom是facebook开源的一个可视工具,可以用来完成pytorch训练过程的可视化。
安装可以使用pip install visdom
启动类似于tb,在命令行上执行:python -m visdom.server
服务启动后可以使用浏览器打开http://127.0.0.1:8097/即可看到主面板。
visdom的绘图API类似于plot,通过API将绘图数据发送到基于tornado的web服务器上并显示在浏览器中。更详细内容参见visdom的github主页
9 GPU及并行支持
为了能在GPU上运行,Tensor与Module都需要转换到cuda模式下。
import torch
import torchvisiont = torch.Tensor(3,4)
print t.is_cuda #False
t = t.cuda(0)
print t.is_cuda #Truenet = torchvision.model.AlexNet()
net.cuda(0)
如果有多块显卡,可以通过cuda(device_id)来将tensor分到不同的GPU上以达到负载的均衡。
另一种比较省事的做法是调用torch.set_default_tensor_type使程序默认使用某种cuda的tensor。或者使用torch.cuda.set_device(id)指定使用某个GPU。
10 示例:Pytorch实现CIFAR10与MNIST分类
关于cifar10与mnist数据集不再进行解释了。这里的Model类实现的二者的共同的任务,借鉴了keras的接口方式,Model类提供了train与evaluat方法,并没有实现序列模型的添加方法以及predict方法。此外设定损失函数与优化函数时,也只是简单的全部实例化出来,根据参数选择其中的一个,这里完全可以根据参数动态创建。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transformsclass Model:def __init__(self, net, cost, optimist):self.net = netself.cost = self.create_cost(cost)self.optimizer = self.create_optimizer(optimist)passdef create_cost(self, cost):support_cost = {'CROSS_ENTROPY': nn.CrossEntropyLoss(),'MSE': nn.MSELoss()}return support_cost[cost]def create_optimizer(self, optimist, **rests):support_optim = {'SGD': optim.SGD(self.net.parameters(), lr=0.1, **rests),'ADAM': optim.Adam(self.net.parameters(), lr=0.01, **rests),'RMSP':optim.RMSprop(self.net.parameters(), lr=0.001, **rest)}return support_optim[optimist]def train(self, train_loader, epoches=3):for epoch in range(epoches):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataself.optimizer.zero_grad()# forward + backward + optimizeoutputs = self.net(inputs)loss = self.cost(outputs, labels)loss.backward()self.optimizer.step()running_loss += loss.item()if i % 100 == 0:print('[epoch %d, %.2f%%] loss: %.3f' %(epoch + 1, (i + 1)*1./len(train_loader), running_loss / 100))running_loss = 0.0print('Finished Training')def evaluate(self, test_loader):print('Evaluating ...')correct = 0total = 0with torch.no_grad(): # no grad when test and predictfor data in test_loader:images, labels = dataoutputs = self.net(images)predicted = torch.argmax(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')def cifar_load_data():transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)return trainloader, testloaderclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), 2)x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef mnist_load_data():transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0,], [1,])])trainset = torchvision.datasets.MNIST(root='./data', train=True,download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,shuffle=True, num_workers=2)testset = torchvision.datasets.MNIST(root='./data', train=False,download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=32,shuffle=True, num_workers=2)return trainloader, testloaderclass MnistNet(torch.nn.Module):def __init__(self):super(MnistNet, self).__init__()self.fc1 = torch.nn.Linear(28*28, 512)self.fc2 = torch.nn.Linear(512, 512)self.fc3 = torch.nn.Linear(512, 10)def forward(self, x):x = x.view(-1, 28*28)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.softmax(self.fc3(x), dim=1)return xif __name__ == '__main__':# train for mnistnet = MnistNet()model = Model(net, 'CROSS_ENTROPY', 'RMSP')train_loader, test_loader = mnist_load_data()model.train(train_loader)model.evaluate(test_loader)# train for cifarnet = LeNet()model = Model(net, 'CROSS_ENTROPY', 'RMSP')train_loader, test_loader = cifar_load_data()model.train(train_loader)model.evaluate(test_loader)
---------------------
作者:zzulp
来源:CSDN
原文:https://blog.csdn.net/zzulp/article/details/80573331
版权声明:本文为作者原创文章,转载请附上博文链接!
内容解析By:CSDN,CNBLOG博客文章一键转载插件