【从零开始实现意图识别】中文对话意图识别详解

前言

意图识别(Intent Recognition)是自然语言处理(NLP)中的一个重要任务,它旨在确定用户输入的语句中所表达的意图或目的。简单来说,意图识别就是对用户的话语进行语义理解,以便更好地回答用户的问题或提供相关的服务。

在NLP中,意图识别通常被视为一个分类问题,即通过将输入语句分类到预定义的意图类别中来识别其意图。这些类别可以是各种不同的任务、查询、请求等,例如搜索、购买、咨询、命令等。

下面是一个简单的例子来说明意图识别的概念:

用户输入: "我想订一张从北京到上海的机票。

意图识别:预订机票。

在这个例子中,通过将用户输入的语句分类到“预订机票”这个意图类别中,系统可以理解用户的意图并为其提供相关的服务。

意图识别是NLP中的一项重要任务,它可以帮助我们更好地理解用户的需求和意图,从而为用户提供更加智能和高效的服务。

在智能对话任务中,意图识别是一种非常重要的技术,它可以帮助系统理解用户的输入,从而提供更加准确和个性化的回答和服务。

模型

意图识别和槽位填充是对话系统中的基础任务。本仓库实现了一个基于BERT的意图(intent)和槽位(slots)联合预测模块。想法上实际与JoinBERT类似(GitHub:BERT for Joint Intent Classification and Slot Filling),利用 [CLS] token对应的last hidden state去预测整句话的intent,并利用句子tokens的last hidden states做序列标注,找出包含slot values的tokens。你可以自定义自己的意图和槽位标签,并提供自己的数据,通过下述流程训练自己的模型,并在JointIntentSlotDetector类中加载训练好的模型直接进行意图和槽值预测。

源GitHub:https://github.com/Linear95/bert-intent-slot-detector

在本文使用的模型中对数据进行了扩充、对代码进行注释、对部分代码进行了修改

Bert模型下载

Bert模型下载地址:https://huggingface.co/bert-base-chinese/tree/main

下载下方红框内的模型即可。

数据集介绍

训练数据以json格式给出,每条数据包括三个关键词:text表示待检测的文本,intent代表文本的类别标签,slots是文本中包括的所有槽位以及对应的槽值,以字典形式给出。

{

"text": "搜索西红柿的做法。",

"domain": "cookbook",

"intent": "QUERY",

"slots": {"ingredient": "西红柿"}

}

原始数据集:https://conference.cipsc.org.cn/smp2019/

本项目中在原始数据集中新增了部分数据,用来平衡数据。

模型训练

python train.py

# -----------training-------------
max_acc = 0
for epoch in range(args.train_epochs):
    total_loss = 0
    model.train()
    for step, batch in enumerate(train_dataloader):
        input_ids, intent_labels, slot_labels = batch

        outputs = model(
            input_ids=torch.tensor(input_ids).long().to(device),
            intent_labels=torch.tensor(intent_labels).long().to(device),
            slot_labels=torch.tensor(slot_labels).long().to(device)
        )

        loss = outputs['loss']
        total_loss += loss.item()

        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        loss.backward()

        if step % args.gradient_accumulation_steps == 0:
            # 用于对梯度进行裁剪,以防止在神经网络训练过程中出现梯度爆炸的问题。
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            optimizer.step()
            scheduler.step()
            model.zero_grad()

    train_loss = total_loss / len(train_dataloader)

    dev_acc, intent_avg, slot_avg = dev(model, val_dataloader, device, slot_dict)

    flag = False
    if max_acc < dev_acc:
        max_acc = dev_acc
        flag = True
        save_module(model, model_save_dir)
    print(f"[{epoch}/{args.train_epochs}] train loss: {train_loss}  dev intent_avg: {intent_avg} "
          f"def slot_avg: {slot_avg} save best model: {'*' if flag else ''}")

