基于深度学习神经网络的AI图片上色DDcolor系统源码

第一步:DDcolor介绍

        DDColor 是最新的 SOTA 图像上色算法,能够对输入的黑白图像生成自然生动的彩色结果,使用 UNet 结构的骨干网络和图像解码器分别实现图像特征提取和特征图上采样,并利用 Transformer 结构的颜色解码器完成基于视觉语义的颜色查询,最终聚合输出彩色通道预测结果。

        它甚至可以对动漫游戏中的风景进行着色/重新着色,将您的动画风景转变为逼真的现实生活风格!(图片来源:原神)

第二步:DDcolor网络结构

        算法整体流程如下图,使用 UNet 结构的骨干网络和图像解码器分别实现图像特征提取和特征图上采样,并利用 Transformer 结构的颜色解码器完成基于视觉语义的颜色查询,最终聚合输出彩色通道预测结果。

第三步:模型代码展示

import os
import torch
from collections import OrderedDict
from os import path as osp
from tqdm import tqdm
import numpy as np

from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.img_util import tensor_lab2rgb
from basicsr.utils.dist_util import master_only
from basicsr.utils.registry import MODEL_REGISTRY
from .base_model import BaseModel
from basicsr.metrics.custom_fid import INCEPTION_V3_FID, get_activations, calculate_activation_statistics, calculate_frechet_distance
from basicsr.utils.color_enhance import color_enhacne_blend


