实验算法设计

文章目录

  • Unet
  • transformer
    • 整体网络架构

视觉

Unet

U-Net
卷积和转置卷积
转置卷积

  • 可以用双线性差值替换,效果差不多,参数更少。
    边缘缩减问题解释

边缘权重更大

双线性差值版

from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F 

class DoubleConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, mid_channels= None):
        if mid_channels is None:
            mid_channels = out_channels
        
        super(DoubleConv, self).__init__(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3,padding=1,bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3,padding=1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down,self).__init__(
            nn.MaxPool2d(2,stride=2),
            DoubleConv(in_channels, out_channels)
        )

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor)-> torch.Tensor:
        x1 = self.up(x1)
        diff_x = x2.size()[2] - x1.size()[2]
        diff_y = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2])
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

 class OutConv(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__(
            nn.Conv2d(in_channels, num_classes, kernel_size=1),
        )

 class Unet(nn.Module):
    def __init__(self,
                 in_channels: int = 1, 
                 num_classes: int = 2, 
                 bilinear: bool = True,
                 base_c: int = 64
                 ):
        super(Unet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.bilinear = bilinear
        self.in_conv = DoubleConv(in_channels, base_c)
        self.down1 = Down(base_c, base_c * 2)
        self.down2 = Down(base_c * 2, base_c * 4)
        self.down3 = Down(base_c * 4, base_c * 8)
        factor = 2 if bilinear else 1
        self.down4 = Down(base_c * 8, base_c * 16 // factor)
        self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
        self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
        self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
        self.up4 = Up(base_c * 2, base_c, bilinear)
        self.out_conv = OutConv(base_c, num_classes)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1) 
        logits = self.out_conv(x)
        return {"out":logits}

数据集

  • test测试集

    • 1st_manual标签
    • 2nd_manual更加细致的标签
    • images原图
    • mask掩码,感兴趣的区域
  • training训练集

    • 1st_manual
    • images
    • mask
  • 数据处理部分

import os
from PIL import Image
import numpy as np 
from torch.utils.data import Dataset

class DriveDataset(Dataset):
    def __init__(self, root: str, train: bool, transforms=None):
        super(DriveDataset, self).__init__()
        self.flag = "training" if train else "test"
        data_root = os.path.join(root, "DRIVE", self.flag)
        assert os.path.exists(data_root), f"path '{data_root}' does not exist."
        self.transforms = transforms
        img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]
        self.img_list = [os.path.join(data_root, "images", i) for i in img_names]
        self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")
                        for i in img_names]
        for i in self.manual:
            if not os.path.exists(i):
                raise FileNotFoundError(f"file {i} does not exist.")
        self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")
                        for i in img_names]
        for i in self.roi_mask:
            if not os.path.exists(i):
                raise FileNotFoundError(f"file {i} does not exist.")
        
    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx]).convert("RGB")
        manual = Image.open(self.manual[idx]).convert("L")
        manual = np.array(manual) / 255
        roi_mask = Image.open(self.roi_mask[idx]).convert("L")  
        roi_mask = 255 - np.array(roi_mask)
        mask = np.clip(manual + roi_mask, a_min=0, a_max=255)
        mask = Image.fromarray(mask)
        if self.transforms is not None:
            img, mask = self.transforms(img, mask)
        return img, mask
    
    def __len__(self):
        return len(self.img_list)

    @staticmethod
    def collate_fn(batch):
        images, targets = list(zip(*batch))
        batched_imgs = DriveDataset.cat_list(images, fill_value=0)
        batched_targets = DriveDataset.cat_list(targets, fill_value=255)
        return batched_imgs, batched_targets

    @staticmethod
    def cat_list(images, fill_value=0):
        max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
        batch_shape = (len(images),) + max_size
        batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
        for img, pad_img in zip(images, batched_imgs):
            pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
        return batched_imgs
  • 代码文件不全,参考一下就行

transformer

整体网络架构

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
编码器

  1. Embedding
    位置编码
  • 它能为每个时间步输出一个独一无二的编码;
  • 不同长度的句子之间,任何两个时间步之间的距离应该保持一致;
  • 模型应该能毫不费力地泛化到更长的句子。
  • 它的值应该是有界的;它必须是确定性的

