深度学习之生成唐诗案例(Pytorch版)

主要思路:

对于唐诗生成来说,我们定义一个"S" 和 "E"作为开始和结束。

 示例的唐诗大概有40000多首,

首先数据预处理,将唐诗加载到内存,生成对应的word2idx、idx2word、以及唐诗按顺序的字序列。

Dataset_Dataloader.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


def deal_tangshi():
    with open("poems.txt", "r", encoding="utf-8") as fr:
        lines = fr.read().strip().split("\n")

    tangshis = []
    for line in lines:
        splits = line.split(":")
        if len(splits) != 2:
            continue
        tangshis.append("S" + splits[1] + "E")

    word2idx = {"S": 0, "E": 1}
    word2idx_count = 2

    tangshi_ids = []

    for tangshi in tangshis:
        for word in tangshi:
            if word not in word2idx:
                word2idx[word] = word2idx_count
                word2idx_count += 1

    idx2word = {idx: w for w, idx in word2idx.items()}

    for tangshi in tangshis:
        tangshi_ids.extend([word2idx[w] for w in tangshi])

    return word2idx, idx2word, tangshis, word2idx_count, tangshi_ids


word2idx, idx2word, tangshis, word2idx_count, tangshi_ids = deal_tangshi()


class TangShiDataset(Dataset):
    def __init__(self, tangshi_ids, num_chars):
        # 语料数据
        self.tangshi_ids = tangshi_ids
        # 语料长度
        self.num_chars = num_chars
        # 词的数量
        self.word_count = len(self.tangshi_ids)
        # 句子数量
        self.number = self.word_count // self.num_chars

    def __len__(self):
        return self.number

    def __getitem__(self, idx):
        # 修正索引值到: [0, self.word_count - 1]
        start = min(max(idx, 0), self.word_count - self.num_chars - 2)

        x = self.tangshi_ids[start: start + self.num_chars]
        y = self.tangshi_ids[start + 1: start + 1 + self.num_chars]

        return torch.tensor(x), torch.tensor(y)


def __test_Dataset():
    dataset = TangShiDataset(tangshi_ids, 8)
    x, y = dataset[0]

    print(x, y)


if __name__ == '__main__':
    # deal_tangshi()
    __test_Dataset()
TangShiModel.py:唐诗的模型
import torch
import torch.nn as nn
from Dataset_Dataloader import *
import torch.nn.functional as F


class TangShiRNN(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # 初始化词嵌入层
        self.ebd = nn.Embedding(vocab_size, 128)
        # 循环网络层
        self.rnn = nn.RNN(128, 128, 1)
        # 输出层
        self.out = nn.Linear(128, vocab_size)

    def forward(self, inputs, hidden):

        embed = self.ebd(inputs)

        # 正则化层
        embed = F.dropout(embed, p=0.2)

        output, hidden = self.rnn(embed.transpose(0, 1), hidden)

        # 正则化层
        embed = F.dropout(output, p=0.2)

        output = self.out(output.squeeze())

        return output, hidden

    def init_hidden(self):
        return torch.zeros(1, 64, 128)

 main.py:

import time

import torch

from Dataset_Dataloader import *
from TangShiModel import *
import torch.optim as optim
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train():
    dataset = TangShiDataset(tangshi_ids, 128)
    epochs = 100
    model = TangShiRNN(word2idx_count).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for idx in range(epochs):
        dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)
        start_time = time.time()
        total_loss = 0
        total_num = 0
        total_correct = 0
        total_correct_num = 0
        hidden = model.init_hidden()

        for x, y in tqdm(dataloader):
            x = x.to(device)
            y = y.to(device)
            # 隐藏状态
            hidden = model.init_hidden()
            hidden = hidden.to(device)
            # 模型计算
            output, hidden = model(x, hidden)
            # print(output.shape)
            # print(y.shape)
            # 计算损失
            loss = criterion(output.permute(1, 2, 0), y)
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 参数更新
            optimizer.step()

            total_loss += loss.sum().item()
            total_num += len(y)
            total_correct_num += y.shape[0] * y.shape[1]
            # print(output.shape)
            total_correct += (torch.argmax(output.permute(1, 0, 2), dim=-1) == y).sum().item()

        print("epoch : %d average_loss : %.3f average_correct : %.3f use_time : %ds" %
              (idx + 1, total_loss / total_num, total_correct / total_correct_num, time.time() - start_time))

        torch.save(model.state_dict(), f"./modules/tangshi_module_{idx + 1}.bin")


