基于word2vec 和 fast-pytorch-kmeans 的文本聚类实现,利用GPU加速提高聚类速度

文章目录

    • 简介
      • GPU加速
    • 代码实现
    • kmeans
    • 聚类结果
    • kmeans 绘图函数
    • 相关资料参考

简介

本文使用text2vec模型,把文本转成向量。使用text2vec提供的训练好的模型权重进行文本编码,不重新训练word2vec模型。

直接用训练好的模型权重,方便又快捷

完整可运行代码如下:
https://github.com/JieShenAI/csdn/blob/main/machine_learning/kmeans_pytorch.ipynb

GPU加速

传统sklearn的TF-IDF文本转向量,在CPU上计算速度较慢。使用text2vec通过cuda加速,加快文本转向量的速度。
传统使用sklearn的kmeans聚类算法在CPU上计算,如遇到大批量的数据,计算耗时太长。
故本文使用kmeans_pytorch包,基于pytorch在GPU上计算,提高聚类速度。

代码实现

装包

pip install fast-pytorch-kmeans text2vec
import torch
import numpy as np

from text2vec import SentenceModel

不使用SentenceModel模型也可以,在 text2vec 中,还有很多其他的向量编码模型供选择。

文本编码模型

embedder = SentenceModel()

异常情况说明,该模型需要从huggingface下载模型权重,目前被墙了。(请想办法解决,或者尝试其他的编码模型)
在这里插入图片描述

语料库如下:

# Corpus with example sentences
corpus = [
    '花呗更改绑定银行卡',
    '我什么时候开通了花呗',
    'A man is eating food.',
    'A man is eating a piece of bread.',
    'The girl is carrying a baby.',
    'A man is riding a horse.',
    'A woman is playing violin.',
    'Two men pushed carts through the woods.',
    'A man is riding a white horse on an enclosed ground.',
]
corpus_embeddings = embedder.encode(corpus)
# numpy 转成 pytorch, 并转移到GPU显存中
corpus_embeddings = torch.from_numpy(corpus_embeddings).to('cuda')

如下图所示,编码的向量是768纬;

type(corpus_embeddings), corpus_embeddings.shape

在这里插入图片描述

kmeans

kmeans_pytorch vs fast-pytorch-kmeans:
在实验过程中,利用kmeans_pytorch 针对30万个词进行聚类的时候,发现显存炸了,程序崩溃退出。30万个词的词向量,占用显存还不到2G,但是运行kmeans_pytorch后,显存就炸了。

fast-pytorch-kmeans不存在上述显存崩溃的问题。本以为词向量很多会跑很长时间,但fast-pytorch-kmeans在非常短的时间内就完成了kmeans聚类。

# kmeans
# from kmeans_pytorch import kmeans
from fast_pytorch_kmeans import KMeans

num_class = 3 # 分类类别数
kmeans = KMeans(n_clusters=num_class, mode='euclidean', verbose=1)

# 模型预测结果
labels = kmeans.fit_predict(corpus_embeddings)

聚类程序运行如下:

used 2 iterations (0.3682s) to cluster 9 items into 3 clusters

模型中心点坐标:

kmeans.centroids

在这里插入图片描述

聚类结果

class_data = {
    i:[]
    for i in range(3)
}

for text,cls in zip(corpus, labels):
    class_data[cls.item()].append(text)

class_data

文本聚类结果如下:
0: 女
1:男
2: 花呗
在这里插入图片描述

kmeans 绘图函数

封装了KMeansPlot 绘图类,方便聚类结果可视化

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

class KMeansPlot:

    def __init__(self, numClass=4, func_type='PCA'):
        if func_type == 'PCA':
            self.func_plot = PCA(n_components=2)
        elif func_type == 'TSNE':
            from sklearn.manifold import TSNE
            self.func_plot = TSNE(2)
        self.numClass = numClass

    def plot_cluster(self, result, pos, cluster_centers=None):
        plt.figure(2)
        Lab = [[] for i in range(self.numClass)]
        index = 0
        for labi in result:
            Lab[labi].append(index)
            index += 1
        color = ['oy', 'ob', 'og', 'cs', 'ms', 'bs', 'ks', 'ys', 'yv', 'mv', 'bv', 'kv', 'gv', 'y^', 'm^', 'b^', 'k^',
                    'g^'] * 3

        for i in range(self.numClass):
            x1 = []
            y1 = []
            for ind1 in pos[Lab[i]]:
                # print ind1
                try:
                    y1.append(ind1[1])
                    x1.append(ind1[0])
                except:
                    pass
            plt.plot(x1, y1, color[i])

        if cluster_centers is not None:
            #绘制初始中心点
            x1 = []
            y1 = []

            for ind1 in cluster_centers:
                try:
                    y1.append(ind1[1])
                    x1.append(ind1[0])
                except:
                    pass

            plt.plot(x1, y1, "rv") #绘制中心

        plt.show()

    def plot(self, weight, label, cluster_centers=None):
        pos = self.func_plot.fit_transform(weight)
        # 高纬的中心点坐标,也经过降纬处理
        cluster_centers = self.func_plot.fit_transform(cluster_centers)
        self.plot_cluster(list(label), pos, cluster_centers)

