梯度上升和随机梯度上升

目录

梯度上升算法:

代码:

随机梯度上升算法:

代码:

实验:

做图代码:

疑问:

1.梯度上升算法不适应大的数据集,改用随机梯度上升更合适。

2.改进过的随机梯度算法,w vs epoch曲线出现波动。

代码实现时遇到的问题

1.对随机的理解,随机过的样本不再参与随机?

 2.数组越界

实验结果: 


Logistic回归实现二分类:http://t.csdnimg.cn/eEEjF

学习资料:Peter Harrington 机器学习实战

梯度上升算法:

如下,每更新一次权重需要计算所有样本(train_X)和权重乘积的sigmoid值。

对于m行n列特征矩阵和n行权重,迭代次数epoch,计算复杂度为O(n*m*epoch),w迭代次数为epoch.

代码:

def grad(train_X,train_y):# 100*3m,n = len(train_X[:,0]),len(train_X[0])#3x1 weight=np.ones((n,1))#迭代系数epoch=500for i in range(epoch):# mxn nx1 ->m*1y_=sigmoid(np.dot(train_X,weight))# m*1loss = train_y -y_a = 0.01# 3*1 weight = weight - np.dot(a*train_X.transpose(),loss)return weight

随机梯度上升算法:

每计算一次样本更新一次权重。

代码:

def grad(train_X,train_y):# 100*3m,n = len(train_X[:,0]),len(train_X[0])#3x1 weight=np.ones((n,1))# 每个样本更新权重一次,故使用mfor i in range(m):# mxn nx1 ->m*1# y_=sigmoid(np.dot(train_X,weight))# 1xn nx1 -->1x1#(3,)表示只有一个维度,在这个维度上有三个数字y_=sigmoid(sum(np.dot(train_X[i].reshape(1,n),weight)))# 1x1loss = train_y[i] -y_a = 0.01# nx1 1xn 1x1 nx1 1x1weight = weight + a*np.dot(train_X[i].transpose().reshape(n,1),loss.reshape(1,1))return weight

使用这种方法实现的分类效果相较之前效果变差了,应该是w迭代次数不够,为了查看何时收敛,查看w和epoch的变化关系图。

梯度上升随机梯度上升

实验:

使用随机梯度算法:研究w和epoch的关系,epoch表示迭代数据集的次数.

ln(\frac{y}{1-y})=wx+b=w_{1}x1+w_{2}x2+b

w1,b,w2分别为权重为w[0],w[1],w[2]。

可以看出w0收敛很快,w1和w2需要更多时间才能实现收敛.

 

 

发现w大概到epoch=200处收敛,设置遍历数据集200次得到新图像基本和梯度算法基本一致:

做图代码:

def grad(train_X,train_y):# 100*3m,n = len(train_X[:,0]),len(train_X[0])#3x1 weight=np.ones((n,1))#迭代系数epoch=500w=[]# lineArr.append(float(currLine[i]))#100x100次for j in range(epoch):for i in range(m):# mxn nx1 ->m*1# y_=sigmoid(np.dot(train_X,weight))# 1xn nx1 -->1x1#(3,)表示只有一个维度,在这个维度上有三个数字y_=sigmoid(sum(np.dot(train_X[i].reshape(1,n),weight)))# 1x1loss = train_y[i] -y_a = 0.01# nx1 1xn 1x1 nx1 1x1weight = weight + a*np.dot(train_X[i].transpose().reshape(n,1),loss.reshape(1,1))w.append(weight)w0=[row[0] for row in w]w1=[row[1] for row in w]w2=[row[2] for row in w]epochlist =list(range(1,epoch+1,1))plt.plot(epochlist,w0)    plt.xlabel('gradw')plt.ylabel('epoch')plt.legend()plt.title("w[0] vs epoch")plt.show()plt.plot(epochlist,w1)    plt.xlabel('gradw')plt.ylabel('epoch')plt.legend()plt.title("w[1] vs epoch")plt.show()plt.plot(epochlist,w2)    plt.xlabel('gradw')plt.ylabel('epoch')plt.legend()plt.title("w[2] vs epoch")plt.show()return weight

疑问:

1.梯度上升算法不适应大的数据集,改用随机梯度上升更合适。

书中提到:数据量大的话不太方便,所以想到每计算一次样本更新一次权重,也就是随机梯度上升算法。但是计算每个样本迭代的权重,再遍历全部样本,不就是把矩阵乘法拆开算吗,我不理解,感觉没有提升运算效率的作用。

2.改进过的随机梯度算法,w vs epoch曲线出现波动。

