RAM模型从数据准备到pretrain、finetune与推理全过程详细说明

提示:RAM++模型:环境安装、数据准备与说明、模型推理、模型finetune、模型pretrain等

文章目录

  • 前言
  • 一、环境安装
  • 二、数据准备与解读
    • 1.数据下载
    • 2.数据标签内容解读
    • 3.标签map内容解读
  • 三、finetune训练
    • 1.微调训练命令
    • 2.load载入参数问题
    • 3.权重载入
    • 4.数据加载问题
    • 5.设备不匹配报错
    • 6.运行结果
  • 四、pretrain预训练
    • 1.预训练命令
    • 2.swin_large_patch4_window12_384_22k.pth权重
      • a.下载
      • b.权重加载修改
    • 3.ram_plus_tag_embedding_class_4585_des_51.pth权重
      • a.下载
      • b.权重加载修改
    • 4.变量设备匹配问题
    • 5. 预训练成功显示
  • 五、数据加载源码简单解读
  • 六、推理


前言

随着SAM模型分割一切大火之后,又有RAM模型识别一切,RAM模型由来可有三篇模型构成,TAG2TEXT为首篇将tag引入VL模型中,由tagging、generation、alignment分支构成,随后才是RAM模型,主要借助CLIP模型辅助与annotation处理trick,由tagging、generation分支构成,最后才是RAM++模型,该模型引入semantic concepts到图像tagging训练框架,RAM++模型能够利用图像-标签-文本三者之间的关系,整合image-text alignment 和 image-tagging 到一个统一的交互框架里。作者也将三个模型整合成一套代码,本文将介绍RAM++模型,主要内容包含环境安装、数据准备与说明、模型推理、模型finetune、模型pretrain等内容,并逐过程解读,也帮读者踩完所有坑,只要按照我我步骤将会实现RAM流畅运行。


TAG2TEXT论文链接:点击这里
RAM论文链接:点击这里
RAM++论文链接:点击这里
github官网链接:点击这里

一、环境安装

说实话,环境安装按照官网来,没有报什么错,可直接推理运行,但是训练可能会缺一些东西,后续将介绍,环境安装如下:

Install recognize-anything as a package:

pip install git+https://github.com/xinyu1205/recognize-anything.git

Or, for development, you may build from source

git clone https://github.com/xinyu1205/recognize-anything.git
cd recognize-anything
pip install -e .

二、数据准备与解读

1.数据下载

图像数据需要根据相应内容去下载,而数据标签下载可以去github代码官网链接,点击下面红框即可。当然你也可转到下面网页链接。
数据标签下载:https://huggingface.co/datasets/xinyu1205/recognize-anything-dataset-14m/tree/main
在这里插入图片描述
当进入标签页面如下:
在这里插入图片描述

2.数据标签内容解读

当你下载了标签后,你能发现标签实际是列表,列表中每个数据又是一个字典,包含image_path、caption、union_label_id、parse_label_id字典,以vg_ram.json标签举列,我们取第一个元素,如下图所示:
在这里插入图片描述
我们进一步展开该数据,你会重点发现parse_label_id是一个二维列表,每一行是对对应caption描述取的tag而union_label_id是一维列表,parse_label_id中tag都能在union_label_id找到,反之不行。如下图:
在这里插入图片描述

3.标签map内容解读

我们在上面可看到parse_label_id与union_label_id是数字,那么这些数字如何得到,必然有一个映射表,该表是ram_tag_list_4585_llm_tag_descriptions.json文件中,该文件也是一个列表,列表中每个元素是一个字典,该字典key就是tag,value是一个列表,是对key的描述,我查看value的列表有50个描述。其中该文件列表位置(索引)就代表key(tag),这也是parse_label_id与union_label_id的数字。如下:

在这里插入图片描述

当然,RAM模型数据可以一个元素的一张图有多个描述如下左图,也可以多个元素表示同一张图,进行多个描述,如下:
在这里插入图片描述

三、finetune训练

1.微调训练命令

可看出训练使用finetune.py文件,参数配置是finetune.yaml文件,模型类型选择是ram_plus文件,如下:

python -m torch.distributed.run --nproc_per_node=8 finetune.py \   --model-type ram_plus \   --config ram/configs/finetune.yaml  \   --checkpoint outputs/ram_plus/checkpoint_04.pth \   --output-dir outputs/ram_plus_ft

