政安晨:【Keras机器学习示例演绎】(五十三)—— 使用 TensorFlow 决策森林进行分类

目录

简介

设置

准备数据

 定义数据集元数据

配置超参数

实施培训和评估程序

实验 1:使用原始特征的决策森林

检查模型

实验 2:目标编码决策森林

创建模型输入

使用目标编码实现特征编码

使用预处理器创建梯度提升树模型

训练和评估模型

实验 3:决策森林与训练嵌入

结束语


政安晨的个人主页:政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:使用 TensorFlow 决策森林进行结构化数据分类。 

简介

TensorFlow 决策森林(TensorFlow Decision Forests)是与 Keras API 兼容的决策森林模型的最新算法集合。 这些模型包括随机森林(Random Forests)、梯度提升树(Gradient Boosted Trees)和 CART,可用于回归、分类和排序任务。 

本示例使用梯度提升树模型对结构化数据进行二进制分类,包括以下场景:

1. 通过指定输入特征用法来构建决策森林模型。
2. 将自定义二进制目标编码器作为 Keras 预处理层来实现,以便根据目标值共现对分类特征进行编码,然后使用编码后的特征构建决策森林模型。
3. 将分类特征编码为嵌入,在简单的 NN 模型中训练这些嵌入,然后使用训练后的嵌入作为输入构建决策森林模型。

本示例使用 TensorFlow 2.7 或更高版本以及 TensorFlow 决策森林,您可以使用以下命令安装 TensorFlow 决策森林:

pip install -U tensorflow_decision_forests

设置

import math
import urllib
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_decision_forests as tfdf

准备数据

本示例使用加州大学欧文分校机器学习资料库提供的美国人口普查收入数据集。
该数据集包含约 30 万个实例和 41 个输入特征,其中有 7 个数字特征和 34 个分类特征:

首先,我们将 UCI 机器学习库中的数据加载到 Pandas DataFrame 中。

BASE_PATH = "https://kdd.ics.uci.edu/databases/census-income/census-income"
CSV_HEADER = [l.decode("utf-8").split(":")[0].replace(" ", "_")for l in urllib.request.urlopen(f"{BASE_PATH}.names")if not l.startswith(b"|")
][2:]
CSV_HEADER.append("income_level")train_data = pd.read_csv(f"{BASE_PATH}.data.gz", header=None, names=CSV_HEADER,)
test_data = pd.read_csv(f"{BASE_PATH}.test.gz", header=None, names=CSV_HEADER,)

 定义数据集元数据

在此,我们定义了数据集的元数据,这些元数据将有助于根据输入特征的类型对其进行编码。

# Target column name.
TARGET_COLUMN_NAME = "income_level"
# The labels of the target columns.
TARGET_LABELS = [" - 50000.", " 50000+."]
# Weight column name.
WEIGHT_COLUMN_NAME = "instance_weight"
# Numeric feature names.
NUMERIC_FEATURE_NAMES = ["age","wage_per_hour","capital_gains","capital_losses","dividends_from_stocks","num_persons_worked_for_employer","weeks_worked_in_year",
]
# Categorical features and their vocabulary lists.
CATEGORICAL_FEATURE_NAMES = ["class_of_worker","detailed_industry_recode","detailed_occupation_recode","education","enroll_in_edu_inst_last_wk","marital_stat","major_industry_code","major_occupation_code","race","hispanic_origin","sex","member_of_a_labor_union","reason_for_unemployment","full_or_part_time_employment_stat","tax_filer_stat","region_of_previous_residence","state_of_previous_residence","detailed_household_and_family_stat","detailed_household_summary_in_household","migration_code-change_in_msa","migration_code-change_in_reg","migration_code-move_within_reg","live_in_this_house_1_year_ago","migration_prev_res_in_sunbelt","family_members_under_18","country_of_birth_father","country_of_birth_mother","country_of_birth_self","citizenship","own_business_or_self_employed","fill_inc_questionnaire_for_veteran's_admin","veterans_benefits","year",
]

现在我们进行基本的数据准备。

def prepare_dataframe(dataframe):# Convert the target labels from string to integer.dataframe[TARGET_COLUMN_NAME] = dataframe[TARGET_COLUMN_NAME].map(TARGET_LABELS.index)# Cast the categorical features to string.for feature_name in CATEGORICAL_FEATURE_NAMES:dataframe[feature_name] = dataframe[feature_name].astype(str)prepare_dataframe(train_data)
prepare_dataframe(test_data)

现在,让我们展示训练数据帧和测试数据帧的形状,并显示一些实例。

print(f"Train data shape: {train_data.shape}")
print(f"Test data shape: {test_data.shape}")
print(train_data.head().T)
Train data shape: (199523, 42)
Test data shape: (99762, 42)0  \
age                                                                                73   
class_of_worker                                                       Not in universe   
detailed_industry_recode                                                            0   
detailed_occupation_recode                                                          0   
education                                                        High school graduate   
wage_per_hour                                                                       0   
enroll_in_edu_inst_last_wk                                            Not in universe   
marital_stat                                                                  Widowed   
major_industry_code                                       Not in universe or children   
major_occupation_code                                                 Not in universe   
race                                                                            White   
hispanic_origin                                                             All other   
sex                                                                            Female   
member_of_a_labor_union                                               Not in universe   
reason_for_unemployment                                               Not in universe   
full_or_part_time_employment_stat                                  Not in labor force   
capital_gains                                                                       0   
capital_losses                                                                      0   
dividends_from_stocks                                                               0   
tax_filer_stat                                                               Nonfiler   
region_of_previous_residence                                          Not in universe   
state_of_previous_residence                                           Not in universe   
detailed_household_and_family_stat           Other Rel 18+ ever marr not in subfamily   
detailed_household_summary_in_household                 Other relative of householder   
instance_weight                                                               1700.09   
migration_code-change_in_msa                                                        ?   
migration_code-change_in_reg                                                        ?   
migration_code-move_within_reg                                                      ?   
live_in_this_house_1_year_ago                        Not in universe under 1 year old   
migration_prev_res_in_sunbelt                                                       ?   
num_persons_worked_for_employer                                                     0   
family_members_under_18                                               Not in universe   
country_of_birth_father                                                 United-States   
country_of_birth_mother                                                 United-States   
country_of_birth_self                                                   United-States   
citizenship                                         Native- Born in the United States   
own_business_or_self_employed                                                       0   
fill_inc_questionnaire_for_veteran's_admin                            Not in universe   
veterans_benefits                                                                   2   
weeks_worked_in_year                                                                0   
year                                                                               95   
income_level                                                                        0   
                                                                               1  \
age                                                                           58   
class_of_worker                                   Self-employed-not incorporated   
detailed_industry_recode                                                       4   
detailed_occupation_recode                                                    34   
education                                             Some college but no degree   
wage_per_hour                                                                  0   
enroll_in_edu_inst_last_wk                                       Not in universe   
marital_stat                                                            Divorced   
major_industry_code                                                 Construction   
major_occupation_code                        Precision production craft & repair   
race                                                                       White   
hispanic_origin                                                        All other   
sex                                                                         Male   
member_of_a_labor_union                                          Not in universe   
reason_for_unemployment                                          Not in universe   
full_or_part_time_employment_stat                       Children or Armed Forces   
capital_gains                                                                  0   
capital_losses                                                                 0   
dividends_from_stocks                                                          0   
tax_filer_stat                                                 Head of household   
region_of_previous_residence                                               South   
state_of_previous_residence                                             Arkansas   
detailed_household_and_family_stat                                   Householder   
detailed_household_summary_in_household                              Householder   
instance_weight                                                          1053.55   
migration_code-change_in_msa                                          MSA to MSA   
migration_code-change_in_reg                                         Same county   
migration_code-move_within_reg                                       Same county   
live_in_this_house_1_year_ago                                                 No   
migration_prev_res_in_sunbelt                                                Yes   
num_persons_worked_for_employer                                                1   
family_members_under_18                                          Not in universe   
country_of_birth_father                                            United-States   
country_of_birth_mother                                            United-States   
country_of_birth_self                                              United-States   
citizenship                                    Native- Born in the United States   
own_business_or_self_employed                                                  0   
fill_inc_questionnaire_for_veteran's_admin                       Not in universe   
veterans_benefits                                                              2   
weeks_worked_in_year                                                          52   
year                                                                          94   
income_level                                                                   0   
                                                                                  2  \
