ViT Vision Transformer超详细解析,网络构建,可视化,数据预处理,全流程实例教程

关于ViT的分析和教程,网上又虚又空的东西比较多,本文通过一个实例,将ViT全解析。

包括三部分内容,网络构建;orchview.draw_graph 将网络每一层的结构与输入输出可视化;数据预处理。附完整代码

网络构建

创建一个model.py,其中实现ViT网络构建

import torch.nn as nn
import torch
import torch.optim as optim
import torch.nn.functional as F
import lightning as Lclass AttentionBlock(nn.Module):def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):"""Inputs:embed_dim - Dimensionality of input and attention feature vectorshidden_dim - Dimensionality of hidden layer in feed-forward network(usually 2-4x larger than embed_dim)num_heads - Number of heads to use in the Multi-Head Attention blockdropout - Amount of dropout to apply in the feed-forward network"""super().__init__()self.layer_norm_1 = nn.LayerNorm(embed_dim)self.attn = nn.MultiheadAttention(embed_dim, num_heads)self.layer_norm_2 = nn.LayerNorm(embed_dim)self.linear = nn.Sequential(nn.Linear(embed_dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, embed_dim),nn.Dropout(dropout),)def forward(self, x):inp_x = self.layer_norm_1(x)x = x + self.attn(inp_x, inp_x, inp_x)[0]x = x + self.linear(self.layer_norm_2(x))return xclass VisionTransformer(nn.Module):def __init__(self,embed_dim,hidden_dim,num_channels,num_heads,num_layers,num_classes,patch_size,num_patches,dropout=0.0,):"""Inputs:embed_dim - Dimensionality of the input feature vectors to the Transformerhidden_dim - Dimensionality of the hidden layer in the feed-forward networkswithin the Transformernum_channels - Number of channels of the input (3 for RGB)num_heads - Number of heads to use in the Multi-Head Attention blocknum_layers - Number of layers to use in the Transformernum_classes - Number of classes to predictpatch_size - Number of pixels that the patches have per dimensionnum_patches - Maximum number of patches an image can havedropout - Amount of dropout to apply in the feed-forward network andon the input encoding"""super().__init__()self.patch_size = patch_size# Layers/Networksself.input_layer = nn.Linear(num_channels * (patch_size**2), embed_dim)self.transformer = nn.Sequential(*(AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers)))self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes))self.dropout = nn.Dropout(dropout)# Parameters/Embeddingsself.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))def img_to_patch(self, x, patch_size, flatten_channels=True):"""Inputs:x - Tensor representing the image of shape [B, C, H, W]patch_size - Number of pixels per dimension of the patches (integer)flatten_channels - If True, the patches will be returned in a flattened formatas a feature vector instead of a image grid."""B, C, H, W = x.shapex = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H', W', C, p_H, p_W]x = x.flatten(1, 2)  # [B, H'*W', C, p_H, p_W]if flatten_channels:x = x.flatten(2, 4)  # [B, H'*W', C*p_H*p_W]return xdef forward(self, x):# Preprocess inputx = self.img_to_patch(x, self.patch_size)B, T, _ = x.shapex = self.input_layer(x)# Add CLS token and positional encodingcls_token = self.cls_token.repeat(B, 1, 1)x = torch.cat([cls_token, x], dim=1)x = x + self.pos_embedding[:, : T + 1]# Apply Transforrmerx = self.dropout(x)x = x.transpose(0, 1)x = self.transformer(x)# Perform classification predictioncls = x[0]out = self.mlp_head(cls)return outclass ViT(L.LightningModule):def __init__(self, model_kwargs, lr):super().__init__()self.save_hyperparameters()self.model = VisionTransformer(**model_kwargs)def forward(self, x):return self.model(x)def configure_optimizers(self):optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)return [optimizer], [lr_scheduler]def _calculate_loss(self, batch, mode="train"):imgs, labels = batchpreds = self.model(imgs)loss = F.cross_entropy(preds, labels)acc = (preds.argmax(dim=-1) == labels).float().mean()self.log("%s_loss" % mode, loss)self.log("%s_acc" % mode, acc)return lossdef training_step(self, batch, batch_idx):loss = self._calculate_loss(batch, mode="train")return lossdef validation_step(self, batch, batch_idx):self._calculate_loss(batch, mode="val")def test_step(self, batch, batch_idx):self._calculate_loss(batch, mode="test")

