Dataset
- 是否需要自己定义:如果你使用的数据集不是 PyTorch 提供的标准数据集(如 MNIST、CIFAR-10 等),那么你需要继承
torch.utils.data.Dataset
类并实现两个方法:__len__()
和__getitem__()
。 __len__()
应该返回数据集的总大小。__getitem__()
应该根据索引返回一个数据样本。
DataLoader
- 是否需要自己定义:
DataLoader
不需要自己定义,它是 PyTorch 提供的一个类,用于包装Dataset
并在数据集上提供迭代功能。它支持批量处理、打乱数据、多线程加载等。 - 使用
DataLoader
时,你可以指定批处理大小(batch_size
)、是否打乱数据(shuffle
)、数据加载的线程数(num_workers
)等。
model定义【继承nn.module父类】
forward:input--forward-->output
forward(self,x)中x表示输入,即x->卷积->relu->卷积-->relu-->输出
class HeightPredictor(nn.Module):def __init__(self):super(HeightPredictor, self).__init__()self.conv1 = nn.Conv2d(1,20,5)self.conv2 = nn.Conv2d(20,20,5)def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))
Dict
building_info = {}【dict,key--value】
这是一个字典(dictionary)的创建语句。在Python中,字典是一种可变的、无序的、键值对(key-value pairs)的集合。每个键(key)都是唯一的,且必须是不可变的类型(如字符串、数字或元组),而值(value)可以是任何类型的数据。字典通过键来访问对应的值,提供了快速查找和插入的能力。
特殊:defaultdict:defaultdict
是Python标准库collections
模块中的一个类。defaultdict
与普通字典类似,但它在创建时提供了一个默认工厂函数【比如defaultdict(list):当访问一个不存在的键时,defaultdict
会自动为该键创建一个空列表作为默认值。】,当尝试访问一个不存在的键时,defaultdict
会自动为该键创建一个默认值,而不会抛出KeyError
。
整理csv
df = pd.read_csv(file_path, encoding="utf-8")#读取csv
#根据某个属性分组
area_bins = [0, 100, 200, 300, 400, np.inf]
area_labels = [f"{left}-{right}" if right != np.inf else f">{left}"
for left, right in zip(area_bins[:-1], area_bins[1:])]
df['area_bins'] = pd.cut(df['area'], bins=area_bins, labels=area_labels)
methods = ["a","b"]
attributes = ['material', 'height_bin']
for attr in attributes:
results = []
for method in methods:
Num_col = f"{method}_Num"
predict_col = f"{method}_predict"
if Num_col not in df.columns or predict_col not in df.columns:
print(f"跳过 {method},缺少必要列")
continue
valid_data = df[['true_Num', Num_col, predict_col, attr]].dropna()
if valid_data.empty:
print(f"{method} 在属性 {attr} 下无有效数据")
continue
# 计算完整指标
grouped = valid_data.groupby(attr).apply(
lambda x: pd.Series({
'Ori_RMSE': np.sqrt(mean_squared_error(x['true_Num'], x[Num_col])),
'Pred_RMSE': np.sqrt(mean_squared_error(x['true_Num'], x[predict_col])),
'Ori_MAE': mean_absolute_error(x['true_Num'], x[Num_col]),
'Pred_MAE': mean_absolute_error(x['true_Num'], x[predict_col]),
'Group_Size': len(x),
'Sample_Optimized': np.sum(
np.abs(x[Num_col] - x['true_Num']) >
np.abs(x[predict_col] - x['true_height'])
)
})
).reset_index()
grouped['method'] = method
results.append(grouped)
if not results:
print(f"属性 {attr} 无数据,跳过")
continue
# 合并结果
combined_df = pd.concat(results, ignore_index=True)
# 生成透视表
pivot_df = combined_df.pivot(
index=attr,
columns='method',
values=['Ori_RMSE', 'Pred_RMSE', 'Ori_MAE', 'Pred_MAE']
)
# 扁平化列名并填充NaN
pivot_df.columns = [f"{method}_{metric}" for metric, method in pivot_df.columns]
pivot_df = pivot_df.fillna(0)
# 保存到CSV
csv_path = os.path.join(output_dir, f"{attr}.csv")
pivot_df.reset_index().to_csv(csv_path, index=False)
实现了分别对每个方法依据不同属性评估的功能