【小样本命名实体识别】COPNER论文源码详解

COPNER: Contrastive Learning with Prompt Guiding for Few-shot Named Entity Recognition

原文与代码链接: https://github.com/AndrewHYC/COPNER

一、项目结构

在这里插入图片描述

二、代码分析

1.定义参数

配置训练环境

parser.add_argument('--gpu', default='0',
        help='the gpu number for traning')

parser.add_argument('--seed', type=int, default=42,
        help='random seed')

训练任务定义

parser.add_argument('--mode', default='inter',
        help='training mode, must be in [inter, intra, supervised, i2b2, conll, wnut, mit-movie]')
parser.add_argument('--task', default='cross-label-space',
        help='training task, must be in [cross-label-space, domain-transfer, in-label-space]')

parser.add_argument('--trainN', default=5, type=int,
        help='N in train')
parser.add_argument('--N', default=5, type=int,
        help='N way')
parser.add_argument('--K', default=1, type=int,
        help='K shot')
parser.add_argument('--Q', default=1, type=int,
        help='Num of query per class')

parser.add_argument('--support_num', default=0, type=int,
        help='the id number of support set')

parser.add_argument('--zero_shot', action='store_true',
        help='')

parser.add_argument('--only_test', action='store_true',
        help='only test')

parser.add_argument('--load_ckpt', default=None,
        help='load ckpt')
parser.add_argument('--ckpt_name', type=str, default='',
        help='checkpoint name.')

模型配置

parser.add_argument('--pretrain_ckpt', default='./premodel/roberta-wwm-ext-base',
       help='bert pre-trained checkpoint: bert-base-uncased / bert-base-cased')

parser.add_argument('--prompt', default=1, type=int, choices=[0,1,2],
        help='choice in [0,1,2]:\
                0: Continue Prompt\
                1: Partition Prompt\
                2: Queue Prompt')
parser.add_argument('--pseudo_token', default='[S]', type=str,
        help='pseudo_token')

parser.add_argument('--max_length', default=64, type=int,
        help='max length')

parser.add_argument('--ignore_index', type=int, default=-1,
        help='label index to ignore when calculating loss and metrics')

parser.add_argument('--struct', action='store_true',
        help='StructShot parameter to re-normalizes the transition probabilities')

parser.add_argument('--tau', default=1, type=float,
        help='the temperature rate for contrastive learning')

parser.add_argument('--struct_tau', default=0.32, type=float,
        help='the tau in the viterbi decode')

训练配置

parser.add_argument('--batch_size', default=16, type=int,
        help='batch size')
parser.add_argument('--test_bz', default=1, type=int,
        help='test or val batch size')

parser.add_argument('--train_iter', default=10000, type=int,
        help='num of iters in training')
parser.add_argument('--val_iter', default=200, type=int,
        help='num of iters in validation')
parser.add_argument('--test_iter', default=5000, type=int,
        help='num of iters in testing')
parser.add_argument('--val_step', default=200, type=int,
        help='val after training how many iters')

parser.add_argument('--adapt_step', default=5, type=int,
        help='adapting how many iters in validing or testing')
parser.add_argument('--adapt_auto', action='store_true',
        help='adapting how many iters in validing or testing')

parser.add_argument('--threshold_alpha', default=0.1, type=float,
        help='Gradient descent change threshold for early stopping')
parser.add_argument('--threshold_beta', default=0.5, type=float,
        help='loss threshold for early stopping')

parser.add_argument('--lr', default=1e-4, type=float,
        help='learning rate of Training')

parser.add_argument('--adapt_lr', default=None, type=float,
        help='learning rate of Adapting')

parser.add_argument('--grad_iter', default=1, type=int,
        help='accumulate gradient every x iterations')
parser.add_argument('--early_stopping', type=int, default=3000,
                    help='iteration numbers to stop without performance increasing')

parser.add_argument('--use_sgd_for_lm', action='store_true',
        help='use SGD instead of AdamW for BERT.')
2.主函数

调用参数,配置预训练模型

def main():
    trainN = opt.trainN if opt.trainN is not None else opt.N # opt.trainN = opt.N = 5
    N = opt.N # 5
    K = opt.K # 1
    Q = opt.Q # 1
    max_length = opt.max_length # 64
    
    if opt.adapt_lr is None and opt.lr: # opt.adapt_lr = None / opt.lr = 1e-4
        opt.adapt_lr = opt.lr

    print("{}-way-{}-shot Few-Shot NER".format(N, K))
    print('task: {}'.format(opt.task))
    print('mode: {}'.format(opt.mode))
    print('prompt: {}'.format(opt.prompt))
    print("support: {}".format(opt.support_num))
    print("max_length: {}".format(max_length))
    print("batch_size: {}".format(opt.test_bz if opt.only_test else opt.batch_size))

    set_seed(opt.seed)
    print('loading model and tokenizer...')
    pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased'

    config = BertConfig.from_pretrained(pretrain_ckpt)
    tokenizer = BertTokenizer.from_pretrained(pretrain_ckpt)
    opt.tokenizer = tokenizer
    word_encoder = BERTWordEncoder.from_pretrained(pretrain_ckpt, config=config, args=opt)

加载数据集

