首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >基于LSTM的脑电信号分类体系结构

基于LSTM的脑电信号分类体系结构
EN

Stack Overflow用户
提问于 2021-01-20 21:54:09
回答 1查看 849关注 0票数 2

我有一个多类分类问题,我在python3.6中使用了keras & tensorflow。我基于本文中提到的“叠层 LSTM层(a)”实现了高精度的分类:深入学习人类思维实现视觉自动分类

有些事情是这样的:

代码语言:javascript
复制
model.add(LSTM(256,input_shape=(32, 15360), return_sequences=True))
model.add(LSTM(128), return_sequences=True)
model.add(LSTM(64), return_sequences=False)

model.add(Dense(6, activation='softmax'))

设32为脑电通道#,15360为信号长度为160 Hz的96秒记录。

我想实现上面提到的“通道LSTM和通用LSTM (b)”策略,但我不知道该如何通过这个新策略来建立我的模型。

请帮帮我。Thx

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-01-21 04:09:18

首先,在使用Common 实现编码器时遇到了问题,默认情况下,角膜LSTM层采用形状为(batch, timesteps, channel)的输入,因此如果设置input_shape=(32, 15360),则模型将读入为timesteps=32channel=15360,这与您的意图相反。

因为使用的第一层编码器通用LSTM描述为:

在每个时间步骤t中,第一层接受输入s(·,t)(从这个意义上说,“公共”意味着所有脑电通道最初都是fed8进入同一层)。

因此,使用通用LSTM实现编码器的正确实现应该是:

代码语言:javascript
复制
import tensorflow as tf
from tensorflow.keras import layers, models

timesteps = 15360
channels_num = 32

model = models.Sequential()
model.add(layers.LSTM(256,input_shape=(timesteps, channels_num), return_sequences=True))
model.add(layers.LSTM(128, return_sequences=True))
model.add(layers.LSTM(64, return_sequences=False))
model.add(layers.Dense(6, activation='softmax'))

model.summary()

哪个输出(PS:您可以总结您最初的实现,然后您将看到Total params要大得多):

代码语言:javascript
复制
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
lstm (LSTM)                  (None, 15360, 256)        295936
_________________________________________________________________
lstm_1 (LSTM)                (None, 15360, 128)        197120
_________________________________________________________________
lstm_2 (LSTM)                (None, 64)                49408
_________________________________________________________________
dense (Dense)                (None, 6)                 390
=================================================================
Total params: 542,854
Trainable params: 542,854
Non-trainable params: 0
_________________________________________________________________

第二,因为使用信道LSTM和通用LSTM的编码器被描述为:

第一编码层由几个LSTM组成,每个LSTM仅连接到一个输入信道:例如,第一LSTM处理输入数据(1,·),第二LSTM进程(2,·)等等。这样,每个“通道LSTM”的输出就是单个通道数据的汇总。然后,第二编码层通过接收所有信道LSTM的级联输出向量作为输入来执行信道间分析。如上所述,在最后一步使用最深的LSTM输出作为编码器的输出矢量。

由于第一层中的每个LSTM只处理一个信道,所以我们需要在第一层中使用等于信道数的LSTM数,下面的代码将演示如何使用信道LSTM和通用LSTM构建一个编码器。

代码语言:javascript
复制
import tensorflow as tf
from tensorflow.keras import layers, models

timesteps = 15360
channels_num = 32

first_layer_inputs = []
second_layer_inputs = []
for i in range(channels_num):
    l_input = layers.Input(shape=(timesteps, 1))
    first_layer_inputs.append(l_input)
    l_output = layers.LSTM(1, return_sequences=True)(l_input)
    second_layer_inputs.append(l_output)

x = layers.Concatenate()(second_layer_inputs)
x = layers.LSTM(128, return_sequences=True)(x)
x = layers.LSTM(64, return_sequences=False)(x)
outputs = layers.Dense(6, activation='softmax')(x)

model = models.Model(inputs=first_layer_inputs, outputs=outputs)

model.summary()

产出:

代码语言:javascript
复制
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_3 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_5 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_6 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_7 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_8 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_9 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_10 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_11 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_12 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_13 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_14 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_15 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_16 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_17 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_18 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_19 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_20 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_21 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_22 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_23 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_24 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_25 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_26 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_27 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_28 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_29 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_30 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_31 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_32 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
lstm (LSTM)                     (None, 15360, 1)     12          input_1[0][0]
__________________________________________________________________________________________________
lstm_1 (LSTM)                   (None, 15360, 1)     12          input_2[0][0]
__________________________________________________________________________________________________
lstm_2 (LSTM)                   (None, 15360, 1)     12          input_3[0][0]
__________________________________________________________________________________________________
lstm_3 (LSTM)                   (None, 15360, 1)     12          input_4[0][0]
__________________________________________________________________________________________________
lstm_4 (LSTM)                   (None, 15360, 1)     12          input_5[0][0]
__________________________________________________________________________________________________
lstm_5 (LSTM)                   (None, 15360, 1)     12          input_6[0][0]
__________________________________________________________________________________________________
lstm_6 (LSTM)                   (None, 15360, 1)     12          input_7[0][0]
__________________________________________________________________________________________________
lstm_7 (LSTM)                   (None, 15360, 1)     12          input_8[0][0]
__________________________________________________________________________________________________
lstm_8 (LSTM)                   (None, 15360, 1)     12          input_9[0][0]
__________________________________________________________________________________________________
lstm_9 (LSTM)                   (None, 15360, 1)     12          input_10[0][0]
__________________________________________________________________________________________________
lstm_10 (LSTM)                  (None, 15360, 1)     12          input_11[0][0]
__________________________________________________________________________________________________
lstm_11 (LSTM)                  (None, 15360, 1)     12          input_12[0][0]
__________________________________________________________________________________________________
lstm_12 (LSTM)                  (None, 15360, 1)     12          input_13[0][0]
__________________________________________________________________________________________________
lstm_13 (LSTM)                  (None, 15360, 1)     12          input_14[0][0]
__________________________________________________________________________________________________
lstm_14 (LSTM)                  (None, 15360, 1)     12          input_15[0][0]
__________________________________________________________________________________________________
lstm_15 (LSTM)                  (None, 15360, 1)     12          input_16[0][0]
__________________________________________________________________________________________________
lstm_16 (LSTM)                  (None, 15360, 1)     12          input_17[0][0]
__________________________________________________________________________________________________
lstm_17 (LSTM)                  (None, 15360, 1)     12          input_18[0][0]
__________________________________________________________________________________________________
lstm_18 (LSTM)                  (None, 15360, 1)     12          input_19[0][0]
__________________________________________________________________________________________________
lstm_19 (LSTM)                  (None, 15360, 1)     12          input_20[0][0]
__________________________________________________________________________________________________
lstm_20 (LSTM)                  (None, 15360, 1)     12          input_21[0][0]
__________________________________________________________________________________________________
lstm_21 (LSTM)                  (None, 15360, 1)     12          input_22[0][0]
__________________________________________________________________________________________________
lstm_22 (LSTM)                  (None, 15360, 1)     12          input_23[0][0]
__________________________________________________________________________________________________
lstm_23 (LSTM)                  (None, 15360, 1)     12          input_24[0][0]
__________________________________________________________________________________________________
lstm_24 (LSTM)                  (None, 15360, 1)     12          input_25[0][0]
__________________________________________________________________________________________________
lstm_25 (LSTM)                  (None, 15360, 1)     12          input_26[0][0]
__________________________________________________________________________________________________
lstm_26 (LSTM)                  (None, 15360, 1)     12          input_27[0][0]
__________________________________________________________________________________________________
lstm_27 (LSTM)                  (None, 15360, 1)     12          input_28[0][0]
__________________________________________________________________________________________________
lstm_28 (LSTM)                  (None, 15360, 1)     12          input_29[0][0]
__________________________________________________________________________________________________
lstm_29 (LSTM)                  (None, 15360, 1)     12          input_30[0][0]
__________________________________________________________________________________________________
lstm_30 (LSTM)                  (None, 15360, 1)     12          input_31[0][0]
__________________________________________________________________________________________________
lstm_31 (LSTM)                  (None, 15360, 1)     12          input_32[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 15360, 32)    0           lstm[0][0]
                                                                 lstm_1[0][0]
                                                                 lstm_2[0][0]
                                                                 lstm_3[0][0]
                                                                 lstm_4[0][0]
                                                                 lstm_5[0][0]
                                                                 lstm_6[0][0]
                                                                 lstm_7[0][0]
                                                                 lstm_8[0][0]
                                                                 lstm_9[0][0]
                                                                 lstm_10[0][0]
                                                                 lstm_11[0][0]
                                                                 lstm_12[0][0]
                                                                 lstm_13[0][0]
                                                                 lstm_14[0][0]
                                                                 lstm_15[0][0]
                                                                 lstm_16[0][0]
                                                                 lstm_17[0][0]
                                                                 lstm_18[0][0]
                                                                 lstm_19[0][0]
                                                                 lstm_20[0][0]
                                                                 lstm_21[0][0]
                                                                 lstm_22[0][0]
                                                                 lstm_23[0][0]
                                                                 lstm_24[0][0]
                                                                 lstm_25[0][0]
                                                                 lstm_26[0][0]
                                                                 lstm_27[0][0]
                                                                 lstm_28[0][0]
                                                                 lstm_29[0][0]
                                                                 lstm_30[0][0]
                                                                 lstm_31[0][0]
__________________________________________________________________________________________________
lstm_32 (LSTM)                  (None, 15360, 128)   82432       concatenate[0][0]
__________________________________________________________________________________________________
lstm_33 (LSTM)                  (None, 64)           49408       lstm_32[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, 6)            390         lstm_33[0][0]
==================================================================================================
Total params: 132,614
Trainable params: 132,614
Non-trainable params: 0
__________________________________________________________________________________________________

现在,由于模型期望输入形状为(channel, batch, timesteps, 1),所以在输入到模型之前,我们必须重新排序数据集的轴,下面的示例代码将向您展示如何重新排序从(batch, timesteps, channel)(channel, batch, timesteps, 1)的轴。

代码语言:javascript
复制
import numpy as np

batch_size = 64
timesteps = 15360
channels_num = 32

x = np.random.rand(batch_size, timesteps, channels_num)
print(x.shape)
x = np.moveaxis(x, -1, 0)[..., np.newaxis]
print(x.shape)
x = [i for i in x]
print(x[0].shape)

产出:

代码语言:javascript
复制
(64, 15360, 32)
(32, 64, 15360, 1)
(64, 15360, 1)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65818241

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档