深度学习 -- 逻辑回归 PyTorch实现逻辑回归

前言

线性回归解决的是回归问题,而逻辑回归解决的是分类问题,这两种问题的区别是前者的目标属性是连续的数值类型,而后者的目标属性是离散的标称类型。

可以将逻辑回归视为神经网络的一个神经元,因此学习逻辑回归能帮助理解神经网络的工作原理。

什么是逻辑回归?

逻辑回归是一种广义的线性回归分析模型,是监督学习的一种重要方法,主要用于二分类问题,但也可以用于多分类问题。

逻辑回归的主要思想是,对于一个二分类问题,先根据样本数据计算出每个特征的概率,然后根据这些概率计算出每个样本属于每个类别的概率,最后根据这些概率来预测测试集数据属于每个类别的概率。

逻辑回归介绍

逻辑回归的推导过程与计算方式类似于回归的过程,但实际上主要是用来解决二分类问题。 在逻辑回归中,输入数据集D被分成两个部分:一类是训练集D_train,一类是测试集D_test。在每次训练时,我们使用一部分数据来训练模型,然后使用另一部分数据来评估模型的性能。 在测试时,我们使用所有的数据来评估模型的性能。

在实际使用中,逻辑回归可以使用各种不同的损失函数来最小化训练数据集和测试数据集之间的均方误差。常见的逻辑回归损失函数包括均方误差损失函数、交叉熵损失函数、对数损失函数等。

Sigmoid函数

Sigmoid函数是一个在生物学中常见的S型函数,也称为S型生长曲线。在信息科学中,由于其单增以及反函数单增等性质,Sigmoid函数常被用作神经网络的激活函数,将变量映射到0,1之间。

在神经网络中经常使用Sigmoid函数作为激活函数,因为它能够有效的输出0-1之间的概率。

代价函数

代价函数(Cost Function)是深度学习模型中用于评估模型性能的函数,它是优化算法的目标函数。代价函数通常定义为损失函数(Loss Function)的平方,这样可以简单地通过计算损失函数值来评估模型的性能。

在深度学习中,代价函数通常是指均方误差(MSE)损失函数,因为均方误差是深度学习中最常用的损失函数之一。均方误差损失函数定义为:

J(y_true, y_pred) = 1/N - ∑i=1N(yi - y_i)^2

其中,y_true是真实标签,y_pred是模型预测的标签,N是样本数量,yi是真实标签对应的样本值。

代价函数的作用是评估模型的性能,其中J(y_true, y_pred)表示真实标签和模型预测标签之间的均方误差。优化算法会在代价函数上进行最小化操作,以最小化损失函数值。

除了均方误差损失函数,还有其他类型的损失函数,如交叉熵损失函数、对数损失函数等,它们在不同的场景下可能更有效或更适合。

逻辑回归在PyTorch中的实现

1 从头开始实现一个逻辑回归

  • 首先定义一个逻辑回归模型
import torch


def sigmoid(z):
    '''s型激活函数'''
    g = 1 / (1+torch.exp(-z))
    return g

def model(x,w,b):
    '''逻辑回归模型'''
    return sigmoid(x.mv(w)+b)

# w是向量,b是标量,而x是矩阵,使用x.mv(w) 可以实现矩阵x与向量w的相乘

注意:这里w是向量,b是标量,而x是矩阵,使用x.mv(w) 可以实现矩阵x与向量w的相乘

  • 然后定义损失函数和损失函数求导
# 定义损失函数
def loss_fn(y_pred,y):
    '''损失函数'''
    loss = - y.mul(y_pred.view_as(y)) - (1-y).mul(1-y_pred.view_as(y))
    return loss.mean()

# 损失函数求导
def grad_loss_fn(y_pred,y):
    '''损失函数求导'''
    return y_pred.view_as(y)-y

  • 接着定义一个梯度函数
# 定义梯度函数
def grad_fn(x,y,y_pred):
    '''梯度函数'''
    grad_w = grad_loss_fn(y_pred,y)*x
    grad_b = grad_loss_fn(y_pred,y)
    return torch.cat((grad_w.mean(dim=0),grad_b.mean().unsqueeze(0)),0)

  • 模型训练函数
