深度学习笔记_4、CNN卷积神经网络+全连接神经网络解决MNIST数据

1、首先,导入所需的库和模块,包括NumPy、PyTorch、MNIST数据集、数据处理工具、模型层、优化器、损失函数、混淆矩阵、绘图工具以及数据处理工具。

import numpy as np
import torch
from torchvision.datasets import mnist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import csv
import pandas as pd

2、设置超参数,包括训练批次大小、测试批次大小、学习率和训练周期数。

# 设置超参数
train_batch_size = 64
test_batch_size = 64
learning_rate = 0.001
num_epochs = 10

3、创建数据转换管道,将图像数据转换为张量并进行标准化。

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])

4、下载和预处理MNIST数据集,分为训练集和测试集。

# 下载和预处理数据集
train_dataset = mnist.MNIST('data', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('data', train=False, transform=transform)

5、创建用于训练和测试的数据加载器,以便有效地加载数据。

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

 6、定义了一个简单的CNN模型,包括两个卷积层和两个全连接层。

# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=5)self.conv2 = nn.Conv2d(32, 64, kernel_size=5)self.fc1 = nn.Linear(1024, 256)self.fc2 = nn.Linear(256, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2(x), 2))x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)

7、初始化模型、优化器和损失函数。

# 初始化模型、优化器和损失函数
model = CNN()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

8、准备用于记录训练和测试过程中损失和准确率的列表。

# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

9、进入训练循环,遍历每个训练周期。在每个训练周期内,进入训练模式,遍历训练数据批次,计算损失、反向传播并更新模型参数,同时记录训练损失和准确率。

for epoch in range(num_epochs):model.train()train_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()train_loss += loss.item()# 计算训练准确率_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 计算平均训练损失和训练准确率train_loss /= len(train_loader)train_accuracy = 100. * correct / totaltrain_losses.append(train_loss)train_accuracies.append(train_accuracy)  # 记录训练准确率# 测试模型model.eval()test_loss = 0.0correct = 0all_labels = []all_preds = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_labels.extend(target.numpy())all_preds.extend(pred.numpy())

10、在每个训练周期结束后,进入测试模式,遍历测试数据批次,计算测试损失和准确率,同时记录它们。打印每个周期的训练和测试损失以及准确率。

# 计算平均测试损失和测试准确率test_loss /= len(test_loader)test_accuracy = 100. * correct / len(test_loader.dataset)test_losses.append(test_loss)test_accuracies.append(test_accuracy)print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

11、losses、acces、eval_losses、eval_acces保存到TXT文件

# 保存训练结果
data = np.column_stack((train_losses,test_losses,train_accuracies, test_accuracies))
np.savetxt("results.txt", data)

12、绘制Loss、ACC图像

# 绘制Loss曲线图
plt.figure(figsize=(10, 2))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.savefig('loss_curve.png')
plt.show()# 绘制Accuracy曲线图
plt.figure(figsize=(10, 2))
plt.plot(train_accuracies, label='Train Accuracy', color='red')  # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.savefig('accuracy_curve.png')
plt.show()

 

 13、绘制混淆矩阵图像

# 计算混淆矩阵
confusion_mat = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/95013.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用向量数据库Milvus Cloud 搭建AI聊天机器人

加入大语言模型(LLM) 接着,需要在聊天机器人中加入 LLM。这样,用户就可以和聊天机器人开展对话了。本示例中,我们将使用 OpenAI ChatGPT 背后的模型服务:GPT-3.5。 聊天记录 为了使 LLM 回答更准确,我们需要存储用户和机器人的聊天记录,并在查询时调用这些记录,可以用…

redis的持久化消息队列

Redis Stream Redis Stream 是 Redis 5.0 版本新增加的数据结构。 Redis Stream 主要用于消息队列(MQ,Message Queue),Redis 本身是有一个 Redis 发布订阅 (pub/sub) 来实现消息队列的功能,但它有个缺点就是消息无法…

一键AI高清换脸——基于InsightFace、CodeFormer实现高清换脸与验证换脸后效果能否通过人脸比对、人脸识别算法

前言 1、项目简介 AI换脸是指利用基于深度学习和计算机视觉来替换或合成图像或视频中的人脸。可以将一个人的脸替换为另一个人的脸,或者将一个人的表情合成到另一个人的照片或视频中。算法常常被用在娱乐目上,例如在社交媒体上创建有趣的照片或视频,也有用于电影制作、特效…

全屋灯具选购指南,如何选择合适的灯具。福州中宅装饰,福州装修

灯具装修指南 灯具就像我们家里的星星,在黑暗中带给我们明亮,可是灯具如果选择的不好,这个效果不仅体现不出来,还会让人觉得烦躁。 灯具到底该怎么选呢?装修灯具有哪些注意事项呢?给大家做了一个总结&#…

C++设计模式-抽象工厂(Abstract Factory)

目录 C设计模式-抽象工厂(Abstract Factory) 一、意图 二、适用性 三、结构 四、参与者 五、代码 C设计模式-抽象工厂(Abstract Factory) 一、意图 提供一个创建一系列相关或相互依赖对象的接口,而无需指定它们…

笔试编程ACM模式JS(V8)、JS(Node)框架、输入输出初始化处理、常用方法、技巧

目录 考试注意事项 先审完题意,再动手 在本地编辑器(有提示) 简单题515min 通过率0%,有额外log 常见输入处理 str-> num arr:line.split( ).map(val>Number(val)) 初始化数组 new Array(length).fill(v…

国庆中秋特辑(七)Java软件工程师常见20道编程面试题

以下是中高级Java软件工程师常见编程面试题,共有20道。 如何判断一个数组是否为有序数组? 答案:可以通过一次遍历,比较相邻元素的大小。如果发现相邻元素的大小顺序不对,则数组不是有序数组。 public boolean isSort…

Windows下Tensorflow docker python开发环境搭建

前置条件 windows10 更新到较新的版本,硬件支持Hyper-V。 参考:https://learn.microsoft.com/zh-cn/windows/wsl/install 启用WSL 在Powershell中输入如下指令: dism.exe /online /enable-feature /featurename:Microsoft-Windows-Subsys…

01-工具篇-windows与linux文件共享

一般来说绝大部分PC上装的系统均是windows,为了开发linux程序,会在PC上安装一个Vmware的虚拟机,在虚拟机上安装ubuntu18.04,由于windows上的代码查看软件、浏览器,通信软件更全,我们想只用ubuntu进行编译&a…

哨兵(Sentinel-1、2)数据下载

哨兵(Sentinel-1、2)数据下载 一、登陆欧空局网站 二、检索 先下载2号为光学数据 分为S2A和S2B,产品种类有1C和2A,区别就是2A是做好大气校正的影像,当然数量也会少一些,云量检索条件中记得要按格式&#x…

【Linux基础】Linux发展史

👉系列专栏:【Linux基础】 🙈个人主页:sunny-ll 一、前言 本篇主要介绍Linux的发展历史,这里并不需要我们掌握,但是作为一个合格的Linux学习者与操作者,这些东西是需要了解的,而且…

深入理解浏览器渲染原理

文章目录 浏览器是如何渲染页面的渲染流程解析HTML(构建DOM树)解析过程中遇到JS代码 样式计算1. 解析CSS代码2. 转换样式表中的属性值,使其标准化3. 计算DOM树中每个节点的具体样式CSS继承规则CSS层叠规则 布局分层分层update layer tree 绘制…

WebGL 响应上下文丢失解决方案

目录 响应上下文丢失 如何响应上下文丢失 上下文事件 示例程序(RotatingTriangle_contextLost.js) 响应上下文丢失 WebGL使用了计算机的图形硬件,而这部分资源是被操作系统管理,由包括浏览器在内的多个应用程序共享。在某些特…

逐步解决Could not find artifact com:ojdbc8:jar:12

Could not find artifact com:ojdbc8:jar:12 in central (https://repo.maven.apache.org/maven2) 原因: ojdbc8:jar:12 属于Oracle 数据库链接的一个程序集,缺失的话很有可能会影响数据库链接,蝴蝶效应产生不可预测的BUG!但是版…

OpenGLES:绘制一个混色旋转的3D立方体

效果展示 混色旋转的3D立方体 一.概述 之前关于OpenGLES实战开发的博文,不论是实现相机滤镜还是绘制图形,都是在2D纬度 这篇博文开始,将会使用OpenGLES进入3D世界 本篇博文会实现一个颜色渐变、旋转的3D立方体 动态3D图形的绘制&#xf…

mybatise-plus的id过长问题

一、问题情景 笔者在做mp插入数据库(id已设置为自增)操作时,发现新增数据的id过长,结果导致前端JS拿到的数据出现了精度丢失问题,原因是后端id的类型是Long。在网上查了一下,只要在该属性上加上如下注解就可以 TableId(value &q…

进程调度的时机,切换与过程以及方式

1.进程调度的时机 进程调度(低级调度〉,就是按照某种算法从就绪队列中选择一个进程为其分配处理机。 1.需要进行进程调度与切换的情况 1.当前运行的进程主动放弃处理机 进程正常终止运行过程中发生异常而终止进程主动请求阻塞(如等待l/O)…

大模型部署手记(1)ChatGLM2+Windows GPU

1.简介: 组织机构:智谱/清华 代码仓:https://github.com/THUDM/ChatGLM2-6B 模型:THUDM/chatglm2-6b 下载:https://huggingface.co/THUDM/chatglm2-6b 镜像下载:https://aliendao.cn/models/THUDM/chat…

很普通的四非生,保研破局经验贴

推免之路 个人情况简介夏令营深圳大学情况机试面试结果 预推免湖南师范大学面试结果 安徽大学面试结果 北京科技大学笔试面试结果 合肥工业大学南京航空航天大学面试结果 暨南大学东北大学 最终结果一些建议写在后面 个人情况简介 教育水平:某中医药院校的医学信息…

Discuz!X 3.4任意文件删除漏洞

复现过程: 1.访问http://x.x.x/robots.txt(文件存在) 2.登录弱口令 账号:admin密码:admin 3.来到个人设置页面找到自己的formhash: 4.点击保存,抓包 来到这个参数:birthprovin…