【Pytroch】基于支持向量机算法的数据分类预测(Excel可直接替换数据)

【Pytroch】基于支持向量机算法的数据分类预测(Excel可直接替换数据)

  • 1.模型原理
  • 2.数学公式
  • 3.文件结构
  • 4.Excel数据
  • 5.下载地址
  • 6.完整代码
  • 7.运行结果

1.模型原理

支持向量机(Support Vector Machine,SVM)是一种强大的监督学习算法,用于二分类和多分类问题。它的主要思想是找到一个最优的超平面,可以在特征空间中将不同类别的数据点分隔开。

下面是使用PyTorch实现支持向量机算法的基本步骤和原理:

  1. 数据准备: 首先,你需要准备你的训练数据。每个数据点应该具有特征(Feature)和对应的标签(Label)。特征是用于描述数据点的属性,标签是数据点所属的类别。

  2. 数据预处理: 根据SVM的原理,数据点需要线性可分。因此,你可能需要进行一些数据预处理,如特征缩放或标准化,以确保数据线性可分。

  3. 定义模型: 在PyTorch中,你可以定义一个支持向量机模型作为一个线性模型,例如使用nn.Linear

  4. 定义损失函数: SVM的目标是最大化支持向量到超平面的距离,即最大化间隔(Margin)。这可以通过最小化损失函数来实现,通常使用hinge loss(合页损失)。PyTorch提供了nn.MultiMarginLoss损失函数,它可以用于SVM训练。

  5. 定义优化器: 选择一个优化器,如torch.optim.SGD,来更新模型的参数以最小化损失函数。

  6. 训练模型: 使用训练数据对模型进行训练。在每个训练步骤中,计算损失并通过优化器更新模型参数。

  7. 预测: 训练完成后,你可以使用训练好的模型对新的数据点进行分类预测。对于二分类问题,可以使用模型的输出值来判断数据点所属的类别。

2.数学公式

当使用支持向量机(SVM)进行数据分类预测时,目标是找到一个超平面(或者在高维空间中是一个超曲面),可以将不同类别的数据点有效地分隔开。以下是SVM的数学原理:

  1. 超平面方程: 在二维情况下,超平面可以表示为

    w 1 x 1 + w 2 x 2 + b = 0 w_1 x_1 + w_2 x_2 + b = 0 w1x1+w2x2+b=0

  2. 决策函数: 数据点 (x) 被分为两个类别的决策函数为

    f ( x ) = w T x + b f(x) = w^T x + b f(x)=wTx+b

  3. 间隔(Margin): 对于一个给定的超平面,数据点到超平面的距离被称为间隔。支持向量机的目标是找到能最大化间隔的超平面。间隔可以用下面的公式计算:

    间隔 = 2 ∥ w ∥ \text{间隔} = \frac{2}{\|w\|} 间隔=w2

  4. 支持向量: 支持向量是离超平面最近的那些数据点。这些点对于确定超平面的位置和间隔非常重要。支持向量到超平面的距离等于间隔。

  5. 最大化间隔: SVM 的目标是找到一个超平面,使得所有支持向量到该超平面的距离(即间隔)都最大化。这等价于最小化法向量的范数 (|w|),即:

    最小化 1 2 ∥ w ∥ 2 \text{最小化} \quad \frac{1}{2}\|w\|^2 最小化21w2

  6. 对偶问题和核函数: 对偶问题的解决方法涉及到拉格朗日乘子,可以得到一个关于训练数据点的内积的表达式。这样,如果直接在高维空间中计算内积是非常昂贵的,可以使用核函数来避免高维空间的计算。核函数将数据映射到更高维的空间,并在计算内积时使用高维空间的投影,从而实现了在高维空间中的计算,但在计算上却更加高效。

总之,SVM利用线性超平面来分隔不同类别的数据点,通过最大化支持向量到超平面的距离来实现分类。对偶问题和核函数使SVM能够处理非线性问题,并在高维空间中进行计算。以上是SVM的基本数学原理。

3.文件结构

在这里插入图片描述

iris.xlsx						% 可替换数据集
Main.py							% 主函数

4.Excel数据

在这里插入图片描述

5.下载地址

- 资源下载地址

6.完整代码

import torch
import torch.nn as nn
import pandas as pd
import numpy as np  # Don't forget to import numpy for the functions using it
import matplotlib.pyplot as plt  # Import matplotlib for plotting
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix

class SVM(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SVM, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)

    def forward(self, x):
        return self.linear(x)


def train(model, X, y, num_epochs, learning_rate):
    criterion = nn.CrossEntropyLoss()  # Use CrossEntropyLoss for multi-class classification
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        inputs = torch.tensor(X, dtype=torch.float32)
        labels = torch.tensor(y, dtype=torch.long)  # Use long for class indices

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')


