深度学习笔记_4、CNN卷积神经网络+全连接神经网络解决MNIST数据

1、首先,导入所需的库和模块,包括NumPy、PyTorch、MNIST数据集、数据处理工具、模型层、优化器、损失函数、混淆矩阵、绘图工具以及数据处理工具。

import numpy as np
import torch
from torchvision.datasets import mnist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import csv
import pandas as pd

2、设置超参数,包括训练批次大小、测试批次大小、学习率和训练周期数。

# 设置超参数
train_batch_size = 64
test_batch_size = 64
learning_rate = 0.001
num_epochs = 10

3、创建数据转换管道,将图像数据转换为张量并进行标准化。

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])

4、下载和预处理MNIST数据集,分为训练集和测试集。

# 下载和预处理数据集
train_dataset = mnist.MNIST('data', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('data', train=False, transform=transform)

5、创建用于训练和测试的数据加载器,以便有效地加载数据。

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

 6、定义了一个简单的CNN模型,包括两个卷积层和两个全连接层。

# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=5)self.conv2 = nn.Conv2d(32, 64, kernel_size=5)self.fc1 = nn.Linear(1024, 256)self.fc2 = nn.Linear(256, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2(x), 2))x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)

7、初始化模型、优化器和损失函数。

# 初始化模型、优化器和损失函数
model = CNN()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

8、准备用于记录训练和测试过程中损失和准确率的列表。

# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

9、进入训练循环,遍历每个训练周期。在每个训练周期内,进入训练模式,遍历训练数据批次,计算损失、反向传播并更新模型参数,同时记录训练损失和准确率。

for epoch in range(num_epochs):model.train()train_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()train_loss += loss.item()# 计算训练准确率_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 计算平均训练损失和训练准确率train_loss /= len(train_loader)train_accuracy = 100. * correct / totaltrain_losses.append(train_loss)train_accuracies.append(train_accuracy)  # 记录训练准确率# 测试模型model.eval()test_loss = 0.0correct = 0all_labels = []all_preds = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_labels.extend(target.numpy())all_preds.extend(pred.numpy())

10、在每个训练周期结束后,进入测试模式,遍历测试数据批次,计算测试损失和准确率,同时记录它们。打印每个周期的训练和测试损失以及准确率。

# 计算平均测试损失和测试准确率test_loss /= len(test_loader)test_accuracy = 100. * correct / len(test_loader.dataset)test_losses.append(test_loss)test_accuracies.append(test_accuracy)print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

11、losses、acces、eval_losses、eval_acces保存到TXT文件

# 保存训练结果
data = np.column_stack((train_losses,test_losses,train_accuracies, test_accuracies))
np.savetxt("results.txt", data)

12、绘制Loss、ACC图像

# 绘制Loss曲线图
plt.figure(figsize=(10, 2))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.savefig('loss_curve.png')
plt.show()# 绘制Accuracy曲线图
plt.figure(figsize=(10, 2))
plt.plot(train_accuracies, label='Train Accuracy', color='red')  # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.savefig('accuracy_curve.png')
plt.show()

 

 13、绘制混淆矩阵图像

# 计算混淆矩阵
confusion_mat = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/95013.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用向量数据库Milvus Cloud 搭建AI聊天机器人

加入大语言模型(LLM) 接着,需要在聊天机器人中加入 LLM。这样,用户就可以和聊天机器人开展对话了。本示例中,我们将使用 OpenAI ChatGPT 背后的模型服务:GPT-3.5。 聊天记录 为了使 LLM 回答更准确,我们需要存储用户和机器人的聊天记录,并在查询时调用这些记录,可以用…

Java练习 day4

一、存在重复元素 II 1、题目链接 点击跳转到题目位置 2、代码 class Solution {public boolean containsNearbyDuplicate(int[] nums, int k) {Map<Integer, Integer> mp new HashMap<Integer, Integer>();int n nums.length;for(int i 0; i < n; i){if…

C#WPF框架MvvMLight应用实例

本文实例演示C#WPF框架MvvMLight应用实例。 目录 一、MVVM概述 二、MVVMLight概述 三、使用MVMLight框架 一、MVVM概述 MVVM概述MVVM是Model-View-ViewModel的简写,主要目的是为了解耦视图(View)和模型(Model)。

redis的持久化消息队列

Redis Stream Redis Stream 是 Redis 5.0 版本新增加的数据结构。 Redis Stream 主要用于消息队列&#xff08;MQ&#xff0c;Message Queue&#xff09;&#xff0c;Redis 本身是有一个 Redis 发布订阅 (pub/sub) 来实现消息队列的功能&#xff0c;但它有个缺点就是消息无法…

一键AI高清换脸——基于InsightFace、CodeFormer实现高清换脸与验证换脸后效果能否通过人脸比对、人脸识别算法

前言 1、项目简介 AI换脸是指利用基于深度学习和计算机视觉来替换或合成图像或视频中的人脸。可以将一个人的脸替换为另一个人的脸,或者将一个人的表情合成到另一个人的照片或视频中。算法常常被用在娱乐目上,例如在社交媒体上创建有趣的照片或视频,也有用于电影制作、特效…

Python基础之生成器

&#xff08;一&#xff09;什么是生成器 在python中&#xff0c;想要了解什么是生成器(generator)&#xff0c;首先就需要了解什么是yield关键字。yield表达式只能函数内部或者lambda函数中使用&#xff0c;使用了yield表达式的函数即为生成器函数&#xff0c;而生成器函数返…

全屋灯具选购指南,如何选择合适的灯具。福州中宅装饰,福州装修

灯具装修指南 灯具就像我们家里的星星&#xff0c;在黑暗中带给我们明亮&#xff0c;可是灯具如果选择的不好&#xff0c;这个效果不仅体现不出来&#xff0c;还会让人觉得烦躁。 灯具到底该怎么选呢&#xff1f;装修灯具有哪些注意事项呢&#xff1f;给大家做了一个总结&#…

链式法则(Chain Rule)

定义 链式法则&#xff08;Chain Rule&#xff09;是概率论和统计学中的一个基本原理&#xff0c;用于计算联合概率分布或条件概率分布的乘积。它可以用于分解一个复杂的概率分布为多个较简单的条件概率分布的乘积&#xff0c;从而简化概率分析问题。 链式法则有两种常见的形…

Map声明、元素访问及遍历、⼯⼚模式、实现 Set - GO语言从入门到实战

Map声明、元素访问及遍历 - GO语言从入门到实战 Map 声明的方式 m := map[string]int{"one": 1, "two": 2, "three": 3} //m初始化时就已经设置了3个键值对,所以它的初始长度len(m)是3。m1 := map[string]int{} //m1被初始化为一个空的m…

C++设计模式-抽象工厂(Abstract Factory)

目录 C设计模式-抽象工厂&#xff08;Abstract Factory&#xff09; 一、意图 二、适用性 三、结构 四、参与者 五、代码 C设计模式-抽象工厂&#xff08;Abstract Factory&#xff09; 一、意图 提供一个创建一系列相关或相互依赖对象的接口&#xff0c;而无需指定它们…

笔试编程ACM模式JS(V8)、JS(Node)框架、输入输出初始化处理、常用方法、技巧

目录 考试注意事项 先审完题意&#xff0c;再动手 在本地编辑器&#xff08;有提示&#xff09; 简单题515min 通过率0%&#xff0c;有额外log 常见输入处理 str-> num arr&#xff1a;line.split( ).map(val>Number(val)) 初始化数组 new Array(length).fill(v…

国庆中秋特辑(七)Java软件工程师常见20道编程面试题

以下是中高级Java软件工程师常见编程面试题&#xff0c;共有20道。 如何判断一个数组是否为有序数组&#xff1f; 答案&#xff1a;可以通过一次遍历&#xff0c;比较相邻元素的大小。如果发现相邻元素的大小顺序不对&#xff0c;则数组不是有序数组。 public boolean isSort…

Windows下Tensorflow docker python开发环境搭建

前置条件 windows10 更新到较新的版本&#xff0c;硬件支持Hyper-V。 参考&#xff1a;https://learn.microsoft.com/zh-cn/windows/wsl/install 启用WSL 在Powershell中输入如下指令&#xff1a; dism.exe /online /enable-feature /featurename:Microsoft-Windows-Subsys…

01-工具篇-windows与linux文件共享

一般来说绝大部分PC上装的系统均是windows&#xff0c;为了开发linux程序&#xff0c;会在PC上安装一个Vmware的虚拟机&#xff0c;在虚拟机上安装ubuntu18.04&#xff0c;由于windows上的代码查看软件、浏览器&#xff0c;通信软件更全&#xff0c;我们想只用ubuntu进行编译&a…

哨兵(Sentinel-1、2)数据下载

哨兵&#xff08;Sentinel-1、2&#xff09;数据下载 一、登陆欧空局网站 二、检索 先下载2号为光学数据 分为S2A和S2B&#xff0c;产品种类有1C和2A&#xff0c;区别就是2A是做好大气校正的影像&#xff0c;当然数量也会少一些&#xff0c;云量检索条件中记得要按格式&#x…

LeetCode 251:展开二维向量

题目 Implement an iterator to flatten a 2d vector. Example: [1,2,3,4,5,6] [1,2,3,4,5,6] Follow up: As an added challenge, try to code it using only iterators in C++ or iterators in Java. 题解: 用两个index 分别记录list 的 index 和当前 list的element index. …

【Linux基础】Linux发展史

&#x1f449;系列专栏&#xff1a;【Linux基础】 &#x1f648;个人主页&#xff1a;sunny-ll 一、前言 本篇主要介绍Linux的发展历史&#xff0c;这里并不需要我们掌握&#xff0c;但是作为一个合格的Linux学习者与操作者&#xff0c;这些东西是需要了解的&#xff0c;而且…

深入理解浏览器渲染原理

文章目录 浏览器是如何渲染页面的渲染流程解析HTML&#xff08;构建DOM树&#xff09;解析过程中遇到JS代码 样式计算1. 解析CSS代码2. 转换样式表中的属性值&#xff0c;使其标准化3. 计算DOM树中每个节点的具体样式CSS继承规则CSS层叠规则 布局分层分层update layer tree 绘制…

WebGL 响应上下文丢失解决方案

目录 响应上下文丢失 如何响应上下文丢失 上下文事件 示例程序&#xff08;RotatingTriangle_contextLost.js&#xff09; 响应上下文丢失 WebGL使用了计算机的图形硬件&#xff0c;而这部分资源是被操作系统管理&#xff0c;由包括浏览器在内的多个应用程序共享。在某些特…

【Java-LangChain:使用 ChatGPT API 搭建系统-5】处理输入-思维链推理

第五章&#xff0c;处理输入-思维链推理 在本章中&#xff0c;我们将专注于处理输入&#xff0c;即通过一系列步骤生成有用地输出。 有时&#xff0c;模型在回答特定问题之前需要进行详细地推理。如果您参加过我们之前的课程&#xff0c;您将看到许多这样的例子。有时&#xf…