价格分类(神经网络)

# 1.导入依赖包
import timeimport torch
import torch.nn as nn
import torch.optim as optimfrom torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_splitimport numpy as np
import pandas as pd
import matplotlib.pyplot as pltfrom torchsummary import summary# 2.构建数据集
def create_dataset():# 2.1 读取数据集data = pd.read_csv('dataset/手机价格预测.csv')# 2.2 获取特征值和目标值,类型转化  特征(Float)  标签(Long)x, y = data.iloc[:, :-1], data.iloc[:, -1]x, y = x.astype(np.float32), y.astype(np.int64)# 2.3 数据集划分x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2,random_state=2)# 2.4 数据转Tensortrain_dataset = TensorDataset(torch.from_numpy(x_train.values), torch.tensor(y_train.values))test_dataset = TensorDataset(torch.from_numpy(x_test.values), torch.tensor(y_test.values))return train_dataset, test_dataset, x_train.shape[1], len(np.unique(y))# 3. 构建模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 256)self.linear2 = nn.Linear(256, 1024)self.fc = nn.Linear(1024, output_dim)def forward(self, x):x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))output = self.fc(x)# output = torch.softmax(self.fc(x), dim=-1)return output# 4.模型训练(225)
def train(model, train_dataset, num_epochs, batch_size):# 2 初始化参数  损失函数  优化器loss1 = nn.CrossEntropyLoss()# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.99, 0.99))start = time.time()# 2 2个遍历  epoch  dataloaderfor epoch in range(num_epochs):dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)total_num = 0total_loss = 0.0for x, y in dataloader:# 5 前向传播  损失计算 梯度归零  反向传播 参数更新output = model(x)loss = loss1(output, y)optimizer.zero_grad()loss.backward()optimizer.step()total_num += 1  # 批次total_loss += loss.item()epoch += 1print(f'epoch:{epoch + 1:4d},loss:{total_loss / (total_num * epoch):.4f}, time:{time.time() - start:.2f}s')# 模型持久化torch.save(model.state_dict(), 'model/phone2.pth')# 5.模型预测评估
def test(model, test_dataset, input_dim, output_dim):# 3.导入数据dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)correct = 0# 4.遍历数据for x, y in dataloader:# 4.1 前向传播output = model(x)print(output)# 4.2 获取输出结果(类别)y_pred = torch.argmax(output, dim=1)# print(y_pred)  # 预测错误# 4.3 计算准确率Acccorrect += (y_pred == y).sum()print(correct.item())Acc = correct.item() / len(test_dataset)return Accif __name__ == '__main__':train_dataset, test_dataset, feature_num, label_num = create_dataset()# 1.实例化模型model = PhonePriceModel(feature_num, label_num)# 2.加载模型model.load_state_dict(torch.load('model/phone2.pth'))# 模型训练# train(model, train_dataset, num_epochs=50, batch_size=8)# 模型预测Acc = test(model, test_dataset, feature_num, label_num)print(f'Acc:{Acc:.5f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/61984.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视频 | Navicat 17 功能亮点

探索 Navicat 17 的数据库管理与开发功能 本视频,我们将主要从结构设计、数据操作、SQL 开发、数据建模和 BI 这五个方面,介绍 Navicat Premium 17 是如何助力数据库开发和管理人员简单高效地完成数据库工作的。 此外,Navicat 系列还涵盖了广…

「Chromeg谷歌浏览器/Edge浏览器」篡改猴Tempermongkey插件的安装与使用

