(动手学习深度学习)第13章 实战kaggle竞赛:CIFAR-10

  1. 导入相关库
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
  1. 下载数据集
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
                                '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')

# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = True

if demo:
    data_dir = d2l.download_extract('cifar10_tiny')
else:
    data_dir = '../data/kaggle/cifar-10/'
  1. 整理数据集
# 查看数据集
def read_csv_labels(fname):
    """读取‘fname’来给标签字典返回一个文件名"""
    with open(fname, 'r') as f:
        lines = f.readlines()[1:]  # readlines(): 每次读文档的一行,以后还需要逐步循环
        tokens = [l.rstrip().split(',') for l in lines]  # rstrip(): 删除字符串后面(右面)的空格或特殊字符, 还有lstrip(左面)、strip(两面)
        return dict((name, label) for name, label in tokens)

labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
print('训练样本:', len(labels))
print('类别:', len(set(labels.values())))  # set(): 集合,里面不能包含重复的元素,接受一个list作为参数

在这里插入图片描述
将验证集从原始的训练集钟拆分出来

# 拆分数据集:训练集、验证集
def copyfile(filename, target_dir):
    """将文件复制到目标目录"""
    os.makedirs(target_dir, exist_ok=True)  # 创建多层目录,exist_ok为True:在目标目录已存在的情况下不会触发FileExistsError异常。
    shutil.copy(filename, target_dir)  #拷贝文件,filename:要拷贝的文件;target_dir:目标文件夹

def reorg_train_valid(data_dir, labels, valid_ratio):
    """将验证集从原始训练集钟拆分出来"""
    # 训练数据集中样本数量最少的类别中的样本数
    # Counter: 计数器,返回一个字典,键为元素,值为元素个数;
    # .most_common(): 返回一个列表, 列表元素为(元素,出现次数),默认按出现频率排序
    # [-1]: 样本数量最少的类别(类别, 样本数),[-1][1]: 样本数数量最少的类别中的样本数
    n = collections.Counter(labels.values()).most_common()[-1][1]
    # 验证集中每个类别的样本数
    n_valid_per_label= max(1, math.floor((n * valid_ratio)))  # math.floor(): 向下取整  math.ceil(): 向上取整
    label_count = {}

    # 遍历原始训练集中的每个样本
    for train_file in os.listdir(os.path.join(data_dir, 'train')):
        label = labels[train_file.split('.')[0]]  # 从文件名中提取标签
        fname = os.path.join(data_dir, 'train', train_file)
        copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train_valid', label))
        # 如果该类别的样本数还未达到在验证集中的设定数量,则将样本复制到验证集中
        if label not in label_count or label_count[label] < n_valid_per_label:
            copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'valid', label))
            label_count[label] = label_count.get(label, 0) + 1
        else:
            copyfile(fname, os.path.join(data_dir, 'train_valid_test', 'train', label))

    return n_valid_per_label

# reorg_test函数用来在预测期间整理测试集,以方便读取
def reorg_test(data_dir):
    """在预测期间整理测试集,以方便读取"""
    # 遍历测试集中的每个样本
    for test_file in os.listdir(os.path.join(data_dir, 'test')):
        # 将测试集中的样本复制到新的目录结构中的 'test' 子目录下,标签为 'unknown'
        copyfile(os.path.join(data_dir, 'test', test_file),
                 os.path.join(data_dir, 'train_valid_test', 'test', 'unknown'))
# 整个处理数据集函数
def reorg_cifar10_data(data_dir, valid_ratio):
    labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
    reorg_train_valid(data_dir, labels, valid_ratio)
    reorg_test(data_dir)
  • 这个小规模数据集的批量大小是32,在实际的cifar-10数据集中,可以设为128
  • 将10%的训练样本作为调整超参数的验证集
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)
结果会生成一个train_valid_test的文件夹,里面有:
- test文件夹---unknow文件夹:5张没有标签的测试照片
- train_valid文件夹---10个类被的文件夹:每个文件夹包含所属类别的全部照片
- train文件夹--10个类别的文件夹:每个文件夹下包含90%的照片用于训练
- valid文件夹--10个类别的文件夹:每个文件夹下包含10%的照片用于验证
  1. 图像增广
transform_train = torchvision.transforms.Compose([
    # 原本图像是32*32,先放大成40*40, 在随机裁剪为32*32,实现训练数据的增强
    torchvision.transforms.Resize(40),
    torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        [0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]
    )
])
transform_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    # 标准化图像的每个通道 : 消除评估结果中的随机性
    torchvision.transforms.Normalize(
        [0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]
    )
])
  1. 加载数据集
train_ds, train_valid_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train
    ) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder), transform=transform_test
    ) for folder in ['valid', 'test']
]
  1. 定义迭代器,方便快速迭代数据
