【深度学习与神经网络】MNIST手写数字识别1

简单的全连接层

导入相应库

import torch
import numpy as np
from torch import nn,optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

读入数据并转为tensor向量

# 训练集
# 转为tensor数据
train_dataset = datasets.MNIST(root='./',train=True, transform = transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./',train=False, transform = transforms.ToTensor(), download=True)

装载数据集

# 批次大小
batch_size = 64# 装载训练集
train_loader = DataLoader(dataset = train_dataset, batch_size=batch_size, shuffle = True)
test_loader = DataLoader(dataset = test_dataset, batch_size=batch_size, shuffle = True)

定义网络结构
一层全连接网络,最后使用softmax转概率值输出

# 定义网络结构
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(784, 10)self.softmax = nn.Softmax(dim =1)def forward(self, x):# [64,1,28,28] ——> [64, 784]x = x.view(x.size()[0], -1)x = self.fc1(x)x = self.softmax(x)return x   

定义模型
使用均方误差损失函数,梯度下降优化

# 定义模型
model = Net()
mes_loss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),0.5)

训练并测试网络:
训练时注意最后输出(64,10)
标签是(64) ,需要将其转为one-hot编码(64,10)

def train():for i,data in enumerate(train_loader):# 获得一个批次的数据和标签inputs, labels = data# 获得模型结果 (64,10)out = model(inputs)# to one-hot 把数据标签变为独热编码labels = labels.reshape(-1,1)one_hot = torch.zeros(inputs.shape[0],10).scatter(1, labels, 1)# 计算lossloss = mes_loss(out, one_hot)# 梯度清0optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test():correct = 0for i,data in enumerate(test_loader):# 获得一个批次的数据和标签inputs, labels = data# 获得模型结果 (64,10)out = model(inputs)# 获取最大值和最大值所在位置_,predicted = torch.max(out,1)# 预测正确数量correct += (predicted == labels).sum()print("test ac:{0}".format(correct.item()/len(test_dataset)))

调用模型 训练10次

# 使用mse损失函数 
for epoch in range(10):print("epoch:",epoch)train()test()

训练结果:
在这里插入图片描述
准确率不够

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/757337.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[Python人工智能] 四十三.命名实体识别 (4)利用bert4keras构建Bert+BiLSTM-CRF实体识别模型

从本专栏开始,作者正式研究Python深度学习、神经网络及人工智能相关知识。前文讲解如何实现中文命名实体识别研究,构建BiGRU-CRF模型实现。这篇文章将继续以中文语料为主,介绍融合Bert的实体识别研究,使用bert4keras和kears包来构建Bert+BiLSTM-CRF模型。然而,该代码最终结…

vue2 实战:模板模式与渲染模式代码互切

显示效果 模板模式 <template><tr ><td class"my-td" v-if"element.isInsert1"><el-button type"danger" circle size"mini" class"delete-btn" title"删除" click"deleteItem()&quo…

KMM初探

什么是KMM&#xff1f; 在开始使用 KMM 之前&#xff0c;您需要了解 Kotlin。 KMM 全称&#xff1a;Kotlin Multiplatform Mobile&#xff09;是一个用于跨平台移动开发的 SDK,相比于其他跨平台框架&#xff0c;KMM是原生UI逻辑共享的理念,由KMM封装成Android(Kotlin/JVM)的aar…

AI大模型智能大气科学探索之:ChatGPT在大气科学领域建模、数据分析、可视化与资源评估中的高效应用及论文写作

本文深度探讨人工智能在大气科学中的应用&#xff0c;特别是如何结合最新AI模型与Python技术处理和分析气候数据。课程介绍包括GPT-4等先进AI工具&#xff0c;旨在帮助大家掌握这些工具的功能及应用范围。本文内容覆盖使用GPT处理数据、生成论文摘要、文献综述、技术方法分析等…

Java 学习和实践笔记(41):API 文档以及String类的常用方法

JDK 8用到的全部类的文档在这里下载&#xff1a; Java Development Kit 8 文档 | Oracle 中国

docker入门(一)—— docker概述

docker 概述 docker 官网&#xff1a;http://www.docker.com 官网文档&#xff1a; https://docs.docker.com/get-docker/ Docker Hub官网&#xff1a;https://hub.docker.com &#xff08;仓库&#xff09; 什么是 docker docker 是一个开源的容器化平台&#xff0c;可以…

FSP40罗德与施瓦茨FSP40频谱分析仪

181/2461/8938产品概述&#xff1a; 频率范围:9千赫至40千兆赫 分辨率带宽:1赫兹至10兆赫 显示的平均噪音水平:-155分贝&#xff08;1赫兹&#xff09; 相位噪声:10 kHz时为-113 dB&#xff08;1Hz&#xff09; 附加滤波器:100 Hz至5 MHz的通道滤波器和RRC滤波器、1 Hz至3…

数据仓库系列总结

一、数据仓库架构 1、数据仓库的概念 数据仓库&#xff08;Data Warehouse&#xff09;是一个面向主题的、集成的、相对稳定的、反映历史变化的数据集合&#xff0c;用于支持管理决策。 数据仓库通常包含多个来源的数据&#xff0c;这些数据按照主题进行组织和存储&#x…

