深度学习-2.6在MINST-FASHION上实现神经网络的学习流程

文章目录

  • 在MINST-FASHION上实现神经网络的学习流程
    • 1. 导库
    • 2. 导入数据,分割小批量
    • 3. 定义神经网络
    • 4.定义训练函数
    • 5.进行训练与评估

在MINST-FASHION上实现神经网络的学习流程

现在我们要整合本节课中所有的代码实现一个完整的训练流程。
首先要梳理一下整个流程:

  • 1)设置步长lr,动量值 g a m m a gamma gamma ,迭代次数 e p o c h s epochs epochs , b a t c h _ s i z e batch\_size batch_size等信息,(如果需要)设置初始权重 w 0 w_0 w0

  • 2)导入数据,将数据切分成 b a t c h _ s i z e batch\_size batch_size

  • 3)定义神经网络架构

  • 4)定义损失函数 L ( w ) L(w) L(w),如果需要的话,将损失函数调整成凸函数,以便求解最小值

  • 5)定义所使用的优化算法

  • 6)开始在 e p o c h e s epoches epoches b a t c h batch batch上循环,执行优化算法:

    • 6.1)调整数据结构,确定数据能够在神经网络、损失函数和优化算法中顺利运行;
    • 6.2)完成向前传播,计算初始损失
    • 6.3)利用反向传播,在损失函数 L ( w ) L(w) L(w)上对每一个 w w w求偏导数
    • 6.4)迭代当前权重
    • 6.5)清空本轮梯度
    • 6.6)完成模型进度与效果监控
  • 7)输出结果

1. 导库

这次我们要使用PyTorch中自带的数据,MINST-FATION。

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
#确定数据、确定优先需要设置的值
lr = 0.15
gamma = 0
epochs = 10
bs = 128

2. 导入数据,分割小批量


