文章目录
- 一、MNN 资料
- 二、使用示例
- 三、源码分析
- 1、createFromFile、createFromBuffer
- 1.1 Content
- 1.2 createFromBufferInternal
- 1.3 Net
- 1.4 Interpreter
- 1.5 Interpreter::Interpreter
- 2、createRuntime
- 2.1 RuntimeInfo
- 2.2 Schedule::getApprociateType
- 2.2.1 MNNGetExtraRuntimeCreator
- 2.2.1.1 registerBackend
- 2.2.1.2 GetExtraCreator
- 2.3 RuntimeFactory::create
- 2.3.1 VulkanRuntimeCreator
- 2.3.2 VulkanRuntimeCreator
- createSession
- runSession
一、MNN 资料
MNN GitHub
中文文档
二、使用示例
// 创建解释器 Interpreterauto net_ = Interpreter* createFromFile(const char* file);// 创建运行时 RuntimeScheduleConfig config;config.numberThread = 4;auto runtimeInfo = Interpreter::createRuntime({config}); // 创建会话 Sessionauto session = net_->createSession(config, runtimeInfo);// 执行推理net_->runSession(session1);
三、源码分析
1、createFromFile、createFromBuffer
createFromFile、createFromBuffer 把模型读入,并放置在结构体 Content 的 buffer 中。
// source/core/Interpreter.cpp
Interpreter* Interpreter::createFromFile(const char* file) {Content* net = loadModelFile(file);if (nullptr == net) {return nullptr;}return createFromBufferInternal(net, true);
}Interpreter* Interpreter::createFromBuffer(const void* buffer, size_t size) {if (nullptr == buffer || 0 == size) {MNN_PRINT("Buffer is null for create interpreter\n");return nullptr;}auto net = new Content;net->buffer.reset((int)size);if (nullptr == net->buffer.get()) {MNN_ERROR("Memory not enought!\n");return nullptr;}::memcpy(net->buffer.get(), buffer, size);return createFromBufferInternal(net, true);
}
1.1 Content
// source/core/Interpreter.cpp
struct Content {AutoStorage<uint8_t> buffer;const Net* net = nullptr;std::vector<std::unique_ptr<Session>> sessions;std::map<Tensor*, const Session*> tensorMap;Session::ModeGroup modes;AutoStorage<uint8_t> cacheBuffer;std::string cacheFile;std::mutex lock;size_t lastCacheSize = 0;std::string bizCode;std::string uuid;std::string externalFile;
#ifdef MNN_INTERNAL_ENABLEDstd::map<std::string, std::string> basicLogginData;std::map<const Session*, std::tuple<int, int>> sessionInfo;
#endif
};
1.2 createFromBufferInternal
// source/core/Interpreter.cpp
Interpreter* Interpreter::createFromBufferInternal(Content* net, bool enforceAuth) {if (nullptr == net) {MNN_PRINT("Buffer is null for create interpreter\n");return nullptr;}
#ifndef MNN_BUILD_MINI// 验证模型flatbuffers::Verifier verify((const uint8_t*)(net->buffer.get()), net->buffer.size());if (false == VerifyNetBuffer(verify)) {MNN_PRINT("Invalidate buffer to create interpreter\n");delete net;return nullptr;}
#endif// 获取网络net->net = GetNet(net->buffer.get());if (nullptr == net->net->oplists()) {MNN_ERROR("Model has no oplist\n");delete net;return nullptr;}// 验证模型算子int opSize = net->net->oplists()->size();for (int i = 0; i < opSize; ++i) {auto op = net->net->oplists()->GetAs<Op>(i);if (nullptr == op || nullptr == op->outputIndexes()) {MNN_ERROR("Invalid Model, the %d op is empty\n", i);delete net;return nullptr;}}// 新建解释器return new Interpreter(net);
}
1.3 Net
// schema/current/MNN_generated.h
struct Net FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {typedef NetT NativeTableType;static const flatbuffers::TypeTable *MiniReflectTypeTable() {return NetTypeTable();}const flatbuffers::String *bizCode() const {return GetPointer<const flatbuffers::String *>(4);}const flatbuffers::Vector<flatbuffers::Offset<TensorDescribe>> *extraTensorDescribe() const {return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TensorDescribe>> *>(6);}const ExtraInfo *extraInfo() const {return GetPointer<const ExtraInfo *>(8);}const flatbuffers::Vector<flatbuffers::Offset<Op>> *oplists() const {return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<Op>> *>(10);}const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *outputName() const {return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(12);}ForwardType preferForwardType() const {return static_cast<ForwardType>(GetField<int8_t>(14, 0));}NetSource sourceType() const {return static_cast<NetSource>(GetField<int8_t>(16, 0));}const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *tensorName() const {return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(18);}int32_t tensorNumber() const {return GetField<int32_t>(20, 0);}Usage usage() const {return static_cast<Usage>(GetField<int8_t>(22, 0));}const flatbuffers::Vector<flatbuffers::Offset<SubGraphProto>> *subgraphs() const {return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<SubGraphProto>> *>(24);}const flatbuffers::String *mnn_uuid() const {return GetPointer<const flatbuffers::String *>(26);}bool Verify(flatbuffers::Verifier &verifier) const {return VerifyTableStart(verifier) &&VerifyOffset(verifier, 4) &&verifier.VerifyString(bizCode()) &&VerifyOffset(verifier, 6) &&verifier.VerifyVector(extraTensorDescribe()) &&verifier.VerifyVectorOfTables(extraTensorDescribe()) &&VerifyOffset(verifier, 8) &&verifier.VerifyTable(extraInfo()) &&VerifyOffset(verifier, 10) &&verifier.VerifyVector(oplists()) &&verifier.VerifyVectorOfTables(oplists()) &&VerifyOffset(verifier, 12) &&verifier.VerifyVector(outputName()) &&verifier.VerifyVectorOfStrings(outputName()) &&VerifyField<int8_t>(verifier, 14) &&VerifyField<int8_t>(verifier, 16) &&VerifyOffset(verifier, 18) &&verifier.VerifyVector(tensorName()) &&verifier.VerifyVectorOfStrings(tensorName()) &&VerifyField<int32_t>(verifier, 20) &&VerifyField<int8_t>(verifier, 22) &&VerifyOffset(verifier, 24) &&verifier.VerifyVector(subgraphs()) &&verifier.VerifyVectorOfTables(subgraphs()) &&VerifyOffset(verifier, 26) &&verifier.VerifyString(mnn_uuid()) &&verifier.EndTable();}NetT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;void UnPackTo(NetT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;static flatbuffers::Offset<Net> Pack(flatbuffers::FlatBufferBuilder &_fbb, const NetT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
1.4 Interpreter
/** net data holder. multiple sessions could share same net. */
class MNN_PUBLIC Interpreter {
public:/*** @brief create net from file.* @param file given file.* @return created net if success, NULL otherwise.*/static Interpreter* createFromFile(const char* file);/*** @brief create net from buffer.* @param buffer given data buffer.* @param size size of data buffer.* @return created net if success, NULL otherwise.*/static Interpreter* createFromBuffer(const void* buffer, size_t size);~Interpreter();public:/*** @brief create session with schedule config. created session will be managed in net.* @param config session schedule config.* @return created session if success, NULL otherwise.*/Session* createSession(const ScheduleConfig& config);/*** @brief create multi-path session with schedule configs. created session will be managed in net.* @param configs session schedule configs.* @return created session if success, NULL otherwise.*/Session* createMultiPathSession(const std::vector<ScheduleConfig>& configs);/*** @brief release session.* @param session given session.* @return true if given session is held by net and is freed.*/bool releaseSession(Session* session);/*** @brief call this function to get tensors ready. output tensor buffer (host or deviceId) should be retrieved* after resize of any input tensor.* @param session given session.*/void resizeSession(Session* session);/*** @brief call this function if don't need resize or create session any more, it will save a few memory that equal* to the size of model buffer*/void releaseModel();/*** @brief Get the model buffer for user to save* @return std::make_pair(modelBuffer, modelSize).* @example:* std::ofstream output("trainResult.alinn")* auto buffer = net->getModelBuffer();* output.write((const char*)buffer.first, buffer.second);*/std::pair<const void*, size_t> getModelBuffer() const;/*** @brief update Session's Tensor to model's Const Op* @param session given session.* @return result of running.*/ErrorCode updateSessionToModel(Session* session);/*** @brief run session.* @param session given session.* @return result of running.*/ErrorCode runSession(Session* session) const;/** @brief run session.* @param session given session.* @param before callback before each op. return true to run the op; return false to skip the op.* @param after callback after each op. return true to continue running; return false to interrupt the session.* @param sync synchronously wait for finish of execution or not.* @return result of running.*/ErrorCode runSessionWithCallBack(const Session* session, const TensorCallBack& before, const TensorCallBack& end,bool sync = false) const;/** @brief run session.* @param session given session.* @param before callback before each op. return true to run the op; return false to skip the op.* @param after callback after each op. return true to continue running; return false to interrupt the session.* @param sync synchronously wait for finish of execution or not.* @return result of running.*/ErrorCode runSessionWithCallBackInfo(const Session* session, const TensorCallBackWithInfo& before,const TensorCallBackWithInfo& end, bool sync = false) const;/*** @brief get input tensor for given name.* @param session given session.* @param name given name. if NULL, return first input.* @return tensor if found, NULL otherwise.*/Tensor* getSessionInput(const Session* session, const char* name);/*** @brief get output tensor for given name.* @param session given session.* @param name given name. if NULL, return first output.* @return tensor if found, NULL otherwise.*/Tensor* getSessionOutput(const Session* session, const char* name);/*** @brief get all input tensors.* @param session given session.* @return all input tensors mapped with name.*/const std::map<std::string, Tensor*>& getSessionOutputAll(const Session* session) const;/*** @brief get all output tensors.* @param session given session.* @return all output tensors mapped with name.*/const std::map<std::string, Tensor*>& getSessionInputAll(const Session* session) const;public:/*** @brief resize given tensor.* @param tensor given tensor.* @param dims new dims. at most 6 dims.*/void resizeTensor(Tensor* tensor, const std::vector<int>& dims);/*** @brief resize given tensor by nchw.* @param batch / N.* @param channel / C.* @param height / H.* @param width / W*/void resizeTensor(Tensor* tensor, int batch, int channel, int height, int width);/*** @brief get backend used to create given tensor.* @param session given session.* @param tensor given tensor.* @return backend used to create given tensor, may be NULL.*/const Backend* getBackend(const Session* session, const Tensor* tensor) const;/*** @brief get business code (model identifier).* @return business code.*/const char* bizCode() const;private:static Interpreter* createFromBufferInternal(Content* net);Content* mNet = nullptr;Interpreter(Content* net);Interpreter(const Interpreter&) = delete;Interpreter(const Interpreter&&) = delete;Interpreter& operator=(const Interpreter&) = delete;Interpreter& operator=(const Interpreter&&) = delete;
};
} // namespace MNN
1.5 Interpreter::Interpreter
把 Content 放入到 Interpreter 中
Interpreter::Interpreter(Content* net) {MNN_ASSERT(nullptr != net);mNet = net;// Store bizcode and uuid because we need them even after `releaseModel` is called.mNet->bizCode = std::string(mNet->net->bizCode() ? mNet->net->bizCode()->c_str() : "");mNet->uuid = std::string(mNet->net->mnn_uuid() ? mNet->net->mnn_uuid()->c_str() : "");
#ifdef MNN_INTERNAL_ENABLEDmNet->basicLogginData = getBasicLoggingData();mNet->basicLogginData.emplace("ModelVersion", getModelVersion());
#endif
}
2、createRuntime
根据 ScheduleConfig 创建运行时 Runtime 。RuntimeInfo 的定义见 2.1,其 first 用来存放根据 configs 创建的 Runtime(如 VulkanRuntime,CUDARuntime),它的 second 存放的是默认 Runtime,一般为 CPURuntime 。
RuntimeInfo Interpreter::createRuntime(const std::vector<ScheduleConfig>& configs) {RuntimeInfo res;// 根据 configs 创建的 Runtime 存放在这里auto& mRuntimes = res.first;for (auto& config : configs) {Backend::Info compute;compute.type = Schedule::getApprociateType(config);compute.numThread = config.numThread;if(config.type == MNN_FORWARD_AUTO) {if(compute.type == MNN_FORWARD_OPENCL || compute.type == MNN_FORWARD_METAL) {// AUTO set default gpu-mode MNN_GPU_TUNING_FASTcompute.numThread = 16;}}compute.user = config.backendConfig;if (mRuntimes.find(compute.type) == mRuntimes.end()) {auto newBn = RuntimeFactory::create(compute);if (nullptr == newBn) {MNN_ERROR("Can't create Runtime: %s\n", EnumNameForwardType((ForwardType)compute.type));continue;}mRuntimes[compute.type].reset(newBn);}}_getDefaultBackend(res);return res;
}
2.1 RuntimeInfo
typedef std::pair< std::map<MNNForwardType, std::shared_ptr<Runtime>>, \std::shared_ptr<Runtime>> RuntimeInfo;
2.2 Schedule::getApprociateType
// source/core/Schedule.cpp
MNNForwardType Schedule::getApprociateType(const ScheduleConfig& config) {MNNForwardType type = config.type;// FIXME: Support Auto determine// MNN_FORWARD_AUTO 的处理逻辑if (MNN_FORWARD_AUTO == config.type) {//Define Auto choose prioritystd::vector<MNNForwardType> priorityList;priorityList.push_back(MNN_FORWARD_USER_0); //HIAIpriorityList.push_back(MNN_FORWARD_NN); //CoreMLpriorityList.push_back(MNN_FORWARD_USER_1); //TensoRTpriorityList.push_back(MNN_FORWARD_CUDA); //CUDApriorityList.push_back(MNN_FORWARD_OPENCL); //OpenCLpriorityList.push_back(MNN_FORWARD_METAL); //METALpriorityList.push_back(MNN_FORWARD_VULKAN); //VulkanpriorityList.push_back(MNN_FORWARD_CPU); //CPUfor (auto bn : priorityList) {if (MNNGetExtraRuntimeCreator(bn) != nullptr) {type = (MNNForwardType)bn;break;}}}auto creator = MNNGetExtraRuntimeCreator(type);if (nullptr == creator) {MNN_PRINT("Can't Find type=%d backend, use %d instead\n", type, config.backupType);type = config.backupType;} else {// TODO : Not Limited to openclif(type == MNN_FORWARD_OPENCL && config.backendConfig != nullptr) {if(config.backendConfig->power == BackendConfig::Power_Low) {Backend::Info info;info.type = type;std::shared_ptr<Runtime> bn(creator->onCreate(info));bool isSupportLowPower = bn->onGetRuntimeStatus(RuntimeStatus::STATUS_SUPPORT_POWER_LOW);if(!isSupportLowPower) {MNN_PRINT("type=%d backend don't Support Low Power, use %d instead\n", type, config.backupType);type = config.backupType;}}}}return type;
}
2.2.1 MNNGetExtraRuntimeCreator
// source/core/Backend.cpp
const RuntimeCreator* MNNGetExtraRuntimeCreator(MNNForwardType type) {registerBackend();// 获取运行时创建器// (std::map<MNNForwardType, std::pair<const RuntimeCreator*, bool>>类型)auto& gExtraCreator = GetExtraCreator();// 根据推理类型查找运行时创建器auto iter = gExtraCreator.find(type);if (iter == gExtraCreator.end()) {return nullptr;}// iter->second 的类型为 std::pair<const RuntimeCreator* creator, bool needCheck>if (!iter->second.second) {return iter->second.first;}Backend::Info info;info.type = type;std::shared_ptr<Runtime> bn(iter->second.first->onCreate(info));if (nullptr != bn.get()) {return iter->second.first;}return nullptr;
}
2.2.1.1 registerBackend
static std::once_flag s_flag;
void registerBackend() {std::call_once(s_flag, [&]() {
#ifdef MNN_INTERNAL_ENABLEDLogInit();
#endif// 注册 CPU 的运行时创建器和一些核心函数registerCPURuntimeCreator();
#ifndef MNN_BUILD_MINISizeComputerSuite::init();// 图像着色器 ?GeometryComputer::init();
#endif
#if MNN_COREML_ENABLEDregisterCoreMLRuntimeCreator();
#endif
#ifdef MNN_NNAPI_ENABLEDregisterNNAPIRuntimeCreator();
#endif
#if MNN_OPENCL_ENABLEDOpenCL::registerOpenCLRuntimeCreator();
#endif
#if MNN_METAL_ENABLEDregisterMetalRuntimeCreator();
#endif});
}
2.2.1.2 GetExtraCreator
static std::map<MNNForwardType, std::pair<const RuntimeCreator*, bool>>& GetExtraCreator() {static std::once_flag gInitFlag;static std::map<MNNForwardType, std::pair<const RuntimeCreator*, bool>>* gExtraCreator;std::call_once(gInitFlag,[&]() { gExtraCreator = new std::map<MNNForwardType, std::pair<const RuntimeCreator*, bool>>; });return *gExtraCreator;
}
获取 RuntimeCreator ,然后根据类型创建对应的 Runtime。gExtraCreator 是一个 map 类型,其是通过函数 MNNInsertExtraRuntimeCreator 进行注册的。
// source/core/Backend.cpp
bool MNNInsertExtraRuntimeCreator(MNNForwardType type, const RuntimeCreator* creator, bool needCheck) {auto& gExtraCreator = GetExtraCreator();if (gExtraCreator.find(type) != gExtraCreator.end()) {MNN_ASSERT(false && "duplicate type");return false;}gExtraCreator.insert(std::make_pair(type, std::make_pair(creator, needCheck)));return true;
}
- VULKAN 注册
// source/backend/vulkan/runtime/VulkanRuntime.cpp
static bool gResistor = []() {MNNInsertExtraRuntimeCreator(MNN_FORWARD_VULKAN, new VulkanRuntimeCreator, false);return false;
}();
- CUDA 注册
// source/backend/cuda/Register.cpp
static const auto __cuda_global_initializer = []() {MNNInsertExtraRuntimeCreator(MNN_FORWARD_CUDA, new CUDARuntimeCreator, false);return true;
}();
- OPENGL 注册
// source/backend/opengl/GLBackend.cpp
bool placeholder = []() {static std::once_flag createOnce;std::call_once(createOnce, []() {MNNInsertExtraRuntimeCreator(MNN_FORWARD_OPENGL, new GLRuntimeCreator, false);});return true;
}();
- Metal 注册,在 2.4 registerBackend 中主动调用
// source/backend/metal/MetalBackend.mm
void registerMetalRuntimeCreator() {// according to// https://developer.apple.com/library/archive/documentation/DeviceInformation/Reference/iOSDeviceCompatibility/HardwareGPUInformation/HardwareGPUInformation.html// not all device with iOS 8+ supports metal.id<MTLDevice> device = MTLCreateSystemDefaultDevice();if (nil != device) {registerMetalOps();
#ifdef MNN_SUPPORT_RENDERregisterMetalRenderOps();
#endifMNNInsertExtraRuntimeCreator(MNN_FORWARD_METAL, new MetalRuntimeCreator(device), false);} else {MNN_ERROR("Init Metal Error\n");}
}
RuntimeInfo Interpreter::createRuntime(const std::vector<ScheduleConfig>& configs) {RuntimeInfo res;// 根据 configs 创建的 Runtime 存放在这里auto& mRuntimes = res.first;for (auto& config : configs) {Backend::Info compute;compute.type = Schedule::getApprociateType(config);compute.numThread = config.numThread;if(config.type == MNN_FORWARD_AUTO) {if(compute.type == MNN_FORWARD_OPENCL || compute.type == MNN_FORWARD_METAL) {// AUTO set default gpu-mode MNN_GPU_TUNING_FASTcompute.numThread = 16;}}compute.user = config.backendConfig;if (mRuntimes.find(compute.type) == mRuntimes.end()) {auto newBn = RuntimeFactory::create(compute);if (nullptr == newBn) {MNN_ERROR("Can't create Runtime: %s\n", EnumNameForwardType((ForwardType)compute.type));continue;}mRuntimes[compute.type].reset(newBn);}}_getDefaultBackend(res);return res;
}
2.3 RuntimeFactory::create
Runtime* RuntimeFactory::create(const Backend::Info& info) {auto creator = MNNGetExtraRuntimeCreator(info.type);if (nullptr == creator) {MNN_PRINT("Create Runtime Failed because no creator for %d\n", info.type);return nullptr;}// 调用具体 RuntimeCreator,如 VulkanRuntimeCreator,MetalRuntimeCreator,GLRuntimeCreatorauto runtime = creator->onCreate(info);if (nullptr == runtime) {MNN_PRINT("Create Runtime failed, the creator return nullptr, type = %d\n", info.type);}return runtime;
}
2.3.1 VulkanRuntimeCreator
若 RuntimeFactory::create 中 info.type 为 MNN_FORWARD_VULKAN,则 creator->onCreate(info) 实际调用的是 VulkanRuntimeCreator::onCreate 函数。
// source/backend/vulkan/runtime/VulkanRuntime.cpp
class VulkanRuntimeCreator : public RuntimeCreator {
public:virtual Runtime* onCreate(const Backend::Info& info) const {// 初始化 Vulkan 库,获取相应的 API 函数if (InitVulkan()) {if (_testVulkan()) {// 创建 Vulkan 运行时return new VulkanRuntime(info);}}return nullptr;}virtual bool onValid(Backend::Info& info) const {return true;}
};static bool gResistor = []() {MNNInsertExtraRuntimeCreator(MNN_FORWARD_VULKAN, new VulkanRuntimeCreator, false);return false;
}();
}
2.3.2 VulkanRuntimeCreator
若 RuntimeFactory::create 中 info.type 为 MNN_FORWARD_CPU,则 creator->onCreate(info) 实际调用的是 CPURuntimeCreator ::onCreate 函数。
class CPURuntimeCreator : public RuntimeCreator {
public:virtual Runtime* onCreate(const Backend::Info& info) const override {return new CPURuntime(info);}
};#ifdef MNN_SUPPORT_BF16
extern void registerBF16Backend();
#endif
#ifdef ENABLE_ARMV82
extern void registerArm82RuntimeCreator();
#endif
void registerCPURuntimeCreator() {CPUBackend::initCreatorMap();registerCPUOps();
#ifdef MNN_SUPPORT_BF16registerBF16Backend();
#endif
#ifdef MNN_USE_ARMV82registerArm82RuntimeCreator();
#endif// TODO: Merge _initCoreFunction MNNFunctionInit and cpuinfo_arm_initMNNCoreFunctionInit();MNNInsertExtraRuntimeCreator(MNN_FORWARD_CPU, new CPURuntimeCreator);
};
createSession
runSession
☆