Code Lab - 2

pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.2+cu102.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.10.2+cu102.html
pip install torch-geometric
pip install ogb

1. PyG Datasets

PyG有两个类,用于存储图以及将图转换为Tensor格式
        torch_geometry.datasets 包含各种常见的图形数据集
        torch_geometric.data 提供Tensor的图数据处理

1.1 从torch_geometric.datasets中读取数据集

# 每一个dataset都是多张图的list,单张图的类型为torch_geometric.data.Data
import torch
import os
from torch_geometric.datasets import TUDataset

root = './enzymes'
name = 'ENZYMES'

# The ENZYMES dataset
pyg_dataset= TUDataset('./enzymes', 'ENZYMES')

# You can find that there are 600 graphs in this dataset
print(pyg_dataset)
print(type(pyg_dataset))

# 对于第一张图
# print(pyg_dataset[0])
# print(pyg_dataset[0].num_nodes)
# print(pyg_dataset[0].edge_index)
# print(pyg_dataset[0].x)
# print(pyg_dataset[0].y)

1.2 ENZYMES数据集中类别数量和特征维度

# num_classes
def get_num_classes(pyg_dataset):
    num_classes = pyg_dataset.num_classes
    return num_classes
# num_features
def get_num_features(pyg_dataset):
    num_features = pyg_dataset.num_node_features
    return num_features

num_classes = get_num_classes(pyg_dataset)
num_features = get_num_features(pyg_dataset)
print("{} dataset has {} classes".format(name, num_classes))
print("{} dataset has {} features".format(name, num_features))

1.3 ENZYMES数据集中第idx张图的label和边的数量

def get_graph_class(pyg_dataset, idx):
    # y就是这张图的类别
    label = pyg_dataset[idx].y.item()
    return label

graph_0 = pyg_dataset[0]
print(graph_0)
idx = 100
label = get_graph_class(pyg_dataset, idx)
print('Graph with index {} has label {}'.format(idx, label))


def get_graph_num_edges(pyg_dataset, idx):
    num_edges = pyg_dataset[idx].num_edges/2 # 无向图
    return num_edges

idx = 200
num_edges = get_graph_num_edges(pyg_dataset, idx)
print('Graph with index {} has {} edges'.format(idx, num_edges))

2. Open Graph Benchmark (OGB)

OGB是基准数据集的集合,使用OGB数据加载器自动下载、处理和拆分,通过OGB Evaluator以统一的方式来评估模型性能

2.1 读取OBG的数据集(以ogbn-arxiv为例)

import ogb
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset

dataset_name = 'ogbn-arxiv'
# 加载ogbn-arxiv数据集,并使用ToSparseTensor转换成Tensor格式
dataset = PygNodePropPredDataset(name=dataset_name,
                                 transform=T.ToSparseTensor())
print('The {} dataset has {} graph'.format(dataset_name, len(dataset)))

# 第一张图
data = dataset[0]
print(data)

def graph_num_features(data):
    num_features=data.num_features
    return num_features

num_features = graph_num_features(data)
print('The graph has {} features'.format(num_features))

3. GNN节点属性预测(节点分类)(以ogbn-arxiv数据集为例)

3.1 加载并预处理数据集

import torch
import pandas as pd
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator

device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset_name = 'ogbn-arxiv'
# PygNodePropPredDataset读取节点分类的数据集
dataset = PygNodePropPredDataset(name=dataset_name,
                              transform=T.ToSparseTensor())
data = dataset[0]

# 转换成稀疏矩阵
data.adj_t = data.adj_t.to_symmetric()
print(type(data.adj_t))

data = data.to(device)
# 用get_idx_split划分数据集为train,valid,test三部分
split_idx = dataset.get_idx_split()
print(split_idx)
train_idx = split_idx['train'].to(device)

3.2 GCN Model

class GCN(torch.nn.Module):
    
## Note:
## 1. You should use torch.nn.ModuleList for self.convs and self.bns
## 2. self.convs has num_layers GCNConv layers
## 3. self.bns has num_layers - 1 BatchNorm1d layers
## 4. You should use torch.nn.LogSoftmax for self.softmax
## 5. The parameters you can set for GCNConv include 'in_channels' and 
## 'out_channels'. For more information please refer to the documentation:
## https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv
## 6. The only parameter you need to set for BatchNorm1d is 'num_features'

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers,
                 dropout, return_embeds=False):
        
        super(GCN, self).__init__()

        # A list of GCNConv layers
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(GCNConv(input_dim, hidden_dim))
            input_dim = hidden_dim
        self.convs.append(GCNConv(hidden_dim, output_dim))
        
        # A list of 1D batch normalization layers
        self.bns = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(torch.nn.BatchNorm1d(hidden_dim))

        # The log softmax layer
        self.softmax = torch.nn.LogSoftmax()

        # Probability of an element getting zeroed
        self.dropout = dropout

        # Skip classification layer and return node embeddings
        self.return_embeds = return_embeds

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

