工具系列:TensorFlow决策森林_(6)模型预测

文章目录

    • 重要提示
    • 设置
    • `model.predict(...)` 和 `pd_dataframe_to_tf_dataset` 函数
    • `model.predict(...)` 和手动的TF数据集
    • `model.predict(...)`和`model.predict_on_batch()`在字典上的使用
    • 使用YDF格式进行推理

TensorFlow决策森林TF-DF)的 预测
在本文中,您将学习使用 Python API使用之前训练过的 TF-DF模型生成预测的不同方法。

备注:在本文中展示的Python API易于使用,非常适合实验。然而,其他API,如TensorFlow Serving和C++ API更适合生产系统,因为它们更快速和更稳定。所有Serving API的详尽列表可在这里找到。

在本文中,您将会:

  1. 使用model.predict()函数在使用pd_dataframe_to_tf_dataset创建的TensorFlow数据集上进行预测。
  2. 使用model.predict()函数在手动创建的TensorFlow数据集上进行预测。
  3. 使用model.predict()函数在Numpy数组上进行预测。
  4. 使用CLI API进行预测。
  5. 使用CLI API对模型的推理速度进行基准测试。

重要提示

用于预测的数据集应与用于训练的数据集具有相同的特征名称和类型。如果未能这样做,很可能会引发错误。

例如,使用两个特征f1f2训练模型,并尝试在没有f2的数据集上生成预测将失败。请注意,将(某些或全部)特征值设置为“缺失”是可以的。同样,如果训练一个f2是数值特征(例如,float32)的模型,并将该模型应用于f2是文本特征(例如,字符串)的数据集,将会失败。

尽管Keras API对其进行了抽象,但在Python中实例化的模型(例如,使用tfdf.keras.RandomForestModel())和从磁盘加载的模型(例如,使用tf.keras.models.load_model())可能会有不同的行为。值得注意的是,Python实例化的模型会自动应用必要的类型转换。例如,如果将float64特征提供给期望float32特征的模型,这种转换会隐式地执行。然而,对于从磁盘加载的模型,这种转换是不可能的。因此,训练数据和推断数据的类型始终要完全相同。

设置

首先,我们安装 TensorFlow Decision Forests…