@MODEL_REGISTRY.register()
class ColorModel(BaseModel):
    """Colorization model for single image colorization."""

    def __init__(self, opt):
        super(ColorModel, self).__init__(opt)

        # define network net_g
        self.net_g = build_network(opt['network_g'])
        self.net_g = self.model_to_device(self.net_g)
        self.print_network(self.net_g)
        
        # load pretrained model for net_g
        load_path = self.opt['path'].get('pretrain_network_g', None)
        if load_path is not None:
            param_key = self.opt['path'].get('param_key_g', 'params')
            self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)

        if self.is_train:
            self.init_training_settings()

    def init_training_settings(self):
        train_opt = self.opt['train']

        self.ema_decay = train_opt.get('ema_decay', 0)
        if self.ema_decay > 0:
            logger = get_root_logger()
            logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
            # define network net_g with Exponential Moving Average (EMA)
            # net_g_ema is used only for testing on one GPU and saving
            # There is no need to wrap with DistributedDataParallel
            self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
            # load pretrained model
            load_path = self.opt['path'].get('pretrain_network_g', None)
            if load_path is not None:
                self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
            else:
                self.model_ema(0)  # copy net_g weight
            self.net_g_ema.eval()

        # define network net_d
        self.net_d = build_network(self.opt['network_d'])
        self.net_d = self.model_to_device(self.net_d)
        self.print_network(self.net_d)

        # load pretrained model for net_d
        load_path = self.opt['path'].get('pretrain_network_d', None)
        if load_path is not None:
            param_key = self.opt['path'].get('param_key_d', 'params')
            self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)

        self.net_g.train()
        self.net_d.train()

        # define losses
        if train_opt.get('pixel_opt'):
            self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
        else:
            self.cri_pix = None

        if train_opt.get('perceptual_opt'):
            self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
        else:
            self.cri_perceptual = None

        if train_opt.get('gan_opt'):
            self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
        else:
            self.cri_gan = None

        if self.cri_pix is None and self.cri_perceptual is None:
            raise ValueError('Both pixel and perceptual losses are None.')

        if train_opt.get('colorfulness_opt'):
            self.cri_colorfulness = build_loss(train_opt['colorfulness_opt']).to(self.device)
        else:
            self.cri_colorfulness = None

        # set up optimizers and schedulers
        self.setup_optimizers()
        self.setup_schedulers()

        # set real dataset cache for fid metric computing
        self.real_mu, self.real_sigma = None, None
        if self.opt['val'].get('metrics') is not None and self.opt['val']['metrics'].get('fid') is not None:
            self._prepare_inception_model_fid()

    def setup_optimizers(self):
        train_opt = self.opt['train']
        # optim_params_g = []
        # for k, v in self.net_g.named_parameters():
        #     if v.requires_grad:
        #         optim_params_g.append(v)
        #     else:
        #         logger = get_root_logger()
        #         logger.warning(f'Params {k} will not be optimized.')
        optim_params_g = self.net_g.parameters()

        # optimizer g
        optim_type = train_opt['optim_g'].pop('type')
        self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
        self.optimizers.append(self.optimizer_g)

        # optimizer d
        optim_type = train_opt['optim_d'].pop('type')
        self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
        self.optimizers.append(self.optimizer_d)
    
    def feed_data(self, data):
        self.lq = data['lq'].to(self.device)
        self.lq_rgb = tensor_lab2rgb(torch.cat([self.lq, torch.zeros_like(self.lq), torch.zeros_like(self.lq)], dim=1))
        if 'gt' in data:
            self.gt = data['gt'].to(self.device)
            self.gt_lab = torch.cat([self.lq, self.gt], dim=1)
            self.gt_rgb = tensor_lab2rgb(self.gt_lab)

            if self.opt['train'].get('color_enhance', False):
                for i in range(self.gt_rgb.shape[0]):
                    self.gt_rgb[i] = color_enhacne_blend(self.gt_rgb[i], factor=self.opt['train'].get('color_enhance_factor'))

    def optimize_parameters(self, current_iter):
        # optimize net_g
        for p in self.net_d.parameters():
            p.requires_grad = False
        self.optimizer_g.zero_grad()
        
        self.output_ab = self.net_g(self.lq_rgb)
        self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)
        self.output_rgb = tensor_lab2rgb(self.output_lab)

        l_g_total = 0
        loss_dict = OrderedDict()
        # pixel loss
        if self.cri_pix:
            l_g_pix = self.cri_pix(self.output_ab, self.gt)
            l_g_total += l_g_pix
            loss_dict['l_g_pix'] = l_g_pix

        # perceptual loss
        if self.cri_perceptual:
            l_g_percep, l_g_style = self.cri_perceptual(self.output_rgb, self.gt_rgb)
            if l_g_percep is not None:
                l_g_total += l_g_percep
                loss_dict['l_g_percep'] = l_g_percep
            if l_g_style is not None:
                l_g_total += l_g_style
                loss_dict['l_g_style'] = l_g_style
        # gan loss
        if self.cri_gan:
            fake_g_pred = self.net_d(self.output_rgb)
            l_g_gan = self.cri_gan(fake_g_pred, target_is_real=True, is_disc=False)
            l_g_total += l_g_gan
            loss_dict['l_g_gan'] = l_g_gan
        # colorfulness loss
        if self.cri_colorfulness:
            l_g_color = self.cri_colorfulness(self.output_rgb)
            l_g_total += l_g_color
            loss_dict['l_g_color'] = l_g_color

        l_g_total.backward()
        self.optimizer_g.step()

        # optimize net_d
        for p in self.net_d.parameters():
            p.requires_grad = True
        self.optimizer_d.zero_grad()

        real_d_pred = self.net_d(self.gt_rgb)
        fake_d_pred = self.net_d(self.output_rgb.detach())
        l_d = self.cri_gan(real_d_pred, target_is_real=True, is_disc=True) + self.cri_gan(fake_d_pred, target_is_real=False, is_disc=True)
        loss_dict['l_d'] = l_d
        loss_dict['real_score'] = real_d_pred.detach().mean()
        loss_dict['fake_score'] = fake_d_pred.detach().mean()

        l_d.backward()
        self.optimizer_d.step()

        self.log_dict = self.reduce_loss_dict(loss_dict)

        if self.ema_decay > 0:
            self.model_ema(decay=self.ema_decay)

    def get_current_visuals(self):
        out_dict = OrderedDict()
        out_dict['lq'] = self.lq_rgb.detach().cpu()
        out_dict['result'] = self.output_rgb.detach().cpu()
        if self.opt['logger'].get('save_snapshot_verbose', False):  # only for verbose
            self.output_lab_chroma = torch.cat([torch.ones_like(self.lq) * 50, self.output_ab], dim=1)
            self.output_rgb_chroma = tensor_lab2rgb(self.output_lab_chroma)
            out_dict['result_chroma'] = self.output_rgb_chroma.detach().cpu()

        if hasattr(self, 'gt'):
            out_dict['gt'] = self.gt_rgb.detach().cpu()
            if self.opt['logger'].get('save_snapshot_verbose', False):  # only for verbose
                self.gt_lab_chroma = torch.cat([torch.ones_like(self.lq) * 50, self.gt], dim=1)
                self.gt_rgb_chroma = tensor_lab2rgb(self.gt_lab_chroma)
                out_dict['gt_chroma'] = self.gt_rgb_chroma.detach().cpu()
        return out_dict

    def test(self):
        if hasattr(self, 'net_g_ema'):
            self.net_g_ema.eval()
            with torch.no_grad():
                self.output_ab = self.net_g_ema(self.lq_rgb)
                self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)
                self.output_rgb = tensor_lab2rgb(self.output_lab)
        else:
            self.net_g.eval()
            with torch.no_grad():
                self.output_ab = self.net_g(self.lq_rgb)
                self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)
                self.output_rgb = tensor_lab2rgb(self.output_lab)
            self.net_g.train()
    
    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
        if self.opt['rank'] == 0:
            self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
    
    def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
        dataset_name = dataloader.dataset.opt['name']
        with_metrics = self.opt['val'].get('metrics') is not None
        use_pbar = self.opt['val'].get('pbar', False)

        if with_metrics and not hasattr(self, 'metric_results'):  # only execute in the first run
            self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
        # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
        if with_metrics:
            self._initialize_best_metric_results(dataset_name)
        # zero self.metric_results
        if with_metrics:
            self.metric_results = {metric: 0 for metric in self.metric_results}

        metric_data = dict()
        if use_pbar:
            pbar = tqdm(total=len(dataloader), unit='image')
        
        if self.opt['val']['metrics'].get('fid') is not None:
            fake_acts_set, acts_set = [], []

        for idx, val_data in enumerate(dataloader):
            # if idx == 100:
            #     break
            img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
            if hasattr(self, 'gt'):
                del self.gt
            self.feed_data(val_data)
            self.test()

            visuals = self.get_current_visuals()
            sr_img = tensor2img([visuals['result']])
            metric_data['img'] = sr_img
            if 'gt' in visuals:
                gt_img = tensor2img([visuals['gt']])
                metric_data['img2'] = gt_img

            torch.cuda.empty_cache()

            if save_img:
                if self.opt['is_train']:
                    save_dir = osp.join(self.opt['path']['visualization'], img_name)
                    for key in visuals:
                        save_path = os.path.join(save_dir, '{}_{}.png'.format(current_iter, key))
                        img = tensor2img(visuals[key])
                        imwrite(img, save_path)
                else:
                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
                                                 f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
                                                 f'{img_name}_{self.opt["name"]}.png')
                    imwrite(sr_img, save_img_path)

            if with_metrics:
                # calculate metrics
                for name, opt_ in self.opt['val']['metrics'].items():
                    if name == 'fid':
                        pred, gt = visuals['result'].cuda(), visuals['gt'].cuda()
                        fake_act = get_activations(pred, self.inception_model_fid, 1)
                        fake_acts_set.append(fake_act)
                        if self.real_mu is None:
                            real_act = get_activations(gt, self.inception_model_fid, 1)
                            acts_set.append(real_act)
                    else:
                        self.metric_results[name] += calculate_metric(metric_data, opt_)
            if use_pbar:
                pbar.update(1)
                pbar.set_description(f'Test {img_name}')
        if use_pbar:
            pbar.close()

        if with_metrics:
            if self.opt['val']['metrics'].get('fid') is not None:
                if self.real_mu is None:
                    acts_set = np.concatenate(acts_set, 0)
                    self.real_mu, self.real_sigma = calculate_activation_statistics(acts_set)
                fake_acts_set = np.concatenate(fake_acts_set, 0)
                fake_mu, fake_sigma = calculate_activation_statistics(fake_acts_set)

                fid_score = calculate_frechet_distance(self.real_mu, self.real_sigma, fake_mu, fake_sigma)
                self.metric_results['fid'] = fid_score

            for metric in self.metric_results.keys():
                if metric != 'fid':
                    self.metric_results[metric] /= (idx + 1)
                # update the best metric result
                self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)

            self._log_validation_metric_values(current_iter, dataset_name, tb_logger)

    def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
        log_str = f'Validation {dataset_name}\n'
        for metric, value in self.metric_results.items():
            log_str += f'\t # {metric}: {value:.4f}'
            if hasattr(self, 'best_metric_results'):
                log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
                            f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
            log_str += '\n'

        logger = get_root_logger()
        logger.info(log_str)
        if tb_logger:
            for metric, value in self.metric_results.items():
                tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)

    def _prepare_inception_model_fid(self, path='pretrain/inception_v3_google-1a9a5a14.pth'):
        incep_state_dict = torch.load(path, map_location='cpu')
        block_idx = INCEPTION_V3_FID.BLOCK_INDEX_BY_DIM[2048]
        self.inception_model_fid = INCEPTION_V3_FID(incep_state_dict, [block_idx])
        self.inception_model_fid.cuda()
        self.inception_model_fid.eval()

    @master_only
    def save_training_images(self, current_iter):
        visuals = self.get_current_visuals()
        save_dir = osp.join(self.opt['root_path'], 'experiments', self.opt['name'], 'training_images_snapshot')
        os.makedirs(save_dir, exist_ok=True)

        for key in visuals:
            save_path = os.path.join(save_dir, '{}_{}.png'.format(current_iter, key))
            img = tensor2img(visuals[key])
            imwrite(img, save_path)

    def save(self, epoch, current_iter):
        if hasattr(self, 'net_g_ema'):
            self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
        else:
            self.save_network(self.net_g, 'net_g', current_iter)
        self.save_network(self.net_d, 'net_d', current_iter)
        self.save_training_state(epoch, current_iter)

