ViT Vision Transformer超详细解析,网络构建,可视化,数据预处理,全流程实例教程

关于ViT的分析和教程,网上又虚又空的东西比较多,本文通过一个实例,将ViT全解析。

包括三部分内容,网络构建;orchview.draw_graph 将网络每一层的结构与输入输出可视化;数据预处理。附完整代码

网络构建

创建一个model.py,其中实现ViT网络构建

import torch.nn as nn
import torch
import torch.optim as optim
import torch.nn.functional as F
import lightning as L


class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x


class VisionTransformer(nn.Module):
    def __init__(
        self,
        embed_dim,
        hidden_dim,
        num_channels,
        num_heads,
        num_layers,
        num_classes,
        patch_size,
        num_patches,
        dropout=0.0,
    ):
        """
        Inputs:
            embed_dim - Dimensionality of the input feature vectors to the Transformer
            hidden_dim - Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels - Number of channels of the input (3 for RGB)
            num_heads - Number of heads to use in the Multi-Head Attention block
            num_layers - Number of layers to use in the Transformer
            num_classes - Number of classes to predict
            patch_size - Number of pixels that the patches have per dimension
            num_patches - Maximum number of patches an image can have
            dropout - Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        self.input_layer = nn.Linear(num_channels * (patch_size**2), embed_dim)
        self.transformer = nn.Sequential(
            *(AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers))
        )
        self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes))
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))

    def img_to_patch(self, x, patch_size, flatten_channels=True):
        """
        Inputs:
            x - Tensor representing the image of shape [B, C, H, W]
            patch_size - Number of pixels per dimension of the patches (integer)
            flatten_channels - If True, the patches will be returned in a flattened format
                               as a feature vector instead of a image grid.
        """
        B, C, H, W = x.shape
        x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
        x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H', W', C, p_H, p_W]
        x = x.flatten(1, 2)  # [B, H'*W', C, p_H, p_W]
        if flatten_channels:
            x = x.flatten(2, 4)  # [B, H'*W', C*p_H*p_W]
        return x

    def forward(self, x):
        # Preprocess input
        x = self.img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:, : T + 1]

        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        # Perform classification prediction
        cls = x[0]
        out = self.mlp_head(cls)
        return out


class ViT(L.LightningModule):
    def __init__(self, model_kwargs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode="train"):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log("%s_loss" % mode, loss)
        self.log("%s_acc" % mode, acc)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")

在其他文件中引入model.py,实现网络搭建

from model import ViT

model = ViT(model_kwargs={
        "embed_dim": 256,
        "hidden_dim": 512,
        "num_heads": 8,
        "num_layers": 6,
        "patch_size": 4,
        "num_channels": 3,
        "num_patches": 64,
        "num_classes": 10,
        "dropout": 0.2,
        },
        lr=3e-4,
    )

也可以下载预训练的模型

# Files to download
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/"
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/VisionTransformers/")
pretrained_files = [
    "tutorial15/ViT.ckpt",
    "tutorial15/tensorboards/ViT/events.out.tfevents.ViT",
    "tutorial5/tensorboards/ResNet/events.out.tfevents.resnet",
]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name.split("/", 1)[1])
    if "/" in file_name.split("/", 1)[1]:
        os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print("Downloading %s..." % file_url)
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n",
                e,
            )

pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
model = ViT.load_from_checkpoint(pretrained_filename)

torchview.draw_graph 网络可视化

model_graph = draw_graph(model, input_size=(1, 3, 16, 16))
model_graph.resize_graph(scale=5.0)
model_graph.visual_graph.render(format='svg')

运行这段代码,会生成一个svg格式的图片,显示网络结构和每一层的输入输出

训练数据准备

新建一个prepare_data.py

import os
import json
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms


class CustomDataset(Dataset):
    def __init__(self, image_dir, names, labels, transform=None):
        self.image_dir = image_dir
        self.names = names
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        name_ = self.names[idx]
        img_name = os.path.join(self.image_dir, name_)
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]

        return image, label


def load_img_ann(ann_path):
    """return [{img_name, [ (x, y, h, w, label), ... ]}]"""
    with open(ann_path) as fp:
        root = json.load(fp)
    img_dict = {}
    img_label_dict = {}
    for img_info in root['images']:
        img_id = img_info['id']
        img_name = img_info['file_name']
        img_dict[img_id] = {'name': img_name}
    for ann_info in root['annotations']:
        img_id = ann_info['image_id']
        img_category_id = ann_info['category_id']
        img_name = img_dict[img_id]['name']

        img_label_dict[img_id] = {'name': img_name, 'category_id': img_category_id}

    return img_label_dict


def get_dataloader():
    annota_dir = '/home/robotics/Downloads/coco_dataset/annotations/instances_val2017.json'
    img_dir = "/home/robotics/Downloads/coco_dataset/val2017"
    img_dict = load_img_ann(annota_dir)
    values = list(img_dict.values())
    img_names = []
    labels = []
    for item in values:
        category_id = item['category_id']
        labels.append(category_id)
        img_name = item['name']
        img_names.append(img_name)

    # 检查剔除黑白的图片
    img_names_rgb = []
    labels_rgb = []
    for i in range(len(img_names)):
        # 检查文件扩展名,确保它是图片文件(可以根据需要扩展支持的文件类型)
        file_path = os.path.join(img_dir, img_names[i])

        # 打开图片文件
        img = Image.open(file_path)

        # 获取通道数
        num_channels = img.mode
        if num_channels == "RGB" and labels[i] < 10:
            img_names_rgb.append(img_names[i])
            labels_rgb.append(labels[i])

    # 定义一系列图像转换操作
    transform = transforms.Compose([
        transforms.Resize((16, 16)),  # 调整图像大小
        transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化图像
    ])

    # 假设 image_dir 是包含所有图像文件的文件夹路径,labels 是标签列表
    train_set = CustomDataset(img_dir, img_names_rgb[-500:], labels_rgb[-500:], transform=transform)
    val_set = CustomDataset(img_dir, img_names_rgb[-500:-100], labels_rgb[-500:-100], transform=transform)
    test_set = CustomDataset(img_dir, img_names_rgb[-100:], labels_rgb[-100:], transform=transform)

    # 创建一个 DataLoader
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True, drop_last=False)
    val_loader = DataLoader(val_set, batch_size=32, shuffle=True, drop_last=False, num_workers=4)
    test_loader = DataLoader(test_set, batch_size=32, shuffle=True, drop_last=False, num_workers=4)

    return train_loader, val_loader, test_loader


if __name__ == "__main__":
    train_loader, val_loader, test_loader = get_dataloader()
    for batch in train_loader:
        print(batch)

解释一下上面的代码:

这里使用的是coco数据集的2017,可以在官网自行下载,下载下来以后,annotations包含如下内容

这里我们使用的是 instances_val2017.json,如果是正经做训练,应该用train2017,但是train2017文件太大了,处理起来速度很慢,本文仅为说明,不追求训练效果,所以使用val2017进行说明,instances就是用于图像识别的annotation,里面包括了每张图片的label和box,本文创建的ViT 不输出box,仅输出类别。函数

def load_img_ann(ann_path):

是为了将图片的id(每张图片的唯一主键),name和category_id(属于哪一个类别,也就是label)关联起来。

        # 获取通道数
        num_channels = img.mode
        if num_channels == "RGB" and labels[i] < 10:
            img_names_rgb.append(img_names[i])
            labels_rgb.append(labels[i])

注意coco数据集有单通道的黑白图片,要剔除,因为本文的ViT比较简单,输出只能10个类别,所以预处理图片的时候,只选择10个类别。

定义操作变换

    # 定义一系列图像转换操作
    transform = transforms.Compose([
        transforms.Resize((16, 16)),  # 调整图像大小
        transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化图像
    ])

创建一个自己的Dataset类,继承自 torch.utils.data.Dataset

class CustomDataset(Dataset):
    def __init__(self, image_dir, names, labels, transform=None):
        self.image_dir = image_dir
        self.names = names
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        name_ = self.names[idx]
        img_name = os.path.join(self.image_dir, name_)
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]

        return image, label

先创建Dataset,再创建dataloader,从Dataset取minibatch。

    # 假设 image_dir 是包含所有图像文件的文件夹路径,labels 是标签列表
    train_set = CustomDataset(img_dir, img_names_rgb[-500:], labels_rgb[-500:], transform=transform)
    val_set = CustomDataset(img_dir, img_names_rgb[-500:-100], labels_rgb[-500:-100], transform=transform)
    test_set = CustomDataset(img_dir, img_names_rgb[-100:], labels_rgb[-100:], transform=transform)

    # 创建一个 DataLoader
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True, drop_last=False)
    val_loader = DataLoader(val_set, batch_size=32, shuffle=True, drop_last=False, num_workers=4)
    test_loader = DataLoader(test_set, batch_size=32, shuffle=True, drop_last=False, num_workers=4)

以上,数据准备工作完成,对模型进行训练

    trainer = L.Trainer(
            default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
            accelerator="auto",
            devices=1,
            max_epochs=180,
            callbacks=[
                ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
                LearningRateMonitor("epoch"),
            ],
        )
    trainer.logger._log_graph = True  # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    trainer.fit(model, train_loader, val_loader)
    # Load best checkpoint after training
    model = ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    # Test best model on validation and test set
    val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}

完整代码:

一共包括三个文件:model.py 搭建网络的功能, prepare_data.py 数据预处理工作, main.py 网络训练

model.py内容:

import torch.nn as nn
import torch
import torch.optim as optim
import torch.nn.functional as F
import lightning as L


class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x


class VisionTransformer(nn.Module):
    def __init__(
        self,
        embed_dim,
        hidden_dim,
        num_channels,
        num_heads,
        num_layers,
        num_classes,
        patch_size,
        num_patches,
        dropout=0.0,
    ):
        """
        Inputs:
            embed_dim - Dimensionality of the input feature vectors to the Transformer
            hidden_dim - Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels - Number of channels of the input (3 for RGB)
            num_heads - Number of heads to use in the Multi-Head Attention block
            num_layers - Number of layers to use in the Transformer
            num_classes - Number of classes to predict
            patch_size - Number of pixels that the patches have per dimension
            num_patches - Maximum number of patches an image can have
            dropout - Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        self.input_layer = nn.Linear(num_channels * (patch_size**2), embed_dim)
        self.transformer = nn.Sequential(
            *(AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers))
        )
        self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes))
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))

    def img_to_patch(self, x, patch_size, flatten_channels=True):
        """
        Inputs:
            x - Tensor representing the image of shape [B, C, H, W]
            patch_size - Number of pixels per dimension of the patches (integer)
            flatten_channels - If True, the patches will be returned in a flattened format
                               as a feature vector instead of a image grid.
        """
        B, C, H, W = x.shape
        x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
        x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H', W', C, p_H, p_W]
        x = x.flatten(1, 2)  # [B, H'*W', C, p_H, p_W]
        if flatten_channels:
            x = x.flatten(2, 4)  # [B, H'*W', C*p_H*p_W]
        return x

    def forward(self, x):
        # Preprocess input
        x = self.img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:, : T + 1]

        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        # Perform classification prediction
        cls = x[0]
        out = self.mlp_head(cls)
        return out


