【深度学习】2.单层感知机

目标:

实现一个简单的二分类模型的训练过程,通过模拟数据集进行训练和优化,训练目标是使模型能够根据输入特征正确分类数据。

演示:

1.通过PyTorch生成了一个模拟的二分类数据集,包括特征矩阵data_x和对应的标签数据data_y。标签数据通过基于特征的线性组合生成,并转换成独热编码的形式。

import torch
# 从torch库中导入神经网络模块nn,用于构建神经网络模型
from torch import nn
# 导入torch.nn模块中的functional子模块,可用于访问各种函数,例如激活函数
import torch.nn.functional as Fn_item = 1000
n_feature = 2
learning_rate = 0.01
epochs = 100# 生成一个模拟的数据集,其中包括一个随机生成的特征矩阵data_x和相应生成的标签数据data_y。标签数据通过基于特征的线性组合生成,并且转换成独热编码的形式。# 设置随机数生成器的种子为123,通过设置随机种子,我们可以确保在每次运行代码时生成的随机数相同,这对于结果的可重现性非常重要。
torch.manual_seed(123)
# 生成一个随机数矩阵data_x,其中包含n_item行和n_feature列。矩阵中的元素是从标准正态分布(均值为0,标准差为1)中随机采样的。
data_x = torch.randn(size=(n_item, n_feature)).float()
# torch.where(...): 根据条件返回两个张量中相应位置的值。如果条件成立,将为0,否则为1。  long(): 用于将张量转换为Long型数据类型。
data_y = torch.where(torch.subtract(data_x[:, 0]*0.5, data_x[:, 1]*1.5)+0.02 > 0, 0, 1).long()
# 将标签数据data_y转换为独热编码形式,即将每个标签转换为一个相应长度的独热向量
data_y = F.one_hot(data_y)# print(data_x)
# print(data_y)

2.定义了一个简单的二分类模型BinaryClassificationModel,包含一个单层感知器(Single Perceptron)结构,其中使用了一个线性层和sigmoid激活函数,用于将输入特征映射到概率空间。

# 定义了一个简单的二分类模型,采用单层感知器的结构,包含一个线性层和sigmoid激活函数,用于将输入特征映射到概率空间。这样的模型可以用来对数据集进行二分类任务的预测。# 定义了一个名为BinaryClassificationModel的类,其继承自nn.Module类,这意味着这个类是一个PyTorch模型。
class BinaryClassificationModel(nn.Module):def __init__(self, in_feature):# 调用了父类nn.Module的构造函数,确保正确初始化模型。super(BinaryClassificationModel, self).__init__()"""single perception"""# 这行代码定义了模型的第一层,是一个线性层(Fully Connected Layer)。in_features参数指定输入特征的数量,out_features指定输出特征的数量,这里设置为2表示二分类问题。bias=True表示该层包含偏置项。self.layer_1 = nn.Linear(in_features=in_feature, out_features=2, bias=True)# 定义模型前向传播的方法,即输入数据x通过模型前向计算得到输出。def forward(self, x):# 输入数据x首先通过定义的线性层self.layer_1进行线性变换,然后通过F.sigmoid()函数进行激活函数处理。return F.sigmoid(self.layer_1(x))

3.创建了该二分类模型的实例model、使用随机梯度下降(SGD)优化器opt、以及二分类问题常用的损失函数BCELoss(Binary Cross Entropy Loss)。

4.在训练过程中,通过多个epoch和每个样本的批处理(在这里是一次处理一个样本),计算模型预测输出和真实标签之间的损失值,进行反向传播计算梯度,并更新模型参数以最小化损失函数。