# 安装tensorflow_decision_forests库
!pip install tensorflow_decision_forests
Collecting tensorflow_decision_forestsUsing cached tensorflow_decision_forests-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.2 MB)
Requirement already satisfied: wheel in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (0.37.1)
Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.16.0)
Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.3.0)
Requirement already satisfied: tensorflow~=2.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.11.0)
Collecting wurlitzerUsing cached wurlitzer-3.0.3-py3-none-any.whl (7.3 kB)
Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.24.0rc2)
Requirement already satisfied: pandas in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.5.2)
Requirement already satisfied: tensorflow-estimator<2.12,>=2.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (2.11.0)
Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (3.7.0)
Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (1.14.1)
Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (3.3.0)
Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (0.28.0)
Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (14.0.6)
Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (22.0)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (1.51.1)
Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (4.4.0)
Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (0.4.0)
Requirement already satisfied: protobuf<3.20,>=3.9.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (3.19.6)
Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (2.1.1)
Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (65.6.3)
Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (1.6.3)
Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (0.2.0)
Requirement already satisfied: flatbuffers>=2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (22.12.6)
Requirement already satisfied: keras<2.12,>=2.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (2.11.0)
Requirement already satisfied: tensorboard<2.12,>=2.11 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.11.0->tensorflow_decision_forests) (2.11.0)
Requirement already satisfied: python-dateutil>=2.8.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2022.6)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (0.4.6)
Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (2.2.2)
Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (2.28.1)
Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (2.15.0)
Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (3.4.1)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (0.6.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (1.8.1)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (0.3.0rc1)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (5.2.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (4.9)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (1.3.1)
Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (5.1.0)
Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (2022.12.7)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (1.26.13)
Requirement already satisfied: charset-normalizer<3,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (2.1.1)
Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (2.1.1)
Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (3.11.0)
Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (0.5.0rc2)
Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.12,>=2.11->tensorflow~=2.11.0->tensorflow_decision_forests) (3.2.2)
Installing collected packages: wurlitzer, tensorflow_decision_forests
Successfully installed tensorflow_decision_forests-1.1.0 wurlitzer-3.0.3

…,并导入此示例中使用的库。

# 导入所需的库
import tensorflow_decision_forests as tfdf  # 导入决策森林库
import os  # 导入操作系统库
import numpy as np  # 导入numpy库,用于数值计算
import pandas as pd  # 导入pandas库,用于数据处理
import tensorflow as tf  # 导入tensorflow库,用于构建和训练模型
import math  # 导入math库,用于数学计算
2022-12-14 12:06:51.603857: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:06:51.603946: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:06:51.603955: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

model.predict(...)pd_dataframe_to_tf_dataset 函数

TensorFlow Decision Forests 实现了 Keras 模型 API。
因此,TF-DF 模型具有 predict 函数用于进行预测。该函数以 TensorFlow Dataset 作为输入,并输出一个预测数组。
创建 TensorFlow dataset 的最简单方法是使用 Pandas 和 tfdf.keras.pd_dataframe_to_tf_dataset(...) 函数。

下面的示例展示了如何使用 pd_dataframe_to_tf_dataset 创建一个 TensorFlow dataset。

# 创建一个名为pd_dataset的DataFrame对象
pd_dataset = pd.DataFrame({"feature_1": [1,2,3],  # 创建一个名为feature_1的列,包含值1,2,3"feature_2": ["a", "b", "c"],  # 创建一个名为feature_2的列,包含值"a","b","c""label": [0, 1, 0],  # 创建一个名为label的列,包含值0,1,0
})
feature_1feature_2label
01a0
12b1
23c0
# 将Pandas数据集转换为TensorFlow数据集
tf_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_dataset, label="label")# 遍历TensorFlow数据集中的每个样本
for features, label in tf_dataset:# 打印特征print("Features:", features)# 打印标签print("label:", label)
Features: {'feature_1': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, 'feature_2': <tf.Tensor: shape=(3,), dtype=string, numpy=array([b'a', b'b', b'c'], dtype=object)>}
label: tf.Tensor([0 1 0], shape=(3,), dtype=int64)

注意:“pd_”代表“pandas”。 “tf_”代表“TensorFlow”。

TensorFlow数据集是一个输出值序列的函数。这些值可以是简单的数组(称为张量),也可以是组织成结构的数组(例如,组织在字典中的数组)。

以下示例展示了在一个玩具数据集上进行训练和推断(使用predict)的过程:

# 创建一个Pandas的训练数据集
pd_train_dataset = pd.DataFrame({"feature_1": np.random.rand(1000),  # 创建一个包含1000个随机数的特征1列"feature_2": np.random.rand(1000),  # 创建一个包含1000个随机数的特征2列
})# 添加一个标签列,标签值为特征1是否大于特征2的布尔值
pd_train_dataset["label"] = pd_train_dataset["feature_1"] > pd_train_dataset["feature_2"] # 返回创建的训练数据集
pd_train_dataset
feature_1feature_2label
00.6830350.952359False
10.4866410.669202False
20.6855800.967570False
30.2338150.725952False
40.2501870.503956False
............
9950.6766690.043817True
9960.5648270.605345False
9970.9969680.488901True
9980.9873900.097840True
9990.6921320.738431False

1000 rows × 3 columns

# 创建一个包含两个特征的数据集
pd_serving_dataset = pd.DataFrame({"feature_1": np.random.rand(500),  # 创建一个包含500个随机数的特征1列"feature_2": np.random.rand(500),  # 创建一个包含500个随机数的特征2列
})# 输出数据集
pd_serving_dataset
feature_1feature_2
00.3264670.689151
10.8074470.075198
20.0950110.947676
30.8513190.819100
40.4883050.274047
.........
4950.4808030.238047
4960.6335650.722966
4970.9452470.128379
4980.2679380.503427
4990.1858480.901847

