【深度学习实验】线性模型(三):使用Pytorch实现简单线性模型:搭建、构造损失函数、计算损失值

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入库

1. 定义线性模型linear_model

2. 定义损失函数loss_function

3. 定义数据

4. 调用模型

5. 完整代码


一、实验介绍

  • 使用Pytorch实现
    • 线性模型搭建
    • 构造损失函数
    • 计算损失值

 二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        线性模型是一种基本的机器学习模型,用于建立输入特征与输出之间的线性关系。它是一种线性组合模型,通过对输入特征进行加权求和,再加上一个偏置项,来预测输出值。

        线性模型的一般形式可以表示为:y = w1x1 + w2x2 + ... + wnxn + b,其中y是输出变量,x1, x2, ..., xn是输入特征,w1, w2, ..., wn是特征的权重,b是偏置项。模型的目标是通过调整权重和偏置项,使预测值与真实值之间的差异最小化。

线性模型有几种常见的应用形式:

  1. 线性回归(Linear Regression):用于建立输入特征与连续输出之间的线性关系。它通过最小化预测值与真实值的平方差来拟合最佳的回归直线。

  2. 逻辑回归(Logistic Regression):用于建立输入特征与二分类或多分类输出之间的线性关系。它通过使用逻辑函数(如sigmoid函数)将线性组合的结果映射到概率值,从而进行分类预测。

  3. 支持向量机(Support Vector Machines,SVM):用于二分类和多分类问题。SVM通过找到一个最优的超平面,将不同类别的样本分隔开。它可以使用不同的核函数来处理非线性问题。

  4. 岭回归(Ridge Regression)和Lasso回归(Lasso Regression):用于处理具有多重共线性(multicollinearity)的回归问题。它们通过对权重引入正则化项,可以减小特征的影响,提高模型的泛化能力。

        线性模型的优点包括简单、易于解释和计算效率高。它们在许多实际问题中都有广泛的应用。然而,线性模型也有一些限制,例如对非线性关系的建模能力较弱。在处理复杂的问题时,可以通过引入非线性特征转换或使用核函数进行扩展,以提高线性模型的性能。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

0. 导入库

import torch

1. 定义线性模型linear_model

        该函数接受输入数据x,使用随机生成的权重w和偏置b,计算输出值output。这里的线性模型的形式为 output = x * w + b

def linear_model(x):w = torch.rand(1, 1, requires_grad=True)b = torch.randn(1, requires_grad=True)return torch.matmul(x, w) + b

2. 定义损失函数loss_function

      这里使用的是均方误差(MSE)作为损失函数,计算预测值与真实值之间的差的平方。

def loss_function(y_true, y_pred):loss = (y_pred - y_true) ** 2return loss

3. 定义数据

  • 生成一个随机的输入张量 x,形状为 (5, 1),表示有 5 个样本,每个样本的特征维度为 1。

  • 生成一个目标张量 y,形状为 (5, 1),表示对应的真实标签。

  • 打印数据的信息,包括每个样本的输入值x和目标值y
x = torch.rand(5, 1)
y = torch.tensor([1, -1, 1, -1, 1], dtype=torch.float32).view(-1, 1)
print("The data is as follows:")
for i in range(x.shape[0]):print("Item " + str(i), "x:", x[i][0], "y:", y[i])

4. 调用模型

  • 使用 linear_model 函数对输入 x 进行预测,得到预测结果 prediction

  • 使用 loss_function 计算预测结果与真实标签之间的损失,得到损失张量 loss

  • 打印了每个样本的损失值。
prediction = linear_model(x)
loss = loss_function(y, prediction)
print("The all loss value is:")
for i in range(len(loss)):print("Item ", str(i), "Loss:", loss[i])

5. 完整代码

