python-pytorch基础之神经网络分类

这里写目录标题

  • 生成数据函数
  • 定义数据集
  • 定义loader加载数据
  • 定义神经网络模型
    • 测试输出是否为2个
    • 输入数据,输出结果
  • 训练模型函数
    • 计算正确率
  • 训练数据并保存模型
  • 测试模型
    • 准备数据
    • 加载模型预测
    • 对比结果

生成数据函数

import random
def get_rectangle():width=random.random()height=random.random()# 如果width大于height就是1,否则就是0fat=int(width>=height)return width,height,fatget_rectangle()
(0.19267437580138802, 0.645061020860627, 0)
fat=int(0.8>=0.9)
fat
0

定义数据集

import torch
class Dataset(torch.utils.data.Dataset):def __init__(self):passdef __len__(self):return 1000def __getitem__(self,i):width,height,fat=get_rectangle()# 这里注意width和height列表形式再转为tensor,另外是floattensorx=torch.FloatTensor([width,height])y=fatreturn x,y
dataset=Dataset()
# 这里没执行一次都要变,有什么意义?
len(dataset),dataset[999]
(1000, (tensor([0.4756, 0.1713]), 1))

定义loader加载数据

# 定义了数据集,然后使用loader去加载,需要记住batch_size,shuffle,drop_last这个三个常用的
loader=torch.utils.data.DataLoader(dataset=dataset,batch_size=9,shuffle=True,drop_last=True)# 加载完了以后,就可以使用next去遍历了
len(loader),next(iter(loader))
(111,[tensor([[0.1897, 0.6766],[0.2460, 0.2725],[0.5871, 0.7739],[0.3035, 0.9607],[0.7006, 0.7421],[0.4279, 0.9501],[0.6750, 0.1704],[0.5777, 0.1154],[0.5512, 0.3933]]),tensor([0, 0, 0, 0, 0, 0, 1, 1, 1])])
# 输出结果:
# (111,
#  [tensor([[0.3577, 0.3401],
#           [0.0156, 0.7550],
#           [0.0435, 0.4984],
#           [0.1329, 0.5488],
#           [0.4330, 0.5362],
#           [0.1070, 0.8500],
#           [0.1073, 0.2496],
#           [0.1733, 0.0226],
#           [0.6790, 0.2119]]),
#   tensor([1, 0, 0, 0, 0, 0, 0, 1, 1])])  这里是torch自动将y组合成了一个tensor

定义神经网络模型

class Model(torch.nn.Module):def __init__(self):super().__init__()# 定义神经网络结构self.fb=torch.nn.Sequential(# 第一层输入两个,输出32个torch.nn.Linear(in_features=2,out_features=32),# 激活层的意思是小于0的数字变为0,即负数归零操作torch.nn.ReLU(),# 第二层,输入是32个,输出也是32个torch.nn.Linear(in_features=32,out_features=32),# 激活torch.nn.ReLU(),# 第三次,输入32个,输出2个torch.nn.Linear(in_features=32,out_features=2),# 激活,生成的两个数字相加等于1torch.nn.Softmax(dim=1))# 定义网络计算过程def forward(self,x):return self.fb(x)model=Model()

测试输出是否为2个

# 测试 8行2列的数据,让模型测试,看是否最后输出是否也是8行一列?model(torch.rand(8,2)).shape
torch.Size([8, 2])

输入数据,输出结果

model(torch.tensor([[0.3577, 0.3401]]))
tensor([[0.5228, 0.4772]], grad_fn=<SoftmaxBackward>)

训练模型函数

def train():#定义优化器,le-4表示0.0004,10的负四次方opitimizer=torch.optim.Adam(model.parameters(),lr=1e-4)#定义损失函数 ,一般回归使用mseloss,分类使用celossloss_fn=torch.nn.CrossEntropyLoss()#然后trainmodel.train()# 全量数据遍历100轮for epoch in range(100):# 遍历loader的数据,循环一次取9条数据,这里注意unpack时,x就是【width,height】for i,(x,y) in enumerate(loader):out=model(x)# 计算损失loss=loss_fn(out,y)# 计算损失的梯度loss.backward()# 优化参数opitimizer.step()# 梯度清零opitimizer.zero_grad()# 第二十轮的时候打印一下数据if epoch % 20 ==0:# 正确率# out.argmax(dim=1) 表示哪个值大就说明偏移哪一类,dim=1暂时可以看做是固定的# (out.argmax(dim=1)==y).sum() 表示的true的个数acc=(out.argmax(dim=1)==y).sum().item()/len(y)print(epoch,loss.item(),acc)torch.save(model,"4.model")

