深入学习 torch.distributions

0. 引言

前几天分几篇博文精细地讲述了《von Mises-Fisher 分布》, 以及相应的 PyTorch 实现《von Mises-Fisher Distribution (代码解析)》, 其中以 Uniform 分布为例简要介绍了 torch.distributions 包的用法. 本以为已经可以了, 但这两天看到论文 The Power Spherical distribution 的代码, 又被其实现分布的方式所吸引.

Power Spherical 分布与 von Mises Fisher 分布类似, 只不过将后者概率密度函数中的指数函数换成了多项式函数: f p ( x ; μ , κ ) ∝ e x p ( κ μ ⊺ x ) ⇓ f p ( x ; μ , κ ) ∝ ( 1 + μ ⊺ x ) κ \begin{aligned} f_p(\bm{x}; \bm{\mu}, \kappa) &\propto exp(\kappa \bm{\mu}^\intercal \bm{x}) \\ &\Downarrow\\ f_p(\bm{x}; \bm{\mu}, \kappa) &\propto (1+\bm{\mu}^\intercal \bm{x})^\kappa \\ \end{aligned} fp(x;μ,κ)fp(x;μ,κ)exp(κμx)(1+μx)κ 采样框架基本一致, 且这么做可以使边缘 t t t 的线性变换 t + 1 2 ∼ B e t a ( p − 1 2 + κ , p − 1 2 ) \frac{t+1}{2} \sim Beta(\frac{p-1}{2}+\kappa, \frac{p-1}{2}) 2t+1Beta(2p1+κ,2p1), 从而避免了接受-拒绝采样过程.

当然, 按照之前的 VonMisesFisher 的写法, 这个 t 的采样大概是这样:

z = beta.sample(sample_shape)
t = 2 * z - 1

但现在我遇到了这种写法:

class MarginalTDistribution(tds.TransformedDistribution):arg_constraints = {'dim': constraints.positive_integer,'scale': constraints.positive,}has_rsample = Truedef __init__(self, dim, scale, validate_args=None):self.dim = dimself.scale = scalesuper().__init__(tds.Beta(  # 用 Beta 分布转换, z 服从 Beta(α+κ,β)(dim - 1) / 2 + scale, (dim - 1) / 2, validate_args=validate_args),transforms=tds.AffineTransform(loc=-1, scale=2),  # t=2z-1 是想要的边缘分布随机数)

然后就可以进行对 t t t 的采样了.

架构大概是这样的: 一个基本分布类 distributions.Beta 和一个转换 transforms.AffineTransform, 输入到 TransformedDistribution 的子类 MarginalTDistribution 中, 通过对一个 B e t a Beta Beta 的线性转换, 实现边缘分布 t t t.

图1: 上述代码的解构图. 浅蓝色代表抽象基类, 绿色代表实类; 虚线代表继承, 实线代表参数输入

我们可以看到其基本架构, 本文将详细解析其内部的具体细节, 包括:

1. Distribution

在之前的 <von Mises-Fisher Distribution (代码解析)> 中, 已经通过 Uniform 简单介绍了 Distribution 的用法. 它是实现各种分布的抽象基类. 本文将以解析源码的方式详细介绍.

1.1 参数验证 validate_args

打开源码, 首先映入眼帘的是关于参数验证的代码:

# true if Python was not started with an -O option. See also the assert statement.
_validate_args = __debug__@staticmethod
def set_default_validate_args(value: bool) -> None:"""设置 validation 是否开启.validation 通常是耗时的, 所以最好在模型 work 后关闭它."""if value not in [True, False]:raise ValueErrorDistribution._validate_args = value

Distribution 有一个类属性叫 _validate_args, 默认值是 __debug__(见附录1), 可以通过类静态方法 set_default_validate_args(value: bool) 来修改此值.

构造方法 __init__(...) 中的验证逻辑:

def __init__(self, ..., validate_args: Optional[bool]=None):...if validate_args is not None:self._validate_args = validate_args

也就是说, 你可以在创建 Distribution 实例的时候设置是否进行参数验证. 如果不设置, 则按照类的属性 Distribution._validate_args.

if self._validate_args:  # validate_args=False 就不用设置 arg_constraints 了try:  # 尝试获取字典 arg_constraintsarg_constraints = self.arg_constraintsexcept NotImplementedError:  # 如果没设置, 则设置为 {}, 抛出警告arg_constraints = {}warnings.warn(...)

如果需要验证参数, 那么首先要获取一个叫 arg_constraints 的参数验证字典, 它列出了需要验证哪些参数. 这个抽象类里面并没有给出, 需要用户继承该类时写在子类中. 以 Uniform 为例:

class Uniform(Distribution):...arg_constraints = {"low": constraints.dependent(is_discrete=False, event_dim=0),"high": constraints.dependent(is_discrete=False, event_dim=0),}...

至于 constraints.dependent 是啥, 后面会详细介绍. 值得注意的是, 如果你在创建实例时指定 validate_args=False, 那么所有关于参数验证的事就都不用管了.

for param, constraint in arg_constraints.items():if constraints.is_dependent(constraint):continue  # skip constraints that cannot be checkedif param not in self.__dict__ and isinstance(getattr(type(self), param), lazy_property):continue  # skip checking lazily-constructed argsvalue = getattr(self, param)  # 从当前对象获取参数 valuevalid = constraint.check(value)  # 检查参数值if not valid.all():  # 检查不通过raise ValueError(...)

这一段就是验证过程了, 包括:

  • skip constraints that cannot be checked, 由 constraints.is_dependent(constraint) 判断是否可验证;
  • skip checking lazily-constructed args, 即参数名不在 self.__dict__ 中, 并属于 lazy_property 的跳过;
  • 获得参数, 进行验证;

具体的验证细节将在后面介绍.

1.2 batch_shape & event_shape

除了 validate_args 参数, __init__(...) 方法中的另外两个参数就是:

def __init__(self,batch_shape: torch.Size = torch.Size(),event_shape: torch.Size = torch.Size(),
):self._batch_shape = batch_shapeself._event_shape = event_shape...

这两个参数是啥? 在这个抽象类中, 我们看不到太多信息, 甚至 Uniform 中也只有 batch_shape = self.low.size() 的信息, 大概意思同时进行着一批的均匀分布, 如 low = torch.tensor([0.0, 1.0]) 时, batch_shape = torch.Size([2]), 表示一个二元的均匀分布. 看 MultivariateNormal, 里面信息量较大:

batch_shape = torch.broadcast_shapes(covariance_matrix.shape[:-2],  # [:-2]是去掉了协方差矩阵的维度, 剩下的可能是 batch 的维度loc.shape[:-1]  # [:-1]是去掉了 envent 的维度, 剩下的可能是 batch 的维度
)  # broadcast_shapes 意思是进行了广播, 如果 matrix 的 batch_shape 是 [2,1], loc 的 batch_shape 是 [1,2], 那么整个的 batch_shape 是广播后的 [2,2]
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))  # 之后 covariance_matrix 都被 expand 了
...
event_shape = self.loc.shape[-1:]  # 看来就是样本的 shape

从这一段来看, batch_shape 是指创建的实例在进行多少个平行的基本分布, 而 event_shape 是指基本分布的事件(支撑点)维度. 如:

locs = torch.randn(2, 3)
matrixs = torch.randn(2, 3, 3)
covariance_matrixs = torch.bmm(matrixs, matrixs.transpose(1, 2))normal = distributions.MultivariateNormal(loc=locs, covariance_matrix=covariance_matrixs)
print(normal.batch_shape)  # 2
print(normal.event_shape)  # 3
print(normal.sample())
##### output #####
torch.Size([2])
torch.Size([3])
tensor([[ 1.8972, -0.3961, -0.1530],[-0.5018, -2.5110,  0.1293]])

batch 的意思还是那个 batch, 不过这里是指分布的 batch, 而不是数据的 batch. 采样时, 得到一批 samples, 对应每个分布.

还有一个 method 和这两个参数有关: expand, 因为它是一个抽象 method, 基类中并没有实现, 那就直接看 MultivariateNormal 中的:

def expand(self, batch_shape: torch.Size, _instance=None):"""Args:batch_shape (torch.Size): the desired expanded size._instance: new instance provided by subclasses that need to override `.expand`.Returns:New distribution instance with batch dimensions expanded to `batch_size`."""new = self._get_checked_instance(MultivariateNormal, _instance)batch_shape = torch.Size(batch_shape)loc_shape = batch_shape + self.event_shapecov_shape = batch_shape + self.event_shape + self.event_shapenew.loc = self.loc.expand(loc_shape)new._unbroadcasted_scale_tril = self._unbroadcasted_scale_trilif "covariance_matrix" in self.__dict__:new.covariance_matrix = self.covariance_matrix.expand(cov_shape)if "scale_tril" in self.__dict__:new.scale_tril = self.scale_tril.expand(cov_shape)if "precision_matrix" in self.__dict__:new.precision_matrix = self.precision_matrix.expand(cov_shape)super(MultivariateNormal, new).__init__(batch_shape, self.event_shape, validate_args=False)new._validate_args = self._validate_argsreturn new

