神经网络基础-价格分类案例

文章目录

    • 1. 需求分析
    • 2. 导入所需工具包
    • 3. 构建数据集
    • 4. 构建分类网络模型
    • 5. 训练模型
    • 6. 模型训练
    • 7. 评估模型
    • 8. 模型优化

学习目标:

  1. 掌握构建分类模型流程
  2. 动手实践整个过程

1. 需求分析

小明创办了一家手机公司,他不知道如何估算手机产品的价格。为了解决这个问题,他收集了多家公司的手机销售数据。该数据为二手手机的各个性能的数据,最后根据这些性能得到4个价格区间,作为这些二手手机售出的价格区间。主要包括:

battery_power电池一次可存储的电量,单位:毫安/时
blue是否有蓝牙
clock_speed微处理器执行指令的速度
dual_sim是否支持双卡
fc前置摄像头百万像素
four_g是否有4G
int_memory内存(GB)
m_dep移动深度(cm)
mobile_wt手机重量
n_cores处理器内核数
pc主摄像头百万像素
px_height像素分辨率高度
px_width像素分辨率宽度
ram随机存储器(兆字节)
sc_h手机屏幕高度(cm)
sc_w手机屏幕宽度(cm)
talk_time一次充电持续时长
three_g是否有3G
touch_screen是否有触屏控制
wifi是否能连wifi
price_range价格区间(0,1,2,3)

我们需要帮助小明找出手机的功能(例如:RAM等)与其售价之间的某种关系。我们可以使用机器学习的方法来解决这个问题,也可以构建一个全连接的网络。

需要注意的是: 在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。接下来我们还是按照四个步骤来完成这个任务:

  • 准备训练集数据

  • 构建要使用的模型

  • 模型训练

  • 模型预测评估

2. 导入所需工具包

# 导入相关模块
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time

3. 构建数据集

数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便构造数据集加载对象。

#1. 导入相关模块
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time# 构建数据集
def load_dataset():# 使用pandas 读取数据data = pd.read_csv('data/手机价格预测.csv')# 特征值和目标值x,y = data.iloc[:,:-1],data.iloc[:,-1]# 类型转换:特征值,目标值x = x.astype(np.float32)y = y.astype(np.int64)# 划分训练集和测试集x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=88)# 构建数据集,转换为pytorch格式train_dataset = TensorDataset(torch.from_numpy(x_train.values), torch.from_numpy(y_train.values))test_dataset = TensorDataset(torch.from_numpy(x_test.values), torch.from_numpy(y_test.values))#返回结果return train_dataset, test_dataset,x_train.shape[1],len(np.unique(y))if __name__ == '__main__':train_dataset, test_dataset,input_dim,class_num = load_dataset()print("输入特征数:",input_dim)print("分类个数:",class_num)

输出结果为:

输入特征数: 20
分类个数: 4

4. 构建分类网络模型

构建全连接神经网络来进行手机价格分类,该网络主要由三个线性层来构建,使用relu激活函数。

网络共有 3 个全连接层, 具体信息如下:

  1. 第一层: 输入为维度为 20, 输出维度为: 128
  2. 第二层: 输入为维度为 128, 输出维度为: 256
  3. 第三层: 输入为维度为 256, 输出维度为: 4
# 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self,input_dim,output_dim):super(PhonePriceModel, self).__init__()# 1. 第一层: 输入为维度为 20, 输出维度为: 128self.linear1 = nn.Linear(input_dim, 128)# 2. 第二层: 输入为维度为 128, 输出维度为: 256self.linear2 = nn.Linear(128, 256)# 3. 第三层: 输入为维度为 256, 输出维度为: 4self.linear3 = nn.Linear(256, output_dim)def forward(self, x):# 前向传播过程x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))output = self.linear3(x)# 获取数据结果return outputif __name__ == '__main__':train_dataset, test_dataset,input_dim,class_num = load_dataset()print("输入特征数:",input_dim)print("分类个数:",class_num)# 模型实例化model = PhonePriceModel(input_dim,class_num)

5. 训练模型

网络编写完成之后,我们需要编写训练函数。所谓的训练函数,指的是输入数据读取、送入网络、计算损失、更新参数的流程,该流程较为固定。我们使用的是多分类交叉生损失函数、使用 SGD 优化方法。最终,将训练好的模型持久化到磁盘中。

