1.首先先讲一下代码
这是官方给的代码:torch_geometric.nn.conv.transformer_conv — pytorch_geometric documentation
import math
import typing
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (
Adj,
NoneType,
OptTensor,
PairTensor,
SparseTensor,
)
from torch_geometric.utils import softmax
if typing.TYPE_CHECKING:
from typing import overload
else:
from torch.jit import _overload_method as overload
[docs]class TransformerConv(MessagePassing):
r"""The graph transformer operator from the `"Masked Label Prediction:
Unified Message Passing Model for Semi-Supervised Classification"
<https://arxiv.org/abs/2009.03509>`_ paper.
.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},
where the attention coefficients :math:`\alpha_{i,j}` are computed via
multi-head dot product attention:
.. math::
\alpha_{i,j} = \textrm{softmax} \left(
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}
{\sqrt{d}} \right)
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
concat (bool, optional): If set to :obj:`False`, the multi-head
attentions are averaged instead of concatenated.
(default: :obj:`True`)
beta (bool, optional): If set, will combine aggregation and
skip information via
.. math::
\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +
(1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}
\alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}
with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}
[ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1
\mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`)
dropout (float, optional): Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default: :obj:`0`)
edge_dim (int, optional): Edge feature dimensionality (in case
there are any). Edge features are added to the keys after
linear transformation, that is, prior to computing the
attention dot product. They are also added to final values
after the same linear transformation. The model is:
.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(
\mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}
\right),
where the attention coefficients :math:`\alpha_{i,j}` are now
computed via:
.. math::
\alpha_{i,j} = \textrm{softmax} \left(
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}
(\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}
{\sqrt{d}} \right)
(default :obj:`None`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add the transformed root node features to the output and the
option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
_alpha: OptTensor
def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
heads: int = 1,
concat: bool = True,
beta: bool = False,
dropout: float = 0.,
edge_dim: Optional[int] = None,
bias: bool = True,
root_weight: bool = True,
**kwargs,
):
kwargs.setdefault('aggr', 'add')
super().__init__(node_dim=0, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.beta = beta and root_weight
self.root_weight = root_weight
self.concat = concat
self.dropout = dropout
self.edge_dim = edge_dim
self._alpha = None
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
self.lin_key = Linear(in_channels[0], heads * out_channels)
self.lin_query = Linear(in_channels[1], heads * out_channels)
self.lin_value = Linear(in_channels[0], heads * out_channels)
if edge_dim is not None:
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
else:
self.lin_edge = self.register_parameter('lin_edge', None)
if concat:
self.lin_skip = Linear(in_channels[1], heads * out_channels,
bias=bias)
if self.beta:
self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)
else:
self.lin_beta = self.register_parameter('lin_beta', None)
else:
self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)
if self.beta:
self.lin_beta = Linear(3 * out_channels, 1, bias=False)
else:
self.lin_beta = self.register_parameter('lin_beta', None)
self.reset_parameters()
[docs] def reset_parameters(self):
super().reset_parameters()
self.lin_key.reset_parameters()
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
if self.edge_dim:
self.lin_edge.reset_parameters()
self.lin_skip.reset_parameters()
if self.beta:
self.lin_beta.reset_parameters()
@overload
def forward(
self,
x: Union[Tensor, PairTensor],
edge_index: Adj,
edge_attr: OptTensor = None,
return_attention_weights: NoneType = None,
) -> Tensor:
pass
@overload
def forward( # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: Tensor,
edge_attr: OptTensor = None,
return_attention_weights: bool = None,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
pass
@overload
def forward( # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: SparseTensor,
edge_attr: OptTensor = None,
return_attention_weights: bool = None,
) -> Tuple[Tensor, SparseTensor]:
pass
[docs] def forward( # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: Adj,
edge_attr: OptTensor = None,
return_attention_weights: Optional[bool] = None,
) -> Union[
Tensor,
Tuple[Tensor, Tuple[Tensor, Tensor]],
Tuple[Tensor, SparseTensor],
]:
r"""Runs the forward pass of the module.
Args:
x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node
features.
edge_index (torch.Tensor or SparseTensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
return_attention_weights (bool, optional): If set to :obj:`True`,
will additionally return the tuple
:obj:`(edge_index, attention_weights)`, holding the computed
attention weights for each edge. (default: :obj:`None`)
"""
H, C = self.heads, self.out_channels
if isinstance(x, Tensor):
x = (x, x)
query = self.lin_query(x[1]).view(-1, H, C)
key = self.lin_key(x[0]).view(-1, H, C)
value = self.lin_value(x[0]).view(-1, H, C)
# propagate_type: (query: Tensor, key:Tensor, value: Tensor,
# edge_attr: OptTensor)
out = self.propagate(edge_index, query=query, key=key, value=value,
edge_attr=edge_attr)
alpha = self._alpha
self._alpha = None
if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)
if self.root_weight:
x_r = self.lin_skip(x[1])
if self.lin_beta is not None:
beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
beta = beta.sigmoid()
out = beta * x_r + (1 - beta) * out
else:
out = out + x_r
if isinstance(return_attention_weights, bool):
assert alpha is not None
if isinstance(edge_index, Tensor):
return out, (edge_index, alpha)
elif isinstance(edge_index, SparseTensor):
return out, edge_index.set_value(alpha, layout='coo')
else:
return out
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
if self.lin_edge is not None:
assert edge_attr is not None
edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
self.out_channels)
key_j = key_j + edge_attr
alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
alpha = softmax(alpha, index, ptr, size_i)
self._alpha = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
out = value_j
if edge_attr is not None:
out = out + edge_attr
out = out * alpha.view(-1, self.heads, 1)
return out
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, heads={self.heads})')
2.详细解释一下
几个重要的参数
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.
out_channels (int): Size of each output sample.
heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`)
怎么理解这几个参数?
in_channels
表示每个输入样本的大小。如果设置为整数,则表示所有输入样本的大小相同;如果设置为-1
,则表示输入样本的大小将从forward
方法的第一个输入中推导出来;如果设置为元组,则表示输入样本的大小对应于源维度和目标维度的大小。
out_channels
表示每个输出样本的大小,即经过卷积操作后产生的特征向量的维度大小。
当使用
tg.nn.TransformerConv
时,可以通过以下方式理解in_channels
和out_channels
:假设我们有一个图数据集,每个节点都有一个 10 维的特征向量表示。那么在这种情况下:
如果我们想将每个节点的特征向量作为输入,然后使用
tg.nn.TransformerConv
进行卷积操作,那么in_channels
应该设置为 10,表示每个输入样本的大小为 10。假设我们想将节点的特征向量转换为一个 16 维的特征向量,那么
out_channels
应该设置为 16,表示每个输出样本的大小为 16,即经过卷积操作后每个节点的特征向量将变为 16 维。在
tg.nn.TransformerConv
中,heads
参数表示多头注意力的数量。举个例子,如果heads
参数设置为 4,那么模型将学习 4 组注意力权重,每组权重都用于计算输入的不同子空间的注意力,然后将这些头的输出进行合并以产生最终的输出。
举个整体的例子:
我们有一个输入张量
x
,它的形状是(batch_size, seq_length, input_dim)
,其中:
batch_size
表示批量大小;seq_length
表示序列长度;input_dim
表示输入特征的维度。现在假设我们使用了
tg.nn.TransformerConv
,并设置heads=2
,那么模型将学习两组注意力权重,每组用于计算不同的注意力。输出张量的形状将取决于out_channels
参数,我们假设out_channels=64
。
import torch
import torch_geometric.nn as tg
# 假设输入张量的形状是 (batch_size, seq_length, input_dim)
x = torch.randn(32, 10, 128) # 32 个样本,每个样本有 10 个时间步,每个时间步有 128 个特征
# 创建 TransformerConv 模型,设置 heads=2,out_channels=64
conv_layer = tg.nn.TransformerConv(in_channels=128, out_channels=64, heads=2)
# 使用模型进行前向传播
output = conv_layer(x)
print("输出张量的形状:", output.shape)
2.1将特征映射到键值对中
在这里,通过线性变换层 Linear
,输入特征被转换成了键(key)、查询(query)和数值(value)的表示形式,以便用于多头自注意力机制。
具体来说:
self.lin_key
用于将输入特征(in_channels[0])映射到键的表示形式。self.lin_query
用于将输入特征(in_channels[1])映射到查询的表示形式。self.lin_value
用于将输入特征(in_channels[0])映射到数值的表示形式。
具体地,假设输入特征的维度是 (batch_size, num_nodes, in_channels)
,其中 batch_size
是批量大小,num_nodes
是节点数,in_channels
是输入特征的通道数。在映射到键的过程中,线性变换层的权重矩阵将是一个维度为 (in_channels, heads * out_channels)
的矩阵,其中 heads
是注意力头的数量,out_channels
是输出特征的通道数。因此,通过矩阵乘法运算,输入特征将被映射到一个新的特征空间,其维度为 (batch_size, num_nodes, heads, out_channels)
。在这个新的特征空间中,每个节点的每个头都有一个键表示。