在这里插入图片描述
在这里插入图片描述
注意力机制:从众多信息中选出对当前任务目标更加关键的信息

  1. 基本的注意力机制
  2. 在TRM中怎么操作

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
前馈神经网络

  • 前馈网络(feed-forward network)是一种常见的神经网络结构,由一个或多个线性变换和非线性激活函数组成。它的输入是一个词向量,经过一系列线性变换和激活函数处理之后,输出另一个词向量。
  • 前面都是线性变化(矩阵乘法),表达能力不够
  • 更加深入的特征提取
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    编码器:将输入文本序列编码为一系列隐藏表示,通常使用多层自注意力机制和前馈神经网络。
    解码器:接收编码器的输出和部分已生成的序列,使用自注意力机制和注意力机制来生成下一个词或字符。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/335948.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

interpret,一个超酷的 Python 库

更多资料获取 📚 个人网站:ipengtao.com 大家好,今天为大家分享一个超酷的 Python 库 - interpret。 Github地址:https://github.com/interpretml/interpret Python Interpret 是一个强大的开源工具,它为 Python 开发…

Crow:设置网站的index.html

对于一个网展来说,index.html是其第一个页面,也是根页面,如何通过Crow来加载index.html呢。 Crow:静态资源使用举例-CSDN博客 讲述了静态资源的使用,也就是通常存饭html,css,jpg文件的地方 当然index.html也会放在这个目录,但通常是放在static的根目录,其他资源会根据…

2.1.4-相关性分析

跳转到根目录:知行合一:投资篇 已完成: 1、投资&技术   1.1.1 投资-编程基础-numpy   1.1.2 投资-编程基础-pandas   1.2 金融数据处理   1.3 金融数据可视化 2、投资方法论   2.1.1 预期年化收益率   2.1.2 一个关于yaxb的…

vuex-跨模块访问

1. 场景 案例:跨模块访问和退出登录 假设我们有一个Vuex store,其中包含user模块和cart模块。当用户点击退出登录按钮时,我们需要调用user模块中的方法来清除用户信息,同时还需要清除cart模块中的购物车数据。 2. 实现-跨模块访…

air001研究笔记.基于arduino快速开发简单项目

一、air001芯片简介 air001是厂商合宙推出的一款tssop封装的mcu芯片。支持swd与串口烧录,多面向简单的功能简单类别的电子产品,因为官方文档齐全上手简易,所以也特别适合非专业爱好者乃至于幼儿编程。芯片内置资源:AIR001芯片数据…

国产AI新篇章:书生·浦语2.0带来200K超长上下文解决方案

总览:大模型技术的快速演进 自2023年7月6日“书生浦语”(InternLM)在世界人工智能大会上正式开源以来,其在社区和业界的影响力日益扩大。在过去半年中,大模型技术体系经历了快速的演进,特别是100K级别的长…

用LED数码显示器循环显示数字0~9

#include<reg51.h> // 包含51单片机寄存器定义的头文件 /************************************************** 函数功能&#xff1a;延时函数&#xff0c;延时一段时间 ***************************************************/ void delay(void) { unsigned …

Docker项目部署()

1.创建文件夹tools mkdir tools 配置阿里云 Docker Yum 源 : yum install - y yum - utils device - mapper - persistent - data lvm2 yum - config - manager -- add - repo http://mirrors.aliyun.com/docker- ce/linux/centos/docker - ce.repo 更新 yum 缓存 yum makec…

视频剪辑技巧:一键批量制作画中画视频的方法,高效提升剪辑任务

在数字媒体时代&#xff0c;视频剪辑已成为一项重要的技能。无论是专业的影视制作&#xff0c;还是日常的社交媒体分享&#xff0c;掌握视频剪辑技巧都能为内容增色不少。下面来看云炫AI智剪如何高效的剪辑视频技巧&#xff1a;一键批量制作画中画视频的方法&#xff0c;帮助您…

Vue3前端开发,provide和enject的基础练习,跨层级传递数据

Vue3前端开发,provide和enject的基础练习,跨层级传递数据&#xff01; 声明:provide虽然可以跨层级传递&#xff0c;但是依旧是需要由上向下的方向传递。根传子的方向。 <script setup> import {onMounted, ref} from vue import Base from ./components/Base.vue impor…

ssrf漏洞代码审计之douphp解析(超详细)

