半监督学习

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

目录

  • 介绍
  • 一、Self Training自训练
    • 1、介绍
    • 2、代码示例
    • 3、参数解释
  • 二、Label Propagation(标签传播)
    • 1、介绍
    • 2、代码示例
    • 3、参数解释
  • 三、Label Spreading(标签扩散)
    • 1、介绍
    • 2、代码示例
    • 3、参数解释


介绍

半监督学习(Semi-Supervised Learning,SSL)是机器学习领域中的一个重要分支,它结合了监督学习和无监督学习的思想,用于处理标签数据稀缺而无标签数据丰富的场景。
常用方法:

  • Self Training自训练
  • Label Propagation标签传播
  • Label Spreading标签扩散

一、Self Training自训练

1、介绍

Self Training自训练是一种简单的半监督学习方法,它首先使用已标记的数据训练一个监督学习模型。然后,该模型用于预测未标记数据的标签。预测最自信的标签可以被选择添加到训练集中,然后模型在新的、更大的训练集上重新训练。先训练一个小模型,再继续预测标签,类似于迁移学习。当无标签数据和有标签数据分布相同时,使用自训练方法效果最佳。

2、代码示例

  • 读入数据
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore")# 数据预处理计算函数
def preprocessing(df):from sklearn.impute import SimpleImputerfrom sklearn.preprocessing import StandardScalerfrom sklearn.pipeline import Pipelinefrom sklearn.preprocessing import OrdinalEncoderfrom sklearn.compose import ColumnTransformercat_cols= df.select_dtypes(include=["object"])   # 分类型变量num_cols= df.select_dtypes(include=["int", "float"])   # 数值型变量# 连续型数据num_imp= SimpleImputer(strategy='mean')  # 缺失值num_std= StandardScaler()  # 标准化num_pipeline= Pipeline(steps=[("num_imp", num_imp), ("num_std", num_std)])# 分类型数据cat_imp= SimpleImputer(strategy="most_frequent")  # 缺失值cat_encode= OrdinalEncoder()   # 数据编码cat_pipeline= Pipeline(steps=[("cat_imp", cat_imp), ("cat_encode", cat_encode)])col_trans= ColumnTransformer(transformers=[("num_pipeline", num_pipeline, num_cols.columns),("cat_pipeline", cat_pipeline, cat_cols.columns),])# 数据集处理的计算transfer= col_trans.fit(df)return transfer# 读入数据
raw_data= pd.read_csv('半监督学习.csv')
labels= raw_data.pop("resp_flag")  # 标签
  • 缺失数据对比
print("缺失值/总样本:"+str(labels.isnull().sum())+"/"+str(len(labels)))

在这里插入图片描述

  • 数据处理
    注意:切分的测试数据集一定是有标签的样本
# sklearn中的半监督学习算法要求,所有缺失的标签必须用-1填充
labels_fill= labels.fillna(-1)# 特征数据处理
transfer= preprocessing(raw_data)
data_trans= transfer.transform(raw_data)data_concat= pd.concat([labels_fill, pd.DataFrame(data_trans)], axis= 1)# 保存一部分有标签样本作为测试集
mask_labeled= (labels_fill != -1)
mask_unlabeled= (labels_fill == -1)data_labeled= data_concat[mask_labeled]
data_unlabeled= data_concat[mask_unlabeled]# 切分测试集
from sklearn.model_selection  import train_test_split
train, test= train_test_split(data_labeled, test_size=0.2, stratify= data_labeled["resp_flag"], random_state= 42)Xtrain= pd.concat([train, data_unlabeled])
Ytrain= Xtrain.pop("resp_flag")
  • 使用模型
from sklearn.ensemble import RandomForestClassifier
RF= RandomForestClassifier(oob_score=True)# Self Training
from sklearn.semi_supervised import SelfTrainingClassifier
RF_self_training= SelfTrainingClassifier(RF)
RF_self_training.fit(Xtrain, Ytrain)# 测试集模型评估
Xtest= test
Ytest= Xtest.pop("resp_flag")from sklearn.metrics import roc_auc_score
print("AUC: ", roc_auc_score(Ytest, RF_self_training.predict_proba(Xtest)[:, 1]))

