文章目录
- 前言
- 一、热力图主函数代码
- 1、正规N图热力图运行代码
- 2、重新迭代循环求解方法
- 二、中断重启继续推理或训练
- 1、封装含参主函数
- 2、终止重启进程管理方法
- 1、终止启动源码
- 2、源码解读
- 三、终止启动主函数源码解读
- 1、终止启动源码
- 2、源码解读
- 关键点解析
- 四、完整代码Demo
- 1、完整源码
- 2、总结实现步骤
- 3、关键点解析
- 4、运行结果
- 结论
前言
在进行机器学习或深度学习模型训练时,经常会遇到由于各种原因导致的异常情况,例如内存不足、数据预处理错误或其他不可预见的问题。这些问题可能导致整个训练过程中断,迫使开发者手动重启训练脚本。为了提高训练过程的健壮性和自动化程度,我们可以设计一种机制,在每次遇到异常时自动终止当前训练任务,并重新启动一个新的训练实例,继续下一次迭代。本文将介绍如何通过 Python 的 multiprocessing
模块实现这一目标,并提供一个高度概括的异常处理方法——异常安全重启机制
。当然,每个人动机不一样,我的动机:我对N个图循环求解热力图grad CAM时候,每次迭代都增加显存,尝试很多方法,最终采用投机取巧方式(本篇文章技术)完成N张图热力图图求解。
一、热力图主函数代码
1、正规N图热力图运行代码
如果正规方式应该是如下的方法来调用,然而每次只能成功运行2张图片,显存就炸了。我还是给出主函数,以便后续解说明异常代码,其源码如下:
if __name__ == '__main__':opt = parse_opt()cam_model = sam_yolov5_heatmap(opt)cam_model.heatmap_main(opt.source, opt.save_dir)
2、重新迭代循环求解方法
我想到,既然只能运行2张图片,那我就大不了用一个循环来解决该问题,于是就有了下面修改代码,然而该方法依然不能摆脱显存增加,我也觉得很神奇,估摸是py文件未停止,显存存了之前变量或泄露原因导致。我们先看修改源码:
if __name__ == '__main__':for i in range(21):opt = parse_opt()cam_model = yolov5_heatmap(opt)cam_model.heatmap_main(opt.source, opt.save_dir,N=i)
我也给出heatmap_main修改后的源码,如下:
def heatmap_main(self, img_path, save_path,N=None):self.build_dir(save_path)img_names_lst = np.sort(np.array([n for n in os.listdir(img_path) if n[-3:] in ['jpg','png','PNG']]))if N is not None:i_idx,j_idx=int(2*N+1),int(2*N+3)img_names_lst=img_names_lst[i_idx:j_idx]# print('i_idx-->j_idx:{}-->{}'.format(i_idx,j_idx))for img_name in tqdm(img_names_lst):# try:# self.computer_heatmap(f'{img_path}/{img_name}', f'{save_path}/{img_name}',rescale=True)# except:# print('未通过图片:',img_name)self.computer_heatmap(f'{img_path}/{img_name}', f'{save_path}/{img_name}',rescale=True)if N is not None:print('i_idx-->j_idx:{}-->{}'.format(i_idx,j_idx))
我以为能解决问题,结果意想不到,显存依然增加。
二、中断重启继续推理或训练
异常安全重启机制,是指在一个循环中运行多个独立的任务(如模型训练),当某个任务遇到异常时,能够安全地终止该任务而不影响主程序的执行,并且可以在主循环中继续尝试新的迭代。这种方法不仅提高了系统的稳定性,还减少了人工干预的需求。
1、封装含参主函数
首先得有个函数将for i in range(21): opt = parse_opt() cam_model = yolov5_heatmap(opt) cam_model.heatmap_main(opt.source, opt.save_dir,N=i)
我们这个函数封装,再通过multiprocessing方法来实现异常中断再重启。这里,我给出封装主函数方法,后续读者有需求可以按照我的模板自己进行封装。其代码如下:
# 假设 parse_opt, yolov5_heatmap 和 heatmap_main 是你已定义的函数/类
def run_model_with_param(opt, source, save_dir, N):try:cam_model = yolov5_heatmap(opt)cam_model.heatmap_main(source, save_dir, N=N)print(f"Model process {os.getpid()} completed successfully for iteration {N}.")except Exception as e:print(f"Exception caught in iteration {N}: {e}")sys.exit(1) # 异常终止模型进程
2、终止重启进程管理方法
1、终止启动源码
我们定义了一个上下文管理器 process_context
,它用于启动和管理一个子进程,并确保在退出上下文时正确清理资源。
@contextmanager
def process_context(target, args):p = multiprocessing.Process(target=target, args=args)p.start()try:yield pfinally:if p.is_alive():p.terminate()p.join()
2、源码解读
@contextmanager
def process_context(target, args):
- 装饰器
@contextmanager
:这是 Python 标准库contextlib
模块中的一个装饰器,用于简化上下文管理器的创建。使用这个装饰器后,函数可以像with
语句一样被使用。 - 函数定义
process_context
:该函数接受两个参数:target
: 这是要在子进程中执行的目标函数。args
: 这是一个元组,包含了传递给目标函数的参数。
p = multiprocessing.Process(target=target, args=args)
- 创建进程对象
p
:这里使用multiprocessing.Process
创建了一个新的进程实例。target
参数指定了要在新进程中运行的函数,而args
参数则是传递给该函数的参数列表。
p.start()
- 启动进程:调用
start()
方法来启动子进程。这将使目标函数在一个独立的进程中开始执行。
try:yield p
- 进入上下文管理器的主体部分:
yield
语句是上下文管理器的关键。当使用with process_context(...) as p:
语法时,yield
前面的代码会在进入with
语句块之前执行,而yield
后面的代码会在离开with
语句块之后执行。 - 返回进程对象
p
:yield p
将进程对象p
返回给with
语句,使得可以在with
语句块中访问和操作这个进程对象。
finally:if p.is_alive():p.terminate()p.join()
- 确保资源释放:
finally
块中的代码无论如何都会被执行,即使在try
块中发生了异常。这里的作用是确保子进程在退出上下文时被正确终止和清理。- 检查进程是否存活:
if p.is_alive()
检查进程是否仍然在运行。 - 终止进程:如果进程还在运行,则调用
terminate()
方法发送终止信号给进程。 - 等待进程结束:
p.join()
确保主程序会等待子进程完全终止后再继续执行。这一步很重要,因为它保证了所有资源都被正确释放。
- 检查进程是否存活:
通过 @contextmanager
和 process_context
的结合,我们可以方便地管理子进程的生命周期,确保它们在不再需要时被正确终止,从而避免潜在的资源泄漏问题。这种方法非常适合那些需要频繁启动和终止子进程的任务,如模型训练、批处理作业等。
三、终止启动主函数源码解读
1、终止启动源码
我们定义了一个名为 main_loop
的函数,该函数实现了模型训练的主循环逻辑。它使用了之前定义的 process_context
上下文管理器来确保每个迭代中的模型进程能够被安全地启动和终止,源码如下:
def main_loop():for i in range(21):print(f"Starting iteration {i}...")opt = parse_opt()source = opt.sourcesave_dir = opt.save_dirwith process_context(run_model_with_param, (opt, source, save_dir, i)) as model_process:model_process.join(timeout=6) # 设置适当的超时时间if model_process.exitcode != 0:print(f"Model process exited with an error on iteration {i}. Restarting...")else:print(f"Model process completed successfully for iteration {i}.")
下面是对这段代码的详细解读:
2、源码解读
def main_loop():
- 定义主循环函数:
main_loop
是整个程序的核心逻辑所在,负责管理和控制模型训练的多次迭代。
for i in range(21):
- 迭代循环:这里使用
for
循环进行 21 次迭代(从 0 到 20),每次迭代都会尝试启动一个新的模型训练进程,并传递一个唯一的参数i
给模型。
print(f"Starting iteration {i}...")
- 打印当前迭代信息:在每次迭代开始时,打印一条消息以标识当前是第几次迭代,便于跟踪和调试。
opt = parse_opt()
source = opt.source
save_dir = opt.save_dir
- 解析配置选项:调用
parse_opt()
函数获取命令行或其他来源的配置选项,并从中提取出source
和save_dir
参数,这些参数将用于初始化模型并指定数据源和保存目录。
with process_context(run_model_with_param, (opt, source, save_dir, i)) as model_process:
- 启动子进程:使用
process_context
上下文管理器启动一个新的子进程来运行run_model_with_param
函数。这个函数接收四个参数:opt
、source
、save_dir
和i
。with
语句确保即使在子进程中发生异常或错误,资源也会被正确清理。
model_process.join(timeout=6) # 设置适当的超时时间
- 等待子进程完成:调用
join()
方法等待子进程结束,同时设置了一个 6 秒的超时时间。这意味着如果子进程在这段时间内没有完成,主程序将继续执行而不等待其完成。你可以根据实际情况调整这个超时值。
if model_process.exitcode != 0:print(f"Model process exited with an error on iteration {i}. Restarting...")
else:print(f"Model process completed successfully for iteration {i}.")
- 检查子进程退出状态:
- 非零退出码:如果子进程以非零状态码退出(即发生了异常或错误),则打印一条错误消息,并继续下一次迭代。
- 零退出码:如果子进程成功完成(以零状态码退出),则打印一条成功消息。根据需要,可以在这里添加逻辑来决定是否继续下一次迭代或者提前终止循环。
关键点解析
- 迭代控制:通过
for
循环实现对多个模型训练任务的控制,确保每个任务都能独立启动和终止。 - 配置解析:每次迭代前都重新解析配置选项,确保使用最新的配置参数。
- 多进程管理:利用
multiprocessing.Process
和上下文管理器process_context
来管理子进程的生命周期,保证资源的安全释放。 - 超时处理:为每个子进程设置了合理的超时时间,防止某个任务长时间挂起影响整体进度。
- 异常处理:通过检查子进程的退出状态码,可以在遇到异常时及时做出响应,并决定是否重启新的迭代。
四、完整代码Demo
这样,我就解决了热力图显存问题,但这不是正规方法,但这个技术可以应用到其它方式中。
1、完整源码
完整源码如下:
import multiprocessing
import sys
import os
from contextlib import contextmanager# 假设 parse_opt, yolov5_heatmap 和 heatmap_main 是你已定义的函数/类
def run_model_with_param(opt, source, save_dir, N):try:cam_model = yolov5_heatmap(opt)cam_model.heatmap_main(source, save_dir, N=N)print(f"Model process {os.getpid()} completed successfully for iteration {N}.")except Exception as e:print(f"Exception caught in iteration {N}: {e}")sys.exit(1) # 异常终止模型进程@contextmanager
def process_context(target, args):p = multiprocessing.Process(target=target, args=args)p.start()try:yield pfinally:if p.is_alive():p.terminate()p.join()def main_loop():for i in range(21):print(f"Starting iteration {i}...")opt = parse_opt()source = opt.sourcesave_dir = opt.save_dirwith process_context(run_model_with_param, (opt, source, save_dir, i)) as model_process:model_process.join(timeout=6) # 设置适当的超时时间if model_process.exitcode != 0:print(f"Model process exited with an error on iteration {i}. Restarting...")else:print(f"Model process completed successfully for iteration {i}.")if __name__ == "__main__":main_loop()
2、总结实现步骤
- 封装任务逻辑:将每个任务的执行逻辑封装到一个函数中,该函数接收必要的参数并返回结果。
- 使用多进程:利用
multiprocessing.Process
来创建和管理子进程,确保每个任务都在独立的进程中运行。 - 异常捕获与处理:在子进程中捕获所有可能发生的异常,并在遇到异常时调用
sys.exit(1)
以非零状态码退出。 - 上下文管理器:定义一个上下文管理器来启动和监控子进程,确保即使在发生异常的情况下也能正确清理资源。
- 主循环控制:在外层循环中不断尝试启动新的子进程,直到满足特定条件为止。
3、关键点解析
- 封装任务逻辑:通过
run_model_with_param
函数封装了模型训练的具体逻辑,并允许传递额外的参数N
。 - 使用多进程:
multiprocessing.Process
创建了一个新的进程来运行模型,确保每个任务都是独立的。 - 异常捕获与处理:在
run_model_with_param
中捕获所有异常,并通过sys.exit(1)
终止进程。 - 上下文管理器:
process_context
确保了即使发生异常,子进程也会被正确终止。 - 主循环控制:
main_loop
函数实现了外层循环,持续尝试启动新的模型进程,直到完成所有迭代或达到其他终止条件。
4、运行结果
运行效果如下:
我解决问题效果:
结论
通过上述方法,我们构建了一个异常安全重启机制,它能够在遇到异常时自动终止当前任务并在主循环中继续下一次迭代。这种方法不仅提高了模型训练过程的稳定性和自动化水平,还减少了因意外错误而导致的训练中断问题。