class ViT(L.LightningModule):
    def __init__(self, model_kwargs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode="train"):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log("%s_loss" % mode, loss)
        self.log("%s_acc" % mode, acc)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")

 prepare_data.py内容:

import os
import json
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms


class CustomDataset(Dataset):
    def __init__(self, image_dir, names, labels, transform=None):
        self.image_dir = image_dir
        self.names = names
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        name_ = self.names[idx]
        img_name = os.path.join(self.image_dir, name_)
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]

        return image, label


def load_img_ann(ann_path):
    """return [{img_name, [ (x, y, h, w, label), ... ]}]"""
    with open(ann_path) as fp:
        root = json.load(fp)
    img_dict = {}
    img_label_dict = {}
    for img_info in root['images']:
        img_id = img_info['id']
        img_name = img_info['file_name']
        img_dict[img_id] = {'name': img_name}
    for ann_info in root['annotations']:
        img_id = ann_info['image_id']
        img_category_id = ann_info['category_id']
        img_name = img_dict[img_id]['name']

        img_label_dict[img_id] = {'name': img_name, 'category_id': img_category_id}

    return img_label_dict


def get_dataloader():
    annota_dir = '/home/robotics/Downloads/coco_dataset/annotations/instances_val2017.json'
    img_dir = "/home/robotics/Downloads/coco_dataset/val2017"
    img_dict = load_img_ann(annota_dir)
    values = list(img_dict.values())
    img_names = []
    labels = []
    for item in values:
        category_id = item['category_id']
        labels.append(category_id)
        img_name = item['name']
        img_names.append(img_name)

    # 检查剔除黑白的图片
    img_names_rgb = []
    labels_rgb = []
    for i in range(len(img_names)):
        # 检查文件扩展名,确保它是图片文件(可以根据需要扩展支持的文件类型)
        file_path = os.path.join(img_dir, img_names[i])

        # 打开图片文件
        img = Image.open(file_path)

        # 获取通道数
        num_channels = img.mode
        if num_channels == "RGB" and labels[i] < 10:
            img_names_rgb.append(img_names[i])
            labels_rgb.append(labels[i])

    # 定义一系列图像转换操作
    transform = transforms.Compose([
        transforms.Resize((16, 16)),  # 调整图像大小
        transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化图像
    ])

    # 假设 image_dir 是包含所有图像文件的文件夹路径,labels 是标签列表
    train_set = CustomDataset(img_dir, img_names_rgb[-500:], labels_rgb[-500:], transform=transform)
    val_set = CustomDataset(img_dir, img_names_rgb[-500:-100], labels_rgb[-500:-100], transform=transform)
    test_set = CustomDataset(img_dir, img_names_rgb[-100:], labels_rgb[-100:], transform=transform)

    # 创建一个 DataLoader
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True, drop_last=False)
    val_loader = DataLoader(val_set, batch_size=32, shuffle=True, drop_last=False, num_workers=4)
    test_loader = DataLoader(test_set, batch_size=32, shuffle=True, drop_last=False, num_workers=4)

    return train_loader, val_loader, test_loader