dev_acc, intent_avg, slot_avg = dev(model, val_dataloader, device, slot_dict)
print("last model dev intent_avg: {} def slot_avg: {}".format(intent_avg, slot_avg))

运行过程:

模型推理

python predict.py
 
def detect(self, text, str_lower_case=True):
    """
    text : list of string, each string is a utterance from user
    """
    list_input = True

    if isinstance(text, str):
        text = [text]
        list_input = False

    if str_lower_case:
        text = [t.lower() for t in text]

    batch_size = len(text)

    inputs = self.tokenizer(text, padding=True)

    with torch.no_grad():
        outputs = self.model(input_ids=torch.tensor(inputs['input_ids']).long().to(self.device))

    intent_logits = outputs['intent_logits']
    slot_logits = outputs['slot_logits']

    intent_probs = torch.softmax(intent_logits, dim=-1).detach().cpu().numpy()
    slot_probs = torch.softmax(slot_logits, dim=-1).detach().cpu().numpy()

    slot_labels = self._predict_slot_labels(slot_probs)
    intent_labels = self._predict_intent_labels(intent_probs)

    slot_values = self._extract_slots_from_labels(inputs['input_ids'], slot_labels, inputs['attention_mask'])

    outputs = [{'text': text[i], 'intent': intent_labels[i], 'slots': slot_values[i]}
               for i in range(batch_size)]

    if not list_input:
        return outputs[0]

    return outputs

推理结果:

模型检测相关代码

将概率值转换为实际标注值

def _predict_slot_labels(self, slot_probs):
    """
    slot_probs : probability of a batch of tokens into slot labels, [batch, seq_len, slot_label_num], numpy array
    """
    slot_ids = np.argmax(slot_probs, axis=-1)
    return self.slot_dict[slot_ids.tolist()]

def _predict_intent_labels(self, intent_probs):
    """
    intent_labels : probability of a batch of intent ids into intent labels, [batch, intent_label_num], numpy array
    """
    intent_ids = np.argmax(intent_probs, axis=-1)
    return self.intent_dict[intent_ids.tolist()]

槽位验证(确保检测结果的正确性)

def _extract_slots_from_labels_for_one_seq(self, input_ids, slot_labels, mask=None):
    results = {}
    unfinished_slots = {}  # dict of {slot_name: slot_value} pairs
    if mask is None:
        mask = [1 for _ in range(len(input_ids))]

    def add_new_slot_value(results, slot_name, slot_value):
        if slot_name == "" or slot_value == "":
            return results
        if slot_name in results:
            results[slot_name].append(slot_value)
        else:
            results[slot_name] = [slot_value]
        return results

    for i, slot_label in enumerate(slot_labels):
        if mask[i] == 0:
            continue
        # 检测槽位的第一字符(B_)开头
        if slot_label[:2] == 'B_':
            slot_name = slot_label[2:]  # 槽位名称 (B_ 后面)
            if slot_name in unfinished_slots:
                results = add_new_slot_value(results, slot_name, unfinished_slots[slot_name])
            unfinished_slots[slot_name] = self.tokenizer.decode(input_ids[i])
        # 检测槽位的后面字符(I_)开头
        elif slot_label[:2] == 'I_':
            slot_name = slot_label[2:]
            if slot_name in unfinished_slots and len(unfinished_slots[slot_name]) > 0:
                unfinished_slots[slot_name] += self.tokenizer.decode(input_ids[i])

    for slot_name, slot_value in unfinished_slots.items():
        if len(slot_value) > 0:
            results = add_new_slot_value(results, slot_name, slot_value)

    return results

源码获取

NLP/bert-intent-slot at main · mzc421/NLP (github.com)icon-default.png?t=N7T8https://github.com/mzc421/NLP/tree/main/bert-intent-slot

链接作者

欢迎关注我的公众号:@AI算法与电子竞赛

硬性的标准其实限制不了无限可能的我们,所以啊!少年们加油吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/182340.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JS数组常用的20种方法详解(每一个方法都有例子,超全面,超好理解的教程,干货满满)

