[oneAPI] 手写数字识别-LSTM

[oneAPI] 手写数字识别-LSTM

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.01

加载数据

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False,
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

模型

# Recurrent neural network (many-to-one)
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # Set initial hidden and cell states 
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)

        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

训练过程

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))

# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

结果

在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

# 模型
model = ConvNet(num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/82918.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

八种架构演进

日升时奋斗,日落时自省 目录 1、单机架构 2、应用数据分离架构 3、应用服务集群架构 4、读写分离/主从分离架构 5、冷热分离架构 6、垂直分库架构 7、微服务架构 8、容器编排架构 9、小结 1、单机架构 特征:应用服务和数据库服务器公用一台服务…

LVS负载均衡群-DR模式

目录 1、lvs-dr数据包流向分析 2、lvs-dr中ARP问题 3、lvs-dr特性 4、lvs-dr集群-构建 1、lvs-dr数据包流向分析 (1)客户端发送请求到 Director Server(负载均衡器),请求的数据报文(源 IP 是 CIP,目标 IP …

【校招VIP】CSS校招考点之选择器优先级

考点介绍: 选择器是CSS的基础,也是校招中的高频考点,特别是复合选择器的执行优先级,同时也是实战中样式不生效的跟踪依据。 因为选择器的种类较多,很难直接记忆,可以考虑选择一个相对值,比如id类…

回归预测 | MATLAB实现BO-SVM贝叶斯优化支持向量机多输入单输出回归预测(多指标,多图)

回归预测 | MATLAB实现BO-SVM贝叶斯优化支持向量机多输入单输出回归预测(多指标,多图) 目录 回归预测 | MATLAB实现BO-SVM贝叶斯优化支持向量机多输入单输出回归预测(多指标,多图)效果一览基本介绍程序设计…

Java自学到什么程度就可以去找工作了?

引言 Java作为一门广泛应用于软件开发领域的编程语言,对于初学者来说,了解到什么程度才能开始寻找实习和入职机会是一个常见的问题。 本文将从实习和入职这两个方面,分点详细介绍Java学习到什么程度才能够开始进入职场。并在文章末尾给大家安…

k8s集群部署vmalert和prometheusalert实现钉钉告警

先决条件 安装以下软件包:git, kubectl, helm, helm-docs,请参阅本教程。 1、安装 helm wget https://xxx-xx.oss-cn-xxx.aliyuncs.com/helm-v3.8.1-linux-amd64.tar.gz tar xvzf helm-v3.8.1-linux-amd64.tar.gz mv linux-amd64/helm /usr/local/bin…

网络编程(TCP和UDP的基础模型)

一、TCP基础模型&#xff1a; tcp Ser&#xff1a; #include <stdio.h> #include <sys/types.h> #include <sys/socket.h> #include <arpa/inet.h> #include <netinet/in.h> #include <string.h> #include <head.h>#define PORT 88…

无涯教程-Perl - wait函数

描述 该函数等待子进程终止,返回已故进程的进程ID。进程的退出状态包含在$?中。 语法 以下是此函数的简单语法- wait返回值 如果没有子进程,则此函数返回-1,否则将显示已故进程的进程ID Perl 中的 wait函数 - 无涯教程网无涯教程网提供描述该函数等待子进程终止,返回已故…

Java 项目日志实例:LogBack

点击下方关注我&#xff0c;然后右上角点击...“设为星标”&#xff0c;就能第一时间收到更新推送啦~~~ LogBack 和 Log4j 都是开源日记工具库&#xff0c;LogBack 是 Log4j 的改良版本&#xff0c;比 Log4j 拥有更多的特性&#xff0c;同时也带来很大性能提升。LogBack 官方建…

Pytorch建立MyDataLoader过程详解

简介 torch.utils.data.DataLoader(dataset, batch_size1, shuffleNone, samplerNone, batch_samplerNone, num_workers0, collate_fnNone, pin_memoryFalse, drop_lastFalse, timeout0, worker_init_fnNone, multiprocessing_contextNone, generatorNone, *, prefetch_factorN…

TensorFlow2.1 模型训练使用

文章目录 1、环境安装搭建2、神经网络2.1、解决线性问题2.2、FAshion MNIST数据集使用 3、卷积神经网络3.1、卷积神经网络使用3.2、ImageDataGenerator使用3.3、猫狗识别案例3.4、参数优化 3.5、石头剪刀布案例4、词条化4.1、讽刺数据集的词条化和序列化4.2、词嵌入 1、环境安装…

[保研/考研机试] KY11 二叉树遍历 清华大学复试上机题 C++实现

题目链接&#xff1a; 二叉树遍历_牛客题霸_牛客网编一个程序&#xff0c;读入用户输入的一串先序遍历字符串&#xff0c;根据此字符串建立一个二叉树&#xff08;以指针方式存储&#xff09;。题目来自【牛客题霸】https://www.nowcoder.com/share/jump/43719512169254700747…

实数信号的傅里叶级数研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

(2)、将SpringCache扩展功能封装为starter

(2)、将SpringCache扩展功能封装为starter 1、准备工作 前面我们写了一个common-cache模块,尽可能的将自定义的RedisConnectionFactory, RedisTemplate, RedisCacheManager等Bean封装了起来。 就是为了方便我们将其封装为一个Starter。 我们这里直接《SpringCache+Redis实…

微信个人号开发,实现机器人辅助社群操作

微信 iPad 协议是指用于在 iPad 设备上使用微信应用的技术协议。一般来说&#xff0c;通过该协议可以将微信账号同步到 iPad 设备上&#xff0c;并且可以在 iPad 上发送和接收微信消息&#xff0c;查看好友列表、聊天记录等功能。微信 iPad 协议是通过私有API实现的。 需要一定…

嵌入式:ARM Day6

作业:完成cortex-A7核UART总线实验 目的&#xff1a;1.输入a,显示b&#xff0c;将输入的字符的ASCII码下一位字符输出 2.原样输出输入的字符串 源码&#xff1a; uart4.h #ifndef __UART4_H__ #define __UART4_H__#include "stm32mp1xx_rcc.h" #incl…

九、Linux下,如何在命令行进入文本编辑页面?

1、文本编辑基础 说到文本编辑页面&#xff0c;那就必须提到vi和vim&#xff0c;两者都是Linux系统中&#xff0c;常用的文本编辑器 2、三种工作模式 3、使用方法 &#xff08;1&#xff09;在进入Linux系统&#xff0c;在输入vim text.txt之后&#xff0c;会进入文本编辑中&…

Python多组数据三维绘图系统

文章目录 增添和删除坐标数据更改绘图逻辑源代码 Python绘图系统&#xff1a; 基础&#xff1a;将matplotlib嵌入到tkinter &#x1f4c8;简单的绘图系统 &#x1f4c8;数据导入&#x1f4c8;三维绘图系统自定义控件&#xff1a;坐标设置控件&#x1f4c9;坐标列表控件 增添和…

【力扣】84. 柱状图中最大的矩形 <模拟、双指针、单调栈>

目录 【力扣】84. 柱状图中最大的矩形题解暴力求解双指针单调栈 【力扣】84. 柱状图中最大的矩形 给定 n 个非负整数&#xff0c;用来表示柱状图中各个柱子的高度。每个柱子彼此相邻&#xff0c;且宽度为 1 。求在该柱状图中&#xff0c;能够勾勒出来的矩形的最大面积。 示例…

使用ChatGPT进行创意写作的缺点

Open AI警告ChatGPT的使用者要明白此工具的局限性&#xff0c;更不应完全依赖。作为一位创作者&#xff0c;这一点非常重要&#xff0c;应尽可能地避免让版权问题或不必要的文体问题出现在自己的作品中。[1] 毕竟使用ChatGPT进行创意写作目前还有以下种种局限或缺点[2]&#xf…