1.进入douphp的安装界面 www.douphp.com/install/ 由此可知安装界面已经被锁定了&#xff0c;但是由于install.lock是可控的&#xff0c;删除了install.lock后即可进行安装&#xff0c;所以我们现在的目的就是找到怎么去删除install.lock的方法。 要删除目标网站的任意文件&a…

人工智能-机器学习-深度学习-分类与算法梳理

人工智能-机器学习-深度学习-分类与算法梳理 目前人工智能的概念层出不穷&#xff0c;容易搞混&#xff0c;理清脉络&#xff0c;有益新知识入脑。 为便于梳理&#xff0c;本文只有提纲&#xff0c;且笔者准备仓促&#xff0c;敬请勘误&#xff0c;不甚感激。 请看右边目录索引…

动态规划Day14(子序列第二天)

目录 1143.最长公共子序列 看到题目的第一想法 看到代码随想录之后的想法 自己实现过程中遇到的困难 1035.不相交的线 看到题目的第一想法 看到代码随想录之后的想法 自己实现过程中遇到的困难 53. 最大子序和 看到题目的第一想法 …

网络编程01 常见名词的一些解释

本文将讲解网络编程的一些常见名词以及含义 在这之前让我们先唠一唠网络的产生吧,其实网络的产生也拯救了全世界 网络发展史 网络的产生是在美苏争霸的期间,实际上双方都持有核武器,希望把对方搞垮的同时不希望自己和对方两败俱伤. 希望破坏对方的核武器发射,这就涉及到三个方面…

【Vue】vue项目中Uncaught runtime errors:怎样关闭

vue项目中Uncaught runtime errors:怎样关闭 一、背景描述二、报错原因三、解决方案3.1 只显示错误信息不全屏覆盖3.2 取消全屏覆盖 四、参考资料 一、背景描述 项目本来运行的好好&#xff0c;换了个新的浏览器&#xff0c;新的Chrome浏览器版本号是116.0.5845.97&#xff08…

【Linux】Linux进程间通信(四)

​ ​&#x1f4dd;个人主页&#xff1a;Sherry的成长之路 &#x1f3e0;学习社区&#xff1a;Sherry的成长之路&#xff08;个人社区&#xff09; &#x1f4d6;专栏链接&#xff1a;Linux &#x1f3af;长路漫漫浩浩&#xff0c;万事皆有期待 上一篇博客&#xff1a;【Linux】…

flask分页宏增加更多参数

背景&#xff1a;我正在开发一个博客&#xff0c;核心的两个model是文章和文章类别。 现在想要实现的功能是&#xff1a;点击一个文章类别&#xff0c;以分页的形式显示这个文章类别下的所有文章&#xff0c;类似这种效果。 参考的书中分页宏只接受页数这一个参数&#xff0c;…

NLP论文阅读记录 - 2021 | WOS MAPGN:用于序列到序列预训练的掩码指针生成器网络

文章目录 前言0、论文摘要一、Introduction1.1目标问题1.2相关的尝试1.3本文贡献 二.前提三.本文方法四 实验效果4.1数据集4.2 对比模型4.3实施细节4.4评估指标4.5 实验结果4.6 细粒度分析 五 总结思考 前言 MAPGN: MASKED POINTER-GENERATOR NETWORK FOR SEQUENCE-TO-SEQUENCE…

python常用库

常见模块解析 1. math库 数学函数 函数返回值 ( 描述 )abs(x)返回数字的绝对值&#xff0c;如abs(-10) 返回 10ceil(x)返回数字的上入整数&#xff0c;如math.ceil(4.1) 返回 5cmp(x, y)如果 x < y 返回 -1, 如果 x y 返回 0, 如果 x > y 返回 1。 **Python 3 已废弃…

Pandas.DataFrame.groupby() 数据分组(数据透视、分类汇总) 详解 含代码 含测试数据集 随Pandas版本持续更新

关于Pandas版本&#xff1a; 本文基于 pandas2.1.2 编写。 关于本文内容更新&#xff1a; 随着pandas的stable版本更迭&#xff0c;本文持续更新&#xff0c;不断完善补充。 Pandas稳定版更新及变动内容整合专题&#xff1a; Pandas稳定版更新及变动迭持续更新。 Pandas API参…
最新文章