手机版 欢迎访问it开发者社区(www.mfbz.cn)网站

当前位置: > 开发

三步搞定使用Augmentor对训练数据集进行扩增

时间:2021/6/10 16:09:39|来源:|点击: 次

文章目录

    • 前言
    • 实现过程

前言

在训练模型时,有时在数据量较少情况下,避免过拟合,通常会采取人为进行数据增强来达到扩充数据集的目的,下面就介绍一种使用Augmentor来扩充数据集的方法。

实现过程

  • step 1 将源数据放在E盘
  • step 2 运行脚本
  • step 3 生成扩增后的数据

程序实现过程如下:

import numpy as np, Augmentor, cv2, sys
import os
import shutil
def del_file(path):
    ls = os.listdir(path) 
    for i in ls:
        c_path = os.path.join(path, i)
        if os.path.isdir(c_path):
            del_file(c_path)
        else:
            os.remove(c_path)

def Enhancement(filePath, rate):
    index = ngFilePath.rfind("\\")
    print(index)
    dataType = filePath.find("NG")
    dataType1 = filePath.find("OK")

    if(dataType > 0):
        enhancementDir = filePath[0:index] + "\\" + 'EnhanceImg' + '\\' + 'NG'
    if (dataType1 > 0):
        enhancementDir = filePath[0:index] + "\\" + 'EnhanceImg' + '\\' + 'OK'

    showDir = filePath[0:index] + "\\" + 'showImg'
    singleDir = filePath[0:index] + "\\" + 'sigleImg'

    isExist = os.path.exists(enhancementDir);
    if not isExist:
        os.makedirs(enhancementDir)
    else:
        del_file(enhancementDir)
    isExist = os.path.exists(showDir)

    if not isExist:
        os.makedirs(showDir)
    else:
        del_file(showDir)

    isExist = os.path.exists(singleDir)
    if not isExist:
        os.makedirs(singleDir)
    else:
        del_file(singleDir)

    sourceFiles = os.listdir(filePath)
    num = len(sourceFiles)
    sourceList = list(range(num))
    for i in sourceList:
        sourceFilesName = os.path.join(filePath,sourceFiles[i])
        src = cv2.imread(sourceFilesName, 0)
        shutil.copy2(sourceFilesName, singleDir)
        p = Augmentor.Pipeline(singleDir, showDir)
        p.random_brightness(probability= 0.7, min_factor = 0.5, max_factor= 1.2)
        # p.crop_centre(probability=0.5,160, 160)
        p.resize(probability=1, width=160, height=160)
        p.random_contrast(probability= 0.5, min_factor= 0.5, max_factor= 1.2)
        p.sample(rate)
        # shutil.copy2(sourceFilesName, enhancementDir)
        enhancedImg = os.listdir(showDir)
        enhanceImgList = list(range(len(enhancedImg)))
        sampleImgList =  []
        for j in enhanceImgList:
            fileName = enhancedImg[j]
            sampleImgList.append(fileName)
        numSampleImg = list(range(len(sampleImgList)))
        for k in numSampleImg:
            fileName = os.path.join(showDir, sampleImgList[k])
            shutil.copy2(fileName, enhancementDir)

        del_file(showDir)

if __name__ == '__main__':
    ngFilePath = "E:\\IMG\\NG"
    okFilePath = "E:\\IMG\\OK"
    rate = 10
    Enhancement(okFilePath, rate)

注:
由于这里是做二分类,所以将数据分为OK和NG,这里OK文件夹里随便放了5张图片,然后对这5张图片进行数据增强。

在这里插入图片描述
运行脚本后,后自动生成三个文件夹,数据增强后的数据会自动保存在EnhanceImg文件夹下
在这里插入图片描述在这里插入图片描述

Copyright © 2002-2019 某某自媒体运营 版权所有