神经网络案例实战

🔎我们通过一个案例详细使用PyTorch实战 ,案例背景:你创办了一家手机公司,不知道如何估算手机产品的价格。为了解决这个问题,收集了多家公司的手机销售数据:这些数据维度可以包括RAM、存储容量、屏幕尺寸、摄像头像素等。

在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。

🔎思路:

  1. 数据预处理:对收集到的数据进行清洗和预处理,确保数据的质量和一致性。这包括处理缺失值、异常值和重复数据等。

  2. 特征工程:从原始数据中提取有用的特征,以便用于建模。这可以包括对连续型特征进行归一化或标准化,对分类特征进行编码等。

  3. 模型选择:选择一个适合的机器学习算法来建立模型,这里我们使用神经网络模型。

  4. 模型训练:将收集到的数据划分为训练集和测试集。使用训练集来训练模型,通过调整模型的参数来最小化预测误差。

  5. 模型评估:使用测试集来评估模型的性能。常用的评估指标包括均方误差(MSE)、均方根误差(RMSE)和平均绝对误差(MAE)等。

  6. 价格预测:使用训练好的模型来预测新手机的价格。输入新手机的功能特征,模型将输出预测的价格。

构建数据集 

💬数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便后期构造数据集加载对象。

def create_dataset():

    data = pd.read_csv('predict.csv')

    # 特征值和目标值
    x, y = data.iloc[:, :-1], data.iloc[:, -1]
    x = x.astype(np.float32)
    y = y.astype(np.int64)

    # 数据集划分
    x_train, x_valid, y_train, y_valid = train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)

    # 构建数据集
    train_dataset = TensorDataset(torch.tensor(x_train.values), torch.tensor(y_train.values))
    valid_dataset = TensorDataset(torch.tensor(x_valid.values), torch.tensor(y_valid.values))

    return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))


train_dataset, valid_dataset, input_dim, class_num = create_dataset()

💬其中 train_test_split 方法中的stratify=y参数的作用是在划分训练集和验证集时,保持类别的比例相同。这样可以确保在训练集和验证集中各类别的比例与原始数据集中的比例相同,有助于提高模型的泛化能力,防止出现一份中某个类别只有几个。

构建分类网络模型

# 构建网络模型
class PhonePriceModel(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(PhonePriceModel, self).__init__()

        self.linear1 = nn.Linear(input_dim, 128) # 输入层到隐藏层的线性变换
        self.linear2 = nn.Linear(128, 256)  # 隐藏层到输出层的线性变换
        self.linear3 = nn.Linear(256, output_dim)  # 输出层到最终输出的线性变换

    def _activation(self, x):
        return torch.sigmoid(x)

    def forward(self, x):

        x = self._activation(self.linear1(x))
        # 通过第一层线性变换和激活函数
        x = self._activation(self.linear2(x))
        # 通过第二层线性变换和激活函数
        output = self.linear3(x)
        # 通过第三层线性变换得到最终输出

        return output
  • self.linear1self.linear2之间的线性变换将输入维度从input_dim映射到128个神经元,然后再将128个神经元映射到256个神经元。
  • 我们通过self._activation方法对输入数据进行激活函数处理。具体来说,它使用Sigmoid激活函数对输入数据进行非线性变换,将线性变换后的输出映射到0和1之间。
  • 第三层: 输入为维度为 256, 输出维度为: 4

编写训练函数 

def train():

    # 固定随机数种子
    torch.manual_seed(66)

    
    model = PhonePriceModel(input_dim, class_num)
    # 损失函数
    criterion = nn.CrossEntropyLoss()
    # 优化方法
    optimizer = optim.SGD(model.parameters(), lr=1e-3)
    
    num_epoch = 50

    for epoch_idx in range(num_epoch):

        # 初始化数据加载器
        dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)
        # 训练时间
        start = time.time()
        # 计算损失
        total_loss = 0.0
        total_num = 1
        # 准确率
        correct = 0

        for x, y in dataloader:

            output = model(x)
            # 计算损失
            loss = criterion(output, y)
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 参数更新
            optimizer.step()

            total_num += len(y)
            total_loss += loss.item() * len(y)

        print('epoch: %4s loss: %.2f, time: %.2fs' %
              (epoch_idx + 1, total_loss / total_num, time.time() - start))

    # 模型保存
    torch.save(model.state_dict(), 'price-model.bin')
  • 💬要在PyTorch中查看随机数种子,可以使用torch.random.initial_seed()函数。这个函数会返回当前设置的随机数种子。 

