12.2深度学习_项目实战

十、项目实战

鲍勃开了自己的手机公司。他想与苹果、三星等大公司展开硬仗。
他不知道如何估算自己公司生产的手机的价格。在这个竞争激烈的手机市场,你不能简单地假设事情。为了解决这个问题,他收集了各个公司的手机销售数据。
鲍勃想找出手机的特性(例如:RAM、内存等)和售价之间的关系。但他不太擅长机器学习。所以他需要你帮他解决这个问题。
在这个问题中,你不需要预测实际价格,而是要预测一个价格区间,表明价格多高。

需要注意的是: 在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。

数据说明:https://tianchi.aliyun.com/dataset/157241 Mobile Price Classification

推荐专业的数据集平台:https://www.kaggle.com/datasets/iabhishekofficial/mobile-price-classification

字段说明
battery_power电池容量(mAh)
blue是否支持蓝牙
clock_speed微处理器执行指令的速度
dual_sim是否支持双卡
fc前置摄像头分辨率(百万像素)
four_g是否支持4G
int_memory存储内存(GB)
m_dep手机厚度(厘米)
mobile_wt重量
n_cores核心数
pc主摄像头分辨率(百万像素)
px_height屏幕分辨率高度(像素)
px_width屏幕分辨率宽度(像素)
ram运行内存(MB)
sc_h屏幕长度(厘米)
sc_w屏幕宽度(厘米)
talk_time单次充电最长通话时间
three_g是否支持3G
touch_screen是否是触摸屏
wifi是否支持WIFI
price_range价格区间

1. 构建数据集

数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便构造数据集加载对象。

# 构建数据集
def create_dataset():data = pd.read_csv('data/手机价格预测.csv')# 特征值和目标值x, y = data.iloc[:, :-1], data.iloc[:, -1]x = x.astype(np.float32)y = y.astype(np.int64)# 数据集划分x_train, x_valid, y_train, y_valid=train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)# 构建数据集train_dataset = TensorDataset(torch.from_numpy(x_train.values), torch.tensor(y_train.values))valid_dataset = TensorDataset(torch.from_numpy(x_valid.values), torch.tensor(y_valid.values))return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))train_dataset, valid_dataset, input_dim, class_num = create_dataset()

2. 构建分类网络模型

我们构建的用于手机价格分类的模型叫做全连接神经网络。它主要由三个线性层来构建,在每个线性层后,我们使用的时 sigmoid 激活函数。

# 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 128)self.linear2 = nn.Linear(128, 256)self.linear3 = nn.Linear(256, output_dim)def _activation(self, x):return torch.sigmoid(x)def forward(self, x):x = self._activation(self.linear1(x))x = self._activation(self.linear2(x))output = self.linear3(x)return output

3. 编写训练函数

网络编写完成之后,我们需要编写训练函数。所谓的训练函数,指的是输入数据读取、送入网络、计算损失、更新参数的流程,该流程较为固定。我们使用的是多分类交叉生损失函数、使用 SGD 优化方法。最终,将训练好的模型持久化到磁盘中。

def train():# 固定随机数种子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim, class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.SGD(model.parameters(), lr=1e-3)# 训练轮数num_epoch = 50for epoch_idx in range(num_epoch):# 初始化数据加载器dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 准确率correct = 0for x, y in dataloader:output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_num += len(y)total_loss += loss.item() * len(y)print('epoch: %4s loss: %.2f, time: %.2fs' %(epoch_idx + 1, total_loss / total_num, time.time() - start))# 模型保存torch.save(model.state_dict(), 'model/phone-price-model.bin')

4. 编写评估函数

评估函数、也叫预测函数、推理函数,主要使用训练好的模型,对未知的样本的进行预测的过程。我们这里使用前面单独划分出来的测试集来进行评估。

def test():# 加载模型model = PhonePriceModel(input_dim, class_num)model.load_state_dict(torch.load('model/phone-price-model.bin'))# 构建加载器dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)# 评估测试集correct = 0for x, y in dataloader:output = model(x)y_pred = torch.argmax(output, dim=1)correct += (y_pred == y).sum()print('Acc: %.5f' % (correct.item() / len(valid_dataset)))

5. 网络性能调优

我们前面的网络模型在测试集的准确率为: 0.54750, 我们可以通过以下方面进行调优:

  1. 对输入数据进行标准化
  2. 调整优化方法
  3. 调整学习率
  4. 增加批量归一化层
  5. 增加网络层数、神经元个数
  6. 增加训练轮数
  7. 等等…

我进行下如下调整:

  1. 优化方法由 SGD 调整为 Adam
  2. 学习率由 1e-3 调整为 1e-4
  3. 对数据数据进行标准化
  4. 增加网络深度, 即: 增加网络参数量

