基于pytorch的手写数字识别-训练+使用

import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoadermatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',"font.serif": ['SimSun'],'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)def mymap(labels):return np.where(labels < 10, labels, 0)# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
my_nn = torch.nn.Sequential(torch.nn.Linear(400, 128),torch.nn.Sigmoid(),torch.nn.Linear(128, 256),torch.nn.Sigmoid(),torch.nn.Linear(256, 512),torch.nn.Sigmoid(),torch.nn.Linear(512, 10)
).to(device)# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签# 进行预测
with torch.no_grad():  # 禁用梯度计算predictions = my_nn(sample_images)predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签# 绘制图像
plt.figure(figsize=(10, 10))
for i in range(50):plt.subplot(10, 5, i + 1)  # 10行5列的子图plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像plt.title(f'Predicted: {predicted_labels[i]}', fontsize=8)plt.axis('off')  # 关闭坐标轴plt.tight_layout()  # 调整子图间距
plt.show()

Iteration 0, Loss: 0.8472495079040527
Iteration 20, Loss: 0.014742681756615639
Iteration 40, Loss: 0.00011596851982176304
Iteration 60, Loss: 9.278443030780181e-05
Iteration 80, Loss: 1.3701709576707799e-05
Iteration 100, Loss: 5.019319928578625e-07
Iteration 120, Loss: 0.0
Iteration 140, Loss: 0.0
Iteration 160, Loss: 1.2548344585638915e-08
Iteration 180, Loss: 1.700657230685465e-05
预测准确率: 100.00%

下面使用已经训练好的模型,进行再次测试:

import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoadermatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',"font.serif": ['SimSun'],'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)def mymap(labels):return np.where(labels < 10, labels, 0)# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
my_nn = torch.nn.Sequential(torch.nn.Linear(400, 128),torch.nn.Sigmoid(),torch.nn.Linear(128, 256),torch.nn.Sigmoid(),torch.nn.Linear(256, 512),torch.nn.Sigmoid(),torch.nn.Linear(512, 10)
).to(device)# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签# 进行预测
with torch.no_grad():  # 禁用梯度计算predictions = my_nn(sample_images)predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签plt.figure(figsize=(16, 10))
for i in range(20):plt.subplot(4, 5, i + 1)  # 4行5列的子图plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像plt.title(f'True: {sample_labels[i]}, Pred: {predicted_labels[i]}', fontsize=12)  # 标题中显示真实值和预测值plt.axis('off')  # 关闭坐标轴plt.tight_layout()  # 调整子图间距
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/56020.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

洗衣店订单管理:Spring Boot技术革新

3系统分析 3.1可行性分析 通过对本洗衣店订单管理系统实行的目的初步调查和分析&#xff0c;提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本洗衣店订单管理系统采用JAVA作为开发语言&#xff0c;S…

pytest(六)——allure-pytest的基础使用

前言 一、allure-pytest的基础使用 二、需要掌握的allure特性 2.1 Allure报告结构 2.2 Environment 2.3 Categories 2.4 Flaky test 三、allure的特性&#xff0c;allure.step()、allure.attach的详细使用 3.1 allure.step 3.2 allure.attach&#xff08;挺有用的&a…

如何利用wsl-Ubuntu里conda用来给Windows的PyCharm开发

前提&#xff1a;咱们在wsl-Ubuntu上&#xff0c;有conda的虚拟环境 咱们直接打开PyCharm,打开Settings 更换Python Interpreter即可 当然一开始可能没有下面的选项&#xff0c;需要我们点击右边的Add Interpreter 这里选择wsl 点击next 将这两步进行修改 可以看出来&#xff0…

kubernetes中微服务部署

微服务 问&#xff1a;用控制器来完成集群的工作负载&#xff0c;那么应用如何暴漏出去&#xff1f; 答&#xff1a;需要通过微服务暴漏出去后才能被访问 Service 是一组提供相同服务的Pod对外开放的接口借助Service&#xff0c;应用可以实现服务发现和负载均衡Service 默认只…

智谱开放平台API调用解析

一、什么是智谱AI 智谱AI成立于2019年&#xff0c;由‌清华大学计算机系知识工程实验室的技术成果转化而来&#xff0c;是一家致力于人工智能技术研发和应用的公司。智谱致力于打造新一代认知智能大模型&#xff0c;专注于做大模型的中国创新。 二、智谱开放平台API调用 官方文…

【LeetCode】动态规划—673. 最长递增子序列的个数(附完整Python/C++代码)

动态规划—673. 最长递增子序列的个数 前言题目描述基本思路1. 问题定义2. 理解问题和递推关系3. 解决方法3.1 动态规划方法3.2 优化方法 4. 进一步优化5. 小总结 代码实现PythonPython3代码实现Python 代码解释 CC代码实现C 代码解释1. 初始化&#xff1a;2. 动态规划过程&…

FiBiNET模型实现推荐算法

1. 项目简介 A031-FiBiNET模型项目是一个基于深度学习的推荐系统算法实现&#xff0c;旨在提升推荐系统的性能和精度。该项目的背景源于当今互联网平台中&#xff0c;推荐算法在电商、社交、内容分发等领域的广泛应用。推荐系统通过分析用户的历史行为和兴趣偏好&#xff0c;预…

Django学习笔记十三:优秀案例学习

