libtorch学习第六

构建卷积网络

#include<torch/torch.h>
#include<torch/script.h>
#include<iostream>

using std::cout; using std::endl;

class LinearBnReluImpl : public torch::nn::Module
{
private:
	torch::nn::Linear ln{ nullptr };
	torch::nn::BatchNorm1d bn{ nullptr };

public:
	LinearBnReluImpl(int input_features, int out_features);
	torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(LinearBnRelu);

inline torch::nn::Conv2dOptions conv_options(
	int64_t in_planes, int64_t out_planes, int64_t kernel_size,
	int64_t stride = 1, int64_t padding = 0, bool with_bias = false
)
{
	torch::nn::Conv2dOptions conv_options = torch::nn::Conv2dOptions(in_planes, out_planes, kernel_size);
	conv_options.stride(stride);
	conv_options.padding(padding);
	conv_options.bias(with_bias);

	return conv_options;
}

class ConvReluBnImpl : public torch::nn::Module
{
private:
	torch::nn::Conv2d conv{ nullptr };
	torch::nn::BatchNorm2d bn{ nullptr };
	
public:
	ConvReluBnImpl(int input_channel, int output_channel, int kernel_size, int stride, int padding=1);
	torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(ConvReluBn);


class MLP : public torch::nn::Module
{
private:
	int mid_features[3] = { 32, 64, 128 };
	LinearBnRelu ln1{ nullptr };
	LinearBnRelu ln2{ nullptr };
	LinearBnRelu ln3{ nullptr };
	torch::nn::Linear out_ln{ nullptr };

public:
	MLP(int in_features, int out_features);
	torch::Tensor forward(torch::Tensor x);
};

class plainCNN : public torch::nn::Module
{
private:
	int mid_channels[3]{ 32,64,128 };
	ConvReluBn conv1{ nullptr };
	ConvReluBn down1{ nullptr };
	ConvReluBn conv2{ nullptr };
	ConvReluBn down2{ nullptr };
	ConvReluBn conv3{ nullptr };
	ConvReluBn down3{ nullptr };
	torch::nn::Conv2d out_conv{ nullptr };

public:
	plainCNN(int in_channels, int out_channels);
	torch::Tensor forward(torch::Tensor x);
};


int main()
{
	plainCNN c(3, 2);

	auto x = torch::rand({ 1,3,224,224 }, torch::kFloat);
	//cout << x.sizes() << endl;
	
	auto a = c.forward(x);
	cout <<"[in Main]: "<< a.sizes() << endl;

	return 0;
}

LinearBnReluImpl::LinearBnReluImpl(int input_features, int out_features)
{
	ln = register_module("ln", torch::nn::Linear(torch::nn::LinearOptions(input_features, out_features)));
	bn = register_module("bn", torch::nn::BatchNorm1d(out_features));
}

torch::Tensor LinearBnReluImpl::forward(torch::Tensor x)
{
	x = torch::relu(ln->forward(x));
	x = bn(x);
	return x;
}

ConvReluBnImpl::ConvReluBnImpl(int input_channel, int output_channel, int kernel_size, int stride, int padding)
{
	conv = register_module("conv", torch::nn::Conv2d(conv_options(input_channel, output_channel, kernel_size, stride, padding)));
	bn = register_module("bn", torch::nn::BatchNorm2d(output_channel));
}

torch::Tensor ConvReluBnImpl::forward(torch::Tensor x)
{
	x = torch::relu(conv->forward(x));
	x = bn(x);
	return x;
}

MLP::MLP(int in_features, int out_features)
{
	ln1 = LinearBnRelu(in_features, mid_features[0]);
	ln2 = LinearBnRelu(mid_features[0], mid_features[1]);
	ln3 = LinearBnRelu(mid_features[1], mid_features[2]);
	out_ln = torch::nn::Linear(mid_features[2], out_features);

	ln1 = register_module("ln1", ln1);
	ln2 = register_module("ln2", ln2);
	ln3 = register_module("ln3", ln3);
	out_ln = register_module("out_ln", out_ln);
}

torch::Tensor MLP::forward(torch::Tensor x)
{
	x = ln1->forward(x);
	x = ln2->forward(x);
	x = ln3->forward(x);
	x = out_ln->forward(x);
	return x;
}

plainCNN::plainCNN(int in_channels, int out_channels)
{
	conv1 = ConvReluBn(in_channels, mid_channels[0], 3, 1);
	down1 = ConvReluBn(mid_channels[0], mid_channels[0], 3, 2);
	conv2 = ConvReluBn(mid_channels[0], mid_channels[1], 3,1);
	down2 = ConvReluBn(mid_channels[1], mid_channels[1], 3, 2);
	conv3 = ConvReluBn(mid_channels[1], mid_channels[2], 3,1);
	down3 = ConvReluBn(mid_channels[2], mid_channels[2], 3, 2);
	out_conv = torch::nn::Conv2d(conv_options(mid_channels[2], out_channels, 3));

	conv1 = register_module("conv1", conv1);
	down1 = register_module("down1", down1);
	conv2 = register_module("conv2", conv2);
	down2 = register_module("down2", down2);
	conv3 = register_module("conv3", conv3);
	down3 = register_module("down3", down3);
	out_conv = register_module("out_conv", out_conv);
}

torch::Tensor plainCNN::forward(torch::Tensor x)
{
	x = conv1->forward(x);
	cout << x.sizes() << endl;
	x = down1->forward(x);
	cout << x.sizes() << endl;
	x = conv2->forward(x);
	cout << x.sizes() << endl;
	x = down2->forward(x);
	cout << x.sizes() << endl;
	x = conv3->forward(x);
	cout << x.sizes() << endl;
	x = down3->forward(x);
	cout << x.sizes() << endl;
	x = out_conv->forward(x);
	cout << x.sizes() << endl;
	return x;
}


结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/345443.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[Python] glob内置模块介绍和使用场景(案例)

Unix glob是一种用于匹配文件路径的模式&#xff0c;它可以帮助我们快速地找到符合特定规则的文件。在本文中&#xff0c;我们将介绍glob的基本概念、使用方法以及一些实际应用案例。 glob介绍 Glob(Global Match)是Unix和类Unix系统中的一种文件名扩展功能&#xff0c;它可以…

分布式锁的实现方式

分布式锁是指分布式环境下&#xff0c;系统部署在多个机器中&#xff0c;实现多进程分布式互斥的一种锁。实现分布式锁有三种主流方式&#xff0c;接下来一一盘点。 盘点之前要说说选择时的优缺点 数据库实现的锁表完全不推荐。 Redis分布式锁性能优于ZooKeeper&#xff0c;因…

01、领域驱动设计:微服务设计为什么要选择DDD总结

目录 1、前言 2、软件架构模式的演进 3、微服务设计和拆分的困境 4、为什么 DDD适合微服务 5、DDD与微服务的关系 6、总结 1、前言 我们知道&#xff0c;微服务设计过程中往往会面临边界如何划定的问题&#xff0c;不同的人会根据自己对微服务的理 解而拆分出不同的微服…

通过 GScan 工具自动排查后门

一、简介 GScan 是一款为安全应急响应提供便利的工具&#xff0c;自动化监测系统中常见位置。 工具运行环境&#xff1a;CentOS (6、7) python (2.x、3.x) 工具检查项目&#xff1a; 1、主机信息获取 2、系统初始化 alias 检查 3、文件类安全扫描 3.1、系统重要文件完整行…

JS进阶-深入对象(二)

拓展&#xff1a;深入对象主要介绍的是Js的构造函数&#xff0c;实例成员&#xff0c;静态成员&#xff0c;其中构造函数和Java种的构造函数用法相似&#xff0c;思想是一样的&#xff0c;但静态成员和实例成员和java种的有比较大的差别&#xff0c;需要认真理解 • 创建对象三…

Switch用法以及新特性-最全总结版

本篇文章参考了大佬文章&#xff0c;感谢大佬无私分享&#xff1a; http://t.csdnimg.cn/MjZnX http://t.csdnimg.cn/QFg0x 目录 一、Switch用法&#xff1a;JDK7及以前 1.1、举例一&#xff1a; 1.2、举例二&#xff1a; 二、Switch穿透&#xff1a; 2.1、举例&#xf…

三极管的奥秘:如何用小电流控制大电流

双极性晶体管&#xff08;英语&#xff1a;bipolar transistor&#xff09;&#xff0c;全称双极性结型晶体管&#xff08;bipolar junction transistor, BJT&#xff09;&#xff0c;俗称三极管&#xff0c;是一种具有三个引脚的电子元器件。 本文是讲述的是三极管的基础知识…

基于openssl v3搭建ssl安全加固的c++ tcpserver

1 概述 tcp server和tcp client同时使用openssl库&#xff0c;可对通信双方流通的字节序列进行加解密&#xff0c;保障通信的安全。本文以c编写的tcp server和tcp client为例子&#xff0c;openssl的版本为v3。 2 安装openssl v3 2.1 安装 perl-IPC-Cmd openssl项目中的co…

企业内部知识库搭建教程,赶紧收藏起来

在企业运营中&#xff0c;内部知识库搭建是一项重要的挑战&#xff0c;并需要合理的规划与管理。尤其对于中大型企业&#xff0c;内部知识库能够提高工作效率&#xff0c;减轻员工工作压力与突发事件的处理的困扰。下面给大家提供一份完整的内部知识库搭建教程&#xff0c;快看…

UE4运用C++和框架开发坦克大战教程笔记(十五)(第46~48集)

UE4运用C和框架开发坦克大战教程笔记&#xff08;十五&#xff09;&#xff08;第46~48集&#xff09; 46. 批量加载 UClass 功能测试批量加载多个同类 UClass 资源 47. 创建单个资源对象测试加载并创建单个 UClass 资源对象 48. 创建同类资源对象 46. 批量加载 UClass 功能 逻…

Leetcode1143. 最长公共子序列

解题思路 求两个数组或者字符串的最长公共子序列问题&#xff0c;肯定是要用动态规划的。下面的题解并不难&#xff0c;你肯定能看懂。 首先&#xff0c;区分两个概念&#xff1a;子序列可以是不连续的&#xff1b;子数组&#xff08;子字符串&#xff09;需要是连续的&#xf…

rabbitmq基础-java-3、Fanout交换机

1、简介 Fanout&#xff0c;英文翻译是扇出。 2、 特点 1&#xff09; 可以有多个队列 2&#xff09; 每个队列都要绑定到Exchange&#xff08;交换机&#xff09; 3&#xff09; 生产者发送的消息&#xff0c;只能发送到交换机 4&#xff09; 交换机把消息发送给绑定过的…

3d模型怎么分辨材质?--模大狮模型网

在3D模型中&#xff0c;通常可以通过以下几种方式来分辨材质&#xff1a; 视觉检查&#xff1a;在3D渲染视图或预览窗口中&#xff0c;您可以直接观察模型的外观来区分不同的材质。不同的材质可能具有不同的颜色、纹理、反射率等特征&#xff0c;因此通过直观的视觉检查&#x…

网络通信课程总结(小飞有点东西)

27集 局域网通信&#xff1a;用MAC地址 跨局域网通信&#xff1a;用IP地址&#xff08;MAC地址的作用只是让我们找到网关&#xff09; 又因为arp技术&#xff0c;可以通过MAC地址找到IP地址&#xff0c;所以我们可以通过IP地址定位到全世界任意一台计算机。 28集 在数据链路…

C语言每日一题(47)两数相加II

力扣 445 两数相加II 题目描述 给你两个 非空 链表来代表两个非负整数。数字最高位位于链表开始位置。它们的每个节点只存储一位数字。将这两数相加会返回一个新的链表。 你可以假设除了数字 0 之外&#xff0c;这两个数字都不会以零开头。 示例1&#xff1a; 输入&#xff…

了解WPF控件:RadioButton和RepeatButton常用属性与用法(九)

掌握WPF控件&#xff1a;熟练常用属性&#xff08;九&#xff09; RadioButton 一种允许用户在一组选项中单选一个的控件。通常用于提供一组互斥的选项供用户选择。 常用属性描述Content用于设置 RadioButton 显示的文本内容。GroupName用于将多个 RadioButton 控件组合到一…

船的最小载重量-算法

说明&#xff1a;题解完全是从leetCode上拉下来的&#xff0c;在这里只是作为一个备份&#xff0c;怕之后找不着了。同时也分享给大家&#xff0c;这个题目用了一个我之前从未遇到的思路。 原题&#xff1a;船的最小载重量-leetCode1101 题目&#xff08;看懂题目了吗&#xff…

python批量处理修改pdf内容

将PDF转换为Word&#xff1a; 使用pdf2docx库中的Converter类来进行PDF转换。convert_pdf_to_docx函数接受PDF文件路径和输出的Word文档路径作为参数。通过调用Converter对象的convert方法将PDF转换为Docx格式。最后调用close方法关闭Converter对象并保存转换后的文档。 将Word…

QT下载、安装详细教程[Qt5.15及Qt6在线安装,附带下载链接]

QT5.15及QT6的下载和安装 1.下载1.1官网下载1.2国内镜像网站下载 2.安装3.软件启动及测试程序运行3.1Qt Creator&#xff08;Community&#xff09; 1.下载 QT自Qt5.15版本后不在支持离线安装包下载(非商业版本&#xff0c;开源)&#xff0c;故Qt5.15及Qt6需要使用在线安装程序…

Zephyr 源码调试

背景 调试环境对于学习源码非常重要&#xff0c;但嵌入式系统的调试环境搭建稍微有点复杂&#xff0c;需要的条件略多。本文章介绍如何在 Zephyr 提供的 qemu 上调试 Zephyr 源码&#xff0c;为后续分析 Zephyr OS 相关原理做铺垫。 环境 我的开发环境为 wsl ubuntu&#xf…
最新文章