网络模型在测试集的准确率由 0.5475 上升到 0.9625

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
import time
from sklearn.preprocessing import StandardScaler# 构建数据集
def create_dataset():data = pd.read_csv('data/手机价格预测.csv')# 特征值和目标值x, y = data.iloc[:, :-1], data.iloc[:, -1]x = x.astype(np.float32)y = y.astype(np.int64)# 数据集划分x_train, x_valid, y_train, y_valid = \train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)# 数据标准化transfer = StandardScaler()x_train = transfer.fit_transform(x_train)x_valid = transfer.transform(x_valid)# 构建数据集train_dataset = TensorDataset(torch.from_numpy(x_train), torch.tensor(y_train.values))valid_dataset = TensorDataset(torch.from_numpy(x_valid), torch.tensor(y_valid.values))return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))train_dataset, valid_dataset, input_dim, class_num = create_dataset()# 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 128)self.linear2 = nn.Linear(128, 256)self.linear3 = nn.Linear(256, 512)self.linear4 = nn.Linear(512, 128)self.linear5 = nn.Linear(128, output_dim)def _activation(self, x):return torch.sigmoid(x)def forward(self, x):x = self._activation(self.linear1(x))x = self._activation(self.linear2(x))x = self._activation(self.linear3(x))x = self._activation(self.linear4(x))output = self.linear5(x)return output# 编写训练函数
def train():# 固定随机数种子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim, class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.Adam(model.parameters(), lr=1e-4)# 训练轮数num_epoch = 50for epoch_idx in range(num_epoch):# 初始化数据加载器dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 准确率correct = 0for x, y in dataloader:output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_num += len(y)total_loss += loss.item() * len(y)print('epoch: %4s loss: %.2f, time: %.2fs' %(epoch_idx + 1, total_loss / total_num, time.time() - start))# 模型保存torch.save(model.state_dict(), 'model/phone-price-model.bin')def test():# 加载模型model = PhonePriceModel(input_dim, class_num)model.load_state_dict(torch.load('model/phone-price-model.bin'))# 构建加载器dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)# 评估测试集correct = 0for x, y in dataloader:output = model(x)y_pred = torch.argmax(output, dim=1)correct += (y_pred == y).sum()print('Acc: %.5f' % (correct.item() / len(valid_dataset)))if __name__ == '__main__':train()test()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/61619.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视频流媒体服务解决方案之Liveweb视频汇聚平台

一,Liveweb视频汇聚平台简介: LiveWeb是深圳市好游科技有限公司开发的一套综合视频汇聚管理平台,可提供多协议(RTSP/RTMP/GB28181/海康Ehome/大华,海康SDK等)的视频设备接入,支持GB/T28181上下级联&#xf…

ES中间件学习书籍阅读笔记

ES的书籍学习阅读笔记 一本书讲透ElasticSearch:原理、进阶与工程实践ES的基础知识搜索引擎基础知识倒排索引全文检索 ES的核心概念ES集群知识介绍基础知识 一本书讲透ElasticSearch:原理、进阶与工程实践 ES的基础知识 搜索引擎基础知识 倒排索引 面…

物联网——WatchDog(监听器)

看门狗简介 独立看门狗框图 看门狗原理:定时器溢出,产生系统复位信号;若定时‘喂狗’则不产生系统复位信号 定时中断基本结构(对比) IWDG键寄存器 独立看门狗超时时间 WWDG(窗口看门狗) WWDG特性 WWDG超时时间 由于…

LabVIEW将TXT文本转换为CSV格式(多行多列)