kmeans.centroids :是一个高纬空间的中心点坐标,故在plot函数中,将其降纬到2D平面上;

k_plot = KMeansPlot(num_class)
k_plot.plot(
    corpus_embeddings.to('cpu'),
    labels.to('cpu'),
    kmeans.centroids.to('cpu')
)

在这里插入图片描述

完整可运行代码如下:
https://github.com/JieShenAI/csdn/blob/main/machine_learning/kmeans_pytorch.ipynb

相关资料参考

  • 动手实战基于 ML 的中文短文本聚类
  • tfidf和word2vec构建文本词向量并做文本聚类
    提到训练word2vec模型,silhouette_score_show(word2vec, 'word2vec') 轮廓系数,判断分几个类别最好。
  • 机器学习:Kmeans聚类算法总结及GPU配置加速demo
    PyTorch kmeans 加速。from scratch 实现;
  • KMeans算法全面解析与应用案例 通俗易懂的原理讲解
  • pytorch K-means算法的实现 底层代码实现
  • 【pytorch】Kmeans_pytorch用于一般聚类任务的代码模板 使用pytorch封装的kmeans包实现,包括训练和预测;
  • text2vec 包

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/457863.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

19C 19.22 RAC 2节点一键安装演示

Oracle 一键安装脚本,演示 2 节点 RAC 一键安装过程(全程无需人工干预):(脚本包括 GRID/ORALCE PSU/OJVM 补丁自动安装) ⭐️ 脚本下载地址:Shell脚本安装Oracle数据库 脚本第三代支持 N 节点…

CompletableFuture原理与实践-外卖商家端API的异步化

背景 随着订单量的持续上升,美团外卖各系统服务面临的压力也越来越大。作为外卖链路的核心环节,商家端提供了商家接单、配送等一系列核心功能,业务对系统吞吐量的要求也越来越高。而商家端API服务是流量入口,所有商家端流量都会由…

畅捷通T+ Ufida.T.DI.UIP.RRA.RRATableController 反序列化RCE漏洞复现

0x01 产品简介 畅捷通 T+ 是一款灵动,智慧,时尚的基于互联网时代开发的管理软件,主要针对中小型工贸与商贸企业,尤其适合有异地多组织机构(多工厂,多仓库,多办事处,多经销商)的企业,涵盖了财务,业务,生产等领域的应用,产品应用功能包括:采购管理、库存管理、销售…

Python基于大数据的豆瓣电影分析,豆瓣电影可视化系统,附源码

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝12W、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

linux 安装gradle7.4.2环境

1.下载gradle7.4.2工程 百度网盘 请输入提取码百度网盘为您提供文件的网络备份、同步和分享服务。空间大、速度快、安全稳固,支持教育网加速,支持手机端。注册使用百度网盘即可享受免费存储空间https://pan.baidu.com/s/1hoNEFkBJPHAgs9ITAEh3Zg?pwdGJ…

活动图高阶讲解-03

1 00:00:00,000 --> 00:00:06,260 刚才我们讲了活动图的历史 2 00:00:06,260 --> 00:00:11,460 那我们来看这个活动图 3 00:00:11,460 --> 00:00:15,260 如果用来建模的话怎么用 4 00:00:15,260 --> 00:00:20,100 按照我们前面讲的软件方法的工作流 5 00:00:20…

mysql的语法总结3

查询表 精确查找 举例 去除重复行 假设您有一个名为 students 的表,其中包含 name 和 age 两列,您想要查询所有不重复的年龄,可以使用以下查询: 详细匹配 查询emp表中在部门10工作、工资高于1000或岗位是CLERK的所有雇员的姓名、…

C++ 优先级队列(大小根堆)OJ

目录 1、 1046. 最后一块石头的重量 2、 703. 数据流中的第 K 大元素 为什么小根堆可以解决TopK问题? 3、 692. 前K个高频单词 4、 295. 数据流的中位数 1、 1046. 最后一块石头的重量 思路:根据示例发现可以用大根堆(降序)模拟这个过程。 class So…