if opt.task == 'cross-label-space':
        opt.train = f'data/few-nerd/{opt.mode}/train.txt'
        opt.dev = f'data/few-nerd/{opt.mode}/dev.txt'
        opt.test = f'data/few-nerd/{opt.mode}/test.txt'

        opt.train_word_map = opt.dev_word_map = opt.test_word_map = FEWNERD_WORD_MAP

        print(f'loading train data: {opt.train}')
        train_data_loader = get_loader(opt.train, tokenizer, word_map = opt.train_word_map,
                N=trainN, K=1, Q=Q, batch_size=opt.batch_size, max_length=max_length, # K=1 for training
                ignore_index=opt.ignore_index, args=opt, train=True)
        print(f'loading eval data: {opt.dev}')
        val_data_loader = get_loader(opt.dev, tokenizer, word_map = opt.dev_word_map,
                N=N, K=K, Q=Q, batch_size=opt.test_bz, max_length=max_length, 
                ignore_index=opt.ignore_index, args=opt)
        print(f'loading test data: {opt.test}')
        test_data_loader = get_loader(opt.test, tokenizer, word_map = opt.test_word_map,
                N=N, K=K, Q=Q, batch_size=opt.test_bz, max_length=max_length, 
                ignore_index=opt.ignore_index, args=opt)
3.get_loader

N=5 K=1 for training, Q=1 batch_size=16 ignore_index=-1 opt.train_word_map = opt.dev_word_map = opt.test_word_map = FEWNERD_WORD_MAP

FEWNERD_WORD_MAP
先初始化定义一个OrderedDict,然后按照键值对插入,此时dict可以记录插入字典的顺序

from collections import OrderedDict
# # Few-NERD
FEWNERD_WORD_MAP = OrderedDict()

FEWNERD_WORD_MAP['O'] = 'none'

FEWNERD_WORD_MAP['location-GPE'] = 'nation'
FEWNERD_WORD_MAP['location-bodiesofwater'] = 'water'
FEWNERD_WORD_MAP['location-island'] = 'island'
FEWNERD_WORD_MAP['location-mountain'] = 'mountain'
FEWNERD_WORD_MAP['location-park'] = 'parks'
FEWNERD_WORD_MAP['location-road/railway/highway/transit'] = 'road'
FEWNERD_WORD_MAP['location-other'] = 'location'

FEWNERD_WORD_MAP['person-actor'] = 'actor'
FEWNERD_WORD_MAP['person-artist/author'] = 'artist'
FEWNERD_WORD_MAP['person-athlete'] = 'athlete'
FEWNERD_WORD_MAP['person-director'] = 'director'
FEWNERD_WORD_MAP['person-politician'] = 'politician'
FEWNERD_WORD_MAP['person-scholar'] = 'scholar'
FEWNERD_WORD_MAP['person-soldier'] = 'soldier'
FEWNERD_WORD_MAP['person-other'] = 'person'

FEWNERD_WORD_MAP['organization-company'] = 'company'
FEWNERD_WORD_MAP['organization-education'] = 'education'
FEWNERD_WORD_MAP['organization-government/governmentagency'] = 'government'
FEWNERD_WORD_MAP['organization-media/newspaper'] = 'media'
FEWNERD_WORD_MAP['organization-politicalparty'] = 'parties'
FEWNERD_WORD_MAP['organization-religion'] = 'religion'
FEWNERD_WORD_MAP['organization-showorganization'] = 'show'
FEWNERD_WORD_MAP['organization-sportsleague'] = 'league'
FEWNERD_WORD_MAP['organization-sportsteam'] = 'team'
FEWNERD_WORD_MAP['organization-other'] = 'organization'

FEWNERD_WORD_MAP['building-airport'] = 'airport'
FEWNERD_WORD_MAP['building-hospital'] = 'hospital'
FEWNERD_WORD_MAP['building-hotel'] = 'hotel'
FEWNERD_WORD_MAP['building-library'] = 'library'
FEWNERD_WORD_MAP['building-restaurant'] = 'restaurant'
FEWNERD_WORD_MAP['building-sportsfacility'] = 'facility'
FEWNERD_WORD_MAP['building-theater'] = 'theater'
FEWNERD_WORD_MAP['building-other'] = 'building'

FEWNERD_WORD_MAP['art-broadcastprogram'] = 'broadcast'
FEWNERD_WORD_MAP['art-film'] = 'film'
FEWNERD_WORD_MAP['art-music'] = 'music'
FEWNERD_WORD_MAP['art-painting'] = 'painting'
FEWNERD_WORD_MAP['art-writtenart'] = 'writing'
FEWNERD_WORD_MAP['art-other'] = 'art'

FEWNERD_WORD_MAP['product-airplane'] = 'airplane'
FEWNERD_WORD_MAP['product-car'] = 'car'
FEWNERD_WORD_MAP['product-food'] = 'food'
FEWNERD_WORD_MAP['product-game'] = 'game'
FEWNERD_WORD_MAP['product-ship'] = 'ship'
FEWNERD_WORD_MAP['product-software'] = 'software'
FEWNERD_WORD_MAP['product-train'] = 'train'
FEWNERD_WORD_MAP['product-weapon'] = 'weapon'
FEWNERD_WORD_MAP['product-other'] = 'product'

FEWNERD_WORD_MAP['event-attack/battle/war/militaryconflict'] = 'war'
FEWNERD_WORD_MAP['event-disaster'] = 'disaster'
FEWNERD_WORD_MAP['event-election'] = 'election'
FEWNERD_WORD_MAP['event-protest'] = 'protest'
FEWNERD_WORD_MAP['event-sportsevent'] = 'sport'
FEWNERD_WORD_MAP['event-other'] = 'event'

