http://alanse7en.github.io/caffedai-ma-jie-xi-4/
三.
从一个比较宏观的层面上去了解caffe怎么去完成一些初始化的工作和使用Solver
的接口函数,本文将主要分为四部分的内容:
- Google Flags的使用
- Register Brew Function的宏的定义和使用
train()
函数的具体实现SolverParameter
的具体解析过程
Google Flags的使用
从Caffe官网中可以看到,caffe的Command Line Interfaces一共提供了四个功能:train/test/time/device_query,而Interfaces的输入除了这四种功能还可以输入诸如-solver/-weights/-snapshot/-gpu等参数。这些参数的解析是通过Google Flags这个工具来完成的。
在caffe.cpp(位于/CAFFE_ROOT/tools/caffe.cpp)的开头,我们可以看到很多这样的宏:
DEFINE_string(gpu, "","Optional; run in GPU mode on given device IDs separated by ','.""Use '-gpu all' to run on all available GPUs. The effective training ""batch size is multiplied by the number of devices.");
这个宏的使用方式为DEFINE_xxx(name, default_value, instruction);
,这样就定义了一个xxx类型名为FLAGS_name的标志,如果用户没有在Command Line中提供其值,那么会默认为default_value
,instruction
是这个标志含义的说明。因此,上面的代码定义了一个string类型的名为FLAGS_gpu的标志,如果在Command Line中用户没有提供值,那么会默认为空字符串,根据说明可以得知这个标志是提供给用户来指定caffe将使用的GPU的。其余的定义也是类似的理解方式就不一一列举了。
解析这些标志的代码在caffe.cpp中的main()
中调用了/CAFFE_ROOT/src/common.cpp中的GlobalInit(&argc, &argv)
函数:
1 void GlobalInit(int* pargc, char*** pargv) {
2 // Google flags.
3 ::gflags::ParseCommandLineFlags(pargc, pargv, true);
4 // Google logging.
5 ::google::InitGoogleLogging(*(pargv)[0]);
6 // Provide a backtrace on segfault.
7 ::google::InstallFailureSignalHandler();
8 }
第三行的函数就是Google Flags用来解析输入的参数的,前两个参数分别是指向main()
的argc
和argv
的指针,第三个参数为true
,表示在解析完所有的标志之后将这些标志从argv
中清除,因此在解析完成之后,argc
的值为2,argv[0]
为main,argv[1]
为train/test/time/device_query中的一个。
Register Brew Function的宏的定义和使用
Caffe在Command Line Interfaces中一共提供了4种功能:train/test/time/device_query,分别对应着四个函数,这四个函数的调用是通过一个叫做g_brew_map
的全局变量来完成的:
1 // A simple registry for caffe commands.
2 typedef int (*BrewFunction)();
3 typedef std::map<caffe::string, BrewFunction> BrewMap;
4 BrewMap g_brew_map;
g_brew_map
是一个key为string类型,value为BrewFunction类型的一个map类型的全局变量,BrewFunction是一个函数指针类型,指向的是参数为空,返回值为int的函数,也就是train/test/time/device_query这四个函数的类型。在train等四个函数实现的后面都紧跟着这样一句宏的调用:RegisterBrewFunction(train)
;
其中使用的宏的具体定义为:
1 \#define RegisterBrewFunction(func) \
2 namespace { \
3 class __Registerer_##func { \
4 public: /* NOLINT */ \
5 __Registerer_##func() { \
6 g_brew_map[#func] = &func; \
7 } \
8 }; \
9 __Registerer_##func g_registerer_##func; \
10 }
以train函数为例子,RegisterBrewFunction(train)
这个宏的作用是定义了一个名为__Register_train
的类,在定义完这个类之后,定义了一个这个类的变量,会调用构造函数,这个类的构造函数在前面提到的g_brew_map
中添加了key为”train”,value为指向train函数的指针的一个元素。
然后函数的调用在main()
函数中是通过下面的这段代码实现的,在完成初始化(GlobalInit)之后,有这样一句代码:
1 // main()中的调用代码
2 return GetBrewFunction(caffe::string(argv[1]))();
3 // BrewFunction的具体实现
4 static BrewFunction GetBrewFunction(const caffe::string& name) {
5 if (g_brew_map.count(name)) {
6 return g_brew_map[name];
7 } else {
8 LOG(ERROR) << "Available caffe actions:";
9 for (BrewMap::iterator it = g_brew_map.begin();
10 it != g_brew_map.end(); ++it) {
11 LOG(ERROR) << "\t" << it->first;
12 }
13 LOG(FATAL) << "Unknown action: " << name;
14 return NULL; // not reachable, just to suppress old compiler warnings.
15 }
16 }
还是以train函数为例子,如果我们在Command Line中输入了caffe train <args>
,经过Google Flags的解析argv[1]=train,因此,在GetBrewFunction
中会通过g_brew_map
返回一个指向train函数的函数指针,最后在main函数中就通过这个返回的函数指针完成了对train函数的调用。
总结一下:RegisterBrewFunction
这个宏在每一个实现主要功能的函数之后将这个函数的名字和其对应的函数指针添加到了g_brew_map
中,然后在main函数中,通过GetBrewFunction
得到了我们需要调用的那个函数的函数指针,并完成了调用。
train()
函数的具体实现
接下来我们仔细地分析一下在train()
的具体实现。
首先是这样的一段代码:
1 CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train.";
2 CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size())
3 << "Give a snapshot to resume training or weights to finetune "
4 "but not both.";
这段代码的第一行使用了glog的CHECK_GT
宏(含义为check greater than),检查FLAGS_solver
的size是否大于0,如果小于或等于0则输出提示:”Need a solver definition to train”。FLAGS_solver
是最开始通过DEFINE_string
定义的标志,如果我们希望训练一个模型,那么自然应该应该提供对应的solver定义文件的路径,这一句话正是在确保我们提供了这样的路径。这样的检查语句在后续的代码中会经常出现,将不再一一详细解释,如果有不清楚含义的glog宏可以去看看文档。 与第一行代码类似,第二行代码是确保用户没有同时提供snapshot和weights参数,这两个参数都是继续之前的训练或者进行fine-tuning的,如果同时指明了这两个标志,则不知道到底应该从哪个路径的文件去读入模型的相关参数更为合适。
然后出现了SolverParameter solver_param
的声明和解析的代码:
1 caffe::SolverParameter solver_param;
2 caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);
SolverParameter
是通过Google Protocol Buffer
自动生成的一个类,如果有不清楚的可以参考上一篇文章。而具体的解析函数将在下一部分具体解释。
接下来这一部分的代码是根据用户的设置来选择caffe工作的模式(GPU或CPU)以及使用哪些GPU(caffe已经支持了多GPU同时工作!具体参考:官网tutorial的Parallelism部分):
1 // If the gpus flag is not provided, allow the mode and device to be set
2 // in the solver prototxt.
3 if (FLAGS_gpu.size() == 0
4 && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) {
5 if (solver_param.has_device_id()) {
6 FLAGS_gpu = "" +
7 boost::lexical_cast<string>(solver_param.device_id());
8 } else { // Set default GPU if unspecified
9 FLAGS_gpu = "" + boost::lexical_cast<string>(0);
10 }
11 }
12 vector<int> gpus;
13 get_gpus(&gpus);
14 if (gpus.size() == 0) {
15 LOG(INFO) << "Use CPU.";
16 Caffe::set_mode(Caffe::CPU);
17 } else {
18 ostringstream s;
19 for (int i = 0; i < gpus.size(); ++i) {
20 s << (i ? ", " : "") << gpus[i];
21 }
22 LOG(INFO) << "Using GPUs " << s.str();
23
24 solver_param.set_device_id(gpus[0]);
25 Caffe::SetDevice(gpus[0]);
26 Caffe::set_mode(Caffe::GPU);
27 Caffe::set_solver_count(gpus.size());
28 }
首先是判断用户在Command Line中是否输入了gpu相关的参数,如果没有(FLAGS_gpu.size()==0)但是用户在solver的prototxt定义中提供了相关的参数,那就把相关的参数放到FLAGS_gpu中,如果用户仅仅是选择了在solver的prototxt定义中选择了GPU模式,但是没有指明具体的gpu_id,那么就默认设置为0。
接下来的代码则通过一个get_gpus的函数,将存放在FLAGS_gpu中的string转成了一个vector,并完成了具体的设置。
下面的代码声明并通过SolverRegistry
初始化了一个指向Solver
类型的shared_ptr。并通过这个shared_ptr指明了在遇到系统信号(用户按了ctrl+c或者关闭了当前的terminal)时的处理方式。
1 caffe::SignalHandler signal_handler(
2 GetRequestedAction(FLAGS_sigint_effect),
3 GetRequestedAction(FLAGS_sighup_effect));
4
5 shared_ptr<caffe::Solver<float> >
6 solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
7
8 solver->SetActionFunction(signal_handler.GetActionFunction());
接下来判断了一下用户是否定义了snapshot或者weights这两个参数中的一个,如果定义了则需要通过Solver
提供的接口从snapshot或者weights文件中去读取已经训练好的网络的参数:
1 if (FLAGS_snapshot.size()) {
2 LOG(INFO) << "Resuming from " << FLAGS_snapshot;
3 solver->Restore(FLAGS_snapshot.c_str());
4 } else if (FLAGS_weights.size()) {
5 CopyLayers(solver.get(), FLAGS_weights);
6 }
最后,如果用户设置了要使用多个gpu,那么要声明一个P2PSync
类型的对象,并通过这个对象来完成多gpu的计算,这一部分的代码,这一系列的文章会暂时先不涉及。而如果是只使用单个gpu,那么就通过Solver
的Solve()
开始具体的优化过程。在优化结束之后,函数将0值返回给main函数,整个train过程到这里也就结束了:
1 if (gpus.size() > 1) {
2 caffe::P2PSync<float> sync(solver, NULL, solver->param());
3 sync.run(gpus);
4 } else {
5 LOG(INFO) << "Starting Optimization";
6 solver->Solve();
7 }
8 LOG(INFO) << "Optimization Done.";
9 return 0;
上面的代码中涉及了很多Solver
这个类的接口,这些内容都将在下一篇文章中进行具体的分析。
SolverParameter
的具体解析过程
前面提到了SolverParameter
是通过ReadSolverParamsFromTextFileOrDie
来完成解析的,这个函数的实现在/CAFFE_ROOT/src/caffe/util/upgrade_proto.cpp里,我们来看一下具体的过程:
1 // Read parameters from a file into a SolverParameter proto message.
2 void ReadSolverParamsFromTextFileOrDie(const string& param_file,
3 SolverParameter* param) {
4 CHECK(ReadProtoFromTextFile(param_file, param))
5 << "Failed to parse SolverParameter file: " << param_file;
6 UpgradeSolverAsNeeded(param_file, param);
7 }
这里调用了先后调用了两个函数,首先是ReadProtoFromTextFile
,这个函数的作用是从param_file这个路径去读取solver的定义,并将文件中的内容解析存到param这个指针指向的对象,具体的实现在/CAFFE_ROOT/src/caffe/util/io.cpp的开始:
1 bool ReadProtoFromTextFile(const char* filename, Message* proto) {
2 int fd = open(filename, O_RDONLY);
3 CHECK_NE(fd, -1) << "File not found: " << filename;
4 FileInputStream* input = new FileInputStream(fd);
5 bool success = google::protobuf::TextFormat::Parse(input, proto);
6 delete input;
7 close(fd);
8 return success;
9 }
这段代码首先是打开了文件,并且读取到了一个FileInputStream
的指针中,然后通过protobuf
的TextFormat::Parse
函数完成了解析。
然后UpgradeSolverAsNeeded
完成了新老版本caffe.proto的兼容处理:
1 // Check for deprecations and upgrade the SolverParameter as needed.
2 bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) {
3 bool success = true;
4 // Try to upgrade old style solver_type enum fields into new string type
5 if (SolverNeedsTypeUpgrade(*param)) {
6 LOG(INFO) << "Attempting to upgrade input file specified using deprecated "
7 << "'solver_type' field (enum)': " << param_file;
8 if (!UpgradeSolverType(param)) {
9 success = false;
10 LOG(ERROR) << "Warning: had one or more problems upgrading "
11 << "SolverType (see above).";
12 } else {
13 LOG(INFO) << "Successfully upgraded file specified using deprecated "
14 << "'solver_type' field (enum) to 'type' field (string).";
15 LOG(WARNING) << "Note that future Caffe releases will only support "
16 << "'type' field (string) for a solver's type.";
17 }
18 }
19 return success;
20 }
主要的问题就是在旧版本中Solver
的type是enum类型,而新版本的变为了string。
总结
本文从主要分析了caffe.cpp中实现各种具体功能的函数的调用的机制,以及在Command Line中用户输入的各种参数是怎么解析的,以及最常用的train函数的具体代码。通过这些分析,我们对Solver
类型的接口有了一个初步的认识和了解,在下一篇文章中,我们将去具体地分析Solver
的实现。
四.
在上文对Command Line Interfaces进行了简单的介绍之后,本文将对caffe的Solver相关的代码进行分析。
本文将主要分为四部分的内容:
Solver
的初始化(Register宏和构造函数)SIGINT
和SIGHUP
信号的处理Solver::Solve()
具体实现SGDSolver::ApplyUpdate
具体实现
Solver的初始化(Register宏和构造函数)
shared_ptr<caffe::Solver<float> >solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
caffe.cpp中的train函数中通过上面的代码定义了一个指向Solver<float>
的shared_ptr。其中主要是通过调用SolverRegistry
这个类的静态成员函数CreateSolver
得到一个指向Solver
的指针来构造shared_ptr类型的solver
。而且由于C++多态的特性,尽管solver
是一个指向基类Solver
类型的指针,通过solver
这个智能指针来调用各个成员函数会调用到各个子类(SGDSolver
等)的函数。具体的过程如下面的流程图所示:
下面我们就来具体看一下SolverRegistry
这个类的代码,以便理解是如何通过同一个函数得到不同类型的Solver:
1 class SolverRegistry {
2 public:
3 typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
4 typedef std::map<string, Creator> CreatorRegistry;
5 static CreatorRegistry& Registry() {
6 static CreatorRegistry* g_registry_ = new CreatorRegistry();
7 return *g_registry_;
8 }
9 static void AddCreator(const string& type, Creator creator) {
10 CreatorRegistry& registry = Registry();
11 CHECK_EQ(registry.count(type), 0)
12 << "Solver type " << type << " already registered.";
13 registry[type] = creator;
14 }
15 static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
16 const string& type = param.type();
17 CreatorRegistry& registry = Registry();
18 CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
19 << " (known types: " << SolverTypeListString() << ")";
20 return registry[type](param);
21 }
22 static vector<string> SolverTypeList() {
23 CreatorRegistry& registry = Registry();
24 vector<string> solver_types;
25 for (typename CreatorRegistry::iterator iter = registry.begin();
26 iter != registry.end(); ++iter) {
27 solver_types.push_back(iter->first);
28 }
29 return solver_types;
30 }
31 private:
32 SolverRegistry() {}
33 static string SolverTypeListString() {
34 vector<string> solver_types = SolverTypeList();
35 string solver_types_str;
36 for (vector<string>::iterator iter = solver_types.begin();
37 iter != solver_types.end(); ++iter) {
38 if (iter != solver_types.begin()) {
39 solver_types_str += ", ";
40 }
41 solver_types_str += *iter;
42 }
43 return solver_types_str;
44 }
45 };
首先需要注意的是这个类的构造函数是private的,也就是用我们没有办法去构造一个这个类型的变量,这个类也没有数据成员,所有的成员函数也都是static的,可以直接调用。
我们首先从CreateSolver
函数(第15行)入手,这个函数先定义了string类型的变量type,表示Solver的类型(‘SGD’/’Nestrov’等),然后定义了一个key类型为string,value类型为Creator
的map:registry,其中Creator
是一个函数指针类型,指向的函数的参数为SolverParameter
类型,返回类型为Solver<Dtype>*
(见第2行和第3行)。如果是一个已经register过的Solver类型,那么registry.count(type)
应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的Solver<Dtype>*
返回。
上面的代码中,Registry
这个函数(第5行)中定义了一个static的变量g_registry,这个变量是一个指向CreatorRegistry
这个map类型的指针,然后直接返回,因为这个变量是static的,所以即使多次调用这个函数,也只会定义一个g_registry,而且在其他地方修改这个map里的内容,是存储在这个map中的。事实上各个Solver的register的过程正是往g_registry指向的那个map里添加以Solver的type为key,对应的Creator函数指针为value的内容。Register的过程如流程图所示:
下面我们具体来看一下Solver的register的过程:
1 template <typename Dtype>
2 class SolverRegisterer {
3 public:
4 SolverRegisterer(const string& type,
5 Solver<Dtype>* (*creator)(const SolverParameter&)) {
6 // LOG(INFO) << "Registering solver type: " << type;
7 SolverRegistry<Dtype>::AddCreator(type, creator);
8 }
9 };
10 #define REGISTER_SOLVER_CREATOR(type, creator) \
11 static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \
12 static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \
13
14 #define REGISTER_SOLVER_CLASS(type) \
15 template <typename Dtype> \
16 Solver<Dtype>* Creator_##type##Solver( \
17 const SolverParameter& param) \
18 { \
19 return new type##Solver<Dtype>(param); \
20 } \
21 REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
22 }
23 // register SGD Solver
24 REGISTER_SOLVER_CLASS(SGD);
在sgd_solver.cpp(SGD Solver对应的cpp文件)末尾有上面第24行的代码,使用了REGISTER_SOLVER_CLASS
这个宏,这个宏会定义一个名为Creator_SGDSolver
的函数,这个函数即为Creator
类型的指针指向的函数,在这个函数中调用了SGDSolver
的构造函数,并将构造的这个变量得到的指针返回,这也就是Creator类型函数的作用:构造一个对应类型的Solver对象,将其指针返回。然后在这个宏里又调用了REGISTER_SOLVER_CREATOR
这个宏,这里分别定义了SolverRegisterer
这个模板类的float和double类型的static变量,这会去调用各自的构造函数,而在SolverRegisterer
的构造函数中调用了之前提到的SolverRegistry
类的AddCreator
函数,这个函数就是将刚才定义的Creator_SGDSolver
这个函数的指针存到g_registry指向的map里面。类似地,所有的Solver对应的cpp文件的末尾都调用了这个宏来完成注册,在所有的Solver都注册之后,我们就可以通过之前描述的方式,通过g_registry得到对应的Creator函数的指针,并通过调用这个Creator函数来构造对应的Solver。Register和Create对应的流程图如下所示:
SIGINT
和SIGHUP
信号的处理
Caffe在train或者test的过程中都有可能会遇到系统信号(用户按下ctrl+c或者关掉了控制的terminal),我们可以通过对sigint_effect
和sighup_effect
来设置遇到系统信号的时候希望进行的处理方式:
caffe train –solver=/path/to/solver.prototxt –sigint_effect=EFFECT –sighup_effect=EFFECT
在caffe.cpp中定义了一个GetRequesedAction函数来将设置的string类型的标志转变为枚举类型的变量:
1 caffe::SolverAction::Enum GetRequestedAction(
2 const std::string& flag_value) {
3 if (flag_value == "stop") {
4 return caffe::SolverAction::STOP;
5 }
6 if (flag_value == "snapshot") {
7 return caffe::SolverAction::SNAPSHOT;
8 }
9 if (flag_value == "none") {
10 return caffe::SolverAction::NONE;
11 }
12 LOG(FATAL) << "Invalid signal effect \""<< flag_value << "\" was specified";
13 }
14 // SolverAction::Enum的定义
15 namespace SolverAction {
16 enum Enum {
17 NONE = 0, // Take no special action.
18 STOP = 1, // Stop training. snapshot_after_train controls whether a
19 // snapshot is created.
20 SNAPSHOT = 2 // Take a snapshot, and keep training.
21 };
22 }
其中SolverAction::Enum的定义在solver.hpp中,这是一个定义为枚举类型的数据类型,只有三个可能的值,分别对应了三种处理系统信号的方式:NONE(忽略信号什么都不做)/STOP(停止训练)/SNAPSHOT(保存当前的训练状态,继续训练)。在caffe.cpp中的train函数里Solver设置如何处理系统信号的代码为:
1 caffe::SignalHandler signal_handler(
2 GetRequestedAction(FLAGS_sigint_effect),
3 GetRequestedAction(FLAGS_sighup_effect));
4
5 solver->SetActionFunction(signal_handler.GetActionFunction());
FLAGS_sigint_effect和FLAGS_sighup_effect是通过gflags定义和解析的两个Command Line Interface的输入参数,分别对应遇到sigint和sighup信号的处理方式,如果用户不设定(大部分时候我自己就没设定),sigint的默认值为”stop”,sighup的默认值为”snapshot”。GetRequestedAction
函数会将string类型的FLAGS_xx转为SolverAction::Enum类型,并用来定义一个SignalHandler
类型的对象signal_handler。我们可以看到这部分代码都依赖于SignalHandler
这个类的接口,我们先来看看这个类都做了些什么:
1 // header file
2 class SignalHandler {
3 public:
4 // Contructor. Specify what action to take when a signal is received.
5 SignalHandler(SolverAction::Enum SIGINT_action,
6 SolverAction::Enum SIGHUP_action);
7 ~SignalHandler();
8 ActionCallback GetActionFunction();
9 private:
10 SolverAction::Enum CheckForSignals() const;
11 SolverAction::Enum SIGINT_action_;
12 SolverAction::Enum SIGHUP_action_;
13 };
14 // source file
15 SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action,
16 SolverAction::Enum SIGHUP_action):
17 SIGINT_action_(SIGINT_action),
18 SIGHUP_action_(SIGHUP_action) {
19 HookupHandler();
20 }
21 void HookupHandler() {
22 if (already_hooked_up) {
23 LOG(FATAL) << "Tried to hookup signal handlers more than once.";
24 }
25 already_hooked_up = true;
26 struct sigaction sa;
27 sa.sa_handler = &handle_signal;
28 // ...
29 }
30 static volatile sig_atomic_t got_sigint = false;
31 static volatile sig_atomic_t got_sighup = false;
32 void handle_signal(int signal) {
33 switch (signal) {
34 case SIGHUP:
35 got_sighup = true;
36 break;
37 case SIGINT:
38 got_sigint = true;
39 break;
40 }
41 }
42 ActionCallback SignalHandler::GetActionFunction() {
43 return boost::bind(&SignalHandler::CheckForSignals, this);
44 }
45 SolverAction::Enum SignalHandler::CheckForSignals() const {
46 if (GotSIGHUP()) {
47 return SIGHUP_action_;
48 }
49 if (GotSIGINT()) {
50 return SIGINT_action_;
51 }
52 return SolverAction::NONE;
53 }
54 bool GotSIGINT() {
55 bool result = got_sigint;
56 got_sigint = false;
57 return result;
58 }
59 bool GotSIGHUP() {
60 bool result = got_sighup;
61 got_sighup = false;
62 return result;
63 }
64 // ActionCallback的含义
65 typedef boost::function<SolverAction::Enum()> ActionCallback;
SignalHandler
这个类有两个数据成员,都是SolverAction::Enum
类型的,分别对应sigint和sighup信号,在构造函数中,用解析FLAGS_xx得到的结果分别给两个成员赋值,然后调用了HookupHandler
函数,这个函数的主要作用是定义了一个sigaction
类型(应该是系统级别的代码)的对象sa,然后通过sa.sa_handler = &handle_signal来设置,当有遇到系统信号时,调用handle_signal
函数来处理,而我们可以看到这个函数的处理很简单,就是判断一下当前的信号是什么类型,如果是sigint就将全局的static变量got_sigint变为true,sighup的处理类似。
在根据用户设置(或者默认值)的参数定义了signal_handler之后,solver通过SetActionFunction
来设置了如何处理系统信号。这个函数的输入为signal_handler的GetActionFunction
的返回值,根据上面的代码我们可以看到,GetActionFunction
会返回signal_handler这个对象的CheckForSignals函数的地址(boost::bind的具体使用请参考boost官方文档)。而在Solver
的SetActionFunction
函数中只是简单的把Solver
的一个成员action_request_function_赋值为输入参数的值,以当前的例子来说就是,solver对象的action_request_function_指向了signal_handler对象的CheckForSignals函数的地址。其中的ActionCallback是一个函数指针类型,指向了参数为空,返回值为SolverAction::Enum类型的函数(boost::function具体用法参考官方文档)。
总结起来,我们通过定义一个SignalHandler
类型的对象,告知系统在遇到系统信号的时候回调handle_signal
函数来改变全局变量got_sigint和got_sighup的值,然后通过Solver
的接口设置了其遇到系统函数将调用signal_handler的Check函数,这个函数实际上就是去判断当前是否遇到了系统信号,如果遇到某个类型的信号,就返回我们之前设置的处理方式(SolverAction::Enum
类型)。剩余的具体处理再交给Solver
的其它函数,后面会具体分析。
Solver::Solve()
具体实现
Solve
函数实现了具体的网络的优化过程,下面我们来具体分析一下这部分的代码,分析见注释:
1 void Solver<Dtype>::Solve(const char* resume_file) {
2 // 检查当前是否是root_solver(多GPU模式下,只有root_solver才运行这一部分的代码)
3 CHECK(Caffe::root_solver());
4 // 然后输出learning policy(更新学习率的策略)
5 LOG(INFO) << "Solving " << net_->name();
6 LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
7 // requested_early_exit_`一开始被赋值为false,也就是现在没有要求在优化结束前退出
8 requested_early_exit_ = false;
9 // 判断`resume_file`这个指针是否NULL,如果不是则需要从resume_file存储的路径里读取之前训练的状态
10 if (resume_file) {
11 LOG(INFO) << "Restoring previous solver status from " << resume_file;
12 Restore(resume_file);
13 }
14 // 然后调用了'Step'函数,这个函数执行了实际的逐步的迭代过程
15 Step(param_.max_iter() - iter_);
16 // 迭代结束或者遇到系统信号提前结束后,判断是否需要在训练结束之后snapshot
17 // 这个可以在solver.prototxt里设置
18 if (param_.snapshot_after_train()
19 && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
20 Snapshot();
21 }
22 // 如果在`Step`函数的迭代过程中遇到了系统信号,且我们的处理方式设置为`STOP`,
23 // 那么`requested_early_exit_`会被修改为true,迭代提前结束,输出相关信息
24 if (requested_early_exit_) {
25 LOG(INFO) << "Optimization stopped early.";
26 return;
27 }
28 // 判断是否需要输出最后的loss
29 if (param_.display() && iter_ % param_.display() == 0) {
30 Dtype loss;
31 net_->ForwardPrefilled(&loss);
32 LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
33 }
34 // 判断是否需要最后Test
35 if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
36 TestAll();
37 }
38 LOG(INFO) << "Optimization Done.";
39 }
下面继续分析具体的迭代过程发生的Step
函数:
1 template <typename Dtype>
2 void Solver<Dtype>::Step(int iters) {
3 vector<Blob<Dtype>*> bottom_vec;
4 // 设置开始的迭代次数(如果是从之前的snapshot恢复的,那iter_等于snapshot时的迭代次数)和结束的迭代次数
5 const int start_iter = iter_;
6 const int stop_iter = iter_ + iters;
7 // 输出的loss为前average_loss次loss的平均值,在solver.prototxt里设置,默认为1,
8 // losses存储之前的average_loss个loss,smoothed_loss为最后要输出的均值
9 int average_loss = this->param_.average_loss();
10 vector<Dtype> losses;
11 Dtype smoothed_loss = 0;
12 // 迭代
13 while (iter_ < stop_iter) {
14 // 清空上一次所有参数的梯度
15 net_->ClearParamDiffs();
16 // 判断是否需要测试
17 if (param_.test_interval() && iter_ % param_.test_interval() == 0
18 && (iter_ > 0 || param_.test_initialization())
19 && Caffe::root_solver()) {
20 TestAll();
21 // 判断是否需要提前结束迭代
22 if (requested_early_exit_) {
23 break;
24 }
25 }
26 for (int i = 0; i < callbacks_.size(); ++i) {
27 callbacks_[i]->on_start();
28 }
29 // 判断当前迭代次数是否需要显示loss等信息
30 const bool display = param_.display() && iter_ % param_.display() == 0;
31 net_->set_debug_info(display && param_.debug_info());
32 Dtype loss = 0;
33 // iter_size也是在solver.prototxt里设置,实际上的batch_size=iter_size*网络定义里的batch_size,
34 // 因此每一次迭代的loss是iter_size次迭代的和,再除以iter_size,这个loss是通过调用`Net::ForwardBackward`函数得到的
35 // 这个设置我的理解是在GPU的显存不够的时候使用,比如我本来想把batch_size设置为128,但是会out_of_memory,
36 // 借助这个方法,可以设置batch_size=32,iter_size=4,那实际上每次迭代还是处理了128个数据
37 for (int i = 0; i < param_.iter_size(); ++i) {
38 loss += net_->ForwardBackward(bottom_vec);
39 }
40 loss /= param_.iter_size();
41 // 计算要输出的smoothed_loss,如果losses里还没有存够average_loss个loss则将当前的loss插入,如果已经存够了,则将之前的替换掉
42 if (losses.size() < average_loss) {
43 losses.push_back(loss);
44 int size = losses.size();
45 smoothed_loss = (smoothed_loss * (size - 1) + loss) / size;
46 } else {
47 int idx = (iter_ - start_iter) % average_loss;
48 smoothed_loss += (loss - losses[idx]) / average_loss;
49 losses[idx] = loss;
50 }
51 // 输出当前迭代的信息
52 if (display) {
53 LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
54 << ", loss = " << smoothed_loss;
55 const vector<Blob<Dtype>*>& result = net_->output_blobs();
56 int score_index = 0;
57 for (int j = 0; j < result.size(); ++j) {
58 const Dtype* result_vec = result[j]->cpu_data();
59 const string& output_name =
60 net_->blob_names()[net_->output_blob_indices()[j]];
61 const Dtype loss_weight =
62 net_->blob_loss_weights()[net_->output_blob_indices()[j]];
63 for (int k = 0; k < result[j]->count(); ++k) {
64 ostringstream loss_msg_stream;
65 if (loss_weight) {
66 loss_msg_stream << " (* " << loss_weight
67 << " = " << loss_weight * result_vec[k] << " loss)";
68 }
69 LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
70 << score_index++ << ": " << output_name << " = "
71 << result_vec[k] << loss_msg_stream.str();
72 }
73 }
74 }
75 for (int i = 0; i < callbacks_.size(); ++i) {
76 callbacks_[i]->on_gradients_ready();
77 }
78 // 执行梯度的更新,这个函数在基类`Solver`中没有实现,会调用每个子类自己的实现,后面具体分析`SGDSolver`的实现
79 ApplyUpdate();
80 // 迭代次数加1
81 ++iter_;
82 // 调用GetRequestedAction,实际是通过action_request_function_函数指针调用之前设置好(通过`SetRequestedAction`)的
83 // signal_handler的`CheckForSignals`函数,这个函数的作用是
84 // 会根据之前是否遇到系统信号以及信号的类型和我们设置(或者默认)的方式返回处理的方式
85 SolverAction::Enum request = GetRequestedAction();
86 // 判断当前迭代是否需要snapshot,如果request等于`SNAPSHOT`则也需要
87 if ((param_.snapshot()
88 && iter_ % param_.snapshot() == 0
89 && Caffe::root_solver()) ||
90 (request == SolverAction::SNAPSHOT)) {
91 Snapshot();
92 }
93 // 如果request为`STOP`则修改`requested_early_exit_`为true,之后就会提前结束迭代
94 if (SolverAction::STOP == request) {
95 requested_early_exit_ = true;
96 break;
97 }
98 }
99 }
SGDSolver::ApplyUpdate
具体实现
每一组网络中的参数的更新都是在不同类型的Solver自己实现的ApplyUpdate
函数中完成的,下面我们就以最常用的SGD为例子来分析这个函数具体的功能:
1 template <typename Dtype>
2 void SGDSolver<Dtype>::ApplyUpdate() {
3 CHECK(Caffe::root_solver());
4 // GetLearningRate根据设置的lr_policy来计算当前迭代的learning rate的值
5 Dtype rate = GetLearningRate();
6 // 判断是否需要输出当前的learning rate
7 if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
8 LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
9 }
10 // 避免梯度爆炸,如果梯度的二范数超过了某个数值则进行scale操作,将梯度减小
11 ClipGradients();
12 // 对所有可更新的网络参数进行操作
13 for (int param_id = 0; param_id < this->net_->learnable_params().size();
14 ++param_id) {
15 // 将第param_id个参数的梯度除以iter_size,这一步的作用是保证实际的batch_size=iter_size*设置的batch_size
16 Normalize(param_id);
17 // 将正则化部分的梯度降入到每个参数的梯度中
18 Regularize(param_id);
19 // 计算SGD算法的梯度(momentum等)
20 ComputeUpdateValue(param_id, rate);
21 }
22 // 调用`Net::Update`更新所有的参数
23 this->net_->Update();
24 }
下面我们继续具体分析一下Normalize
/Regularize
/ComputeUpdateValue
的实现,我们均以CPU的代码为例子,GPU部分的处理原理是一样的:
Normalize
1 template <typename Dtype>
2 void SGDSolver<Dtype>::Normalize(int param_id) {
3 // 如果iter_size的值为1,则不需要任何处理直接return
4 if (this->param_.iter_size() == 1) { return; }
5 // 通过net_返回所有可以学习的参数,是一个vector<shared_ptr<Blob<Dtype> > >
6 const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
7 // 要乘以的系数等于1/iter_size
8 const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
9 switch (Caffe::mode()) {
10 case Caffe::CPU: {
11 // caffe_scal在/CAFFE_ROOT/src/caffe/util/math_functions.cpp中
12 // 是blas的scale函数的一个封装,第一个参数是数据的个数,第二个参数是乘以的系数,
13 // 第三个参数是数据的指针
14 caffe_scal(net_params[param_id]->count(), accum_normalization,
15 net_params[param_id]->mutable_cpu_diff());
16 break;
17 }
18 case Caffe::GPU: {
19 // GPU代码略
20 }
21 }
Regularize
1 template <typename Dtype>
2 void SGDSolver<Dtype>::Regularize(int param_id) {
3 // 获取所有可以学习的参数的vector
4 const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
5 // 获取所有的参数对应的weight_decay的vector
6 const vector<float>& net_params_weight_decay =
7 this->net_->params_weight_decay();
8 // 模型整体的weight_decay数值
9 Dtype weight_decay = this->param_.weight_decay();
10 // 获取正则化的类型:L1 或 L2
11 string regularization_type = this->param_.regularization_type();
12 // 实际的weight_decay等于整体模型的数值乘以具体每个参数的数值
13 Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
14 switch (Caffe::mode()) {
15 case Caffe::CPU: {
16 // 如果weight_decay不为0,则计算
17 if (local_decay) {
18 if (regularization_type == "L2") {
19 // L2的梯度为diff_ = weight_decay*data_ + diff_
20 // caffe_axpy的功能是 y = a*x + y
21 // 第一个参数是数据的个数,第二个是上式的a,第三个是x的指针,第四个是y的指针
22 caffe_axpy(net_params[param_id]->count(),
23 local_decay,
24 net_params[param_id]->cpu_data(),
25 net_params[param_id]->mutable_cpu_diff());
26 } else if (regularization_type == "L1") {
27 // L1的梯度为diff_ = diff_ + sign(data_)
28 // temp_ = sign(data_)
29 caffe_cpu_sign(net_params[param_id]->count(),
30 net_params[param_id]->cpu_data(),
31 temp_[param_id]->mutable_cpu_data());
32 // 将temp_加到diff_中 diff_ = weight_decay*temp_ + diff_
33 caffe_axpy(net_params[param_id]->count(),
34 local_decay,
35 temp_[param_id]->cpu_data(),
36 net_params[param_id]->mutable_cpu_diff());
37 } else {
38 LOG(FATAL) << "Unknown regularization type: " << regularization_type;
39 }
40 }
41 break;
42 }
43 // GPU代码略
44 }
ComputeUpdatedValue
1 template <typename Dtype>
2 void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
3 // 获取所有可以更新的参数的vector
4 const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
5 // 获取所有参数对应的learning_rate的vector
6 const vector<float>& net_params_lr = this->net_->params_lr();
7 // 获取momentum数值
8 Dtype momentum = this->param_.momentum();
9 // 实际的learning_rate为全局的learning_rate乘以每个参数对应的learning_rate
10 Dtype local_rate = rate * net_params_lr[param_id];
11 switch (Caffe::mode()) {
12 case Caffe::CPU: {
13 // 关于SGD的公式参考caffe官网tutorial的Solver部分
14 // history_存储了上一次的梯度,下面这个函数:
15 // history_ = learning_rate*diff_ + momentum*history
16 caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
17 net_params[param_id]->cpu_diff(), momentum,
18 history_[param_id]->mutable_cpu_data());
19 // 把当前的梯度拷贝给参数Blob的diff_
20 caffe_copy(net_params[param_id]->count(),
21 history_[param_id]->cpu_data(),
22 net_params[param_id]->mutable_cpu_diff());
23 break;
24 }
25 case Caffe::GPU: {
26 // GPU代码略
27 }
28 }
至此Solver
主要的代码都已经分析完了,总结起来主要有:(1)solver_factory的register和create不同类型Solver的机制,(2)通过signal_handler来获取系统信号,并根据用户或默认的设置进行相应的处理,(3)Solver::Solve
函数的具体实现的分析,(4)SGDSolver::ApplyUpdate
函数的具体实现。前面三个部分都属于基类的,最后一个是SGDSolver这个子类的,如果用户想要实现自己的Solver类,也应该类似地去继承基类,并实现自己的ApplyUpdate
函数,在代码的末尾通过register宏完成注册,便可以被成功的调用。
在train()中的solver->Solve()中的Step()中的ForwardBackward()中进行的各个layers的计算