使用pytorch利用神经网络原理进行图片的训练(持续学习中....)

1.做这件事的目的
语言只是工具,使用python训练图片数据,最终会得到.pth的训练文件,java有使用这个文件进行图片识别的工具,顺便整合,我觉得Neo4J正确率太低了,草莓都能识别成为苹果,而且速度慢,不能持续识别视频帧

2.什么是神经网络?(其实就是数学的排列组合最终得到统计结果的概率)

1.先把二维数组转为一维
2.通过公式得到节点个数和值
3…同2
4.通过节点得到概率(softmax归一化公式)
5.对比模型的和 差值=原始概率-目标结果概率
6.不断优化原来模型的概率
5.激活函数,激活某个节点的函数,可以引入非线性的(因为所有问题不可能是线性的比如 很少图片识别一定可以识别出绝对的正方形,他可能中间有一定弯曲或者线在中心短开了)

在这里插入图片描述
在这里插入图片描述

3.训练的代码
//环境python3.8 最好使用conda进行版本管理,不然每个版本都可能不兼容,到处碰壁

 #安装依赖
 pip install numpy torch torchvision matplotlib

#文件夹结构,图片一定要是28x28的
在这里插入图片描述

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolder

class Net(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(28 * 28, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x

#导入数据
def get_data_loader(is_train):
     #张量,多维数组
    to_tensor = transforms.Compose([transforms.ToTensor()])
     # 下载数据集 下载目录
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
     #一个批次15张,顺序打乱
    return DataLoader(data_set, batch_size=15, shuffle=True)

def get_image_loader(folder_path):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    data_set = ImageFolder(folder_path, transform=to_tensor)
    return DataLoader(data_set, batch_size=1)

#评估准确率
def evaluate(test_data, net):
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        #按批次取数据
        for (x, y) in test_data:
            #计算神经网络预测值
            outputs = net.forward(x.view(-1, 28 * 28))

            for i, output in enumerate(outputs):
                #比较预测结果和测试集结果
                if torch.argmax(output) == y[i]:
                    #统计正确预测结果数
                    n_correct += 1
                #统计全部预测结果
                n_total += 1
        #返回准确率=正确/全部的
    return n_correct / n_total


def main():
    #加载训练集
    train_data = get_data_loader(is_train=True)
    #加载测试集
    test_data = get_data_loader(is_train=False)
    #初始化神经网络
    net = Net()
    #打印测试网络的准确率 0.1
    print("initial accuracy:", evaluate(test_data, net))
    #训练神经网络
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    #重复利用数据集 2次
    for epoch in range(100):
        for (x, y) in train_data:
            #初始化 固定写法
            net.zero_grad()
            #正向传播
            output = net.forward(x.view(-1, 28 * 28))
            #计算差值
            loss = torch.nn.functional.nll_loss(output, y)
            #反向误差传播
            loss.backward()
            #优化网络参数
            optimizer.step()
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))
    # #使用3张图片进行预测
    # for (n, (x, _)) in enumerate(test_data):
    #     if n > 3:
    #         break
    #     predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
    #     plt.figure(n)
    #     plt.imshow(x[0].view(28, 28))
    #     plt.title("prediction: " + str(int(predict)))
    # plt.show()
    image_loader = get_image_loader("aa")

    for (n, (x, _)) in enumerate(image_loader):
        if n > 2:
            break
        predict = torch.argmax(net.forward(x.view(-1, 28 * 28)))
        plt.figure(n)
        plt.imshow(x[0].permute(1, 2, 0))
        plt.title("prediction: " + str(int(predict)))
    plt.show()


if __name__ == "__main__":
    main()


#运行结果 弹框出现图片和识别结果

4.测试电脑的cuda是否安装成功,不成功不能运行下面的代码

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('CUDA version:', torch.version.cuda)
print('PyTorch version:', torch.__version__)

5.在gpu上运行,需要去官网下载cuda安装
https://developer.nvidia.com/cuda-toolkit-archive
#并且需要安装和torch对应的版本,我的电脑是1660ti的所以安装了10.2的cuda
#安装torchgpu版本

