风格迁移开发记录(DCT-Net)

1.DCT-Net部署

阿里旗下的 modelscope社区,丰富的开源风格迁移算法模型
image.png
DCT-Net GitHub链接

git clone https://github.com/menyifang/DCT-Net.git
cd DCT-Netpython run_sdk.py下载不同风格的模型

如下图每个文件夹代表一种风格,有cartoon_bg.pb, cartoon_h.pb两个模型,bg是全图风格模型,h是脸部风格模型:
image.png

模型转换

export_model.py

模型转换方式,不能将pb模型全部转换,要取中间节点,有些前后节点rknn或ncnn不支持需放在cpu处理
pb->tflite->rknn
pb->onnx->ncnn

"""
@File   : export_model.py
@Author : 
@Date   : 2023/12/13
@Desc   : 
"""
import os
import shutil
import tensorflow as tf
import cv2
import tf2onnx
import onnx
import time
import onnxruntime
import subprocess
import numpy as np# python -m tf2onnx.convert --graphdef .\damo\cv_unet_person-image-cartoon_compound-models\cartoon_bg.pb --output .\damo\cv_unet_person-image-cartoon_compound-models\cartoon_bg.onnx --
# inputs input_image:0 --outputs output_image:0 
def convert_pb2tflite(model_path, input_shape, model_dir, type_name):converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(graph_def_file = model_path + '.pb',input_arrays = ["strided_slice_1"],output_arrays = ["strided_slice_4"],input_shapes = {'strided_slice_1' : input_shape})# converter = tf.lite.TFLiteConverter.from_frozen_graph(model_path + '.pb', input_arrays=["input_image"], output_arrays=output_name, input_shapes={"input_image": input_shape})# converter.optimizations = [tf.lite.Optimize.DEFAULT]# converter.target_spec.supported_types = [tf.float16]tflite_model = converter.convert()bgh = model_path.split('/')[-1]save_path = os.path.join(model_dir, type_name + '_' + bgh + '.tflite')print('===> ', save_path, bgh)# exit()open(save_path, "wb").write(tflite_model)interpreter = tf.lite.Interpreter(model_content=tflite_model)input = interpreter.get_input_details()print('===> input: ', input)output = interpreter.get_output_details()print('===> output: ', output)def pb2tflite_2():pb_dir = './damo/'bg_model = 'cartoon_bg'h_model = 'cartoon_h'tflite_dir = './damo/tflite_model'# 自定义设置输入shape# bg_input_shape = [1, 720, 720, 3]# bg_input_shape = [1, 1920, 1080, 3]# bg_input_shape = [1, 2560, 1440, 3]bg_input_shape = [1, 1280, 720, 3]head_input_shape = [1, 288, 288, 3]tflite_dir = tflite_dir + str(bg_input_shape[1]) + "x" + str(bg_input_shape[2])if not os.path.exists(tflite_dir):os.makedirs(tflite_dir)for i in os.listdir(pb_dir):if not i.startswith('cv_unet'):continuemodel_dir = os.path.join(pb_dir, i)bg_path = os.path.join(model_dir, bg_model)h_path = os.path.join(model_dir, h_model)print('============', i)type_name = i.split('-')[-2]type_name = type_name.split('_')[0]print('===> ', i, type_name, bg_path)convert_pb2tflite(bg_path, bg_input_shape, tflite_dir, type_name)# convert_pb2tflite(h_path, head_input_shape, tflite_dir, type_name)# exit(0)def pb2onnx(bg_path, bg_input_shape, onnx_dir, type_name):# 定义要执行的命令行命令pb_path = bg_path + '.pb'onnx_name = type_name + '_' + bg_path.split('/')[-1] + '.onnx'onnx_path =  os.path.join(onnx_dir, onnx_name)  command = "python -m tf2onnx.convert --graphdef {pb_path} --output {onnx_path} --inputs strided_slice_1:0 --outputs add_1:0 --inputs-as-nchw strided_slice_1:0 --outputs-as-nchw add_1:0"# 使用字符串格式化将变量插入命令中formatted_command = command.format(pb_path=pb_path, onnx_path=onnx_path)# 使用 subprocess.Popen 执行命令p = subprocess.Popen(formatted_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)# 获取命令行输出和错误信息output, error = p.communicate()# 将二进制输出转换为字符串并打印出来print(output.decode())time.sleep(5)# onnx-simonnx_sim, _ = os.path.splitext(onnx_path)onnx_sim_path = onnx_sim + '-sim.onnx'n = bg_input_shape[0]c = bg_input_shape[3]h = bg_input_shape[1]w = bg_input_shape[2]print(bg_input_shape, onnx_path, onnx_sim_path)command2 = 'python -m onnxsim {onnx_path} {onnx_sim_path} --overwrite-input-shape {n},{c},{h},{w}'# 使用字符串格式化将变量插入命令中formatted_command = command2.format(onnx_path=onnx_path, onnx_sim_path=onnx_sim_path, bg_input_shape=bg_input_shape, n=n, c=c, h=h, w=w)# 使用 subprocess.Popen 执行命令p = subprocess.Popen(formatted_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)# 获取命令行输出和错误信息output, error = p.communicate()# 将二进制输出转换为字符串并打印出来print(output.decode())def tfpb2onnx():pb_dir = './damo/'bg_model = 'cartoon_bg'h_model = 'cartoon_h'onnx_dir = './damo/onnx_model'# bg_input_shape = [1, 1024, 1024, 3]# bg_input_shape = [1, 1920, 1080, 3]bg_input_shape = [1, 2560, 2560, 3]head_input_shape = [1, 288, 288, 3]onnx_dir = onnx_dir + str(bg_input_shape[1]) + "x" + str(bg_input_shape[2])if not os.path.exists(onnx_dir):os.makedirs(onnx_dir)for i in os.listdir(pb_dir):if not i.startswith('cv_unet'):continuemodel_dir = os.path.join(pb_dir, i)bg_path = os.path.join(model_dir, bg_model)h_path = os.path.join(model_dir, h_model)print(f'============>bg_path: {bg_path}, h_path: {h_path}')type_name = i.split('-')[-2].split('_')[0]# type_name = type_name.split('_')[0]print(f'type_name: {type_name}, i: {i}')# exit(0)pb2onnx(bg_path, bg_input_shape, onnx_dir, type_name)pb2onnx(h_path, head_input_shape, onnx_dir, type_name)# exit()# tf2onnx# python -m tf2onnx.convert --graphdef damo/cv_unet_person-image-cartoon_compound-models/cartoon_bg.pb --output damo/cv_unet_person-image-cartoon_compound-models/cartoon_bg.onnx  # --inputs strided_slice_1:0 --outputs add_1:0 --inputs-as-nchw strided_slice_1:0# simplifier onnx# python -m onnxsim cartoon_bg.onnx cartoon_bg-sim.onnx --overwrite-input-shape 1,3,1024,1024def onnx2ncnn():onnx_dir = './damo/onnx_model2560x2560'for i in os.listdir(onnx_dir):if not i.endswith('.onnx'):continueif 'h-sim' in i:continueonnx_path = os.path.join(onnx_dir, i)onnx_name, ext = os.path.splitext(onnx_path)# print(onnx_name)ncnn_param = onnx_name + '.param'ncnn_bin = onnx_name + '.bin'print(f'onnx_name: {onnx_name}, ncnn_param: {ncnn_param}, ncnn_bin {ncnn_bin}')command3 = "./ncnn-20231027-ubuntu-2204/bin/onnx2ncnn {onnx_path} {ncnn_param} {ncnn_bin}"# 使用字符串格式化将变量插入命令中formatted_command = command3.format(onnx_path=onnx_path, ncnn_param=ncnn_param, ncnn_bin=ncnn_bin)# 使用 subprocess.Popen 执行命令p = subprocess.Popen(formatted_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)# 获取命令行输出和错误信息output, error = p.communicate()# 将二进制输出转换为字符串并打印出来print(output.decode())# exit(0)def ncnn_optimize():onnx_dir = './damo/onnx_model1024x1024'for i in os.listdir(onnx_dir):if not i.endswith('.param'):continueparam_path = os.path.join(onnx_dir, i)bin_path = param_path.replace('.param', '.bin')opt_param_path = param_path.replace('.param', '-opt.param')opt_bin_path = bin_path.replace('.bin', '-opt.bin')print(f'param_path: {param_path}, bin_path: {bin_path}, opt_param_path {opt_param_path}, opt_bin_path {opt_bin_path}')command3 = "./ncnn-20231027-ubuntu-2204/bin/ncnnoptimize {param_path} {bin_path} {opt_param_path} {opt_bin_path} 1"# 使用字符串格式化将变量插入命令中formatted_command = command3.format(param_path=param_path, bin_path=bin_path, opt_param_path=opt_param_path, opt_bin_path=opt_bin_path)# 使用 subprocess.Popen 执行命令p = subprocess.Popen(formatted_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)# 获取命令行输出和错误信息output, error = p.communicate()# 将二进制输出转换为字符串并打印出来print(output.decode())if __name__ == '__main__':pb2tflite_2()# tfpb2onnx()# onnx2ncnn()# ncnn_optimize()

