政安晨:【使用 TensorFlow 和 Keras 为结构化数据构建和训练神经网络】(一)—— 单个神经元

政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏政安晨的机器学习笔记

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

咱们在这篇文章中将了解与练习深度学习的构建模块--线性单元。

开始深度学习的入门练习

利用这个系列的文章,您即将学习到构建自己的深度神经网络所需的一切。

通过使用Keras和Tensorflow,您将学习以下内容:

× 创建一个全连接的神经网络架构

× 将神经网络应用于两个经典的机器学习问题:回归和分类

× 使用随机梯度下降训练神经网络

× 通过使用dropout、批归一化(batch normalization)和其他技术来提高性能。

咱们这个系列文章将通过完整的实例向您介绍这些主题,然后在练习中,您将更深入地探索这些主题,并将它们应用于真实世界的数据集中。

现在让我们开始!


深度学习究竟是什么?

最近几年人工智能领域最令人印象深刻的进展之一是深度学习领域。

自然语言翻译、图像识别和游戏玩耍等任务,深度学习模型已经接近甚至超过了人类水平的表现。

那么什么是深度学习呢?

深度学习是一种机器学习方法,其特点是深层次的计算堆叠。

这种计算的深度使得深度学习模型能够解开最具挑战性的真实世界数据集中的复杂和层次化模式。

通过其强大的能力和可伸缩性,神经网络已成为深度学习的定义模型。

神经网络由神经元组成,每个神经元仅执行简单的计算。

神经网络的强大之处在于这些神经元之间可以形成复杂的连接。

线性单元

让我们从神经网络的基本组件开始:单个神经元。

作为一个图示,具有一个输入的神经元(或单元)如下所示:

输入是x。它与神经元的连接具有权重w。

当一个值通过连接流动时,你将该值乘以连接的权重对于输入x,到达神经元的值为w * x。神经网络通过修改它的权重来"学习"。

b是一种特殊的重量,我们称之为偏置偏置没有与之关联的输入数据

相反,我们在图中放置一个1,以使到达神经元的值只是b(因为1 * b = b)偏置使神经元能够独立于其输入来修改输出。

y是神经元最终输出的值为了得到输出,神经元将通过其连接接收到的所有值相加起来。

这个神经元的激活函数y = w * x + b,也可以用公式表示为 y=wx+b。

这个公式 y=wx+b 看起来熟悉吗?

这是一条直线的方程这是斜率截距方程,其中 w 表示斜率,b 表示 y 轴截距。

例子 - 以线性单元为模型

线性单元是一种常见的机器学习模型,用于解决回归和分类问题。它是一种简单的数学模型,可以表示为输入特征的加权和加上一个偏置项。

线性单元的数学表达式可以表示为

y = w1x1 + w2x2 + ... + wn*xn + b

其中,y是输出变量,x1, x2, ..., xn是输入特征,w1, w2, ..., wn是对应特征的权重,b是偏置项。

线性单元的训练目标是通过调整权重和偏置项,使得模型的输出与训练样本的真实值之间的差距尽可能小。

通过训练数据集,可以使用不同的优化算法(如梯度下降)来调整权重和偏置项,使得模型尽可能拟合训练数据,并在测试数据上有较好的泛化能力。

线性单元的优点是简单、易解释,并且在某些问题上具有很好的表现。然而,它也有一些限制,例如不能处理非线性关系、对异常值敏感等。

总之,线性单元是机器学习中常用的模型之一,它在许多问题上都能够提供良好的性能和解释能力。

尽管个体神经元通常只能作为更大网络的一部分发挥作用,但从单个神经元模型作为基线开始通常是有用的。单个神经元模型是线性模型。

让我们思考一下如何在一个像80种谷物的数据集上进行训练。

我们以'糖分'(每份的克数)作为输入,以'卡路里'(每份的卡路里)作为输出,可能会发现偏差是b=90,权重是w=2.5。我们可以用这个模型来估计每份含有5克糖分的谷物的卡路里含量,方法如下:

而且,根据我们的公式检验,我们有卡路里=2.5×5+90=102.5,正如我们所预期的一样。

多重输入

80种麦片数据集不仅仅含有“糖分”这一特征。如果我们想要扩展我们的模型,包括像纤维或蛋白质含量这样的特征,这很容易实现。我们只需为每个额外的特征添加更多的输入连接到神经元。为了找到输出,我们需要将每个输入与其连接权重相乘,然后将它们全部相加。

这个神经元的公式是 y=w0x0+w1x1+w2x2+b具有两个输入的线性单元可以拟合一个平面,而具有更多输入的单元可以拟合一个超平面。

