pytorch集智-2单车预测器

完整代码在个人主页简介链接pytorch路径下可找到

1 单车预测器1.0

1.1 人工神经元

对于sigmoid函数来说,w控制函数曲线的方向,b控制曲线水平方向位移,w'控制曲线在y方向的幅度

1.2 多个人工神经元

模型如下

数学上可证,有限神经元绘制的曲线可以逼近任意有限区间内的曲线(闭区间连续函数有界)

1.3 模型与代码

通过训练可得到逼近真实曲线的神经网络参数

通过梯度下降法寻找局部最优(如何寻找全局最优后面考虑)

思考 n个峰需在一个隐层要多少隐单元?材料说3个峰10个单元就够了,理论上算,最少需要5个,可能保险起见,加其他一些不平滑处,就弄了10个

初次代码如下

from os import path
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import matplotlib.pyplot as plotDATA_PATH = path.realpath('pytorch/jizhi/bike/data/hour.csv')class Bike():def exec(self):self.prepare_data_and_params()self.train()def prepare_data_and_params(self):self.data = pd.read_csv(DATA_PATH)counts = self.data['cnt'][:50]self.x = torch.FloatTensor(np.arange(len(counts)))self.y = torch.FloatTensor(np.array(counts, dtype=float))self.size = 10self.weights = torch.randn((1, self.size), requires_grad=True)self.biases = torch.randn((self.size), requires_grad=True)self.weights2 = torch.randn((self.size, 1), requires_grad=True)def train(self):rate = 0.001losses = []x, y = self.x.view(50, -1), self.y.view(50, -1) # reshapefor num in range(30000):hidden = x * self.weights + self.biaseshidden = torch.sigmoid(hidden)predictions = hidden.mm(self.weights2)loss = torch.mean((predictions - y) ** 2)losses.append(loss.data.numpy())if num % 3000 == 0:print(f'loss: {loss}')loss.backward()self.weights.data.add_(- rate * self.weights.grad.data)self.biases.data.add_(- rate * self.biases.grad.data)self.weights2.data.add_(- rate * self.weights2.grad.data)self.weights.grad.data.zero_()self.biases.grad.data.zero_()self.weights2.grad.data.zero_()# plot loss#plot.plot(losses)#plot.xlabel('epoch')#plot.ylabel('loss')#plot.show()# plot predictx_data = x.data.numpy()plot.figure(figsize=(10, 7))xplot, = plot.plot(x_data, y.data.numpy(), 'o')yplot, = plot.plot(x_data, predictions.data.numpy())plot.xlabel('x')plot.ylabel('y')plot.legend([xplot, yplot], ['Data', 'prediction with 30000 epoch'])plot.show()def main():Bike().exec()if __name__ == '__main__':main()

拟合有问题,原因是拟合次数不够,为啥不够?从sklearn学习了解到,神经网络对输入参数敏感,一般来说需要对数据做标准化处理。具体来说,第一个隐层输出范围变成-50-50,0.0001学习率情况下100000次也不够,可以对数据做预处理,减小x跨度,变为0-1,可加快训练速度,进行如下改动再次训练

self.x = torch.FloatTensor(np.arange(len(counts))) / len(counts)

正确了,再取50个点预测一下

    def predict_and_plot(self):counts_predict = self.data['cnt'][50:100]x = torch.FloatTensor((np.arange(len(counts_predict), dtype=float) + 50) / 100)y = torch.FloatTensor(np.array(counts_predict, dtype=float))# num multiply replace matrix multiplyhidden = x.expand(self.size, len(x)).t() * self.weights.expand(len(x), self.size)hidden = torch.sigmoid(hidden)predictions = hidden.mm(self.weights2)loss = torch.mean((predictions - y) ** 2)print(f'loss: {loss}')x_data = x.data.numpy()plot.figure(figsize=(10, 7))xplot, = plot.plot(x_data, y.data.numpy(), 'o')yplot, = plot.plot(x_data, predictions.data.numpy())plot.xlabel('x')plot.ylabel('y')plot.legend([xplot, yplot], ['data', 'prediction'])plot.show()

预测失败,可能是过拟合

2 单车预测器2.0

2.1 数据预处理

