【单层神经网络】基于MXNet的线性回归实现(底层实现)

写在前面

  1. 刚开始先从普通的寻优算法开始,熟悉一下学习训练过程
  2. 下面将使用梯度下降法寻优,但这大概只能是局部最优,它并不是一个十分优秀的寻优算法

整体流程

  1. 生成训练数据集(实际工程中,需要从实际对象身上采集数据)
  2. 确定模型及其参数(输入输出个数、阶次,偏置等)
  3. 确定学习方式(损失函数、优化算法,学习率,训练次数,终止条件等)
  4. 读取数据集(不同的读取方式会影响最终的训练效果)
  5. 训练模型

完整程序及注释

from IPython import display
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random'''
获取(生成)训练集
'''
input_num = 2				# 输入个数
examples_num = 1000			# 生成样本个数
# 确定真实模型参数
real_W = [10.9, -8.7]		
real_bias = 6.5	features = nd.random.normal(scale=1, shape=(examples_num, input_num))       # 标准差=1,均值缺省=0
labels = real_W[0]*features[:,0] + real_W[1]*features[:,1] + real_bias		# 根据特征和参数生成对应标签
labels_noise = labels + nd.random.normal(scale=0.1, shape=labels.shape)		# 为标签附加噪声,模拟真实情况# 绘制标签和特征的散点图(矢量图)
# def use_svg_display():
#     display.set_matplotlib_formats('svg')# def set_figure_size(figsize=(3.5,2.5)):
#     use_svg_display()
#     plt.rcParams['figure.figsize'] = figsize# set_figure_size()
# plt.scatter(features[:,0].asnumpy(), labels_noise.asnumpy(), 1)
# plt.scatter(features[:,1].asnumpy(), labels_noise.asnumpy(), 1)
# plt.show()# 创建一个迭代器(确定从数据集获取数据的方式)
def data_iter(batch_size, features, labels):num = len(features)indices = list(range(num))                                  # 生成索引数组random.shuffle(indices)                                     # 打乱indices# 该遍历方式同时确保了随机采样和无遗漏for i in range(0, num, batch_size):j = nd.array(indices[i: min(i+batch_size, num)])        # 对indices从i开始取,取batch_size个样本,并转换为列表yield features.take(j), labels.take(j)                  # take方法使用索引数组,从features和labels提取所需数据"""
训练的基础准备
"""
# 声明训练变量,并赋高斯随机初始值
w = nd.random.normal(scale=0.01, shape=(input_num))
b = nd.zeros(shape=(1,))
# b = nd.zeros(1)       # 不同写法,等价于上面的
w.attach_grad()         # 为需要迭代的参数申请求梯度空间
b.attach_grad()# 定义模型
def linreg(X, w, b):return nd.dot(X,w)+b# 定义损失函数
def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) **2 /2# 定义寻优算法
def sgd(params, learning_rate, batch_size):for param in params:# 新参数 = 原参数 - 学习率*当前批量的参数梯度/当前批量的大小param[:] = param - learning_rate * param.grad / batch_size# 确定超参数和学习方式
lr = 0.03
num_iterations = 5
net = linreg				# 目标模型
loss = squared_loss			# 代价函数(损失函数)
batch_size = 10				# 每次随机小批量的大小'''
开始训练
'''
for iteration in range(num_iterations):		# 确定迭代次数for x, y in data_iter(batch_size, features, labels):with autograd.record():l = loss(net(x,w,b), y)			# 求当前小批量的总损失l.backward()						# 求梯度sgd([w,b], lr, batch_size)			# 梯度更新参数train_l = loss(net(features,w,b), labels)print("iteration %d, loss %f" % (iteration+1, train_l.mean().asnumpy()))
# 打印比较真实参数和训练得到的参数
print("real_w " + str(real_W) + "\n train_w " + str(w))
print("real_w " + str(real_bias) + "\n train_b " + str(b))

具体程序解释

param[:] = param - learning_rate * param.grad / batch_size
将batch_size与参数调整相关联的原因,是为了使得每次更新的步长不受批次大小的影响
具体来说,当计算一批数据的损失函数的梯度时,实际上是将这批数据中每个样本对损失函数的贡献累加起来。这意味着如果批次较大,梯度的模也会相应增大
故更新权值时,使用的是数据集的平均梯度,而不是总和

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/894648.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

本地快速部署DeepSeek-R1模型——2025新年贺岁

一晃年初六了,春节长假余额马上归零了。今天下午在我的电脑上成功部署了DeepSeek-R1模型,抽个时间和大家简单分享一下过程: 概述 DeepSeek模型 是一家由中国知名量化私募巨头幻方量化创立的人工智能公司,致力于开发高效、高性能…

C++11详解(一) -- 列表初始化,右值引用和移动语义

文章目录 1.列表初始化1.1 C98传统的{}1.2 C11中的{}1.3 C11中的std::initializer_list 2.右值引用和移动语义2.1左值和右值2.2左值引用和右值引用2.3 引用延长生命周期2.4左值和右值的参数匹配问题2.5右值引用和移动语义的使用场景2.5.1左值引用主要使用场景2.5.2移动构造和移…

手写MVVM框架-构建虚拟dom树

MVVM的核心之一就是虚拟dom树,我们这一章节就先构建一个虚拟dom树 首先我们需要创建一个VNode的类 // 当前类的位置是src/vnode/index.js export default class VNode{constructor(tag, // 标签名称(英文大写)ele, // 对应真实节点children,…

【大数据技术】教程03:本机PyCharm远程连接虚拟机Python

本机PyCharm远程连接虚拟机Python 注意:本文需要使用PyCharm专业版。 pycharm-professional-2024.1.4VMware Workstation Pro 16CentOS-Stream-10-latest-x86_64-dvd1.iso写在前面 本文主要介绍如何使用本地PyCharm远程连接虚拟机,运行Python脚本,提高编程效率。 注意: …

pytorch实现门控循环单元 (GRU)

人工智能例子汇总:AI常见的算法和例子-CSDN博客 特性GRULSTM计算效率更快,参数更少相对较慢,参数更多结构复杂度只有两个门(更新门和重置门)三个门(输入门、遗忘门、输出门)处理长时依赖一般适…

PAT甲级1032、sharing

题目 To store English words, one method is to use linked lists and store a word letter by letter. To save some space, we may let the words share the same sublist if they share the same suffix. For example, loading and being are stored as showed in Figure …

最小生成树kruskal算法

文章目录 kruskal算法的思想模板 kruskal算法的思想 模板 #include <bits/stdc.h> #define lowbit(x) ((x)&(-x)) #define int long long #define endl \n #define PII pair<int,int> #define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0); using na…

为何在Kubernetes容器中以root身份运行存在风险?

作者&#xff1a;马辛瓦西奥内克&#xff08;Marcin Wasiucionek&#xff09; 引言 在Kubernetes安全领域&#xff0c;一个常见的建议是让容器以非root用户身份运行。但是&#xff0c;在容器中以root身份运行&#xff0c;实际会带来哪些安全隐患呢&#xff1f;在Docker镜像和…

ConcurrentHashMap线程安全:分段锁 到 synchronized + CAS

专栏系列文章地址&#xff1a;https://blog.csdn.net/qq_26437925/article/details/145290162 本文目标&#xff1a; 理解ConcurrentHashMap为什么线程安全&#xff1b;ConcurrentHashMap的具体细节还需要进一步研究 目录 ConcurrentHashMap介绍JDK7的分段锁实现JDK8的synchr…

[ESP32:Vscode+PlatformIO]新建工程 常用配置与设置

2025-1-29 一、新建工程 选择一个要创建工程文件夹的地方&#xff0c;在空白处鼠标右键选择通过Code打开 打开Vscode&#xff0c;点击platformIO图标&#xff0c;选择PIO Home下的open&#xff0c;最后点击new project 按照下图进行设置 第一个是工程文件夹的名称 第二个是…

述评:如果抗拒特朗普的“普征关税”

题 记 美国总统特朗普宣布对美国三大贸易夥伴——中国、墨西哥和加拿大&#xff0c;分别征收10%、25%的关税。 他威胁说&#xff0c;如果这三个国家不解决他对非法移民和毒品走私的担忧&#xff0c;他就要征收进口税。 去年&#xff0c;中国、墨西哥和加拿大这三个国家&#…

九. Redis 持久化-AOF(详细讲解说明,一个配置一个说明分析,步步讲解到位 2)

九. Redis 持久化-AOF(详细讲解说明&#xff0c;一个配置一个说明分析&#xff0c;步步讲解到位 2) 文章目录 九. Redis 持久化-AOF(详细讲解说明&#xff0c;一个配置一个说明分析&#xff0c;步步讲解到位 2)1. Redis 持久化 AOF 概述2. AOF 持久化流程3. AOF 的配置4. AOF 启…

基于Springboot框架的学术期刊遴选服务-项目演示

项目介绍 本课程演示的是一款 基于Javaweb的水果超市管理系统&#xff0c;主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的 Java 学习者。 1.包含&#xff1a;项目源码、项目文档、数据库脚本、软件工具等所有资料 2.带你从零开始部署运行本套系统 3.该项目附…

新版231普通阿里滑块 自动化和逆向实现 分析

声明: 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01; 逆向过程 补环境逆向 部分补环境 …

java-(Oracle)-Oracle,plsqldev,Sql语法,Oracle函数

卸载好注册表,然后安装11g 每次在执行orderby的时候相当于是做了全排序,思考全排序的效率 会比较耗费系统的资源,因此选择在业务不太繁忙的时候进行 --给表添加注释 comment on table emp is 雇员表 --给列添加注释; comment on column emp.empno is 雇员工号;select empno,en…

泰山派Linux环境下自动烧录脚本(EMMC 2+16G)

脚本名字&#xff1a; download.sh 输入./download -h获取帮助信息 &#xff0c;其中各个IMG/TXT烧录的地址和路径都在前几行修改即可 #!/bin/bash# # DownLoad.sh 多镜像烧录脚本 # 版本&#xff1a;1.1 # 作者&#xff1a;zhangqi # 功能&#xff1a;通过参数选择烧录指定镜…

正大杯攻略|分层抽样+不等概率三阶段抽样

首先&#xff0c;先进行分层抽样&#xff0c;确定主城区和郊区的比例 然后对主城区分别进行不等概率三阶段抽样 第一阶段&#xff0c;使用PPS抽样&#xff0c;确定行政区&#xff08;根据分层抽样比例合理确定主城区和郊区行政区数量&#xff09; 第二阶段&#xff0c;使用分…

开源智慧园区管理系统对比其他十种管理软件的优势与应用前景分析

内容概要 在当今数字化快速发展的时代&#xff0c;园区管理软件的选择显得尤为重要。而开源智慧园区管理系统凭借其独特的优势&#xff0c;逐渐成为用户的新宠。与传统管理软件相比&#xff0c;它不仅灵活性高&#xff0c;而且具有更强的可定制性&#xff0c;让各类园区&#…

计算机网络 应用层 笔记1(C/S模型,P2P模型,FTP协议)

应用层概述&#xff1a; 功能&#xff1a; 常见协议 应用层与其他层的关系 网络应用模型 C/S模型&#xff1a; 优点 缺点 P2P模型&#xff1a; 优点 缺点 DNS系统&#xff1a; 基本功能 系统架构 域名空间&#xff1a; DNS 服务器 根服务器&#xff1a; 顶级域…

人类心智逆向工程:AGI的认知科学基础

文章目录 引言:为何需要逆向工程人类心智?一、逆向工程的定义与目标1.1 什么是逆向工程?1.2 AGI逆向工程的核心目标二、认知科学的四大支柱与AGI2.1 神经科学:大脑的硬件解剖2.2 心理学:心智的行为建模2.3 语言学:符号与意义的桥梁2.4 哲学:意识与自我模型的争议三、逆向…