PyTorch实现逻辑回归

最终效果

先看下最终效果:
1
这里用一条直线把二维平面上不同的点分开。

生成随机数据

#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导n_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)

数据可视化

def plot(x, y, c):ax = plt.gca()sc = ax.scatter(x, y, color='black')paths = []for i in range(len(x)):if c[i].item() == 0:marker_obj = mmarkers.MarkerStyle('o')else:marker_obj = mmarkers.MarkerStyle('x')path = marker_obj.get_path().transformed(marker_obj.get_transform())paths.append(path)sc.set_paths(paths)return sc
plot(x, y, c)
plt.show()

使用x和o来表示两种不同类别的数据。
1

定义模型和损失函数

#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)  # 随机初始化w
b = torch.zeros((1),requires_grad=True)  # 使用0初始化bwx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()

这里使用了平方损失函数来估算模型准确度。

训练模型

最多训练100次,每次都会更新模型参数,当损失值小于0.03时停止训练。

xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):#前向传播loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()#反向传播loss.backward()#更新参数b.data.sub_(lr*b.grad) # b = b - lr*b.gradw.data.sub_(lr*w.grad) # w = w - lr*w.grad#绘图if iteration % 3 == 0:plot(x, y, c)yy = w*xx + bplt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})plt.xlim(-4,4)plt.ylim(-4,4)plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))plt.show()if loss.data.numpy() < 0.03:  # 停止条件break

全部代码

import torch
import matplotlib.pyplot as plt
import matplotlib.markers as mmarkers#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导wx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + bn_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)def plot(x, y, c):ax = plt.gca()sc = ax.scatter(x, y, color='black')paths = []for i in range(len(x)):if c[i].item() == 0:marker_obj = mmarkers.MarkerStyle('o')else:marker_obj = mmarkers.MarkerStyle('x')path = marker_obj.get_path().transformed(marker_obj.get_transform())paths.append(path)sc.set_paths(paths)return sc
plot(x, y, c)
plt.show()#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)#随机初始化w
b = torch.zeros((1),requires_grad=True)#使用0初始化bwx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):#前向传播loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()#反向传播loss.backward()#更新参数b.data.sub_(lr*b.grad) # b = b - lr*b.gradw.data.sub_(lr*w.grad) # w = w - lr*w.grad#绘图if iteration % 3 == 0:plot(x, y, c)yy = w*xx + bplt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})plt.xlim(-4,4)plt.ylim(-4,4)plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))plt.show()if loss.data.numpy() < 0.03:#停止条件break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/210885.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用 ROS 和 Geomagic Haptic 驱动 Franka 机械臂

文章目录 前言一、安装 franka_ros二、安装 OpenHaptics for Linux三、安装 3D Systems Geomagic Touch ROS Driver四、安装 franka_interactive_controllers五、使用 Geomagic Haptic 驱动 Franka 机械臂 前言 本文为在双系统上使用 ROS 和 Geomagic Haptic 驱动 Franka 机械…

滑动窗口(单调队列)

154. 滑动窗口 - AcWing题库 给定一个大小为 n≤10^6≤10^6 的数组。 有一个大小为 k 的滑动窗口&#xff0c;它从数组的最左边移动到最右边。 你只能在窗口中看到 k 个数字。 每次滑动窗口向右移动一个位置。 以下是一个例子&#xff1a; 该数组为 [1 3 -1 -3 5 3 6 7]&…

HashMap的那些事

一、HashMap与HashTable的区别 1.来历 HashTable是一种键值映射的数据结构&#xff0c;自从java发布就存在&#xff0c;而HashMap是jdk1.2后才出现的&#xff0c;虽然说HashTable出现得早且线程安全&#xff0c;但是效率很低已经弃用了&#xff0c;现在HashMap逐渐成为主流 …

Nmap脚本未来的发展趋势

Nmap脚本技术的发展趋势和前景 Nmap脚本是一种基于Lua语言开发的脚本&#xff0c;可以扩展Nmap的功能&#xff0c;用于自动化扫描、漏洞检测、服务探测、设备管理等方面。随着网络安全的不断发展和漏洞的不断出现&#xff0c;Nmap脚本技术也在不断发展和壮大。在本文中&#xf…

小米手机锁屏时间设置为永不休眠_手机不息屏_保持亮屏

环境&#xff1a;打开手机自带的锁屏时间设置发现没有 永不息屏的选项 原因&#xff1a;采用了三星OLED屏幕&#xff0c;所以根据OLED屏幕特性&#xff0c;这个是为了防止烧屏而特意设计的。非OLED机型支持设置“永不” 解决方案1&#xff1a;原生系统是支持永不锁屏的&#…

Android 13 - Media框架(20)- ACodec(二)

这一节开始我们就来学习 ACodec 的实现 1、创建 ACodec ACodec 是在 MediaCodec 中创建的&#xff0c;这里先贴出创建部分的代码&#xff1a; mCodec mGetCodecBase(name, owner);if (mCodec NULL) {ALOGE("Getting codec base with name %s (owner%s) failed", n…

ES 如何将国际标准时间格式进行格式化与调整时区