通过上节学习和之前写的sklearn博客发现,神经网络训练前需要预处理数据,主要有1数值型变量需要范围标准化2数值型类型变量需处理为onehot。标准化可用sklearn的scaler,也可手动标准化,类型变量可用pd.get_dummies操作。直接开始操作

    def prepare_data_and_params_2(self):# type columns to dummyself.data = pd.read_csv(DATA_PATH)dummy_fields = ['season', 'weathersit', 'mnth', 'hr', 'weekday']for each in dummy_fields:dummies = pd.get_dummies(self.data[each], prefix=each, drop_first=False)self.data = pd.concat([self.data], dummies)drop_fields = ['season', 'weathersit', 'mnth', 'hr', 'weekday', 'instant', 'dteday', 'workingday', 'atemp']self.data = self.data.drop(drop_fields, axis=1)# decimal columns to scalerquant_features = ['cnt', 'temp', 'hum', 'windspeed']scaled_features = {}for each in quant_features:mean, std = self.data[each].mean(), self.data[each].std()scaled_features[each] = [mean, std]self.data.loc[:, each] = (self.data[each] - mean) / stdself.tr, self.te = self.data[:-21 * 24], self.data[-21 * 24:]target_fields = ['cnt', 'casual', 'registered']self.xtr, self.ytr = self.tr.drop(self.tr.drop[target_fields], axis=1), self.tr[target_fields]self.xte, self.yte = self.te.drop(self.te.drop[target_fields], axis=1), self.te[target_fields]self.x = self.xtr.valuesy = self.ytr.values.astype(float)self.y = np.reshape(y, [len(y), 1])        self.loss = []

2.2 构造神经网络

    def train_and_plot2(self):input_size = self.xtr.shape[1]hidden_size=10output_size=1batch_size=128neu = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size),torch.nn.Sigmoid(),torch.nn.Linear(hidden_size, output_size))cost = torch.nn.MSELoss()optimizer = torch.optim.SGD(neu.parameters(), lr=0.01)

2.3 数据批处理

为啥要批处理?如果数据太多,每个iter直接处理所有数据会比较慢

        for i in range(1000):batch_loss = []for start in range(0, len(self.x), batch_size):end = start + batch_size if start + batch_size < len(self.x) else len(self.x)xx = torch.FloatTensor(self.x[start:end])yy = torch.FloatTensor(self.y[start:end])predictions = neu(xx)loss = cost(predictions, yy)optimizer.zero_grad()loss.backward()optimizer.step()batch_loss.append(loss.data.numpy())if i % 100 == 0:self.loss.append(np.mean(batch_loss))print(i, np.mean(batch_loss))plot.plot(np.arange(len(self.loss)) * 100, self.loss)plot.xlabel('epoch')plot.ylabel('MSE')plot.show()

2.4 测试神经网络

原始数据是从2011-2012两个完整年,按教材,取2012最后21天作测试集预测

    def predict_and_plot2(self):targets = self.yte['cnt']targets = targets.values.reshape([len(targets), 1]).astype(float)x = torch.FloatTensor(self.xte.values.astype(float))y = torch.FloatTensor(targets)predict = self.neu(x)predict = predict.data.numpy()fig, ax = plot.subplots(figsize=(10, 7))mean, std = self.scaled_features['cnt']ax.plot(predict * std + mean, label='prediction')ax.plot(targets * std + mean, label='data')ax.legend()ax.set_xlabel('date-time')ax.set_ylabel('counts')dates = pd.to_datetime(self.rides.loc[self.te.index]['dteday'])dates = dates.apply(lambda d: d.strftime('%b %d'))ax.set_xticks(np.arange(len(dates))[12::24])ax.set_xticklabels(dates[12::24], rotation=45)plot.show()

发现2012最后21天前半段还行,后半段有差异,看日历发现临近圣诞节,可能不能用正常日程预测

2.5 改进与分析(重要)

这节有啥用?上节圣诞节预测不准,为啥?这节可以通过分析神经网络回答这个问题

怎么分析?本节主要通过分析神经网络参数来在底层寻找原因,帮助分析问题

在异常处将多个神经源绘制独自的曲线,绘制其图像,分析找原因,比如趋势相同,趋势相反这种曲线,重点分析对象。适用于神经元较少,可以一个一个神经元看,多了就不行了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/602486.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux查找命令@which、find

