NLP(4)--实现一个线性层

前言

仅记录学习过程,有问题欢迎讨论

感觉全连接层就像一个中间层转换数据的形态的,或者说预处理数据?

代码

里面有两个部分,一部分是自己实现的,一部分是利用模块的方法实现的。

import torch
import torch.nn as nn
import numpy as np

"""
numpy手动实现模拟一个线性层
"""


# 搭建一个2层的神经网络模型
# 每层都是线性层
# 继承 nn.Module
class TorchModel(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2):
        super(TorchModel, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size1) # 5*3
        self.layer2 = nn.Linear(hidden_size1, hidden_size2)

    def forward(self, x):
        x = self.layer1(x)
        # 第二层输出就是y
        y_pred = self.layer2(x)
        return y_pred


# 自定义模型
class DiyModel:
    def __init__(self, w1, b1, w2, b2):
        self.w1 = w1
        self.b1 = b1
        self.w2 = w2
        self.b2 = b2

    def forward(self, x):
        # 点 * w1.T 是转置 2*3 * 3*5 === 2*5
        hidden = np.dot(x, self.w1.T) + self.b1  # 1*5
        # 2*5 * 5*2 ===  2* 2
        y_pred = np.dot(hidden, self.w2.T) + self.b2  # 1*2
        return y_pred


# 随便准备一个网络输入 2*3
x = np.array([[3.1, 1.3, 1.2],
              [2.1, 1.3, 13]])
# 建立torch 参数是隐藏层的维度
torch_model = TorchModel(3, 5, 2)
# 字典 包含了w b 因为隐藏层的函数为: y = w*x + b
print(torch_model.state_dict())
print("-----------")
#打印模型权重,权重为随机初始化
torch_model_w1 = torch_model.state_dict()["layer1.weight"].numpy()
torch_model_b1 = torch_model.state_dict()["layer1.bias"].numpy()
torch_model_w2 = torch_model.state_dict()["layer2.weight"].numpy()
torch_model_b2 = torch_model.state_dict()["layer2.bias"].numpy()
print(torch_model_w1, "torch w1 权重")
print(torch_model_b1, "torch b1 权重")
print("-----------")
print(torch_model_w2, "torch w2 权重")
print(torch_model_b2, "torch b2 权重")
print("-----------")

# 预测
torch_x = torch.FloatTensor(x)
y_pred = torch_model.forward(torch_x)
# 2*2 的矩阵
print("torch模型预测结果:", y_pred)


# #把torch模型权重拿过来自己实现计算过程
diy_model = DiyModel(torch_model_w1, torch_model_b1, torch_model_w2, torch_model_b2)
# #用自己的模型来预测
y_pred_diy = diy_model.forward(np.array(x))
print("diy模型预测结果:", y_pred_diy)




本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/558634.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ELK日志采集系统

1.什么是ELK ELK 是一套流行的数据搜索、分析和可视化解决方案,由三个开源项目组成,每个项目的首字母合起来形成了“ELK”这一术语: Elasticsearch (ES): Elasticsearch 是一个基于 Apache Lucene 构建的分布式、实时搜索与分析引擎。它能够…

小程序AI智能名片S2B2C商城系统:做内容、造IP、玩社群打造私域流量的新营销秘籍

在数字化浪潮汹涌的新时代,小程序AI智能名片S2B2C商城系统正以其独特的魅力,引领着营销领域的新变革。这套系统不仅将人工智能与小程序技术完美结合,更通过创新的S2B2C模式,为企业打开了一扇通往成功的大门。 面对激烈的市场竞争&…

Jenkins 的构建时执行时间问题

我们希望我的项目能够在特定的时间自动执行,我们需要设定一个定时任务。 Jenkins 的定时任务是通过 Cron 任务来实现的,但是由有点不一样。 H/2 * * * * 比如说上面的设置就是每 2 分钟执行一次。 希望每分钟执行一次 Jenkins 的每分钟执行一次的设置…

c++头文件string函数的用法

目录 前言: 字符串截取 字符串插入与替换 字符串区间删除 字符串排序与相加和查找 如后续需文字描述,,请评论区告诉我,我看到后会进行添加一些文字描述。 前言: 因本人女朋友在学习c过程中在一些知识网页上学了st…

uni-app中页面生命周期与vue生命周期的执行顺序对比

应用生命周期 uni-app 支持如下应用生命周期函数: 函数名说明平台兼容onLaunch当uni-app 初始化完成时触发(全局只触发一次),参数为应用启动参数,同 uni.getLaunchOptionsSync 的返回值onShow当 uni-app 启动&#x…

09 MySQL--操作真题

1. 用一条 SQL 语句&#xff0c;查询出每门课程都大于 80 分的人。 分析&#xff1a; 去重查询出存在课程小于 80 分的人&#xff0c;设为集合A查询不在集合 A 中的人 # 第一步&#xff1a;找小于等于80分的学员姓名 select distinct name from t_student where fenshu <…

跨境电商指南:防关联浏览器和云主机有什么区别?

跨境电商的卖家分为独立站卖家和平台卖家。前者会自己开设独立站点&#xff0c;比如通过 shopify&#xff1b;后者则是入驻亚马逊或 Tiktok 等平台&#xff0c;开设商铺。其中平台卖家为了扩大收益&#xff0c;往往不止开一个店铺&#xff0c;或者有店铺代运营的供应商&#xf…

