【Python】 剪辑法欠采样 CNN压缩近邻法欠采样

借鉴:关于K近邻(KNN),看这一篇就够了!算法原理,kd树,球树,KNN解决样本不平衡,剪辑法,压缩近邻法 - 知乎

但是不要看他里面的代码,因为作者把代码里的一些符号故意颠倒了 ,比如“==”改成“!=”,还有乱加“~”,看明白逻辑才能给他改过来

一、剪辑法

        当训练集数据中存在一部分不同类别数据的重叠时(在一部分程度上说明这部分数据的类别比较模糊),这部分数据会对模型造成一定的过拟合,那么一个简单的想法就是将这部分数据直接剔除掉即可,也就是剪辑法。

        剪辑法将训练集 D 随机分成两个部分,一部分作为新的训练集 Dtrain,一部分作为测试集 Dtest,然后基于 Dtrain,使用 KNN 的方法对 Dtest 进行分类,并将其中分类错误的样本从整体训练集 D 中剔除掉,得到 Dnew。

        由于对训练集 D 的划分是随机划分,难以保证数据重叠部分的样本在第一次剪辑时就被剔除,因此在得到 Dnew 后,可以对 Dnew 继续进行上述操作数次,这样可以得到一个比较清爽的类别分界。

        效果如下图:

        附上可直接运行的代码:

from sklearn import datasets
import matplotlib.pyplot as pyplot
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier as KNN
import numpy as np
from collections import Counter
from numpy import where# make_classification用于手动构造数据
# 1000个样本,分成4类
X, y = datasets.make_classification(n_samples=1000, n_features=2,n_informative=2, n_redundant=0, n_repeated=0,n_classes=4, n_clusters_per_class=1)# # # 画出二维散点图
# for label, _ in counter.items():
# 	row_ix = where(y == label)[0]
# 	pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
# pyplot.legend()
# pyplot.show()# 剪辑10次
for i in range(10):x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.5)k = 5KNN_clf = KNN(n_neighbors=k)KNN_clf.fit(x_train, y_train)  # 用训练集训练KNNy_predict = KNN_clf.predict(x_test)  # 用测试集测试cond = y_predict == y_testx_test = x_test[cond]  # 把预测错误的从整体数据集中剔除掉y_test = y_test[cond]  # 把预测错误的从整体数据集中剔除掉X = np.vstack([x_train, x_test])  # 为下一次循环做准备(剔除掉本轮预测错误的y = np.hstack([y_train, y_test])  # 为下一次循环做准备(剔除掉本轮预测错误的# summarize the new class distribution
counter = Counter(y)
print(counter)# 画出二维散点图
for label, _ in counter.items():row_ix = where(y == label)[0]pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

        以上使用了k=20的参数进行剪辑的结果,循环了10次,一般而言,k越大,被抛弃的样本会越多,因为被分类的错误的概率更大。

二、CNN压缩近邻法欠采样

        

        压缩近邻法的想法是认为同一类型的样本大量集中在类簇的中心,而这些集中在中心的样本对分类没有起到太大的作用,因此可以舍弃掉这些样本。

        其做法是将训练集随机分为两个部分,第一个部分为 store,占所有样本的 10% 左右,第二个部分为 grabbag,占所有样本的 90% 左右,然后将 store 作为训练集训练 KNN 模型,grabbag 作为测试集,将分类错误的样本从 grabbag 中移动到 store 里,然后继续用增加了样本的 store 和减少了样本的 grabbag 再次训练和测试 KNN 模型,直到 grabbag 中所有样本被分类正确,或者 grabbag 中样本数为0。

        在压缩结束之后,store 中存储的是初始化时随机选择的 10% 左右的样本,以及在之后每一次循环中被分类错误的样本,这些被分类错误的样本集中在类簇的边缘,认为是对分类作用较大的样本。

        CNN欠采样已经有相应的Python实现库了,相应的方法是CondensedNearestNeighbour(),下面是可直接运行的代码。