age                                                                               18   
class_of_worker                                                      Not in universe   
detailed_industry_recode                                                           0   
detailed_occupation_recode                                                         0   
education                                                                 10th grade   
wage_per_hour                                                                      0   
enroll_in_edu_inst_last_wk                                               High school   
marital_stat                                                           Never married   
major_industry_code                                      Not in universe or children   
major_occupation_code                                                Not in universe   
race                                                       Asian or Pacific Islander   
hispanic_origin                                                            All other   
sex                                                                           Female   
member_of_a_labor_union                                              Not in universe   
reason_for_unemployment                                              Not in universe   
full_or_part_time_employment_stat                                 Not in labor force   
capital_gains                                                                      0   
capital_losses                                                                     0   
dividends_from_stocks                                                              0   
tax_filer_stat                                                              Nonfiler   
region_of_previous_residence                                         Not in universe   
state_of_previous_residence                                          Not in universe   
detailed_household_and_family_stat           Child 18+ never marr Not in a subfamily   
detailed_household_summary_in_household                            Child 18 or older   
instance_weight                                                               991.95   
migration_code-change_in_msa                                                       ?   
migration_code-change_in_reg                                                       ?   
migration_code-move_within_reg                                                     ?   
live_in_this_house_1_year_ago                       Not in universe under 1 year old   
migration_prev_res_in_sunbelt                                                      ?   
num_persons_worked_for_employer                                                    0   
family_members_under_18                                              Not in universe   
country_of_birth_father                                                      Vietnam   
country_of_birth_mother                                                      Vietnam   
country_of_birth_self                                                        Vietnam   
citizenship                                      Foreign born- Not a citizen of U S    
own_business_or_self_employed                                                      0   
fill_inc_questionnaire_for_veteran's_admin                           Not in universe   
veterans_benefits                                                                  2   
weeks_worked_in_year                                                               0   
year                                                                              95   
income_level                                                                       0   
                                                                                 3  \
age                                                                              9   
class_of_worker                                                    Not in universe   
detailed_industry_recode                                                         0   
detailed_occupation_recode                                                       0   
education                                                                 Children   
wage_per_hour                                                                    0   
enroll_in_edu_inst_last_wk                                         Not in universe   
marital_stat                                                         Never married   
major_industry_code                                    Not in universe or children   
major_occupation_code                                              Not in universe   
race                                                                         White   
hispanic_origin                                                          All other   
sex                                                                         Female   
member_of_a_labor_union                                            Not in universe   
reason_for_unemployment                                            Not in universe   
full_or_part_time_employment_stat                         Children or Armed Forces   
capital_gains                                                                    0   
capital_losses                                                                   0   
dividends_from_stocks                                                            0   
tax_filer_stat                                                            Nonfiler   
region_of_previous_residence                                       Not in universe   
state_of_previous_residence                                        Not in universe   
detailed_household_and_family_stat           Child <18 never marr not in subfamily   
detailed_household_summary_in_household               Child under 18 never married   
instance_weight                                                            1758.14   
migration_code-change_in_msa                                              Nonmover   
migration_code-change_in_reg                                              Nonmover   
migration_code-move_within_reg                                            Nonmover   
live_in_this_house_1_year_ago                                                  Yes   
migration_prev_res_in_sunbelt                                      Not in universe   
num_persons_worked_for_employer                                                  0   
family_members_under_18                                       Both parents present   
country_of_birth_father                                              United-States   
country_of_birth_mother                                              United-States   
country_of_birth_self                                                United-States   
citizenship                                      Native- Born in the United States   
own_business_or_self_employed                                                    0   
fill_inc_questionnaire_for_veteran's_admin                         Not in universe   
veterans_benefits                                                                0   
weeks_worked_in_year                                                             0   
year                                                                            94   
income_level                                                                     0   
                                                                                 4  
age                                                                             10  
class_of_worker                                                    Not in universe  
detailed_industry_recode                                                         0  
detailed_occupation_recode                                                       0  
education                                                                 Children  
wage_per_hour                                                                    0  
enroll_in_edu_inst_last_wk                                         Not in universe  
marital_stat                                                         Never married  
major_industry_code                                    Not in universe or children  
major_occupation_code                                              Not in universe  
race                                                                         White  
hispanic_origin                                                          All other  
sex                                                                         Female  
member_of_a_labor_union                                            Not in universe  
reason_for_unemployment                                            Not in universe  
full_or_part_time_employment_stat                         Children or Armed Forces  
capital_gains                                                                    0  
capital_losses                                                                   0  
dividends_from_stocks                                                            0  
tax_filer_stat                                                            Nonfiler  
region_of_previous_residence                                       Not in universe  
state_of_previous_residence                                        Not in universe  
detailed_household_and_family_stat           Child <18 never marr not in subfamily  
detailed_household_summary_in_household               Child under 18 never married  
instance_weight                                                            1069.16  
migration_code-change_in_msa                                              Nonmover  
migration_code-change_in_reg                                              Nonmover  
migration_code-move_within_reg                                            Nonmover  
live_in_this_house_1_year_ago                                                  Yes  
migration_prev_res_in_sunbelt                                      Not in universe  
num_persons_worked_for_employer                                                  0  
family_members_under_18                                       Both parents present  
country_of_birth_father                                              United-States  
country_of_birth_mother                                              United-States  
country_of_birth_self                                                United-States  
citizenship                                      Native- Born in the United States  
own_business_or_self_employed                                                    0  
fill_inc_questionnaire_for_veteran's_admin                         Not in universe  
veterans_benefits                                                                0  
weeks_worked_in_year                                                             0  
year                                                                            94  
income_level                                                                     0  

配置超参数

你可以在文档中找到梯度提升树模型的所有参数。

# Maximum number of decision trees. The effective number of trained trees can be smaller if early stopping is enabled.
NUM_TREES = 250
# Minimum number of examples in a node.
MIN_EXAMPLES = 6
# Maximum depth of the tree. max_depth=1 means that all trees will be roots.
MAX_DEPTH = 5
# Ratio of the dataset (sampling without replacement) used to train individual trees for the random sampling method.
SUBSAMPLE = 0.65
# Control the sampling of the datasets used to train individual trees.
SAMPLING_METHOD = "RANDOM"
# Ratio of the training dataset used to monitor the training. Require to be >0 if early stopping is enabled.
VALIDATION_RATIO = 0.1

实施培训和评估程序

run_experiment() 方法负责加载训练数据集和测试数据集、训练给定模型以及评估训练后的模型。 请注意,在训练决策森林模型时,只需要一个历元来读取完整的数据集。

任何额外的步骤都会导致不必要的训练速度减慢。 因此,在 run_experiment() 方法中使用了默认的 num_epochs=1。

def run_experiment(model, train_data, test_data, num_epochs=1, batch_size=None):train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(train_data, label=TARGET_COLUMN_NAME, weight=WEIGHT_COLUMN_NAME)test_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(test_data, label=TARGET_COLUMN_NAME, weight=WEIGHT_COLUMN_NAME)model.fit(train_dataset, epochs=num_epochs, batch_size=batch_size)_, accuracy = model.evaluate(test_dataset, verbose=0)print(f"Test accuracy: {round(accuracy * 100, 2)}%")

