文章目录
- 前言
- 一、说明
- 二、示例
- 1.步骤
- 2.示例代码
- 总结
前言
介绍如何利用PyTorch中Softmax 分类器实现多分类问题。
一、说明
1.多分类问题的输出是一个分布,满足和为1.
2.Softmax 分类器
3.损失函数:交叉熵损失
torch.nn.CrossEntropyLoss()
二、示例
1.步骤
1.建立模型
2.定义训练函数
3.定义测试函数
4.主函数:定义训练集和测试集,定义损失函数和优化器,进行训练,存储结果,绘图
2.示例代码
代码如下(示例):
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import pickle
# prepare dataset# batch_size = 64
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # 归一化,均值和方差
#
# train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
# train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
# test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
# test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)# design model using classclass Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784, 512)self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784) # -1其实就是自动获取mini_batchx = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))return self.l5(x) # 最后一层不做激活,不进行非线性变换# model = Net()
#
# # construct loss and optimizer
# criterion = torch.nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)# training cycle forward, backward, updatedef train(epoch):running_loss = 0.0loss_s = 0.0for batch_idx, data in enumerate(train_loader, 0):# 获得一个批次的数据和标签inputs, target = dataoptimizer.zero_grad()# 获得模型预测结果(64, 10)outputs = model(inputs)# 交叉熵代价函数outputs(64,10),target(64)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()loss_s += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))running_loss = 0.0return loss_s / len(train_loader)def test():correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, dim=1) # dim = 1 列是第0个维度,行是第1个维度total += labels.size(0)correct += (predicted == labels).sum().item() # 张量之间的比较运算print('accuracy on test set: %d %% ' % (100 * correct / total))return 100 * correct / totalif __name__ == '__main__':batch_size = 64transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # 归一化,均值和方差train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)model = Net()# construct loss and optimizercriterion = torch.nn.CrossEntropyLoss(reduction='mean')optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)epoch_list = []loss_list = []accuracy_list = []for epoch in range(10):epoch_list.append(epoch)loss_lis = train(epoch)loss_list.append(loss_lis)tes = test()accuracy_list.append(tes)with open('8/epoch_list.pkl', 'wb') as f:pickle.dump(epoch_list, f)with open('8/loss_list.pkl', 'wb') as f:pickle.dump(loss_list, f)with open('8/accuracy_list.pkl', 'wb') as f:pickle.dump(accuracy_list, f)
画图程序如下:
import pickle
import matplotlib.pyplot as pltwith open('8/epoch_list.pkl', 'rb') as f:loaded_epoch_list = pickle.load(f)
with open('8/loss_list.pkl', 'rb') as f:loaded_loss_list = pickle.load(f)
with open('8/accuracy_list.pkl', 'rb') as f:loaded_acc_list = pickle.load(f)plt.subplot(2, 1, 1) # 创建子图,2行1列,第1个子图
plt.plot(loaded_epoch_list, loaded_loss_list)
plt.xlabel('epoch')
plt.ylabel('loss 1')plt.subplot(2, 1, 2) # 创建子图,2行1列,第2个子图
plt.plot(loaded_epoch_list, loaded_acc_list,'r')
plt.xlabel('epoch')
plt.ylabel('acc 1')
plt.show()
得到如下结果:
总结
PyTorch学习8:多分类问题