gdip-yolo项目解读:gdip模块 |mdgip模块 |GDIP regularizer模块的使用分析

gdip-yolo是2022年提出了一个端到端的图像自适应目标检测框架,其论文中的效果展示了良好的图像增强效果。其提出了gdip模块 |mdgip模块 |GDIP regularizer模块等模块,并表明这是效果提升的关键。为此对gdip-yolo的项目进行深入分析。
gdip-yolo的论文可以查阅:https://hpg123.blog.csdn.net/article/details/135658906

在这里插入图片描述

1、整体分析

gdip-yolo项目基于yolov3项目改进所实现,与原始代码相比,仅是删除了训练代码。这里的代码与核心部分gdip功能关联不是很强,其配置文件为常规yolov3训练配置文件。
在这里插入图片描述

1.1 配置文件

这里所展露的是gdip-yolo项目中基于py的方式编写配置文件,与新一代的配置文件格式yaml|yml相比存在一定不足。

# coding=utf-8
# project

DATA_PATH = "/scratch/data"
PROJECT_PATH = "/scratch/"
WEIGHT_PATH="/scratch/data/weights/darknet53_448.weights"

DATA = {"CLASSES":['person','bicycle','car','bus','motorbike'],
        "NUM":5}
#DATA = {"CLASSES":['bicycle','boat','bottle','bus','car','cat','chair','dog','motorbike','person'],
#        "NUM":10}
# model
MODEL = {"ANCHORS":[[(1.25, 1.625), (2.0, 3.75), (4.125, 2.875)],  # Anchors for small obj
            [(1.875, 3.8125), (3.875, 2.8125), (3.6875, 7.4375)],  # Anchors for medium obj
            [(3.625, 2.8125), (4.875, 6.1875), (11.65625, 10.1875)]] ,# Anchors for big obj
         "STRIDES":[8, 16, 32],
         "ANCHORS_PER_SCLAE":3
         }

# train
TRAIN = {
         "TRAIN_IMG_SIZE":448,
         "AUGMENT":True,
         "BATCH_SIZE":8,
         "MULTI_SCALE_TRAIN":False,
         "IOU_THRESHOLD_LOSS":0.5,
         "EPOCHS":80,
         "NUMBER_WORKERS":5,
         "MOMENTUM":0.9,
         "WEIGHT_DECAY":0.0005,
         "LR_INIT":1e-4,
         "LR_END":1e-6,
         "WARMUP_EPOCHS":2  # or None
         }


# test
TEST = {
        "TEST_IMG_SIZE":448,
        "BATCH_SIZE":1,
        "NUMBER_WORKERS":0,
        "CONF_THRESH":0.01,
        "NMS_THRESH":0.5,
        "MULTI_SCALE_TEST":False,
        "FLIP_TEST":False,
        "DATASET_PATH":"/scratch/data/RTTS",
        "DATASET_DIRECTORY":"JPEGImages"
        }

1.2 推理与测试代码

推理代码核心为eval目录下的各类evaluator文件,在项目外面的推理代码仅为for循环调用。通过对eval相关代码进行对比分析,发现针对于各类数据的测试代码处数据预处理部分有差异外其余结构都完全一致。
在这里插入图片描述

from torch.utils.data import DataLoader
import utils.gpu as gpu
from model.yolov3_multilevel_gdip import Yolov3
from tqdm import tqdm
from utils.tools import *
from eval.evaluator_RTTS_GDIP import Evaluator
import argparse
import os
import config.yolov3_config_RTTS as cfg
from utils.visualize import *
from tqdm import tqdm


# import os
# os.environ["CUDA_VISIBLE_DEVICES"]='0'


class Tester(object):
    def __init__(self,
                 weight_path=None,
                 gpu_id=0,
                 img_size=544,
                 visiual=None,
                 eval=False
                 ):
        self.img_size = img_size
        self.__num_class = cfg.DATA["NUM"]
        self.__conf_threshold = cfg.TEST["CONF_THRESH"]
        self.__nms_threshold = cfg.TEST["NMS_THRESH"]
        self.__device = gpu.select_device(gpu_id)
        self.__multi_scale_test = cfg.TEST["MULTI_SCALE_TEST"]
        self.__flip_test = cfg.TEST["FLIP_TEST"]

        self.__visiual = visiual
        self.__eval = eval
        self.__classes = cfg.DATA["CLASSES"]

        self.__model = Yolov3(cfg).to(self.__device)

        self.__load_model_weights(weight_path)

        self.__evalter = Evaluator(self.__model, visiual=False)


    def __load_model_weights(self, weight_path):
        print("loading weight file from : {}".format(weight_path))

        weight = os.path.join(weight_path)
        chkpt = torch.load(weight, map_location=self.__device)
        self.__model.load_state_dict(chkpt)
        # self.__model.load_state_dict(chkpt['model'])
        print("loading weight file is done")
        del chkpt


    def test(self):
        if self.__visiual:
            imgs = os.listdir(self.__visiual)
            for v in tqdm(imgs):
                path = os.path.join(self.__visiual, v)
                # print("test images : {}".format(path))

                img = cv2.imread(path)
                assert img is not None

                bboxes_prd = self.__evalter.get_bbox(img)
                if bboxes_prd.shape[0] != 0:
                    boxes = bboxes_prd[..., :4]
                    class_inds = bboxes_prd[..., 5].astype(np.int32)
                    scores = bboxes_prd[..., 4]

                    visualize_boxes(image=img, boxes=boxes, labels=class_inds, probs=scores, class_labels=self.__classes)
                    path = os.path.join(cfg.PROJECT_PATH, "results/rtts/{}".format(v))
                    cv2.imwrite(path, img)



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--weight_path', type=str, default='best.pt', help='weight file path')
    parser.add_argument('--visiual', type=str, default='path/to/images', help='test data path or None')
    parser.add_argument('--eval', action='store_true', default=True, help='eval the mAP or not')
    parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
    opt = parser.parse_args()

    Tester( weight_path=opt.weight_path,
            gpu_id=opt.gpu_id,
            eval=opt.eval,
            visiual=opt.visiual).test()