Springboot+vue的仓库管理系统(有报告)。Javaee项目,springboot vue前后端分离项目。

演示视频&#xff1a; Springbootvue的仓库管理系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot vue前后端分离项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层…

Leetcode 62. 不同路径

心路历程&#xff1a; 这道题基本就是Q-learning经典迷宫问题的简化版本&#xff0c;所以肯定是用动态规划了&#xff0c;毕竟RL中的时序差分估计法的本身也是来自于MC和动态规划的结合。如果正常正向思维思考的话&#xff0c;首先看不到问题明显的循环结构&#xff0c;考虑递…

秒级生图,大模型 SDXL-turbo、LCM-SDXL 实战案例来了

最近一个月&#xff0c;快速生图成为文生图领域的热点&#xff0c;其中比较典型的两种方式的代表模型分别为SDXL-turbo 和 LCM-SDXL。 SDXL-turbo 模型是 SDXL 1.0 的蒸馏版本&#xff0c;SDXL-Turbo 基于一种称之为对抗扩散蒸馏&#xff08;ADD&#xff09;的新颖的训练方法&…

Go 1.22 - 更加强大的 Go 执行跟踪

原文&#xff1a;Michael Knyszek - 2024.03.14 runtime/trace 包含了一款强大的工具&#xff0c;用于理解和排查 Go 程序。这个功能可以生成一段时间内每个 goroutine 的执行追踪。然后&#xff0c;你可以使用 go tool trace 命令&#xff08;或者优秀的开源工具 gotraceui&a…

Spring Cloud 整合 GateWay

目录 第一章 微服务架构图 第二章 Spring Cloud整合Nacos集群 第三章 Spring Cloud GateWay 第四章 Spring Cloud Alibaba 整合Sentinel 第五章 Spring Cloud Alibaba 整合SkyWalking链路跟踪 第六章 Spring Cloud Alibaba 整合Seata分布式事务 第七章 Spring Cloud 集成Auth用…

[Qt学习笔记]Release后的exe程序在新的电脑上出现“找不到MSVCP140.dll”的错误

1、背景介绍 我们在打包程序的时候一般都会把相关依赖库整体打包&#xff0c;这样程序在新的电脑和环境下就不需要再去配置对应的环境&#xff0c;但是有时候新程序在一台新的电脑运行时会出现“找不到MSVCP140.dll”这种错误&#xff0c;其原因就是在新电脑的操作系统中缺少一…

倒计时 7 天 | 立即加入 GDE 成长计划,飞跃成为谷歌开发者专家

谷歌开发者专家 (Google Developer Experts&#xff0c;GDE)&#xff0c;又称谷歌开发者专家项目&#xff0c;是由一群经验丰富的技术专家、具有社交影响力的开发者和思想领袖组成的全球性社区。通过在各项活动演讲以及各个平台上发布优质内容来积极助力开发者、企业和技术社区…

AI助手 - 月之暗面 Kimi.ai

前言 这是 AI工具专栏 下的第四篇&#xff0c;这一篇所介绍的AI&#xff0c;也许是截至今天&#xff08;204-03-19&#xff09;国内可访问的实用性最强的一款。 今年年初&#xff0c;一直看到有人推荐 Kimi&#xff0c;不过面对雨后春笋般的各类品质的AI&#xff0c;说实话也有…

windows 多网卡情况dns解析超时问题的排查

最近遇到一个问题 多网卡&#xff0c;多网络环境下&#xff0c;dns解析总是超时。 排查之后发现是dns配置的问题&#xff0c;一个有线网络配置的内网dns&#xff0c;一个无线网络配置的公网dns 访问公网时莫名的时不时出现超时现象 初步排查是dns解析的耗时太长&#xff0c;…

【Go语言】Go语言中的函数

Go语言中的函数 Go语言中&#xff0c;函数主要有三种类型&#xff1a; 普通函数 匿名函数&#xff08;闭包&#xff09; 类方法 1 函数定义 Go语言函数的基本组成包括&#xff1a;关键字func、函数名、参数列表、返回值、函数体和返回语句。Go语言是强类型语言&#xff0…

【MySQL | 第五篇】MySQL事务总结

文章目录 5.MySQL事务5.1什么是事务&#xff1f;5.2什么是数据库事务&#xff1f;5.3数据库事务四大特性5.4并发事务带来的问题及解决方案&#xff1f;5.4.1脏读/不可重复读/幻读5.4.2不可重复读和幻读有什么区别&#xff1f;5.4.3解决并发事务带来的问题&#xff08;1&#xf…

springboot实战笔记

springboot实战笔记 用户模块开发用户登录接口实现根据token获取用户信息检查账号是否可用用户注册接口实现 首页模块开发查询首页分类分页查询首页头条信息查询头条详情 头条模块开发登陆检查头条发布和登录保护拦截器头条根据id回显头条修改头条删除 用户模块开发 用户登录接…