当训练任务结束,常常需要评价函数(Metrics)来评估模型的好坏。不同的训练任务往往需要不同的Metrics函数。例如,对于二分类问题,常用的评价指标有precision(准确率)、recall(召回率)等,而对于多分类任务,可使用宏平均(Macro)和微平均(Micro)来评估。
MindSpore提供了大部分常见任务的评价函数,如Accuracy、Precision、MAE和MSE等,由于MindSpore提供的评价函数无法满足所有任务的需求,很多情况下用户需要针对具体的任务自定义Metrics来评估训练的模型。
本章主要介绍如何自定义Metrics以及如何在mindspore.train.Model中使用Metrics。
自定义Metrics
自定义Metrics函数需要继承mindspore.train.Metric父类,并重新实现父类中的clear方法、update方法和eval方法。
- clear:初始化相关的内部参数。
- update:接收网络预测输出和标签,计算误差,每次step后并更新内部评估结果。
- eval:计算最终评估结果,在每次epoch结束后计算最终的评估结果。
平均绝对误差(MAE)算法如式(1)所示:
下面以简单的MAE算法为例,介绍clear、update和eval三个函数及其使用方法。
模型训练中使用Metrics
mindspore.train.Model是用于训练和评估的高层API,可以将自定义或MindSpore已有的Metrics作为参数传入,Model能够自动调用传入的Metrics进行评估。
在网络模型训练后,需要使用评价指标,来评估网络模型的训练效果,因此在演示具体代码之前首先简单拟定数据集,对数据集进行加载和定义一个简单的线性回归网络模型:
使用内置评价指标
使用MindSpore内置的Metrics作为参数传入Model时,Metrics可以定义为一个字典类型,字典的key值为字符串类型,字典的value值为MindSpore内置的评价指标,如下示例使用train.Accuracy计算分类的准确率。
使用自定义评价指标
如下示例在Model中传入上述自定义的评估指标MAE(),将验证数据集传入model.fit()接口边训练边验证。
验证结果为一个字典类型,验证结果的key值与metrics的key值相同,验证结果的value值为预测值与实际值的平均绝对误差。