【信号处理】基于变分自编码器(VAE)的脑电信号增强典型方法实现(tensorflow)

关于

在脑电信号分析处理任务中,数据不均衡是一个常见的问题。针对数据不均衡,传统方法有过采样和欠采样方法来应对,但是效果有限。本项目通过变分自编码器对脑电信号进行生成增强,提高增强样本的多样性,从而提高最终的后端分析性能。

EEG数据增强方法参考:https://dlib.phenikaa-uni.edu.vn/bitstream/PNK/8319/1/Data%20Augmentation%20techniques%20in%20time%20series%20domain%20a%20survey%20and%20taxonomy-2023.pdf

工具

数据集下载地址: BCI Competition IV

 

方法实现

加载必要的库函数和数据

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob

from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, BatchNormalization, LeakyReLU, Dense, Lambda, Reshape, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mse
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.callbacks import EarlyStopping

from tensorflow.keras import backend as K



direc = r'bci_iv_2a_data/A01/train/0/'     #data directory

train_dataset = []
train_label = []

test_dataset = []
test_label = []

files = os.listdir(direc)
for j, name in enumerate(files):
    filename = glob.glob(direc + '/'+ name)
    df = pd.read_csv(filename[0], index_col=None, header=None)
    df = df.drop(0, axis=1)     #dropping column of channel names
    df = df.iloc[:,0:1000]      #taking 1000 timesteps
    train_dataset.append(np.array(df))
            


train_dataset = np.array(train_dataset)
train_data = np.expand_dims(train_dataset,axis=-1)

 VAE模型>编码器定义

# VAE model
input_shape=(X_train.shape[1:])
batch_size = 32
kernel_size = 5
filters = 16
latent_dim = 2
epochs = 1000

# reparameterization
def sampling(args): 
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon




# encoder
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs

filters = filters* 2
x = Conv2D(filters=filters,kernel_size=(1, 50),strides=(1,25),)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)


filters = filters* 2
x = Conv2D(filters=filters,kernel_size=(22, 1),)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)

shape = K.int_shape(x)

x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)
z_log_var = z_log_var + 1e-8 

# reparameterization
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var]) 

encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()

  VAE模型>解码器定义

# decoder 
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

x = Conv2DTranspose(filters=filters,kernel_size=(22, 1),activation='relu',)(x)
x = BatchNormalization()(x)

filters = filters// 2
x = Conv2DTranspose(filters=filters,kernel_size=(1, 50),activation='relu',strides=(1,25))(x)
x = BatchNormalization()(x)

filters = filters// 2
outputs = Conv2DTranspose(filters=1,kernel_size=kernel_size,padding='same',name='decoder_output')(x)

decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()
# VAE model (merging encoder and decoder)
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae')
vae.summary()

 定义损失函数

# defining Custom loss function 
reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))

reconstruction_loss *= input_shape[0] * input_shape[1]
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)

#optimizer
optimizer = Adam(learning_rate=0.001, beta_1=0.5, beta_2=0.999)

# compiling vae
vae.compile(optimizer=optimizer, loss=None)
vae.summary()

 模型配置和训练

# early stopping callback
callbacks = EarlyStopping(monitor = 'val_loss',
                          mode='min',
                          patience =50,
                          verbose = 1,
                          restore_best_weights = True)


# fit vae model
history = vae.fit(X_train,X_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(X_test, X_test),callbacks=callbacks)

训练流程可视化

# loss curves
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('loss curves')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()

 中间隐空间特征2D可视化

# 2D plot of the classes in latent space
z_m, _, _ = encoder.predict(X_test,batch_size=batch_size)
plt.figure(figsize=(12, 10))
plt.scatter(z_m[:, 0], z_m[:, 1], c=X_test[:,0,0,0])
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.show()

 数据合成

# predicting on validation data
pred=vae.predict(X_test)

代码获取

附文章底部;

相关项目开发,问题咨询,欢迎交流沟通。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/514401.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Layui】------ layui实现table表格拖拽行、列位置的示例代码