计算正确率

执行的命令:

print(out,“======”,y,“+++++”,(out.argmax(dim=1)==y).sum().item())

print(out.argmax(dim=1),“------------”,(out.argmax(dim=1)==y),“~~~~~~~~~~~”,(out.argmax(dim=1)==y).sum())

输出的结果:

tensor([[9.9999e-01, 1.4671e-05],

    [4.6179e-14, 1.0000e+00],        [3.2289e-02, 9.6771e-01],        [1.1237e-22, 1.0000e+00],[9.9993e-01, 7.0015e-05],[8.6740e-02, 9.1326e-01],[1.1458e-18, 1.0000e+00],[5.2558e-01, 4.7442e-01],[9.7923e-01, 2.0772e-02]],         grad_fn=<SoftmaxBackward>) ====== tensor([0, 1, 1, 1, 0, 1, 1, 1, 0]) +++++ 8

tensor([0, 1, 1, 1, 0, 1, 1, 0, 0]) ------------ tensor([ True, True, True, True, True, True, True, False, True]) ~~~~~~~~~~~ tensor(8)

解释:

out.argmax(dim=1) 表示哪个值大就说明偏移哪一类,dim=1暂时可以看做是固定的

(out.argmax(dim=1)==y).sum() 表示的true的个数

 a = torch.tensor([[1,2,3],[4,7,6]])d = a.argmax(dim=1)
print(d)
tensor([2, 1])

训练数据并保存模型

train()
80 0.3329530954360962 1.0
80 0.31511250138282776 1.0
80 0.33394935727119446 1.0
80 0.3242819309234619 1.0
80 0.3188716471195221 1.0
80 0.3405844569206238 1.0
80 0.32696405053138733 1.0
80 0.3540787696838379 1.0
80 0.3390745222568512 1.0
80 0.3645476996898651 0.8888888888888888
80 0.3371085822582245 1.0
80 0.31789034605026245 1.0
80 0.31553390622138977 1.0
80 0.3162603974342346 1.0
80 0.35249051451683044 1.0
80 0.3582523465156555 1.0
80 0.3162645995616913 1.0
80 0.37988030910491943 1.0
80 0.34384390711784363 1.0
80 0.31773826479911804 1.0
80 0.3145104646682739 1.0
80 0.31753242015838623 1.0
80 0.3222736120223999 1.0
80 0.38612237572669983 1.0
80 0.35490038990974426 1.0
80 0.34469687938690186 1.0
80 0.34534531831741333 1.0
80 0.31800928711891174 1.0
80 0.34892910718917847 1.0
80 0.33424195647239685 1.0
80 0.37350085377693176 1.0
80 0.3298128843307495 1.0
80 0.3715909719467163 1.0
80 0.3507140874862671 1.0
80 0.33337005972862244 1.0
80 0.3134789764881134 1.0
80 0.35244104266166687 1.0
80 0.3148314654827118 1.0
80 0.3376845419406891 1.0
80 0.3315282464027405 1.0
80 0.3450225591659546 1.0
80 0.3139556646347046 1.0
80 0.34932857751846313 1.0
80 0.3512738049030304 1.0
80 0.3258627951145172 1.0
80 0.3197799324989319 1.0
80 0.358166366815567 0.8888888888888888
80 0.3716268837451935 1.0
80 0.31426626443862915 1.0
80 0.32130196690559387 1.0
80 0.3207002282142639 1.0
80 0.3891155421733856 1.0
80 0.35045987367630005 1.0
80 0.32332736253738403 1.0
80 0.31951677799224854 1.0
80 0.3184094727039337 1.0
80 0.3341224491596222 1.0
80 0.3408585786819458 1.0
80 0.3139263093471527 1.0
80 0.33058592677116394 1.0
80 0.3134475648403168 1.0
80 0.3281571567058563 1.0
80 0.33370518684387207 1.0
80 0.33172252774238586 1.0
80 0.32849007844924927 1.0
80 0.3604048788547516 1.0
80 0.3651810884475708 1.0

测试模型

准备数据

# 准备数据
x,fat=next(iter(loader))
x,fat
(tensor([[0.6733, 0.4044],[0.6503, 0.0303],[0.9353, 0.9518],[0.4145, 0.6948],[0.9560, 0.8009],[0.6331, 0.0852],[0.5510, 0.8283],[0.1402, 0.2726],[0.3257, 0.8351]]),tensor([1, 1, 0, 0, 1, 1, 0, 0, 0]))

