PyTorch实现逻辑回归

最终效果

先看下最终效果:
1
这里用一条直线把二维平面上不同的点分开。

生成随机数据

#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导n_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)

数据可视化

def plot(x, y, c):ax = plt.gca()sc = ax.scatter(x, y, color='black')paths = []for i in range(len(x)):if c[i].item() == 0:marker_obj = mmarkers.MarkerStyle('o')else:marker_obj = mmarkers.MarkerStyle('x')path = marker_obj.get_path().transformed(marker_obj.get_transform())paths.append(path)sc.set_paths(paths)return sc
plot(x, y, c)
plt.show()

使用x和o来表示两种不同类别的数据。
1

定义模型和损失函数

#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)  # 随机初始化w
b = torch.zeros((1),requires_grad=True)  # 使用0初始化bwx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()

这里使用了平方损失函数来估算模型准确度。

训练模型

最多训练100次,每次都会更新模型参数,当损失值小于0.03时停止训练。

xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):#前向传播loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()#反向传播loss.backward()#更新参数b.data.sub_(lr*b.grad) # b = b - lr*b.gradw.data.sub_(lr*w.grad) # w = w - lr*w.grad#绘图if iteration % 3 == 0:plot(x, y, c)yy = w*xx + bplt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})plt.xlim(-4,4)plt.ylim(-4,4)plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))plt.show()if loss.data.numpy() < 0.03:  # 停止条件break

全部代码

import torch
import matplotlib.pyplot as plt
import matplotlib.markers as mmarkers#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导wx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + bn_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)def plot(x, y, c):ax = plt.gca()sc = ax.scatter(x, y, color='black')paths = []for i in range(len(x)):if c[i].item() == 0:marker_obj = mmarkers.MarkerStyle('o')else:marker_obj = mmarkers.MarkerStyle('x')path = marker_obj.get_path().transformed(marker_obj.get_transform())paths.append(path)sc.set_paths(paths)return sc
plot(x, y, c)
plt.show()#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)#随机初始化w
b = torch.zeros((1),requires_grad=True)#使用0初始化bwx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):#前向传播loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()#反向传播loss.backward()#更新参数b.data.sub_(lr*b.grad) # b = b - lr*b.gradw.data.sub_(lr*w.grad) # w = w - lr*w.grad#绘图if iteration % 3 == 0:plot(x, y, c)yy = w*xx + bplt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})plt.xlim(-4,4)plt.ylim(-4,4)plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))plt.show()if loss.data.numpy() < 0.03:#停止条件break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/210885.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小米手机锁屏时间设置为永不休眠_手机不息屏_保持亮屏

环境&#xff1a;打开手机自带的锁屏时间设置发现没有 永不息屏的选项 原因&#xff1a;采用了三星OLED屏幕&#xff0c;所以根据OLED屏幕特性&#xff0c;这个是为了防止烧屏而特意设计的。非OLED机型支持设置“永不” 解决方案1&#xff1a;原生系统是支持永不锁屏的&#…

ES 如何将国际标准时间格式进行格式化与调整时区

需求&#xff0c;日志收集的时候&#xff0c;时间格式是国际标准时间格式。形如yyyy-MM-ddTHH:mm:ss.SSS。 &#xff08;2023-12-05T02:45:50.282Z&#xff09;这个时区也不对&#xff0c;那如何将此类型的时间&#xff0c;进行格式化呢&#xff1f; 本篇文章体统一个案例&…

Other -- ChatGPT 原理

本文为个人理解&#xff0c;帮助小白&#xff08;本人就是&#xff09;了解正在创建新时代的 AI 产品&#xff0c;如文中理解有误欢迎留言。 [参考链接--](https://baijiahao.baidu.com/s?id1765556782543603120&wfrspider&forpc) 1. 了解一些基本概念 大语言模型&a…

【面试】测试/测开(ING)

63. APP端特有的测试 参考&#xff1a;APP专项测试、APP应用测试 crash和anr的区别 1&#xff09;网络测试 2&#xff09;中断测试 3&#xff09;安装、卸载测试 4&#xff09;兼容测试 5&#xff09;性能测试&#xff08;耗电量、流量、内存、服务器端&#xff09; 6&#xf…

画对比折线图【Python】

出这一期想必是我做某个课程作业遇到了。 由于去各个官网下载对比图要钱&#xff0c;我还是不想花钱的&#xff01;真讨厌&#xff01;浅浅水一期。 以下是要做的对比图的数据&#xff1a; 代码&#xff1a; from matplotlib import pyplot as plt#设置中文显示plt.rcParams[…

CSS新手入门笔记整理:CSS浮动布局

文档流概述 正常文档流 “文档流”指元素在页面中出现的先后顺序。正常文档流&#xff0c;又称为“普通文档流”或“普通流”&#xff0c;也就是W3C标准所说的“normal flow”。正常文档流&#xff0c;将一个页面从上到下分为一行一行&#xff0c;其中块元素独占一行&#xf…

ChatGPT OpenAI API请求限制 尝试解决