FEWNERD_WORD_MAP['other-astronomything'] = 'astronomy'
FEWNERD_WORD_MAP['other-award'] = 'award'
FEWNERD_WORD_MAP['other-biologything'] = 'biology'
FEWNERD_WORD_MAP['other-chemicalthing'] = 'chemistry'
FEWNERD_WORD_MAP['other-currency'] = 'currency'
FEWNERD_WORD_MAP['other-disease'] = 'disease'
FEWNERD_WORD_MAP['other-educationaldegree'] = 'degree'
FEWNERD_WORD_MAP['other-god'] = 'god'
FEWNERD_WORD_MAP['other-language'] = 'language'
FEWNERD_WORD_MAP['other-law'] = 'law'
FEWNERD_WORD_MAP['other-livingthing'] = 'organism'
FEWNERD_WORD_MAP['other-medical'] = 'medical'
def get_loader(filepath, tokenizer, N, K, Q, batch_size, max_length, word_map,
        ignore_index=-1, args=None, num_workers=4, support_file_path=None, train=False):
    if train:
        dataset = SingleDatasetwithEpisodeSample(N, 1, filepath, tokenizer, max_length, 
                                                        ignore_label_id=ignore_index, 
                                                        args=args, word_map=word_map)
        return data.DataLoader(dataset=dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                pin_memory=True,
                                num_workers=num_workers,
                                collate_fn=single_collate_fn)
    else:
        if args.task in ['cross-label-space']:
            dataset = PairDatasetwithEpisodeSample(N, K, Q, filepath, tokenizer, max_length, 
                                                        ignore_label_id=ignore_index, 
                                                        args=args, word_map=word_map)
            return data.DataLoader(dataset=dataset,
                                    batch_size=1,
                                    shuffle=True,
                                    pin_memory=True,
                                    num_workers=num_workers,
                                    collate_fn=pair_collate_fn)
        elif args.task in ['domain-transfer']:
            dataset = PairDatasetwithFixedSupport(N, filepath, support_file_path, tokenizer, max_length,
                                                        ignore_label_id=ignore_index,
                                                        args=args, word_map=word_map)
            return data.DataLoader(dataset=dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    pin_memory=True,
                                    num_workers=num_workers,
                                    collate_fn=pair_collate_fn)
        elif args.task in ['in-label-space']:
            dataset = SingleDatasetwithRamdonSample(filepath, tokenizer, max_length, 
                                                        ignore_label_id=ignore_index, 
                                                        args=args, word_map=word_map)
        
            return data.DataLoader(dataset=dataset,
                                        batch_size=batch_size,
                                        shuffle=False,
                                        pin_memory=True,
                                        num_workers=num_workers,
                                        collate_fn=single_collate_fn)
4.SingleDatasetwithEpisodeSample

继承自 PairDatasetwithEpisodeSample 类,该类用于处理单数据集的示例采样。

class SingleDatasetwithEpisodeSample(PairDatasetwithEpisodeSample):

    def __init__(self, N, K, filepath, tokenizer, max_length, word_map, ignore_label_id=-1, args=None):
        if not os.path.exists(filepath):
            print("[ERROR] Data file does not exist!")
            assert(0)
        self.class2sampleid = {}
        self.word_map = word_map
        self.word2class = OrderedDict()
        for key, value in self.word_map.items():
            self.word2class[value] = key

        self.BOS = '[CLS]'
        self.EOS = '[SEP]'

        self.max_length = max_length
        self.ignore_label_id = ignore_label_id

        self.samples, self.classes = self.__load_data_from_file__(filepath)
        
        self.sampler = SingleFewshotSampler(N, K, self.samples, classes=self.classes)

        self.prompt = args.prompt
        self.tokenizer = tokenizer
        self.pseudo_token = args.pseudo_token
        self.tokenizer.add_special_tokens({'additional_special_tokens': [args.pseudo_token]})


    def __getitem__(self, index):
        target_classes, support_idx = self.sampler.__next__()
        # add 'none' and make sure 'none' is labeled 0
        distinct_tags = [self.word_map['O']] + target_classes
        prompt_tags = distinct_tags.copy()
        random.shuffle(prompt_tags)
        self.tag2label = {tag:idx for idx, tag in enumerate(distinct_tags)}
        self.label2tag = {idx:self.word2class[tag] for idx, tag in enumerate(distinct_tags)}
        support_set = self.__populate__(support_idx, distinct_tags, prompt_tags, savelabeldic=True)

        return support_set
    
    def __len__(self):
        return 1000000
  1. init方法:初始化类的实例。参数包括 N、K、filepath、tokenizer、max_length、word_map、ignore_label_id 和 args。在初始化过程中,首先检查给定的文件路径是否存在,然后设置一些实例变量,如 word_map、BOS、EOS、max_length、ignore_label_id 等。随后从文件中加载数据,并使用 SingleFewshotSampler 对象创建一个采样器。最后设置一些额外变量,如 prompt、tokenizer、pseudo_token 等;
  2. getitem方法:根据给定的索引,获取采样数据。通过采样器获取目标类别和支持集索引,然后创建不同标签序列和随机标签序列。接着根据标签生成一些支持集数据,并返回支持集数据。

在这里插入图片描述
在这里插入图片描述

load_data_from_file