2、数据加载器分析

在gdip-yolo论文中提到训练时无需专属loss,理论上参照原始的dataloader即可,但是在分析代码时发现针对带雾数据、低亮度数据都有单独的加载器。为此进行源码分析。

2.1 IA_datasets_foggy.py

代码在utils\IA_datasets_foggy.py中,其关键代码如下所示,非foggy相关代码部分被博主删除了。可以看到IA_datasets_foggy中返回了img 与adv_img 图像,adv_img 为img的带雾副本图像(使用getFog函数实现)。这里奇怪的是adv_img 在论文中没有利用,却被返回了。带雾图像的概率为0.5。这里可以看出gdip-yolo使用的是在线数据增强的策略。ia-yolo:使用ASM来生成10个不同级别的雾,以包括在我们的综合训练集中的方差。我们以类似的方式从PascalVOC 2007测试集准备4952张图像(称为V_F_Ts)合成测试集。我们采用了一种混合策略,即混合使用雾和清晰的图像(以2:1的比例),即带雾概率为0.66

class VocDataset(Dataset):
    def __getitem__(self, item):

        img_org,adv_img_org, bboxes_org = self.__parse_annotation(self.__annotations[item])
        img_org = img_org.transpose(2, 0, 1)  # HWC->CHW
        adv_img_org = adv_img_org.transpose(2, 0, 1)  # HWC->CHW


        img,adv_img, bboxes = dataAug.Mixup()(img_org,adv_img_org, bboxes_org)
        del img_org, bboxes_org,adv_img_org


        label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes = self.__creat_label(bboxes)

        img = torch.from_numpy(img).float()
        adv_img = torch.from_numpy(adv_img).float()
        label_sbbox = torch.from_numpy(label_sbbox).float()
        label_mbbox = torch.from_numpy(label_mbbox).float()
        label_lbbox = torch.from_numpy(label_lbbox).float()
        sbboxes = torch.from_numpy(sbboxes).float()
        mbboxes = torch.from_numpy(mbboxes).float()
        lbboxes = torch.from_numpy(lbboxes).float()

        return img,adv_img, label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes

    def __parse_annotation(self, annotation):
        """
        Data augument.
        :param annotation: Image' path and bboxes' coordinates, categories.
        ex. [image_path xmin,ymin,xmax,ymax,class_ind xmin,ymin,xmax,ymax,class_ind ...]
        :return: Return the enhanced image and bboxes. bbox'shape is [xmin, ymin, xmax, ymax, class_ind]
        """

        anno = annotation.strip().split(' ')

        img_path = anno[0]
        img = cv2.imread(img_path)  # H*W*C and C=BGR
        assert img is not None, 'File Not Found ' + img_path
        bboxes = np.array([list(map(float, box.split(','))) for box in anno[1:]])
        img, bboxes = dataAug.RandomHorizontalFilp()(np.copy(img), np.copy(bboxes))
        img, bboxes = dataAug.RandomCrop()(np.copy(img), np.copy(bboxes))
        img, bboxes = dataAug.RandomAffine()(np.copy(img), np.copy(bboxes))
        adv_img = img.copy()  # H*W*C and C=BGR
        if random.randint(0,2) > 0:
            adv_img = normalize(adv_img)
            fog_img = getFog(adv_img.copy())
            fog_img = fog_img.astype(np.uint8)
            adv_img = fog_img.copy()
        # assert adv_img is not None, 'File Not Found ' + adv_img_path


        
        img, bboxes = dataAug.Resize((self.img_size, self.img_size), True)(np.copy(img), np.copy(bboxes))
        adv_img,_ = dataAug.Resize((self.img_size, self.img_size), True)(np.copy(adv_img), np.copy(bboxes))

        return img,adv_img, bboxes

2.2 IA_datasets_lightning.py

IA_datasets_lightning的代码在utils\IA_datasets_lightning.py中,同样按照惯例删除非关键代码进行分析。IA_datasets_lightning的实现风格与上一份一样,多了一个adv_img,为原始图像的低亮度副本,基于getLightning函数实现(伽玛从1.5到5的范围内均匀采样)。图像低亮度的概率为0.5。这里可以看出gdip-yolo使用的是在线数据增强的策略。ia-yolo: 我们从ExDark中选择具有对象的PascalVOC中的图像,并应用伽玛变化来模拟低光照条件,伽玛从1.5到5的范围内均匀采样。在训练过程中,我们采用混合策略(类似于雾设置),使用黑暗和清晰图像的混合。即0.66的低亮度概率

