sklearn 机器学习 Pipeline 模板

文章目录

    • 1. 导入工具包
    • 2. 读取数据
    • 3. 数字特征、文字特征分离
    • 4. 数据处理Pipeline
    • 5. 尝试不同的模型
    • 6. 参数搜索
    • 7. 特征重要性筛选
    • 8. 最终完整Pipeline

使用 sklearn 的 pipeline 搭建机器学习的流程
本文例子为 [Kesci] 新人赛 · 员工满意度预测
参考 [Hands On ML] 2. 一个完整的机器学习项目(加州房价预测)

1. 导入工具包

import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import LabelBinarizer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import FeatureUnion
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score

2. 读取数据

data = pd.read_csv("../competition/Employee_Satisfaction/train.csv")
test = pd.read_csv("../competition/Employee_Satisfaction/test.csv")
data.columns
Index(['id', 'last_evaluation', 'number_project', 'average_monthly_hours','time_spend_company', 'Work_accident', 'package','promotion_last_5years', 'division', 'salary', 'satisfaction_level'],dtype='object')
  • 训练数据,标签分离
y = data['satisfaction_level']
X = data.drop(['satisfaction_level'], axis=1)

3. 数字特征、文字特征分离

def num_cat_splitor(X):s = (X.dtypes == 'object')object_cols = list(s[s].index)# object_cols # ['package', 'division', 'salary']num_cols = list(set(X.columns) - set(object_cols))# num_cols# ['Work_accident', 'time_spend_company', 'promotion_last_5years', 'id',#  'average_monthly_hours',  'last_evaluation',  'number_project']return num_cols, object_cols
num_cols, object_cols = num_cat_splitor(X)
# print(num_cols)
# print(object_cols)
# X[object_cols].values
  • 特征数值筛选器
class DataFrameSelector(BaseEstimator, TransformerMixin):def __init__(self, attribute_names):self.attribute_names = attribute_namesdef fit(self, X, y=None):return selfdef transform(self, X):return X[self.attribute_names].values

4. 数据处理Pipeline

  • 数字特征
num_pipeline = Pipeline([('selector', DataFrameSelector(num_cols)),('imputer', SimpleImputer(strategy="median")),('std_scaler', StandardScaler()),])
  • 文字特征
cat_pipeline = Pipeline([('selector', DataFrameSelector(object_cols)),('cat_encoder', OneHotEncoder(sparse=False)),])
  • 组合数字和文字特征
full_pipeline = FeatureUnion(transformer_list=[("num_pipeline", num_pipeline),("cat_pipeline", cat_pipeline),])
X_prepared = full_pipeline.fit_transform(X)

5. 尝试不同的模型

from sklearn.ensemble import RandomForestRegressor
forest_reg = RandomForestRegressor()
forest_scores = cross_val_score(forest_reg,X_prepared,y,scoring='neg_mean_squared_error',cv=3)
forest_rmse_scores = np.sqrt(-forest_scores)
print(forest_rmse_scores)
print(forest_rmse_scores.mean())
print(forest_rmse_scores.std())

还可以尝试别的模型

6. 参数搜索

param_grid = [{'n_estimators' : [3,10,30,50,80],'max_features':[2,4,6,8]},{'bootstrap':[False], 'n_estimators' : [3,10],'max_features':[2,3,4]},
]
forest_reg = RandomForestRegressor()
grid_search = GridSearchCV(forest_reg, param_grid, cv=5,scoring='neg_mean_squared_error')
grid_search.fit(X_prepared,y)
  • 最佳参数
grid_search.best_params_
  • 最优模型
grid_search.best_estimator_
  • 搜索结果
cv_result = grid_search.cv_results_
for mean_score, params in zip(cv_result['mean_test_score'], cv_result['params']):print(np.sqrt(-mean_score), params)
0.2129252723367584 {'max_features': 2, 'n_estimators': 3}
0.19276874697889504 {'max_features': 2, 'n_estimators': 10}
0.1865548358477794 {'max_features': 2, 'n_estimators': 30}
.......

7. 特征重要性筛选

feature_importances = grid_search.best_estimator_.feature_importances_
  • 选择前 k 个最重要的特征
