BasicVSR++模型转JIT并用c++libtorch推理

BasicVSR++模型转JIT并用c++libtorch推理

文章目录

  • BasicVSR++模型转JIT并用c++libtorch推理
    • 安装BasicVSR++ 环境
      • 1.下载源码
      • 2. 新建一个conda环境
      • 3. 安装pytorch
      • 4. 安装 mim 和 mmcv-full
      • 5. 安装 mmedit
      • 6. 下载模型文件
      • 7. 测试一下能否正常运行
    • 转换为JIT模型
    • 用c++ libtorch推理
      • 效果

安装BasicVSR++ 环境

1.下载源码

git clone https://github.com/ckkelvinchan/BasicVSR_PlusPlus.git

2. 新建一个conda环境

conda create -n BasicVSRPLUSPLUS  python=3.8 -y
conda activate BasicVSRPLUSPLUS  

3. 安装pytorch

pytorch官网 安装合适的版本
我这里是CUDA11.6版本

conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia

4. 安装 mim 和 mmcv-full

pip install openmim
mim install mmcv-full

5. 安装 mmedit

pip install mmedit

6. 下载模型文件

下载模型文件放在这里chkpts/basicvsr_plusplus_reds4.pth

7. 测试一下能否正常运行

python demo/restoration_video_demo.py configs/basicvsr_plusplus_reds4.py chkpts/basicvsr_plusplus_reds4.pth data/demo_000 results/demo_000

在这里插入图片描述
在这里插入图片描述
OK ! 环境正常!下面开始转换工作

转换为JIT模型

在demo下新建一个转换脚本

import os
import cv2
import mmcv
import numpy as np
import torch
from mmedit.core import tensor2img
from mmedit.apis import init_modeldef main():# 加载模型并设置为评估模式model = init_model("configs/basicvsr_plusplus_reds4.py","chkpts/basicvsr_plusplus_reds4.pth", device=torch.device('cuda', 0))model.eval()# 准备一个示例输入src1 = cv2.imread("./data/img/00000000.png")src = cv2.cvtColor(src1, cv2.COLOR_BGR2RGB)src = torch.from_numpy(src / 255.).permute(2, 0, 1).float()src = src.unsqueeze(0)input_arg = torch.stack([src], dim=1)input_arg = input_arg.to(torch.device('cuda', 0))  # 确保输入在GPU上# # 执行模型推理# with torch.no_grad():  # 在推理时不需要计算梯度#     result = model(input_arg, test_mode=True)['output'].cpu()# output_i = tensor2img(result)# mmcv.imwrite(output_i, "./test.png")# 模型转换traced_model = torch.jit.trace(model.generator, input_arg)torch.jit.save(traced_model, "basicvsrPP.pt")# 测试res = traced_model(input_arg)out = tensor2img(res)mmcv.imwrite(out, "./testoo.png")if __name__ == '__main__':main()

用c++ libtorch推理

