cuda从零开始手搓PB神经网络

cuda实现PB神经网络


基于上一篇的矩阵点乘,实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下:

#ifndef __BP_NETWORK_HPP__
#define __BP_NETWORK_HPP__
#include "matrix.hpp"
#include "mat.hpp"
#include "update_methods.hpp"template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_, int ... remain_layer>
struct bp_network
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using next_node_type = typename bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>;using term_output_type = typename next_node_type::term_output_type;weight_type weight;update_type_tpl<weight_type> weight_update_method;output_type bias;update_type_tpl<output_type> bias_update_method;input_type pre_input;output_type pre_func_input;next_node_type next_node;bp_network():weight_update_method(), bias_update_method(){weight.template reset<init_type>();bias.template reset<init_type>();next_node = bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>();}auto forward(input_type& input){output_type curr_output;pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;curr_output = pre_func_input.template activate<forward_func>();return next_node.forward(curr_output);}auto backward(term_output_type& delta, val_type lr){output_type curr_delta = next_node.backward(delta, lr);curr_delta = pre_func_input.template activate<backward_func>() * curr_delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update_method.update(weight, delta_weight);bias = bias_update_method.update(bias, curr_delta);return ret;}   // 更新惯性量void update_inert(){weight_update_method.update_inert();bias_update_method.update_inert();next_node.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("=================\n");next_node.print();}
};template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_>
struct bp_network<activate_type, val_type_, update_type_tpl, init_type, input_num_, output_num_>
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using term_output_type = typename output_type;using weight_update_type = typename update_type_tpl<weight_type>;using bias_update_type = typename update_type_tpl<output_type>;weight_type weight;weight_update_type weight_update;output_type bias;bias_update_type bias_update;output_type pre_func_input;input_type pre_input;bp_network():weight_update(), bias_update(){weight.template reset<init_type>();bias.template reset<init_type>();}auto forward(input_type& input){pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;return pre_func_input.template activate<forward_func>();}auto backward(output_type& delta, val_type lr){output_type curr_delta = pre_func_input.template activate<backward_func>() * delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update.update(weight, delta_weight);bias = bias_update.update(bias, curr_delta);return ret;}void update_inert(){weight_update.update_inert();bias_update.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("*****************\n");}
};#endif

下面实验一下我们的bp神经网络。

#include <chrono>
#include <thread>
#include "matrix.hpp"
#include "bp_network.hpp"
int main()
{constexpr int row_num = 32;constexpr int adj_num = 32;constexpr int col_num = 32;/*matrix_device_proxy<row_num, adj_num, double> A;eyes(A(), 2.0);matrix_device_proxy<adj_num, col_num, double> B;eyes(B(), 1.0);matrix_device_proxy<row_num, col_num, double> C;mat_dot<sigmoid>(A(), B(), C());print(type_cast(C()));auto A = mat<row_num, adj_num, double>::eyes(2.0);auto B = mat<adj_num, col_num, double>::eyes(1.0);auto C = A.dot(B);C = C + 1.0;C = sqrtl(C);C = C - 2.0;C = C * 3.0;C = C / 4.0;C.print();std::cout << "---------- D ----------" << std::endl;auto D = mat<row_num, col_num, double>::xavier_gaussian();D.print();std::cout << "---------- E ----------" << std::endl;auto E = mat<row_num, col_num, double>::xavier_mean();E.print();std::cout << "---------- F ----------" << std::endl;auto F = mat<row_num, col_num, double>::he_gaussian();F.print();std::cout << "---------- G ----------" << std::endl;auto G = mat<row_num, col_num, double>::he_mean();G.print();*/bp_network<sigmoid, double, nadam, xavier_gaussian_type, row_num, adj_num, col_num> node;auto input = mat<row_num, 1, double>::ones(0.2);auto expect = mat<col_num, 1, double>::ones(0.4);int times = 8000;int update_inert_times = 100;int step = times / update_inert_times;// 计时开始auto start = std::chrono::high_resolution_clock::now();for (int i = 0; i < times; ++i){auto output = node.forward(input);auto delta = (output - expect);node.backward(delta, 0.001);if (i == times - 1){output.t().print();}if (i % step == 0 && i != 0){node.update_inert();}}// 计时结束// 获取结束时间点auto end = std::chrono::high_resolution_clock::now();// 计算持续时间std::chrono::duration<double> duration = end - start;// 输出执行时间std::cout << "Execution time: " << duration.count() << " seconds" << std::endl;//node.print();cudaDeviceReset();return 0;
}