# 模型训练过程
def train(train_dataset,input_dim,class_num):# 固定随机数种子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim,class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.SGD(model.parameters(), lr=1e-3)# 训练轮数num_epochs = 50# 遍历轮数for epoch_idx in range(num_epochs):# 初始化数据加载器dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 遍历每个batch数据进行处理for x,y in dataloader:# 将数据送入网络中进行预测output = model(x)# 计算损失loss = criterion(output, y)#梯度清零optimizer.zero_grad()# 方向传播loss.backward()# 参数更新optimizer.step()# 损失计算total_num += 1total_loss += loss.item()# 打印损失变换结果print('epoch: %4s loss: %.2f, time: %.2fs' % (epoch_idx + 1, total_loss / total_num, time.time() - start))# 保存模型torch.save(model.state_dict(), 'model/phone.ptn')

6. 模型训练

if __name__ == '__main__':train_dataset, test_dataset,input_dim,class_num = load_dataset()print("输入特征数:",input_dim)print("分类个数:",class_num)# 模型训练过程train(train_dataset,input_dim,class_num)

输出结果:

epoch:    1 loss: 13.31, time: 0.25s
epoch:    2 loss: 0.96, time: 0.24s
epoch:    3 loss: 0.90, time: 0.24s
epoch:    4 loss: 0.89, time: 0.25s
epoch:    5 loss: 0.86, time: 0.26s
...
epoch:   46 loss: 0.68, time: 0.25s
epoch:   47 loss: 0.69, time: 0.26s
epoch:   48 loss: 0.68, time: 0.28s
epoch:   49 loss: 0.69, time: 0.24s
epoch:   50 loss: 0.69, time: 0.24s

7. 评估模型

使用训练好的模型,对未知的样本的进行预测的过程。我们这里使用前面单独划分出来的验证集来进行评估。

# 4 评估模型
def test(test_dataset,input_dim,class_num):# 加载模型和训练好的网络参数model = PhonePriceModel(input_dim,class_num)model.load_state_dict(torch.load('model/phone.ptn',weights_only=False))# 构建加载器dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)# 评估测试集correct = 0# 遍历测试集中的数据for x,y in dataloader:# 将其送入网络中output = model(x)# 获取类别结果y_pred = torch.argmax(output, dim=1)# 获取预测正确的个数correct += (y_pred == y).sum().sum()# 求预测精度print('Acc: %.5f' % (correct.item() / len(test_dataset)))
if __name__ == '__main__':train_dataset, test_dataset,input_dim,class_num = load_dataset()print("输入特征数:",input_dim)print("分类个数:",class_num)# 评估模型test(test_dataset,input_dim,class_num)

输出结果:

Acc: 0.62500

8. 模型优化

我们前面的网络模型在测试集的准确率为: 0.54750, 我们可以通过以下方面进行调优:

  1. 优化方法由 SGD 调整为 Adam
  2. 学习率由 1e-3 调整为 1e-4
  3. 对数据数据进行标准化
  4. Dropout 正则化
  5. 调整训练轮次
# 使用Adam方法优化网络
#1. 导入相关模块
import torch
from tensorboard import summary
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time#2. 构建数据集
def load_dataset():# 使用pandas 读取数据data = pd.read_csv('data/手机价格预测.csv')# 特征值和目标值x,y = data.iloc[:,:-1],data.iloc[:,-1]# 类型转换:特征值,目标值x = x.astype(np.float32)y = y.astype(np.int64)# 划分训练集和测试集x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=88)# 数据标准化scaler = StandardScaler()x_train = scaler.fit_transform(x_train)x_test = scaler.fit_transform(x_test)x_train = torch.tensor(x_train,dtype=torch.float32)x_test = torch.tensor(x_test,dtype=torch.float32)# 构建数据集,转换为pytorch格式train_dataset = TensorDataset(x_train, torch.from_numpy(y_train.values))test_dataset = TensorDataset(x_test, torch.from_numpy(y_test.values))#返回结果return train_dataset, test_dataset,x_train.shape[1],len(np.unique(y))
#2. 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self,input_dim,output_dim,p_dropout=0.4):super(PhonePriceModel, self).__init__()# 第一层: 输入为维度为 20, 输出维度为: 128self.linear1 = nn.Linear(input_dim, 128)# Dropout优化self.dropout = nn.Dropout(p_dropout)# 第二层: 输入为维度为 128, 输出维度为: 256self.linear2 = nn.Linear(128, 256)# Dropout优化self.dropout = nn.Dropout(p_dropout)# 第三层: 输入为维度为 256, 输出维度为: 4self.linear3 = nn.Linear(256, output_dim)def forward(self, x):# 前向传播过程x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))output = self.linear3(x)# 获取数据结果return output# 3. 模型训练过程
def train(train_dataset,input_dim,class_num):# 固定随机数种子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim,class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法# optimizer = optim.SGD(model.parameters(), lr=1e-3)# Adam优化方法 调整学习率为 lr=1e-4optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.99))# 训练轮数 100 - 0.9075 50 - 0.9125num_epochs = 50# 遍历轮数for epoch_idx in range(num_epochs):# 初始化数据加载器dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 遍历每个batch数据进行处理for x,y in dataloader:# 将数据送入网络中进行预测output = model(x)# 计算损失loss = criterion(output, y)#梯度清零optimizer.zero_grad()# 方向传播loss.backward()# 参数更新optimizer.step()# 损失计算total_num += 1total_loss += loss.item()# 打印损失变换结果print('epoch: %4s loss: %.2f, time: %.2fs' % (epoch_idx + 1, total_loss / total_num, time.time() - start))# 保存模型torch.save(model.state_dict(), 'model/phone2.ptn')
# 4 评估模型
def test(test_dataset,input_dim,class_num):# 加载模型和训练好的网络参数model = PhonePriceModel(input_dim,class_num)model.load_state_dict(torch.load('model/phone2.ptn',weights_only=False))# 构建加载器dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)# 评估测试集correct = 0# 遍历测试集中的数据for x,y in dataloader:# 将其送入网络中output = model(x)# 获取类别结果y_pred = torch.argmax(output, dim=1)# 获取预测正确的个数correct += (y_pred == y).sum().sum()# 求预测精度print('Acc: %.5f' % (correct.item() / len(test_dataset)))
if __name__ == '__main__':train_dataset, test_dataset,input_dim,class_num = load_dataset()print("输入特征数:",input_dim)print("分类个数:",class_num)# 模型训练train(train_dataset,input_dim,class_num)test(test_dataset,input_dim,class_num)

