仅使用python标准库(不使用numpy)写一个小批量梯度下降的线性回归算法

看到一个有意思的题目:仅使用python的标准库,完成一个小批量梯度下降的线性回归算法

平常使用numpy这样的计算库习惯了,只允许使用标准库还有点不习惯,下面就使用这个过程来写一个。

import random
from typing import List# 生成测试数据
def generate_data(num_samples: int, weights: List[float], bias: float, noise=0.1) -> (List[List[float]], List[float]):X = [[random.uniform(-10, 10) for _ in range(len(weights))] for _ in range(num_samples)]y = [sum(w * x for w, x in zip(weights, x_i)) + bias + random.uniform(-noise, noise) for x_i in X]return X, y# 计算损失
def mse(y_true: List[float], y_pred: List[float]):return 0.5 * sum((yt - yp) for yt, yp in zip(y_true, y_pred)) ** 2# 将矩阵转置
def transpose(mat: List[List[float]]):row, col = len(mat), len(mat[0])# 固定列,访问行result = [[mat[r][c] for r in range(row)] for c in range(col)]return result# 计算矩阵乘法
def matmul(mat: List[List[float]], vec: List[float]):return [sum(r * c for r, c in zip(row, vec)) for row in mat]# 计算梯度
def compute_grad(y_true_batch: List[float], y_pred_batch: List[float], x_batch: List[List[float]]):batch_size = len(y_true_batch)residual = [yt - yp for yt, yp in zip(y_true_batch, y_pred_batch)]# 根据 y = x @ w + b# grad_w = -x.T @ residualgrad_w = matmul(transpose(x_batch), residual)grad_w = [-gw / batch_size for gw in grad_w]grad_b = -sum(residual) / batch_size# grad_w: List[float]# grad_b: floatreturn grad_w, grad_b# 开启训练
def train():lr = 0.01epochs = 50batch_size = 16dim_feat = 3num_samples = 500weights = [random.random() * 0.1 for _ in range(dim_feat)]bias = random.random() * 0.1print('original params')print('w:', weights)print('b:', bias)X, y = generate_data(num_samples, weights, bias, noise=0.1)for epoch in range(epochs):for i in range(0, num_samples, batch_size):x_batch = X[i:i+batch_size]y_batch = y[i:i+batch_size]y_pred = [item + bias for item in matmul(x_batch, weights)]loss = mse(y_batch, y_pred)grad_w, grad_b = compute_grad(y_batch, y_pred, x_batch)weights = [w - lr * gw for w, gw in zip(weights, grad_w)]bias -= lr * grad_bprint(f'Epoch: {epoch + 1}, Loss = {loss:.3f}')print('trained params')print('w:', weights)print('b:', bias)train()

输出结果如下

original params
w: [0.04845598598148951, 0.007741816562531545, 0.02436678108587098]
b: 0.01644073086522535
Epoch: 1, Loss = 0.000
Epoch: 2, Loss = 0.000
Epoch: 3, Loss = 0.000
Epoch: 4, Loss = 0.000
Epoch: 5, Loss = 0.000
Epoch: 6, Loss = 0.000
Epoch: 7, Loss = 0.000
Epoch: 8, Loss = 0.000
Epoch: 9, Loss = 0.000
Epoch: 10, Loss = 0.000
Epoch: 11, Loss = 0.000
Epoch: 12, Loss = 0.000
Epoch: 13, Loss = 0.000
Epoch: 14, Loss = 0.000
Epoch: 15, Loss = 0.000
Epoch: 16, Loss = 0.000
Epoch: 17, Loss = 0.000
Epoch: 18, Loss = 0.000
Epoch: 19, Loss = 0.000
Epoch: 20, Loss = 0.000
Epoch: 21, Loss = 0.000
Epoch: 22, Loss = 0.000
Epoch: 23, Loss = 0.000
Epoch: 24, Loss = 0.000
Epoch: 25, Loss = 0.000
Epoch: 26, Loss = 0.000
Epoch: 27, Loss = 0.000
Epoch: 28, Loss = 0.000
Epoch: 29, Loss = 0.000
Epoch: 30, Loss = 0.000
Epoch: 31, Loss = 0.000
Epoch: 32, Loss = 0.000
Epoch: 33, Loss = 0.000
Epoch: 34, Loss = 0.000
Epoch: 35, Loss = 0.000
Epoch: 36, Loss = 0.000
Epoch: 37, Loss = 0.000
Epoch: 38, Loss = 0.000
Epoch: 39, Loss = 0.000
Epoch: 40, Loss = 0.000
Epoch: 41, Loss = 0.000
Epoch: 42, Loss = 0.000
Epoch: 43, Loss = 0.000
Epoch: 44, Loss = 0.000
Epoch: 45, Loss = 0.000
Epoch: 46, Loss = 0.000
Epoch: 47, Loss = 0.000
Epoch: 48, Loss = 0.000
Epoch: 49, Loss = 0.000
Epoch: 50, Loss = 0.000
trained params
w: [0.05073234817652038, 0.007306286342947243, 0.023218625946243507]
b: 0.016648404245261664