import torchvision
import torchvision.transforms as transforms#初次运行时会下载,需要等待较长时间
mnist = torchvision.datasets.FashionMNIST(root='C:\Pythonwork\DEEP LEARNING\Datasets\FashionMNIST',train=True, download=True, transform=transforms.ToTensor())len(mnist)#查看特征张量mnist.data#这个张量结构看起来非常常规,可惜的是它与我们要输入到模型的数据结构有差异#查看标签
mnist.targets#查看标签的类别
mnist.classes
#查看图像的模样
import matplotlib.pyplot as plt
plt.imshow(mnist[0][0].view((28, 28)).numpy());plt.imshow(mnist[1][0].view((28, 28)).numpy());#分割batch
batchdata = DataLoader(mnist,batch_size=bs, shuffle = True)
#总共多少个batch?
len(batchdata)
#查看会放入进行迭代的数据结构
for x,y in batchdata:print(x.shape)print(y.shape)breakinput_ = mnist.data[0].numel() #特征的数目,一般是第一维之外的所有维度相乘的数
output_ = len(mnist.targets.unique()) #分类的数目#最好确认一下没有错误input_output_#========================
import torchvision
import torchvision.transforms as transforms
mnist = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=False, transform=transforms.ToTensor())
batchdata = DataLoader(mnist,batch_size=bs, shuffle = True)
input_ = mnist.data[0].numel()
output_ = len(mnist.targets.unique()

3. 定义神经网络

class Model(nn.Module):def __init__(self,in_features=10,out_features=2):super().__init__()#self.normalize = nn.BatchNorm2d(num_features=1)self.linear1 = nn.Linear(in_features,128,bias=False)self.output = nn.Linear(128,out_features,bias=False)def forward(self, x):#x = self.normalize(x)x = x.view(-1, 28*28)#需要对数据的结构进行一个改变,这里的“-1”代表,我不想算,请pytorch帮我计算sigma1 = torch.relu(self.linear1(x))z2 = self.output(sigma1)sigma2 = F.log_softmax(z2,dim=1)return sigma2

4.定义训练函数

def fit(net,batchdata,lr=0.01,epochs=5,gamma=0):criterion = nn.NLLLoss() #定义损失函数opt = optim.SGD(net.parameters(), lr=lr,momentum=gamma) #定义优化算法correct = 0samples = 0for epoch in range(epochs):for batch_idx, (x,y) in enumerate(batchdata):y = y.view(x.shape[0])sigma = net.forward(x)loss = criterion(sigma,y)loss.backward()opt.step()opt.zero_grad()#求解准确率yhat = torch.max(sigma,1)[1]correct + torch. sum Cyhat == y)samples + = x. shape [ o]if (batch_ idx+ 1) % 125 o or batch_ idx = len (batchdata)-1:print( Epocht: [ / (:of] % ) ] tLoss : 6ft Accuracy::.3f].format(epoch+1 , samples ,len( batchdata. dataset) * epochs,100* samples/ ( len (batchdata. dataset)epochs),loss.data.item(),float(correct*100)/samples))

5.进行训练与评估

#实例化神经网络,调用优化算法需要的参数
torch. manualseed(420)
net = Mode ( in_ features= input_ out_features=output_)
fit( net, batchdata, lr= lr, epochs= epochs, gamma=gamma)

我们现在已经完成了一个最基本的、神经网络训练并查看训练结果的代码。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/746935.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用try...catch进行判断

在写一些提交数据的判断上,有时候会写下面的ifelse的判断方法,少一点还好,多的话就很难受也不好看。 if(!that.driverObj.contrary){this.__utils.showToast(请先上传驾驶证副页图片);return false } if(!this.driverObj.start){this.__util…

鸿蒙Harmony应用开发—ArkTS声明式开发(容器组件:Flex)

以弹性方式布局子组件的容器组件。 说明: 该组件从API Version 7开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。Flex组件在渲染时存在二次布局过程,因此在对性能有严格要求的场景下建议使用Column、Row代替。Flex组…

Vue3基础笔记(1)模版语法 属性绑定 渲染

Vue全称Vue.js是一种渐进式的JavaScript框架,采用自底向上增量开发的设计,核心库只关注视图层。性能丰富,完全有能力驱动采用单文件组件和Vue生态系统支持的库开发的复杂单页应用,适用于场景丰富的web前端框架。灵活性和可逐步集成…

149.乐理基础-七和弦的第一转位、第二转位、第三转位

内容参考于:三分钟音乐社 上一个内容:148.常用的7个七和弦结构与简称 上一个内容里练习的答案: 前置内容:必须看过 140.音程的转位 和 146.三和弦的第一转位、第二转位这两个 现在还是狭义上、理论上的转位,下面用C…

深度学习专家学习计划

深度学习专家学习计划 一、学习背景与目标 作为一名有6年工作经验的Java开发人员,您已具备基本的编程能力和数据处理经验。现计划转岗至深度学习领域,成为深度学习专家。本计划将结合您的工作背景和现有知识,为您制定详细且精确的学习计划,帮助您逐步达到专家水平。 二、…

高校实验室科研仪器开放共享存在的问题及对策建议

随着科技的迅速发展和高校科研水平的提高,高校实验室科研仪器的开放共享已经成为推动科研进步和创新发展的重要手段。然而,在实际操作中,我们也面临着诸多问题和挑战。本文将分析高校实验室科研仪器开放共享存在的问题,并提出相应…

【前言】神经网络与深度学习简介

如果您已经了解过神经网络与深度学习,请直接跳转到第一章学习 概念: 神经网络,一种基于生物启发式编程范式,它使计算机能够从观测数据中学习 深度学习,一套用于神经网络学习的强大技术集合 简介 神经网络和深度学习…

c# MD5加密函数

/// <summary> /// 对字符串进行MD5运算 /// </summary> /// <param name"str"></param> /// <returns></returns> public static string GetMd5String(string str) { …

杂七杂八111

MQ 用处 一、异步。可提高性能和吞吐量 二、解耦 三、削峰 四、可靠。常用消息队列可以保证消息不丢失、不重复消费、消息顺序、消息幂等 选型 一Kafak:吞吐量最大&#xff0c;性能最好&#xff0c;集群高可用。缺点&#xff1a;会丢数据&#xff0c;功能较单一。 二Ra…

javaEE7

1. <% page pageEncoding"UTF-8"%><% page import"java.io.*"%> <% page import"java.util.*"%> <% page import"java.math.*"%> <html> <head><title>网站计数器</title></head&…

【软件测试基础篇】第二节.黑盒测试中常见方法

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言⼀、等价类法&#xff08;解决穷举问题&#xff09;二、边界值法&#xff08;解决边界限制问题&#xff09;三、正交表法&#xff08;解决多条件依赖问题&#…

媒体发稿:澳门媒体发稿7个流程

推广平台澳门是一个重要的度假旅游娱乐终点&#xff0c;都是媒体领域热议的话题。对于澳门的媒体发稿营销推广要求&#xff0c;大家提供了一个简单易用的套餐系统软件&#xff0c;帮助大家在澳门媒体上发表推广文章。下面我们就根据7个阶段&#xff0c;详解构建这一套餐推广平台…

Python如何处理拥塞控制

拥塞控制是计算机网络中用于防止网络拥塞&#xff08;即过多的数据导致网络性能下降&#xff09;的一系列技术和算法。在Python中&#xff0c;处理拥塞控制通常不直接涉及到代码层面的实现&#xff0c;因为拥塞控制主要是在网络协议栈&#xff08;如TCP/IP&#xff09;和操作系…

echarts tooltip提示组件框自定义浮窗内容

echarts tooltip提示组件框自定义浮窗内容 tooltip提示组件框 有三种浮窗展示方法 第一种&#xff1a;默认展示 第二种&#xff1a;字符串模板 第三种&#xff1a;回调函数 第二种 formatter&#xff08;字符串模板&#xff09; 模板变量有 {a}, {b}&#xff0c;{c}&#xff0…

C++ 作业 24/3/14

1、成员函数版本实现算术运算符的重载&#xff1b;全局函数版本实现算术运算符的重载 #include <iostream>using namespace std;class Test {friend const Test operator-(const Test &L,const Test &R); private:int c;int n; public:Test(){}Test(int c,int n…

LeetCode 热题 100 | 回溯(二)

目录 1 39. 组合总和 2 22. 括号生成 3 79. 单词搜索 菜鸟做题&#xff0c;语言是 C&#xff0c;感冒快好版 关于对回溯算法的理解请参照我的上一篇博客&#xff1b; 在之后的博客中&#xff0c;我将只分析回溯算法中的 for 循环。 1 39. 组合总和 题眼&#xff1a;c…

VBA_MF系列技术资料1-400

MF系列VBA技术资料1-400 为了让广大学员在VBA编程中有切实可行的思路及有效的提高自己的编程技巧&#xff0c;我参考大量的资料&#xff0c;并结合自己的经验总结了这份MF系列VBA技术综合资料&#xff0c;而且开放源码&#xff08;MF04除外&#xff09;&#xff0c;其中MF01-0…

python爬虫(11)之BeautifulSoup模块

1、模块介绍 所谓BeautifulSoup模块是通过html源代码进行筛选类似于正则表达式那种类型 2、代码 import os import requests from bs4 import BeautifulSoup from PIL import Image from io import BytesIOheaders {Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit…

Java复习02 IO流

Java复习02 IO流 首先&#xff0c;“IO”在计算机里面代表的是“输入/输出”&#xff08;Input / Output&#xff09;&#xff0c;简单来说&#xff0c;就是计算机与外部世界进行数据交流的过程。比如&#xff0c;你在键盘上敲字&#xff0c;数据就输入到计算机里了&#xff0…

深入理解Spring的ApplicationContext:案例详解与应用

深入理解Spring的ApplicationContext&#xff1a;案例详解与应用 在Spring框架的丰富生态中&#xff0c;ApplicationContext扮演着至关重要的角色。作为BeanFactory的扩展&#xff0c;ApplicationContext不仅继承了其所有功能&#xff0c;还引入了更多高级特性&#xff0c;使得…