27、ResNet50处理STEW数据集,用于情感三分类+全备的代码

1、数据介绍

IEEE-Datasets-STEW:SIMULTANEOUS TASK EEG WORKLOAD DATASET :

该数据集由48名受试者的原始EEG数据组成,他们参加了利用SIMKAP多任务测试进行的多任务工作负荷实验。受试者在休息时的大脑活动也在测试前被记录下来,也包括在其中。Emotiv EPOC设备,采样频率为128Hz,有14个通道,用于获取数据,每个案例都有2.5分钟的EEG记录。受试者还被要求在每个阶段后以1到9的评分标准对其感知的心理工作量进行评分,评分结果在单独的文件中提供。

说明:每个受试者的数据遵循命名惯例:subno_task.txt。例如,sub01_lo.txt将是受试者1在休息时的原始脑电数据,而sub23_hi.txt将是受试者23在多任务测试中的原始脑电数据。每个数据文件的行对应于记录中的样本,列对应于EEG设备的14个通道: AF3, F7, F3, FC5, T7, P7, O1, O2, P8, T8, FC6, F4, F8, AF4。

数据说明、下载地址:

STEW: Simultaneous Task EEG Workload Data Set | IEEE Journals & Magazine | IEEE Xplore

2、代码

本次使用ResNet50,去做此情感数据的分类工作,数据导入+模型训练+测试代码如下:

import torch
import torchvision.datasets
from torch.utils.data import Dataset        # 继承Dataset类
import os
from PIL import Image
import numpy as np
from torchvision import transforms
 
 
# 预处理
data_transform = transforms.Compose([
    transforms.Resize((224,224)),           # 缩放图像
    transforms.ToTensor(),                  # 转为Tenso
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))       # 标准化
])


path =  r'C:\STEW\test'

for root,dirs,files in os.walk(path):
        print('root',root) #遍历到该目录地址
        print('dirs',dirs) #遍历到该目录下的子目录名 []
        print('files',files)  #遍历到该目录下的文件  []
 
def read_txt_files(path):
    # 创建文件名列表
    file_names = []
    # 遍历给定目录及其子目录下的所有文件
    for root, dirs, files in os.walk(path):
        # 遍历所有文件
        for file in files:
            # 如果是 .txt 文件,则加入文件名列表
            if file.endswith('.txt'): # endswith () 方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回True,否则返回False。
                file_names.append(os.path.join(root, file))
    # 返回文件名列表
    return file_names

class DogCat(Dataset):      # 数据处理
    def __init__(self,root,transforms = None):                  # 初始化,指定路径,是否预处理等等
 
        #['cat.15454.jpg', 'cat.445.jpg', 'cat.46456.jpg', 'cat.656165.jpg', 'dog.123.jpg', 'dog.15564.jpg', 'dog.4545.jpg', 'dog.456465.jpg']
        imgs = os.listdir(root)
 
        self.imgs = [os.path.join(root,img) for img in imgs]    # 取出root下所有的文件
        self.transforms = data_transform                        # 图像预处理
 
    def __getitem__(self, index):       # 读取图片
        img_path = self.imgs[index]
        label = 1 if 'dog' in img_path.split('/')[-1] else 0 
        #然后,就可以根据每个路径的id去做label了。将img_path 路径按照 '/ '分割,-1代表取最后一个字符串,如果里面有dog就为1,cat就为0.
 
        data = Image.open(img_path)
 
        if self.transforms:     # 图像预处理
            data = self.transforms(data)
 
        return data,label
 
    def __len__(self):
        return len(self.imgs)
 
dataset = DogCat('./data/',transforms=True)
 
for img,label in dataset:
    print('img:',img.size(),'label:',label)
'''
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
'''

import os
 
# 获取file_path路径下的所有TXT文本内容和文件名
def get_text_list(file_path):
    files = os.listdir(file_path)
    text_list = []
    for file in files:
        with open(os.path.join(file_path, file), "r", encoding="UTF-8") as f:
            text_list.append(f.read())
    return text_list, files
 