class VocDataset(Dataset):
    def __getitem__(self, item):

        img_org,adv_img_org, bboxes_org = self.__parse_annotation(self.__annotations[item])
        img_org = img_org.transpose(2, 0, 1)  # HWC->CHW
        adv_img_org = adv_img_org.transpose(2, 0, 1)  # HWC->CHW


        img,adv_img, bboxes = dataAug.Mixup()(img_org,adv_img_org, bboxes_org)
        del img_org, bboxes_org,adv_img_org


        label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes = self.__creat_label(bboxes)

        img = torch.from_numpy(img).float()
        adv_img = torch.from_numpy(adv_img).float()
        label_sbbox = torch.from_numpy(label_sbbox).float()
        label_mbbox = torch.from_numpy(label_mbbox).float()
        label_lbbox = torch.from_numpy(label_lbbox).float()
        sbboxes = torch.from_numpy(sbboxes).float()
        mbboxes = torch.from_numpy(mbboxes).float()
        lbboxes = torch.from_numpy(lbboxes).float()

        return img,adv_img, label_sbbox, label_mbbox, label_lbbox, sbboxes, mbboxes, lbboxes



    def __parse_annotation(self, annotation):
        """
        Data augument.
        :param annotation: Image' path and bboxes' coordinates, categories.
        ex. [image_path xmin,ymin,xmax,ymax,class_ind xmin,ymin,xmax,ymax,class_ind ...]
        :return: Return the enhanced image and bboxes. bbox'shape is [xmin, ymin, xmax, ymax, class_ind]
        """

        anno = annotation.strip().split(' ')

        img_path = anno[0]
        img = cv2.imread(img_path)  # H*W*C and C=BGR
        assert img is not None, 'File Not Found ' + img_path
        bboxes = np.array([list(map(float, box.split(','))) for box in anno[1:]])
        img, bboxes = dataAug.RandomHorizontalFilp()(np.copy(img), np.copy(bboxes))
        img, bboxes = dataAug.RandomCrop()(np.copy(img), np.copy(bboxes))
        img, bboxes = dataAug.RandomAffine()(np.copy(img), np.copy(bboxes))
        adv_img = img.copy()  # H*W*C and C=BGR
        if random.randint(0,2) > 0:
            adv_img = normalize(adv_img)
            l_img = getLightning(adv_img.copy())
            l_img = l_img.astype(np.uint8)
            adv_img = l_img.copy()
        # assert adv_img is not None, 'File Not Found ' + adv_img_path


        
        img, bboxes = dataAug.Resize((self.img_size, self.img_size), True)(np.copy(img), np.copy(bboxes))
        adv_img,_ = dataAug.Resize((self.img_size, self.img_size), True)(np.copy(adv_img), np.copy(bboxes))

        return img,adv_img, bboxes

2.3 getFog与getLightning函数

getFog函数实现如下所示,相比于ia-yolo的实现代码行数更多

def getFog(img):
    h,w,c = img.shape
    x = np.linspace(0,w-1,w)
    y = np.linspace(0,h-1,h)
    xx,yy = np.meshgrid(x,y)
    x_c , y_c = w//2 , h//2
    transmission_map = np.zeros((h,w,1))
    c = np.random.uniform(0,9)
    beta = 0.01*c+0.05
    A = 0.5
    d = -0.04 * np.sqrt((yy-y_c)**2+(xx-x_c)**2)+np.sqrt(np.maximum(h,w))
    transmission_map[:,:,0] = np.exp(-beta*d)
    fog_img = img*transmission_map + (1-transmission_map)* A
    # fog_img = normalize(fog_img)
    fog_img = fog_img*255.
    fog_img = np.clip(fog_img,0,255)
    return fog_img

getLightning的实现如下,基于gamma变化实现

def getLightning(img):
    gamma = np.random.uniform(1.5,5)
    img = img**gamma
    img = img*255.
    img = np.clip(img,0,255)
    return img

3、GDIP-yolo关键模块

在GDIP-yolo论文中描述到,没有额外使用loss,故此所开源的loss代码与原始yolov3 loss一模一样。但是在GDIP regularizer模块中需要额外loss(与原始图像计算l1 loss与 mae loss作为正则项),但是没有找到相应实现。

3.1 GatedDIP

在gidp-yolo项目中有多个GatedDIP模块,这里以符合论文中描述的代码为参考。通过代码注释可以看到GatedDIP可以使用vgg16做编码器。这里的GatedDIP内带VisionEncoder。

import math
import torch
import torchvision
from model.vision_encoder import VisionEncoder