这个 method 会创建一个新的 instance 或调用的时候用户提供, 并设置 batch_shape 为参数提供的形状, 然后把参数 expand 到新的 batch_shape. 用法:

mean = torch.randn(3)
matrix = torch.randn(3, 3)
covariance_matrix = torch.mm(matrix, matrix.t())mvn = MultivariateNormal(mean, covariance_matrix)
bmvn = mvn.expand(torch.Size([2]))print(bmvn.batch_shape)
print(bmvn.event_shape)
print(bmvn.sample())##### output #####
torch.Size([2])
torch.Size([3])
tensor([[-4.0891, -4.2424,  6.2574],[ 0.7656, -0.2199, -0.9836]])
1.3 一些属性

包括: m e a n mean mean, m o d e mode mode, s t d std std, v a r i a n c e variance variance, e n t r o p y entropy entropy 等基本属性, 都需要用户在子类中自己实现. 还有一些相关的函数:

  • cumulative density/mass function cdf(value);
  • inverse cumulative density/mass function icdf(value);
    这个函数非常有用, Inverse Transform Sampling 中用其进行采样. 从 U ( 0 , 1 ) U(0,1) U(0,1) 中采样一个 u u u, 然后令 x = F − 1 ( u ) x = F^{-1}(u) x=F1(u) 就是所求随机变量 X X X 的一个采样.
  • log of the probability density/mass function log_prob(value), 对数概率.

注意, 目前看到的只有 log_prob, 并没有 prob, 一些示例要么只算 log_prob, 要么计算后通过 exp(log_prob) 得到 prob.

2. constraints.Constraint

前面在1.1参数验证中已经遇到 constraints.dependent(is_discrete=False, event_dim=0)constraint.check(value), 但没有讲具体细节. 本节将详细剖析.

2.1 抽象基类 Constraint

先看源码:

class Constraint:"""一个 constraint 对象, 表示变量在某区域内有效, 即变量可优化的范围."""is_discrete = False  # Default to continuous.event_dim = 0  # Default to univariate.def check(self, value):"""结果的形状为"sample_shape + batch_shape", 指示 each event 值是否满足此限制."""raise NotImplementedError

这是抽象基类 Constraint, 比较简单, 只有两个类属性和一个 method check(value). is_discrete 表示待验证值是否为离散; 联想前面的 event_shape, 大概可以知道 event_dim 是指 len(event_shape).(不过目前看只是为了验证参数, 还能验证采样的 event?)

2.2 _Dependent() 不被验证

这个基类信息太少, 对我们理解前面的内容毫无用处, 还是直接观察一些子类吧. 从 dependent = _Dependent() 开始, 它是 constraints.py 中定义好的 placeholder(这个倒是可以学一学):

class _Dependent(Constraint):  # 看"_", 应该是不希望用户直接创建实例"""Placeholder for variables whose support depends on other variables.These variables obey no simple coordinate-wise constraints."""def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):self._is_discrete = is_discreteself._event_dim = event_dimsuper().__init__()def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):"""Support for syntax to customize static attributes::constraints.dependent(is_discrete=True, event_dim=1)"""if is_discrete is NotImplemented:  # 未提供就是默认is_discrete = self._is_discreteif event_dim is NotImplemented:event_dim = self._event_dimreturn _Dependent(is_discrete=is_discrete, event_dim=event_dim)def check(self, x):raise ValueError("Cannot determine validity of dependent constraint")

闹了半天, 我们并不能看到 constraints.dependent(is_discrete=False, event_dim=0) 有什么卵用, 只知道 “Cannot determine validity of dependent constraint”, 这也呼应了前面的:

if constraints.is_dependent(constraint):continue  # skip constraints that cannot be checked

也就是说, dependent 类型的限制是不会执行参数验证的. 那这个 _Dependent 到底有何用处? 先不管了.

2.3 _IndependentConstraint 重新解释 event_dim

我们看点复杂的, MultivariateNormal.arg_constraints:

arg_constraints = {"loc": constraints.real_vector,"covariance_matrix": constraints.positive_definite,"precision_matrix": constraints.positive_definite,"scale_tril": constraints.lower_cholesky,
}

