bert 相似度任务训练简单版本,faiss 寻找相似 topk

目录

任务

代码

train.py

predit.py

faiss 最相似的 topk


任务

使用 bert-base-chinese 训练相似度任务,参考:微调BERT模型实现相似性判断 - 知乎

参考他上面代码,他使用的是 BertForNextSentencePrediction 模型,BertForNextSentencePrediction 原本是设计用于下一个句子预测任务的。在BERT的原始训练中,模型会接收到一对句子,并试图预测第二个句子是否紧跟在第一个句子之后;所以使用这个模型标签(label)只能是 0,1,相当于二分类任务了

但其实在相似度任务中,我们每一条数据都是【text1\ttext2\tlabel】的形式,其中 label 代表相似度,可以给两个文本打分表示相似度,也可以映射为分类任务,0 代表不相似,1 代表相似,他这篇文章利用了这种思想,对新手还挺有用的。

现在我搞了一个招聘数据,里面有办公区域列,处理过了,每一行代表【地址1\t地址2\t相似度】

只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

利用这数据微调,没有使用验证数据集,就最后使用测试集来看看效果。

代码

train.py

import json
import torch
from transformers import BertTokenizer, BertForNextSentencePrediction
from torch.utils.data import DataLoader, Dataset


# 能用gpu就用gpu
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

bacth_size = 32
epoch = 3
auto_save_batch = 5000
learning_rate = 2e-5


# 准备数据集
class MyDataset(Dataset):
    def __init__(self, data_file_paths):
        self.texts = []
        self.labels = []
        # 分词器用默认的
        self.tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
        # 自己实现对数据集的解析
        with open(data_file_paths, 'r', encoding='utf-8') as f:
            for line in f:
                text1, text2, label = line.split('\t')
                self.texts.append((text1, text2))
                self.labels.append(int(label))

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text1, text2 = self.texts[idx]
        label = self.labels[idx]
        encoded_text = self.tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        return encoded_text, label


# 训练数据文件路径
train_dataset = MyDataset('../data/train.txt')

# 定义模型
# num_labels=5 定义相似度评分有几个
model = BertForNextSentencePrediction.from_pretrained('../bert-base-chinese', num_labels=6)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
train_loader = DataLoader(train_dataset, batch_size=bacth_size, shuffle=True)
trained_data = 0
batch_after_last_save = 0
total_batch = 0
total_epoch = 0

for epoch in range(epoch):
    trained_data = 0
    for batch in train_loader:
        inputs, labels = batch
        # 不知道为啥,出来的数据维度是 (batch_size, 1, 128),需要把第二维去掉
        inputs['input_ids'] = inputs['input_ids'].squeeze(1)
        inputs['token_type_ids'] = inputs['token_type_ids'].squeeze(1)
        inputs['attention_mask'] = inputs['attention_mask'].squeeze(1)
        # 因为要用GPU,将数据传输到gpu上
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(**inputs, labels=labels)
        loss, logits = outputs[:2]
        loss.backward()
        optimizer.step()
        trained_data += len(labels)
        trained_process = float(trained_data) / len(train_dataset)
        batch_after_last_save += 1
        total_batch += 1
        # 每训练 auto_save_batch 个 batch,保存一次模型
        if batch_after_last_save >= auto_save_batch:
            batch_after_last_save = 0
            model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')
            print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))
        print("训练进度:{:.2f}%, loss={:.4f}".format(trained_process * 100, loss.item()))
    total_epoch += 1
    model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')
    print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))

训练好后的文件,输出的最后一个文件夹才是效果最好的模型:

predit.py

import torch
from transformers import BertTokenizer, BertForNextSentencePrediction


tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertForNextSentencePrediction.from_pretrained('../output/cn_equal_model_3_171.pth')

with torch.no_grad():
    with open('../data/test.txt', 'r', encoding='utf8') as f:
        lines = f.readlines()
        correct = 0
        for i, line in enumerate(lines):
            text1, text2, label = line.split('\t')
            encoded_text = tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
            outputs = model(**encoded_text)
            res = torch.argmax(outputs.logits, dim=1).item()
            print(text1, text2, label, res)
            if str(res) == label.strip('\n'):
                correct += 1
            print(f'{i + 1}/{len(lines)}')
        print(f'acc:{correct / len(lines)}')

可以看到还是较好的学习了我数据特征:只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

faiss 最相似的 topk

使用 faiss 寻找 topk 相似的,从结果上看最相似的基本都还是找到排到较为靠前的位置

import torch
import faiss
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel


# 假设有一个数据集df,其中包含'index'列和'text'列
df = pd.read_csv('../data/DataAnalyst.csv', encoding='gbk')  # 根据实际情况加载数据集
df = df.dropna().drop_duplicates().reset_index()
df['index'] = df.index
df = df[['index', '公司所在商区']]  # 保留所需列
df['公司所在商区'] = df['公司所在商区'].map(lambda row: ','.join(eval(row)))