500 rows × 2 columns

让我们将Pandas数据框转换为TensorFlow数据集:

# 将Pandas数据集转换为TensorFlow数据集
tf_train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_train_dataset, label="label")# 将Pandas数据集转换为用于模型服务的TensorFlow数据集
tf_serving_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_serving_dataset)

我们现在可以在tf_train_dataset上训练一个模型:

# 创建一个RandomForestModel对象,并设置verbose参数为0(不显示训练过程的详细信息)
model = tfdf.keras.RandomForestModel(verbose=0)# 使用tf_train_dataset数据集对模型进行训练
model.fit(tf_train_dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089[INFO 2022-12-14T12:06:58.981628493+00:00 kernel.cc:1175] Loading model from path /tmpfs/tmp/tmp0b3hukdi/model/ with prefix 0234a68d9d6c49ee
[INFO 2022-12-14T12:06:59.017961685+00:00 abstract_model.cc:1306] Engine "RandomForestOptPred" built
[INFO 2022-12-14T12:06:59.017993244+00:00 kernel.cc:1021] Use fast generic engineWARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f76793294c0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f76793294c0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert<keras.callbacks.History at 0x7f76701969d0>

然后在tf_serving_dataset上生成预测结果:

# 使用模型对tf_serving_dataset进行预测,并打印出前10个预测结果
predictions = model.predict(tf_serving_dataset, verbose=0)[:10]
print(predictions)
array([[0.        ],[0.99999917],[0.        ],[0.29666647],[0.99999917],[0.        ],[0.99999917],[0.99999917],[0.99999917],[0.        ]], dtype=float32)

model.predict(...) 和手动的TF数据集

在前一节中,我们展示了如何使用pd_dataframe_to_tf_dataset函数创建一个TF数据集。这个选项简单但不适用于大型数据集。相反,TensorFlow提供了几个选项来创建一个TensorFlow数据集。
下面的例子展示了如何使用tf.data.Dataset.from_tensor_slices()函数创建一个数据集。

# 创建一个数据集对象,使用tf.data.Dataset.from_tensor_slices()方法,将一个列表[1,2,3,4,5]转换为数据集
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5])# 遍历数据集中的每个元素
for value in dataset:# 打印当前元素的值,使用value.numpy()方法将Tensor对象转换为NumPy数组print("value:", value.numpy())
value: 1
value: 2
value: 3
value: 4
value: 5

TensorFlow 模型的训练采用小批量训练方式:而不是逐个输入,样本被分组成“批次”。对于神经网络,批次大小会影响模型的质量,最佳值需要在训练过程中由用户确定。对于决策森林,批次大小对模型没有影响。然而,为了兼容性的原因,TensorFlow 决策森林要求数据集被分批处理。可以使用 batch() 函数进行分批处理。

# 创建一个数据集对象,使用tf.data.Dataset.from_tensor_slices()方法,将一个列表[1,2,3,4,5]转换为数据集
# 使用batch()方法将数据集分成大小为2的批次
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5]).batch(2)# 遍历数据集中的每个批次
for value in dataset:# 打印当前批次的值,使用numpy()方法将张量转换为numpy数组print("value:", value.numpy())
value: [1 2]
value: [3 4]
value: [5]

TensorFlow决策森林期望数据集具有以下两种结构之一:

  • 特征,标签
  • 特征,标签,权重

特征可以是一个二维数组(其中每列是一个特征,每行是一个示例),也可以是一个数组字典。

以下是一个与TensorFlow决策森林兼容的数据集示例:

# 创建一个包含单个2D数组的数据集
tf_dataset = tf.data.Dataset.from_tensor_slices(([[1,2],[3,4],[5,6]], # 特征[0,1,0], # 标签)).batch(2)# 遍历数据集中的每个批次
for features, label in tf_dataset:print("features:", features) # 打印特征print("label:", label) # 打印标签
features: tf.Tensor(
[[1 2][3 4]], shape=(2, 2), dtype=int32)
label: tf.Tensor([0 1], shape=(2,), dtype=int32)
features: tf.Tensor([[5 6]], shape=(1, 2), dtype=int32)
label: tf.Tensor([0], shape=(1,), dtype=int32)
# 创建一个包含特征字典的数据集
tf_dataset = tf.data.Dataset.from_tensor_slices(({"feature_1": [1,2,3], # 特征1"feature_2": [4,5,6], # 特征2},[0,1,0], # 标签)).batch(2) # 批量大小为2# 遍历数据集中的每个批次
for features, label in tf_dataset:print("features:", features) # 打印特征字典print("label:", label) # 打印标签
features: {'feature_1': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 2], dtype=int32)>, 'feature_2': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 5], dtype=int32)>}
label: tf.Tensor([0 1], shape=(2,), dtype=int32)
features: {'feature_1': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([3], dtype=int32)>, 'feature_2': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([6], dtype=int32)>}
label: tf.Tensor([0], shape=(1,), dtype=int32)

让我们使用第二个选项来训练一个模型。

# 导入必要的库已经完成,不需要再添加import语句
# 生成一个包含两个特征和一个标签的数据集
# 特征1和特征2都是100个随机数
# 标签是一个100个元素的布尔型数组,每个元素都是随机生成的,大于等于0.5为True,小于0.5为False
tf_dataset = tf.data.Dataset.from_tensor_slices(({"feature_1": np.random.rand(100),"feature_2": np.random.rand(100),},np.random.rand(100) >= 0.5, # Label)).batch(2)# 创建一个随机森林模型
# verbose=0表示不输出训练过程中的详细信息
model = tfdf.keras.RandomForestModel(verbose=0)# 使用生成的数据集进行训练
model.fit(tf_dataset)
[INFO 2022-12-14T12:07:00.416575763+00:00 kernel.cc:1175] Loading model from path /tmpfs/tmp/tmpvzrrxxmw/model/ with prefix 0bc6f955d2d1456e
[INFO 2022-12-14T12:07:00.440516186+00:00 kernel.cc:1021] Use fast generic engine<keras.callbacks.History at 0x7f75f016e220>

predict函数可以直接在训练数据集上使用:

# 使用模型对tf_dataset进行预测,verbose=0表示不显示进度条
# 返回结果为前10个预测值
model.predict(tf_dataset, verbose=0)[:10]
array([[0.43666634],[0.58999956],[0.42999968],[0.73333275],[0.75666606],[0.20666654],[0.67666614],[0.66666615],[0.82333267],[0.3999997 ]], dtype=float32)

model.predict(...)model.predict_on_batch()在字典上的使用

在某些情况下,可以使用数组(或数组字典)而不是TensorFlow数据集来使用predict函数。

以下示例使用先前训练过的模型和一个NumPy数组字典。

# 使用模型对输入数据进行预测,返回前10个预测结果
model.predict({"feature_1": np.random.rand(100),"feature_2": np.random.rand(100),}, verbose=0)[:10]
array([[0.6533328 ],[0.5399996 ],[0.2133332 ],[0.22999986],[0.16333325],[0.18333323],[0.3766664 ],[0.5066663 ],[0.20333321],[0.8633326 ]], dtype=float32)

在前面的示例中,数组会自动分批处理。或者,可以使用predict_on_batch函数来确保所有的示例都在同一个批次中运行。

# 获取前10个预测结果
model.predict_on_batch({"feature_1": np.random.rand(100),"feature_2": np.random.rand(100),})[:10]
array([[0.54666626],[0.21666653],[0.18333323],[0.5299996 ],[0.5499996 ],[0.12666662],[0.6299995 ],[0.06000001],[0.33999977],[0.08999998]], dtype=float32)

**注意:**如果predict在原始数据上无法工作,例如上面的示例,请尝试使用predict_on_batch函数或将原始数据转换为TensorFlow数据集。

使用YDF格式进行推理

这个例子展示了如何使用CLI API(其他Serving APIs之一)运行一个经过训练的TF-DF模型。我们还将使用Benchmark工具来测量模型的推理速度。

让我们先训练并保存一个模型:

# 创建一个梯度提升树模型对象,verbose参数设置为0表示不输出训练过程的详细信息
model = tfdf.keras.GradientBoostedTreesModel(verbose=0)# 将pandas的训练数据集转换为TensorFlow的数据集,并指定"label"列作为标签
train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_train_dataset, label="label")# 使用转换后的训练数据集来训练模型
model.fit(train_dataset)# 将训练好的模型保存到文件中
model.save("my_model")
2022-12-14 12:07:00.950798: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1765] Subsample hyperparameter given but sampling method does not match.
2022-12-14 12:07:00.950839: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1778] GOSS alpha hyperparameter given but GOSS is disabled.
2022-12-14 12:07:00.950846: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1787] GOSS beta hyperparameter given but GOSS is disabled.
2022-12-14 12:07:00.950852: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1799] SelGB ratio hyperparameter given but SelGB is disabled.
[INFO 2022-12-14T12:07:01.160357659+00:00 kernel.cc:1175] Loading model from path /tmpfs/tmp/tmpo37712qo/model/ with prefix 391746915b7842cb
[INFO 2022-12-14T12:07:01.164736847+00:00 kernel.cc:1021] Use fast generic engine
WARNING:absl:Found untraced functions such as call_get_leaves, _update_step_xla while saving (showing 2 of 2). These functions will not be directly callable after loading.INFO:tensorflow:Assets written to: my_model/assetsINFO:tensorflow:Assets written to: my_model/assets