pip install torch==1.9.0+cu102 -f
https://download.pytorch.org/whl/cu102/torch_stable.html

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Net(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(28 * 28, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x

def get_data_loader(is_train):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
    return DataLoader(data_set, batch_size=15, shuffle=True)

def get_image_loader(folder_path):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    data_set = ImageFolder(folder_path, transform=to_tensor)
    return DataLoader(data_set, batch_size=1)

def evaluate(test_data, net):
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            x, y = x.to(device), y.to(device)
            outputs = net.forward(x.view(-1, 28 * 28))
            for i, output in enumerate(outputs):
                if torch.argmax(output.cpu()) == y[i].cpu():
                    n_correct += 1
                n_total += 1
    return n_correct / n_total

def main():
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = Net().to(device)
    print("initial accuracy:", evaluate(test_data, net))
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(100):
        for (x, y) in train_data:
            x, y = x.to(device), y.to(device)
            net.zero_grad()
            output = net.forward(x.view(-1, 28 * 28))
            loss = torch.nn.functional.nll_loss(output, y)
            loss.backward()
            optimizer.step()
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))
    image_loader = get_image_loader("aa")

    for (n, (x, _)) in enumerate(image_loader):
        if n > 2:
            break
        x = x.to(device)
        predict = torch.argmax(net.forward(x.view(-1, 28 * 28)).cpu())
        plt.figure(n)
        plt.imshow(x[0].permute(1, 2, 0).cpu())
        plt.title("prediction: " + str(int(predict)))
    plt.show()

if __name__ == "__main__":
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/172898.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

算法分析与设计课后练习23

求下面的0-1背包问题 (1)N5,M12,(p1,p2,…,p5)(10,15,6,8,4),(w1,w2,…,w5)(4,6,3,4,2) (2)N5,M15,(p1,p2,…,p5)(w1,w2,…,w5)(4,4,5,8,9)

深入理解JSON及其在Java中的应用

✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏:每天一个知识点 ✨特色专栏&#xff1a…

日常办公:批处理编写Word邮件合并获取图片全路径

大家在使用Word邮件合并这个功能,比如制作席卡、贺卡、准考证、员工档案、成绩单、邀请函、名片等等,那就需要对图片路径进行转换处理,此脚本就是直接将图片的路径提取出来,并把内容放到txt格式的文本文档里,打开Excel…

netty整合websocket(完美教程)

websocket的介绍: WebSocket是一种在网络通信中的协议,它是独立于HTTP协议的。该协议基于TCP/IP协议,可以提供双向通讯并保有状态。这意味着客户端和服务器可以进行实时响应,并且这种响应是双向的。WebSocket协议端口通常是80&am…

Redis:抢单预热

前言 在当今的互联网时代,抢单活动已经成为了电商平台、外卖平台等各种电子商务平台中常见的营销手段。通过抢单活动,商家可以吸引大量用户参与,从而提高销量和知名度。然而,抢单活动所带来的高并发请求往往会给系统带来巨大的压…

opencv-形态学处理

通过阈值化分割可以得到二值图,但往往会出现图像中物体形态不完整,变的残缺,可以通过形态学处理,使其变得丰满,或者去除掉多余的像素。常用的形态学处理算法包括:腐蚀,膨胀,开运算&a…

Spring-IOC-@Import的用法

1、Car.java package com.atguigu.ioc; import lombok.Data; Data public class Car {private String cname; }2、 MySpringConfiguration2.java package com.atguigu.ioc; import org.springframework.context.annotation.Bean; import org.springframework.context.annotatio…

一、防火墙-基础知识

学习防火墙之前,对路由交换应要有一定的认识 1、什么是防火墙2、防火墙的发展史3、安全区域3.1.接口、网络和安全区域的关系3.2.报文在安全区域之间流动方向3.3.安全区域的配置安全区域小实验 3.4.状态检测和会话机制3.4.1.状态检测3.4.2.会话 3.5.状态检测和会话机…

