(二)Pytorch快速搭建神经网络模型实现气温预测回归(代码+详细注解)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • 一、数据集
  • 二、导入数据以及展示部分
    • 1.导入数据集以及对数据集进行处理
    • 2.展示数据(看看就好)
  • 三(1)、搭建网络进行预测(理解版)
  • 三(2)、搭建网络进行预测(应用版)
  • 四、 对预测结果进行一个展示,蓝色真实值,红色预测值
  • 总结


前言

深度学习pytorch系列第二篇,第一篇实现的是分类任务,这篇是回归任务,大差不差,重在理解,具体的理解内容我都以注释的形式放在了代码中,方便大家阅读


一、数据集

想要复现的可以下载
链接:网盘链接
提取码:k6a4

二、导入数据以及展示部分

1.导入数据集以及对数据集进行处理

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
# 过滤警告
import warnings
warnings.filterwarnings("ignore")
# 读取数据
features = pd.read_csv('data/temps.csv')
#
#看看数据长什么样子
# print(features.head())
# print('数据维度:', features.shape)
# 数据维度:(348, 9),348条数据,每条8个特征x,1个标签y
# 处理时间数据
import datetime
# 分别得到年,月,日
years = features['year']
months = features['month']
days = features['day']
#
# # datetime格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]
# 在打印的结果中,每个datetime.datetime对象的后面两个0表示小时和分钟,没有时默认为0
# print(dates[:5])
# 独热编码
# # 将字符串进行onehot
# # 周一 周二 周三 周四 周五 周六 周天
# # 如果是周一,编码就是
# # 1000000
# Pandas库中的get_dummies函数,是一种独热编码(One-Hot Encoding)的方法
features = pd.get_dummies(features)# print(features.head(5))
# print(features.shape)
# 此时的数据维度:(348, 15),多的7个是日期的七天
# 取标签
labels = np.array(features['actual'])
# 在特征中去掉标签,features.drop,去掉标签列
features= features.drop('actual', axis = 1)
# 名字单独保存一下,以备后患
feature_list = list(features.columns)
# 转换成合适的格式
features = np.array(features)
# print(features.shape)
# print(features)
"""
数据标准化
由于神经网络在训练的过程中具有倾向性,数值越大,认为越重要
# 但是在月份这种重要程度与数值无关的特征上,这种倾向性就会出错
# 因此进行标准化,使数据以零点为中心均匀分布
# (x-u)/σ
# x-u  去均值
# /σ  除以标准差:让离散数据更加收敛
标准化通常是针对特征而不是标签的。
标准化的目的是使特征具有相同的尺度,以便模型能够更好地学习权重并提高模型的性能。
标签(也称为目标变量)通常不需要标准化,因为它们是模型试图预测的值,而不是用于学习权重的输入。
"""
from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)
"""
[ 0.         -1.5678393  -1.65682171 -1.48452388 -1.49443549 -1.3470703-1.98891668  2.44131112 -0.40482045 -0.40961596 -0.40482045 -0.40482045-0.41913682 -0.40482045]标准化处理后的数据以零点为中心,均匀分布
"""

上述代码中的初始数据集为:
在这里插入图片描述
处理完成后的数据样貌:
在这里插入图片描述

2.展示数据(看看就好)

代码如下(示例):

# 该段是展示一下数据的样貌
plt.style.use('fivethirtyeight')
# 设置布局
# 4个子图,两行两列
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize = (10,10))
# 坐标倾斜45度
fig.autofmt_xdate(rotation = 45)# 标签值
ax1.plot(dates, features['actual'])
ax1.set_xlabel(''); ax1.set_ylabel('Temperature'); ax1.set_title('Max Temp')
# 昨天
ax2.plot(dates, features['temp_1'])
ax2.set_xlabel(''); ax2.set_ylabel('Temperature'); ax2.set_title('Previous Max Temp')
#
# 前天
ax3.plot(dates, features['temp_2'])
ax3.set_xlabel('Date'); ax3.set_ylabel('Temperature'); ax3.set_title('Two Days Prior Max Temp')
#
# 朋友感觉的值
ax4.plot(dates, features['friend'])
ax4.set_xlabel('Date'); ax4.set_ylabel('Temperature'); ax4.set_title('Friend Estimate')
# 子图之间间隔多少
plt.tight_layout(pad=2)
plt.show()

展示图如下:
在这里插入图片描述


三(1)、搭建网络进行预测(理解版)

该过程是一步一步构建网络,促进理解,后边会附上更为简单的网络结构


