【ICCV2023论文阅读】XNet(能跑通代码)

这里写目录标题

  • 论文阅读
    • 摘要
    • 介绍
    • 方法
      • overview
      • why use wavelet transform?
      • 融合方法
      • 用于全监督分割和半监督分割可行性分析
    • 效果
    • 局限性
    • 总结
  • 代码跑通
    • 去掉分布式训练
    • 生成低频和高频图片
    • 产生数据集
    • 改读取数据的位置
    • 损失函数
    • 添加自己数据集的信息
    • 结果

ps:我现在不知道自己研究方向是做什么的,就是分割也试试,医疗诊断也试试。然后之后更的尽量把代码跑通也写上。因为之前代码水平有限不能很好的跑通,然后我只是说我这个数据集怎么改,以及我这个硬件水平下,你们看着改就可以。
论文地址
代码地址

论文阅读

摘要

现状

  1. 把全监督分割和半监督当作两种训练方式,很少有把它们统一起来的。(本文就把这两个统一起来了,就是个创新点)
  2. 很少有完全监督的模型关注图像的固有的低频信息和高频信息去提升性能。
  3. 半监督学习的扰动是人为添加的,可能引入不利的学习偏见。
    方法
    提出了一种基于小波的LF和HF融合模型XNet,它支持全监督和半监督语义分割,并在这两个领域都优于最先进的模型。

介绍

启发:对于语义分割问题,HF信息通常表示图像细节,LF信息通常是抽象与一。提取和融合不同频率信息的策略可以帮助模型更好地关注LF予以和HF细节,以提高性能。模型使用小波变换生成LF和HF图像,用于基于一致性差分的半监督学习。这些一致性差异源于对LF和HF信息的不同关注,这缓解了人工设计造成的学习偏差。
contributions:

  1. 提出了低频和高频融合模型XNet,在有监督和半监督上实现了优异的性能。
  2. XNet使用小波变化生成LF和HF图片来进行一致性学习,可以减轻人为扰动引起的学习偏差。
  3. 在两个2D和两个3D公共生物医学数据集上进行的广泛基准测试证实了XNet的有效性。

方法

在这里插入图片描述

overview

获取相应的LF和HF图片。然后将它们输入到LF和HF编码器以分别生成LF(语义)和HF(边缘、纹理)特征。之后使用融合模块对他们的特征进行融合。然后把融合特征放到解码器中获得LF和HF分支的预测结果。全监督损失是监督损失(两个分支的预测和真实值之间的损失,记为 L s u p L_{sup} Lsup)和标记图像的一致性损失(记为 L u n s u p L_{unsup} Lunsup)。半监督训练,最大限度减少标记图像的监督损失和未标记图像的双重输出的一致性损失。都是dice loss。
L u n s u p L_{unsup} Lunsup是由交叉伪标签监督损失实现,用一个分支的预测作为伪标签去监督另一个分支。 L u n s u p = L u n s u p H ( p i L , p ^ i H ) + L u n s u p L ( p i H , p ^ i L ) L_{unsup}=L_{unsup}^{H}(p^L_{i},\hat{p}^H_{i})+L_{unsup}^{L}(p^H_{i},\hat{p}^L_{i}) Lunsup=LunsupH(piL,p^iH)+LunsupL(piH,p^iL)
我们选择在训练阶段表现更好的分支作为推理过程中的最终输出。

why use wavelet transform?

在这里插入图片描述
与其他方法(如傅立叶变换)相比,小波变换是生成L和H的有效方法。使用L作为输入,XNet可以更多地关注LF语义,因为L具有较少的噪声和细节。相比之下,H具有更多的噪声,但对象边界更清晰,这可以帮助模型更多地关注HF细节。此外,使用L和H进行半监督训练,一致性差异来自图像的固有LF和HF信息,这可以缓解人工扰动引起的学习偏差。

融合方法