我感觉这个第一个随机梯度上升法完全没有体现随机性,随机应该随机抽取训练集的的子集来更新回归系数吧。

 书里有提到随机梯度算法的改进:

2.1.学习率随着迭代次数增加应该减小,能缓解高频波动。(不过我的实验在第一个梯度上升没有出现高频波动,改进后反而出现了。)

2.2.随机抽取样本进行类别预测,loss计算和更新权重。

randIndex是0-len(dataIndex)的随机值;去掉dataIndex[randIndex]的值

但是这个randIndex和dataIndex为什么这么写,不太理解.

def grad(train_X,train_y):m,n = len(train_X[:,0]),len(train_X[0])weight=np.ones((n,1))epoch=50w=[]for j in range(epoch):#[0,1,2,,,m-1]对应train_X的索引dataIndex = list(range(m))for i in range(m):#随着迭代次数增加,减小学习率alpha = 4/(1.0+j+i)+0.0001 #在dataIndex中取随机样本的索引randIndex = int(random.uniform(0,len(dataIndex)))y_=sigmoid(sum(np.dot(train_X[randIndex].reshape(1,n),weight)))loss = train_y[i] -y_weight = weight + alpha*np.dot(train_X[i].transpose().reshape(n,1),loss.reshape(1,1))#去掉计算过的样本del(dataIndex[randIndex])return weight

代码实现时遇到的问题

1.对随机的理解,随机过的样本不再参与随机?

从dataIndex中随机取值,作为随机数;去掉dataIndex中被选过的值,即随机过的样本不再参与随机了。

randIndex = random.choice(dataIndex)
dataIndex.remove(randIndex)

 2.数组越界

代码错误:取0-len(dataIndex)的值作为索引,拿到randIndex的随机值;但是去掉索引为randIndex对应的值

dataIndex=[0,1,2,,,80] len=81 假设去掉了dataIndex[79]=79,

dataIndex=[0,1,2,,,78,80] len=80 取79 randIndex=dataIndex[79]=80 去掉randIndex[80]数组越界

就可能出现越界:

randIndex = dataIndex[int(random.uniform(0,len(dataIndex)))]
del(dataIndex[randIndex])

实验结果: 

w和epoch关系:

w1,w2收敛速度更慢了,w0的波动非常明显,感觉效果比改进前还差,这算法我学不明白。
 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/197828.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android Edittext进阶版(Textfieids)

一、Text fieids 允许用户在 UI 中输入文本,TextInputLayout TextInputEditText。 在 Text fieids 没出来(我不知道)前,想实现这个功能就需要自己自定义控件来实现这个功能。 几年前做个上面这种样式(filled 填充型)。需要多个控件组合 动画才能实现&a…

游戏开发增笑-扣扣死-Editor的脚本属性自定义定制-还写的挺详细的,旧版本反而更好

2012年在官方论坛注册的一个号,居然被禁言了,不知道官方现在是什么辣鸡,算了,大人不记狗子过 ”后来提交问题给CEO了,结果CEO百忙之中居然回复了,也是很低调的一个人,毕竟做技术的有什么坏心思呢…

基于SSM的老年公寓信息管理的设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

Clickhouse遇到密码错误如何修改密码

输入错误密码报错 rootDAILACHDBUD001:/var/log# clickhouse-client ClickHouse client version 23.4.2.11 (official build). Connecting to localhost:9000 as user default. Password for user (default): Connecting to localhost:9000 as user default. Code: 516. DB::E…

2023年JetBrains开发调查:Java 8仍广泛使用

