pytorch神经网络入门代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms# 定义神经网络结构
class SimpleNN(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 设置超参数
input_size = 784  # MNIST数据集的输入大小是28x28=784
hidden_size = 784
num_classes = 10learning_rate = 0.01
num_epochs = 10# 加载MNIST数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)# 实例化模型
model = SimpleNN(input_size, hidden_size, num_classes)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):# 将输入数据转换为一维向量images = images.reshape(-1, 28*28)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))# 测试模型
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, 28*28)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))# 获取模型参数
params = model.parameters()# 打印每个参数的名称和值
for name, param in model.named_parameters():print(f'Parameter name: {name}')print(f'Parameter value: {param}')

以下代码测试正确率为:99.37%

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 定义适合MNIST数据集的CNN模型
class MNISTCNN(nn.Module):def __init__(self):super(MNISTCNN, self).__init__()# 卷积块 1self.conv_block1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2))# 卷积块 2self.conv_block2 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2))# 全连接层self.fc_layer = nn.Sequential(nn.Linear(64 * 7 * 7, 512),  # 假设经过前面的卷积和池化后特征图大小为7x7nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(512, 10)  # MNIST有10个类别)def forward(self, x):x = self.conv_block1(x)x = self.conv_block2(x)# 将卷积层输出展平为一维向量x = x.view(x.size(0), -1)# 通过全连接层x = self.fc_layer(x)return x# 创建模型实例
model = MNISTCNN()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 加载MNIST数据集并预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 使用DataLoader加载批量数据
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 开始训练
num_epochs = 10
for epoch in range(num_epochs):for inputs, labels in train_loader:# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()  # 清空梯度缓存loss.backward()  # 计算梯度optimizer.step()  # 更新参数# 每个epoch结束时打印损失print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 测试模型
model.eval()  # 将模型切换到评估模式(禁用Dropout和BatchNorm等)
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Test Accuracy: {100 * correct / total}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/685982.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

css篇---移动端适配的方案有哪几种

