欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/133673820
在使用 LightningModule 框架训练模型时,因数据导致的训练错误,严重影响训练稳定性,因此需要使用 try-except 及时捕获错误。即 当错误发生时,在 training_step
异常返回 None,同时,on_before_zero_grad
也需要进行异常处理,处理 training_step
的异常返回 None。
同样的,validation_step
也可以这样处理。
源码如下:
class MyObject(pl.LightningModule):def __init__(self, config, args):# ...def training_step_wrapper(self, batch, batch_idx, log_interval=10):# train key processdef training_step(self, batch, batch_idx, log_interval=10):"""typically, each step costs 50 seconds参考: https://github.com/Lightning-AI/lightning/pull/3566"""try:res = self.training_step_wrapper(batch, batch_idx, log_interval)return resexcept Exception as e:logger.info(f"[CL] training_step, exception: {e}")return Nonedef on_before_zero_grad(self, *args, **kwargs):try:self.ema.update(self.model)except Exception as e:# 支持 training_step return Nonelogger.info(f"[CL] on_before_zero_grad, exception: {e}")returndef validation_step_wrapper(self, batch, batch_idx):# val key processdef validation_step(self, batch, batch_idx):try:self.validation_step_wrapper(batch, batch_idx)except Exception as e:logger.info(f"[CL] validation_step, exception: {e}")return
常见错误如下
数组越界:
index 0 is out of bounds for dimension 0 with size 0
字典错误字段:
num_res = int(np_example["seq_length"])
KeyError: 'seq_length'
计算输入数值为空:
V, _, W = torch.linalg.svd(C)
free()异常:
free(): invalid next size (fast)
munmap_chunk()
空指针:
munmap_chunk(): invalid pointer