实验 1:使用原始特征的决策森林

指定模型输入特征用法

您可以为每个特征附加语义,以控制模型如何使用它。

如果没有指定,语义将从表示类型中推断出来。 建议明确指定特征用法,以避免推断语义不正确。 例如,一个分类值标识符(整数)会被推断为数值,而它在语义上是分类的。 对于数值特征,可以将离散参数设置为数值特征应被离散的桶数。 这样可以加快训练速度,但可能会导致模型质量下降。

def specify_feature_usages():feature_usages = []for feature_name in NUMERIC_FEATURE_NAMES:feature_usage = tfdf.keras.FeatureUsage(name=feature_name, semantic=tfdf.keras.FeatureSemantic.NUMERICAL)feature_usages.append(feature_usage)for feature_name in CATEGORICAL_FEATURE_NAMES:feature_usage = tfdf.keras.FeatureUsage(name=feature_name, semantic=tfdf.keras.FeatureSemantic.CATEGORICAL)feature_usages.append(feature_usage)return feature_usages

创建梯度提升树模型

在编译决策森林模型时,只能提供额外的评估指标。 损失在模型构建中指定,优化器与决策森林模型无关。

def create_gbt_model():# See all the model parameters in https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/GradientBoostedTreesModelgbt_model = tfdf.keras.GradientBoostedTreesModel(features=specify_feature_usages(),exclude_non_specified_features=True,num_trees=NUM_TREES,max_depth=MAX_DEPTH,min_examples=MIN_EXAMPLES,subsample=SUBSAMPLE,validation_ratio=VALIDATION_RATIO,task=tfdf.keras.Task.CLASSIFICATION,)gbt_model.compile(metrics=[keras.metrics.BinaryAccuracy(name="accuracy")])return gbt_model

训练和评估模型

gbt_model = create_gbt_model()
run_experiment(gbt_model, train_data, test_data)
Starting reading the dataset
200/200 [==============================] - ETA: 0s
Dataset read in 0:00:08.829036
Training model
Model trained in 0:00:48.639771
Compiling model
200/200 [==============================] - 58s 268ms/step
Test accuracy: 95.79%

检查模型

model.summary() 方法将显示有关决策树模型、模型类型、任务、输入特征和特征重要性的几类信息。