# Undersample and plot imbalanced dataset with the Condensed Nearest Neighbor Rule
from collections import Counter
from sklearn.datasets import make_classification
from imblearn.under_sampling import CondensedNearestNeighbour
from matplotlib import pyplot
from numpy import where# make_classification方法用于生成分类任务的人造数据集
# X是数据,几维都可以,n_features=4表示4维
# y用0/1表示类别,weights调整0和1的占比
X, y = make_classification(n_samples=500, n_classes=2, n_features=3, n_redundant=0,# n_clusters_per_class表示每个类别多少簇  # flip_y噪声,增加分类难度n_clusters_per_class=2, weights=[0.5], flip_y=0, random_state=1)# summarize class distribution
counter = Counter(y)  # {0: 990, 1: 10} counter是一个字典,value存储类别,key存储类别个数
print(counter)# ==================CNN有直接可以调用的包  n_neighbors设置k值,k值越小越省时间,就设置为1吧
undersample = CondensedNearestNeighbour(n_neighbors=1)
# transform the dataset
X, y = undersample.fit_resample(X, y)# summarize the new class distribution
counter = Counter(y)
print(counter)# scatter plot of examples by class label
for label, _ in counter.items():row_ix = where(y == label)[0]pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

        但是我觉得这个CondensedNearestNeighbour()方法的可操作性太低,所以没用这个方法,而是根据CNN的原理(CNN底层是训练KNN)去写的

from sklearn import datasets
import matplotlib.pyplot as pyplot
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier as KNN
import numpy as np
from collections import Counter
from numpy import where# make_classification用于手动构造数据
# 1000个样本,分成4类
X, y = datasets.make_classification(n_samples=1000, n_features=2,n_informative=2, n_redundant=0, n_repeated=0,n_classes=4, n_clusters_per_class=1, random_state=1)
counter = Counter(y)
# 画出二维散点图
for label, _ in counter.items():row_ix = where(y == label)[0]pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()# 10%作为训练集,90%作为测试集
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.9)while True:k = 1KNN_clf = KNN(n_neighbors=k)KNN_clf.fit(x_train, y_train)y_predict = KNN_clf.predict(x_test)cond = y_predict == y_test  # cond记录分类的对与错,分类错是False,正确是True# 都分类正确,退出if  cond.all():print('所有测试集都分类正确,CNN正常结束')breakx_train = np.vstack([x_train, x_test[~cond]])  # 把分类错误(cond的值是False)的移动到训练集里y_train = np.hstack([y_train, y_test[~cond]])x_test = x_test[cond]  # 把分类对的继续作为下一轮的测试集y_test = y_test[cond]if len(x_test) == 0:print("所有样本都能做到分类错误,也就是结果集=原始数据集,一般不会出现这种情况")break# summarize the new class distribution
counter = Counter(y_train)
print(counter)# 画出二维散点图
for label, _ in counter.items():row_ix = where(y_train == label)[0]pyplot.scatter(x_train[row_ix, 0], x_train[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

2.1 改进版——指定压缩后样本大小的CNN

在如下代码中,用sampleNum指定全体样本数量,用endNum指定压缩后样本数量

from sklearn import datasets
import matplotlib.pyplot as pyplot
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier as KNN
import numpy as np
from collections import Counter
from numpy import wheresampleNum = 1000
endNum = 500
k = 1  # KNN算法的K值
# make_classification用于手动构造数据
# 1000个样本,分成4类
X, y = datasets.make_classification(n_samples=sampleNum, n_features=2,n_informative=2, n_redundant=0, n_repeated=0,n_classes=4, n_clusters_per_class=1, random_state=1)
# counter = Counter(y)
# # 画出二维散点图
# for label, _ in counter.items():
# 	row_ix = where(y == label)[0]
# 	pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
# pyplot.legend()
# pyplot.show()# 10%作为训练集,90%作为测试集
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.9)
# print(x_train.shape[0])  # 100nowNum = x_train.shape[0]  # 用来控制 训练集/筛选后的样本数 满足resultNum就停下, 初始有x_train这么多个while True:KNN_clf = KNN(n_neighbors=k)KNN_clf.fit(x_train, y_train)y_predict = KNN_clf.predict(x_test)cond = y_predict == y_test  # cond记录分类的对与错,分类错是False,正确是True# 都分类正确,退出if cond.all():print('所有测试集都分类正确,CNN自动结束,但是结果集没凑够呢!')break# 如果结果集数量不够要求的endNum,继续下一轮if nowNum+y_test[~cond].shape[0] < endNum:nowNum = nowNum+y_test[~cond].shape[0]print("目前结果集数量:", nowNum)x_train = np.vstack([x_train, x_test[~cond]])  # 把分类错误(cond的值是False)的移动到训练集里y_train = np.hstack([y_train, y_test[~cond]])x_test = x_test[cond]  # 把分类对的继续作为下一轮的测试集y_test = y_test[cond]# 如果结果集数量超过endNum,我们只要测试集里分类错误的前endNum-nowNum个else:# 记录前endNum-nowNum个的位置(截取位置condCut = 0  # 记录截取位置for i in range(cond.shape[0]):if not cond[i]:nowNum = nowNum + 1if nowNum == endNum:condCut = i  # 在cond[condCut]处刚好是我们要的第endNum个结果集样本break# 把cond[condCut]后面的都设置成Truecond[condCut+1:] = Truex_train = np.vstack([x_train, x_test[~cond]])  # 把分类错误(cond的值是False)的移动到训练集里y_train = np.hstack([y_train, y_test[~cond]])print("结果集的数量为", x_train.shape[0], "满足endNum=", endNum)breakif len(x_test) == 0:print("所有样本都能做到分类错误,也就是结果集=原始数据集,一般不会出现这种情况")break# summarize the new class distribution
counter = Counter(y_train)
print(counter)# 画出二维散点图
for label, _ in counter.items():row_ix = where(y_train == label)[0]pyplot.scatter(x_train[row_ix, 0], x_train[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/693746.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入探索STM32的存储选项:片内RAM、片内Flash与SDRAM

博客&#xff1a;深入探索STM32的存储选项&#xff1a;片内RAM、片内Flash与SDRAM 在嵌入式系统设计中&#xff0c;存储管理是一个至关重要的方面&#xff0c;尤其是对于基于STM32这类强大的微控制器来说。STM32系列微控制器因其高性能、低功耗以及灵活的存储选项而广受欢迎。本…

RabbitMQ集群架构

1.RabbitMQ集群模式介绍 普通集群 默认的集群模式&#xff0c;比如有节点node1、node2和node3&#xff0c;三个节点是普通集群&#xff0c;但是他们仅有相同的元数据&#xff0c;即交换机、队列的结构消息只存在其中的一个节点里面&#xff0c;假如消息A存储在node1节点&#x…

Python中HTTP请求的安全性考虑与实践:安全帽下的网络舞者

在Python的HTTP请求世界里&#xff0c;安全性就像是一个必不可少的舞伴&#xff0c;时刻陪伴着你的网络舞步。想象一下&#xff0c;你正在举办一场网络舞会&#xff0c;而安全性则是那个穿着防弹舞衣&#xff0c;戴着安全帽的忠诚舞伴&#xff0c;确保你在舞池中尽情舞动而不必…

数据结构---字典树(Tire)

字典树是一种能够快速插入和查询字符串的多叉树结构&#xff0c;节点的编号各不相同&#xff0c;根节点编号为0 Trie树&#xff0c;即字典树&#xff0c;又称单词查找树或键树&#xff0c;是一种树形结构&#xff0c;是一种哈希树的变种。 核心思想也是通过空间来换取时间上的…

C#写的一个计算DCI-P3色域和SRGB的小工具

文章最后附带分享链接与提取码 方便需要测试屏幕的小伙伴&#xff0c;只需要输入RGB就能得到覆盖率与比率&#xff0c;W计算色温&#xff0c;不测也要写上&#xff0c;不然会报错 链接&#xff1a;https://pan.baidu.com/s/1wdmAwmwiXjNvn1tGsvy0HA 提取码&#xff1a;1234

安卓学习笔记之五:Android Studio_骰子案例3(Kotlin搭配 Jetpack Compose实现)

使用 Compose 创建一款交互式 Dice Roller Android 应用。 完成&#xff1a; 定义可组合函数。使用组合创建布局。使用 Button 可组合项创建按钮。导入 drawable 资源。使用 Image 可组合项显示图片。使用可组合项构建交互式界面。使用 remember 可组合项将组合中的对象存储到…

【Docker】有用的命令

文章目录 DockerDocker 镜像与容器的差异Docker的好处Hypervisor运维 一、安装docker二、启动docker三、获取docker镜像四、创建镜像使用命令行创建镜像使用dockerfile创建镜像 五、docker报错 Docker docker镜像&#xff08;Image&#xff09; docker镜像类似于虚拟机镜像&…

linux 安装anaconda踩坑——哈希值对不上

下载安装包时执行命令 curl -O https://repo.anaconda.com/archive/Anaconda3-<INSTALLER_VERSION>-Linux-x86_64.sh 其中的<INSTALLER_VERSION>需要填写下载的anaconda版本号&#xff0c;于是我就点开官网提供的版本号链接&#xff0c;将我要下载的版本号copy了一…

pom.xml常见依赖及其作用

1.org.mybatis.spring.boot下的mybatis-spring-boot-starter&#xff1a;这个依赖是mybatis和springboot的集成库&#xff0c;简化了springboot项目中使用mybatis进行持久化操作的配置和管理 2.org.projectlombok下的lombok&#xff1a;常用注解Data、NoArgsConstructor、AllA…

如何在Ubuntu部署Emlog,并将本地博客发布至公网可远程访问

文章目录 前言1. 网站搭建1.1 Emolog网页下载和安装1.2 网页测试1.3 cpolar的安装和注册 2. 本地网页发布2.1 Cpolar临时数据隧道2.2.Cpolar稳定隧道&#xff08;云端设置&#xff09;2.3.Cpolar稳定隧道&#xff08;本地设置&#xff09; 3. 公网访问测试总结 前言 博客作为使…

2.20学习总结

1.【模板】单源最短路径&#xff08;弱化版&#xff09; 2.【模板】单源最短路径&#xff08;标准版&#xff09; 3.无线通讯网 4.子串简写 5.整数删除 6.拆地毯 【模板】单源最短路径&#xff08;标准版&#xff09;https://www.luogu.com.cn/problem/P4779 题目描述 给定一个…

Excel SUMPRODUCT函数用法(乘积求和,分组排序)

SUMPRODUCT函数是Excel中功能比较强大的一个函数&#xff0c;可以实现sum,count等函数的功能&#xff0c;也可以实现一些基础函数无法直接实现的功能&#xff0c;常用来进行分类汇总&#xff0c;分组排序等 SUMPRODUCT 函数基础 SUMPRODUCT函数先计算多个数组的元素之间的乘积…

Kubernetes安装nginx-controller作为统一网关

nginx-controller是什么呢? 它是一个能调度nginx的一个kubernetes operator,它能监听用户创建,更新,删除NginxConf对象,来调度本地的nginx实现配置的动态更新。如添加新的代理(http,https,tcp,udp),缓存(浏览器缓存,本地缓存),ssl证书(配置本身,ConfigMap,Secret),更新,删除等…

c语言结构体与共用体

前面我们介绍了基本的数据类型 在c语言中 有一种特殊的数据类型 由程序员来定义类型 目录 一结构体 1.1概述 1.2定义结构体 1.3 结构体变量的初始化 1.4 访问结构体的成员 1.5结构体作为函数的参数 1.6指向结构的指针 1.7结构体大小的计算 二共用体 2.1概述 2.2 访…

04 Aras Innovator二次开发-客户端方法

客户端方法为JS方法。 系统提供了很多触发点&#xff0c;可以嵌入客户端方法&#xff0c;如下&#xff1a; 1 对象类的客户端事件页签&#xff1a; 2 窗体的Form Event和Filed Event 3.关系类的网格事件&#xff1a; 4 属性事件&#xff1a; 5.可自定义Action,触发客户端事件…

数据结构与算法:栈

朋友们大家好啊&#xff0c;在链表的讲解过后&#xff0c;我们本节内容来介绍一个特殊的线性表&#xff1a;栈&#xff0c;在讲解后也会以例题来加深对本节内容的理解 栈 栈的介绍栈进出栈的变化形式 栈的顺序存储结构的有关操作栈的结构定义与初始化压栈操作出栈操作获取栈顶元…

VR全景开启线上卖房新渠道,助力房企改变营销方式

当下房产行业&#xff0c;还在依靠传统线下发传单、跑客户、做地推吗&#xff1f;在短视频和直播火热的今天&#xff0c;房产行业也开启了线上卖房的新渠道&#xff0c;通过VR全景技术&#xff0c;可以为各个小区的线上宣传增加趣味性和互动性。 一、VR全景漫游可以彰显房源真实…

如何更换过期的SSL证书?

SSL证书是保护网站安全的重要组成部分&#xff0c;它能在客户端和服务器之间建立数据传输加密通道&#xff0c;防止数据在传输过程中被泄露、劫持和窃听。但SSL证书也有有效期限&#xff0c;当SSL证书到期时&#xff0c;您需要及时更换它&#xff0c;以确保网站的安全性和可信度…

Git基本操作(2)

Git基本操作&#xff08;2&#xff09; 上交文件之后&#xff0c;git文件的变化git cat-file HEAD指针里面有啥文件被修改git statusgit diff 文件名 版本回退&#xff08;git reset&#xff09;撤销回退git reflog 撤销的三种情况还没有addgit checkout -- [file] 已经add还没…

Pandas快问快答16-30题

16. 如何对一个Pandas数据框进行聚合操作? 聚合操作是数据处理中的一种重要方式&#xff0c;主要用于对一组数据进行汇总和计算&#xff0c;以得到单一的结果。在聚合操作中&#xff0c;可以执行诸如求和、平均值、最大值、最小值、计数等统计操作。这些操作通常用于从大量数…