Django CMS 是一个基于 Django 框架的开源内容管理系统&#xff0c;它允许开发者轻松地创建和管理网站内容。Django CMS 提供了一个易于使用的界面来实现动态网站的快速开发&#xff0c;并且具有丰富的内容管理功能和多种插件扩展。以下是 Django CMS 的一些核心特性和如何开始…

opencv的相机标定与姿态解算

首先我们要知道四个重要的坐标系 世界坐标系相机坐标系图像成像坐标系图像像素坐标系 坐标系之间的转换 世界坐标系——相机坐标系 从世界坐标系到相机坐标系&#xff0c;涉及到旋转和平移&#xff08;其实所有的运动也可以用旋转矩阵和平移向量来描述&#xff09;。绕着不…

最新Prompt预设词指令教程大全ChatGPT、AI智能体(300+预设词应用)

使用指南 直接复制在AI工具助手中使用&#xff08;提问前&#xff09; 可以前往已经添加好Prompt预设的AI系统测试使用&#xff08;可自定义添加使用&#xff09; SparkAi系统现已支持自定义添加官方GPTs&#xff08;对专业领域更加专业&#xff0c;支持多模态文档&#xff0…

同三维T80001EHK 4K超高清HDMI编码器

【系列介绍】 同三维T80001EHK 4K超高清HDMI编码器 4K超高清编码器&#xff08;采集盒&#xff09;是专业的高清音视频编码产品&#xff0c;只需要占用较小的带宽&#xff0c;即可获得高清晰度的视频信号。该产品采用H.265编码格式&#xff0c;可同时对视频音频进行编码。输出…

【万字长文】Word2Vec计算详解(二)Skip-gram模型

【万字长文】Word2Vec计算详解&#xff08;二&#xff09;Skip-gram模型 写在前面 本篇介绍Word2Vec中的第二个模型Skip-gram模型 【万字长文】Word2Vec计算详解&#xff08;一&#xff09;CBOW模型 markdown行 9000 【万字长文】Word2Vec计算详解&#xff08;二&#xff09;S…

<Project-8.1 pdf2tx-MM> Python Flask 用浏览器翻译PDF内容 2个翻译引擎 繁简中文结果 从P8更改

更新 Project Name&#xff1a;pdf2tx (P6) Date: 5oct.24 Function: 在浏览器中翻译PDF文件 Code:https://blog.csdn.net/davenian/article/details/142723144 升级 Project Name: pdf2tx-mm (P8) 7oct.24 加入多线程&#xff0c;分页OCR识别&#xff0c;提高性能与速度 使…

5G NR UE初始接入信令流程

文章目录 5G NR UE初始接入信令流程 5G NR UE初始接入信令流程 用户设备向gNB-DU发送RRCSetupRequest消息。gNB-DU 包含 RRC 消息&#xff0c;如果 UE 被接纳&#xff0c;则在 INITIAL UL RRC MESSAGE TRANSFER 消息中包括为 UE 分配的低层配置&#xff0c;并将其传输到 gNB-CU…

【OpenCV】基础操作学习--实现原理理解

读取和显示图像 基本操作 cv2.imread(filename , flags)&#xff1a;文件中读取图像&#xff0c;从指定路径中读取图像&#xff0c;返回一个图像数组&#xff08;NumPy数组&#xff09; filename&#xff1a;图像文件的路径flags&#xff1a;指定读取图像的方式 cv2.IMREAD_COL…

linux线程 | 线程的概念

前言:本篇讲述linux里面线程的相关概念。 线程在我们的教材中的定义通常是这样的——线程是进程的一个执行分支。 线程的执行粒度&#xff0c; 要比进程要细。 我们在读完这句话后其实并不能很好的理解什么是线程。 所以&#xff0c; 本节内容博主将会带友友们理解什么是线程&a…

代码随想录算法训练营第四十六天 | 647. 回文子串,516.最长回文子序列

四十六天打卡&#xff0c;今天用动态规划解决回文问题&#xff0c;回文问题需要用二维dp解决 647.回文子串 题目链接 解题思路 没做出来&#xff0c;布尔类型的dp[i][j]&#xff1a;表示区间范围[i,j] &#xff08;注意是左闭右闭&#xff09;的子串是否是回文子串&#xff0…

2024.10月7~10日 进一步完善《电信资费管理系统》

一、新增的模块&#xff1a; 在原项目基础上&#xff0c;新增加了以下功能&#xff1a; 1、增加AspectJ 框架的AOP 异常记录和事务管理模块。 2、增加SpringMVC的拦截器&#xff0c;实现登录 控制页面访问权限。 3、增加 Logback日志框架&#xff0c;记录日志。 4、增加动态验…

Hunuan-DiT代码阅读

一 整体架构 该模型是以SD为基础的文生图模型&#xff0c;具体扩散模型原理参考https://zhouyifan.net/2023/07/07/20230330-diffusion-model/&#xff0c;代码地址https://github.com/Tencent/HunyuanDiT&#xff0c;这里介绍 Full-parameter Training 二 输入数据处理 这里…

netdata保姆级面板介绍

netdata保姆级面板介绍 基本介绍部署流程下载安装指令选择设置KSM为什么要启用 KSM&#xff1f;如何启用 KSM&#xff1f;验证 KSM 是否启用注意事项 检查端口启动状态 netdata和grafana的区别NetdataGrafananetdata各指标介绍总览system overview栏仪表盘1. CPU2. Load3. Disk…