在这里插入图片描述
LF和HF融合模块的架构。相同大小Conv表示输出和输入特征具有相同大小。下采样Conv将输出特征的大小减少一半。上采样Conv使输出特征的大小加倍。Transition Conv使用信道级联特征作为输入和输出融合特征。
就是LF Feature1是第4层的feature,它进行一次不改变大小的卷积得到第一个有花纹的蓝色块也就是特征,它进行一次下采样得到下面那个小的特征。LF Feature2是第5层的feature,它进行一次上采样得到横杠的看色特征,进行一次不改变大小的卷积得到方块特征。其他同理,结合之后进行卷积获得和原来LF Feature1相同大小的特征图,然后进行U-Net那个skip connect即可。

用于全监督分割和半监督分割可行性分析

对于生物医学图像,我们假设原始图像I由LF特征FL、HF特征FH、LF加性噪声NL和HF加性噪声NH组成。因此, I I I被定义为:
I = F L + F H + N L + N H I = F_L+F_H+N_L+N_H I=FL+FH+NL+NH
因为生物医学图像中的噪声通常是加性的。对于语义分割问题,准确的分割需要LF语义(如形状、颜色等)和HF细节(如边缘、纹理等)。
对于监督学习,对完整信息进行解码可以获得分割预测。对于半监督学习,由于每个解码分支对LF和HF信息的关注程度不同,因此双分支解码器的预测在LF语义和HF细节方面存在差异。这些差异可用于基于一致性规则的半监督训练。
总之,XNet既可以用于全监督学习,也可以用于半监督学习。图显示了XNet分割过程的拓扑流程图。
在这里插入图片描述

效果

在这里插入图片描述

局限性

由于XNet强调HF信息,当图像几乎没有HF信息时,XNet的性能会受到负面影响。

总结

我们提出了一种基于小波的低频和高频融合模型XNet,该模型在生物医学图像的全监督和半监督语义分割方面都取得了最先进的性能。在2D和3D数据集上进行的大量实验证明了我们提出的模型的有效性。然而,XNet的局限性在于,当高频信息不可用时,其性能可能会受到负面影响。我们认为,完全监督和半监督的语义分割模型可以而且应该是统一的。我们希望我们的研究能为它们的统一提供一些例证和思考。

代码跑通

完全可以按作者的那个readme对自己数据集进行修改跑通。
我不了解分布式训练,所以一直报错。下面展示不用分布式训练的代码。

去掉分布式训练

下面是我把有分布式训练的地方都删了。(应该问题不大吧。(lll¬ω¬))