移动端适配 移动端适配是指同一个页面可以在不同的移动端设备上都有合理的布局。主流实现的方案有 响应式布局通过rem或者vw,vh 等实现不同设备有相同的比例而实现适配 首先需要了解viewport 【视口】 视口代表了一个可看见的多边形区域(通常来说是矩形&#xff0…

【c++】vector的增删查改

1.先定义一个类对象vector 为了防止和库里面发生冲突&#xff0c;定义一个命名空间&#xff0c;将类对象放在命名空间 里面 #include<iostream> using namespace std; namespace zjw {class vector {public:private:}; }2.定义变量&#xff0c;需要一个迭代器&#xff…

Android---DslTabLayout实现底部导航栏

1. 在 Android 项目中引用 JitPack 库 AGP 8. 根目录的 settings.gradle dependencyResolutionManagement {...repositories {...maven { url https://jitpack.io }} } AGP 8. 根目录如果是 settings.gradle.kts 文件 dependencyResolutionManagement {...repositories {...…

openclash无法访问 https://raw.githubusercontent.com

个人尝试&#xff1a; 1&#xff1a;尝试覆写设置“ -> ”规则设置“ -> ”自定义规则“ -> ”规则设置仍无法访问&#xff1b; 2&#xff1a;配置管理“ -> ”配置文件编辑“ -> ”左边区域添加规则&#xff0c;尝试后也无效。 解决方案&#xff1a; 1&#x…

单片机学习笔记---LED呼吸灯直流电机调速

目录 LED呼吸灯 直流电机调速 模型结构 波形 定时器初始化函数 中断函数 主程序 上一节讲了电机的工作原理&#xff0c;这一节开始代码演示&#xff01; 我们上一篇说Ton的时间长Toff时间短电机会快&#xff0c;Ton的时间短Toff时间长电机会慢 并且我们还要保证无论Ton和…

element-UI 组件 dialog 中 ref 获取不到元素

项目场景&#xff1a; vue3集成bpmn.js 渲染过程中&#xff0c;进行流程图查看 问题描述 dialog弹窗加载获取canvas中 加载不到&#xff0c;导致偶尔流程展示加载失败 原因分析&#xff1a; 提示&#xff1a;官方解释如下&#xff0c;主要就是获取的时候&#xff0c;组件没有…

【微服安全】OpenID Connect 简介:现代应用程序的身份验证

OpenID Connect (OIDC) 是一个建立在 OAuth 2.0 之上的开放身份验证协议。它简化了应用程序以一种标准化和可互操作的方式验证用户身份并获取其基本个人资料信息的方式。可以将其视为应用程序“知道你是谁”的一种安全方式&#xff0c;而无需你创建单独的帐户或透露你的密码。 …

【Git】常用命令

命令作用注意git -v查看 git 版本git init初始化 git 仓库git add 文件标识暂存某个文件文件标识以终端为起始的相对路径git add .暂存所有文件git commit -m ‘说明注释’提交产生版本记录每次提交&#xff0c;把暂存区内容快照一份git status查看文件状态 - 详细信息git stat…

qml之Control类型布局讲解,padding属性和Inset属性细讲

1、Control布局图 2、如何理解&#xff1f; *padding和*Inset参数如何理解呢&#xff1f; //main.qml import QtQuick 2.0 import QtQuick.Controls 2.12 import QtQuick.Layouts 1.12 import QtQuick.Controls 1.4 import QtQml 2.12ApplicationWindow {id: windowvisible: …

Leetcode - 周赛384

目录 一&#xff0c;3033. 修改矩阵 二&#xff0c;3035. 回文字符串的最大数量 三&#xff0c;3036. 匹配模式数组的子数组数目 II 一&#xff0c;3033. 修改矩阵 这道题直接暴力求解&#xff0c;先算出每一列的最大值&#xff0c;再将所有为-1的区域替换成该列的最大值&am…

算法刷题:无重复字符的最长字串

无重复字符的最长字串 .题目链接题目详情算法原理题目解析滑动窗口定义指针进窗口判断出窗口更新结果 我的答案 . 题目链接 无重复字符的最长字串 题目详情 算法原理 题目解析 首先,为了使字符串遍历的更加方便,我们选择将字符串转换为数组 题目要求子串中不能有重复的字符…

知识图谱:py2neo将csv文件导入neo4j

文章目录 安装py2neo创建节点-连线关系图导入csv文件删除重复节点并连接边 安装py2neo 安装python中的neo4j操作库&#xff1a;pip install py2neo 安装py2neo后我们可以使用其中的函数对neo4j进行操作。 图数据库Neo4j中最重要的就是结点和边&#xff08;关系&#xff09;&a…

隐函数的求导【高数笔记】

1. 什么是隐函数&#xff1f; 2. 隐函数的做题步骤&#xff1f; 3. 隐函数中的复合函数求解法&#xff0c;与求导中复合函数求解法有什么不同&#xff1f; 4. 隐函数求导的过程中需要注意什么&#xff1f;

【Linux网络编程五】Tcp套接字编程(四个版本服务器编写)

【Linux网络编程五】Tcp套接字编程(四个版本服务器编写&#xff09; [Tcp套接字编程]一.服务器端进程&#xff1a;1.创建套接字2.绑定网络信息3.设置监听状态4.获取新连接5.根据新连接进行通信 二.客户端进程&#xff1a;1.创建套接字2.连接服务器套接字3.连接成功后进行通信 三…

深度学习的基本原理和算法

深度学习的基本原理、算法和神经网络的基本概念是人工智能领域中非常重要的部分&#xff0c;下面我将分别深入解释这些内容。 深度学习的基本原理 深度学习的基本原理在于使用深层神经网络来模拟人脑神经元的连接方式&#xff0c;从而实现对复杂数据的分析和处理。它依赖于大…

爱上JVM——常见问题(一):JVM组成

1 JVM组成 1.1 JVM由那些部分组成&#xff0c;运行流程是什么&#xff1f; 难易程度&#xff1a;☆☆☆ 出现频率&#xff1a;☆☆☆☆ JVM是什么 Java Virtual Machine Java程序的运行环境&#xff08;java二进制字节码的运行环境&#xff09; 好处&#xff1a; 一次编写&…

mysql存储范式简记

范式与反范式&#xff0c;范式追求不同表中的不重复存储&#xff0c;反范式可以重复存储一张表里。每张表一定要有一个主键&#xff0c;自增主键只推荐用在非核心业&#xff0c;核心业务表推荐使用 UUID 或业务自定义主键&#xff0c;可以通过 JSON 数据类型进行反范式设计。 …

第三十三回 镇三山大闹青州道 霹雳火夜走瓦砾场-python分割字符串

黄信和刘知寨押解宋江和花荣向青州走&#xff0c;碰到了燕顺等三人来劫囚车&#xff0c;黄信逃走了&#xff0c;刘知寨被抓住&#xff0c;被花荣一刀杀了。 黄信把情况报给青州知府&#xff0c;派来了青州兵马秦统制&#xff0c;人称霹雳火的秦明。秦明与花荣打&#xff0c;花…

计算机二级之sql语言的学习(数据模型—概念模型)

概念模型 含义: 概念模型用于信息世界&#xff08;作用对象&#xff09;的建模&#xff0c;是实现现实世界到信息世界&#xff08;所以万丈高楼平地起&#xff0c;不断地学习相关的基础知识&#xff0c;保持不断地重复才能掌握最为基础的基础知识&#xff09;的概念抽象&#…

c++STL系列——(十三)总结

目录 引言 正文 结语 引言 经过一段时间的学习和探索&#xff0c;我的C STL系列博客终于迎来了结束的时刻。在这个系列中&#xff0c;我们深入了解了C标准模板库&#xff08;STL&#xff09;的核心组件、常用容器和算法&#xff0c;并探讨了它们的特性和适用场景。通过这一…