当我们训练好一个网络模型后必不可少的就是对模型跑前向,看模型的实际性能如何。python绝对是最简单的环境,所以本文写一个python版本的前向测试。
import os
import cv2
import sys
import caffe
import glob
import argparse
from PIL import Image
import numpy as npdef parse_args():parser = argparse.ArgumentParser(description='deblur arguments')parser.add_argument('--image_root_dir', type=str, default='/.../data', # 存放数据集目录的路径help='test image root dir')parser.add_argument('--test_txt', type=str, default='/.../test.txt', # 生成的txt的相对路径help='test txt path')parser.add_argument('--caffe_model', type=str,default='/.../xxx.caffemodel',help='caffemodel path')parser.add_argument('--deploy', type=str, default='/.../deploy.prototxt',help='deploy path')parser.add_argument('--num_cls', type=int, default='3', help='class number')parser.add_argument('--input_size', type=int, default='96', help='net input size')parser.add_argument('--save_dir', type=str, default='./results', help='test result dir')parser.add_argument('--roc_name', type=str, default='roc.txt', help='test roc name')parser.add_argument('--saveimg_flag', type=int, default='1', help='if 0, do not save img, else save img')args = parser.parse_args()return argsdef main():args = parse_args()order = Test(args)def Test(args):if os.path.exists(args.save_dir) == False:os.mkdir(args.save_dir)roc_file = open(args.save_dir + '/' + args.roc_name, 'w+')net = caffe.Net(args.deploy, args.caffe_model, caffe.TEST)transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})transformer.set_transpose('data', (2, 0, 1))# transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))transformer.set_raw_scale('data', 255)# transformer.set_mean('data', np.array([104, 117, 123]))transformer.set_channel_swap('data', (2, 1, 0))f1 = open(args.test_txt, 'r')datas = f1.readlines()wrong = np.zeros(args.num_cls)right = np.zeros(args.num_cls)for data in datas:tmp = data.split(' ')imgname = tmp[0]for i in range(1, len(tmp) - 1):imgname += ' ' + tmp[i]label = int(tmp[-1][0])print(tmp, imgname, label)im = caffe.io.load_image(args.image_root_dir + '/' + imgname)im = cv2.resize(im, (args.input_size, args.input_size))net.blobs['data'].data[...] = transformer.preprocess('data', im)out = net.forward()prob = net.blobs['prob'].data[0].flatten()print(prob)order = prob.argsort()[-1] # small-bigprint(order, label)if args.saveimg_flag != 0:if os.path.exists(os.path.join(args.save_dir, str(label) + "-" + str(order))) == False:os.mkdir(os.path.join(args.save_dir, str(label) + "-" + str(order)))im = Image.open(args.image_root_dir + '/' + imgname).convert('RGB')smallname = imgname.split('/')[-1]if prob[order] >= 0.6:im.save(os.path.join(args.save_dir, str(label) + "-" + str(order), str(prob[order]) + smallname))for p in prob:roc_file.write('%s ' % p)roc_file.write('%s\n' % label)if order != label:wrong[label] += 1else:right[label] += 1print(wrong, right)roc_file.close()if __name__ == '__main__':main()print("done")
生成的roc.txt可用于在下一篇画roc曲线。