前言: 使用在teachable machine训练的h5格式模型
tensorflow使用篇
1. 使用teachable machine训练模型
地址: 传送门, 需要梯子翻一下
训练后, 导出的时候可以选择三种类型
导出模型文件 converted_keras.zip (py版)
解压后得到
2. py项目中使用模型
根据你当时使用teachable machine的时间, 选择py项目中TensorFlow的版本
我现在使用的是必须是2.3.0版本及以上才行, 然后我直接升级到了2.10.0
如果版本不匹配会报错如下
ValueError: (‘Unrecognized keyword arguments:’, dict_keys([‘ragged’]))
解决的方法就是升级TensorFlow版本
pip install tensorflow==2.10.0 --upgrade
目录结构如下
第一种app.py, 判断项目本地的图片, 可以直接使用postman请求无参get, 可以得到卡类型
# -*- coding: utf-8 -*-
import flask as fk
from flask import jsonify, request
import tensorflow as tf
from PIL import Image
import numpy as npapp = fk.Flask(__name__)# 加载标签映射
class_label_map = {}
with open('labels.txt', 'r', encoding='utf-8') as f:for line in f.readlines():index, label = line.strip().split()class_label_map[int(index)] = labelprint(class_label_map)# 加载模型
global model
model = tf.keras.models.load_model('keras_model.h5')
print('模型加载成功')# 图片预处理方法
def preprocess_image(image_path):img = Image.open(image_path)# 调整大小、归一化等操作,具体取决于模型要求img_resized = img.resize((224, 224))img_array = np.array(img_resized) / 255.0 # 将像素值归一化到[0, 1]区间img_array = np.expand_dims(img_array, axis=0) # 添加批量维度(batch size = 1)return img_array# 预测方法
def load_model():# 准备输入数据input_data = preprocess_image("danka.jpg")# input_data = preprocess_image("duolianka.jpg")# 预测predictions = model.predict(input_data)# 获取预测结果predicted_class_index = np.argmax(predictions[0])# 获取预测的类名predicted_class_name = class_label_map[predicted_class_index]print(f"Predicted class: {predicted_class_name}")return predicted_class_name# 测试预测
@app.route('/api/hello', methods=['GET'])
def get_data():return load_model()# 假设我们要提供一个获取用户信息的API
@app.route('/api/user/<int:user_id>', methods=['GET'])
def get_user_info(user_id):# 这里模拟从数据库或其他服务获取用户信息user_data = {'id': user_id, 'name': 'John Doe', 'email': 'john.doe@example.com'}# 假设用户不存在,返回404# 返回JSON格式的用户信息return jsonify(user_data)# 定义一个接收POST请求的路由,假设该接口用于创建新用户
@app.route('/api/users', methods=['POST'])
def create_user():# 从请求体中获取JSON格式的数据data = request.get_json()# 检查必要的字段是否存在if not all(key in data for key in ('username', 'email', 'password')):return jsonify({"error": "Missing required fields"}), 400# 这里仅做示例,实际开发中应将数据保存至数据库等new_user = {'username': data['username'],'email': data['email'],'password': data['password']}# 模拟用户创建成功resultMap = {"message": "User created successfully", "user": new_user}# 返回201状态码表示已创建资源return jsonify(resultMap), 201if __name__ == '__main__':app.run(host='0.0.0.0', port=5000, debug=True)