configs

configs 部分

```python
import os  # 导入os模块,用于系统级操作

emotion = ["Valence"]  # 定义情绪列表,只包含情绪维度"Valence"

# 配置参数字典
config = {
    "extract_class_label": 1,  # 是否提取类别标签
    "extract_continuous_label": 1,  # 是否提取连续标签
    "extract_eeg": 1,  # 是否提取EEG数据
    "eeg_folder": "eeg",  # 存放EEG数据的文件夹名称
    "eeg_config": {  # EEG数据处理的详细配置
        "sampling_frequency": 256,  # 采样频率
        "window_sec": 2,  # 窗口长度(秒)
        "hop_sec": 0.25,  # 跳跃长度(秒)
        "buffer_sec": 5,  # 缓冲区长度(秒)
        "num_electrodes": 32,  # 电极数量
        "interest_bands": [(0.3, 4), (4, 8), (8, 12), (12, 18), (18, 30), (30, 45)],  # 感兴趣频段
        "f_trans_interest_bands": [(0.1, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)],  # 感兴趣频段的过渡频率
        "channel_slice": {'eeg': slice(0, 32), 'ecg': slice(32, 35), 'misc': slice(35, -1)},  # 通道切片
        "features": ["eeg_bandpower"],  # 特征
        "filter_type": 'cheby2',  # 滤波器类型
        "filter_order": 4  # 滤波器阶数
    },
    "save_npy": 1,  # 是否保存为.npy格式的数据
    "npy_folder": "compacted_48",  # 存放.npy数据的文件夹名称
    "dataset_name": "mahnob",  # 数据集的名称
    "emotion_list": emotion,  # 情绪列表
    "root_directory": r"D:\DingYi\Dataset\MAHNOB-O",  # 原始数据集的根目录路径
    "output_root_directory": r"D:\DingYi\Dataset\MAHNOB-P-R",  # 处理后数据的输出根目录路径
    "raw_data_folder": "Sessions",  # 原始数据存放的文件夹名称
    "multiplier": {  # 不同数据类型的倍增因子
        "video": 16,
        "eeg_raw": 1,
        "eeg_bandpower": 1,
        "eeg_DE": 1,
        "eeg_RP": 1,
        "eeg_Hjorth": 1,
        "continuous_label": 1
    },
    "feature_dimension": {  # 不同特征的维度信息
        "eeg_raw": (16384,),
        "eeg_bandpower": (192,),
        "eeg_DE": (192,),
        "eeg_RP": (192,),
        "eeg_Hjorth": (96,),
        "continuous_label": (1,),
        "class_label": (1,)
    },
    "max_epoch": 15,  # 最大的训练周期数
    "min_epoch": 0,  # 最小的训练周期数
    "model_name": "2d1d",  # 模型的名称
    "backbone": {  # 模型的骨干网络配置
        "state_dict": "res50_ir_0.887",
        "mode": "ir"
    },
    "early_stopping": 10,  # 提前停止训练的步数
    "load_best_at_each_epoch": 1,  # 是否在每个周期加载最佳模型
    "time_delay": 0,  # 时间延迟
    "metrics": ["rmse", "pcc", "ccc"],  # 评估指标
    "save_plot": 0  # 是否保存图形结果
}
```

这段代码是一个Python字典,包含了各种配置参数,用于处理和分析一个名为MAHNOB的数据集,主要用于情绪识别研究。以下是每行代码的解释:

1. `import os`: 导入Python的os模块,用于操作文件路径等系统级操作。

2. `emotion = ["Valence"]`: 定义一个情绪列表,只包含情绪维度"Valence"。

3. `config = { ... }`: 定义一个名为config的字典,包含了各种配置参数。

4. `"extract_class_label": 1`: 是否提取类别标签,这里设为1表示是。

5. `"extract_continuous_label": 1`: 是否提取连续标签,这里设为1表示是。

6. `"extract_eeg": 1`: 是否提取EEG数据,这里设为1表示是。

7. `"eeg_folder": "eeg"`: 存放EEG数据的文件夹名称。

8. `"eeg_config": { ... }`: EEG数据处理的详细配置,包括采样频率、窗口长度、跳跃长度、通道数量等参数。

9. `"save_npy": 1`: 是否保存为.npy格式的数据,这里设为1表示是。

10. `"npy_folder": "compacted_48"`: 存放.npy数据的文件夹名称。