第四步:运行

第五步:整个工程的内容

代码的下载路径(新窗口打开链接)基于深度学习神经网络的AI图片上色DDcolor系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/579284.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

循环单链表的介绍与操作

定义 区别 链表合并 整合代码 typedef struct node{int data;node* next;; }lnode,*linklist; lnode* n; linklist l;//定义 void init(linklist &l){lnode lnew lnode;l->nextl;lnode *rl; } //单循环链表的合并 linklist merge(linklist &a,linklist b){//存头结…

注意力机制:SENet详解

SENet(Squeeze-and-Excitation Networks)是2017年提出的一种经典的通道注意力机制,这种注意力可以让网络更加专注于一些重要的featuremap,它通过对特征通道间的相关性进行建模,把重要的特征图进行强化来提升模型的性能…

ZISUOJ 高级语言程序设计实训-基础C(部分题)

说明&#xff1a; 有几个题是不会讲的&#xff0c;我只能保证大家拿保底分。 题目列表&#xff1a; 问题 A: 求平均数1 思路&#xff1a; 送分题…… 参考题解&#xff1a; #include <iostream> #include <iomanip> using std::cin; using std::cout;int main(…

C++项目在Linux下编译动态库

一、说明 最近在Windows下开发了一个C线程池项目&#xff0c;准备移植到Linux下&#xff0c;并且编译成动态库进行使用。现将具体过程在此记录。 二、准备 1、项目文件 我的项目文件如下&#xff0c;其中除main.cpp是测试文件之外&#xff0c;其他都是线程池项目相关的 将C…