在其他文件中引入model.py,实现网络搭建

from model import ViTmodel = ViT(model_kwargs={"embed_dim": 256,"hidden_dim": 512,"num_heads": 8,"num_layers": 6,"patch_size": 4,"num_channels": 3,"num_patches": 64,"num_classes": 10,"dropout": 0.2,},lr=3e-4,)

也可以下载预训练的模型

# Files to download
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/"
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/VisionTransformers/")
pretrained_files = ["tutorial15/ViT.ckpt","tutorial15/tensorboards/ViT/events.out.tfevents.ViT","tutorial5/tensorboards/ResNet/events.out.tfevents.resnet",
]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:file_path = os.path.join(CHECKPOINT_PATH, file_name.split("/", 1)[1])if "/" in file_name.split("/", 1)[1]:os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)if not os.path.isfile(file_path):file_url = base_url + file_nameprint("Downloading %s..." % file_url)try:urllib.request.urlretrieve(file_url, file_path)except HTTPError as e:print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n",e,)pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
model = ViT.load_from_checkpoint(pretrained_filename)

torchview.draw_graph 网络可视化

model_graph = draw_graph(model, input_size=(1, 3, 16, 16))
model_graph.resize_graph(scale=5.0)
model_graph.visual_graph.render(format='svg')

运行这段代码,会生成一个svg格式的图片,显示网络结构和每一层的输入输出

训练数据准备

新建一个prepare_data.py

import os
import json
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transformsclass CustomDataset(Dataset):def __init__(self, image_dir, names, labels, transform=None):self.image_dir = image_dirself.names = namesself.labels = labelsself.transform = transformdef __len__(self):return len(self.labels)def __getitem__(self, idx):name_ = self.names[idx]img_name = os.path.join(self.image_dir, name_)image = Image.open(img_name)if self.transform:image = self.transform(image)label = self.labels[idx]return image, labeldef load_img_ann(ann_path):"""return [{img_name, [ (x, y, h, w, label), ... ]}]"""with open(ann_path) as fp:root = json.load(fp)img_dict = {}img_label_dict = {}for img_info in root['images']:img_id = img_info['id']img_name = img_info['file_name']img_dict[img_id] = {'name': img_name}for ann_info in root['annotations']:img_id = ann_info['image_id']img_category_id = ann_info['category_id']img_name = img_dict[img_id]['name']img_label_dict[img_id] = {'name': img_name, 'category_id': img_category_id}return img_label_dictdef get_dataloader():annota_dir = '/home/robotics/Downloads/coco_dataset/annotations/instances_val2017.json'img_dir = "/home/robotics/Downloads/coco_dataset/val2017"img_dict = load_img_ann(annota_dir)values = list(img_dict.values())img_names = []labels = []for item in values:category_id = item['category_id']labels.append(category_id)img_name = item['name']img_names.append(img_name)# 检查剔除黑白的图片img_names_rgb = []labels_rgb = []for i in range(len(img_names)):# 检查文件扩展名,确保它是图片文件(可以根据需要扩展支持的文件类型)file_path = os.path.join(img_dir, img_names[i])# 打开图片文件img = Image.open(file_path)# 获取通道数num_channels = img.modeif num_channels == "RGB" and labels[i] < 10:img_names_rgb.append(img_names[i])labels_rgb.append(labels[i])# 定义一系列图像转换操作transform = transforms.Compose([transforms.Resize((16, 16)),  # 调整图像大小transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化图像])# 假设 image_dir 是包含所有图像文件的文件夹路径,labels 是标签列表train_set = CustomDataset(img_dir, img_names_rgb[-500:], labels_rgb[-500:], transform=transform)val_set = CustomDataset(img_dir, img_names_rgb[-500:-100], labels_rgb[-500:-100], transform=transform)test_set = CustomDataset(img_dir, img_names_rgb[-100:], labels_rgb[-100:], transform=transform)# 创建一个 DataLoadertrain_loader = DataLoader(train_set, batch_size=32, shuffle=True, drop_last=False)val_loader = DataLoader(val_set, batch_size=32, shuffle=True, drop_last=False, num_workers=4)test_loader = DataLoader(test_set, batch_size=32, shuffle=True, drop_last=False, num_workers=4)return train_loader, val_loader, test_loaderif __name__ == "__main__":train_loader, val_loader, test_loader = get_dataloader()for batch in train_loader:print(batch)