k = 3
def indices_of_top_k(arr, k):return np.sort(np.argpartition(np.array(arr), -k)[-k:])class TopFeatureSelector(BaseEstimator, TransformerMixin):def __init__(self, feature_importances, k):self.feature_importances = feature_importancesself.k = kdef fit(self, X, y=None):self.feature_indices_ = indices_of_top_k(self.feature_importances, self.k)return selfdef transform(self, X):return X[:, self.feature_indices_]

8. 最终完整Pipeline

prepare_select_and_predict_pipeline = Pipeline([('preparation', full_pipeline),('feature_selection', TopFeatureSelector(feature_importances, k)),('forst_reg', RandomForestRegressor())
])
  • 参数搜索
param_grid = [{'preparation__num_pipeline__imputer__strategy': ['mean', 'median', 'most_frequent'],'feature_selection__k': list(range(5, len(feature_importances) + 1)),'forst_reg__n_estimators' : [200,250,300,310,330],'forst_reg__max_features':[2,4,6,8]
}]grid_search_prep = GridSearchCV(prepare_select_and_predict_pipeline, param_grid, cv=10,scoring='neg_mean_squared_error', verbose=2, n_jobs=-1)
  • 训练
grid_search_prep.fit(X,y)
grid_search_prep.best_params_
final_model = grid_search_prep.best_estimator_
  • 预测
y_pred_test = final_model.predict(test)
result = pd.DataFrame()
result['id'] = test['id']
result['satisfaction_level'] = y_pred_test
result.to_csv('rf_ML_pipeline.csv',index=False)

以上只是粗略的大体框架,还有很多细节,大家多指教!


我的CSDN博客地址 https://michael.blog.csdn.net/

长按或扫码关注我的公众号(Michael阿明),一起加油、一起学习进步!
Michael阿明

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/474627.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

php获取表单信息的代码_PHP获取HTML文件名表单数据等

1、PHP获取表单各项数据 --- 与表单提交的方式有关GET方式,格式:$_GET[“formelement”]POST方式,格式:$_POST[“formelement”]REQUEST方式,格式:$_REQUEST[“formelement”]2、表单中上传文件的数据数组&…

SQL Server 批量插入数据的两种方法