双系统下删除ubuntu

絮絮叨叨 由于我在安装Ubuntu的时候没有自定义安装位置&#xff0c;而是使用与window共存的方式让Ubuntu自己选择安装位置&#xff0c;导致卸载时我不知道去格式化哪个分区&#xff0c;查阅多方资料后无果&#xff0c;后在大佬帮助下找到解决方案 解决步骤 1、 插上Ubuntu安…

【IR 论文】DPR — 最早提出使用嵌入向量来检索文档的模型

论文&#xff1a;Dense Passage Retrieval for Open-Domain Question Answering ⭐⭐⭐⭐⭐ EMNLP 2020, Facebook Research Code: github.com/facebookresearch/DPR 文章目录 一、论文速读二、DPR 的训练2.1 正样本和负样本的选取2.2 In-batch negatives 技巧 三、实验3.1 数据…

医学影像增强:空间域方法与频域方法等

医学影像图像增强是一项关键技术,旨在改善图像质量,以便更好地进行疾病诊断和评估。增强方法通常分为两大类:空间域方法和频域方法。 一、 空间域方法 空间域方法涉及直接对医学影像的像素值进行操作,以提高图像的视觉质量。以下是一些常用的空间域方法: 对比度调整:通过…

双向链表的介绍

引入 特点 操作 定义 插入 删除 小结 整合代码&#xff1a; //定义 typedef struct node{int data;node* next,prior; }Dlnode,*Dlinklist;//初始 void init(Dlinklist &l){Dlnode lnew Dlnode;l->nextNULL;l->priorNULL; }//插入&#xff1a;插入指定位置&#…

buuctf-misc-rar

13.rar 题目&#xff1a;直接破解就可以了 这个借用工具Aparch进行4位纯数字暴力破解&#xff0c;根据题目的提示是4位的纯数字&#xff0c;那我们就选择数字的破解 在长度这里&#xff0c;把最小口令长度和最大口令长度都选择为4 获得密码后进行解压。

各省财政涉农支出统计数据集(2001-2022年)