解释一下上面的代码:

这里使用的是coco数据集的2017,可以在官网自行下载,下载下来以后,annotations包含如下内容

这里我们使用的是 instances_val2017.json,如果是正经做训练,应该用train2017,但是train2017文件太大了,处理起来速度很慢,本文仅为说明,不追求训练效果,所以使用val2017进行说明,instances就是用于图像识别的annotation,里面包括了每张图片的label和box,本文创建的ViT 不输出box,仅输出类别。函数

def load_img_ann(ann_path):

是为了将图片的id(每张图片的唯一主键),name和category_id(属于哪一个类别,也就是label)关联起来。

        # 获取通道数num_channels = img.modeif num_channels == "RGB" and labels[i] < 10:img_names_rgb.append(img_names[i])labels_rgb.append(labels[i])

注意coco数据集有单通道的黑白图片,要剔除,因为本文的ViT比较简单,输出只能10个类别,所以预处理图片的时候,只选择10个类别。

定义操作变换

    # 定义一系列图像转换操作transform = transforms.Compose([transforms.Resize((16, 16)),  # 调整图像大小transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化图像])

创建一个自己的Dataset类,继承自 torch.utils.data.Dataset

class CustomDataset(Dataset):def __init__(self, image_dir, names, labels, transform=None):self.image_dir = image_dirself.names = namesself.labels = labelsself.transform = transformdef __len__(self):return len(self.labels)def __getitem__(self, idx):name_ = self.names[idx]img_name = os.path.join(self.image_dir, name_)image = Image.open(img_name)if self.transform:image = self.transform(image)label = self.labels[idx]return image, label

先创建Dataset,再创建dataloader,从Dataset取minibatch。

    # 假设 image_dir 是包含所有图像文件的文件夹路径,labels 是标签列表train_set = CustomDataset(img_dir, img_names_rgb[-500:], labels_rgb[-500:], transform=transform)val_set = CustomDataset(img_dir, img_names_rgb[-500:-100], labels_rgb[-500:-100], transform=transform)test_set = CustomDataset(img_dir, img_names_rgb[-100:], labels_rgb[-100:], transform=transform)# 创建一个 DataLoadertrain_loader = DataLoader(train_set, batch_size=32, shuffle=True, drop_last=False)val_loader = DataLoader(val_set, batch_size=32, shuffle=True, drop_last=False, num_workers=4)test_loader = DataLoader(test_set, batch_size=32, shuffle=True, drop_last=False, num_workers=4)

以上,数据准备工作完成,对模型进行训练

    trainer = L.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),accelerator="auto",devices=1,max_epochs=180,callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),LearningRateMonitor("epoch"),],)trainer.logger._log_graph = True  # If True, we plot the computation graph in tensorboardtrainer.logger._default_hp_metric = None  # Optional logging argument that we don't needtrainer.fit(model, train_loader, val_loader)# Load best checkpoint after trainingmodel = ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)# Test best model on validation and test setval_result = trainer.test(model, dataloaders=val_loader, verbose=False)test_result = trainer.test(model, dataloaders=test_loader, verbose=False)result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}

完整代码:

一共包括三个文件:model.py 搭建网络的功能, prepare_data.py 数据预处理工作, main.py 网络训练

model.py内容:

import torch.nn as nn
import torch
import torch.optim as optim
import torch.nn.functional as F
import lightning as Lclass AttentionBlock(nn.Module):def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):"""Inputs:embed_dim - Dimensionality of input and attention feature vectorshidden_dim - Dimensionality of hidden layer in feed-forward network(usually 2-4x larger than embed_dim)num_heads - Number of heads to use in the Multi-Head Attention blockdropout - Amount of dropout to apply in the feed-forward network"""super().__init__()self.layer_norm_1 = nn.LayerNorm(embed_dim)self.attn = nn.MultiheadAttention(embed_dim, num_heads)self.layer_norm_2 = nn.LayerNorm(embed_dim)self.linear = nn.Sequential(nn.Linear(embed_dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, embed_dim),nn.Dropout(dropout),)def forward(self, x):inp_x = self.layer_norm_1(x)x = x + self.attn(inp_x, inp_x, inp_x)[0]x = x + self.linear(self.layer_norm_2(x))return xclass VisionTransformer(nn.Module):def __init__(self,embed_dim,hidden_dim,num_channels,num_heads,num_layers,num_classes,patch_size,num_patches,dropout=0.0,):"""Inputs:embed_dim - Dimensionality of the input feature vectors to the Transformerhidden_dim - Dimensionality of the hidden layer in the feed-forward networkswithin the Transformernum_channels - Number of channels of the input (3 for RGB)num_heads - Number of heads to use in the Multi-Head Attention blocknum_layers - Number of layers to use in the Transformernum_classes - Number of classes to predictpatch_size - Number of pixels that the patches have per dimensionnum_patches - Maximum number of patches an image can havedropout - Amount of dropout to apply in the feed-forward network andon the input encoding"""super().__init__()self.patch_size = patch_size# Layers/Networksself.input_layer = nn.Linear(num_channels * (patch_size**2), embed_dim)self.transformer = nn.Sequential(*(AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers)))self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes))self.dropout = nn.Dropout(dropout)# Parameters/Embeddingsself.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))def img_to_patch(self, x, patch_size, flatten_channels=True):"""Inputs:x - Tensor representing the image of shape [B, C, H, W]patch_size - Number of pixels per dimension of the patches (integer)flatten_channels - If True, the patches will be returned in a flattened formatas a feature vector instead of a image grid."""B, C, H, W = x.shapex = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H', W', C, p_H, p_W]x = x.flatten(1, 2)  # [B, H'*W', C, p_H, p_W]if flatten_channels:x = x.flatten(2, 4)  # [B, H'*W', C*p_H*p_W]return xdef forward(self, x):# Preprocess inputx = self.img_to_patch(x, self.patch_size)B, T, _ = x.shapex = self.input_layer(x)# Add CLS token and positional encodingcls_token = self.cls_token.repeat(B, 1, 1)x = torch.cat([cls_token, x], dim=1)x = x + self.pos_embedding[:, : T + 1]# Apply Transforrmerx = self.dropout(x)x = x.transpose(0, 1)x = self.transformer(x)# Perform classification predictioncls = x[0]out = self.mlp_head(cls)return outclass ViT(L.LightningModule):def __init__(self, model_kwargs, lr):super().__init__()self.save_hyperparameters()self.model = VisionTransformer(**model_kwargs)def forward(self, x):return self.model(x)def configure_optimizers(self):optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)return [optimizer], [lr_scheduler]def _calculate_loss(self, batch, mode="train"):imgs, labels = batchpreds = self.model(imgs)loss = F.cross_entropy(preds, labels)acc = (preds.argmax(dim=-1) == labels).float().mean()self.log("%s_loss" % mode, loss)self.log("%s_acc" % mode, acc)return lossdef training_step(self, batch, batch_idx):loss = self._calculate_loss(batch, mode="train")return lossdef validation_step(self, batch, batch_idx):self._calculate_loss(batch, mode="val")def test_step(self, batch, batch_idx):self._calculate_loss(batch, mode="test")

 prepare_data.py内容:

import os
import json
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transformsclass CustomDataset(Dataset):def __init__(self, image_dir, names, labels, transform=None):self.image_dir = image_dirself.names = namesself.labels = labelsself.transform = transformdef __len__(self):return len(self.labels)def __getitem__(self, idx):name_ = self.names[idx]img_name = os.path.join(self.image_dir, name_)image = Image.open(img_name)if self.transform:image = self.transform(image)label = self.labels[idx]return image, labeldef load_img_ann(ann_path):"""return [{img_name, [ (x, y, h, w, label), ... ]}]"""with open(ann_path) as fp:root = json.load(fp)img_dict = {}img_label_dict = {}for img_info in root['images']:img_id = img_info['id']img_name = img_info['file_name']img_dict[img_id] = {'name': img_name}for ann_info in root['annotations']:img_id = ann_info['image_id']img_category_id = ann_info['category_id']img_name = img_dict[img_id]['name']img_label_dict[img_id] = {'name': img_name, 'category_id': img_category_id}return img_label_dictdef get_dataloader():annota_dir = '/home/robotics/Downloads/coco_dataset/annotations/instances_val2017.json'img_dir = "/home/robotics/Downloads/coco_dataset/val2017"img_dict = load_img_ann(annota_dir)values = list(img_dict.values())img_names = []labels = []for item in values:category_id = item['category_id']labels.append(category_id)img_name = item['name']img_names.append(img_name)# 检查剔除黑白的图片img_names_rgb = []labels_rgb = []for i in range(len(img_names)):# 检查文件扩展名,确保它是图片文件(可以根据需要扩展支持的文件类型)file_path = os.path.join(img_dir, img_names[i])# 打开图片文件img = Image.open(file_path)# 获取通道数num_channels = img.modeif num_channels == "RGB" and labels[i] < 10:img_names_rgb.append(img_names[i])labels_rgb.append(labels[i])# 定义一系列图像转换操作transform = transforms.Compose([transforms.Resize((16, 16)),  # 调整图像大小transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化图像])# 假设 image_dir 是包含所有图像文件的文件夹路径,labels 是标签列表train_set = CustomDataset(img_dir, img_names_rgb[-500:], labels_rgb[-500:], transform=transform)val_set = CustomDataset(img_dir, img_names_rgb[-500:-100], labels_rgb[-500:-100], transform=transform)test_set = CustomDataset(img_dir, img_names_rgb[-100:], labels_rgb[-100:], transform=transform)# 创建一个 DataLoadertrain_loader = DataLoader(train_set, batch_size=32, shuffle=True, drop_last=False)val_loader = DataLoader(val_set, batch_size=32, shuffle=True, drop_last=False, num_workers=4)test_loader = DataLoader(test_set, batch_size=32, shuffle=True, drop_last=False, num_workers=4)return train_loader, val_loader, test_loaderif __name__ == "__main__":train_loader, val_loader, test_loader = get_dataloader()for batch in train_loader:print(batch)

main.py内容:

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # 下面老是报错 shape 不一致
import urllib.request
from urllib.error import HTTPError
import lightning as L
from model import ViT
from torchview import draw_graph
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpointfrom prepare_data import get_dataloader# 加载模型
# Files to download
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/"
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/VisionTransformers/")
pretrained_files = ["tutorial15/ViT.ckpt","tutorial15/tensorboards/ViT/events.out.tfevents.ViT","tutorial5/tensorboards/ResNet/events.out.tfevents.resnet",
]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:file_path = os.path.join(CHECKPOINT_PATH, file_name.split("/", 1)[1])if "/" in file_name.split("/", 1)[1]:os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)if not os.path.isfile(file_path):file_url = base_url + file_nameprint("Downloading %s..." % file_url)try:urllib.request.urlretrieve(file_url, file_path)except HTTPError as e:print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n",e,)pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
needTrain = False
if not os.path.isfile(pretrained_filename):print("Found pretrained model at %s, loading..." % pretrained_filename)# Automatically loads the model with the saved hyperparametersmodel = ViT.load_from_checkpoint(pretrained_filename)
else:L.seed_everything(42)  # To be reproducablemodel = ViT(model_kwargs={"embed_dim": 256,"hidden_dim": 512,"num_heads": 8,"num_layers": 6,"patch_size": 4,"num_channels": 3,"num_patches": 64,"num_classes": 10,"dropout": 0.2,},lr=3e-4,)needTrain = True# 网络结构可视化
model_graph = draw_graph(model, input_size=(1, 3, 16, 16))
model_graph.resize_graph(scale=5.0)
model_graph.visual_graph.render(format='svg')# 准备训练数据
train_loader, val_loader, test_loader = get_dataloader()if needTrain:trainer = L.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),accelerator="auto",devices=1,max_epochs=180,callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),LearningRateMonitor("epoch"),],)trainer.logger._log_graph = True  # If True, we plot the computation graph in tensorboardtrainer.logger._default_hp_metric = None  # Optional logging argument that we don't needtrainer.fit(model, train_loader, val_loader)# Load best checkpoint after trainingmodel = ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)# Test best model on validation and test setval_result = trainer.test(model, dataloaders=val_loader, verbose=False)test_result = trainer.test(model, dataloaders=test_loader, verbose=False)result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/128110.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mac下载安装jenkins

