YOLOV8逐步分解(2)_DetectionTrainer类初始化过程

 接上篇文章yolov8逐步分解(1)--默认参数&超参配置文件加载继续讲解。

 1. 默认配置文件加载完成后,创建对象trainer时,需要从默认配置中获取类DetectionTrainer初始化所需的参数args,如下所示

def train(cfg=DEFAULT_CFG, use_python=False):
    """Train and optimize YOLO model given training data and device."""
    model = cfg.model or 'yolov8n.pt'
    data = cfg.data or 'coco128.yaml'  # or yolo.ClassificationDataset("mnist")
    device = cfg.device if cfg.device is not None else ''
    args = dict(model=model, data=data, device=device)
    if use_python:
        from ultralytics import YOLO
        YOLO(model).train(**args)
    else:
        trainer = DetectionTrainer(overrides=args)  #初始化训练器
        trainer.train()

        通过debug可以看到,如下所示,args值为指定模型和数据集

 2. 使用上一步中获取的参数args,创建并初始化一个目标检测训练器trainer

trainer = DetectionTrainer(overrides=args)

3. DetectionTrainer类的初始化代码如下,下面我们将逐步讲解。

def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """
        Initializes the BaseTrainer class.
        Args:
            cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
            overrides (dict, optional): Configuration overrides. Defaults to None.
            对配置文件/训练数据文件参数进行加载,关键信息判断处理解析,保证文件存在,不存在则下载等合法性检测,及值的初始化化操作
        """
        self.args = get_cfg(cfg, overrides)  #将overrides中的配置与cfg中的配置融合,返回SimpleNameSpace类型
        self.device = select_device(self.args.device, self.args.batch) #选择运行在CPU/GPU还是苹果推出的MPS库上
        self.check_resume() #判断是否基于之前的断点继续训练,如果是,则加载之前保存的数据参数
        self.validator = None
        self.model = None
        self.metrics = None
        self.plots = {}
        init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) #初始化随机数
        # Dirs 创建运行结果保存额目录及文件:创建本次训练的目录/ weights保存目录 /保存运行参数
        project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task #project: runs/detect
        name = self.args.name or f'{self.args.mode}'  #name: 'train'
        if hasattr(self.args, 'save_dir'):  #判断是否设置保存路径 ,如果没有则根据项目和任务名穿件保存目录
            self.save_dir = Path(self.args.save_dir)
        else:
           self.save_dir = Path(
                increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
        self.wdir = self.save_dir / 'weights'  # weights dir #runs/detect/train72/weighhts
        if RANK in (-1, 0):
            self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
            self.args.save_dir = str(self.save_dir)
            yaml_save(self.save_dir / 'args.yaml', vars(self.args))  # save run args  #保存运行参数
        self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'  # checkpoint paths
        self.save_period = self.args.save_period   #保存周期
        #设置 epoch次数 和 batch的大小
        self.batch_size = self.args.batch
        self.epochs = self.args.epochs
        self.start_epoch = 0
        if RANK == -1:
            print_args(vars(self.args))
        # Device
        if self.device.type == 'cpu':
            self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading
        # Model and Dataset 初始化模型文件 和数据集
        self.model = self.args.model  #yolov8n.pt
        try:
            if self.args.task == 'classify':   #分类任务
                self.data = check_cls_dataset(self.args.data)
            elif self.args.data.endswith('.yaml') or self.args.task in ('detect', 'segment'):  #检测和分割任务
                self.data = check_det_dataset(self.args.data) #加载数据yaml文件,进行关键属性值检测,并进行路径转换,确保数据集文件存在,不存在则下载
                if 'yaml_file' in self.data:
                    self.args.data = self.data['yaml_file']  # for validating 'yolo train data=url.zip' usage
        except Exception as e:
            raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
        self.trainset, self.testset = self.get_dataset(self.data) #初始化训练集测试集参数 获取路径
        self.ema = None
        # Optimization utils init
        self.lf = None   #损失函数
        self.scheduler = None  #学习率调整策略
        # Epoch level metrics 指标
        self.best_fitness = None
        self.fitness = None
        self.loss = None   #当前损失值
        self.tloss = None  #总损失值
        self.loss_names = ['Loss']
        self.csv = self.save_dir / 'results.csv'
        self.plot_idx = [0, 1, 2]
        # Callbacks
        self.callbacks = _callbacks or callbacks.get_default_callbacks()
        if RANK in (-1, 0):
            callbacks.add_integration_callbacks(self)

3.1  self.args = get_cfg(cfg, overrides) 该行主要实现功能为:

        将默认配置参数从Simplenamespace转为字典后与overrides中的参数合并更新,进行一些参数的合法性检测后,再转换为Simplenamespace格式输出。

        overrides该参数主要是用于更新默认加载的配置文件中model和data的值,默认配置中上述值均为None,如下图所示:

更新后的配置如下图所示:

3.2 self.device = select_device(self.args.device, self.args.batch) 功能为:

        选择算法运行在CPU还是GPU上,参数batch用于检测设置的batch数值是否是GPU个数的整数倍,若不是整数倍则报错。

3.3  self.check_resume() :判断是否基于之前的断点继续训练,如果是,则加载之前保存的数据参数,本次默认配置参数该值为False.

3.4 接下来创建运行时的文件保存目录,包括本次训练的权重文件保存目录,并保存训练使用的参数以及checkPoint路径等。

# Dirs 创建运行结果保存目录及文件:创建本次训练的目录/ weights保存目录 /保存运行参数
        project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task #project: runs/detect
        name = self.args.name or f'{self.args.mode}'  #name: 'train'
        if hasattr(self.args, 'save_dir'):  #判断是否设置保存路径 ,如果没有则根据项目和任务名创建保存目录
            self.save_dir = Path(self.args.save_dir)
        else:
            self.save_dir = Path(
                increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
        self.wdir = self.save_dir / 'weights'  # weights dir #runs/detect/train72/weighhts
        if RANK in (-1, 0):
            self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
            self.args.save_dir = str(self.save_dir)
            yaml_save(self.save_dir / 'args.yaml', vars(self.args))  # save run args  #保存运行参数
        self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'  # checkpoint paths
        self.save_period = self.args.save_period   #保存周期

3.5 初始化batch/epoch等参数,这个一目了然,不在解释

3.6  初始化数据集(coco128.yaml),步骤如下: 

        3.6.1 检测传入的数据集参数’dataset’是否是yaml结尾文件

        3.6.2 若是路径并且是压缩格式,则下载数据集配置文件

        3.6.3  加载coco128.yaml,通过函数yaml_load()加载

def check_det_dataset(dataset, autodownload=True):
    """Download, check and/or unzip dataset if not found locally."""
    data = check_file(dataset)  #dataset: coco128.yaml #判断文件是否合法,如果不存在在下载,或者从本地搜索
    # Download (optional)
    extract_dir = ''
    if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)): #判断数据集是否时zip or tar压缩格式 #
        new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)
        data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
        extract_dir, autodownload = data.parent, False
    # Read yaml (optional)
    if isinstance(data, (str, Path)):
        data = yaml_load(data, append_filename=True)  # dictionary #读取数据集yam文件 simplenamespace格式
    # Checks 必要参数检测
    for k in 'train', 'val':
        if k not in data: #如果数据中既不包含 train也不包含 val,则报错
            raise SyntaxError(
                emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))
    if 'names' not in data and 'nc' not in data:
        raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
    if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:
        raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
    if 'names' not in data: #如果没有names则,用数字代替
        data['names'] = [f'class_{i}' for i in range(data['nc'])]
    else:
        data['nc'] = len(data['names'])
    data['names'] = check_class_names(data['names']) #检测data['names']是否是dict,以及将key转换为数字
    # Resolve paths
    path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent)  # dataset root
    if not path.is_absolute():
        path = (DATASETS_DIR / path).resolve() #转化为绝对路径
    data['path'] = path  # download scripts
    for k in 'train', 'val', 'test':  #全部转换为绝对路径
        if data.get(k):  # prepend path
            if isinstance(data[k], str):
                x = (path / data[k]).resolve()
                if not x.exists() and data[k].startswith('../'):
                    x = (path / data[k][3:]).resolve()
                data[k] = str(x)
            else:
                data[k] = [str((path / x).resolve()) for x in data[k]]
    # Parse yaml
    train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
    if val:
        val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])]  # val path
        if not all(x.exists() for x in val):  #不存在则下载
            name = clean_url(dataset)  # dataset name with URL auth stripped
            m = f"\nDataset '{name}' images not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()]
            if s and autodownload:
                LOGGER.warning(m)
            else:
                m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'"
                raise FileNotFoundError(m)
            t = time.time()
            if s.startswith('http') and s.endswith('.zip'):  # URL
                safe_download(url=s, dir=DATASETS_DIR, delete=True)
                r = None  # success
            elif s.startswith('bash '):  # bash script
                LOGGER.info(f'Running {s} ...')
                r = os.system(s)
            else:  # python script
                r = exec(s, {'yaml': data})  # return None
            dt = f'({round(time.time() - t, 1)}s)'
            s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
            LOGGER.info(f'Dataset download {s}\n')
    check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf')  # download fonts
    return data  # dictionary

        其中,data = yaml_load(data, append_filename=True)加载完成后,data内容如下:

注意:’nc’:80 是通过 data['nc'] = len(data['names']) 后添加的。

   3.6.4 将data中的路径全部转换为绝对路径

 for k in 'train', 'val', 'test':  #全部转换为绝对路径
        if data.get(k):  # prepend path
            if isinstance(data[k], str):
                x = (path / data[k]).resolve()
                if not x.exists() and data[k].startswith('../'):
                    x = (path / data[k][3:]).resolve()
                data[k] = str(x)
            else:
                data[k] = [str((path / x).resolve()) for x in data[k]]

        转换完成并更新data后,data的内容如下,其中train,val,test等键的值变为了绝对路径:

        3.6.5 获取训练集、测试集、验证集、以及下载路径

train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))

        3.6.6 最終返回data,数据类型为字典,完成对coco128.yaml文件的加载解析及校验工作。

3.7  获取训练集和验证集的路径

self.trainset, self.testset = self.get_dataset(self.data) #初始化训练集测试集参数 获取路径

        其中,获取路径方法函数实现过程如下:

def get_dataset(data):
        """
        Get train, val path from data dict if it exists. Returns None if data format is not recognized.
        """
        return data['train'], data.get('val') or data.get('test')

3.8 其他学习率、损失函数等都设置为None

        self.ema = None
        # Optimization utils init
        self.lf = None   #损失函数
        self.scheduler = None  #学习率调整策略
        # Epoch level metrics 指标
        self.best_fitness = None
        self.fitness = None
        self.loss = None   #当前损失值
        self.tloss = None  #总损失值
        self.loss_names = ['Loss']
        self.csv = self.save_dir / 'results.csv'
        self.plot_idx = [0, 1, 2]