class GatedDIP(torch.nn.Module):
    """_summary_

    Args:
        torch (_type_): _description_
    """
    def __init__(self,
                encoder_output_dim : int = 256,
                num_of_gates : int = 7):
        """_summary_

        Args:
            encoder_output_dim (int, optional): _description_. Defaults to 256.
            num_of_gates (int, optional): _description_. Defaults to 7.
        """
        super(GatedDIP,self).__init__()
        print("GatedDIP with custom Encoder!!")

        # Encoder Model
        # self.encoder = torchvision.models.vgg16(pretrained=False)
        self.encoder = VisionEncoder(encoder_output_dim=encoder_output_dim)

        # Gating Module
        self.gate_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,num_of_gates,bias=True))

        # White-Balance Module
        self.wb_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,3,bias=True))

        # Gamma Module
        self.gamma_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,1,bias=True))

        # Sharpning Module
        self.gaussian_blur = torchvision.transforms.GaussianBlur(13, sigma=(0.1, 5.0))
        self.sharpning_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,1,bias=True))

        # De-Fogging Module
        self.defogging_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,1,bias=True))

        # Contrast Module
        self.contrast_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,1,bias=True))

        # Contrast Module
        self.tone_module = torch.nn.Sequential(torch.nn.Linear(encoder_output_dim,8,bias=True))

    def rgb2lum(self,img: torch.tensor):
        """_summary_
        Args:
            img (torch.tensor): _description_
        Returns:
            _type_: _description_
        """
        img = 0.27 * img[:, 0, :, :] + 0.67 * img[:, 1, :,:] + 0.06 * img[:, 2, :, :]
        return img
   
    def lerp(self ,a : int , b : int , l : torch.tensor):
        return (1 - l.unsqueeze(2).unsqueeze(3)) * a + l.unsqueeze(2).unsqueeze(3) * b


    def dark_channel(self,x : torch.tensor):
        """_summary_

        Args:
            x (torch.tensor): _description_

        Returns:
            _type_: _description_
        """
        z = x.min(dim=1)[0].unsqueeze(1)
        return z
   
    def atmospheric_light(self,x : torch.tensor,dark : torch.tensor ,top_k : int=1000):
        """_summary_

        Args:
            x (torch.tensor): _description_
            top_k (int, optional): _description_. Defaults to 1000.

        Returns:
            _type_: _description_
        """
        h,w = x.shape[2],x.shape[3]
        imsz = h * w
        numpx = int(max(math.floor(imsz/top_k),1))
        darkvec = dark.reshape(x.shape[0],imsz,1)
        imvec = x.reshape(x.shape[0],3,imsz).transpose(1,2)
        indices = darkvec.argsort(1)
        indices = indices[:,imsz-numpx:imsz]
        atmsum = torch.zeros([x.shape[0],1,3]).cuda()
        # print(imvec[:,indices[0,0]].shape)
        for b in range(x.shape[0]):
            for ind in range(1,numpx):
                atmsum[b,:,:] = atmsum[b,:,:] + imvec[b,indices[b,ind],:]
        a = atmsum/numpx
        a = a.squeeze(1).unsqueeze(2).unsqueeze(3)
        return a
   
    def blur(self,x : torch.tensor):
        """_summary_

        Args:
            x (torch.tensor): _description_

        Returns:
            _type_: _description_
        """
        return self.gaussian_blur(x)


    def defog(self,x:torch.tensor ,latent_out : torch.tensor ,fog_gate : torch.tensor):
        """Defogging module is used for removing the fog from the image using ASM
        (Atmospheric Scattering Model).
        I(X) = (1-T(X)) * J(X) + T(X) * A(X)
        I(X) => image containing the fog.
        T(X) => Transmission map of the image.
        J(X) => True image Radiance.
        A(X) => Atmospheric scattering factor.

        Args:
            x (torch.tensor): Input image I(X)
            latent_out (torch.tensor): Feature representation from DIP Module.
            fog_gate (torch.tensor): Gate value raning from (0. - 1.) which enables defog module.

        Returns:
            torch.tensor : Returns defogged image with true image radiance.
        """
        omega = self.defogging_module(latent_out).unsqueeze(2).unsqueeze(3)
        omega = self.tanh_range(omega,torch.tensor(0.1),torch.tensor(1.))
        dark_i = self.dark_channel(x)
        a = self.atmospheric_light(x,dark_i)
        i = x/a
        i = self.dark_channel(i)
        t = 1. - (omega*i)
        j = ((x-a)/(torch.maximum(t,torch.tensor(0.01))))+a
        j = (j - j.min())/(j.max()-j.min())
        # j = j* fog_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3)
        return j
       
    def white_balance(self,x : torch.tensor,latent_out : torch.tensor ,wb_gate: torch.tensor):
        """ White balance of the image is predicted using latent output of an encoder.

        Args:
            x (torch.tensor): Input RGB image.
            latent_out (torch.tensor): Output from the last layer of an encoder.
            wb_gate (torch.tensor): White-balance gate used to change the influence of color scaled image.

        Returns:
            torch.tensor: returns White-Balanced image.
        """
        log_wb_range = 0.5
        wb = self.wb_module(latent_out)
        wb = torch.exp(self.tanh_range(wb,-log_wb_range,log_wb_range))
       
        color_scaling = 1./(1e-5 + 0.27 * wb[:, 0] + 0.67 * wb[:, 1] +
        0.06 * wb[:, 2])
        wb = color_scaling.unsqueeze(1)*wb
        wb_out = wb.unsqueeze(2).unsqueeze(3)*x
        wb_out = (wb_out-wb_out.min())/(wb_out.max()-wb_out.min())
        # wb_out = wb_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3)*wb_out
        return wb_out

    def tanh01(self,x : torch.tensor):
        """_summary_

        Args:
            x (torch.tensor): _description_

        Returns:
            _type_: _description_
        """
        return torch.tanh(x)*0.5+0.5

    def tanh_range(self,x : torch.tensor,left : float,right : float):
        """_summary_

        Args:
            x (torch.tensor): _description_
            left (float): _description_
            right (float): _description_

        Returns:
            _type_: _description_
        """
        return self.tanh01(x)*(right-left)+ left

    def gamma_balance(self,x : torch.tensor,latent_out : torch.tensor,gamma_gate : torch.tensor):
        """_summary_

        Args:
            x (torch.tensor): _description_
            latent_out (torch.tensor): _description_
            gamma_gate (torch.tensor): _description_

        Returns:
            _type_: _description_
        """
        log_gamma = torch.log(torch.tensor(2.5))
        gamma = self.gamma_module(latent_out).unsqueeze(2).unsqueeze(3)
        gamma = torch.exp(self.tanh_range(gamma,-log_gamma,log_gamma))
        g = torch.pow(torch.maximum(x,torch.tensor(1e-4)),gamma)
        g = (g-g.min())/(g.max()-g.min())
        # g = g*gamma_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3)
        return g
   
    def sharpning(self,x : torch.tensor,latent_out: torch.tensor,sharpning_gate : torch.tensor):
        """_summary_

        Args:
            x (torch.tensor): _description_
            latent_out (torch.tensor): _description_
            sharpning_gate (torch.tensor): _description_

        Returns:
            _type_: _description_
        """
        out_x = self.blur(x)
        y = self.sharpning_module(latent_out).unsqueeze(2).unsqueeze(3)
        y = self.tanh_range(y,torch.tensor(0.1),torch.tensor(1.))
        s = x + (y*(x-out_x))
        s = (s-s.min())/(s.max()-s.min())
        # s = s * (sharpning_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3))
        return s
   
    def identity(self,x : torch.tensor,identity_gate : torch.tensor):
        """_summary_

        Args:
            x (torch.tensor): _description_
            identity_gate (torch.tensor): _description_

        Returns:
            _type_: _description_
        """
        # x = x*identity_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3)
        return x
   
    def contrast(self,x : torch.tensor,latent_out : torch.tensor,contrast_gate : torch.tensor):
        """_summary_

        Args:
            x (torch.tensor): _description_
            latent_out (torch.tensor): _description_
            contrast_gate (torch.tensor): _description_

        Returns:
            _type_: _description_
        """
        alpha = torch.tanh(self.contrast_module(latent_out))
        luminance = torch.minimum(torch.maximum(self.rgb2lum(x), torch.tensor(0.0)), torch.tensor(1.0)).unsqueeze(1)
        contrast_lum = -torch.cos(math.pi * luminance) * 0.5 + 0.5
        contrast_image = x / (luminance + 1e-6) * contrast_lum
        contrast_image = self.lerp(x, contrast_image, alpha)
        contrast_image = (contrast_image-contrast_image.min())/(contrast_image.max()-contrast_image.min())
        # contrast_image = contrast_image * contrast_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3)
        return contrast_image
   
    def tone(self,x : torch.tensor,latent_out : torch.tensor,tone_gate : torch.tensor):
        """_summary_

        Args:
            x (torch.tensor): _description_
            latent_out (torch.tensor): _description_
            tone_gate (torch.tensor): _description_

        Returns:
            _type_: _description_
        """
        curve_steps = 8
        tone_curve = self.tone_module(latent_out).reshape(-1,1,curve_steps)
        tone_curve = self.tanh_range(tone_curve,0.5, 2)
        tone_curve_sum = torch.sum(tone_curve, dim=2) + 1e-30
        total_image = x * 0
        for i in range(curve_steps):
            total_image += torch.clamp(x - 1.0 * i /curve_steps, 0, 1.0 /curve_steps) \
                            * tone_curve[:,:,i].unsqueeze(2).unsqueeze(3)
        total_image *= curve_steps / tone_curve_sum.unsqueeze(2).unsqueeze(3)
        total_image = (total_image-total_image.min())/(total_image.max()-total_image.min())
        # total_image = total_image * tone_gate.unsqueeze(1).unsqueeze(2).unsqueeze(3)
        return total_image


   
    def forward(self, x : torch.Tensor):
        """_summary_

        Args:
            x (torch.Tensor): _description_

        Returns:
            _type_: _description_
        """
        # latent_out = torch.nn.functional.relu_(self.encoder(x))
        latent_out = self.encoder(x)
        gate = self.tanh_range(self.gate_module(latent_out),0.01,1.0)
        out_idx = gate.argmax(dim=1)

        if out_idx == 0:
            wb_out = self.white_balance(x,latent_out,gate[:,0])
            return wb_out,gate
        elif out_idx == 1:
            gamma_out = self.gamma_balance(x,latent_out,gate[:,1])
            return gamma_out,gate
        elif out_idx == 2:
            identity_out = self.identity(x,gate[:,2])
            return identity_out,gate
        elif out_idx == 3:
            sharpning_out = self.sharpning(x,latent_out,gate[:,3])
            return sharpning_out, gate
        elif out_idx == 4:
            fog_out = self.defog(x,latent_out,gate[:,4])
            return fog_out,gate
        elif out_idx == 5:
            contrast_out = self.contrast(x,latent_out,gate[:,5])
            return contrast_out, gate
        else:
            tone_out = self.tone(x,latent_out,gate[:,6])
            return tone_out,gate

