Tensorflow代码转pytorch代码 函数的转换

tensoflow函数和pytorch函数之间的转换

tensorflowpytroch
tf.reshape(input, shape)input.view()
tf.expand_dims(input, dim)input.unsqueeze(dim) / input.view()
tf.squeeze(input, dim)torch.squeeze(dim)/ input.view()
tf.gather(input1, input2)input1[input2]
tf.tile(input, shape)input.repeat(shape)
tf.boolean_mask(input, mask)input[mask] #注意,mask是bool值,不是0,1的数值
tf.concat(input1, input2)torch.cat(input1, input2)
tf.matmul()torch.matmul()
tf.minium(input, min)torch.clamp(input, max=min)
tf.equal(input1, input2)torch.eq(input1, input2)/ input1 == input2
tf.logical_and(input1, input2)input1 & input2
tf.logical_not(input) ~input
tf.reduce_logsumexp(input, [dim])torch.logsumexp(input, dim=dim)
tf.reduce_any(input, dim)input.any(dim)
tf.reduce_mean(input)torch.mean(input)
tf.reduce_sum(input)input.sum()
tf.transpose(input)input.t()
tf.softmax_cross_entroy_with_logits(logits, labels)torch.nn.CrossEntropyLoss(logits, labels)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/535175.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在服务器上远程使用tensorboard查看训练loss和准确率

本人使用的是vscode 很简单 from torch.utils.tensorboard import SummaryWriter writer SummaryWriter(./logs)writer.add_scalar(train_loss,loss.val(),iteration) # 名字,数据,迭代次数训练的过程中会产生一个./logs的文件夹,里面存放的…

python,pytorch:读取,保存,显示图片

文章目录一,Pytorch1. 直接保存Tensor2.Tensor 转CV2 保存二、python1. opencv2.matplotlib:3. PIL一,Pytorch 1. 直接保存Tensor #!/usr/bin/env python # _*_ coding:utf-8 _*_ import torch from torchvision import utils as vutilsdef save_image…

Python循环创建变量名

使用命名空间locals locals 中是当前程序段中的全部变量名是一个字典的形式 所以我们新增的话,直接和字典那样就行了 names locals() #获取当前程序段中的全体局部变量名 for i in np.arange(0,10):names[fname_{i}]i

pytorch:固定部分层参数,固定单个模型

文章目录固定部分层参数固定指定层的参数不同层设置不同的学习率固定部分层参数 class RESNET_attention(nn.Module):def __init__(self, model, pretrained):super(RESNET_attetnion, self).__init__()self.resnet model(pretrained) # 这个model被固定for p in self.parame…

图片拼接的几种方法

1. torch tensor 格式 from torchvision.utils import save_imageimg_train_list torch.cat([image_s,image_r,im_G[0],fake_A,mask_s[:, :, 0],mask_r[:, :, 0]])result_path self.save_path_dir/imageif not os.path.exists(result_path):os.makedirs(result_path)save_im…

Pytorch 各种模块:降低学习率,

