Dropout是一种常用的正则化方法,用于减少神经网络的过拟合现象。它的基本思想是在训练神经网络的过程中,随机地将一部分神经元的输出值置为0,从而使得神经网络的结构变得不稳定,从而强制网络学习到更加鲁棒的特征表示。
haiku 自带的dropout模块,使用方法如下:
haiku.dropout(rng, rate, x, broadcast_dims=())
Parameters
-
rng (PRNGKey) – A JAX random key.
-
rate (float) – Probability that each element of
x
is discarded. Must be a scalar in the range[0, 1)
. -
x (jax.Array) – The value to be dropped out.
-
broadcast_dims (Sequence[int]) – specifies dimensions that will share the same dropout mask.
Return type: jax.Array
Returns: x, but dropped out and scaled by 1 / (1 - rate)
.
自定义dropout模块,示例代码如下:
import haiku as hk
import jax
import jax.numpy as jnp### 1. 定义dropout模块
class MyDropout(hk.Module):def __init__(self, rate=0.5, name=None):super().__init__(name=name)self.rate = ratedef __call__(self, x, is_training=True):if is_training:# 使用 hk.next_rng_key() 生成一个新的随机数生成器种子key = hk.next_rng_key()mask = jax.random.bernoulli(key, self.rate, shape=x.shape)# 对结果进行缩放,目的是为了在训练模式下保持输入的期望值return x * mask / (1.0 - self.rate)else:return x### 2. 定义模型运行的函数
def forward(x, is_training):dropout = MyDropout(name="my_dropout")return dropout(x, is_training=is_training)### 3. hk.transform初始化模块运行函数
transform = hk.transform(forward)### 4. 初始化模块参数
x_train = jnp.ones((10, 10))
rng = jax.random.PRNGKey(123)
params = transform.init(rng, x_train, is_training=True)### 5. 应用模块,得到输出
x_train_dropout = transform.apply(params, rng, x_train, is_training=True)print("Training Mode:")
print(x_train_dropout)
参考:
https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=dropout#dropout