PyTorch实现逻辑回归

最终效果

先看下最终效果:
1
这里用一条直线把二维平面上不同的点分开。

生成随机数据

#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))


#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导

n_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)

x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)

数据可视化

def plot(x, y, c):
    ax = plt.gca()
    sc = ax.scatter(x, y, color='black')
    paths = []
    for i in range(len(x)):
        if c[i].item() == 0:
            marker_obj = mmarkers.MarkerStyle('o')
        else:
            marker_obj = mmarkers.MarkerStyle('x')
        path = marker_obj.get_path().transformed(marker_obj.get_transform())
        paths.append(path)
    sc.set_paths(paths)
    return sc
plot(x, y, c)
plt.show()

使用x和o来表示两种不同类别的数据。
1

定义模型和损失函数

#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)  # 随机初始化w
b = torch.zeros((1),requires_grad=True)  # 使用0初始化b

wx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()

这里使用了平方损失函数来估算模型准确度。

训练模型

最多训练100次,每次都会更新模型参数,当损失值小于0.03时停止训练。

xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):
    #前向传播
    loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()
    #反向传播
    loss.backward()
    #更新参数
    b.data.sub_(lr*b.grad) # b = b - lr*b.grad
    w.data.sub_(lr*w.grad) # w = w - lr*w.grad
    #绘图
    if iteration % 3 == 0:
        plot(x, y, c)
        yy = w*xx + b
        plt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)
        plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})
        plt.xlim(-4,4)
        plt.ylim(-4,4)
        plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))
        plt.show()

        if loss.data.numpy() < 0.03:  # 停止条件
            break

全部代码

import torch
import matplotlib.pyplot as plt
import matplotlib.markers as mmarkers

#创建训练数据
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))


#构建线性回归参数
w = torch.randn((1))#随机初始化w,要用到自动梯度求导
b = torch.zeros((1))#使用0初始化b,要用到自动梯度求导

wx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b


n_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值为2.标准差为1.5的随机数组成的矩阵
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值为-2.标准差为1.5的随机数组成的矩阵
c1 = torch.ones(100)

x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)


def plot(x, y, c):
    ax = plt.gca()
    sc = ax.scatter(x, y, color='black')
    paths = []
    for i in range(len(x)):
        if c[i].item() == 0:
            marker_obj = mmarkers.MarkerStyle('o')
        else:
            marker_obj = mmarkers.MarkerStyle('x')
        path = marker_obj.get_path().transformed(marker_obj.get_transform())
        paths.append(path)
    sc.set_paths(paths)
    return sc
plot(x, y, c)
plt.show()


#构建逻辑回归参数
w = torch.tensor([1.,],requires_grad=True)#随机初始化w
b = torch.zeros((1),requires_grad=True)#使用0初始化b

wx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()

xx = torch.arange(-4, 5)
lr = 0.02 #学习率
for iteration in range(100):
    #前向传播
    loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()
    #反向传播
    loss.backward()
    #更新参数
    b.data.sub_(lr*b.grad) # b = b - lr*b.grad
    w.data.sub_(lr*w.grad) # w = w - lr*w.grad
    #绘图
    if iteration % 3 == 0:
        plot(x, y, c)
        yy = w*xx + b
        plt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)
        plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})
        plt.xlim(-4,4)
        plt.ylim(-4,4)
        plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))
        plt.show()

        if loss.data.numpy() < 0.03:#停止条件
            break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/232025.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【开源】基于Vue和SpringBoot的衣物搭配系统

项目编号&#xff1a; S 016 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S016&#xff0c;文末获取源码。} 项目编号&#xff1a;S016&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容2.1 衣物档案模块2.2 衣物搭配模块2.3 衣…

深度模型训练时CPU或GPU的使用model.to(device)

一、使用device控制使用CPU还是GPU device torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 单GPU或者CPU.先判断机器上是否存在GPU&#xff0c;没有则使用CPU训练 model model.to(device) data data.to(device)#或者在确定有GPU的…

python+pytest接口自动化之参数关联

什么是参数关联&#xff1f; 参数关联&#xff0c;也叫接口关联&#xff0c;即接口之间存在参数的联系或依赖。在完成某一功能业务时&#xff0c;有时需要按顺序请求多个接口&#xff0c;此时在某些接口之间可能会存在关联关系。比如&#xff1a;B接口的某个或某些请求参数是通…

TA-Lib学习研究笔记(九)——Pattern Recognition (1)

TA-Lib学习研究笔记&#xff08;九&#xff09;——Pattern Recognition &#xff08;1&#xff09; 0.程序代码 形态识别的函数的应用&#xff0c;通过使用A股实际的数据&#xff0c;验证形态识别函数&#xff0c;用K线显示出现标志的形态走势&#xff0c;由于入口参数基本上…

大学生有担当,乡村振兴新亮点“艺术点亮乡村,创意引领未来”

12月7日上午&#xff0c;由花都区文化馆&#xff08;区非物质文化遗产保护中心&#xff09;指导&#xff0c;广州工商学院主办&#xff0c;广州工商学院国际教育学院承办&#xff0c;花都区文化馆炭步分馆、广州盛美文化传播有限公司协办的广州工商学院国际教育学院视觉传达设计…

C++新经典模板与泛型编程:策略类模板