【Jenkins】data stream error|Error cloning remote repo ‘origin‘ 错误解决【亲测有效】

错误构建日志 17:39:09 ERROR: Error cloning remote repo origin 17:39:09 hudson.plugins.git.GitException: Command "git fetch --tags --progress http://domain/xxx.git refs/heads/*:refs/remotes/origin/*" returned status code 128: 17:39:09 stdout: 17…

unity报错出现Asset database transaction committed twice!

错误描述: 运行时报错 Assertion failed on expression: ‘m_ErrorCode MDB_MAP_RESIZED || !HasAbortingErrors()’Asset database transaction committed twice!Assertion failed on expression: ‘errors MDB_SUCCESS || errors MDB_NOTFOUND’ 解决办法&…

基于springboot实现驾校信息管理系统项目【项目源码+论文说明】计算机毕业设计

基于springboot实现驾校信息管理系统演示 摘要 随着人们生活水平的不断提高,出行方式多样化,也以私家车为主,那么既然私家车的需求不断增长,那么基于驾校的考核管理也就不断增强,那么业务系统也就慢慢的随之加大。信息…

mac删除带锁标识的app

一 、我们这里要删除FortiClient.app 带锁 常规方式删除不掉带锁的 app【如下图】 二、删除命令,依次执行即可。 /bin/ls -dleO /Applications/FortiClient.app sudo /usr/bin/chflags -R noschg /Applications/FortiClient.app /bin/ls -dleO /Applications/Forti…

C语言【典型算法编程题】总结

以下最全总结! 一,分支结构 1,if 编写程序,从键盘上输入三角形的三个边长(实数),判断这三个边能否构成三角形(构成三角形的条件为:任意两边之和大于第三边),如果能构成三角形,则计算三角形的面积并输出(保留2位小数);如果不能构成三角形,则输出“Flase”字符…

idea如何使用,从激活开始

idea到期后激活使用 如何使用 点击阅读 idea分享

【LinuxC】C语言线程(pthread)

文章目录 一、 POSIX 线程库1.1 POSIX标准1.2 Pthreads1.2 数据类型、函数、宏1.21 数据类型1.22 函数1.23 宏 二、创建线程三、线程同步四、线程销毁五、示例5.1 完整示例5.2 信号量示例 本专栏上一篇文章是Windows下(MSVC)的线程编程,需要的…

ELK日志管理实现的3种常见方法

ELK日志管理实现的3种常见方法 1. 日志收集方法 1.1 使用DaemonSet方式日志收集 通过将node节点的/var/log/pods目录挂载给以DaemonSet方式部署的logstash来读取容器日志,并将日志吐给kafka并分布写入Zookeeper数据库.再使用logstash将Zookeeper中的数据写入ES,并通过kibana…

如何利用百度SEO优化技巧将排到首页

拥有一个成功的网站对于企业和个人来说是至关重要的,在当今数字化的时代。在互联网上获得高流量和优质的访问者可能并不是一件容易的事情,然而。一个成功的SEO战略可以帮助你实现这一目标。需要一些特定的技巧和策略、但要在百度搜索引擎中获得较高排名。…

某狗网翻译接口逆向之webpack扣取

​​​​​逆向网址 aHR0cHM6Ly9mYW55aS5zb2dvdS5jb20 逆向链接 aHR0cHM6Ly9mYW55aS5zb2dvdS5jb20vdGV4dA 逆向接口 aHR0cHM6Ly9mYW55aS5zb2dvdS5jb20vYXBpL3RyYW5zcGMvdGV4dC9yZXN1bHQ 逆向过程 请求方式:POST 参数构成: 【s】 1b921dbefaa8d939afca…

了解常用测试模型 -- 瀑布模型、螺旋模型、增量与迭代、敏捷开发

目录 瀑布模型 开发流程 开发特征 优缺点 适用场景 螺旋模型 开发流程 开发特征 优缺点 适用场景 增量与迭代开发 什么是增量开发?什么是迭代开发? 敏捷开发 什么是敏捷开发四原则(敏捷宣言)? 什么是 s…

Pytorch从零开始实战21

Pytorch从零开始实战——Pix2Pix理论与实战 本系列来源于365天深度学习训练营 原作者K同学 文章目录 Pytorch从零开始实战——Pix2Pix理论与实战内容介绍数据集加载模型实现开始训练总结 内容介绍 Pix2Pix是一种用于用于图像翻译的通用框架,即图像到图像的转换。…