这里我们调整 Adam方法优化梯度下降,学习率调整为1e-4,样本数据采用标准化处理。采用Dropout正则化。最后输出结果:

Acc: 0.91250

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/66224.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SAP 固定资产常用的数据表有哪些,他们是怎么记录数据的?

在SAP系统中,固定资产管理(FI-AA)涉及多个核心数据表,用于记录资产主数据、折旧、交易等。以下是常用的数据表及其记录数据的逻辑: 1. ANKT - 资产主数据表 功能:存储资产主数据的文本描述。 字段&#x…

光伏储能电解水制氢仿真模型Matlab/Simulink

今天更新的内容为光伏储能制氢技术,这个方向我之前在21年就系统研究并发表过相关文章,经过这几年的发展,绿色制氢技术也受到更多高校的注意,本篇博客也是在原先文章的基础上进行更新。 首先让大家熟悉一下绿氢制取技术这个概念&a…

Redis 3.2.1在Win10系统上的安装教程

诸神缄默不语-个人CSDN博文目录 这个文件可以跟我要,也可以从官网下载:https://github.com/MicrosoftArchive/redis/releases 这个是微软以前维护的Windows版Redis安装包,如果想要比较新的版本可以从别人维护的项目里下(https://…

基于springboot+vue.js+uniapp技术开发的一套大型企业MES生产管理系统源码,支持多端管理

企业级智能制造MES系统源码,技术架构:springboot vue-element-plus-admin 企业级云MES全套源码,支持app、小程序、H5、台后管理端 MES指的是制造企业生产过程执行系统,是一套面向制造企业车间执行层的生产信息化管理系统。MES系…

【Redis】Redis事务和Lua脚本的区别

Redis事务 概念 事务:Redis事务是一组命令的集合,这些命令会被序列化地执行,中间不会被其他命令插入。 MULTI/EXEC:Redis事务通过MULTI命令开始,通过EXEC命令执行所有已入队的命令。 特点 原子性: 事务…

frameworks 之 AMS与ActivityThread交互

frameworks 之 AMS与ActivityThread交互 1. 类关系2. 流程2.1 AMS流程2.1 ActivityThread流程 3. 堆栈 讲解AMS 如何和 ActivityThread 生命周期调用流程 涉及到的类如下 frameworks/base/core/java/android/app/servertransaction/ResumeActivityItem.javaframeworks/base/cor…

Jmeter 简单使用、生成测试报告(一)

