tch-rs指南 - Tensor的基本操作

文章目录

    • 1 概述
    • 2 Tensor的基本操作
      • 2.1 Tensor的初始化
        • (1)通过数组创建
        • (2)通过默认方法创建
        • (3)通过其他的`tensor`创建
        • (4)通过`opencv::core::Mat`创建
      • 2.2 Tensor的属性
      • 2.3 Tensor的运算
        • (1)改变device
        • (2)获取值(indexing and slicing)
        • (3)合并tensors
        • (4)四则运算
    • 参考资料

1 概述

在使用rust进行torch模型部署时,不可避免地会用到tch-rs。但是tch-rs的文档太过简洁,和没有一样,网上的资料也少得可怜,很多操作需要我们自己去试。这些内容虽然简单,但是自己找起来很费时间。

这篇文章总结了如何使用tch-rs进行tensor的基本操作。讲述的内容参考了pytorch的tensor教程。

运行环境:

[dependencies]
tch = "0.7.0"
opencv = "0.63"

2 Tensor的基本操作

用到的库

use std::iter;use opencv::prelude::*;
use opencv::core::{Mat, Scalar};
use opencv::core::{CV_8UC3};
use tch::IndexOp;
use tch::{Device, Tensor};

2.1 Tensor的初始化

(1)通过数组创建

let t = Tensor::of_slice::<i32>(&[1, 2, 3, 4, 5]);
t.print();
// vector也是一样的
let v = vec![1,2,3];
let t = Tensor::of_slice::<i32>(&v);
t.print();
// 2d vector
let v = vec![[1.5,2.0,3.9,4.4], [3.1,4.3,5.1,6.9]];
let v:Vec<f32> = v.iter().flat_map(|array| array.iter()).cloned().collect();
let data = unsafe{std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * std::mem::size_of::<f32>())
};
let t = Tensor::of_data_size(data, &[2,4], tch::Kind::Float);
t.print();

print的结果是

 12345
[ CPUIntType{5} ]123
[ CPUIntType{3} ]1.5000  2.0000  3.9000  4.40003.1000  4.3000  5.1000  6.9000
[ CPUFloatType{2,4} ]

(2)通过默认方法创建

let t = Tensor::randn(&[2, 3], (tch::Kind::Float, Device::Cpu));
t.print();
let t = Tensor::ones(&[2, 3], (tch::Kind::Float, Device::Cpu));
t.print();
let t = Tensor::zeros(&[2, 3], (tch::Kind::Float, Device::Cpu));
t.print();
let t = Tensor::arange_start(0, 2 * 3, (tch::Kind::Float, Device::Cpu)).view([2, 3]);
t.print();

print的结果是

 1.0522  0.6981  0.92360.2324 -1.1048 -2.5820
[ CPUFloatType{2,3} ]1  1  11  1  1
[ CPUFloatType{2,3} ]0  0  00  0  0
[ CPUFloatType{2,3} ]0  1  23  4  5
[ CPUFloatType{2,3} ]

(3)通过其他的tensor创建

let t = Tensor::randn(&[2, 3], (tch::Kind::Float, Device::Cpu));
let t = t.rand_like();
t.print();

print的结果是

 0.3376  0.1885  0.34150.5135  0.8321  0.4140
[ CPUFloatType{2,3} ]

(4)通过opencv::core::Mat创建

这可以用在opencv读取图像后,转为torch tensor。当然tch-rs本身也有各种读取图片的方式,可见tch::vision::image。这里介绍两种方法,一种通过tch::Tensor::f_of_blob,一种通过tch::Tensor::of_data_size

// 创建一个(row, col, channel)=(2, 3, 3)=(height, width, channel)的Mat
let mat = Mat::new_rows_cols_with_default(2, 3, CV_8UC3, Scalar::from((3.0, 2.0, 1.0))
).unwrap();
// 获取mat的size,这里的结果是[2, 3, 3]
let size: Vec<_> = mat.mat_size().iter().cloned().map(|dim| dim as i64).chain(iter::once(mat.channels() as i64)).collect();
// 获取每个dimension的stride,这里的结果是[9, 3, 1]
let strides = {let mut strides: Vec<_> = size.iter().rev().cloned().scan(1, |prev, dim| {let stride = *prev;*prev *= dim;Some(stride)}).collect();strides.reverse();strides
};
// 构建tensor
let t = unsafe {let ptr = mat.ptr(0).unwrap() as *const u8;tch::Tensor::f_of_blob(ptr, &size, &strides, tch::Kind::Uint8, tch::Device::Cpu).unwrap()
};
t.print();

print的结果是

