神经网络案例实战

🔎我们通过一个案例详细使用PyTorch实战 ,案例背景:你创办了一家手机公司,不知道如何估算手机产品的价格。为了解决这个问题,收集了多家公司的手机销售数据:这些数据维度可以包括RAM、存储容量、屏幕尺寸、摄像头像素等。

在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。

🔎思路:

  1. 数据预处理:对收集到的数据进行清洗和预处理,确保数据的质量和一致性。这包括处理缺失值、异常值和重复数据等。

  2. 特征工程:从原始数据中提取有用的特征,以便用于建模。这可以包括对连续型特征进行归一化或标准化,对分类特征进行编码等。

  3. 模型选择:选择一个适合的机器学习算法来建立模型,这里我们使用神经网络模型。

  4. 模型训练:将收集到的数据划分为训练集和测试集。使用训练集来训练模型,通过调整模型的参数来最小化预测误差。

  5. 模型评估:使用测试集来评估模型的性能。常用的评估指标包括均方误差(MSE)、均方根误差(RMSE)和平均绝对误差(MAE)等。

  6. 价格预测:使用训练好的模型来预测新手机的价格。输入新手机的功能特征,模型将输出预测的价格。

构建数据集 

💬数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便后期构造数据集加载对象。

def create_dataset():data = pd.read_csv('predict.csv')# 特征值和目标值x, y = data.iloc[:, :-1], data.iloc[:, -1]x = x.astype(np.float32)y = y.astype(np.int64)# 数据集划分x_train, x_valid, y_train, y_valid = train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)# 构建数据集train_dataset = TensorDataset(torch.tensor(x_train.values), torch.tensor(y_train.values))valid_dataset = TensorDataset(torch.tensor(x_valid.values), torch.tensor(y_valid.values))return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))train_dataset, valid_dataset, input_dim, class_num = create_dataset()

💬其中 train_test_split 方法中的stratify=y参数的作用是在划分训练集和验证集时,保持类别的比例相同。这样可以确保在训练集和验证集中各类别的比例与原始数据集中的比例相同,有助于提高模型的泛化能力,防止出现一份中某个类别只有几个。

构建分类网络模型

# 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 128) # 输入层到隐藏层的线性变换self.linear2 = nn.Linear(128, 256)  # 隐藏层到输出层的线性变换self.linear3 = nn.Linear(256, output_dim)  # 输出层到最终输出的线性变换def _activation(self, x):return torch.sigmoid(x)def forward(self, x):x = self._activation(self.linear1(x))# 通过第一层线性变换和激活函数x = self._activation(self.linear2(x))# 通过第二层线性变换和激活函数output = self.linear3(x)# 通过第三层线性变换得到最终输出return output
  • self.linear1self.linear2之间的线性变换将输入维度从input_dim映射到128个神经元,然后再将128个神经元映射到256个神经元。
  • 我们通过self._activation方法对输入数据进行激活函数处理。具体来说,它使用Sigmoid激活函数对输入数据进行非线性变换,将线性变换后的输出映射到0和1之间。
  • 第三层: 输入为维度为 256, 输出维度为: 4

编写训练函数 

def train():# 固定随机数种子torch.manual_seed(66)model = PhonePriceModel(input_dim, class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.SGD(model.parameters(), lr=1e-3)num_epoch = 50for epoch_idx in range(num_epoch):# 初始化数据加载器dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 准确率correct = 0for x, y in dataloader:output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_num += len(y)total_loss += loss.item() * len(y)print('epoch: %4s loss: %.2f, time: %.2fs' %(epoch_idx + 1, total_loss / total_num, time.time() - start))# 模型保存torch.save(model.state_dict(), 'price-model.bin')
  • 💬要在PyTorch中查看随机数种子,可以使用torch.random.initial_seed()函数。这个函数会返回当前设置的随机数种子。 

评估函数

def test():# 加载模型model = PhonePriceModel(input_dim, class_num)model.load_state_dict(torch.load('price-model.bin'))# 构建加载器dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)# 评估测试集correct = 0for x, y in dataloader:output = model(x)y_pred = torch.argmax(output, dim=1)correct += (y_pred == y).sum()print('Acc: %.5f' % (correct.item() / len(valid_dataset)))

网络性能调优

  1. 对输入数据进行标准化
  2. 调整优化方法
  3. 调整学习率
  4. 增加批量归一化层
  5. 增加网络层数、神经元个数
  6. 增加训练轮数

🔎我们可以自行将优化方法由 SGD 调整为 Adam , 学习率由 1e-3 调整为 1e-4 ,对数据数据进行标准化 ,增加网络深度 。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/8678.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

# 光标变为下划线怎么办?

光标变为下划线怎么办? 光标变为下划线通常表示您处于覆盖模式。在这种模式下,您键入的任何内容都将覆盖光标位置处的文本。如果想要恢复光标为正常显示,您可以尝试以下方法: 1、在桌面或文档编辑界面,按键盘上的 【I…

Object类

Object类 概念:Object类是所有类的父类,也就是说任何一个类在定义时候如果没有明确的指定继承一个父类的话,那么它就都默认继承Object类,因此Object类被称为所有类的父类,也叫做基类/超类。 常用方法 方法类型描述eq…

每日OJ题_记忆化搜索①_力扣509. 斐波那契数(四种解法)

目录 记忆化搜索概念和使用场景 力扣509. 斐波那契数 解析代码1_循环 解析代码2_暴搜递归 解析代码3_记忆化搜索 解析代码4_动态规划 记忆化搜索概念和使用场景 记忆化搜索是一种典型的空间换时间的思想,可以看成带备忘录的爆搜递归。 搜索的低效在于没有能够…

《手把手教你怎么上手做一个小程序》