tflite2rknn.py

import os
import time
import shutil
import numpy as np
import cv2
from rknn.api import RKNNdef show_outputs(outputs):output = outputs[0][0]index = sorted(range(len(output)), key=lambda k : output[k], reverse=True)fp = open('./labels.txt', 'r')labels = fp.readlines()top5_str = 'mobilenet_v1\n-----TOP 5-----\n'for i in range(5):value = output[index[i]]if value > 0:topi = '[{:>4d}] score:{:.6f} class:"{}"\n'.format(index[i], value, labels[index[i]].strip().split(':')[-1])else:topi = '[  -1]: 0.0\n'top5_str += topiprint(top5_str.strip())def dequantize(outputs, scale, zp):outputs[0] = (outputs[0] - zp) * scalereturn outputsdef letterbox(im, new_shape=(640, 640), color=(0, 0, 0)):# Resize and pad image while meeting stride-multiple constraintsshape = im.shape[:2]  # current shape [height, width]if isinstance(new_shape, int):new_shape = (new_shape, new_shape)# Scale ratio (new / old)r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])# Compute paddingratio = r, r  # width, height ratiosnew_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh paddingdw /= 2  # divide padding into 2 sidesdh /= 2if shape[::-1] != new_unpad:  # resizeim = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))left, right = int(round(dw - 0.1)), int(round(dw + 0.1))im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add borderreturn im, ratio, (dw, dh)def post_process(rknn_result):rknn_result = rknn_result.clip(-0.999999, 0.999999)rknn_result = (rknn_result + 1) * 127.5cartoon_img = rknn_result.astype('uint8')# onnx_result = cv2.cvtColor(onnx_result, cv2.COLOR_RGB2BGR)# cv2.imwrite('8_anime.jpg', rknn_result)return cartoon_imgdef export_rknn(tflite_model_path, QUANTIZE_ON, DATASET):# Create RKNN objectrknn = RKNN(verbose=True)# Pre-process configprint('--> Config model')# rknn.config(mean_values=[128, 128, 128], std_values=[128, 128, 128], target_platform='rk3566')rknn.config(target_platform='rk3588')print('done')# Load model (from https://www.tensorflow.org/lite/examples/image_classification/overview?hl=zh-cn)print('--> Loading model')ret = rknn.load_tflite(model=tflite_model_path)if ret != 0:print('Load model failed!')exit(ret)print('done')# Build modelprint('--> Building model')ret = rknn.build(do_quantization=QUANTIZE_ON, dataset=DATASET)if ret != 0:print('Build model failed!')exit(ret)print('done')# Export rknn modelprint('--> Export rknn model')ret = rknn.export_rknn(tflite_model_path.replace('.tflite', '.rknn'))if ret != 0:print('Export rknn model failed!')exit(ret)print('done')# Init runtime environmentprint('--> Init runtime environment')ret = rknn.init_runtime()if ret != 0:print('Init runtime environment failed!')exit(ret)print('done')# Set inputsIMG_PATH = './16.png'IMG_SIZE = (288, 288)  # w, himg = cv2.imread(IMG_PATH)# img, ratio, (dw, dh) = letterbox(img, new_shape=(IMG_SIZE[0], IMG_SIZE[1]))# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = cv2.resize(img, IMG_SIZE)# img = img.astype('float32')# img = img / 127.5 - 1img = np.expand_dims(img, 0)print(f'===> input shape: {img.shape}')# Inferenceprint('--> Running model')outputs = rknn.inference(inputs=[img], data_format=['nhwc'])print(f'===> output shape: {outputs[0].shape}')# np.save('./tflite_mobilenet_v1_qat_0.npy', outputs[0])# show_outputs(dequantize(outputs, scale=0.00390625, zp=0))cartoon_img = post_process(outputs[0])cv2.imwrite(model_path.replace('.tflite', '.jpg'), cartoon_img)print('done')rknn.release()if __name__ == '__main__':model_dir = './StyleTransfer/DCT-Net-main/damo/tflite_head'QUANTIZE_ON = FalseDATASET = './dataset.txt'for i in os.listdir(model_dir):if not i.endswith('.tflite'):continuemodel_path = os.path.join(model_dir, i)print(f'model path: {model_path}')export_rknn(model_path, QUANTIZE_ON, DATASET)