if __name__ == "__main__":
    train_loader, val_loader, test_loader = get_dataloader()
    for batch in train_loader:
        print(batch)

main.py内容:

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # 下面老是报错 shape 不一致
import urllib.request
from urllib.error import HTTPError
import lightning as L
from model import ViT
from torchview import draw_graph
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

from prepare_data import get_dataloader


# 加载模型
# Files to download
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/"
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/VisionTransformers/")
pretrained_files = [
    "tutorial15/ViT.ckpt",
    "tutorial15/tensorboards/ViT/events.out.tfevents.ViT",
    "tutorial5/tensorboards/ResNet/events.out.tfevents.resnet",
]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name.split("/", 1)[1])
    if "/" in file_name.split("/", 1)[1]:
        os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print("Downloading %s..." % file_url)
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n",
                e,
            )

pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
needTrain = False
if not os.path.isfile(pretrained_filename):
    print("Found pretrained model at %s, loading..." % pretrained_filename)
    # Automatically loads the model with the saved hyperparameters
    model = ViT.load_from_checkpoint(pretrained_filename)
else:
    L.seed_everything(42)  # To be reproducable
    model = ViT(model_kwargs={
        "embed_dim": 256,
        "hidden_dim": 512,
        "num_heads": 8,
        "num_layers": 6,
        "patch_size": 4,
        "num_channels": 3,
        "num_patches": 64,
        "num_classes": 10,
        "dropout": 0.2,
        },
        lr=3e-4,
    )
    needTrain = True