让我们也将数据集导出为一个csv文件:

# 将pd_serving_dataset保存为dataset.csv文件
pd_serving_dataset.to_csv("dataset.csv")

让我们下载并提取Yggdrasil Decision Forests的CLI工具。

# 下载 Yggdrasil Decision Forests 的命令行工具
!wget https://github.com/google/yggdrasil-decision-forests/releases/download/1.0.0/cli_linux.zip# 解压缩下载的文件
!unzip cli_linux.zip
--2022-12-14 12:07:01--  https://github.com/google/yggdrasil-decision-forests/releases/download/1.0.0/cli_linux.zip
Resolving github.com (github.com)... 140.82.114.3
Connecting to github.com (github.com)|140.82.114.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/360444739/bfcd0b9d-5cbc-42a8-be0a-02131875f9a6?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20221214%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20221214T120701Z&X-Amz-Expires=300&X-Amz-Signature=94e7b8fd2c219cbe6305222b34f566360eb9fea8ea35e8303519f09b04744b93&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=360444739&response-content-disposition=attachment%3B%20filename%3Dcli_linux.zip&response-content-type=application%2Foctet-stream [following]
--2022-12-14 12:07:01--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/360444739/bfcd0b9d-5cbc-42a8-be0a-02131875f9a6?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20221214%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20221214T120701Z&X-Amz-Expires=300&X-Amz-Signature=94e7b8fd2c219cbe6305222b34f566360eb9fea8ea35e8303519f09b04744b93&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=360444739&response-content-disposition=attachment%3B%20filename%3Dcli_linux.zip&response-content-type=application%2Foctet-stream
Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 31516027 (30M) [application/octet-stream]
Saving to: ‘cli_linux.zip’cli_linux.zip         0%[                    ]       0  --.-KB/s               
cli_linux.zip         2%[                    ] 727.40K  3.47MB/s               
cli_linux.zip        13%[=>                  ]   4.01M  9.90MB/s               
cli_linux.zip        53%[=========>          ]  16.01M  26.1MB/s               
cli_linux.zip       100%[===================>]  30.06M  38.2MB/s    in 0.8s    2022-12-14 12:07:03 (38.2 MB/s) - ‘cli_linux.zip’ saved [31516027/31516027]Archive:  cli_linux.zipinflating: README                  inflating: cli.txt                 inflating: train                   inflating: show_model              inflating: show_dataspec           inflating: predict                 inflating: infer_dataspec          inflating: evaluate                inflating: convert_dataset         inflating: benchmark_inference     inflating: edit_model              inflating: synthetic_dataset       inflating: grpc_worker_main        inflating: LICENSE                 inflating: CHANGELOG.md            