1.训练过程中学习率衰减 if (self.e1) > (self.num_epochs - self.num_epochs_decay):g_lr - (self.g_lr / float(self.num_epochs_decay))d_lr - (self.d_lr / float(self.num_epochs_decay))self.update_lr(g_lr, d_lr)print(Decay learning rate to g_lr: {}, d_lr:{}..…

cudnn.deterministic = True 固定随机种子

随机数种子seed确定时,模型的训练结果将始终保持一致。 随机数种子seed确定时使用相同的网络结构,跑出来的效果完全不同,用的学习率,迭代次数,batch size 都是一样。 torch.backends.cudnn.deterministic是啥&#x…

torch.backends.cudnn.benchmark 加速训练

设置 torch.backends.cudnn.benchmarkTrue 将会让程序在开始时花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。适用场景是网络结构固定(不是动态变化的),网络的输入形状&…

各种损失损失函数的使用场景和使用方法:KL散度

KL 散度的使用场景 KL散度( Kullback–Leibler divergence),又称相对熵,是描述两个概率分布 P 和 Q 差异的一种方法 torch.nn.functional.kl_div(input, target, size_averageNone, reduceNone, reductionmean) torch.nn.KLDivLoss(input, target, si…

RNN,LSTM,GRU的理解

RNN x 为当前状态下数据的输入, h 表示接收到的上一个节点的输入。 y为当前节点状态下的输出,而h′h^\primeh′为传递到下一个节点的输出. LSTM #定义网络 lstm nn.LSTM(input_size20,hidden_size50,num_layers2) #输入变量 input_data Variable(tor…

常用的loss函数,以及在训练中的使用

文章目录KL 散度L2 loss做标准化处理CElossCTCLossAdaptiveAvgPool2dKL 散度 算KL散度的时候要注意前后顺序以及加log import torhch.nn as nn d_loss nn.KLDivLoss(reductionreduction_kd)(F.log_softmax(y / T, dim1),F.softmax(teacher_scores / T, dim1)) * T * T蒸馏lo…

Shell 在训练模型的时候自动保存训练文件和模型到指定文件夹

在进行深度学习训练的过程中,往往会跑很多实验,这就导致有的实验设置会忘记或者记混淆,我们最好把train test model 的代码都copy一遍到指定文件夹中,这样后面检查也方便。 用shell指令保存文件 #!/bin/sh GRUB_CMDLINE_LINUX&qu…

Pytorch:数据并行和模型并行,解决训练过程中内存分配不均衡的问题

文章目录数据并行单机多卡训练,即并行训练。并行训练又分为数据并行 (Data Parallelism) 和模型并行两种。 数据并行指的是,多张 GPU 使用相同的模型副本,但是使用不同的数据批进行训练。而模型并行指的是,多张GPU 分别训练模型的…

DataParallel 和 DistributedDataParallel 的区别和使用方法

1.DataParallel DataParallel更易于使用(只需简单包装单GPU模型)。 model nn.DataParallel(model)它使用一个进程来计算模型参数,然后在每个批处理期间将分发到每个GPU,然后每个GPU计算各自的梯度,然后汇总到GPU0中…

torch.cuda.is_available(),torch.cuda.device_count(),torch.cuda.get_device_name(0)

torch.cuda.is_available() cuda是否可用; torch.cuda.device_count() 返回gpu数量; torch.cuda.get_device_name(0) 返回gpu名字,设备索引默认从0开始; torch.cuda.current_device() 返回当前设备索引;

windows, 放方向键设置为vim格式,autohotkey-windows

安装 Autohotkey https://www.autohotkey.com/download/ 设置快捷键 随便找个目录,鼠标右键新建一个autohotkey的脚本。 映射一个键——上左下右 经常打字的人都知道,我们编辑文本时要上下左右移动光标,难免要将手移到方向键再移回来打字。对我这样的懒癌后期患者,这简直不能…

window设置快捷键左右方向键

autohotkey-windows快捷键设置神器 使用方法 地址

Hbase数据模型及Hbase Shell

目录 1 数据模型 1.1 相关名词概念 1.2 模型分析 2 Hbase Shell操作 2.1 命名空间 2.2 表操作 2.2.1 创建表 2.2.2 更改表结构 2.2.3 表的其他操作 2.3 数据操作 2.3.1 添加数据(put) 2.3.2 删除数据(delete) 2.3.3 获取数据(get|scan) 3 过滤器 3.1 比较运算符…

非关型数据库之Hbase

目录 1 Hbase简介 1.1 初识Hbase 1.2 Hbase的特性 2 HDFS专项模块 2.1 HDFS的基本架构 2.1.1 HDFS各组件的功能: 2.2 HFDFS多种机制 2.2.1 分块机制 2.2.2 副本机制 2.2.3 容错机制 2.2.4 读写机制 3 Hbase组件及其功能 3.1 客户端 3.2 Zookeeper 3.3 …

MongoDB Shell操作

目录 1 数据库操作 2 集合操作 3 文档操作 3.1 插入文档(insert|insertOne|insertMany) 3.2插入、删除的循环操作 3.2 删除文档(remove|deleteOne|deleteMany) 3.3 更新文档(update|save) 3.4 查询文档(find) 4 游标 5 索引 6 聚合 1 数据库操作 当新创建的数据库里…