下载 https://get.jenkins.io/war/ 启动 使用命令行启动 java -jar jenkins.war 浏览器访问 IP:8080 或 localhost:8080 &#xff0c;对jenkins进行配置&#xff0c;刚开始需要输入密码 终端会展示密码和密码存放位置 jenkins插件下载地址&#xff0c; 下载后自行上传。 I…

不是我吹牛逼,这绝对是去掉 if...else 最佳的文章

我相信小伙伴一定看过多篇怎么去掉 if…else 的文章&#xff0c;也知道大家都很有心得&#xff0c;知道多种方法来去掉 if…else &#xff0c;比如 Option&#xff0c;策略模式等等&#xff0c;但我相信大明哥这篇文章绝对是最全&#xff0c;最完备怎么去掉 if…else 的文章&am…

Redis安装与配置及常用命令使用讲解

目录 一、Redis简介 二、Redis安装和配置 2.1 Linux版 2.2 Windows版 三、Redis命令 3.1 通过命令操作Redis 3.2 String 字符串 3.3 Hash 哈希 3.4 List 列表 3.5 Set 有序集合&#xff08;sorted set&#xff09; 一、Redis简介 Redis是一个开源的内存数据结构存储…

【Linux基础IO篇】系统文件接口(1)

【Linux基础IO篇】系统文件接口&#xff08;1&#xff09; 目录 【Linux基础IO篇】系统文件接口&#xff08;1&#xff09;回顾C语言的文件接口系统文件I/Oopen接口的介绍 open函数返回值文件描述符fd&#xff08;小整数&#xff09;文件描述符的分配规则 重定向dup2系统调用改…

Py之auto-gptq:auto-gptq的简介、安装、使用方法之详细攻略

Py之auto-gptq&#xff1a;auto-gptq的简介、安装、使用方法之详细攻略 目录 auto-gptq的简介 1、版本更新历史 2、性能对比 推理速度 困惑度&#xff08;PPL&#xff09; 3、支持的模型 3、支持的评估任务 auto-gptq的安装 auto-gptq的使用方法 1、基础用法 (1)、量…

《C语言从入门到精通》:入门容易,精通难,C语言也不例外

《C语言从入门到精通》&#xff1a;入门容易&#xff0c;精通难&#xff0c;C语言也不例外 C语言&#xff0c;容易上手&#xff0c;难以精通。它是一把双刃剑&#xff0c;既打开了编程世界的大门&#xff0c;又需要耐心与热情。无论是初学者还是专业人士&#xff0c;都需不断钻…

苹果cms论坛多播放源自动采集在线影视网站

苹果 cms 论坛一个基于 vue 和 gin 实现的在线观影网站 项目采用 vite vue 作为前端技术栈, 使用 ElementPlus 作为 UI 框架进行开发 后端程序使用 Gin gorm go-redis 等相关框架提供接口服务, 使用 gocolly 和 robfig/cron 进行公共影视资源采集和定时更新功能 目前用户…

2.Spark的工作与架构原理

概述 目标&#xff1a; spark的工作原理spark数据处理通用流程rdd 什么是rddrdd 的特点 spark架构 spark架构相关进程spark架构原理 spark的工作原理 spark 的工作原理&#xff0c;如下图 图中中间部分是spark集群&#xff0c;也可以是基于 yarn 的&#xff0c;图上可以…

为什么重写 redisTemplate

为什么重写 redisTemplate 1.安装 redis 上传 redis 的安装包tar -xvf redis-5.0.7.tar.gzyum -y install gcc-cmakemake PREFIX/soft/redis installcd /soft/redis/bin./redis-server redis.conf 2. 集成 redisTemplate maven 依赖 <dependency><groupId>org…

全国产EtherCAT运动控制边缘控制器(六):RtBasic文件下载与连续轨迹加工的Python+Qt开发

今天&#xff0c;正运动小助手给大家分享一下全国产EtherCAT运动控制边缘控制器ZMC432H如何使用PythonQT实现连续轨迹加工。 01 功能简介 全国产EtherCAT运动控制边缘控制器ZMC432H是正运动的一款软硬件全国产自主可控&#xff0c;运动控制接口兼容EtherCAT总线和脉冲型的独立…