11. `"dataset_name": "mahnob"`: 数据集的名称。

12. `"emotion_list": emotion`: 情绪列表,使用了之前定义的emotion变量。

13. `"root_directory": r"D:\DingYi\Dataset\MAHNOB-O"`: 原始数据集的根目录路径。

14. `"output_root_directory": r"D:\DingYi\Dataset\MAHNOB-P-R"`: 处理后数据的输出根目录路径。

15. `"raw_data_folder": "Sessions"`: 原始数据存放的文件夹名称。

16. `"multiplier": { ... }`: 不同数据类型的倍增因子,用于数据增强或者调整数据量。

17. `"feature_dimension": { ... }`: 不同特征的维度信息,用于数据处理和模型输入。

18. `"max_epoch": 15`: 最大的训练周期数。

19. `"min_epoch": 0`: 最小的训练周期数。

20. `"model_name": "2d1d"`: 模型的名称,这里只是命名用途,实际上没有使用。

21. `"backbone": { ... }`: 模型的骨干网络配置,包括状态字典和模式。

22. `"early_stopping": 10`: 提前停止训练的步数。

23. `"load_best_at_each_epoch": 1`: 是否在每个周期加载最佳模型。

24. `"time_delay": 0`: 时间延迟,用于连续标签在数据点中的移动。

25. `"metrics": ["rmse", "pcc", "ccc"]`: 评估指标,包括均方根误差、皮尔逊相关系数和一致性相关系数。

26. `"save_plot": 0`: 是否保存图形结果,这里设为0表示否。

这些配置参数用于设置数据预处理、模型训练和评估过程中的各种选项和参数,确保流程能够顺利进行和有效执行。

from base.preprocessing import GenericDataPreprocessing  # 导入基础数据预处理类
from base.utils import expand_index_by_multiplier, load_pickle, save_to_pickle, get_filename_from_a_folder_given_extension, ensure_dir  # 导入一些辅助函数和工具
from base.label_config import *  # 导入标签配置

import os  # 导入os模块,用于系统级操作
import scipy.io as sio  # 导入scipy.io模块,用于读取.mat文件

import pandas as pd  # 导入pandas库,用于数据处理和分析
import numpy as np  # 导入numpy库,用于数值计算

import xml.etree.ElementTree as et  # 导入xml.etree.ElementTree模块,用于解析XML文件

generate_dataset.py

from base.preprocessing import GenericDataPreprocessing  # 导入基础数据预处理类
from base.utils import expand_index_by_multiplier, load_pickle, save_to_pickle, get_filename_from_a_folder_given_extension, ensure_dir  # 导入一些辅助函数和工具
from base.label_config import *  # 导入标签配置

import os  # 导入os模块,用于系统级操作
import scipy.io as sio  # 导入scipy.io模块,用于读取.mat文件

import pandas as pd  # 导入pandas库,用于数据处理和分析
import numpy as np  # 导入numpy库,用于数值计算

import xml.etree.ElementTree as et  # 导入xml.etree.ElementTree模块,用于解析XML文件


