overlap_filter
overlap_filter
函数的通俗含义是根据给定的掩码(filter_mask
),逐层地过滤掉主掩码(mask
)中的某些值。具体来说,该函数会从最后一个通道开始,逐层检查并根据对应层的过滤掩码,将前面的通道中对应位置的值设置为零。
结果没有完全看懂
import numpy as npdef overlap_filter(mask, filter_mask):# 获取 mask 的通道数、高度和宽度C, _, _ = mask.shape# 从最后一个通道开始,逐个通道向前遍历for c in range(C - 1, -1, -1):# 创建一个过滤器,根据 filter_mask 的当前通道是否不为 0 来确定filter = np.repeat((filter_mask[c] != 0)[None, :], c, axis=0)# 将 mask 中前 c 个通道对应位置的值设置为 0mask[:c][filter] = 0# 返回修改后的 maskreturn mask# 创建整齐的示例数据
C, H, W = 4, 2, 2 # 通道数、高度和宽度
mask = np.zeros((C, H, W), dtype=np.uint8)
filter_mask = np.zeros((C, H, W), dtype=np.uint8)# 为 mask 填充整齐的数据
for c in range(C):mask[c] = c + 1 # 每个通道填充相同的值,方便观察# 为 filter_mask 填充整齐的数据
filter_mask[0, 0, 0] = 1 # 仅在第一个通道的左上角设置为 1
filter_mask[1, :, 1] = 1 # 在第二个通道的右边一列设置为 1
filter_mask[2, 1, :] = 1 # 在第三个通道的底边一行设置为 1print("Original mask:")
print(mask)
print("\nFilter mask:")
print(filter_mask)# 调用 overlap_filter 函数
filtered_mask = overlap_filter(mask.copy(), filter_mask)print("\nFiltered mask:")
print(filtered_mask)
结果:
Original mask:
[[[1 1]
[1 1]]
[[2 2]
[2 2]]
[[3 3]
[3 3]]
[[4 4]
[4 4]]]
Filter mask:
[[[1 0]
[0 0]]
[[0 1]
[0 1]]
[[0 0]
[1 1]]
[[0 0]
[0 0]]]
Filtered mask:
[[[1 0]
[0 0]]
[[2 2]
[0 0]]
[[3 3]
[3 3]]
[[4 4]
[4 4]]]