from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
import argparse
import time
import os
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.backends import cudnn
import random
from config.dataset_config.dataset_cfg import dataset_cfgfrom config.train_test_config.train_test_config import print_train_loss_XNet, print_val_loss, print_train_eval_XNet, print_val_eval, save_val_best_2d, draw_pred_XNet, print_best
from config.visdom_config.visual_visdom import visdom_initialization_XNet, visualization_XNet, visual_image_XNet
from config.warmup_config.warmup import GradualWarmupScheduler
from config.augmentation.online_aug import data_transform_2d, data_normalize_2d
from loss.loss_function import segmentation_loss
from models.getnetwork import get_network
from dataload.dataset_2d import imagefloder_iitnn
from warnings import simplefiltersimplefilter(action='ignore', category=FutureWarning)if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--path_trained_models', default='./checkpoints/sup_xnet')parser.add_argument('--path_seg_results', default='./seg_pred/sup_xnet')parser.add_argument('--path_dataset', default='自己数据集的根目录')parser.add_argument('--dataset_name', default='自己数据集的名称', help='CREMI, ISIC-2017, GlaS')parser.add_argument('--input1', default='L')parser.add_argument('--input2', default='H')parser.add_argument('--sup_mark', default='100')parser.add_argument('-b', '--batch_size', default=4, type=int)parser.add_argument('-e', '--num_epochs', default=200, type=int)parser.add_argument('-s', '--step_size', default=50, type=int)parser.add_argument('-l', '--lr', default=0.5, type=float)parser.add_argument('-g', '--gamma', default=0.5, type=float)parser.add_argument('-u', '--unsup_weight', default=5, type=float)parser.add_argument('--loss', default='dice', type=str)parser.add_argument('-w', '--warm_up_duration', default=20)parser.add_argument('--momentum', default=0.9, type=float)parser.add_argument('--wd', default=-5, type=float, help='weight decay pow')parser.add_argument('-i', '--display_iter', default=5, type=int)parser.add_argument('-n', '--network', default='xnet', type=str)parser.add_argument('--local_rank', default=-1, type=int)args = parser.parse_args()dataset_name = args.dataset_namecfg = dataset_cfg(dataset_name)print_num = 77 + (cfg['NUM_CLASSES'] - 3) * 14print_num_minus = print_num - 2print_num_half = int(print_num / 2 - 1)path_trained_models = args.path_trained_models + '/' + str(os.path.split(args.path_dataset)[1])path_seg_results = args.path_seg_results + '/' + str(os.path.split(args.path_dataset)[1])# Datasetif args.input1 == 'image':input1_mean = 'MEAN'input1_std = 'STD'else:input1_mean = 'MEAN_' + args.input1input1_std = 'STD_' + args.input1if args.input2 == 'image':input2_mean = 'MEAN'input2_std = 'STD'else:input2_mean = 'MEAN_' + args.input2input2_std = 'STD_' + args.input2data_transforms = data_transform_2d()data_normalize_1 = data_normalize_2d(cfg[input1_mean], cfg[input1_std])data_normalize_2 = data_normalize_2d(cfg[input2_mean], cfg[input2_std])dataset_train = imagefloder_iitnn(data_dir=args.path_dataset+'/train',input1=args.input1,input2=args.input2,data_transform_1=data_transforms['train'],data_normalize_1=data_normalize_1,data_normalize_2=data_normalize_2,sup=True,num_images=None,)dataset_val = imagefloder_iitnn(data_dir=args.path_dataset + '/val',input1=args.input1,input2=args.input2,data_transform_1=data_transforms['val'],data_normalize_1=data_normalize_1,data_normalize_2=data_normalize_2,sup=True,num_images=None,)dataloaders = dict()dataloaders['train'] = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8)dataloaders['val'] = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8)num_batches = {'train_sup': len(dataloaders['train']), 'val': len(dataloaders['val'])}# Modelmodel = get_network(args.network, cfg['IN_CHANNELS'], cfg['NUM_CLASSES'])model = model.cuda()# Training Strategycriterion = segmentation_loss(args.loss, False).cuda()optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=5*10**args.wd)exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.warm_up_duration, after_scheduler=exp_lr_scheduler)# Train & Valsince = time.time()count_iter = 0best_model = modelbest_result = 'Result1'best_val_eval_list = [0 for i in range(4)]for epoch in range(args.num_epochs):count_iter += 1if (count_iter - 1) % args.display_iter == 0:begin_time = time.time()model.train()train_loss_sup_1 = 0.0train_loss_sup_2 = 0.0train_loss_unsup = 0.0train_loss = 0.0val_loss_sup_1 = 0.0val_loss_sup_2 = 0.0unsup_weight = args.unsup_weight * (epoch + 1) / args.num_epochs# dist.barrier()for i, data in enumerate(dataloaders['train']):inputs_train_1 = Variable(data['image'].cuda())inputs_train_2 = Variable(data['image_2'].cuda())mask_train = Variable(data['mask'].cuda())optimizer.zero_grad()outputs_train1, outputs_train2 = model(inputs_train_1, inputs_train_2)torch.cuda.empty_cache()if count_iter % args.display_iter == 0:if i == 0:score_list_train1 = outputs_train1score_list_train2 = outputs_train2mask_list_train = mask_train# else:elif 0 < i <= num_batches['train_sup'] / 4:score_list_train1 = torch.cat((score_list_train1, outputs_train1), dim=0)score_list_train2 = torch.cat((score_list_train2, outputs_train2), dim=0)mask_list_train = torch.cat((mask_list_train, mask_train), dim=0)max_train1 = torch.max(outputs_train1, dim=1)[1]max_train2 = torch.max(outputs_train2, dim=1)[1]max_train1 = max_train1.long()max_train2 = max_train2.long()loss_train_sup1 = criterion(outputs_train1, mask_train)loss_train_sup2 = criterion(outputs_train2, mask_train)loss_train_unsup = criterion(outputs_train1, max_train2) + criterion(outputs_train2, max_train1)loss_train_unsup = loss_train_unsup * unsup_weightloss_train = loss_train_sup1 + loss_train_sup2 + loss_train_unsuploss_train.backward()optimizer.step()train_loss_sup_1 += loss_train_sup1.item()train_loss_sup_2 += loss_train_sup2.item()train_loss_unsup += loss_train_unsup.item()train_loss += loss_train.item()scheduler_warmup.step()# torch.cuda.empty_cache()if count_iter % args.display_iter == 0:print('=' * print_num)print('| Epoch {}/{}'.format(epoch + 1, args.num_epochs).ljust(print_num_minus, ' '), '|')train_epoch_loss_sup1, train_epoch_loss_sup2, train_epoch_loss_cps, train_epoch_loss = print_train_loss_XNet(train_loss_sup_1, train_loss_sup_2, train_loss_unsup, train_loss, num_batches, print_num,print_num_half)# print(score_list_train1)# print(score_list_train2)train_eval_list1, train_eval_list2, train_m_jc1, train_m_jc2 = print_train_eval_XNet(cfg['NUM_CLASSES'], score_list_train1, score_list_train2, mask_list_train, print_num_half)torch.cuda.empty_cache()with torch.no_grad():model.eval()for i, data in enumerate(dataloaders['val']):# if 0 <= i <= num_batches['val']:inputs_val = Variable(data['image'].cuda())inputs_val_wavelet = Variable(data['image_2'].cuda())mask_val = Variable(data['mask'].cuda())name_val = data['ID']optimizer.zero_grad()outputs_val1, outputs_val2 = model(inputs_val, inputs_val_wavelet)torch.cuda.empty_cache()if i == 0:score_list_val1 = outputs_val1score_list_val2 = outputs_val2mask_list_val = mask_valname_list_val = name_valelse:score_list_val1 = torch.cat((score_list_val1, outputs_val1), dim=0)score_list_val2 = torch.cat((score_list_val2, outputs_val2), dim=0)mask_list_val = torch.cat((mask_list_val, mask_val), dim=0)name_list_val = np.append(name_list_val, name_val, axis=0)loss_val_sup1 = criterion(outputs_val1, mask_val)loss_val_sup2 = criterion(outputs_val2, mask_val)val_loss_sup_1 += loss_val_sup1.item()val_loss_sup_2 += loss_val_sup2.item()torch.cuda.empty_cache()val_epoch_loss_sup1, val_epoch_loss_sup2 = print_val_loss(val_loss_sup_1, val_loss_sup_2,num_batches, print_num, print_num_half)val_eval_list1, val_eval_list2, val_m_jc1, val_m_jc2 = print_val_eval(cfg['NUM_CLASSES'],score_list_val1,score_list_val2,mask_list_val, print_num_half)best_val_eval_list, best_model, best_result = save_val_best_2d(cfg['NUM_CLASSES'], best_model,best_val_eval_list, best_result,model, model, score_list_val1,score_list_val2, name_list_val,val_eval_list1, val_eval_list2,path_trained_models,path_seg_results, cfg['PALETTE'])torch.cuda.empty_cache()torch.cuda.empty_cache()

