文章目录
- 简介
- 数据集
- 划分数据集
简介
记录深度学习编写程序过程中的一些工具函数
数据集
划分数据集
数据集划分思路:
- 若数据集很小,直接随机打乱
import random random.shuffle(data)
- 若数据集很大,选择随机打乱下标,根据下标实现数据集划分
-
get_dataset_split_num
无需输入训练集,只输入验证集和测试集的比例或具体数量
def get_dataset_split_num(n, valid=0, test=0):"""n: 数据集数量valid, test: 可为比例和具体数值"""if valid < 1:assert test < 1assert valid + test > 0valid_num = int(n * valid)test_num = int(n * test)train_num = n - valid_num - test_numelse:valid_num = validtest_num = testtrain_num = n - valid_num - test_numreturn train_num, valid_num, test_num
运行:
train_num, valid_num, test_num = get_dataset_split_num(100, valid=0.2, test=0.31) train_num, valid_num, test_num = get_dataset_split_num(100, valid=20, test=31)
-
cut_datasets
数据集打乱def cut_datasets(arr, valid=0, test=0):"""arr: 为下标数组"""train_num, valid_num, _ = get_dataset_split_num(len(arr), valid, test)a1 = arr[:train_num]a2 = arr[train_num:train_num + valid_num]a3 = arr[train_num + valid_num:]return a1, a2, a3