神经网络极简入门

神经网络是深度学习的基础,正是深度学习的兴起,让停滞不前的人工智能再一次的取得飞速的发展。

其实神经网络的理论由来已久,灵感来自仿生智能计算,只是以前限于硬件的计算能力,没有突出的表现,直至谷歌的AlphaGO的出现,才让大家再次看到神经网络相较于传统机器学习的优异表现。

本文主要介绍神经网络中的重要基础概念,然后基于这些概念手工实现一个简单的神经网络。希望通过理论结合实践的方式让大家更容易的理解神经网络。

1. 神经网络是什么

神经网络就像人脑一样,整体看上去非常复杂,但是其基础组成部分并不复杂。其组成部分中最重要的就是神经元neural),sigmod函数layer)。

1.1. 神经元

神经元(neural)是神经网络最基本的元素,一个神经元包含3个部分:

  • 获取输入:获取多个输入的数据

  • 数学处理:对输入的数据进行数学计算

  • 产生输出:计算后多个输入数据变成一个输出数据

image.png

从上图中可以看出,神经元中的处理有2个步骤。第一个步骤:从蓝色框变成红色框,是对输入的数据进行加权计算后合并为一个值(N)。N=x1w1+x2w2𝑁=𝑥1𝑤1+𝑥2𝑤2 其中,w1,w2𝑤1,𝑤2分别是输入数据x1,x2𝑥1,𝑥2的权重。一般在计算N𝑁的过程中,除了权重,还会加上一个偏移参数b𝑏,最终得到:N=x1w1+x2w2+b𝑁=𝑥1𝑤1+𝑥2𝑤2+𝑏

第二个步骤:从红色框变成绿色框,通过sigmoid函数是对N进一步加工得到的神经元的最终输出(M)。

1.2. sigmoid函数

sigmoid函数也被称为S函数,因为的形状类似S形

image.png

它是神经元中的重要函数,能够将输入数据的值映射到(0,1)(0,1)之间。最常用的sigmoid函数是 f(x)=11+e−x𝑓(𝑥)=11+𝑒−𝑥,当然,不是只有这一种sigmoid函数。

至此,神经元通过两个步骤,就把输入的多个数据,转换为一个(0,1)(0,1)之间的值。

1.3. 层

多个神经元可以组合成一层,一个神经网络一般包含一个输入层和一个输出层,以及多个隐藏层。

image.png

比如上图中,有2个隐藏层,每个隐藏层中分别有4个和2个神经元。实际的神经网络中,隐藏层数量和其中的神经元数量都是不固定的,根据模型实际的效果来进行调整。

1.4. 网络

通过神经元和层的组合就构成了一个网络,神经网络的名称由此而来。神经网络可大可小,可简单可复杂,不过,太过简单的神经网络模型效果一般不会太好。

因为一只果蝇就有10万个神经元,而人类的大脑则有大约1000亿个神经元,这就是为什么训练一个可用的神经网络模型需要庞大的算力,这也是为什么神经网络的理论1943年就提出了,但是基于深度学习的AlphaGO却诞生于2015年

2. 实现一个神经网络

了解上面的基本概念只能形成一个感性的认知。下面通过自己动手实现一个最简单的神经网络,来进一步认识神经元sigmoid函数以及隐藏层是如何发挥作用的。

2.1. 准备数据

数据使用sklearn库中经典的鸢尾花数据集,这个数据集中有3个分类的鸢尾花,每个分类50条数据。为了简化,只取其中前100条数据来使用,也就是取2个分类的鸢尾花数据。

from sklearn.datasets import load_iris

ds = load_iris(as_frame=True, return_X_y=True)
data = ds[0].iloc[:100]
target = ds[1][:100]

print(data)
print(target)

image.png

变量data100条数据,每条数据包含4个属性,分别是花萼的宽度和长度,花瓣的宽度和长度。

image.png

变量target中也是100条数据,只有0和1两种值,表示两种不同种类的鸢尾花。

2.2. 实现神经元

准备好了数据,下面开始逐步实现一个简单的神经网络。首先,实现最基本的单元--神经元。本文第一节中已经介绍了神经元中主要的2个步骤,分别计算出N𝑁和M𝑀。

image.png

