基于word2vec 和 fast-pytorch-kmeans 的文本聚类实现,利用GPU加速提高聚类速度

文章目录

    • 简介
      • GPU加速
    • 代码实现
    • kmeans
    • 聚类结果
    • kmeans 绘图函数
    • 相关资料参考

简介

本文使用text2vec模型,把文本转成向量。使用text2vec提供的训练好的模型权重进行文本编码,不重新训练word2vec模型。

直接用训练好的模型权重,方便又快捷

完整可运行代码如下:
https://github.com/JieShenAI/csdn/blob/main/machine_learning/kmeans_pytorch.ipynb

GPU加速

传统sklearn的TF-IDF文本转向量,在CPU上计算速度较慢。使用text2vec通过cuda加速,加快文本转向量的速度。
传统使用sklearn的kmeans聚类算法在CPU上计算,如遇到大批量的数据,计算耗时太长。
故本文使用kmeans_pytorch包,基于pytorch在GPU上计算,提高聚类速度。

代码实现

装包

pip install fast-pytorch-kmeans text2vec
import torch
import numpy as npfrom text2vec import SentenceModel

不使用SentenceModel模型也可以,在 text2vec 中,还有很多其他的向量编码模型供选择。

文本编码模型

embedder = SentenceModel()

异常情况说明,该模型需要从huggingface下载模型权重,目前被墙了。(请想办法解决,或者尝试其他的编码模型)
在这里插入图片描述

语料库如下:

# Corpus with example sentences
corpus = ['花呗更改绑定银行卡','我什么时候开通了花呗','A man is eating food.','A man is eating a piece of bread.','The girl is carrying a baby.','A man is riding a horse.','A woman is playing violin.','Two men pushed carts through the woods.','A man is riding a white horse on an enclosed ground.',
]
corpus_embeddings = embedder.encode(corpus)
# numpy 转成 pytorch, 并转移到GPU显存中
corpus_embeddings = torch.from_numpy(corpus_embeddings).to('cuda')

如下图所示,编码的向量是768纬;

type(corpus_embeddings), corpus_embeddings.shape

在这里插入图片描述

kmeans

kmeans_pytorch vs fast-pytorch-kmeans:
在实验过程中,利用kmeans_pytorch 针对30万个词进行聚类的时候,发现显存炸了,程序崩溃退出。30万个词的词向量,占用显存还不到2G,但是运行kmeans_pytorch后,显存就炸了。

fast-pytorch-kmeans不存在上述显存崩溃的问题。本以为词向量很多会跑很长时间,但fast-pytorch-kmeans在非常短的时间内就完成了kmeans聚类。

# kmeans
# from kmeans_pytorch import kmeans
from fast_pytorch_kmeans import KMeansnum_class = 3 # 分类类别数
kmeans = KMeans(n_clusters=num_class, mode='euclidean', verbose=1)# 模型预测结果
labels = kmeans.fit_predict(corpus_embeddings)

聚类程序运行如下:

used 2 iterations (0.3682s) to cluster 9 items into 3 clusters

模型中心点坐标:

kmeans.centroids

在这里插入图片描述

聚类结果

class_data = {i:[]for i in range(3)
}for text,cls in zip(corpus, labels):class_data[cls.item()].append(text)class_data

文本聚类结果如下:
0: 女
1:男
2: 花呗
在这里插入图片描述

kmeans 绘图函数

封装了KMeansPlot 绘图类,方便聚类结果可视化

from sklearn.decomposition import PCA
import matplotlib.pyplot as pltclass KMeansPlot:def __init__(self, numClass=4, func_type='PCA'):if func_type == 'PCA':self.func_plot = PCA(n_components=2)elif func_type == 'TSNE':from sklearn.manifold import TSNEself.func_plot = TSNE(2)self.numClass = numClassdef plot_cluster(self, result, pos, cluster_centers=None):plt.figure(2)Lab = [[] for i in range(self.numClass)]index = 0for labi in result:Lab[labi].append(index)index += 1color = ['oy', 'ob', 'og', 'cs', 'ms', 'bs', 'ks', 'ys', 'yv', 'mv', 'bv', 'kv', 'gv', 'y^', 'm^', 'b^', 'k^','g^'] * 3for i in range(self.numClass):x1 = []y1 = []for ind1 in pos[Lab[i]]:# print ind1try:y1.append(ind1[1])x1.append(ind1[0])except:passplt.plot(x1, y1, color[i])if cluster_centers is not None:#绘制初始中心点x1 = []y1 = []for ind1 in cluster_centers:try:y1.append(ind1[1])x1.append(ind1[0])except:passplt.plot(x1, y1, "rv") #绘制中心plt.show()def plot(self, weight, label, cluster_centers=None):pos = self.func_plot.fit_transform(weight)# 高纬的中心点坐标,也经过降纬处理cluster_centers = self.func_plot.fit_transform(cluster_centers)self.plot_cluster(list(label), pos, cluster_centers)

