reprod_log复现精度对比小工具
主要用于对比和记录模型复现过程中的各个步骤精度对齐情况
pip 安装
pip3 install reprod_log --force-reinstall
提供的类和方法
论文复现赛
在论文复现赛中,主要用到的类如下所示。
- ReprodLogger
- 功能:记录和保存复现过程中的中间变量,用于后续的diff排查
- 初始化参数:无
- 方法
- add(key, val)
- 功能:向logger中添加key-val pair
- 输入
- key (str) : PaddlePaddle中的key与参考代码中保存的key应该完全相同,否则会提示报错
- value (numpy.ndarray) : key对应的值
- 返回: None
- remove(key)
- 功能:移除logger中的关键字段key及其value
- 输入
- key (str) : 关键字段
- value (numpy.ndarray) : key对应的值
- 返回: None
- clear()
- 功能:清空logger中的关键字段key及其value
- 输入: None
- 返回: None
- save(path)
- 功能:将logger中的所有的key-value信息保存到文件中
- 输入:
- path (str): 路径
- 返回: None
- add(key, val)
- ReprodDiffHelper
- 功能:对
ReprodLogger
保存的日志文件进行解析,打印与记录diff - 初始化参数:无
- 方法
- load_info(path)
- 功能:加载
- 输入:
- path (str): 日志文件路径
- 返回: dict信息,key为str,value为numpy.ndarray
- compare_info(info1, info2)
- 功能:计算两个字典对于相同key的value的diff,具体计算方法为
diff = np.abs(info1[key] - info2[key])
- 输入:
- info1/info2 (dict): PaddlePaddle与参考代码保存的文件信息
- 返回: diff的dict信息
- 功能:计算两个字典对于相同key的value的diff,具体计算方法为
- report(diff_method="mean", diff_threshold=1e-6, path="./diff.txt")
- 功能:可视化diff,保存到文件或者到屏幕
- 参数
- diff_method (str): diff计算方法,包括
mean
、min
、max
、all
,默认为mean
- diff_threshold (float): 阈值,如果diff大于该阈值,则核验失败,默认为
1e-6
- path (str): 日志保存的路径,默认为
./diff.txt
- diff_method (str): diff计算方法,包括
- load_info(path)
- 功能:对