/** @Author: Liangbaikai* @LastEditTime: 2024-03-29 11:28:42* @Description: 视频超分* Copyright (c) 2024 by Liangbaikai, All Rights Reserved.*/#pragma once
#include <iostream>
#include <torch/script.h>
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <vector>
#include <c10/cuda/CUDACachingAllocator.h>namespace LIANGBAIKAI_BASE_MODEL_NAME
{class lbk_video_super_resolution_basicPP{public:lbk_video_super_resolution_basicPP() = default;virtual ~lbk_video_super_resolution_basicPP(){c10::cuda::CUDACachingAllocator::emptyCache();// cudaDeviceReset();}/*** @description: 初始化* @param {string} &modelpath 模型文件* @param {int} gpuid GPU的id* @return {*}成功返回0,失败返回-1*/int init(const std::string &modelpath, int gpuid = 0){try{_mymodule = std::make_unique<torch::jit::script::Module>(torch::jit::load(modelpath));}catch (const c10::Error &e){std::cerr << "Error loading the model " << modelpath << std::endl;std::cerr << "Error " << e.what() << std::endl;return -1;}_gpuid = gpuid;if ((_gpuid < 0) || (!torch::cuda::is_available())){_device = std::make_unique<torch::Device>(torch::kCPU);_mymodule->to(at::kCPU);}else{_device = std::make_unique<torch::Device>(torch::kCUDA, _gpuid);_mymodule->to(at::kCUDA, _gpuid);}_mymodule->eval();_modelsuccess = true;return 0;}/*** @description: 推理* @param {Mat} &inputpic 输入图片* @param {Mat} &outputpic 输出结果* @param {bool} showlog  是否打印日志* @return {*} 成功返回0,失败返回-1*/int inference(cv::Mat &inputpic, cv::Mat &outputpic, bool showlog = false){if (inputpic.empty() || (inputpic.channels() != 3)){std::cout << "input data ERROR" << std::endl;return -1;}if (!_modelsuccess){std::cout << "model has not been inited!" << std::endl;return -1;}// torch::DeviceGuard 是一个类,它的作用是确保在使用完设备(如CPU或GPU)后,能够正确地将设备恢复到使用前的状态。torch::DeviceGuard device_guard(*_device); // 作用域内所有操作都在指定设备上运行,离开此作用域恢复cv::transpose(inputpic, inputpic); // 顺时针旋转// 将图片转换为tensorcv::Mat img_float;inputpic.convertTo(img_float, CV_32FC3, 1.0 / 255);torch::Tensor img_tensor = torch::from_blob(img_float.data, {img_float.rows, img_float.cols, 3}, torch::kFloat32).permute({2, 1, 0});img_tensor = (img_tensor - 0.5) / 0.5;img_tensor = (img_tensor + 1) / 2;img_tensor = torch::clamp(img_tensor, 0, 1);torch::Tensor src_unsqueezed = img_tensor.unsqueeze(0).to(*_device); // 将tensor转移到GPU上std::vector<torch::Tensor> tensors_to_stack = {src_unsqueezed}; // 创建一个包含 src 的 vectortorch::Tensor input_arg = torch::stack(tensors_to_stack, 1); // 沿着维度1堆叠tensorsif (showlog){std::cout << input_arg.sizes() << std::endl;}torch::NoGradGuard no_grad; // 暂时禁用梯度计算auto output_dict = _mymodule->forward({input_arg});torch::Tensor output_data;if (output_dict.isTensor()){output_data = output_dict.toTensor().to(at::kCPU); // 如果是Tensor,则通过toTensor()方法获取它if (showlog){std::cout << "out shape: " << output_data.sizes() << std::endl;}}else{if (showlog){std::cerr << "The IValue does not contain a Tensor." << std::endl;}}float *f = output_data.data_ptr<float>();int output_width = output_data.size(3);int output_height = output_data.size(4);int size_pic = output_width * output_height;std::vector<cv::Mat> rgbChannels(3);rgbChannels[0] = cv::Mat(output_width, output_height, CV_32FC1, f);rgbChannels[1] = cv::Mat(output_width, output_height, CV_32FC1, f + size_pic);rgbChannels[2] = cv::Mat(output_width, output_height, CV_32FC1, f + size_pic + size_pic);rgbChannels[0].convertTo(rgbChannels[0], CV_8UC1, 255);rgbChannels[1].convertTo(rgbChannels[1], CV_8UC1, 255);rgbChannels[2].convertTo(rgbChannels[2], CV_8UC1, 255);merge(rgbChannels, outputpic);return 0;}private:bool _modelsuccess = false;int _gpuid = 0;std::unique_ptr<torch::Device> _device;std::unique_ptr<torch::jit::script::Module> _mymodule;};}
#include <unistd.h>
#include"lbk_video_super_resolution.hpp"
using namespace LIANGBAIKAI_BASE_MODEL_NAME;
int main(int argc,char *argv[])
{if(argc < 5){std::cout << "./test 模型  GPUid(cpu传-1) 输入图片 输出图片" << std::endl;return -1;}std::string modelfile = argv[1];int gpuid = atoi(argv[2]);std::string imgfile = argv[3];std::string outfile = argv[4];cv::Mat src = cv::imread(imgfile);lbk_video_super_resolution_basicPP test;if(0 > test.init(modelfile,gpuid)){std::cout << "init failed" << std::endl;return -1;}cv::Mat out;int rec = test.inference(src,out,true);if(rec >= 0){cv::imwrite(outfile, out);}return 0;
}

效果

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/780916.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

malloc是如何分配内存|malloc(1)分配多大内存|free释放内存,会还给操作系统吗?

前言 大家好&#xff0c; 我jiantaoyab&#xff0c;这篇文章给大家介绍mallo和free面试中常问到的问题。 malloc是如何分配内存的&#xff1f; 如果用户分配的内存小于128KB&#xff0c;则通过brk()申请内存 如果用户分配的内存大于128KB&#xff0c;则通过mmap()申请内存 简…

数据分析之POWER Piovt的KPI设置

内容总结&#xff1a; 1.两个表格关联不上&#xff1a;需要添加辅助列&#xff0c;建立关联 2.添加辅助列后还关联不上&#xff1a;将虚线变为实线 3.根据需求要增加一些度量值 4.设置KPI后&#xff0c;绝对值选1后设定百分比 5.在透视表里面加入KPI状态 导入所关联的数据后建立…

游戏领域AI智能视频剪辑解决方案

游戏行业作为文化创意产业的重要组成部分&#xff0c;其发展和创新速度令人瞩目。然而&#xff0c;随着游戏内容的日益丰富和直播文化的兴起&#xff0c;传统的视频剪辑方式已难以满足玩家和观众日益增长的需求。美摄科技&#xff0c;凭借其在AI智能视频剪辑领域的深厚积累和创…

SQLBolt,一个练习SQL的宝藏网站

知乎上有人问学SQL有什么好的网站&#xff0c;这可太多了。 我之前学习SQL买了本SQL学习指南&#xff0c;把语法从头到尾看了个遍&#xff0c;但仅仅是心里有数的程度&#xff0c;后来进公司大量的写代码跑数&#xff0c;才算真真摸透了SQL&#xff0c;知道怎么调优才能最大化…

SpringBoot登录校验(三)

​​​​​​​SpringBoot 登录认证&#xff08;一&#xff09;-CSDN博客 SpringBoot 登录认证&#xff08;二&#xff09;-CSDN博客 SpringBoot登录校验&#xff08;三&#xff09;-CSDN博客 前面我们介绍了传统的会话跟踪技术cookie和sesstion&#xff0c;本节讲解令牌技术…

Ubuntu20.04LTS+uhd3.15+gnuradio3.8.1源码编译及安装

文章目录 前言一、卸载本地 gnuradio二、安装 UHD 驱动三、编译及安装 gnuradio四、验证 前言 本地 Ubuntu 环境的 gnuradio 是按照官方指导使用 ppa 的方式安装 uhd 和 gnuradio 的&#xff0c;也是最方便的方法&#xff0c;但是存在着一个问题&#xff0c;就是我无法修改底层…

HarmonyOS实战开发-如何实现一个简单的电子相册应用开发

介绍 本篇Codelab介绍了如何实现一个简单的电子相册应用的开发&#xff0c;主要功能包括&#xff1a; 实现首页顶部的轮播效果。实现页面跳转时共享元素的转场动画效果。实现通过手势控制图片的放大、缩小、左右滑动查看细节等效果。 相关概念 Swiper&#xff1a;滑块视图容…

java多线程中的阻塞队列

一、普通不阻塞队列 还记得队列我们如何实现吗&#xff1f;我们用的是循环队列的方式&#xff0c;回一下&#xff1a; 描述&#xff1a;开始tail和head指针都指向最开始位置&#xff0c;往里面添加元素tail&#xff0c;出元素head 初始状态&#xff1a; put元素后状态 take…

账号微服务短信验证码发送工具单元测试

账号微服务短信验证码发送工具单元测试 注意sms的 app-code #----------sms短信配置-------------- sms:app-code: dd7829bedfaf4373875aa91abba82523template-id: JM1000372package net.xdclass.config;import org.springframework.context.annotation.Bean; import org.spri…

ROS 2边学边练(4)-- 何为主题(topics)

概念 主题是一种节点间的通信方式&#xff0c;某个节点充当发布特定&#xff08;主题&#xff09;消息&#xff08;数据&#xff09;的角色&#xff0c;另外一些节点则可以订阅接收该特定&#xff08;主题&#xff09;消息&#xff08;数据&#xff09;。两者&#xff0…

在ubuntu上搭建系统监控系统

大纲 数据生产方安装和运行验证 数据收集、存储和分发方下载和解压修改配置运行验证 数据消费方下载和运行验证新增数据源新增看板关联看板和数据源效果展现 参考资料 在一个监控系统中&#xff0c;一定会有“数据生产方”和“数据消费方”存在。“数据生产方”用于产出需要监控…

Android MediaRecorder

AndroidManifest.xml中添加权限标记 <uses-permission android:name"android.permission.RECORD_AUDIO"/> 动态添加权限MainActivity requestPermissions(new String[]{Manifest.permission.CAMERA,Manifest.permission.RECORD_AUDIO},100); 创建MediaReco…

Flask学习(五):session相关流程

流程图如下图所示&#xff1a; 调用相关类如下图所示&#xff1a; 相关代码如下&#xff1a; from flask import Flask, sessionapp Flask(__name__)1. 加密会话数据&#xff1a;在 Flask 中&#xff0c;会话数据存储在客户端的 cookie 中。设置 app.secret_key 可以加密会话…

OLED模块

OLED模块 综述&#xff1a;本篇文章简要讲述了oled的定义&#xff0c;两种oled的引脚和接线情况、iic通讯协议、spi通讯协议、OLED代码引用和注意事项。 1.定义 OLED&#xff08;Organic Light-Emitting Diode&#xff09;模块是一种使用有机发光二极管作为显示元素的显示模…

DFS:二叉树的深搜与回溯

一、计算布尔二叉树的值 . - 力扣&#xff08;LeetCode&#xff09; class Solution { public:bool evaluateTree(TreeNode* root) {if(root->leftnullptr) return root->val0?false:true; bool left evaluateTree(root->left);bool rightevaluateTree(root->rig…

1.1 单片机的概念

一,单片机的概念 单片机(Single-Chip Microcomputer),也被称为单片微控制器,是一种集成电路芯片。它采用超大规模集成电路技术,将具有数据处理能力的中央处理器CPU、随机存储器RAM、只读存储器ROM、多种I/O口和中断系统、定时器/计数器等功能(可能还包括显示驱动电路、…

springcloud基本使用(搭建eureka服务端)

创建springbootmaven项目 next next finish创建成功 删除项目下所有文件目录&#xff0c;只保留pox.xml文件 父项目中的依赖&#xff1a; springboot依赖&#xff1a; <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-s…

Swift:“逻辑运算子“与“比较运算符“

1. 逻辑非 ! 逻辑非运算符 ! 是用于对布尔值取反的。当操作数为 true 时&#xff0c;! 将返回 false&#xff0c;而当操作数为 false 时&#xff0c;! 将返回 true。 let isTrue true let isFalse !isTrue // isFalse 现在是 false 2. 逻辑与 && 逻辑与运算符 &a…

爬取b站音频和视频数据,未合成一个视频

一、首先找到含有音频和视频的url地址 打开一个视频&#xff0c;刷新后&#xff0c;找到这个包&#xff0c;里面有我们所需要的数据 访问这个数据包后&#xff0c;获取字符串数据&#xff0c;用正则提取&#xff0c;再转为json字符串方便提取。 二、获得标题和音频数据后&…

linux基础命令篇:Linux基础命令讲解——文件浏览(cat、less、head、tail和grep)

Linux基础命令讲解——文件浏览&#xff08;cat、less、head、tail和grep&#xff09; 本文详细介绍Linux中的cat、less、head、tail和grep命令&#xff0c;这些命令在日常工作中非常实用&#xff0c;以下是关于这些命令的详细介绍&#xff1a; 1. cat命令&#xff1a;用于查看…