# 网络结构可视化
model_graph = draw_graph(model, input_size=(1, 3, 16, 16))
model_graph.resize_graph(scale=5.0)
model_graph.visual_graph.render(format='svg')

# 准备训练数据
train_loader, val_loader, test_loader = get_dataloader()

if needTrain:
    trainer = L.Trainer(
            default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
            accelerator="auto",
            devices=1,
            max_epochs=180,
            callbacks=[
                ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
                LearningRateMonitor("epoch"),
            ],
        )
    trainer.logger._log_graph = True  # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    trainer.fit(model, train_loader, val_loader)
    # Load best checkpoint after training
    model = ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    # Test best model on validation and test set
    val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/113880.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

通过USM(U盘魔术大师)在PE环境下使用分区助手拷贝磁盘——无损升级硬盘

这里写自定义目录标题 背景本次使用技术步骤1、添加新硬盘2、添加PE3、开机进入BIOS&#xff0c;进入PE4、开始拷贝磁盘5、调整分区5.1 删除系统盘前的所有分区5.2 修改硬盘分区表格式为GUID5.3 新建引导分区 6、修复引导7、大功告成 背景 由于硬盘空间不够的时候就需要更换硬盘…

桥接模式birdge

简介 桥接模式&#xff1a;将抽象与实现相分离&#xff0c;使他们可以独立变化。 角色 抽象化&#xff08;Abstraction&#xff09;角色&#xff1a; 该类持有一个对实现角色的引用&#xff0c;抽象角色中的方法需要实现角色来实现&#xff0c;抽象角色一般为抽象类&#xf…