import torchdef linear_model(x):w = torch.rand(1, 1, requires_grad=True)b = torch.randn(1, requires_grad=True)return torch.matmul(x, w) + bdef loss_function(y_true, y_pred):loss = (y_pred - y_true) ** 2return lossx = torch.rand(5, 1)
y = torch.tensor([1, -1, 1, -1, 1], dtype=torch.float32).view(-1, 1)
print("The data is as follows:")
for i in range(x.shape[0]):print("Item " + str(i), "x:", x[i][0], "y:", y[i])prediction = linear_model(x)
loss = loss_function(y, prediction)
print("The all loss value is:")
for i in range(len(loss)):print("Item ", str(i), "Loss:", loss[i])


注意:

        本实验的线性模型仅简单地使用随机权重和偏置,计算了模型在训练集上的均方误差损失,没有使用优化算法进行模型参数的更新。

        通常情况下会使用梯度下降等优化算法来最小化损失函数,并根据训练数据不断更新模型的参数,具体内容请听下回分解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/81226.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

idea创建springboot项+集成阿里连接池druid

创建项目并集成流程 1:前提准备2:创建springboot项目流程3:集成阿里连接池步骤4:集成swagger方便测试5:书写增删改查进行测试6:项目gitee地址 1:前提准备 准备开发工具:idea java环…

nginx知识点详解:反向代理+负载均衡+动静分离+高可用集群

一、nginx基本概念 1. nginx是什么,做什么事情? Nginx是一个高性能的HTTP和反向代理服务器,特点是占有内存少,并发能力强。Nginx转为性能优化而开发,能经受高负载考验。支持热部署,启动容易,运…

linux内核分析:线程和进程创建,内存管理

lec18-19:进程与线程创建 lec20-21虚拟内存管理 内核代码,全局变量这些只有一份,但是内核栈有多份,这可能就是linux线程模型1对1模式的由来。通过栈来做的 x86 CPU支持分段和分页(平坦内存模式)两种 分段,选择子那里就有特权标记了

Linux多线程【线程控制】

✨个人主页: 北 海 🎉所属专栏: Linux学习之旅 🎃操作环境: CentOS 7.6 阿里云远程服务器 文章目录 🌇前言🏙️正文1、线程知识补充1.2、线程私有资源1.3、线程共享资源1.4、原生线程库 2、线程…

GIS跟踪监管系统单元信息更新

GIS跟踪监管系统单元信息更新 单元信息更新。① 新增单元。② 编辑单元。③ 删除单元。物资查询(1)物资查询与展示。① 几何查询。• 单击查询:• 拉框查询:• 多边形查询:② 物资定位。• 多个物资定位: 单…

Visual Studio2019报错

1- Visual Studio2019报错 错误 MSB8036 找不到 Windows SDK 版本 10.0.19041.0的解决方法 小伙伴们在更新到Visual Studio2019后编译项目时可能遇到过这个错误:“ 错误 MSB8036 找不到 Windows SDK 版本 10.0.19041.0的解决方法”,但是我们明明安装了该…

智慧公厕:不放过任何“卫生死角”,为公共厕所装上“净化系统”。

#智慧公厕[话题]# #智慧公厕系统[话题]# #智慧公厕管理系统[话题]# #智慧公厕设备[话题]# #智慧公厕厂家[话题]# 在社会活动中,公共厕所是我们经常使用和停留的场所。然而,由于传统公共厕所的粗放式管理,导致卫生情况差、设备不齐全、配置破…

MySql中分割字符串

MySql中分割字符串 在MySql中分割字符串可以用到SUBSTRING_INDEX(str, delim, count) 参数解说       解释 str         需要拆分的字符串 delim         分隔符,通过某字符进行拆分 count          当 count 为正数&…

Learn Prompt-ChatGPT 精选案例:内容总结

ChatGPT 可以通过分析内容并生成一个浓缩版本来总结文本。这对节省时间和精力很有帮助,特别是在阅读长篇文章、研究论文或报告时。 通用总结​ 你所要做的就是把具体的文字复制并粘贴到提示中,并要求ChatGPT对所选文本进行简化总结。这里我们参考opena…

Linux的调试工具 - gdb(超详细)