c语言-数据结构-链式二叉树

目录 1、二叉树的概念及结构 2、二叉树的遍历概念 2.1 二叉树的前序遍历 2.2 二叉树的中序遍历 2.3 二叉树的后序遍历 2.4 二叉树的层序遍历 3、创建一颗二叉树 4、递归方法实现二叉树前、中、后遍历 4.1 实现前序遍历 4.2 实现中序遍历 4.3 实现后序遍历 5、…

《算法通关村——最长公共前缀问题解析》

《算法通关村——最长公共前缀问题解析》 14. 最长公共前缀 编写一个函数来查找字符串数组中的最长公共前缀。 如果不存在公共前缀,返回空字符串 ""。 示例 1: 输入:strs ["flower","flow","flight…

腾讯云代金券怎么领取(腾讯云代金券在哪领取)

腾讯云代金券是可抵扣费用的优惠券,领券之后新购、续费、升级腾讯云相关云产品可以直接抵扣订单金额,节省购买腾讯云的费用,本文将详细介绍腾讯云代金券的领取方法和使用教程。 一、腾讯云代金券领取 1、新用户代金券【点此领取】 2、老用户…

Unity中Shader的PBR的基础知识与理论

文章目录 前言一、什么是PBR二、什么是PBS在这里插入图片描述 三、PBS的核心理论1、物质的光学特性(Substance Optical Properties)2、微平面理论(Microfacet Theory)3、能量守恒(Energy Conservation)4、菲…

90%的测试工程师是这样使用Postman做接口测试的...

一:接口测试前准备 接口测试是基于协议的功能黑盒测试,在进行接口测试之前,我们要了解接口的信息,然后才知道怎么来测试一个接口,如何完整的校验接口的响应值。 那么问题来了,那接口信息从哪里获取呢&…

金山云2023年Q3财报:持续向好!

11月21日,金山云公布了2023年第三季度业绩。 财报显示,金山云Q3营收16.3亿元,调整后毛利率达12.1%再创历史新高,调整后毛利额同比上涨57.5%。今年第三季度,公有云实现收入10.2亿元,毛利率达到4.7%&#xf…

STM32出现 Invalid Rom Table 芯片锁死解决方案

出现该现象的原因为板子外部晶振为25M,而程序软件上以8M为输入晶振频率,导致芯片超频锁死,无法连接、下载。 解决方案 断电,将芯片原来通过10k电阻接地的BOOT0引脚直接接3.3V,硬件上置1上电,连接目标板&am…

Redis跳跃表

前言 跳跃表(skiplist)是一种有序数据结构,它通过在每一个节点中维持多个指向其他节点的指针,从而达到快速访问节点的目的。 跳跃表支持平均O(logN),最坏O(N),复杂度的节点查找,还可以通过顺序性来批量处理节点…

ROS2中Executors对比和优化

目录 SingleThreadExecutorEventExecutor SingleThreadExecutor 执行流程 EventExecutor 通信图

局域网文件共享神器:Landrop

文章目录 前言解决方案Landrop软件界面手机打开效果 软件操作 前言 平常为了方便传文件,我们都是使用微信或者QQ等聊天软件,互传文件。这样传输有两个问题: 必须登录微信或者QQ聊天软件。手机传电脑还有网页版微信,电脑传手机比…

MT8735/MTK8735安卓核心板规格参数介绍

MT8735核心板是一款高性能的64位Cortex-A53四核处理器,设计用于在4G智能设备上运行安卓操作系统。这款多功能核心板支持LTE-FDD/LTE-TDD/WCDMA/TD-SCDMA/EVDO/CDMA/GSM等多种网络标准,同时还具备WiFi 802.11a/b/g/n和BT4.0LE等无线通信功能。此外&#x…

pipeline传参给job

场景:pipeline实现自动部署,job实现自动测试,但是只有部署dddd环境时,才调自动测试的job,所以需要在调自动测试job时,把参数传给测试job 上一个任务会显示下一步调谁 ------------------------------------…