def __load_data_from_file__(self, filepath):
        samples = [] # 存储样本
        classes = [] # 存储类别
        with open(filepath, 'r', encoding='utf-8')as f:
            lines = f.readlines()
        samplelines = []
        index = 0
        for line in lines:
            line = line.strip()
            if len(line.split('\t'))>1: # 若一行中包含制表符'\t',则将改行添加到samplelines列表中,表示这一行是样本数据的一部分
                samplelines.append(line)
            else:
                # 若不包含制表符,则表示当前行是样本的结束,开始处理新的样本,将samplelines列表中的数据用于创建一个Sample对象
                sample = Sample(samplelines, self.word_map)
                samples.append(sample)
                # 从Sample对象中获取标签类别,通过get_tag_class方法获取,并将这些类别添加到classes列表中
                sample_classes = sample.get_tag_class()
                self.__insert_sample__(index, sample_classes)
                classes += sample_classes
                samplelines = [] # 清空samplelines列表
                index += 1 # 将index加1
        classes = list(set(classes)) # 遍历完成后将classes列表转换为集合,去除重复的类别
        return samples, classes

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

SingleFewshotSampler

class SingleFewshotSampler(PairFewshotSampler):
    def __init__(self, N, K, samples, classes=None, random_state=0):
        '''
        N: int, how many types in each set
        K: int, how many instances for each type in data set
        samples: List[Sample], Sample class must have `get_class_count` attribute
        classes[Optional]: List[any], all unique classes in samples. If not given, the classes will be got from samples.get_class_count()
        random_state[Optional]: int, the random seed
        '''
        self.K = K
        self.N = N
        self.samples = samples
        self.__check__() # check if samples have correct types
        if classes:
            self.classes = classes
        else:
            self.classes = self.__get_all_classes__()
        random.seed(random_state)

    def __next__(self):
        '''
        randomly sample one episode set
        '''
        episode_class = {'k':self.K}
        episode_idx = []
        target_classes = random.sample(self.classes, self.N)
        candidates = self.__get_candidates__(target_classes)
        while not candidates:
            target_classes = random.sample(self.classes, self.N)
            candidates = self.__get_candidates__(target_classes)

        # greedy search for episode set
        while not self.__finish__(episode_class):
            index = random.choice(candidates)
            
            if index not in episode_idx:
                if self.__valid_sample__(self.samples[index], episode_class, target_classes):
                    
                    self.__additem__(index, episode_class)
                    episode_idx.append(index)

        return target_classes, episode_idx

这段代码定义了一个名为 SingleFewshotSampler 的类,它继承自 PairFewshotSampler。SingleFewshotSampler 的目的是从一个包含多种类别(types)的数据集中采样少数样本(few-shot),以用于训练或测试。

  1. 初始化函数 init: N: 每个集合中类型的数量。 K: 每个类型在数据集中的样本数量。 samples: 一个样本列表,每个样本必须有一个 get_class_count 属性。classes: 样本中所有独特类别的列表。如果没有提供,则从样本的 get_class_count() 中获取。
  2. random_state: 随机种子,用于保证可重复性。
  3. check 方法: 检查 samples 是否具有正确的类型。
  4. get_all_classes 方法: 如果没有提供 classes,则通过调用每个样本的 get_class_count 方法来获取所有独特的类别。
  5. next 方法: 随机采样一个样本集(episode set)。
    episode_class: 存储采样的类别的字典。
    episode_idx: 存储被采样的样本索引的列表。
    target_classes: 从所有类别中随机选取的类别列表,数量为 N。
    candidates: 根据 target_classes 获取的可选样本索引列表。
    如果 candidates 为空,会重新随机选择类别,直到找到有候选样本的类别。使用贪心搜索(greedy search)构建一个样本集,直到满足某个条件(由 finish 方法确定)。
    finish 方法: 判断是否已经完成一个样本集的构建。具体的完成条件在 finish 方法中定义,但代码中这个方法没有给出。
    additem 方法: 向 episode_class 中添加一个样本。
    get_candidates 方法: 根据目标类别 target_classes 获取候选样本索引列表。
    valid_sample 方法: 判断给定的样本是否有效,即是否满足采样器对于样本的要求。
    整体来看,这个类是为了实现一种特定类型的少样本学习(few-shot learning)策略,其中每个类别只随机选择少数样本进行训练。代码中的某些方法(如 finishvalid_sample)没有给出具体实现,所以无法完全确定这个采样器的所有行为。