在SQL Server 中插入一条数据使用Insert语句,但是如果想要批量插入一堆数据的话,循环使用Insert不仅效率低,而且会导致SQL一系统性能问题。下面介绍SQL Server支持的两种批量数据插入方法:Bulk和表值参数(Table-Valued Parameters…

LeetCode MySQL 1532. The Most Recent Three Orders(dense_rank + over窗口函数)

文章目录1. 题目2. 解题1. 题目 Table: Customers ------------------------ | Column Name | Type | ------------------------ | customer_id | int | | name | varchar | ------------------------ customer_id is the primary key for this table. T…

php 负载监控_php记录服务器负载、内存、cpu状态的代码

通过调用系统命令top,然后借助函数explode,实现记录服务器负载、内存使用情况、cpu当前状态等信息。代码如下:/*** 记录服务器负载、内存使用、cpu状态* 每10秒检测一次* edit by www.jbxue.com*/while(1){exec(top -b -n 1 -d 3,$out);$Cpu …

Dota改键

利用全局钩子 制作一个个性化的dota游戏改键&#xff01; dll部分&#xff1a; // FileName: add.cpp#include <Windows.h>/* 定义全局变量 */ HWND g_hwnd NULL; HHOOK g_hKeyboard NULL;// 设置数据段 #pragma data_seg("MySec") static WORD g_keyNum[6]{…

LeetCode MySQL 1501. 可以放心投资的国家

文章目录1. 题目2. 解题1. 题目 表 Person: ------------------------- | Column Name | Type | ------------------------- | id | int | | name | varchar | | phone_number | varchar | ------------------------- id 是该表主键. 该表…

php 小数末尾进1,PHP小数点最后一位加1、减1

比如我有几个数字(小数点后面的位数不固定)&#xff1a;1、155.0552、122.1963、0.9631我怎么做才能让这些数字的小数点最后一位1&#xff0c;或者-1&#xff1f;比如1的话希望得到&#xff1a;1、155.0562、122.1973、0.9632回复内容&#xff1a;比如我有几个数字(小数点后面的…

ARM汇编Hello,World

1. 编译运行环境见http://www.cnblogs.com/linucos/archive/2013/03/01/2938517.htm2. 汇编例子.data msg: .asciz "hello, world\n" .text .global main …

LeetCode MySQL 1270. 向公司CEO汇报工作的所有人

文章目录1. 题目2. 解题1. 题目 员工表&#xff1a;Employees ------------------------ | Column Name | Type | ------------------------ | employee_id | int | | employee_name | varchar | | manager_id | int | ------------------------ employee_…

php 正则 尖括号,php使用正则表达式提取字符串中尖括号、小括号、中括号、大括号中的字符串...

$str"你好(爱)[北京]{天安门}";echo f1($str); //返回你好echo f2($str); //返回我echo f3($str); //返回爱echo f4($str); //返回北京echo f5($str); //返回天安门function f1($str){$result array();preg_match_all("/^(.*)(?:return $result[1][0];}functi…

经济学经典书籍

I&#xff1a;入门阶段&#xff1a; 中文版名称&#xff1a;《经济学原理》 曼昆 英文版名称&#xff1a;principle of economics by Mankiw,N.G.II&#xff1a;基础阶段&#xff1a; 《微观经济学》 周惠中 《微观经济学&#xff1a;现代观点》 哈尔.R.范里安&#xff08;Hal …

LeetCode MySQL 570. 至少有5名直接下属的经理

文章目录1. 题目2. 解题1. 题目 Employee 表包含所有员工和他们的经理。 每个员工都有一个 Id&#xff0c;并且还有一列是经理的 Id。 ------------------------------------- |Id |Name |Department |ManagerId | ------------------------------------- |101 |John…

php 数据接口,初识 php 接口

这次的这篇文章介绍的是PHP接口的内容&#xff0c;现在分享给大家&#xff0c;也给有需要帮助的朋友一个参考&#xff0c;大家一起过来看一看吧一. 接口按请求人可以分为两种&#xff1a;一种是被其他内部项目调用的接口(包括js异步请求的接口和定时程序)。另一种是对外的接口&…

SYSU每周一赛(13.03.16)1003

给定起点终点的无向图&#xff0c;出发时速度为1&#xff0c;到达时速度也为1&#xff0c;在每个点可以进行速度1&#xff0c;不变&#xff0c;-1的操作&#xff0c;在每条边都有限速&#xff0c;到达一城市后不能直接走反向边&#xff0c;求最短时间。 SPFA作松弛操作的典型例…

LeetCode MySQL 1132. 报告的记录 II

文章目录1. 题目2. 解题1. 题目 动作表&#xff1a; Actions ------------------------ | Column Name | Type | ------------------------ | user_id | int | | post_id | int | | action_date | date | | action | enum | | extra…

java封装省市区三级json格式,微信开发 使用picker封装省市区三级联动模板

目前学习小程序更多的是看看能否二次封装其它组件&#xff0c;利于以后能快速开发各种小程序应用。目前发现picker的selector模式只有一级下拉&#xff0c;那么我们是否可以通过3个picker来实现三级联动模板的形式来引入其它页面中呢&#xff1f;答案是肯定可以的。那么我的思路…

LeetCode MySQL 1126. 查询活跃业务

文章目录1. 题目2. 解题1. 题目 事件表&#xff1a;Events ------------------------ | Column Name | Type | ------------------------ | business_id | int | | event_type | varchar | | occurences | int | ------------------------ 此表的主键是…

php linux 删除文件夹,linux下如何删除文件夹

linux下删除文件夹的方法&#xff1a;可以使用【rm -rf 目录名】命令进行删除&#xff0c;如【rm -rf /var/log/httpd/access】&#xff0c;表示删除/var/log/httpd/access目录及其下的所有文件、文件夹。直接rm就可以了&#xff0c;不过要加两个参数-rf 即&#xff1a;rm -rf …

Too many fragmentation in LMT?

这周和同事讨论技术问题时&#xff0c;他告诉我客户的一套11.1.0.6的数据库中某个本地管理表空间上存在大量的Extents Fragment区间碎片&#xff0c;这些连续的Extents没有正常合并为一个大的Extent&#xff0c;他怀疑这是由于11.1.0.6上的bug造成了LMT上存在大量碎片。 同事判…

LeetCode 1533. Find the Index of the Large Integer(二分查找)

文章目录1. 题目2. 解题1. 题目 We have an integer array arr, where all the integers in arr are equal except for one integer which is larger than the rest of the integers. You will not be given direct access to the array, instead, you will have an API Array…