记一次中间件宕机以后持续请求导致应用OOM的排查思路(server.max-http-header-size属性配置不当的严重后果)

一、背景 最近有一次在系统并发比较高的时候&#xff0c;数据库突然发生了故障&#xff0c;导致大量请求失败&#xff0c;在数据库宕机不久&#xff0c;通过应用日志可以看到系统发生了OOM。 二、排查 初次看到这个现象的时候&#xff0c;我还是有点懵逼的&#xff0c;数据库…

解决方案ImportError: cannot import name ‘BertTokenizerFast‘ from ‘transformers‘

文章目录 一、现象二、解决方案 一、现象 从transformers 库调用该包的时候 from transformers import BertTokenizer, AdamW, BertTokenizerFast报错显示 ImportError: cannot import name ‘BertTokenizerFast’ from ‘transformers’ 二、解决方案 追溯查看transforme…

人工智能论文GPT-3(1):2020.5 Language Models are Few-Shot Learners;摘要;引言;scaling-law

摘要 近期的工作表明&#xff0c;在大量文本语料库上进行预训练&#xff0c;然后针对特定任务进行微调&#xff0c;可以在许多NLP任务和基准测试中取得实质性进展。虽然这种方法在架构上通常是与任务无关的&#xff0c;但仍然需要包含数千或数万示例的针对特定任务的微调数据集…

【解决】Caused by: javax.net.ssl.SSLHandshakeException: PKIX path building failed

问题原因&#xff1a; 在Java8及高版本以上的版本在源应用程序不信任目标应用程序的证书&#xff0c;因为在源应用程序的JVM信任库中找不到该证书或证书链。也就是目标站点启用了HTTPS 而缺少安全证书时出现的异常 解决方案&#xff1a; 我使用的是忽略证书验证 public clas…

面试算法-173-二叉树的直径

题目 给你一棵二叉树的根节点&#xff0c;返回该树的 直径 。 二叉树的 直径 是指树中任意两个节点之间最长路径的 长度 。这条路径可能经过也可能不经过根节点 root 。 两节点之间路径的 长度 由它们之间边数表示。 示例 1&#xff1a; 输入&#xff1a;root [1,2,3,4,…

水电预付费系统多少钱?

一、水电预付费系统的定义与优势 水电预付费系统是一种现代化的管理方式&#xff0c;它颠覆了传统的后付费模式&#xff0c;让用户在使用水电前先进行支付。这种系统通常包括智能电表、充值终端、后台管理系统等组成部分&#xff0c;通过自动化处理&#xff0c;实现费用的预先…

MATLAB实现蚁群算法优化柔性车间调度(ACO-fjsp)

蚁群算法优化车间调度的步骤可以分为以下几个主要阶段&#xff1a; 1.初始化阶段&#xff1a; 设置算法参数&#xff0c;如信息素浓度、启发式因子等。这些参数将影响蚂蚁在选择路径时的决策过程。 确定车间调度的具体问题规模&#xff0c;包括工件数量、机器数量以及每个工件…

通过Docker新建并使用MySQL数据库

1. 安装Docker 确保您的系统上已经安装了Docker。可以通过以下命令检查Docker是否安装并运行&#xff1a; systemctl status docker如果没有安装或运行&#xff0c;请按照官方文档进行安装和启动。 2. 拉取MySQL镜像 从Docker Hub拉取MySQL官方镜像。这里以MySQL 5.7版本为…

【数学】深度学习中的概率基础知识记录

基于 Deep Learning (2017, MIT) 书总结了必要的概率知识 原blog 以及用到的Ipython notebook 文章目录 1 概述2 知识2.1 离散变量和概率质量函数&#xff08;PMF&#xff09;2.2 连续变量和概率密度函数&#xff08;PDF&#xff09;2.3 边缘概率2.4 条件概率2.5 条件概率的链式…

Qt gsl库配置踩坑记录

想求解非线性方程组&#xff0c;之前使用拟牛顿法写过相关的matlab代码&#xff0c;这次想移植到C代码&#xff0c;网上说gsl库挺好用的&#xff0c;于是我也想试一下。相关参考&#xff1a; 【C】GSL(GNU Scientific Library) 的安装及在 Visual Studio 2017 中的使用 QT5使用…

k8s部署Eureka集群

部署有状态负载 镜像配置&#xff1a; 环境变量如下&#xff1a; AUTHENTICATE_ENABLEtrue JAVA_OPTS-Dauth.userName账号 -Dauth.password密码 MY_POD_NAMEmetadata.name BOOL_REGISTERtrue BOOL_FETCHtrue APPLICATION_NAME负载名称 EUREKA_INSTANCE_HOSTNAME${MY_POD_NA…

单臂路由实验

单臂路由是一种在单个物理接口上配置多个逻辑接口&#xff0c;以实现不同VLAN间通信的技术。它通过在路由器接口上划分子接口&#xff0c;每个子接口对应一个VLAN网段&#xff0c;从而实现了VLAN间的互联互通。单臂路由能够重新封装MAC地址&#xff0c;转换VLAN标签&#xff0c…

1.微服务介绍

完整的微服务架构图 注册中心 配置中心 服务集群 服务网关 分布式缓存 分布式搜索 数据库集群 消息队列 分布式日志服务 系统监控链路追踪 Jenkins docker k8s 技术栈 微服务治理&#xff1a; 注册发现、远程调用、负载均衡、配置管理、网关路由、系统保护、流量…
最新文章