鸢尾花分类-pytorch实现

前言

本文用pytorch实现了鸢尾花分类,数据不多,只做代码展示用,后续有升级版本。

代码

'''
-*- coding: utf-8 -*-
@File  : main.py
@Author: Shanmh
@Time  : 2024/05/06 上午9:37
@Function:
'''
import torch
from sklearn import datasets
import torch.nn as nn#1.数据准备
dataset=datasets.load_iris()
print(dataset["data"][:10])
print(dataset["target"][:10])
i_data=torch.FloatTensor(dataset["data"])
i_target=torch.LongTensor(dataset["target"])#2.模型构建
class IrisModel(nn.Module):def __init__(self,input_n=4,hidden_n=20,output_n=3):super().__init__()self.line1=nn.Linear(input_n,hidden_n)self.line2=nn.Linear(hidden_n,output_n)self.relu=nn.ReLU()def forward(self,x):x=self.line1(x)x=self.relu(x)x=self.line2(x)return x#3.参数定义
epoch=500
lr=0.01model=IrisModel()
optimizer=torch.optim.SGD(model.parameters(),lr=lr) #定义优化器
loss_fun=torch.nn.CrossEntropyLoss() #多分类采用交叉熵损失函数for e in range(epoch):out=model(i_data)loss=loss_fun(out,i_target)optimizer.zero_grad()  # 梯度清零loss.backward()  # 前馈操作optimizer.step()# 5. 得出结果
out = model(i_data)
prediction = torch.max(out, 1)[1]
pred_y = prediction.data.numpy()
target_y = i_target.data.numpy()
result=pred_y==target_y
print(f"模型预测准确度,acc:{'{:.2f}'.format(len(result[result==True])/len(result))}%")

展望

1.还在考虑中怎么进行建模,建一个4维空间用来直接看出输入与输出的关系

2.有尝试过标签平滑,从结果上看不出什么区别,再想怎么可视化出来

3.怎么从结果倒推出可用的输入数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/9499.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Spring】Spring 整合 Junit、MyBatis

一、 Spring 整合 Junit <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache…

Sql Server 2016数据库定时备份

一、 配置备份计划任务 选中“维护计划“--右键--“维护计划向导” 完成

遥控挖掘机之ESP8266调试心得(1)

ESP8266调试心得 1. 前言2.遇到的问题2.1 ESP8266模块建立TCP连接时候报错2.2 指令异常问题 3. 更新ESP8266固件3. ESP8266的部分AT指令3. 连接步骤3.1 模块与电脑连接3.2.1 电脑上的设置3.2.2 ESP8266模块作为客户机&#xff08;TCP Cilent&#xff09;的设置步骤 3.2 模块与模…

从开发角度理解漏洞成因(02)

文章目录 文件上传类需求文件上传漏洞 文件下载类需求文件下载漏洞 扩展 留言板类&#xff08;XSS漏洞&#xff09;需求XSS漏洞 登录类需求cookie伪造漏洞万能密码登录 持续更新中… 文章中代码资源已上传资源&#xff0c;如需要打包好的请点击PHP开发漏洞环境&#xff08;SQL注…

贝塞尔曲线 java

参考文章&#xff1a; 理解贝塞尔曲线https://blog.csdn.net/weixin_42301220/article/details/125167672 代码实现参考 https://blog.csdn.net/yinhun2012/article/details/118653732 贝塞尔 一二三阶java代码实现,N阶段可以通过降阶递归实现 public class Test extends JPan…

java选择结构语句

文章目录 Java选择结构语句的几种形式1. **if 单选择结构**:2. **if-else 双选择结构**:3. **if-else if 多选择结构**:4. **switch 选择结构**: Java 12及更高版本的Switch Expressions返回值的Switch表达式yield关键字使用Switch作为语句或表达式 Pattern Matching for insta…

Final Draft 12 for Mac:高效专业剧本创作软件

对于剧本创作者来说&#xff0c;一款高效、专业的写作工具是不可或缺的。Final Draft 12 for Mac就是这样一款完美的选择。这款专为Mac用户设计的剧本创作软件&#xff0c;凭借其卓越的性能和丰富的功能&#xff0c;让您的剧本创作更加得心应手。 Final Draft 12支持多种剧本格…

【C++】CentOS环境搭建-编译安装Boost库(附CMAKE编译文件)

【C】环境搭建-编译安装Boost库 Boost库简介Boost库安装通过YUM安装&#xff08;版本较低 V1.53.0&#xff09;通过编译安装&#xff08;官网最新版本1.85.0&#xff09;1.安装相关依赖2.查询官网下载最新安装包并解压3.编译Boost4.安装Boost库到系统路径 Boost库验证 Boost库简…

(22.12.20)matlab2022+yalmip+cplex安装教程,win11 x64

