工厂注册类
利用模版形式注册类
#include <iostream>
#include <memory>
#include <functional>
namespace cyn {//自定义断言
//#ifndef _DEBUG // _RELEASE 或者 _DEBUG ,根据你的编译器/构建系统
#ifdef _DEBUG // _RELEASE 或者 _DEBUG ,根据你的编译器/构建系统
#define CHECK(cond) \do { \if (!(cond)) { \std::cerr << "Assertion failed: (" << #cond << "), function " << __FUNCTION__ \<< ", file " << __FILE__ << ", line " << __LINE__ << "." << std::endl; \
} \
} while (0)
#define CHECK_EX(cond, msg) \do { \if (!(cond)) { \std::cerr << "Assertion failed: (" << #cond << "), function " << __FUNCTION__ \<< ", file " << __FILE__ << ", line " << __LINE__ << "." << std::endl; \std::cerr << "Message: " << msg << std::endl; \throw std::runtime_error(msg); \
} \
} while (0)
#else
#define CHECK(cond) \do { \\
} while (0)
#define CHECK_EX(cond, msg) \do { \\
} while (0)
#endif// 参数基类
class LayerParams
{
public:virtual ~LayerParams() = default;
};
// 具体参数类型
class DenseLayerParams : public LayerParams {
public:int units;float learning_rate;DenseLayerParams(int units, float learning_rate) : units(units), learning_rate(learning_rate) {}
};// 基类
class Layer
{
public:Layer(const std::string& name,const std::shared_ptr<LayerParams>& params) : name(name),layer_params(params){}virtual int forward(int){return 0;};public:int input_;std::string name;std::shared_ptr<LayerParams> layer_params;
};//获取类型名字
template <typename T>
constexpr auto class_type_name() noexcept
{std::string_view name, prefix, suffix;
#ifdef _MSC_VERname = __FUNCSIG__;prefix = "auto __cdecl cyn::class_type_name<class cyn::";suffix = ">(void) noexcept";
#elif defined(__GNUC__)name = __PRETTY_FUNCTION__;prefix = "constexpr auto cyn::class_type_name() [with T = cyn::";suffix = "]";
#elif defined(__clang__)name = __PRETTY_FUNCTION__;prefix = "constexpr auto cyn::class_type_name() [with T = cyn::";suffix = "]";#endifname.remove_prefix(prefix.size());name.remove_suffix(suffix.size());return name;
}// 定义一个注册表基类模版
template <typename T, typename... Args>
class Factory
{using CreateFunction = std::function<std::shared_ptr<T>(Args...)>;Factory(){}public:virtual ~Factory() = default;static Factory& Instance(){static Factory instance;return instance;}void Register(const std::string & name,const CreateFunction& create){std::cout <<"Factory Register name:"<< std::string(name) << std::endl;registry[std::string(name)] = create;}std::shared_ptr<T> Create(const std::string& name,Args... args){if (registry.find(name) != registry.end()){std::cout <<"Create succeed:"<< std::string(name) << std::endl;return registry[name](args...);}std::cout <<"Create fail:"<< std::string(name) << std::endl;return nullptr;}
private:std::unordered_map<std::string, CreateFunction> registry;
};
//注册类 统一接口
template <typename Base, typename Impl, typename... Args>
class Register
{
public:explicit Register(){std::cout<<"Register name: "<<std::string(class_type_name<Impl>())<<std::endl;//流程6Factory<Base, Args...>::Instance().Register(std::string(class_type_name<Impl>()), [](Args... args){return std::shared_ptr<Base>(new Impl(args...));});}
};// LayerSub 在这很关键,实现了继承LayerSub (也就是继承基类)自动注册
//注册Layer子类,实现继承自动注册 T 为Layer子类类型
/** 继承LayerSub-构造函数->(void)registered->trigger()->_register(str)*/
template <typename T>
class LayerSub : public Layer
{//流程4static bool trigger(){//流程5Register<Layer,T,std::shared_ptr<LayerParams>> _register;return true;}
protected://流程2LayerSub(const std::shared_ptr<LayerParams>& params) : Layer(std::string(class_type_name<T>()),params){std::cout<<"LayerSub:"<<std::string(class_type_name<T>())<<std::endl;//流程3(void)registered;}public:static bool registered;};
template <typename T>
bool LayerSub<T>::registered = LayerSub<T>::trigger();//定义Layer类 创建接口
typedef Factory<Layer, std::shared_ptr<LayerParams>> LayerFactory;// 具体层类型class DenseLayer : public LayerSub<DenseLayer>
{
public:DenseLayer(const std::shared_ptr<LayerParams>& params) : LayerSub(params){std::cout<<"DenseLayer constructor!"<<std::endl;std::shared_ptr<DenseLayerParams> p = std::dynamic_pointer_cast<DenseLayerParams>(params);//检查参数类型 - 调试用CHECK_EX(p,"error");//打印参数值std::cout<<p->units<<std::endl;std::cout<<p->learning_rate<<std::endl;}int forward(int s) override {std::cout<<"wo shi zi lei!"<<s<<std::endl;return s;}
};}int main(int argc, char *argv[])
{QCoreApplication a(argc, argv);std::shared_ptr<cyn::DenseLayerParams> param = std::make_shared<cyn::DenseLayerParams>(100,1);auto ss = cyn::LayerFactory::Instance().Create(std::string("DenseLayer"), param);ss->forward(1);return a.exec();
}
输出结果:
Register name: DenseLayer
Factory Register name:DenseLayer
Create succeed:DenseLayer
LayerSub:DenseLayer
DenseLayer constructor!
100
1
wo shi zi lei!1