sklearn 机器学习 Pipeline 模板

文章目录

    • 1. 导入工具包
    • 2. 读取数据
    • 3. 数字特征、文字特征分离
    • 4. 数据处理Pipeline
    • 5. 尝试不同的模型
    • 6. 参数搜索
    • 7. 特征重要性筛选
    • 8. 最终完整Pipeline

使用 sklearn 的 pipeline 搭建机器学习的流程
本文例子为 [Kesci] 新人赛 · 员工满意度预测
参考 [Hands On ML] 2. 一个完整的机器学习项目(加州房价预测)

1. 导入工具包

import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import LabelBinarizer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import FeatureUnion
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score

2. 读取数据

data = pd.read_csv("../competition/Employee_Satisfaction/train.csv")
test = pd.read_csv("../competition/Employee_Satisfaction/test.csv")
data.columns
Index(['id', 'last_evaluation', 'number_project', 'average_monthly_hours','time_spend_company', 'Work_accident', 'package','promotion_last_5years', 'division', 'salary', 'satisfaction_level'],dtype='object')
  • 训练数据,标签分离
y = data['satisfaction_level']
X = data.drop(['satisfaction_level'], axis=1)

3. 数字特征、文字特征分离

def num_cat_splitor(X):s = (X.dtypes == 'object')object_cols = list(s[s].index)# object_cols # ['package', 'division', 'salary']num_cols = list(set(X.columns) - set(object_cols))# num_cols# ['Work_accident', 'time_spend_company', 'promotion_last_5years', 'id',#  'average_monthly_hours',  'last_evaluation',  'number_project']return num_cols, object_cols
num_cols, object_cols = num_cat_splitor(X)
# print(num_cols)
# print(object_cols)
# X[object_cols].values
  • 特征数值筛选器
class DataFrameSelector(BaseEstimator, TransformerMixin):def __init__(self, attribute_names):self.attribute_names = attribute_namesdef fit(self, X, y=None):return selfdef transform(self, X):return X[self.attribute_names].values

4. 数据处理Pipeline

  • 数字特征
num_pipeline = Pipeline([('selector', DataFrameSelector(num_cols)),('imputer', SimpleImputer(strategy="median")),('std_scaler', StandardScaler()),])
  • 文字特征
cat_pipeline = Pipeline([('selector', DataFrameSelector(object_cols)),('cat_encoder', OneHotEncoder(sparse=False)),])
  • 组合数字和文字特征
full_pipeline = FeatureUnion(transformer_list=[("num_pipeline", num_pipeline),("cat_pipeline", cat_pipeline),])
X_prepared = full_pipeline.fit_transform(X)

5. 尝试不同的模型

from sklearn.ensemble import RandomForestRegressor
forest_reg = RandomForestRegressor()
forest_scores = cross_val_score(forest_reg,X_prepared,y,scoring='neg_mean_squared_error',cv=3)
forest_rmse_scores = np.sqrt(-forest_scores)
print(forest_rmse_scores)
print(forest_rmse_scores.mean())
print(forest_rmse_scores.std())

还可以尝试别的模型

6. 参数搜索

param_grid = [{'n_estimators' : [3,10,30,50,80],'max_features':[2,4,6,8]},{'bootstrap':[False], 'n_estimators' : [3,10],'max_features':[2,3,4]},
]
forest_reg = RandomForestRegressor()
grid_search = GridSearchCV(forest_reg, param_grid, cv=5,scoring='neg_mean_squared_error')
grid_search.fit(X_prepared,y)
  • 最佳参数
grid_search.best_params_
  • 最优模型
grid_search.best_estimator_
  • 搜索结果
cv_result = grid_search.cv_results_
for mean_score, params in zip(cv_result['mean_test_score'], cv_result['params']):print(np.sqrt(-mean_score), params)
0.2129252723367584 {'max_features': 2, 'n_estimators': 3}
0.19276874697889504 {'max_features': 2, 'n_estimators': 10}
0.1865548358477794 {'max_features': 2, 'n_estimators': 30}
.......

7. 特征重要性筛选

feature_importances = grid_search.best_estimator_.feature_importances_
  • 选择前 k 个最重要的特征