0基础学习PyFlink——个数滚动窗口(Tumbling Count Windows)

大纲 Tumbling Count WindowsmapreduceWindow Size为2Window Size为3Window Size为4Window Size为5Window Size为6 完整代码参考资料 之前的案例中&#xff0c;我们的Source都是确定内容的数据。而Flink是可以处理流式&#xff08;Streaming&#xff09;数据的&#xff0c;就是…

每日自动化提交git

目前这个功能&#xff0c;有个前提&#xff1a; 这个git代码仓库&#xff0c;是一个人负责&#xff0c;所以不存在冲突问题 我这个仓库地址下载后的本地路径是&#xff1a;D:\Projects\Tasks 然后我在另外一个地方新建了一个bat文件&#xff1a; bat文件所在目录为&#xff1a…

LeetCode----124. 二叉树中的最大路径和

题目 二叉树中的 路径 被定义为一条节点序列,序列中每对相邻节点之间都存在一条边。同一个节点在一条路径序列中 至多出现一次 。该路径 至少包含一个 节点,且不一定经过根节点。 路径和 是路径中各节点值的总和。 给你一个二叉树的根节点 root ,返回其 最大路径和 。 示…

【实战Flask API项目指南】之二 Flask基础知识

实战Flask API项目指南之 Flask基础知识 本系列文章将带你深入探索实战Flask API项目指南&#xff0c;通过跟随小菜的学习之旅&#xff0c;你将逐步掌握Flask 在实际项目中的应用。让我们一起踏上这个精彩的学习之旅吧&#xff01; 前言 当小菜踏入Flask后端开发的世界&…

【OpenCV实现图像找到轮廓的不同特征,就像面积,周长,质心,边界框等等。】

文章目录 概要图像矩凸包边界矩形 概要 OpenCV是一个流行的计算机视觉库&#xff0c;它提供了许多图像处理和分析功能&#xff0c;其中包括查找图像中物体的轮廓。通过查找轮廓&#xff0c;可以提取许多有用的特征&#xff0c;如面积、周长、质心、边界框等。 以下是几种使用…

GD32F303RCT6不小心将读保护开启,导致后续程序烧不进去的解决办法

这里写自定义目录标题 错误现象判断读保护开启的方法用JLink-commander查看选项字节地址处的值 解锁读保护 错误现象 用j-flash v7.68b软件通过ARM仿真器设置接口为SWD烧录编译好的目标.bin文件&#xff0c;第一次烧录成功&#xff0c;后面再也烧录不进&#xff0c;出先现象 如…

Hadoop相关知识点

