BasicVSR++模型转JIT并用c++libtorch推理

BasicVSR++模型转JIT并用c++libtorch推理

文章目录

  • BasicVSR++模型转JIT并用c++libtorch推理
    • 安装BasicVSR++ 环境
      • 1.下载源码
      • 2. 新建一个conda环境
      • 3. 安装pytorch
      • 4. 安装 mim 和 mmcv-full
      • 5. 安装 mmedit
      • 6. 下载模型文件
      • 7. 测试一下能否正常运行
    • 转换为JIT模型
    • 用c++ libtorch推理
      • 效果

安装BasicVSR++ 环境

1.下载源码

git clone https://github.com/ckkelvinchan/BasicVSR_PlusPlus.git

2. 新建一个conda环境

conda create -n BasicVSRPLUSPLUS  python=3.8 -y
conda activate BasicVSRPLUSPLUS  

3. 安装pytorch

pytorch官网 安装合适的版本
我这里是CUDA11.6版本

conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia

4. 安装 mim 和 mmcv-full

pip install openmim
mim install mmcv-full

5. 安装 mmedit

pip install mmedit

6. 下载模型文件

下载模型文件放在这里chkpts/basicvsr_plusplus_reds4.pth

7. 测试一下能否正常运行

python demo/restoration_video_demo.py configs/basicvsr_plusplus_reds4.py chkpts/basicvsr_plusplus_reds4.pth data/demo_000 results/demo_000

在这里插入图片描述
在这里插入图片描述
OK ! 环境正常!下面开始转换工作

转换为JIT模型

在demo下新建一个转换脚本

import os
import cv2
import mmcv
import numpy as np
import torch
from mmedit.core import tensor2img
from mmedit.apis import init_modeldef main():# 加载模型并设置为评估模式model = init_model("configs/basicvsr_plusplus_reds4.py","chkpts/basicvsr_plusplus_reds4.pth", device=torch.device('cuda', 0))model.eval()# 准备一个示例输入src1 = cv2.imread("./data/img/00000000.png")src = cv2.cvtColor(src1, cv2.COLOR_BGR2RGB)src = torch.from_numpy(src / 255.).permute(2, 0, 1).float()src = src.unsqueeze(0)input_arg = torch.stack([src], dim=1)input_arg = input_arg.to(torch.device('cuda', 0))  # 确保输入在GPU上# # 执行模型推理# with torch.no_grad():  # 在推理时不需要计算梯度#     result = model(input_arg, test_mode=True)['output'].cpu()# output_i = tensor2img(result)# mmcv.imwrite(output_i, "./test.png")# 模型转换traced_model = torch.jit.trace(model.generator, input_arg)torch.jit.save(traced_model, "basicvsrPP.pt")# 测试res = traced_model(input_arg)out = tensor2img(res)mmcv.imwrite(out, "./testoo.png")if __name__ == '__main__':main()

用c++ libtorch推理