01、数据简介 财政涉农支出是指在政府预算中&#xff0c;用于支持农业、农村和农民发展的财政支出。这些支出旨在促进农村经济的发展&#xff0c;提高农民收入&#xff0c;改善农村生产生活条件&#xff0c;推进农业现代化。 在未来的发展中&#xff0c;各省将继续加大财政涉…

【研发管理】产品经理知识体系-产品设计与开发工具

导读&#xff1a;产品设计与开发工具的重要性体现在多个方面&#xff0c;它们对于产品的成功开发、质量提升以及市场竞争力都具有至关重要的影响。产品设计工具可以帮助设计师更高效地创建和优化产品原型。开发工具在产品开发过程中发挥着至关重要的作用。产品设计与开发工具还…

这三个AI导航站,你绝对用得到!!

Hi&#xff0c;这里是前端后花园&#xff0c;专注分享前端软件网站、工具资源。不得不说AI正在慢慢改变我们学习和工作方式&#xff0c;今天带来这三个AI导航站&#xff0c;总有一个你用得到&#xff0c;记得收藏啊&#xff01; AI工具集 https://ai-bot.cn/ 网站汇集了700 …

Linux多线程(三) 线程池C++实现

一、线程池原理 我们使用线程的时候就去创建一个线程&#xff0c;这样实现起来非常简便。但如果并发的线程数量很多&#xff0c;并且每个线程都是执行一个时间很短的任务就结束了&#xff0c;这样频繁创建线程就会大大降低系统的效率&#xff0c;因为频繁创建线程和销毁线程需…

华为MRS服务使用记录

背景&#xff1a;公司的业务需求是使用华为的这一套成品来进行开发&#xff0c;使用中发现&#xff0c;这个产品跟原生的Hadoop的那一套的使用&#xff0c;还是有很大的区别的&#xff0c;现记录一下&#xff0c;避免以后忘了 一、原始代码的下载 下载地址&#xff1a;MRS样例…

STM32HAL库++ESP8266+cJSON连接阿里云物联网平台

实验使用资源&#xff1a;正点原子F1 USART1&#xff1a;PA9P、A10&#xff08;串口打印调试&#xff09; USART3&#xff1a;PB10、PB11&#xff08;WiFi模块&#xff09; DHT11&#xff1a;PG11&#xff08;采集数据、上报&#xff09; LED0、1&#xff1a;PB5、PE5&#xff…

input框添加验证(如只允许输入数字)中文输入导致显示问题的解决方案

文章目录 input框添加验证(如只允许输入数字)中文输入导致显示问题的解决方案问题描述解决办法 onCompositionStart与onCompositionEnd input框添加验证(如只允许输入数字)中文输入导致显示问题的解决方案 问题描述 测试环境&#xff1a;react antd input (react的事件与原生…

浅谈在Java代码中创建线程的多种方式

文章目录 一、Thread 类1.1 跨平台性 二、Thread 类里的常用方法三、创建线程的方法1、自定义一个类&#xff0c;继承Thread类&#xff0c;重写run方法1.1、调用 start() 方法与调用 run() 方法来创建线程&#xff0c;有什么区别&#xff1f;1.2、sleep()方法 2、自定义一个类&…

嵌入式常见存储器

阅读引言&#xff1a; 在看一款芯片的数据手册的时候&#xff0c; 无意间翻到了它的启动模式(Boot Mode), 发现这种这么多种ROM&#xff0c;所以就写下了这篇文章。 目录 一、存储器汇总 二、易失性存储器(RAM) 1. SRAM 1.1 单口SRAM 1.2 双口SRAM 2. DRAM 2.1 SDRAM 2…

Fast-DetectGPT 无需训练的快速文本检测

本文提出了一种新的文本检测方法 ——Fast-DetectGPT&#xff0c;无需训练&#xff0c;直接使用开源小语言模型检测各种大语言模型&#xff0c;如GPT等生成的文本内容。 Fast-DetectGPT 将检测速度提高了 340 倍&#xff0c;将检测准确率相对提升了 75%&#xff0c;超过商用系…

有哪些好用电脑端时间定时软件?桌面日程安排软件推荐 桌面备忘录

随着现代生活节奏的加快&#xff0c;人们对于时间管理和任务提醒的需求越来越大。为了满足这一需求&#xff0c;市场上涌现出了众多桌面便签备忘录软件&#xff0c;它们不仅可以帮助我们记录待办事项&#xff0c;还能定时提醒我们完成任务。在这篇文章中&#xff0c;我将为大家…
最新文章