1. 基本数学运算
1.1 平方根和幂运算
import torchx = torch. tensor( [ 4.0 , 9.0 , 16.0 ] )
sqrt_x = torch. sqrt( x)
square_x = torch. square( x)
pow_x = torch. pow ( x, 3 )
sqrt_x_alt = x ** 0.5
square_x_alt = x ** 2
1.2 指数和对数
exp_x = torch. exp( x)
log_x = torch. log( x)
log10_x = torch. log10( x)
safe_log = torch. log( x + 1e-8 )
2. 统计运算
2.1 求和与均值
x = torch. randn( 3 , 4 )
total = torch. sum ( x)
sum_dim0 = torch. sum ( x, dim= 0 )
sum_dim1 = torch. sum ( x, dim= 1 )
mean_val = torch. mean( x)
mean_dim0 = torch. mean( x, dim= 0 )
2.2 极值与排序
max_val = torch. max ( x)
min_val = torch. min ( x)
max_vals, max_indices = torch. max ( x, dim= 1 )
min_vals, min_indices = torch. min ( x, dim= 0 )
sorted_vals, sorted_indices = torch. sort( x, dim= 1 , descending= True )
2.3 方差与标准差
var_x = torch. var( x, unbiased= True )
var_dim0 = torch. var( x, dim= 0 )
std_x = torch. std( x)
std_dim1 = torch. std( x, dim= 1 )
3. 矩阵运算
3.1 基本矩阵运算
A = torch. randn( 3 , 4 )
B = torch. randn( 4 , 5 )
matmul = torch. matmul( A, B)
matmul_alt = A @ B
v1 = torch. randn( 3 )
v2 = torch. randn( 3 )
dot_product = torch. dot( v1, v2)
batch_A = torch. randn( 5 , 3 , 4 )
batch_B = torch. randn( 5 , 4 , 5 )
batch_matmul = torch. bmm( batch_A, batch_B)
3.2 矩阵分解
sym_matrix = torch. randn( 3 , 3 )
sym_matrix = sym_matrix @ sym_matrix. T
eigenvals, eigenvecs = torch. linalg. eigh( sym_matrix)
U, S, V = torch. linalg. svd( A)
4. 比较运算
4.1 元素级比较
a = torch. tensor( [ 1 , 2 , 3 ] )
b = torch. tensor( [ 3 , 2 , 1 ] )
eq = torch. eq( a, b)
gt = torch. gt( a, b)
lt = torch. lt( a, b)
eq_alt = a == b
gt_alt = a > b
4.2 约简比较
all_true = torch. all ( eq)
any_true = torch. any ( gt)
torch. equal( a, b)
5. 规约运算
5.1 常用规约
x = torch. randn( 2 , 3 )
sum_all = x. sum ( )
sum_dim = x. sum ( dim= 1 )
cumsum = x. cumsum( dim= 0 )
prod_all = x. prod( )
5.2 高级规约
weights = torch. softmax( torch. randn( 3 ) , dim= 0 )
weighted_mean = torch. sum ( x * weights, dim= 1 )
logsumexp = torch. logsumexp( x, dim= 1 )
6. 工程实践建议
6.1. 广播机制理解:确保运算张量的形状兼容
a = torch. randn( 3 , 1 )
b = torch. randn( 1 , 3 )
c = a + b
6.2. 原地操作:使用_后缀节省内存
x. sqrt_( )
x. add_( 1 )
6.3. 设备一致性:确保运算张量在同一设备
if torch. cuda. is_available( ) : x = x. cuda( ) y = y. cuda( ) z = x + y
6.4. 梯度保留:注意运算对计算图的影响
x = torch. tensor( 2.0 , requires_grad= True )
y = x ** 2
y. backward( )
6.5. 数值稳定性:使用稳定实现
unstable = torch. exp( x) / torch. exp( x) . sum ( dim= 1 , keepdim= True )
stable = torch. softmax( x, dim= 1 )
7. 性能优化技巧
7.1 向量化操作:避免Python循环
result = torch. zeros_like( x)
for i in range ( x. size( 0 ) ) : result[ i] = x[ i] * 2
result = x * 2
7.2. 融合操作:减少中间结果
temp = x + y
result = temp * z
result = ( x + y) * z
7.3. 使用内置函数:利用优化实现
custom_norm = torch. sqrt( torch. sum ( x ** 2 ) )
optimized_norm = torch. norm( x)