RKNN和NCNN推理代码

GitHub

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/875093.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++STL详解(一)——String接口详解(上)!!!

目录 一.string类介绍 二.string类的构造赋值 2.1string类的拷贝和构造函数 2.2深拷贝 三.string类的插入 3.1push_back 3.2append 3.3操作符 3.4insert 四.string的删除 4.1pop_back 4.2erase 五.string的查找 5.1find 5.2rfind 六.string的比较 6.1compare函…

独家|二十年国货羊奶粉老品牌发力成人奶粉,瞄准低GI、特医食品

前言 中国羊奶看陕西。 作为陕西省农业产业化重点企业以及陕西省专精特新企业,成立于2004年的羊奶粉品牌雅泰乳业正在不断进行深入布局。 雅泰乳业成人粉部门销售总监于维涛近日向AgeFood表示,雅泰成人奶粉业务主要分为两部分。一部分是以雅泰牧歌、龙…

深入浅出WebRTC—Pacer

平滑发包(Pacer)是 WebRTC 实现高质量实时通信不可或缺的一部分。在视频通信中,单帧视频可能包含大量的数据,如果未经控制地立即发送,可能瞬间对网络造成巨大压力。Pacer 能够根据网络条件动态调整发送速率&#xff0c…

SpringBoot事务管理、任务调度、Mail整合。

