动手学深度学习——实战Kaggle比赛:预测房价

  1. 导入包,建立字典DATA_HUB包含数据集的url和验证文件完整性的sha-1密钥。
  2. 定义download()函数用来下载数据集, 将数据集缓存在本地目录(默认情况下为…/data)中, 并返回下载文件的名称。
  3. 定义download_extract()函数下载并解压缩一个zip或tar文件,定义download_all()将所有数据集从DATA_HUB下载到缓存目录中。
  4. 进行数据标准化得到all_features
  5. 定义损失函数loss和线性模型net
  6. 采用价格预测的对数log_rmse()来衡量差异
  7. 定义训练函数train(),采用Adam优化器
  8. 定义K折交叉验证get_k_fold_data(),并且在K折交叉验证中训练K次k_fold()
  9. 最后进行预测并保存预测文件train_and_pred()
#!/usr/bin/env python
# coding: utf-8# 导入所需要的包
import hashlib
import os
import tarfile
import zipfile
import requests#@save
# DATA_HUB为二元组:包含数据集的url和验证文件完整性的sha-1密钥
DATA_HUB = dict()# 数据集托管在DATA_URL的站点上
DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'"""定义download函数:1、下载数据集,将数据集缓存在本地目录(../data),并返回下载文件的名称2、如果缓存目录中存在此数据集文件,并且与sha-1与存储在DATA_HUB中相匹配,则使用缓存的文件
"""
# download(文件名称,缓存目录)
def download(name, cache_dir=os.path.join('..', 'data')): #@save"""下载一个DATA_HUB中的文件,返回本地文件名"""assert name in DATA_HUB, f"{name} 不存在于 {DATA_HUB}"url, sha1_hash = DATA_HUB[name]os.makedirs(cache_dir, exist_ok=True)fname = os.path.join(cache_dir, url.split('/')[-1])if os.path.exists(fname):sha1 = hashlib.sha1()with open(fname, 'rb') as f:while True:data = f.read(1048576)if not data:breaksha1.update(data)if sha1.hexdigest() == sha1_hash:return fname # 命中缓存print(f'正在从{url}下载{fname}...')r = requests.get(url, stream=True, verify=True)with open(fname, 'wb') as f:f.write(r.content)return fname"""实现两个实用函数:1、一个将下载并解压缩一个zip/tar文件2、将使用的数据集从DATA_HUB下载到缓存目录
"""
def download_extract(name, folder=None): #@save"""下载并解压zip/tar文件"""fname = dowanload(name)base_dir = os.path.dirname(fname)data_dir = os.path.splitext(fname)if ext == '.zip':fp = zipfile.Zipfile(fname, 'r')elif ext in ('.tar', '.gz'):fp = tarfile.open(fname, 'r')else:assert False, '只有zip/tar文件可以解压'fp.extractall(base_dir)return os.path.join(base_dir, folder) if folder else data_dirdef download_all(): #@save"""下载DATA_HUB中的所有文件"""for name in DATA_HUB:download(name)# 如果没有安装pandas,请取消下一行的注释
# !pip install pandasget_ipython().run_line_magic('matplotlib', 'inline')
import numpy as np
import pandas as pd
import torch
from torch import nn
from d2l import torch as d2l# 使用上面定义的脚本下载并缓存Kaggle房屋数据集
DATA_HUB['kaggle_house_train'] = (  #@saveDATA_URL + 'kaggle_house_pred_train.csv','585e9cc93e70b39160e7921475f9bcd7d31219ce')DATA_HUB['kaggle_house_test'] = (  #@saveDATA_URL + 'kaggle_house_pred_test.csv','fa19780a7b011d9b009e8bff8e99922a8ee2eb90')# 使用pandas分布加载包含训练数据和测试数据的两个csv文件
train_data = pd.read_csv(download('kaggle_house_train'))
test_data = pd.read_csv(download('kaggle_house_test'))# 训练数据集包括1460个样本,每个样本80个特征和1个标签
# 测试数据集包括1459个样本,每个样本80个特征
print(train_data.shape)
print(test_data.shape)# 前四个和最后两个特征,以及相应标签(房价)
print(train_data.iloc[0:4, [0, 1, 2, 3, -3, -2, -1]])# 对于每个样本:删除第一个特征ID,因为其不携带任何用于预测的信息
all_features = pd.concat((train_data.iloc[:, 1:-1], test_data.iloc[:, 1:-1]))"""数据预处理:1、将所有缺失的值替换为相应特征的平均值2、为了将所有数据放在共同的尺度上,通过特征重新缩放到零均值和单位方差来标准化数据(1)方便优化(2)不知道哪些特征是相关的,所以不想让惩罚分配给一个特征的系数比分配给其他特征的系数更大3、处理离散值,用独热编码来替换。如"MSZoning_RL"为1,"MSZoning_RM"为0
"""
# 若无法获得测试数据,则可根据训练数据计算均值和标准差:x←(x-μ)/σ
# 获取无法获得测试数据的数量
numeric_features = all_features.dtypes[all_features.dtypes != 'object'].index
all_features[numeric_features] = all_features[numeric_features].apply(lambda x: (x - x.mean()) / (x.std())) 
# 在标准化数据之后,所有均值消失,因此我们可以将缺失值设置为0
all_features[numeric_features] = all_features[numeric_features].fillna(0)#“Dummy_na=True”将“na”(缺失值)视为有效的特征值,并为其创建指示符特征
# all_features是删除ID那一列之后,将每个样本中所有的特征连接起来
all_features = pd.get_dummies(all_features, dummy_na=True)
all_features.shape# 通过values属性将数据从pandas格式提取numpy格式,并将其转为张量用于训练
n_train = train_data.shape[0]
train_features = torch.tensor(all_features[:n_train].values, dtype=torch.float32)
test_features = torch.tensor(all_features[n_train:].values, dtype=torch.float32)
train_labels = torch.tensor(train_data.SalePrice.values.reshape(-1, 1), dtype=torch.float32)"""训练一个带有损失平方的线性模型:1、损失函数为损失平方2、线性模型作为基线模型
"""
loss = nn.MSELoss()
in_features = train_features.shape[1]def get_net():net = nn.Sequential(nn.Linear(in_features,1))return net# 采用价格预测的对数来衡量差异:√ ̄(1/n*(∑(logy - logy')^2))
def log_rmse(net, features, labels):# 为了在取对数时进一步稳定该值,将小于1的值设置为1clipped_preds = torch.clamp(net(features), 1, float('inf'))rmse = torch.sqrt(loss(torch.log(clipped_preds),torch.log(labels)))return rmse.item()# 优化器借助Adam优化器
"""定义训练函数:1、加载训练数据集2、使用Adam优化器(对初始学习率不那么敏感)3、进行训练:计算损失,进行梯度优化,返回训练损失和测试损失
"""
def train(net, train_features, train_labels, test_features, test_labels,num_epochs, learning_rate, weight_decay, batch_size):train_ls, test_ls = [], []train_iter = d2l.load_array((train_features, train_labels), batch_size)# 这里使用的是Adam优化算法optimizer = torch.optim.Adam(net.parameters(),lr = learning_rate,weight_decay = weight_decay)for epoch in range(num_epochs):for X,y in train_iter:optimizer.zero_grad()l = loss(net(X), y)l.backward()optimizer.step()train_ls.append(log_rmse(net, train_features, train_labels))if test_labels is not None:test_ls.append(log_rmse(net, test_features, test_labels))return train_ls,  test_ls"""定义K折交叉验证:1、当k > 1时,进行K折交叉验证,将数据集分为K份2、选择第i个切片作为验证集,其余部分作为训练数据3、第一片的训练数据直接填进去,之后的使用cat进行相连接
""" 
def get_k_fold_data(k, i, X, y):assert k > 1fold_size = X.shape[0] // k X_train, y_train = None, Nonefor j in range(k):idx = slice(j * fold_size, (j + 1) * fold_size)X_part, y_part = X[idx, :], y[idx]if j == i:X_valid, y_valid = X_part, y_partelif X_train is None:X_train, y_train = X_part, y_partelse:X_train = torch.cat([X_train, X_part], 0) y_train = torch.cat([y_train, y_part], 0)return X_train, y_train, X_valid, y_valid"""在K折交叉验证中训练K次:1、返回训练和验证误差的平均值2、可视化训练误差和验证误差
"""
def k_fold(k, X_train, y_train, num_epochs, learning_rate, weight_decay,batch_size):train_l_sum, valid_l_sum = 0, 0for i in range(k):data = get_k_fold_data(k, i, X_train, y_train)net = get_net()train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,weight_decay, batch_size)train_l_sum += train_ls[-1]valid_l_sum += valid_ls[-1]if i == 0:d2l.plot(list(range(1, num_epochs + 1)), [train_ls, valid_ls],xlabel='epoch', ylabel='rmse', xlim=[1, num_epochs],legend=['train', 'valid'], yscale='log')print(f'折{i + 1}, 训练log rmse{float(train_ls[-1]):f},'f'验证log rmse{float(valid_ls[-1]):f}')return train_l_sum / k, valid_l_sum / k# 模型选择
k, num_epochs, lr, weight_decay, batch_size = 5, 100, 5, 0, 64
train_l, valid_l = k_fold(k, train_features, train_labels, num_epochs, lr, weight_decay,batch_size)
print(f'{k}-折验证:平均训练log rmse:{float(train_l):f},'f'平均验证log rmse:{float(valid_l):f}')"""提交Kaggle预测:1、使用所有数据进行训练,得到模型2、该模型可以应用到测试集上,将预测保存在csv文件
"""
def train_and_pred(train_features, test_features, train_labels, test_data,num_epochs, lr, weight_decay, batch_size):net = get_net()train_ls, _ = train(net, train_features, train_labels, None, None,num_epochs, lr, weight_decay, batch_size)d2l.plot(np.arange(1, num_epochs + 1), [train_ls], xlabel='epoch',ylabel='log rmse', xlim=[1, num_epochs], yscale='log')print(f'训练log rmse:{float(train_ls[-1]):f}')# 将网络应用与测试集。preds = net(test_features).detach().numpy()# 将其重新格式化以导出到Kaggletest_data['SalePrice'] = pd.Series(preds.reshape(1, -1)[0])submission = pd.concat([test_data['Id'], test_data['SalePrice']], axis=1)submission.to_csv('submission.csv', index=False)# 预测
train_and_pred(train_features, test_features, train_labels, test_data,num_epochs, lr, weight_decay, batch_size)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/7402.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数字孪生:未来科技的新前沿

数字孪生作为一项新兴的研究方向,正逐渐成为科技界的焦点。它是将现实世界中的实体、系统或过程通过数字化手段进行建模、仿真和分析,形成与实体相对应的数字化副本。数字孪生的发展为我们带来了无限的想象空间,以及解决现实问题的新途径。 在…

java.net.ConnectException: Connection refused: no further information

报错如下: java.net.ConnectException: Connection refused: no further informationat sun.nio.ch.SocketChannelImpl.checkConnect(Native Method) ~[na:1.8.0_181]at sun.nio.ch.SocketChannelImpl.finishConnect(SocketChannelImpl.java:717) ~[na:1.8.0_181]a…

Zabbix监控安装grafana并配置图形操作

第三阶段基础 时 间:2023年7月20日 参加人:全班人员 内 容: Zabbix监控安装grafana 目录 安装并配置grafana 一、安装Grafana 二、下载安装插件 三、配置grafana 四、Web访问并配置: 安装并配置grafana 一、安装Graf…

【团队协作开发】将Gitee项目导入到本地IDEA中出现根目录不完整的问题解决(已解决)

前言:在团队协作开发过程中,通常我们的Gitee完整项目中会包含很多内容:后端代码、前端代码、项目结构图、项目文档等一系列资产。 将Gitee项目导入到本地IDEA中,通常会出现根目录不完整的问题。这是因为项目里面包含了后端代码、前…

如何使用Webman框架实现日历和事件提醒功能?

如何使用Webman框架实现日历和事件提醒功能? 引言: 在现代社会中,时间管理变得越来越重要。作为开发者,我们可以利用Webman框架来构建一个功能强大的日历应用程序,帮助人们更好地管理自己的时间。本文将介绍如何使用W…

基于Java+SpringBoot+vue前后端分离甘肃非物质文化网站设计实现

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

DAY51:动态规划(十五)买卖股票最佳时机Ⅲ+买卖股票最佳时期Ⅳ

文章目录 123.买卖股票最佳时机Ⅲ(注意初始化)思路DP数组含义递推公式初始化遍历顺序最开始的写法:初始化全部写成0debug测试:解答错误,第0天实际上是对应prices[0]和dp[0] 完整版总结 188.买卖股票最佳时机Ⅳ思路DP数…

09.计算机网络——套接字编程

文章目录 网络字节序socket编程socket 常见APIsockaddr结构 UDP编程创建socket绑定socketsendto发送数据recvform接收数据关闭socket TCP编程创建socket绑定socketlisten监听套接字accept服务端接收连接套接字connect客户端连接套接字send发送数据recv接收数据关闭socket 工具n…

【flink】ColumnarRowData

列式存储 在调试flink读取parquet文件时,读出来的数据是ColumnarRowData,由于parquet是列式存储的文件格式,所以需要用一种列式存储的表示方式,ColumnarRowData就是用来表示列式存储的一行数据,它包含多个数组的数据结…

从电商指标洞察到运营归因,只需几句话?AI 数智助理准备好了!

Lily 是名入职不久的电商运营助理,最近她想要根据 2022 年的客单价情况,分析品牌 A 在不同电商渠道的用户行为和表现,并提供一些有价值的洞察和建议给客户。然而在向技术人员提报表需求后,技术人员以需求排满为借口拒绝了。 Lily …

5分钟,结合 LangChain 搭建自己的生成式智能问答系统

伴随大语言模型(LLM,Large Language Model)的涌现,人们发现生成式人工智能在非常多领域具有重要意义,如图像生成,书写文稿,信息搜索等。随着 LLM 场景的多样化,大家希望 LLM 能在垂直…

记一次容器环境下出现 Address not available

作者:郑明泉、余凯 困惑的源地址 pod 创建后一段时间一直是正常运行,突然有一天发现没有新的连接创建了,业务上是通过 pod A 访问 svc B 的 svc name 的方式,进入 pod 手动去 wget 一下,发现报错了 Address not avai…

jar 更新 jar包内的 class,以及如何修改class

一、提取Jar 内文件 #提取jar内的配置文件jar -xvf a.jar META-INF\plugin.xml-已解压: META-INF/plugin.xml#提取jar内的class文件, 提示:反编译为java文件,修改后再使用javac xxx.java编译为class,jar -xvf a.jar io.config.**…

TCP长连接和短连接

tcp长连接和短连接 1. TCP短连接2. TCP长连接3. TCP长/短连接操作过程3.1 短连接的操作步骤3.2 长连接的操作步骤 4. TCP长/短连接的优点和缺点5. TCP长/短连接的应用场景 TCP在真正的读写操作之前,server 与 client之间必须建立一个连接,当读写操作完成…

Android中的ImageView设置图片显示有哪几种模式,有什么区别?

Android中的ImageView设置图片显示有哪几种模式,有什么区别? 在 Android 中,ImageView 是显示图像的视图控件,提供了多种图片显示模式(ScaleType)来控制图片的展示方式。不同的图片显示模式适用于不同的场…

全面解析缓存应用经典问题

1、前言 随着互联网从简单的单向浏览请求,发展为基于用户个性信息的定制化以及社交化的请求,这要求产品需要做到以用户和关系为基础,对海量数据进行分析和计算。对于后端服务来说,意味着用户的每次请求都需要查询用户的个人信息和…

使用frp实现公网使用https访问exsi控制台

目录 背景方法esxi配置上传替换证书重启相关服务 frp配置frps配置frpc配置重启服务 完成 背景 esxi控制台默认是通过https登陆的,但是因为它默认的证书是自签名的,所以在浏览器会标记为红色的叉;同时这对于配置安全的公网访问来说也是必须要…

单例模式类设计|什么是饿汉模式和懒汉模式

前言 那么这里博主先安利一些干货满满的专栏了! 首先是博主的高质量博客的汇总,这个专栏里面的博客,都是博主最最用心写的一部分,干货满满,希望对大家有帮助。 高质量干货博客汇总https://blog.csdn.net/yu_cblog/c…

Started CityManagementApplication in 0.982 seconds (JVM running for 1.97)

在pom文件里&#xff0c;添加依赖&#xff1a; <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId> </dependency>参考了这个作者&#xff08;zhttp://t.csdn.cn/fo5J2&#xff0…

在Vue-Element中引入jQuery的方法

一、在终端窗口执行安装命令 npm install jquery --save执行完后&#xff0c;npm会自动在package.json中加上jquery 二、在main.js中引入&#xff08;或者在需要使用的页面中引入即可&#xff09; import $ from jquery三、使用jquery