train_iter, train_valid_iter = [
    torch.utils.data.DataLoader(
        dataset, batch_size, shuffle=True, drop_last=True
    ) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(
    valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(
    test_ds, batch_size, shuffle=False, drop_last=False
)
  1. 定义模型与损失函数
# 对resnet18做微调,输入通道数为3, 输出类别数为10
def get_net():
    num_classes = 10
    net = d2l.resnet18(num_classes, in_channels=3)
    return net
# 查看网络模型
get_net()

在这里插入图片描述

# 使用交叉熵损失函数作为损失函数: 直接返回n分样本的loss
loss = nn.CrossEntropyLoss(reduction='none')
  1. 定义训练函数
# 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):
    trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
    num_batches, timer = len(train_iter), d2l.Timer()
    legend = ['train loss', 'train acc']
    if valid_iter is not None:
        legend.append('valid acc')
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        net.train()
        metric = d2l.Accumulator(3)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = d2l.train_batch_ch13(net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0])
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0]/ metric[2], metric[1] / metric[2], None))
        if valid_iter is not None:
            valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)
            animator.add(epoch+1, (None, None, valid_acc))
        scheduler.step()
    measures = (f'train loss {metric[0] / metric[2]:.3f},'
                f'train acc{metric[1] / metric[2]:.3f}')
    if valid_iter is not None:
        measures += f', valid acc {valid_acc:.3f}'
    print(measures + f'\n{metric[2] * num_epochs /timer.sum():.1f}'
                     f'example/sec on {str(devices)}')
  1. 训练模型
    • (数据集太小,导致精度不高)
import time

# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以

# 训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)

# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以

# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

在这里插入图片描述
10. 对测试集进行分类并提交结果

net, preds = get_net(), []
train(net ,train_valid_iter, None, num_epochs, lr, wd, devices, lr_period, lr_decay)
for X, _ in test_iter:
    y_hat = net(X.to(devices[0]))
    preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())