评估函数

def test():

    # 加载模型
    model = PhonePriceModel(input_dim, class_num)
    model.load_state_dict(torch.load('price-model.bin'))

    # 构建加载器
    dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

    # 评估测试集
    correct = 0
    for x, y in dataloader:

        output = model(x)
        y_pred = torch.argmax(output, dim=1)
        correct += (y_pred == y).sum()

    print('Acc: %.5f' % (correct.item() / len(valid_dataset)))

网络性能调优

  1. 对输入数据进行标准化
  2. 调整优化方法
  3. 调整学习率
  4. 增加批量归一化层
  5. 增加网络层数、神经元个数
  6. 增加训练轮数

🔎我们可以自行将优化方法由 SGD 调整为 Adam , 学习率由 1e-3 调整为 1e-4 ,对数据数据进行标准化 ,增加网络深度 。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/604538.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JavaScript数字分隔符

● 如果现在我们用一个很大的数字,例如2300000000,这样真的不便于我们进行阅读,我们希望用千位分隔符来隔开它,例如230,000,000; ● 下面我们使用_当作分隔符来尝试一下 const diameter 287_266_000_000; console.log(diameter)…

论文分享[cvpr2018]Non-local Neural Networks非局部神经网络

论文 https://arxiv.org/abs/1711.07971 代码https://github.com/facebookresearch/video-nonlocal-net 非局部神经网络 motivation:受计算机视觉中经典的非局部均值方法[4]的启发,非局部操作将位置的响应计算为所有位置的特征的加权和。 非局部均值方法 NLM&#…

Python实现Chiikawa

写在前面 哈?呀哈!本期小编给大家素描版Chiikawa! 主人公当然是我们可爱的吉伊、小八以及乌萨奇啦~ Chiikawa小小可爱 《Chiikawa》是一部来自日本的超萌治愈系漫画与动画作品,由作者秋田祯信创作。"Chiikawa"这个名字…

【Kolmogorov-Arnold网络 替代多层感知机MLPs】KAN: Kolmogorov-Arnold Networks

KAN: Kolmogorov-Arnold Networks 论文地址 代码地址 知乎上的讨论(看一下评论区更正) Abstract Inspired by the Kolmogorov-Arnold representation theorem, we propose Kolmogorov-Arnold Networks (KANs) as promising alternatives to Multi-Layer…

支持LLM的Markdown笔记;ComfyUI-HiDiffusion图片生成和对图像进行高质量编辑

✨ 1: ComfyUI-HiDiffusion ComfyUI-HiDiffusion是一个为HiDiffusion技术使用而定制的节点。HiDiffusion技术是专门用于在计算机视觉和图像处理中生成和改进图片质量的先进算法。该技术通常应用于图像的超分辨率、去噪、风格转换等方面。 ComfyUI-HiDiffusion的主要特点包含提…

Julia 语言环境安装与使用

1、Julia 语言环境安装 安装教程:https://www.runoob.com/julia/julia-environment.html Julia 安装包下载地址为:https://julialang.org/downloads/。 安装步骤:注意(勾选 Add Julia To PATH 自动将 Julia 添加到环境变量&…

(五)JSP教程——response对象

response对象主要用于动态响应客户端请求(request),然后将JSP处理后的结果返回给客户端浏览器。JSP容器根据客户端的请求建立一个默认的response对象,然后使用response对象动态地创建Web页面、改变HTTP标头、返回服务器端地状态码…

C++string续

一.find_first_of与find 相同:都是从string里面找字符,传参格式一样(都可以从某个位置开始找) 不同:find_first_of只能找字符,find可以找字符串 find_first_of参数里面的string与char*是每个字符的集合,指找出string…

ETL工具中JSON格式的转换方式

