anomalib1.0学习纪实-续3:结合python lightning理思路

一、python lightning

python lightning是个好东西,但不见得那么友好。

GPT4给我讲解了他的用法:

 

 

二、anomalib的思路

 1、 创建一个Lightning Module。

首先,在src\anomalib\models\components\base\anomaly_module.py中, 

class AnomalyModule(pl.LightningModule, ABC):
    """AnomalyModule to train, validate, predict and test images.

    Acts as a base class for all the Anomaly Modules in the library.
    """

    def __init__(self) -> None:
        super().__init__()
        logger.info("Initializing %s model.", self.__class__.__name__)

        self.save_hyperparameters()
        self.model: nn.Module
        self.loss: nn.Module
        self.callbacks: list[Callback]

        self.image_threshold: BaseThreshold
        self.pixel_threshold: BaseThreshold

        self.normalization_metrics: Metric

        self.image_metrics: AnomalibMetricCollection
        self.pixel_metrics: AnomalibMetricCollection

    def forward(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> Any:  # noqa: ANN401
        """Perform the forward-pass by passing input tensor to the module.

        Args:
            batch (dict[str, str | torch.Tensor]): Input batch.
            *args: Arguments.
            **kwargs: Keyword arguments.

        Returns:
            Tensor: Output tensor from the model.
        """
        del args, kwargs  # These variables are not used.

        return self.model(batch)

    def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT:
        """To be implemented in the subclasses."""
        raise NotImplementedError

    def predict_step(
        self,
        batch: dict[str, str | torch.Tensor],
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> STEP_OUTPUT:
        """Step function called during :meth:`~lightning.pytorch.trainer.Trainer.predict`.

        By default, it calls :meth:`~lightning.pytorch.core.lightning.LightningModule.forward`.
        Override to add any processing logic.

        Args:
            batch (Any): Current batch
            batch_idx (int): Index of current batch
            dataloader_idx (int): Index of the current dataloader

        Return:
            Predicted output
        """
        del batch_idx, dataloader_idx  # These variables are not used.

        return self.validation_step(batch)
。。。以下省略

定义了一堆类似 def forward的虚函数,都有待于他的之类去实现。

这里就要说一下,在python中,也有类似c++中虚函数的概念吗?

GPT给了我回答,是的。只不过,在c++中,虚函数需要明确指出,但是在python中,在Python中实现类似C++虚函数的行为,主要依靠方法重写(Override)。当子类重写了父类的方法时,无论是通过对象直接调用该方法,还是通过父类的接口调用,实际执行的都是子类中重写的方法。这使得我们可以在子类中改变或扩展在父类中定义的行为,这与C++中虚函数的目的是一致的。

所以我们看一下,在我们自己搞的Ddad类中,如何重写了AnomalyModule类的一些方法。

在src\anomalib\models\image\ddad\lightning_model.py中,

class Ddad(MemoryBankMixin, AnomalyModule):
    """Ddad: a Patch Distribution Modeling Framework for Anomaly Detection and Localization.

    

。。。省略

    @staticmethod
    def configure_optimizers() -> None:
        """Ddad doesn't require optimization, therefore returns no optimizers."""
        return
    

    def on_train_epoch_start (self)-> None:
        print("----------------------------------------on_train_epoch_start")


    def prepare_data(self) -> None:
         print("----------------------------------------prepare_data")

    def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> None:
        print("---------------training_step")
        """Perform the training step of Ddad. For each batch, hierarchical features are extracted from the CNN.

        Args:
            batch (dict[str, str | torch.Tensor]): Batch containing image filename, image, label and mask
            args: Additional arguments.
            kwargs: Additional keyword arguments.

        Returns:
            Hierarchical feature map
        """
        
        
        del args, kwargs  # These variables are not used.
        
        self.model.feature_extractor.eval()
        embedding = self.model(batch["image"])

        self.embeddings.append(embedding.cpu())
        

    def fit(self) -> None:

        """Fit a Gaussian to the embedding collected from the training set."""
        
        logger.info("Aggregating the embedding extracted from the training set.")
        embeddings = torch.vstack(self.embeddings)

        logger.info("Fitting a Gaussian to the embedding collected from the training set.")
        self.stats = self.model.gaussian.fit(embeddings)
        

    def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT:
        """Perform a validation step of PADIM.

        Similar to the training step, hierarchical features are extracted from the CNN for each batch.

        Args:
            batch (dict[str, str | torch.Tensor]): Input batch
            args: Additional arguments.
            kwargs: Additional keyword arguments.

        Returns:
            Dictionary containing images, features, true labels and masks.
            These are required in `validation_epoch_end` for feature concatenation.
        """

        
        del args, kwargs  # These variables are not used.

        batch["anomaly_maps"] = self.model(batch["image"])
        return batch
        

        return 0

你看,重写了training_step、fit、validation_step等重要函数。

 2、 准备数据

在src\anomalib\data\base\dataset.py中,定义了Dataset的一个之类AnomalibDataset,如下:

 

class AnomalibDataset(Dataset, ABC):
    """Anomalib dataset.

    Args:
        task (str): Task type, either 'classification' or 'segmentation'
        transform (A.Compose): Albumentations Compose object describing the transforms that are applied to the inputs.
    """

    def __init__(self, task: TaskType, transform: A.Compose) -> None:
        super().__init__()
        self.task = task
        self.transform = transform
        self._samples: DataFrame

    def __len__(self) -> int:
        """Get length of the dataset."""
        return len(self.samples)

    def subsample(self, indices: Sequence[int], inplace: bool = False) -> "AnomalibDataset":
        """Subsamples the dataset at the provided indices.

        Args:
            indices (Sequence[int]): Indices at which the dataset is to be subsampled.
            inplace (bool): When true, the subsampling will be performed on the instance itself.
                Defaults to ``False``.
        """
        assert len(set(indices)) == len(indices), "No duplicates allowed in indices."
        dataset = self if inplace else copy.deepcopy(self)
        dataset.samples = self.samples.iloc[indices].reset_index(drop=True)
        return dataset

    @property
    def is_setup(self) -> bool:
        """Checks if setup() been called."""
        return hasattr(self, "_samples")

    @property
    def samples(self) -> DataFrame:
        """Get the samples dataframe."""
        if not self.is_setup:
            msg = "Dataset is not setup yet. Call setup() first."
            raise RuntimeError(msg)
        return self._samples

    @samples.setter
    def samples(self, samples: DataFrame) -> None:
        """Overwrite the samples with a new dataframe.

        Args:
            samples (DataFrame): DataFrame with new samples.
        """
        # validate the passed samples by checking the
        assert isinstance(samples, DataFrame), f"samples must be a pandas.DataFrame, found {type(samples)}"
        expected_columns = _EXPECTED_COLUMNS_PERTASK[self.task]
        assert all(
            col in samples.columns for col in expected_columns
        ), f"samples must have (at least) columns {expected_columns}, found {samples.columns}"
        assert samples["image_path"].apply(lambda p: Path(p).exists()).all(), "missing file path(s) in samples"

        self._samples = samples.sort_values(by="image_path", ignore_index=True)

    @property
    def has_normal(self) -> bool:
        """Check if the dataset contains any normal samples."""
        return 0 in list(self.samples.label_index)

    @property
    def has_anomalous(self) -> bool:
        """Check if the dataset contains any anomalous samples."""
        return 1 in list(self.samples.label_index)

    def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
        """Get dataset item for the index ``index``.

        Args:
            index (int): Index to get the item.

        Returns:
            dict[str, str | torch.Tensor]: Dict of image tensor during training. Otherwise, Dict containing image path,
                target path, image tensor, label and transformed bounding box.
        """
        image_path = self._samples.iloc[index].image_path
        mask_path = self._samples.iloc[index].mask_path
        label_index = self._samples.iloc[index].label_index

        image = read_image(image_path)
        item = {"image_path": image_path, "label": label_index}

        if self.task == TaskType.CLASSIFICATION:
            transformed = self.transform(image=image)
            item["image"] = transformed["image"]
        elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION):
            # Only Anomalous (1) images have masks in anomaly datasets
            # Therefore, create empty mask for Normal (0) images.

            mask = np.zeros(shape=image.shape[:2]) if label_index == 0 else cv2.imread(mask_path, flags=0) / 255.0
            mask = mask.astype(np.single)

            transformed = self.transform(image=image, mask=mask)

            item["image"] = transformed["image"]
            item["mask_path"] = mask_path
            item["mask"] = transformed["mask"]

            if self.task == TaskType.DETECTION:
                # create boxes from masks for detection task
                boxes, _ = masks_to_boxes(item["mask"])
                item["boxes"] = boxes[0]
        else:
            msg = f"Unknown task type: {self.task}"
            raise ValueError(msg)

        return item

    def __add__(self, other_dataset: "AnomalibDataset") -> "AnomalibDataset":
        """Concatenate this dataset with another dataset.

        Args:
            other_dataset (AnomalibDataset): Dataset to concatenate with.

        Returns:
            AnomalibDataset: Concatenated dataset.
        """
        assert isinstance(other_dataset, self.__class__), "Cannot concatenate datasets that are not of the same type."
        assert self.is_setup, "Cannot concatenate uninitialized datasets. Call setup first."
        assert other_dataset.is_setup, "Cannot concatenate uninitialized datasets. Call setup first."
        dataset = copy.deepcopy(self)
        dataset.samples = pd.concat([self.samples, other_dataset.samples], ignore_index=True)
        return dataset

    def setup(self) -> None:
        """Load data/metadata into memory."""
        if not self.is_setup:
            self._setup()
        assert self.is_setup, "setup() should set self._samples"

    @abstractmethod
    def _setup(self) -> DataFrame:
        """Set up the data module.

        This method should return a dataframe that contains the information needed by the dataloader to load each of
        the dataset items into memory.

        The DataFrame must, at least, include the following columns:
            - `split` (str): The subset to which the dataset item is assigned (e.g., 'train', 'test').
            - `image_path` (str): Path to the file system location where the image is stored.
            - `label_index` (int): Index of the anomaly label, typically 0 for 'normal' and 1 for 'anomalous'.
            - `mask_path` (str, optional): Path to the ground truth masks (for the anomalous images only).
            Required if task is 'segmentation'.

        Example DataFrame:
            +---+-------------------+-----------+-------------+------------------+-------+
            |   | image_path        | label     | label_index | mask_path        | split |
            +---+-------------------+-----------+-------------+------------------+-------+
            | 0 | path/to/image.png | anomalous | 1           | path/to/mask.png | train |
            +---+-------------------+-----------+-------------+------------------+-------+

        Note:
            The example above is illustrative and may need to be adjusted based on the specific dataset structure.
        """
        raise NotImplementedError

重写了__len__、__getitem__等重要函数。最终,通过def _setup(self) -> DataFrame:

获得了DataFrame如下

最后,在src\anomalib\data\base\datamodule.py的AnomalibDataModule 类中,

 后来,这个train_dataloader就被自动调用了。至于怎么被自动调用的,我还没看明白

欢迎留言指点一下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/397739.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

你真的了解—————NumPy吗

🌈个人主页:小田爱学编程 🔥 系列专栏:opencv 🏆🏆关注博主,随时获取更多关于IT的优质内容!🏆🏆 😀欢迎来到小田代码世界~ 😁 喜欢的…

2024-02-19(Flume,DataX)

1.flume中拦截器的作用:个人认为就是修改或者删除事件中的信息(处理一下事件)。 2.一些拦截器 Host Interceptor,Timestamp Interceptor,Static Interceptor,UUID Interceptor,Search and Rep…

Docker从入门到上天系列第一篇:Docker开篇介绍

😉😉 欢迎加入我们的学习交流群呀! ✅✅1:这是孙哥suns给大家的福利! ✨✨2:我们免费分享Netty、Dubbo、k8s、Mybatis、Spring、Security、Docker、Grpc、消息中间件、Rpc、SpringCloud等等很多应用和源码级…

如何在iStoreOS软路由系统中安装cpolar实现公网远程本地电脑桌面

文章目录 简介一、配置远程桌面公网地址二、家中使用永久固定地址 访问公司电脑**具体操作方法是:** 简介 软路由是PC的硬件加上路由系统来实现路由器的功能,也可以说是使用软件达成路由功能的路由器。 使用软路由控制局域网内计算机的好处&#xff1a…

git上传报错:Object too large, rejecting the pack

在gerrit设置了最大不能上传超过600M的文件,今天开发遇到推送问题: 结果到本地怎么也找不到大文件。 后来只能按commit排查,用如下命令排查到了: 解决方法,将大文件去掉:(commitid为大文件所在commit&…

扫盲:什么是webGPU,和webGL对比哪些优点?

web端的3D图像渲染,大都采用webGL,不过其性能让大家很崩溃,webGPU的出现,让大家看到了访问加速的可能,本文通过对比webGPU与webGL,给老铁们普及一下。老铁们如有数据可视化的设计和开发需求,可以…

“丑女”上春晚:任素汐获赞,黄绮珊遭网暴?

♥ 为方便您进行讨论和分享,同时也为能带给您不一样的参与感。请您在阅读本文之前,点击一下“关注”,非常感谢您的支持! 文 |猴哥聊娱乐 编 辑|徐 婷 校 对|侯欢庭 在这个光怪陆离的人间舞台,我们常被绚烂的表象所迷…

Solidworks:钣金模型作业

有了实体模型设计的基础,钣金模型掌握起来很容易。

微博数据可视化分析:利用Python构建信息图表展示话题热度

1. 引言 随着社交媒体的迅速发展,微博已成为人们交流观点、表达情感的重要平台之一。微博评论数据蕴含着丰富的信息,通过对这些数据进行分析和可视化,我们可以深入了解用户对特定话题的关注程度和情感倾向。本文将介绍如何利用Python进行微博…

网络原理 - HTTP/HTTPS(3)

HTTP请求 认识请求"报头" header的整体的格式也是"键值对"的结构. 每个键值对占一行,键和值之间使用分号进行分割. 报头的种类有很多,此处仅介绍几个常见的. Host 表示服务器主机的地址和端口.(Host和URL中的ip地址端口啥的,绝大部分情况下都是一样的,少…

React项目基础搭建过程中遇到的问题

由于React Router版本的不同导致的问题 报错信息如下: Line 9:18: Switch is not defined react/jsx-no-undef Line 13:88: Redirect is not defined react/jsx-no-undef 问题出现的原因: 对于导入 Switch is not defined 和 Redirect is not…

Netty Review - NIO空轮询及Netty的解决方案源码分析

文章目录 Pre问题说明NIO CodeNetty是如何解决的?源码分析入口源码分析selectCntselectRebuildSelector Pre Netty Review - ServerBootstrap源码解析 Netty Review - NioServerSocketChannel源码分析 Netty Review - 服务端channel注册流程源码解析 问题说明 N…

某和OA C6 RssModulesHttp.aspx存在SQL注入漏洞(附漏洞利用脚本)

免责声明 文章中涉及的漏洞均已修复,敏感信息均已做打码处理,文章仅做经验分享用途,切勿当真,未授权的攻击属于非法行为!文章中敏感信息均已做多层打马处理。传播、利用本文章所提供的信息而造成的任何直接或者间接的…

Linux调优指南

更多相关知识可以阅读: https://www.yuque.com/treblez/qksu6c/yxl59pkvczqot9us https://www.yuque.com/treblez/qksu6c/nqe8ip59cwegl6rk 本文不会讲解linux的基础知识。 CPU 工具大图 观测时优先使用top、vmstat和pidstat三个工具: 设置调度器 这…

设计模式----开题

简介: 本文主要介绍设计模式中的六大设计原则。开闭原则,里氏代换原则,依赖倒转原则,接口隔离原则,迪米特原则和合成复用原则。这几大原则是设计模式使用的基础,在使用设计模式时,应该牢记这六大…

《剑指 Offer》专项突破版 - 面试题 53 : 二叉搜索树的下一个节点(详解 C++ 实现的两种方法)

目录 前言 一、方法一 二、方法二 前言 题目链接:LCR 053. 二叉搜索树中的中序后继 - 力扣(LeetCode) 题目: 给定一棵二叉搜索树和它的一个节点 p,请找出按中序遍历的顺序该节点 p 的下一个节点。假设二叉搜索树…

十一、图像颜色通道分离与合并

项目功能实现&#xff1a;对一张彩色图片的三个颜色通道进行分离&#xff0c;修改之后进行合并 按照之前的博文结构来&#xff0c;这里就不在赘述了 一、头文件 channel.h #pragma once#include<opencv2/opencv.hpp>using namespace cv;class CHANNELS { public:void …

[word] word正反面打印应该怎么设置呢? #知识分享#学习方法#职场发展

word正反面打印应该怎么设置呢&#xff1f; word文档打印时&#xff0c;如果页数比较多&#xff0c;出于格式要求或为了节省纸张&#xff0c;通常需要正反面打印&#xff0c;那怎么操作正反双面打印呢&#xff1f;通常有两种方法打印。 1、选择“打印”对话框底部的“打印”下…

android11:基于rk3568 适配裕泰以太网Phy芯片YT8512

使用裕太以太网 Phy 芯片YT8512C 确认硬件连接&#xff1a; 根据硬件连接修改设备&#xff1a; rk3568设备树修改&#xff1a; &gmac1 {phy-mode "rmii";clock_in_out "output";//phy inputsnps,reset-gpio <&gpio3 RK_PB2 GPIO_ACTIVE_LOW…

如何选择适合你的阿里云服务器?

阿里云服务器配置怎么选择&#xff1f;根据实际使用场景选择&#xff0c;个人搭建网站可选2核2G配置&#xff0c;访问量大的话可以选择2核4G配置&#xff0c;企业部署Java、Python等开发环境可以选择2核8G配置&#xff0c;企业数据库、Web应用或APP可以选择4核8G配置或4核16G配…