我是直接运行finetune.py文件,使用远程链接方式运行的!

2.load载入参数问题

当执行微调命令时,我遇到yaml载入问题,如下图:
在这里插入图片描述
当然这个是个小问题,与环境相关,可能你们不会遇到,若遇到可尝试我的解决方法:
导入包ruamel.yaml,更改原有代码

config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

为以下代码即可:

import ruamel.yaml  yaml = ruamel.yaml.YAML(typ='rt') config = yaml.load(open(args.config, 'r'))

注:该问题pretrain可能也会遇到。

3.权重载入

第二个问题,模型权重 模型权重载入需要修改,根据你的需求可修改权重路径,如下图:  我使用ram++,将model_clip, _ = clip.load("/home/notebook/data/group/huangxinyu/clip/ViT-B-16.pt")中的地址替换即可。

权重下载地址如下:

_MODELS = {     
"RN50":"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",     
"RN101":"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",     
"RN50x4":"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",   
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",   
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",   
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",    
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",     
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",   
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 
}  

4.数据加载问题

需在finetune.yaml文件中设定image_path_root: “” 参数,使得该参数与下图image_path合并为图像绝对路径,我设定如下:

 image_path_root: "/home/Project/recognize-anything/datasets/train"  

图像路径如下图所示:
在这里插入图片描述

5.设备不匹配报错

运行预训练命令依然会报错,如下:
在这里插入图片描述
该问题也是小问题,就是变量设备不匹配问题,在finetune.py文件,为image_tag变量指定设备,添加一句代码:

image_tag = image_tag.to(device,non_blocking=True) 

修改后整体代码如下:

for i, (image, image_224, caption, image_tag, parse_tag) in enumerate(metric_logger.log_every(data_loader, 	  print_freq, header)):          
	optimizer.zero_grad()          
	batch_text_embed = build_text_embed(model_clip,caption)                  
	image = image.to(device,non_blocking=True)         
	image_224 = image_224.to(device,non_blocking=True)                  
	image_tag = image_tag.to(device,non_blocking=True)                  
	clip_image_feature = model_clip.encode_image(image_224)          
	loss_tag, loss_dis, loss_alignment = model(image, caption, image_tag, clip_image_feature, batch_text_embed)           
	loss = loss_tag + loss_dis + loss_alignment  

6.运行结果

之后运行结果如下: 在这里插入图片描述

我们进一步可发现使用一张3090显卡,batch为20即可满负载,如下:
在这里插入图片描述
注:以上微调内容某些在训练时候遇到,按其修改即可!

四、pretrain预训练

1.预训练命令

可看出预训练使用pretrain.py文件,参数配置是pretrain.yaml文件,模型类型选择是ram_plus文件,如下:

python -m torch.distributed.run --nproc_per_node=8 pretrain.py \
  --model-type ram_plus \
  --config ram/configs/pretrain.yaml  \
  --output-dir outputs/ram_plus

2.swin_large_patch4_window12_384_22k.pth权重

a.下载

但你直接使用该命令时候,会报如下错误:
在这里插入图片描述

以上报错是因为缺失相应权重swin_large_patch4_window12_384_22k.pth,我们只需通过下面链接点击这里,获得如下图权重下载即可,如下:
在这里插入图片描述

b.权重加载修改

对应权重下载实际是pretrain.yaml参数设置的vit: 'swin_l'image_size: 224共同决定,我们将其定位为config_swinl_224.json文件,如下图:
在这里插入图片描述
上面我们已知权重路径更改位置,我们将其下载权重绝对路径替换即可,如下代码示列:

{
{
    "ckpt": "绝对路径位置/swin_large_patch4_window12_384_22k.pth",
    "vision_width": 1536,
    "image_res": 224,
    "window_size": 7,
    "embed_dim": 192,
    "depths": [ 2, 2, 18, 2 ],
    "num_heads": [ 6, 12, 24, 48 ]
  }
  }

3.ram_plus_tag_embedding_class_4585_des_51.pth权重

a.下载

当你再次使用该命令时候,会报如下错误:
在这里插入图片描述

不要慌张,依然是权重问题,我们只需链接:点击这里,可在huggingface下载我们想要的权重文件。

b.权重加载修改

对应权重下载后有2种方法可实现权重正确加载,第一将下载权重放到指定路,第二将在源码ram_plus.py改成绝对路径,如下图:
在这里插入图片描述