可以看到,结果还是不错的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/848329.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

层出不穷的大模型产品,你怎么选?【模板】

层出不穷的大模型产品,你怎么选? 随着近日腾讯元宝APP的正式上线,国内大模型产品又添一员。关于接连出现的“全能“大模型AIGC产品,你都用过哪些呢?不妨来分享一下你的使用体验吧!在这些大模型产品中&…

使用Qt对word文档进行读写

目录 开发环境原理使用的QT库搭建开发环境准备word模板测试用例结果Gitee地址 开发环境 vs2022 Qt 5.9.1 msvc2017_x64,在文章最后提供了源码。 原理 Qt对于word文档的操作都是在书签位置进行插入文本、图片或表格的操作。 使用的QT库 除了基本的gui、core、…

鲁教版八年级数学上册-笔记

文章目录 第一章 因式分解1 因式分解2 提公因式法3 公式法 第二章 分式与分式方程1 认识分式2 分式的乘除法3 分式的加减法4 分式方程 第三章 数据的分析1 平均数2 中位数与众数3 从统计图分析数据的集中趋势4 数据的离散程度 第四章 图形的平移与旋转1 图形的平移2 图形的旋转…

解决 @Scope 注解失效问题:深入理解与排查方法

在使用 Spring 框架时,你可能遇到过 Scope 注解失效的情况。这个注解是用来定义 Bean 的作用域的,比如 singleton、prototype、request、session 等。当 Scope 注解失效时,意味着 Bean 的作用域没有被正确地设置,这可能会导致 Bea…

JavaWeb1 Json+BOM+DOM+事件监听

JS对象-Json //Json 字符串转JS对象 var jsObject Json.parse(userStr); //JS对象转JSON字符串 var jsonStr JSON.stringify(jsObject);JS对象-BOM BOM是浏览器对象模型,允许JS与浏览器对话 它包括5个对象:window、document、navigator、screen、hi…

力扣hot100:138. 随机链表的复制(技巧,数据结构)

LeetCode:138. 随机链表的复制 这是一个经典的数据结构题,当做数据结构来学习。 1、哈希映射 需要注意的是,指针也能够当做unordered_map的键值,指针实际上是一个地址值,在unordered_map中,使用指针的实…

VXLAN技术

