线性回归csv数据集_用mxnet的gluon线性回归训练只有两个特征的数据集

前言

自从上次试着用最基础的线性回归训练一个有80个特征的数据集,梯度爆炸之后,今天拿一个简单到不能再简单的数据集试试能不能成功收敛。途中我们又会遇到什么问题?

数据集

来自吴恩达机器学习课程第二周的课后练习。原本是txt文件,我通过下面三行代码把数据集另存为了csv,可以在这里下载。

import pandas as pd
df = pd.read_csv("ex1data2.txt",delimiter=',')
df.columns=['size','bedroom','price']
df.to_csv('house_simple.csv')

读取数据集

数据没有分训练集和测试集,房子的特征只有面积和房间数两个。 我们将通过pandas库读取并处理数据

导入这里需要的包

%matplotlib inline
import d2lzh as d2l
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import data as gdata, loss as gloss, nn
import numpy as np
import pandas as pd
data = pd.read_csv('data/house/house_2_features.csv' ,index_col=0)
data.head()

size bedroom price 0 1600 3 329900 1 2400 3 369000 2 1416 2 232000 3 3000 4 539900 4 1985 4 299900

data.shape
(46, 3)

预处理数据集

我们对连续数值的特征做标准化(standardization):设该特征在整个数据集上的均值为$mu$,标准差为$sigma$。那么,我们可以将该特征的每个值先减去$mu$再除以$sigma$得到标准化后的每个特征值。对于缺失的特征值,我们将其替换成该特征的均值。

data = data.apply(lambda x: (x - x.mean()) / (x.std()))data.fillna(0);

标准化后,每个特征的均值变为0,所以可以直接用0来替换缺失值。

data.head()

size bedroom price 0 -0.495977 -0.226166 -0.073110 1 0.499874 -0.226166 0.236953 2 -0.725023 -1.526618 -0.849457 3 1.246762 1.074287 1.592190 4 -0.016724 1.074287 -0.311010

把数据集分成两部分,训练集和测试集,并通过values属性得到NumPy格式的数据,并转成NDArray方便后面的训练。

n_train=36
train_features = nd.array(data[['size','bedroom']][:n_train].values)
test_features = nd.array(data[['size','bedroom']][n_train:].values)
train_labels = nd.array(data.price[:n_train].values).reshape((-1, 1))
train_features.shape
(36, 2)
train_features[:3]
[[-0.4959771  -0.22616564][ 0.4998739  -0.22616564][-0.72502285 -1.526618  ]]
<NDArray 3x2 @cpu(0)>

定义模型

我们使用一个基本的线性回归模型和平方损失函数来训练模型。 关于更多gluon使用的步骤请参考这里

net = nn.Sequential()
net.add(nn.Dense(1))

初始化模型参数

net.initialize(init.Normal(sigma=0.01))

定义损失函数

loss = gloss.L2Loss()

定义优化算法

创建一个Trainer实例,并指定学习率为0.03的小批量随机梯度下降(sgd)为优化算法。该优化算法将用来迭代net实例所有通过add函数嵌套的层所包含的全部参数。这些参数可以通过collect_params函数获取。

trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.03})

训练模型

随机读取包含batch_size个数据样本的小批量

batch_size=4
train_iter = gdata.DataLoader(gdata.ArrayDataset(train_features, train_labels), batch_size, shuffle=True)
num_epochs = 10
for epoch in range(1, num_epochs + 1):for X, y in train_iter:with autograd.record():l = loss(net(X), y)l.backward()trainer.step(batch_size)l = loss(net(train_features), train_labels)print('epoch %d, loss: %f' % (epoch, l.mean().asnumpy()))
epoch 1, loss: 0.349735
epoch 2, loss: 0.255017
epoch 3, loss: 0.207258
epoch 4, loss: 0.180886
epoch 5, loss: 0.166463
epoch 6, loss: 0.156838
epoch 7, loss: 0.150244
epoch 8, loss: 0.145748
epoch 9, loss: 0.142224
epoch 10, loss: 0.139501

后记

暂时看训练是能收敛的,损失也比上次少很多很多。下次我们再看几个问题: + 怎么算测试集的房价 + 有没有过拟 + 损失函数的结果怎么看,是大还是小

新手村的小伙伴们,你们有什么看法呢?

4c62e12ab32db1400b26d7b87eef4be7.png