4.变量设备匹配问题

当你很开心再次使用预训练命令时,会报如下错误(该错误在finetune也会出现):
在这里插入图片描述
该问题也是小问题,就是变量设备不匹配问题,从上图报错地方可追述到ram_plus.py文件除了问题,实际决定该问题是在pretrain.py文件调用那里,主要是image_tag是一个传入参数未能给定device,我们在pretrain.py下面代码给定即可,我也建议在pretrain.py修改,而不要动报错地方修改,你只需添加image_tag = image_tag.to(device, non_blocking=True)指定设备,修改如下:

    for i, (image, caption, image_tag, parse_tag) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        
        if epoch==0:
            warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr'])
            
        optimizer.zero_grad()

        batch_text_embed = build_text_embed(model_clip,caption)
        
        image = image.to(device,non_blocking=True)

        image_tag = image_tag.to(device, non_blocking=True) #


5. 预训练成功显示

如出现下图表示预训练成功,如下:
在这里插入图片描述
我们进一步可发现使用一张3090显卡,batch为20即可满负载,如下:
在这里插入图片描述

五、数据加载源码简单解读

标签源码如下,可看到图像做了2次加工一次该模型本身使用image,一次为图像特征提取swin模型使用image_224,而caption为一句话(若为多句随机选择一句),该句话直接通过clip的文本编码获得特征,image_tag 是union_label_id, parse_tag是parse_label_id,具体如下代码:


    def __getitem__(self, index):    
        
        ann = self.ann[index]   

        image_path_use = os.path.join(self.root, ann['image_path'])
        image = Image.open(image_path_use).convert('RGB')   
        image = self.transform(image)

        image_224 = Image.open(image_path_use).convert('RGB')  
        image_224 = self.transform_224(image_224)
        # image_tag 是union_label_id
        num = ann['union_label_id']
        image_tag = np.zeros([self.class_num])
        image_tag[num] = 1
        image_tag = torch.tensor(image_tag, dtype = torch.long)

        caption_index = np.random.randint(0, len(ann['caption']))  # 有的数据集有多个描述

        caption = pre_caption(ann['caption'][caption_index],30)
        # parse_tag是parse_label_id
        num = ann['parse_label_id'][caption_index]
        parse_tag = np.zeros([self.class_num])
        parse_tag[num] = 1
        parse_tag = torch.tensor(parse_tag, dtype = torch.long)

        return image, image_224, caption, image_tag, parse_tag

六、推理

推理可直接使用命令,指定权重我在pretrain已给出链接,可自行下载:

python batch_inference.py \   --model-type ram_plus \   --checkpoint pretrained/ram_plus_swin_large_14m.pth \   --dataset openimages_common_214 \   --output-dir outputs/ram_plus

当然,你也可以使用我的代码,我是将一个文件夹循环推理,并将推理结果打印于图上便于查看,如下:

'''
 * The Recognize Anything Plus Model (RAM++)
 * Written by Xinyu Huang
'''
import argparse
import os

import numpy as np
import random

import torch

from PIL import Image
from ram.models import ram_plus
from ram import inference_ram as inference
from ram import get_transform


parser = argparse.ArgumentParser(
    description='Tag2Text inferece for tagging and captioning')
parser.add_argument('--image',

                    help='path to dataset',
                    default='images/demo/demo1.jpg')
parser.add_argument('--pretrained',

                    help='path to pretrained model',
                    default='路径位置/ram_plus_swin_large_14m.pth')
parser.add_argument('--image-size',
                    default=384,
                    type=int,
                    metavar='N',
                    help='input image size (default: 448)')




import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont

def cv2ImgAddText(img, text, left, top, textColor=(0, 255, 0), textSize=20):
    if (isinstance(img, np.ndarray)):  # 判断是否OpenCV图片类型
        img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    # 创建一个可以在给定图像上绘图的对象
    draw = ImageDraw.Draw(img)
    # 字体的格式
    fontStyle = ImageFont.truetype("/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", textSize, encoding="utf-8") 
    # 绘制文本
    draw.text((left, top), text, textColor, font=fontStyle)
    # 转换回OpenCV格式
    return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)

def build_dir(out_dir):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir,exist_ok=True)
    return out_dir