def test(model, X, y):
    inputs = torch.tensor(X, dtype=torch.float32)
    labels = torch.tensor(y, dtype=torch.long)  # Use long for class indices

    with torch.no_grad():
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        accuracy = (predicted == labels).float().mean()
        print("真实值:", labels)
        print("预测值:", predicted)
        print(f'Accuracy on test set: {accuracy:.2f}')

# Define the plot functions
def plot_confusion_matrix(conf_matrix, classes):
    plt.figure(figsize=(8, 6))
    plt.imshow(conf_matrix, cmap=plt.cm.Blues, interpolation='nearest')
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.tight_layout()
    plt.show()

def plot_predictions_vs_true(y_true, y_pred):
    plt.figure(figsize=(10, 6))
    plt.plot(y_true, 'go', label='True Labels')
    plt.plot(y_pred, 'rx', label='Predicted Labels')
    plt.title("True Labels vs Predicted Labels")
    plt.xlabel("Sample Index")
    plt.ylabel("Class Label")
    plt.legend()
    plt.show()


def main():
    data = pd.read_excel('iris.xlsx')
    X = data.iloc[:, :-1].values
    y = data.iloc[:, -1].values

    label_encoder = LabelEncoder()
    y = label_encoder.fit_transform(y)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    num_classes = len(label_encoder.classes_)
    model = SVM(X_train.shape[1], num_classes)
    num_epochs = 1000
    learning_rate = 0.001

    train(model, X_train, y_train, num_epochs, learning_rate)

    # Call the test function to get predictions
    inputs = torch.tensor(X_test, dtype=torch.float32)
    labels = torch.tensor(y_test, dtype=torch.long)
    with torch.no_grad():
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)

    # Convert torch tensors back to numpy arrays
    y_true = labels.numpy()
    y_pred = predicted.numpy()

    test(model, X_test, y_test)

    # Call the plot functions
    conf_matrix = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(conf_matrix, label_encoder.classes_)
    plot_predictions_vs_true(y_true, y_pred)


if __name__ == '__main__':
    main()




7.运行结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/72867.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

换架 3D 飞机,继续飞呀飞

相信大多数图扑 HT 用户都曾见过这个飞机的 Demo,在图扑发展的这十年,这个 Demo 是许多学习 HT 用户一定会参考的经典 Demo 之一。 这个 Demo 用简洁的代码生动地展示了 OBJ 模型加载、数据绑定、动画和漫游等功能的实现。许多用户参考这个简单的 Demo 后…

Redis进阶(4)——结合redis.conf配置文件深入理解 Redis两种数据持久化方案:RDB和AOF

目录 引出持久化方案RDBAOF Redis的持久化方案RDB如果采用docker stop关闭如果采用强制关闭 AOF参数设置混编方式的加载让aof进行重写 两种持久化方案的优缺点AOF优缺点RDB优势和劣势 总结 引出 1.Redis数据持久化的两种方式,RDB和AOF; 2.RDB采用二进制存储&#xf…

QT报表Limereport v1.5.35编译及使用

1、编译说明 下载后QT CREATER中打开limereport.pro然后直接编译就可以了。编译后结果如下图: 一次编译可以得到库文件和DEMO执行程序。 2、使用说明 拷贝如下图编译后的lib目录到自己的工程目录中。 release版本的重新命名为librelease. PRO文件中配置 QT …

抖音小程序实现less语言编译样式

1.在抖音开发工具中搜索扩展less 2. 然后点击小齿轮选择扩展设置 3. 然后在扩展设置中选择在settings.json中编辑# 4. 在settings.json中加入以下这段代码即可 // Easy LESS配置"less.compile": {"compress": false,//是否压缩"sourceMap": fal…

virtualBox桥接模式下openEuler镜像修改IP地址、openEule修改IP地址、openEule设置IP地址

安装好openEuler后,设置远程登入前,必不可少的一步,主机与虚拟机之间的通信要解决,下面给出详细步骤: 第一步:检查虚拟机适配器模式:桥接模式 第二步:登入虚拟机修改IP cd /etc/sysconfig/network-scripts vim ifcfg-enpgs3 没有vim的安装或者用vi代替:sudo dnf …

优化开发体验:掌握VSCode配置Vue模板的技巧与方法

前言 当你使用了 VSCode 配置 vue 模版时,你会发现它就像是一位贴心的助手,本文将为大家介绍如何使用 VSCode 配置 vue 模板,让你的代码更加高效、美观。下面让我们一起来看看吧。 一、打开 VSCode 控制台 文件 --> 首选项 --> 用户片段…

HOT92-最小路径和

leetcode原题链接:最小路径和 题目描述 给定一个包含非负整数的 m x n 网格 grid ,请找出一条从左上角到右下角的路径,使得路径上的数字总和为最小。 说明:每次只能向下或者向右移动一步。 示例 1: 输入:…

