机器学习复习(2)——线性回归SGD优化算法

目录

线性回归代码

线性回归理论

SGD算法

手撕线性回归算法

模型初始化

定义模型主体部分

定义线性回归模型训练过程

数据demo准备

模型训练与权重参数

定义线性回归预测函数

定义R2系数计算

可视化展示 

预测结果

训练过程 

sklearn进行机器学习

线性回归代码

class My_Model(nn.Module):def __init__(self, input_dim):super(My_Model, self).__init__()# 矩阵的维度(dimensions) self.layers = nn.Sequential(nn.Linear(input_dim, 16),nn.ReLU(),nn.Linear(16, 8),nn.ReLU(),nn.Linear(8, 1))def forward(self, x):x = self.layers(x)x = x.squeeze(1) # (B, 1) -> (B)return x

线性回归理论

回归算法是相对分类算法而言的,与我们想要预测的目标变量y的值类型有关。

如果目标变量y是分类型变量,如预测用户的性别(男、女),预测月季花的颜色(红、白、黄……),那我们就需要用分类算法去拟合训练数据并做出预测;

如果y是连续型变量,如预测用户的收入(4千,2万,10万……),预测患肺癌的概率(1%,50%,99%……),我们则需要用回归模型。

有时分类问题也可以转化为回归问题。可以用回归模型先预测出患肺癌的概率,然后再给定一个阈值,例如50%,概率值在50%以下为A类,50%以上为B类。

一元线性回归公式:

 具象化含义:

SGD算法

手撕线性回归算法

模型初始化

### 初始化模型参数
def initialize_params(dims):'''输入:dims:训练数据变量维度输出:w:初始化权重参数值b:初始化偏差参数值'''# 初始化权重参数为零矩阵w = np.zeros((dims, 1))# 初始化偏差参数为零b = 0return w, b
w,b=initialize_params(3)#用于测试
print("w初始化是",w)
print("b初始化是",b)

运行结果:

定义模型主体部分

包括线性回归公式、均方损失和参数偏导三部分
def linear_loss(X, y, w, b):'''输入:X:输入变量矩阵y:输出标签向量w:变量参数权重矩阵b:偏差项输出:y_hat:线性模型预测输出loss:均方损失值dw:权重参数一阶偏导db:偏差项一阶偏导'''# 训练样本数量num_train = X.shape[0]# 训练特征数量num_feature = X.shape[1]# 线性回归预测输出y_hat = np.dot(X, w) + b# 计算预测输出与实际标签之间的均方损失loss = np.sum((y_hat-y)**2)/num_train# 基于均方损失对权重参数的一阶偏导数dw = np.dot(X.T, (y_hat-y)) /num_train# 基于均方损失对偏差项的一阶偏导数db = np.sum((y_hat-y)) /num_trainreturn y_hat, loss, dw, db

定义线性回归模型训练过程

### 定义线性回归模型训练过程
def linear_train(X, y, learning_rate=0.01, epochs=10000):'''输入:X:输入变量矩阵y:输出标签向量learning_rate:学习率epochs:训练迭代次数输出:loss_his:每次迭代的均方损失params:优化后的参数字典grads:优化后的参数梯度字典'''# 记录训练损失的空列表loss_his = []# 初始化模型参数w, b = initialize_params(X.shape[1])# 迭代训练for i in range(1, epochs):# 计算当前迭代的预测值、损失和梯度y_hat, loss, dw, db = linear_loss(X, y, w, b)
#y_hat是预测值,loss是损失,dw是权重参数一阶偏导,db是偏差项一阶偏导# 基于梯度下降的参数更新w += -learning_rate * dwb += -learning_rate * db# 记录当前迭代的损失loss_his.append(loss)# 每1000次迭代打印当前损失信息if i % 10000 == 0:print('epoch %d loss %f' % (i, loss))# 将当前迭代步优化后的参数保存到字典params = {'w': w,'b': b}# 将当前迭代步的梯度保存到字典grads = {'dw': dw,'db': db}     return loss_his, params, grads