print(gbt_model.summary())
Model: "gradient_boosted_trees_model"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================
=================================================================
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "GRADIENT_BOOSTED_TREES"
Task: CLASSIFICATION
Label: "__LABEL"
Input Features (40):agecapital_gainscapital_lossescitizenshipclass_of_workercountry_of_birth_fathercountry_of_birth_mothercountry_of_birth_selfdetailed_household_and_family_statdetailed_household_summary_in_householddetailed_industry_recodedetailed_occupation_recodedividends_from_stockseducationenroll_in_edu_inst_last_wkfamily_members_under_18fill_inc_questionnaire_for_veteran's_adminfull_or_part_time_employment_stathispanic_originlive_in_this_house_1_year_agomajor_industry_codemajor_occupation_codemarital_statmember_of_a_labor_unionmigration_code-change_in_msamigration_code-change_in_regmigration_code-move_within_regmigration_prev_res_in_sunbeltnum_persons_worked_for_employerown_business_or_self_employedracereason_for_unemploymentregion_of_previous_residencesexstate_of_previous_residencetax_filer_statveterans_benefitswage_per_hourweeks_worked_in_yearyear
Trained with weights
Variable Importance: MEAN_MIN_DEPTH:1.                 "enroll_in_edu_inst_last_wk"  3.942647 ################2.                    "family_members_under_18"  3.942647 ################3.              "live_in_this_house_1_year_ago"  3.942647 ################4.               "migration_code-change_in_msa"  3.942647 ################5.             "migration_code-move_within_reg"  3.942647 ################6.                                       "year"  3.942647 ################7.                                    "__LABEL"  3.942647 ################8.                                  "__WEIGHTS"  3.942647 ################9.                                "citizenship"  3.942137 ###############10.    "detailed_household_summary_in_household"  3.942137 ###############11.               "region_of_previous_residence"  3.942137 ###############12.                          "veterans_benefits"  3.942137 ###############13.              "migration_prev_res_in_sunbelt"  3.940135 ###############14.               "migration_code-change_in_reg"  3.939926 ###############15.                      "major_occupation_code"  3.937681 ###############16.                        "major_industry_code"  3.933687 ###############17.                    "reason_for_unemployment"  3.926320 ###############18.                            "hispanic_origin"  3.900776 ###############19.                    "member_of_a_labor_union"  3.894843 ###############20.                                       "race"  3.878617 ###############21.            "num_persons_worked_for_employer"  3.818566 ##############22.                               "marital_stat"  3.795667 ##############23.          "full_or_part_time_employment_stat"  3.795431 ##############24.                    "country_of_birth_mother"  3.787967 ##############25.                             "tax_filer_stat"  3.784505 ##############26. "fill_inc_questionnaire_for_veteran's_admin"  3.783607 ##############27.              "own_business_or_self_employed"  3.776398 ##############28.                    "country_of_birth_father"  3.715252 #############29.                                        "sex"  3.708745 #############30.                            "class_of_worker"  3.688424 #############31.                       "weeks_worked_in_year"  3.665290 #############32.                "state_of_previous_residence"  3.657234 #############33.                      "country_of_birth_self"  3.654377 #############34.                                        "age"  3.634295 ############35.                              "wage_per_hour"  3.617817 ############36.         "detailed_household_and_family_stat"  3.594743 ############37.                             "capital_losses"  3.439298 ##########38.                      "dividends_from_stocks"  3.423652 ##########39.                              "capital_gains"  3.222753 ########40.                                  "education"  3.158698 ########41.                   "detailed_industry_recode"  2.981471 ######42.                 "detailed_occupation_recode"  2.364817 
Variable Importance: NUM_AS_ROOT:1.                                  "education" 33.000000 ################2.                              "capital_gains" 29.000000 ##############3.                             "capital_losses" 24.000000 ###########4.         "detailed_household_and_family_stat" 14.000000 ######5.                      "dividends_from_stocks" 14.000000 ######6.                              "wage_per_hour" 12.000000 #####7.                      "country_of_birth_self" 11.000000 #####8.                 "detailed_occupation_recode" 11.000000 #####9.                       "weeks_worked_in_year" 11.000000 #####10.                                        "age" 10.000000 ####11.                "state_of_previous_residence" 10.000000 ####12. "fill_inc_questionnaire_for_veteran's_admin"  9.000000 ####13.                            "class_of_worker"  8.000000 ###14.          "full_or_part_time_employment_stat"  8.000000 ###15.                               "marital_stat"  8.000000 ###16.              "own_business_or_self_employed"  8.000000 ###17.                                        "sex"  6.000000 ##18.                             "tax_filer_stat"  5.000000 ##19.                    "country_of_birth_father"  4.000000 #20.                                       "race"  3.000000 #21.                   "detailed_industry_recode"  2.000000 22.                            "hispanic_origin"  2.000000 23.                    "country_of_birth_mother"  1.000000 24.            "num_persons_worked_for_employer"  1.000000 25.                    "reason_for_unemployment"  1.000000 
Variable Importance: NUM_NODES:1.                 "detailed_occupation_recode" 785.000000 ################2.                   "detailed_industry_recode" 668.000000 #############3.                              "capital_gains" 275.000000 #####4.                      "dividends_from_stocks" 220.000000 ####5.                             "capital_losses" 197.000000 ####6.                                  "education" 178.000000 ###7.                    "country_of_birth_mother" 128.000000 ##8.                    "country_of_birth_father" 116.000000 ##9.                                        "age" 114.000000 ##10.                              "wage_per_hour" 98.000000 #11.                "state_of_previous_residence" 95.000000 #12.         "detailed_household_and_family_stat" 78.000000 #13.                            "class_of_worker" 67.000000 #14.                      "country_of_birth_self" 65.000000 #15.                                        "sex" 65.000000 #16.                       "weeks_worked_in_year" 60.000000 #17.                             "tax_filer_stat" 57.000000 #18.            "num_persons_worked_for_employer" 54.000000 #19.              "own_business_or_self_employed" 30.000000 20.                               "marital_stat" 26.000000 21.                    "member_of_a_labor_union" 16.000000 22. "fill_inc_questionnaire_for_veteran's_admin" 15.000000 23.          "full_or_part_time_employment_stat" 15.000000 24.                        "major_industry_code" 15.000000 25.                            "hispanic_origin"  9.000000 26.                      "major_occupation_code"  7.000000 27.                                       "race"  7.000000 28.                                "citizenship"  1.000000 29.    "detailed_household_summary_in_household"  1.000000 30.               "migration_code-change_in_reg"  1.000000 31.              "migration_prev_res_in_sunbelt"  1.000000 32.                    "reason_for_unemployment"  1.000000 33.               "region_of_previous_residence"  1.000000 34.                          "veterans_benefits"  1.000000 
Variable Importance: SUM_SCORE:1.                 "detailed_occupation_recode" 15392441.075369 ################2.                              "capital_gains" 5277826.822514 #####3.                                  "education" 4751749.289550 ####4.                      "dividends_from_stocks" 3792002.951255 ###5.                   "detailed_industry_recode" 2882200.882109 ##6.                                        "sex" 2559417.877325 ##7.                                        "age" 2042990.944829 ##8.                             "capital_losses" 1735728.772551 #9.                       "weeks_worked_in_year" 1272820.203971 #10.                             "tax_filer_stat" 697890.160846 11.            "num_persons_worked_for_employer" 671351.905595 12.         "detailed_household_and_family_stat" 444620.829557 13.                            "class_of_worker" 362250.565331 14.                    "country_of_birth_mother" 296311.574426 15.                    "country_of_birth_father" 258198.889206 16.                              "wage_per_hour" 239764.219048 17.                "state_of_previous_residence" 237687.602572 18.                      "country_of_birth_self" 103002.168158 19.                               "marital_stat" 102449.735314 20.              "own_business_or_self_employed" 82938.893541 21. "fill_inc_questionnaire_for_veteran's_admin" 22692.700206 22.          "full_or_part_time_employment_stat" 19078.398837 23.                        "major_industry_code" 18450.345505 24.                    "member_of_a_labor_union" 14905.360879 25.                            "hispanic_origin" 12602.867902 26.                      "major_occupation_code" 8709.665989 27.                                       "race" 6116.282065 28.                                "citizenship" 3291.490393 29.    "detailed_household_summary_in_household" 2733.439375 30.                          "veterans_benefits" 1230.940488 31.               "region_of_previous_residence" 1139.240981 32.                    "reason_for_unemployment" 219.245124 33.               "migration_code-change_in_reg" 55.806436 34.              "migration_prev_res_in_sunbelt" 37.780635 
Loss: BINOMIAL_LOG_LIKELIHOOD
Validation loss value: 0.228983
Number of trees per iteration: 1
Node format: NOT_SET
Number of trees: 245
Total number of nodes: 7179
Number of nodes by tree:
Count: 245 Average: 29.302 StdDev: 2.96211
Min: 17 Max: 31 Ignored: 0
----------------------------------------------
[ 17, 18)   2   0.82%   0.82%
[ 18, 19)   0   0.00%   0.82%
[ 19, 20)   3   1.22%   2.04%
[ 20, 21)   0   0.00%   2.04%
[ 21, 22)   4   1.63%   3.67%
[ 22, 23)   0   0.00%   3.67%
[ 23, 24)  15   6.12%   9.80% #
[ 24, 25)   0   0.00%   9.80%
[ 25, 26)   5   2.04%  11.84%
[ 26, 27)   0   0.00%  11.84%
[ 27, 28)  21   8.57%  20.41% #
[ 28, 29)   0   0.00%  20.41%
[ 29, 30)  39  15.92%  36.33% ###
[ 30, 31)   0   0.00%  36.33%
[ 31, 31] 156  63.67% 100.00% ##########
Depth by leafs:
Count: 3712 Average: 3.95259 StdDev: 0.249814
Min: 2 Max: 4 Ignored: 0
----------------------------------------------
[ 2, 3)   32   0.86%   0.86%
[ 3, 4)  112   3.02%   3.88%
[ 4, 4] 3568  96.12% 100.00% ##########
Number of training obs by leaf:
Count: 3712 Average: 11849.3 StdDev: 33719.3
Min: 6 Max: 179360 Ignored: 0
----------------------------------------------
[      6,   8973) 3100  83.51%  83.51% ##########
[   8973,  17941)  148   3.99%  87.50%
[  17941,  26909)   79   2.13%  89.63%
[  26909,  35877)   36   0.97%  90.60%
[  35877,  44844)   44   1.19%  91.78%
[  44844,  53812)   17   0.46%  92.24%
[  53812,  62780)   20   0.54%  92.78%
[  62780,  71748)   39   1.05%  93.83%
[  71748,  80715)   24   0.65%  94.48%
[  80715,  89683)   12   0.32%  94.80%
[  89683,  98651)   22   0.59%  95.39%
[  98651, 107619)   21   0.57%  95.96%
[ 107619, 116586)   17   0.46%  96.42%
[ 116586, 125554)   17   0.46%  96.88%
[ 125554, 134522)   13   0.35%  97.23%
[ 134522, 143490)    8   0.22%  97.44%
[ 143490, 152457)    5   0.13%  97.58%
[ 152457, 161425)    6   0.16%  97.74%
[ 161425, 170393)   15   0.40%  98.14%
[ 170393, 179360]   69   1.86% 100.00%
Attribute in nodes:785 : detailed_occupation_recode [CATEGORICAL]668 : detailed_industry_recode [CATEGORICAL]275 : capital_gains [NUMERICAL]220 : dividends_from_stocks [NUMERICAL]197 : capital_losses [NUMERICAL]178 : education [CATEGORICAL]128 : country_of_birth_mother [CATEGORICAL]116 : country_of_birth_father [CATEGORICAL]114 : age [NUMERICAL]98 : wage_per_hour [NUMERICAL]95 : state_of_previous_residence [CATEGORICAL]78 : detailed_household_and_family_stat [CATEGORICAL]67 : class_of_worker [CATEGORICAL]65 : sex [CATEGORICAL]65 : country_of_birth_self [CATEGORICAL]60 : weeks_worked_in_year [NUMERICAL]57 : tax_filer_stat [CATEGORICAL]54 : num_persons_worked_for_employer [NUMERICAL]30 : own_business_or_self_employed [CATEGORICAL]26 : marital_stat [CATEGORICAL]16 : member_of_a_labor_union [CATEGORICAL]15 : major_industry_code [CATEGORICAL]15 : full_or_part_time_employment_stat [CATEGORICAL]15 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]9 : hispanic_origin [CATEGORICAL]7 : race [CATEGORICAL]7 : major_occupation_code [CATEGORICAL]1 : veterans_benefits [CATEGORICAL]1 : region_of_previous_residence [CATEGORICAL]1 : reason_for_unemployment [CATEGORICAL]1 : migration_prev_res_in_sunbelt [CATEGORICAL]1 : migration_code-change_in_reg [CATEGORICAL]1 : detailed_household_summary_in_household [CATEGORICAL]1 : citizenship [CATEGORICAL]
Attribute in nodes with depth <= 0:33 : education [CATEGORICAL]29 : capital_gains [NUMERICAL]24 : capital_losses [NUMERICAL]14 : dividends_from_stocks [NUMERICAL]14 : detailed_household_and_family_stat [CATEGORICAL]12 : wage_per_hour [NUMERICAL]11 : weeks_worked_in_year [NUMERICAL]11 : detailed_occupation_recode [CATEGORICAL]11 : country_of_birth_self [CATEGORICAL]10 : state_of_previous_residence [CATEGORICAL]10 : age [NUMERICAL]9 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]8 : own_business_or_self_employed [CATEGORICAL]8 : marital_stat [CATEGORICAL]8 : full_or_part_time_employment_stat [CATEGORICAL]8 : class_of_worker [CATEGORICAL]6 : sex [CATEGORICAL]5 : tax_filer_stat [CATEGORICAL]4 : country_of_birth_father [CATEGORICAL]3 : race [CATEGORICAL]2 : hispanic_origin [CATEGORICAL]2 : detailed_industry_recode [CATEGORICAL]1 : reason_for_unemployment [CATEGORICAL]1 : num_persons_worked_for_employer [NUMERICAL]1 : country_of_birth_mother [CATEGORICAL]
Attribute in nodes with depth <= 1:140 : detailed_occupation_recode [CATEGORICAL]82 : capital_gains [NUMERICAL]65 : capital_losses [NUMERICAL]62 : education [CATEGORICAL]59 : detailed_industry_recode [CATEGORICAL]47 : dividends_from_stocks [NUMERICAL]31 : wage_per_hour [NUMERICAL]26 : detailed_household_and_family_stat [CATEGORICAL]23 : age [NUMERICAL]22 : state_of_previous_residence [CATEGORICAL]21 : country_of_birth_self [CATEGORICAL]21 : class_of_worker [CATEGORICAL]20 : weeks_worked_in_year [NUMERICAL]20 : sex [CATEGORICAL]15 : country_of_birth_father [CATEGORICAL]12 : own_business_or_self_employed [CATEGORICAL]11 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]10 : num_persons_worked_for_employer [NUMERICAL]9 : tax_filer_stat [CATEGORICAL]9 : full_or_part_time_employment_stat [CATEGORICAL]8 : marital_stat [CATEGORICAL]8 : country_of_birth_mother [CATEGORICAL]6 : member_of_a_labor_union [CATEGORICAL]5 : race [CATEGORICAL]2 : hispanic_origin [CATEGORICAL]1 : reason_for_unemployment [CATEGORICAL]
Attribute in nodes with depth <= 2:399 : detailed_occupation_recode [CATEGORICAL]249 : detailed_industry_recode [CATEGORICAL]170 : capital_gains [NUMERICAL]117 : dividends_from_stocks [NUMERICAL]116 : capital_losses [NUMERICAL]87 : education [CATEGORICAL]59 : wage_per_hour [NUMERICAL]45 : detailed_household_and_family_stat [CATEGORICAL]43 : country_of_birth_father [CATEGORICAL]43 : age [NUMERICAL]40 : country_of_birth_self [CATEGORICAL]38 : state_of_previous_residence [CATEGORICAL]38 : class_of_worker [CATEGORICAL]37 : sex [CATEGORICAL]36 : weeks_worked_in_year [NUMERICAL]33 : country_of_birth_mother [CATEGORICAL]28 : num_persons_worked_for_employer [NUMERICAL]26 : tax_filer_stat [CATEGORICAL]14 : own_business_or_self_employed [CATEGORICAL]14 : marital_stat [CATEGORICAL]12 : full_or_part_time_employment_stat [CATEGORICAL]12 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]8 : member_of_a_labor_union [CATEGORICAL]6 : race [CATEGORICAL]6 : hispanic_origin [CATEGORICAL]2 : major_occupation_code [CATEGORICAL]2 : major_industry_code [CATEGORICAL]1 : reason_for_unemployment [CATEGORICAL]1 : migration_prev_res_in_sunbelt [CATEGORICAL]1 : migration_code-change_in_reg [CATEGORICAL]
Attribute in nodes with depth <= 3:785 : detailed_occupation_recode [CATEGORICAL]668 : detailed_industry_recode [CATEGORICAL]275 : capital_gains [NUMERICAL]220 : dividends_from_stocks [NUMERICAL]197 : capital_losses [NUMERICAL]178 : education [CATEGORICAL]128 : country_of_birth_mother [CATEGORICAL]116 : country_of_birth_father [CATEGORICAL]114 : age [NUMERICAL]98 : wage_per_hour [NUMERICAL]95 : state_of_previous_residence [CATEGORICAL]78 : detailed_household_and_family_stat [CATEGORICAL]67 : class_of_worker [CATEGORICAL]65 : sex [CATEGORICAL]65 : country_of_birth_self [CATEGORICAL]60 : weeks_worked_in_year [NUMERICAL]57 : tax_filer_stat [CATEGORICAL]54 : num_persons_worked_for_employer [NUMERICAL]30 : own_business_or_self_employed [CATEGORICAL]26 : marital_stat [CATEGORICAL]16 : member_of_a_labor_union [CATEGORICAL]15 : major_industry_code [CATEGORICAL]15 : full_or_part_time_employment_stat [CATEGORICAL]15 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]9 : hispanic_origin [CATEGORICAL]7 : race [CATEGORICAL]7 : major_occupation_code [CATEGORICAL]1 : veterans_benefits [CATEGORICAL]1 : region_of_previous_residence [CATEGORICAL]1 : reason_for_unemployment [CATEGORICAL]1 : migration_prev_res_in_sunbelt [CATEGORICAL]1 : migration_code-change_in_reg [CATEGORICAL]1 : detailed_household_summary_in_household [CATEGORICAL]1 : citizenship [CATEGORICAL]
Attribute in nodes with depth <= 5:785 : detailed_occupation_recode [CATEGORICAL]668 : detailed_industry_recode [CATEGORICAL]275 : capital_gains [NUMERICAL]220 : dividends_from_stocks [NUMERICAL]197 : capital_losses [NUMERICAL]178 : education [CATEGORICAL]128 : country_of_birth_mother [CATEGORICAL]116 : country_of_birth_father [CATEGORICAL]114 : age [NUMERICAL]98 : wage_per_hour [NUMERICAL]95 : state_of_previous_residence [CATEGORICAL]78 : detailed_household_and_family_stat [CATEGORICAL]67 : class_of_worker [CATEGORICAL]65 : sex [CATEGORICAL]65 : country_of_birth_self [CATEGORICAL]60 : weeks_worked_in_year [NUMERICAL]57 : tax_filer_stat [CATEGORICAL]54 : num_persons_worked_for_employer [NUMERICAL]30 : own_business_or_self_employed [CATEGORICAL]26 : marital_stat [CATEGORICAL]16 : member_of_a_labor_union [CATEGORICAL]15 : major_industry_code [CATEGORICAL]15 : full_or_part_time_employment_stat [CATEGORICAL]15 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]9 : hispanic_origin [CATEGORICAL]7 : race [CATEGORICAL]7 : major_occupation_code [CATEGORICAL]1 : veterans_benefits [CATEGORICAL]1 : region_of_previous_residence [CATEGORICAL]1 : reason_for_unemployment [CATEGORICAL]1 : migration_prev_res_in_sunbelt [CATEGORICAL]1 : migration_code-change_in_reg [CATEGORICAL]1 : detailed_household_summary_in_household [CATEGORICAL]1 : citizenship [CATEGORICAL]
Condition type in nodes:2418 : ContainsBitmapCondition1018 : HigherCondition31 : ContainsCondition
Condition type in nodes with depth <= 0:137 : ContainsBitmapCondition101 : HigherCondition7 : ContainsCondition
Condition type in nodes with depth <= 1:448 : ContainsBitmapCondition278 : HigherCondition9 : ContainsCondition
Condition type in nodes with depth <= 2:1097 : ContainsBitmapCondition569 : HigherCondition17 : ContainsCondition
Condition type in nodes with depth <= 3:2418 : ContainsBitmapCondition1018 : HigherCondition31 : ContainsCondition
Condition type in nodes with depth <= 5:2418 : ContainsBitmapCondition1018 : HigherCondition31 : ContainsCondition
None

