yolov8训练进阶:自定义训练脚本,从配置文件载入训练超参数

yolov8官方教程提供了2种训练方式,一种是通过命令行启动训练,一种是通过写代码启动。
image.png
image.png
命令行的方式启动方便,通过传入参数可以方便的调整训练参数,但这种方式不方便记录训练参数和调试训练代码。
自行写训练代码的方式更灵活,也比较方便调试,但官方的示例各种参数都是在代码中写死的方式,失去了灵活性。
其实我们可以结合这两种方法的优势,既能够通过命令行参数修改很容易变化的参数(如batch size, epoch, imgsz等),然后用配置文件保存很少需要变化的参数,或者这些变化需要保存下来方便对比(如各类增强比例)。

代码分析

首先我们需要知道我们能够设置哪些参数,尽管官方文档列出了命令行能够传入的参数列表,但每次设置大量参数还是不方便,而不设置的时候默认参数是多少我们也不知道,所以还是有必要分析一下代码。
通过模型的train接口我们会知道所有的Trainer均继承自BaseTrainer(yolo/engine/trainer.py),该类的构造函数如下:

def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """
        Initializes the BaseTrainer class.

        Args:
            cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
            overrides (dict, optional): Configuration overrides. Defaults to None.
        """
        self.args = get_cfg(cfg, overrides)
        self.device = select_device(self.args.device, self.args.batch)
        self.check_resume()
        ...

其中overrides就是我们设置的参数,我们未设置的参数则来源于DEFAULT_CFG,继续跟踪我们会发现这个DEFAULT_CFG实际来源于yolo/cfg/default.yaml:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# Default training settings and hyperparameters for medium-augmentation COCO training

task: detect  # YOLO task, i.e. detect, segment, classify, pose
mode: train  # YOLO mode, i.e. train, val, predict, export, track, benchmark

# Train settings -------------------------------------------------------------------------------------------------------
model:  # path to model file, i.e. yolov8n.pt, yolov8n.yaml
data:  # path to data file, i.e. coco128.yaml
epochs: 100  # number of epochs to train for
start_epoch: 0  # start epoch
patience: 50  # epochs to wait for no observable improvement for early stopping of training
batch: 16  # number of images per batch (-1 for AutoBatch)
imgsz: 640  # size of input images as integer or w,h
save: True  # save train checkpoints and predict results
save_period: -1 # Save checkpoint every x epochs (disabled if < 1)
cache: False  # True/ram, disk or False. Use cache for data loading
device:  # device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
workers: 8  # number of worker threads for data loading (per RANK if DDP)
project:  # project name
name:  # experiment name, results saved to 'project/name' directory
exist_ok: False  # whether to overwrite existing experiment
pretrained: False  # whether to use a pretrained model
optimizer: SGD  # optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
verbose: True  # whether to print verbose output
seed: 0  # random seed for reproducibility
deterministic: True  # whether to enable deterministic mode
single_cls: False  # train multi-class data as single-class
rect: False  # rectangular training if mode='train' or rectangular validation if mode='val'
cos_lr: False  # use cosine learning rate scheduler
close_mosaic: 0  # (int) disable mosaic augmentation for final epochs
resume: False  # resume training from last checkpoint
amp: True  # Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
fraction: 1.0  # dataset fraction to train on (default is 1.0, all images in train set)
profile: False  # profile ONNX and TensorRT speeds during training for loggers
# Segmentation
overlap_mask: True  # masks should overlap during training (segment train only)
mask_ratio: 4  # mask downsample ratio (segment train only)
# Classification
dropout: 0.0  # use dropout regularization (classify train only)
...

我们所有能设置的参数就在这个文件中,如果我们设置了不在其中的参数则会报错(下一篇介绍怎么增加参数)。

自定义参数配置文件

我们可以将训练会调整的参数单独保存到一个yaml文件,如hyp.scratch.yaml作为从头训练的配置,进行多次实验时,就可以建立不同的配置参数文件:

lr0: 0.01  # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
lrf: 0.001  # final learning rate (lr0 * lrf)
momentum: 0.937  # SGD momentum/Adam beta1
weight_decay: 0.0005  # optimizer weight decay 5e-4
warmup_epochs: 3.0  # warmup epochs (fractions ok)
warmup_momentum: 0.8  # warmup initial momentum
warmup_bias_lr: 0.1  # warmup initial bias lr
box: 7.5  # box loss gain
cls: 0.5  # cls loss gain (scale with pixels)
dfl: 1.5  # dfl loss gain
pose: 12.0  # pose loss gain
kobj: 1.0  # keypoint obj loss gain
label_smoothing: 0.0  # label smoothing (fraction)
nbs: 64  # nominal batch size
hsv_h: 0.015  # image HSV-Hue augmentation (fraction)
hsv_s: 0.7  # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4  # image HSV-Value augmentation (fraction)
degrees: 0.0  # image rotation (+/- deg)
translate: 0.1  # image translation (+/- fraction)
scale: 0.5  # image scale (+/- gain)
shear: 0.0  # image shear (+/- deg)
perspective: 0.0  # image perspective (+/- fraction), range 0-0.001
flipud: 0.0  # image flip up-down (probability)
fliplr: 0.5  # image flip left-right (probability)
mosaic: 0.1  # image mosaic (probability)
mixup: 0.05  # image mixup (probability)
copy_paste: 0.0  # segment copy-paste (probability)

workers: 12  # number of workers
# cache: disk

自定义训练脚本

建立了自定义参数文件,我们还要建立自己的训练脚本来载入配置文件,并且还有一些经常变化的参数需要通过命令行传入, 新建train.py:

from ultralytics import YOLO
import yaml
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='configs/data/phd.yaml', help='dataset.yaml path')
parser.add_argument('--epochs', type=int, default=300, help='number of epochs')
parser.add_argument('--hyp', type=str, default='configs/hyp.yaml', help='size of each image batch')
parser.add_argument('--model', type=str, default='weights/yolov8n.pt', help='pretrained weights or model.config path')
parser.add_argument('--batch-size', type=int, default=64, help='size of each image batch')
parser.add_argument('--img-size', type=int, default=320, help='size of each image dimension')
parser.add_argument('--device', type=str, default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--project', type=str, default='yolo', help='project name')
parser.add_argument('--name', type=str, default='pretrain', help='exp name')
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')

args = parser.parse_args()

assert args.data, 'argument --data path is required'
assert args.model, 'argument --model path is required'

if __name__ == '__main__':
    # Initialize
    model = YOLO(args.model)
    hyperparams = yaml.safe_load(open(args.hyp))
    hyperparams['epochs'] = args.epochs
    hyperparams['batch'] = args.batch_size
    hyperparams['imgsz'] = args.img_size
    hyperparams['device'] = args.device
    hyperparams['project'] = args.project
    hyperparams['name'] = args.name
    hyperparams['resume'] = args.resume

    model.train(data= args.data, **hyperparams)

该脚本通过argparse来接受命令行参数,并设置到超参数字典,和yolov5的启动脚本类似。
主要有以下几个参数(可以根据个人需要增删):

  • data: 数据集配置文件
  • hyp: 参数配置文件(上一节我们建立的)
  • model: 模型权重或者模型结构配置文件
    其他参数根据名字就显而易见了。

模型训练(单卡)

python train.py --model weights/yolov8n.pt --data
configs/data/objects365.yaml --hyp configs/hyp.yaml --batch-size 512 --img-size 416 --device
0 --project object365 --name yolov8n

模型训练(多卡DDP)

理论上,我们只需要将device设置为多张卡就可以进行多卡并行了,但我们直接运行会发生一下错误:

assert args.model, 'argument --model path is required'

也就是我们设置的参数并没有接收到,进一步分析,DDP情况下,实际运行的命令是:

DDP command: ['/root/miniconda3/bin/python', '-m', 'torch.distributed.run', '--nproc_per_node', '4', '--master_port', '39083', 'xxx/code/yolov8/train.py']
WARNING:__main__:

也就是yolov8实际是用pytorch的ddp脚本启动了我们写得train.py脚本,但是却没有把我们设置的参数传过来(应该算是个bug吧···),这个过程发生在BaseTrainer的train接口中:
image.png
我们对generate_ddp_command进行修改,将命令行参数增加到train.py后(file后增加*sys.argv[1:]):

cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file, *sys.argv[1:]]

完整的函数:

def generate_ddp_command(world_size, trainer):
    """Generates and returns command for distributed training."""
    import __main__  # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
    if not trainer.resume:
        shutil.rmtree(trainer.save_dir)  # remove the save_dir
    file = str(Path(sys.argv[0]).resolve())

    safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$')  # allowed characters and maximum of 100 characters
    if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')):  # using CLI
        file = generate_ddp_file(trainer)
    dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
    port = find_free_network_port()
    cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file, *sys.argv[1:]]
    return cmd, file

