Unet简单结构概述

在这里插入图片描述

总体结构代码

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        #bilinear表示是否使用双线性插值 如果使用 输出通道数量需要减半 factor=2
        #bilinear为False时 输出通道不需要减半 factor=1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

各个部分实现细节

DoubleConv每块内部的卷积层

构造的时候传入输入通道输出通道,使用的时候直接传入向量x

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

下采样

下采样压缩高度和宽度,增加通道数

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

上采样

上采样是一个反卷积的过程

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        #这里的x1是从下面往上采样的结果 x2是encoder部分的结果

        #上采样反卷积过程
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        #对x1进行填充,为什么在两边填充?
        #这种padding方式是为了保持特征图片中心不变 对小图片进行上下左右均匀的填充
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

输出卷积层

输出最后的结果

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/599324.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

软件设计师-应用技术-数据结构及算法题4

考题形式: 第一题:代码填空 4-5空 8-10第二题:时间复杂度 / 代码策略第三题:拓展,跟一组数据,把数据带入代码中,求解 基础知识及技巧: 1. 分治法: 基础知识&#xff1…

取消vscode go保存时自动格式化代码

go:v1.22.0 vscode go 插件:v0.41.4 setting.json formatOnSave: 保存文件时,是否执行格式化 codeActiosnOnSave:保存文件时,是否执行某些操作 organizeImports: 不再改动import()里面的包

分类规则挖掘(三)

目录 四、贝叶斯分类方法(一)贝叶斯定理(二)朴素贝叶斯分类器(三)朴素贝叶斯分类方法的改进 五、其它分类方法 四、贝叶斯分类方法 贝叶斯 (Bayes) 分类方法是以贝叶斯定理为基础的一系列分类算法的总称。贝…

python中numpy库使用

array数组 生成array数组 将list转化为array数组 import numpy as np np.array([1,2],typenp.int32)其中dtype定义的是元素类型,np.int32指32位的整形 如果直接定义dtypeint 默认的是32位整形。 zeors和ones方法 zeros()方法,该方法和ones()类似&a…

Qt——入门基础

目录 Qt入门第一个应用程序 main.cpp widget.h widget.cpp widget.ui .pro Hello World程序 对象树 编辑框 按钮 Qt 窗口坐标系 Qt入门第一个应用程序 main.cpp 这就像一开始学语言时都会打印一个“Hello World”一样,我们先来看看创建好一个项目后&…

ModuleNotFoundError: No module named ‘PyQt5‘

运行python程序的时候报错:ModuleNotFoundError: No module named ‘PyQt5‘ 这是因为没有安装pyqt5依赖包导致的,安装一下即可解决该问题。 安装依赖 pip install PyQt5 -i https://pypi.tuna.tsinghua.edu.cn/simple 这里是使用的清华镜像源进行安装…

数据库系统原理实验报告5 | 数据查询

整理自博主本科《数据库系统原理》专业课自己完成的实验报告,以便各位学习数据库系统概论的小伙伴们参考、学习。 专业课本: ———— 本次实验使用到的图形化工具:Heidisql 目录 一、实验目的 二、实验内容 1.找出读者所在城市是“shangh…

STM32G0存储器和总线架构

文章目录 前言一、系统架构二、存储器构成三、存储器地址映射四、存储器边界地址五、外设寄存器边界地址 前言 此文章是STM32G0 MCU的学习记录,并非权威,请谨慎参考。 STM32G0主流微控制器基于工作频率可达64 MHz的高性能Arm Cortex-M0 32位RISC内核。该…

GEE数据集——DeltaDTM 全球沿海数字地形模型数据集

DeltaDTM 全球沿海数字地形模型产品 简介 DeltaDTM 是全球沿岸数字地形模型(DTM),水平空间分辨率为 1 弧秒(∼30 米),垂直平均绝对误差(MAE)为 0.45 米。它利用 ICESat-2 和 GEDI …

内容安全(IPS入侵检测)

入侵检测系统( IDS )---- 网络摄像头,侧重于风险管理,存在于滞后性,只能够进行风险发现,不能及时制止。而且早期的IDS误报率较高。优点则是可以多点进行部署,比较灵活,在网络中可以进…

【java9】java9新特性之改进JavaDocs

Java9在JavaDocs方面的主要新特性是,其输出现在符合兼容HTML5标准。在之前的版本中,默认的HTML版本是 HTML4.01,但在Java9及之后的版本中,JavaDocs命令行工具将默认使用HTML5作为输出标记语言。这意味着,使用JavaDocs工…

Markdown 精简教程(胎教级教程)

文章目录 一、关于 Markdown1. 什么是 Markdown?2. 为什么要用 Markdown?3. 怎么用 Markdown?(编辑软件) 二、标题1. 常用标题写法2. 可选标题写法3. 自定义标题 ID4. 注意事项 三、段落四、换行五、字体选项1. 粗体2.…

跨境电商行业分析-商品出海的四大路径

1. 跨境电子商务模式和国内电子商务模式【区别】 最大的不同点有3个: 达成交易的双方是属于不同【关境】的交易主体商品通过众多电子商务平台/独立站等,进行支付结算通过国际物流的方式(海运/铁路/空运/卡车)进行报关、清关、派…

anconda创建虚拟环境,使用虚拟环境(基于win平台)

假设已经安装了anconda,打开anaconda的 shell。 查看已存在的虚拟环境,base是默认的,不用理会,后面的yolov5就是用户创建的 #查看有那些虚拟环境 (base) PS C:\Users\x> conda info -e # conda environments: # base …

如何判断代理IP质量?

由于各种原因(从匿名性和安全性到绕过地理限制),代理 IP 的使用变得越来越普遍。然而,并非所有代理 IP 都是一样的,区分高质量和低质量的代理 IP 对于确保流畅、安全的浏览体验至关重要。以下是评估代理 IP 质量时需要…

计划订单转采购申请的增强点和可以增强的内容

MD15 MD14 计划订单转采购申请,涉及的增强点和增强内容 对于外协的采购申请,有时候需要对组件的内容做一些特殊的处理,但是处理组件清单的增强ME_COMPONENTS_UPDATE的增强点(这个增强点对于手工创建的外协PR、外协PO,外协pr转外协…

Day21 代码随想录打卡|字符串篇---右旋转字符串

题目(卡码网 T55): 字符串的右旋转操作是把字符串尾部的若干个字符转移到字符串的前面。给定一个字符串 s 和一个正整数 k,请编写一个函数,将字符串中的后面 k 个字符移到字符串的前面,实现字符串的右旋转…

Pycharm无法链接服务器环境(host is unresponsived)

困扰了很久的一个问题,一开始是在服务器ubuntu20.04上安装pycharm community,直接运行服务器上的pycharm community就识别不了anaconda中的环境 后来改用pycharm professional也无法远程连接上服务器的环境,识别不了服务器上的环境&#xff…

[力扣题解]102.二叉树的层序遍历

题目&#xff1a;102. 二叉树的层序遍历 代码 迭代法 class Solution { public:vector<vector<int>> levelOrder(TreeNode* root) {queue<TreeNode*> que;TreeNode* cur;int i, size;vector<vector<int>> result;if(root ! NULL){que.push(ro…

Pycharm导入自定义模块报红

文章目录 Pycharm导入自定义模块报红1.问题描述2.解决办法 Pycharm导入自定义模块报红 1.问题描述 Pycharm 导入自定义模块报红&#xff0c;出现红色下划线。 2.解决办法 打开【File】->【Setting】->【Build,Execution,Deployment】->【Console】->【Python Con…
最新文章