实验 2:目标编码决策森林

目标编码是一种常见的预处理技术,用于将分类特征转换为数字特征。 原封不动地使用高卡因度的分类特征可能会导致过度拟合。 目标编码的目的是用一个或多个数值来代替每个分类特征值,这些数值代表了特征值与目标标签的共现程度。

更准确地说,给定一个分类特征后,本例中的二进制目标编码器将产生三个新的数值特征:

正向频率(positive_frequency):每个特征值与正向目标标签共现的次数;

负向频率(negative_frequency):每个特征值与负向目标标签共现的次数:

每个特征值与正目标标签出现的次数: 每个特征值出现负目标标签的次数。 positive_probability(正概率): 根据特征值,目标标签为正的概率,计算公式为 positive_frequency / (positive_frequency + negative_frequency + correction)。

加入校正项是为了使稀有分类值的划分更加稳定。 修正项的默认值为 1.0。 请注意,目标编码对于无法自动学习密集表示分类特征的模型(如决策森林或核方法)是有效的。 如果使用神经网络模型,建议将分类特征编码为嵌入。

实现二进制目标编码器

为简单起见,我们假设 adapt 和调用方法的输入是预期的数据类型和形状,因此不添加验证逻辑。 建议将分类特征的词汇量_大小传递给 BinaryTargetEncoding 构造函数。 如果未指定,将在 adapt() 方法执行过程中计算。