生成低频和高频图片

里面有一个wavelet2D.py(我是2D图片)。运行即可。

import numpy as np
from PIL import Image
import pywt
import argparse
import osif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--image_path', default='自己数据的位置')# parser.add_argument('--mask_path', default='')parser.add_argument('--L_path', default='自己保存低频图片的位置')parser.add_argument('--H_path', default='自己保存高频图片的位置')parser.add_argument('--wavelet_type', default='db2', help='haar, db2, bior1.5, bior2.4, coif1, dmey')parser.add_argument('--if_RGB', default=False)args = parser.parse_args()if not os.path.exists(args.L_path):os.mkdir(args.L_path)if not os.path.exists(args.H_path):os.mkdir(args.H_path)for i in os.listdir(args.image_path):image_path = os.path.join(args.image_path, i)L_path = os.path.join(args.L_path, i)H_path = os.path.join(args.H_path, i)if args.if_RGB:image = Image.open(image_path).convert('L')else:image = Image.open(image_path)image = np.array(image)LL, (LH, HL, HH) = pywt.dwt2(image, args.wavelet_type)LL = (LL - LL.min()) / (LL.max() - LL.min()) * 255LL = Image.fromarray(LL.astype(np.uint8))LL.save(L_path)LH = (LH - LH.min()) / (LH.max() - LH.min()) * 255HL = (HL - HL.min()) / (HL.max() - HL.min()) * 255HH = (HH - HH.min()) / (HH.max() - HH.min()) * 255merge1 = HH + HL + LHmerge1 = (merge1-merge1.min()) / (merge1.max()-merge1.min()) * 255merge1 = Image.fromarray(merge1.astype(np.uint8))merge1.save(H_path)

