import argparse
import math
import os
import random
import sys
import time
from copy import deepcopy
from datetime import datetime
from pathlib import Path
import numpy as np
import yaml
from tqdm import tqdmimport torch
import torch. distributed as dist
import torch. nn as nn
from torch. cuda import amp
from torch. nn. parallel import DistributedDataParallel as DDP
from torch. optim import SGD, Adam, AdamW, lr_scheduler
FILE = Path( __file__) . resolve( )
ROOT = FILE. parents[ 0 ]
if str ( ROOT) not in sys. path: sys. path. append( str ( ROOT) )
ROOT = Path( os. path. relpath( ROOT, Path. cwd( ) ) ) import val
from models. experimental import attempt_load
from models. yolo import Model
from utils. autoanchor import check_anchors
from utils. autobatch import check_train_batch_size
from utils. callbacks import Callbacks
from utils. datasets import create_dataloader
from utils. downloads import attempt_download
from utils. general import ( LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer)
from utils. loggers import Loggers
from utils. loggers. wandb. wandb_utils import check_wandb_resume
from utils. loss import ComputeLoss, ComputeLossOTA
from utils. metrics import fitness
from utils. plots import plot_lr_scheduler, plot_evolve, plot_labels
from utils. torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
WORLD_SIZE = int ( os. getenv( 'WORLD_SIZE' , 1 ) )
LOCAL_RANK = int ( os. getenv( 'LOCAL_RANK' , - 1 ) )
RANK = int ( os. getenv( 'RANK' , - 1 ) )
def train ( hyp, opt, device, callbacks) : version = opt. versionsave_dir = Path( opt. save_dir) epochs = opt. epochsbatch_size = opt. batch_sizeweights = opt. weightssingle_cls = opt. single_clsevolve = opt. evolvedata = opt. datacfg = opt. cfgresume = opt. resumenoval = opt. noval nosave = opt. nosaveworkers = opt. workersfreeze = opt. freeze" 训练权重保存路径 " w = save_dir / 'weights' ( w. parent if evolve else w) . mkdir( parents= True , exist_ok= True ) last = w / 'last.pt' best = w / 'best.pt' " 超参数Hyperparameters加载 " if isinstance ( hyp, str ) : with open ( hyp, errors= 'ignore' ) as f: hyp = yaml. safe_load( f) LOGGER. info( colorstr( 'hyperparameters: ' ) + ', ' . join( f' { k} = { v} ' for k, v in hyp. items( ) ) ) if not evolve: with open ( save_dir / 'hyp.yaml' , 'w' ) as f: yaml. safe_dump( hyp, f, sort_keys= False ) with open ( save_dir / 'opt.yaml' , 'w' ) as f: yaml. safe_dump( vars ( opt) , f, sort_keys= False ) data_dict = None if RANK in [ - 1 , 0 ] : loggers = Loggers( save_dir, weights, opt, hyp, LOGGER) if loggers. wandb: data_dict = loggers. wandb. data_dictif resume: weights, epochs, hyp, batch_size = opt. weights, opt. epochs, opt. hyp, opt. batch_sizefor k in methods( loggers) : callbacks. register_action( k, callback= getattr ( loggers, k) ) plots = not evolve cuda = device. type != 'cpu' init_seeds( 1 + RANK) " dataset数据集加载 " with torch_distributed_zero_first( LOCAL_RANK) : data_dict = data_dict or check_dataset( data) train_path = data_dict[ 'train' ] val_path = data_dict[ 'val' ] nc = 1 if single_cls else int ( data_dict[ 'nc' ] ) names = [ 'item' ] if single_cls and len ( data_dict[ 'names' ] ) != 1 else data_dict[ 'names' ] assert len ( names) == nc, f' { len ( names) } names found for nc= { nc} dataset in { data} ' is_coco = isinstance ( val_path, str ) and val_path. endswith( 'coco/val2017.txt' ) " Model模型加载 " if version == 1 : check_suffix( weights, '.pt' ) pretrained = weights. endswith( '.pt' ) if pretrained: with torch_distributed_zero_first( LOCAL_RANK) : weights = attempt_download( weights) ckpt = torch. load( weights, map_location= 'cpu' ) model = Model( cfg or ckpt[ 'model' ] . yaml, ch= 3 , nc= nc, anchors= hyp. get( 'anchors' ) ) . to( device) exclude = [ 'anchor' ] if ( cfg or hyp. get( 'anchors' ) ) and not resume else [ ] csd = ckpt[ 'model' ] . float ( ) . state_dict( ) csd = intersect_dicts( csd, model. state_dict( ) , exclude= exclude) model. load_state_dict( csd, strict= False ) LOGGER. info( f'Transferred { len ( csd) } / { len ( model. state_dict( ) ) } items from { weights} ' ) else : model = Model( cfg, ch= 3 , nc= nc, anchors= hyp. get( 'anchors' ) ) . to( device) elif version == 2 : check_suffix( weights, [ '.pt' , '.pth' ] ) pretrained = weights. endswith( '.pt' ) or weights. endswith( '.pth' ) if pretrained: ckpt = torch. load( weights, map_location= device) model = Model( cfg) . to( device) if resume: exclude = [ 'anchor' ] if ( cfg or hyp. get( 'anchors' ) ) and not resume else [ ] csd = ckpt[ 'model' ] . float ( ) . state_dict( ) csd = intersect_dicts( csd, model. state_dict( ) , exclude= exclude) model. load_state_dict( csd, strict= False ) else : for k in list ( ckpt. keys( ) ) : if "head" in k: del ckpt[ k] if 'model' in ckpt. keys( ) : csd = ckpt[ 'model' ] else : csd = ckpt model. load_state_dict( csd, strict= False ) miss_key, unexpected_key = model. backbone. load_state_dict( csd, strict= False ) print ( "预训练权重加载结果: \n" ) else : model = Model( cfg) . to( device) else : pass freeze = [ f'model. { x} .' for x in ( freeze if len ( freeze) > 1 else range ( freeze[ 0 ] ) ) ] for k, v in model. named_parameters( ) : v. requires_grad = True