加载模型预测

# 加载模型
modell=torch.load("4.model")
# 使用模型
out=modell(x)
out
tensor([[1.5850e-04, 9.9984e-01],[1.4121e-06, 1.0000e+00],[7.1068e-01, 2.8932e-01],[9.9994e-01, 5.5789e-05],[4.1401e-03, 9.9586e-01],[3.3441e-06, 1.0000e+00],[9.9995e-01, 4.7039e-05],[9.9111e-01, 8.8864e-03],[1.0000e+00, 1.5224e-06]], grad_fn=<SoftmaxBackward>)
out.argmax(dim=1)
tensor([1, 1, 0, 0, 1, 1, 0, 0, 0])

对比结果

fat
tensor([1, 1, 0, 0, 1, 1, 0, 0, 0])
查看上面out的结果和fat的结论一致,不错

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/14116.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Pytorch气温预测实战

数据集 数据有8个特征&#xff0c;一个标签值 自变量因变量yearactual:当天的真实最高温度monthdayweek:星期几temp_1:昨天的最高温度temp_2:前天的最高温度值average:在历史中&#xff0c;每年这一天的平均最高温度friend:朋友猜测的温度 year,month,day,week,temp_2,temp_…

WPF基础知识

WPF布局基础原则 1.一个窗口中只能包含一个元素 2. 不应该显示设置元素尺寸 3. 不应使用坐标设置元素的位置 4.可以嵌套布局容器WPF布局容器 StackPanel:水平或垂直排列元素&#xff0c;Orientation属性分别为&#xff1a;Horizontal/VerticalWrapPanel:水平或垂直排列元素、…

安科瑞能源物联网以能源供应、能源管理、设备管理、能耗分析的能源流向为主线-安科瑞黄安南

摘要&#xff1a;随着科学技术的发展&#xff0c;我国的物联网技术有了很大进展。为了提升电力抄表服务的稳定性&#xff0c;保障电力抄表数据的可靠性&#xff0c;本文提出并实现了基于物联网的智能电力抄表服务平台&#xff0c;结合云计算、大数据等技术&#xff0c;提供电力…

Codeforces Round 888 (Div. 3)(视频讲解全部题目)