一、完整的示例代码&#xff1a;&#xff08;请使用layui v2.8.3的版本&#xff09;看懂就能用、不要照搬、照搬会出错误、拷贝重要代码改改符合你自己的需求。 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><…

Camtasia Studio2024汉化版下载(功能强大的屏幕录制和视频编辑软件)

Camtasia Studio 2024是一款功能强大的屏幕录制和视频编辑软件&#xff0c;由TechSmith公司开发。这款软件不仅能够帮助用户轻松地记录电脑屏幕上的任何操作&#xff0c;还可以将录制的视频进行专业的编辑和制作&#xff0c;最终输出高质量的视频教程、演示文稿、培训课程等。 …

如何在本地使用Ollama运行开源LLMs

本文将指导您下载并使用Ollama&#xff0c;在您的本地设备上与开源大型语言模型&#xff08;LLMs&#xff09;进行交互的强大工具。 与像ChatGPT这样的闭源模型不同&#xff0c;Ollama提供透明度和定制性&#xff0c;使其成为开发人员和爱好者的宝贵资源。 我们将探索如何下载…

【大数据存储】实验4 NoSQL数据库

实验4 NoSQL数据库 NoSQL数据库的安装和使用实验环境&#xff1a; Ubuntu 22.04.3 Jdk 1.8.0_341 Hadoop 3.2.3 Hbase 2.4.17 Redis 6.0.6 mongdb 6.0.12 mogosh 2.1.0 Redis 安装redis完成 新建终端启动redisredis-server新建一个终端redis-cli 建表操作 尝…

Vue项目中 安装及使用Sass(scss)

普通方法 一、安装使用scss 1. 安装 scss npm install scss --save2. 安装 node-sass 和 sass-loader sass-loader&#xff1a;把 sass编译成css node-sass&#xff1a;nodejs环境中将sass转css 提示&#xff1a;限制 node-sass&#xff0c;sass-loader 版本号&#xff0c;…

Jmeter02-1:参数化组件CVS

目录 1、Jmeter组件&#xff1a;参数化概述 1.1 是什么&#xff1f; 1.2 为什么&#xff1f; 1.3 怎么用&#xff1f; 2、Jmeter组件&#xff1a;参数化实现之CSV Data Set Config(重点中重点) 2.1 是什么&#xff1f; 2.2 为什么&#xff1f; 2.3 怎么用&#xff1f; …

高斯消元详解

算法概述 高斯消元法是一个用来求解线性方程组的算法 那么什么是线性方程组呢? 线性:每个未知数次数都为1次方程组:多个方程&#xff0c;多个未知数。 &#xff08;a1x1a2x2..anxnbn&#xff09;x为一次的 当x是平方的时候就不是线性 简而言之就是有多个未知数&#xff…

docker版Elasticsearch安装,ik分词器安装,用户名密码配置,kibana安装

1、安装es和ik分词器 创建映射目录并赋予权限&#xff1a; mkdir -p /docker_data/elasticsearch/conf mkdir -p /docker_data/elasticsearch/data mkdir -p /docker_data/elasticsearch/plugins chmod -R 777 /docker_data/elasticsearch编写配置文件&#xff1a; vi /dock…

水果销售(源码+文档)

水果销售管理系统&#xff08;小程序、ios、安卓都可部署&#xff09; 文件包含内容程序简要说明含有功能项目截图客户端添加地址首页商品详细意见反馈待发货商品分类我的代付款我的地址搜索防骗指南资料修改登录注册 后端管理分类管理反馈管理订单管理商品管理用户管理 文件包…

医疗器械5G智能制造工厂数字孪生可视化平台,推进行业数字化转型

医疗设备5G智能制造工厂数字孪生可视化平台&#xff0c;推进行业数字化转型。在数字化浪潮的推动下&#xff0c;医疗设备行业正迎来一场深刻的变革。5G技术的崛起&#xff0c;智能制造工厂的兴起&#xff0c;以及数字孪生可视化平台的出现&#xff0c;正在共同推动医疗设备行业…

C# WPF编程-命令