k = 3
def indices_of_top_k(arr, k):return np.sort(np.argpartition(np.array(arr), -k)[-k:])class TopFeatureSelector(BaseEstimator, TransformerMixin):def __init__(self, feature_importances, k):self.feature_importances = feature_importancesself.k = kdef fit(self, X, y=None):self.feature_indices_ = indices_of_top_k(self.feature_importances, self.k)return selfdef transform(self, X):return X[:, self.feature_indices_]

8. 最终完整Pipeline

prepare_select_and_predict_pipeline = Pipeline([('preparation', full_pipeline),('feature_selection', TopFeatureSelector(feature_importances, k)),('forst_reg', RandomForestRegressor())
])
  • 参数搜索
param_grid = [{'preparation__num_pipeline__imputer__strategy': ['mean', 'median', 'most_frequent'],'feature_selection__k': list(range(5, len(feature_importances) + 1)),'forst_reg__n_estimators' : [200,250,300,310,330],'forst_reg__max_features':[2,4,6,8]
}]grid_search_prep = GridSearchCV(prepare_select_and_predict_pipeline, param_grid, cv=10,scoring='neg_mean_squared_error', verbose=2, n_jobs=-1)
  • 训练
grid_search_prep.fit(X,y)
grid_search_prep.best_params_
final_model = grid_search_prep.best_estimator_
  • 预测
y_pred_test = final_model.predict(test)
result = pd.DataFrame()
result['id'] = test['id']
result['satisfaction_level'] = y_pred_test
result.to_csv('rf_ML_pipeline.csv',index=False)

以上只是粗略的大体框架,还有很多细节,大家多指教!


我的CSDN博客地址 https://michael.blog.csdn.net/

长按或扫码关注我的公众号(Michael阿明),一起加油、一起学习进步!
Michael阿明

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/474627.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL Server 批量插入数据的两种方法