计算N𝑁时,依据每个输入元素的权重(w1,w2𝑤1,𝑤2)和整体的偏移b𝑏;计算M𝑀时,通过sigmoid函数。

def sigmoid(x):
    return 1 / (1 + np.exp(-1 * x))

@dataclass
class Neuron:
    weights: list[float] = field(default_factory=lambda: [])
    bias: float = 0.0
    N: float = 0.0
    M: float = 0.0

    def compute(self, inputs):
        self.N = np.dot(self.weights, inputs) + self.bias
        self.M = sigmoid(self.N)
        return self.M

上面的代码中,Neuron类表示神经元,这个类有4个属性:其中属性weightsbias是计算N𝑁时的权重和偏移;属性NM分别是神经元中两步计算的结果。

Neuron类的compute方法根据输入的数据计算神经元的输出。

2.3. 实现神经网络

神经元实现之后,下面就是构建神经网络。我们的输入数据是带有4个属性(花萼的宽度和长度,花瓣的宽度和长度)的鸢尾花数据,所以神经网络的输入层有4个值(x1,x2,x3,x4𝑥1,𝑥2,𝑥3,𝑥4)。

为了简单起见,我们的神经网络只构建一个隐藏层,其中包含3个神经元。最后就是输出层,输出层最后输出一个值,表示鸢尾花的种类。

由此形成的简单神经网络如下图所示:

image.png

实现的代码:

@dataclass
class MyNeuronNetwork:
    HL1: Neuron = field(init=False)
    HL2: Neuron = field(init=False)
    HL3: Neuron = field(init=False)
    O1: Neuron = field(init=False)

    def __post_init__(self):
        self.HL1 = Neuron()
        self.HL1.weights = np.random.dirichlet(np.ones(4))
        self.HL1.bias = np.random.normal()

        self.HL2 = Neuron()
        self.HL2.weights = np.random.dirichlet(np.ones(4))
        self.HL2.bias = np.random.normal()

        self.HL3 = Neuron()
        self.HL3.weights = np.random.dirichlet(np.ones(4))
        self.HL3.bias = np.random.normal()

        self.O1 = Neuron()
        self.O1.weights = np.random.dirichlet(np.ones(3))
        self.O1.bias = np.random.normal()

    def compute(self, inputs):
        m1 = self.HL1.compute(inputs)
        m2 = self.HL2.compute(inputs)
        m3 = self.HL3.compute(inputs)

        output = self.O1.compute([m1, m2, m3])
        return output

MyNeuronNetwork类是自定义的神经网络,其中的属性是4个神经元HL1HL2HL3隐藏层的3个神经元;O1输出层的神经元。

__post__init__函数是为了初始化各个神经元。因为输入层是4个属性(x1,x2,x3,x4𝑥1,𝑥2,𝑥3,𝑥4),所以神经元HL1HL2HL3weights初始化为4个随机数组成的列表,偏移(bias)用一个随时数来初始化。

对于神经元O1,它的输入是隐藏层的3个神经元,所以它的weights初始化为3个随机数组成的列表,偏移(bias)还是用一个随时数来初始化。

最后还有一个compute函数,这个函数描述的就是整个神经网络的计算过程。首先,根据输入层(x1,x2,x3,x4𝑥1,𝑥2,𝑥3,𝑥4)的数据计算隐藏层的神经元(HL1HL2HL3);然后,以隐藏层的神经元(HL1HL2HL3)的输出作为输出层的神经元(O1)的输入,并将O1的计算结果作为整个神经网络的输出。

2.4. 训练模型

上面的神经网络中各个神经元的中的参数(主要是weightsbias)都是随机生成的,所以直接使用这个神经网络,效果一定不会很好。所以,我们需要给神经网络(MyNeuronNetwork类)加一个训练函数,用来训练神经网络中各个神经元的参数(也就是个各个神经元中的weightsbias)。

@dataclass
class MyNeuronNetwork:
    # 略...

    def train(self, data: pd.DataFrame, target: pd.Series):
        ## 使用 随机梯度下降算法来训练
        pass

上面的train函数有两个参数data(训练数据)和target(训练数据的标签)。我们使用随机梯度下降算法来训练模型的参数。这里略去了具体的代码,完整的代码可以在文章的末尾下载。

此外,再实现一个预测函数predict,传入测试数据集,然后用我们训练好的神经网络模型来预测测试数据集的标签。