[TOC](Codeforces Round 888 (Div. 3)&#xff08;视频讲解全部题目&#xff09;) Codeforces Round 888 (Div. 3)&#xff08;A–G&#xff09;全部题目详解 A Escalator Conversations #include<bits/stdc.h> #define endl \n #define INF 0x3f3f3f3f using namesp…

mars3d绘制区域范围(面+边框)

1、图例&#xff08;绿色面区域白色边框&#xff09; 2、代码 1&#xff09;、绘制区域ts文件 import { mapLayerCollection } from /hooks/cesium-map-init /*** 安全防護目標* param map*/ export const addSafetyProtection async (map) > {const coverDatas await m…

20230729 git github gitee

1.gitee与gitHub概念&#xff1f; Gitee&#xff08;码云&#xff09;是开源中国社区推出的代码托管协作开发平台&#xff0c;支持Git和SVN&#xff0c;提供免费的私有仓库托管。Gitee专为开发者提供稳定、高效、安全的云端软件开发协作平台&#xff0c;无论是个人、团队、或是…

6个ChatGPT4的最佳用途

文章目录 ChatGPT 4’s Current Limitations ChatGPT 4 的当前限制1. Crafting Complex Prompts 制作复杂的提示2. Logic Problems 逻辑问题3. Verifying GPT 3.5 Text 验证 GPT 3.5 文本4. Complex Coding 复杂编码5.Nuanced Text Transformation 细微的文本转换6. Complex Kn…

web前端如何判断自己当前的职业等级

初级前端开发工程师 熟练使用HTML标签&#xff0c;对HTML标签特性有一定理解。对HTML语义话有一定了解&#xff1b;熟练使用CSS属性及选择器&#xff0c;能使用一些CSShack。对模块化和栅格化布局有一定的了解。能独立使用JS完成一些简单的需求&#xff0c;至少能使用一种前端框…

P5727 【深基5.例3】冰雹猜想

【深基5.例3】冰雹猜想 题目描述 给出一个正整数 n n n&#xff0c;然后对这个数字一直进行下面的操作&#xff1a;如果这个数字是奇数&#xff0c;那么将其乘 3 3 3 再加 1 1 1&#xff0c;否则除以 2 2 2。经过若干次循环后&#xff0c;最终都会回到 1 1 1。经过验证很…

Windows下安装HBase

Windows下安装HBase 一、HBase简介二、HBase下载安装包三、环境准备3.1、 JDK的安装3.2、 Hadoop的安装 四、HBase安装4.1、压缩包解压为文件夹4.2、配置环境变量4.3、%HBASE_HOME%目录下新建临时文件夹4.4、修改配置文件 hbase-env.cmd4.4.1、配置JAVA环境4.4.2、set HBASE_MA…

使用Hutool工具类中的BeanUtil.fillBeanWithMap方法报错`DateException`

使用Hutool工具类中的BeanUtil.fillBeanWithMap方法报错DateException 问题背景 在实现登录功能时&#xff0c;我先将用户信息存入Redis中&#xff0c;然后再获取用户信息的时候&#xff0c;又取出来。我存入Redis的用户信息是Hash格式的&#xff0c;所以取出来的时候&#xff…

Ansible的脚本 --- playbook 剧本

文章目录 一、playbook剧本的组成创建剧本运行playbook二、定义、引用变量三、指定远程主机sudo切换用户四、when条件判断五、迭代Templates 模块tags 模块 一、playbook剧本的组成 playbooks 本身由以下各部分组成 &#xff08;1&#xff09;Tasks&#xff1a;任务&#xff0…

kubernetes 证书更新

参考&#xff1a; https://kubernetes.io/zh-cn/docs/tasks/administer-cluster/kubeadm/kubeadm-certs/https://kubernetes.io/zh-cn/docs/tasks/tls/certificate-rotation/ 查看证书 查看 kubelet是否支持证书自动轮换&#xff0c;默认轮换的证书位于目录 /var/lib/kubele…

vscode 打开文件时如何在资源管理器中展开文件所在的整个目录树(包含node_modules)

如题。去 首选项 --> 设置 中 搜索 “Auto Reveal”&#xff0c;然后选true&#xff0c;注意把下面的Auto Reveal Exclude排除项中的node_modules去掉&#xff0c;这样才能定位到node_modules中的文件。 **/node_modules

正则,JS:this,同步异步,原型链笔记整理

一 正则表达式 正则表达式&#xff08;regular expression&#xff09;是一种表达文本模式&#xff08;即字符串结构&#xff09;的方法&#xff0c;有点像字符串的模板&#xff0c;常常用来按照“给定模式”匹配文本 正则表达式可以用于以下常见操作&#xff1a; 匹配&…

Leetcode刷题---C语言实现初阶数据结构---单链表

1 删除链表中等于给定值 val 的所有节点 删除链表中等于给定值 val 的所有节点 给你一个链表的头节点head和一个整数val&#xff0c;请你删除链表中所有满足Node.valval的节点&#xff0c;并返回新的头节点 输入&#xff1a;head [1,2,6,3,4,5,6], val 6 输出&#xff1a;[…

基于 ThinkPHP 5.1(稳定版本) 开发wms 进销存系统源码

基于ThinkPHP 5.1&#xff08;LTS版本&#xff09;开发的WMS进销存系统源码 管理员账号密码&#xff1a;admin 一、项目简介 这个系统是一个基于ThinkPHP框架的WMS进销存系统。 二、实现功能 控制台 – 权限管理&#xff08;用户管理、角色管理、节点管理&#xff09; – 订…

Docker 入门终极指南[详细]

前言 富 Web 时代&#xff0c;应用变得越来越强大&#xff0c;与此同时也越来越复杂。集群部署、隔离环境、灰度发布以及动态扩容缺一不可&#xff0c;而容器化则成为中间的必要桥梁。 本节我们就来探索一下 Docker 的神秘世界&#xff0c;从零到一掌握 Docker 的基本原理与实…

【Golang】Golang进阶系列教程--Go 语言 new 和 make 关键字的区别

文章目录 前言new源码使用 make源码使用 总结 前言 本篇文章来介绍一道非常常见的面试题&#xff0c;到底有多常见呢&#xff1f;可能很多面试的开场白就是由此开始的。那就是 new 和 make 这两个内置函数的区别。 在 Go 语言中&#xff0c;有两个比较雷同的内置函数&#xf…

忽略nan值,沿指定轴计算标准(偏)差numpy.nanstd()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 沿指定轴方向 计算标准(偏)差 numpy.nanstd() [太阳]选择题 import numpy as np a np.array([[1,2],[np.nan,3]]) print("【显示】a ") print(a) print("【执行】np.std(a)&qu…