目录 which概念语法作用 find概念语法按文件名查找按文件大小查找 作用演示一演示二演示三 通配符 which 概念 which 是一个常用的 Linux/Unix 命令&#xff0c;用于查找并显示指定命令的绝对路径。 语法 which 要查找的命令 》无参数。 》 which后面&#xff0c;跟要查找绝对…

【CentOS 7.9】死机卡住如何处理

一、解决办法 1.打开tty2 按下组合键&#xff1a;ctrl alt F2 进入 tty2 2.进入 root 权限 su root3.杀死该用户的所有进程&#xff08;相当于 windows 里面的注销用户&#xff09; 请注意&#xff0c;用户名应该全部使用小写字母&#xff0c;如我的用户名叫 Ragdoll&am…

摄像头视频录制程序使用教程(Win10)

摄像头视频录制程序-Win10 &#x1f957;介绍&#x1f35b;使用说明&#x1f6a9;config.json 说明&#x1f6a9;启动&#x1f6a9;关闭&#x1f6a9;什么时候开始录制&#xff1f;&#x1f6a9;什么时候触发录制&#xff1f;&#x1f6a9;调参 &#x1f957;介绍 检测画面变化…

Javaweb之Mybatis的基础操作之查询操作的详细解析

1.6 查询 1.6.1 根据ID查询 在员工管理的页面中&#xff0c;当我们进行更新数据时&#xff0c;会点击 “编辑” 按钮&#xff0c;然后此时会发送一个请求到服务端&#xff0c;会根据Id查询该员工信息&#xff0c;并将员工数据回显在页面上。 SQL语句&#xff1a; select id,…

大型语言模型的幻觉问题

1.什么是大模型幻觉&#xff1f; 在语言模型的背景下&#xff0c;幻觉指的是一本正经的胡说八道&#xff1a;看似流畅自然的表述&#xff0c;实则不符合事实或者是错误的。 幻觉现象的存在严重影响LLM应用的可靠性&#xff0c;本文将探讨大型语言模型(LLMs)的幻觉问题&#x…

求两个数之间的最小公约数

目录 前言 方法&#xff1a;求两个数之间的最小公约数 1.欧几里得算法 2.枚举法 3.公共因子积 4.更相减损术 5.Stein算法 解题&#xff1a;在链表中插入最大公约数 总结 前言 今天刷每日一题&#xff1a;2807. 在链表中插入最大公约数 - 力扣&#xff08;LeetCode&#xff09;…

基于X86的助力智慧船载监控系统

船载综合监控系统结合雷达、AIS、CCTV、GPS等探测技术&#xff0c;以及高度融合的实时态势与认知技术&#xff0c;实现对本船以及范围内船舶的有效监控&#xff0c;延伸岸基监控中心监管范围&#xff0c;保障行船安全&#xff0c;为船舶安全管理部门实现岸基可控的数据通信和动…

第 121 场 LeetCode 双周赛题解

A 大于等于顺序前缀和的最小缺失整数 模拟&#xff1a;先求最长顺序前缀的和 s s s &#xff0c;然后从 s s s 开始找没有出现在 n u m s nums nums 中的最小整数 class Solution { public:int missingInteger(vector<int> &nums) {unordered_set<int> vis(…

如何批量自定义视频画面尺寸

在视频制作和编辑过程中&#xff0c;对于视频画面尺寸的调整是一项常见的需求。有时候&#xff0c;为了适应不同的播放平台或满足特定的展示需求&#xff0c;我们需要对视频尺寸进行批量调整。那么&#xff0c;如何实现批量自定义视频画面尺寸呢&#xff1f;本文将为您揭示这一…

LLM之RAG实战(十三)| 利用MongoDB矢量搜索实现RAG高级检索

想象一下&#xff0c;你是一名侦探&#xff0c;身处庞大的信息世界&#xff0c;试图在堆积如山的数据中找到隐藏的一条重要线索&#xff0c;这就是检索增强生成&#xff08;RAG&#xff09;发挥作用的地方&#xff0c;它就像你在人工智能和语言模型世界中的可靠助手。但即使是最…