准备工作: 硬件准备: 装有微信的手机一台。 账号注册: 进入https://mp.weixin.qq.com/cgi-bin/registermidpage?actionindex&langzh_CN&token注册一个微信小程序账号。 然后输入邮箱注册账号。一个邮箱只能注册一个微信公众平台…

【面试经典 150 | 数组】找出字符串中第一个匹配项的下标

文章目录 写在前面Tag题目来源解题思路方法一:find方法二:暴力匹配方法三:KMP 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法,两到三天更新一篇文章,欢迎催更…… 专栏内容以分析题目为主,…

算法学习006-瓷砖总数 广度优先算法BFS 中小学算法思维学习 信奥算法解析 c++实现

目录 C瓷砖总数 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、程序说明 五、运行结果 六、考点分析 七、推荐资料 C瓷砖总数 一、题目要求 1、编程实现 在一个长方形房间,铺着不同颜色的的瓷砖,有红色和黑色&#…

AR人脸道具SDK解决方案,实现道具与人脸的自然融合

AR人脸道具SDK解决方案,实现道具与人脸的自然融合美摄科技以其卓越的技术实力和创新能力,为企业带来了全新的AR人脸道具SDK解决方案。这一解决方案将为企业打开全新的市场机会,为用户带来前所未有的互动体验。 颠覆传统,开启AR人…

AI预测福彩3D第10套算法实战化赚米验证第1弹2024年5月5日第1次测试

从今天开始,准备启用第10套算法,来验证下本算法的可行性。因为本算法通过近三十期的内测(内测版没有公开预测结果),发现本算法的预测结果优于其他所有算法的效果。彩票预测只有实战才能检验是否有效,只有真…

产业项目招商活动会议课程报名签到h5小程序pc开源版开发

产业项目招商活动会议课程报名签到h5小程序pc开源版开发 一个集PC和移动端功能于一体的解决方案,线上线下进行服务,围绕 活动报名、在线课程、项目大厅、线下签到、会员系统等。为商会提供了更加便捷高效的管理方式,提升了商会活动和项目的组…

qt day 3

优化登录框,点击登录按钮,如果账号和密码匹配,则弹出 信息对话框 给出提示信息“登录成功”,并给出一个 ok 按钮,当用户点击 ok 后,关闭当前界面,跳转到另一个界面;如果账号和密码不…

labview技术交流-将时间字符串转换成时间格式

应用场景 我们在数据库中设计了datetime类型的字段,比如字段名就叫“保存时间”,当我们使用labview将表中数据读取出来后datetime类型的数据是以字符串的格式显示的。而我们想计算两条数据“保存时间”的间隔时间时,用字符串类型自然是没法计…

OpenMV 图像串口传输示例

注意:本程序根据 OpenMV采集图片通过串口发送,PC接收并保存为图片 更改。 一、例程说明 这个例程主要实现了以下功能: 1. OpenMV 端采集图像:使用OpenMV开发板上的摄像头采集实时图像数据。 2. 通过串口传输图像数据:将采集到的图像数据打包成字节流,…

idea 项目 修改项目文件名 教程

文章目录 目录 文章目录 修改流程 小结 概要流程技术细节小结 概要 原项目名 修改流程 关掉当前项目的idea页面 修改之后的文件名 重新打开idea。选择项目打开项目页面 技术细节 出现下面这个问题,可以参考作者新的一编文章idea开发工具 项目使用Spring框架开发解…

【Linux】25. 网络基础(一)

网络基础(一) 计算机网络背景 网络发展 独立模式: 计算机之间相互独立; 网络互联: 多台计算机连接在一起, 完成数据共享; 其实本质上一台计算机内部也是一个小型网络结构(如果我们将计算机内部某个硬件不存放在电脑中,而是拉根长长的线进行连接。这其实也就是网…

在做题中学习(51): x的平方根

69. x 的平方根 - 力扣(LeetCode)​​​​​​ 解法:二分查找 思路:看示例2: 可以看到8的平方根是2.82,在2^2和3^2之间,所以可以把数组分为两部分,(具有二段性) 而2.82去掉小数部分…

【Kolmogorov-Arnold网络 替代多层感知机MLPs】KAN: Kolmogorov-Arnold Networks

KAN: Kolmogorov-Arnold Networks 论文地址 代码地址 知乎上的讨论(看一下评论区更正) Abstract Inspired by the Kolmogorov-Arnold representation theorem, we propose Kolmogorov-Arnold Networks (KANs) as promising alternatives to Multi-Layer…

【OceanBase 系列】—— OceanBase v4.3 特性解读:查询性能提升之利器列存储引擎

原文链接:OceanBase 社区 对于分析类查询,列存可以极大地提升查询性能,也是 OceanBase 做好 HTAP 和 OLAP 的一项不可缺少的特性。本文介绍 OceanBase 列存的实现特色。 OceanBase从诞生起就一直坚持LSM-Tree架构,不断打磨功能支…

2020 年国考【计算机专业】真题及答案

真题及答案 第一部分数学基础课程 一、(共 5 分)用逻辑符号表达下列语句(论域为包含一切事物的合集) (2 分)确诊者并不都有症状(注:需给出两种形式表达, 一种用存在量词, 一种用全称…

python魔法方法是什么

魔法方法是python内置方法,不需要主动调用,存在的目的是为了给python的解释器进行调用,几乎每个魔法方法都有一个对应的内置函数,或者运算符,当我们对这个对象使用这些函数或者运算符时就会调用类中的对应魔法方法&…

MVC与MVVM架构模式

1、MVC MVC:Model-View-Controller,即模型-视图-控制器 MVC模式是一种非常经典的软件架构模式。从设计模式的角度来看,MVC模式是一种复合模式,它将多个设计模式结合在一种解决方案中,从而可以解决许多设计问题。 MV…