最后,让我们进行预测:

备注:

  • TensorFlow决策森林(TF-DF)基于Yggdrasil决策森林(YDF)库,并且TF-DF模型始终在内部包含一个YDF模型。将TF-DF模型保存到磁盘时,TF-DF模型目录包含一个assets子目录,其中包含YDF模型。此YDF模型可与所有YDF工具一起使用。在下一个示例中,我们将使用predictbenchmark_inference工具。有关更多详细信息,请参阅模型格式文档。
  • YDF工具假定数据集的类型是使用前缀指定的,例如csv:。有关更多详细信息,请参阅YDF用户手册。
# 该代码是用于执行预测的脚本# 导入必要的库# 执行预测
# 使用"./predict"命令来执行预测
# "--model=my_model/assets"参数指定了模型的路径
# "--dataset=csv:dataset.csv"参数指定了数据集的路径和格式
# "--output=csv:predictions.csv"参数指定了预测结果的输出路径和格式
!./predict --model=my_model/assets --dataset=csv:dataset.csv --output=csv:predictions.csv
[INFO abstract_model.cc:1296] Engine "GradientBoostedTreesQuickScorerExtended" built
[INFO predict.cc:133] Run predictions with semi-fast engine

我们现在可以看一下预测结果:

# 读取CSV文件"predictions.csv"并将其存储为一个DataFrame对象
data = pd.read_csv("predictions.csv")
12
00.9667790.033221
10.0317730.968227
20.9667790.033221
30.6000730.399927
40.0308850.969115
.........
4950.0308850.969115
4960.9482520.051748
4970.0317730.968227
4980.9669960.033004
4990.9667790.033221

500 rows × 2 columns

模型的推理速度可以使用基准推理工具来测量。

**注意:**在YDF版本1.1.0之前,基准推理中使用的数据集需要有一个__LABEL列。

# 创建一个空的标签列
pd_serving_dataset["__LABEL"] = 0# 将数据集保存为csv文件
pd_serving_dataset.to_csv("dataset.csv")
# 运行benchmark_inference脚本进行推理性能测试# 参数说明:
# --model:指定模型的路径,这里是my_model/assets
# --dataset:指定数据集的路径和格式,这里是csv:dataset.csv,表示数据集是以csv格式存储在dataset.csv文件中
# --batch_size:指定每个推理批次的大小,这里是100
# --warmup_runs:指定预热运行的次数,用于消除冷启动的影响,这里是10次
# --num_runs:指定总共运行的次数,用于统计平均推理性能,这里是50次
!./benchmark_inference \--model=my_model/assets \--dataset=csv:dataset.csv \--batch_size=100 \--warmup_runs=10 \--num_runs=50
[INFO benchmark_inference.cc:245] Loading model
[INFO benchmark_inference.cc:248] The model is of type: GRADIENT_BOOSTED_TREES
[INFO benchmark_inference.cc:250] Loading dataset
[INFO benchmark_inference.cc:259] Found 3 compatible fast engines.
[INFO benchmark_inference.cc:262] Running GradientBoostedTreesGeneric
[INFO decision_forest.cc:639] Model loaded with 27 root(s), 1471 node(s), and 2 input feature(s).
[INFO benchmark_inference.cc:262] Running GradientBoostedTreesQuickScorerExtended
[INFO benchmark_inference.cc:262] Running GradientBoostedTreesOptPred
[INFO decision_forest.cc:639] Model loaded with 27 root(s), 1471 node(s), and 2 input feature(s).
[INFO benchmark_inference.cc:268] Running the slow generic engine
batch_size : 100  num_runs : 50
time/example(us)  time/batch(us)  method
----------------------------------------0.22425          22.425  GradientBoostedTreesOptPred [virtual interface]0.2465           24.65  GradientBoostedTreesQuickScorerExtended [virtual interface]0.6875           68.75  GradientBoostedTreesGeneric [virtual interface]1.825           182.5  Generic slow engine
----------------------------------------