(1,.,.) = 3  2  13  2  13  2  1(2,.,.) = 3  2  13  2  13  2  1
[ CPUByteType{2,3,3} ]

还有一种比较简洁的转换方法

let mut mat = Mat::new_rows_cols_with_default(2, 3, CV_8UC3, Scalar::from((3.0, 2.0, 1.0))
).unwrap();
let h = mat.size().unwrap().height;
let w = mat.size().unwrap().width;   
let data = mat.data_bytes_mut().unwrap(); 
let t = tch::Tensor::of_data_size(data, &[h as i64, w as i64, 3], tch::Kind::Uint8);
t.print();

print的结果也是

(1,.,.) = 3  2  13  2  13  2  1(2,.,.) = 3  2  13  2  13  2  1
[ CPUByteType{2,3,3} ]
test tensor_ops::init_ops ... ok

2.2 Tensor的属性

用tch::Tensor的print()方法可打印出数据的所有属性,但是想要获取到这些属性,需要用其他的方法。

let t = Tensor::randn(&[2, 3], (tch::Kind::Float, Device::Cpu));
println!("size of the tensor: {:?}", t.size());
println!("kind of the tensor: {:?}", t.kind());
println!("device on which the tensor is located: {:?}", t.device());

打印的结果是

size of the tensor: [2, 3]
kind of the tensor: Float
device on which the tensor is located: Cpu

2.3 Tensor的运算

(1)改变device

.to().to_device()这两个方法都可以。

let mut t = Tensor::randn(&[2, 3], (tch::Kind::Float, Device::Cpu));
if tch::Cuda::is_available(){t = t.to(Device::Cuda(0));println!("change device to {:?}", t.device());
}
t = t.to_device(Device::Cpu);
println!("change device to {:?}", t.device());

如果是有cuda,且安装了cuda版本的tch-rs的话,就会打印出

change device to Cuda(0)
change device to Cpu

(2)获取值(indexing and slicing)

这个在tch-rs的例子中有很多,详见tests/tensor_indexing.rs。这里列几种常用的。

通过.i()进行索引

let tensor = Tensor::arange_start(0, 2 * 3, (tch::Kind::Float, Device::Cpu)).view([2, 3]);
println!("original tensor:");
tensor.print();
println!("tensor.i(0):");
tensor.i(0).print();
println!("tensor.i((1, 1)):");
tensor.i((1, 1)).print();
println!("tensor.i((.., 2)):");
tensor.i((.., 2)).print();
println!("tensor.i((.., -1)):");
tensor.i((.., -1)).print();
println!("tensor.i((.., [2, 0])):");
let index: &[_] = &[2, 0];
tensor.i((.., index)).print();

打印的结果是

original tensor:0  1  23  4  5
[ CPUFloatType{2,3} ]
tensor.i(0):012
[ CPUFloatType{3} ]
tensor.i((1, 1)):
4
[ CPUFloatType{} ]
tensor.i((.., 2)):25
[ CPUFloatType{2} ]
tensor.i((.., -1)):25
[ CPUFloatType{2} ]
tensor.i((.., [2, 0])):2  05  3
[ CPUFloatType{2,2} ]

通过.index()进行索引

let tensor = Tensor::arange(6, (tch::Kind::Int64, Device::Cpu)).view((2, 3));
println!("original tensor:");
tensor.print();
let rows_select = Tensor::of_slice(&[0i64, 1, 0]);
let column_select = Tensor::of_slice(&[1i64, 2, 2]);
let selected = tensor.index(&[Some(rows_select), Some(column_select)]);
println!("selecte by row and column:");
selected.print();

打印的结果是

original tensor:0  1  23  4  5
[ CPULongType{2,3} ]
selecte by row and column:152
[ CPULongType{3} ]

(3)合并tensors

Tensor::f_cat不会生成新的axis,而Tensor::stack会生成新的axis。

let t1 = Tensor::arange(6, (tch::Kind::Int64, Device::Cpu)).view((2, 3));
let t2 = Tensor::arange_start(6, 12, (tch::Kind::Int64, Device::Cpu)).view((2, 3));
let tensor = Tensor::f_cat(&[t1.copy(), t2.copy()], 1).unwrap();
println!("using Tensor::f_cat");
tensor.print();
let tensor = Tensor::stack(&[t1.copy(), t2.copy()], 1);
println!("using Tensor::stack");
tensor.print();

打印的结果是

using Tensor::f_cat0   1   2   6   7   83   4   5   9  10  11
[ CPULongType{2,6} ]
using Tensor::stack
(1,.,.) = 0  1  26  7  8(2,.,.) = 3   4   59  10  11
[ CPULongType{2,2,3} ]

(4)四则运算