# device = torch.device('gpu' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

# 加载微调好的模型和tokenizer
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertModel.from_pretrained('../output/cn_equal_model_3_171.pth')
model.eval()


# 将数据集转化为模型所需的格式并计算所有样本的向量表示
def encode_texts(df):
    text_vectors = []
    for index, row in df.iterrows():
        text = row['公司所在商区']
        inputs = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        with torch.no_grad():
            embeddings = model(**inputs.to(device))['last_hidden_state'][:, 0]
            text_vectors.append(embeddings.cpu().numpy())
        print(f'{index + 1}/{len(df)}')
    return np.vstack(text_vectors)

# 加载数据集并计算所有样本的向量
print('enbedding all data...')
all_embeddings = encode_texts(df)

# 初始化Faiss索引
print('init faiss all embedding...')
index = faiss.IndexFlatIP(all_embeddings.shape[1])  # 使用内积空间,适用于余弦相似度
index.add(all_embeddings)
print('init faiss all embedding finish~~~')


# 定义查找最相似样本的函数
def find_top_k_similar(query_text, k=100):
    print('当前 query_text embedding.')
    query_embedding = encode_single_text(query_text)
    print('begin to search topk....')
    D, I = index.search(query_embedding, k)  # 返回距离和索引
    top_k_indices = df.iloc[I[0]].index.tolist()  # 将索引转换为原始数据集的索引
    return top_k_indices


# 编码单个文本的函数
def encode_single_text(text):
    inputs = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
    with torch.no_grad():
        embedding = model(**inputs.to(device))['last_hidden_state'][:, 0].cpu().numpy()
    print('当前 query_text embedding finish!')
    return embedding


# 示例:找一个query_text的top10相似样本
query_text = "左家庄,国展,西坝河"
top10_indices = find_top_k_similar(query_text)
# 获取与查询文本最相似的前10条原始文本
top10_texts = [df.loc[index, '公司所在商区'] for index in top10_indices]

print(f"与'{query_text}'最相似的前100条样本及其文本:")
for i, (idx, text) in enumerate(zip(top10_indices, top10_texts)):
    print(f"{i+1}. 索引:{idx},文本:{text}")

数据

链接:https://pan.baidu.com/s/1Cpr-ZD9Neakt73naGdsVTw 
提取码:eryw 
链接:https://pan.baidu.com/s/1qHYjXC7UCeUsXVnYTQIPCg 
提取码:o8py 
链接:https://pan.baidu.com/s/1CTntG1Z6AIhiPt6i8Ad97Q 
提取码:x6sz 
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/429782.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

固定资产管理系统包括哪些

固定资产管理是企业经营过程中一项非常重要的任务。它涉及到公司的核心资产,包括土地、建筑物、设备、车辆等。为了有效地管理这些资产,许多企业选择使用固定资产管理系统。那么,固定资产管理系统的内容是什么呢?本文将为您进行全…

O2OA(翱途)通过服务来调用接口实现单点登录案例

本文介绍O2OA服务管理中,接口的权限设定和调用方式。 创建接口 具有服务管理设计权限的用户(具有ServiceManager角色或Manager角色)打开“服务管理平台”,进入接口配置视图,点击左上角的新建按钮,可创建一…

webpack基础配置及使用

webpack是什么 是一个现代 JavaScript 应用程序的静态模块打包器。当webpack 处理应用程序时,它会递归地构建一个依赖关系图 ,其中包含应用程序需要的每个模块,然后将所有这些模块打包成一个或多个 bundle 。主要有 五个核心概念&#xff1a…

11. Nginx进阶-HTTPS

简介 基本概述 SSL SSL是安全套接层。 主要用于认证用户和服务器,确保数据发送到正确的客户机和服务器上。 SSL可以加密数据,防止数据中途被窃取。 SSL也可以维护数据的完整性,确保数据在传输过程中不被改变。 HTTPS HTTPS就是基于SSL来…

vue中使用echarts实现人体动态图

最近一直处于开发大屏的项目,在开发中遇到了一个小知识点,在大屏中如何实现人体动态图。然后看了下echarts官方文档,根据文档中的示例调整出来自己想要的效果。 根据文档上发现 series 中 type 类型设置为 象形柱形图,象形柱图是…

Gitlab 安装部署

目录 1、Jenkins 结合 Gitlab 构建 CI/CD 环境 CI/CD 介绍 CI/CD 流程 Jenkins 简介 GitLab 简介 项目部署方式 CI系统的工作流程 2、搭建 GitLab 安装 GitLab 配置 GitLab 修改root密码 访问 GitLab 开机自启 3、使用 GitLab 管理 GitLab 关闭 GitLab 注册功能…

Conda笔记--移动Conda环境后pip使用异常的解决