目录 1.会改变原数组的方法&#xff08;7种&#xff09; 1.push() 2.pop() 3.unshift() 4.shift() 5.reverse() 6.sort() 7.splice() 2.不改变原数组的方法&#xff08;13种&#xff0c;返回的新数组是从原数组浅拷贝来的&#xff09; 1.concat() 2.join() 3.slice…

九、ffmpeg命令转封装

开了几天小差&#xff0c;今天继续学习ffmpeg。 准备测试使用的视频&#xff0c;并查看其信息 # 查看视频信息。使用Mediainfo也可以 ffprobe test.mp4 视频格式的信息如下。 保持编码格式&#xff1a;ffmpeg -i test.mp4 -vcodec copy -acodec copy test_copy.tsffmpeg -i…

2015年2月4日 Go生态洞察:Go语言中的包命名艺术

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

王道p150 20.将给定的表达式树转化为等价的中缀表达式(通过括号反应操作符的计算次序)

本题代码如下 void btreetoexp(tree t, char deep) {if (t NULL)return;else if (t->lchild NULL && t->rchild NULL)printf("%c", t->data);//输出操作数&#xff0c;不加括号else {if (deep > 1)printf("(");btreetoexp(t->l…

231123 刷题日报-动态规划

今天主要看了DP&#xff0c;前几天频繁遇到DP打击有点大。。 1. 0-1背包问题 要点&#xff1a; a. 三部曲&#xff1a; 1. 状态和选择 状态&#xff1a;物品序号、背包容量 选择&#xff1a;放、不放 2. dp数组定义、base case dp[i][w] 对于前i个物品&#xff0c;当前背包…

UNETR:用于三维医学图像分割的Transformer

论文链接&#xff1a;https://arxiv.org/abs/2103.10504 代码链接&#xff1a; https://monai.io/research/unetr 机构&#xff1a;Vanderbilt University, NVIDIA 最近琢磨不出来怎么把3d体数据和文本在cnn中融合&#xff0c;因为确实存在在2d里面用的transformer用在3d里面…

leetcode刷题之用栈实现队列(C语言版)

leetcode刷题之用栈实现队列&#xff08;C语言版&#xff09; 一、题目描述二、题目要求三、题目解析Ⅰ、typedef structⅡ、MyQueue* myQueueCreateⅢ、void myQueuePush(MyQueue* obj, int x)Ⅳ、int myQueuePeek(MyQueue* obj)Ⅴ、int myQueuePop(MyQueue* obj)Ⅶ、bool myQ…

编译器核心技术概览

编译技术是一门庞大的学科&#xff0c;我们无法对其做完善的讲解。但不同用途的编译器或编译技术的难度可能相差很大&#xff0c;对知识的掌握要求也会相差很多。如果你要实现诸如 C、JavaScript 这类通用用途语言&#xff08;general purpose language&#xff09;&#xff0c…

[shader] 光照入门(未完结。。。

反射 漫反射&#xff1a;而当物体表面粗糙时&#xff0c;我们把物体表面看作无数不同方向的微小镜面&#xff0c;则这些镜面反射出的光方向均不相同&#xff0c;这就是漫反射。 高光反射&#xff1a;我们假定物体表面光滑&#xff0c;只有一个镜面&#xff0c;那么所有的光都…

微信小程序前端环境搭建

搭建微信小程序前端环境 申请小程序测试账号 访问路径 使用微信扫描二维码进行申请&#xff0c;申请成功之后&#xff0c;进入界面&#xff0c;获取小程序ID(AppID)和秘钥(AppSecret) 安装微信web开发者工具 访问路径 选择稳定开发的版本 需要在小程序的设置中将默认关闭…

深入理解JVM 类加载机制

深入理解JVM 类加载机制 虚拟机如何加载Class文件&#xff1f; Class文件中的信息进入到虚拟机后会发生什么变化&#xff1f; 类加载机制就是Java虚拟机把描述类的数据从Class文件加载到内存&#xff0c;并对数据进行校验、转换解析和初始化&#xff0c;最终形成可以被虚拟机…

AMEYA360:瑞萨面向高端工业传感器系统推出高精度模拟前端的32位RX MCU

全球半导体解决方案供应商瑞萨电子&#xff08;TSE&#xff1a;6723&#xff09;宣布面向高端工业传感器系统推出一款全新RX产品——RX23E-B&#xff0c;扩展32位微控制器&#xff08;MCU&#xff09;产品线。新产品作为广受欢迎的RX产品家族的一员&#xff0c;具有高精度模拟前…

3D火山图绘制教程

一边学习&#xff0c;一边总结&#xff0c;一边分享&#xff01; 本期教程内容 **注&#xff1a;**本教程详细内容 Volcano3D绘制3D火山图 一、前言 火山图是做差异分析中最常用到的图形&#xff0c;在前面的推文中&#xff0c;我们也推出了好几期火山图的绘制教程&#xff0…

如何通过宝塔面板搭建一个本地MySQL数据库服务并实现远程访问

宝塔安装MySQL数据库&#xff0c;并内网穿透实现公网远程访问 文章目录 宝塔安装MySQL数据库&#xff0c;并内网穿透实现公网远程访问前言1.Mysql服务安装2.创建数据库3.安装cpolar3.2 创建HTTP隧道 4.远程连接5.固定TCP地址5.1 保留一个固定的公网TCP端口地址5.2 配置固定公网…

Axios 通过a标签下载文件 跨域下载

<!-- a标签占位 --><a ref"down" ></a>getTest() {this.$axios.request({url: https://cnv13.55.la/download?file_key3695fa9461a0ae59cf3148581e4fe339&handle_typeexcel2pdf,method: get,responseType: blob, // 切记类型 blob}).then(re…

java:CommandLineRunner命令行操作

背景 CommandLineRunner是一个SpringBoot提供的接口&#xff0c;这个接口可以让我们在SpringBoot启动之后&#xff0c;执行一些特定的命令行操作。 实现CommandLineRunner接口后&#xff0c;SpringBoot在启动的时候会自动执行run方法。通常&#xff0c;我们可以在run方法中进…

TableStructureRec: 表格结构识别推理库来了

目录 引言lineless_table_rec: 无线表格识别库安装使用结果 wired_table_rec&#xff1a;有线表格识别库安装使用结果 写在最后 引言 TableStructureRec 仓库是用来对文档中表格做结构化识别的推理库&#xff0c;包括来自 PaddleOCR 的表格结构识别算法模型、来自阿里读光有线…

Python监控服务进程及自启动服务方法与实践

1. 需求概述 当我们在Windows Server环境中部署XX系统的实际应用中&#xff0c;往往会遇到一些运维管理的挑战。为了确保系统的持续稳定运行&#xff0c;特别是在服务程序因各种原因突然关闭的情况下&#xff0c;我们可以借助Python的强大生态系统来构建一个监控与自动重启的管…

Linux下载工具XDM下载安装与使用

Windows上IDM多线程下载非常强大&#xff0c;即能捕捉页面上的视频、图片、音频&#xff0c;又能作为浏览器下载器使用&#xff0c;但是IDM无法在Linux下使用&#xff0c;除非使用wine。不过我们可以在Linux中用XDM(Xtreme Download Manager)代替IDM。 1、XDM下载 Xtreme Dow…

线性回归中的函数求导

在线性回归中&#xff0c;函数求导是一个重要的数学工具&#xff0c;用于计算损失函数关于模型参数的导数。通过求导&#xff0c;我们可以找到最优的参数值&#xff0c;以实现更好的线性回归拟合。 本文将介绍线性回归的基本原理&#xff0c;以及如何通过函数求导来优化线性回…