class BinaryTargetEncoding(layers.Layer):def __init__(self, vocabulary_size=None, correction=1.0, **kwargs):super().__init__(**kwargs)self.vocabulary_size = vocabulary_sizeself.correction = correctiondef adapt(self, data):# data is expected to be an integer numpy array to a Tensor shape [num_exmples, 2].# This contains feature values for a given feature in the dataset, and target values.# Convert the data to a tensor.data = tf.convert_to_tensor(data)# Separate the feature values and target valuesfeature_values = tf.cast(data[:, 0], tf.dtypes.int32)target_values = tf.cast(data[:, 1], tf.dtypes.bool)# Compute the vocabulary_size of not specified.if self.vocabulary_size is None:self.vocabulary_size = tf.unique(feature_values).y.shape[0]# Filter the data where the target label is positive.positive_indices = tf.where(condition=target_values)postive_feature_values = tf.gather_nd(params=feature_values, indices=positive_indices)# Compute how many times each feature value occurred with a positive target label.positive_frequency = tf.math.unsorted_segment_sum(data=tf.ones(shape=(postive_feature_values.shape[0], 1), dtype=tf.dtypes.float64),segment_ids=postive_feature_values,num_segments=self.vocabulary_size,)# Filter the data where the target label is negative.negative_indices = tf.where(condition=tf.math.logical_not(target_values))negative_feature_values = tf.gather_nd(params=feature_values, indices=negative_indices)# Compute how many times each feature value occurred with a negative target label.negative_frequency = tf.math.unsorted_segment_sum(data=tf.ones(shape=(negative_feature_values.shape[0], 1), dtype=tf.dtypes.float64),segment_ids=negative_feature_values,num_segments=self.vocabulary_size,)# Compute positive probability for the input feature values.positive_probability = positive_frequency / (positive_frequency + negative_frequency + self.correction)# Concatenate the computed statistics for traget_encoding.target_encoding_statistics = tf.cast(tf.concat([positive_frequency, negative_frequency, positive_probability], axis=1),dtype=tf.dtypes.float32,)self.target_encoding_statistics = tf.constant(target_encoding_statistics)def call(self, inputs):# inputs is expected to be an integer numpy array to a Tensor shape [num_exmples, 1].# This includes the feature values for a given feature in the dataset.# Raise an error if the target encoding statistics are not computed.if self.target_encoding_statistics == None:raise ValueError(f"You need to call the adapt method to compute target encoding statistics.")# Convert the inputs to a tensor.inputs = tf.convert_to_tensor(inputs)# Cast the inputs int64 a tensor.inputs = tf.cast(inputs, tf.dtypes.int64)# Lookup target encoding statistics for the input feature values.target_encoding_statistics = tf.cast(tf.gather_nd(self.target_encoding_statistics, inputs),dtype=tf.dtypes.float32,)return target_encoding_statistics

让我们测试二进制目标编码器

data = tf.constant([[0, 1],[2, 0],[0, 1],[1, 1],[1, 1],[2, 0],[1, 0],[0, 1],[2, 1],[1, 0],[0, 1],[2, 0],[0, 1],[1, 1],[1, 1],[2, 0],[1, 0],[0, 1],[2, 0],]
)binary_target_encoder = BinaryTargetEncoding()
binary_target_encoder.adapt(data)
print(binary_target_encoder([[0], [1], [2]]))
tf.Tensor(
[[6.         0.         0.85714287][4.         3.         0.5       ][1.         5.         0.14285715]], shape=(3, 3), dtype=float32)

创建模型输入

def create_model_inputs():inputs = {}for feature_name in NUMERIC_FEATURE_NAMES:inputs[feature_name] = layers.Input(name=feature_name, shape=(), dtype=tf.float32)for feature_name in CATEGORICAL_FEATURE_NAMES:inputs[feature_name] = layers.Input(name=feature_name, shape=(), dtype=tf.string)return inputs

使用目标编码实现特征编码

def create_target_encoder():inputs = create_model_inputs()target_values = train_data[[TARGET_COLUMN_NAME]].to_numpy()encoded_features = []for feature_name in inputs:if feature_name in CATEGORICAL_FEATURE_NAMES:# Get the vocabulary of the categorical feature.vocabulary = sorted([str(value) for value in list(train_data[feature_name].unique())])# Create a lookup to convert string values to an integer indices.# Since we are not using a mask token nor expecting any out of vocabulary# (oov) token, we set mask_token to None and  num_oov_indices to 0.lookup = layers.StringLookup(vocabulary=vocabulary, mask_token=None, num_oov_indices=0)# Convert the string input values into integer indices.value_indices = lookup(inputs[feature_name])# Prepare the data to adapt the target encoding.print("### Adapting target encoding for:", feature_name)feature_values = train_data[[feature_name]].to_numpy().astype(str)feature_value_indices = lookup(feature_values)data = tf.concat([feature_value_indices, target_values], axis=1)feature_encoder = BinaryTargetEncoding()feature_encoder.adapt(data)# Convert the feature value indices to target encoding representations.encoded_feature = feature_encoder(tf.expand_dims(value_indices, -1))else:# Expand the dimensions of the numerical input feature and use it as-is.encoded_feature = tf.expand_dims(inputs[feature_name], -1)# Add the encoded feature to the list.encoded_features.append(encoded_feature)# Concatenate all the encoded features.encoded_features = tf.concat(encoded_features, axis=1)# Create and return a Keras model with encoded features as outputs.return keras.Model(inputs=inputs, outputs=encoded_features)

使用预处理器创建梯度提升树模型

在这种情况下,我们使用目标编码作为梯度提升树模型的预处理器,让模型推断输入特征的语义。