在SQL Server 中插入一条数据使用Insert语句,但是如果想要批量插入一堆数据的话,循环使用Insert不仅效率低,而且会导致SQL一系统性能问题。下面介绍SQL Server支持的两种批量数据插入方法:Bulk和表值参数(Table-Valued Parameters…

LeetCode MySQL 1532. The Most Recent Three Orders(dense_rank + over窗口函数)

文章目录1. 题目2. 解题1. 题目 Table: Customers ------------------------ | Column Name | Type | ------------------------ | customer_id | int | | name | varchar | ------------------------ customer_id is the primary key for this table. T…

Dota改键

利用全局钩子 制作一个个性化的dota游戏改键&#xff01; dll部分&#xff1a; // FileName: add.cpp#include <Windows.h>/* 定义全局变量 */ HWND g_hwnd NULL; HHOOK g_hKeyboard NULL;// 设置数据段 #pragma data_seg("MySec") static WORD g_keyNum[6]{…

LeetCode MySQL 1501. 可以放心投资的国家

文章目录1. 题目2. 解题1. 题目 表 Person: ------------------------- | Column Name | Type | ------------------------- | id | int | | name | varchar | | phone_number | varchar | ------------------------- id 是该表主键. 该表…

LeetCode MySQL 1270. 向公司CEO汇报工作的所有人

文章目录1. 题目2. 解题1. 题目 员工表&#xff1a;Employees ------------------------ | Column Name | Type | ------------------------ | employee_id | int | | employee_name | varchar | | manager_id | int | ------------------------ employee_…

LeetCode MySQL 570. 至少有5名直接下属的经理

文章目录1. 题目2. 解题1. 题目 Employee 表包含所有员工和他们的经理。 每个员工都有一个 Id&#xff0c;并且还有一列是经理的 Id。 ------------------------------------- |Id |Name |Department |ManagerId | ------------------------------------- |101 |John…

LeetCode MySQL 1132. 报告的记录 II

文章目录1. 题目2. 解题1. 题目 动作表&#xff1a; Actions ------------------------ | Column Name | Type | ------------------------ | user_id | int | | post_id | int | | action_date | date | | action | enum | | extra…

java封装省市区三级json格式,微信开发 使用picker封装省市区三级联动模板

目前学习小程序更多的是看看能否二次封装其它组件&#xff0c;利于以后能快速开发各种小程序应用。目前发现picker的selector模式只有一级下拉&#xff0c;那么我们是否可以通过3个picker来实现三级联动模板的形式来引入其它页面中呢&#xff1f;答案是肯定可以的。那么我的思路…

LeetCode MySQL 1126. 查询活跃业务

文章目录1. 题目2. 解题1. 题目 事件表&#xff1a;Events ------------------------ | Column Name | Type | ------------------------ | business_id | int | | event_type | varchar | | occurences | int | ------------------------ 此表的主键是…

php linux 删除文件夹,linux下如何删除文件夹

linux下删除文件夹的方法&#xff1a;可以使用【rm -rf 目录名】命令进行删除&#xff0c;如【rm -rf /var/log/httpd/access】&#xff0c;表示删除/var/log/httpd/access目录及其下的所有文件、文件夹。直接rm就可以了&#xff0c;不过要加两个参数-rf 即&#xff1a;rm -rf …

LeetCode 1533. Find the Index of the Large Integer(二分查找)

文章目录1. 题目2. 解题1. 题目 We have an integer array arr, where all the integers in arr are equal except for one integer which is larger than the rest of the integers. You will not be given direct access to the array, instead, you will have an API Array…

MySQL Server Architecture

MySQL 服务器架构&#xff1a; 转载于:https://www.cnblogs.com/macleanoracle/archive/2013/03/19/2968212.html

LeetCode MySQL 1479. 周内每天的销售情况(dayname星期几)

文章目录1. 题目2. 解题1. 题目 表&#xff1a;Orders ------------------------ | Column Name | Type | ------------------------ | order_id | int | | customer_id | int | | order_date | date | | item_id | varchar | | quantity …

php的swoole教程,PHP + Swoole2.0 初体验(swoole入门教程)

PHP Swoole2.0 初体验(swoole入门教程)环境&#xff1a;centos7 PHP7.1 swoole2.0准备工作&#xff1a;一、 swoole 扩展安装1 、下载swoolecd/usr/localwget -c https://github.com/swoole/swoole-src/archive/v2.0.8.tar.gztar -zxvf v2.0.8.tar.gzcdswoole-src-2.0.8/2 编…

Git常用命令解说

http://zensheno.blog.51cto.com/2712776/490748 1. Git概念 1.1. Git库中由三部分组成 Git 仓库就是那个.git 目录&#xff0c;其中存放的是我们所提交的文档索引内容&#xff0c;Git 可基于文档索引内容对其所管理的文档进行内容追踪&#xff0c;从而实现文档的版本控…

LeetCode MySQL 1412. 查找成绩处于中游的学生

文章目录1. 题目2. 解题1. 题目 表: Student ------------------------------ | Column Name | Type | ------------------------------ | student_id | int | | student_name | varchar | ------------------------------ student_id 是该表…

LeetCode MySQL 618. 学生地理信息报告(row_number)

文章目录1. 题目2. 解题1. 题目 一所美国大学有来自亚洲、欧洲和美洲的学生&#xff0c;他们的地理信息存放在如下 student 表中。 | name | continent | |--------|-----------| | Jack | America | | Pascal | Europe | | Xi | Asia | | Jane | Americ…

java非必填字段跳过校验,avalon2表单验证,非必填字段在不填写的时候不能通过验证...

avalon2表单验证,非必填字段在不填写的时候不能通过验证代码var vm avalon.define({$id: "validate1",aaa : "",validate: {onError: function(reasons) {reasons.forEach(function(reason) {console.log(reason.getMessage())})},onValidateAll: functio…

jQuery心得5--jQuery深入了解串讲1

1.CSS-DOM 操作 获取和设置元素的样式属性: css()。 获取和设置元素透明度: opacity 属性(css 的一个属性)。 获取和设置元素高度, 宽度: height(), width(). 在设置值时, 若只传递数字, 则默认单位是 px. 如需要使用其他单位则需传递一个字符串, 例如 $(“p:first”).height(“…

LeetCode MySQL 1225. 报告系统状态的连续日期(date_sub + over)

文章目录1. 题目2. 解题1. 题目 Table: Failed ----------------------- | Column Name | Type | ----------------------- | fail_date | date | ----------------------- 该表主键为 fail_date。 该表包含失败任务的天数.Table: Succeeded --------------------…