class Preprocessing(GenericDataPreprocessing):
    def __init__(self, config):
        super().__init__(config)

    def generate_iterator(self):
        # 生成迭代器,返回按照文件名排序的文件路径列表
        path = os.path.join(self.config['root_directory'], self.config['raw_data_folder'])
        iterator = [os.path.join(path, file) for file in sorted(os.listdir(path), key=float)]
        return iterator

    def generate_per_trial_info_dict(self):
        # 生成每个试验的信息字典
        per_trial_info_path = os.path.join(self.config['output_root_directory'], "processing_records.pkl")
        if os.path.isfile(per_trial_info_path):
            per_trial_info = load_pickle(per_trial_info_path)
        else:
            per_trial_info = {}
            pointer = 0

            sub_trial_having_continuous_label = self.get_sub_trial_info_for_continuously_labeled()
            all_continuous_labels = self.read_all_continuous_label()

            iterator = self.generate_iterator()

            for idx, file in enumerate(iterator):
                kwargs = {}
                this_trial = {}
                print(file)

                time_stamp_file = get_filename_from_a_folder_given_extension(file, "tsv", "All-Data")[0]
                video_trim_range = self.read_start_end_from_mahnob_tsv(time_stamp_file)
                if video_trim_range is not None:
                    this_trial['video_trim_range'] = self.read_start_end_from_mahnob_tsv(time_stamp_file)
                else:
                    this_trial['discard'] = 1
                    continue

                this_trial['has_continuous_label'] = 0
                session = int(file.split(os.sep)[-1])
                subject_no, trial_no = session // 130 + 1, session % 130

                if subject_no == sub_trial_having_continuous_label[pointer][0] and trial_no == sub_trial_having_continuous_label[pointer][1]:
                    this_trial['has_continuous_label'] = 1

                this_trial['continuous_label'] = None
                this_trial['annotated_index'] = None
                annotated_index = np.arange(this_trial['video_trim_range'][0][1])
                if this_trial['has_continuous_label']:
                    raw_continuous_label = all_continuous_labels[pointer]
                    this_trial['continuous_label'] = raw_continuous_label
                    annotated_index = self.process_continuous_label(raw_continuous_label)
                    this_trial['annotated_index'] = annotated_index
                    pointer += 1

                this_trial['has_eeg'] = 1
                eeg_path = get_filename_from_a_folder_given_extension(file, "bdf")
                if len(eeg_path) == 1:
                    this_trial['eeg_path'] = eeg_path[0].split(os.sep)
                else:
                    this_trial['eeg_path'] = None
                    this_trial['has_eeg'] = 0

                this_trial['audio_path'] = ""

                this_trial['subject_no'] = subject_no
                this_trial['trial_no'] = trial_no
                this_trial['trial'] = "P{}-T{}".format(str(subject_no), str(trial_no))

                this_trial['target_fps'] = 64

                kwargs['feature'] = "video"
                kwargs['has_continuous_label'] = this_trial['has_continuous_label']
                this_trial['video_annotated_index'] = self.get_annotated_index(annotated_index, **kwargs)

                this_trial['class_label'] = get_filename_from_a_folder_given_extension(file, "xml")[0]
                per_trial_info[idx] = this_trial

        ensure_dir(per_trial_info_path)
        save_to_pickle(per_trial_info_path, per_trial_info)
        self.per_trial_info = per_trial_info

    def generate_dataset_info(self):
        # 生成数据集信息
        class_label = {}
        for idx, record in self.per_trial_info.items():
            self.dataset_info['trial'].append(record['processing_record']['trial'])
            self.dataset_info['trial_no'].append(record['trial_no'])
            self.dataset_info['subject_no'].append(record['subject_no'])
            self.dataset_info['has_continuous_label'].append(record['has_continuous_label'])
            self.dataset_info['has_eeg'].append(record['has_eeg'])

            if record['has_continuous_label']:
                self.dataset_info['length'].append(len(record['continuous_label']))
            else:
                self.dataset_info['length'].append(len(record['video_annotated_index']) // 16)

            if self.config['extract_class_label']:
                class_label.update({record['processing_record']['trial']: self.extract_class_label_fn(record)})

        self.dataset_info['multiplier'] = self.config['multiplier']
        self.dataset_info['data_folder'] = self.config['npy_folder']

        path = os.path.join(self.config['output_root_directory'], 'dataset_info.pkl')
        save_to_pickle(path, self.dataset_info)

        if self.config['extract_class_label']:
            path = os.path.join(self.config['output_root_directory'], 'class_label.pkl')
            save_to_pickle(path, class_label)

    def extract_class_label_fn(self, record):
        # 提取类别标签
        class_label = {}
        if record['has_eeg']:
            xml_file = et.parse(record['class_label']).getroot()
            felt_emotion = xml_file.find('.').attrib['feltEmo']
            felt_arousal = xml_file.find('.').attrib['feltArsl']
            felt_valence = xml_file.find('.').attrib['feltVlnc']

            arousal = 0 if float(felt_arousal) <= 5 else 1
            valence = 0 if float(felt_valence) <= 5 else 1

            class_label = {
                "Arousal": arousal,
                "Valence": valence,
                "Arousal_3cls": arousal_class_to_number[emotion_tag_to_arousal_class[number_to_emotion_tag_dict[felt_emotion]]],
                "Valence_3cls": valence_class_to_number[emotion_tag_to_valence_class[number_to_emotion_tag_dict[felt_emotion]]]
            }

        return class_label

    def extract_continuous_label_fn(self, idx, npy_folder):
        # 提取连续标签
        if self.per_trial_info[idx]["has_continuous_label"]:
            raw_continuous_label = self.per_trial_info[idx]['continuous_label']

            if self.config['save_npy']:
                filename = os.path.join(npy_folder, "continuous_label.npy")
                if not os.path.isfile(filename):
                    ensure_dir(filename)
                    np.save(filename, raw_continuous_label)

    def load_continuous_label(self, path, **kwargs):
        # 加载连续标签
        cols = [emotion.lower() for emotion in self.config['emotion_list']]

        if os.path.isfile(path):
            continuous_label = pd.read_csv(path, sep=";",
                                           skipinitialspace=True, usecols=cols,
                                           index_col=False).values.squeeze()
        else:
            continuous_label = 0

        return continuous_label

    def get_annotated_index(self, annotated_index, **kwargs):
        # 获取标注索引
        feature = kwargs['feature']
        multiplier = self.config['multiplier'][feature]

        if kwargs['has_continuous_label']:
            annotated_index = expand_index_by_multiplier(annotated_index, multiplier)
        else:
            pass

        return annotated_index

    def get_sub_trial_info_for_continuously_labeled(self):
        # 获取连续标签的子试验信息
        label_file = os.path.join(self.config['root_directory'], "lable_continous_Mahnob.mat")
        mat_content = sio.loadmat(label_file)
        sub_trial_having_continuous_label = mat_content['trials_included']

        return sub_trial_having_continuous_label

    @staticmethod
    def read_start_end_from_mahnob_tsv(tsv_file):
        # 从Mahnob的tsv文件中读取起始和结束时间
        if os.path.isfile(tsv_file):
            data = pd.read_csv(tsv_file, sep='\t', skiprows=23)
            end = data[data['Event'] == 'MovieEnd'].index[0]
            start_end = [(0, end)]
        else:
            start_end = None
        return start_end

    def read_all_continuous_label(self):
        # 读取所有连续标签
        label_file = os.path.join(self.config['root_directory'], "lable_continous_Mahnob.mat")
        mat_content = sio.loadmat(label_file)
        annotation_cell = np.squeeze(mat_content['labels'])

        label_list = []
        for index in range(len(annotation_cell)):
            label_list.append(annotation_cell[index].T)
        return label_list

    @staticmethod
    def init_dataset_info():
        # 初始化数据集信息
        dataset_info = {
            "trial": [],
            "subject_no": [],
            "trial_no": [],
            "length": [],
            "has_continuous_label": [],
            "has_eeg": [],
        }
        return dataset_info


if __name__ == "__main__":
    from configs import config

    pre = Preprocessing(config)
    pre.generate_per_trial_info_dict()
    pre.prepare_data()

这段代码定义了一个名为Preprocessing的类,继承自GenericDataPreprocessing类,用于数据预处理。它包含了一些方法和函数,用于生成每个试验的信息字典、生成数据集信息、提取类别标签、提取连续标签等操作。在if __name__ == "__main__":部分,创建了Preprocessing对象,并调用了相关方法进行数据预处理。

main.py

from base.preprocessing import GenericDataPreprocessing  # 导入自定义的GenericDataPreprocessing类
from base.utils import expand_index_by_multiplier, load_pickle, save_to_pickle, get_filename_from_a_folder_given_extension, ensure_dir  # 导入一些辅助函数和工具
from base.label_config import *  # 导入标签配置

import os  # 导入os模块,用于文件和目录操作
import scipy.io as sio  # 导入scipy.io模块,用于读取MATLAB文件

import pandas as pd  # 导入pandas库,用于数据处理
import numpy as np  # 导入numpy库,用于数值计算

import xml.etree.ElementTree as et  # 导入xml.etree.ElementTree模块,用于解析XML文件


class Preprocessing(GenericDataPreprocessing):
    def __init__(self, config):
        super().__init__(config)

    def generate_iterator(self):
        path = os.path.join(self.config['root_directory'], self.config['raw_data_folder'])
        iterator = [os.path.join(path, file) for file in sorted(os.listdir(path), key=float)]
        return iterator

    def generate_per_trial_info_dict(self):
        # 生成每个试验的信息字典

        per_trial_info_path = os.path.join(self.config['output_root_directory'], "processing_records.pkl")
        if os.path.isfile(per_trial_info_path):
            per_trial_info = load_pickle(per_trial_info_path)
        else:
            per_trial_info = {}
            pointer = 0

            sub_trial_having_continuous_label = self.get_sub_trial_info_for_continuously_labeled()
            all_continuous_labels = self.read_all_continuous_label()

            iterator = self.generate_iterator()

            for idx, file in enumerate(iterator):
                kwargs = {}
                this_trial = {}
                print(file)

                time_stamp_file = get_filename_from_a_folder_given_extension(file, "tsv", "All-Data")[0]
                video_trim_range = self.read_start_end_from_mahnob_tsv(time_stamp_file)
                if video_trim_range is not None:
                    this_trial['video_trim_range'] = self.read_start_end_from_mahnob_tsv(time_stamp_file)
                else:
                    this_trial['discard'] = 1
                    continue

                this_trial['has_continuous_label'] = 0
                session = int(file.split(os.sep)[-1])
                subject_no, trial_no = session // 130 + 1, session % 130

                if subject_no == sub_trial_having_continuous_label[pointer][0] and trial_no == sub_trial_having_continuous_label[pointer][1]:
                    this_trial['has_continuous_label'] = 1

                this_trial['continuous_label'] = None
                this_trial['annotated_index'] = None
                annotated_index = np.arange(this_trial['video_trim_range'][0][1])
                if this_trial['has_continuous_label']:
                    raw_continuous_label = all_continuous_labels[pointer]
                    this_trial['continuous_label'] = raw_continuous_label
                    annotated_index = self.process_continuous_label(raw_continuous_label)
                    this_trial['annotated_index'] = annotated_index
                    pointer += 1

                this_trial['has_eeg'] =  1
                eeg_path = get_filename_from_a_folder_given_extension(file, "bdf")
                if len(eeg_path) == 1:
                    this_trial['eeg_path'] = eeg_path[0].split(os.sep)
                else:
                    this_trial['eeg_path'] = None
                    this_trial['has_eeg'] = 0

                this_trial['audio_path'] = ""

                this_trial['subject_no'] = subject_no
                this_trial['trial_no'] = trial_no
                this_trial['trial'] = "P{}-T{}".format(str(subject_no), str(trial_no))

                this_trial['target_fps'] = 64

                kwargs['feature'] = "video"
                kwargs['has_continuous_label'] = this_trial['has_continuous_label']
                this_trial['video_annotated_index'] = self.get_annotated_index(annotated_index, **kwargs)

                this_trial['class_label'] = get_filename_from_a_folder_given_extension(file, "xml")[0]
                per_trial_info[idx] = this_trial

        ensure_dir(per_trial_info_path)
        save_to_pickle(per_trial_info_path, per_trial_info)
        self.per_trial_info = per_trial_info

    def generate_dataset_info(self):
        # 生成数据集信息

        class_label = {}
        for idx, record in self.per_trial_info.items():
            self.dataset_info['trial'].append(record['processing_record']['trial'])
            self.dataset_info['trial_no'].append(record['trial_no'])
            self.dataset_info['subject_no'].append(record['subject_no'])
            self.dataset_info['has_continuous_label'].append(record['has_continuous_label'])
            self.dataset_info['has_eeg'].append(record['has_eeg'])

            if record['has_continuous_label']:
                self.dataset_info['length'].append(len(record['continuous_label']))
            else:
                self.dataset_info['length'].append(len(record['video_annotated_index']) // 16)

            if self.config['extract_class_label']:
                class_label.update({record['processing_record']['trial']: self.extract_class_label_fn(record)})

        self.dataset_info['multiplier'] = self.config['multiplier']
        self.dataset_info['data_folder'] = self.config['npy_folder']

        path = os.path.join(self.config['output_root_directory'], 'dataset_info.pkl')
        save_to_pickle(path, self.dataset_info)

        if self.config['extract_class_label']:
            path = os.path.join(self.config['output_root_directory'], 'class_label.pkl')
            save_to_pickle(path, class_label)

    def extract_class_label_fn(self, record):
        # 提取类别标签的函数

        class_label = {}
        if record['has_eeg']:
            xml_file = et.parse(record['class_label']).getroot()
            felt_emotion = xml_file.find('.').attrib['feltEmo']
            felt_arousal = xml_file.find('.').attrib['feltArsl']
            felt_valence = xml_file.find('.').attrib['feltVlnc']

            arousal = 0 if float(felt_arousal) <= 5 else 1
            valence = 0 if float(felt_valence) <= 5 else 1

            class_label = {
                "Arousal": arousal,
                "Valence": valence,
                "Arousal_3cls": arousal_class_to_number[emotion_tag_to_arousal_class[number_to_emotion_tag_dict[felt_emotion]]],
                "Valence_3cls": valence_class_to_number[emotion_tag_to_valence_class[number_to_emotion_tag_dict[felt_emotion]]]
            }

        return class_label

    def extract_continuous_label_fn(self, idx, npy_folder):
        # 提取连续标签的函数

        if self.per_trial_info[idx]["has_continuous_label"]:
            raw_continuous_label = self.per_trial_info[idx]['continuous_label']

            if self.config['save_npy']:
                filename = os.path.join(npy_folder, "continuous_label.npy")
                if not os.path.isfile(filename):
                    ensure_dir(filename)
                    np.save(filename, raw_continuous_label)

    def load_continuous_label(self, path, **kwargs):
        # 加载连续标签

        cols = [emotion.lower() for emotion in self.config['emotion_list']]

        if os.path.isfile(path):
            continuous_label = pd.read_csv(path, sep=";",
                                           skipinitialspace=True, usecols=cols,
                                           index_col=False).values.squeeze()
        else:
            continuous_label = 0

        return continuous_label

    def get_annotated_index(self, annotated_index, **kwargs):
        # 获取标注索引

        feature = kwargs['feature']
        multiplier = self.config['multiplier'][feature]

        if kwargs['has_continuous_label']:
            annotated_index = expand_index_by_multiplier(annotated_index, multiplier)
        else:
            pass

        return annotated_index

    def get_sub_trial_info_for_continuously_labeled(self):
        # 获取具有连续标签的子试验信息

        label_file = os.path.join(self.config['root_directory'], "lable_continous_Mahnob.mat")
        mat_content = sio.loadmat(label_file)
        sub_trial_having_continuous_label = mat_content['trials_included']

        return sub_trial_having_continuous_label

    @staticmethod
    def read_start_end_from_mahnob_tsv(tsv_file):
        # 从Mahnob TSV文件中读取起始和结束时间

        if os.path.isfile(tsv_file):
            data = pd.read_csv(tsv_file, sep='\t', skiprows=23)
            end = data[data['Event'] == 'MovieEnd'].index[0]
            start_end = [(0, end)]
        else:
            start_end = None
        return start_end

    def read_all_continuous_label(self):
        # 读取所有连续标签

        label_file = os.path.join(self.config['root_directory'], "lable_continous_Mahnob.mat")
        mat_content = sio.loadmat(label_file)
        annotation_cell = np.squeeze(mat_content['labels'])

        label_list = []
        for index in range(len(annotation_cell)):
            label_list.append(annotation_cell[index].T)
        return label_list

    @staticmethod
    def init_dataset_info():
        # 初始化数据集信息

        dataset_info = {
            "trial": [],
            "subject_no": [],
            "trial_no": [],
            "length": [],
            "has_continuous_label": [],
            "has_eeg": [],
        }
        return dataset_info


if __name__ == "__main__":
    from configs import config  # 导入配置文件

    pre = Preprocessing(config)  # 创建Preprocessing对象,传入配置文件
    pre.generate_per_trial_info_dict()  # 生成每个试验的信息字典
    pre.prepare_data()  # 准备数据
这段代码是一个数据预处理的类Preprocessing,继承自GenericDataPreprocessing。它包含了一些方法用于生成每个试验的信息字典、生成数据集信息、提取类别标签和连续标签等操作。在__main__函数中,创建了一个Preprocessing对象,并调用了相关方法进行数据预处理。

加入其他数据集中

```python
from model import MASA_TCN  # 从model模块中导入MASA_TCN模型

data = torch.randn(1, 1, 192, 96)  # 生成一个随机张量作为输入数据,形状为(batch_size=1, cnn_channel=1, EEG_channel*feature=32*6, data_sequence=96)

# 对于回归任务,输出形状为(batch_size, data_sequence, 1)。
net = MASA_TCN(
        cnn1d_channels=[128, 128, 128],  # 1维卷积层的通道数列表
        cnn1d_kernel_size=[3, 5, 15],  # 1维卷积层的核大小列表
        cnn1d_dropout_rate=0.1,  # 1维卷积层的dropout率
        num_eeg_chan=32,  # EEG通道数
        freq=6,  # 特征频率
        output_dim=1,  # 输出维度
        early_fusion=True,  # 是否使用早期融合
        model_type='reg')  # 模型类型为回归
preds = net(data)  # 对输入数据进行预测

# 对于分类任务,输出形状为(batch_size, num_classes)。注意:output_dim应该是类别的数量。
net = MASA_TCN(
        cnn1d_channels=[128, 128, 128],  # 1维卷积层的通道数列表
        cnn1d_kernel_size=[3, 5, 15],  # 1维卷积层的核大小列表
        cnn1d_dropout_rate=0.1,  # 1维卷积层的dropout率
        num_eeg_chan=32,  # EEG通道数
        freq=6,  # 特征频率
        output_dim=2,  # 输出维度
        early_fusion=True,  # 是否使用早期融合
        model_type='cls')  # 模型类型为分类
preds = net(data)  # 对输入数据进行预测
```

这段代码首先导入了MASA_TCN模型,然后创建了一个随机输入数据,并使用MASA_TCN模型进行了回归和分类任务的预测。注释已经在代码中添加了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/28772.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙开发:【线程模型】

线程模型 线程类型 Stage模型下的线程主要有如下三类&#xff1a; 主线程 执行UI绘制。管理主线程的ArkTS引擎实例&#xff0c;使多个UIAbility组件能够运行在其之上。管理其他线程的ArkTS引擎实例&#xff0c;例如使用TaskPool&#xff08;任务池&#xff09;创建任务或取消…

数据持久层与 SQL 注入

使用 ORM&#xff08;Object/Relation Mapping&#xff09;框架对 SQL 注入是有积极意义的。我们知道对抗 SQL 注入的最佳方式就是使用“预编译绑定变量”。在实际解决 SQL 注入时&#xff0c;还有一个难点就 是应用复杂后&#xff0c;代码数量庞大&#xff0c;难以把可能存在 …

示例:WPF中应用MarkupExtention自定义IValueConverter

一、目的&#xff1a;应用MarkupExtention定义IValueConverter&#xff0c;使得应用起来更简单和高效 二、实现 public abstract class MarkupValueConverterBase : MarkupExtension, IValueConverter{public abstract object Convert(object value, Type targetType, object …

数字化转型对企业有什么价值?

数字化转型对企业有什么价值&#xff1f; 1. 信息共享 很多业务设计和管理规划&#xff0c;通常需要综合多个业务部门和业务专题的数据。 如果企业的数据和信息在位置分布上非常分散&#xff0c;就很难充分利用企业积累的数据资源&#xff0c;并将其用于有效的管理决策和业务…

《java 编程基础》试题

2023-2024 学年第二学期《java 编程基础》试题 系 班 姓名 学号 &#xff08;说明&#xff1a;本试卷共五大题&#xff0c;共 2 页&#xff0c;满分100分&#xff0c;答题时间90分钟。&#xff09; 开卷考试&#xff1a;要求书写答案在纸上。 一、单…

零基础直接上手java跨平台桌面程序,使用javafx(七)用户操作界面探讨,这个系列结束

GUI&#xff0c;我们还是喜欢web。如果javafx有像wpf的WebView2差不多的功能&#xff0c;我们就开心了scene builder中拖出一个webview&#xff0c;然后再回到代码中。发现<?import javafx.scene.web.*?>是红色的&#xff0c;我们缺少配置。于是在pom.xml中添加JavaFX依…

Spring (63)CORS,如何在Spring中配置它

CORS简介 CORS&#xff08;Cross-Origin Resource Sharing&#xff0c;跨源资源共享&#xff09;是一种机制&#xff0c;它使用额外的HTTP头来告诉浏览器让运行在一个origin&#xff08;源&#xff09;上的Web应用被准许访问来自不同源服务器上的指定资源。当一个资源从与该资…

Google ghOSt 调度器分析(2)

调度器分析 *ghOSt* 调度干预过程1. 内核相关介绍2. 干预过程ghOSt 调度干预过程 1. 内核相关介绍 下面先来介绍以下 ghost 调度类相关的内容。 ghost 调度类 ghost 调度器在内核中新建了两个调度类: ghost_agent 调度类 ghost 调度类 ghost 调度类与其他调度类的优先级关系…

Nodejs 第七十七章(MQ高级)

MQ介绍和基本使用在75章介绍过了&#xff0c;不再重复 MQ高级用法-延时消息 什么是延时消息? Producer 将消息发送到 MQ 服务端&#xff0c;但并不期望这条消息立马投递&#xff0c;而是延迟一定时间后才投递到 Consumer 进行消费&#xff0c;该消息即延时消息 插件安装 R…

【康复学习--LeetCode每日一题】521. 最长特殊序列 Ⅰ

题目&#xff1a; 给你两个字符串 a 和 b&#xff0c;请返回 这两个字符串中 最长的特殊序列 的长度。如果不存在&#xff0c;则返回 -1 。 「最长特殊序列」 定义如下&#xff1a;该序列为 某字符串独有的最长 子序列 &#xff08;即不能是其他字符串的子序列&#xff09; 。…

[C++] 从零实现一个ping服务

&#x1f4bb;文章目录 前言ICMP概念报文格式 Ping服务实现系统调用函数具体实现运行测试 总结 前言 ping命令&#xff0c;因为其简单、易用等特点&#xff0c;几乎所有的操作系统都内置了一个ping命令。如果你是一名C初学者&#xff0c;对网络编程、系统编程有所了解&#xff…

徐州BGP服务器租用的好处有哪些?

BGP是一种路径矢量协议&#xff0c;能够维护不同主机、网络和网关的路由器的路径&#xff0c;并且可以根据BGP做出路由决定&#xff0c;将电信和联通等线路通过BGP互连技术&#xff0c;把不同的线路融合在一起。其中BGP服务器则是一种用于不同主机和互联网之间传输数据和信息的…

ijkplayer编译 android版本

ijkplayer源码下载地址如下&#xff1a;https://github.com/bilibili/ijkplayer 下载代码&#xff0c;直接执行如下命令即可&#xff1a; $cd /data/project/ijkplayer/ $git clone https://github.com/bilibili/ijkplayer.git $git checkout -B latest k0.8.8 1 环境安装 …

学会python——读取大文本文件(python实例六)

目录 1、认识Python 2、环境与工具 2.1 python环境 2.2 Visual Studio Code编译 3、读取大文本文件 3.1 代码构思 3.2 代码示例 3.3 运行结果 4、总结 1、认识Python Python 是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。 Python 的设计具有很强…

了解JS递归

在JavaScript中&#xff0c;递归是一个非常重要的概念&#xff0c;它允许函数在其定义内部调用自身。递归在处理许多类型的问题时非常有用&#xff0c;尤其是那些可以通过分解成更小、更简单的子问题来解决的问题。然而&#xff0c;递归也需要谨慎使用&#xff0c;因为它可能导…

电脑内存怎么看?5个秘诀,轻松查看内存!

“新买了一台电脑&#xff0c;想查看一下我电脑的内存&#xff0c;大家可以分享一下查看方法吗&#xff1f;” 当我们谈论电脑的性能时&#xff0c;内存无疑是一个不容忽视的关键组件。然而&#xff0c;对于许多普通用户来说&#xff0c;如何查看电脑内存的大小、类型以及使用情…

跳舞电动机器人单片机方案

这款机器人形状智能电子玩具是一款集娱乐、教育和互动于一身的高科技产品。它的主要功能包括&#xff1a; 1、智能对话&#xff1a;机器人可以进行简单的对话&#xff0c;回答用户的问题&#xff0c;提供有趣的互动体验。 2、前进、后退、左转、右转、滑行&#xff1a;机器人…

企业级-封装Java对内卷PDF利用关键字分页导出标题

提供 PDF 文件 File入参&#xff0c;根据需要将其中内卷文件需要分页利用关键字读取分页&#xff0c;转成 XML。 使用 依赖&#xff1a;itextpdf、pdfbox 1、导入依赖 <dependency><groupId>org.apache.pdfbox</groupId><artifactId>pdfbox</arti…

BERT报错记录

一、加载数据集下载失败 报错&#xff1a; TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应&#xff0c;连接尝试失败。urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x00000241F9AD4…

Element UI 一键校验多表单(v-for循环表单,异步校验规则,v-for 中的 ref 属性,避坑 forEach 不支持异步 await )

需求描述 表单为数组 v-for 循环得到的多表单&#xff0c;如可自由增删的动态表单表单中存在异步校验规则&#xff0c;如姓名需访问接口校验是否已存在点击提交按钮&#xff0c;需一键校验所有表单&#xff0c;仅当所有表单都通过校验&#xff0c;才能最终提交到后台 效果预览 …