def create_gbt_with_preprocessor(preprocessor):gbt_model = tfdf.keras.GradientBoostedTreesModel(preprocessing=preprocessor,num_trees=NUM_TREES,max_depth=MAX_DEPTH,min_examples=MIN_EXAMPLES,subsample=SUBSAMPLE,validation_ratio=VALIDATION_RATIO,task=tfdf.keras.Task.CLASSIFICATION,)gbt_model.compile(metrics=[keras.metrics.BinaryAccuracy(name="accuracy")])return gbt_model

训练和评估模型

gbt_model = create_gbt_with_preprocessor(create_target_encoder())
run_experiment(gbt_model, train_data, test_data)
### Adapting target encoding for: class_of_worker
### Adapting target encoding for: detailed_industry_recode
### Adapting target encoding for: detailed_occupation_recode
### Adapting target encoding for: education
### Adapting target encoding for: enroll_in_edu_inst_last_wk
### Adapting target encoding for: marital_stat
### Adapting target encoding for: major_industry_code
### Adapting target encoding for: major_occupation_code
### Adapting target encoding for: race
### Adapting target encoding for: hispanic_origin
### Adapting target encoding for: sex
### Adapting target encoding for: member_of_a_labor_union
### Adapting target encoding for: reason_for_unemployment
### Adapting target encoding for: full_or_part_time_employment_stat
### Adapting target encoding for: tax_filer_stat
### Adapting target encoding for: region_of_previous_residence
### Adapting target encoding for: state_of_previous_residence
### Adapting target encoding for: detailed_household_and_family_stat
### Adapting target encoding for: detailed_household_summary_in_household
### Adapting target encoding for: migration_code-change_in_msa
### Adapting target encoding for: migration_code-change_in_reg
### Adapting target encoding for: migration_code-move_within_reg
### Adapting target encoding for: live_in_this_house_1_year_ago
### Adapting target encoding for: migration_prev_res_in_sunbelt
### Adapting target encoding for: family_members_under_18
### Adapting target encoding for: country_of_birth_father
### Adapting target encoding for: country_of_birth_mother
### Adapting target encoding for: country_of_birth_self
### Adapting target encoding for: citizenship
### Adapting target encoding for: own_business_or_self_employed
### Adapting target encoding for: fill_inc_questionnaire_for_veteran's_admin
### Adapting target encoding for: veterans_benefits
### Adapting target encoding for: year
Use /tmp/tmpj_0h78ld as temporary training directory
Starting reading the dataset
198/200 [============================>.] - ETA: 0s
Dataset read in 0:00:06.793717
Training model
Model trained in 0:04:32.752691
Compiling model
200/200 [==============================] - 280s 1s/step
Test accuracy: 95.81%

实验 3:决策森林与训练嵌入

在这种情况下,我们建立了一个编码器模型,将分类特征编码为嵌入,其中给定分类特征的嵌入大小是其词汇量大小的平方根。 我们通过反向传播在一个简单的 NN 模型中训练这些嵌入。

嵌入编码器训练完成后,我们将其作为梯度提升树(Gradient Boosted Tree)模型输入特征的预处理器。

请注意,嵌入和决策森林模型不能在一个阶段内协同训练,因为决策森林模型不使用反向传播训练。 相反,必须在初始阶段对嵌入进行训练,然后将其作为决策森林模型的静态输入。

利用嵌入实现特征编码

def create_embedding_encoder(size=None):inputs = create_model_inputs()encoded_features = []for feature_name in inputs:if feature_name in CATEGORICAL_FEATURE_NAMES:# Get the vocabulary of the categorical feature.vocabulary = sorted([str(value) for value in list(train_data[feature_name].unique())])# Create a lookup to convert string values to an integer indices.# Since we are not using a mask token nor expecting any out of vocabulary# (oov) token, we set mask_token to None and  num_oov_indices to 0.lookup = layers.StringLookup(vocabulary=vocabulary, mask_token=None, num_oov_indices=0)# Convert the string input values into integer indices.value_index = lookup(inputs[feature_name])# Create an embedding layer with the specified dimensionsvocabulary_size = len(vocabulary)embedding_size = int(math.sqrt(vocabulary_size))feature_encoder = layers.Embedding(input_dim=len(vocabulary), output_dim=embedding_size)# Convert the index values to embedding representations.encoded_feature = feature_encoder(value_index)else:# Expand the dimensions of the numerical input feature and use it as-is.encoded_feature = tf.expand_dims(inputs[feature_name], -1)# Add the encoded feature to the list.encoded_features.append(encoded_feature)# Concatenate all the encoded features.encoded_features = layers.concatenate(encoded_features, axis=1)# Apply dropout.encoded_features = layers.Dropout(rate=0.25)(encoded_features)# Perform non-linearity projection.encoded_features = layers.Dense(units=size if size else encoded_features.shape[-1], activation="gelu")(encoded_features)# Create and return a Keras model with encoded features as outputs.return keras.Model(inputs=inputs, outputs=encoded_features)

建立一个 NN 模型来训练嵌入模型

def create_nn_model(encoder):inputs = create_model_inputs()embeddings = encoder(inputs)output = layers.Dense(units=1, activation="sigmoid")(embeddings)nn_model = keras.Model(inputs=inputs, outputs=output)nn_model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.BinaryCrossentropy(),metrics=[keras.metrics.BinaryAccuracy("accuracy")],)return nn_modelembedding_encoder = create_embedding_encoder(size=64)
run_experiment(create_nn_model(embedding_encoder),train_data,test_data,num_epochs=5,batch_size=256,
)
Epoch 1/5
200/200 [==============================] - 10s 27ms/step - loss: 8303.1455 - accuracy: 0.9193
Epoch 2/5
200/200 [==============================] - 5s 27ms/step - loss: 1019.4900 - accuracy: 0.9371
Epoch 3/5
200/200 [==============================] - 5s 27ms/step - loss: 612.2844 - accuracy: 0.9416
Epoch 4/5
200/200 [==============================] - 5s 27ms/step - loss: 858.9774 - accuracy: 0.9397
Epoch 5/5
200/200 [==============================] - 5s 26ms/step - loss: 842.3922 - accuracy: 0.9421
Test accuracy: 95.0%

利用嵌入式训练和评估梯度提升树模型

gbt_model = create_gbt_with_preprocessor(embedding_encoder)
run_experiment(gbt_model, train_data, test_data)
Use /tmp/tmpao5o88p6 as temporary training directory
Starting reading the dataset
199/200 [============================>.] - ETA: 0s
Dataset read in 0:00:06.722677
Training model
Model trained in 0:05:18.350298
Compiling model
200/200 [==============================] - 325s 2s/step
Test accuracy: 95.82%

结束语

TensorFlow 决策森林提供了强大的模型,尤其是在处理结构化数据时。 在我们的实验中,梯度提升树模型的测试准确率达到了 95.79%。 当使用带有分类特征的目标编码时,同一模型的测试准确率达到了 95.81%。 在预训练嵌入作为梯度提升树模型的输入时,我们取得了 95.82% 的测试准确率。

决策森林可以与神经网络一起使用,具体方法是:

1)使用神经网络学习输入数据的有用表示,然后使用决策森林完成监督学习任务;

2)创建决策森林和神经网络模型的集合。

请注意,TensorFlow 决策森林(目前)还不支持硬件加速器。 所有训练和推理都在 CPU 上完成。 此外,决策森林的训练程序需要一个适合内存的有限数据集。 然而,增加数据集的规模会带来收益递减,与大型神经网络模型相比,决策森林算法需要更少的示例才能收敛。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/868794.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

社区6月月报 | Apache DolphinScheduler重要修复和优化记录