if __name__ == "__main__":

    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = get_transform(image_size=args.image_size)

    #######load model
    model = ram_plus(pretrained=args.pretrained,
                             image_size=args.image_size,
                             vit='swin_l')
    model.eval()

    model = model.to(device)


    total = sum(p.numel() for p in model.parameters())  # 统计个数
    print("模型参数总量: %.2f million\t" % (total / 1e6), " 以float32模型内存占用:%.2f M" % (total * 4 / 1e6))

    
    
    # 下面是推理
    
    file_root='/推理文件路径/sam_test' # 这个是多个文件夹路径
    save_file_path=build_dir('runs')

    for file_name in os.listdir(file_root):
        save_path=os.path.join(save_file_path,file_name)
        img_root=os.path.join(file_root,file_name)
        for img_name in os.listdir(img_root):
            img_path=os.path.join(img_root,img_name)
            image = transform(Image.open(img_path)).unsqueeze(0).to(device)
            res = inference(image, model)
            # print("Image Tags: ", res[0])

            # print("图像标签: ", res[1])
            img = cv2.imread(img_path)
            N=int(len(res[1])/2)
            r1 = res[1][:N]
            r2 = res[1][N:]
            # r3 = res[1][2*N:]
            img = cv2ImgAddText(img, r1, 40, 50, textColor=(255, 0, 0), textSize=20)
            img = cv2ImgAddText(img, r2, 40, 200, textColor=(255, 0, 0), textSize=20)
            # img = cv2ImgAddText(img, r2, 40, 300, textColor=(255, 0, 0), textSize=40)

            build_dir(save_path)
            cv2.imwrite(os.path.join(save_path, img_name), img)



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/185987.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

大数据技术之数据安全与网络安全——CMS靶场实训

大数据技术之数据安全与网络安全——CMS靶场实训 在当今数字化时代,大数据技术的迅猛发展带来了前所未有的数据增长,同时也催生了对数据安全和网络安全的更为迫切的需求。本篇博客将聚焦于大数据技术背景下的数据安全与网络安全,并通过CMS&a…

4.操作系统常见面试题(2)

3.4 虚拟内存 直接使⽤物理内存会产⽣⼀些问题 1. 内存空间利⽤率的问题:各个进程对内存的使⽤会导致内存碎⽚化,当要⽤ malloc 分配⼀块很⼤的内存空间时,可能会出现虽然有⾜够多的空闲物理内存,却没有⾜够⼤的连续空闲内存这种…

点大商城V2.5.3分包小程序端+小程序上传提示限制分包制作教程

这几天很多播播资源会员反馈点大商城V2.5.3小程序端上传时提示大小超限,官方默认单个包都不能超过2M,总分包不能超20M。如下图提示超了93KB,如果出现超的不多情况下可采用手动删除一些images目录下不使用的图片,只要删除超过100KB…

82基于matlab GUI的图像处理