这些都是 constraints.py 中定义好的实例, 对于大多情况, 这些预定义好的实例已经够用, 但如果需要, 你也可以自定义. 先看 real_vector:

independent = _IndependentConstraint
real_vector = independent(real, 1)
class _IndependentConstraint(Constraint):"""封装一个 constraint,  通过 aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`,an event is valid 当且仅当它依赖的所有 entries 是 valid 的."""def __init__(self, base_constraint, reinterpreted_batch_ndims):self.base_constraint = base_constraintself.reinterpreted_batch_ndims = reinterpreted_batch_ndimssuper().__init__()@propertydef event_dim(self):# real.event_dim 是 0, + real_vector(reinterpreted_batch_ndims=1) = 1return self.base_constraint.event_dim + self.reinterpreted_batch_ndimsdef check(self, value):result = self.base_constraint.check(value)  # 首先要符合 base.checkif result.dim() < self.reinterpreted_batch_ndims:# 给 batch 留够 dimexpected = self.base_constraint.event_dim + self.reinterpreted_batch_ndimsraise ValueError(f"Expected value.dim() >= {expected} but got {value.dim()}")result = result.reshape(  # 减掉 eventresult.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,))result = result.all(-1)  # 减少一个 dimreturn result

意思很明了了, real_vector 是依赖于 real(base_constraint) 的, reinterpreted_batch_ndims=1 是说把原来 valuebatch_dim 重新解释, 分出 n 个给 event_dim: 加上 reinterpreted_batch_ndims, 比如

value = [[1, 2, 3],[4, 5, 6]]

本来 realevent_dim=0, 验证结果为(sample_shape + batch_shape = (2,2)):

value = [[True, True, True],[True, True, True]]

现在重新解释为 event_dim=1, 验证结果为:

result = result.reshape(  # 减掉 eventresult.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)  # (-1,) 表示新 event 内的所有 entries 展平
)
result = result.all(-1)  # 新 event 内的所有 entries 为 True, 则新 event 为 True
================>
value = [True, True]

3. Transform & _InverseTransform

上一节介绍了 constraints.Constraint, 明白了在构建 Distribution 实例时进行的参数验证, 以保证用户提供的参数符合要求. 但还留下了一个疑问: Constraint 中的 event_dim 是指 len(event_shape), 难道还能验证采样的 event? 再者, check(value) 返回值的形状是 sample_shape + batch_shape, 进一步说明它是会被用于采样结果检查的. 让我们看一看能否在 Transform 中找到答案.

Transform & _InverseTransform 是一对互逆的操作, 实现从一个分布到另一个分布的转换. 这很有用, 因为 distributions 包已经实现了很多常见分布和转换, 自由组合威力巨大. 本节将详细介绍它是如何实现对分布的转换的.
[注] 从 _InverseTransform_ 来看, 是不需要用户了解它的.

3.1 抽象类 Transform 的基本信息
class Transform:"""变换的抽象基类, 子类应该实现 one or both of `_call` or `_inverse`.如果 `bijective=True`, 则必须实现 `log_abs_det_jacobian`.Args:cache_size (int): If one, the latest single value is cached.Only 0 and 1 are supported."""bijective = False  # Transform 是否双射, 默认 Falsedomain: constraints.Constraint  # 有效输入范围codomain: constraints.Constraint  # 有效输出范围def __init__(self, cache_size=0):self._cache_size = cache_sizeself._inv = Noneif cache_size == 0:pass  # default behaviorelif cache_size == 1:self._cached_x_y = None, Noneelse:raise ValueError("cache_size must be 0 or 1")super().__init__()

果然, Transform 中有 Constraint 的, 分别是 domaincodomain, 用于其检查输入输出是否符合要求. 此外, 还有 bijectivecache_size 这两个信息, 等一下看后面怎么说.

3.2 AffineTransform

抽象类的基本信息不多, 还是要看一个简单的例子: AffineTransform, 线性变换.

class AffineTransform(Transform):bijective = Truedef __init__(self, loc, scale, event_dim=0, cache_size=0):super().__init__(cache_size=cache_size)self.loc = locself.scale = scaleself._event_dim = event_dim

线性变换是可逆的, 可以看到它的 bijective = True. 参数是 y = l o c + s c a l e × x y = loc + scale × x y=loc+scale×x 中的 locscale; event_dim 则是用于构建 domaincodomain:

@constraints.dependent_property(is_discrete=False)
def domain(self):if self.event_dim == 0:return constraints.realreturn constraints.independent(constraints.real, self.event_dim)@constraints.dependent_property(is_discrete=False)
def codomain(self):if self.event_dim == 0:return constraints.realreturn constraints.independent(constraints.real, self.event_dim)

即, domaincodomain 被限制为 event_dim 维向量, 默认是 0, 输入输出皆为标量.

变换过程
def _call(self, x):"""Method to compute forward transformation."""return self.loc + self.scale * xdef _inverse(self, y):"""Method to compute inverse transformation."""return (y - self.loc) / self.scale

由于是双射, 还要实现:

def log_abs_det_jacobian(self, x, y):shape = x.shapescale = self.scaleif isinstance(scale, numbers.Real):result = torch.full_like(x, math.log(abs(scale)))else:result = torch.abs(scale).log()if self.event_dim:result_size = result.size()[: -self.event_dim] + (-1,)result = result.view(result_size).sum(-1)shape = shape[: -self.event_dim]return result.expand(shape)

计算结果的形状调整为 x 中除 event_dim 以外的形状, 即 sample_shape + batch_shape. 至于为什么要这么做, 还需要看 TransformedDistribution 中具体的转换流程.

这里有个问题, 假设 event_dim=1, 输入的 x.shape=(2,3), 而 scale=2.0scale=torch.tensor(2.0)计算结果是不一致的:

====================== scale=2.0 ==========================
result = torch.full_like(x, math.log(abs(2.0)))
[[log(2), log(2), log(2)],[log(2), log(2), log(2)]]
result_size = (2,3)[: -1] + (-1,) = (2,3)
result = [3log(2), 3log(2)].expand([2]) = [3log(2), 3log(2)]
================== scale=tensor(2.0) =======================
result = torch.abs(scale).log() = log(2)
result_size = ()[: -1] + (-1,) = (-1,)
result = log(2).expand([2]) = [log(2), log(2)]

类似的, 只要 scaletensor, 并出现了计算广播, 就会出现这种情况. 不知道会不会造成计算错误, 看了后面的 TransformedDistribution 就能知道. 现在只能暂时不管了.

3.3 TransformedDistribution
3.3.1 基本信息
class TransformedDistribution(Distribution):"""Extension of the Distribution class, which applies a sequence of Transformsto a base distribution."""arg_constraints: Dict[str, constraints.Constraint] = {}def __init__(self, base_distribution, transforms, validate_args=None):>>> 单 transfrom 变成 [transfrom], 再检查是否符合 transforms: List[Transform] <<<

它是对 Distribution 的扩展, 对一个 base distribution 实施一连串的 Transforms:

X ~ BaseDistribution
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
log p(Y) = log p(X) + log |det (dX/dY)|

一个简单的例子:

# #################################
# Building a Logistic Distribution
# X ~ Uniform(0, 1)
# f = a + b * logit(X)
# Y ~ f(X) ~ Logistic(a, b)
# #################################
base_distribution = Uniform(0, 1)
transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
logistic = TransformedDistribution(base_distribution, transforms)

其中 l o g i t ( x ) = l o g x 1 − x logit(x) = log\frac{x}{1-x} logit(x)=log1xx s i g m o i d sigmoid sigmoid 函数的逆.

下面是 TransformedDistribution__init__(...) 内容(省略了开头将单 Transform 转换为列表以及检查类型的代码):

