【深度学习与神经网络】MNIST手写数字识别1

简单的全连接层

导入相应库

import torch
import numpy as np
from torch import nn,optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

读入数据并转为tensor向量

# 训练集
# 转为tensor数据
train_dataset = datasets.MNIST(root='./',train=True, transform = transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./',train=False, transform = transforms.ToTensor(), download=True)

装载数据集

# 批次大小
batch_size = 64# 装载训练集
train_loader = DataLoader(dataset = train_dataset, batch_size=batch_size, shuffle = True)
test_loader = DataLoader(dataset = test_dataset, batch_size=batch_size, shuffle = True)

定义网络结构
一层全连接网络,最后使用softmax转概率值输出

# 定义网络结构
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(784, 10)self.softmax = nn.Softmax(dim =1)def forward(self, x):# [64,1,28,28] ——> [64, 784]x = x.view(x.size()[0], -1)x = self.fc1(x)x = self.softmax(x)return x   

定义模型
使用均方误差损失函数,梯度下降优化

# 定义模型
model = Net()
mes_loss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),0.5)

训练并测试网络:
训练时注意最后输出(64,10)
标签是(64) ,需要将其转为one-hot编码(64,10)

def train():for i,data in enumerate(train_loader):# 获得一个批次的数据和标签inputs, labels = data# 获得模型结果 (64,10)out = model(inputs)# to one-hot 把数据标签变为独热编码labels = labels.reshape(-1,1)one_hot = torch.zeros(inputs.shape[0],10).scatter(1, labels, 1)# 计算lossloss = mes_loss(out, one_hot)# 梯度清0optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test():correct = 0for i,data in enumerate(test_loader):# 获得一个批次的数据和标签inputs, labels = data# 获得模型结果 (64,10)out = model(inputs)# 获取最大值和最大值所在位置_,predicted = torch.max(out,1)# 预测正确数量correct += (predicted == labels).sum()print("test ac:{0}".format(correct.item()/len(test_dataset)))

调用模型 训练10次

# 使用mse损失函数 
for epoch in range(10):print("epoch:",epoch)train()test()

训练结果:
在这里插入图片描述
准确率不够

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/757337.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql主从及备份

1.由于centOS7中默认安装了MariaDB,需要先进行卸载 rpm -qa | grep -i mariadb rpm -e --nodeps mariadb-libs-5.5.64-1.el7.x86_64查询下本机mysql是否卸载干净,若有残留也需要卸载 rpm -qa | grep mysql2.下载MySQL仓库并安装 wget https://repo.mysql.com//my…

C++: 多态实现原理解析

文章目录 1 静态多态实现原理 2 动态多态实现原理 code 1 静态多态 实现 函数重载,在编译器确定 函数重载的条件: 函数名相同参数个数不同,参数的类型不同,参数顺序不同返回值类型,不作为重载的标准 原理 函数名…

Android 开机启动的核心系统服务:你了解了吗?