class ImageFolderCustom(Dataset):

    # 2. Initialize with a targ_dir and transform (optional) parameter
    def __init__(self, targ_dir: str, transform=None) -> None:

        # 3. Create class attributes
        # Get all image paths
        self.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg")) # note: you'd have to update this if you've got .png's or .jpeg's
        # Setup transforms
        self.transform = transform
        # Create classes and class_to_idx attributes
        self.classes, self.class_to_idx = find_classes(targ_dir)

    # 4. Make function to load images
    def load_image(self, index: int) -> Image.Image:
        "Opens an image via a path and returns it."
        image_path = self.paths[index]
        return Image.open(image_path) 

    # 5. Overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)
    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.paths)

    # 6. Overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, data and label (X, y)."
        img = self.load_image(index)
        class_name  = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpeg
        class_idx = self.class_to_idx[class_name]

        # Transform if necessary
        if self.transform:
            return self.transform(img), class_idx # return data, label (X, y)
        else:
            return img, class_idx # return data, label (X, y)
                  
import torchvision as tv
import numpy as np
import torch
import time
import os
from torch import nn, optim
from torchvision.models import resnet50
from torchvision.transforms import transforms
 
os.environ["CUDA_VISIBLE_DEVICE"] = "0,1,2"
 
# cifar-10进行测验