x = torch.tensor(input_features, dtype=float)
y = torch.tensor(labels, dtype=float)
# # 权重参数初始化
# (14, 128),将14个特征转成128个神经元,可以理解为转成128个特征
# requires_grad = True,是否求导,也就是是否记录梯度
weights = torch.randn((14, 128), dtype=float, requires_grad=True)
biases = torch.randn(128, dtype=float, requires_grad=True)
weights2 = torch.randn((128, 1), dtype=float, requires_grad=True)
biases2 = torch.randn(1, dtype=float, requires_grad=True)
# 学习率  :决定梯度更新幅度的大小,计算出来的梯度只能确定方向
# 这个幅度不能太大
learning_rate = 0.001
losses = []
# 迭代次数,每次算梯度,然后更新
for i in range(1000):# 计算隐层hidden = x.mm(weights) + biases# 加入激活函数,非线性映射hidden = torch.relu(hidden)# 预测结果  :h1*w2+b2=预测值predictions = hidden.mm(weights2) + biases2# 通计算损失loss = torch.mean((predictions - y) ** 2)losses.append(loss.data.numpy())# 打印损失值if i % 100 == 0:print('loss:', loss)# 返向传播计算loss.backward()# 更新参数#     grad.data  取梯度,然后乘以学习率,应该沿着梯度的反方向更新weights.data.add_(- learning_rate * weights.grad.data)biases.data.add_(- learning_rate * biases.grad.data)weights2.data.add_(- learning_rate * weights2.grad.data)biases2.data.add_(- learning_rate * biases2.grad.data)# 每次迭代都得记得清空#     每次迭代过程都是独立的,之前计算的梯度要清零# 在torch中,如果不清零,梯度就会累加weights.grad.data.zero_()biases.grad.data.zero_()weights2.grad.data.zero_()biases2.grad.data.zero_()
print(predictions.shape)
print(predictions)

三(2)、搭建网络进行预测(应用版)

实际应用中,往往会这样实现

# 更简单的构建网络模型
# 取特征个数
# 0是样本数;1是特征数
input_size = input_features.shape[1]
# print(input_size)  14 有14个特征
# 隐层个数
hidden_size = 128
output_size = 1
batch_size = 16
# Sequential序列模块,按顺序执行
my_nn = torch.nn.Sequential(# 计算隐层,相当于wx+b,参数是自动更新的torch.nn.Linear(input_size, hidden_size),
#     激活函数torch.nn.Sigmoid(),
#     预测结果  :h1*w2+b2=预测值torch.nn.Linear(hidden_size, output_size),
)
# 计算损失
# reduction='mean  平均损失
cost = torch.nn.MSELoss(reduction='mean')
# 优化器
# my_nn.parameters() 更新nn中所有参数
optimizer = torch.optim.Adam(my_nn.parameters(), lr = 0.001)
# ADM优化器,比SGD(梯度下降)效果好,效率高
# 训练网络
losses = []
# 迭代1000次
for i in range(1000):#     每次取一个batch的数据,每次只取一批数据batch_loss = []# MINI-Batch方法来进行训练#   for start in range(0, len(input_features), batch_size):# 从0开始,到整个数据结束,取batch,间隔是一个batch_size大小for start in range(0, len(input_features), batch_size):end = start + batch_size if start + batch_size < len(input_features) else len(input_features)  # 判断索引越界xx = torch.tensor(input_features[start:end], dtype=torch.float, requires_grad=True)yy = torch.tensor(labels[start:end], dtype=torch.float, requires_grad=True)prediction = my_nn(xx)loss = cost(prediction, yy)#         通过优化器进行梯度清零optimizer.zero_grad()#     反向传播loss.backward(retain_graph=True)#     更新参数optimizer.step()#     将每一个batch的损失相加batch_loss.append(loss.data.numpy())# 打印损失if i % 100 == 0:losses.append(np.mean(batch_loss))print(i, np.mean(batch_loss))
x = torch.tensor(input_features, dtype = torch.float)
# 所有的数据进行预测,得到结果,进行画图
predict = my_nn(x).data.numpy()

四、 对预测结果进行一个展示,蓝色真实值,红色预测值