## Note:
## 1. Construct the network as shown in the figure
## 2. torch.nn.functional.relu and torch.nn.functional.dropout are useful
## For more information please refer to the documentation:
## https://pytorch.org/docs/stable/nn.functional.html
## 3. Don't forget to set F.dropout training to self.training
## 4. If return_embeds is True, then skip the last softmax layer

    def forward(self, x, adj_t):
        for layer in range(len(self.convs)-1):
            x=self.convs[layer](x,adj_t)
            x=self.bns[layer](x)
            x=F.relu(x)
            x=F.dropout(x,self.dropout,self.training)
        out=self.convs[-1](x,adj_t)
        if not self.return_embeds:
            out=self.softmax(out)
        return out

3.3 训练和评估 

def train(model, data, train_idx, optimizer, loss_fn):
    model.train()
    optimizer.zero_grad()
    out = model(data.x,data.adj_t)
    # 计算训练部分的loss
    train_output = out[train_idx]
    train_label = data.y[train_idx,0]
    loss = loss_fn(train_output,train_label)
    loss.backward()
    optimizer.step()

    return loss.item()

# 注意data.y[train_idx]和data.y[train_idx,0]的区别
print(data.y[train_idx])
print(data.y[train_idx,0])
def test(model, data, split_idx, evaluator, save_model_results=False):
    model.eval()
    out = model(data.x,data.adj_t)
    y_pred = out.argmax(dim=-1, keepdim=True)
    # 使用OGB Evaluator进行评估
    train_acc = evaluator.eval({
        'y_true': data.y[split_idx['train']],
        'y_pred': y_pred[split_idx['train']],
    })['acc']
    valid_acc = evaluator.eval({
        'y_true': data.y[split_idx['valid']],
        'y_pred': y_pred[split_idx['valid']],
    })['acc']
    test_acc = evaluator.eval({
        'y_true': data.y[split_idx['test']],
        'y_pred': y_pred[split_idx['test']],
    })['acc']

    return train_acc, valid_acc, test_acc
args = {
    'device': device,
    'num_layers': 3,
    'hidden_dim': 256,
    'dropout': 0.5,
    'lr': 0.01,
    'epochs': 100,
}
model = GCN(data.num_features, args['hidden_dim'],
            dataset.num_classes, args['num_layers'],
            args['dropout']).to(device)
evaluator = Evaluator(name='ogbn-arxiv')


# 初始化模型参数
model.reset_parameters()

optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
loss_fn = F.nll_loss

for epoch in range(1, args["epochs"]+1):
    loss = train(model, data, train_idx, optimizer, loss_fn)
    train_acc, valid_acc, test_acc = test(model, data, split_idx, evaluator)
    print(f'Epoch: {epoch:02d}, '
        f'Loss: {loss:.4f}, '
        f'Train: {100 * train_acc:.2f}%, '
        f'Valid: {100 * valid_acc:.2f}% '
        f'Test: {100 * test_acc:.2f}%')

4. GNN图属性预测(图分类)(以ogbn-arxiv数据集为例)

4.1 加载并预处理数据集

from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from torch_geometric.data import DataLoader
from tqdm.notebook import tqdm

# PygGraphPropPredDataset读取图分类的数据集
dataset = PygGraphPropPredDataset(name='ogbg-molhiv')
split_idx = dataset.get_idx_split()

# DataLoader
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, num_workers=0)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, num_workers=0)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, num_workers=0)

4.2 GCN Model

from ogb.graphproppred.mol_encoder import AtomEncoder
from torch_geometric.nn import global_add_pool, global_mean_pool

class GCN_Graph(torch.nn.Module):
    def __init__(self, hidden_dim, output_dim, num_layers, dropout):
        super(GCN_Graph, self).__init__()

        # encoders
        self.node_encoder = AtomEncoder(hidden_dim)

        # 通过GCN
        self.gnn_node = GCN(hidden_dim, hidden_dim,
            hidden_dim, num_layers, dropout, return_embeds=True)
        
        # 全局池化
        self.pool = global_mean_pool
        self.linear = torch.nn.Linear(hidden_dim, output_dim)


    def reset_parameters(self):
        self.gnn_node.reset_parameters()
        self.linear.reset_parameters()

    def forward(self, batched_data):
        x, edge_index, batch = batched_data.x, batched_data.edge_index, batched_data.batch
        embed = self.node_encoder(x)
        
        out = self.gnn_node(embed,edge_index)
        out = self.pool(out,batch)
        out = self.linear(out)
        return out