tch-rs对[+, -, *, /]都进行了重载,可以实现和标量的直接运算。涉及到dim的复杂运算可以用tensor来处理。下面以加法为例,其他与f_add对应的分别是f_subf_mulf_div

let tensor = Tensor::ones(&[2, 4, 3], (tch::Kind::Float, Device::Cpu));
tensor.print();
// add with scalar
let add_tensor = &tensor + 0.5;
add_tensor.print();
// add with tensor
let add_tensor = Tensor::of_slice::<f32>(&[1.0,2.0,3.0]).view((1,1,3));
let add_tensor = &tensor.f_add(&add_tensor).unwrap();
add_tensor.print();

打印的结果为

original tensor:
(1,.,.) = 1  1  11  1  11  1  11  1  1(2,.,.) = 1  1  11  1  11  1  11  1  1
[ CPUFloatType{2,4,3} ]
add with scalar:
(1,.,.) = 1.5000  1.5000  1.50001.5000  1.5000  1.50001.5000  1.5000  1.50001.5000  1.5000  1.5000(2,.,.) = 1.5000  1.5000  1.50001.5000  1.5000  1.50001.5000  1.5000  1.50001.5000  1.5000  1.5000
[ CPUFloatType{2,4,3} ]
add with tensor:
(1,.,.) = 2  3  42  3  42  3  42  3  4(2,.,.) = 2  3  42  3  42  3  42  3  4
[ CPUFloatType{2,4,3} ]

参考资料

[1] https://github.com/LaurentMazare/tch-rs
[2] https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html#

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/470507.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

命令行运行jmeter脚本

1、通过gui界面的jmeter创建一份脚本&#xff1b;2、打开cmd,切换到jmeter程序的Bin目录&#xff1b;3、执行jmeter.bat -n -t bookair_0613.jmx -l log_3.jtl&#xff1b;4、使用gui界面添加一个监听器&#xff0c;打开log_3.jtl文件&#xff0c;来分析测试结果。转载于:https…

bootstrap table 分页_Java入门007~springboot+freemarker+bootstrap快速实现分页功能

本节是建立在上节的基础上&#xff0c;上一节给大家讲了管理后台表格如何展示数据&#xff0c;但是当我们的数据比较多的时候我们就需要做分页处理了。这一节给大家讲解如何实现表格数据的分页显示。准备工作1&#xff0c;项目要引入freemarker和bootstrap&#xff0c;如果不知…

Rust小技巧 - 通过FFI编程运行tensorrt模型

文章目录1 概述2 使用说明2.1 配置说明2.2 修改c头文件2.3 编写build.rs2.4 测试参考资料1 概述 shouxieai/tensorRT_Pro是一个文档完善&#xff0c;效果也很不错的tensorrt库&#xff0c;里面有对yolov5&#xff0c;yolox&#xff0c;unet&#xff0c;bert&#xff0c;retina…

1+X web中级 Laravel学习笔记——查询构造器简介及新增、更新、删除、查询数据

一、新增数据 插入多条数据&#xff1a; 二、更新数据 更新某条数据&#xff1a; 自增某字段的值&#xff1a; 自减某字段的值&#xff1a; 自增的同时改变其他字段的值&#xff1a; 三、删除数据 四、查询 查面构造器查面数据 有以下几种方法 get&#xff08;&…

【HTML5】Canvas画布

什么是 Canvas&#xff1f; HTML5 的 canvas 元素使用 JavaScript 在网页上绘制图像。 画布是一个矩形区域&#xff0c;您可以控制其每一像素。 canvas 拥有多种绘制路径、矩形、圆形、字符以及添加图像的方法。 * 添加 canvas 元素。规定元素的 id、宽度和高度&#xff1a; &l…

SynthText流程解读 - 不看代码不知道的那些事

文章目录1 概述2 流程解读2.1 生成文字mask2.2 plane2xyz的bug2.3 文字上色2.4 图像融合参考资料1 概述 SynthText是OCR领域生成数据集非常经典&#xff0c;且至今看来无人超越的方法。整体可以分为三个大的步骤&#xff0c;分别是生成文字的mask&#xff0c;这里用到了图像的…

python if name main 的作用_Python中if __name__ == '__main__':的作用和原理

if __name__ __main__:的作用 一个python文件通常有两种使用方法&#xff0c;第一是作为脚本直接执行&#xff0c;第二是 import 到其他的 python 脚本中被调用&#xff08;模块重用&#xff09;执行。因此 if __name__ main: 的作用就是控制这两种情况执行代码的过程&#x…

1+X web中级 Laravel学习笔记——Eloquent ORM查询、更新、删除、新增