【WinForm详细教程五】WinForm中的MenuStrip 、ContextMenuStrip 、ToolStrip、StatusStrip控件

文章目录 1.MenuStrip2.ContextMenuStrip3.ToolStrip4.StatusStrip 1.MenuStrip MenuStrip作为一个容器可以包含多个菜单项。MenuStrip 的重要属性包括&#xff1a; Name&#xff1a;菜单的名字Dock&#xff1a;菜单的停靠位置Items&#xff1a;菜单项的集合 ToolStripMenuI…

华为云服务器,在线安装MySQL

需求 在华为云服务器上&#xff0c;部署MySQL数据库&#xff0c;通过 公网IP 访问数据库。 通过 yum &#xff0c;在线安装MySQL&#xff1b;配置远程连接&#xff0c;开放3306端口&#xff0c;能够通过公网访问。 云服务器配置说明 本文所使用的 华为云服务器 配置如下。 …

C++进阶语法——STL 标准模板库(上)(Standard Template Library)【学习笔记(六)】

文章目录 STL 标准模板库1、 STL简介2、STL容器的类别3、STL迭代器的类别4、STL算法的类别5、泛型编程&#xff08;generic programming&#xff09;6、C模板&#xff08;template&#xff09;6.1 函数模板&#xff08;function template&#xff09;6.2 类模板&#xff08;cla…

20231102从头开始配置cv180zb的编译环境(欢迎入坑,肯定还有很多问题等着你)

20231102从头开始配置cv180zb的编译环境&#xff08;欢迎入坑&#xff0c;肯定还有很多问题等着你&#xff09; 2023/11/2 11:31 &#xff08;欢迎入坑&#xff0c;本篇只是针对官方的文档整理的&#xff01;只装这些东西你肯定编译不过的&#xff0c;还有很多问题等着你呢&…

3.字符集和比较规则简介

3.字符集和比较规则简介 1.字符集和比较规则简介1.1 字符集简介1.2 比较规则简介1.3 一些重要的比较规则 2. MySQL 中支持的字符集和比较规则2.1 MySQL 的 utf8 和 utf8mb42.2 字符集查看2.3 比较规则查看 3. 字符集和比较规则的应用3.1 各级别的字符集和比较规则1. 服务器级别…

AR眼镜定制开发-智能眼镜的主板硬件、软件

AR眼镜定制开发是一项复杂而又重要的工作&#xff0c;它需要准备相关的硬件设备和软件。这些设备包括多个传感器、显示装置和处理器等。传感器用于捕捉用户的动作和环境信息&#xff0c;如摄像头、陀螺仪、加速度计等;显示装置则用于将虚拟信息呈现给用户;处理器用于处理和协调…

京东科技埋点数据治理和平台建设实践 | 京东云技术团队

导读 本文核心内容聚焦为什么要埋点治理、埋点治理的方法论和实践、奇点一站式埋点管理平台的建设和创新功能。读者可以从全局角度深入了解埋点、埋点治理的整体思路和实践方法&#xff0c;落地的埋点工具和创新功能都有较高的实用参考价值。遵循埋点治理的方法论&#xff0c;…

Web - Servlet详解

目录 前言 一 . Servlet简介 1.1 动态资源和静态资源 1.2 Servlet简介 二 . Servlet开发流程 2.1 目标 2.2 开发过程 三 . Servlet注解方式配置 ​编辑 四 . servlet生命周期 4.1 生命周期简介 4.2 生命周期测试 4.3 生命周期总结 五 . servlet继承结构 5.1 ser…

Content-Type 值有哪些?

1、application/x-www-form-urlencoded 最常见 POST 提交数据的方式。 浏览器的原生 form 表单&#xff0c;如果不设置 enctype 属性&#xff0c;那么最终就会以 application/x-www-form-urlencoded 方式提交数据。 <form action"http://www.haha/ads/sds?name小草莓…

Jmeter调用测试片段 —— 模块控制器

可以使用模块控制器调用测试片段。模块控制器提供了一种在运行时将测试片段替换为当前测试计划的机制。测试片段可以位于任何线程组中。 1、打开一个Jmeter窗口&#xff0c;添加好线程组、用户定义变量、模块控制器、测试片段、察看结果树。 2、用户定义变量同样定义好访问ip及…