基于matlab GUI的图像处理,功能包括图像一般处理(灰度图像、二值图);图像几何变换(旋转可输入旋转角度、平移、镜像)、图像边缘检测(拉普拉斯算子、sobel算子、wallis算子、roberts算子&#xf…

unordered_map 与 unordered_set 的模拟实现

unordered_map 与 unordred_set 的模拟实现与 map 与 set 的模拟实现差不多。map 与 set 的模拟实现中,底层的数据结构是红黑树。unordered_map 与 unordered_set 的底层数据结构是哈希表。因此,在模拟实现 unordered_map 与 unordred_set 之前你必须确保…

nodejs微信小程序+python+PHP-青云商场管理系统的设计与实现-安卓-计算机毕业设计

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性:…

org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder

密码,加密,解密 spring-security-crypto-5.7.3.jar /** Copyright 2002-2011 the original author or authors.** Licensed under the Apache License, Version 2.0 (the "License");* you may not use this file except in compliance with t…

HTML新特性【缩放图像、图像切片、平移、旋转、缩放、变形、裁切路径、时钟、运动的小球】(二)-全面详解(学习总结---从入门到深化)

目录 绘制图像_缩放图像 绘制图像_图像切片 Canvas状态的保存和恢复 图形变形_平移 图形变形_旋转 图形变形_缩放 图形变形_变形 裁切路径 动画_时钟 动画_运动的小球 引入外部SVG 绘制图像_缩放图像 ctx.drawImage(img, x, y, width, height) img &#xf…

开源与闭源

我的观点: 开源与闭源软件都有各自的优势和劣势,没有绝对的对错之分。.. 一、开源和闭源的优劣势比较 开源的好处与劣处 优势: 创新与合作:开源软件能够吸引更多的开发者参与到项目中来,促进创新和合作。开放的源代码…

【网易云商】构建高效 SaaS 系统的技术要点与最佳实践

SaaS 是什么 定义 相信大家都对云服务中的 IaaS、PaaS、SaaS 早就有所耳闻,现在更是衍生出了 aPaaS、iPaaS、DaaS 等等的类似概念。对于 SaaS 也有各种各样的定义,本文给出的定义是: SaaS 是一种基于互联网提供服务和软件的交付模式&#xf…

一文彻底看懂Python切片,Python切片理解与操作

1.什么是切片 切片是Python中一种用于操作序列类型(如列表、字符串和元组)的方法。它通过指定起始索引和结束索引来截取出序列的一部分,形成一个新的序列。切片是访问特定范围内的元素,就是一个Area。 说个笑话:切片不是切片,而是切片,但是又是切片。大家理解下呢(末…

80C51单片机----数据传送类指令

目录 一.一般传送指令,即mov指令 1.16位传送(仅1条) 2.8位传送 (1)目的字节为A(累加器) (2)目的字节为Rn(工作寄存器) (3)目的字节为direct…

java中的String.format()方法详解

介绍 String.format() 是 Java 中的一个字符串格式化方法,它用于生成指定格式的字符串。这个方法可以接受一个或多个参数,并将它们按照指定的格式插入到字符串中。它使用了类似于 C 语言中的 printf 函数的语法。 String.format() 方法的使用格式如下&…

Tars框架 Tars-Go 学习

Tars 框架安装 网上安装教程比较多,官方可以参数这个 TARS官方文档 (tarsyun.com) 本文主要介绍部署应用。 安装完成后Tars 界面 增加应用amc 部署申请 amc.GoTestServer.GoTestObj 名称不知道的可以参考自己创建的app config 点击刷新可以看到自己部署的应用 服…

微机原理_3

一、单项选择题(本大题共15小题,每小题3分,共45分。在每小题给出的四个备选项中,选出一个正确的答案,请将选定的答案填涂在答题纸的相应位置上。) 在 8086 微机系统中,完成对指令译码操作功能的部件是()。 A. EU B. BIU C. SRAM D. DRAM 使计算机执行某…

【Rust日报】2023-11-22 Floneum -- 基于 Rust 的一款用于 AI 工作流程的图形编辑器

Floneum -- 基于 Rust 的一款用于 AI 工作流程的图形编辑器 Floneum 是一款用于 AI 工作流程的图形编辑器,专注于社区制作的插件、本地 AI 和安全性。 Floneum 有哪些特性: 可视化界面:您无需任何编程知识即可使用Floneum。可视化图形编辑器可…

2023年金融信创行业研究报告

第一章 行业概况 1.1 定义 金融信创是指在金融行业中应用的信息技术,特别是那些涉及到金融IT基础设施、基础软件、应用软件和信息安全等方面的技术和产品。这一概念源于更广泛的“信创 (信息技术应用创新)”,即通过中国国产信息技术替换海外信息技术&a…

某60区块链安全之未初始化的存储指针实战二学习记录

系列文章目录 文章目录 系列文章目录未初始化的存储指针实战二实验目的实验环境实验工具实验原理实验内容实验过程EXP利用 未初始化的存储指针实战二 实验目的 学会使用python3的web3模块 学会分析以太坊智能合约未初始化的存储指针漏洞 找到合约漏洞进行分析并形成利用 实验…

【Vue】图片切换

上一篇&#xff1a; vue的指令 https://blog.csdn.net/m0_67930426/article/details/134599378?spm1001.2014.3001.5502 本篇所需要的指令有&#xff1a; v-on v-bind v-show <!DOCTYPE html> <html lang"en"> <head><meta charset"…

【微服务专题】SpringBoot自动配置源码解析

目录 前言阅读对象阅读导航前置知识笔记正文0、什么是自动配置0.1 基本概念0.2 SpringBoot中的【约定大于配置】0.3 从SpringMVC看【约定大于配置】0.4 从Redis看【约定大于配置】 一、EnableAutoConfiguration源码解析二、SpringBoot常用条件注解源码解析2.1 自定义条件注解2.…