# 转换日期格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]# 创建一个表格来存日期和其对应的标签数值
true_data = pd.DataFrame(data = {'date': dates, 'actual': labels})# 同理,再创建一个来存日期和其对应的模型预测值
months = features[:, feature_list.index('month')]
days = features[:, feature_list.index('day')]
years = features[:, feature_list.index('year')]test_dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]test_dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in test_dates]predictions_data = pd.DataFrame(data = {'date': test_dates, 'prediction': predict.reshape(-1)})
# 真实值
plt.plot(true_data['date'], true_data['actual'], 'b-', label = 'actual')# 预测值
plt.plot(predictions_data['date'], predictions_data['prediction'], 'ro', label = 'prediction')
plt.xticks(rotation = '60');
plt.legend()
plt.show()
# 图名
plt.xlabel('Date'); plt.ylabel('Maximum Temperature (F)'); plt.title('Actual and Predicted Values');
# 层数越来越对,就会过拟合
# 什么是过拟合?过拟合(Overfitting)是指机器学习模型在训练数据上表现得很好,但在未见过的新数据上表现较差的现象。

在这里插入图片描述

总结

pytorch学习的第二篇啦,慢慢更新ing

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/149092.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

004 OpenCV akaze特征点检测匹配

目录 一、环境 二、akaze特征点算法 2.1、基本原理 2.2、实现过程 2.3、实际应用 2.4、优点与不足 三、代码 3.1、数据准备 3.2、完整代码 一、环境 本文使用环境为&#xff1a; Windows10Python 3.9.17opencv-python 4.8.0.74 二、akaze特征点算法 特征点检测算法…

【数字图像处理】Gamma 变换

在数字图像处理中&#xff0c;Gamma 变换是一种重要的灰度变换方法&#xff0c;可以用于图像增强与 Gamma 校正。本文主要介绍数字图像 Gamma 变换的基本原理&#xff0c;并记录在紫光同创 PGL22G FPGA 平台的布署与实现过程。 目录 1. Gamma 变换原理 2. FPGA 布署与实现 2…

JSP 四大域对象

我们来说说JSP的四大域对象 首先 我们要了解他们是四种保存范围 第一种 是 Page范围 只作用于当前界面 只要页面跳转了 其他页面就拿不到了 第二种 request范围 在一次请求中有效 就是 我们服务端指向某个界面 并传递数据给他 那么 如果你是客户端跳转就不生效了 第三种 sessi…

经典ctf ping题目详解 青少年CTF-WEB-PingMe02

题目环境&#xff1a; 根据题目名称可知 这是一道CTF-WEB方向常考的知识点&#xff1a;ping地址 随便ping一个地址查看接受的数据包?ip0.0.0.0 有回显数据&#xff0c;尝试列出目录文件 堆叠命令使用’;作为命令之间的连接符&#xff0c;当上一个命令完成后&#xff0c;继续执…

深度学习二维码识别 计算机竞赛

文章目录 0 前言2 二维码基础概念2.1 二维码介绍2.2 QRCode2.3 QRCode 特点 3 机器视觉二维码识别技术3.1 二维码的识别流程3.2 二维码定位3.3 常用的扫描方法 4 深度学习二维码识别4.1 部分关键代码 5 测试结果6 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天…

windows中运行项目中.sh和kaggle安装与配置

在git bash中运行 命令如下&#xff1a; bash download_data.sh 或者 ./download_data.sh如果使用kaggle的数据集&#xff0c;会要求输入用户名和API。 API在这个文件里面&#xff0c;复制过来即可。 安装kaggle pip install kaggle去kaggle官网&#xff0c;点击这里&…

[一周AI简讯]OpenAI宫斗;微软Bing Chat更名Copilot;Youtube测试音乐AI

OpenAI宫斗&#xff0c;奥特曼被解雇&#xff0c;董事会内讧 Sam Altman被解雇&#xff0c;不再担任CEO&#xff0c;董事会的理由是奥特曼在与董事会的沟通中始终不坦诚&#xff0c;阻碍了董事会履行职责的能力。原首席技术官Mira Murati担任新CEO。OpenAI宫斗剧远未结束&…

【数据结构】图的简介(图的逻辑结构)

一.引例&#xff08;哥尼斯堡七桥问题&#xff09; 哥尼斯堡七桥问题是指在哥尼斯堡市&#xff08;今属俄罗斯&#xff09;的普雷格尔河&#xff08;Pregel River&#xff09;中&#xff0c;是否可以走遍每座桥一次且仅一次&#xff0c;最后回到起点的问题。这个问题被认为是图…

达梦数据库常用参数查询

