项目结构
假设我们的项目结构如下:
my_project/
|-- dataset/
| |-- __init__.py
| |-- imbalance_cifar.py
| |-- balance_cifar.py
|-- main.py
代码示例
1. dataset/imbalance_cifar.py
# dataset/imbalance_cifar.pyclass IMBALANCECIFAR10:def __init__(self, mode, cfg):self.mode = modeself.cfg = cfgprint(f"Initialized IMBALANCECIFAR10 with mode: {mode} and cfg: {cfg}")def get_annotations(self):return ["annotation1", "annotation2"]def get_num_classes(self):return 10
2. dataset/balance_cifar.py
# dataset/balance_cifar.pyclass BALANCECIFAR10:def __init__(self, mode, cfg):self.mode = modeself.cfg = cfgprint(f"Initialized BALANCECIFAR10 with mode: {mode} and cfg: {cfg}")def get_annotations(self):return ["annotation3", "annotation4"]def get_num_classes(self):return 10
3. dataset/__init__.py
# dataset/__init__.pyfrom .imbalance_cifar.py import *
from .balance_cifar.py import *
4. main.py
# main.pyfrom dataset import *# 模拟配置文件中的类名字符串
cfg = {"DATASET": {"DATASET": "IMBALANCECIFAR10"}
}# 动态实例化类
dataset_class = eval(cfg["DATASET"]["DATASET"])
train_set = dataset_class("train", cfg)
valid_set = dataset_class("valid", cfg)# 调用方法
annotations = train_set.get_annotations()
num_classes = train_set.get_num_classes()print("Annotations:", annotations)
print("Number of classes:", num_classes)
注意:
如果没有 __init__.py
文件,使用 from ... import ...
的方式导入模块将会失败。