4.3 训练和评估

def train(model, device, data_loader, optimizer, loss_fn):
    model.train()
    for step, batch in enumerate(tqdm(data_loader, desc="Iteration")):
        batch = batch.to(device)
        # 跳过不完整的batch
        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            # 过滤掉无标签的数据
            is_labeled = (batch.y == batch.y)
            optimizer.zero_grad()
            op = model(batch)
            train_op = op[is_labeled]
            train_labels = batch.y[is_labeled].view(-1)
            loss = loss_fn(train_op.float(),train_labels.float())
            loss.backward()
            optimizer.step()

    return loss.item()
def eval(model, device, loader, evaluator):
    model.eval()
    y_true = []
    y_pred = []
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)
        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                pred = model(batch)
            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_pred = torch.cat(y_pred, dim = 0).numpy()
    input_dict = {"y_true": y_true, "y_pred": y_pred}
    result = evaluator.eval(input_dict)
    return result
args = {
    'device': device,
    'num_layers': 5,
    'hidden_dim': 256,
    'dropout': 0.5,
    'lr': 0.001,
    'epochs': 30,
}

model = GCN_Graph(args['hidden_dim'],
            dataset.num_tasks, args['num_layers'],
            args['dropout']).to(device)
evaluator = Evaluator(name='ogbg-molhiv')

model.reset_parameters()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
loss_fn = torch.nn.BCEWithLogitsLoss()


for epoch in range(1, 1 + args["epochs"]):
    print('Training...')
    loss = train(model, device, train_loader, optimizer, loss_fn)

    print('Evaluating...')
    train_result = eval(model, device, train_loader, evaluator)
    val_result = eval(model, device, valid_loader, evaluator)
    test_result = eval(model, device, test_loader, evaluator)

    train_acc, valid_acc, test_acc = train_result[dataset.eval_metric], val_result[dataset.eval_metric], test_result[dataset.eval_metric]
    print(f'Epoch: {epoch:02d}, '
        f'Loss: {loss:.4f}, '
        f'Train: {100 * train_acc:.2f}%, '
        f'Valid: {100 * valid_acc:.2f}% '
        f'Test: {100 * test_acc:.2f}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/86155.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

学Python静不下来,看了一堆资料还是很迷茫是为什么

一、前言 最近发现,身边很多的小伙伴学Python都会遇到一个问题,就是资料也看了很多,也花了很多时间去学习但还是很迷茫,时间长了又发现之前学的知识点很多都忘了,都萌生出了想半路放弃的想法。 让我们看看蚂蚁金服的大…

IDEA:Error running,Command line is too long. 解决方法

报错如下: Error running SendSmsUtil. Command line is too long. Shorten the command line via JAR manifest or via a classpath file and rerun.原因是启动命令过长。 解决方法: 1、打开Edit Configurations 2、点击Modify options设置&#x…

关于android studio 几个简单的问题说明

自信是成功的第一步。——爱迪生 1. android studio 如何运行不同项目是否要更换不同的sdk 和 gradle 2.编译Gradle总是错误为什么 3.如何清理android studio 的缓存 4. 关于android Studio中的build 下面的rebuild project

Kafka基本使用

查看Kafka的进程是否在运行 #命令行终端中运行如下命令 ps -ef | grep kafkafind / -iname kafka-server-start.shcd /usr/local/kafka/bin/#启动kafka ./kafka-server-start.sh -daemon /usr/local/kafka/config/server.propertiesKafka默认使用9092端口提供服务&#xf…

使用opencv-python在图片上显示中文

测试图像如下: 核心代码如下: import cv2 import numpy as np from PIL import Image, ImageDraw, ImageFontdef cv2ImgAddText(img, text, left, top, textColor(0, 255, 0), textSize20):if (isinstance(img, np.ndarray)): #判断是否OpenCV图片类型…

javaScript:七夕特辑-对象的认识与应用(包含日期对象及相关案例)

目录 一.前言 二.认识对象 在js中声明对象的方法 1.直接使用{}声明对象 2.使用构造函数创建对象 获取属性的值 执行对象方法 解释 三.对象的应用 代码 效果图 ​编辑 四.日期对象 1.Date 日期对象 2. getFullYear() 获取当前年份 3.getMonth() 获取当前日期对象…

记一次由于整型参数错误导致的任意文件上传

当时误打误撞发现的,觉得挺奇葩的,记录下 一个正常的图片上传的点,文件类型白名单 但是比较巧的是当时刚对上面的id进行过注入测试,有一些遗留的测试 payload 没删,然后在测试上传的时候就发现.php的后缀可以上传了&a…

初识 JVM 01

JVM JRE JDK的关系 JVM 的内存机构 程序计数器 java指令的执行流程: 1 右侧的java源代码编译为左侧的java字节码(右侧第一个方块对应左侧第一个方块) 2 字节码 经过解释器 变为机器码 3 机器码就可以被cpu来执行 程序计数器的作用就…

C语言<自定义类型>结构体、枚举、联合

✨Blog:🥰不会敲代码的小张:)🥰 🉑推荐专栏:C语言🤪、Cpp😶‍🌫️、数据结构初阶💀 💽座右铭:“記住,每一天都是一個新的開始&#x1…