策略类模板 在前面的博文中&#xff0c;策略类SumPolicy和MinPolicy都是普通的类&#xff0c;其中包含的是一个静态成员函数模板algorithm()&#xff0c;该函数模板包含两个类型模板参数。其实&#xff0c;也可以把SumPolicy和MinPolicy类写成类模板—直接把algorithm()中的两…

C/C++,树算法——二叉树的插入、移除、合并及遍历算法之源代码

1 文本格式 #include<iostream>; using namespace std; // A BTree node class BTreeNode { int* keys; // An array of keys int t; // Minimum degree (defines the range for number of keys) BTreeNode** C; // An array of child pointers …

SAP FICO S_ALR_87013611 报表列宽度的调整

如何去调整&#xff1f; 选中对应的列 菜单-设置-列属性 连起来

十一、了解分布式计算

1、什么是&#xff08;数据&#xff09;计算&#xff1f; 2、分布式(数据)计算 &#xff08;1&#xff09;概念 顾名思义&#xff0c;分布式计算&#xff0c;即以分布式的形式完成数据的统计&#xff0c;得到需要的结果。 分布式数据计算&#xff0c;顾名思义&#xff0c;就是…

idea开发环境配置

idea重新安装后&#xff0c;配置的东西还挺多的&#xff0c;这里简单记录一下。 1、基础配置 1.1、主题、背景、主题字体大小 1.2、默认字体设置 控制台默认编码设置&#xff1a; 全局文件默认编码设置&#xff1a; 2、构建、编译、部署配置 说明&#xff1a;本地装了JD…

【Java基础篇 | 面向对象】—— 聊聊什么是接口(下篇)

个人主页&#xff1a;兜里有颗棉花糖 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 兜里有颗棉花糖 原创 收录于专栏【JavaSE_primary】 本专栏旨在分享学习JavaSE的一点学习心得&#xff0c;欢迎大家在评论区交流讨论&#x1f48c; 上篇&#xff08;【Ja…

学习Linux(1)-开始前的准备

一、Linux介绍 如图,“Linux的发行版说简单点就是将Linux内核与应用软件做一个打包”&#xff0c;所以&#xff0c;我们要学习Linux&#xff0c;就要选择一个趁手的应用软件&#xff0c;通常使用较多的有centerOs、Ubuntu。本文将基于centerOs6进行学习。 二、安装环境 使用Li…

认识线程和创建线程

目录 1.认识多线程 1.1线程的概念 1.2进程和线程 1.2.1进程和线程用图描述关系 1.2.2进程和线程的区别 1.3Java 的线程和操作系统线程的关系 2.创建线程 2.1继承 Thread 类 2.2实现 Runnable 接口 2.3匿名内部类创建 Thread 子类对象 2.4匿名内部类创建 Runnable 子类对…

SAP UI5 walkthrough step7 JSON Model

这个章节&#xff0c;帮助我们理解MVC架构中的M 我们将会在APP中新增一个输入框&#xff0c;并将输入的值绑定到model&#xff0c;然后将其作为描述&#xff0c;直接显示在输入框的右边 首先修改App.controllers.js webapp/controller/App.controller.js sap.ui.define([&…

教师需要什么技能?

作为一名老师&#xff0c;需要掌握许多技能&#xff0c;以便能够成功地教育和指导学生。以下是一些关键技能&#xff1a; 1.教学技能&#xff1a;老师需要有深入的学科知识和教学经验&#xff0c;以便能够有效地传授知识。教师应该了解如何设计和执行教学计划&#xff0c;制定课…

Java、JDK、JRE、JVM

Java、JDK、JRE、JVM 一、 Java 广义上看&#xff0c;Kotlin、JRuby等运行于Java虚拟机上的编程语言以及相关的程序都属于Java体系的一员。从传统意义上看&#xff0c;Java社区规定的Java技术体系包括以下几个部分&#xff1a; Java程序设计语言各种硬件平台上的Java虚拟机实…

JFrog----基于Docker方式部署JFrog

文章目录 1 下载镜像2 创建数据挂载目录3 启动 JFrog服务4 浏览器登录5 重置密码6 设置 license7 设置 Base URL8 设置代理9 选择仓库类型10 预览11 查看结果 1 下载镜像 免费版 docker pull docker.bintray.io/jfrog/artifactory-oss体验版&#xff1a; docker pull releas…

论文导读|10月MSOM文章精选:智慧医疗

编者按 在“10月MSOM文章精选&#xff1a;智慧医疗”中&#xff0c;我们有主题、有针对性地选择了MSOM期刊杂志中一些有关智慧医疗领域的有趣文章&#xff0c;不但对文章的内容进行了概括与点评&#xff0c;而且也对文章的结构进行了梳理&#xff0c;旨在激发广大读者的阅读兴趣…

vue预览pdf,放大缩小拖动,dialog拖动,父页面滚动

公共组件部分代码 main.js import draggable from /directive/drag/index Vue.use(draggable) pdf组件部分代码

1-3、Java反编译

语雀原文链接 文章目录 1、JD-GUI反编译下载1-1、打开class文件无反应 1、JD-GUI反编译下载 http://java-decompiler.github.io jd-gui-windows-1.6.6.zip 1-1、打开class文件无反应 目前是可以正常打jar包文件&#xff0c;但是在直接打开.class文件时软件会卡住。首先将要…
最新文章