在这里插入图片描述

3、参数解释

base_estimator: BaseEstimator,# 基学习器
threshold: Float = 0.75,# 默认阈值0.75,大于0.75,小于0.25会被打标签,该参数比k_best更为常用
criterion: Literal['threshold', 'k_best'] = "threshold",# 默认值threshold,为该值时和threshold参数相同,即设阈值,k_best超参数阈值,如为10,则不考虑预测概率,只取排名前10的打标签
k_best: Int = 10,# 超参数阈值,如为10,则不考虑预测概率,只取排名前10的打标签
max_iter: int | None = 10,# 最大迭代次数
verbose: bool = False

二、Label Propagation(标签传播)

在sklearn中,基于图算法的半监督学习有Label Propagation和Label Spreading两种。他们的主要区别是第二种方法带有正则化机制。

1、介绍

Label Propagation(标签传播)基本原理:Label Propagation算法基于图理论。算法首先构建一个图,其中每个节点代表一个数据点,无论是标记的还是未标记的。节点之间的边代表数据点之间的相似性。算法的目的是通过图传播标签信息,使未标记数据获得标签。

关键特点:
相似性度量:通常使用K近邻(KNN)或者基于核的方法来定义数据点之间的相似性。
标签传播:标签信息从标记数据点传播到未标记数据点,通过迭代过程实现。
适用场景:适合于数据量较大、标记数据稀缺的情况。

  • 以环形数据为例,绿色全是为打标签的数据:
    在这里插入图片描述
    打标签后数据结果如图:
from sklearn.semi_supervised import LabelPropagationlabel_propagation = LabelPropagation(kernel="knn")
label_propagation.fit(X, labels)output= np.asarray(label_propagation.transduction_)
outer_numbers = np.where(output == outer)[0]
inner_numbers = np.where(output == inner)[0]plt.figure(figsize=(4, 4))
plt.scatter(X[outer_numbers, 0], X[outer_numbers, 1],)
plt.scatter(X[inner_numbers, 0], X[inner_numbers, 1],);

在这里插入图片描述

2、代码示例

from sklearn.semi_supervised import LabelPropagationlabel_propagation = LabelPropagation(kernel="knn")
label_propagation.fit(Xtrain, Ytrain)Ytrain_propagation= label_propagation.transduction_from sklearn.ensemble import RandomForestClassifier
RF_propagation= RandomForestClassifier(oob_score=True)
RF_propagation.fit(Xtrain, Ytrain_propagation)print("AUC: ", roc_auc_score(Ytest, RF_propagation.predict_proba(Xtest)[:, 1]))

在这里插入图片描述

3、参数解释

    kernel: ((...) -> Any) | Literal['knn', 'rbf'] = "rbf",# knn:k近邻,RBF核用于计算图中节点之间的相似度。这些相似度值随后用于传播标签信息,从而根据相邻节点的标签来预测未知节点的标签,rbf函数和正态分布比较相似*,gamma: Float = 20, # rbf函数的系数,可以简单理解为正态分布的方差n_neighbors: Int = 7, # 附近的7个样本,哪个样本多,就打成哪个标签,为knn时生效max_iter: Int = 1000,# 迭代次数tol: float = 0.001,# 算法收敛的阈值n_jobs: Int | None = None

三、Label Spreading(标签扩散)

1、介绍

基本原理:Label Spreading和Label Propagation非常相似,但在处理标签信息和正则化方面有所不同。它同样基于构建图来传播标签。

关键特点:
正则化机制:Label Spreading引入了正则化参数,可以控制标签传播的过程,使算法更加健壮。
稳定性:由于正则化的存在,Label Spreading在面对噪声数据时通常比Label Propagation更稳定。
适用场景:同样适用于有大量未标记数据的情况,尤其当数据包含噪声或者不完全标记时。