5.加载模型类CopNER
model = COPNER(word_encoder, opt, opt.train_word_map if not opt.only_test else opt.test_word_map)
class COPNER(FewShotNERModel):
    
    def __init__(self, word_encoder, args, word_map):
        FewShotNERModel.__init__(self, word_encoder, ignore_index=args.ignore_index)
        self.tokenizer = args.tokenizer
        self.tau = args.tau
        # 初始化损失函数loss_fct为CrossEntropyLoss,用于分类问题,并设置忽略索引
        self.loss_fct = CrossEntropyLoss(ignore_index=args.ignore_index)
        self.method = 'euclidean'

        self.class2word = word_map
        self.word2class = OrderedDict()
        for key, value in self.class2word.items():
            self.word2class[value] = key

    def __dist__(self, x, y, dim, normalize=False):
        if normalize: # 对向量进行归一化处理
            x = F.normalize(x, dim=-1)         
            y = F.normalize(y, dim=-1)
        if self.method == 'dot': # 点积
            sim = (x * y).sum(dim)
        elif self.method == 'euclidean': # 欧氏距离
            sim = -(torch.pow(x - y, 2)).sum(dim)
        elif self.method == 'cosine': # 余弦相似度
            sim = F.cosine_similarity(x, y, dim=dim)
        return sim / self.tau
    
    def get_contrastive_logits(self, hidden_states, inputs, valid_mask, target_classes): # 获取对比损失
        class_indexs = [self.tokenizer.get_vocab()[tclass] for tclass in target_classes] # 获取目标类别的索引列表class_indexs

        class_rep = [] 
        for iclass in class_indexs:
            class_rep.append(torch.mean(hidden_states[inputs.eq(iclass), :].view(-1, hidden_states.size(-1)), 0))
        
        class_rep = torch.stack(class_rep).unsqueeze(0) # 计算每个类别的代表性向量class_rep
        token_rep = hidden_states[valid_mask != self.tokenizer.pad_token_id, :].view(-1, hidden_states.size(-1)).unsqueeze(1)

        logits = self.__dist__(class_rep, token_rep, -1)

        return logits.view(-1, len(target_classes))

    def forward(self,
                input_ids,
                labels,
                valid_masks,
                target_classes,
                sentence_num,
                ):
        # 验证输入数据的尺寸是否一致
        assert input_ids.size(0) == labels.size(0) == valid_masks.size(0), \
                print('[ERROR] inputs and labels must have same batch size.')
        assert len(sentence_num) == len(target_classes)
        # 通过词编码器获得隐藏状态hidden_states
        hidden_states = self.word_encoder(input_ids) # logits, (encoder_hs, decoder_hs)
        
        loss = None
        logits = []
        current_num = 0
        # 对于每个句子,计算对比损失,若处于训练状态,累加损失
        for i, num in enumerate(sentence_num):
            current_hs = hidden_states[current_num: current_num+num]
            current_input_ids = input_ids[current_num: current_num+num]
            current_labels = labels[current_num: current_num+num]
            current_valid_masks = valid_masks[current_num: current_num+num]
            current_target_classes = target_classes[i]

            current_num += num

            contrastive_logits = self.get_contrastive_logits(current_hs, 
                                                        current_input_ids, 
                                                        current_valid_masks, 
                                                        current_target_classes)
            
            current_logits = F.softmax(contrastive_logits, -1)

            if self.training:
                contrastive_loss = self.loss_fct(contrastive_logits, current_labels[current_valid_masks != self.tokenizer.pad_token_id].view(-1))
                loss = contrastive_loss if loss is None else loss + contrastive_loss

            current_logits = current_logits.view(-1, current_logits.size(-1))

            logits.append(current_logits)
        # 计算每个句子的logits,并将其堆叠起来
        logits = torch.cat(logits, 0)
        _, preds = torch.max(logits, 1) # 预测结果
        
        # 返回平均损失
        if loss:
            loss /= len(sentence_num)

        return logits, preds, loss
6.实现少样本命名实体识别(NER)的框架
framework = FewShotNERFramework(opt, train_data_loader, val_data_loader, test_data_loader,
                                        train_fname=opt.train if opt.struct else None, 
                                        viterbi=True if opt.struct else False)

FewShotNERFramework

class FewShotNERFramework:

    def __init__(self, args, train_data_loader, val_data_loader, test_data_loader, viterbi=False, train_fname=None):
        '''
        train_data_loader: DataLoader for training.
        val_data_loader: DataLoader for validating.
        test_data_loader: DataLoader for testing.
        viterbi: Whether to use Viterbi decoding.
        train_fname: Path of the data file to get abstract transitions.
        '''
        self.args = args
        self.train_data_loader = train_data_loader
        self.val_data_loader = val_data_loader
        self.test_data_loader = test_data_loader
        self.viterbi = viterbi
        if viterbi: # 是否使用维特比解码器来进行序列标注任务的解码
            abstract_transitions = get_abstract_transitions(train_fname, args)
            self.viterbi_decoder = ViterbiDecoder(self.args.N+2, abstract_transitions, tau=args.struct_tau)

get_abstract_transitions

def get_abstract_transitions(train_fname, args):
    """
    Compute abstract transitions on the training dataset for StructShot
    """
    samples = SingleDatasetwithRamdonSample(train_fname, None, None, word_map=args.train_word_map, args=args).samples
    tag_lists = [sample.tags for sample in samples]

    s_o, s_i = 0., 0.
    o_o, o_i = 0., 0.
    i_o, i_i, x_y = 0., 0., 0.
    for tags in tag_lists:
        if tags[0] == 'O': s_o += 1
        else: s_i += 1
        for i in range(len(tags)-1):
            p, n = tags[i], tags[i+1]
            if p == 'O':
                if n == 'O': o_o += 1
                else: o_i += 1
            else:
                if n == 'O':
                    i_o += 1
                elif p != n:
                    x_y += 1
                else:
                    i_i += 1

    trans = []
    trans.append(s_o / (s_o + s_i))
    trans.append(s_i / (s_o + s_i))
    trans.append(o_o / (o_o + o_i))
    trans.append(o_i / (o_o + o_i))
    trans.append(i_o / (i_o + i_i + x_y))
    trans.append(i_i / (i_o + i_i + x_y))
    trans.append(x_y / (i_o + i_i + x_y))
    return trans
  1. 首先,函数根据数据加载方式(小样本数据或完整数据集)获取样本列表 samples;
  2. 然后,根据样本列表生成标签列表 tag_lists。对于小样本数据加载方式,直接从样本中提取支持集和查询集的标签。对于完整数据集加载方式,遍历所有样本,从中提取每个样本的标签;
  3. 接着,函数初始化并更新用于计算抽象转移概率的统计变量。具体地,对于每个标签序列:统计标签序列起始为 O 和 I 的次数;统计标签序列从 O 到 O 和从 O 到 I 的次数;统计标签序列从 I 到 O、从 I 到 I 和标签序列中不同标签相邻的次数;
  4. 最后,函数计算并返回标签序列的抽象转移概率列表 trans。其中,trans 列表中的每个元素表示一个抽象转移概率。

