tensoflow函数和pytorch函数之间的转换
tensorflow | pytroch |
---|---|
tf.reshape(input, shape) | input.view() |
tf.expand_dims(input, dim) | input.unsqueeze(dim) / input.view() |
tf.squeeze(input, dim) | torch.squeeze(dim)/ input.view() |
tf.gather(input1, input2) | input1[input2] |
tf.tile(input, shape) | input.repeat(shape) |
tf.boolean_mask(input, mask) | input[mask] #注意,mask是bool值,不是0,1的数值 |
tf.concat(input1, input2) | torch.cat(input1, input2) |
tf.matmul() | torch.matmul() |
tf.minium(input, min) | torch.clamp(input, max=min) |
tf.equal(input1, input2) | torch.eq(input1, input2)/ input1 == input2 |
tf.logical_and(input1, input2) | input1 & input2 |
tf.logical_not(input) ~ | input |
tf.reduce_logsumexp(input, [dim]) | torch.logsumexp(input, dim=dim) |
tf.reduce_any(input, dim) | input.any(dim) |
tf.reduce_mean(input) | torch.mean(input) |
tf.reduce_sum(input) | input.sum() |
tf.transpose(input) | input.t() |
tf.softmax_cross_entroy_with_logits(logits, labels) | torch.nn.CrossEntropyLoss(logits, labels) |