基于深度学习的图像去雨去雾

基于深度学习的图像去雨去雾


文末附有源码下载地址
b站视频地址: https://www.bilibili.com/video/BV1Jr421p7cT/

基于深度学习的图像去雨去雾,使用的网络为unet,
网络代码:

import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor
import torch.nn.functional as F
from torchstat import stat

class Resnet18(nn.Module):
    def __init__(self):
        super(Resnet18, self).__init__()
        self.resnet = models.resnet18(pretrained=False)
        # self.resnet = create_feature_extractor(self.resnet, {'relu': 'feat320', 'layer1': 'feat160', 'layer2': 'feat80',
        #                                                'layer3': 'feat40'})

    def forward(self,x):
        for name,m in self.resnet._modules.items():

            x=m(x)
            if name=='relu':
                x1=x
            elif name=='layer1':
                x2=x
            elif name=='layer2':
                x3=x
            elif name=='layer3':
                x4=x
                break
        # x=self.resnet(x)
        return x1,x2,x3,x4
class Linears(nn.Module):
    def __init__(self,a,b):
        super(Linears, self).__init__()
        self.linear1=nn.Linear(a,b)
        self.relu1=nn.LeakyReLU()
        self.linear2 = nn.Linear(b, a)
        self.sigmoid=nn.Sigmoid()
    def forward(self,x):
        x=self.linear1(x)
        x=self.relu1(x)
        x=self.linear2(x)
        x=self.sigmoid(x)
        return x