sorted_ids = list(range(1, len(test_ds) + 1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id' : sorted_ids, 'label': preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])
df.to_csv('submission.csv', index=False)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/169895.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

解决证书加密问题:OpenSSL与urllib3的兼容性与优化

在使用客户端证书进行加密通信时&#xff0c;用户可能会遇到一些问题。特别是当客户端证书被加密并需要密码保护时&#xff0c;OpenSSL会要求用户输入密码。这对于包含多个调用的大型会话来说并不方便&#xff0c;因为密码无法在连接的多个调用之间进行缓存和重复使用。用户希望…

【mediasoup】TransportCongestionControlClient 1: 代码走读

TransportCongestionControlClient 基于m77版本的libwebrtc ,但是TransportCongestionControlClient 并不是libwebrt中的,是mediasoup自己封装实现:TransportCongestionControlClient 用于发送端D:\XTRANS\soup\mediasoup-sfu-cpp\worker\src\RTC\TransportCongestionContro…

HarmonyOS开发(四):UIAbility组件

1、UIAbility概述 UIAbility 一种包含用户界面的应用组件用于与用户进行交互系统调度的单元为应用提供窗口在其中绘制界同 注&#xff1a;每一个UIAbility实例&#xff0c;都对应一个最近任务列表中的任务。 一个应用可以有一个UIAbility也可以有多个UIAbility。 如一般的…

BLIP-2:冻结现有视觉模型和大语言模型的预训练模型

Li J, Li D, Savarese S, et al. Blip-2: Bootstrapping language-image pre-training with frozen image encoders and large language models[J]. arXiv preprint arXiv:2301.12597, 2023. BLIP-2&#xff0c;是 BLIP 系列的第二篇&#xff0c;同样出自 Salesforce 公司&…

力扣贪心——跳跃游戏I和II

1 跳跃游戏 利用边界进行判断&#xff0c;核心就是判定边界&#xff0c;边界内所有步数一定是最小的&#xff0c;然后在这个边界里找能到达的最远地方。 1.1 跳跃游戏I class Solution {public boolean canJump(int[] nums) {int len nums.length;int maxDistance 0;int te…

C/C++多级指针与多维数组

使用指针访问数组 指针类型的加减运算可以使指针内保存的首地址移动。 指针类型加n后。首地址向后移动 n * 步长 字节。 指针类型减n后。首地址向前移动 n * 步长 字节。 步长为指针所指向的类型所占空间大小。 例如&#xff1a; int *p (int *)100;p 1&#xff0c;结果为首…

[机缘参悟-119] :反者道之动与阴阳太极

目录 一、阴阳对立、二元对立的规律 1.1 二元对立 1.2 矛盾的对立与统一 二、阴阳互转、阴阳变化、变化无常 》无序变化和有序趋势的规律 三、阴阳合一、佛魔一体、善恶同源 四、看到积极的一面 五、反者道之动 5.1 概述 5.2 "否极泰来" 5.3 “乐极生悲”…

科大讯飞 vue.js 语音听写流式实现 全网首发

组件下载 还是最近的需求&#xff0c;页面表单输入元素过多&#xff0c;需要实现语音识别来由用户通过朗读的方式向表单中填写数据&#xff0c;尽量快的、高效的完成表单数据采集及输入。 国内科大讯飞在语音识别方面的建树还是有目共睹&#xff0c;于是还是选择了科大讯飞的平…

让别人访问电脑本地

查看本地IP地址&#xff1a; 使用ipconfig&#xff08;Windows&#xff09;或ifconfig&#xff08;Linux/macOS&#xff09;命令来查看你的计算机本地网络的IP地址。确保*****是你的本地IP地址。 防火墙设置&#xff1a; 确保你的防火墙允许从外部访问*****。你可能需要在防火…

leetcode:504. 七进制数

一、题目&#xff1a; 链接&#xff1a; 504. 七进制数 - 力扣&#xff08;LeetCode&#xff09; 函数原型&#xff1a; char* convertToBase7(int num) 二、思路 本题要将十进制数转换为二进制数&#xff0c;只要将十进制num数模7再除7&#xff0c;直到num等于0 每次将模7的结…

React整理总结(五、Redux)

1.Redux核心概念 纯函数 确定的输入&#xff0c;一定会产生确定的输出&#xff1b;函数在执行过程中&#xff0c;不能产生副作用 store 存储数据 action 更改数据 reducer 连接store和action的纯函数 将传入的state和action结合&#xff0c;生成一个新的state dispatc…

【算法】二分查找-20231121

这里写目录标题 一、344. 反转字符串二、392. 判断子序列三、581. 最短无序连续子数组四、680. 验证回文串 II 一、344. 反转字符串 提示 简单 865 相关企业 编写一个函数&#xff0c;其作用是将输入的字符串反转过来。输入字符串以字符数组 s 的形式给出。 不要给另外的数组…

数据结构--串的基本概念

目录 串的基本概念 串的定义 串与线性表对比 ​串的基本操作​ 串的比较 字符集编码 乱码问题​编辑 总结 ​串的存储结构 ​串的顺序存储​编辑 串的链式存储 串的基本操作 1、求字串 2、比较 3、定位操作 总结 串的基本概念 串的定义 串与线性表对比 串的…

飞翔的小鸟

运行游戏如下&#xff1a; 碰到柱子就结束游戏 App GameApp类 package App;import main.GameFrame;public class GameApp {public static void main(String[] args) {//游戏的入口new GameFrame();} } main Barrier 类 package main;import util.Constant; import util.Ga…

C/C++最大质因子 2021年12月电子学会中小学生软件编程(C/C++)等级考试一级真题答案解析

目录 C/C最大质因子 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、程序说明 五、运行结果 六、考点分析 C/C最大质因子 一、题目要求 1、编程实现 质因子是指能整除给定正整数的质数。而最大质因子是指一个整数的所有质因子中最大的那个。…

〖大前端 - 基础入门三大核心之JS篇㊴〗- DOM节点的关系

说明&#xff1a;该文属于 大前端全栈架构白宝书专栏&#xff0c;目前阶段免费&#xff0c;如需要项目实战或者是体系化资源&#xff0c;文末名片加V&#xff01;作者&#xff1a;不渴望力量的哈士奇(哈哥)&#xff0c;十余年工作经验, 从事过全栈研发、产品经理等工作&#xf…

电子学会C/C++编程等级考试2022年06月(一级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:倒序输出 依次输入4个整数a、b、c、d,将他们倒序输出,即依次输出d、c、b、a这4个数。 时间限制:1000 内存限制:65536输入 一行4个整数a、b、c、d,以空格分隔。 0 < a,b,c,d < 108输出 一行4个整数d、c、b、a,整数之…

机器学习-笔记

绪论 参考期刊 ICCV 偏向视觉CVPR 偏向MLIAAA AI原理ICML 参考链接 CSDN 机器学习知识点全面总结 课堂内容学习-0912-N1 对于特征提取&#xff0c;简而言之就是同类聚得紧&#xff0c;异类分得开&#xff1b;   detection研究的是样本二分类问题&#xff0c;即分为正样本…

C语言之sizeof 和 strlen 详细介绍

C语言之sizeof 和 strlen 文章目录 C语言之sizeof 和 strlen1. sizeof 和 strlen 的比较1.1 sizeof1.2 strlen1.3 sizeof 和 strlen 的对比 2. 练习2.1.1 一维数组2.1.2 字符数组 1. sizeof 和 strlen 的比较 1.1 sizeof sizeof是C语言中的一个关键字&#xff0c;计算的是变量…

数字化文化的守护之星:十八数藏的非遗创新之道

在数字时代的浪潮中&#xff0c;十八数藏犹如一颗璀璨的守护之星&#xff0c;为传统文化注入了新的生命力。这个非遗创新项目以数字化为工具&#xff0c;以守护为使命&#xff0c;开辟了文化传承的新航道。 十八数藏是文化数字守护的引领者&#xff0c;通过数字技术&#xff0…