2、代码示例

from sklearn.semi_supervised import LabelSpreadinglabel_spreading = LabelSpreading(kernel="knn", alpha= 0.2)
label_spreading.fit(Xtrain, Ytrain)Ytrain_spreading= label_spreading.transduction_from sklearn.ensemble import RandomForestClassifier
RF_spreading= RandomForestClassifier(oob_score=True)
RF_spreading.fit(Xtrain, Ytrain_spreading)print("AUC: ", roc_auc_score(Ytest, RF_spreading.predict_proba(Xtest)[:, 1]))

在这里插入图片描述

3、参数解释

	kernel: ((...) -> Any) | Literal['rbf', 'knn'] = "rbf",*,gamma: Float = 20,n_neighbors: Int = 7,alpha: Float = 0.2, # 正则化参数,用于控制算法对标签平滑的程度,值较小时,会更强调邻居节点信息,值较大时,更倾向于保持原始标签max_iter: Int = 30,tol: Float = 0.001, # 算法收敛的阈值n_jobs: Int | None = None

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/28935.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

618狂欢日,美味产品齐上阵,超值优惠等你享

这个充满激情与活力的6月,我们带着满满的诚意与惊喜,为广大美食爱好者们开启一场独特的618狂欢之旅。 当我们提及甘肃,那丰富多样的甘肃传统美食便是不得不说的瑰宝。烤馍,油饼,锅盔、擀面皮、浆水等每一种美食都…

你知道花洒其实起源于中国古代吗?

花洒作为日常生活中不可或缺的一部分,其发展历程不仅见证了人类文明的进步,也反映了生活美学的演变。从最初的简单构想到现代的智能化设计,花洒的变迁历程是一部生动的人类生活史。 早在隋朝时期,我们的祖先就已经有了花洒的初步构…

《纪元 1800》好玩吗? 苹果电脑能玩《纪元 1800》吗?

《纪元1800》是一款不错的策略游戏,这款游戏因为画面和玩法独特深受玩家们的喜爱。下面我们来看看《纪元 1800》好玩吗,苹果电脑能玩《纪元 1800》吗的相关内容。 一、《纪元1800》好玩吗 《纪元1800》是一款备受瞩目的策略游戏。下面让我们来看看这款…

初探工厂抽象模式

设计模式的-工厂模式 1.定义一个约定的规则抽象类 class ETFactory {createStore() {throw new Error(抽象方法,不允许直接调用,需重写)}createUser(){throw new Error(抽象方法,不允许直接调用,需重写)} } 案例:…

eNSP学习——OSPF在帧中继网络中的配置

目录 主要命令 原理概述 实验目的 实验场景 实验拓扑 实验编址 实验步骤 1、基本配置 2、在帧中继上搭建OSPF网络 主要命令 //检查帧中继的虚电路状态 display fr pvc-info//检查帧中继的映射表 display fr map-info//手工指定OSPF邻居,采用单播方式发送报文 [R1]os…

Android Compose 文本输入框TextField使用详解

一、 TextField介绍 TextField 允许用户输入和修改文本,也就是文本输入框。 TextField 分为三种: TextField是默认样式OutlinedTextField 是轮廓样式版本BasicTextField 允许用户通过硬件或软件键盘修改文本,但不提供提示或占位符等装饰&a…

youlai-boot项目的学习—本地数据库安装与配置

数据库脚本 在项目代码的路径下,有两个版本的mysql数据库脚本,使用对应的脚本就安装对应的数据库版本,本文件选择了5 数据库安装 这里在iterm2下使用homebrew安装mysql5 brew install mysql5.7注:记得配置端终下的科学上网&a…

实时工业数据采集分析系统高效处理产线信息!

对于大部分制造业企业,测量仪器的自动数据采集一直是个令人烦恼的事情,即使仪器已经具有RS232/485等接口,但仍然在使用一边测量,一边手工记录到纸张,再输入到PC中处理的方式,不但工作繁重,同时也…

try catch return语句情况分析

