基于mxnet的Regression问题Kaggle比赛代码框架

一、概述

书中3.16节扩展一下可以作为kaggle比赛的框架,这个赛题的名字是House Prices: Advanced Regression Techniques,是一个Regression问题。

二、Deeplearning的一般流程

结合李航《统计学习方法》中对机器学习流程的总结,分为data、model、strategy、algorithm、training、prediction

1、 Data

1.1、read data

# read data
train_data = pd.read_csv('./d2l-zh-1.1/data/kaggle_house_pred_train.csv')
test_data = pd.read_csv('./d2l-zh-1.1/data/kaggle_house_pred_test.csv')
# print(train_data.shape)
# print(train_data.iloc[0:4, [0, 1, 2, -1, -2, -3]])

1.2、preprocess data

# standardization to numeric type
all_features = pd.concat((train_data.iloc[:, 1:-1], test_data.iloc[:, 1:]))
numeric_features = all_features.dtypes[all_features.dtypes != 'object'].index
all_features[numeric_features] = all_features[numeric_features].apply(lambda x: (x - x.mean()) / x.std())
# 标准化后,每个特征的均值变为0,所以可以直接用0来替换缺失值
all_features[numeric_features] = all_features[numeric_features].fillna(0)# convert discrete value to dummy variable
all_features = pd.get_dummies(all_features, dummy_na=True)# get train and test data
n_train = train_data.shape[0]
train_features = nd.array(all_features[:n_train].values)
test_features = nd.array(all_features[n_train:].values)
train_labels = nd.array(train_data['SalePrice'].values).reshape((-1, 1))

1.3、get_k_fold_data

# k folds validation
def get_k_fold_data(k, i, X, y):assert k > 1fold_size = X.shape[0] // kX_train, y_train, X_valid, y_valid = None, None, None, Nonefor j in range(k):idx = slice(j * fold_size, (j + 1) * fold_size)X_part, y_part = X[idx, :], y[idx]if j == i:X_valid, y_valid = X_part, y_partelif X_train is None:X_train, y_train = X_part, y_partelse:X_train = nd.concat(X_train, X_part, dim=0)y_train = nd.concat(y_train, y_part, dim=0)return X_train, y_train, X_valid, y_valid

2、Model

def get_net():net = nn.Sequential()net.add(nn.Dense(256, activation='relu'),nn.Dropout(0.5),nn.Dense(1))net.initialize()return net

3、Strategy

loss = gloss.L2Loss()

4、Algorithm

# loss = gloss.L2Loss()

5、Training

# training
def train(net, train_iter, train_features, train_labels, test_features, test_labels,loss, num_epochs, trainer, batch_size):train_ls, test_ls = [], []for epoch in range(num_epochs):for X, y in train_iter:with autograd.record():l = loss(net(X), y)l.backward()trainer.step(batch_size)train_ls.append(log_rmse(net, train_features, train_labels))if test_labels is not None:test_ls.append(log_rmse(net, test_features, test_labels))return train_ls, test_ls

6、Validation

def k_fold(k, X_train, y_train, num_epochs, learning_rate, weight_decay, batch_size):train_l_sum, valid_l_sum = 0.0, 0.0for i in range(k):# datadata = get_k_fold_data(k, i, X_train, y_train)train_features, train_labels, _, _ = datatrain_iter = gdata.DataLoader(gdata.ArrayDataset(train_features, train_labels), batch_size, shuffle=True)# modelnet = get_net()# strategyloss = gloss.L2Loss()# algorithmtrainer = gluon.Trainer(net.collect_params(), 'adam',{'learning_rate': learning_rate, 'wd': weight_decay})# trainingtrain_ls, valid_ls = train(net, train_iter, *data, loss, num_epochs, trainer, batch_size)train_l_sum += train_ls[-1]valid_l_sum += valid_ls[-1]if i == 0:d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'rmse',range(1, num_epochs + 1), valid_ls, ['train', 'valid'])print('fold %d, train rmse %f, valid rmse %f' % (i, train_ls[-1], valid_ls[-1]))return train_l_sum / k, valid_l_sum / k# model selection
k, num_epochs, lr, weight_decay, batch_size = 5, 500, 0.01, 512, 64
train_l, valid_l = k_fold(k, train_features, train_labels, num_epochs, lr,weight_decay, batch_size)
print('%d-fold validation: avg train rmse %f, avg valid rmse %f' % (k, train_l, valid_l))

7、Prediction

train_and_pred(train_features, test_features, train_labels, test_data,num_epochs, lr, weight_decay, batch_size)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/420784.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

centos8安装

一. 下载centos centos下载 下载镜像版 mini版本 二,安装centos8 虚拟机安装 可 打开虚拟机安装centos 选择下载的镜像 配置磁盘大小 配置资源 配置虚拟机内存,处理器个数等. 安装成功后,也可配置