Android 核心系统服务 目录 Android 核心系统服务 1. ActivityManagerService(活动管理服务,简称AMS) 2. PackageManagerService(包管理服务,简称PMS) 3. WindowManagerService(窗口管理服务…

[Python人工智能] 四十三.命名实体识别 (4)利用bert4keras构建Bert+BiLSTM-CRF实体识别模型

从本专栏开始,作者正式研究Python深度学习、神经网络及人工智能相关知识。前文讲解如何实现中文命名实体识别研究,构建BiGRU-CRF模型实现。这篇文章将继续以中文语料为主,介绍融合Bert的实体识别研究,使用bert4keras和kears包来构建Bert+BiLSTM-CRF模型。然而,该代码最终结…

vue2 实战:模板模式与渲染模式代码互切

显示效果 模板模式 <template><tr ><td class"my-td" v-if"element.isInsert1"><el-button type"danger" circle size"mini" class"delete-btn" title"删除" click"deleteItem()&quo…

KMM初探

什么是KMM&#xff1f; 在开始使用 KMM 之前&#xff0c;您需要了解 Kotlin。 KMM 全称&#xff1a;Kotlin Multiplatform Mobile&#xff09;是一个用于跨平台移动开发的 SDK,相比于其他跨平台框架&#xff0c;KMM是原生UI逻辑共享的理念,由KMM封装成Android(Kotlin/JVM)的aar…

AI大模型智能大气科学探索之:ChatGPT在大气科学领域建模、数据分析、可视化与资源评估中的高效应用及论文写作

本文深度探讨人工智能在大气科学中的应用&#xff0c;特别是如何结合最新AI模型与Python技术处理和分析气候数据。课程介绍包括GPT-4等先进AI工具&#xff0c;旨在帮助大家掌握这些工具的功能及应用范围。本文内容覆盖使用GPT处理数据、生成论文摘要、文献综述、技术方法分析等…

Http的缓存有哪些

HTTP 缓存可以通过多种 HTTP 头部字段来控制&#xff0c;主要包括以下几种&#xff1a; 1.Expires&#xff1a;这个字段定义了响应的过期时间。如果当前时间小于 Expires 的时间&#xff0c;那么就可以直接使用缓存。 2.Cache-Control&#xff1a;这个字段是一个指令&#xff…

Java 学习和实践笔记(41):API 文档以及String类的常用方法

JDK 8用到的全部类的文档在这里下载&#xff1a; Java Development Kit 8 文档 | Oracle 中国

docker入门(一)—— docker概述

docker 概述 docker 官网&#xff1a;http://www.docker.com 官网文档&#xff1a; https://docs.docker.com/get-docker/ Docker Hub官网&#xff1a;https://hub.docker.com &#xff08;仓库&#xff09; 什么是 docker docker 是一个开源的容器化平台&#xff0c;可以…

C语言经典面试题目(十六)

1、什么是C语言中的指针常量和指针变量&#xff1f;它们有什么区别&#xff1f; 在C语言中&#xff0c;指针常量和指针变量是指针的两种不同类型。它们的区别在于指针的指向和指针本身是否可以被修改。 指针常量&#xff1a;指针指向的内存地址不可变&#xff0c;但指针本身的…

FSP40罗德与施瓦茨FSP40频谱分析仪

181/2461/8938产品概述&#xff1a; 频率范围:9千赫至40千兆赫 分辨率带宽:1赫兹至10兆赫 显示的平均噪音水平:-155分贝&#xff08;1赫兹&#xff09; 相位噪声:10 kHz时为-113 dB&#xff08;1Hz&#xff09; 附加滤波器:100 Hz至5 MHz的通道滤波器和RRC滤波器、1 Hz至3…

数据仓库系列总结

一、数据仓库架构 1、数据仓库的概念 数据仓库&#xff08;Data Warehouse&#xff09;是一个面向主题的、集成的、相对稳定的、反映历史变化的数据集合&#xff0c;用于支持管理决策。 数据仓库通常包含多个来源的数据&#xff0c;这些数据按照主题进行组织和存储&#x…

在Qt中使用线程类QThread

说明 QThread是qt中的一个线程类。目前我了解到的共有两种用法&#xff0c;一种是作为普通的线程&#xff0c;就像c标准库中的std::thread一样&#xff0c;另一种就是作为信号槽的容器&#xff0c;负责调用qt的事件循环。 作为普通线程 重载QThread::run()这个虚函数&#x…

深度学习基础之《TensorFlow框架(7)—变量》

一、什么是变量 1、TensorFlow变量是表示程序处理的共享持久状态的最佳方法。变量通过tf.Variable OP类进行操作 这里的变量和传统认知里存储值或者返回值不一样&#xff0c;他是TensorFlow里的一个组件 2、变量的特点 &#xff08;1&#xff09;存储持久化 把程序中定义的数…

Springboot+vue的仓库管理系统(有报告)。Javaee项目,springboot vue前后端分离项目。

演示视频&#xff1a; Springbootvue的仓库管理系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot vue前后端分离项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层…

Leetcode 62. 不同路径

心路历程&#xff1a; 这道题基本就是Q-learning经典迷宫问题的简化版本&#xff0c;所以肯定是用动态规划了&#xff0c;毕竟RL中的时序差分估计法的本身也是来自于MC和动态规划的结合。如果正常正向思维思考的话&#xff0c;首先看不到问题明显的循环结构&#xff0c;考虑递…

秒级生图,大模型 SDXL-turbo、LCM-SDXL 实战案例来了

最近一个月&#xff0c;快速生图成为文生图领域的热点&#xff0c;其中比较典型的两种方式的代表模型分别为SDXL-turbo 和 LCM-SDXL。 SDXL-turbo 模型是 SDXL 1.0 的蒸馏版本&#xff0c;SDXL-Turbo 基于一种称之为对抗扩散蒸馏&#xff08;ADD&#xff09;的新颖的训练方法&…

Go 1.22 - 更加强大的 Go 执行跟踪

原文&#xff1a;Michael Knyszek - 2024.03.14 runtime/trace 包含了一款强大的工具&#xff0c;用于理解和排查 Go 程序。这个功能可以生成一段时间内每个 goroutine 的执行追踪。然后&#xff0c;你可以使用 go tool trace 命令&#xff08;或者优秀的开源工具 gotraceui&a…

c++11 标准模板(STL)本地化库 - std::iscntrl(std::locale) 检查字符是否被本地环境分类为控制字符

本地化库 本地环境设施包含字符分类和字符串校对、数值、货币及日期/时间格式化和分析&#xff0c;以及消息取得的国际化支持。本地环境设置控制流 I/O 、正则表达式库和 C 标准库的其他组件的行为。 检查字符是否被本地环境分类为控制字符 std::iscntrl(std::locale) templa…