一、下载Jmter 去官网下载,我下载的是apache-jmeter-5.6.3.zip,解压后就能用。 二、安装java环境 JMeter是基于Java开发的,运行JMeter需要Java环境。 1.下载JDK、安装Jdk 2.配置java环境变量 3.验证安装是否成功(java -versio…

如何使用淘宝URL采集商品详情数据及销量

一、通过淘宝开放平台(如果有资质) 注册成为淘宝开发者 访问淘宝开放平台官方网站,按照要求填写开发者信息,包括企业或个人身份验证等步骤。这一步是为了获取合法的 API 使用权限。 了解商品详情 API 淘宝开放平台提供了一系列…

Unity3D中的Lua、ILRuntime与HybridCLR/huatuo热更对比分析详解

前言 在游戏开发中,热更新技术是一项重要的功能,它允许开发者在不重新发布游戏客户端的情况下,更新游戏内容。Unity3D作为广泛使用的游戏引擎,支持多种热更新方案,包括Lua、ILRuntime和HybridCLR/huatuo。本文将详细介…

QT加载Ui文件信息方法(python)

在 PyQt 或 PySide 中,加载 Qt Designer 生成的 .ui 文件有两种常见方法: 使用 pyuic 将 .ui 文件转换为 Python 代码。动态加载 .ui 文件。 以下是两种方法的详细说明和示例代码。 方法 1:使用 pyuic 将 .ui 文件转换为 Python 代码 步骤…

javascript基础从小白到高手系列一十二:JSON

本章内容  理解JSON 语法  解析JSON  JSON 序列化 正如上一章所说,XML 曾经一度成为互联网上传输数据的事实标准。第一代Web 服务很大程度上 是以XML 为基础的,以服务器间通信为主要特征。可是,XML 也并非没有批评者。有的人认为XML 过…

网络编程 - - TCP套接字通信及编程实现

概述 TCP(Transmission Control Protocol,传输控制协议)是一种面向连接的、可靠的传输层协议。在网络编程中,TCP常用于实现客户端和服务器之间的可靠数据传输。本文将基于C语言实现TCP服务端和客户端建立通信的过程。 三次握手 在…

2023-2024 学年 广东省职业院校技能大赛(高职组)“信息安全管理与评估”赛题一

2023-2024 学年 广东省职业院校技能大赛(高职组“信息安全管理与评估”赛题一) 模块一:网络平台搭建与设备安全防护第一阶段任务书任务 1:网络平台搭建任务 2:网络安全设备配置与防护DCRS:DCFW:DCWS:DCBC:WAF: 模块二:网络安全事件…

thinkphp6 + redis实现大数据导出excel超时或内存溢出问题解决方案

redis下载安装(window版本) 参考地址:https://blog.csdn.net/Ci1693840306/article/details/144214215 php安装redis扩展 参考链接:https://blog.csdn.net/jianchenn/article/details/106144313 解决思路:&#xff0…

PT8M2302 触控 A/D 型 8-Bit MCU

1. 产品概述 PT8M2302 是一款可多次编程( MTP ) A/D 型 8 位 MCU ,其包括 2K*16bit MTP ROM 、 256*8bit SRAM、 ADC 、 PWM 、 Touch 等功能,具有高性能精简指令集、低工作电压、低功耗特性且完全集 成触控按键功能。为…

如何使用策略模式并让spring管理

1、策略模式公共接口类 BankFileStrategy public interface BankFileStrategy {String getBankFile(String bankType) throws Exception; } 2、策略模式业务实现类 Slf4j Component public class ConcreteStrategy implements BankFileStrategy {Overridepublic String ge…

前端开发:盒子模型、块元素

1.border边框 *{box-sizing:border-box; } //使所有边框不再撑大盒子模型 粗细 : border-width 样式 : border-style, 默认没边框 . solid 实线边框 dashed 虚线边框 dotted 点线边框 颜色 : border-color div { width : 200px ; height : 200px ; border : …

Nvidia Blackwell架构深度剖析:深入了解RTX 50系列GPU的升级

在CES 2025上,英伟达推出了基于Blackwell架构的GeForce RTX 50系列显卡,包括RTX 5090、RTX 5080、RTX 5070 Ti和RTX 5070。一段时间以来,我们已经知晓了该架构的各种细节,其中许多此前还只是传闻。不过,英伟达近日在20…

计算机网络 (45)动态主机配置协议DHCP

前言 计算机网络中的动态主机配置协议(DHCP,Dynamic Host Configuration Protocol)是一种网络管理协议,主要用于自动分配IP地址和其他网络配置参数给连接到网络的设备。 一、基本概念 定义:DHCP是一种网络协议&#xf…

“扣子”开发之四:与千帆AppBuilder比较

上一个专题——“扣子”开发——未能落地,开始抱着极大的热情进入,但迅速被稚嫩的架构模型折磨打击,硬着头皮坚持了两周,终究还是感觉不实用不趁手放弃了。今天询问了下豆包,看看还有哪些比较好的AI开发平台&#xff0…