需求&#xff0c;日志收集的时候&#xff0c;时间格式是国际标准时间格式。形如yyyy-MM-ddTHH:mm:ss.SSS。 &#xff08;2023-12-05T02:45:50.282Z&#xff09;这个时区也不对&#xff0c;那如何将此类型的时间&#xff0c;进行格式化呢&#xff1f; 本篇文章体统一个案例&…

Other -- ChatGPT 原理

本文为个人理解&#xff0c;帮助小白&#xff08;本人就是&#xff09;了解正在创建新时代的 AI 产品&#xff0c;如文中理解有误欢迎留言。 [参考链接--](https://baijiahao.baidu.com/s?id1765556782543603120&wfrspider&forpc) 1. 了解一些基本概念 大语言模型&a…

修改 Ganglia 监控 Grid Report timezone 时区 为 东八区 +8 PRC

Ganglia 监控 Grid Report timezone 默认时区 为 零时区 0 现在要修改为 东八区 8 具体操作如下 modify ganglia-web report timezone 0 --> 8 vim /apps/svr/httpd-2.4.48/htdocs/ganglia/header.php // add timezone GMT8 ini_set(date.timezone, PRC);详细记录&#x…

【面试】测试/测开(ING)

63. APP端特有的测试 参考&#xff1a;APP专项测试、APP应用测试 crash和anr的区别 1&#xff09;网络测试 2&#xff09;中断测试 3&#xff09;安装、卸载测试 4&#xff09;兼容测试 5&#xff09;性能测试&#xff08;耗电量、流量、内存、服务器端&#xff09; 6&#xf…

画对比折线图【Python】

出这一期想必是我做某个课程作业遇到了。 由于去各个官网下载对比图要钱&#xff0c;我还是不想花钱的&#xff01;真讨厌&#xff01;浅浅水一期。 以下是要做的对比图的数据&#xff1a; 代码&#xff1a; from matplotlib import pyplot as plt#设置中文显示plt.rcParams[…

CSS新手入门笔记整理:CSS浮动布局

文档流概述 正常文档流 “文档流”指元素在页面中出现的先后顺序。正常文档流&#xff0c;又称为“普通文档流”或“普通流”&#xff0c;也就是W3C标准所说的“normal flow”。正常文档流&#xff0c;将一个页面从上到下分为一行一行&#xff0c;其中块元素独占一行&#xf…

ChatGPT OpenAI API请求限制 尝试解决

1. OpenAI API请求限制 Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: Rate limit reached for gpt-3.5-turbo-16k in organization org-U7I2eKpAo6xA7RUa2Nq307ae on reques…

让内存无处可逃:智能指针[C++11]

智能指针 文章目录 智能指针前言RAII什么是智能指针智能指针的应用示例 C98的auto_ptr共享型智能指针&#xff1a;shared_ptrshared_ptr的使用初始化获取原生指针指定删除器默认删除器default_delete指定删除器指定删除器管理动态数组 shared_ptr的伪实现shared_ptr的注意事项避…

【Docker】进阶之路:(五)Docker引擎

【Docker】进阶之路&#xff1a;&#xff08;五&#xff09;Docker引擎 Docker引擎简介Docker引擎的组件构成runccontainerd Docker引擎简介 Docker引擎是用来运行和管理容器的核心部分。Docker首次发布时&#xff0c;Docker 引擎由LXC 和 Docker daemon 两个核心组件构成。 …

linux驱动开发——内核调试技术

目录 一、前言 二、内核调试方法 2.1 内核调试概述 2.2 学会分析内核源程序 2.3调试方法介绍 三、内核打印函数 3.1内核镜像解压前的串口输出函数 3.2 内核镜像解压后的串口输出函数 3.3 内核打印函数 四、获取内核信息 4.1系统请求键 4.2 通过/proc 接口 4.3 通过…

算法:有效的括号(入栈出栈)

时间复杂度 O(n) 空间复杂度 O(n∣Σ∣)&#xff0c;其中 Σ 表示字符集&#xff0c;本题中字符串只包含 6 种括号 /*** param {string} s* return {boolean}*/ var isValid function(s) {const map {"(":")","{":"}","["…

List截取指定长度(java截取拼接URL)

场景&#xff1a; N多个参数&#xff0c;截取指定个数&#xff0c;拼接URL public static void main(final String[] args) {int count 0;//每页数量final int pageSize 5;final List<Integer> memberNos ListUtil.toList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13…

python格式化内容

1.字符串格式化: 定义列表 [{"姓名": "张三", "年龄": 18, "性别": "男"}, {"姓名": "里斯李四李斯", "年龄": 18, "性别": "男"}, {"姓名": "斯托夫斯基…

C++知识 抽象基类

抽象基类通常包含至少一个纯虚函数&#xff0c;即一个没有具体实现的虚函数&#xff0c;通过在基类中声明纯虚函数&#xff0c;它强制派生类提供这个函数的具体实现。 通过在类的声明中使用 virtual 关键字和 0 初始化来创建纯虚函数&#xff0c;这样的类就成为抽象基类。以下…