1.DCT-Net部署
阿里旗下的 modelscope社区,丰富的开源风格迁移算法模型
DCT-Net GitHub链接
git clone https://github.com/menyifang/DCT-Net.git
cd DCT-Netpython run_sdk.py下载不同风格的模型
如下图每个文件夹代表一种风格,有cartoon_bg.pb, cartoon_h.pb两个模型,bg是全图风格模型,h是脸部风格模型:
模型转换
export_model.py
模型转换方式,不能将pb模型全部转换,要取中间节点,有些前后节点rknn或ncnn不支持需放在cpu处理:
pb->tflite->rknn
pb->onnx->ncnn
"""
@File : export_model.py
@Author :
@Date : 2023/12/13
@Desc :
"""
import os
import shutil
import tensorflow as tf
import cv2
import tf2onnx
import onnx
import time
import onnxruntime
import subprocess
import numpy as np# python -m tf2onnx.convert --graphdef .\damo\cv_unet_person-image-cartoon_compound-models\cartoon_bg.pb --output .\damo\cv_unet_person-image-cartoon_compound-models\cartoon_bg.onnx --
# inputs input_image:0 --outputs output_image:0
def convert_pb2tflite(model_path, input_shape, model_dir, type_name):converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(graph_def_file = model_path + '.pb',input_arrays = ["strided_slice_1"],output_arrays = ["strided_slice_4"],input_shapes = {'strided_slice_1' : input_shape})# converter = tf.lite.TFLiteConverter.from_frozen_graph(model_path + '.pb', input_arrays=["input_image"], output_arrays=output_name, input_shapes={"input_image": input_shape})# converter.optimizations = [tf.lite.Optimize.DEFAULT]# converter.target_spec.supported_types = [tf.float16]tflite_model = converter.convert()bgh = model_path.split('/')[-1]save_path = os.path.join(model_dir, type_name + '_' + bgh + '.tflite')print('===> ', save_path, bgh)# exit()open(save_path, "wb").write(tflite_model)interpreter = tf.lite.Interpreter(model_content=tflite_model)input = interpreter.get_input_details()print('===> input: ', input)output = interpreter.get_output_details()print('===> output: ', output)def pb2tflite_2():pb_dir = './damo/'bg_model = 'cartoon_bg'h_model = 'cartoon_h'tflite_dir = './damo/tflite_model'# 自定义设置输入shape# bg_input_shape = [1, 720, 720, 3]# bg_input_shape = [1, 1920, 1080, 3]# bg_input_shape = [1, 2560, 1440, 3]bg_input_shape = [1, 1280, 720, 3]head_input_shape = [1, 288, 288, 3]tflite_dir = tflite_dir + str(bg_input_shape[1]) + "x" + str(bg_input_shape[2])if not os.path.exists(tflite_dir):os.makedirs(tflite_dir)for i in os.listdir(pb_dir):if not i.startswith('cv_unet'):continuemodel_dir = os.path.join(pb_dir, i)bg_path = os.path.join(model_dir, bg_model)h_path = os.path.join(model_dir, h_model)print('============', i)type_name = i.split('-')[-2]type_name = type_name.split('_')[0]print('===> ', i, type_name, bg_path)convert_pb2tflite(bg_path, bg_input_shape, tflite_dir, type_name)# convert_pb2tflite(h_path, head_input_shape, tflite_dir, type_name)# exit(0)def pb2onnx(bg_path, bg_input_shape, onnx_dir, type_name):# 定义要执行的命令行命令pb_path = bg_path + '.pb'onnx_name = type_name + '_' + bg_path.split('/')[-1] + '.onnx'onnx_path = os.path.join(onnx_dir, onnx_name) command = "python -m tf2onnx.convert --graphdef {pb_path} --output {onnx_path} --inputs strided_slice_1:0 --outputs add_1:0 --inputs-as-nchw strided_slice_1:0 --outputs-as-nchw add_1:0"# 使用字符串格式化将变量插入命令中formatted_command = command.format(pb_path=pb_path, onnx_path=onnx_path)# 使用 subprocess.Popen 执行命令p = subprocess.Popen(formatted_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)# 获取命令行输出和错误信息output, error = p.communicate()# 将二进制输出转换为字符串并打印出来print(output.decode())time.sleep(5)# onnx-simonnx_sim, _ = os.path.splitext(onnx_path)onnx_sim_path = onnx_sim + '-sim.onnx'n = bg_input_shape[0]c = bg_input_shape[3]h = bg_input_shape[1]w = bg_input_shape[2]print(bg_input_shape, onnx_path, onnx_sim_path)command2 = 'python -m onnxsim {onnx_path} {onnx_sim_path} --overwrite-input-shape {n},{c},{h},{w}'# 使用字符串格式化将变量插入命令中formatted_command = command2.format(onnx_path=onnx_path, onnx_sim_path=onnx_sim_path, bg_input_shape=bg_input_shape, n=n, c=c, h=h, w=w)# 使用 subprocess.Popen 执行命令p = subprocess.Popen(formatted_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)# 获取命令行输出和错误信息output, error = p.communicate()# 将二进制输出转换为字符串并打印出来print(output.decode())def tfpb2onnx():pb_dir = './damo/'bg_model = 'cartoon_bg'h_model = 'cartoon_h'onnx_dir = './damo/onnx_model'# bg_input_shape = [1, 1024, 1024, 3]# bg_input_shape = [1, 1920, 1080, 3]bg_input_shape = [1, 2560, 2560, 3]head_input_shape = [1, 288, 288, 3]onnx_dir = onnx_dir + str(bg_input_shape[1]) + "x" + str(bg_input_shape[2])if not os.path.exists(onnx_dir):os.makedirs(onnx_dir)for i in os.listdir(pb_dir):if not i.startswith('cv_unet'):continuemodel_dir = os.path.join(pb_dir, i)bg_path = os.path.join(model_dir, bg_model)h_path = os.path.join(model_dir, h_model)print(f'============>bg_path: {bg_path}, h_path: {h_path}')type_name = i.split('-')[-2].split('_')[0]# type_name = type_name.split('_')[0]print(f'type_name: {type_name}, i: {i}')# exit(0)pb2onnx(bg_path, bg_input_shape, onnx_dir, type_name)pb2onnx(h_path, head_input_shape, onnx_dir, type_name)# exit()# tf2onnx# python -m tf2onnx.convert --graphdef damo/cv_unet_person-image-cartoon_compound-models/cartoon_bg.pb --output damo/cv_unet_person-image-cartoon_compound-models/cartoon_bg.onnx # --inputs strided_slice_1:0 --outputs add_1:0 --inputs-as-nchw strided_slice_1:0# simplifier onnx# python -m onnxsim cartoon_bg.onnx cartoon_bg-sim.onnx --overwrite-input-shape 1,3,1024,1024def onnx2ncnn():onnx_dir = './damo/onnx_model2560x2560'for i in os.listdir(onnx_dir):if not i.endswith('.onnx'):continueif 'h-sim' in i:continueonnx_path = os.path.join(onnx_dir, i)onnx_name, ext = os.path.splitext(onnx_path)# print(onnx_name)ncnn_param = onnx_name + '.param'ncnn_bin = onnx_name + '.bin'print(f'onnx_name: {onnx_name}, ncnn_param: {ncnn_param}, ncnn_bin {ncnn_bin}')command3 = "./ncnn-20231027-ubuntu-2204/bin/onnx2ncnn {onnx_path} {ncnn_param} {ncnn_bin}"# 使用字符串格式化将变量插入命令中formatted_command = command3.format(onnx_path=onnx_path, ncnn_param=ncnn_param, ncnn_bin=ncnn_bin)# 使用 subprocess.Popen 执行命令p = subprocess.Popen(formatted_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)# 获取命令行输出和错误信息output, error = p.communicate()# 将二进制输出转换为字符串并打印出来print(output.decode())# exit(0)def ncnn_optimize():onnx_dir = './damo/onnx_model1024x1024'for i in os.listdir(onnx_dir):if not i.endswith('.param'):continueparam_path = os.path.join(onnx_dir, i)bin_path = param_path.replace('.param', '.bin')opt_param_path = param_path.replace('.param', '-opt.param')opt_bin_path = bin_path.replace('.bin', '-opt.bin')print(f'param_path: {param_path}, bin_path: {bin_path}, opt_param_path {opt_param_path}, opt_bin_path {opt_bin_path}')command3 = "./ncnn-20231027-ubuntu-2204/bin/ncnnoptimize {param_path} {bin_path} {opt_param_path} {opt_bin_path} 1"# 使用字符串格式化将变量插入命令中formatted_command = command3.format(param_path=param_path, bin_path=bin_path, opt_param_path=opt_param_path, opt_bin_path=opt_bin_path)# 使用 subprocess.Popen 执行命令p = subprocess.Popen(formatted_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)# 获取命令行输出和错误信息output, error = p.communicate()# 将二进制输出转换为字符串并打印出来print(output.decode())if __name__ == '__main__':pb2tflite_2()# tfpb2onnx()# onnx2ncnn()# ncnn_optimize()
tflite2rknn.py
import os
import time
import shutil
import numpy as np
import cv2
from rknn.api import RKNNdef show_outputs(outputs):output = outputs[0][0]index = sorted(range(len(output)), key=lambda k : output[k], reverse=True)fp = open('./labels.txt', 'r')labels = fp.readlines()top5_str = 'mobilenet_v1\n-----TOP 5-----\n'for i in range(5):value = output[index[i]]if value > 0:topi = '[{:>4d}] score:{:.6f} class:"{}"\n'.format(index[i], value, labels[index[i]].strip().split(':')[-1])else:topi = '[ -1]: 0.0\n'top5_str += topiprint(top5_str.strip())def dequantize(outputs, scale, zp):outputs[0] = (outputs[0] - zp) * scalereturn outputsdef letterbox(im, new_shape=(640, 640), color=(0, 0, 0)):# Resize and pad image while meeting stride-multiple constraintsshape = im.shape[:2] # current shape [height, width]if isinstance(new_shape, int):new_shape = (new_shape, new_shape)# Scale ratio (new / old)r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])# Compute paddingratio = r, r # width, height ratiosnew_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh paddingdw /= 2 # divide padding into 2 sidesdh /= 2if shape[::-1] != new_unpad: # resizeim = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))left, right = int(round(dw - 0.1)), int(round(dw + 0.1))im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add borderreturn im, ratio, (dw, dh)def post_process(rknn_result):rknn_result = rknn_result.clip(-0.999999, 0.999999)rknn_result = (rknn_result + 1) * 127.5cartoon_img = rknn_result.astype('uint8')# onnx_result = cv2.cvtColor(onnx_result, cv2.COLOR_RGB2BGR)# cv2.imwrite('8_anime.jpg', rknn_result)return cartoon_imgdef export_rknn(tflite_model_path, QUANTIZE_ON, DATASET):# Create RKNN objectrknn = RKNN(verbose=True)# Pre-process configprint('--> Config model')# rknn.config(mean_values=[128, 128, 128], std_values=[128, 128, 128], target_platform='rk3566')rknn.config(target_platform='rk3588')print('done')# Load model (from https://www.tensorflow.org/lite/examples/image_classification/overview?hl=zh-cn)print('--> Loading model')ret = rknn.load_tflite(model=tflite_model_path)if ret != 0:print('Load model failed!')exit(ret)print('done')# Build modelprint('--> Building model')ret = rknn.build(do_quantization=QUANTIZE_ON, dataset=DATASET)if ret != 0:print('Build model failed!')exit(ret)print('done')# Export rknn modelprint('--> Export rknn model')ret = rknn.export_rknn(tflite_model_path.replace('.tflite', '.rknn'))if ret != 0:print('Export rknn model failed!')exit(ret)print('done')# Init runtime environmentprint('--> Init runtime environment')ret = rknn.init_runtime()if ret != 0:print('Init runtime environment failed!')exit(ret)print('done')# Set inputsIMG_PATH = './16.png'IMG_SIZE = (288, 288) # w, himg = cv2.imread(IMG_PATH)# img, ratio, (dw, dh) = letterbox(img, new_shape=(IMG_SIZE[0], IMG_SIZE[1]))# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = cv2.resize(img, IMG_SIZE)# img = img.astype('float32')# img = img / 127.5 - 1img = np.expand_dims(img, 0)print(f'===> input shape: {img.shape}')# Inferenceprint('--> Running model')outputs = rknn.inference(inputs=[img], data_format=['nhwc'])print(f'===> output shape: {outputs[0].shape}')# np.save('./tflite_mobilenet_v1_qat_0.npy', outputs[0])# show_outputs(dequantize(outputs, scale=0.00390625, zp=0))cartoon_img = post_process(outputs[0])cv2.imwrite(model_path.replace('.tflite', '.jpg'), cartoon_img)print('done')rknn.release()if __name__ == '__main__':model_dir = './StyleTransfer/DCT-Net-main/damo/tflite_head'QUANTIZE_ON = FalseDATASET = './dataset.txt'for i in os.listdir(model_dir):if not i.endswith('.tflite'):continuemodel_path = os.path.join(model_dir, i)print(f'model path: {model_path}')export_rknn(model_path, QUANTIZE_ON, DATASET)
RKNN和NCNN推理代码
GitHub