JSON的用处 JSON(JavaScript Object Notation)是一种轻量级的数据交换格式,其设计初衷是为了提升网络应用中数据的传输效率及简化数据结构的解析过程。自其诞生以来,JSON 已成为Web开发乃至众多软件开发领域中不可或缺的一部分&a…

【大模型认识】警惕AI幻觉,利用插件+微调来增强GPT模型

文章目录 一. 大模型的局限1. 大模型不会计算2. 甚至明目张胆的欺骗 二. 使用插件和微调来增强GPT模型1. 模型的局限性2. 插件来增强大模型的能力3. 微调技术-提高特定任务的准确性 一. 大模型的局限 1. 大模型不会计算 LLM根据给定的输入提示词逐个预测下一个词(…

Leaflet在WGS84 Web墨卡托投影与WGS84经纬度投影下空间信息变形问题及修正-以圆为例

目录 前言 一、投影的相关知识 1、经纬度投影 2、Web墨卡托投影 二、经纬度投影下的空间信息展示 1、空间信息展示 2、效果展示 3、经纬度投影下的圆修正 三、Web墨卡托投影下空间信息展示 1、底图引用 2、自定义生成圆 总结 前言 在GIS的知识海洋中,对…

Redis集群分片

什么是集群 集群是由多个复制集组成的,能提供在多个redis节点间共享数据的程序集 简而言之就是将原来的单master主机拆分为多个master主机,将整个数据集分配到各主机上 集群的作用 集群中可以存在多个master,而每个master可以挂载多个slave自带哨兵的故障转移机制,不需要再去…

【Android】源码解析Activity的结构分析

源码解析Activity的结构分析 目录 1、Activity、View、Window有什么关联?2、Activity的结构构建流程3 源码解析Activity的构成 3.1 Activity的Attach方法3.2 Activity的OnCreate 4、WindowManager与View的关系总结 1、一个Activity对应几个WindowManage&#xff0…

【论文阅读笔记】关于“二进制函数相似性检测”的调研(Security 22)

个人博客链接 注:部分内容参考自GPT生成的内容 [Security 22] 关于”二进制函数相似性检测“的调研(个人阅读笔记) 论文:《How Machine Learning Is Solving the Binary Function Similarity Problem》(Usenix Securi…

C++ 模拟实现 priority_queue(优先队列)

目录 一,优先队列简介 二,priority_queue 的内部实现原理 三,模拟实现 priority_queue 1,模板参数与数据结构 2,构造 3,辅助功能(堆的有序化,建立堆) 4&#xff0…

嵌入式学习69-C++(Opencv)

知识零碎: QT的两种编译模式 1.debug 调试模式 …

springboot整合rabbitmq的不同工作模式详解

前提是已经安装并启动了rabbitmq,并且项目已经引入rabbitmq,完成了配置。 不同模式所需参数不同,生产者可以根据参数不同使用重载的convertAndSend方法。而消费者均是直接监听某个队列。 不同的交换机是实现不同工作模式的关键组件.每种交换…

泛微E9开发 选择项目类型,自动带出该类项目的预计金额(即下拉框联动浮点型数据)

1、功能背景 在用户进行项目类型选择时,自动带出其余的标准数据(样例中的预计金额),如对员工进行表彰奖励时,不同的表彰有不同的奖励金额,那么我们就可以使用以下的方式来进行操作。 2、展示效果 3、实现…

WiFine通信与Wi-sun通信对比

调制速率 WiFine通信:(G)FSK 50Kbps~500Kbps ;LoRa 5Kbps~37.5Kbps Wi-Sun通信:(G)FSK 50Kbps~300Kbps ;QPSK/OFDM 计划中… 2、协议简介 WiFine通信:为低成本、低功耗、移动设备倾力打造 的轻量级、分布式无线移动…

英语新概念2-回译法-lesson13

The Greenwood Boys 绿林少年是一组流行歌手们。现在他们正在参观城市里的所有公园,他们明天就要到这。他们将坐火车到并且大多数小镇上的年轻人将要欢迎他们,明天晚上他们将要在工人俱乐部唱歌。绿林少年将在这待五天,在这期间,…
最新文章