在LabVIEW中,将TXT格式的文本文件内容转换为Excel格式(即CSV文件)是一项常见的数据处理任务,适用于将以制表符、空格或其他分隔符分隔的数据格式化为可用于电子表格分析的形式。以下是将TXT文件转换为Excel(CSV&#x…

我的第一个创作纪念日 —— 梦开始的地方

前言 时光荏苒,转眼间,我已经在CSDN这片技术沃土上耕耘了365天 今天,我迎来了自己在CSDN的第1个创作纪念日,这个特殊的日子不仅是对我过去努力的肯定,更是对未来持续创作的激励 机缘 回想起初次接触CSDN,那…

Android:生成Excel表格并保存到本地

提醒 本文实例是使用Kotlin进行开发演示的。 一、技术方案 org.apache.poi:poiorg.apache.poi:poi-ooxml 二、添加依赖 [versions]poi "5.2.3" log4j "2.24.2"[libraries]#https://mvnrepository.com/artifact/org.apache.poi/poi apache-poi { module…

Leecode刷题C语言之N皇后②

执行结果:通过 执行用时和内存消耗如下: struct hashTable {int key;UT_hash_handle hh; };struct hashTable* find(struct hashTable** hashtable, int ikey) {struct hashTable* tmp NULL;HASH_FIND_INT(*hashtable, &ikey, tmp);return tmp; }void insert(…

vue跳转以及传参

1.跳转页面的三种方法 <template><button click"twopage">跳转</button> </template><script setup> import { useRouter } from "vue-router"; const router useRouter(); // 获取 router 实例const twopage () > {r…

EasyMedia播放rtsprtmp视频流(flvhls)

学习链接 MisterZhang/EasyMedia - gitee地址 EasyMedia转码rtsp视频流flv格式&#xff0c;hls格式&#xff0c;H5页面播放flv流视频 EasyMedia播放rtsp视频流&#xff08;vue2、vue3皆可用&#xff09; EasyMedia转码rtsp视频流flv格式&#xff0c;hls格式&#xff0c;H5页…

用 React 编写一个笔记应用程序

这篇文章会教大家用 React 编写一个笔记应用程序。用户可以创建、编辑、和切换 Markdown 笔记。 1. nanoid nanoid 是一个轻量级和安全的唯一字符串ID生成器&#xff0c;常用于JavaScript环境中生成随机、唯一的字符串ID&#xff0c;如数据库主键、会话ID、文件名等场景。 …

利用 Redis 与 Lua 脚本解决秒杀系统中的高并发与库存超卖问题

1. 前言 1.1 秒杀系统中的库存超卖问题 在电商平台上&#xff0c;秒杀活动是吸引用户参与并提升销量的一种常见方式。秒杀通常会以极低的价格限量出售某些商品&#xff0c;目的是制造紧迫感&#xff0c;吸引大量用户参与。然而&#xff0c;这种活动的特殊性也带来了许多技术挑…

机器学习8-决策树CART原理与GBDT原理

CART决策树算法流程举例 该篇文章对于CART的算法举例讲解&#xff0c;一看就懂。 决策树(Decision Tree)—CART算法 同时也可以观看视频 分类树 GBDT原理举例 可以看如下示例可以理解GBDT的计算原理 用通俗易懂的方式讲解&#xff1a; GBDT算法及案例&#xff08;Python 代…

【Linux探索学习】第十八弹——进程等待:深入解析操作系统中的进程等待机制

Linux学习笔记&#xff1a;https://blog.csdn.net/2301_80220607/category_12805278.html?spm1001.2014.3001.5482 前言&#xff1a; 在Linux操作系统中&#xff0c;进程是资源的管理和执行单元&#xff0c;每个进程都有其自己的生命周期。在进程的执行过程中&#xff0c;进程…

大数据技术Kafka详解 ② | Kafka基础与架构介绍

目录 1、kafka的基本介绍 2、kafka的好处 3、分布式发布与订阅系统 4、kafka的主要应用场景 4.1、指标分析 4.2、日志聚合解决方法 4.3、流式处理 5、kafka架构 6、kafka主要组件 6.1、producer(生产者) 6.2、topic(主题) 6.3、partition(分区) 6.4、consumer(消费…

嵌入式Linux之wifi配网脚本分析

嵌入式Linux系统,一般都支持wifi联网,可以通过sh脚本或其它语言代码编程来实现wifi联网。 本篇来介绍一种通过sh脚本来配置wifi的脚本执行原理。 1 sh脚本wifi联网介绍 这里以飞凌开发板中的wifi启动脚本为例来介绍。 在飞凌开发板的串口中,执行如下命令(调用fltest_wif…

如何通过 JWT 来解决登录认证问题

1. 问题引入 在登录功能的实现中 传统思路&#xff1a; 登录页面时把用户名和密码提交给服务器服务器验证用户名和密码&#xff0c;并把检验结果返回给后端如果密码正确&#xff0c;则在服务器端创建 session&#xff0c;通过 cookie 把 session id 返回给浏览器 但是正常情…

【Linux】文件操作的艺术——从基础到精通

&#x1f3ac; 个人主页&#xff1a;谁在夜里看海. &#x1f4d6; 个人专栏&#xff1a;《C系列》《Linux系列》《算法系列》 ⛰️ 道阻且长&#xff0c;行则将至 目录 &#x1f4da;前言&#xff1a;一切皆文件 &#x1f4da;一、C语言的文件接口 &#x1f4d6;1.文件打…

Redis数据库详解

文章目录 Redis数据库1 Redis概述1.1 Redis介绍1.2 Redis特点1.3 应用场景1.4 Redis版本1.5 Redis附加功能1.6 Redis安装1.6.1 Rocky Linux操作系统1.6.2 Windows操作系统1.6.3 Mac操作系统 2 配置文件详解3 数据类型3.1 字符串String3.1.1 字符串3.1.2 数值 3.2 列表List3.3 H…

FFmpeg 4.3 音视频-多路H265监控录放C++开发十八,ffmpeg解封装

为啥要封装和解封装呢&#xff1f; 1.封装就相当于将 h264 和aac 包裹在一起。既能播放声音&#xff0c;也能播放视频 2.在封装的时候没指定编码格式&#xff0c;帧率&#xff0c;时长&#xff0c;等参数&#xff1b;特别是视频&#xff0c;可以将视频帧索引存储&#xff0c;…

JAVA |日常开发中常见问题归纳讲解

JAVA &#xff5c;日常开发中常见问题归纳讲解 前言一、语法错误相关问题1.1 分号缺失或多余1.2 括号不匹配1.3 变量未定义或重复定义 二、数据类型相关问题2.1 数据类型不匹配2.2 整数溢出和浮点数精度问题 三、面向对象编程相关问题3.1 空指针异常&#xff08;NullPointerExc…