教程 | 如何利用C++搭建个人专属的TensorFlow

在开始之前,首先看一下最终成型的代码:

  1. 分支与特征后端(https://github.com/OneRaynyDay/autodiff/tree/eigen)
  2. 仅支持标量的分支(https://github.com/OneRaynyDay/autodiff/tree/master)

这个项目是我与 Minh Le 一起完成的。

为什么?

如果你修习的是计算机科学(CS)的人的话,你可能听说过这个短语「不要自己动手____」几千次了。它包含了加密、标准库、解析器等等。我想到现在为止,它也应该包含了机器学习库(ML library)。

不管现实是怎么样的,这个震撼的课程都值得我们去学习。人们现在把 TensorFlow 和类似的库当作理所当然了。他们把它看作黑盒子并让它运行起来,但是并没有多少人知道在这背后的运行原理。这只是一个非凸(Non-convex)的优化问题!请停止对代码无意义的胡搞——仅仅只是为了让代码看上去像是正确的。
创一个小群,供大家学习交流聊天
如果有对学C++方面有什么疑惑问题的,或者有什么想说的想聊的大家可以一起交流学习一起进步呀。
也希望大家对学C++能够持之以恒
C++爱好群,
如果你想要学好C++最好加入一个组织,这样大家学习的话就比较方便,还能够共同交流和分享资料,给你推荐一个学习的组织:快乐学习C++组织 可以点击组织二字,可以直达请添加链接描述

教程 | 如何利用C++搭建个人专属的TensorFlow

TensorFlow

在 TensorFlow 的代码里,有一个重要的组件,允许你将计算串在一起,形成一个称为「计算图」的东西。这个计算图是一个有向图 G=(V,E),其中在某些节点处 u1,u2,…,un,v∈V,和 e1,e2,…,en∈E,ei=(ui,v)。我们知道,存在某种计算图将 u1,…,un 映射到 vv。

举个例子,如果我们有 x + y = z,那么 (x,z),(y,z)∈E。

这对于评估算术表达式非常有用,我们能够在计算图的汇点下找到结果。汇点是类似 v∈V,∄e=(v,u) 这样的顶点。从另一方面来说,这些顶点从自身到其他顶点并没有定向边界。同样的,输入源是 v∈V,∄e=(u,v)。

对于我们来说,我们总是把值放在输入源上,而值也将传播到汇点上。

反向模式求微分

如果你觉得我的解释不正确,可以参考下这些幻灯片的说明。

微分是 Tensorflow 中许多模型的核心需求,因为我们需要它来运行梯度下降。每一个从高中毕业的人都应该知道微分的意思。如果是基于基础函数组成的复杂函数,则只需要求出函数的导数,然后应用链式法则。

超级简洁的概述

如果我们有一个像这样的函数:

对 x 求导:

对 y 求导:

其它的例子:

其导数是:

所以其梯度是:

链式法则,例如应用于 f(g(h(x))):

在 5 分钟内倒转模式

所以现在请记住我们运行计算图时用的是有向无环结构(DAG/Directed Acyclic Graph),还有上一个例子用到的链式法则。正如下方所示的形式:

x -> h -> g -> f

作为一个图,我们能够在 f 获得答案,然而,也可以反过来:

dx <- dh <- dg <- df

这样它看起来就像链式法则了!我们需要沿着路径把导数相乘以得到最终的结果。这是一个计算图的例子:

这就将其简化为一个图的遍历问题。有谁察觉到了这就是拓扑排序和深度优先搜索/宽度优先搜索?

没错,为了在两种路径都支持拓扑排序,我们需要包含一套父组一套子组,而汇点是另一个方向的来源。反之亦然。

执行

在开学前,Minh Le 和我开始设计这个项目。我们决定使用特征库后端(Eigen library backend)进行线性代数运算,这个库有一个叫做 MatrixXd 的矩阵类,用在我们的项目中:

class var {// Forward declarationstruct impl;public:
// For initialization of new vars by ptr var(std::shared_ptr<impl>);

var(double);
var(const MatrixXd&);
var(op_type, const std::vector<var>&);    
...// Access/Modify the current node value    MatrixXd getValue() const;
void setValue(const MatrixXd&);
op_type getOp() const;
void setOp(op_type);// Access internals (no modify)    std::vector<var>& getChildren() const;
std::vector<var> getParents() const;
...private: 
// PImpl idiom requires forward declaration of the class:    std::shared_ptr<impl> pimpl;};struct var::impl{public:
impl(const MatrixXd&);
impl(op_type, const std::vector<var>&);
MatrixXd val;
op_type op; 
std::vector<var> children;
std::vector<std::weak_ptr<impl>> parents;};

在这里,我们使用了一个叫「pImpl」的语法,意思是「执行的指针」。它有很多用途,比如接口的解耦实现,以及当在堆栈上有一个本地接口时实例化内存堆上的东西。「pImpl」的一些副作用是微弱的减慢运行时间,但是编译时间缩短了很多。这允许我们通过多个函数调用/返回来保持数据结构的持久性。像这样的树形数据结构应该是持久的。

我们有一些枚举来告诉我们目前正在进行哪些操作:

enum class op_type {
plus,
minus,
multiply,
divide,
exponent,
log,
polynomial,
dot,
...
none // no operators. leaf.};

执行此树的评估的实际类称为 expression:

class expression {public:
expression(var);
...
// Recursively evaluates the tree. double propagate();
...
// Computes the derivative for the entire graph. // Performs a top-down evaluation of the tree. void backpropagate(std::unordered_map<var, double>& leaves);
... private:
var root;};

在反向传播里,我们的代码能做类似以下所示的事情:

backpropagate(node, dprev):
derivative = differentiate(node)*dprev
for child in node.children:
backpropagate(child, derivative)

这几乎就是在做一个深度优先搜索(DFS),你发现了吗?

为什么是 C++?

在实际过程中,C++可能并不适合做这类事情。我们可以在像「Oaml」这样的函数式语言中花费更少的时间开发。现在我明白为什么「Scala」被用于机器学习中,主要就是因为「Spark」。然而,使用 C++有很多好处。

Eigen(库名)

举例来说,我们可以直接使用一个叫「Eigen」的 TensorFlow 的线性代数库。这是一个不假思索就被人用烂了的线性代数库。有一种类似于我们的表达式树的味道,我们构建表达式,它只会在我们真正需要的时候进行评估。然而,使用「Eigen」在编译的时间内就能决定什么时候使用模版,这意味着运行的时间减少了。我对写出「Eigen」的人抱有很大的敬意,因为查看模版的错误几乎让我眼瞎!

他们的代码看起来类似这样的:

Matrix A(...), B(...);
auto lazy_multiply = A.dot(B);
typeid(lazy_multiply).name(); // the class name is something like Dot_Matrix_Matrix.
Matrix(lazy_multiply); // functional-style casting forces evaluation of this matrix.

这个特征库非常的强大,这就是它作为 TensortFlow 主要后端之一的原因,即除了这个慵懒的评估技术之外还有其它的优化。

运算符重载

在 Java 中开发这个库很不错——因为没有 shared_ptrs、unique_ptrs、weak_ptrs;我们得到了一个真实的,有用的图形计算器(GC=Graphing Calculator)。这大大节省了开发时间,更不必说更快的执行速度。然而,Java 不允许操作符重载,因此它们不能这样:

// These 3 lines code up an entire neural network!
var sigm1 = 1 / (1 + exp(-1 dot(X, w1)));
var sigm2 = 1 / (1 + exp(-1
dot(sigm1, w2)));
var loss = sum(-1 (y log(sigm2) + (1-y) * log(1-sigm2)));

顺便说一下,上面是实际使用的代码。是不是非常的漂亮?我想说的是这甚至比 TensorFlow 里的 Python 封装还更优美!我只是想表明,它们也是矩阵。

在 Java 中,有一连串的 add(), divide() 等等是非常难看的。更重要的是,这将让用户更多的关注在「PEMDAS」上,而 C++的操作符则有非常好的表现。

特征,而不是一连串的故障

在这个库中,可以确定的是,TensorFlow 没有定义清晰的 API,或者有但我不知道。例如,如果我们只想训练一个特定子集的权重,我们可以只对我们感兴趣的特定来源做反向传播。这对于卷积神经网络的迁移学习非常有用,因为很多时候,像 VGG19 这样的大型网络可以被截断,然后附加一些额外的层,这些层的权重使用新领域的样本来训练。

基准

在 Python 的 TensorFlow 库中,对虹膜数据集进行 10000 个「Epochs」的训练以进行分类,并使用相同的超参数,我们有:

1.TensorFlow 的神经网络: 23812.5 ms
2.「Scikit」的神经网络:22412.2 ms
3.「Autodiff」的神经网络,迭代,优化:25397.2 ms
4.「Autodiff」的神经网络,迭代,无优化:29052.4 ms
5.「Autodiff」的神经网络,带有递归,无优化:28121.5 ms

令人惊讶的是,Scikit 是所有这些中最快的。这可能是因为我们没有做庞大的矩阵乘法。也可能是 TensorFlow 需要额外的编译步骤,如变量初始化等等。或者,也许我们不得不在 python 中运行循环,而不是在 C 中(Python 循环真的非常糟糕!)我自己也不是很确定。我完全明白这绝不是一种全面的基准测试,因为它只在特定的情况下应用了单个数据点。然而,这个库的表现并不能代表当前最佳,所以希望各位读者和我们共同完善

转载于:https://blog.51cto.com/14209412/2354021

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/536906.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

docker kali安装mysql_kali安装docker(有效详细的教程) ——vulhub漏洞复现 001

前记&#xff1a;博主有着多次安装docker的丰富经验&#xff0c;曾经为了在kali成功安装docker花费不少时间。在kali2016.3一直到最新的kali2019.4都通吃&#xff01;所以跟着下面的步骤走&#xff0c;绝对不会出错。(该机子此前没装过docker&#xff0c;并且配置好了kali更新源…

PDF文件如何转成markdown格式

百度上根据pdf转makrdown为关键字进行搜索&#xff0c;结果大多数是反过来的转换&#xff0c;即markdown文本转PDF格式。 但是PDF转markdown的解决方案很少。 正好我工作上有这个需求&#xff0c;所以自己实现了一个解决方案。 下图是一个用PDF XChange Editor打开的PDF文件&am…

kangle支不支持PHP_【转载】PHP调用kangle的API

摘要&#xff1a;根据管理的API公布写了一个类封装了一个操作集合&#xff0c;这是一个kangleAPI的一个封...根据管理的API公布写了一个类封装了一个操作集合&#xff0c;这是一个kangleAPI的一个封装吧&#xff0c;是在其他地方看到的&#xff0c;接口包含获取easypanel的信息…

ES6 学习笔记(一)let,const和解构赋值

let和const let和const是es6新增的两个变量声明关键字&#xff0c;与var的不同点在于&#xff1a; &#xff08;1&#xff09;let和const都是块级作用域&#xff0c;在{}内有效&#xff0c;这点在for循环中非常有用&#xff0c;只在循环体内有效。var为函数作用域。 &#xff0…

mysql数据库容量和性能_新品速递丨容量盘性能提升超 300%,数据库支持 MySQL 8.0...

2关系型数据库 MySQL Plus支持 MySQL 8.0 内核及 XtraBackup 物理在线迁移方式关系型数据库服务 MySQL Plus 发布新版本 1.0.6 &#xff0c; 新增多项功能&#xff0c;提升了集群自动化运维能力。主要升级有&#xff1a;- 支持 MySQL 8.0 内核&#xff1a;根据官方测试&#xf…

10. Python面向对象

Python从设计之初就已经是一门面向对象的语言&#xff0c;正因为如此&#xff0c;在Python中创建一个类和对象是很容易的。如果接触过java语言同学应该都知道&#xff0c;Java面向对象三大特征是&#xff1a;封装、继承、多态。Python面向对象也有一些特征&#xff0c;接下来我…

mysql聚簇索引 和主键的区别_[MySQL] innoDB引擎的主键与聚簇索引

MysqL的innodb引擎本身存储的形式就必须是聚簇索引的形式,在磁盘上树状存储的,但是不一定是根据主键聚簇的,有三种情形:1. 有主键的情况下,主键就是聚簇索引2. 没有主键的情况下,第一个非空null的唯一索引就是聚簇索引3. 如果上面都没有,那么就是有一个隐藏的row-id作为聚簇索引…

前端页面:一直报Cannot set property 'height' of undefined

1、出现错误的例子&#xff0c;只拷贝了项目中关键出现问题的部分 例子中明明写了styleheight:16px这个属性&#xff0c;但是为什么还说height未定义呢 通过打印发现&#xff1a;cks.each(function () { autoTextAreaHeight($(this)); });中的$(this)取出来…

mysql表在线转成分区表_11g普通表在线转换分区表

本帖最后由 灯和树 于 2016-5-4 14:58 编辑由于业务系统数据量增大&#xff0c;对其用户表在线完成分区表转换过程&#xff0c;记录如下&#xff0c;11g数据库支持。创建过渡分区表根据USER_ID创建分区表CREATE TABLE SDP_SMECD.TEST_T_USER_ID(USER_ID NUMBER(13) …

tiger4444/rabbit4444后缀勒索病毒怎么删除 能否百分百恢复

上海某客户中了tiger4444的勒索病毒&#xff0c;找到我们后&#xff0c;一天内全部恢复完成。说了很多关于勒索病毒的事情&#xff0c;也提醒过大家&#xff0c;可总是有人疏忽&#xff0c;致使中招后&#xff0c;丢钱丢面子&#xff0c;还丢工作。 那么要怎样预防呢与处理呢&a…

mysql远程一会不用卡住_连接远程MySQL数据库项目启动时,不报错但是卡住不继续启动的,...

连接远程MySQL数据库项目启动时&#xff0c;不报错但是卡住不继续启动的&#xff0c;2018-03-12 17:08:52.532DEBUG[localhost-startStop-1]o.s.beans.factory.support.DefaultListableBeanFactory.doGetBean():251 -Returning cached instance of singleton bean ‘org.spring…

GPT-5、开源、更强的ChatGPT!

年终岁尾&#xff0c;正值圣诞节热闹气氛的OpenAI写下了2024年的发展清单。 OpenAI联合创始人兼首席执行官Sam Altman在社交平台公布&#xff0c;AGI&#xff08;稍晚一些&#xff09;、GPT-5、更好的语音模型、更高的费率限制&#xff1b; 更好的GPTs&#xff1b;更好的推理…

CentOS_7 安装MySql5.7

2019独角兽企业重金招聘Python工程师标准>>> 下载mysql的源 wget http://dev.mysql.com/get/mysql57-community-release-el7-7.noarch.rpm 安装yum库 yum localinstall -y mysql57-community-release-el7-7.noarch.rpm 安装MySQL yum install -y mysql-community-…

python查询mysql decimal报错_python读取MySQL数据表时,使用ast模块转换decimal格式数据的坑...

概述MySQL中常用的数据格式有tinyint()、int()、float()、double()、decimal() 、varchar、enum()、datetime;小数格式中decimal比较常用&#xff0c;因为更加精确&#xff0c;这里就以decimal为例。从MySQL中读取了一行数据&#xff0c;内容为&#xff1a;(17479, datetime.da…

性能测试总结(一)---基础理论篇(转载)

随着软件行业的快速发展&#xff0c;现代的软件系统越来越复杂&#xff0c;功能越来越多&#xff0c;测试人员除了需要保证基本的功能测试质量&#xff0c;性能也随越来越受到人们的关注。但是一提到性能测试&#xff0c;很多人就直接连想到Loadrunner。认为LR就等于性能测试&a…

java listen_JavaWeb之Filter、Listener

昨天和大家介绍了一下JSON的用法&#xff0c;其实JSON中主要是用来和数据库交互数据的。今天给大家讲解的是Filter和Listener的用法。一、Listenner监听器1.1、定义Javaweb中的监听器是用于监听web常见对象HttpServletRequest,HttpSession,ServletContext。1.2、监听器的作用监…

BFC的概念及作用

在了解什么是BFC之前&#xff0c;首先得明白什么是Box , Formatting Context &#xff08;一个决定如何渲染文档的容器&#xff09;的概念 Box: CSS布局的基本单位 Box是 CSS 布局的对象和基本单位&#xff0c; 直观点来说&#xff0c; 就是一个页面是由很多个 Box组成的&#…

bitcount java_java-Long.bitCount()如何找到设置的位数?

让我们以255为例.我们将这些位组合在一起.首先,我们从255开始,为0b1111.1111(二进制为8 1)第一行代码是&#xff1a;i i - ((i > > > 1) & 0x5555555555555555L);这条线正在梳理每对1.由于我们有8个1,所以我们希望组合成对,并得到2,2,2,2之类的东西.由于它是二进…

Luogu P2463 [SDOI2008]Sandy的卡片

题目链接 \(Click\) \(Here\) 真的好麻烦啊。。事实证明&#xff0c;理解是理解&#xff0c;一定要认认真真把板子打牢&#xff0c;不然调锅的时候真的会很痛苦。。&#xff08;最好是八分钟能无脑把\(SA\)码对的程度\(QAQ\)&#xff09; 这个题最开始我想的是\(RMQ\)遍历每一个…

java log输出到文件路径_Java - 配置log4j的日志文件路径 (附-获取当前类路径的多种方法)...

1 日志路径带来的痛点Java 项目中少不了要和log4j等日志框架打交道, 开发环境和生产环境下日志文件的输出路径总是不一致, 设置为绝对路径的方式缺少了灵活性, 每次变更项目路径都要修改文件, 目前想到的最佳实现方式是: 根据项目位置自动加载并配置文件路径.本文借鉴 Tomcat 的…