@dataclass
class MyNeuronNetwork:
    # 略...
    
    def predict(self, data: pd.DataFrame):

        results = []
        for idx, row in enumerate(data.values):
            pred = self.compute(row)
            results.append(round(pred))

        return results

2.5. 验证模型效果

最后就是验证模型的效果。

def main():
    # 加载数据
    ds = load_iris(as_frame=True, return_X_y=True)

    # 只用前100条数据
    data = ds[0].iloc[:100]
    target = ds[1][:100]

    # 划分训练数据,测试数据
    # test_size=0.2 表示80%作为训练数据,20%作为测试数据
    X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)

    # 创建神经网络
    nn = MyNeuronNetwork()

    # 用训练数据集来训练模型
    nn.train(X_train, y_train)

    # 检验模型
    # 用训练过的模型来预测测试数据的标签
    results = nn.predict(X_test)
    df = pd.DataFrame()
    df["预测值"] = results
    df["实际值"] = y_test.values
    print(df)

运行结果可以看出,模型的效果还不错,20条测试数据的预测结果都正确。

image.png

3. 总结

本文中的的神经网络示例是为了介绍神经网络的一些基本概念,所以对神经网络做了尽可能的简化,为了方便去手工实现。

而实际环境中的神经网络,不仅神经元的个数,隐藏层的数量极其庞大,而且其计算和训练的方式也很复杂,手工去实现不太可能,一般会借助TensorFlowKerasPyTorch等等知名的python深度学习库来帮助我们实现。

文章转载自:wang_yb

原文链接:https://www.cnblogs.com/wang_yb/p/18176563

体验地址:引迈 - JNPF快速开发平台_低代码开发平台_零代码开发平台_流程设计器_表单引擎_工作流引擎_软件架构

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/602125.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

响应式编程Spring Reactor探索

一,介绍 响应式编程(Reactive Programming),简单来说是一种生产者只负责生成并发出数据/事件,消费者来监听并负责定义如何处理数据/事件的变化传递方式的编程思想。 响应式编程借鉴了Reactor设计模式,我们…

神器:jQuery一键转换为纯净JavaScript代码

我的新书《Android App开发入门与实战》已于2020年8月由人民邮电出版社出版,欢迎购买。点击进入详情 该工具将 jQuery 代码转换为现代、高效的 JavaScript。它允许您用纯 JavaScript 替换 jQuery,同时保持原始代码不变。 虽然 jQuery 一直是 Web 开发中…

防火墙技术基础篇:什么是包过滤技术

什么是防火墙包过滤技术 当数据在网络中传输时,它们被分割成小的单元,称为数据包。防火墙的包过滤是一种基本的网络安全技术,用于检查这些数据包并根据预定义的规则决定是否允许它们通过防火墙。 防火墙包过滤是一种关键的网络安全技术&am…

在下游市场需求带动下 轮胎模具市场规模逐渐扩大

在下游市场需求带动下 轮胎模具市场规模逐渐扩大 轮胎模具是通过硫化、成型等工序生产各种轮胎的一种工具。轮胎模具是生产轮胎的关键设备之一,其性能直接影响到轮胎的耐用性和安全性。根据花纹加工工艺不同,轮胎模具加工工艺可分为精密铸造工艺、数控雕…

炒美股怎么开户?

近年来,随着国内投资者对境外投资需求的不断增长,炒美股逐渐成为许多投资者的选择。然而,随着监管政策的不断完善,传统的互联网券商开户方式已经不再适用。那么,对于想要入场美股市场的投资者来说,该如何开…

实现左上角的固定视口但是网格以图片中心放大缩小

仅仅修改了showbk() 函数部分,增加bkv4 直接采样,然后粘贴到左上角,实现多余部分裁剪,形成视口内放大缩小 // 程序:2D RPG 地图编辑器与摄像机追随 // 作者:bilibili 民用级脑的研发…

windows10打印机共享完美解决方案

提到文件共享大家并不陌生,相关的还有打印机共享,这个多见于单位、复印部,在一个区域网里多台电脑共用一台打印机,打印资料非常方便,就包括在家里,我们现在一般都会有多台电脑或设备,通过家庭网络联接,如果共享一台打印机的话也是件便捷的事。 但是随着操作系统的更新…