此处围观我的github 博客,这里下载本文代码

续集

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/245057.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java中include标签的用法_原 ng-include用法分析以及多标签页面的简单实现方式

在平时的项目开发中&#xff0c;应该会经常遇到上图所示的需求&#xff0c;就是在一个页面中有多个标签&#xff0c;被选中的标签颜色会高亮显示&#xff0c;切换不同标签显示相应的不同内容。如果内容代码过多则写在同一个html文件就会显得特别乱&#xff0c;所以这里我们最好…

禅道项目管理_禅道 11.6.1 版本发布,完善细节,修复 Bug

禅道项目管理软件集产品管理、项目管理、质量管理、文档管理、组织管理和事务管理于一体&#xff0c;是一款功能完备的项目管理软件&#xff0c;完美地覆盖了项目管理的核心流程。禅道官网&#xff1a;www.zentao.net。大家好&#xff0c;禅道项目管理软件11.6.1发布&#xff0…

mendeley引用参考文献不显示_免费文献管理器Mendeley

June 2020有机合成化学文献检索今天小编给大家分享一款免费又好用的文献管理器——Mendeley&#xff0c;另外晶体cif文件下载—Materialsproject和COD数据库可在菜单栏的文献检索[文献管理/资源]中查看Mendeley是什么Mendeley是一款免费的跨平台文献管理软件&#xff0c;同时也…

停车场管理系统代码_jsp19109商场商铺停车场服务系统-SSM-Mysql

jsp19109商场商铺停车场服务系统-SSM-Mysql该设计有演示视频    100%能运行买重包换  保密发送  一校一份编号&#xff1a;jsp19109语言数据库&#xff1a;jspMysql论文字数&#xff1a;12032字摘 要随着社会的发展&#xff0c;社会的方方面面都在利用信息化时代的优势。计…

qregexp限制数字范围_数字系统实现电压电流控制的必经之路数模转换器

《芯势力》系列接上一篇文章&#xff0c;我们了解到了模数转换器&#xff0c;本文将带你了解数模转换器。看名字就能知道&#xff0c;如果模数转换器实现了模拟信号到数字信号的转换&#xff0c;那么&#xff0c;数模转换器就是模数转换器的逆过程&#xff0c;即把数字信号转换…

js方式调用php_js如何调用php函数

js调用php函数的方法&#xff1a;jQuery.ajax({type: "POST",url: your_functions_address.php,dataType: json,data: {functionname: add, arguments: [1, 2]},success: function (obj, textstatus) {if( !(error in obj) ) {yourVariable obj.result;}else {conso…

最大子序列求和_算法——求最大子段和

一、问题描述给定由n个整数组成的序列(a_1,a_2,…,a_n)&#xff0c;最大子段和问题要求该序列形如 的最大值(1≤i≤j≤n)&#xff0c;当序列中所有整数均为负整数时&#xff0c;其最大子段和为0。例如&#xff0c;序列(-20, 11, -4, 13, -5, -2)的最大子段和为&#xff1a; 注意…

seo黑帽劫持用的php,黑帽seo 论坛:黑帽seo防止网站被k的js劫持跳转代码

由于目前百度搜索百度搜索引擎对于js代码还没有办法完全辨别&#xff0c;因此也就出现了运用js代码跳转的黑帽优化提升手法。现如今在网络上有关js跳转代码不计其数&#xff0c;但是作为黑帽优化提升的seo手法之一&#xff0c;如何确保有效降低跳转的网址被k危害性&#xff0c;…

oracle 同义词_【干货7】Oracle知识关键代码摘要

&#xff08;如果我分享的干货内容对你有帮助&#xff0c;可以通过赞或者评论的方式告诉我&#xff0c;我会持续分享&#xff1b;或者留言你想要的IT方面的支持&#xff0c;我将分享大家感兴趣的IT类技术干货&#xff1b;如果没有收到大家的反馈&#xff0c;10天后我将停止技术…

qt做的接收串口数据并显示曲线_QT无人机地面站设计与制作

近年来&#xff0c;无人机可谓是大火。无论是军事&#xff0c;还是民用&#xff0c;它的地位更是不用说。但&#xff0c;如何利用利用现有技术对无人机的信息进行操作&#xff0c;实现人、机合一呢&#xff1f;“无人机地面站”应运而生&#xff0c;结合仿真系统为地面工作人员…