其中的shape操作说明:

import numpy as np
# 创建一个示例的训练数据集 X
X = np.array([[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12],[13, 14, 15]])
# 计算训练样本数量
shape0 = X.shape[0]
shape1 = X.shape[1]
print("shape0是",shape0)
print("shape1是",shape1)

运行结果:

数据demo准备

from sklearn.datasets import load_diabetes
diabetes = load_diabetes()
data = diabetes.data
target = diabetes.target 
print(data.shape)
print(target.shape)
print(data[:5])
print(target[:5])
###########################################
# 导入sklearn diabetes数据接口
from sklearn.datasets import load_diabetes
# 导入sklearn打乱数据函数
from sklearn.utils import shuffle
# 获取diabetes数据集
diabetes = load_diabetes()
# 获取输入和标签
data, target = diabetes.data, diabetes.target 
# 打乱数据集
X, y = shuffle(data, target, random_state=13)
# 按照8/2划分训练集和测试集
offset = int(X.shape[0] * 0.8)
# 训练集
X_train, y_train = X[:offset], y[:offset]
# 测试集
X_test, y_test = X[offset:], y[offset:]
# 将训练集改为列向量的形式
y_train = y_train.reshape((-1,1))
# 将验证集改为列向量的形式
y_test = y_test.reshape((-1,1))
# 打印训练集和测试集维度
print("X_train's shape: ", X_train.shape)
print("X_test's shape: ", X_test.shape)
print("y_train's shape: ", y_train.shape)
print("y_test's shape: ", y_test.shape)

模型训练与权重参数

# 线性回归模型训练
loss_his, params, grads = linear_train(X_train, y_train, 0.01, 200000)
# 打印训练后得到模型参数
print(params)

定义线性回归预测函数

### 定义线性回归预测函数
def predict(X, params):'''输入:X:测试数据集params:模型训练参数输出:y_pred:模型预测结果'''# 获取模型参数w = params['w']b = params['b']# 预测y_pred = np.dot(X, w) + breturn y_pred
# 基于测试集的预测
y_pred = predict(X_test, params)
# 打印前五个预测值
y_pred[:5]

定义R2系数计算

R2系数,也称为决定系数(Coefficient of Determination),是一种用于评估回归模型拟合优度的统计指标。它表示模型对观测数据的方差解释比例,通常用于衡量回归模型的拟合程度。

R2系数的取值范围在0到1之间,具体含义如下:

  • 如果R2等于0,表示模型未能解释目标变量的任何方差,即模型无法拟合数据。
  • 如果R2等于1,表示模型完美拟合了数据,能够解释目标变量的所有方差。
  • 如果R2在0和1之间,表示模型能够解释一部分目标变量的方差,数值越接近1,说明模型的拟合程度越好。

计算公式如下:

其中:

  • SSR(Sum of Squares of Residuals)表示模型的残差平方和,即实际观测值与模型预测值之间的差异的平方和。
  • SST(Total Sum of Squares)表示总平方和,即实际观测值与观测值的均值之间的差异的平方和。

R2系数越接近1,说明模型对数据的拟合越好,而越接近0则表示模型的拟合效果较差。这个指标对于评估回归模型的性能非常有用,帮助我们了解模型解释数据方差的程度。

### 定义R2系数函数
def r2_score(y_test, y_pred):'''输入:y_test:测试集标签值y_pred:测试集预测值输出:r2:R2系数'''# 测试标签均值y_avg = np.mean(y_test)# 总离差平方和ss_tot = np.sum((y_test - y_avg)**2)# 残差平方和ss_res = np.sum((y_test - y_pred)**2)# R2计算r2 = 1 - (ss_res/ss_tot)return r2

可视化展示 

预测结果

import matplotlib.pyplot as plt
f = X_test.dot(params['w']) + params['b']plt.scatter(range(X_test.shape[0]), y_test)
plt.plot(f, color = 'darkorange')
plt.xlabel('X_test')
plt.ylabel('y_test')
plt.show();