if __name__ == '__main__':
    batch_size = 2
    encoder_out_dim = 256
    x = torch.randn(batch_size,3,448,448)
    x = (x-x.min())/(x.max()-x.min())
    model = GatedDIP(encoder_output_dim = encoder_out_dim)
    print(model)
    out,gate= model(x)
    print('out shape:',out.shape)
    print('gate shape:',gate.shape)

作者论文中所提出的视觉编码器实现如下所示:

import torch 

class VisionEncoder(torch.nn.Module):
    def __init__(self,encoder_output_dim=256):
        super(VisionEncoder,self).__init__()
        # conv_1
        self.conv_1 = torch.nn.Sequential(torch.nn.Conv2d(3,64,kernel_size = 3 , stride = 1),
                                        torch.nn.ReLU(True))
        self.max_pool_1 = torch.nn.AvgPool2d((3,3),(2,2))
        
        # conv_2
        self.conv_2 = torch.nn.Sequential(torch.nn.Conv2d(64,128,kernel_size = 3 , stride = 1),
                                        torch.nn.ReLU(True))
        self.max_pool_2 = torch.nn.AvgPool2d((3,3),(2,2))
        # conv_3
        self.conv_3 = torch.nn.Sequential(torch.nn.Conv2d(128,256,kernel_size = 3 , stride = 1),
                                        torch.nn.ReLU(True))
        self.max_pool_3 = torch.nn.AvgPool2d((3,3),(2,2))
        
        # conv_4
        self.conv_4 = torch.nn.Sequential(torch.nn.Conv2d(256,512,kernel_size = 3 , stride = 1),
                                        torch.nn.ReLU(True))
        self.max_pool_4 = torch.nn.AvgPool2d((3,3),(2,2))
        
        # conv_5
        self.conv_5 = torch.nn.Sequential(torch.nn.Conv2d(512,1024,kernel_size = 3 , stride = 1),
                                        torch.nn.ReLU(True))
        self.adp_pool_5 = torch.nn.AdaptiveAvgPool2d((1,1))
        self.linear_proj_5 = torch.nn.Sequential(torch.nn.Linear(1024,encoder_output_dim),
                                                torch.nn.ReLU(True))
        

    def forward(self,x):
        out_x = self.conv_1(x)
        max_pool_1 = self.max_pool_1(out_x)
        
        out_x = self.conv_2(max_pool_1)
        max_pool_2 = self.max_pool_2(out_x)
        
        out_x = self.conv_3(max_pool_2)
        max_pool_3 = self.max_pool_3(out_x)
        
        out_x = self.conv_4(max_pool_3)
        max_pool_4 = self.max_pool_4(out_x)
        
        out_x = self.conv_5(max_pool_4)
        adp_pool_5 = self.adp_pool_5(out_x)
        linear_proj_5 = self.linear_proj_5(adp_pool_5.view(adp_pool_5.shape[0],-1))


        return linear_proj_5