kmeans.centroids :是一个高纬空间的中心点坐标,故在plot函数中,将其降纬到2D平面上;

k_plot = KMeansPlot(num_class)
k_plot.plot(corpus_embeddings.to('cpu'),labels.to('cpu'),kmeans.centroids.to('cpu')
)

在这里插入图片描述

完整可运行代码如下:
https://github.com/JieShenAI/csdn/blob/main/machine_learning/kmeans_pytorch.ipynb

相关资料参考

  • 动手实战基于 ML 的中文短文本聚类
  • tfidf和word2vec构建文本词向量并做文本聚类
    提到训练word2vec模型,silhouette_score_show(word2vec, 'word2vec') 轮廓系数,判断分几个类别最好。
  • 机器学习:Kmeans聚类算法总结及GPU配置加速demo
    PyTorch kmeans 加速。from scratch 实现;
  • KMeans算法全面解析与应用案例 通俗易懂的原理讲解
  • pytorch K-means算法的实现 底层代码实现
  • 【pytorch】Kmeans_pytorch用于一般聚类任务的代码模板 使用pytorch封装的kmeans包实现,包括训练和预测;
  • text2vec 包

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/745710.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

软考高级:遗留系统演化策略(集成、淘汰、改造、继承)概念和例题

作者:明明如月学长, CSDN 博客专家,大厂高级 Java 工程师,《性能优化方法论》作者、《解锁大厂思维:剖析《阿里巴巴Java开发手册》》、《再学经典:《Effective Java》独家解析》专栏作者。 热门文章推荐&am…

【刷题训练】Leetcode415.字符串相加

字符串相加 题目要求 示例 1: 输入:num1 “11”, num2 “123” 输出:“134” 示例 2: 输入:num1 “456”, num2 “77” 输出:“533” 示例 3: 输入:num1 “0”, num2 “0”…

【计算机视觉】一、计算机视觉概述

文章目录 一、计算机视觉二、计算机视觉与其它学科领域的关系1、图像处理2、计算机图形学3、模式识别4、人工智能(AI)5、神经生理学与认知科学 三、计算机视觉的应用1. 人脸识别2. 目标检测3. 图像生成4. 城市建模5. 电影特效6. 体感游戏动作捕捉7. 虚拟…

java数据结构与算法刷题-----LeetCode47. 全排列 II

java数据结构与算法刷题目录(剑指Offer、LeetCode、ACM)-----主目录-----持续更新(进不去说明我没写完):https://blog.csdn.net/grd_java/article/details/123063846 文章目录 1. 暴力回溯2. 分区法回溯 此题为46题的衍生题,在46题…

PHP极简网盘系统源码 轻量级文件管理与共享系统网站源码

PHP极简网盘系统源码 轻量级文件管理与共享系统网站源码 极简网盘是一个轻量级文件管理与共享系统,支持多用户,可充当网盘程序,程序无需数据库 安装步骤: 1.建议安装在apache环境下,并确保.htaccess可用 2.解压文件…

PHP序列化基础知识储备

一、序列化与反序列化 1、概念 PHP中的序列化是指将复杂的数据类型转换为可存储或可传输的字符串,而反序列化则是将这些字符串重新转换回原来的数据类型。 序列化通常使用 serialize() 函数完成,它可以将数组、对象、字符串等复杂数据类型压缩到一个字…

Infineon_TC264智能车代码初探及C语言深度学习(二)

本篇文章记录我在智能车竞赛中,对 Infineon_TC264 这款芯片的底层库函数的学习分析。通过深入地对其库函数进行分析,C语言深入的知识得以再次在编程中呈现和运用。故觉得很有必要在此进行记录分享一下。 目录 ​编辑 一、代码段分析 NO.1 指向结构体…

CSDN 编辑器设置图片缩放和居中