1. OpenAI API请求限制 Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: Rate limit reached for gpt-3.5-turbo-16k in organization org-U7I2eKpAo6xA7RUa2Nq307ae on reques…

让内存无处可逃:智能指针[C++11]

智能指针 文章目录 智能指针前言RAII什么是智能指针智能指针的应用示例 C98的auto_ptr共享型智能指针&#xff1a;shared_ptrshared_ptr的使用初始化获取原生指针指定删除器默认删除器default_delete指定删除器指定删除器管理动态数组 shared_ptr的伪实现shared_ptr的注意事项避…

【Docker】进阶之路:(五)Docker引擎

【Docker】进阶之路&#xff1a;&#xff08;五&#xff09;Docker引擎 Docker引擎简介Docker引擎的组件构成runccontainerd Docker引擎简介 Docker引擎是用来运行和管理容器的核心部分。Docker首次发布时&#xff0c;Docker 引擎由LXC 和 Docker daemon 两个核心组件构成。 …

linux驱动开发——内核调试技术

目录 一、前言 二、内核调试方法 2.1 内核调试概述 2.2 学会分析内核源程序 2.3调试方法介绍 三、内核打印函数 3.1内核镜像解压前的串口输出函数 3.2 内核镜像解压后的串口输出函数 3.3 内核打印函数 四、获取内核信息 4.1系统请求键 4.2 通过/proc 接口 4.3 通过…

算法:有效的括号(入栈出栈)

时间复杂度 O(n) 空间复杂度 O(n∣Σ∣)&#xff0c;其中 Σ 表示字符集&#xff0c;本题中字符串只包含 6 种括号 /*** param {string} s* return {boolean}*/ var isValid function(s) {const map {"(":")","{":"}","["…

python格式化内容

1.字符串格式化: 定义列表 [{"姓名": "张三", "年龄": 18, "性别": "男"}, {"姓名": "里斯李四李斯", "年龄": 18, "性别": "男"}, {"姓名": "斯托夫斯基…

上位机与PLC:ModbusTCP通讯之数据类型转换

前请提要: 从PLC读取的数值,不管是读正负整数还是正负浮点数,读取过来后都会变成UInt16,也就是Ushort类型 一、ushort(UInt16)转成 Int32 源代码方法: //ushort类型转Int32类型的方法private int ushortToInt32(ushort[] date, int start){//先进行判断,长度是否正确…

模型评价指标

用训练好的模型结果进行预测&#xff0c;需要采用一些评价指标来进行评价&#xff0c;才可以得到最优的模型 常用的指标&#xff1a; 1.分类任务 ConfusionMatrix 混淆矩阵Accuracy 准确率Precision 精确率Recall 召回率F1 score H-mean值ROC Curve ROC曲线PR …

天池SQL训练营(三)-复杂查询方法-视图、子查询、函数等

-天池龙珠计划SQL训练营 SQL训练营页面地址&#xff1a;https://tianchi.aliyun.com/specials/promotion/aicampsql 3.1 视图 我们先来看一个查询语句&#xff08;仅做示例&#xff0c;未提供相关数据&#xff09; SELECT stu_name FROM view_students_info;单从表面上看起来…

rancher harvester deploy demo 【部署 harvester v1.2.1】

简介 Harvester 是一个现代的、开放的、可互操作的、基于Kubernetes的超融合基础设施(HCI)解决方案。它是一种开源替代方案&#xff0c;专为寻求云原生HCI解决方案的运营商而设计。Harvester运行在裸机服务器上&#xff0c;提供集成的虚拟化和分布式存储功能。除了传统的虚拟机…

SQL SELECT 语句

SELECT 语句用于从数据库中选取数据。 SQL SELECT 语句 SELECT 语句用于从数据库中选取数据。 结果被存储在一个结果表中&#xff0c;称为结果集。 SQL SELECT 语法 SELECT column1, column2, ... FROM table_name; 与 SELECT * FROM table_name; 参数说明&#xff1a; …

二分查找|双指针:LeetCode:2398.预算内的最多机器人数目

作者推荐 本文涉及的基础知识点 二分查找算法合集 滑动窗口 单调队列&#xff1a;计算最大值时&#xff0c;如果前面的数小&#xff0c;则必定被淘汰&#xff0c;前面的数早出队。 题目 你有 n 个机器人&#xff0c;给你两个下标从 0 开始的整数数组 chargeTimes 和 runnin…

算法:最长公共前缀(横向扫描和纵向扫描)

横向扫描 时间复杂度 O(m * n)&#xff0c;空间复杂度O(1) /*** param {string[]} strs* return {string}*/ var longestCommonPrefix function(strs) {// 先把第一个字符串拿出来let str strs[0]// 用 startsWith 检查数组中每个字符串是否以当前字符串为前缀while(!strs.e…

听GPT 讲Rust源代码--src/tools(11)

File: rust/src/tools/rust-analyzer/crates/hir/src/lib.rs 在Rust源代码中&#xff0c;rust/src/tools/rust-analyzer/crates/hir/src/lib.rs文件的作用是定义了Rust语言的高级抽象层次&#xff08;Higher-level IR&#xff0c;HIR&#xff09;。它包含了Rust语言的各种结构和…