产生数据集

我是这样的所以我读取数据集的时候还得改。大家也可以按照这个项目的readme中的那个文件对我下面这个产生数据集代码中的路径进行修改。
dataset
├── train
├── L
├── 1.png
├── 2.png
└── …
├── H
├── 1.png
├── 2.png
└── …
└── mask
├── 1.png
├── 2.png
└── …
└── val
├── L
├── H
└── mask

import os
import argparse
import random
import shutil
from shutil import copyfiledef rm_mkdir(dir_path):if os.path.exists(dir_path):shutil.rmtree(dir_path)print('Remove path - %s' % dir_path)os.makedirs(dir_path)print('Create path - %s' % dir_path)def main(config):rm_mkdir(os.path.join(config.train_path,'H'))rm_mkdir(os.path.join(config.train_path,'L'))rm_mkdir(os.path.join(config.train_path, 'mask'))rm_mkdir(os.path.join(config.valid_path, 'H'))rm_mkdir(os.path.join(config.valid_path, 'L'))rm_mkdir(os.path.join(config.valid_path, 'mask'))H_path = os.path.join(config.origin_data_path, 'H')H_filenames = os.listdir(H_path)data_list = []for filename in H_filenames:ext = os.path.splitext(filename)[-1]if ext == '.png':filename = os.path.basename(filename)data_list.append(filename)num_total = len(data_list)num_train = int((config.train_ratio / (config.train_ratio + config.valid_ratio )) * num_total)num_valid = int((config.valid_ratio / (config.train_ratio + config.valid_ratio )) * num_total)print('\nNum of train set : ', num_train)print('\nNum of valid set : ', num_valid)Arange = list(range(num_total))random.shuffle(Arange)for i in range(num_train):idx = Arange.pop()src = os.path.join(config.origin_data_path,'H', data_list[idx])dst = os.path.join(config.train_path,'H', data_list[idx])copyfile(src, dst)src = os.path.join(config.origin_data_path, 'L', data_list[idx])dst = os.path.join(config.train_path, 'L', data_list[idx])copyfile(src, dst)src = os.path.join(config.origin_data_path, 'mask_', data_list[idx])dst = os.path.join(config.train_path, 'mask', data_list[idx])copyfile(src, dst)for i in range(num_valid):idx = Arange.pop()src = os.path.join(config.origin_data_path, 'H', data_list[idx])dst = os.path.join(config.valid_path, 'H', data_list[idx])copyfile(src, dst)src = os.path.join(config.origin_data_path, 'L', data_list[idx])dst = os.path.join(config.valid_path, 'L', data_list[idx])copyfile(src, dst)src = os.path.join(config.origin_data_path, 'mask', data_list[idx])dst = os.path.join(config.valid_path, 'mask', data_list[idx])copyfile(src, dst)if __name__ == '__main__':parser = argparse.ArgumentParser()# model hyper-parametersparser.add_argument('--train_ratio', type=float, default=0.8)#训练集和测试集的比例parser.add_argument('--valid_ratio', type=float, default=0.2)# data pathparser.add_argument('--origin_data_path', type=str, default='自己数据的位置')parser.add_argument('--train_path', type=str, default='./train/')#自己要保存的训练集和测试集的位置←↓parser.add_argument('--valid_path', type=str, default='./val/')config = parser.parse_args()print(config)main(config)

改读取数据的位置

main.py中他原来是train_sup100,我用的是train文件夹。所以dataloader的参数要改。