php直接读取csv文件,php实现的读取CSV文件函数示例

本文实例讲述了php实现的读取CSV文件函数。分享给大家供大家参考&#xff0c;具体如下&#xff1a;function read_csv($cvs) {$shuang false;$str file_get_contents($cvs);for ($i0;$iif($str{$i}") {if($shuang) {if($str{$i1}") {$str{$i} *;$str{$i1} *;} el…

系统背景描述_【计算机论文】管件加工管理系统和数据库的结构探析

摘 要:结合"中国制造2025"及德国"工业4.0"的发展趋势,概述目前国内管件生产加工流程的现状和不足,基于对管件加工过程中管件之间的差别、管件加工批次的混合等特点导致的管理难点分析,介绍管件生产加工管理系统的设计思路和工作流程,并对该系统未来可进一步…

shell tr 替换 空格_Shell 字符串分隔符!!!(全网最详细总结)

前言&#xff1a;在shell脚本编程中&#xff0c;我们经常会用到切割字符串&#xff0c;类似于python中的split。但shell中的命令比较五花八门&#xff0c;小编也是苦扰了很久&#xff0c;终于下定决心对它做一个总结。方法一&#xff1a;字符串替换法#/bin/bashstring"Hel…

本机用域名不能访问_域名注册申请网站域名注意事项

互联网用户越来越多&#xff0c;也有越来越多人搭建网站&#xff0c;做个人博客也好、搭建企业官网也好&#xff0c;数量都在逐步上升。做网站的数量在上升&#xff0c;域名注册量肯定也在上升。有的朋友头一次注册域名&#xff0c;对域名不了解也不知道申请网站域名该注意哪些…

电脑微信不用手机确认_不用安装第三方软件,手机投屏到电脑就这么简单

在头条上收到网友的提问&#xff0c;如果想把手机的内容投影到电脑上&#xff0c;该怎么做&#xff1f;为此我做一个简单的教程&#xff0c;不用安装第三方软件&#xff0c;就用Windows 10自带的无线显示功能和安卓手机的自带无线显示功能来实现。前提条件&#xff1a;1. 电脑是…

tomcat7 https 拒绝连接_物与网怎么连接呢?物联网架构及五大通信协议

消息触达能力是物联网(internet ofthings, IOT)的重要支撑&#xff0c;而物联网很多技术都源于移动互联网。柳猫将阐述移动互联网消息推送技术在物联网中的应用和演进。一、物联网架构和关键技术从开发的角度&#xff0c;无线接入是物联网设备端的核心技术&#xff0c;身份设备…

安卓手机浏览器排行_5g时代已来临!五月安卓手机性价比排行:两千元以上5G手机屠榜...

5月已经过去&#xff0c;同时也标志着今年上半年手机的发布已经告一段落。那么在这段时间里&#xff0c;智能手机的性价比如何呢&#xff1f;考虑到现在已经开始步入5G时代&#xff0c;所以智能手机的价格也是普遍上涨&#xff0c;想要找到一款性价比不错的手机似乎有些难度。现…

tp3.2 不能提交到action方法_什么是死锁,如何避免死锁(4种方法)

当两个线程相互等待对方释放资源时&#xff0c;就会发生死锁。Python 解释器没有监测&#xff0c;也不会主动采取措施来处理死锁情况&#xff0c;所以在进行多线程编程时应该采取措施避免出现死锁。一旦出现死锁&#xff0c;整个程序既不会发生任何异常&#xff0c;也不会给出任…

虚拟局域网软件开源_ZeroTier虚拟局域网免费远程桌面体验--替代TeamViewer

本文主要是关于使用ZeroTier创建虚拟局域网来实现免费远程桌面的方案的体验&#xff0c;包含了一些对不同方案的优缺点的描述。最近因为疫情在家不得不通过远程连接实验室电脑&#xff0c;有两种基本的解决方案&#xff1a;连接学校VPN&#xff0c;然后使用系统自带的远程桌面连…

java复制的函数会报错,2 面试题之面向对象

大纲&#xff1a;一、两个重要概念① 请说明类和对象的区别类是对某一类实物的描述&#xff0c;是抽象的&#xff1b;对象是一个实实在在的个体&#xff0c;是类的一个实例&#xff1b;② 解释一下什么是类加载机制、双亲委派模型&#xff0c;好处是什么&#xff1f;类加载机制…