# 完成对模型的训练过程,每个epoch中通过优化器进行参数更新,计算损失,反向传播更新梯度。最终我们会得到训练过程中每个epoch的损失值,并可以观察损失的变化情况。# 创建了一个二分类模型实例model,参数n_feature表示输入特征的数量。
model = BinaryClassificationModel(n_feature)
# 创建了一个随机梯度下降(SGD)优化器opt,用于根据计算出的梯度更新模型参数。
opt = torch.optim.SGD(model.parameters(), lr=learning_rate)
# 创建了一个二分类问题常用的损失函数BCELoss(Binary Cross Entropy Loss),用于衡量模型输出与真实标签之间的差异。
criteria = nn.BCELoss()for epoch in range(epochs):# 对每个样本进行训练。for step in range(n_item):x = data_x[step]y = data_y[step]# 梯度清零,避免梯度累加影响优化结果。opt.zero_grad()# 将输入特征x通过模型前向传播得到预测输出y_hat。unsqueeze(0)是因为我们的模型期望输入是(batch_size, n_feature)的形式。y_hat = model(x.unsqueeze(0))# 计算预测输出y_hat和真实标签y之间的损失值。loss = criteria(y_hat, y.unsqueeze(0).float())# 反向传播计算梯度。loss.backward()# 根据计算出的梯度更新模型参数。opt.step()print("Epoch: %03d, Loss: %.3f" % (epoch, loss.item()))

5.打印出每个epoch的序号和损失值,用于监控训练过程中损失值的变化情况。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/15936.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

加密与安全_AES RSA 密钥对生成及PEM格式的代码实现

文章目录 RSA(非对称)和AES(对称)加密算法一、RSA(Rivest-Shamir-Adleman)二、AES(Advanced Encryption Standard) RSA加密三种填充模式一、RSA填充模式二、常见的RSA填充模式组合三…

新业务 新市场 | 灵途科技新品亮相马来西亚亚洲防务展

5月6日,灵途科技携新品模组与武汉长盈通光电(股票代码:688143)携手参加第18届马来西亚亚洲防务展。首次亮相海外,灵途科技便收获全球客户的广泛关注,为公司海外市场开拓打下坚实基础。 灵途科技与长盈通共同…

Dbs封装_连接池

1.Dbs封装 每一个数据库都对应着一个dao 每个dao势必存在公共部分 我们需要将公共部分抽取出来 封装成一个工具类 保留个性化代码即可 我们的工具类一般命名为xxxs 比如Strings 就是字符串相关的工具类 而工具类 我们将其放置于util包中我们以是否有<T>区分泛型方法和非泛…

Python并发编程学习记录

1、初识并发编程 1.1、串行&#xff0c;并行&#xff0c;并发 串行(serial)&#xff1a;一个cpu上按顺序完成多个任务&#xff1b; 并行(parallelism)&#xff1a;任务数小于或等于cup核数&#xff0c;多个任务是同时执行的&#xff1b; 并发(concurrency)&#xff1a;一个…

计算机SCI期刊,IF=8+,专业性强,潜力新刊!

一、期刊名称 Journal of Big data 二、期刊简介概况 期刊类型&#xff1a;SCI 学科领域&#xff1a;计算机科学 影响因子&#xff1a;8.1 中科院分区&#xff1a;2区 出版方式&#xff1a;开放出版 版面费&#xff1a;$1990 三、期刊征稿范围 《大数据杂志》发表了关于…

2024年【T电梯修理】考试内容及T电梯修理新版试题

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2024年【T电梯修理】考试内容及T电梯修理新版试题&#xff0c;包含T电梯修理考试内容答案和解析及T电梯修理新版试题练习。安全生产模拟考试一点通结合国家T电梯修理考试最新大纲及T电梯修理考试真题汇总&#xff0c;…

线性dp合集,蓝桥杯

贸易航线 0贸易航线 - 蓝桥云课 (lanqiao.cn) n,m,kmap(int ,input().split()) #贪心的想&#xff0c;如果买某个东西利润最大&#xff0c;那我肯定直接拉满啊&#xff0c;所以买k个和买一个没区别 p[0] for i in range(n):p.append([-1]list(map(int,input().split())))dp[[…

(2024,SDE,对抗薛定谔桥匹配,离散时间迭代马尔可夫拟合,去噪扩散 GAN)

Adversarial Schrdinger Bridge Matching 公众号&#xff1a;EDPJ&#xff08;进 Q 交流群&#xff1a;922230617 或加 VX&#xff1a;CV_EDPJ 进 V 交流群&#xff09; 目录 0. 摘要 1. 简介 4. 实验 0. 摘要 薛定谔桥&#xff08;Schrdinger Bridge&#xff0c;SB&…

el-autocomplete后台远程搜索