开发者生态系统调查是查找和分析实际情况的好方法,而实际情况通常与看似流行或趋势的情况相反。 排名前三: Java8采用率:50%Java17采用率:45%Java11采用率:38% 看到这么多人仍在使用 Java 8(及更早版本&…

如何使用cpolar内网穿透工具实现公网SSH远程访问Deepin

文章目录 前言1. 开启SSH服务2. Deppin安装Cpolar3. 配置ssh公网地址4. 公网远程SSH连接5. 固定连接SSH公网地址6. SSH固定地址连接测试 前言 Deepin操作系统是一个基于Debian的Linux操作系统,专注于使用者对日常办公、学习、生活和娱乐的操作体验的极致&#xff0…

【C语言】深入理解指针(1)

目录 前言 (一)内存与地址 从实际生活出发 地址 内存 内存与地址关系密切 (二)指针变量 指针变量与取地址操作符 指针变量与解引用操作符 指针的大小 指针的运算 指针 - 整数 指针-指针 指针的关系运算 指针的类型的…

C++ 数组

目录 一维数组 一维数组的创建 一维数组的初始化 一维数组的使用 一维数组在内存中的存储 二维数组 二维数组的创建 二维数组的初始化 二维数组的使用 二维数组在内存中的存储 数组越界 一维数组 数组是一组形同类型的集合。 一维数组的创建 数组的创建方式&…

❀dialog命令运用于linux❀

目录 ❀dialog命令运用于linux❀ msgbox部件(消息框) yesno部件(yesno框) inputbox部件(输入文本框) textbox部件(文本框) menu部件(菜单框) fselect部…

哈希与哈希表

哈希表的概念 哈希表又名散列表,官话一点讲就是: 散列表(Hash table,也叫哈希表),是根据关键码值(Key value)而直接进行访问的数据结构。也就是说,它通过把关键码值映射到表中一个位置来访问记…

SpringBoot集成Redis

引入依赖 建议使用 Lettuce 连接驱动 <!--一&#xff1a;Jedis连接驱动缺点&#xff1a;Jedis基于TCP的阻塞性的连接方式1. 阻塞性IO2. 不能异步3. 线程不安全的Lettuce连接驱动优点&#xff1a;Lettuce基于Netty的多路复用的异步非阻塞的连接方式&#xff0c;1. 线程安全2…

公开Java框架开源到Maven中央仓库(避坑)

前言: gpg下载地址&#xff1a;http://www.gnupg.org/download 安装勾选 kleopatra 下载完成验证 gpg --version 当时为了开源Java框架&#xff0c;真的是绞尽脑汁&#xff0c;耗费很多精力查了很多资料&#xff0c;躺了很多坑&#xff0c;最终的结果无不是以发布失败而告终&am…

线程变量引发的session混乱问题

最近不是在救火&#xff0c;就是在救火的路上。 也没什么特别可写的&#xff0c;今天记录下最近遇到的一个问题&#xff0c;个人觉得挺有意思&#xff0c; 待有缘人阅读 言归正传&#xff0c;售后反馈&#xff1a; 营业查询中付款方式为第三方支付的几条银行缴费&#xff0c;创…

ai绘画Midjourney绘画提示词Prompt教程

一、Midjourney绘画工具 SparkAi【无需魔法使用】&#xff1a; SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作ChatGPT&#xff1f;小编这里写一个详细图文教程吧&#xff01;本系统使用NestjsVueTypescript框架技术&#xff0c;持续集成AI能力到…

成为AI产品经理——模型稳定性评估(PSI)

一、PSI作用 稳定性是指模型性能的稳定程度。 上线前需要进行模型的稳定性评估&#xff0c;是否达到上线标准。 上线后需要进行模型的稳定性的观测&#xff0c;判断模型是否需要迭代。 稳定度指标(population stability index ,PSI)。通过PSI指标&#xff0c;我们可以获得不…

chatgpt、百度、讯飞、阿里写一小段SQL对比

问题&#xff1a;有一张表pay&#xff0c;表中只有一个字段url&#xff0c;字段类型为text&#xff0c;没有其它字段。请写一段sql脚本&#xff0c;删除重复的url行记录&#xff0c;只保留一条记录。 通义千问的回答&#xff1a; DELETE FROM pay WHERE url IN (SELECT url F…

Windows使用Redis

Windows使用Redis 前言一、安装wsl2&#xff08;Windows Subsystem for Linux&#xff09;二、在wsl中下载并安装Redis一主二仆哨兵模式 前言 主要是记录一下&#xff0c;免得自己忘了。 一、安装wsl2&#xff08;Windows Subsystem for Linux&#xff09; Redis官网中说&…

GitHub上1.5K标星的QA和软件测试学习路线图

​最近在GitHub上发现一个项目&#xff0c;项目描述了作为QA工程师&#xff0c;进行软件测试技能提升时的&#xff0c;建议的软件测试学习顺序图​。 虽然2021年起就不再更新了&#xff0c;但是居然有1.5K的​星。 整个项目有两个部分​&#xff1a; ​1.QA和软件测试学习顺序…

嵌入式面试题

1. new和malloc 做嵌入式&#xff0c;对于内存是十分在意的&#xff0c;因为可用内存有限&#xff0c;所以嵌入式笔试面试题目&#xff0c;内存的题目高频。 1&#xff09;malloc和free是c/c语言的库函数&#xff0c;需要头文件支持stdlib.h&#xff1b;new和delete是C的关键…

craco + webpack 4 升 5

craco webpack 4 升 5 更新包版本尝试build升级其他依赖库使用process插件打印进度信息到底需要多少内存分析构建产出添加 splitChunk总结记录一些好文章&#xff1a; 我的项目使用 craco react 开发 我的 package.json {// ......"dependencies": {"ant-desi…