get_emmissions将模型输出的logits(即未归一化的得分)根据输入的标签列表进行分割,形成与标签对应的 emissions(发射概率)。

	def __get_emmissions__(self, logits, tags_list):
        # split [num_of_query_tokens, num_class] into [[num_of_token_in_sent, num_class], ...]
        emmissions = []
        current_idx = 0
        for tags in tags_list:
            emmissions.append(logits[current_idx:current_idx+len(tags)])
            current_idx += len(tags)
        assert current_idx == logits.size()[0]
        return emmissions

viterbi_decode

    def viterbi_decode(self, logits, query_tags):
        emissions_list = self.__get_emmissions__(logits, query_tags)
        pred = []
        for i in range(len(query_tags)):
            sent_scores = emissions_list[i].cpu()
            sent_len, n_label = sent_scores.shape
            sent_probs = F.softmax(sent_scores, dim=1)
            start_probs = torch.zeros(sent_len) + 1e-6
            sent_probs = torch.cat((start_probs.view(sent_len, 1), sent_probs), 1)
            feats = self.viterbi_decoder.forward(torch.log(sent_probs).view(1, sent_len, n_label+1))
            vit_labels = self.viterbi_decoder.viterbi(feats)
            vit_labels = vit_labels.view(sent_len)
            vit_labels = vit_labels.detach().cpu().numpy().tolist()
            for label in vit_labels:
                pred.append(label-1)
        return torch.tensor(pred).cuda()

使用维特比解码器来对序列标签进行解码。首先,它将 logits 分割成与查询标签对应的 emissions。然后,对于每个句子,计算发射概率,并且结合转移概率使用维特比算法找出最有可能的标签序列。最后,将解码得到的标签序列转换为张量并返回。

7.调用训练方法
framework.train(model, prefix,
	load_ckpt=opt.load_ckpt, 
	save_ckpt=ckpt,
	val_step=opt.val_step, 
	train_iter=opt.train_iter, 
	warmup_step=int(opt.train_iter * 0.05), 
	val_iter=opt.val_iter, 
	learning_rate=opt.lr, 
	use_sgd_for_lm=opt.use_sgd_for_lm)