修改后,device设置多卡则能正常开启训练。

结语

本文介绍了如何使用自定义训练脚本的方式启动yolov8的训练,有效的结合命令行和配置文件的优点,即可以灵活的修改训练参数,又可以用配置文件来管理我们的训练超参数。并通过修改文件,支持了DDP训练。

f77d79a3b79d6d9849231e64c8e1cdfa~tplv-dy-resize-origshort-autoq-75_330.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/75596.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

logstash 原理(含部署)

1、ES原理 原理 使⽤filebeat来上传⽇志数据&#xff0c;logstash进⾏⽇志收集与处理&#xff0c;elasticsearch作为⽇志存储与搜索引擎&#xff0c;最后使⽤kibana展现⽇志的可视化输出。所以不难发现&#xff0c;⽇志解析主要还 是logstash做的事情 从上图中可以看到&#x…

将CNKI知网文献条目导出,并导入到Endnote内

将CNKI知网文献条目导出&#xff0c;并导入到Endnote内 目录 将CNKI知网文献条目导出&#xff0c;并导入到Endnote内一、从知网上导出参考文献二、将知网导出的参考文献导入到Endnote 一、从知网上导出参考文献 从知网上导出参考文献过程和步骤如图1所示。 图1 导出的参考文献…

码银送书第五期《互联网广告系统:架构、算法与智能化》

广告平台的建设和完善是一项长期工程。例如&#xff0c;谷歌早于2003年通过收购Applied Semantics开展Google AdSense 项目&#xff0c;而直到20年后的今天&#xff0c;谷歌展示广告平台仍在持续创新和提升。广告平台是负有营收责任的复杂在线平台&#xff0c;对其进行任何改动…

Memory Allocators 101 - Write a simple memory allocator

Memory Allocators 101 - Write a simple memory allocator - Arjun Sreedharan BlogAboutContactPosts GoogleLinkedInGithubFacebookTwitterUMass Amherst 1:11 AM 9th 八月 20160 notes Memory Allocators 101 - Write a simple memory allocator Code related to this…

Grafana展示k8s中pod的jvm监控面板/actuator/prometheus

场景 为保障java服务正常运行&#xff0c;对服务的jvm进行监控&#xff0c;通过使用actuator组件监控jvm情况&#xff0c;使用prometheus对数据进行采集&#xff0c;并在Grafana展现。 基于k8s场景 prometheus数据收集 配置service的lable&#xff0c;便于prometheus使用labl…

Python Flask+Echarts+sklearn+MySQL(评论情感分析、用户推荐、BI报表)项目分享

Python FlaskEchartssklearnMySQL(评论情感分析、用户推荐、BI报表)项目分享 项目背景&#xff1a; 随着互联网的快速发展和智能手机的普及&#xff0c;人们越来越倾向于在网上查找餐厅、购物中心、酒店和旅游景点等商户的点评和评分信息&#xff0c;以便做出更好的消费决策。…

Android 广播发送流程分析

在上一篇文章中Android 广播阻塞、延迟问题分析方法讲了广播阻塞的分析方法&#xff0c;但是分析完这个问题&#xff0c;自己还是有一些疑问&#xff1a; 广播为啥会阻塞呢&#xff1f;发送给接收器就行了&#xff0c;为啥还要等着接收器处理完才处理下一个&#xff1f;由普通…

【不限于联想Y9000P电脑关盖再打开时黑屏的解决办法】

不限于联想Y9000P电脑关盖再打开时黑屏的解决办法 问题的前言问题的出现问题拟解决 问题的前言 事情发生在昨天&#xff0c;更新了Win11系统后&#xff1a; 最惹人注目的三处地方就是&#xff1a; 1.可以查看时间的秒数了&#xff1b; 2.右键展示的内容变窄了&#xff1b; 3.按…

205、仿真-51单片机直流数字电流表多档位切换Proteus仿真设计(程序+Proteus仿真+原理图+流程图+元器件清单+配套资料等)

毕设帮助、开题指导、技术解答(有偿)见文未 目录 一、硬件设计 二、设计功能 三、Proteus仿真图 四、原理图 五、程序源码 资料包括&#xff1a; 方案选择 单片机的选择 方案一&#xff1a;STM32系列单片机控制&#xff0c;该型号单片机为LQFP44封装&#xff0c;内部资源…

等保案例 1

