价格分类(神经网络)

# 1.导入依赖包
import timeimport torch
import torch.nn as nn
import torch.optim as optimfrom torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_splitimport numpy as np
import pandas as pd
import matplotlib.pyplot as pltfrom torchsummary import summary# 2.构建数据集
def create_dataset():# 2.1 读取数据集data = pd.read_csv('dataset/手机价格预测.csv')# 2.2 获取特征值和目标值,类型转化  特征(Float)  标签(Long)x, y = data.iloc[:, :-1], data.iloc[:, -1]x, y = x.astype(np.float32), y.astype(np.int64)# 2.3 数据集划分x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2,random_state=2)# 2.4 数据转Tensortrain_dataset = TensorDataset(torch.from_numpy(x_train.values), torch.tensor(y_train.values))test_dataset = TensorDataset(torch.from_numpy(x_test.values), torch.tensor(y_test.values))return train_dataset, test_dataset, x_train.shape[1], len(np.unique(y))# 3. 构建模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 256)self.linear2 = nn.Linear(256, 1024)self.fc = nn.Linear(1024, output_dim)def forward(self, x):x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))output = self.fc(x)# output = torch.softmax(self.fc(x), dim=-1)return output# 4.模型训练(225)
def train(model, train_dataset, num_epochs, batch_size):# 2 初始化参数  损失函数  优化器loss1 = nn.CrossEntropyLoss()# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.99, 0.99))start = time.time()# 2 2个遍历  epoch  dataloaderfor epoch in range(num_epochs):dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)total_num = 0total_loss = 0.0for x, y in dataloader:# 5 前向传播  损失计算 梯度归零  反向传播 参数更新output = model(x)loss = loss1(output, y)optimizer.zero_grad()loss.backward()optimizer.step()total_num += 1  # 批次total_loss += loss.item()epoch += 1print(f'epoch:{epoch + 1:4d},loss:{total_loss / (total_num * epoch):.4f}, time:{time.time() - start:.2f}s')# 模型持久化torch.save(model.state_dict(), 'model/phone2.pth')# 5.模型预测评估
def test(model, test_dataset, input_dim, output_dim):# 3.导入数据dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)correct = 0# 4.遍历数据for x, y in dataloader:# 4.1 前向传播output = model(x)print(output)# 4.2 获取输出结果(类别)y_pred = torch.argmax(output, dim=1)# print(y_pred)  # 预测错误# 4.3 计算准确率Acccorrect += (y_pred == y).sum()print(correct.item())Acc = correct.item() / len(test_dataset)return Accif __name__ == '__main__':train_dataset, test_dataset, feature_num, label_num = create_dataset()# 1.实例化模型model = PhonePriceModel(feature_num, label_num)# 2.加载模型model.load_state_dict(torch.load('model/phone2.pth'))# 模型训练# train(model, train_dataset, num_epochs=50, batch_size=8)# 模型预测Acc = test(model, test_dataset, feature_num, label_num)print(f'Acc:{Acc:.5f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/61984.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

「Chromeg谷歌浏览器/Edge浏览器」篡改猴Tempermongkey插件的安装与使用