以上代码有个学习率lr没有地方设置哈,将来优化,见谅。执行结果如下:
在这里插入图片描述
可以看出,经过8000次的训练,这个使用sigmoid激活函数、NAdam优化、Xavier-Gaussian初始化的323232的PB能够将误差缩减到0.0001这个量级,而训练时间仅为8.54秒。还是相当给力的。
虽然这对于我的工作没有任何关系,但是我还是想搞一下。毕竟“越是没用的知识就越有用,越是有用的东西就越没用”。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/66381.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

我的世界-与门、或门、非门等基本门电路实现

一、红石比较器 (1) 红石比较器结构 红石比较器有前端单火把、后端双火把以及两个侧端 其中后端和侧端是输入信号,前端是输出信号 (2) 红石比较器的两种模式 比较模式 前端火把未点亮时处于比较模式 侧端>后端 → 0 当任一侧端强度大于后端强度时,输出…

项目开发实践——基于SpringBoot+Vue3实现的在线考试系统(七)

文章目录 一、题库管理模块实现1、新增题目功能实现1.1 页面设计1.2 前端功能实现1.3 后端功能实现1.4 效果展示2、题目列表功能实现2.1 页面设计2.2 前端功能实现2.3 后端功能实现2.3.1 后端查询题目列表接口实现2.3.2 后端编辑试题接口实现2.4 效果展示二、代码下载一、题库管…

打破编程“鄙视链”:探索行业发展新路径

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 打破…

【统计的思想】假设检验(一)

假设检验是统计学里的重要方法&#xff0c;同时也是一种“在理想与现实之间观察求索”的测试活动。假设检验从概率的角度去考察理想与现实之间的关系&#xff0c;籍此来缓解测试可信性问题。 我们先来看一个例子。民航旅客服务系统&#xff0c;简称PSS系统&#xff0c;有一种业…

SpringBoot2 + Flowable(UI)

文章目录 引言I 技术栈软件架构基于 Vue.js 和 Element UI 的后台管理系统工程结构II 依赖rest,logic,conf 的依赖工作流flowable jar包flowable-ui所需jar包III 配置jdbc 配置 nullCatalogMeansCurrent = true引言 I 技术栈 软件架构 前端基于vue 、element-ui框架分模块设…

Linux探秘坊-------3.开发工具详解(1)

1 初识vim编辑器 创建第一个vim编辑的代码 1.新建文件 2.使用vim打开 3.打开默认是命令模式&#xff0c;写代码需要在屏幕上输出“i”字符 1.写完代码后要按Esc键退出到指令模式2.再按shift:wq即可保存并退出vim &#xff08;因为不支持鼠标&#xff0c;通常 使用键盘上的箭…

基于海思soc的智能产品开发(高、中、低soc、以及和fpga的搭配)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 市场上关于图像、音频的soc其实非常多&#xff0c;这里面有高、中、低档&#xff0c;开发方式也不相同。之所以会这样&#xff0c;有价格的因素&am…

51单片机——DS18B20温度传感器

由于DS18B20数字温度传感器是单总线接口&#xff0c;所以需要使用51单片机的一个IO口模拟单总线时序与DS18B20通信&#xff0c;将检测的环境温度读取出来 1、DS18B20模块电路 传感器接口的单总线管脚接至单片机P3.7IO口上 2、DS18B20介绍 2.1 DS18B20外观实物图 管脚1为GN…

STL容器-- list的模拟实现(附源码)

STL容器-- list的模拟实现&#xff08;附源码&#xff09; List的实现主要时考察我们对list这一容器的理解&#xff0c;和代码的编写能力&#xff0c;通过上节对list容器的使用&#xff0c;我们对list容器已经有了一些基本的了解&#xff0c;接下来就让我们来实现一些list容器常…

Lynx TiDB 慢日志收集工具