Keras中的线性单元

在Keras中创建模型的最简单方法是通过keras.Sequential,它将神经网络以一系列层的堆叠形式创建。

我们可以使用一个密集层(在以后的文章中咱们将学习更多)来创建像上面那样的模型。

我们可以这样定义一个线性模型,接收三个输入特征('sugars','fiber'和'protein'),并产生一个输出('calories'):

from tensorflow import keras
from tensorflow.keras import layers

# Create a network with 1 linear unit
model = keras.Sequential([
    layers.Dense(units=1, input_shape=[3])
])

使用第一个参数units我们定义了我们想要的输出数量。在这种情况下,我们只是预测“卡路里”,所以我们将使用units=1。

使用第二个参数input_shape,我们告诉Keras输入的维度。设置input_shape = [3]确保模型将接受三个特征作为输入('sugars','fiber'和'protein')。

这个模型现在已经准备好拟合训练数据了!

为什么input_shape是一个Python列表?

在这个系列的文章中,我们将使用表格数据,类似于Pandas dataframe中的数据

我们将为数据集中的每个特征设置一个输入。特征以列的形式排列,所以我们将始终有input_shape=[num_columns]。

Keras在这里使用列表的原因是为了允许使用更复杂的数据集。例如,图像数据可能需要三个维度:[高度,宽度,通道]。

做个练习

在本文中,我们学习了神经网络的构建块:线性单元

我们看到,只有一个线性单元的模型可以将线性函数拟合到一个数据集上(等同于线性回归)

在这个练习中,你将构建一个线性模型,并在Keras中进行一些实践工作。

这个练习我们是在Kaggle中执行的,使用的数据已经由Kaggle准备好

# Setup plotting
import matplotlib.pyplot as plt

plt.style.use('seaborn-whitegrid')
# Set Matplotlib defaults
plt.rc('figure', autolayout=True)
plt.rc('axes', labelweight='bold', labelsize='large',
       titleweight='bold', titlesize=18, titlepad=10)

# Setup feedback system
from learntools.core import binder
binder.bind(globals())
from learntools.deep_learning_intro.ex1 import *

背景介绍

红葡萄酒质量数据集包括大约1600瓶葡萄牙红葡萄酒的理化测量数据。此外,还包括每瓶酒的品质评分,评分是通过盲品测试得出的。

首先,运行下一个单元格以显示该数据集的前几行数据。

import pandas as pd

red_wine = pd.read_csv('../input/dl-course-data/red-wine.csv')
red_wine.head()

您可以使用shape属性获取数据帧(或Numpy数组)的行数和列数。

red_wine.shape # (rows, columns)

1. 输入形状

我们能够从理化测量中准确预测葡萄酒的品质吗?目标是“品质 - quality”,其余的列是特征。你会如何设置Keras模型在此任务中的input_shape参数?

# YOUR CODE HERE
input_shape = ____

# Check your answer
q_1.check()

检查当您更新了初始代码后check()将告诉您您的代码是否正确。您需要更新创建变量input_shape的代码。

# Lines below will give you a hint or solution code
#q_1.hint()
#q_1.solution()

2. 定义一个线性模型

线性模型是一种描述变量之间线性关系的数学模型。它假设因变量与自变量之间存在线性关系,并试图使用线性方程来表示这种关系。在线性模型中,因变量被假设为自变量的线性组合,通过调整模型的系数,可以找到最佳拟合于数据的线性关系。线性模型适用于许多领域,如统计学、经济学和机器学习等,它提供了一种简单而有效的方法来理解和预测变量之间的关系。

现在定义一个适合这个任务的线性模型。请注意模型应该有多少个输入和输出。

from tensorflow import keras
from tensorflow.keras import layers

# YOUR CODE HERE
model = ____

# Check your answer
q_2.check()

检查当你更新了起始代码后,check()函数将告诉你你的代码是否正确。你需要更新创建模型变量的代码。

# Lines below will give you a hint or solution code
#q_2.hint()
#q_2.solution()

3.查看权重

在内部,Keras使用张量来表示神经网络的权重

张量基本上是TensorFlow版本的Numpy数组,但有一些差异使其更适合深度学习。其中最重要的是,张量与GPU和TPU(张量处理器)加速器兼容。实际上,TPU是专门设计用于张量计算的。

一个模型的权重以张量列表的形式保存在其权重属性中。

获取您上面定义的模型的权重。(如果您愿意,您可以使用类似以下的代码显示权重: print("Weights\n{}\n\nBias\n{}".format(w, b)))。