CSDN 编辑器设置图片缩放和居中 文章目录 CSDN 编辑器设置图片缩放和居中对齐方式比例缩放 对齐方式 Markdown 编辑器插入图片的代码格式为 ![图片描述](图片路径)CSDN 的 Markdown 编辑器中插入图片,默认都是左对齐,需要设置居中对齐的话,…

QTextToSpeech的使用——Qt

前言 之前随便看了几眼QTextToSpeech的帮助就封装使用了,达到了效果就没再管了,最近需要在上面加功能(变换语速),就写了个小Demo后,发现不对劲了。 出现的问题 场景 写了个队列添加到语音播放子线程中&a…

HTTPS基础

目录 HTTPS简介 HTTP与HTTPS的区别 CA证书 案例 服务器生成私钥与证书 查看证书和私钥存放路径 Cockpit(图像化服务管理工具) HTTPS简介 超文本传输协议HTTP协议被用于在Web浏览器和网站服务器之间传递信息。HTTP协议以明文方式发送内容,不提供任何方式的数据加密&…

C++——类和对象(1)

1. 面向对象和面向过程对比 当涉及到编程范式时,两个主要的方法是面向对象编程(Object-Oriented Programming,OOP)和面向过程编程(Procedural Programming)。这两种编程范式在解决问题和组织代码时有着不同…

COX回归影响因素分析的基本过程与方法

在科学研究中,经常遇到分类的结局,主要是二分类结局(阴性/阳性;生存/死亡),研究者可以通过logistic回归来探讨影响结局的因素,但很多时候logistic回归方法无法使用。如比较两种手段治疗新冠肺炎…

Annaconda环境下ChromeDriver配置及爬虫编写

Anaconda环境的chromedriver安装配置_anaconda 配置chromedriver-CSDN博客 Chromedriver驱动( 121.0.6167.85 ) - 知乎 下载好的驱动文件解压,将exe程序复制到Annaconda/Scripts目录以及Chrome/Application目录下 注意要提前pip install selenium包才能运行成功&a…

BEV系列一:BEV介绍和常用BEV算法简介

BEV系列一:BEV介绍和常用BEV算法简介 自动驾驶最全学习资料获取:链接

Linux操作系统——线程概念

1.什么是线程? 在一个程序里的一个执行路线就叫做线程(thread)。更准确的定义是:线程是“一个进程内部的控制序列”一切进程至少都有一个执行线程线程在进程内部运行,本质是在进程地址空间内运行在Linux系统中&#x…

openGauss学习笔记-242 openGauss性能调优-SQL调优-典型SQL调优点-SQL自诊断

文章目录 openGauss学习笔记-242 openGauss性能调优-SQL调优-典型SQL调优点-SQL自诊断242.1 SQL自诊断242.1.1 告警场景242.1.2 规格约束 openGauss学习笔记-242 openGauss性能调优-SQL调优-典型SQL调优点-SQL自诊断 SQL调优是一个不断分析与尝试的过程:试跑Query&…

【Qt】常用控件或属性(1)

需要云服务器等云产品来学习Linux可以移步/-->腾讯云<--/官网&#xff0c;轻量型云服务器低至112元/年&#xff0c;新用户首次下单享超低折扣。 目录 一、QWidget属性一览 二、控件button、属性enabled(可用状态) 三、属性geometry(修改位置和尺寸) 1、QRect类型的结…

微信小程序之tabBar

1、tabBar 如果小程序是一个多 tab 应用&#xff08;客户端窗口的底部或顶部有 tab 栏可以切换页面&#xff09;&#xff0c;可以通过 tabBar 配置项指定 tab 栏的表现&#xff0c;以及 tab 切换时显示的对应页面。 属性类型必填默认值描述colorHexColor是tab 上的文字默认颜色…

Leetcode 3.14

Leetcode hot100 二叉树1.二叉树的层序遍历2.验证二叉搜索树3.二叉树的右视图 二叉树 1.二叉树的层序遍历 二叉树的层序遍历 二叉树的层序遍历可以用先进先出的队列来实现。 将每一层的所有node都添加到队列中&#xff0c;记录下当前队列的长度&#xff0c;即该层的元素数量&…

『 Linux 』进程替换( Process replacement ) 及 简单Shell的实现(万字)

文章目录 &#x1f984; 进程替换&#x1f9a9; execl()函数&#x1f9a9; execlp()函数&#x1f9a9; execle()函数&#x1f9a9; execv()函数&#x1f9a9; execvp()函数&#x1f9a9; execvpe()函数&#x1f9a9; execve()函数 &#x1f984; 简单Shell命令行解释器的实现&a…