1--概述 由于各种原因,需要将Anaconda转变为Minicoda,为了保留之前安装的所有环境,直接将anaconda3/envs的所有环境拷贝到Miniconda/envs中,但在使用移动后环境时会出现pip的错误:bad interpreter: No such file or di…

AWS的RDS数据库开启慢查询日志

#开启慢日志两个参数 slow_query_log 1 设置为1,来启用慢查询日志 long_query_time 5 (单位秒) sql执行多长时间被定义为慢日志1. 点击RDS然后点击参数组,选择slow_query_log,设置为1【表示开启慢日志】点击保存…

力扣hot9---滑动窗口

题目: 先记录一下(没想到有生之年,还能):其实还能优化,后面会讲述优化思路 思路: 滑动窗口的大小就是固定的,就是len_p。那么依次将窗口从s的最左端向右滑动。在当下的窗口中&#x…

python概率分析:为什么葫芦娃救爷爷是一个一个地救成功率最高?

关键词: Python 、葫芦娃 、 概率计算 、 数学 、 建模 前言 过完年了返工后想起了小孩子们爱看的葫芦娃救爷爷的动画片,葫芦娃为什么是一个一个前去救爷爷,为什么不等着七个一起去救爷爷。带着这个疑问,我决定今天用数学的角度…

微信小程序用户隐私保护指引设置

场景:开发小程序时,有时候需要获取用户隐私信息,在提交小程序审核时,需要填写一份隐私保护协议,经常由于填写不规范导致审核不通过,在网上找到了一份模块可供参考 步骤:小程序后台-》设置-》服…

MySQL 学习笔记(基础篇 Day1)

「写在前面」 本文为黑马程序员 MySQL 教程的学习笔记。本着自己学习、分享他人的态度,分享学习笔记,希望能对大家有所帮助。 目录 0 课程介绍 1 MySQL 概述 1.1 数据库相关概念 1.2 MySQL 数据库 2 SQL 2.1 SQL 通用语法 2.2 SQL 分类 2.3 DDL 2.4 图形…

周最佳:詹姆斯场均30.3分8.7助 杰伦-布朗场均28.3分分别当选

直播吧指定地址:www.bdky.cn 3月5日讯 今日NBA官方公布了本赛季第19周周最佳球员,湖人球星勒布朗-詹姆斯和绿军球星杰伦-布朗分别当选。 上周詹姆斯场均可以得到30.3分4.7篮板8.7助攻,湖人取得2胜1负战绩。 布朗场均可以得到28.3分5.3篮板…

Linux中断实验:定时器实现按键消抖处理实验一

一. 简介 前面文章学习了Linux驱动按键中断实验,文章地址如下: Linux驱动按键中断实验:按键中断功能的实现-CSDN博客 本文在Linux驱动按键中断实现的基础上,使用定时器实现按键消抖处理。 二. Linux中断实验:定时器…

java:String和StringBuilder 的相互转换实现字符串拼接

public class StringDemo {/* 练习题:字符串拼接升级版1.定义一个int类型的数组,用静态初始化完成数组元素的初始化2.定义一个方法,用于把int数组中的数据按照指定格式拼接成一个字符串返回3.在方法中用StringBuilder按照要求进行拼接&#x…

新生儿放屁的温馨小贴士:呵护宝宝舒适健康成长

引言 新生儿的生活充满了各种令人惊喜和可爱的瞬间,其中包括他们放臭屁的时刻。尽管这看似简单的行为可能引发父母的担忧,但实际上,它通常是宝宝健康发展的自然表现。在这篇文章中,我们将分享一些关于新生儿放臭屁的注意事项&…

面试经典150题 -- 回溯 (总结)

总的链接 : 面试经典 150 题 - 学习计划 - 力扣(LeetCode)全球极客挚爱的技术成长平台 17 . 电话号码的字母组合 1 . 先创建一个下标 与 对应字符串映射的数组,这里使用hash表进行映射也是可以的 ; 2 . 对于回溯 ,…

MySQL性能优化-Mysql索引篇(1)

什么是索引? 数据库中的索引,就好比一本书的目录,它可以帮我们快速进行特定值的定位与查找,从而加快数据查询的效率。索引就是帮助数据库管理系统高效获取数据的数据结构。如果我们不使用索引,就必须从第 1 条记录开始…

什么台灯护眼效果好?一文搞懂如何正确挑选护眼台灯

现在的孩子学习状态可以用四个字来形容,“学业繁重”,不少孩子从上小学开始,晚上完成功课到八九点都是在正常不过的事情了,因此室内的光线环境是非常重要的,直接影响了视力健康尤其是书桌上的那一盏台灯,有…

012 Linux_线程控制

前言 本文将会向你介绍线程控制(创建(请见上文),终止,等待,分离) 线程控制 线程终止 pthread_t pthread_self(void); 获取线程自身的ID 如果需要只终止某个线程而不终止整个进程,可以有三种…
最新文章