/** @Author: Liangbaikai* @LastEditTime: 2024-03-29 11:28:42* @Description: 视频超分* Copyright (c) 2024 by Liangbaikai, All Rights Reserved.*/#pragma once
#include <iostream>
#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <vector>
#include <c10/cuda/CUDACachingAllocator.h>namespace LIANGBAIKAI_BASE_MODEL_NAME
{class lbk_video_super_resolution_basicPP{public:lbk_video_super_resolution_basicPP() = default;virtual ~lbk_video_super_resolution_basicPP(){c10::cuda::CUDACachingAllocator::emptyCache();// cudaDeviceReset();}/*** @description: 初始化* @param {string} &modelpath 模型文件* @param {int} gpuid GPU的id* @return {*}成功返回0,失败返回-1*/int init(const std::string &modelpath, int gpuid = 0){try{_mymodule = std::make_unique<torch::jit::script::Module>(torch::jit::load(modelpath));}catch (const c10::Error &e){std::cerr << "Error loading the model " << modelpath << std::endl;std::cerr << "Error " << e.what() << std::endl;return -1;}_gpuid = gpuid;if ((_gpuid < 0) || (!torch::cuda::is_available())){_device = std::make_unique<torch::Device>(torch::kCPU);_mymodule->to(at::kCPU);}else{_device = std::make_unique<torch::Device>(torch::kCUDA, _gpuid);_mymodule->to(at::kCUDA, _gpuid);}_mymodule->eval();_modelsuccess = true;return 0;}/*** @description: 推理* @param {Mat} &inputpic 输入图片* @param {Mat} &outputpic 输出结果* @param {bool} showlog  是否打印日志* @return {*} 成功返回0,失败返回-1*/int inference(cv::Mat &inputpic, cv::Mat &outputpic, bool showlog = false){if (inputpic.empty() || (inputpic.channels() != 3)){std::cout << "input data ERROR" << std::endl;return -1;}if (!_modelsuccess){std::cout << "model has not been inited!" << std::endl;return -1;}// torch::DeviceGuard 是一个类,它的作用是确保在使用完设备(如CPU或GPU)后,能够正确地将设备恢复到使用前的状态。torch::DeviceGuard device_guard(*_device); // 作用域内所有操作都在指定设备上运行,离开此作用域恢复cv::transpose(inputpic, inputpic); // 顺时针旋转// 将图片转换为tensorcv::Mat img_float;inputpic.convertTo(img_float, CV_32FC3, 1.0 / 255);torch::Tensor img_tensor = torch::from_blob(img_float.data, {img_float.rows, img_float.cols, 3}, torch::kFloat32).permute({2, 1, 0});img_tensor = (img_tensor - 0.5) / 0.5;img_tensor = (img_tensor + 1) / 2;img_tensor = torch::clamp(img_tensor, 0, 1);torch::Tensor src_unsqueezed = img_tensor.unsqueeze(0).to(*_device); // 将tensor转移到GPU上std::vector<torch::Tensor> tensors_to_stack = {src_unsqueezed}; // 创建一个包含 src 的 vectortorch::Tensor input_arg = torch::stack(tensors_to_stack, 1); // 沿着维度1堆叠tensorsif (showlog){std::cout << input_arg.sizes() << std::endl;}torch::NoGradGuard no_grad; // 暂时禁用梯度计算auto output_dict = _mymodule->forward({input_arg});torch::Tensor output_data;if (output_dict.isTensor()){output_data = output_dict.toTensor().to(at::kCPU); // 如果是Tensor,则通过toTensor()方法获取它if (showlog){std::cout << "out shape: " << output_data.sizes() << std::endl;}}else{if (showlog){std::cerr << "The IValue does not contain a Tensor." << std::endl;}}float *f = output_data.data_ptr<float>();int output_width = output_data.size(3);int output_height = output_data.size(4);int size_pic = output_width * output_height;std::vector<cv::Mat> rgbChannels(3);rgbChannels[0] = cv::Mat(output_width, output_height, CV_32FC1, f);rgbChannels[1] = cv::Mat(output_width, output_height, CV_32FC1, f + size_pic);rgbChannels[2] = cv::Mat(output_width, output_height, CV_32FC1, f + size_pic + size_pic);rgbChannels[0].convertTo(rgbChannels[0], CV_8UC1, 255);rgbChannels[1].convertTo(rgbChannels[1], CV_8UC1, 255);rgbChannels[2].convertTo(rgbChannels[2], CV_8UC1, 255);merge(rgbChannels, outputpic);return 0;}private:bool _modelsuccess = false;int _gpuid = 0;std::unique_ptr<torch::Device> _device;std::unique_ptr<torch::jit::script::Module> _mymodule;};}
#include <unistd.h>
#include"lbk_video_super_resolution.hpp"
using namespace LIANGBAIKAI_BASE_MODEL_NAME;
int main(int argc,char *argv[])
{if(argc < 5){std::cout << "./test 模型  GPUid(cpu传-1) 输入图片 输出图片" << std::endl;return -1;}std::string modelfile = argv[1];int gpuid = atoi(argv[2]);std::string imgfile = argv[3];std::string outfile = argv[4];cv::Mat src = cv::imread(imgfile);lbk_video_super_resolution_basicPP test;if(0 > test.init(modelfile,gpuid)){std::cout << "init failed" << std::endl;return -1;}cv::Mat out;int rec = test.inference(src,out,true);if(rec >= 0){cv::imwrite(outfile, out);}return 0;
}

效果

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/780916.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用docker 安装oracle 11g 挂载数据目录并修改SID centos-7

建议&#xff1a;建议使用其它系统去装ubuntu或Rocky&#xff08;因为centos已经停止维护&#xff09; 1、安装docker 这里就不细写了&#xff0c;可以查看清华镜像源或者阿里镜像源 清华&#xff1a;https://mirrors.tuna.tsinghua.edu.cn/help/docker-ce/ 阿里&#xff1a;ht…

2434. 使用机器人打印字典序最小的字符串

点击跳转题目 本题学到两点&#xff1a; 1.初始化数组&#xff0c;全部为0的简单写法。之前都是 int arr[26]; memset(arr,0,sizeof(arr));2.if条件中的&&部分左右顺序不能颠倒。颠倒报错&#xff0c;之前一直没重视。 思路&#xff1a; 遍历s&#xff0c;push当前字…

[c++]类和对象常见题目详解

本专栏内容为&#xff1a;C学习专栏&#xff0c;分为初阶和进阶两部分。 通过本专栏的深入学习&#xff0c;你可以了解并掌握C。 &#x1f493;博主csdn个人主页&#xff1a;小小unicorn ⏩专栏分类&#xff1a;C &#x1f69a;代码仓库&#xff1a;小小unicorn的代码仓库&…

2024.03.19 校招 实习 内推 面经

绿*泡*泡VX&#xff1a; neituijunsir 交流*裙 &#xff0c;内推/实习/校招汇总表格 1、校招 | RoboSense 速腾聚创2024届春招启动&#xff08;内推&#xff09; 校招 | RoboSense 速腾聚创2024届春招启动&#xff08;内推&#xff09; 2、实习 | 百度智能驾驶事业群组 202…

kanzi 3d知识点

整理学习资料 名字链接Kanzi视频合集中科创达-智能座舱视频专辑-中科创达-智能座舱视频合集-哔哩哔哩视频 (bilibili.com)Kanzi在线文档Working with … - Kanzi framework 3.9.7 documentationThe Book of ShadersThe Book of Shaders着色器语言Shader_着色语言Shading Langua…

Vim - 文本编辑器 Vi vs Vim

你是否应该在学习 Vim 之前先学习 Vi&#xff0c;这完全取决于您自己、您的要求以及您的具体目标和需求。Vim 是 Vi 的扩展、增强和改进版本&#xff0c;它包括 Vi 的所有功能以及许多附加功能。 简约&#xff1a; Vi 设计简约。先学习 Vi 可以让你对基础知识有扎实的了解&…

malloc是如何分配内存|malloc(1)分配多大内存|free释放内存,会还给操作系统吗?

前言 大家好&#xff0c; 我jiantaoyab&#xff0c;这篇文章给大家介绍mallo和free面试中常问到的问题。 malloc是如何分配内存的&#xff1f; 如果用户分配的内存小于128KB&#xff0c;则通过brk()申请内存 如果用户分配的内存大于128KB&#xff0c;则通过mmap()申请内存 简…

数据分析之POWER Piovt的KPI设置

内容总结&#xff1a; 1.两个表格关联不上&#xff1a;需要添加辅助列&#xff0c;建立关联 2.添加辅助列后还关联不上&#xff1a;将虚线变为实线 3.根据需求要增加一些度量值 4.设置KPI后&#xff0c;绝对值选1后设定百分比 5.在透视表里面加入KPI状态 导入所关联的数据后建立…

游戏领域AI智能视频剪辑解决方案

游戏行业作为文化创意产业的重要组成部分&#xff0c;其发展和创新速度令人瞩目。然而&#xff0c;随着游戏内容的日益丰富和直播文化的兴起&#xff0c;传统的视频剪辑方式已难以满足玩家和观众日益增长的需求。美摄科技&#xff0c;凭借其在AI智能视频剪辑领域的深厚积累和创…

SQLBolt,一个练习SQL的宝藏网站

知乎上有人问学SQL有什么好的网站&#xff0c;这可太多了。 我之前学习SQL买了本SQL学习指南&#xff0c;把语法从头到尾看了个遍&#xff0c;但仅仅是心里有数的程度&#xff0c;后来进公司大量的写代码跑数&#xff0c;才算真真摸透了SQL&#xff0c;知道怎么调优才能最大化…

数据可视化之折线图plot

import matplotlib.pyplot as plt plt.rcParams[font.family] [SimHei]# 查看matplotlibde文件地址# import matplotlib # print(matplotlib.matplotlib_fname()) # plt.rcParams[font.sans-serif] [SimHei] # 准备数据time [20200401,20200402,20200403,20200404,20200405…

SpringBoot登录校验(三)

​​​​​​​SpringBoot 登录认证&#xff08;一&#xff09;-CSDN博客 SpringBoot 登录认证&#xff08;二&#xff09;-CSDN博客 SpringBoot登录校验&#xff08;三&#xff09;-CSDN博客 前面我们介绍了传统的会话跟踪技术cookie和sesstion&#xff0c;本节讲解令牌技术…

Ubuntu20.04LTS+uhd3.15+gnuradio3.8.1源码编译及安装

文章目录 前言一、卸载本地 gnuradio二、安装 UHD 驱动三、编译及安装 gnuradio四、验证 前言 本地 Ubuntu 环境的 gnuradio 是按照官方指导使用 ppa 的方式安装 uhd 和 gnuradio 的&#xff0c;也是最方便的方法&#xff0c;但是存在着一个问题&#xff0c;就是我无法修改底层…

Spel 表达式

模板占位替换&#xff0c;在项目开发中&#xff0c;还是很常用的。比如在代码中获取参数&#xff0c;消息推送可以使用变量占位&#xff0c;我比较推荐使用 SPEL 表达式。 在注解中&#xff0c;获取方法的参数 public class SpElParser {private static final ExpressionPars…

基于机器视觉的智能物流机器人的设计与开发

标题&#xff1a;基于机器视觉的智能物流机器人的设计与开发 摘要&#xff1a; 随着电子商务和物流行业的快速发展&#xff0c;智能物流机器人作为一种高效、准确的自动化解决方案&#xff0c;正逐渐受到广泛关注。本文围绕基于机器视觉技术的智能物流机器人的设计与研发展开&…

HarmonyOS实战开发-如何实现一个简单的电子相册应用开发

介绍 本篇Codelab介绍了如何实现一个简单的电子相册应用的开发&#xff0c;主要功能包括&#xff1a; 实现首页顶部的轮播效果。实现页面跳转时共享元素的转场动画效果。实现通过手势控制图片的放大、缩小、左右滑动查看细节等效果。 相关概念 Swiper&#xff1a;滑块视图容…

java多线程中的阻塞队列

一、普通不阻塞队列 还记得队列我们如何实现吗&#xff1f;我们用的是循环队列的方式&#xff0c;回一下&#xff1a; 描述&#xff1a;开始tail和head指针都指向最开始位置&#xff0c;往里面添加元素tail&#xff0c;出元素head 初始状态&#xff1a; put元素后状态 take…

账号微服务短信验证码发送工具单元测试

账号微服务短信验证码发送工具单元测试 注意sms的 app-code #----------sms短信配置-------------- sms:app-code: dd7829bedfaf4373875aa91abba82523template-id: JM1000372package net.xdclass.config;import org.springframework.context.annotation.Bean; import org.spri…

ROS 2边学边练(4)-- 何为主题(topics)

概念 主题是一种节点间的通信方式&#xff0c;某个节点充当发布特定&#xff08;主题&#xff09;消息&#xff08;数据&#xff09;的角色&#xff0c;另外一些节点则可以订阅接收该特定&#xff08;主题&#xff09;消息&#xff08;数据&#xff09;。两者&#xff0…

在ubuntu上搭建系统监控系统

大纲 数据生产方安装和运行验证 数据收集、存储和分发方下载和解压修改配置运行验证 数据消费方下载和运行验证新增数据源新增看板关联看板和数据源效果展现 参考资料 在一个监控系统中&#xff0c;一定会有“数据生产方”和“数据消费方”存在。“数据生产方”用于产出需要监控…