alert,confirm和prompt

1.警告消息框alertalert 方法有一个参数,即希望对用户显示的文本字符串。该字符串不是 HTML 格式。该消息框提供了一个“确定”按钮让用户关闭该消息框,并且该消息框是模式对话框,也就是说,用户必须先关闭该消息框然后才能继续进行…

(一)卷积网络之基础要点

一、提出问题 对于生活生产中的表格数据,至多也就上百维,而且表格数据的行与行之间没有序列和位置上的关系,所以用传统的机器学习算法就可轻松的解决这些问题。但是到了图片数据,传统机器学习就非常吃力了,一个普通的…

Windows Phone本地数据库(SQLCE):3、[table]attribute(翻译) (转)

这是“windows phone mango本地数据库(sqlce)”系列短片文章的第三篇。 为了让你开始在Windows Phone Mango中使用数据库,这一系列短片文章将覆盖所有你需要知道的知识点。这个时候我将谈谈有关你使用windows phone mango本地数据库时使用[ta…

Java代理模式——静态代理动态代理

proxy mode1. 什么是代理1.1 例子解释1.2 作用2. 静态代理2.1 优缺点分析2.2 以厂家卖u盘用代码说明3. 动态代理3.1 什么是动态代理3.2 jdk实现原理3.3 代码描述1. 什么是代理 1.1 例子解释 1. 生活中的例子,常见的商家卖东西, 商家就是代理&#xff0…

一、Insertion sort

1. 问题 2. 算法 2.1 伪代码 2.2 算法思想 2.3 手工演示 2.4 Python实现 《算法导论》一书数组默认从111开始,这种方式适合算法分析,从000开始适合程序实现,为了能和伪代码一致便于对比,后边所有的Python实现中数组均从111开始。…

windows 2502 2503 错误解决

1. 错误原因 1. c盘下temp文件夹权限问题 2. c盘temp文件夹环境变量配置错误,或者更改了2. 造成的问题 每次安装msi文件或者卸载msi程序包时,都会弹出此恶心的错误...3. 解决 1. 针对问题一,解决,以管理员身份安装或者卸载 win…

Hibernate学习笔记

Hibernate是什么: Hibernate 架构: 下载、安装、必要的 jar包、环境CLASSPAST的设置(此步骤省略) Hibernate框架的使用步骤:1、创建Hibernate的配置文件(hibernate.cfg.xml)2、创建持久化类&…

二、Merge sort

1 问题 2 算法 2.1 伪代码 2.2 算法思想 2.3 手工演示 2.4 Python实现 # -*- coding: utf-8 -*- import sysdef merge(A, p, q, r):n1 q - p 1n2 r - qL [0] * (n1 2)R [0] * (n2 2)for i in range(1, n11):L[i] A[pi-1]for j in range(1, n21):R[j] A[qj]L[n11] 6…

cglib实现动态代理

对目标方法实现前置或者后置增强, 是在程序动态运行时加入增强方法的。 1. 目标类 package com.lovely.proxy.cglib;/*** 目标类* author echo lovely* date 2020/7/26 15:20*/ public class Target {public void save() {System.out.println("sve running..…

fragment嵌套,viewpager嵌套 不能正确显示

转帖:http://blog.csdn.net/mybook1122/article/details/24003343 通常为 viewPager.setAdapter(new MyFragmentPagerAdapter(getSupportFragmentManager(), fragmentsList)); 替换为 mPager.setAdapter(new MyFragmentPagerAdapter(getChildFragmentManager(), fra…

三、递归树分析法

1 问题 2 解决思路 使用递归树猜想一个上界,使用归纳法证明上界也是下界。 2.1 使用递归树(recursion tree)猜想结论(不严谨) 使用递归树两点:1⃣️逐行展开;2⃣️逐行相加; 逐行…

Linux文件查看/编辑方法介绍

转载:https://www.centos.bz/2011/10/linux-file-view-edit/ cat 命令介绍 cat 命令的原含义为连接(concatenate), 用于连接多个文件内容并输出到标准输出流中(标准输出流默认为屏幕)。实际运用过程中,我们常使用它来显示文件内容…

html5input表单标签新属性

初探h5一,h5 新增表单类型二,新增表单属性三,code demo一,h5 新增表单类型 •email 邮箱地址•url 网络地址•number 数字框•range 滑块•Date pickers (date, month, week, time, datetime, datetime-local) 日期时间框•search…

关于java的JIT知识

1.JIT的工作原理图 工作原理 当JIT编译启用时(默认是启用的),JVM读入.class文件解释后,将其发给JIT编译器。JIT编译器将字节码编译成本机机器代码。 通常javac将程序源码编译,转换成java字节码,JVM通过解释…