class DenseNetBlock(nn.Module):
    def __init__(self,inplanes=1,planes=1,stride=1):
        super(DenseNetBlock,self).__init__()
        self.conv1=nn.Conv2d(inplanes,planes,3,stride,1)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1=nn.LeakyReLU()

        self.conv2 = nn.Conv2d(inplanes, planes, 3,stride,1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.LeakyReLU()

        self.conv3 = nn.Conv2d(inplanes, planes, 3,stride,1)
        self.bn3 = nn.BatchNorm2d(planes)
        self.relu3 = nn.LeakyReLU()
    def forward(self,x):
        ins=x
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x=x+ins

        x2=self.conv3(x)
        x2 = self.bn3(x2)
        x2=self.relu3(x2)

        out=ins+x+x2
        return out
class SEnet(nn.Module):
    def __init__(self,chs,reduction=4):
        super(SEnet,self).__init__()
        self.average_pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.fc = nn.Sequential(
            # First reduce dimension, then raise dimension.
            # Add nonlinear processing to fit the correlation between channels
            nn.Linear(chs, chs // reduction),
            nn.LeakyReLU(inplace=True),
            nn.Linear(chs // reduction, chs)
        )
        self.activation = nn.Sigmoid()
    def forward(self,x):
        ins=x
        batch_size, chs, h, w = x.shape
        x=self.average_pooling(x)
        x = x.view(batch_size, chs)
        x=self.fc(x)
        x = x.view(batch_size,chs,1,1)
        return x*ins
class UAFM(nn.Module):
    def __init__(self):
        super(UAFM, self).__init__()
        # self.meanPool_C=torch.max()

        self.attention=nn.Sequential(
            nn.Conv2d(4, 8, 3, 1,1),
            nn.LeakyReLU(),
            nn.Conv2d(8, 1, 1, 1),
            nn.Sigmoid()
        )


    def forward(self,x1,x2):
        x1_mean_pool=torch.mean(x1,dim=1)
        x1_max_pool,_=torch.max(x1,dim=1)
        x2_mean_pool = torch.mean(x2, dim=1)
        x2_max_pool,_ = torch.max(x2, dim=1)

        x1_mean_pool=torch.unsqueeze(x1_mean_pool,dim=1)
        x1_max_pool=torch.unsqueeze(x1_max_pool,dim=1)
        x2_mean_pool=torch.unsqueeze(x2_mean_pool,dim=1)
        x2_max_pool=torch.unsqueeze(x2_max_pool,dim=1)

        cat=torch.cat((x1_mean_pool,x1_max_pool,x2_mean_pool,x2_max_pool),dim=1)
        a=self.attention(cat)
        out=x1*a+x2*(1-a)
        return out

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.resnet18=Resnet18()
        self.SENet=SEnet(chs=256)
        self.UAFM=UAFM()
        self.DenseNet1=DenseNetBlock(inplanes=256,planes=256)
        self.transConv1=nn.ConvTranspose2d(256,128,3,2,1,output_padding=1)

        self.DenseNet2 = DenseNetBlock(inplanes=128, planes=128)
        self.transConv2 = nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding=1)

        self.DenseNet3 = DenseNetBlock(inplanes=64, planes=64)
        self.transConv3 = nn.ConvTranspose2d(64, 64, 3, 2, 1, output_padding=1)

        self.transConv4 = nn.ConvTranspose2d(64, 32, 3, 2, 1, output_padding=1)
        self.DenseNet4=DenseNetBlock(inplanes=32,planes=32)
        self.out=nn.Sequential(
            nn.Conv2d(32,3,1,1),
            nn.Sigmoid()
        )

    def forward(self,x):
        """
        下采样部分
        """
        x1,x2,x3,x4=self.resnet18(x)
        # feat320=features['feat320']
        # feat160=features['feat160']
        # feat80=features['feat80']
        # feat40=features['feat40']
        feat320=x1
        feat160=x2
        feat80=x3
        feat40=x4
        """
        上采样部分
        """
        x=self.SENet(feat40)
        x=self.DenseNet1(x)
        x=self.transConv1(x)
        x=self.UAFM(x,feat80)

        x=self.DenseNet2(x)
        x=self.transConv2(x)
        x=self.UAFM(x,feat160)

        x = self.DenseNet3(x)
        x = self.transConv3(x)
        x = self.UAFM(x, feat320)

        x=self.transConv4(x)
        x=self.DenseNet4(x)
        out=self.out(x)

        # out=torch.concat((out,out,out),dim=1)*255.

        return out

    def freeze_backbone(self):
        for param in self.resnet18.parameters():
            param.requires_grad = False

    def unfreeze_backbone(self):
        for param in self.resnet18.parameters():
            param.requires_grad = True


if __name__ == '__main__':

    net=Net()
    print(net)
    # stat(net,(3,640,640))

    summary(net,input_size=(3,512,512),device='cpu')

    aa=torch.ones((6,3,512,512))
    out=net(aa)
    print(out.shape)
    # ii=torch.zeros((1,3,640,640))
    # outs=net(ii)
    # print(outs.shape)






主题界面显示及代码:
在这里插入图片描述

from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from untitled import Ui_Form
import sys
import cv2 as cv
from PyQt5.QtCore import QCoreApplication
import numpy as np
from PyQt5 import QtCore,QtGui
from PIL import Image
from predict import *

class My(QMainWindow,Ui_Form):
    def __init__(self):
        super(My,self).__init__()
        self.setupUi(self)
        self.setWindowTitle('图像去雨去雾')
        self.setIcon()
        self.pushButton.clicked.connect(self.pic)
        self.pushButton_2.clicked.connect(self.pre)
        self.pushButton_3.clicked.connect(self.pre2)
    def setIcon(self):
       palette1 = QPalette()
       # palette1.setColor(self.backgroundRole(), QColor(192,253,123))   # 设置背景颜色
       palette1.setBrush(self.backgroundRole(), QBrush(QPixmap('back.png')))  # 设置背景图片
       self.setPalette(palette1)
    def pre(self):
        out=pre(self.img,0)
        out=self.cv_qt(out)
        self.label_2.setPixmap(QPixmap.fromImage(out).scaled(self.label.width(),self.label.height(),QtCore.Qt.KeepAspectRatio))
    def pre2(self):
        out=pre(self.img,1)
        out=self.cv_qt(out)
        self.label_2.setPixmap(QPixmap.fromImage(out).scaled(self.label.width(),self.label.height(),QtCore.Qt.KeepAspectRatio))

    def pic(self):
        imgName, imgType = QFileDialog.getOpenFileName(self,
                                                       "打开图片",
                                                       "",
                                                       " *.png;;*.jpg;;*.jpeg;;*.bmp;;All Files (*)")
        #KeepAspectRatio
        png = QtGui.QPixmap(imgName).scaled(self.label.width(),self.label.height(),QtCore.Qt.KeepAspectRatio)  # 适应设计label时的大小
        self.label.setPixmap(png)

        self.img=Image.open(imgName)
        self.img=np.array(self.img)
    def cv_qt(self, src):
        #src必须为bgr格式图像
        #src必须为bgr格式图像
        #src必须为bgr格式图像
        if len(src.shape)==2:
            src=np.expand_dims(src,axis=-1)
            src=np.tile(src,(1,1,3))
            h, w, d = src.shape
        else:h, w, d = src.shape



        bytesperline = d * w
        # self.src=cv.cvtColor(self.src,cv.COLOR_BGR2RGB)
        qt_image = QImage(src.data, w, h, bytesperline, QImage.Format_RGB888).rgbSwapped()
        return qt_image

if __name__ == '__main__':
    QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
    app=QApplication(sys.argv)
    my=My()
    my.show()
    sys.exit(app.exec_())

项目结构:
在这里插入图片描述
直接运行main.py即可弹出交互界面。
项目下载地址:下载地址-列表第19

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/455629.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Prompt提示工程上手指南:基础原理及实践(二)-Prompt主流策略

前言 上篇文章将Prompt提示工程大体概念和具体工作流程阐述清楚了,我们知道Prompt工程是指人们向生成性人工智能(AI)服务输入提示以生成文本或图像的过程中,对这些提示进行精炼的过程。生成人工智能是一个根据人类和机器产生的数…

【Python】使用plt库绘制动态曲线图,并导出为GIF或MP4

一、绘制初始图像 正常使用plt进行绘图,这里举例一个正弦函数: 二、绘制动态图的每一帧 思路: 根据横坐标点数绘制每一帧画面每次在当前坐标处,绘制一个点和垂直的线,来表示当前点可以在点上加个坐标等样式来增加…

Gut Microbes | 新生儿微生物组研究的方法学挑战

摘要 新生儿出生后,肠道菌群的定植对新生儿的健康发育起着至关重要的作用,并影响其日后的健康和疾病。了解新生儿肠道菌群的发育以及其与新生儿宿主的相互作用是一个重要的研究领域。然而,该领域的研究必须解决影响研究方法设计和实施的一系…

【Java系列】OOM 时,JVM 堆栈信息保存和分析

一、前言 在日常开发中,即使代码写得再谨慎,免不了还是会发生各种意外的事件,比如服务器内存突然飙高,又或者发生内存溢出(OOM)。当发生这种情况时,我们怎么去排查,怎么去分析原因呢? 一般遇到…

展厅设计中灯光的要点都是什么

1、白炽灯 白炽灯也就是普通普通白炽灯泡白炽灯有显色性强,开灯即亮,明暗可调,结构简单,造价低等优点,但缺点是使用寿命短,光效较低展厅设计中常使用于走道和其他部位。 2、卤钨灯 充气白炽灯填充气体中含有…

代码随想录day20(1)二叉树:二叉树的最小深度(leetcode111)

题目要求:求出一棵二叉树的最小深度 思路:最小深度指的是从根节点到最近叶子节点的最短路径上的节点数量(左右孩子必须都为空!)。思路类似于求二叉树的最大深度,仍然采用后序遍历,增加判断只有…

Docker启动安装nacos(踩过坑版)

1、Docker 拉取镜像 docker pull nacos/nacos-server:v2.1.0 2、创建宿主机挂载目录 mkdir -p /mydata/nacos/logs/ mkdir -p /mydata/nacos/conf/ 3、启动nacos并复制文件到宿主机,关闭容器 启动容器 docker run -p 8848:8848 --name nacos -d nacos/nacos-se…

通过Maven创建Web工程

通过Maven创建Web工程 方式一方式二 方式一 1.先创建一个Maven工程 2.把该Maven模块的pom文件里添加一个war 3.选中该Maven模块 点击项目架构 4.手动添加一个Web架构 方式二 1.也是new一个模块 但是直接配置好Web 2.这里就是我IDEA对Maven的设置 3.第一次创建 可能…

【C++】STL--String

这一节主要总结string类的常见接口,以及完成了string类的模拟实现。 目录 标准库的String类 string类常见接口 string类对象的常见构造 string析构函数:~string string类对象的容量操作 string类对象的访问及遍历操作 string类对象的修改操作 s…

jvm题库详解

1、JVM内存模型 注意:这个是基于jdk1.8之前的虚拟机,在jdk1.8后 已经没有方法区,一并合并到堆中的元空间了 JVM内存区域总共分为两种类型 线程私有区域:程序计数器、本地方法栈和虚拟机栈 线程共享区域:堆&#xff08…

让若依生成的service、mapper继承mybatisPlus的基类

前言:若依继承mybatisPlus后,生成代码都要手动去service、serviceImpl、mapper文件去继承mybatisplus的基类,繁琐死了。这里通过修改若依生成模版从而达到生成文件后直接使用mybatisPlus的方法。 一、首先找到若依生成模版文件位置&#xff…

顶顶通呼叫中心中间件-群集模式配置

文章目录 群集模式介绍联系我们配置流程群集模式下呼叫线路配置 群集模式介绍 在大规模的外呼或者呼入系统,比如整个系统需要1万并发,单机最高也就3000-5000并发,这时候就需要多机群集了。顶顶通呼叫中心中间件使用的是 redis 数据库&#x…

你选的Six Sigma咨询公司靠谱吗?保姆级避坑指南

近年来,企业为了追求更高的运营效率和产品质量,纷纷寻求Six Sigma这样的先进管理方法。然而,市场上的咨询公司琳琅满目,如何选择一家真正靠谱、能带来实际效益的咨询公司呢? 一、了解公司背景和实力 在选择Six Sigma咨…

msdn我告诉你itellyou做一个安静的工具站,各种windows镜像下载,iso体积都是很小的那种

官网地址:MSDN, 我告诉你 - 做一个安静的工具站 可以看到里面集成了各种操作系统,可以下载使用。 或者在他的新站点:登录 里面有最新的windows11系统可以下载,但是需要登陆之后才可以,随便第三方账号登陆即可&…

【Devin AI】全球首位AI程序员登场,程序员该如何保住饭碗?编程新纪元的革命已到来!

程序员们,警惕!我们的饭碗要被砸了! 一觉醒来,全球首位AI程序员 Devin 上线了!直接引爆整个科技圈。 Devin被介绍为世界首个完全自主的AI软件工程师。只需一句指令,它可端到端地构建和部署整个开发项目。 …

Ajax学习笔记(一):原生AJAX、HTTP协议、AJAX案例准备工作、发送AJAX请求、AJAX 请求状态

目录 一、原生AJAX 1.1AJAX 简介 1.2 XML 简介 1.3 AJAX的特点 二、HTTP协议 三、AJAX案例准备工作 四、发送AJAX请求 1.发送GET请求 2.发送POST请求 3.JSON响应 IE缓存问题: 五、AJAX 请求状态 一、原生AJAX 1.1AJAX 简介 AJAX 全称为 Asynchronous …

使用Nginx进行负载均衡

什么是负载均衡 Nginx是一个高性能的开源反向代理服务器,也可以用作负载均衡器。通过Nginx的负载均衡功能,可以将流量分发到多台后端服务器上,实现负载均衡,提高系统的性能、可用性和稳定性。 如下图所示: Nginx负…

shell控制多线程并发处理

一、前言 我们在用shell编程时,当用到循环语句时,如果循环的对象数量比较多,则代码一条一条处理,时间消耗会特别慢。如果此时机器资源充足,不妨学会多线程并发处理这招,帮助你提前打卡完成工作。 二、控制…

MySQL数据库自动备份(Linux操作系统)

方式一 参考:https://blog.csdn.net/qq_48157004/article/details/126683610?spm1001.2014.3001.5506 1.MySQL备份脚本 在/home/backups/下建立.sh文件,文件名称 mysql_backup.sh ,内容如下 #!/bin/bash #备份路径 BACKUP/home/backups/mysqlBackup…

酷开系统走在前列,品牌重启增长,酷开科技成为品牌商合作目标

区别于火热的移动端,手机屏作为私密屏,往往面向的是用户个体,而电视作为家庭连接的重要枢纽,不仅仅定位于公共屏,同时也面向客厅场景发挥着其大屏传播的作用,这里不仅牵扯到大屏营销,也关联着大…
最新文章