dataset_train = imagefloder_iitnn(data_dir=args.path_dataset+'/train',input1=args.input1,input2=args.input2,data_transform_1=data_transforms['train'],data_normalize_1=data_normalize_1,data_normalize_2=data_normalize_2,sup=True,num_images=None,)dataset_val = imagefloder_iitnn(data_dir=args.path_dataset + '/val',input1=args.input1,input2=args.input2,data_transform_1=data_transforms['val'],data_normalize_1=data_normalize_1,data_normalize_2=data_normalize_2,sup=True,num_images=None,)

还有dataset_2d.py中的,也有train_sup100好像也改了。具体的忘了。

class dataset_iitnn(Dataset):def __init__(self, data_dir, input1, input2, augmentation1, normalize_1, normalize_2, sup=True,num_images=None, **kwargs):super(dataset_iitnn, self).__init__()img_paths_1 = []img_paths_2 = []mask_paths = []image_dir_1 = data_dir + '/' + input1image_dir_2 = data_dir + '/' + input2if sup:mask_dir = data_dir + '/mask'

损失函数

我数据集是只有一个类别。

class DiceLoss(nn.Module):"""Dice loss, need one hot encode input"""def __init__(self, weight=None, aux=False, aux_weight=0.4, ignore_index=-1, **kwargs):super(DiceLoss, self).__init__()self.kwargs = kwargsself.weight = weightself.ignore_index = ignore_indexself.aux = auxself.aux_weight = aux_weightdef _base_forward(self, predict, target, valid_mask):dice = BinaryDiceLoss(**self.kwargs)total_loss = 0predict = F.softmax(predict, dim=1)for i in range(target.shape[-1]):if i != self.ignore_index:dice_loss = dice(predict, target, valid_mask)#这里只有一个类别的把[i,:]删了,不然会报错因为超出范围if self.weight is not None:assert self.weight.shape[0] == target.shape[1], \'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])dice_loss *= self.weights[i]total_loss += dice_lossreturn total_loss / target.shape[-1]def _aux_forward(self, output, target, **kwargs):# *preds, target = tuple(inputs)valid_mask = (target != self.ignore_index).long()target_one_hot = F.one_hot(torch.clamp_min(target, 0))loss = self._base_forward(output[0], target_one_hot, valid_mask)for i in range(1, len(output)):aux_loss = self._base_forward(output[i], target_one_hot, valid_mask)loss += self.aux_weight * aux_lossreturn lossdef forward(self, output, target):# preds, target = tuple(inputs)# inputs = tuple(list(preds) + [target])if self.aux:return self._aux_forward(output, target)else:valid_mask = (target != self.ignore_index).long()# target_one_hot = F.one_hot(torch.clamp_min(target, 0))# target_one_hot = F.one_hot(torch.clamp_min(target, 0))#这个注释掉return self._base_forward(output, target, valid_mask)#把target_one_hot改成target

添加自己数据集的信息

在/config/dataset_config/dataset_cfg.py中。
添加自己数据集的信息。我的理解。下面是求相关数据的程序。