Win11任务栏通知很不明显的解决方案

Win11流行起来后,不少用户抱怨Win11的任务栏通知闪烁的颜色很不明显,经常微信来消息了看不到。虽然有右下角的微信图标会闪烁,但是提醒舒适度还是觉得不如Win10舒服显眼。 默认的颜色是这样子的,可以看得出Win11的任务栏提醒颜…

QGraphicsItem的prepareGeometryChange 和 update方法区别

prepareGeometryChange 这个函数用于为图形的几何形状变化做准备。在改变一个项目的边界矩形之前调用此函数,以保持 QGraphicsScene 的索引是最新的。如果必要的话,prepareGeometryChange() 会调用 update()。QGraphicsScene认为所有图元的boundingRect…

Python数据分析常用模块的介绍与使用

Python数据分析模块 前言一、Numpy模块Numpy介绍Numpy的使用Numpy生成数组ndarrayarray生成数组arange生成数组random生成数组其他示例 关于randint示例1示例2 关于rand Numpy数组统计方法示例 二、Pandas模块pandas介绍Series示例 DataFrame示例 三、其他模块Matplotlib/Seabo…

Apache Knox 2.0.0使用

目录 介绍 使用 gateway-site.xml users.ldif my_hdfs.xml my_yarn.xml 其它 介绍 The Apache Knox Gateway is a system that provides a single point of authentication and access for Apache Hadoop services in a cluster. The goal is to simplify Hadoop securit…

【Qt】Qt开发中常用命名规范、快捷键和窗口坐标体系详解

Qt是一款强大的跨平台C应用程序开发框架,为了提高代码的可读性和可维护性,遵循一定的命名规范是非常重要的。此外,Qt Creator提供了许多快捷键和便捷功能,能够提高开发效率。本文将介绍Qt开发中常用的命名规范、快捷键以及窗口坐标…

来聊聊Java项目分层规范

写在文章开头 近期和读者交流聊到项目规范,借着这个机会我们不妨聊聊主流Java项目是如何进行分层的。 Hi,我是 sharkChili ,是个不断在硬核技术上作死的 java coder ,是 CSDN的博客专家 ,也是开源项目 Java Guide 的维…

[华为OD]C卷 运输时间 200 动态规划

题目: M辆车需要在一条不能超车的单行道到达终点,起点到终点的距离为N。速度快的车追上前车 后,只能以前车的速度继续行驶,求最后一车辆到达目的地花费的时间。 注意: 每辆车固定间隔1小时出发,比如第…

静态NAT

哈喽!各位小伙伴们好久不见,最近由于工作的原因断更了一段时间,不过最近我都会把这些给补上,今天我们来学习一个简单的知识——静态NAT转换。 第一章 什么是NAT技术? 网络地址转换技术NAT(Networ…

红帽发布Red Hat Enterprise Linux AI(RHEL AI)

红帽 2024 峰会正在科罗拉多州丹佛市举行…鉴于当前的时代背景,人工智能(AI)在此次峰会上占据了重要位置,因此红帽公司(Red Hat)也不甘人后宣布推出 RHEL AI。 红帽公司今天发布了 Red Hat Enterprise Lin…

优化电脑空间清理电脑占用磁盘空间垃圾

1. 清理磁盘 右下角放大镜,搜索 此电脑 点击要清理的磁盘 ,比如点击C盘,右键属性,常规选项卡,点击清理磁盘, 和点击清理系统文件 1.1 优化磁盘 右下角放大镜,搜索 此电脑 点击要清理的磁盘 &…

RUST 编程语言使构建更安全的软件变得更加容易。RUST ALL THE THINGS 需要什么?

人不走空 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌赋:斯是陋室,惟吾德馨 目录 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌…

基于Spring Ai 快速创建一个AI会话

文章目录 1、创建SpringBoot项目2、引入依赖3、修改配置文件4、一个简单的会话 前期准备 在OpenAI 注册页面创建帐户并在API 密钥页面生成令牌。 Spring AI 项目定义了一个配置属性,您应该将其设置为从 openai.com 获取的spring.ai.openai.api-key值 代码托管于gite…

sql查询数据语句

select * from 表名 where 列名 某个数据名字 查询某个表名中的某列是否有某个数据
最新文章