运行结果:

训练过程 

plt.plot(loss_his, color='blue')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.show()

运行结果:

sklearn进行机器学习

 和torch.nn类似:封装好了linear函数,直接掉包

### sklearn版本为1.0.2
# 导入线性回归模块
from sklearn import linear_model
from sklearn.metrics import mean_squared_error, r2_score
# 创建模型实例
regr = linear_model.LinearRegression()
# 模型拟合
regr.fit(X_train, y_train)
# 模型预测
y_pred = regr.predict(X_test)
# 打印模型均方误差
print("Mean squared error: %.2f" % mean_squared_error(y_test, y_pred))
# 打印R2
print('R2 score: %.2f' % r2_score(y_test, y_pred))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/665107.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电商小程序01需求分析

目录 1 电商用例分析2 功能架构3 原型开发3.1 首页3.2 店铺页面3.3 配货单3.4 配货单有货3.5 我的应用3.6 商品详情3.7 订单确认3.8 收货地址3.9 店铺详情3.10 店铺分类3.11 商品分类 总结 低代码学习的时候最高效的方法就是带着问题去学习,一般可以先从电商小程序开…

【大数据】Flink SQL 语法篇(三):窗口聚合(TUMBLE、HOP、SESSION、CUMULATE)

Flink SQL 语法篇(三):窗口聚合 1.滚动窗口(TUMBLE)1.1 Group Window Aggregation 方案(支持 Batch / Streaming 任务)1.2 Windowing TVF 方案(1.13 只支持 Streaming 任务&#xff…

配置实例—交换机VLAN聚合配置实例

一、组网需求 某公司拥有多个部门且位于同一网段,为了提升业务安全性,将不同部门的用户划分到不同VLAN中。现由于业务需要,不同部门间的用户需要互通。如图1所示,VLAN2和VLAN3为不同部门,现需要实现不同VLAN间的用户可…

浪漫的通讯录(顺序表篇)

本篇会加入个人的所谓‘鱼式疯言’ ❤️❤️❤️鱼式疯言:❤️❤️❤️此疯言非彼疯言 而是理解过并总结出来通俗易懂的大白话, 我会尽可能的在每个概念后插入鱼式疯言,帮助大家理解的. 🤭🤭🤭可能说的不是那么严谨.但小编初心是能让更多人能…

代码随想录算法训练营第39天 | 62.不同路径 + 63.不同路径 II

今日任务 62.不同路径 63. 不同路径 II 62.不同路径 - Medium 题目链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器人每次只…

flutter如何实现省市区选择器

前言 当我们需要用户填写地址时,稳妥的做法是让用户通过“滚轮”来滑动选择省份,市,区,此文采用flutter的第三方库来实现这一功能,比调用高德地图api简单一些。 流程 选择库 这里我选择了一个最近更新且支持中国的…

Acwing 141 周赛 解题报告 | 珂学家 | 逆序数+奇偶性分析

前言 整体评价 很普通的一场比赛,t2思维题,初做时愣了下,幸好反应过来了。t3猜猜乐,感觉和逆序数有关,和奇偶性有关。不过要注意int溢出。 欢迎关注: 珂朵莉的天空之城 A. 客人数量 题型: 签到 累加和即可 import…

Three.js学习3:第一个Three.js页面

一、一图看懂Three.js 坐标 这个没什么好说的,只是需要注意颜色。在 Three.js 提供的编辑器中,各种物体的坐标也这样的色彩: 红色:x 轴 绿色:y 轴 蓝色:z 轴 Three.js 提供的编辑器可以在本地 Three.js …

常用git指令

一.安装配置git&&利用SSH完成Git与GitHub的绑定 1.参考知乎网址:还不会使用 GitHub ? GitHub 教程来了!万字图文详解 二.在git上更新仓库步骤 1.在新建文件夹下,右键选择“git bash here” 2.把项目下载到本地&#xf…

AI应用开发-git开源项目的一些问题及镜像解决办法

AI应用开发相关目录 本专栏包括AI应用开发相关内容分享,包括不限于AI算法部署实施细节、AI应用后端分析服务相关概念及开发技巧、AI应用后端应用服务相关概念及开发技巧、AI应用前端实现路径及开发技巧 适用于具备一定算法及Python使用基础的人群 AI应用开发流程概…

微信小程序(三十一)本地同步存储API

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.存储数据 2.读取数据 3.删除数据 4.清空数据 源码&#xff1a; index.wxml <!-- 列表渲染基础写法&#xff0c;不明白的看上一篇 --> <view class"students"><view class"item…

使用 git 将本地文件上传到 gitee 远程仓库中,推送失败

项目场景&#xff1a; 背景&#xff1a; 使用 git 想要push 本地文件 到 另一个远程仓库&#xff0c;执行 git push origin master后此时报错 问题描述 问题&#xff1a; git push 本地文件 到 另一个远程仓库时&#xff0c;运行 git push origin master ,push文件失败&…

老版本labelme如何不保存imagedata

我的版本是3.16&#xff0c;默认英文且不带取消保存imagedata的选项。 最简单粗暴的方法就是在json文件保存时把传递过来的imagedata数据设定为None&#xff0c;方法如下&#xff1a; 找到labelme的源文件&#xff0c;例如&#xff1a;D:\conda\envs\deeplab\Lib\site-packages…

vue 适配大屏 页面 整体缩放

正常应该放在app.vue 里面。我这里因为用到element-ui 弹框无法缩放&#xff0c;所以加在body上面 (function (doc, win) {var docEl doc.documentElement,resizeEvt orientationchange in window ? orientationchange : resize,recalc function () {var clientWidth docE…

基于协同过滤的个性化电影推荐系统分析设计python+flask

本系统为用户而设计制作个性化电影推荐管理&#xff0c;旨在实现个性化电影推荐智能化、现代化管理。本个性化电影推荐自动化系统的开发和研制的最终目的是将个性化电影推荐的运作模式从手工记录数据转变为网络信息查询管理&#xff0c;从而为现代管理人员的使用提供更多的便利…

PPT录屏功能在哪?一键快速找到它!

在现代办公环境中&#xff0c;ppt的录屏功能日益受到关注&#xff0c;它不仅能帮助我们记录演示文稿的播放过程&#xff0c;还能将操作过程、游戏等内容完美录制下来。可是很多人不知道ppt录屏功能在哪&#xff0c;本文将为您介绍ppt录屏的打开方法&#xff0c;以帮助读者更好地…

如何计算两个指定日期相差几年几月几日

一、题目要求 假定给出两个日期&#xff0c;让你计算两个日期之间相差多少年&#xff0c;多少月&#xff0c;多少天&#xff0c;应该如何操作呢&#xff1f; 本文提供网页、ChatGPT法、VBA法和Python法等四种不同的解法。 二、解决办法 1. 网页计算法 这种方法是利用网站给…

时间回显+选择(年月日时分秒

一、获取某个时间 1、Date获取Date类型 <el-form-item label"时间" name"endTime"><el-date-picker type"datetime" v-model"editForm.endTime"></el-date-picker> </el-form-item> 效果如图&#xff1a; …

Java学习笔记2024/1/29

1. 流程控制语句 笔记地点 1.1 流程控制语句基础概念 package com.angus.processControlStatement.processControlStatement;public class processControlStatementNote {public static void main(String[] args) {// 本章知识点: 流程控制语句// 流程控制语句: 通过一些语句…

基于SpringBoot Vue超市管理系统

大家好✌&#xff01;我是Dwzun。很高兴你能来阅读我&#xff0c;我会陆续更新Java后端、前端、数据库、项目案例等相关知识点总结&#xff0c;还为大家分享优质的实战项目&#xff0c;本人在Java项目开发领域有多年的经验&#xff0c;陆续会更新更多优质的Java实战项目&#x…