不带标签
from collections import Counter
import numpy as npimport torch
from torch import nn, optim
from torch.utils.data import DataLoader
# from vit_pytorch.cross_vit import CrossViT
from medmnist import DermaMNIST
from medmnist import INFO
import torchvision.transforms as transforms
from sklearn.metrics import classification_report
import numpy as npimport osprint(os.getcwd())
# 设置环境变量,指定下载路径
# os.environ['MEDMNIST_DATA_PATH'] = os.getcwd()# 数据集和数据加载器
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),
])train_dataset = DermaMNIST(split="train", transform=transform, download=True, size=224, root = './dermamnist')
val_dataset = DermaMNIST(split="val", transform=transform, download=True, size=224, root = './dermamnist')
test_dataset = DermaMNIST(split="test", transform=transform, download=True, size=224, root = './dermamnist')# 获取标签并转换为一维列表
train_labels = np.ravel(train_dataset.labels).tolist()
val_labels = np.ravel(val_dataset.labels).tolist()
test_labels = np.ravel(test_dataset.labels).tolist()# 统计每个类别的数量
train_counts = Counter(train_labels)
val_counts = Counter(val_labels)
test_counts = Counter(test_labels)# 将结果转为纯数量列表
train_counts_list = [count for _, count in sorted(train_counts.items())]
val_counts_list = [count for _, count in sorted(val_counts.items())]
test_counts_list = [count for _, count in sorted(test_counts.items())]# 打印每个数据集的类别数量
print("Train Dataset Classes Count:", train_counts_list)
print("Validation Dataset Classes Count:", val_counts_list)
print("Test Dataset Classes Count:", test_counts_list)
带标签
import torch
from torch.utils.data import DataLoader
from medmnist import DermaMNIST
import torchvision.transforms as transforms
from collections import Counter
import numpy as np# 设置转换
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),
])# 加载数据集
train_dataset = DermaMNIST(split="train", transform=transform, download=True, size=224, root = '../../dermamnist')
val_dataset = DermaMNIST(split="val", transform=transform, download=True, size=224, root = '../../dermamnist')
test_dataset = DermaMNIST(split="test", transform=transform, download=True, size=224, root = '../../dermamnist')# 获取标签并转换为列表
train_labels = train_dataset.labels.flatten().tolist()
val_labels = val_dataset.labels.flatten().tolist()
test_labels = test_dataset.labels.flatten().tolist()# 统计每个类别的数量
train_counts = Counter(train_labels)
val_counts = Counter(val_labels)
test_counts = Counter(test_labels)# 将结果转为列表形式
train_counts_list = sorted(train_counts.items())
val_counts_list = sorted(val_counts.items())
test_counts_list = sorted(test_counts.items())# 打印每个数据集的类别数量
print("Train Dataset Classes Count:", train_counts_list)
print("Validation Dataset Classes Count:", val_counts_list)
print("Test Dataset Classes Count:", test_counts_list)