3.9 设置用于结果展示获取的一些回调函数

        # Callbacks
        self.callbacks = _callbacks or callbacks.get_default_callbacks()
        if RANK in (-1, 0):
            callbacks.add_integration_callbacks(self)

        至此,trainer的初始化过程解析完成。

        总结,本章详细介绍了yolov8训练器trainer的初始化过程,讲解参数的加载替换过程,着重讲解了coco128数据集的加载解析及校验,最后介绍了损失函数学习率的初始化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/496810.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

17.注释和关键字

文章目录 一、 注释二、关键字class关键字 我们之前写的HelloWorld案例写的比较简单,但随着课程渐渐深入,当我们写一些比较难的代码时,在刚开始写完时,你知道这段代码是什么意思,但是等过了几天,再次看这段…

图片标注编辑平台搭建系列教程(3)——画布拖拽、缩放实现

简介 标注平台很关键的一点,对于整个图片为底图的画布,需要支持缩放、拖拽,并且无论画布位置在哪里,大小如何,所有绘制的点、线、面的坐标都是相对于图片左上角的,并且,拖拽、缩放,…

从零开始学习在VUE3中使用canvas(六):线条样式(线条宽度lineWidth,线条端点样式lineCap)

一、线条宽度lineWidth 1.1简介 值为一个数字 const ctx canvas.getContext("2d"); ctx.lineWidth 6; 1.2效果展示 1.3全部代码 <template><div class"canvasPage"><!-- 写一个canvas标签 --><canvas class"main" r…

图像处理与视觉感知---期末复习重点(5)

文章目录 一、膨胀与腐蚀1.1 膨胀1.2 腐蚀 二、开操作与闭操作 一、膨胀与腐蚀 1.1 膨胀 1. 集合 A A A 被集合 B B B 膨胀&#xff0c;定义式如下。其中集合 B B B 也称为结构元素&#xff1b; ( B ^ ) z (\hat{B})z (B^)z 表示 B B B 的反射平移 z z z 后得到的新集合。…

冥想打坐睡觉功法

睡觉把手机放远一点&#xff0c;有电磁辐射&#xff0c;我把睡觉功法交给你&#xff0c;这样就可以睡好了。

55、Qt/事件机制相关学习20240326

一、代码实现设置闹钟&#xff0c;到时间后语音提醒用户。示意图如下&#xff1a; 代码&#xff1a; #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget), speecher(new QTextToSpeech(t…

C++超市商品管理系统

一、简要介绍 1.本项目为面向对象程序设计的大作业&#xff0c;基于Qt creator进行开发&#xff0c;Qt框架版本6.4.1&#xff0c;编译环境MINGW 11.2.0。 2.项目结构简介&#xff1a;关于系统逻辑部分的代码的头文件在head文件夹中&#xff0c;源文件在s文件夹中。与图形界面…

权限提升-Win系统权限提升篇AD内网域控NetLogonADCSPACKDCCVE漏洞

知识点 1、WIN-域内用户到AD域控-CVE-2014-6324 2、WIN-域内用户到AD域控-CVE-2020-1472 3、WIN-域内用户到AD域控-CVE-2021-42287 4、WIN-域内用户到AD域控-CVE-2022-26923 章节点&#xff1a; 1、Web权限提升及转移 2、系统权限提升及转移 3、宿主权限提升及转移 4、域控权…