字符集 字符是各种文字和符号的统称&#xff0c;包括各个国家文字、标点符号、表情、数字等等。 字符集 就是一系列字符的集合。字符集的种类较多&#xff0c;每个字符集可以表示的字符范围通常不同&#xff0c;就比如说有些字符集是无法表示汉字的。 常见的字符集有 ASCII、G…

开发知识点-前端-webpack

webpack技术笔记 一、 介绍二、 下载使用 一、 介绍 Webpack是一个现代 JavaScript 应用程序的静态模块打包器 打包&#xff1a;可以把js、css等资源按模块的方式进行处理然后再统一打包输出 静态&#xff1a;最终产出的静态资源都可以直接部署到静态资源服务器上进行使用 模…

C#开发的OpenRA游戏之属性QuantizeFacingsFromSequence(7)

C#开发的OpenRA游戏之属性QuantizeFacingsFromSequence(7) 前面分析了身体的方向,在这里继续QuantizeFacingsFromSequence属性,这个属性就是通过序列定义文件里获取身体的方向。 根据前面分析可知,同样有一个信息类QuantizeFacingsFromSequenceInfo: [Desc("Deriv…

组件插槽,生命周期,轮播图组件的封装,自定义指令的封装等详解以及axios的卖座案例

3.组件插槽 3-1组件插槽 注意 插槽内容可以访问到父组件的数据作用域,因为插槽内容本身就是在父组件模版中定义的 插槽内容无法访问子组件的数据.vue模版中的表达式只能访问其定义时所处的作用域,这和JavaScript的词法作用域是一致的,换言之: 父组件模版的表达式只能访问父组…

金属压块液压打包机比例阀放大器

液压打包机是机电一体化产品&#xff0c;主要由机械系统、液压控制系统、上料系统与动力系统等组成。整个打包过程由压包、回程、提箱、转箱、出包上行、出包下行、接包等辅助时间组成。市场上液压打包机主要分为卧式与立式两种&#xff0c;立式废纸打包机的体积比较小&#xf…

CI/CD -gitlab

目录 一、常用命令 二、部署 一、常用命令 官网&#xff1a;https://about.gitlab.com/install/ gitlab-ctl start # 启动所有 gitlab 组件 gitlab-ctl stop # 停止所有 gitlab 组件 gitlab-ctl restart # 重启所有 gitlab 组件 gitlab-ctl statu…

Deque继承ArrayDeque和继承LinkedList区别在哪里

在Java中&#xff0c;ArrayDeque和LinkedList都是Deque接口的实现类&#xff0c;但它们的内部实现和性能特性有一些不同。 ArrayDeque&#xff1a; 内部实现&#xff1a;ArrayDeque使用动态数组&#xff08;resizable array&#xff09;来实现&#xff0c;它允许在两端高效地进…

大白话解释什么类加载机制

大家好&#xff0c;我是伍六七。 今天我们来聊聊一个 Java 面试必考基础题目&#xff1a;类加载机制和双亲委派机制。 Java 类的加载机制是 Java 虚拟机&#xff08;JVM&#xff09;中类加载&#xff08;Class Loading&#xff09;和链接&#xff08;Linking&#xff09;的过…

LeetCode27.移除元素(暴力法、快慢指针法)

每日一题&#xff1a;LeetCode27.移除元素 1.问题描述2.解题思路3.代码 1.问题描述 问题描述&#xff1a;给你一个数组 nums 和一个值 val&#xff0c;你需要 原地 移除所有数值等于 val 的元素&#xff0c;并返回移除后数组的新长度。不要使用额外的数组空间&#xff0c;你必…

有趣的按钮分享

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。 广告打完&#xff0c;我们进入正题&#xff0c;先看效果&#xff1a; 废话不多&#xff0c;上源码&#xff1a; <button class&quo…

租赁小程序|租赁系统一种新型的商业模式

租赁市场是一个庞大的市场&#xff0c;它由出租人和承租人组成&#xff0c;以及相关的中介机构和供应商等。随着经济的发展和人们对灵活性的需求增加&#xff0c;租赁市场也在不断发展和壮大。特别是在共享经济时代&#xff0c;租赁市场得到了进一步的推动和发展。租赁系统是一…

如何在Docker部署Draw.io绘图工具并远程访问

文章目录 前言1. 使用Docker本地部署Drawio2. 安装cpolar内网穿透工具3. 配置Draw.io公网访问地址4. 公网远程访问Draw.io 前言 提到流程图&#xff0c;大家第一时间可能会想到Visio&#xff0c;不可否认&#xff0c;VIsio确实是功能强大&#xff0c;但是软件为收费&#xff0…