在这个基准测试中,我们可以看到不同推理引擎的推理速度。例如,“time/example(us) = 0.6315”(在不同运行中可能会有所变化)表示一个示例的推理需要0.63微秒。也就是说,模型每秒可以运行约160万次。

**注意:**TF-DF和其他API总是会自动选择可用的最快推理引擎。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/578432.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

kubernetes(k8s) Yaml 文件详解

YAML格式&#xff1a;用于配置和管理&#xff0c;YAML是一种简洁的非标记性语言&#xff0c;内容格式人性化&#xff0c;较易读。 1、查看API 资源版本标签 kubectl api-versions 2、编写资源配置清单 kubectl create -f nginx-test.yaml --validatefalse 2.3 查看创建的po…

氢燃料电池商用车系统架构开发与集成技术

一、国家及不同地区对氢能发展支持政策 近三年国家对氢能及燃料电池产业的支持政策 近年来22个省市的发展规划中提到了大力支持氢能源产业发展 二、燃料电池客车架构分解及国内外已有车型 未来燃料电池客车发展方向 未来燃料电池客车新增加的燃料电池堆产业链及供应商 国内外差…

Java毕业设计——vue+springboot音乐网站音乐播放器,歌曲管理系统

1&#xff0c;项目背景 随着计算机技术的发展&#xff0c;网络技术对我们生活和工作显得越来越重要&#xff0c;特别是现在信息高度发达的今天&#xff0c;人们对最新信息的需求和发布迫切的需要及时性。为了满足不同人们对网络需求&#xff0c;各种特色&#xff0c;各种主题的…

spring初始化bean之后执行某个方法

这个问题可以分两种解释&#xff1a; 1. 某个bean初始化执行? 2. 所有bean初始化后执行? 第一个问题可以在spring bean的生命周期中找到答案&#xff1a; bean定义-实例化-初始化-销毁。注意&#xff1a; 这里的bean定义是指所有的bean定义完成&#xff0c;然后才继续执…

1.Linux是什么与如何学习

第 1 章 Linux 是什么与如何学习 历史部分略过。 1.2.5 Linux的内核版本 Linux的内核版本编号有点类似如下的样子&#xff1a; 3.10.0-123.el7.x86_64 主版本.次版本.发布版本-修改版本虽然编号就是如上的方式来编的&#xff0c;不过依据 Linux 内核的发展期程&#xff0c;…

使用代码生成器生成代码 mybatis-plus-generator

1、将相关依赖导入到项目中 <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-generator</artifactId><version>3.4.1</version></dependency><dependency><groupId>org.apache.velocity<…

猫头虎博主第六期赠书活动:《手机摄影短视频和后期从小白到高手》

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…

移动应用专项测试:确保用户体验的全方位保障

引言&#xff1a; 随着移动应用的普及和功能的不断增加&#xff0c;开发者需要对应用进行全面的测试&#xff0c;以确保其在不同场景下的稳定性、安全性和用户体验。本文将介绍一些常见的移动应用专项测试&#xff0c;包括安装测试、卸载测试、特殊操作测试、交互测试、通知测试…

MySQL中varchar和int隐式转换的注意事项!

一、前言 在一个阳光明媚的下午&#xff0c;我们的测试在运行SQL是发现了一个灵异事件。 别着急&#xff0c;等我慢慢说来&#xff0c;是一个查询库存的SQL&#xff0c;控制台打印了&#xff0c;查询为0条记录。 想着不太信&#xff0c;自己把SQL粘出来执行一下&#xff0c;刚…

【美团大数据面试】Java面试题附答案

