Embedding例子:简单NN网络、迁移学习例子

一、简单例子:构造简单NN网络生成Embedding

1、pytorch例子

2、tensorflow例子

# 1导入模块
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding
import numpy as np

# 2构建语料库
corpus=[
  ["The", "weather", "will", "be", "nice", "tomorrow"],
  ["How", "are", "you", "doing", "today"],
  ["Hello", "world", "!"]
]

# 3生成字典
#获取语料不同单词,并过滤掉一些字符如"!"
word_set=set([i for item in corpus for i in item if i!='!']) 
word_dicts={}

#索引从1开始,0用来填充
j=1
for i in word_set:
    word_dicts[i]=j 
    j=j+1
   
# 4用索引表示语料
raw_inputs=[]
for i in range(len(corpus)):
    raw_inputs.append([word_dicts[j]  for j in corpus[i] if j!="!"])
    padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(raw_inputs,padding='post')
    print(padded_inputs)

# 5构建网络
model = Sequential()
model.add(Embedding(20, 4, input_length=6,mask_zero=True))
model.compile('rmsprop', 'mse')
output_array = model.predict(padded_inputs)
output_array.shape
# 6 查看结果
output_array[1]

输出结果:

二、迁移学习: 使用预训练模型生成Embedding

1、使用Glove预训练数据集迁移学习

import os

imdb_dir = './aclImdb' # 电影评论数据集
train_dir = os.path.join(imdb_dir, 'train')

labels = []
texts = []

for label_type in ['neg', 'pos']:
    dir_name = os.path.join(train_dir, label_type)
    for fname in os.listdir(dir_name):
        if fname[-4:] == '.txt':
            f = open(os.path.join(dir_name, fname))
            texts.append(f.read())
            f.close()
            if label_type == 'neg':
                labels.append(0)
            else:
                labels.append(1)

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np

maxlen = 100  # 只保留前100单词的评论
training_samples = 200  # 在200个样本上训练
validation_samples = 10000  # W对10000个样品进行验证
max_words = 10000  # 只考虑数据集中最常见的10000 个单词

tokenizer = Tokenizer(num_words=max_words)
tokenizer.fit_on_texts(texts)
sequences = tokenizer.texts_to_sequences(texts)

word_index = tokenizer.word_index
print('Found %s unique tokens.' % len(word_index))

data = pad_sequences(sequences, maxlen=maxlen)

labels = np.asarray(labels)
print('Shape of data tensor:', data.shape)
print('Shape of label tensor:', labels.shape)

# 将数据划分为训练集和验证集
# 首先打乱数据, 因一开始数据集是排序好的
# 负面评论在前, 正面评论在后
indices = np.arange(data.shape[0])
np.random.shuffle(indices)
data = data[indices]
labels = labels[indices]

x_train = data[:training_samples]
y_train = labels[:training_samples]
x_val = data[training_samples: training_samples + validation_samples]
y_val = labels[training_samples: training_samples + validation_samples]

glove_dir = './glove.6B/'

embeddings_index = {}
f = open(os.path.join(glove_dir, 'glove.6B.100d.txt'))
for line in f:
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

print('Found %s word vectors.' % len(embeddings_index))

for key,value in embeddings_index.items():
    print(key,value)
    break

embedding_dim = 100

embedding_matrix = np.zeros((max_words, embedding_dim))
for word, i in word_index.items():
    embedding_vector = embeddings_index.get(word)
    if i < max_words:
        if embedding_vector is not None:
            # 在嵌入索引(embedding index)找不到的词,其嵌入向量都设为0
            embedding_matrix[i] = embedding_vector
            
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, Flatten, Dense

model = Sequential()
model.add(Embedding(max_words, embedding_dim, input_length=maxlen))
model.add(Flatten())
model.add(Dense(32, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()

model.layers[0].set_weights([embedding_matrix])
model.layers[0].trainable = False

model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['acc'])
history = model.fit(x_train, y_train,
                    epochs=10,
                    batch_size=32, 
                    validation_data=(x_val, y_val))
model.save_weights('pre_trained_glove_model.h5')
import matplotlib.pyplot as plt
%matplotlib inline

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc) + 1)

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

输出结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/556608.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