Linux的调试工具 - gdb 1. 背景2. 开始使用指令的使用都用下面这个C语言简单小代码来进行演示:1. list或l 行号:显示文件源代码,接着上次的位置往下列,每次列10行。2. list或l 函数名:列出某个函数的源代码。3. r或run: 运行程序。…

Java进化史:从Java 8到Java 17的语言特性全解析

文章目录 Java 8:引入Lambda表达式和Stream APILambda表达式Stream API Java 9:模块化系统模块Jigsaw项目 Java 10:局部变量类型推断Java 11:引入HTTP客户端HTTP客户端 Java 12:引入Switch表达式Switch表达式 Java 13到…

微服务架构介绍

系统架构的演变 1、技术架构发展历史时间轴 ①单机垂直拆分:应用间进行了解耦,系统容错提高了,也解决了独立应用发布的问题,存在单机计算能力瓶颈。 ②集群化负载均衡可有效解决单机情况下并发量不足瓶颈。 ③服务改造架构 虽然系…

Spring修炼之路--基础知识

一、核心概念 1.1软件模块化 软件模块化是一种软件开发的设计模式,它将一个大型的软件系统划分成多个独立的模块,每个模块都有自己的功能和接口,并且能够与其他模块独立地工作1. 软件模块化设计可以使软件不至于随着逐渐变大而变得不可控&am…

【ICASSP 2023】ST-MVDNET++论文阅读分析与总结

主要是数据增强的提点方式。并不能带来idea启发,但对模型性能有帮助 Challenge: 少有作品应用一些全局数据增强,利用ST-MVDNet自训练的师生框架,集成了更常见的数据增强,如全局旋转、平移、缩放和翻转。 Contributi…

Vulnhub实战-DC9

前言 本次的实验靶场是Vulnhub上面的DC-9,其中的渗透测试过程比较多,最终的目的是要找到其中的flag。 一、信息收集 对目标网络进行扫描 arp-scan -l 对目标进行端口扫描 nmap -sC -sV -oA dc-9 192.168.1.131 扫描出目标开放了22和80两个端口&a…

Python 之利用matplotlib.pyplot 生成图形和图表

文章目录 介绍运用 介绍 matplotlib.pyplot是Matplotlib库的一个子模块,它提供了一个简单的界面来创建各种类型的图形和图表。使用pyplot,您可以轻松创建、定制和显示图形,而无需编写大量的底层代码。以下是matplotlib.pyplot的一些常见用法…

天然气跟踪监管系统具体实现

物资跟踪监管系统具体实现 系统开发环境框架设计(1)在VS2017中创建一个项目工程(2)在web目录下新建一个index.htm页面,② 与前端界面和操作相关框架文件③ 自定义文件。 物资跟踪监管系统基于Leaflet开发库实现&#x…

虚拟机已经启动 但是xshell连接不上服务器

目录 一:关于ping的问题二.网络的问题--找到控制面板三:防火墙的问题 一:关于ping的问题 1.虚拟机ping百度 观察虚拟机是否有网络 2.windows下ping linux的ip地址 ping 虚拟机地址 3.linux下ping windows 二.网络的问题–找到控制面板 三:防火墙的问题…

21天学会C++:Day9----初识类与对象

CSDN的uu们,大家好。这里是C入门的第九讲。 座右铭:前路坎坷,披荆斩棘,扶摇直上。 博客主页: 姬如祎 收录专栏:C专题 目录 1. 面向过程与面向对象 2. 类的定义 3. 类中的访问限定符 3.1 访问限定符的…

基于Java的新能源充电系统的设计与实现(亮点:完整合理的充电流程,举报反馈机制、余额充值、在线支付、在线聊天)

新能源充电系统 一、前言二、我的优势2.1 自己的网站2.2 自己的小程序(小蔡coding)2.3 有保障的售后2.4 福利 三、开发环境与技术3.1 MySQL数据库3.2 Vue前端技术3.3 Spring Boot框架3.4 微信小程序 四、功能设计4.1 主要功能描述 五、系统实现5.1 完整充…