el-complete可以实现后台远程搜索功能&#xff0c;但有时传入数据为空时&#xff0c;接口可能会报错。此时可在querySearchAsync方法中&#xff0c;根据queryString判断&#xff0c;若为空&#xff0c;则不掉用接口&#xff0c;直接callback([])&#xff0c;反之则调用接口&…

浮点型比较大小

浮点数的存储形式 浮点数按照在内存中所占字节数和数值范围&#xff0c;可以分为浮点型&#xff0c;双精度浮点型和长双浮点型数。 代码&#xff1a; printf("lgn:%e \n", pow(exp(1), 100));printf("lgn:%f ", pow(exp(1), 100));输出结果&#xff1a; …

Stanford斯坦福 CS 224R: 深度强化学习 (5)

离线强化学习:第一部分 强化学习(RL)旨在让智能体通过与环境交互来学习最优策略,从而最大化累积奖励。传统的RL训练都是在线(online)进行的,即智能体在训练过程中不断与环境交互,实时生成新的状态-动作数据,并基于新数据来更新策略。这种在线学习虽然简单直观,但也存在一些局限…

【Could not find Chrome This can occur if either】

爬虫练习中遇到的问题 使用puppeteer执行是提示一下错误 Error: Could not find Chrome (ver. 125.0.6422.78). This can occur if either you did not perform an installation before running the script (e.g. npx puppeteer browsers install chrome) oryour cache path…

CLIP 论文的关键内容

CLIP 论文整体架构 该论文总共有 48 页&#xff0c;除去最后的补充材料十页去掉&#xff0c;正文也还有三十多页&#xff0c;其中大部分篇幅都留给了实验和响应的一些分析。 从头开始的话&#xff0c;第一页就是摘要&#xff0c;接下来一页多是引言&#xff0c;接下来的两页就…

常用 CSS 写法

不是最后一个 :not(:last-child)渐变色 background: linear-gradient(270deg, #15aaff 0%, #02396a 100%);文字渐变色 background-image: linear-gradient(to right, #ff7e5f, #feb47b); -webkit-background-clip: text; background-clip: text; color: transparent;

python文件IO基础知识

目录 1.open函数打开文件 2.文件对象读写数据和关闭 3.文本文件和二进制文件的区别 4.编码和解码 读写文本文件时 读写二进制文件时 5.文件指针位置 6.文件缓存区与flush()方法 1.open函数打开文件 使用 open 函数创建一个文件对象&#xff0c;read 方法来读取数据&…

谈谈磁盘的那些操作

磁盘格式化 是指把一张空白的盘划分成一个个小区域并编号&#xff0c;以供计算机存储和读取数据。格式化是一种纯物理操作&#xff0c;是在磁盘的所有数据区上写零的操作过程&#xff0c;同时对硬盘介质做一致性检测&#xff0c;并且标记出不可读和坏的扇区。由于大部分硬盘在…

电子技术学习路线

在小破站上看到大佬李皆宁的技术路线分析&#xff0c;再结合自己这几年的工作。发现的确是这样&#xff0c;跟着大佬的技术路线去学习是会轻松很多&#xff0c;现在想想&#xff0c;这路线其实跟大学四年的学习顺序是很像的。 本期记录学习路线&#xff0c;方便日后查看。 传统…

python 深度图生成点云(方法二)

深度图生成点云 一、介绍1.1 概念1.2 思路1.3 函数讲解二、代码示例三、结果示例接上篇:深度图生成点云(方法1) 一、介绍 1.1 概念 深度图生成点云:根据深度图像(depth image)和相机内参(camera intrinsics)生成点云(PointCloud)。 1.2 思路 点云坐标的计算公式如…

pillow学习7

绘制验证码 from PIL import Image,ImageFilter,ImageFont,ImageDraw import random width100 hight100 imImage.new(RGB,(width,hight),(255,255,255)) drawImageDraw.Draw(im) #获取颜色 def get_color1():return (random.randint(200, 255), random.randint(200, 255), ran…

京东Java社招面试题真题,最新面试题

Java中接口与抽象类的区别是什么&#xff1f; 1、定义方式&#xff1a; 接口是完全抽象的&#xff0c;只能定义抽象方法和常量&#xff0c;不能有实现&#xff1b;而抽象类可以有抽象方法和具体实现的方法&#xff0c;也可以定义成员变量。 2、实现与继承&#xff1a; 一个类…