if __name__ == '__main__':
    train()

predict.py:

import torch
import torch.nn as nn
from Dataset_Dataloader import *
from TangShiModel import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def predict():
    model = TangShiRNN(word2idx_count)
    model.load_state_dict(torch.load("./modules/tangshi_module_100.bin", map_location=torch.device('cpu')))

    model.eval()

    hidden = torch.zeros(1, 1, 128)

    start_word = input("输入第一个字:")

    flag = None

    tangshi_strs = []

    while True:
        if not flag:
            outputs, hidden = model(torch.tensor([[word2idx["S"]]], dtype=torch.long), hidden)
            tangshi_strs.append("S")
            flag = True
        else:
            tangshi_strs.append(start_word)
            outputs, hidden = model(torch.tensor([[word2idx[start_word]]], dtype=torch.long), hidden)
            top_i = torch.argmax(outputs, dim=-1)

            if top_i.item() == word2idx["E"]:
                break

            print(top_i)

            start_word = idx2word[top_i.item()]
        print(tangshi_strs)


if __name__ == '__main__':
    predict()

完整代码如下:

https://github.com/STZZ-1992/tangshi-generator.giticon-default.png?t=N7T8https://github.com/STZZ-1992/tangshi-generator.git

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/170774.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【HarmonyOS】低代码平台组件拖拽使用技巧之常用基础组件(上)

【关键字】 HarmonyOS、低代码平台、组件拖拽、常用基础组件、基础容器 1、写在前面 之前是花了一些时间介绍了在低代码平台中滚动容器、网格布局、页签容器、列表这几种容器的拖拽技巧及使用方法,今天我会继续来介绍咱们在应用开发中可能会经常用到的一些基础容器…

捷报连连!怿星科技荣获北京市科学技术进步奖一等奖

近期,北京市科学技术委员会、中关村科技园区管理委员会揭晓了2022年北京市科学技术奖的获奖名单。其中,由清华大学牵头、怿星科技参与开发的《电动汽车底盘运动控制与能量管理关键技术及应用》项目荣获“北京市科学技术进步奖一等奖”。 作为北京市政府设…

【开源】基于Vue.js的车险自助理赔系统的设计和实现

项目编号: S 018 ,文末获取源码。 \color{red}{项目编号:S018,文末获取源码。} 项目编号:S018,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 角色管理模块2.3 车…

“玄学+社交+AI”最全解题思路,融云 AI 对话方案全力支持

“东北 I 人异于常人”成了 MBTI 最新热梗。互联网 Meme 在放过了“为 I 做 E”后,开始对 MBTI 做更精细的划分了。关注【融云全球互联网通信云】了解更多 一切皆可玄学,今年爆火的还有香灰琉璃和十八籽手串,作为年轻人“在上进与上班中选择了…

java系列之 页面打印出 [object Object],[object Object]

我 | 在这里 🕵️ 读书 | 长沙 ⭐软件工程 ⭐ 本科 🏠 工作 | 广州 ⭐ Java 全栈开发(软件工程师) 🎃 爱好 | 研究技术、旅游、阅读、运动、喜欢流行歌曲 🏷️ 标签 | 男 自律狂人 目标明确 责任心强 ✈️公…

深入探索 PaddlePaddle 中的计算图

**引言** 计算图是深度学习平台 PaddlePaddle 的核心组件之一,它提供了一种图形化的方式来表示和执行深度学习模型。通过了解和理解 PaddlePaddle 中的计算图,我们可以更好地理解深度学习的工作原理,并且能够更加灵活和高效地构建和训练复杂…

西米支付”:在游戏SDK中,提供了哪些支付渠道?SDK的用处?

在游戏SDK中,提供了哪些支付渠道? 常见的支付方式包括支付宝、微信支付、银联支付等。游戏SDK的支付功能可以方便玩家选择不同的支付渠道,以满足他们个性化的支付需求。 流行的支付应用:该应用集成了流行的支付应用支付接口&#…

如何解决requests库自动确定认证arded 类型

requests 库是一种非常强大的爬虫工具,可以用于快速构建高效和稳定的网络爬虫程序。对于经常使用爬虫IP用来网站爬虫反爬策略的我来说,下面遇到的问题应当值得我们思考一番。 问题背景 在使用requests库进行网络请求时,有时会遇到需要对目标服务进行认证…