VXLAN技术 一、VXLAN简介 1、定义 VXLAN(Virtual eXtensible Local Area Network):采用MAC in UDP(User Datagram Protocol)封装方式,是NVO3(Network Virtualization over Layer 3&#xff09…

使用 Logback.xml 配置文件输出日志信息

官方链接:Chapter 3: Configurationhttps://logback.qos.ch/manual/configuration.html 配置使用 logback 的方式有很多种,而使用配置文件是较为简单的一种方式,下述就是简单描述一个 logback 配置文件基本的配置项: 由于 logba…

Vuforia AR篇(七)— 二维码识别

目录 前言一、什么是Barcode ?二、使用步骤三、点击二维码显示信息四、效果 前言 在数字化时代,条形码和二维码已成为连接现实世界与数字信息的重要桥梁。Vuforia作为领先的AR开发平台,提供了Barcode Scanner功能,使得在Unity中实…

ros常用环境变量

RMW层DDS实现 rti dds export RMW_IMPLEMENTATIONrmw_connextdds //rti dds 或者 RMW_IMPLEMENTATIONrmw_connextdds ros2 run ... export NDDS_QOS_PROFILES/qos.xml //配置qos文件fastdds export RMW_IMPLEMENTATIONrmw_fastrtps_cpp 或者 RMW_IMPLEMENTATIONrmw_fas…

提供全面的网络监控和管理功能,帮助客户实时了解网络状态和优化网络性能

联通IP Transit产品依托中国联通在全球范围内的AS4837/AS10099网络平台,采用BGP对接技术,为客户自有的IP地址段提供全球互联网络穿透服务。通过这一产品,客户可以享受到专属带宽带来的优质访问体验,快速、高效地将网络数据内容接入…

力扣1438.绝对差不超过限制的最长连续子数组

力扣1438.绝对差不超过限制的最长连续子数组 难点&#xff1a;保存数组缩小后的最大最小值 用两个单调队列分别处理最大值和最小值 class Solution {public:int longestSubarray(vector<int>& nums, int limit) {deque<int> quemax,quemin;int n nums.size…

Http和Socks的区别?

HTTP和SOCKS都是用于网络通信的协议&#xff0c;但它们在设计目标和应用场景上有显著的区别。 一、HTTP (HyperText Transfer Protocol) HTTP是用于分布式、协作和超媒体信息系统的应用层协议。主要特点包括&#xff1a; 用途&#xff1a;HTTP主要用于万维网&#xff0c;通过…

json和axion结合

目录 java中使用JSON对象 在pom.xml中导入依赖 使用 public static String toJSONString(Object object)把自定义对象变成JSON对象 json和axios综合案例 使用的过滤器 前端代码 响应和请求都是普通字符串 和 请求时普通字符串&#xff0c;响应是json字符串 响应的数据是…

MySQL换路径(文件夹)

#MySQL作为免费数据库很受欢迎&#xff0c;即使公司没有使用&#xff0c;自己也可以用。它是一个服务&#xff0c;在点击CtrlAltDelete选择任务管理器后&#xff0c;它在服务那个归类里。 经常整理计算机磁盘分类的小伙伴&#xff0c;如果你们安装了MySQL&#xff0c;并且想移…

[Vue3] 滚动条自动滚动到底部

需求 在一个区域会依次打印log&#xff0c;随着log的加长&#xff0c;出现滚动条&#xff0c;而滚动条应该始终保持在最下方。 点击回到顶部按钮&#xff0c;可以使滚动条回到最上方 方案 在滚动区域添加reflog为一个数组&#xff0c;对其添加watch在watch函数中&#xff0c…

actuator/env;.js 漏洞修复

该问题是指Spring Boot Actuator中的一个漏洞&#xff0c;它涉及到暴露了Spring Boot应用的环境信息。Spring Boot Actuator是一个用于监控和管理Spring Boot应用的组件&#xff0c;它提供了多个端点&#xff08;endpoints&#xff09;&#xff0c;如健康检查、度量收集、环境信…

插件:Plugins

一、安装网格插件

重大变化,2024软考!

根据官方发布的2024年度计算机技术与软件专业技术资格&#xff08;水平&#xff09;考试安排&#xff0c;2024年软考上、下半年开考科目有着巨大变化&#xff0c;我为大家整理了相关信息&#xff0c;大家可以看看&#xff01; &#x1f3af;2024年上半年&#xff1a;5月25日&am…

业务安全蓝军测评标准解读—业务安全体系化

目录 1.前言 2.业务蓝军测评标准 2.1 业务安全脆弱性评分(ISVS) 2.2 ISVS评分的参考意义<