Gradio 是一个用于快速创建机器学习模型和用户界面之间交互的 Python 库。它允许你无需编写大量前端代码,就能将机器学习模型部署为可交互的网页应用。以下是一个基于 Gradio 可视化部署机器学习应用的基本步骤:
-
安装 Gradio:
首先,你需要安装 Gradio 库。你可以使用 pip 来安装:pip install gradio
-
导入 Gradio 并定义界面:
在你的 Python 脚本中,导入 Gradio,并定义输入和输出的组件。这些组件将构成你的交互界面的基础。 -
加载机器学习模型:
加载你已经训练好的机器学习模型。这可以是一个 scikit-learn 模型、TensorFlow 模型、PyTorch 模型等。 -
定义预测函数:
创建一个函数,该函数接受 Gradio 界面上的输入,使用加载的模型进行预测,并返回预测结果。 -
创建 Gradio 接口:
使用 Gradio 的Interface
类(或其简写形式gr.Interface
)来创建交互界面。你需要指定输入组件、输出组件以及预测函数。 -
启动 Gradio 应用:
调用launch()
方法来启动 Gradio 应用。默认情况下,它将在本地服务器上运行,并在浏览器中自动打开。
以下是一个简单的示例,展示了如何使用 Gradio 部署一个基于 scikit-learn 的鸢尾花分类模型:
import gradio as gr
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 训练模型
model = LogisticRegression(max_iter=200)
model.fit(X_train, y_train)# 定义预测函数
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float) -> str:X_new = [[sepal_length, sepal_width, petal_length, petal_width]]X_new_scaled = scaler.transform(X_new)prediction = model.predict(X_new_scaled)return iris.target_names[prediction[0]].capitalize()# 创建 Gradio 接口
iface = gr.Interface(fn=predict,inputs=[gr.inputs.NumberBox(label="Sepal Length"),gr.inputs.NumberBox(label="Sepal Width"),gr.inputs.NumberBox(label="Petal Length"),gr.inputs.NumberBox(label="Petal Width")],outputs=gr.outputs.Textbox(label="Predicted Iris Species"))# 启动 Gradio 应用
iface.launch()
在这个示例中,我们创建了一个简单的 Gradio 界面,用户可以通过输入鸢尾花的四个特征(花萼长度、花萼宽度、花瓣长度、花瓣宽度)来预测鸢尾花的种类。预测结果将以文本形式显示。
你可以根据自己的需求调整输入和输出组件,以及预测函数。Gradio 支持多种类型的输入和输出组件,如文本框、下拉菜单、图像上传、滑块等,使得创建复杂的交互界面变得非常容易。