C# WPF编程-命令 概述WPF命令模型ICommand接口RoutedCommand类RoutedUICommand类命令库 概述 使用路由事件可以响应广泛的鼠标和键盘事件&#xff0c;这些事件是低级的元素。在实际应用程序中&#xff0c;功能被划分成一些高级的任务。这些任务可通过各种不同的动作和用户界面…

[StartingPoint][Tier0]Preignition

Task 1 Directory Brute-forcing is a technique used to check a lot of paths on a web server to find hidden pages. Which is another name for this? (i) Local File Inclusion, (ii) dir busting, (iii) hash cracking. (目录暴力破解是一种用于检查 Web 服务器上的大…

文献速递:深度学习胰腺癌诊断--螺旋变换与模型驱动的多模态深度学习方案相结合,用于自动预测胰腺癌中TP53突变麦田医学

Title 题目 Combined Spiral Transformation and Model-Driven Multi-Modal Deep Learning Scheme for Automatic Prediction of TP53 Mutation in Pancreatic Cancer 螺旋变换与模型驱动的多模态深度学习方案相结合&#xff0c;用于自动预测胰腺癌中TP53突变 01 文献速递介…

计算机视觉——图像金字塔理解与代码示例

图像金字塔 有时为了在图像中检测一个物体&#xff08;例如人脸、汽车或其他类似的物体&#xff09;&#xff0c;需要调整图像的大小或对图像进行子采样&#xff0c;并进行进一步的分析。在这种情况下&#xff0c;会保持一组具有不同分辨率的同一图像。称这种集合为图像金字塔…

【数据分析实战】印尼雅加达咖啡市场分析:品牌排名与市场趋势解读

目录 背景介绍数据展示数据分析可视化1. 各市咖啡店占比&#xff1a;1.1 可视化代码1.2 可视化结果1.3 浅薄解读 2. 品牌市场份额排名&#xff1a;2.1 可视化结果1.2 浅薄解读 3. 品牌消费者满意指数&#xff1a;3.1 可视化代码3.2 可视化结果3.3 浅薄解读 写在最后 背景介绍 …

03 Python进阶:MySQL - mysql-connector

mysql-connector安装 要在 Python 中使用 MySQL 数据库&#xff0c;你需要安装 MySQL 官方提供的 MySQL Connector/Python。下面是安装 MySQL Connector/Python 的步骤&#xff1a; 首先&#xff0c;确保你已经安装了 Python&#xff0c;如果没有安装&#xff0c;可以在 Python…

Flutter应用版本管理与更新策略:在苹果商店上架后的持续优化

引言 Flutter是一款由Google推出的跨平台移动应用开发框架&#xff0c;其强大的性能和流畅的用户体验使其备受开发者青睐。然而&#xff0c;开发一款应用只是第一步&#xff0c;将其成功上架到苹果商店才是实现商业目标的关键一步。本文将详细介绍如何使用Flutter将应用程序上…

Linux下Nginx详解

1、概念 1.1 介绍 Nginx是一个高性能的开源Web服务器软件&#xff0c;也可以用作反向代理服务器、负载均衡器和HTTP缓存。它的设计目标是高性能、稳定性、丰富的功能和低资源消耗。 1.2 应用场景 Web服务器&#xff1a;Nginx可以作为静态资源服务器&#xff0c;处理静态文件…

uniapp选择退出到指定页面

方法一&#xff1a;返回上n层页面 onUnload(){uni.navigateBack({delta:5,//返回上5层})},方法二&#xff1a;关闭当前页面&#xff0c;跳转到应用内的某个页面。 uni.redirectTo({url: "../home/index"//页面地址}) 方法三&#xff1a;关闭所有页面&#xff0c;打…

(2024)Ubuntu源码安装多个版本的opencv并切换使用

本人工作会用到x86_64的opencv和aarch64的opencv&#xff0c;所以写下来备忘自用 一、源码编译安装 依赖库安装&#xff1a; sudo apt-get install build-essential libgtk2.0-dev libgtk-3-dev libavcodec-dev libavformat-dev libjpeg-dev libswscale-dev libtiff5-dev o…