def train(self,
              model,
              model_name,
              learning_rate=1e-4,
              train_iter=30000,
              val_iter=1000,
              val_step=2000,
              load_ckpt=None,
              save_ckpt=None,
              warmup_step=300,
              grad_iter=1,
              use_sgd_for_lm=False):
        '''
        model: a FewShotREModel instance
        model_name: Name of the model
        learning_rate: Initial learning rate
        train_iter: Num of iterations of training
        val_iter: Num of iterations of validating
        val_step: Validate every val_step steps
        load_ckpt: Path of the checkpoint to load
        save_ckpt: Path of the checkpoint to save
        warmup_step: Num of warmup steps
        grad_iter: Accumulate gradients for grad_iter steps
        use_sgd_for_lm: Whether to use SGD for the language model
        '''
        # Init optimizer
        print('Use bert optim!')
        parameters_to_optimize = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        parameters_to_optimize = [
            {'params': [p for n, p in parameters_to_optimize 
                if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in parameters_to_optimize
                if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        if use_sgd_for_lm:
            optimizer = torch.optim.SGD(parameters_to_optimize, lr=learning_rate)
        else:
            optimizer = AdamW(parameters_to_optimize, lr=learning_rate)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=train_iter) 
        
        # load model
        if load_ckpt:
            state_dict = self.__load_model__(load_ckpt)['state_dict']
            own_state = model.state_dict()
            for name, param in state_dict.items():
                if name not in own_state:
                    print('ignore {}'.format(name))
                    continue
                print('load {} from {}'.format(name, load_ckpt))
                own_state[name].copy_(param)

        model.train()

        # Training
        iter_loss = 0.0
        best_precision = 0.0
        best_recall = 0.0
        best_f1 = 0.0
        iter_sample = 0
        pred_cnt = 1e-9
        label_cnt = 1e-9
        correct_cnt = 0
        last_step = 0

        print("Start training...")
        with tqdm(self.train_data_loader, total=train_iter, disable=False, desc="Training") as tbar:

            for it, batch in enumerate(tbar):

                if torch.cuda.is_available():
                    for k in batch:
                        if k != 'target_classes' and \
                            k != 'sentence_num' and \
                            k != 'labels' and \
                            k != 'label2tag':
                                batch[k] = batch[k].cuda()

                    label = torch.cat(batch['labels'], 0)
                    label = label.cuda()

                logits, pred, loss = model(batch['inputs'], 
                                            batch['batch_labels'],
                                            batch['valid_masks'],
                                            batch['target_classes'],
                                            batch['sentence_num'])

                loss.backward()
                
                if it % grad_iter == 0:
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                
                # Calculate metrics
                tmp_pred_cnt, tmp_label_cnt, correct = model.metrics_by_entity(pred, label)
                
                iter_loss += self.item(loss.data)
                pred_cnt += tmp_pred_cnt
                label_cnt += tmp_label_cnt
                correct_cnt += correct
                iter_sample += 1
                precision = correct_cnt / pred_cnt
                recall = correct_cnt / label_cnt
                f1 = 2 * precision * recall / (precision + recall + 1e-9) # 1e-9 for error'float division by zero'
                
                tbar.set_postfix_str("loss: {:2.6f} | F1: {:3.4f}, P: {:3.4f}, R: {:3.4f}, Correct:{}"\
                                            .format(self.item(loss.data), f1, precision, recall, correct_cnt))
                
                if (it + 1) % val_step == 0:
                    precision, recall, f1, _, _, _, _ = self.eval(model, val_iter, word_map=self.args.dev_word_map)

                    model.train()
                    if f1 > best_f1:
                        # print(f'Best checkpoint! Saving to: {save_ckpt}\n')
                        # torch.save({'state_dict': model.state_dict()}, save_ckpt)
                        best_f1 = f1
                        best_precision = precision
                        best_recall = recall
                        last_step = it
                    else:
                        if it - last_step >= self.args.early_stopping:
                            print('\nEarly Stop by {} steps, best f1: {:.4f}%'.format(self.args.early_stopping, best_f1))
                            raise KeyboardInterrupt
                
                if (it + 1) % 100 == 0:
                    iter_loss = 0.
                    iter_sample = 0.
                    pred_cnt = 1e-9
                    label_cnt = 1e-9
                    correct_cnt = 0

                if (it + 1)  >= train_iter:
                    break

        print("\n####################\n")
        print("Finish training {}, best f1: {:.4f}%".format(model_name, best_f1))
  1. 初始化变量: iter_loss 用于累计损失,best_precision、best_recall 和 best_f1 用于记录最佳精确度、召回率和F1分数。iter_sample、pred_cnt、label_cnt 和 correct_cnt 用于计算每个迭代步骤的样本数、预测数和正确预测数;
  2. 训练循环: 使用 tqdm 库来显示训练进度条,它提供了一个动态更新的进度条,显示当前迭代的进度和总迭代次数;
  3. 数据处理: 如果使用了GPU,则将除了标签和其他特定字段之外的所有批量数据移动到GPU上;
  4. 前向传播: model 通过输入数据 batch[‘inputs’] 产生 logits,然后通过softmax或其他激活函数得到 pred(预测)。同时计算损失 loss;
  5. 反向传播和优化: 通过调用 loss.backward() 执行反向传播,然后如果迭代次数 it % grad_iter 为0,则执行一步优化器更新 optimizer.step(),并更新学习率 scheduler.step()。之后,清空梯度 optimizer.zero_grad();
  6. 计算指标: 使用 model.metrics_by_entity 方法计算每个实体的精确度、召回率和F1分数;
  7. 更新进度条: 使用 tbar.set_postfix_str 更新进度条,显示当前的损失和F1分数等信息;
  8. 验证循环: 如果当前迭代次数模 val_step 为0,则进行一次验证,计算验证集上的精确度、召回率和F1分数;
  9. 保存最佳模型: 如果验证F1分数比当前最佳F1分数更高,则保存当前模型状态到 save_ckpt 指定的路径;
  10. 早停机制: 如果连续 self.args.early_stopping 次迭代验证F1分数没有提升,则提前停止训练;
  11. 重置变量: 每100次迭代重置损失和样本计数器;
  12. 训练结束: 当达到预定的训练迭代次数 train_iter 时,训练结束。打印最终结果: 打印模型名称和训练结束时的最佳F1分数。

三、模型训练

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/398688.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java基于SSM的羽毛球馆管理系统,附源码

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

halide package cmake的设置方式

1 先找一个例程。里面用到halide。 这时会提示找不到package。 按照那个提示做就行。 2 把提前下载好的halide放到一个位置 3 然后设置一下那个Halide_DIR就可以了 set(Halide_DIR "${CMAKE_SOURCE_DIR}/your_path/Halide/") list(APPEND CMAKE_PREFIX_PATH ${Ha…

认识ansible,了解常用的模块

ansible的概念 Ansible是一个基于Python开发的配置管理和应用部署工具,现在也在自动化管理领域大放异彩。它融合了众多老牌运维工具的优点,Pubbet和Saltstack能实现的功能,Ansible基本上都可以实现。 Ansible能批量配置、部署、管理上千台主…

Tuxera NTFS2024最新中文版支持M1/M2/M3苹果全系机型

Tuxera NTFS的传输速度会受到多种因素的影响,包括硬件配置、文件大小、存储设备的性能等。因此,无法给出具体的传输速度数值。 不过,根据一些用户的使用经验和测试数据,Tuxera NTFS的传输速度通常都非常快,能够满足大…

深度解析Sora的核心技术

Sora要解决的核心问题 Sora面临的挑战是将不同类型的视觉信息,如视频、文本、图像和声音等,整合为一种共同的表征形式。这种转换是实现统一训练过程的关键,旨在将各类数据集中到一个训练框架中,以便于进行大规模的统一学习。简而…

计算机视觉的应用24-ResNet网络与DenseNet网络的对比学习,我们该如何选择。

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用24-ResNet网络与DenseNet网络的对比学习,我们该如何选择。在计算机视觉领域,ResNet(残差网络)和DenseNet(密集网络)都是深度学…

白银交易新手指南:怎样选择可靠的现货交易平台?

在投资市场上,白银作为一种贵金属,具有较高的投资价值和风险防范功能。对于白银交易新手来说,选择一个可靠的现货交易平台是至关重要的。那么,如何挑选一个适合自己的现货交易平台呢? 1. 平台资质 一个正规的现货交易…

【洛谷题解】B2034 计算 2 的幂

题目链接&#xff1a;计算 2 的幂 - 洛谷 题目难度&#xff1a;入门 涉及知识点&#xff1a;pow函数返回值 题意&#xff1a; 分析&#xff1a;用pow计算再强制转换即可 AC代码&#xff1a; #include<bits/stdc.h> using namespace std; int main(){int a;ios::syn…

在线图片生成工具:定制化占位图片的利器

title: 在线图片生成工具&#xff1a;定制化占位图片的利器 date: 2024/2/20 14:08:16 updated: 2024/2/20 14:08:16 tags: 占位图片网页布局样式展示性能测试响应式设计在线生成开发工具 在现代的网页设计和开发中&#xff0c;占位图片扮演着重要的角色。占位图片是指在开发过…

数据结构---字典树(Tire)

字典树是一种能够快速插入和查询字符串的多叉树结构&#xff0c;节点的编号各不相同&#xff0c;根节点编号为0 Trie树&#xff0c;即字典树&#xff0c;又称单词查找树或键树&#xff0c;是一种树形结构&#xff0c;是一种哈希树的变种。 核心思想也是通过空间来换取时间上的…

AFNetWorking源码

套话 AFNetworking是iOS最常用的网络框架&#xff0c;虽然系统也有NSURLSession&#xff0c;但是我们一般不会直接用它。AFNetworking经过了三个大版本&#xff0c;现在用的大多数都是3.x的版本。 AFNetworking经历了下面三个阶段的发展&#xff1a; 1.0版本 : 基于NSURLConn…

opencv鼠标操作与响应

//鼠标事件 Point sp(-1, -1); Point ep(-1, -1); Mat temp; static void on_draw(int event, int x, int y, int flags, void *userdata) {Mat image *((Mat*)userdata);if (event EVENT_LBUTTONDOWN) {sp.x x;sp.y y;std::cout << "start point:"<<…

CTR之行为序列建模用户兴趣:DIN

在前面的文章中&#xff0c;已经介绍了很多关于推荐系统中CTR预估的相关技术&#xff0c;今天这篇文章也是延续这个主题。但不同的&#xff0c;重点是关于用户行为序列建模&#xff0c;阿里出品。 概要 论文&#xff1a;Deep Interest Network for Click-Through Rate Predict…

C#写的一个计算DCI-P3色域和SRGB的小工具

文章最后附带分享链接与提取码 方便需要测试屏幕的小伙伴&#xff0c;只需要输入RGB就能得到覆盖率与比率&#xff0c;W计算色温&#xff0c;不测也要写上&#xff0c;不然会报错 链接&#xff1a;https://pan.baidu.com/s/1wdmAwmwiXjNvn1tGsvy0HA 提取码&#xff1a;1234

【力扣hot100】刷题笔记Day8

前言 到了大章节【链表】了&#xff0c;争取两三天给它搞定&#xff01;&#xff01; 160. 相交链表 - 力扣&#xff08;LeetCode&#xff09;】 双指针 参考题解&#xff0c;相比于求长度右对齐再一起出发的方法简洁多了 class Solution:def getIntersectionNode(self, head…

【安卓基础2】简单控件

&#x1f3c6;作者简介&#xff1a;|康有为| &#xff0c;大四在读&#xff0c;目前在小米安卓实习&#xff0c;毕业入职。 &#x1f3c6;安卓学习资料推荐&#xff1a; 视频&#xff1a;b站搜动脑学院 视频链接 &#xff08;他们的视频后面一部分没再更新&#xff0c;看看前面…

机器人内部传感器阅读笔记及心得-位置传感器-光电编码器

目前&#xff0c;机器人系统中应用的位置传感器一般为光电编码器。光电编码器是一种应用广泛的位置传感器&#xff0c;其分辨率完全能满足机器人的技术要求&#xff0c;这种非接触型位置传感器可分为绝对型光电编码器和相对型光电编码器。前者只要将电源加到用这种传感器的机电…

9、使用 ChatGPT 的 GPT 制作自己的 GPT!

使用 ChatGPT 的 GPT 制作自己的 GPT! 想用自己的 GPT 超越 GPT ChatGPT 吗?那么让我们 GPT GPT 吧! 山姆 奥特曼利用这个机会在推特上宣传 GPTs 的同时还猛烈抨击了埃隆的格罗克。 GPTs概览 他们来了! 在上周刚刚宣布之后,OpenAI 现在推出了其雄心勃勃的新 ChatGPT…

微服务-Alibaba微服务nacos实战

1. Nacos配置中心 1.1 微服务为什么需要配置中心 在微服务架构中&#xff0c;当系统从一个单体应用&#xff0c;被拆分成分布式系统上一个个服务节点后&#xff0c;配置文件也必须跟着迁移&#xff08;分割&#xff09;&#xff0c;这样配置就分散了&#xff0c;不仅如此&…

Sora给中国AI带来的真实变化

OpenAI的最新技术成果——文生视频模型Sora&#xff0c;在春节假期炸裂登场&#xff0c;令海内外的AI从业者、投资人彻夜难眠。 如果你还没有关注到这个新闻&#xff0c;简单介绍一下&#xff1a;Sora是OpenAI使用超大规模视频数据&#xff0c;训练出的一个通用视觉模型&#x…
最新文章