1. 前言
本文作者以一个前端新手视角,部署自己的神经网络模型作为后端,搭建自己的网站实现应用的实战经历。目前实现的网页应用有:
- AI 语音服务主页
- AI 语音识别
- AI 语音合成
- AI CP号码生成器
欢迎大家试用感受,本文将以博客基于GAN的序列号码预测中训练的pytorch模型为例,进行后端和前端搭建,并构建网站,以下是最终成果展示。
网址:http://www.funsound.cn:5002
2. 相关内容
相关知识点和工具语言如下,希望读者有一定的了解
- 腾讯云服务器
- Html + JavaScript 进行UI设计
- pytorch 模型,onnx 模型导出
- python flask 后端
- 多进程服务实现并发访问
3. 后端工作
3.1 pytorch 模型转 onnx 模型
ONNX 模型是通用的NN格式,采用onnx格式将在服务器cpu推理上速度更快。
# 实例化生成器模型
generator = Generator(input_dim, output_dim)# 加载训练好的生成器模型权重
generator.load_state_dict(torch.load('models/generator_model.pth'))
generator.eval() # 设置生成器为评估模式# 导出模型为 ONNX 格式
generator.export_onnx('models/generator_model.onnx', (batch_size, input_dim))
加载onnx模型进行推理
# 加载 ONNX 模型
ort_session = ort.InferenceSession('models/generator_model.onnx')
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
input_noise = np.random.randn(batch_size, input_dim).astype(np.float32)
generated_numbers = ort_session.run([output_name], {input_name: input_noise})[0]
基于onnx推理的CP号码生成算法封装成 【generator. LOTTO_GENERATOR】
3.2 多进程onnx服务
网站访问往往是一个多路并发访问场景,面对众多用户的请求,送入待处理,后端采用多进程进行调度。
if __name__ == "__main__":from generator import LOTTO_GENERATOR # 我们的gan网络生成算法# 初始化worker数量nj = 4backends = [LOTTO_GENERATOR() for _ in range(nj)]workers = init_workers(nj=nj, backends=backends)# 获取并打印所有worker的状态res = get_workers_state(workers)print(res)# 提交100个任务worker_dir = "demo"for _ in range(100):task_id = generate_random_string(length=11) # 生成长度为11的随机字符串作为task_idtask_dir = f"{worker_dir}/{task_id}" # 任务目录task_inp = generate_random_number_string(length=8) # 生成长度为8的随机数字字符串作为任务输入task_prgs = f'{task_dir}/progress.txt' # 任务进度文件路径task_rst = f'{task_dir}/result.txt' # 任务结果文件路径os.system(f'mkdir -p {task_dir}') # 创建任务目录params = {'task_id': task_id,'task_inp': task_inp,'task_prgs': task_prgs,'task_rst': task_rst}submit_task(workers=workers, params=params) # 提交任务time.sleep(0.01) # 等待10毫秒后提交下一个任务
注意代码中多进程服务处理用户请求采用异步方式,用户提交任务后获取task_id, 主进程不会阻塞, 用户根据task_id来追踪自己的任务进度(task_prgs)和结果(task_rst)。
其中调度方式根据子进程的忙碌情况决定,选取最闲的子进程处理用户请求
def submit_task(workers, params: dict):# 找到任务最少的workermin_task_worker = min(workers, key=lambda worker: worker.queue.qsize() + worker.working.value)min_task_worker.queue.put(params) # 将任务提交到最少任务的worker队列中print(f'assign the task to worker-{min_task_worker.wid}'
3.3 基于Flask搭建http访问接口
我们的后端代码如下,例如我们的ip 是 100.100.123,端口试用5002,则构建了以下http访问接口:
http一般格式: 【http://IP地址:端口/路由】
- http://100.100.123:5002/ 主页
- http://100.100.123:5002/lotto 提交任务 【输入:用户幸运数字,输出:task_id】
- http://100.100.123:5002/get_worker_state 子进程负载状态 【输入:task_id,输出:负载状态】
- http://100.100.123:5002/get_task_prgs 任务完成进度 【输入:task_id,输出:任务进度】
- http://100.100.123:5002/get_task_rst 任务结果 【输入:task_id,输出:任务结果】
from flask import Flask, jsonify,render_template,request
from generator import LOTTO_GENERATOR
from workers import *
import datetime
import json def get_now_time():current_time = datetime.datetime.now()return current_time.strftime('%Y-%m-%d %H:%M:%S')def task_log(text,log_file="TASK.LOG"):with open(log_file,'a+') as f:print(text,file=f)app = Flask(__name__)
USER_DIR = "user_data"
TASK_MAP = {}"""
主页
"""
@app.route('/')
def index():return render_template('index.html')@app.route('/lotto', methods=['POST'])
def predict():# 获取客户端信息ip = request.remote_addrdata = request.get_json()task_id = ip + "_" + generate_random_string(20)user_id = iptask_inp = data['luck_num'] # 8位数字字符串 或者 空字符串task_dir = "%s/%s/%s" % (USER_DIR, user_id, task_id)task_prgs = f'{task_dir}/progress.txt' # 任务进度文件路径task_rst = f'{task_dir}/result.txt' # 任务结果文件路径task_log(f"TIME:{get_now_time()}")task_log(f"TASK_ID:{task_id}")task_log("")# 生成临时文件if not os.path.exists(task_dir): os.makedirs(task_dir)with open(task_prgs,'wt') as f:json.dump([0.0,'start'],f,indent=4)TASK_MAP[task_id] = {'task_dir': task_dir,'task_prgs': task_prgs,'task_rst': task_rst, }# 提交任务params = {'task_id': task_id,'task_inp': task_inp,'task_prgs': task_prgs,'task_rst': task_rst}submit_task(workers=workers, params=params) # 提交任务return task_id"""
获得引擎状态
"""
@app.route('/get_worker_state', methods=['GET'])
def get_worker_state():ip = request.remote_addrres = {}for worker in workers:res[worker.wid] = worker.queue.qsize() + worker.working.valuereturn res"""
获得任务进度
"""
@app.route('/get_task_prgs', methods=['POST'])
def get_task_prgs():ip = request.remote_addrdata = request.get_json()task_id = data['task_id']if task_id not in TASK_MAP:return [-1, 'no such task id']else:task_prgs = TASK_MAP[task_id]['task_prgs']with open(task_prgs, 'rt') as f:content = json.load(f)return content"""
获得任务结果
"""
@app.route('/get_task_rst', methods=['POST'])
def get_task_rst():ip = request.remote_addrdata = request.get_json()task_id = data['task_id']if task_id not in TASK_MAP:return {}else:task_rst = TASK_MAP[task_id]['task_rst']with open(task_rst, 'rt') as f:content = json.load(f)return contentif __name__ == '__main__':# 初始化worker数量nj = 4backends = [LOTTO_GENERATOR() for _ in range(nj)]workers = init_workers(nj=nj, backends=backends)app.run(host='0.0.0.0', port=5002)
这样后端就搭建起来啦,这里有4个onnx 模型在后台监听
3.4 python客户端测试
import requests
import time
import json# 定义服务端地址
server_url = 'http://110.110.123:5002' # 你的服务器和端口
headers = {'Content-Type': 'application/json'}# 检查服务器 Worker 状态
def check_worker_status():response = requests.get(f'{server_url}/get_worker_state')if response.status_code == 200:worker_status = response.json()idle_workers = [wid for wid, status in worker_status.items() if status == 0]if idle_workers:print("Idle workers available:", idle_workers)return Trueelse:print("No idle workers available.")return Falseelse:print("Failed to get worker status.")return False# 提交任务
def submit_task(json_data):if not check_worker_status():print("No idle workers available. Task submission failed.")return Noneresponse = requests.post(f'{server_url}/lotto', json=json_data)if response.status_code == 200:task_id = response.textprint(f"Task submitted successfully. Task ID: {task_id}")return task_idelse:print("Failed to submit task.")return None# 轮询任务进度
def poll_task_progress(task_id):while True:json_data = {'task_id':task_id}response = requests.post(f'{server_url}/get_task_prgs', json=json_data)if response.status_code == 200:progress, text = response.json()print(f"Progress: {progress}, Status: {text}")if progress == 1:print("Task completed successfully.")return Trueelif progress == -1:print(f"Task failed: {text}")return Falseelse:print("Failed to get task progress.")return Falsetime.sleep(3) # 每3秒查询一次# 获取任务结果
def get_task_result(task_id):json_data = {'task_id':task_id}response = requests.post(f'{server_url}/get_task_rst', json=json_data)if response.status_code == 200:result = response.json()print("Task result:", result)return resultelse:print("Failed to get task result.")return None# 主函数
def main():json_data = {'luck_num':""}# json_data = {'luck_num':"12345678"}# 提交TTS任务task_id = submit_task(json_data)if not task_id:return# 轮询任务进度if poll_task_progress(task_id):# 获取任务结果result = get_task_result(task_id)if __name__ == "__main__":main()
访问成功
4. 前端工作
4.1 JavaScript 访问 http 函数
JavaScript 调用 http端口如下:
<script>/* 提交任务 */function submitTask() {var button = document.querySelector("button");button.disabled = true;button.innerText = "正在生成...";var useLuckyNumber = document.getElementById("use_lucky_number").checked;var luckInput = document.getElementById("luck_input");var luckNum = useLuckyNumber ? luckInput.value : "";var xhr = new XMLHttpRequest();xhr.open("POST", "/lotto", true);xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");xhr.onreadystatechange = function () {if (xhr.readyState == 4 && xhr.status == 200) {var taskId = xhr.responseText;checkProgress(taskId);} else if (xhr.readyState == 4) {button.disabled = false;button.innerText = "生成";alert("任务提交失败,请重试。");}};xhr.send(JSON.stringify({luck_num: luckNum}));}/* 检查任务进度 */function checkProgress(taskId) {var xhr = new XMLHttpRequest();xhr.open("POST", "/get_task_prgs", true);xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");xhr.onreadystatechange = function () {if (xhr.readyState == 4 && xhr.status == 200) {var response = JSON.parse(xhr.responseText);var progress = response[0];var status = response[1];// document.getElementById("progress").innerText = "进度: " + progress + ", 状态: " + status;if (progress == 1) {getResult(taskId);} else if (progress == -1) {var button = document.querySelector("button");button.disabled = false;button.innerText = "生成";alert("任务失败: " + status);} else {setTimeout(function() { checkProgress(taskId); }, 3000);}}};xhr.send(JSON.stringify({task_id: taskId}));}/* 获取任务结果 */function getResult(taskId) {var xhr = new XMLHttpRequest();xhr.open("POST", "/get_task_rst", true);xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");xhr.onreadystatechange = function () {if (xhr.readyState == 4 && xhr.status == 200) {var response = JSON.parse(xhr.responseText);displayResult(response);var button = document.querySelector("button");button.disabled = false;button.innerText = "生成";}};xhr.send(JSON.stringify({task_id: taskId}));}/* 显示任务结果 */function displayResult(response) {var frontNumbers = response.front_numbers;var backNumbers = response.back_numbers;var resultContainer = document.getElementById("result");resultContainer.innerHTML = ""; // 清空之前的结果for (var i = 0; i < frontNumbers.length; i++) {var lotterySet = document.createElement("div");lotterySet.className = "lottery-set";frontNumbers[i].forEach(function(number) {var numberBall = document.createElement("div");numberBall.className = "number-ball front-ball";numberBall.innerText = number;lotterySet.appendChild(numberBall);});backNumbers[i].forEach(function(number) {var numberBall = document.createElement("div");numberBall.className = "number-ball back-ball";numberBall.innerText = number;lotterySet.appendChild(numberBall);});resultContainer.appendChild(lotterySet);}}</script>
4.2 制作网页index.html
注意到Flask提供了网页渲染功能,这样我们可以设计我们的主页
@app.route('/')
def index():return render_template('index.html')
把上述JS脚本放入index.html 就可以访问后端服务啦,具体html的UI显示,由于代码量很大这里不与展示了,感兴趣同学可以根据上述python客户端的访问逻辑试用GPT为你编写index.html,手机端访问效果如下:
5. 最后
上述是个人搭建自己网站部署AI应用的简单过程,完整源码后期整理上传,欢迎大家留言关注~