基于Java封装继承多态实现的一个简单图书系统

首先我们大概了解下图书系统的需求 1.要有两种身份 管理员和普通用户。普通用户和管理员分别对应的功能不一样,需要分开实现 2. 图书系统肯定要有图书,和存放图书的地方,存放就用数组来实现 3.实现对应用户的功能 接下来我们第一步&#xf…

数字化转型背景下,企业如何做好知识管理?

在当今数字化转型的时代,企业面临着日益复杂和快速变化的商业环境。知识管理成为了企业成功的关键之一。有效地管理和利用知识资源可以提升企业的创新能力、决策质量和竞争力。以下我列了一些关键的点,讲讲在数字化转型背景下,企业如何可以做…

Qt程序打包成.exe可执行文件

1.使用Release进行编译 2.找到编译成功的地址: 找到对应的目录 3.把SerialTool.exe文件单独复制到一个文件夹,这里我直接在桌面创建一个SerialTool文件夹,这时候直接运行是不行的,我们需要把库都导进去 4. 在安装目录找到如下这个文件,点击打开,找到你电脑对应的版本即可,我这…

印刷企业实施MES管理系统需要哪些硬件设施

随着科技的飞速发展,印刷行业正面临着前所未有的挑战和机遇。为了提高生产效率,降低成本,并增强市场竞争力,越来越多的印刷企业开始实施制造执行系统(MES)管理系统。本文将重点讨论印刷企业在实施MES管理系…

Java 多线程进阶

1 方法执行与进程执行 GetMapping("/demo1")public void demo1(){//方法调用new ThreadTest1("run1").run();//线程调用new ThreadTest1("run2").start();} 下断点调试信息,可以看到run()方法当前线程是“main1” 继续运行到run里面&…

软件测试/测试开发/人工智能丨​Python运算符解析,小白也能轻松get

什么是运算符 运算符是用于进行各种运算操作的符号或关键词。 在数学和计算机编程中,运算符被用来表示不同的运算操作,例如加法、减法、乘法、除法等。 比如: 4 5,其中,4和5为操作数,为运算符。a 10,…

HandBrake :MacOS专业视频转码工具

handbrake 俗称大菠萝,是一款免费开源的视频转换、压缩软件,它几乎支持目前市面上所能见到的所有视频格式,并且支持电脑硬件压缩,是一款不可多得的优秀软件 优点 ∙Windows, Linux, Mac 三平台支持 ∙开源、免费、无广告 ∙支…

uni-app - 弹出框

目录 1.基本介绍 2.原生uinapp 通过uni.showActionSheet实现 3.使用组件 Popup 弹出层 ③效果展示 1.基本介绍 弹出框让我们在需要时在屏幕底部弹出一个菜单,它通常用于在各种应用程序中进行选择操作。Uniapp为我们提供了基本的底部弹出框组件,但它也有…

全国见!飞桨星河社区五周年,邀你共赴大模型盛宴!

自2018年对外发布以来,飞桨星河社区已汇集660万AI开发者。感谢大家一路见证了飞桨星河社区的成长, 也很荣幸飞桨星河社区陪伴了大家的AI开发旅程。 在这个大模型时代, 飞桨星河社区期待可以帮助开发者们实现自我价值, 获得更多成长…

[JDK工具-3] javac编译器生成class文件 java执行器运行class文件

位置:jdk\bin 语法:javac 源文件 -d class文件输出路径 -encoding utf-8 javac HelloWorld.java -d D:\project1\java8\java8\xin-javademo\src\main\java\com\xin\demo\hutooldemo\ -encoding utf-8 语法:java 类文件完全限定名(…

CRM系统的销售预测是什么?怎么做?

简单来说,销售预测可以通过销售关键信息为团队预测收入,分配目标。CRM中的销售预测可以帮助企业制定合理的销售目标和策略,并通过实时数据发现瓶颈所在,提高团队绩效。下面说说CRM中销售预测是什么?如何销售预测&#…

数据中心走向绿色低碳,液冷存储舍我其谁

引言:没有最冷,只有更冷,绿色低碳早已成为行业关键词。 【全球存储观察 | 科技热点关注】 每一次存储行业的创新,其根源离不开行业端的用户需求驱动。 近些年从数据中心建设的整体发展情况来看,从风冷到…
最新文章