运维工程师常见面试题

1、http常见返回码 2、mysql的同步方式 1)异步复制 MySQL默认的复制即是异步的,主库在执行完客户端提交的事务后会立即将结果返给给客户端,并不关心从库是否已经接收并处理,这样就会有一个问题,主如果crash掉了&a…

c++11 标准模板(STL)(std::basic_stringbuf)(二)

定义于头文件 <sstream> template< class CharT, class Traits std::char_traits<CharT>, class Allocator std::allocator<CharT> > class basic_stringbuf : public std::basic_streambuf<CharT, Traits> std::basic_stringbuf…

拒绝摆烂!C语言练习打卡第一天

&#x1f525;博客主页&#xff1a;小王又困了 &#x1f4da;系列专栏&#xff1a;每日一练 &#x1f31f;人之为学&#xff0c;不日近则日退 ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ &#x1f5d2;️前言&#xff1a; 在前面我们学习完C语言的所以知识&#xff0c;当…

PostgreSql 备份恢复

一、概述 数据库备份一般可分为物理备份和逻辑备份&#xff0c;其中物理备份又可分为物理冷备和物理热备&#xff0c;下面就各种备份方式进行详细说明&#xff08;一般情况下&#xff0c;生产环境采取的定时物理热备逻辑备份的方式&#xff0c;均是以下述方式为基础进一步研发编…

pytorch单机多卡后台运行

nohup sh ./train_chat.sh > train_chat20230814.log 2>1&参考资料 Pytorch单机多卡后台运行的解决办法

学C的第三十三天【C语言文件操作】

相关代码gitee自取&#xff1a; C语言学习日记: 加油努力 (gitee.com) 接上期&#xff1a; 学C的第三十二天【动态内存管理】_高高的胖子的博客-CSDN博客 1 . 为什么要使用文件 以前面写的通讯录为例&#xff0c;当通讯录运行起来的时候&#xff0c;可以给通讯录中增加、删…

DolphinDB 入选 Gartner《中国数据库市场指南》代表厂商

近日&#xff0c;国际知名研究机构 Gartner 发布2023年《中国 DBMS 市场指南&#xff08;Market Guide for DBMS, China&#xff09;》研究报告&#xff0c;在中国范围内评估并重点推荐了36家极具实力的企业&#xff0c;DolphinDB 以领先的技术和商业能力顺势入榜。 DolphinDB …

Python批量给excel文件加密

有时候我们需要定期给公司外部发邮件&#xff0c;在自动化发邮件的时候需要对文件进行加密传输。本文和你一起来探索用python给单个文件和批量文件加密。    python自动化发邮件可参考【干货】用Python每天定时发送监控邮件。 文章目录 一、安装pypiwin32包二、定义给excel加…

Vue.js 生命周期详解

Vue.js 是一款流行的 JavaScript 框架&#xff0c;它采用了组件化的开发方式&#xff0c;使得前端开发更加简单和高效。在 Vue.js 的开发过程中&#xff0c;了解和理解 Vue 的生命周期非常重要。本文将详细介绍 Vue 生命周期的四个阶段&#xff1a;创建、挂载、更新和销毁。 …

“中国软件杯”飞桨赛道晋级决赛现场名单公布

“中国软件杯”大学生软件设计大赛是由国家工业和信息化部、教育部、江苏省人民政府共同主办&#xff0c;是全国软件行业规格最高、最具影响力的国家级一类赛事&#xff0c;为《全国普通高校竞赛排行榜》榜单内赛事。今年&#xff0c;组委会联合百度飞桨共同设立了“智能系统设…

Profibus-DP转modbus RTU网关modbus rtu和tcp的区别

捷米JM-DPM-RTU网关在Profibus总线侧实现主站功能&#xff0c;在Modbus串口侧实现从站功能。可将ProfibusDP协议的设备&#xff08;如&#xff1a;EH流量计、倍福编码器等&#xff09;接入到Modbus网络中&#xff1b;通过增加DP/PA耦合器&#xff0c;也可将Profibus PA从站接入…

探究使用HTTP代理ip后无法访问网站的原因与解决方案

目录 访问网站的原理是什么 1. DNS解析 2. 建立TCP连接 3. 发送HTTP请求&#xff1a; 4. 服务器响应&#xff1a; 5. 浏览器渲染&#xff1a; 6. 页面展示&#xff1a; 使用代理IP后访问不了网站&#xff0c;有哪些方面的原因 1. 代理IP的可用性&#xff1a; 2. 代理…

VVIC-商品详情

一、接口参数说明&#xff1a; item_get-根据ID取商品详情&#xff0c;点击更多API调试&#xff0c;请移步注册API账号点击获取测试key和secret 公共参数 请求地址: https://api-gw.onebound.cn/vvic/item_get 名称类型必须描述keyString是调用key&#xff08;点击获取测试k…