前端框架深度技术革新历程:从原生DOM操作到数据双向绑定、虚拟DOM等框架原理深度解析,Web开发与用户体验的共赢

前端的发展与前端框架的发展相辅相成&#xff0c;形成了相互驱动、共同演进的关系。前端技术的进步不仅催生了前端框架的产生&#xff0c;也为其发展提供了源源不断的动力。 前端的发展 前端&#xff0c;即Web前端&#xff0c;是指在创建Web应用程序或网站过程中负责用户界面…

爱普生无源晶体MC-146特点,应用介绍

爱普生的MC-146系列产品&#xff0c;应用广泛&#xff0c;如小的通讯社本&#xff0c;工业控制等等&#xff0c;几乎涉及各个领域。属于现阶段性价比非常不错的一个系列。晶体振荡器有很多种类&#xff0c;无源晶体其中最简单的一个类。在每个设计中&#xff0c;要用到非常多的…

LabVIEW供热管道泄漏监测与定位

LabVIEW供热管道泄漏监测与定位 在现代城市的基础设施建设中&#xff0c;供热管道系统起着极其重要的作用。然而&#xff0c;管道泄漏问题不仅导致巨大的经济损失&#xff0c;还对公共安全构成威胁。因此&#xff0c;开发一种高效、准确的管道泄漏监测与定位技术显得尤为关键。…

Mac 部署 GPT-2 预训练模型 gpt2-chinese-cluecorpussmall

文章目录 下载 GPT-2 模型快速开始 GPT-2 下载 GPT-2 模型 https://huggingface.co/uer/gpt2-chinese-cluecorpussmall git clone https://huggingface.co/uer/gpt2-chinese-cluecorpussmall # 或单独下载 LFS GIT_LFS_SKIP_SMUDGE1 git clone https://huggingface.co/uer/gpt…

清洗机什么牌子好质量过硬、四大公认最好用的超声波清洗机

现在十个人中有九个人都是戴眼镜的&#xff0c;眼镜已成为我们生活中不可或缺的一部分。无论是用于视力矫正&#xff0c;还是作为时尚配饰&#xff0c;眼镜都承载着重要的角色。然而&#xff0c;很多人在享受眼镜带来便利的同时&#xff0c;却忽视了对眼镜的适当清洁和维护。殊…

Trivy离线扫描:容器安全实践指南

一、Trivy简介 1.1 Trivy 概述 Trivy 是一款全面多功能的安全扫描器。Trivy具有寻找安全问题和目标的扫描器。现已经被 Github Action、Harbor 等主流工具集成&#xff0c;Trivy支持大多数流行的编程语言、操作系统和平台的扫描&#xff0c;应该是该领域目前目前采用最广的开…

数据可视化插件echarts【前端】

数据可视化插件echarts【前端】 前言版权开源推荐数据可视化插件echarts一、如何使用1.1 下载1.2 找到js文件1.3 入门使用1.4 我的使用 二、前后端交互&#xff1a;入门demo2.1 前端htmljs 2.2 后端entitycontrollerservicemapper 三、前后端交互&#xff1a;动态数据3.1 前端j…

ChatGPT AI 教我用python实现工作久坐定时提醒工具,防猝死!

日常工作学习久坐的危害很大&#xff0c;非常伤害颈椎和腰椎&#xff0c;严重危害上班族的身体健康&#xff0c;强烈建议久坐后间隔一小时活动一下&#xff0c;最好是能够调整好自己坐姿&#xff0c;行为举止一定要正确&#xff0c;为了您的老腰&#xff01; 久坐一族&#xff…

Linux——日志的编写与线程池

目录 前言 一、日志的编写 二、线程池 1.线程池基本原理 2.线程池作用 3.线程池的实现 前言 学了很多线程相关的知识点&#xff0c;线程控制、线程互斥、线程同步&#xff0c;今天我们将他们做一个总结&#xff0c;运用所学知识写一个较为完整的线程池&#xff0c;同时…

mac: docker安装及其Command not found: docker

已经安装了docker desktop&#xff0c;没安装的 点击安装 傻瓜式安装即可 接着打开终端&#xff1a;好一个 Comand not found:docker 看我不把你整顿&#xff0c;解决如下&#xff1a; 如果你在 macOS 上安装了 Docker Desktop&#xff0c;但是终端无法识别 docker 命令&…