前言 Hi,你好&#xff01;最近刚刚更换新的电脑设备&#xff0c;安装软件时尽量选择最新版本&#xff0c;但也遇到了大大小小的安装问题&#xff0c;这里把踩到的坑一并总结出来&#xff0c;给出一份还算合理的MATLAByalmipCPLEX安装教程&#xff08;win11&#xff09;。 MAT…

从零入门激光SLAM(十三)——LeGo-LOAM源码超详细解析4

大家好呀&#xff0c;我是一个SLAM方向的在读博士&#xff0c;深知SLAM学习过程一路走来的坎坷&#xff0c;也十分感谢各位大佬的优质文章和源码。随着知识的越来越多&#xff0c;越来越细&#xff0c;我准备整理一个自己的激光SLAM学习笔记专栏&#xff0c;从0带大家快速上手激…

OBS插件--视频回放

视频回放 视频回放是一款源插件&#xff0c;它可以将指定源的视频缓存一段时间&#xff08;时间可以设定&#xff09;&#xff0c;将缓存中的视频添加到当前场景中后&#xff0c;可以快速或慢速不限次数的回放。这个功能在类似体育比赛的直播中非常有用&#xff0c;可以捕获指…

【快讯】山东省第四批软件产业高质量发展重点项目开始申报

为加快落实《山东省高端软件“铸魂”工程实施方案&#xff08;2023-2025&#xff09;》&#xff0c;提高软件产业规模能级&#xff0c;提升关键软件技术创新和供给能力&#xff0c;塑强数字经济发展核心竞争力&#xff0c;确定开展第四批软件产业高质量发展重点项目申报工作&am…

CTF-Web Exploitation(持续更新)

CTF-Web Exploitation 1. GET aHEAD Find the flag being held on this server to get ahead of the competition Hints Check out tools like Burpsuite to modify your requests and look at the responses 根据提示使用不同的请求方式得到response可能会得到结果 使用…

如何通过汽车制造供应商协同平台,提高供应链的效率与稳定性?

汽车制造供应商协同是指在汽车制造过程中&#xff0c;整车制造商与其零部件供应商之间建立的一种紧密合作的关系。这种协同关系旨在优化整个供应链的效率&#xff0c;降低成本&#xff0c;提高产品质量&#xff0c;加快创新速度&#xff0c;并最终提升整个汽车产业的竞争力。以…

面试笔记——JVM组成

基本介绍 JVM: Java Virtual Machine Java程序的运行环境&#xff08;java二进制字节码的运行环境&#xff09; 使用JVM的好处&#xff1a; 一次编写&#xff0c;到处运行自动内存管理&#xff0c;垃圾回收机制 JVM的组成及运行流程&#xff1a; 程序计数器 程序计数器&a…

Zabbix5.0——安装与部署

目录 一、zabbix-server(192.168.206.134) 监控方 1. 环境准备 2.安装zabbix 2.1 准备zabbix-repo 2.2清理缓存 2.3安装zabbix主包&#xff08;服务器和代理&#xff09; 2.4安装zabbix前端包 3. 数据库安装 3.1 授权zabbix账号 3.2导入数据库&#xff08;初始化zabbix&#x…

人工智能驱动的设计工具的兴起:彻底改变创意产业

人工智能驱动的设计工具的兴起&#xff1a;彻底改变创意产业 概述 人工智能 (AI) 正在改变创意产业&#xff0c;设计也不例外。人工智能驱动的设计工具正在彻底改变设计师的工作方式&#xff0c;提供无与伦比的效率、创造力和创新水平。从生成图像和设计到自动化日常任务&…

基于Opencv的车牌识别系统(毕业设计可用)

系统架构 图像采集&#xff1a;首先&#xff0c;通过摄像头等设备捕捉车辆图像。图像质量直接影响后续处理的准确性&#xff0c;因此高质量的图像采集是基础。 预处理&#xff1a;对获取的原始图像进行预处理&#xff0c;包括灰度化、降噪、对比度增强和边缘检测等。这些操作旨…

RS3236-3.3YF5 封装SOT-23-5 线性稳压器 带过温保护

RS3236-3.3YF5 是一款由Runic&#xff08;润石&#xff09;公司生产的线性稳压器&#xff08;LDO&#xff09;&#xff0c;以下是该器件的一些功能和参数介绍&#xff1a; 品牌: Runic 产品类型: 线性稳压器 (LDO) 输入电压范围: 最大 7.5V 输出电压: 固定 3.3V 输出电流: 最大…

基于FPGA的去雾算法

去雾算法的原理是基于图像去模糊的原理&#xff0c;通过对图像中的散射光进行估计和去除来消除图像中的雾霾效果。 去雾算法通常分为以下几个步骤&#xff1a; 1. 导引滤波&#xff1a;首先使用导引滤波器对图像进行滤波&#xff0c;目的是估计图像中散射光的强度。导引滤波器…