Eloquent ORM简介 larave1所自带的Eloquent oRM是一个非常优美简洁的ActiveRecord实现&#xff0c;用来实现数据库的操作他的每个数据的表都有对应的模型&#xff08;model&#xff09;用于数据表的交互模型的建立 一、Eloquent ORM的查询 二、Eloquent ORM新增 通过模型新增…

使用复合设计模式扩展持久化的CURD,Select能力

大家可能会经常遇到接口需要经常增加新的方法和实现&#xff0c;可是我们原则上是不建议平凡的增加修改删除接口方法&#xff0c;熟不知这样使用接口是不是正确的接口用法&#xff0c;比如我见到很多的项目分层都是IDAL&#xff0c;DAL&#xff0c;IBLL&#xff0c;BLL&#xf…

python脚本加密_教你如何基于python实现脚本加密

这篇文章主要介绍了如何基于python实现 脚本加密,文中通过示例代码介绍的非常详细&#xff0c;对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 from pathlib import Path import python_minifier import compileall import sys def get_save_path(from_dir,…

1+X web中级 Laravel学习笔记——blade模版

一、blade模版简介 Blade是larave1提供的一个既简单又强大的模版引擎和其他的流行的php模版引擎不一样&#xff0c;blade并不限制你在视图&#xff08;view&#xff09;中使用原生php代码 二、 模版继承 section yield extents parent 三、基本语法以及include的使用 1…

matplotlib 散点图_matplotlib画图 绘制散点图案例

假设通过爬虫你获取到了北京2016年3,10月份 每天白天的最高气温(分别位于列表a,b), 那么此时如何寻找出气温和随时间(天)变化的某种规律? a [11,17,16,11,12,11,12,6,6,7,8,9,12,15,14,17,18,21,16,17,20,14,15,15,15,19,21,22,22,22,23] b [26,26,28,19,21,17,16,19,18,20,…

深度学习基础-1

文章目录0 前言1 图像分类简介1.1 什么是图像分类1.2 图像分类任务的难点1.3 分类任务的评价指标1.3.1 Accuracy1.3.2 Precision和Recall1.3.3 F1 Score1.4 分类图像模型总体框架2 线性分类器2.1 图像的表示方法2.2 Cifar10数据集介绍2.3 分类算法输入2.4 线性分类器3 损失函数…

一、PHP基础——表单传值、上传文件

表单传值 概念: 表单传值即浏览器通过表单元素将用户的选择或者输入的数据提交给后台服务器语言。 为什么使用表单传值? 动态网站&#xff08;Web2.0&#xff09;的特点就是后台根据用户的需求定制数据&#xff0c;所谓的“需求”就是用户通过当前的选择或者输入的数据信息&a…

python dataframe 列_python pandas库中DataFrame对行和列的操作实例讲解

用pandas中的DataFrame时选取行或列&#xff1a; import numpy as np import pandas as pd from pandas import Sereis, DataFrame ser Series(np.arange(3.)) data DataFrame(np.arange(16).reshape(4,4),indexlist(abcd),columnslist(wxyz)) data[w] #选择表格中的w列&…

利用微信搜索抓取公众号文章(转载)

来源&#xff1a;http://www.shareditor.com/blogshow/44 自动收集我关注的微信公众号文章 2016.7.14 更新 搜狐微信增加对referer验证 var page require(webpage).create();page.customHeaders{"referer":"http://weixin.sogou.com/weixin?oq&query关键词…

二、PHP基础——连接msql数据库进行增删改查操作 实战:新闻管理项目

Mysql扩展 PHP针对MySQL数据库操作提供的扩展&#xff1a;允许PHP当做MySQL的一个客户端连接服务器进行操作。 连库基本操作 连接数据库服务器 1&#xff09;资源 mysql_connect(服务器地址&#xff0c;用户名&#xff0c;密码) 连接资源默认也是超全局的&#xff0c;任何地方都…

深度学习基础-2

文章目录0 前言1 全连接神经网络2 激活函数2.1 Sigmoid2.2 Tanh2.3 ReLU2.4 Leaky ReLU3 交叉熵损失4 计算图与反向传播4.1 计算图4.2 梯度消失与梯度爆炸4.3 动量法5 权重初始化5.1 全零初始化5.2 标准随机初始化5.3 Xavier初始化5.4 Kaming初始化6 批归一化7 参考资料0 前言 …

三、PHP基础——HTTP协议 文件编程

一、HTTP协议初步认识 HTTP协议概念 HTTP协议&#xff0c;即超文本传输协议(Hypertext transfer protocol)。是一种详细规定了浏览器和万维网(WWW World Wide Web)服务器之间互相通信的规则&#xff0c;通过因特网传送万维网文档的数据传送协议。 HTTP协议是用于从WWW服务器传…