目标检测——多模态人体动作数据集

一、重要性及意义 连续多模态人体动作识别检测的重要性及意义主要体现在以下几个方面&#xff1a; 首先&#xff0c;它极大地提升了人体动作识别的准确性和稳定性。由于人体动作具有复杂性和多样性&#xff0c;单一模态的数据往往难以全面、准确地描述动作的特征。而连续多模…

深度学习数据处理——对比标签文件与图像文件,把没有打标签的图像文件标记并删除

要对比目录下的jpg文件与json文件&#xff0c;并删除那些没有对应json文件的jpg文件&#xff0c;这个在深度学习或者机器学习时常会遇到。比如对一个数据集做处理时&#xff0c;往往会有些图像不用标注&#xff0c;那么这张图像是没有对应的标签文件的&#xff0c;这个时候又不…

python-django企业设备配件检修系统flask+vue

本课题使用Python语言进行开发。代码层面的操作主要在PyCharm中进行&#xff0c;将系统所使用到的表以及数据存储到MySQL数据库中&#xff0c;方便对数据进行操作本课题基于WEB的开发平台&#xff0c;设计的基本思路是&#xff1a; 前端&#xff1a;vue.jselementui 框架&#…

玄子Share-LVM与磁盘配额

玄子Share-LVM与磁盘配额 LVM概述 Logical Volume Manager&#xff0c;逻辑卷管理 Linux系统中对磁盘分区进行管理的一种逻辑机制&#xff0c;是建立在硬盘和分区之上的一个逻辑层动态调整磁盘容量&#xff0c;从而提高磁盘管理的灵活性 /boot分区用于存放引导文件&#xff…

服务器中查看CPU/GPU使用情况的常用命令

1、查看显卡 nvidia-smi2、间隔查看GPU使用情况 间隔5s刷新信息 watch -n 5 nvidia-smiCtrlC退出 参考博文&#xff1a;https://mbd.baidu.com/ug_share/mbox/4a83aa9e65/share?productsmartapp&tk6ff15196d305c4dd3daab94b4abb81a4&share_urlhttps%3A%2F%2Fyebd1h…

JavaSE备忘录(未完)

文章目录 基本数据类型println 小知识除法( / ) 和 Infinity(无穷) 小知识除法InfinityInfinity 在除法中正负判断 求余(%) 小知识 基本数据类型 除 int、char 的包装类分别为 Integer、Character 外&#xff0c;其余基本数据类型的第一个字母大写就是它的包装类。 println 小…

vscode自动生成返回值的快捷键

vscode中类似idea的altenter功能&#xff0c;可以添加返回值 idea中是Introduce local variable&#xff0c; vscode中按下command.(句号) 然后选extract to local variable或者 Assign statement to new local variable都行&#xff0c; 光标在分号前如图&#xff1a; 光标在…

Redis快速入门操作

启动Redis 进入命令行客户端 字符串命令常用操作&#xff08;redis默认使用字符串来存储数据&#xff09; 列表&#xff08;Lists&#xff09;常用操作 集合&#xff08;Sets&#xff09;常用操作 &#xff08;无序集合且元素不可重复&#xff09; 有序集合&#xff08;So…

云原生虚拟数仓 PieCloudDB Database 4月更新盘点

第一部分 PieCloudDB Database 最新动态 增强本地缓存文件生命周期管理 PieCloudDB 在最新版本中增强了本地缓存文件生命周期管理&#xff0c;执行器节点重启之后可以继续使用之前缓存在本地的数据文件&#xff0c;从而节约重新从远端下载数据文件的带宽资源&#xff0c;提升…

DFS之剪枝(上交考研题目--正方形数组的数目)

题目 给定一个非负整数数组 A A A&#xff0c;如果该数组每对相邻元素之和是一个完全平方数&#xff0c;则称这一数组为正方形数组。 返回 A A A 的正方形排列的数目。 两个排列 A 1 A1 A1 和 A 2 A2 A2 不同的充要条件是存在某个索引 i i i&#xff0c;使得 A 1 [ i …