# >>> Reshape base_distribution according to transforms. >>>
# >>> 获取 base_distribution 的 batch_shape 和 event_shape 以及 event_dim >>>
base_shape = base_distribution.batch_shape + base_distribution.event_shape
base_event_dim = len(base_distribution.event_shape)  # 的基本 shape
# <<< 获取 base_distribution 的 batch_shape 和 event_shape 以及 event_dim <<<
# 将 transforms 组合成一个 transform
transform = ComposeTransform(self.transforms)
# 先正向传播 shape, 再反向传播 shape, 一来一回 shape 不一致, 说明途中发生了广播
# 具体例子可为: 线性转换中的 [1,2,3] * [[2],[3]], 输入向量输出矩阵(再反向也是矩阵)
forward_shape = transform.forward_shape(base_shape)
expanded_base_shape = transform.inverse_shape(forward_shape)
if base_shape != expanded_base_shape:  # 不一致说明发生了广播 (AffineTransform为例)base_batch_shape = expanded_base_shape[: len(expanded_base_shape) - base_event_dim]  # 干脆先把 base_distribution 给 expand 了# 如 base_shape = batch_shape + event_shape = (,) + (,3) = (3,)# expanded_base_shape = (2,3), 则 base_batch_shape = (2,)base_distribution = base_distribution.expand(base_batch_shape)  # 结果 base_shape = (2,3)
# transform.domain.event_dim 是指所有 transforms 中最大的 domain.event_dim (这个 domain.event_dim 可能就只是为了检查 dim 是否够用)
reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
if reinterpreted_batch_ndims > 0:base_distribution = Independent(  # 但却实实在在地调整了 base_distribution 的 event_dimbase_distribution, reinterpreted_batch_ndims)  # 参考前面讲的 _IndependentConstraint
self.base_dist = base_distribution

这一部分的主旋律是 Reshape base_distribution according to transforms. 也就是说, self.base_dist 被赋予的是调整过的 base_distribution. 主要包括:

  • 调整 batch_shape, by base_distribution.expand(base_batch_shape), 前面讲过 expand;
  • 调整 event_shape, by Independent, 这个类似前面讲的 _IndependentConstraint, 只不过这里是对 Distribution 操作;

具体过程看注释. 所以, 使用这种方式建立新的 Distribution 时, 要同时注意 base_distributiontransformsevent_dim, 这对 log_prob 的计算有影响, 且 base_distributionevent_dim 可能被更改.

安排好 self.base_dist 后, 开始计算本 TransformedDistributionbatch_shapeevent_shape.

# Compute shapes.
transform_change_in_event_dim = (  # transform 导致的 event_dim 变化transform.codomain.event_dim - transform.domain.event_dim
)
event_dim = max(transform.codomain.event_dim,  # the transform is coupledbase_event_dim + transform_change_in_event_dim,  # the base dist is coupled
)
assert len(forward_shape) >= event_dim
cut = len(forward_shape) - event_dim  # forward_shape 劈开
batch_shape = forward_shape[:cut]
event_shape = forward_shape[cut:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
3.3.2 采样
def sample(self, sample_shape=torch.Size()):with torch.no_grad():x = self.base_dist.sample(sample_shape)for transform in self.transforms:x = transform(x)return xdef rsample(self, sample_shape=torch.Size()):x = self.base_dist.rsample(sample_shape)for transform in self.transforms:x = transform(x)return x
3.3.3 log_prob 需要 log_abs_det
def log_prob(self, value):if self._validate_args:  # 验证样本的就在此处了self._validate_sample(value)event_dim = len(self.event_shape)log_prob = 0.0y = valuefor transform in reversed(self.transforms):  # 倒着来x = transform.inv(y)  # 逆变换得到 x, 想计算 `log_prob`, 逆变换就得实现.event_dim += transform.domain.event_dim - transform.codomain.event_dimlog_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),event_dim - transform.domain.event_dim,)y = xlog_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape))return log_probdef _monotonize_cdf(self, value):"""This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` ismonotone increasing."""sign = 1for transform in self.transforms:sign = sign * transform.signif isinstance(sign, int) and sign == 1:return valuereturn sign * (value - 0.5) + 0.5def cdf(self, value):"""Computes the cumulative distribution function by inverting thetransform(s) and computing the score of the base distribution."""for transform in self.transforms[::-1]:value = transform.inv(value)if self._validate_args:self.base_dist._validate_sample(value)value = self.base_dist.cdf(value)value = self._monotonize_cdf(value)return valuedef icdf(self, value):"""Computes the inverse cumulative distribution function usingtransform(s) and computing the score of the base distribution."""value = self._monotonize_cdf(value)value = self.base_dist.icdf(value)for transform in self.transforms:value = transform(value)return value
@property
def inv(self):"""Returns the inverse :class:`Transform` of this transform.This should satisfy ``t.inv.inv is t``."""inv = Noneif self._inv is not None:inv = self._inv()if inv is None:inv = _InverseTransform(self)self._inv = weakref.ref(inv)return inv

附录

1. __debug__assert (来自 Kimi)

__debug__ 是一个内置变量,用于指示 Python 解释器是否处于调试模式。当 Python 以调试模式运行时,__debug__ 被设置为 True;否则,在优化模式下运行时,它被设置为 False

__debug__ 可以用于条件性地执行调试代码,例如:

if __debug__:print("Debug mode is on, performing extra checks...")# 这里可以放一些只在调试模式下运行的代码,比如详细的日志记录# 或者复杂的验证逻辑
else:print("Debug mode is off.")

在上面的例子中,如果命令行执行:

python -O myscript.py
##### output #####
Debug mode is off.
------------------------------------------------------
python myscript.py
##### output #####
Debug mode is on, performing extra checks...

assert 语句受 __debug__ 影响:

def calculate(a, b):# 这个 assert 在 __debug__ 为 True 时执行assert a > 0 and b > 0, "Both inputs must be positive."# 正常的函数逻辑return a * b# 在这里,assert 会检查输入是否为正数
result = calculate(5, 3)
print(result)# 如果我们改变条件使 assert 失败
# result = calculate(-1, 3)  # 这会触发 AssertionError,除非运行时 __debug__ 为 False

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/18932.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

microk8s 报错tls: failed to verify certificate: x509:

问题&#xff1a; ssh命令出现如下图所示 输入任何microk8s的容器命令几乎都是x509报错 kubectl get pods -ALL 原因&#xff1a; 证书过期 相关文档&#xff1a; MicroK8s - 服务和端口 Microk8S v1.24 - refresh-certs 似乎无法刷新证书 问题 #3241 规范/microk8s Git…

【Linux系统】--- Linux内核日志等级详解

在编程的艺术世界里&#xff0c;代码和灵感需要寻找到最佳的交融点&#xff0c;才能打造出令人为之惊叹的作品。而在这座秋知叶i博客的殿堂里&#xff0c;我们将共同追寻这种完美结合&#xff0c;为未来的世界留下属于我们的独特印记。 【Linux系统】--- Linux内核日志等级详解…

小白跟做江科大32单片机之LED流水灯

1.复制下面地址新建的工程&#xff0c;改名为3-2 LED流水灯 小白跟做江科大32单片机之LED闪烁-CSDN博客https://blog.csdn.net/weixin_58051657/article/details/139295351?csdn_share_tail%7B%22type%22%3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%2213929…

Opencv图像处理技术(图像轮廓)

1图像轮廓概念&#xff1a; 图像轮廓是指图像中连续的像素边界&#xff0c;这些边界通常代表了图像中的物体或者物体的边缘。在数字图像处理中&#xff0c;轮廓是由相同像素值组成的曲线&#xff0c;它们连接相同的颜色或灰度值&#xff0c;并且具有连续性。轮廓可以用来描述和…

CAD石墨烯生成器 V1.0 渊鱼

插件介绍 CAD石墨烯生成器插件可用于在AutoCAD软件内参数化建立石墨烯几何模型。插件建立石墨烯的球棍模型&#xff0c;可控制模型的尺寸、碳原子环的尺寸、原子直径、化学键直径&#xff0c;并可控制模型的起伏形态。插件生成的实体模型可进行修改或绘图渲染&#xff0c;用于…

做视频号小店和达人对接的好,爆单少不了!

大家好&#xff0c;我是喷火龙。 目前&#xff0c;视频号是没有什么自然流量的&#xff0c;所以&#xff0c;想要出单、爆单的话&#xff0c;靠达人带货的方式才是最可靠的&#xff0c;靠达人带货是肯定要对接达人&#xff0c;并和达人沟通带货的。 下面给大家讲一讲应该怎么…

python Z-score标准化

python Z-score标准化 Zscore标准化sklearn库实现Z-score标准化手动实现Z-score标准化 Zscore标准化 Z-score标准化&#xff08;也称为标准差标准化&#xff09;是一种常见的数据标准化方法&#xff0c;它将数据集中的每个特征的值转换为一个新的尺度&#xff0c;使得转化后的…

三十五岁零基础转行成为AI大模型开发者怎么样呢?

以下从3个方面帮大家分析&#xff1a; 35岁转行会不会太晚&#xff1f;零基础学习AI大模型开发能不能学会&#xff1f;AI大模型开发行业前景如何&#xff0c;学完后能不能找到好工作&#xff1f; 一、35岁转行会不会太晚&#xff1f; 35岁正处于人生的黄金时期&#xff0c;拥…

今日选题.

诱导读者点开文章的9引真经&#xff08;二&#xff09; 标题重要么&#xff1f;新媒体、博客文通常在手机上阅读。首先所有的内容不同于纸媒&#xff0c;手机只展现标题&#xff0c;而内容都是折叠。其次读者能像看内容一样看4、5条或者7、8条标题&#xff08;区别于不同的主流…

代码助手之-百度Comate智能体验

简介 越来越多的厂商提供了智能代码助手&#xff0c;百度也不例外。Baidu Comate&#xff08;智能代码助手&#xff09;是基于文心大模型&#xff0c;Comate取自Coding Mate&#xff0c;寓意大家的AI编码伙伴。Comate融合了百度内部多年积累的编程现场大数据和外部开源代码和知…

如何顺利通过软考中级系统集成项目管理工程师?

中级资格的软考专业包括"信息系统"&#xff0c;属于软考的中级级别。熟悉软考的人都知道&#xff0c;软考分为初级、中级和高级三个级别&#xff0c;涵盖计算机软件、计算机网络、计算机应用技术、信息系统和信息服务五个专业&#xff0c;共设立了27个资格。本文将详…

全程曝光 计算机领域顶会投稿后会经历哪些关键环节?

会议之眼 快讯 亲爱的计算机领域大牛们&#xff0c;当你挥洒汗水&#xff0c;精心打磨一篇科研论文&#xff0c;终于怀着激动的心情投稿至顶会——&#xff08;如&#xff08;ACM MM 、ACL、AAAI&#xff09;时&#xff0c;你是否想知道接下来这篇论文会经历怎样的旅程&#x…

Mybatis进阶——动态SQL(1)

目录 一、 <if> 标签 二、<trim> 标签 三、<where> 标签 四、<set> 标签 五、<foreach> 标签 六、<include> 标签 动态SQL 是Mybatis的强大特性之一&#xff0c;能够完成不同条件下的不同SQL拼接&#xff0c;可以参考官方文档&#…

pyQt处理任务等待动画

写了一个显示Qt正在处理内容的等待动画&#xff0c;任务另开一个线程执行&#xff0c;执行完后自动关闭动画 from PyQt5 import QtCore, QtWidgets from PyQt5.QtWidgets import QApplication, QMessageBox, QDialog, QVBoxLayout from PyQt5.QtCore import pyqtSignal, QTime…

springboot 作为客户端接收服务端的 tcp 长连接数据,并实现自定义结束符,解决 粘包 半包 问题

博主最近的项目对接了部分硬件设备&#xff0c;其中有的设备只支持tcp长连接方式传输数据&#xff0c;博主项目系统平台作为客户端发起tcp请求到设备&#xff0c;设备接收到请求后作为服务端保持连接并持续发送数据到系统平台。 1.依赖引入 连接使用了netty&#xff0c;如果项…

CPU占用率很高,相应很慢排查思路

获取线程状态 通过top -c命令可以动态显示进程及其占用资源的排行榜 可以看到&#xff0c;CPU占用率100%的PID是80972&#xff0c;定位到该进程之后&#xff0c;我们再从线程的dump日志中去定位. 使用top -H -p 80972命令查找到该进程中消耗CPU最多的线程&#xff0c;从下面的…

Apose.Words 常用对象详解

系列文章目录 文章目录 系列文章目录前言一、基础对象1. moveToBookmark 前言 本文介绍 Apose.Words 的常用对象的含义及使用方法。 一、基础对象 1. moveToBookmark 将指针移动到书签位置。 moveToBookmark(String bookmarkName, boolean isStart, boolean isAfter) book…

国产可视化爬虫助力AI大模型训练:精准爬取汉语词典

大语言模型&#xff0c;可以生成流畅对话的会话聊天机器人、通畅起草文章的内容生成器。在炫酷技术的背后&#xff0c;数据、算力、算法&#xff0c;被视作生成式AI的三个核心要素。由此可见&#xff0c;高质量的训练数据对于AI算法的准确性至关重要。 如何获得高质量的训练数…

【方法】如何禁止查看压缩包里的内容?

使用压缩文件&#xff0c;可以让文件更方便存储和传输&#xff0c;那对于重要的文件&#xff0c;如何防止随意查看压缩包的内容呢&#xff1f;我们可以试试以下两个方法。 方法1&#xff1a; 最常见的便是给压缩包设置“打开密码”&#xff0c;这样只有通过密码才能查看文件内…

外汇天眼:PayPoint投资100万英镑,深化与Aperidata开放银行合作

PayPoint今日宣布对Aperidata Ltd进行100万英镑的投资&#xff0c;Aperidata是一家创新的消费者和商业信用报告及开放银行平台。 此交易将使PayPoint集团在两家公司之间现有的商业合作基础上更进一步&#xff0c;为包括政府、地方当局、慈善机构和住房协会在内的多个领域的客户…