pytorch神经网络因素预测_实战:使用PyTorch构建神经网络进行房价预测

微信公号:ilulaoshi / 个人网站:lulaoshi.info

本文将学习一下如何使用PyTorch创建一个前馈神经网络(或者叫做多层感知机,Multiple-Layer Perceptron,MLP),文中会使用PyTorch提供的自动求导功能,训练一个神经网络。

本文的数据集来自Kaggle竞赛:房价预测(https://www.kaggle.com/c/house-prices-advanced-regression-techniques/)。这份数据分为训练数据集和测试数据集。两个数据集都包括每栋房子的特征,如建造年份、地下室状况等特征值。这些特征中,有连续的数值型(Numerical)特征,有离散的分类(Categorical)特征。这些特征中,有些特征值是缺失值“na”。训练数据集包括了每栋房子的价格,也就是需要预测的目标值(Label)。我们应该用训练数据集训练一个模型,并对测试数据集进行预测,然后将结果提交到Kaggle。

数据探索和预处理

首先,我们下载并加载数据集:

train_data_path ='./dataset/train.csv'

train = pd.read_csv(train_data_path)

num_of_train_data = train.shape[0]

test_data_path ='./dataset/test.csv'

test = pd.read_csv(test_data_path)

训练数据集共1460个样本,81个维度,其中,Id是每个样本的唯一编号,SalePrice是房价,也是我们要拟合的目标值。其他维度(列)有数值类特征,也有非数值列,或者叫分类特征。

先查看训练数据集的维度:

train.shape

输出为:

(1460, 81)

或者通过train.describe()来查看整个数据集各个特征的一些统计情况。

接着,我们要把训练数据集和测试数据集合并。将训练数据集和测试数据集合并主要是为了统一特征处理的流程,或者说对训练数据集和测试数据集使用同样的方法,进行同样的特征工程处理。

# 房价,要拟合的目标值

target = train.SalePrice

# 输入特征,可以将SalePrice列扔掉

train.drop(['SalePrice'],axis = 1 , inplace = True)

# 将train和test合并到一起,一块进行特征工程,方便预测test的房价

combined = train.append(test)

combined.reset_index(inplace=True)

combined.drop(['index', 'Id'], inplace=True, axis=1)

接着就要开始进行特征工程了。本文没有进行任何复杂的特征工程,只做了两件事:1、过滤掉了含有缺失值的列;2、对分类特征进行了One-Hot编码。缺失值会在一定程度上影响算法的预测效果,一般可以使用一些默认值或者一些临近值来填充缺失值。对于MLP模型,分类特征必须经过编码,转换成数值才能进行模型训练,One-Hot编码是一种最常见的分类特征处理的方法。

我们用下面的函数过滤非空列:

# 选出非空列

def get_cols_with_no_nans(df,col_type):

'''

Arguments :

df : The dataframe to process

col_type :

num : to only get numerical columns with no nans

no_num : to only get nun-numerical columns with no nans

all : to get any columns with no nans

'''

if (col_type == 'num'):

predictors = df.select_dtypes(exclude=['object'])

elif (col_type == 'no_num'):

predictors = df.select_dtypes(include=['object'])

elif (col_type == 'all'):

predictors = df

else :

print('Error : choose a type (num, no_num, all)')

return 0

cols_with_no_nans = []

for col in predictors.columns:

if not df[col].isnull().any():

cols_with_no_nans.append(col)

return cols_with_no_nans

分别对数值特征和分类特征进行处理:

num_cols = get_cols_with_no_nans(combined, 'num')

cat_cols = get_cols_with_no_nans(combined, 'no_num')

# 过滤掉含有缺失值的特征

combined = combined[num_cols + cat_cols]

print(num_cols[:5])

print ('Number of numerical columns with no nan values: ',len(num_cols))

print(cat_cols[:5])

print ('Number of non-numerical columns with no nan values: ',len(cat_cols))

经过过滤,数值特征共有25列,分类特征共有20列,共45列。

# 对分类特征进行One-Hot编码

def oneHotEncode(df,colNames):

for col in colNames:

if( df[col].dtype == np.dtype('object')):

# pandas.get_dummies 可以对分类特征进行One-Hot编码

dummies = pd.get_dummies(df[col],prefix=col)

df = pd.concat([df,dummies],axis=1)

# drop the encoded column

df.drop([col],axis = 1 , inplace=True)

return df

对于分类特征,还需要进行One-Hot编码,pandas.get_dummies可以帮我们自动完成One-Hot编码过程。经过One-Hot编码后,数据增加了很多列,共有149列。

至此,我们完成了一次非常简单的特征工程,将这些数据转化为PyTorch模型所能接受的Tensor形式:

# 训练数据集特征

train_features = torch.tensor(combined[:num_of_train_data].values, dtype=torch.float)

# 训练数据集目标

train_labels = torch.tensor(target.values, dtype=torch.float).view(-1, 1)

# 测试数据集特征

test_features = torch.tensor(combined[num_of_train_data:].values, dtype=torch.float)

print("train data size: ", train_features.shape)

print("label data size: ", train_labels.shape)

print("test data size: ", test_features.shape)

构建神经网络

接着,我们开始构建神经网络。

在PyTorch中构建神经网络有两种方式。比较简单的前馈网络,可以使用nn.Sequential。nn.Sequential是一个存放神经网络的容器,直接在nn.Sequential里面添加我们需要的层即可。整个模型的输入为特征数,输出为一个标量。模型的隐藏层使用了ReLU激活函数,最后一层是一个线性层,得到的是一个预测的房价值。

model_sequential = nn.Sequential(

nn.Linear(train_features.shape[1], 128),

nn.ReLU(),

nn.Linear(128, 256),

nn.ReLU(),

nn.Linear(256, 256),

nn.ReLU(),

nn.Linear(256, 256),

nn.ReLU(),

nn.Linear(256, 1)

)

另一种构建神经网络的方式是继承nn.Module类,我们将子类起名为Net类。__init__()方法为Net类的构造函数,用来初始化神经网络各层的参数;forward()也是我们必须实现的方法,主要用来实现神经网络的前向传播过程。

class Net(nn.Module):

def __init__(self, features):

super(Net, self).__init__()

self.linear_relu1 = nn.Linear(features, 128)

self.linear_relu2 = nn.Linear(128, 256)

self.linear_relu3 = nn.Linear(256, 256)

self.linear_relu4 = nn.Linear(256, 256)

self.linear5 = nn.Linear(256, 1)

def forward(self, x):

y_pred = self.linear_relu1(x)

y_pred = nn.functional.relu(y_pred)

y_pred = self.linear_relu2(y_pred)

y_pred = nn.functional.relu(y_pred)

y_pred = self.linear_relu3(y_pred)

y_pred = nn.functional.relu(y_pred)

y_pred = self.linear_relu4(y_pred)

y_pred = nn.functional.relu(y_pred)

y_pred = self.linear5(y_pred)

return y_pred

我们已经定义好了一个神经网络的Net类,还要初始化一个Net类的对象实例model,表示某个具体的模型。然后定义损失函数,这里使用MSELoss,MSELoss使用了均方误差(Mean Square Error)来衡量损失函数。对于模型model的训练过程,这里使用Adam算法。Adam是优化算法中的一种,在很多场景中效率要优于SGD。

model = Net(features=train_features.shape[1])

# 使用均方误差作为损失函数

criterion = nn.MSELoss(reduction='mean')

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

训练模型

接着,我们使用Adam算法进行多轮的迭代,更新模型model中的参数。这里对模型进行500轮的迭代。

losses = []

# 训练500轮

for t in range(500):

y_pred = model(train_features)

loss = criterion(y_pred, train_labels)

# print(t, loss.item())

losses.append(loss.item())

if torch.isnan(loss):

break

# 将模型中各参数的梯度清零。

# PyTorch的backward()方法计算梯度会默认将本次计算的梯度与缓存中已有的梯度加和。

# 必须在反向传播前先清零。

optimizer.zero_grad()

# 反向传播,计算各参数对于损失loss的梯度

loss.backward()

# 根据刚刚反向传播得到的梯度更新模型参数

optimizer.step()

每次迭代使用训练数据集中的所有样本train_features。model(train_features)实际是执行的model.forward(train_features),即forward()方法中定义的前向传播逻辑,输入数据在神经网络模型中前向传播,得到预测值y_pred。criterion(y_pred, train_labels)方法计算了预测值y_pred和目标值train_labels之间的损失。

每次迭代时,我们要先对模型中各参数的梯度清零:optimizer.zero_grad()。PyTorch中的backward()默认是把本次计算的梯度和缓存中已有的梯度加和,因此必须在反向传播前先将梯度清零。接着执行backward()方法,完成反向传播过程,PyTorch会帮我们计算各参数对于损失函数的梯度。optimizer.step()会根据刚刚反向传播得到的梯度,更新模型参数。

至此,一个简单的预测房价的模型就训练好了。

测试模型

我们可以使用模型对测试数据集进行预测,将得到的预测值保存成文件,提交到Kaggle上。

predictions = model(test_features).detach().numpy()

my_submission = pd.DataFrame({'Id':pd.read_csv('./dataset/test.csv').Id,'SalePrice': predictions[:, 0]})

my_submission.to_csv('{}.csv'.format('./dataset/submission'), index=False)

参考资料

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/457549.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL基本操作

SQL 操作 检索数据 SELECT 检索数据 -- 检索单个列 SELECT 列名 FROM table_name;-- 检索多个列 SELECT 列1, 列2 FROM table_name;-- 检索所有列 SELECT * FROM table_name;-- 检索不同的值 SELECT DISTINCT 列名 FROM table_name;限制检索结果 -- SQL Server / Access SE…

git 忽略 部分文件夹_git提交忽略某些文件或文件夹

记得第一次用 github 提交代码,node_modules 目录死活传不上去,哈哈哈,后来才知道在 .gitignore 文件里设置了忽略 node_modules 目录上传。是的, .gitignore 文件就是设置那些你不想用 git 一起上传的文件和文件夹。比如刚接触到…

Ajax实现原理详解

Ajax:Asynchronous javascript and xml,实现了客户端与服务器进行数据交流过程。使用技术的好处是:不用页面刷新,并且在等待页面传输数据的同时可以进行其他操作。 这就是异步调用的很好体现。首先得了解什么是异步和同步的概念。…

SpringJDBC解析3-回调函数(update为例)

PreparedStatementCallback作为一个接口,其中只有一个函数doInPrepatedStatement,这个函数是用于调用通用方法execute的时候无法处理的一些个性化处理方法,在update中的函数实现: protected int update(final PreparedStatementCr…

python上下文管理器

DAY 23. python上下文管理器 Python 的 with 语句支持通过上下文管理器所定义的运行时上下文这一概念。 此对象的实现使用了一对专门方法,允许用户自定义类来定义运行时上下文,在语句体被执行前进入该上下文,并在语句执行完毕时退出该上下文&…

勾股定理python思路_趣叮咚编程数学揭秘:为什么勾股定理a+b=c?

我们都知道:三角形3个外角之和360度可是谁知道为什么等于360度呢?其实利用编程制作动图演绎了解啦:那勾股定理abc又是为什么呢?还有很多有趣的数学公式都可以演绎:圆的面积公式、圆周长...通过动图演绎原来晦涩难懂的定…

System.InvalidOperationException : 不应有 Response xmlns=''。

xml如下&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <Response version"2"><datacash_reference>4700203048783633</datacash_reference><information>Failed to identify the card scheme of the supp…

Navicat Premium连接SQL Server

Navicat Premium连接SQL Server 步骤&#xff1a; 激活SQL Server 服务配置SQL Server网络配置连接SQL Server 激活SQLServer服务 直接搜索 计算机管理 点 服务和应用程序&#xff0c; 点 SQL Server配置管理器&#xff0c; 双击第一个SQL Server服务 不出意外的话&#xf…

mysql 单标递归_MySql8 WITH RECURSIVE递归查询父子集的方法

背景开发过程中遇到类似评论的功能是&#xff0c;需要时用查询所有评论的子集。不同数据库中实现方式也不同&#xff0c;本文使用Mysql数据库&#xff0c;版本为8.0Oracle数据库中可使用START [Param] CONNECT BY PRIORMysql 中需要使用 WITH RECURSIVE需求找到name为张三的孩子…

processon完全装逼指南

一、引言 作为一名IT从业者&#xff0c;不仅要有扎实的知识储备&#xff0c;出色的业务能力&#xff0c;还需要具备一定的软实力。软实力体现在具体事务的处理能力&#xff0c;包括沟通&#xff0c;协作&#xff0c;团队领导&#xff0c;问题的解决方案等&#xff0c;这些能力在…

mysql在空闲8小时之后会断开连接(默认情况)

调试程序的过程发现&#xff0c;在mysql连接空闲一定时间&#xff08;默认8小时&#xff09;之后会断开连接&#xff0c;需要重新连接&#xff0c;也引发我对重连机制的思考。转载于:https://www.cnblogs.com/ppzbty/p/5707576.html

selector多路复用_多路复用器Selector

Unix系统有五种IO模型分别是阻塞IO(blocking IO)&#xff0c;非阻塞IO( non-blocking IO)&#xff0c;IO多路复用(IO multiplexing)&#xff0c;信号驱动(SIGIO/Signal IO)和异步IO(Asynchronous IO)。而IO多路复用通常有select&#xff0c;poll&#xff0c;epoll&#xff0c;k…

解决svn log显示no author,no date的方法之一

只要把svnserve.conf中的anon-access read 的read 改为none&#xff0c;也不需要重启svnserve就行 sh-4.1# grep "none" /var/www/html/svn/pro/conf/svnserve.conf ### and "none". The sample settings below are the defaults. anon-access none转载…

REST framework 权限管理源码分析

REST framework 权限管理源码分析 同认证一样&#xff0c;dispatch()作为入口&#xff0c;从self.initial(request, *args, **kwargs)进入initial() def initial(self, request, *args, **kwargs):# .......# 用户认证self.perform_authentication(request)# 权限控制self.che…

解决larave-dompdf中文字体显示问题

0、使用MPDF dompdf个人感觉没有那么好用&#xff0c;最终的生产环境使用的是MPDF&#xff0c;github上有文档说明。如果你坚持使用&#xff0c;下面是解决办法。可以明确的说&#xff0c;中文乱码是可以解决的。 1、安装laravel-dompdf依赖。 Packagist&#xff1a;https://pa…

mfc程序转化为qt_小峰的QT学习笔记

我的专业是输电线路&#xff0c;上个学期&#xff0c;我们开了一门架空线路设计基础的课&#xff0c;当时有一个大作业是计算线路的比载&#xff0c;临界档距&#xff0c;弧垂最低点和安装曲线。恰逢一门结课考试结束&#xff0c;大作业ddl快到&#xff0c;我和另外两个同专业的…

MS SQL的存储过程

-- -- Author: -- Create date: 2016-07-01 -- Description: 注册信息 -- ALTER PROCEDURE [dbo].[sp_MebUser_Register]( UserType INT, MobileNumber VARCHAR(11), MobileCode VARCHAR(50), LoginPwd VARCHAR(50), PayPwd VARCHAR(50), PlateNumber VARCHAR(20), UserTr…

mysql 中 all any some 用法

-- 建表语句 CREATE TABLE score(id INT PRIMARY KEY AUTO_INCREMENT,NAME VARCHAR(20),SUBJECT VARCHAR(20),score INT);-- 添加数据 INSERT INTO score VALUES (NULL,张三,语文,81), (NULL,张三,数学,75), (NULL,李四,语文,76), (NULL,李四,数学,90), (NULL,王五,语文,81), (…

REST framework 用户认证源码

REST 用户认证源码 在Django中&#xff0c;从URL调度器中过来的HTTPRequest会传递给disatch(),使用REST后也一样 # REST的dispatch def dispatch(self, request, *args, **kwargs):""".dispatch() is pretty much the same as Djangos regular dispatch,but w…

scrapyd部署_如何通过 Scrapyd + ScrapydWeb 简单高效地部署和监控分布式爬虫项目

来自 Scrapy 官方账号的推荐需求分析初级用户&#xff1a;只有一台开发主机能够通过 Scrapyd-client 打包和部署 Scrapy 爬虫项目&#xff0c;以及通过 Scrapyd JSON API 来控制爬虫&#xff0c;感觉 命令行操作太麻烦 &#xff0c;希望能够通过浏览器直接部署和运行项目专业用…