1. 谷歌浏览器安装及使用流程 1.1 准备篡改猴扩展程序包。 因为谷歌浏览器的扩展商城打不开,所以需要准备一个篡改猴压缩包。 其他浏览器只需打开扩展商城搜索篡改猴即可。 没有压缩包的可以进我主页下载。 也可直接点击下载:Chrome浏览器篡改猴(油猴…

git base 下载$ git clone 失败解决方法

$ git clone https://github.com/hjsdjko/hangkongdingpiao.git Cloning into hangkongdingpiao... fatal: unable to access https://github.com/hjsdjko/hangkongdingpiao.git/: SSL certificate problem: unable to get local issuer certificate 使用git config --global …

STM32F103C8T6实时时钟RTC

目录 前言 一、RTC基本硬件结构 二、Unix时间戳 2.1 unix时间戳定义 2.2 时间戳与日历日期时间的转换 2.3 指针函数使用注意事项 ​三、RTC和BKP硬件结构 四、驱动代码解析 前言 STM32F103C8T6外部低速时钟LSE(一般为32.768KHz)用的引脚是PC14和PC…

【JavaEE初阶】多线程初阶下部

文章目录 前言一、volatile关键字volatile 能保证内存可见性 二、wait 和 notify2.1 wait()方法2.2 notify()方法2.3 notifyAll()方法2.4 wait 和 sleep 的对比(面试题) 三、多线程案例单例模式 四、总结-保证线程安全的思路五、对比线程和进程总结 前言…

【人工智能】Python在机器学习与人工智能中的应用

Python因其简洁易用、丰富的库支持以及强大的社区,被广泛应用于机器学习与人工智能(AI)领域。本教程通过实用的代码示例和讲解,带你从零开始掌握Python在机器学习与人工智能中的基本用法。 1. 机器学习与AI的Python生态系统 Pyth…

“iOS profile文件与私钥证书文件不匹配”总结打ipa包出现的问题

目录 文件和证书未加载或特殊字符问题 证书过期或Profile文件错误 确认开发者证书和私钥是否匹配 创建证书选择错误问题 申请苹果 AppId时勾选服务不全问题 ​总结 在上线ios平台的时候,在Hbuilder中打包遇见了问题,生成ipa文件时候,一…

element-ui 中el-calendar 日历插件获取显示的第一天和最后一天【原创】

需要获取el-calendar 日历组件上的第1天和最后一天。可以通过document.querySelector()方法进行获取dom元素中的值,这样避免计算问题。 获取的过程中主要有两个难点,第1个是处理上1月和下1月的数据,第2个是跨年的数据。 直接贴代码&#xff…

抓住鸿蒙生态崛起的机遇,拥抱未来开发挑战

随着华为鸿蒙(HarmonyOS)的持续发展,鸿蒙生态正在迅速崛起,逐步在智能手机、智能穿戴、车载、家居等领域形成完整闭环。它不仅为开发者带来了新的机遇,还带来了技术上的挑战。如何抓住这些机遇并应对挑战,是…

高标准农田智慧农业系统建设方案

1 项目概述 1.1 建设背景 我国是农业大国,近30年来农田高产量主要依靠农药化肥的大量投入,大部分化肥和水资源没有被有效利用而随地弃置,导致大量养分损失并造成环境污染。我国农业生产仍然以传统生产模式为主,传统耕种只能凭经验施肥灌溉,不仅浪费大量的人力物力,也对环…

基于Angular+BootStrap+SpringBoot简单的购物网站

目录 一、项目结构图 二、目录结构解析 后端 (Spring Boot) 前端 (Angular) 三、技术栈 四、具体功能实现 五、数据库设计 六、后端实现 1. 设置Spring Boot项目 2. 数据库实体类 3. 创建Repository 4. 创建Service层 5. 创建Controller层 七、前端实现&#xff0…

JavaScript的基础数据类型

一、JavaScript中的数组 定义 数组是一种特殊的对象,用于存储多个值。在JavaScript中,数组可以包含不同的数据类型,如数字、字符串、对象、甚至其他数组。数组的创建有两种常见方式: 字面量表示法:let fruits [apple…

5.5 W5500 TCP服务端与客户端

文章目录 1、TCP介绍2、W5500简介2.1 关键函数socketlistensendgetSn_RX_RSRrecv自动心跳包检测getSn_SR 1、TCP介绍 TCP 服务端: 创建套接字[socket]:服务器首先创建一个套接字,这是网络通信的端点。绑定套接字[bind]:服务器将…

PostGres命令【常用维护,增删改查】

文章目录 连接数据库列出数据库列出表增删改查操作基本的维护命令其他常用命令 PostgreSQL 中常用的 psql 命令,包括连接数据库、列出数据库、列出表、增删改查操作以及一些基本的维护命令。 连接数据库 启动 psql 客户端: psql -U your_username -d yo…

Android 15 版本更新及功能介绍

Android 15版本时间戳 Android 15,代号Vanilla Ice Cream(香草冰淇淋),是当下 Android 移动操作系统的最新主要版本。 开发者预览阶段:2024年2月,谷歌发布了Android 15的第一个开发者预览版本(DP1),这标志着新系统开发的正式启动。随后,在3月和4月,谷歌又相继推出了D…

第02章_MySQL环境搭建(基础)

1. MySQL 的卸载 1.1 步骤1:停止 MySQL 服务 在卸载之前,先停止 MySQL8.0 的服务。按键盘上的 “Ctrl Alt Delete” 组合键,打开“任务管理器”对话 框,可以在“服务”列表找到“MySQL8.0” 的服务,如果现在“正在…

Vue开发05:Vue中Ant-design主要控件用法demo(js为主)

Ant-design主要控件事件总结 在线测试网站:在线运行Vue组件 (rscl.cc) 以下demo全部基于ant-design-vue组件(版本1.7.8) 一、下拉框 1.选项直接赋值($event) 用下面这个技巧,可以不写methods&#xff0…

红队笔记--W1R3S、JARBAS、SickOS、Prime打靶练习记录

W1R3S(思路为主) 信息收集 首先使用nmap探测主机,得到192.168.190.147 接下来扫描端口,可以看到ports文件保存了三种格式 其中.nmap和屏幕输出的一样;xml这种的适合机器 nmap -sT --min-rate 10000 -p- 192.168.190.147 -oA nmapscan/ports…

深入理解 MyBatis 的缓存机制:一级缓存与二级缓存

MyBatis 是目前 Java 开发中常用的一种 ORM(对象关系映射)框架,它不仅简化了 SQL 语句的编写和管理,还提供了强大的缓存机制,用以提高数据库访问的性能。MyBatis 的缓存分为一级缓存和二级缓存,分别应用于不…

使用nvm下载多个版本node后提示vue不是内部或外部命令,执行vue create报.vuerc错误

一、使用nvm后执行含vue的相关命令提示vue不是内部或外部命令 前言:之前有项目需要切换node版本,我把node卸载了然后使用nvm下载多个版本的node。现在想通过vue create搭建vue2的项目时提示vue不是内部或外部命令,执行npm i vue/cli后仍然无…