各位热爱Apache DolphinScheduler的小伙伴们&#xff0c;社区6月月报更新啦&#xff01;这里将记录Apache DolphinScheduler社区每月的重要更新&#xff0c;欢迎关注。 月度Merge Stars 感谢以下小伙伴上个月为Apache DolphinScheduler所做的精彩贡献&#xff08;排名不分先后…

矩阵式键盘最小需要多少个IO驱动

1. 概述 矩阵式键盘由于有其占用硬件资源少的优点有着极其广泛的应用&#xff0c;如PC键盘、电话按键、家用电器等等这类产品.矩阵键盘的基本原理如下所示&#xff08;仅是原理示例&#xff0c;实际实现上还会为每个按键加上防倒流的二极管解决“鬼影”问题&#xff09;&#x…

Windows下编译OpenSSL静态库

目录 1. 版本与下载地址 2. 下载与安装VS2015 3. 下载与安装Perl 4. 测试ActivePerl是否安装正确 5. 下载OpenSSL 6. 编译32位OpenSSL静态库 6.1 解压openssl-1.0.2l.tar.gz 6.2 打开VS2015 x86本机工具命令提示符 6.3 输入命令进入到openssl的目录中 6.4 执行配置命…

完美解决AttributeError: ‘DataFrame‘ object has no attribute ‘ix‘的正确解决方法,亲测有效!!!

完美解决AttributeError: ‘DataFrame’ object has no attribute ix’的正确解决方法&#xff0c;亲测有效&#xff01;&#xff01;&#xff01; 亲测有效 完美解决AttributeError: DataFrame object has no attribute ix的正确解决方法&#xff0c;亲测有效&#xff01;&…

(十五)GLM库对矩阵操作

GLM简单使用 glm是一个开源的对矩阵运算的库&#xff0c;下载地址&#xff1a; https://github.com/g-truc/glm/releases 直接包含其头文件即可使用&#xff1a; #include <glad/glad.h>//glad必须在glfw头文件之前包含 #include <GLFW/glfw3.h> #include <io…

深入解析ROC曲线及其应用

深入解析ROC曲线及其应用 什么是ROC曲线&#xff1f; ROC曲线&#xff08;Receiver Operating Characteristic Curve&#xff09;&#xff0c;即受试者工作特征曲线&#xff0c;是一种用于评估分类模型性能的工具。它通过展示真阳性率&#xff08;TPR&#xff09;与假阳性率&…

免费制作GIF和实时网络监控工具

ScreenToGif 不允许你们还不知道的一款免费且实用好用的GIF动画制作工具软件。可以实时对区域窗口录制、编辑录制多功能模块&#xff0c;操作简单。 支持自定义增减重复帧数、调整循环播放次数、调整播放速度及删除重复帧。 支持对帧做二次编辑&#xff0c;可进行帧翻转、缩放…

政安晨【零基础玩转各类开源AI项目】基于Ubuntu系统部署ComfyUI:功能最强大、模块化程度最高的Stable Diffusion图形用户界面和后台

目录 ComfyUI的特性介绍 开始安装 做点准备工作 在Conda虚拟环境中进行 依赖项的安装 运行 政安晨的个人主页&#xff1a;政安晨 欢迎 &#x1f44d;点赞✍评论⭐收藏 收录专栏: 零基础玩转各类开源AI项目 希望政安晨的博客能够对您有所裨益&#xff0c;如有不足之处&…

从数字化营销与运营视角:看流量效果的数据分析

基于数据打通的“全链路”营销是当下的“时髦”&#xff0c;应用它的前提是什么&#xff1f;深度营销和运营的关键数据如何获得&#xff1f;如何利用数据进行更精准的营销投放&#xff1f;如何利用数据优化投放的效果&#xff1f;如何促进消费者的转化&#xff0c;以及激活留存…

IDEA启动tomcat之后控制台出现中文乱码问题

方法1&#xff1a; 第一步&#xff1a;file--setting--Editor--File Encodings 注意页面中全部改为UTF-8&#xff0c;然后apply再ok 第二步&#xff1a;Run--Edit Configuration&#xff0c;将VM options输入以下值&#xff1a; -Dfile.encodingUTF-8 还是一样先apply再ok …

bdeaver mysql忘记localhost密码修改密码添加用户

描述 bdeaver可以连接当前的localhost数据库&#xff0c;但不知道数据库密码是什么。用这个再建一个用户&#xff0c;用来连接数据库 解决 1、在当前的数据库localhost右键&#xff0c;创建-用户 设置这个用户&#xff0c;密码 加权限 2、连接 用新的账号密码去连接&#x…

千古雄文《渔樵问对》原文、译文、解析

邵雍《渔樵问对》&#xff1a;开悟奇文&#xff0c;揭示世界的终极意义 【邵雍《渔樵问对》&#xff1a;开悟奇文&#xff0c;揭示世界的终极意义】 邵雍&#xff08;1011年1月21日&#xff0d;1077年7月27日&#xff0c;宋真宗大中祥符四年十二月二十五日戌时生至神宗熙宁十…

代谢组数据分析一:代谢组数据准备

介绍 该数据集是来自于Zeybel 2022年发布的文章_Multiomics Analysis Reveals the Impact of Microbiota on Host Metabolism in Hepatic Steatosis_ [@zeybel2022multiomics],它包含了多种组学数据,如: 微生物组(粪便和口腔) 宿主人体学指标 宿主临床学指标 宿主血浆代谢…

SpringCloud Alibaba Sentinel网关流量控制实践总结

官网地址&#xff1a;https://sentinelguard.io/zh-cn/docs/api-gateway-flow-control.html GitHub地址&#xff1a;GitHub Sentinel 网关限流 【1】概述 Sentinel 支持对 Spring Cloud Gateway、Zuul 等主流的 API Gateway 进行限流。 Sentinel 1.6.0 引入了 Sentinel API …

深圳航空x-s3-s4e逆向和顶象滑块动态替换问题

声明(lianxi a15018601872) 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01; 前言(lianxi a…

Hadoop简明教程

文章目录 关于HadoopHadoop拓扑结构Namenode 和 Datanode 基本管理启动Hadoop启动YARN验证Hadoop服务停止Hadoop停止HDFS Hadoop集群搭建步骤准备阶段Java环境配置Hadoop安装与配置HDFS格式化与启动服务测试集群安装额外组件监控与维护&#xff1a; 使用Docker搭建集群使用Hado…

idea2024破解安装教程

&#x1f4d1;打牌 &#xff1a; da pai ge的个人主页 &#x1f324;️个人专栏 &#xff1a; da pai ge的博客专栏 ☁️宝剑锋从磨砺出&#xff0c;梅花香自苦寒来 目录 &#x1f324;️下载安装 &a…

如何将Grammarly内嵌到word中(超简单!)

1、下载 安装包下载链接见文章结尾 官网的grammarly好像只能作为单独软件使用&#xff0c;无法内嵌到word中&#x1f9d0;&#x1f9d0;&#x1f9d0; 2、双击安装包&#xff08;安装之前把Office文件都关掉&#xff09; 3、安装完成&#xff0c;在桌面新建个word文件并打开 注…

Zabbix自动发现

目录 自动发现的主要特点包括&#xff1a; 如何配置自动发现&#xff1a; 实验步骤 1. 创建自动发现规则 2. 给自动发现规则创建动作 3. 给新主机安装agent 在 Zabbix 中&#xff0c;自动发现&#xff08;Auto Discovery&#xff09;是一种强大的功能&#xff0c;用于自…

web端已有项目集成含UI腾讯IM

通过 npm 方式下载 TUIKit 组件&#xff0c;将 TUIKit 组件复制到自己工程的 src 目录下&#xff1a; npm i tencentcloud/chat-uikit-vue mkdir -p ./src/TUIKit && rsync -av --exclude{node_modules,package.json,excluded-list.txt} ./node_modules/tencentcloud/…