if __name__ == '__main__':
    img = torch.randn(4,3,448,448).cuda()
    encoder = VisionEncoder(encoder_output_dim=256).cuda()
    print('output shape:',encoder(img).shape) # output should be [4,256]

gdip模块的用法如下所示,可以看到兼容原始yolo框架代码,但多返回了一个增强后的图像与各个DIP操作的权重。

import torch 
from model.gdip_model import GatedDIP
from model.yolov3 import Yolov3

class Yolov3GatedDIP(torch.nn.Module):
    def __init__(self):
        super(Yolov3GatedDIP,self).__init__()
        self.gated_dip = GatedDIP(256)
        self.yolov3 = Yolov3()
        #self.yolov3.load_darknet_weights(weights_path)
    
    def forward(self,x):
        out_x,gates = self.gated_dip(x)
        p,p_d = self.yolov3(out_x)
        return out_x,gates,p,p_d

3.2 MultiLevelGDIP

代码在model\mgdip.py中。在MultiLevelGDIP中又单独实现了gdip,这里可以看到gdip-yolo项目代码比较混乱。在这里的GatedDIP中,没有内置视觉编码器,而是在MultiLevelGDIP中内置视觉编码器,将GatedDIP作为MultiLevelGDIP中的一个部件。同时,mgdip相关的VisionEncoder与gdip中的返回值不一样,为了实现多尺度VisionEncoder返回的是一个dict,其中包含了各个尺度的特征图。

class GatedDIP(torch.nn.Module):
    '''这里删除了与上一份代码类似的部分'''
    def forward(self,x,linear_proj):
        gate = self.tanh_range(self.gate_module(linear_proj),0.01,1.0)
        wb_out = self.white_balance(x,linear_proj,gate[:,0])
        gamma_out = self.gamma_balance(x,linear_proj,gate[:,1])
        identity_out = self.identity(x,gate[:,2])
        sharpning_out = self.sharpning(x,linear_proj,gate[:,3])
        fog_out = self.defog(x,linear_proj,gate[:,4])
        contrast_out = self.contrast(x,linear_proj,gate[:,5])
        tone_out = self.tone(x,linear_proj,gate[:,6])
        x = wb_out + gamma_out   + fog_out + sharpning_out + contrast_out + tone_out + identity_out
        x = (x-x.min())/(x.max()-x.min())
        return x,gate


class MultiLevelGDIP(torch.nn.Module):

    def __init__(self,
                encoder_output_dim : int = 256,
                num_of_gates : int = 7):

        super(MultiLevelGDIP,self).__init__()
        self.vision_encoder = VisionEncoder(encoder_output_dim,base_channel=32)
        self.gdip1 = GatedDIP(encoder_output_dim,num_of_gates)
        self.gdip2 = GatedDIP(encoder_output_dim,num_of_gates)
        self.gdip3 = GatedDIP(encoder_output_dim,num_of_gates)
        self.gdip4 = GatedDIP(encoder_output_dim,num_of_gates)
        self.gdip5 = GatedDIP(encoder_output_dim,num_of_gates)
        self.gdip6 = GatedDIP(encoder_output_dim,num_of_gates)

    
    def forward(self, x : torch.Tensor):
        """_summary_

        Args:
            x (torch.Tensor): _description_

        Returns:
            _type_: _description_
        """
        out_image = list()
        gates_list = list()
        
        output_dict = self.vision_encoder(x)

        x,gate_6 = self.gdip6(x,output_dict['linear_proj_6'])
        out_image.append(x)
        gates_list.append(gate_6)

        x,gate_5 = self.gdip5(x,output_dict['linear_proj_5'])
        out_image.append(x)
        gates_list.append(gate_5)

        x,gate_4 = self.gdip4(x,output_dict['linear_proj_4'])
        out_image.append(x)
        gates_list.append(gate_4)

        x,gate_3 = self.gdip3(x,output_dict['linear_proj_3'])
        out_image.append(x)
        gates_list.append(gate_3)

        x,gate_2 = self.gdip2(x,output_dict['linear_proj_2'])
        out_image.append(x)
        gates_list.append(gate_2)

        x,gate_1 = self.gdip1(x,output_dict['linear_proj_1'])
        out_image.append(x)
        gates_list.append(gate_1)

        return x,out_image,gates_list

其使用代码如下所示:

import torch 
from model.mgdip import MultiLevelGDIP
from model.yolov3 import Yolov3