class Cutout(object):
    """Randomly mask out one or more patches from an image.
    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length
 
    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        h = img.size(1)
        w = img.size(2)
 
        mask = np.ones((h, w), np.float32)
 
        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)
 
            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)
 
            mask[y1: y2, x1: x2] = 0.
 
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask
 
        return img
 
def load_data_cifar10(batch_size=128,num_workers=2):
    # 操作合集
    # Data augmentation
    train_transform_1 = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.RandomRotation(degrees=(-80,80)),  # 随机角度翻转
        transforms.ToTensor(),
        transforms.Normalize(
            (0.491339968,0.48215827,0.44653124), (0.24703233,0.24348505,0.26158768)  # 两者分别为(mean,std)
        ),
        Cutout(1, 16),  # 务必放在ToTensor的后面
    ])
    train_transform_2 = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std)
        )
    ])
    test_transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std)
        )
    ])
    # 训练集1
    trainset1 = tv.datasets.CIFAR10(
        root='data',
        train=True,
        download=False,
        transform=train_transform_1,
    )
    # 训练集2
    trainset2 = tv.datasets.CIFAR10(
        root='data',
        train=True,
        download=False,
        transform=train_transform_2,
    )
    # 测试集
    testset = tv.datasets.CIFAR10(
        root='data',
        train=False,
        download=False,
        transform=test_transform,
    )
    # 训练数据加载器1
    trainloader1 = torch.utils.data.DataLoader(
        trainset1,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=(torch.cuda.is_available())
    )
    # 训练数据加载器2
    trainloader2 = torch.utils.data.DataLoader(
        trainset2,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=(torch.cuda.is_available())
    )
    # 测试数据加载器
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=(torch.cuda.is_available())
    )
 
    return trainloader1,trainloader2,testloader
 
def main():
    start = time.time()
    batch_size = 128
    cifar_train1,cifar_train2,cifar_test = load_data_cifar10(batch_size=batch_size)
    model = resnet50().cuda()
    # model.load_state_dict(torch.load('_ResNet50.pth'))
    # 存在已保存的参数文件
    # model = nn.DataParallel(model,device_ids=[0,])  # 又套一层
    model = nn.DataParallel(model,device_ids=[0,1,2])
    loss = nn.CrossEntropyLoss().cuda()
    optimizer = optim.Adam(model.parameters(),lr=0.001)
    for epoch in range(50):
        model.train()  # 训练时务必写
        loss_=0.0
        num=0.0
        # train on trainloader1(data augmentation) and trainloader2
        for i,data in enumerate(cifar_train1,0):
            x, label = data
            x, label = x.cuda(),label.cuda()
            # x
            p = model(x) #output
            l = loss(p,label) #loss
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            loss_ += float(l.mean().item())
            num+=1
        for i, data in enumerate(cifar_train2, 0):
            x, label = data
            x, label = x.cuda(), label.cuda()
            # x
            p = model(x)
            l = loss(p, label)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            loss_ += float(l.mean().item())
            num += 1
        model.eval()  # 评估时务必写
        print("loss:",float(loss_)/num)
        # test on trainloader2,testloader
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_train2:
                # [b, 3, 32, 32]
                # [b]
                x, label = x.cuda(), label.cuda()
                # [b, 10]
                logits = model(x)
                # [b]
                pred = logits.argmax(dim=1)
                # [b] vs [b] => scalar tensor
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
                # print(correct)
            acc_1 = total_correct / total_num
        # Test
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                # [b, 3, 32, 32]
                # [b]
                x, label = x.cuda(), label.cuda()
                # [b, 10]
                logits = model(x) #output
                # [b]
                pred = logits.argmax(dim=1)
                # [b] vs [b] => scalar tensor
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
                # print(correct)
            acc_2 = total_correct / total_num
            print(epoch+1,'train acc',acc_1,'|','test acc:', acc_2)
    # 保存时只保存model.module
    torch.save(model.module.state_dict(),'resnet50.pth')
    print("The interval is :",time.time() - start)
 
 
if __name__ == '__main__':
    main()

3、对你有帮助的话,给个关注吧~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/263233.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

测试——VS断点调试

1、问题 本文使用VS对c程序执行简单的断点调试。调试的c程序&#xff1a; #define _CRT_SECURE_NO_WARNINGS #include <stdio.h> #include <Windows.h> #include <string.h>using namespace std;int main(void) {float r0;float s0;printf("请输入圆的…

【Java】Mybatis

MyBatis JavaEE三层框架&#xff1a;表现层、业务层、持久层。 现在开始学习持久层。持久层就是负责与数据库打交道的代码。 框架&#xff1a;就是一个半成品软件。在框架的基础上&#xff0c;可以更加高效地写出代码。 1、MyBatis快速入门 1、准备工作&#xff08;创建sp…

Spring Environment 注入引起NPE问题排查

文章目录 背景原因分析1&#xff09;Spring Aware Bean 是什么&#xff1f;2&#xff09;从 Spring Bean 的生命周期入手 解决方案 背景 写业务代码遇到使用 Spring Environment 注入为 null 的情况&#xff0c;示例代码有以下两种写法&#xff0c;Environment 实例都无法注入…

华为OD机试 - 多段线数据压缩(Java JS Python C)

题目描述 下图中,每个方块代表一个像素,每个像素用其行号和列号表示。 为简化处理,多线段的走向只能是水平、竖直、斜向45度。 上图中的多线段可以用下面的坐标串表示:(2,8),(3,7),(3,6),(3,5),(4,4),(5,3),(6,2),(7,3),(8,4),(7,5)。 但可以发现,这种表示不是最简的,…

RK3568驱动指南|第八篇 设备树插件-第80章 注册attribute实验

瑞芯微RK3568芯片是一款定位中高端的通用型SOC&#xff0c;采用22nm制程工艺&#xff0c;搭载一颗四核Cortex-A55处理器和Mali G52 2EE 图形处理器。RK3568 支持4K 解码和 1080P 编码&#xff0c;支持SATA/PCIE/USB3.0 外围接口。RK3568内置独立NPU&#xff0c;可用于轻量级人工…

【C++入门到精通】互斥锁 (Mutex) C++11 [ C++入门 ]

阅读导航 引言一、Mutex的简介二、Mutex的种类1. std::mutex &#xff08;基本互斥锁&#xff09;2. std::recursive_mutex &#xff08;递归互斥锁&#xff09;3. std::timed_mutex &#xff08;限时等待互斥锁&#xff09;4. std::recursive_timed_mutex &#xff08;限时等待…

【数字图像处理】实验二 图像变换

图像变换 一、实验内容&#xff1a; 1&#xff0e; 熟悉和掌握利用Matlab工具进行数字图像的读、写、显示等数字图像处理基本步骤。 2&#xff0e; 熟练掌握各种图像变换的基本原理及方法。 3&#xff0e; 能够从深刻理解图像变换&#xff0c;并能够思考拓展到一定的应用领域。…

多层负载均衡实现

1、单节点负载均衡 1&#xff09;站点层与浏览器层之间加入了一个反向代理层&#xff0c;利用高性能的nginx来做反向代理 2&#xff09;nginx将http请求分发给后端多个web-server 优点&#xff1a; 1&#xff09;DNS-server不需要动 2&#xff09;负载均衡&#xff1a;通过ngi…

OGG同步异构数据库-表字段变更重新读取异构文件测试验证

OGG同步异构数据库-表字段变更重新读取异构文件测试验证 删除前源和目标端的同步情况&#xff1a; 配置文件信息&#xff1a; 源端&#xff1a; GGSCI (ITSMdoc-236-63) 4> view param etest extract etest setenv (MYSQL_HOME“/data/mysql-5.7.26”) tranlogoptions al…

研究生课程 |《矩阵论》复习

文章目录 【一(18)】填空题【二(10)】范数证明【三(15)】矩阵函数1 计算 e A t e^{At} eAt2 求微分方程的解 【四(10)】QR分解【五(10)】Gerschgorim隔离特征值【六(15)】 A A^ A计算及求解线性方程组1 计算满秩分解2 计算 A A^ A3 判断线性方程组解是否存在 【七(15)】线性…

快速开发教务管理应用,课程表微信小程序源码

介绍 课程表微信小程序源码 快速开发教务管理应用 对接微信公众号每日课表推送 三种导入课表方式可供选择 班级课表导入 爬虫导入课表 学号导入课表

with torch.no_grad()在Pytorch中的应用

with torch.no_grad()在Pytorch中的应用 参考&#xff1a; https://blog.csdn.net/qq_24761287/article/details/129773333 https://blog.csdn.net/sazass/article/details/116668755 在学习Pytorch时&#xff0c;老遇到 with torch.no_grad()&#xff0c;搞不清其作用&#…

P2 H264码流结构分析——Annexb与MP4格式的区别 (中)

前言 从本章开始我们将要学习嵌入式音视频的学习了 &#xff0c;使用的瑞芯微的开发板 &#x1f3ac; 个人主页&#xff1a;ChenPi &#x1f43b;推荐专栏1: 《C_ChenPi的博客-CSDN博客》✨✨✨ &#x1f525; 推荐专栏2: 《Linux C应用编程&#xff08;概念类&#xff09;_Ch…

如何使用Java的GeoTools地理库计算WGS84坐标下的两个经纬度之间得距离

介绍 本章讲解如何使用Java的GeoTools地理库计算基于WGS84坐标的两点之间的距离。适用于后台服务的距离计算。 GeoTools介绍 GeoTools是开源的Java地理信息计算库。GeoServer地图引擎就是基于GeoTools库构建得地图服务,可以说非常强大。 官网地址:https://docs.geotools.o…

python/C 生成beta分布的随机数

python/C 生成beta分布的随机数 文章目录 python/C 生成beta分布的随机数前言一、beta分布理论知识二、python 生成服从beta分布的随机数三、C语言生成服从beta分布的随机数 前言 想把一个算法用C语言实现&#xff0c;其中涉及到了beta分布取随机数&#xff0c;记录一下结果 一…

【Java】网络编程-TCP回显服务器代码编写

前面我们讲了基于UDP的网络编程 UDP回显服务器 UDP字典服务器 下面我们来讲基于TCP的回显服务编写 1、服务器 import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.net.ServerSocket; impo…

25、新加坡南洋理工、新加坡国立大学提出FBCNet:完美融合FBCSP的CNN,EEG解码SOTA水准![抱歉老师,我太想进步了!]

前言&#xff1a; 阴阳差错&#xff0c;因工作需要&#xff0c;需要查阅有关如何将FBCSP融入CNN中的文献&#xff0c;查阅全网&#xff0c;发现只此一篇文章&#xff0c;心中大喜&#xff0c;心想作者哪家单位&#xff0c;读之&#xff0c;原来是自己大导&#xff08;新加坡工…

冬天天冷早安问候语关心话,愿我的每句话都能带给你温馨

1、送你一声问候&#xff0c;为你驱走冬日严寒&#xff0c;送你一份关怀&#xff0c;为你增添丝丝温暖&#xff0c;送你一句祝福&#xff0c;为你驱走所有不快&#xff0c;送你一份关爱&#xff0c;为你增添幸福无限&#xff0c;天虽寒了&#xff0c;我的关心犹在&#xff0c;愿…

Centos安装Docker及使用

文章目录 配置要求Centos安装Docker卸载docker&#xff08;可选&#xff09;安装docker首先需要大家虚拟机联网&#xff0c;安装yum工具然后更新本地镜像源&#xff1a;然后输入安装docker命令&#xff1a;查看docker的版本 启动docker关闭防火墙接着通过命令启动docker 配置镜…

vscode debug c++代码

需要提前写好CMakeLists.txt 在tasks.json中写好编译的步骤&#xff0c;即tasks&#xff0c;如cmake … 和make -j 在lauch.json中配置可执行文件的路径和需要执行tasks中的哪一个任务 具体步骤&#xff1a; 1.写好c代码和CMakeLists.txt 2.配置tasks.json 终端–>配置任务…