小心JDK20 ZipOutputStream

Oracle 團隊竟然這麽粗心&#xff0c;編譯JDK 20 時ZipOutputStream沒有編譯成功就發佈了。 所以這個20版本不可以使用ZipOutputStream。 GZIPInputStream 只能做最後的壓縮&#xff0c;不能添加多個附件ZipEntry。 下一個版本21不存在這個問題。 try(var zipOut new ZipOu…

数据分析——火车信息

任务目标 任务 1、整理火车发车信息数据&#xff0c;结果的表格形式为&#xff1a; 2、并输出最终的发车信息表 难点 1、多文件 一个文件夹&#xff0c;多个月的发车信息&#xff0c;一个excel&#xff0c;放一天的发车情况 2、数据表的格式特殊 如何分析表是一个难点 数…

案例102:基于微信小程序的旅游社交管理系统设计与实现

文末获取源码 开发语言&#xff1a;Java 框架&#xff1a;SSM JDK版本&#xff1a;JDK1.8 数据库&#xff1a;mysql 5.7 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven3.5.4 小程序框架&#xff1a;uniapp 小程序开发软件&#xff1a;HBuilder X 小程序…

解决VMware 虚拟机 ubuntu 20.04 异常关闭导致虚拟网卡 ens33 无法工作问题

问题描述 由于经常使用 SSH 远程链接 VMware 中的虚拟机 ubuntu&#xff0c;每次关闭都是挂起&#xff0c;时间久了&#xff0c;虚拟机运行有些卡顿了&#xff0c;此时可以通过 Linux 命令重启或者关闭 ubuntu&#xff0c;也可以之间使用 VMWare 中的【虚拟机】-- 【电源】-&g…

SiC电机控制器(逆变器)发展概况及技术方向

SiC电机控制器&#xff08;逆变器&#xff09;发展概况及技术方向 1.概述2.电动汽车动力系统设计趋势3.栅极驱动器和驱动电源配置4.结论 tips&#xff1a;资料来自网上搜集&#xff0c;仅供学习使用。 1.概述 2022年到2023年&#xff0c;第三代半导体碳化硅被推上了新的热潮。…

前端uniapp的tab选项卡for循环切换、开通VIP实战案例【带源码/最新】

目录 效果图图1图2 源码最后 这个案例是uniapp&#xff0c;同样也适用Vue项目&#xff0c;语法一样for循环&#xff0c;点击切换 效果图 图1 图2 源码 直接代码复制查看效果 <template><view class"my-helper-service-pass"><view class"tab…

第14课 利用openCV快速数豆豆

除了检测运动&#xff0c;openCV还能做许多有趣且实用的事情。其实openCV和FFmpeg一样都是宝藏开源项目&#xff0c;貌似简单的几行代码功能实现背后其实是复杂的算法在支撑。有志于深入学习的同学可以在入门后进一步研究算法的实现&#xff0c;一定会受益匪浅。 这节课&#…

(Python + Selenium4)Web自动化测试自学Day1

目录 文章声明⭐⭐⭐让我们开始今天的学习吧&#xff01;自动打开Chrome浏览器实现自动搜索元素定位常用的元素定位方式By.IDBy.CLASS_NAMEBy.TAG_NAMEBy.NAMEBy.LINK_TEXTBy.PARTIAL_LINK_TEXTBy.CSS_SELECTOR根据id定位根据class定位根据属性定位组合定位 By.XPATH 文章声明⭐…

#error 在C语言中的作用

1、#error命令是C/C语言的预处理命令之一 #error 是C语言中的预处理指令之一&#xff0c;用于在编译时生成一个错误消息。当编译器遇到 #error 指令时&#xff0c;会立即停止编译&#xff0c;并将指定的错误消息输出到编译器的错误信息中。 在给定的代码中&#xff0c;#error…

玩转Mysql 二(MySQL的目录结构与表结构)

一路走来,所有遇到的人,帮助过我的、伤害过我的都是朋友,没有一个是敌人。 一、MYSQL目录结构及命令存放路径 1、查看MYSQL数据文件存放路径 mysql> show variables like datadir; 注意:生成环境要提前规划好数据存放目录,存储一般以T为单位闪盘。 2、MYSQL命令存放…