文章目录 一、主要命令二、配置虚拟机2.1 设置静态ip2.2 修改主机名及映射2.3 修改映射2.4 单机模式2.5 伪分布式2.6 完全分布式 三、初识Hadoop四、三种模式的区别4.1、单机模式与伪分布式模式的区别4.2、特点4.3、配置文件的差异4.3.1、单机模式4.3.2、伪分布式模式4.3.3、完…

基于nodejs+vue客户管理管理系统

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性&#xff1a;…

智慧公厕:科技赋予公共卫生新生命,提升城市管理品质

在现代化城市中&#xff0c;公共卫生设施的发展与提升一直是对城市管理者和市民的共同期望。然而&#xff0c;传统的公共厕所常常令人困扰&#xff0c;脏乱臭成为难题。为了解决这一难题&#xff0c;广州中期科技科技有限公司全新升级的智慧公厕整体解决方案&#xff0c;补誉为…

SpringBoot3.* 集成又拍云上传组件

集成使用 添加Maven依赖 <!--又拍云--> <dependency><groupId>com.upyun</groupId><artifactId>java-sdk</artifactId><version>4.2.3</version> </dependency>代码编写 PostMapping("/common/upload") pu…

MySQL 优化思路篇

MySQL 优化思路篇 1、MySQL 查询的优化步骤2、查询系统性能参数3、慢查询日志定位问题3.1、开启慢查询日志参数3.2、查看慢查询数目3.3、慢查询日志的分析工具 mysqldumpslow3.4、关闭慢查询日志3.5、慢查询日志的删除与重建 4、SHOW PROFILE &#xff1a;查看SQL执行成本 1、…

React基础知识02

一、通过属性来传值&#xff08;props&#xff09; react中可以使用属性&#xff08;props&#xff09;可以传递给子组件&#xff0c;子组件可以使用这些属性值来控制其行为和呈现输出。 例子&#xff1a; // 1.1 父组件 import React, { useState } from react // 1.2引入子…

测试老鸟,Python接口自动化测试框架搭建-全过程,看这篇就够了...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、接口测试自动化…

Django实战项目-学习任务系统-查询列表分页显示

接着上期代码框架&#xff0c;6个主要功能基本实现&#xff0c;剩下的就是细节点的完善优化了。 接着优化查询列表分页显示功能&#xff0c;有很多菜单功能都有查询列表显示页面情况&#xff0c;如果数据量多&#xff0c;不分页显示的话&#xff0c;页面展示效果就不太好。 本…

Centos虚拟机安装配置与MobaXterm工具及Linux常用命令

目录 一、Centos操作系统 1.1 Centos介绍 1.2 Centos虚拟机安装 1.3 配置centos的镜像 1.4 虚拟机开机初始设置 1.4.1 查看网络配置 1.4.2 编辑网络配置 二、MobaXterm工具 2.1 MobaXterm介绍 2.2 MobaXterm安装 2.3 切换国内源 三、Linux常用命令和模式 3.1 查看网络配置 …

前端工程师的摸鱼日常(20)

今年一整年状态都不怎么好&#xff0c;所以别说摸鱼文了&#xff0c;其他技术文章都没写几篇&#xff0c;发生的事情有点多&#xff0c;无暇顾及这些&#xff0c;当然最主要的一个原因还是因为懒&#xff01; 有很多时候我都觉得人的大脑是单线程的&#xff0c;在处理一件事情…

Vue 事件绑定 和 修饰符

目录 一、事件绑定 1.简介 : 2.实例 : 二、修饰符 1.简介 : 2.实例 : 3.扩展 : 一、事件绑定 1.简介 : (1) 在Vue中&#xff0c;通过"v-on:事件名"可以绑定事件&#xff0c;eg : v-on:click表示绑定点击事件。 (2) 触发事件时调用的方法&#xff0c;定义在Vu…

HBase理论与实践-基操与实践

基操 启动&#xff1a; ./bin/start-hbase.sh 连接 ./bin/hbase shell help命令 输入 help 然后 <RETURN> 可以看到一列shell命令。这里的帮助很详细&#xff0c;要注意的是表名&#xff0c;行和列需要加引号。 建表&#xff0c;查看表&#xff0c;插入数据&#…
最新文章