try catch return语句情况分析 try catch无finally语句写在最后 try catch try catch语法是一种对应于异常处理的语句,其中try语句内用于编写有异常存在可能的语句,而catch语句内用于编写捕获到异常的类型以及对异常对象的处理方法,本文主要…

鸿蒙: 基础认证

先贴鸿蒙认证 官网10个类别总结如下 https://developer.huawei.com/consumer/cn/training/dev-cert-detail/101666948302721398 10节课学习完考试 考试 90分合格 3次机会 1个小时 不能切屏 运行hello world hvigorfile.ts是工程级编译构建任务脚本 build-profile.json5是工程…

IPA清洁棉签 IPA清洁擦拭棒:打印机头、电子设备等清洁的有力工具!

在数字化快速发展的今天,打印机头、电子设备等已经成为了我们日常生活和工作中不可或缺的一部分。然而,随着使用时间的增长,这些设备往往会因为灰尘、油渍等污染物的积累而影响其性能。此时,一款高效、便捷的清洁工具就显得尤为重…

数据预处理之基于聚类的TOD异常值检测#matlab

1.基于聚类的异常值检测方法 物以类聚——相似的对象聚合在一起,基于聚类的异常点检测方法有两个共同特点: (1)先采用特殊的聚类算法处理输入数据而得到聚类,再在聚类的基础上来检测异常。 (2)只需要扫描数据集若干次,效率较高…

3D Gaussian Splatting Windows安装

1.下载源码 git clone https://github.com/graphdeco-inria/gaussian-splatting --recursive 2.安装cuda NVIDIA GPU Computing Toolkit CUDA Toolkit Archive | NVIDIA Developer 3.安装COLMAP https://github.com/colmap/colmap/releases/tag/3.9.1 下载完成需要添加环…

基于Springboot框架班级综合测评管理系统的设计与实现

开头语:你好呀,我是计算机学姐码农小野!如果有相关需求,可以私信联系我。 开发语言:Java 数据库:MySQL 技术:Springboot框架,B/S模式 工具:MyEclipse 系统展示 首页…

Go 并发控制:RWMutex 实战指南

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

《地下城与勇士》新手攻略,开荒必备!云手机多开教程!

《地下城与勇士》(DNF)是一款广受欢迎的多人在线动作角色扮演游戏。玩家将在游戏中扮演不同职业的角色,通过打怪、做任务、PK等方式不断提升自己,探索广阔的阿拉德大陆。游戏中设有丰富的副本、装备、技能系统,玩家可以…

重磅!草料模板库更新,新增签到报名和旅游模板

本次共更新5个签到报名场景模板,以及6个旅游场景模板。 所有模板内容均可自定义修改,并可免费使用。 签到报名场景 签到报名场景更新了 活动报名、大型活动会议报名、展会邀请函、专题讲座活动报名和技能培训邀约报名 5个模板,基于不同的会…

6.13.1 使用残差神经网络堆叠集成进行乳腺肿块分类和诊断的综合框架

计算机辅助诊断 (CAD) 系统需要将肿瘤检测、分割和分类的自动化阶段按顺序集成到一个框架中,以协助放射科医生做出最终诊断决定。 介绍了使用堆叠的残差神经网络 (ResNet) 模型(即 ResNet50V2、ResNet101V2 和 ResNet152V2)进行乳腺肿块分类…

基于自编码器的心电图信号异常检测(Python)

使用的数据集来自PTB心电图数据库,包括14552个心电图记录,包括两类:正常心跳和异常心跳,采样频率为125Hz。 import numpy as np np.set_printoptions(suppressTrue) import pandas as pd import matplotlib.pyplot as plt import…

reverse-android-淘最热点so

资源 1. com.maihan.tredian 2021版 淘最热点 2. 该 app 没有加壳 ,也没混淆。 登录抓包 POST: https://api.taozuiredian.com/api/v1/auth/login/sms POST /api/v1/auth/login/sms HTTP/1.1 Content-Type: application/json Connection: close Charset: UTF-8 User-Agen…