'Data_one':{'IN_CHANNELS': 1,#单通道的'NUM_CLASSES': 1,'MEAN': [0.1612872],'STD': [0.1612872],'MEAN_H': [0.44275072],'STD_H': [0.44275072],'MEAN_L': [0.21374299],'STD_L': [0.22170983],'PALETTE': list(np.array([[255, 255, 255],]).flatten())},
import cv2
import numpy as np
import osdef compute_mean_std(dataset_path):# 初始化累积器mean_accumulator = np.zeros(3)std_accumulator = np.zeros(3)total_samples = 0# 遍历数据集for image_file in os.listdir(dataset_path):if image_file.endswith(".jpg") or image_file.endswith(".png"):image_path = os.path.join(dataset_path, image_file)# 读取图像image = cv2.imread(image_path)image = image / 255.0  # 将像素值缩放到 [0, 1]# 计算均值和标准差mean_accumulator += np.mean(image, axis=(0, 1))std_accumulator += np.std(image, axis=(0, 1))total_samples += 1# 计算平均值mean_values = mean_accumulator / total_samples# 计算标准差std_values = std_accumulator / total_samplesreturn mean_values, std_values# 示例用法
dataset_path = "自己要求平均值、方差的数据的位置"#L,H,
#mask_path = ""
mean_values, std_values = compute_mean_std(dataset_path)print("MEAN:", mean_values)
print("STD:", std_values)

结果

暂时跑出来是这样的。要是有问题之后会更新。大家也可以调调错误。感谢。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/188137.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Codeforces Round 911 (Div. 2)(C~E)(DFS、数论(容斥)、SCC缩点 + DAG图上DP)

​​​​​​1900C - Anjis Binary Tree 题意&#xff1a; 凯克西奇一直被安吉冷落。通过一个共同的朋友&#xff0c;他发现安吉非常喜欢二叉树&#xff0c;于是决定解决她的问题&#xff0c;以引起她的注意。Anji 给了 Keksic 一棵有 n个顶点的二叉树。顶点 1 是根&#xff…

利用异或、取反、自增bypass_webshell_waf

目录 引言 利用异或 介绍 eval与assert 蚁剑连接 进阶题目 利用取反 利用自增 引言 有这样一个waf用于防御我们上传的文件&#xff1a; function fun($var): bool{$blacklist ["\$_", "eval","copy" ,"assert","usort…

Vue+SpringBoot解决session跨域问题

做了一个前后端分离&#xff0c;因为前后端的 session id不一致&#xff0c;导致前端请求时&#xff0c;后端的session读取不到对应的值&#xff0c;造成登录问题。 解决方法&#xff1a; SpringBoot项目: 添加一个跨域配置 代码如下: 或者controller使用CrossOrigin Conf…

No matching version found for @babel/compat-data@^7.23.5 处理

npm ERR! notarget No matching version found for babel/compat-data^7.23.5 处理 报错信息 npm WARN ERESOLVE overriding peer dependency npm ERR! code ETARGET npm ERR! notarget No matching version found for babel/compat-data^7.23.5. npm ERR! notarget In most …

【java】编译时bug 项目启动前bug合集

文章目录 1. jdk8中 Optional orElseThrow 编译时报错java: 未报告的异常错误X; 必须对其进行捕获或声明以便抛出2. 启动项目时提示 Error running Application: Command line is too long. Shorten command line for Application or also for Spring Boot default configurati…

算法学习—排序

排序算法 一、选择排序 1.算法简介 选择排序是一个简单直观的排序方法&#xff0c;它的工作原理很简单&#xff0c;首先从未排序序列中找到最大的元素&#xff0c;放到已排序序列的末尾&#xff0c;重复上述步骤&#xff0c;直到所有元素排序完毕。 2.算法描述 1&#xff…

万宾科技监测设备,可燃气体监测仪特点一览

万宾科技的监测设备种类繁多&#xff0c;包括可燃气体监测仪、管网水位监测仪、内涝积水监测仪等。其中可燃气体监测仪是万宾科技的核心产品之一&#xff0c;用于监测环境中可燃气体的浓度&#xff0c;适用于对甲烷气体浓度进行实时监测&#xff0c;应用于燃气管网、排水管网、…

从cmd登录mysql

说明 先看看mysql.exe文件在哪个目录下&#xff0c;为了后面的操作方便&#xff0c;可以将该文件所在的路径增加到环境变量path中。 如果不增加到path环境变量中&#xff0c;那么在cmd窗口就要切换到mysql.exe文件所在的目录下执行。 在cmd窗口查看mysql命令的帮助信息 在cm…

编译原理:设计与实现一个简单词法分析器

设计与实现一个简单词法分析。具体内容是产生一个二元式文本文件&#xff0c;扩展名为dyd&#xff0c;可将Java或C程序(测试程序)分解成为一个一个的单词及类型。 &#xff08;选做&#xff1a;并查“单词符号与种别对照表”得出其种别&#xff0c;用一数字表示。&#xff09;…

CSS 多主题切换思路

前言 本篇仅提供多主题切换思路&#xff0c;示例简单且清晰。 实现 步骤一&#xff1a;多主题(颜色)定义 定义根伪类 :root&#xff0c;代码第 2 和 7 行。分别定义了默认和带参数的伪类&#xff1b;定义 CSS 变量&#xff0c;注意变量名需要以两个减号&#xff08;--&…

adb修改android系统时间 adb shell date必须要root权限

adb Command adb root //需要root权限 adb shell setprop persist.sys.timezone GMT //校准时区 adb shell date MMDDhhmmYY.ss set //修改系统时间这里是GMT格林尼治时间&#xff0c;北京时间得转换一下 8小时 adb shell hwclock -w //同步硬件时间adb shell date 0201030422…

初识Linux:权限

目录 提示&#xff1a;以下指令均在Xshell 7 中进行 Linux 的权限 内核&#xff1a; 查看操作系统版本 查看cpu信息 查看内存信息 外部程序&#xff1a; 用户&#xff1a; 普通用户变为超级用户&#xff1a; su 和 su-的区别&#xff1a; root用户变成普通用户&#…

机器人最优控制开源库 Model-based Optimization for Robotics

系列文章目录 文章目录 系列文章目录前言一、开源的库和工具箱1.1 ACADO1.2 CasADi1.3 Control Toolbox1.4 Crocoddyl1.5 Ipopt1.6 Manopt1.7 LexLS1.8 NLOpt1.9 qpOASES1.10 qpSWIFT1.11 Roboptim 二、其他库和工具箱2.1 MUSCOD2.2 OCPID-DAE12.3 SNOPT 前言 机器人&#xff…

【论文阅读】1 SkyChain:一个深度强化学习的动态区块链分片系统

SkyChain 一、文献简介二、引言及重要信息2.1 研究背景2.2 研究目的和意义2.3 文献的创新点 三、研究内容3.1模型3.2自适应分类账协议3.2.1状态块创建3.2.2合并过程3.2.3拆分过程 3.3评价框架3.3.1性能3.3.1.1共识延迟3.3.1.2重新分片延迟3.3.1.3处理事务数3.3.1.4 约束 3.3.2 …

【代码】基于麻雀搜索优化Kmeans图像分割算法

程序名称&#xff1a;基于麻雀搜索优化Kmeans图像分割算法 实现平台&#xff1a;matlab 代码简介&#xff1a;首先使用麻雀搜索优化算法来确定 K-means 算法的初始质心位置&#xff0c;然后进行传统的 K-means 聚类。这样做的目的是为了避免 K-means 算法陷入局部最优解&…

使用Docker安装部署Swagger Editor并远程访问编辑API文档

文章目录 Swagger Editor本地接口文档公网远程访问1. 部署Swagger Editor2. Linux安装Cpolar3. 配置Swagger Editor公网地址4. 远程访问Swagger Editor5. 固定Swagger Editor公网地址 Swagger Editor本地接口文档公网远程访问 Swagger Editor是一个用于编写OpenAPI规范的开源编…

还得是字节出来的,太秀了...

前段时间公司缺人&#xff0c;也面了许多测试&#xff0c;一开始瞄准的就是中级水准&#xff0c;当然也没指望能来大牛&#xff0c;提供的薪资在15-20k这个范围&#xff0c;来面试的人有很多&#xff0c;但是平均水平真的让人很失望。看了简历很多上面都是写有4年工作经验&…

MatrixOne Meetup回顾 | 深圳站

11月11日&#xff0c;MatrixOne 社区在深圳成功举办了第二次 MatrixOne Meetup。活动当天&#xff0c;数十位外部小伙伴到场参与&#xff0c;一同分享云原生数据库相关知识内容。此次活动&#xff0c;我们也邀请了来自深圳素问智能的外部讲师&#xff0c;分享了目前火爆的大模型…

【每日OJ —— 144. 二叉树的前序遍历】

每日OJ —— 144. 二叉树的前序遍历 1.题目&#xff1a;144. 二叉树的前序遍历2.方法讲解2.1.算法讲解2.2.代码实现2.3.提交通过展示 1.题目&#xff1a;144. 二叉树的前序遍历 2.方法讲解 2.1.算法讲解 1.首先如果在每次每个节点遍历的时候都去为数组开辟空间&#xff0c;这样…

steam搬砖项目到底能不能做?

在Steam的世界里&#xff0c;有一款引人注目的游戏——《CSGO》游戏设备和配件的搬运项目。它就像一个桥梁&#xff0c;连接着全球的游戏世界&#xff0c;为玩家们提供着精彩的设备和配件。原则上&#xff0c;这是一个低买高卖的过程&#xff0c;通过汇率差来赚取利润。比如&am…