# 模型训练函数
def model_training(x,y,n_epochs,learning_rate,params,print_params=True):
    '''训练'''
    for epoch in range(1,n_epochs+1):
        w,b = params[:-1],params[-1]
        
        # 前向传播
        y_pred = model(x,w,b)
        # 计算损失
        loss = loss_fn(y_pred,y)
        # 梯度
        grad = grad_fn(x,y,y_pred)
        # 更新参数
        params -= learning_rate*grad
        
        if epoch == 1 or epoch%10 == 1:
            print('轮次:%d,\t损失:%f'%(epoch,float(loss)))
            if print_params:
                print(f'参数:{params.detach().numpy()}')
                print(f'梯度:{grad.detach().numpy()}')
                
    return params
  • 最后定义main函数
if __name__ == '__main__':
    # 随机生成数据
    x = torch.randn(2,2)
    y = torch.tensor([[1.,0.],
                      [0.,1.]])
    # 模型参数初始化
    w = torch.zeros(2)  # tensor([0., 0.])
    b = torch.zeros(1)  # tensor([0.])

    params = model_training(x=x,y=y,n_epochs=500,learning_rate=0.1,params=torch.tensor([0.0,0.0,0.0]))

    print(params.numpy())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/17153.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python使用AI photo2cartoon制作属于你的漫画头像

Python使用AI photo2cartoon制作属于你的漫画头像 1. 效果图2. 原理3. 源码参考 git clone https://github.com/minivision-ai/photo2cartoon.git cd ./photo2cartoon python test.py --photo_path images/photo_test.jpg --save_path images/cartoon_result.png1. 效果图 官方…

(22)目标检测算法之 yolov8模型导出总结

yolov8模型导出总结 不断更新中… 几种部署情况: onnxxmlengine官网说明:https://github.com/ultralytics/ultralytics/blob/main/docs/modes/export.md导出参数: onnx 参数解析format: 导出的模型形式:onnx xml engine ... imgsz: 设置模型的输入尺寸大小,默认640*640 ke…

磁盘和固态磁盘

磁盘和固态磁盘 磁盘的物理结构 ​ 磁盘的表面由一些磁性的物质组成,可以用这些磁性物质来记录二进制数据。磁盘的盘面被划分成一个个磁道,这样一个“圈”就是一个磁道。同一磁盘上不同磁道上记录的信息量相同,因此内侧磁道上的数据密度较大…

STM32F429移植microPython笔记

目录 一、microPython下载。二、安装开发环境。三、编译开发板源码。四、下载验证。 一、microPython下载。 https://micropython.org/download/官网 下载后放在linux中。 解压命令: tar -xvf micropython-1.19.1.tar.xz 二、安装开发环境。 sudo apt-get inst…

【Java笔试强训 14】

🎉🎉🎉点进来你就是我的人了博主主页:🙈🙈🙈戳一戳,欢迎大佬指点! 欢迎志同道合的朋友一起加油喔🤺🤺🤺 目录 一、选择题 二、编程题 🔥计算日期…

玩着3dmax把Python学了-01