1. 谷歌浏览器安装及使用流程 1.1 准备篡改猴扩展程序包。 因为谷歌浏览器的扩展商城打不开,所以需要准备一个篡改猴压缩包。 其他浏览器只需打开扩展商城搜索篡改猴即可。 没有压缩包的可以进我主页下载。 也可直接点击下载:Chrome浏览器篡改猴(油猴…

STM32F103C8T6实时时钟RTC

目录 前言 一、RTC基本硬件结构 二、Unix时间戳 2.1 unix时间戳定义 2.2 时间戳与日历日期时间的转换 2.3 指针函数使用注意事项 ​三、RTC和BKP硬件结构 四、驱动代码解析 前言 STM32F103C8T6外部低速时钟LSE(一般为32.768KHz)用的引脚是PC14和PC…

【JavaEE初阶】多线程初阶下部

文章目录 前言一、volatile关键字volatile 能保证内存可见性 二、wait 和 notify2.1 wait()方法2.2 notify()方法2.3 notifyAll()方法2.4 wait 和 sleep 的对比(面试题) 三、多线程案例单例模式 四、总结-保证线程安全的思路五、对比线程和进程总结 前言…

【人工智能】Python在机器学习与人工智能中的应用

Python因其简洁易用、丰富的库支持以及强大的社区,被广泛应用于机器学习与人工智能(AI)领域。本教程通过实用的代码示例和讲解,带你从零开始掌握Python在机器学习与人工智能中的基本用法。 1. 机器学习与AI的Python生态系统 Pyth…

“iOS profile文件与私钥证书文件不匹配”总结打ipa包出现的问题

目录 文件和证书未加载或特殊字符问题 证书过期或Profile文件错误 确认开发者证书和私钥是否匹配 创建证书选择错误问题 申请苹果 AppId时勾选服务不全问题 ​总结 在上线ios平台的时候,在Hbuilder中打包遇见了问题,生成ipa文件时候,一…

element-ui 中el-calendar 日历插件获取显示的第一天和最后一天【原创】

需要获取el-calendar 日历组件上的第1天和最后一天。可以通过document.querySelector()方法进行获取dom元素中的值,这样避免计算问题。 获取的过程中主要有两个难点,第1个是处理上1月和下1月的数据,第2个是跨年的数据。 直接贴代码&#xff…

JavaScript的基础数据类型

一、JavaScript中的数组 定义 数组是一种特殊的对象,用于存储多个值。在JavaScript中,数组可以包含不同的数据类型,如数字、字符串、对象、甚至其他数组。数组的创建有两种常见方式: 字面量表示法:let fruits [apple…

5.5 W5500 TCP服务端与客户端

文章目录 1、TCP介绍2、W5500简介2.1 关键函数socketlistensendgetSn_RX_RSRrecv自动心跳包检测getSn_SR 1、TCP介绍 TCP 服务端: 创建套接字[socket]:服务器首先创建一个套接字,这是网络通信的端点。绑定套接字[bind]:服务器将…

Android 15 版本更新及功能介绍

Android 15版本时间戳 Android 15,代号Vanilla Ice Cream(香草冰淇淋),是当下 Android 移动操作系统的最新主要版本。 开发者预览阶段:2024年2月,谷歌发布了Android 15的第一个开发者预览版本(DP1),这标志着新系统开发的正式启动。随后,在3月和4月,谷歌又相继推出了D…

第02章_MySQL环境搭建(基础)

1. MySQL 的卸载 1.1 步骤1:停止 MySQL 服务 在卸载之前,先停止 MySQL8.0 的服务。按键盘上的 “Ctrl Alt Delete” 组合键,打开“任务管理器”对话 框,可以在“服务”列表找到“MySQL8.0” 的服务,如果现在“正在…

红队笔记--W1R3S、JARBAS、SickOS、Prime打靶练习记录

W1R3S(思路为主) 信息收集 首先使用nmap探测主机,得到192.168.190.147 接下来扫描端口,可以看到ports文件保存了三种格式 其中.nmap和屏幕输出的一样;xml这种的适合机器 nmap -sT --min-rate 10000 -p- 192.168.190.147 -oA nmapscan/ports…

学习笔记|MaxKB对接本地大模型时,选择Ollma还是vLLM?

在使用MaxKB开源知识库问答系统的过程中,除了对接在线大模型,一些用户出于资源配置、长期使用成本、安全性等多方面考虑,还在积极尝试通过Ollama、vLLM等模型推理框架对接本地离线大模型。而在用户实践的过程中,经常会对候选的模型…

计算机网络八股整理(一)

计算机网络八股文整理 一:网络模型 1:网络osi模型和tcp/ip模型分别介绍一下 osi模型是国际标准的网络模型,它由七层组成,从上到下分别是:应用层,表示层,会话层,传输层,…

Spring Boot教程之五:在 IntelliJ IDEA 中运行第一个 Spring Boot 应用程序

在 IntelliJ IDEA 中运行第一个 Spring Boot 应用程序 IntelliJ IDEA 是一个用 Java 编写的集成开发环境 (IDE)。它用于开发计算机软件。此 IDE 由 Jetbrains 开发,提供 Apache 2 许可社区版和商业版。它是一种智能的上下文感知 IDE,可用于在各种应用程序…

单片机学习笔记 9. 8×8LED点阵屏

更多单片机学习笔记:单片机学习笔记 1. 点亮一个LED灯单片机学习笔记 2. LED灯闪烁单片机学习笔记 3. LED灯流水灯单片机学习笔记 4. 蜂鸣器滴~滴~滴~单片机学习笔记 5. 数码管静态显示单片机学习笔记 6. 数码管动态显示单片机学习笔记 7. 独立键盘单片机学习笔记 8…

vue 预览pdf 【@sunsetglow/vue-pdf-viewer】开箱即用,无需开发

sunsetglow/vue-pdf-viewer 开箱即用的pdf插件sunsetglow/vue-pdf-viewer, vue3 版本 无需多余开发,操作简单,支持大文件 pdf 滚动加载,缩放,左侧导航,下载,页码,打印,文本复制&…

Css—实现3D导航栏

一、背景 最近在其他的网页中看到了一个很有趣的3d效果,这个效果就是使用css3中的3D转换实现的,所以今天的内容就是3D的导航栏效果。那么话不多说,直接开始主要内容的讲解。 二、效果展示 三、思路解析 1、首先我们需要将这个导航使用一个大…

重新定义社媒引流:AI社媒引流王如何为品牌赋能?

在社交媒体高度竞争的时代,引流已经不再是单纯追求流量的数字游戏,而是要找到“对的用户”,并与他们建立真实的连接。AI社媒引流王通过技术创新和智能策略,重新定义了社媒引流的方式,帮助品牌在精准触达和高效互动中脱…

Docker1:认识docker、在Linux中安装docker

欢迎来到“雪碧聊技术”CSDN博客! 在这里,您将踏入一个专注于Java开发技术的知识殿堂。无论您是Java编程的初学者,还是具有一定经验的开发者,相信我的博客都能为您提供宝贵的学习资源和实用技巧。作为您的技术向导,我将…

Centos 8, add repo

Centos repo前言 Centos 8更换在线阿里云创建一键更换repo 自动化脚本 华为Centos 源 , 阿里云Centos 源 华为epel 源 , 阿里云epel 源vim /centos8_repo.sh #!/bin/bash # -*- coding: utf-8 -*- # Author: make.han