目录 1.多线程代码示例 2.单例代码示例 3.LinkedBlockingQueue原理解析 4.模板设计模式讲解 5.生产者-消费者队列设计方法 6.堆内存和栈内存的区别 7.ThreadLocal底层机制 8.synchronized原理&#xff0c;存在的问题&#xff0c;解决方案 9.volatile使用场景和原理&am…

解析动态规划

本文由 简悦 SimpRead 转码&#xff0c; 原文地址 juejin.cn 前言 我们刷 leetcode 的时候&#xff0c;经常会遇到动态规划类型题目。动态规划问题非常非常经典&#xff0c;也很有技巧性&#xff0c;一般大厂都非常喜欢问。今天跟大家一起来学习动态规划的套路&#xff0c;文章…

突破PHP disable_functions方法

1. 利用 LD_PRELOAD 环境变量 知识扫盲 LD_PRELOAD&#xff1a;是Linux系统的一个环境变量&#xff0c;它指定的*.so文件会在程序本身的*.so文件之前被加载。putenv()&#xff1a;PHP函数&#xff0c;可以设置环境变量mail()&#xff0c;error_log()&#xff1a;PHP函数&…

Tekton

一. 概念 Tekton 官网 Github Tekton 是一种用于构建 CI/CD 管道的云原生解决方案&#xff0c;它由提供构建块的 Tekton Pipelines&#xff0c;Tekton 作为 Kubernetes 集群上的扩展安装和运行&#xff0c;包含一组 Kubernetes 自定义资源&#xff0c;这些资源定义了您可以为…

redis-连接数占满解决

作者 马文斌 时间 2023-12-12 标签 redis 连接风暴 连接数占满 背景 近期有redis 数据库连不上&#xff0c;起初以为是redis的连接数满了&#xff0c;排查到后面发现问题不简单啊&#xff0c;下面看看具体的排查过程。 连不上的原有有哪些 密码不对 网络不好,丢包 原来…

Google模拟面试【面试】

Google模拟面试【面试】 2023-12-25 16:00:42 Google代码面试 Prompt #1 给一个二叉树&#xff0c;定义深度为结点到根&#xff1b;所要遍历的边的数量。 示例二叉树中8的深度为3&#xff0c;1的深度为0。 编写函数返回这个二叉树的所有结点的深度和。 示例二叉树答案是16 …

Openstack开启虚拟化嵌套

好久没写东西了&#xff0c;前两天我准备在虚机上装一个vmware 的虚机&#xff0c;结果失败了&#xff0c;提示如下&#xff0c;由于我是虚机上安装虚机&#xff0c;我的宿主机肯定是开启了vt-x和vt-d的 查了一些资料&#xff0c;这个需要打开nested,先看看nested返回是否为Y&a…

Unity向量按照某一点进行旋转

Unity向量按照某一点进行旋转 一、unity的旋转二、向量按照原点进行旋转注意案例 三、向量按照指定位置进行旋转案例 一、unity的旋转 首先要知道一点就是在Unity的旋转中使用过四元数进行旋转的&#xff0c;如果对一个物体的rotation直接赋值你会发现结果不是你最终想要的结果…

1111111111111111111

11111111111111111111111111

Astro学习使用记录

Astro学习使用记录 前言Astro是什么&#xff1f;问题记录1. 使用组件库2. pages 目录下不要放 除 .astro 文件以外的文件 总结 前言 Astro的出现 为了追求前端应用的性能与速度&#xff0c;近年前端界涌现出许多的解决方案&#xff0c;像SSR、SSG解决方案再到今天的island架构…

迎新辞旧,欢度元旦

迎新辞旧&#xff0c;欢度元旦 新年钟声即将敲响&#xff0c;欢度元旦的时刻即将来临。在这个美好的时刻&#xff0c;我们纷纷辞旧迎新&#xff0c;放飞自我追逐梦想的翅膀。让羊大师带大家一起来庆祝新年的到来&#xff0c;共同创造美好的开始&#xff01; 一、迎新辞旧&…