3ds Max 2022以前的版本要借助Python的api来实现Python编程达到编辑绘图脚本的功能,但是好消息来了,3ds Max 2022 起,MaxPlus 不再作为 3ds Max 的 Python API 包含在内。而是3ds Max 将 Python 3.7 的标准版本包涵其中了,位于 [3…

Filter 过滤器

Filter过滤器介绍 这里我们讲解Filter的执行流程,从下图可以大致了解到,当客户端发送请求的时候,会经过过滤器,然后才能到我们的servlet,当我们的servlet处理完请求之后,我们的response还是先经过过滤器才…

基于SpringBoot的线上日志阅读器

软件特点 部署后能通过浏览器查看线上日志。支持Linux、Windows服务器。采用随机读取的方式,支持大文件的读取。支持实时打印新增的日志(类终端)。支持日志搜索。 使用手册 基本页面 配置路径 配置日志所在的目录,配置后按回车…

2023亚马逊云科技研究,数字化技能为中国企业和员工带来经济效益

在中国,信息技术在个人、企业和宏观经济层面都推动着重大变革。为了研究这些变化所带来的影响,盖洛普咨询公司(Gallup)和亚马逊云科技开展了关于数字化技能的调研。 研究表明,数字化技能正在为中国企业和在职人员带来巨大的经济价值&#x…

一文带你入门C++类和对象【十万字详解,一篇足够了】

本文字数较多,建议电脑端访问。不多废话,正文开始 文章目录 ———————————————【类和对象 筑基篇】 ———————————————一、前言二、面向过程与面向对象三、结构体与类1、C中结构体的变化2、C中结构体的具体使用3、结构体 --&…

程序环境和预处理

目录 一 程序的翻译环境和执行环境 二 详解编译链接 2.1 翻译环境 2.2 编译本身也分为几个阶段 2.3 运行环境 三 预处理详解 3.1 预定义符号 3.2 #define 3.2.1 #define 定义标识符 3.2.2 #define定义宏 3.2.3 #define 替换规则 3.2.4 #和## 3.2.5 带副作用的宏参…

告别被拒,如何提升iOS审核通过率(上篇)

iOS审核一直是每款移动产品上架苹果商店时面对的一座大山,每次提审都像是一次漫长而又悲壮的旅行,经常被苹果拒之门外,无比煎熬。那么问题来了,我们有没有什么办法准确把握苹果审核准则,从而提升审核的通过率呢&#x…

Centos7快速安装Kibana并连接ES使用

Elasticsearch 提供了一个名为 Kibana 的官方可视化界面。Kibana 是一个开源的数据可视化和管理工具,用于 Elasticsearch。它提供了丰富的功能,如仪表板、图表、地图等,帮助您更好地理解、搜索和可视化存储在 Elasticsearch 中的数据。 在 C…

【软考备战·希赛网每日一练】2023年5月5日

文章目录 一、今日成绩二、错题总结第一题 三、知识查缺 题目及解析来源:2023年05月05日软件设计师每日一练 一、今日成绩 二、错题总结 第一题 解析: 有返回消息的就是同步消息;不需要等待返回消息就可以去做其他事情的请求消息就是异步消息…

从零基础到网络安全专家:全网最全的网络安全学习路线

前言 网络安全知识体系非常广泛,涉及的领域也非常复杂,有时候即使有想法和热情,也不知道从何入手。 为了帮助那些想要进入网络安全行业的小伙伴们更快、更系统地学习网络安全知识,我制定了这份学习路线。本路线覆盖了网络安全的…

网络协议与攻击模拟-03-ARP协议

ARP 协议(地址解析协议) 一、 ARP 协议 将一个已知的 IP 地址解析为 MAC 地址,从而进行二层数据交互 是一个三层的协议,但是工作在二层,是一个2.5层协议 二、工作流程 1、两个阶段 ARP 请求 ARP 相应 2、 ARP 协议…

Java 基础入门篇(三)—— 数组的定义与内存分配

文章目录 一、数组的定义1.1 静态初始化数组1.2 动态初始化数组1.3 数组的访问 二、数组的遍历三、数组的内存图 ★3.1 Java 的内存分配3.2 数组的内存图3.3 两个数组变量指向同一个数组对象 四、数组使用的常见问题补充:随机数 Random 类 一、数组的定义 数组就是…

黑盒测试过程中【测试方法】详解2-正交实验

在黑盒测试过程中,有9种常用的方法:1.等价类划分 2.边界值分析 3.判定表法 4.正交实验法 5.流程图分析 6.因果图法 7.输入域覆盖法 8.输出域覆盖法 9.猜错法 前面我们已经讲解过了等价类划分、边界值、判定表。 可以参考我之前的文章&#xff…

MySQL 常用命令

#--------------------------- #----cmd命令行连接MySql--------- cd C:\Program Files\MySQL\MySQL Server 5.5\bin # 启动mysql服务器 net start mysql # 关闭mysql服务器 net stop mysql # 进入mysql命令行 mysql -h localhost -u root -p 或mysql -u root -p #---------…

SPSS如何进行回归分析之案例实训?

文章目录 0.引言1.线性回归分析2.曲线回归分析3.非线性回归分析4.Logistic回归分析5.有序回归分析6.概率回归分析7.加权回归分析 0.引言 因科研等多场景需要进行数据统计分析,笔者对SPSS进行了学习,本文通过《SPSS统计分析从入门到精通》及其配套素材结合…
最新文章