class Yolov3MGatedDIP(torch.nn.Module):
    def __init__(self):
        super(Yolov3MGatedDIP,self).__init__()
        self.mgdip = MultiLevelGDIP(256,7)
        self.yolov3 = Yolov3()
    
    def forward(self,x):
        out_x,_,gates_list = self.mgdip(x)
        p,p_d = self.yolov3(out_x)
        return out_x,gates_list,p,p_d

3.3 GDIP regularizer

代码在model\yolov3_multilevel_gdip.py,这里是mgdip regularizer模块。同样与上一份代码中的MultiLevelGDIP有所差别,这里的MultiLevelGDIP没有内置视觉编码器,而是获取Yolov3 backbone的3个尺度的输出+ 原始输入作为特征图输入MultiLevelGDIP(即使用Yolov3 backbone作为特征提取器)。其关键代码如下所示,同时发现MultiLevelGDIP没有使用预训练模型,而论文中提到MultiLevelGDIP可以作为正则化器。正则化器的loss通过训练后使Yolov3 backbone提取的特征与MultiLevelGDIP提取的一样。在这里MultiLevelGDIP参与训练,但其输出的值又不参与前向传播。

class Yolov3(nn.Module):
    """
    Note : int the __init__(), to define the modules should be in order, because of the weight file is order
    """
    def __init__(self, cfg, init_weights=True):
        super(Yolov3, self).__init__()

        self.__anchors = torch.FloatTensor(cfg.MODEL["ANCHORS"])
        self.__strides = torch.FloatTensor(cfg.MODEL["STRIDES"])
        self.__nC = cfg.DATA["NUM"]
        self.__out_channel = cfg.MODEL["ANCHORS_PER_SCLAE"] * (self.__nC + 5)

        self.__backnone = Darknet53()
        self.__fpn = FPN_YOLOV3(fileters_in=[1024, 512, 256],
                                fileters_out=[self.__out_channel, self.__out_channel, self.__out_channel])

        # small
        self.__head_s = Yolo_head(nC=self.__nC, anchors=self.__anchors[0], stride=self.__strides[0])
        
        # medium
        self.__head_m = Yolo_head(nC=self.__nC, anchors=self.__anchors[1], stride=self.__strides[1])

        # large
        self.__head_l = Yolo_head(nC=self.__nC, anchors=self.__anchors[2], stride=self.__strides[2])

        # multilevel gdip
        self.__multilevel_gdip = MultiLevelGDIP()


        if init_weights:
            self.__init_weights()


    def forward(self, x):
        out = []

        x_s, x_m, x_l = self.__backnone(x)
        out_x,img_list,gates_list = self.__multilevel_gdip(x,x_s,x_m,x_l)
        x_s, x_m, x_l = self.__fpn(x_l, x_m, x_s)

        out.append(self.__head_s(x_s))
        out.append(self.__head_m(x_m))
        out.append(self.__head_l(x_l))

        if self.training:
            p, p_d = list(zip(*out))
            return out_x,gates_list[-1],p, p_d  # smalll, medium, large
        else:
            p, p_d = list(zip(*out))
            return out_x,gates_list[-1],p, torch.cat(p_d, 0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/340606.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ARM 驱动 1.22

linux内核等待队列wait_queue_head_t 头文件 include <linux/wait.h> 定义并初始化 wait_queue_head_t r_wait; init_waitqueue_head(&cm_dev->r_wait); wait_queue_head_t 表示等待队列头&#xff0c;等待队列wait时&#xff0c;会导致进程或线程被休眠&…

springsecurity集成kaptcha功能

前端代码 本次采用简单的html静态页面作为演示&#xff0c;也可结合vue前后端分离开发&#xff0c;复制就可运行测试 项目目录 登录界面 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</…

详谈c++智能指针!!!

文章目录 前言一、智能指针的发展历史1.C 98/03 的尝试——std::auto_ptr2.std::unique_ptr3.std::shared_ptr4.std::weak_ptr5.智能指针的大小6.智能指针使用注意事项 二、智能指针的模拟实现三、C11和boost中智能指针的关系 前言 C/C 语言最为人所诟病的特性之一就是存在内存…

Quartus II使用小技巧

工程结构&#xff1a; 在建立完某项设计的文件后&#xff0c;依次在其里面新建四个文件夹&#xff0c;分别为&#xff1a;rtl、qprj、msim、doc。 rtl文件夹用于存放设计的源文件。 doc文件夹用于存放设计的一些文档性的资料。 qprj文件夹用于存放quaruts 工程以及quartus生…

陪玩系统:最新商业版游戏陪玩语音聊天系统3.0商业升级独立版本源码

首发价值29800元的最新商业版游戏陪玩语音聊天系统3.0商业升级独立版本源码 &#xff08;价值29800&#xff09;最新陪玩3.0独立版本 &#xff0c;文件截图 结尾将会附上此系统源码以及详细搭建教程包含素材图仅用于学习使用 陪玩系统3.0独立升级版正式发布&#xff0c;此版本…

项目管理中如何有效沟通?项目管理有效沟通指南

无论是少数人的小型企业还是拥有数十名员工的大公司&#xff0c;有效的沟通对于确保每个人都参与并准备好在项目中实现相同的目标至关重要。 然而&#xff0c;由于沟通不畅&#xff0c;似乎在翻译中总是丢失一些东西。事实上&#xff0c;根据布兰迪斯大学的一项研究&#xff0c…

【复现】SpringBlade SQL 注入漏洞_22

目录 一.概述 二 .漏洞影响 三.漏洞复现 1. 漏洞一&#xff1a; 四.修复建议&#xff1a; 五. 搜索语法&#xff1a; 六.免责声明 一.概述 SpringBlade 是由一个商业级项目升级优化而来的SpringCloud微服务架构&#xff0c;采用Java8 API重构了业务代码&#xff0c;完全…

一文梳理Windows自启动位置

不同版本的Windows开机自启动的位置略有出入&#xff0c;一般来说&#xff0c;Windows自启动的位置有&#xff1a;自启动文件夹、注册表子键、自动批处理文件、系统配置文件等。如果计算机感染了木马&#xff0c;很有可能就潜伏于其中&#xff01;本文将说明这些常见的Windows开…

GitHub README-Template.md - README.md 模板

GitHub README-Template.md - README.md 模板 1. README-Template.md 预览模式2. README-Template.md 编辑模式References A template to make good README.md. https://gist.github.com/PurpleBooth/109311bb0361f32d87a2 1. README-Template.md 预览模式 2. README-Templat…

CHS_02.2.2.2+调度的目标 调度算法的评价指标

CHS_02.2.2.2调度的目标 调度算法的评价指标 知识总览CPU利用率系统吞吐量周转时间等待时间响应时间 知识回顾 在这个小节中 我们会学习一系列用于评价一个调度算法好坏的一些评价指标 知识总览 包括cpu利用率 系统吞吐量 周转时间 等待时间和响应时间 那在学习的过程中 要注意…

20240122在WIN10+GTX1080下使用字幕小工具V1.2的使用总结(whisper)

20240122在WIN10GTX1080下使用字幕小工具V1.2的使用总结 2024/1/22 19:52 结论&#xff1a;这个软件如果是习作&#xff0c;可以打101分&#xff0c;功能都实现了。 如果作为商业软件/共享软件&#xff0c;在易用性等方面&#xff0c;可能就只能有70分了。 【百分制】 可选的改…

makefile 编译动态链接库使用(.so库文件)

makefile 编译动态链接库使用&#xff08;.so库文件&#xff09; 动态链接库:不会把代码编译到二进制文件中&#xff0c;而是在运行时才去加载&#xff0c; 好处是程序可以和库文件分离&#xff0c;可以分别发版&#xff0c;然后库文件可以被多处共享 动态链接库 动态&#…

macbookpro怎么恢复出厂设置2024最新恢复方法汇总

可能你的MacBook曾经是高性能的代表&#xff0c;但是现在它正慢慢地逝去了自己的光芒&#xff1f;随着逐年的使用以及文件的添加和程序的安装&#xff0c;你的MacBook可能会开始变得迟缓卡顿&#xff0c;或者失却了以往的光彩。如果你发现你的Mac开始出现这些严重问题&#xff…

牛客周赛 Round 20 解题报告 | 珂学家 | 状压DP/矩阵幂优化 + 前缀和的前缀和

前言 整体评价 这场比赛很特别&#xff0c;是牛客周赛的第20场&#xff0c;后两题难度直线飙升了。 前四题相对简单&#xff0c;E题是道状压题&#xff0c;历来状压题都难&#xff0c;F题压轴难题了&#xff0c;感觉学到了不少。 A. 赝品 先求的最大值 然后统计非最大值的个…

Haar小波下采样模块

论文原址&#xff1a;Haar wavelet downsampling: A simple but effective downsampling module for semantic segmentation - ScienceDirect 原文代码&#xff1a;HWD/HWD.py at main apple1986/HWD (github.com) 介绍 深度卷积神经网络 &#xff08;DCNN&#xff09; 通…

CPMS靶场练习

关键&#xff1a;找到文件上传点&#xff0c;分析对方验证的手段 首先查看前端发现没有任何上传的位置&#xff0c;找到网站的后台&#xff0c;通过弱口令admin 123456可以进入 通过查看网站内容发现只有文章列表可以进行文件上传&#xff1b;有两个图片上传点 图片验证很严格…

《WebKit 技术内幕》学习之六(2): CSS解释器和样式布局

2 CSS解释器和规则匹配 在了解了CSS的基本概念之后&#xff0c;下面来理解WebKit如何来解释CSS代码并选择相应的规则。通过介绍WebKit的主要设施帮助理解WebKit的内部工作原理和机制。 2.1 样式的WebKit表示类 在DOM树中&#xff0c;CSS样式可以包含在“style”元素中或者使…

SpringBoot异常处理和单元测试

学习目标 Spring Boot 异常处理Spring Boot 单元测试 1.SpringBoot异常处理 1.1.自定义错误页面 SpringBoot默认的处理异常的机制&#xff1a;SpringBoot 默认的已经提供了一套处理异常的机制。一旦程序中出现了异常 SpringBoot 会向/error 的 url 发送请求。在 springBoot…

移动开发行业——鸿蒙OS NEXT开出繁花

1月18日&#xff0c;华为宣布HarmonyOS NEXT开发者预览版开放申请&#xff0c;根据官方注解&#xff0c;这个版本的鸿蒙系统有个更通俗易懂的名字——“星河版”&#xff0c;也被称为“纯血”鸿蒙。 根据官方解释&#xff0c;之所以取名星河版&#xff0c;寓意鸿蒙OS NEXT就像…

28、web攻防——通用漏洞SQL注入HTTP头XFFCOOKIEPOST请求

文章目录 $_GET&#xff1a;接收get请求&#xff0c;传输少量数据&#xff0c;URL是有长度限制的&#xff1b; $_POST&#xff1a;接收post请求&#xff1b; $_COOKIE&#xff1a;接收cookie&#xff0c;用于身份验证&#xff1b; $_REQUEST&#xff1a;收集通过 GET 、POST和C…
最新文章