用户简介 吉林省人力资源和社会保障厅&#xff08;简称“吉林省人社厅”&#xff09;响应《网络安全法》的建设要求&#xff0c;为了向吉林省人民提供更好、更快、更稳定的信息化服务&#xff0c;根据《网络安全法》和等级保护2.0相关标准&#xff0c;落实网络安全与信息化建设…

【1572. 矩阵对角线元素的和】

来源&#xff1a;力扣&#xff08;LeetCode&#xff09; 描述&#xff1a; 给你一个正方形矩阵 mat&#xff0c;请你返回矩阵对角线元素的和。 请你返回在矩阵主对角线上的元素和副对角线上且不在主对角线上元素的和。 示例 1&#xff1a; 输入&#xff1a;mat [[1,2,3]…

uniapp 官方扩展组件 uni-combox 实现:只能选择不能手写(输入中支持过滤显示下拉列表)

uniapp 官方扩展组件 uni-combox 实现&#xff1a;只能选择不能手写&#xff08;输入中支持过滤显示下拉列表&#xff09; uni-comboxuni-combox 原本支持&#xff1a;问题&#xff1a; 改造源码参考资料 uni-combox uni-combox 原本支持&#xff1a; 下拉选择。输入关键字&am…

24届近3年南京信息工程大学自动化考研院校分析

今天给大家带来的是南京信息工程大学控制考研分析 满满干货&#xff5e;还不快快点赞收藏 一、南京信息工程大学 学校简介 南京信息工程大学位于南京江北新区&#xff0c;是一所以大气科学为特色的全国重点大学&#xff0c;由江苏省人民政府、中华人民共和国教育部、中国气…

轻量级自动化测试框架WebZ

一、什么是WebZ WebZ是我用Python写的“关键字驱动”的自动化测试框架&#xff0c;基于WebDriver。 设计该框架的初衷是&#xff1a;用自动化测试让测试人员从一些简单却重复的测试中解放出来。之所以用“关键字驱动”模式是因为我觉得这样能让测试人员&#xff08;测试执行人员…

7-2 成绩转换

分数 15 全屏浏览题目 切换布局 作者 沈睿 单位 浙江大学 本题要求编写程序将一个百分制成绩转换为五分制成绩。转换规则&#xff1a; 大于等于90分为A&#xff1b;小于90且大于等于80为B&#xff1b;小于80且大于等于70为C&#xff1b;小于70且大于等于60为D&#xff1b;小…

gitlab修改远程仓库地址

目录 背景&#xff1a; 解决&#xff1a; 1.删除本地仓库关联的远程地址&#xff0c;添加新的远程仓库地址 2.直接修改本地仓库关联的远程仓库地址 3.打开.git隐藏文件修改远程仓库地址 4.拉取代码报错(git host key verification failed) 背景&#xff1a; 公司搬家&#…

数据可视化工具的三大类报表制作流程分享

电脑&#xff08;pc&#xff09;、移动、大屏三大类型的BI数据可视化报表制作步骤基本相同&#xff0c;差别就在于尺寸调整和具体的报表布局。这对于采用点击、拖拉拽方式来制作报表的奥威BI数据可视化工具来说就显得特别简单。接下来&#xff0c;我们就一起看看不这三大类型的…

【第三阶段】kotlin语言中的先决条件函数

用于函数内部判断异常&#xff0c;节省开发 1.checkNotNull&#xff08;&#xff09;如果传入为null则抛出异常 fun main() {var name:String?nullcheckNotNull(name) }执行结果 2.requireNotNull ()如果传入为null则抛出异常 fun main() {var name:String?nullrequireNot…

【图像分类】理论篇(4)图像增强opencv实现

随机旋转 随机旋转是一种图像增强技术&#xff0c;它通过将图像以随机角度进行旋转来增加数据的多样性&#xff0c;从而帮助改善模型的鲁棒性和泛化能力。这在训练深度学习模型时尤其有用&#xff0c;可以使模型更好地适应各种角度的输入。 原图像&#xff1a; 旋转后的图像&…

快手商品详情数据API 抓取快手商品价格、销量、库存、sku信息

快手商品详情数据API是用来获取快手商品详情页数据的接口&#xff0c;请求参数为商品ID&#xff0c;这是每个商品唯一性的标识。返回参数有商品标题、商品标题、商品简介、价格、掌柜昵称、库存、宝贝链接、宝贝图片、商品SKU等。 接口名称&#xff1a;item_get 公共参数 名…
最新文章