Git命令上传本地项目至github

记录如何创建个人仓库并上传已有代码至github in MacOS环境 0. 首先下载git 方法很多 这里就不介绍了 1. Github Create a new repository 先在github上创建一个空仓库&#xff0c;用于一会儿链接项目文件&#xff0c;按照自己的需求设置name和是否private 2.push an exis…

指针数组的有趣程序【C语言】

文章目录 指针数组的有趣程序指针数组是什么&#xff1f;指针数组的魅力指针数组的应用示例&#xff1a;命令行计算器有趣的颜色打印 结语 指针数组的有趣程序 在C语言的世界里&#xff0c;指针是一种强大的工具&#xff0c;它不仅能够指向变量&#xff0c;还能指向数组&#…

如何利用OpenCV4.9离散傅里叶变换

返回&#xff1a;OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇:如何利用OpenCV4.9 更改图像的对比度和亮度 下一篇:OpenCV 如何使用 XML 和 YAML 文件的文件输入和输出 目标 我们将寻求以下问题的答案&#xff1a; 什么是傅里叶变换&#xff0c;为什…

《数据结构学习笔记---第五篇》---链表OJ练习下

step1:思路分析 1.实现复制&#xff0c;且是两个独立的复制&#xff0c;我们必须要理清指针之间的逻辑&#xff0c;注意random的新指针要链接到复制体的后面。 2.我们先完成对于结点的复制&#xff0c;并将复制后的结点放在原节点的后面&#xff0c;并链接。 3.完成random结点…

黑马鸿蒙笔记1

这里与前端类似。

斜率优化dp 笔记

任务安排1 有 N 个任务排成一个序列在一台机器上等待执行&#xff0c;它们的顺序不得改变。 机器会把这 N 个任务分成若干批&#xff0c;每一批包含连续的若干个任务。 从时刻 00 开始&#xff0c;任务被分批加工&#xff0c;执行第 i 个任务所需的时间是 Ti。 另外&#x…

PHP开发全新29网课交单平台源码修复全开源版本,支持聚合登陆易支付

这是一套最新版本的PHP开发的网课交单平台源代码&#xff0c;已进行全开源修复&#xff0c;支持聚合登录和易支付功能。 项目 地 址 &#xff1a; runruncode.com/php/19721.html 以下是对该套代码的主要更新和修复&#xff1a; 1. 移除了论文编辑功能。 2. 移除了强国接码…

linux之进程

一、背景 冯.诺依曼体系结构 输入设备键盘、鼠标、摄像头、话筒、磁盘、网卡...输出设备显示器、声卡、磁盘、网卡...CPU运算器、控制器存储器一般就是内存 数据在计算机的体系结构进行流动&#xff0c;流动过程中&#xff0c;进行数据的加工处理&#xff0c;从一个设备到另一…

网上兼职赚钱攻略:六种方式让你轻松上手

在互联网时代&#xff0c;网上兼职已经成为一种非常流行的赚钱方式。对于许多想要在家里挣钱的人来说&#xff0c;网上兼职不仅可以提供灵活的工作时间&#xff0c;还可以让他们在自己的兴趣领域中寻求机会&#xff0c;实现自己的财务自由。 在这里&#xff0c;我将为您介绍六…

OpenGL 实现“人像背景虚化“效果

手机上的人像模式,也被人们称作“背景虚化”或 ”双摄虚化“ 模式,也称为 Bokeh 模式,能够在保持画面中指定的人或物体清晰的同时,将其他的背景模糊掉。突出画面的主体部分,主观上美感更强烈。 人像模式的一般实现原理是,利用双摄系统获取景深信息,并通过深度传感器和图…

带流量主功能的外卖菜谱小程序源码:助你轻松领取优惠券,个人使用也可通过审查

外卖菜谱小程序源码-带流量主功能-外卖领劵个人也可过审 这套小程序优点就带很多菜谱&#xff0c;各种你爱吃菜的做法与各类食材介绍营养搭配&#xff0c;相信很多小姐姐会感兴趣。 宝妈宝爸这个小程序肯定能留的住这个群体的人脉流量&#xff0c;这是小程序最大的亮点&#…

深圳区块链交易所app系统开发,撮合交易系统开发

随着区块链技术的迅速发展和数字资产市场的蓬勃发展&#xff0c;区块链交易所成为了数字资产交易的核心场所之一。在这个快速发展的领域中&#xff0c;区块链交易所App系统的开发和撮合交易系统的建设至关重要。本文将探讨区块链交易所App系统开发及撮合交易系统的重要性&#…
最新文章