# YOUR CODE HERE
w, b = ____

# Check your answer
q_3.check()

# Lines below will give you a hint or solution code
#q_3.hint()
#q_3.solution()

顺便说一下,Keras使用张量来表示权重,并且也使用张量来表示数据。

当你设置input_shape参数时,你告诉Keras训练数据中每个示例的数组应该具有的维度。

设置input_shape=[3]会创建一个接受长度为3的向量(例如[0.2, 0.4, 0.6])的网络。

可选:绘制一个未经训练的线性模型的输出

我们将通过接下来的文章解决的问题是回归问题,目标是预测某个数值目标。回归问题类似于“曲线拟合”问题:我们试图找到最适合数据的曲线。让我们来看一下线性模型产生的“曲线”。(你可能已经猜到了它是一条直线!

我们在之前提到过,在训练模型之前,模型的权重是随机设置的。运行下面的单元格几次,看看不同的随机初始化产生的线条。(这个练习没有编码,只是一个演示。

import tensorflow as tf
import matplotlib.pyplot as plt

model = keras.Sequential([
    layers.Dense(1, input_shape=[1]),
])

x = tf.linspace(-1.0, 1.0, 100)
y = model.predict(x)

plt.figure(dpi=100)
plt.plot(x, y, 'k')
plt.xlim(-1, 1)
plt.ylim(-1, 1)
plt.xlabel("Input: x")
plt.ylabel("Target y")
w, b = model.weights # you could also use model.get_weights() here
plt.title("Weight: {:0.2f}\nBias: {:0.2f}".format(w[0][0], b[0]))
plt.show()

演绎如下:


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/470367.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

云手机在海外电商中的应用优势

随着海外市场的不断拓展,电商行业对于高效、安全的工具需求日益增长。在这一背景下,云手机作为一种新型服务,为海外电商提供了强大的支持和便利。云手机对传统物理手机起到了非常好的延展和补充作用,拓展了更广泛的应用场景&#…

海外社交营销为什么用云手机?不用普通手机?

海外社交营销作为企业拓展海外市场的重要手段,正日益受到企业的青睐。云手机以其成本效益和全球性特征,成为海外社交营销领域的得力助手。那么,究竟是什么特性使得越来越多的企业选择利用云手机进行海外社交营销呢?下文将对此进行…

js检测数据类型方式(typeof instanceof Object.prototype.toString.call())

typeof 使用 typeof 检测数据类型,首先返回的都是一个字符串,其次字符串中包含了对应的数据类型; 缺点: typeof null "object"不能具体细分是数组、正则还是对象中其他值,使用 typeof 检测数据类型对于对…

图论02-并查集的实现(Java)

2.并查集理论基础 并查集的作用 将两个元素添加到一个集合中。 判断两个元素在不在同一个集合并查集的实现 1.DSU 类定义:DSU 类中包含一个整型数组 s 用来存储元素的父节点信息。2.DSU 构造函数: 构造函数 DSU(int size) 接受一个参数 size&#xff0…

简述JVM

文章目录 什么是JVMJVM的功能解释和运行内存管理即时编译 常见的JVM字节码文件的组成基本信息常量池字段方法属性 JVM的组成类加载器定义分类类加载器的双亲委派机制类的生命周期 运行时数据区域(JVM管理的内存)执行引擎(主要介绍垃圾回收器&…

百度交易中台之系统对账篇

作者 | 天空 导读 introduction 百度交易中台作为集团移动生态战略的基础设施,面向收银交易与清分结算场景,赋能业务、提供高效交易生态搭建。目前支持百度体系内多个产品线,主要包括:度小店、小程序、地图打车、文心一言等。本文…

webpack中常见的Loader?解决了什么问题?

一、是什么 loader 用于对模块的"源代码"进行转换,在 import 或"加载"模块时预处理文件 webpack做的事情,仅仅是分析出各种模块的依赖关系,然后形成资源列表,最终打包生成到指定的文件中。如下图所示&#…

MFC界面美化第四篇----自绘list列表(重绘列表)

1.前言 最近发现读者对我的mfc美化的专栏比较感兴趣,因此在这里进行续写,这里我会计划写几个连续的篇章,包括对MFC按钮的美化,菜单栏的美化,标题栏的美化,list列表的美化,直到最后形成一个完整…

R语言:ggplot2做柱状图,随机生成颜色。

#加载包 > library(ggplot2) > library(tidyverse) > library(openxlsx) > library(reshape2) > library(RColorBrewer) > library(randomcoloR) > library(viridis) > set.seed(1233) #设立种子数。 > palette <- distinctColorPalette(30) …

Weblogic 弱口令 后台getshell漏洞 SSRF漏洞复现 SSRF漏洞绕过IP限制 任意文件上传漏洞复现(附源码)

Weblogic 弱口令 && 后台getshell漏洞 && SSRF漏洞复现 && SSRF漏洞绕过IP限制 && 任意文件上传漏洞复现(附源码)。 利用docker环境模拟了一个真实的weblogic环境,其后台存在一个弱口令,并且前台存在任意文件读取漏洞。 分别通过这两种漏…

Elasticsearch 悬挂索引解析与管理指南

在 Elasticsearch 的实战中&#xff0c;悬挂索引是一个既常见又容易引起困扰的概念。 今天&#xff0c;我将分享一次处理集群状态为RED&#xff0c;原因为DANGLING_INDEX_IMPORTED 的实战经验&#xff0c;深入探讨悬挂索引的定义、产生原因、管理方法&#xff0c;以及如何有效处…

【JVM】如何判断堆上的对象没有被引用?

如何判断堆上的对象没有被引用&#xff1f; 常见的有两种判断方法&#xff1a;引用计数法和可达性分析法。 引用计数法会为每个对象维护一个引用计数器&#xff0c;当对象被引用时加1&#xff0c;取消引用时减1。 引用计数法的缺点-循环引用 引用计数法的优点是实现简单&…

IText5填充PDF表单使用自定义字体中文生效而英文和数字不生效?

为什么使用IText5填充PDF时&#xff0c;使用自定义字体&#xff08;特别是某些新兴的字体&#xff09;时中文生效&#xff0c;英文和数字不生效&#xff1f; 查了相关资料&#xff0c;发现无果&#xff0c;或者都不生效。 看了api接口文档&#xff0c;发现有解决方案&#xf…

pytest ui自动化

chromedriver.exe 要对应已安装的chrome版本号

C# 使用OpenCvSharp4将Bitmap合成为MP4视频的环境

环境安装步骤&#xff1a; 在VS中选中项目或者解决方案&#xff0c;鼠标右键&#xff0c;选择“管理Nuget包”&#xff0c;在浏览窗口中搜索OpenCVSharp4 1.搜索OpenCvSharp4,选择4.8.0版本&#xff0c;点击安装 2.搜索OpenCvSharp4.runtime.win,选择4.8.0版本&#xff0c;点…

单例模式——对象创建型模式

引入——任务管理器&#xff1a; 动机&#xff1a; 对于一个软件系统的某些类而言&#xff0c;我们无须创建多个实例。 举个大家都熟知的例子——Windows任务管理器&#xff1a;通常情况下&#xff0c;无论我们启动任务管理多少次&#xff0c;Windows系统始终只能弹出一个任…

Sora没体验资格?开源项目:Open-Sora,复现类Sora视频生成方案

项目简介 Open-Sora项目是一项高效制作高质量视频的工作&#xff0c;明确所有权使用其模型、工具和内容的计划。通过采用开源原则&#xff0c;Open-Sora 不仅实现了先进的视频生成技术的普及&#xff0c;还提供了一个专业且用户界面的方案&#xff0c;简化了视频制作的复杂性。…

基于Springcloud+Vue校园招聘系统 Eureka分布式微服务

以行动研究为主&#xff0c;辅以文献法、教育实验法和个案研究法等方法相结合的研究方法。在研究方法&#xff0c;遵循软件工程中软件生命周期的规则。概括来讲可以划分成三大步&#xff1a;系统规划、系统开发和系统运行维护。将其上述步骤细分下来&#xff0c;可以分为以下8小…

59、服务攻防——中间件安全CVE复现IISApacheTomcatNginx

文章目录 中间件——IIS漏洞中间件——Nginx漏洞中间件——Apache中间件——Tomcat 中间件&#xff1a;IIS、Apache、Nginx、Tomcat、Docker、Weblogic、JBoss、WebSphere、Jenkinsphp框架&#xff1a;Laravel、Thinkphppythonl框架&#xff1a;Flaskjs框架&#xff1a;jQueryj…

[ESP32]:基于HTTP实现百度AI识图

[ESP32]&#xff1a;基于HTTP实现百度AI识图 测试环境&#xff1a; esp32-s3esp idf 5.1 首先&#xff0c;先配置sdk&#xff0c;可以写入到sdkconfig.defaults CONFIG_IDF_TARGET"esp32s3" CONFIG_IDF_TARGET_ESP32S3yCONFIG_PARTITION_TABLE_CUSTOMy CONFIG_PA…
最新文章