壁仞科技与百度飞桨完成II级兼容性测试

近日,壁仞科技BR104通用GPU与百度飞桨已完成II级兼容性测试。测试结果显示,双方兼容性表现良好,整体运行稳定。这是壁仞科技加入飞桨“硬件生态共创计划”后的阶段性成果。产品兼容性证明本次II级兼容性测试完成了涵盖自然语言处理、计算机视…

七月 NFT 行业解读:游戏和音乐 NFT 引领增长,Opepen 掀起热潮

作者:lesleyfootprint.network 2023 年 7 月,NFT 市场的波动性持续存在,交易量呈下降趋势。然而,游戏和音乐 NFT 等领域的增长引人注目。参与这些细分领域的独立用户数量不断增加,反映了这些领域的复苏。 本综合报告…

opencv 进阶17-使用K最近邻和比率检验过滤匹配(图像匹配)

K最近邻(K-Nearest Neighbors,简称KNN)和比率检验(Ratio Test)是在计算机视觉中用于特征匹配的常见技术。它们通常与特征描述子(例如SIFT、SURF、ORB等)一起使用,以在图像中找到相似…

分布式锁系列之zookeeper分布式锁和mysql分布式锁

目录 介绍 下载安装 基本指令​编辑 java集成zookeeper 官方提供版 永久节点 临时节点​编辑 永久序列化节点 判断当前节点是否存在 获取当前节点中的数据内容 获取当前节点的子节点 更新节点内容 删除节点 zookeeper实现分布式锁 Mysql实现分布式锁 总结 介绍 ZooK…

【Spring框架】Spring事务的介绍与使用方法

⚠️ 再提醒一次:Spring 本身并不实现事务,Spring事务 的本质还是底层数据库对事务的支持。你的程序是否支持事务首先取决于数据库 ,比如使用 MySQL 的话,如果你选择的是 innodb 引擎,那么恭喜你,是可以支持…

2023年计算机设计大赛国三 数据可视化 (源码可分享)

2023年暑假参加了全国大学生计算机设计大赛,并获得了国家三等奖(国赛答辩出了点小插曲)。在此分享和记录本次比赛的经验。 目录 一、作品简介二、作品效果图三、设计思路四、项目特色 一、作品简介 本项目实现对农产品近期发展、电商销售、灾…

基于AVR128单片机世界电子时钟的设计

一、系统方案 上电初始化完成系统初始化,液晶滚动显示北京、莫斯科、东京、伦敦、巴黎、纽约等六个城市的标准时间,显示的内容包括地区名及相应地区的年、月、日、星期、时、分、秒。 使用K1按键控制滚动显示或稳定显示某个地区的时间。 使用K3、K4、K5按…

Linux操作系统面试题汇总

Linux操作系统 1.Linux操作命令 Linux 目录结构及常用命令详细介绍参考 2.在Linux中find和grep的区别? 在Linux中,find命令用于按照指定条件搜索文件或目录,而grep命令则用于在文件中搜索指定的文本字符串。具体来说,find命令可…

使用扩展函数方式,在Winform界面中快捷的绑定树形列表TreeList控件和

在一些字典绑定中,往往为了方便展示详细数据,需要把一些结构树展现在树列表TreeList控件中或者下拉列表的树形控件TreeListLookUpEdit控件中,为了快速的处理数据的绑定操作,比较每次使用涉及太多细节的操作,我们可以把…

深入了解Git:介绍及常用命令指南

当今软件开发领域中,版本控制是一个至关重要的概念,而Git作为最流行的分布式版本控制系统,发挥着不可替代的作用。本文将介绍Git的基本概念以及常用命令,帮助你更好地理解和使用这一强大的工具。 Git简介 Git是一种分布式版本管…

RK3399平台开发系列讲解(内核调试篇)网络调试工具

🚀返回专栏总目录 文章目录 一、dstat 工具介绍二、例如dstat 进行网络问题调试三、ss 命令查看 TCP 详细信息四、netstat 查看TCP详细信息沉淀、分享、成长,让自己和他人都能有所收获!😄 📢 本篇将介绍网络的相关工具。 一、dstat 工具介绍 当设备产生问题,而我们又…
最新文章