作者&#xff1a; 小龙虾爱大龙虾 原文来源&#xff1a; https://tidb.net/blog/7247e68f 简介 lynx 工具可以定时将 TiDB 集群的慢查询收集并持久化到后端数据库中&#xff0c;然后通过 grafana 查询展示出来&#xff0c;这可以帮助我们更好的分析慢查询日志。 背景 尽管…

Gin 源码概览 - 路由

本文基于gin 1.1 源码解读 https://github.com/gin-gonic/gin/archive/refs/tags/v1.1.zip 1. 注册路由 我们先来看一段gin代码&#xff0c;来看看最终得到的一颗路由树长啥样 func TestGinDocExp(t *testing.T) {engine : gin.Default()engine.GET("/api/user", f…

【逆境中绽放:万字回顾2024我在挑战中突破自我】

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” 文章目录 一、引言二、个人成长与盘点情感与心理成长学习与技能提升其它荣誉 三、年度创作历程回顾创作内容概…

【Linux 重装】Ubuntu 启动盘 U盘无法被识别,如何处理?

背景 U盘烧录了 Ubuntu 系统作为启动盘&#xff0c;再次插入电脑后无法被识别 解决方案&#xff08;Mac 适用&#xff09; &#xff08;1&#xff09;查找 USB&#xff0c;&#xff08;2&#xff09;格式化&#xff08;1&#xff09;在 terminal 中通过 diskutil list 查看是…

中职网络建设与运维ansible服务

ansible服务 填写hosts指定主机范围和控制节点后创建一个脚本&#xff0c;可以利用简化脚本 1. 在linux1上安装系统自带的ansible-core,作为ansible控制节点,linux2-linux7作为ansible的受控节点 Linux1 Linux1-7 Yum install ansible-core -y Vi /etc/ansible/hosts 添加…

数据库服务体系结构

1. 数据库服务应用配置 服务进行配置有什么作用&#xff1f; 实现服务运行启动 实现某些功能 应用配置有三种方式&#xff1f; 利用编译安装进行配置 编写配置文件信息 ,.默认的配置文件: /etc/my.cnf 利用启动命令参数配置信息&#xff0c;mysqld_safe --skip-grant-tables --…

Langchain+FastApi+Vue前后端Ai对话(超详细)

一、引入 首先可以先看下作者的文章 FastApi相关文章&#xff1a;创建最简单FastApi的项目Vue相关文章&#xff1a;最简单的aixos二次封装Langchain相关文章&#xff1a;如何使用LangSmith跟踪deepseek模型 二、后端搭建 1 项目文件结构 routers&#xff1a;存放api接口se…

简历_使用优化的Redis自增ID策略生成分布式环境下全局唯一ID,用于用户上传数据的命名以及多种ID的生成

系列博客目录 文章目录 系列博客目录WhyRedis自增ID策略 Why 我们需要设置全局唯一ID。原因&#xff1a;当用户抢购时&#xff0c;就会生成订单并保存到tb_voucher_order这张表中&#xff0c;而订单表如果使用数据库自增ID就存在一些问题。 问题&#xff1a;id的规律性太明显、…

Jira中bug的流转流程

Jira中bug的状态 1. 处理Bug的流程2. bug状态流转详述bug的状态通常包括 1. 处理Bug的流程 2. bug状态流转详述 bug的状态通常包括 未解决 1. 测试人员创建一个bug&#xff0c;填写bug的详细信息&#xff0c;如概要、bug级别、复现步骤、现状、预期结果等 2. 定位bug&#x…

Linux的几个基本指令

文章目录 一、几个基本指令1、ls 指令注意&#xff01; 2、pwd命令3、touch 指令4、mkdir 指令注意&#xff01;注意&#xff01; 5、cd 指令注意&#xff01; 6、cp 指令 今天我们学习Linux下的几个基本指令&#xff0c;本篇是在Xshell环境下执行的。 一、几个基本指令 1、…

数据库基础练习1(创建表,设置外键,检查,不为空,主键等约束)安装mysql详细步骤

安装MySQL详细步骤 1. 下载 MySQL 安装程序 访问 MySQL 官方网站&#xff1a;MySQL Downloads。在下载页面&#xff0c;选择 "MySQL Community (GPL) Downloads"。在 "MySQL Community Server" 部分&#xff0c;根据你的操作系统&#xff08;Windows&…