一.Spring Boot中的事务管理 编程式事务 : 在代码中硬编码(不推荐使用):通过 TransactionTemplate 或者 TransactionManager 手动管理事务,实际应用中很少使用,用于理解Spring 事务管理。 声明式事务:在 XML 配置文件或者基于注解 Transactional(推荐使…

Java实现汉字转拼音工具类的编写与应用

前言 在处理中文数据时,经常需要将汉字转换为拼音,无论是为了搜索优化、数据分析还是提升用户体验。本文将详细介绍如何编写一个实用的Java工具类来实现这一功能,并通过一个完整的示例来展示其使用方法。我们将使用Apache Commons Lang库中的…

缓慢变化维

缓慢变化维 缓慢变化维(Slowly Changing Dimensions,简称SCD)是数据仓库中的一个重要概念,用于处理维度表中数据随时间发生的变化。以下是一个具体的例子来描述缓慢变化维: 假设我们有一个销售数据仓库,其…

AWS全服务历史年表:发布日期、GA和服务概述一览(四)

我一直在尝试从各种角度撰写关于Amazon Web Services(AWS)的信息和魅力。由于我喜欢技术历史,这次我总结了AWS服务发布的历史年表。 虽然AWS官方也通过“Whats New”发布了官方公告,但我一直希望能有一篇文章将公告日期、GA日期&…

python库(14):Arrow库简化时间处理

1 Arrow简介 Arrow 是一个被称为程序员的时间处理利器的 Python 库。 从诞生起,它就是为了填补 Python 的 datetime 类型的功能空白而生的。为程序员提供了一种更简单、更直观的方式来处理日期和时间。 2 安装Arrow库 pip install arrow -i https://pypi.tuna.ts…

什么是设备运维管理系统?有什么作用?(6款设备运维管理系统推荐)

一、什么是设备运维管理系统? 设备运维管理系统是一种集成了监控、管理、维护和优化设备性能的软件平台。它旨在通过自动化的手段,提高设备运行的可靠性和效率,降低运维成本,并优化资源利用。 设备运维管理系统能够实时监控设备…

【1】Python机器学习之基础概念

1、什么是机器学习 最早的机器学习应用——垃圾邮件分辨 传统的计算机解决问题思路: 编写规则,定义“垃圾邮件”,让计算机执行对于很多问题,规则很难定义规则不断变化 机器学习在图像识别领域的重要应用: 人脸识别…

带您详细了解安全漏洞的产生和防护

什么是漏洞? 漏洞是 IT、网络、云、Web 或移动应用程序系统中的弱点或缺陷,可能使其容易受到成功的外部攻击。攻击者经常试图寻找网络安全中的各种类型的漏洞来组合和利用系统。 一些最常见的漏洞: 1.SQL注入 注入诸如 SQL 查询之类的小代…

c# Math.Round()四舍五入取整数

可以使用Math.Round()方法进行四舍五入取整数的操作。 以下是使用Math.Round()方法的实现方法: 将浮点数直接作为参数传递给Math.Round()方法,并指定要保留的小数位数。此方法将返回最接近的整数值。 double number 3.89; int roundedNumber (int)Mat…

react-scripts 这个包的作用是什么

react-scripts 是 Create React App 项目中的一个核心包,它的主要作用包括: 封装和简化项目配置。react-scripts 封装了 Webpack、Babel、ESLint 等工具的配置,使开发者无需手动配置这些复杂的构建工具[1][3]. 提供开发和构建脚本。它包含了 start、bui…

milvus的批量向量搜索

批量向量搜索允许在单个请求中进行多个向量相似性搜索。这种类型的搜索非常适合需要为一组查询向量查找相似向量的场景,可显著减少所需的时间和计算资源。 即:一次查询多个向量,吞吐。 系统会并行处理这些向量,为每个查询向量返回一个单独的…

旋转目标检测:FCOS: Fully Convolutional One-Stage Object Detection【方法解读】

FCOS: 全卷积单阶段目标检测 我们提出了一种全卷积单阶段目标检测器(FCOS),以逐像素预测的方式解决目标检测问题,类似于语义分割。目前几乎所有的最先进目标检测器,如RetinaNet、SSD、YOLOv3和Faster R-CNN,都依赖于预定义的锚框。相反,我们提出的FCOS检测器是无锚框的…

静态解析activiti文本,不入库操作流程

说明&#xff1a; activiti本身状态存库&#xff0c;导致效率太低&#xff0c;把中间状态封装成一个载荷类&#xff0c;返回给上游&#xff0c;下次请求时给带着载荷类即可。 1.pom依赖 <dependency><groupId>net.sf.json-lib</groupId><artifactId>js…

BUU [PASECA2019]honey_shop

BUU [PASECA2019]honey_shop 技术栈&#xff1a;任意文件读取、session伪造 开启靶机&#xff0c;我有1336金币&#xff0c;买flag需要1337金币 点击上面的大图&#xff0c;会直接下载图片 抓包看看&#xff0c;感觉是任意文件读取 修改下路径读一下 读到了session密钥是Kv8i…

Springboot validated JSR303校验

1.导入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-validation</artifactId></dependency> 2.测试类 package com.jmj.gulimall.product.testC;import lombok.Data;import javax.val…

C++《类和对象》(中)

一、 类的默认成员函数介绍二、构造函数 构造函数名与类同名内置类型与自定义类型析构函数拷贝构造函数 C《类和对象》(中) 一、 类的默认成员函数介绍 默认成员函数就是⽤⼾没有显式实现&#xff0c;编译器会⾃动⽣成的成员函数称为默认成员函数。 那么我们主要学习的是1&…

Linux环境docker部署Firefox结合内网穿透远程使用浏览器测试

文章目录 前言1. 部署Firefox2. 本地访问Firefox3. Linux安装Cpolar4. 配置Firefox公网地址5. 远程访问Firefox6. 固定Firefox公网地址7. 固定地址访问Firefox 前言 本次实践部署环境为本地Linux环境&#xff0c;使用Docker部署Firefox浏览器后&#xff0c;并结合cpolar内网穿…