项目四:TensorRT编程

TensorRT编程 - 单层感知机结构

build.cpp

/*
一般先是运行 build.cpp, 在运行 runtime.cu
TensorRT命名空间是 nvinfer1;
1.创建builder
2.创建网络定义: builder --> network
3.配置参数: builder --> config
4.生成engine: builder --> engine(network, config)
5.序列化保存: engine --> serialize
6.释放资源: delete
*/#include<iostream>
#include<NvInfer.h>// 创建 logger 用来管控打印日志级别
// TRTLogger 继承自 nvinfer1 :: ILogger
class TRTLogger : public nvinfer1 :: ILogger
{void log(Severity severity, const char *msg) noexcept override{// 屏蔽 INFO 级别的日志if (severity == Severity::kINFO){std :: cout << msg << std :: endl;}}
}gLogger;// 保存权重
void saveWeights(const std :: string &filename, const float *data, int size)
{std :: ofstream outfile(filename, std :: ios :: binary);assert(outfile.is_open() && "save weights failed");                       // assert 断言, 如果条件不满足就会报错outfile.write((char *)(&size), sizeof(int));                              // 写入 权重size(大小: 代表数据的多少), 传入地址outfile.write((char *)(data), size * sizeof(float));                      // 写入 权重data, 传入地址outfile.close();}// 读取权重, vector 进行读取
std :: vector<float> loadWeights(const std :: string &filename)
{// 读取文件std :: ifstream infile(filename, std :: ios :: binary);assert(infile.is_open() && "load weights failed");                          // assert 断言, 如果条件不满足就会报错int size;infile.read((char *)(&size), sizeof(int));                                  // 读取 权重size(大小: 代表数据的多少), 传入地址std :: vector<float> data(size);                                            // 权重data, 传入大小infile.read((char *)(data.data()), size * sizeof(float));                   // 读取 权重data, 传入地址infile.close();return data;                                                                // 返回形式是: vector }int main()
{// ====================== 1. 创建 builder ======================TRTLogger logger;nvinfer1 :: IBuilder* builder = nvinfer1 :: createInferBuilder(logger);// ====================== 2. 创建网络定义 ======================// 调用createNetworkV2()创建网络结构, 1: 代表显性 batchnvinfer1 :: INetworkDefinition *network = builder -> createNetworkV2(1);// 定义网络结构// mlp 多层感知机: input(1, 3, 1, 1) --> fc1 --> sigmoid --> output(2)// =========== 创建一个 input tensor ===========const int input_size = 3;nvinfer1 :: ITensor *input = network -> addInput("data", nvinfer1 :: DataType :: kFLOAT, nvinfer1 :: Dims4(1, input_size, 1, 1));// =========== 创建全连接层 fc1 ===========// weight 和 bias(堆上分配内存), 直接定义const float *fc1_weight_data = new float [input_size * 2] {0.1, 0.2, 0.3, 0.4, 0.5, 0.6};const float *fc1_bias_data = new float [2] {0.1, 0.5};// 将 weight 和 bias 转为 nvinfer 模型, 参数分别是: data_type, data(指针获取数据传过来是地址), size(大小, 个数)// nvinfer1 :: Weights fc1_weights {nvinfer1 :: DataType :: kFLOAT, fc1_weight_data, input_size * 2};// nvinfer1 :: Weights fc1_bias {nvinfer1 :: DataType :: kFLOAT, fc1_bias_data, 2};// ====== 将 weight 和 bias 保存权重 ======// 将权重保存到文件中, 演示从别的来源加载权重, saveWeights("./model/fc1.wts", fc1_weight_data, 6);saveWeights("./model/fc1.bias", fc1_bias_data, 2);// 读取对应的权重, auto 进行读取, loadWeights()函数返回的是 vecauto fc1_weights_vec = loadWeights("./model/fc1.wts");auto fc1_bias_vec = loadWeights("./model/fc1.bias");// ====== 从保存权重获取weight和bias ======// 将 weight 和 bias 转为 nvinfer 模型, 参数分别是: data_type, data(vec获取数据是 data), size(大小, 个数)nvinfer1 :: Weights fc1_weight {nvinfer1 :: DataType :: kFLOAT, fc1_weights_vec.data(), fc1_weights_vec.size()};nvinfer1 :: Weights fc1_bias {nvinfer1 :: DataType :: kFLOAT, fc1_bias_vec.data(), fc1_bias_vec.size()};const int output_size = 2;// 调用addFullyConnected()创建全连接层, 参数分别是: input tensor, output size, weight, biasnvinfer1 :: IFullyConnectedLayer *fc1 = network -> addFullyConnected(*input, output_size, fc1_weight, fc1_bias);//  =========== 创建 sigmoid ===========// 添加 sigmoid 激活层, 参数分别是: input_tensor, activation_type(激活函数类型),nvinfer1 :: IActivationLayer *sigmoid = network -> addActivation(*fc1->getOutput(0), nvinfer1 :: ActivationType :: kSIGMOID);// =========== 设置输出名字 ===========sigmoid -> getOutput(0) -> setName("output");// 标记输出, 没有标记会被当成顺时针优化掉network -> markOutput(*sigmoid->getOutput(0));// 设定最大的batch_sizebuilder -> setMaxBatchSize(1);// ====================== 3. 配置参数: builer ---> config ======================// 添加配置参数, 告诉 tensorRT 应该如何优化网络nvinfer1 :: IBuilderConfig *config = builder -> createBuilderConfig();// 设置最大工作空间, 单位是字节config -> setMaxWorkspaceSize(1 << 28);                                  // 1 << 28: 256MiB// ====================== 4. 生成engine:builder --> network --> config ======================nvinfer1 :: ICudaEngine *engine = builder -> buildEngineWithConfig(*network, *config);// 如果没有生成 engine, 就简单的报错一下if(!engine){std :: cout << "Failed to create engine!" << std :: endl;return -1;}// ====================== 5. 序列化engine, 保存 ======================nvinfer1 :: IHostMemory *serialized_engine = engine->serialize();// 存入文件std :: ofstream outfile("./model/mlp.engine", std :: ios :: binary);// 进行断言assert(outfile.is_open() && "Failed to open file for writing");outfile.write((char *)serialized_engine->data(), serialized_engine->size());// ====================== 6. 释放资源 ======================// 理论上申请的资源都要释放, 这里只是释放部分资源outfile.close();delete serialized_engine;delete engine;delete config;delete network;delete builder;std :: cout << "engine 文件生成成功" << std :: endl;return 0;}

runtime.cu

// TensorRT运行时的最高层级接口是Runtime
// 执行推理的部分/*
使用 .cu 是希望使用 CUDA 的编译器 NVCC, 会自动连接 cuda 库TensorRT runtime 的推理过程
1. 创建一个 runtime 对象
2. 反序列化生成 engine: runtime --> engine
3. 创建一个执行上下文 (ExecutionContext调用enqueueV2()来运行Inference) ExecutionContext: engine --> context4. 填充数据5. 执行推理: context --> enqueueV26. 释放资源: delete*/#include<iostream>
#include<vector>
#include<cassert>
#include<fstream>#include<NvInfer.h>
#include<cuda_runtime.h>// 创建 logger 用来管控打印日志级别
// TRTLogger 继承自 nvinfer1 :: ILogger
class TRTLogger : public nvinfer1 :: ILogger
{void log(Severity severity, const char *msg) noexcept override{// 屏蔽 INFO 级别的日志if (severity == Severity::kINFO){std :: cout << msg << std :: endl;}}
}gLogger;// 加载模型: mlp.engine
std :: vector<unsigned char> loadEngineModel(const std :: string &filename)
{std :: ifstream infile(filename, std :: ios :: binary);                       // 以二进制方式读取文件// 断言assert(infile.is_open() && "load weights failed");file.seekg(0, std :: ios :: binary);                            // 定位到文件末尾size_t size = file.tellg();                                     // 获取文件大小std :: vector<unsigned char> data(size);                        // 创建一个 vector, 大小为 sizefile.seekg(0, std :: ios :: beg);file.read((char *)(data.data()), size);                         //file.close();return data;}int main()
{// ========================== 1. 创建一个 runtime 对象 ==========================// TRTLogger 实例化TRTLogger logger;nvinfer1 :: IRuntime *runtime = nvinfer1 :: createInferRuntime(logger);// ========================== 2. 反序列化生成 engine: runtime --> engine ==========================// 读取文件auto engineModel = loadEngineModel("./model/mlp.engine");// 调用 runtime 的反序列方法, 生成 engine, 参数: 模型数据地址, 模型大小, pluginFactorynvinfer1 :: ICudaEngine *engine = runtime -> deserializeCudaEngine(engineModel.data(), engineModel.size(), nullprt);if(!engine){std :: cout << "deserialize engine failed" << std :: endl;return -1;}// ========================== 3. 创建一个执行上下文 ==========================nvinfer1 :: IExecutionContext *context = engine -> createExecutionContext();// ========================== 4. 填充数据 ==========================//  设置 stream 流cudaStream_t stream = nullptr;cudaStreamCreate(&stream);// cpu: host  GPU: device// 数据流转: host --> device --> inference --> host// 输入数据float *host_input_data = new float[3] {2, 4, 8};                           // host 输入数据int input_data_size = 3 * sizeof(float);                                   // 输入数据大小float *device_input_data = nullptr;                                        // device 输入数据// 输出数据float *host_output_data = new float[3] {0, 0,};                            // host 输出数据, 初始化0, 0int output_data_size = 2 * sizeof(float);                                  // 输出数据大小float *device_output_data = nullptr;                                       // device 输出数据// 申请 device 内存cudaMalloc((void **)&device_input_data, input_data_size);cudaMalloc((void **)&device_output_data, output_data_size);// host --> device(异步拷贝)// 参数分别是: 目标地址, 源地址, 数据大小, 拷贝方向cudaMemcpyAsync(device_input_data, host_input_data, input_data_size, cudaMemcpyHostToDevice, stream);// bindings 告诉 Context 输入输出数据的位置float *bindings[] = {device_input_data, device_output_data};// ========================== 5. 执行推理 ==========================bool success = context -> enqueueV2((void **)bindings, stream, nullptr);// 数据从 device --> hostcudaMemcpyAsync(host_output_data, device_output_data, output_data_size, cudaMemcpyDeviceToHost, stream);// 等待流执行完毕cudaStreamSynchronize(stream);// 输出结果std :: cout << "输出结果: " << host_output_data[0] << " " << host_output_data[1] << std :: endl;// 释放资源cudaStreamDestroy(stream);cudaFree(device_input_data);cudaFree(device_output_data);delete host_input_data;delete host_output_data;delete context;delete engine;delete runtime;return 0;}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/29422.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Altair 助力优化摩托车空气动力学性能,实现最佳的整流罩设计

案例简介 整流罩是绝大多数摩托车的重要组成部分&#xff0c;旨在提高车辆的空气动力学性能和稳定性。Altair 与 KTM 公司员工组成的项目团队&#xff0c;针对摩托车整流罩空气动力学方面的学生项目&#xff0c;展开了密切合作。 项目任务主要是对摩托车整流罩设计进行比较&…

山体滑坡监测利器:传感器与智能监测平台的应用

山体滑坡&#xff0c;这一地质灾害的代名词&#xff0c;指的是山坡上的土体或岩体在重力作用下&#xff0c;因自然或人为因素而向下滑动的现象。滑坡具有突发性、隐蔽性、危害性和破坏性等特征&#xff0c;因此&#xff0c;对于山体滑坡的监测工作显得尤为重要。本文将探讨山体…

豆包高质量声音有望复现-Seed-TTS

我们介绍了 Seed-TTS&#xff0c;这是一个大规模自回归文本转语音 &#xff08;TTS&#xff09; 模型系列&#xff0c;能够生成与人类语音几乎没有区别的语音。Seed-TTS 作为语音生成的基础模型&#xff0c;在语音上下文学习方面表现出色&#xff0c;在说话人的相似性和自然性方…

Vitis HLS 学习笔记--Stream Chain Matrix Multiplication

目录 1. 简介 2. 示例解析 2.1 示例功能说明 2.2 函数说明 2.2.1 mmult 函数 2.2.2 mm2s 函数 2.2.3 s2mm 函数 2.2.4 总示意图 3. 总结 1. 简介 这是一个包含使用数据流的级联矩阵乘法的内核。该内核启用了 ap_ctrl_chain&#xff0c;以展示如何重叠多个内核调用队…

防火墙规则来阻止攻击者的 IP 地址

1. iptables 要禁止服务器与特定 IP 地址的通信&#xff0c;可以使用防火墙来设置规则。在 Ubuntu 上&#xff0c;iptables 是一个常用的防火墙工具。以下是使用 iptables 设置禁止与特定 IP 通信的步骤&#xff1a; 阻止所有进出的通信 如果你想阻止服务器与特定 IP 地址的…

AES加解密工具类

文章目录 前言一、AES加解密工具类总结 前言 当涉及到数据的安全性和保密性时&#xff0c;加密是一种关键的技术手段。AES&#xff08;Advanced Encryption Standard&#xff09;是一种广泛使用的对称加密算法&#xff0c;被认为是目前最安全和最常用的加密算法之一。 一、AES…

2024年最好用的精简系统推荐!旧电脑福音!

精简版电脑系统经过精心优化&#xff0c;去除了冗余功能&#xff0c;保留了核心功能&#xff0c;让用户的操作更加便捷高效&#xff0c;同时也具备强大的兼容性和稳定性&#xff0c;整体操作体验感很好。但是&#xff0c;许多新手用户不知道在哪里才可以找到好用的精简版系统&a…

Mojo崛起:AI-first 的编程语言能否成为新流行?

眨眼之间&#xff0c;你可能会错过又一种编程语言的发明。 有个笑话说&#xff0c;程序员花费20%的时间编写代码&#xff0c;80%的时间决定使用什么语言。 事实上&#xff0c;编程语言如此之多&#xff0c;以至于我们不确定实际有多少种。据估计&#xff0c;至少有700种编程语…

【Android 11】AOSP Settings添加屏幕旋转按钮

前言 这里是客户要求添加按钮以实现屏幕旋转。屏幕旋转使用adb的命令很容易实现&#xff1a; #屏幕翻转 adb shell settings put system user_rotation 1 #屏幕正常模式 adb shell settings put system user_rotation 0这里的值可以是0&#xff0c;1&#xff0c;2&#xff0c…

中国天辰×蓝卓丨共创行业级工业操作系统,加速培育新质生产力!

6月17日&#xff0c;中国天辰工程有限公司&#xff08;以下简称“中国天辰”&#xff09;党委委员、总经理梁军湘一行莅临蓝卓&#xff0c;双方就工业互联网平台合作进行座谈交流。蓝卓总经理谭彰、副总经理蓝照斌、总经理助理俞益标&#xff0c;以及中控技术副总裁吴才宝、大客…

原生dom操作快速写入html渲染(insertAdjacentHTML)

// 旧方法 const btn document.createElement(div) btn.id material-btn-id btn.className material-btn btn.textContent 素材库 document.body.appendChild(btn) btn.addEventListener(click, () > {// 点击事件 }) // 新方法 const btn document.createElement(div…

软件开发小程序正规公司流程是什么样的?

正规软件开发的流程可以清晰地分为以下几个阶段&#xff0c;每个阶段都有其特定的目标和产出&#xff1a; 项目开发目的分析与确定&#xff1a; 此阶段主要是在软件开发商将开发项目确定下来之后&#xff0c;与需求方进行讨论&#xff0c;明确软件开发的目标及其具体需要实现…

NumPy 切片和索引

NumPy 切片和索引 NumPy 是 Python 中用于科学计算的核心库之一&#xff0c;它提供了一个强大的 N 维数组对象和许多用于操作这些数组的函数。在数据处理和数值计算中&#xff0c;切片和索引是常用的操作&#xff0c;它们允许我们有效地访问和修改数组的部分数据。本文将详细介…

调试的时候给打印加颜色

在 C 中&#xff0c;打印紫色文本通常涉及使用控制台的特定颜色输出。在大多数操作系统中&#xff0c;控制台颜色是通过特殊的转义序列来实现的。这些转义序列可以在输出文本之前插入&#xff0c;以改变文本的颜色、样式或其他属性。 使用 ANSI 转义序列 在 POSIX 兼容的系统&…

Altair 人工智能技术助力MABE预测消费者行为,实现设备性能优化

主要看点 行业&#xff1a; 家电行业 挑战&#xff1a; 企业面临的挑战是如何利用已收集的大量数据&#xff0c;深入了解消费者在产品使用过程中对某些保鲜程序的影响。 Altair 解决方案&#xff1a; Altair采用了Altair RapidMiner人工智能平台来解决问题&#xff0c;特别是…

QML Controls模块-标准对话框用法说明

文章目录 颜色对话框文件对话框字体对话框自定义对话框通知对话框在QML中,Qt提供了一个名为 QtQuick.Controls的模块,其中包含了一系列用户界面控件,可以用于创建现代化、响应式的用户界面。在QtQuick.Controls模块中,一些控件可以用来调用标准对话框,包括文件对话框、字体…

Java进阶示例

使用DataFrame和SQL查询处理数据 在Spark中&#xff0c;DataFrame是一种以结构化方式处理数据的强大工具&#xff0c;它允许用户以类似于SQL的方式操作数据&#xff0c;提供了比RDD更高的抽象层次和更好的性能。下面的示例将展示如何使用Spark SQL的DataFrame API来读取CSV数据…

docker安装消息队列mq中的rabbit服务

在现代化的分布式系统中&#xff0c;消息队列&#xff08;Message Queue, MQ&#xff09;已经成为了一种不可或缺的组件。RabbitMQ作为一款高性能、开源的消息队列软件&#xff0c;因其高可用性、可扩展性和易用性而广受欢迎。本文将详细介绍如何在Docker环境中安装RabbitMQ服务…

2024.6.18 刷题总结

2024.6.18 **每日一题** 2288.价格减免&#xff0c;这是一道纯字符串的题目&#xff0c;我们的目标是识别出字符串中的价格并将它替换为折扣后的数字。这道题利用了一些字符串的关键字&#xff1a; stringstream 是C标准库中的一个类&#xff0c;属于 <sstream> 头文件…