活动介绍

with torch.no_grad():使用实例

时间: 2025-07-24 15:27:11 浏览: 20
<think>嗯,用户问的是关于PyTorch中的`with torch.no_grad():`的使用实例。我需要先回忆一下这个上下文管理器的作用。它主要用于在执行某些操作时不计算梯度,这对于推理阶段或者不需要反向传播的场景非常有用。 用户可能已经知道基本概念,但需要具体的例子来理解如何在实际代码中应用。他们可能是在训练模型时想禁用梯度计算以节省内存,或者在评估模型时不需要梯度,或者在更新参数时不想影响已有的梯度计算。 首先,我应该给出一个简单的例子,比如在推理阶段使用no_grad来禁用梯度计算,这样可以减少内存消耗。然后,可能需要展示如何在训练中临时禁用梯度,例如冻结某些层的时候。另外,可能还要提到在数据预处理或者后处理时使用no_grad避免不必要的梯度跟踪。 需要注意代码的格式是否正确,比如缩进和语法。还要确保例子足够清晰,覆盖不同的使用场景。用户可能还想知道no_grad与其他类似功能的区别,比如detach()方法,但问题主要集中在使用实例上,所以可能不需要深入比较。 另外,用户可能关心no_grad是否会影响模型参数的更新,所以在例子中应该说明在no_grad块内的操作不会记录梯度,因此不会影响优化器的步骤。同时,可能需要提醒用户在使用no_grad时,如果需要更新某些参数,应该如何处理,比如手动开启requires_grad。 最后,要检查是否遗漏了其他常见应用场景,比如测试时的前向传播,或者在计算某些指标时不需要梯度的情况。确保每个例子都有对应的代码片段和解释,并且指出关键点,如内存节省、防止梯度更新等。</think>`with torch.no_grad():` 是 PyTorch 中用于**临时禁用梯度计算**的上下文管理器,常用于推理、评估或无需反向传播的场景。以下是典型使用示例: --- ### 1. **推理阶段加速计算** 在模型推理(预测)时禁用梯度计算,减少内存占用并加速前向传播。 ```python import torch from torch import nn # 定义简单模型 model = nn.Linear(10, 1) x = torch.randn(1, 10) # 推理阶段禁用梯度计算 with torch.no_grad(): output = model(x) print("输出值:", output) ``` > **效果**:`output` 不会记录梯度计算路径,显存消耗更低。 --- ### 2. **冻结网络层参数** 在训练时冻结部分网络层(如预训练模型的特征提取层),仅更新特定参数。 ```python # 加载预训练模型(例如 ResNet) import torchvision.models as models model = models.resnet18(pretrained=True) # 冻结所有参数 for param in model.parameters(): param.requires_grad = False # 解冻最后一层(假设修改为新任务) model.fc = nn.Linear(512, 10) # 训练时仅更新最后一层 optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.01) # 前向传播时无需计算冻结层的梯度 with torch.no_grad(): features = model.features(x) # 假设 features 是冻结层输出 output = model.fc(features) loss = output.sum() loss.backward() optimizer.step() ``` > **关键点**:冻结层的梯度计算被跳过,仅更新 `fc` 层参数。 --- ### 3. **手动参数更新** 在自定义优化器或参数更新时,避免梯度计算干扰。 ```python # 自定义参数更新 x = torch.tensor([2.0], requires_grad=True) with torch.no_grad(): x -= 0.1 * x.grad # 手动更新参数(不记录梯度) print("更新后的 x:", x) ``` > **注意**:直接操作 `x.data` 或 `x -= ...` 需在 `no_grad` 上下文中进行,否则会引发梯度计算错误。 --- ### 4. **数据预处理与后处理** 在数据增强或结果处理时避免梯度跟踪。 ```python # 数据预处理(如归一化) data = torch.randn(100, 3, 28, 28) with torch.no_grad(): normalized_data = (data - data.mean()) / data.std() print("归一化数据是否记录梯度:", normalized_data.requires_grad) # False ``` --- ### 5. **评估模型性能** 在验证/测试阶段禁用梯度计算以加速计算。 ```python model.eval() # 进入评估模式(关闭 Dropout/BatchNorm 等训练专用操作) with torch.no_grad(): test_output = model(test_input) accuracy = compute_accuracy(test_output, test_labels) print("测试准确率:", accuracy) ``` --- ### 关键注意事项 1. **与 `detach()` 的区别** - `no_grad` 是上下文管理器,作用于代码块内的所有操作。 - `detach()` 用于从现有张量分离出一个无梯度副本: ```python y = x.detach() # y.requires_grad == False ``` 2. **显存优化** - 禁用梯度计算可显著降低显存占用(尤其在大模型推理时)。 3. **梯度阻断** - 在 `no_grad` 块内操作不会触发梯度回传,适用于固定参数或中间计算。 ---
阅读全文

相关推荐

# 假设df_virus和labels已经定义 features = df_virus.iloc[:, :-2].values labels = df_virus['label'].values # Load your model (this is assuming you have saved your model in PyTorch format) validation_size = 0.15 performance_results = [] test_results = [] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) for train_percentage in [0.05, 0.025, 0.015, 0.01,0.004]: print(f"\nTraining percentage: {train_percentage * 100}%") for repeat in range(10): print(f"\nRepetition: {repeat + 1}") X_temp, X_val, y_temp, y_val = train_test_split( features, labels, test_size=0.15, stratify=labels, random_state=42 + repeat ) X_train, X_test, y_train, y_test = train_test_split( X_temp, y_temp, train_size=train_percentage / 0.85, stratify=y_temp, random_state=42 + repeat ) print(f"Training set size: {len(X_train)}, Validation set size: {len(X_val)}, Test set size: {len(X_test)}") scaler = StandardScaler() X_train_scaled = scaler.fit_transform(X_train) X_val_scaled = scaler.transform(X_val) X_test_scaled = scaler.transform(X_test) label_encoder = LabelEncoder() y_train_encoded = label_encoder.fit_transform(y_train) y_val_encoded = label_encoder.transform(y_val) y_test_encoded = label_encoder.transform(y_test) # Convert to PyTorch tensors X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32).unsqueeze(1).to(device) X_val_tensor = torch.tensor(X_val_scaled, dtype=torch.float32).unsqueeze(1).to(device) X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32).unsqueeze(1).to(device) y_train_tensor = torch.tensor(y_train_encoded, dtype=torch.long).to(device) y_val_tensor = torch.tensor(y_val_encoded, dtype=torch.long).to(device) y_test_tensor = torch.tensor(y_test_encoded, dtype=torch.long).to(device) train_dataset = TensorDataset(X_train_tensor, y_train_tensor) val_dataset = TensorDataset(X_val_tensor, y_val_tensor) test_dataset = TensorDataset(X_test_tensor, y_test_tensor) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=64) test_loader = DataLoader(test_dataset, batch_size=64) model = CNN1DModel() print("加载前 fc1 权重:", model.fc1.weight[:2, :2]) model.load_state_dict(torch.load(r"G:\VSCode\tl\pytorch_model_spet16.pth")) print("加载后 fc1 权重:", model.fc1.weight[:2, :2]) num_classes = len(np.unique(labels)) last_layer_name = 'fc3' num_ftrs = model._modules[last_layer_name].in_features model._modules[last_layer_name] = nn.Linear(num_ftrs, 6) # Freeze layers: Similar to TensorFlow, we'll freeze the CNN layers for param in model.parameters(): param.requires_grad = False for param in model.conv1.parameters(): param.requires_grad = True for param in model.bn1.parameters(): param.requires_grad = True for param in model.conv2.parameters(): param.requires_grad = True for param in model.bn2.parameters(): param.requires_grad = True for param in model.conv3.parameters(): param.requires_grad = True for param in model.bn3.parameters(): param.requires_grad = True for param in model.fc1.parameters(): param.requires_grad = True for param in model.fc2.parameters(): param.requires_grad = True for param in model.fc3.parameters(): param.requires_grad = True model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=3e-5) criterion = nn.CrossEntropyLoss() epochs = 600 # Early Stopping parameters patience = 7 # 连续多少个epoch验证集损失没有下降就停止训练 best_val_loss = float('inf') counter = 0 best_model_state = None # 用于保存最佳模型的参数 for epoch in tqdm(range(epochs)): model.train() running_loss = 0.0 correct = 0 total = 0 for X_batch, y_batch in train_loader: optimizer.zero_grad() outputs = model(X_batch) loss = criterion(outputs, y_batch) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += y_batch.size(0) correct += (predicted == y_batch).sum().item() train_accuracy = correct / total train_loss = running_loss / len(train_loader) # Validation model.eval() val_loss = 0.0 correct_val = 0 total_val = 0 with torch.no_grad(): for X_val_batch, y_val_batch in val_loader: outputs = model(X_val_batch) loss = criterion(outputs, y_val_batch) val_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total_val += y_val_batch.size(0) correct_val += (predicted == y_val_batch).sum().item() val_accuracy = correct_val / total_val val_loss = val_loss / len(val_loader) performance_results.append({ 'train_percentage': train_percentage, 'repeat': repeat + 1, 'epoch': epoch + 1, 'train_accuracy': train_accuracy, 'val_accuracy': val_accuracy, 'train_loss': train_loss, 'val_loss': val_loss }) # Early Stopping check if val_loss < best_val_loss: best_val_loss = val_loss counter = 0 best_model_state = model.state_dict() # 保存最佳模型的参数 else: counter += 1 if counter >= patience: print(f"Early stopping at epoch {epoch + 1}") break # 加载最佳模型的参数 if best_model_state is not None: model.load_state_dict(best_model_state) # Test set evaluation model.eval() test_loss = 0.0 correct_test = 0 total_test = 0 all_preds = [] with torch.no_grad(): for X_test_batch, y_test_batch in test_loader: outputs = model(X_test_batch) loss = criterion(outputs, y_test_batch) test_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total_test += y_test_batch.size(0) correct_test += (predicted == y_test_batch).sum().item() all_preds.extend(predicted.cpu().numpy()) test_accuracy = correct_test / total_test cm = confusion_matrix(y_test_encoded, all_preds) f1 = f1_score(y_test_encoded, all_preds, average='weighted') precision = precision_score(y_test_encoded, all_preds, average='weighted') recall = recall_score(y_test_encoded, all_preds, average='weighted') test_results.append({ 'train_percentage': train_percentage, 'repeat': repeat + 1, 'test_loss': test_loss / len(test_loader), 'test_accuracy': test_accuracy, 'f1_score': f1, 'precision': precision, 'recall': recall, 'confusion_matrix': cm.tolist() }) # Convert to DataFrames performance_df_TL = pd.DataFrame(performance_results) test_results_df_TL = pd.DataFrame(test_results) 如何修改

下面的Python代码是基于高斯渲染的三维场景训练系统。请解释代码内容,给每一行代码添加注释。 def training(): training_args = cfg.train optim_args = cfg.optim data_args = cfg.data start_iter = 0 tb_writer = prepare_output_and_logger() dataset = Dataset() gaussians = StreetGaussianModel(dataset.scene_info.metadata) scene = Scene(gaussians=gaussians, dataset=dataset) gaussians.training_setup() try: if cfg.loaded_iter == -1: loaded_iter = searchForMaxIteration(cfg.trained_model_dir) else: loaded_iter = cfg.loaded_iter ckpt_path = os.path.join(cfg.trained_model_dir, f'iteration_{loaded_iter}.pth') state_dict = torch.load(ckpt_path) start_iter = state_dict['iter'] print(f'Loading model from {ckpt_path}') gaussians.load_state_dict(state_dict) except: pass print(f'Starting from {start_iter}') save_cfg(cfg, cfg.model_path, epoch=start_iter) gaussians_renderer = StreetGaussianRenderer() iter_start = torch.cuda.Event(enable_timing = True) iter_end = torch.cuda.Event(enable_timing = True) ema_loss_for_log = 0.0 ema_psnr_for_log = 0.0 psnr_dict = {} progress_bar = tqdm(range(start_iter, training_args.iterations)) start_iter += 1 viewpoint_stack = None for iteration in range(start_iter, training_args.iterations + 1): iter_start.record() gaussians.update_learning_rate(iteration) # Every 1000 its we increase the levels of SH up to a maximum degree if iteration % 1000 == 0: gaussians.oneupSHdegree() # Every 1000 iterations upsample # if iteration % 1000 == 0: # if resolution_scales: # scale = resolution_scales.pop() # Pick a random Camera if not viewpoint_stack: viewpoint_stack = scene.getTrainCameras().copy() viewpoint_cam: Camera = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) # ==================================================================== # Get mask # original_mask: pixel in original_mask with 0 will not be surpervised # original_acc_mask: use to suepervise the acc result of rendering # original_sky_mask: sky mask gt_image = viewpoint_cam.original_image.cuda() if hasattr(viewpoint_cam, 'original_mask'): mask = viewpoint_cam.original_mask.cuda().bool() else: mask = torch.ones_like(gt_image[0:1]).bool() if hasattr(viewpoint_cam, 'original_sky_mask'): sky_mask = viewpoint_cam.original_sky_mask.cuda() else: sky_mask = None if hasattr(viewpoint_cam, 'original_obj_bound'): obj_bound = viewpoint_cam.original_obj_bound.cuda().bool() else: obj_bound = torch.zeros_like(gt_image[0:1]).bool() if (iteration - 1) == training_args.debug_from: cfg.render.debug = True render_pkg = gaussians_renderer.render(viewpoint_cam, gaussians) image, acc, viewspace_point_tensor, visibility_filter, radii = render_pkg["rgb"], render_pkg['acc'], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] depth = render_pkg['depth'] # [1, H, W] scalar_dict = dict() # rgb loss Ll1 = l1_loss(image, gt_image, mask) scalar_dict['l1_loss'] = Ll1.item() loss = (1.0 - optim_args.lambda_dssim) * optim_args.lambda_l1 * Ll1 + optim_args.lambda_dssim * (1.0 - ssim(image, gt_image, mask=mask)) # sky loss if optim_args.lambda_sky > 0 and gaussians.include_sky and sky_mask is not None: acc = torch.clamp(acc, min=1e-6, max=1.-1e-6) sky_loss = torch.where(sky_mask, -torch.log(1 - acc), -torch.log(acc)).mean() if len(optim_args.lambda_sky_scale) > 0: sky_loss *= optim_args.lambda_sky_scale[viewpoint_cam.meta['cam']] scalar_dict['sky_loss'] = sky_loss.item() loss += optim_args.lambda_sky * sky_loss # semantic loss if optim_args.lambda_semantic > 0 and data_args.get('use_semantic', False) and 'semantic' in viewpoint_cam.meta: gt_semantic = viewpoint_cam.meta['semantic'].cuda().long() # [1, H, W] if torch.all(gt_semantic == -1): semantic_loss = torch.zeros_like(Ll1) else: semantic = render_pkg['semantic'].unsqueeze(0) # [1, S, H, W] semantic_loss = torch.nn.functional.cross_entropy( input=semantic, target=gt_semantic, ignore_index=-1, reduction='mean' ) scalar_dict['semantic_loss'] = semantic_loss.item() loss += optim_args.lambda_semantic * semantic_loss if optim_args.lambda_reg > 0 and gaussians.include_obj and iteration >= optim_args.densify_until_iter: render_pkg_obj = gaussians_renderer.render_object(viewpoint_cam, gaussians) image_obj, acc_obj = render_pkg_obj["rgb"], render_pkg_obj['acc'] acc_obj = torch.clamp(acc_obj, min=1e-6, max=1.-1e-6) # box_reg_loss = gaussians.get_box_reg_loss() # scalar_dict['box_reg_loss'] = box_reg_loss.item() # loss += optim_args.lambda_reg * box_reg_loss obj_acc_loss = torch.where(obj_bound, -(acc_obj * torch.log(acc_obj) + (1. - acc_obj) * torch.log(1. - acc_obj)), -torch.log(1. - acc_obj)).mean() scalar_dict['obj_acc_loss'] = obj_acc_loss.item() loss += optim_args.lambda_reg * obj_acc_loss # obj_acc_loss = -((acc_obj * torch.log(acc_obj) + (1. - acc_obj) * torch.log(1. - acc_obj))).mean() # scalar_dict['obj_acc_loss'] = obj_acc_loss.item() # loss += optim_args.lambda_reg * obj_acc_loss # lidar depth loss if optim_args.lambda_depth_lidar > 0 and 'lidar_depth' in viewpoint_cam.meta: lidar_depth = viewpoint_cam.meta['lidar_depth'].cuda() # [1, H, W] depth_mask = torch.logical_and((lidar_depth > 0.), mask) # depth_mask[obj_bound] = False if torch.nonzero(depth_mask).any(): expected_depth = depth / (render_pkg['acc'] + 1e-10) depth_error = torch.abs((expected_depth[depth_mask] - lidar_depth[depth_mask])) depth_error, _ = torch.topk(depth_error, int(0.95 * depth_error.size(0)), largest=False) lidar_depth_loss = depth_error.mean() scalar_dict['lidar_depth_loss'] = lidar_depth_loss else: lidar_depth_loss = torch.zeros_like(Ll1) loss += optim_args.lambda_depth_lidar * lidar_depth_loss # color correction loss if optim_args.lambda_color_correction > 0 and gaussians.use_color_correction: color_correction_reg_loss = gaussians.color_correction.regularization_loss(viewpoint_cam) scalar_dict['color_correction_reg_loss'] = color_correction_reg_loss.item() loss += optim_args.lambda_color_correction * color_correction_reg_loss # pose correction loss if optim_args.lambda_pose_correction > 0 and gaussians.use_pose_correction: pose_correction_reg_loss = gaussians.pose_correction.regularization_loss() scalar_dict['pose_correction_reg_loss'] = pose_correction_reg_loss.item() loss += optim_args.lambda_pose_correction * pose_correction_reg_loss # scale flatten loss if optim_args.lambda_scale_flatten > 0: scale_flatten_loss = gaussians.background.scale_flatten_loss() scalar_dict['scale_flatten_loss'] = scale_flatten_loss.item() loss += optim_args.lambda_scale_flatten * scale_flatten_loss # opacity sparse loss if optim_args.lambda_opacity_sparse > 0: opacity = gaussians.get_opacity opacity = opacity.clamp(1e-6, 1-1e-6) log_opacity = opacity * torch.log(opacity) log_one_minus_opacity = (1-opacity) * torch.log(1 - opacity) sparse_loss = -1 * (log_opacity + log_one_minus_opacity)[visibility_filter].mean() scalar_dict['opacity_sparse_loss'] = sparse_loss.item() loss += optim_args.lambda_opacity_sparse * sparse_loss # normal loss if optim_args.lambda_normal_mono > 0 and 'mono_normal' in viewpoint_cam.meta and 'normals' in render_pkg: if sky_mask is None: normal_mask = mask else: normal_mask = torch.logical_and(mask, ~sky_mask) normal_mask = normal_mask.squeeze(0) normal_mask[:50] = False normal_gt = viewpoint_cam.meta['mono_normal'].permute(1, 2, 0).cuda() # [H, W, 3] R_c2w = viewpoint_cam.world_view_transform[:3, :3] normal_gt = torch.matmul(normal_gt, R_c2w.T) # to world space normal_pred = render_pkg['normals'].permute(1, 2, 0) # [H, W, 3] normal_l1_loss = torch.abs(normal_pred[normal_mask] - normal_gt[normal_mask]).mean() normal_cos_loss = (1. - torch.sum(normal_pred[normal_mask] * normal_gt[normal_mask], dim=-1)).mean() scalar_dict['normal_l1_loss'] = normal_l1_loss.item() scalar_dict['normal_cos_loss'] = normal_cos_loss.item() normal_loss = normal_l1_loss + normal_cos_loss loss += optim_args.lambda_normal_mono * normal_loss scalar_dict['loss'] = loss.item() loss.backward() iter_end.record() is_save_images = True if is_save_images and (iteration % 1000 == 0): # row0: gt_image, image, depth # row1: acc, image_obj, acc_obj depth_colored, _ = visualize_depth_numpy(depth.detach().cpu().numpy().squeeze(0)) depth_colored = depth_colored[..., [2, 1, 0]] / 255. depth_colored = torch.from_numpy(depth_colored).permute(2, 0, 1).float().cuda() row0 = torch.cat([gt_image, image, depth_colored], dim=2) acc = acc.repeat(3, 1, 1) with torch.no_grad(): render_pkg_obj = gaussians_renderer.render_object(viewpoint_cam, gaussians) image_obj, acc_obj = render_pkg_obj["rgb"], render_pkg_obj['acc'] acc_obj = acc_obj.repeat(3, 1, 1) row1 = torch.cat([acc, image_obj, acc_obj], dim=2) image_to_show = torch.cat([row0, row1], dim=1) image_to_show = torch.clamp(image_to_show, 0.0, 1.0) os.makedirs(f"{cfg.model_path}/log_images", exist_ok = True) save_img_torch(image_to_show, f"{cfg.model_path}/log_images/{iteration}.jpg") with torch.no_grad(): # Log tensor_dict = dict() # Progress bar ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log ema_psnr_for_log = 0.4 * psnr(image, gt_image, mask).mean().float() + 0.6 * ema_psnr_for_log if viewpoint_cam.id not in psnr_dict: psnr_dict[viewpoint_cam.id] = psnr(image, gt_image, mask).mean().float() else: psnr_dict[viewpoint_cam.id] = 0.4 * psnr(image, gt_image, mask).mean().float() + 0.6 * psnr_dict[viewpoint_cam.id] if iteration % 10 == 0: progress_bar.set_postfix({"Exp": f"{cfg.task}-{cfg.exp_name}", "Loss": f"{ema_loss_for_log:.{7}f},", "PSNR": f"{ema_psnr_for_log:.{4}f}"}) progress_bar.update(10) if iteration == training_args.iterations: progress_bar.close() # Log and save if (iteration in training_args.save_iterations): print("\n[ITER {}] Saving Gaussians".format(iteration)) scene.save(iteration) # Densification if iteration < optim_args.densify_until_iter: gaussians.set_visibility(include_list=list(set(gaussians.model_name_id.keys()) - set(['sky']))) gaussians.parse_camera(viewpoint_cam) gaussians.set_max_radii2D(radii, visibility_filter) gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) prune_big_points = iteration > optim_args.opacity_reset_interval if iteration > optim_args.densify_from_iter: if iteration % optim_args.densification_interval == 0: scalars, tensors = gaussians.densify_and_prune( max_grad=optim_args.densify_grad_threshold, min_opacity=optim_args.min_opacity, prune_big_points=prune_big_points, ) scalar_dict.update(scalars) tensor_dict.update(tensors) # Reset opacity if iteration < optim_args.densify_until_iter: if iteration % optim_args.opacity_reset_interval == 0: gaussians.reset_opacity() if data_args.white_background and iteration == optim_args.densify_from_iter: gaussians.reset_opacity() training_report(tb_writer, iteration, scalar_dict, tensor_dict, training_args.test_iterations, scene, gaussians_renderer) # Optimizer step if iteration < training_args.iterations: gaussians.update_optimizer() if (iteration in training_args.checkpoint_iterations): print("\n[ITER {}] Saving Checkpoint".format(iteration)) state_dict = gaussians.save_state_dict(is_final=(iteration == training_args.iterations)) state_dict['iter'] = iteration ckpt_path = os.path.join(cfg.trained_model_dir, f'iteration_{iteration}.pth') torch.save(state_dict, ckpt_path)

from data import * from utils.augmentations import SSDAugmentation, BaseTransform from utils.functions import MovingAverage, SavePath from utils.logger import Log from utils import timer from layers.modules import MultiBoxLoss from yolact import Yolact import os import sys import time import math, random from pathlib import Path import torch from torch.autograd import Variable import torch.nn as nn import torch.optim as optim import torch.backends.cudnn as cudnn import torch.nn.init as init import torch.utils.data as data import numpy as np import argparse import datetime # Oof import eval as eval_script def str2bool(v): return v.lower() in ("yes", "true", "t", "1") parser = argparse.ArgumentParser( description='Yolact Training Script') parser.add_argument('--batch_size', default=2, type=int, help='Batch size for training') parser.add_argument('--resume', default=None, type=str, help='Checkpoint state_dict file to resume training from. If this is "interrupt"'\ ', the model will resume training from the interrupt file.') parser.add_argument('--start_iter', default=-1, type=int, help='Resume training at this iter. If this is -1, the iteration will be'\ 'determined from the file name.') parser.add_argument('--num_workers', default=0, type=int, help='Number of workers used in dataloading') parser.add_argument('--cuda', default=True, type=str2bool, help='Use CUDA to train model') parser.add_argument('--lr', '--learning_rate', default=None, type=float, help='Initial learning rate. Leave as None to read this from the config.') parser.add_argument('--momentum', default=None, type=float, help='Momentum for SGD. Leave as None to read this from the config.') parser.add_argument('--decay', '--weight_decay', default=None, type=float, help='Weight decay for SGD. Leave as None to read this from the config.') parser.add_argument('--gamma', default=None, type=float, help='For each lr step, what to multiply the lr by. Leave as None to read this from the config.') parser.add_argument('--save_folder', default='weights/', help='Directory for saving checkpoint models.') parser.add_argument('--log_folder', default='logs/', help='Directory for saving logs.') parser.add_argument('--config', default=None, help='The config object to use.') parser.add_argument('--save_interval', default=10000, type=int, help='The number of iterations between saving the model.') parser.add_argument('--validation_size', default=5000, type=int, help='The number of images to use for validation.') parser.add_argument('--validation_epoch', default=2, type=int, help='Output validation information every n iterations. If -1, do no validation.') parser.add_argument('--keep_latest', dest='keep_latest', action='store_true', help='Only keep the latest checkpoint instead of each one.') parser.add_argument('--keep_latest_interval', default=100000, type=int, help='When --keep_latest is on, don\'t delete the latest file at these intervals. This should be a multiple of save_interval or 0.') parser.add_argument('--dataset', default=None, type=str, help='If specified, override the dataset specified in the config with this one (example: coco2017_dataset).') parser.add_argument('--no_log', dest='log', action='store_false', help='Don\'t log per iteration information into log_folder.') parser.add_argument('--log_gpu', dest='log_gpu', action='store_true', help='Include GPU information in the logs. Nvidia-smi tends to be slow, so set this with caution.') parser.add_argument('--no_interrupt', dest='interrupt', action='store_false', help='Don\'t save an interrupt when KeyboardInterrupt is caught.') parser.add_argument('--batch_alloc', default=None, type=str, help='If using multiple GPUS, you can set this to be a comma separated list detailing which GPUs should get what local batch size (It should add up to your total batch size).') parser.add_argument('--no_autoscale', dest='autoscale', action='store_false', help='YOLACT will automatically scale the lr and the number of iterations depending on the batch size. Set this if you want to disable that.') parser.set_defaults(keep_latest=False, log=True, log_gpu=False, interrupt=True, autoscale=True) args = parser.parse_args() if args.config is not None: set_cfg(args.config) if args.dataset is not None: set_dataset(args.dataset) if args.autoscale and args.batch_size != 8: factor = args.batch_size / 8 if __name__ == '__main__': print('Scaling parameters by %.2f to account for a batch size of %d.' % (factor, args.batch_size)) cfg.lr *= factor cfg.max_iter //= factor cfg.lr_steps = [x // factor for x in cfg.lr_steps] # Update training parameters from the config if necessary def replace(name): if getattr(args, name) == None: setattr(args, name, getattr(cfg, name)) replace('lr') replace('decay') replace('gamma') replace('momentum') # This is managed by set_lr cur_lr = args.lr if torch.cuda.device_count() == 0: print('No GPUs detected. Exiting...') exit(-1) if args.batch_size // torch.cuda.device_count() < 6: if __name__ == '__main__': print('Per-GPU batch size is less than the recommended limit for batch norm. Disabling batch norm.') cfg.freeze_bn = True loss_types = ['B', 'C', 'M', 'P', 'D', 'E', 'S', 'I'] if torch.cuda.is_available(): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') if not args.cuda: print("WARNING: It looks like you have a CUDA device, but aren't " + "using CUDA.\nRun with --cuda for optimal training speed.") torch.set_default_tensor_type('torch.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') class NetLoss(nn.Module): """ A wrapper for running the network and computing the loss This is so we can more efficiently use DataParallel. """ def __init__(self, net:Yolact, criterion:MultiBoxLoss): super().__init__() self.net = net self.criterion = criterion def forward(self, images, targets, masks, num_crowds): preds = self.net(images) losses = self.criterion(self.net, preds, targets, masks, num_crowds) return losses class CustomDataParallel(nn.DataParallel): """ This is a custom version of DataParallel that works better with our training data. It should also be faster than the general case. """ def scatter(self, inputs, kwargs, device_ids): # More like scatter and data prep at the same time. The point is we prep the data in such a way # that no scatter is necessary, and there's no need to shuffle stuff around different GPUs. devices = ['cuda:' + str(x) for x in device_ids] splits = prepare_data(inputs[0], devices, allocation=args.batch_alloc) return [[split[device_idx] for split in splits] for device_idx in range(len(devices))], \ [kwargs] * len(devices) def gather(self, outputs, output_device): out = {} for k in outputs[0]: out[k] = torch.stack([output[k].to(output_device) for output in outputs]) return out def train(): if not os.path.exists(args.save_folder): os.mkdir(args.save_folder) dataset = COCODetection(image_path=cfg.dataset.train_images, info_file=cfg.dataset.train_info, transform=SSDAugmentation(MEANS)) if args.validation_epoch > 0: setup_eval() val_dataset = COCODetection(image_path=cfg.dataset.valid_images, info_file=cfg.dataset.valid_info, transform=BaseTransform(MEANS)) # Parallel wraps the underlying module, but when saving and loading we don't want that yolact_net = Yolact() net = yolact_net net.train() if args.log: log = Log(cfg.name, args.log_folder, dict(args._get_kwargs()), overwrite=(args.resume is None), log_gpu_stats=args.log_gpu) # I don't use the timer during training (I use a different timing method). # Apparently there's a race condition with multiple GPUs, so disable it just to be safe. timer.disable_all() # Both of these can set args.resume to None, so do them before the check if args.resume == 'interrupt': args.resume = SavePath.get_interrupt(args.save_folder) elif args.resume == 'latest': args.resume = SavePath.get_latest(args.save_folder, cfg.name) if args.resume is not None: print('Resuming training, loading {}...'.format(args.resume)) yolact_net.load_weights(args.resume) if args.start_iter == -1: args.start_iter = SavePath.from_str(args.resume).iteration else: print('Initializing weights...') yolact_net.init_weights(backbone_path=args.save_folder + cfg.backbone.path) optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.decay) criterion = MultiBoxLoss(num_classes=cfg.num_classes, pos_threshold=cfg.positive_iou_threshold, neg_threshold=cfg.negative_iou_threshold, negpos_ratio=cfg.ohem_negpos_ratio) if args.batch_alloc is not None: args.batch_alloc = [int(x) for x in args.batch_alloc.split(',')] if sum(args.batch_alloc) != args.batch_size: print('Error: Batch allocation (%s) does not sum to batch size (%s).' % (args.batch_alloc, args.batch_size)) exit(-1) net = CustomDataParallel(NetLoss(net, criterion)) if args.cuda: net = net.cuda() # Initialize everything if not cfg.freeze_bn: yolact_net.freeze_bn() # Freeze bn so we don't kill our means yolact_net(torch.zeros(1, 3, cfg.max_size, cfg.max_size).cuda()) if not cfg.freeze_bn: yolact_net.freeze_bn(True) # loss counters loc_loss = 0 conf_loss = 0 iteration = max(args.start_iter, 0) last_time = time.time() epoch_size = len(dataset)+1 // args.batch_size num_epochs = math.ceil(cfg.max_iter / epoch_size) # Which learning rate adjustment step are we on? lr' = lr * gamma ^ step_index step_index = 0 data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=detection_collate, pin_memory=True) save_path = lambda epoch, iteration: SavePath(cfg.name, epoch, iteration).get_path(root=args.save_folder) time_avg = MovingAverage() global loss_types # Forms the print order loss_avgs = { k: MovingAverage(100) for k in loss_types } print('Begin training!') print() # try-except so you can use ctrl+c to save early and stop training try: for epoch in range(num_epochs): # Resume from start_iter if (epoch+1)*epoch_size < iteration: continue for datum in data_loader: # Stop if we've reached an epoch if we're resuming from start_iter if iteration == (epoch+1)*epoch_size: break # Stop at the configured number of iterations even if mid-epoch if iteration == cfg.max_iter: break # Change a config setting if we've reached the specified iteration changed = False for change in cfg.delayed_settings: if iteration >= change[0]: changed = True cfg.replace(change[1]) # Reset the loss averages because things might have changed for avg in loss_avgs: avg.reset() # If a config setting was changed, remove it from the list so we don't keep checking if changed: cfg.delayed_settings = [x for x in cfg.delayed_settings if x[0] > iteration] # Warm up by linearly interpolating the learning rate from some smaller value if cfg.lr_warmup_until > 0 and iteration <= cfg.lr_warmup_until: set_lr(optimizer, (args.lr - cfg.lr_warmup_init) * (iteration / cfg.lr_warmup_until) + cfg.lr_warmup_init) # Adjust the learning rate at the given iterations, but also if we resume from past that iteration while step_index < len(cfg.lr_steps) and iteration >= cfg.lr_steps[step_index]: step_index += 1 set_lr(optimizer, args.lr * (args.gamma ** step_index)) # Zero the grad to get ready to compute gradients optimizer.zero_grad() # Forward Pass + Compute loss at the same time (see CustomDataParallel and NetLoss) losses = net(datum) losses = { k: (v).mean() for k,v in losses.items() } # Mean here because Dataparallel loss = sum([losses[k] for k in losses]) # no_inf_mean removes some components from the loss, so make sure to backward through all of it # all_loss = sum([v.mean() for v in losses.values()]) # Backprop loss.backward() # Do this to free up vram even if loss is not finite if torch.isfinite(loss).item(): optimizer.step() # Add the loss to the moving average for bookkeeping for k in losses: loss_avgs[k].add(losses[k].item()) cur_time = time.time() elapsed = cur_time - last_time last_time = cur_time # Exclude graph setup from the timing information if iteration != args.start_iter: time_avg.add(elapsed) if iteration % 10 == 0: eta_str = str(datetime.timedelta(seconds=(cfg.max_iter-iteration) * time_avg.get_avg())).split('.')[0] total = sum([loss_avgs[k].get_avg() for k in losses]) loss_labels = sum([[k, loss_avgs[k].get_avg()] for k in loss_types if k in losses], []) print(('[%3d] %7d ||' + (' %s: %.3f |' * len(losses)) + ' T: %.3f || ETA: %s || timer: %.3f') % tuple([epoch, iteration] + loss_labels + [total, eta_str, elapsed]), flush=True) if args.log: precision = 5 loss_info = {k: round(losses[k].item(), precision) for k in losses} loss_info['T'] = round(loss.item(), precision) if args.log_gpu: log.log_gpu_stats = (iteration % 10 == 0) # nvidia-smi is sloooow log.log('train', loss=loss_info, epoch=epoch, iter=iteration, lr=round(cur_lr, 10), elapsed=elapsed) log.log_gpu_stats = args.log_gpu iteration += 1 if iteration % args.save_interval == 0 and iteration != args.start_iter: if args.keep_latest: latest = SavePath.get_latest(args.save_folder, cfg.name) print('Saving state, iter:', iteration) yolact_net.save_weights(save_path(epoch, iteration)) if args.keep_latest and latest is not None: if args.keep_latest_interval <= 0 or iteration % args.keep_latest_interval != args.save_interval: print('Deleting old save...') os.remove(latest) # This is done per epoch if args.validation_epoch > 0: if epoch % args.validation_epoch == 0 and epoch > 0: compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None) # Compute validation mAP after training is finished compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None) except KeyboardInterrupt: if args.interrupt: print('Stopping early. Saving network...') # Delete previous copy of the interrupted network so we don't spam the weights folder SavePath.remove_interrupt(args.save_folder) yolact_net.save_weights(save_path(epoch, repr(iteration) + '_interrupt')) exit() yolact_net.save_weights(save_path(epoch, iteration)) def set_lr(optimizer, new_lr): for param_group in optimizer.param_groups: param_group['lr'] = new_lr global cur_lr cur_lr = new_lr def gradinator(x): x.requires_grad = False return x def prepare_data(datum, devices:list=None, allocation:list=None): with torch.no_grad(): if devices is None: devices = ['cuda:0'] if args.cuda else ['cpu'] if allocation is None: allocation = [args.batch_size // len(devices)] * (len(devices) - 1) allocation.append(args.batch_size - sum(allocation)) # The rest might need more/less images, (targets, masks, num_crowds) = datum cur_idx = 0 for device, alloc in zip(devices, allocation): for _ in range(alloc): images[cur_idx] = gradinator(images[cur_idx].to(device)) targets[cur_idx] = gradinator(targets[cur_idx].to(device)) masks[cur_idx] = gradinator(masks[cur_idx].to(device)) cur_idx += 1 if cfg.preserve_aspect_ratio: # Choose a random size from the batch _, h, w = images[random.randint(0, len(images)-1)].size() for idx, (image, target, mask, num_crowd) in enumerate(zip(images, targets, masks, num_crowds)): images[idx], targets[idx], masks[idx], num_crowds[idx] \ = enforce_size(image, target, mask, num_crowd, w, h) cur_idx = 0 split_images, split_targets, split_masks, split_numcrowds \ = [[None for alloc in allocation] for _ in range(4)] for device_idx, alloc in enumerate(allocation): split_images[device_idx] = torch.stack(images[cur_idx:cur_idx+alloc], dim=0) split_targets[device_idx] = targets[cur_idx:cur_idx+alloc] split_masks[device_idx] = masks[cur_idx:cur_idx+alloc] split_numcrowds[device_idx] = num_crowds[cur_idx:cur_idx+alloc] cur_idx += alloc return split_images, split_targets, split_masks, split_numcrowds def no_inf_mean(x:torch.Tensor): """ Computes the mean of a vector, throwing out all inf values. If there are no non-inf values, this will return inf (i.e., just the normal mean). """ no_inf = [a for a in x if torch.isfinite(a)] if len(no_inf) > 0: return sum(no_inf) / len(no_inf) else: return x.mean() def compute_validation_loss(net, data_loader, criterion): global loss_types with torch.no_grad(): losses = {} # Don't switch to eval mode because we want to get losses iterations = 0 for datum in data_loader: images, targets, masks, num_crowds = prepare_data(datum) out = net(images) wrapper = ScatterWrapper(targets, masks, num_crowds) _losses = criterion(out, wrapper, wrapper.make_mask()) for k, v in _losses.items(): v = v.mean().item() if k in losses: losses[k] += v else: losses[k] = v iterations += 1 if args.validation_size <= iterations * args.batch_size: break for k in losses: losses[k] /= iterations loss_labels = sum([[k, losses[k]] for k in loss_types if k in losses], []) print(('Validation ||' + (' %s: %.3f |' * len(losses)) + ')') % tuple(loss_labels), flush=True) def compute_validation_map(epoch, iteration, yolact_net, dataset, log:Log=None): with torch.no_grad(): yolact_net.eval() start = time.time() print() print("Computing validation mAP (this may take a while)...", flush=True) val_info = eval_script.evaluate(yolact_net, dataset, train_mode=True) end = time.time() if log is not None: log.log('val', val_info, elapsed=(end - start), epoch=epoch, iter=iteration) yolact_net.train() def setup_eval(): eval_script.parse_args(['--no_bar', '--max_images='+str(args.validation_size)]) if __name__ == '__main__': train() 模型初始化处在哪儿

模型初始化成功 加载并转换预训练权重: checkpoints/M3FD.ckpt 找到 1 个需要转换的权重层 转换权重: first_stage_model.encoder.conv_in.weight 原始形状: torch.Size([128, 3, 3, 3]) 转换后形状: torch.Size([128, 1, 3, 3]) 权重加载完成: 成功加载 1830/1830 层 VAE输入层形状: torch.Size([128, 1, 3, 3]) 加载预训练权重: checkpoints/M3FD.ckpt 找到需要转换的层: ['first_stage_model.encoder.conv_in.weight'] 转换权重: first_stage_model.encoder.conv_in.weight 原始形状: torch.Size([128, 3, 3, 3]) 转换后形状: torch.Size([128, 1, 3, 3]) 权重加载完成: 缺失层 1, 不匹配层 0 缺失层: ['cond_stage_model.transformer.text_model.embeddings.position_ids'] 使用基础学习率: 1.00e-04 警告: 使用仅含权重的检查点,训练状态将重置: D:\work\AI\DiffV2IR\checkpoints\M3FD.ckpt 预处理检查点文件: D:\work\AI\DiffV2IR\checkpoints\M3FD.ckpt 已重置训练状态: epoch=0, global_step=0 警告: 检查点缺少 'cond_stage_model.transformer.text_model.embeddings.position_ids' 键 已添加 position_ids 到检查点 修复后的完整检查点已保存到: logs\experiment_20250703_181326\checkpoints\fixed_checkpoint.ckpt TensorBoard日志保存在: logs\experiment_20250703_181326\tensorboard 训练批次数: 8 最终训练器配置: default_root_dir: logs\experiment_20250703_181326 max_epochs: 200 gpus: 1 distributed_backend: None plugins: None precision: 16 accumulate_grad_batches: 1 callbacks: [, , <__main__.EnhancedImageLogger object at 0x000001F79D63F130>, <__main__.TQDMProgressBar object at 0x000001F79D63F040>, <__main__.PerformanceMonitor object at 0x000001F79D63F010>] logger: <__main__.TensorBoardLogger object at 0x000001F79BA61480> resume_from_checkpoint: logs\experiment_20250703_181326\checkpoints\fixed_checkpoint.ckpt fast_dev_run: False limit_val_batches: 1.0 num_sanity_val_steps: 0 log_every_n_steps: 10 check_val_every_n_epoch: 1 Using native 16bit precision. GPU available: True, used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs 开始训练... Restoring states from the checkpoint file at logs\experiment_20250703_181326\checkpoints\fixed_checkpoint.ckpt 训练出错: Error(s) in loading state_dict for CustomLatentDiffusion: size mismatch for first_stage_model.encoder.conv_in.weight: copying a param with shape torch.Size([128, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3]).

# 这是一个示例 Python 脚本。 # 按 Shift+F10 执行或将其替换为您的代码。 # 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 import argparse import math import pickle import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm from omegaconf import OmegaConf from sklearn.metrics import f1_score from torch.utils.data import Dataset, DataLoader from torch.nn import TransformerEncoderLayer, TransformerEncoder restypes = [ 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V' ] unsure_restype = 'X' unknown_restype = 'U' def make_dataset(data_config, train_rate=0.7, valid_rate=0.2): data_path = data_config.data_path with open(data_path, 'rb') as f: data = pickle.load(f) total_number = len(data) train_sep = int(total_number * train_rate) valid_sep = int(total_number * (train_rate + valid_rate)) train_data_dicts = data[:train_sep] valid_data_dicts = data[train_sep:valid_sep] test_data_dicts = data[valid_sep:] train_dataset = DisProtDataset(train_data_dicts) valid_dataset = DisProtDataset(valid_data_dicts) test_dataset = DisProtDataset(test_data_dicts) return train_dataset, valid_dataset, test_dataset class DisProtDataset(Dataset): def __init__(self, dict_data): sequences = [d['sequence'] for d in dict_data] labels = [d['label'] for d in dict_data] assert len(sequences) == len(labels) self.sequences = sequences self.labels = labels self.residue_mapping = {'X':20} self.residue_mapping.update(dict(zip(restypes, range(len(restypes))))) def __len__(self): return len(self.sequences) def __getitem__(self, idx): sequence = torch.zeros(len(self.sequences[idx]), len(self.residue_mapping)) for i, c in enumerate(self.sequences[idx]): if c not in restypes: c = 'X' sequence[i][self.residue_mapping[c]] = 1 label = torch.tensor([int(c) for c in self.labels[idx]], dtype=torch.long) return sequence, label class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.0, max_len=40): super().__init__() position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) ) pe = torch.zeros(1, max_len, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) self.dropout = nn.Dropout(p=dropout) def forward(self, x): if len(x.shape) == 3: x = x + self.pe[:, : x.size(1)] elif len(x.shape) == 4: x = x + self.pe[:, :x.size(1), None, :] return self.dropout(x) class DisProtModel(nn.Module): def __init__(self, model_config): super().__init__() self.d_model = model_config.d_model self.n_head = model_config.n_head self.n_layer = model_config.n_layer self.input_layer = nn.Linear(model_config.i_dim, self.d_model) self.position_embed = PositionalEncoding(self.d_model, max_len=20000) self.input_norm = nn.LayerNorm(self.d_model) self.dropout_in = nn.Dropout(p=0.1) encoder_layer = TransformerEncoderLayer( d_model=self.d_model, nhead=self.n_head, activation='gelu', batch_first=True) self.transformer = TransformerEncoder(encoder_layer, num_layers=self.n_layer) self.output_layer = nn.Sequential( nn.Linear(self.d_model, self.d_model), nn.GELU(), nn.Dropout(p=0.1), nn.Linear(self.d_model, model_config.o_dim) ) def forward(self, x): x = self.input_layer(x) x = self.position_embed(x) x = self.input_norm(x) x = self.dropout_in(x) x = self.transformer(x) x = self.output_layer(x) return x def metric_fn(pred, gt): pred = pred.detach().cpu() gt = gt.detach().cpu() pred_labels = torch.argmax(pred, dim=-1).view(-1) gt_labels = gt.view(-1) score = f1_score(y_true=gt_labels, y_pred=pred_labels, average='micro') return score if __name__ == '__main__': device = 'cuda' if torch.cuda.is_available() else 'cpu' parser = argparse.ArgumentParser('IDRs prediction') parser.add_argument('--config_path', default='./config.yaml') args = parser.parse_args() config = OmegaConf.load(args.config_path) train_dataset, valid_dataset, test_dataset = make_dataset(config.data) train_dataloader = DataLoader(dataset=train_dataset, **config.train.dataloader) valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=False) model = DisProtModel(config.model) model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=config.train.optimizer.lr, weight_decay=config.train.optimizer.weight_decay) loss_fn = nn.CrossEntropyLoss() model.eval() metric = 0. with torch.no_grad(): for sequence, label in valid_dataloader: sequence = sequence.to(device) label = label.to(device) pred = model(sequence) metric += metric_fn(pred, label) print("init f1_score:", metric / len(valid_dataloader)) for epoch in range(config.train.epochs): # train loop progress_bar = tqdm( train_dataloader, initial=0, desc=f"epoch:{epoch:03d}", ) model.train() total_loss = 0. for sequence, label in progress_bar: sequence = sequence.to(device) label = label.to(device) pred = model(sequence) loss = loss_fn(pred.permute(0, 2, 1), label) progress_bar.set_postfix(loss=loss.item()) total_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() avg_loss = total_loss / len(train_dataloader) # valid loop model.eval() metric = 0. with torch.no_grad(): for sequence, label in valid_dataloader: sequence = sequence.to(device) label = label.to(device) pred = model(sequence) metric += metric_fn(pred, label) print(f"avg_training_loss: {avg_loss}, f1_score: {metric / len(valid_dataloader)}") # 保存当前 epoch 的模型 save_path = f"model.pkl" torch.save(model.state_dict(), save_path) print(f"Model saved to {save_path}") 根据这样的代码创建符合的数据集

# Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. import argparse import datetime import numpy as np import time import torch import torch.backends.cudnn as cudnn import json import torch.nn as nn # main.py 文件头部添加以下内容 import matplotlib matplotlib.use('Agg') # 用于无图形界面环境(必须放在plt导入前) import matplotlib.pyplot as plt from pathlib import Path from pathlib import Path from timm.data import Mixup from timm.models import create_model from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.scheduler import create_scheduler from timm.optim import create_optimizer from timm.utils import NativeScaler, get_state_dict, ModelEma from datasets import build_dataset from engine import train_one_epoch, evaluate from losses import DistillationLoss from samplers import RASampler from augment import new_data_aug_generator from torch.optim.lr_scheduler import CosineAnnealingLR from contextlib import suppress import models_mamba import utils # log about import mlflow def get_args_parser(): parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) parser.add_argument('--batch-size', default=128, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') #128 parser.add_argument('--epochs', default=50, type=int, help='Total training epochs') # 延长训练周期 parser.add_argument('--bce-loss', action='store_true') parser.add_argument('--unscale-lr', action='store_true') # Model parameters parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--input-size', default=224, type=int, help='images input size') #224 parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') parser.add_argument('--drop-path', type=float, default=0.2, metavar='PCT', help='Drop path rate (小数据集易过拟合)') # 增加DropPath率0.0-0.2 parser.add_argument('--model-ema', action='store_true') parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') parser.set_defaults(model_ema=True) parser.add_argument('--model-ema-decay', type=float, default=0.99999, help='') #EMA模型平滑0.99996 parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') # Optimizer parameters parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)') parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') #parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', # help='Clip gradient norm (default: None, no clipping)') #parser.add_argument('--clip-grad', type=float, default=1.0, metavar='NORM', # help='Max gradient norm') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=0.15, help='weight decay (default: 0.05)') ## 增大权重衰减0.05-0.1 # Learning rate schedule parameters parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "cosine"') parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate (default: 0.001)') #0.001 parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)') parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', help='epoch interval to decay LR') parser.add_argument('--warmup-epochs', type=int, default=20, metavar='N', help='epochs to warmup LR, if scheduler supports') # 延长warmup5-10,15 parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') # Augmentation parameters parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT', help='Color jitter factor (default: 0.3)') parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy. "v0" or "original". " + \ "(default: rand-m9-mstd0.5-inc1)'), parser.add_argument('--smoothing', type=float, default=0.2, help='Label smoothing (default: 0.1)')#0.1 更强的标签平滑 parser.add_argument('--train-interpolation', type=str, default='bicubic', help='Training interpolation (random, bilinear, bicubic default: "bicubic")') parser.add_argument('--repeated-aug', action='store_true') parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') parser.set_defaults(repeated_aug=True) parser.add_argument('--train-mode', action='store_true') parser.add_argument('--no-train-mode', action='store_false', dest='train_mode') parser.set_defaults(train_mode=True) parser.add_argument('--ThreeAugment', action='store_true') #3augment parser.add_argument('--src', action='store_true') #simple random crop # * Random Erase params parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT', help='Random erase prob (小数据集建议关闭)') parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') # * Mixup params parser.add_argument('--mixup', type=float, default=0.3, help='mixup alpha, mixup enabled if > 0. (default: 0.8)') #0.5,0.8 parser.add_argument('--cutmix', type=float, default=1.0, help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') parser.add_argument('--mixup-prob', type=float, default=1.0, help='Probability of performing mixup or cutmix when either/both is enabled') parser.add_argument('--mixup-switch-prob', type=float, default=0.5, help='Probability of switching to cutmix when both mixup and cutmix enabled') parser.add_argument('--mixup-mode', type=str, default='batch', help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') # Distillation parameters parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', help='Name of teacher model to train (default: "regnety_160"') #regnety_160 添加教师模型配置resnet50 parser.add_argument('--teacher-path', type=str, default='') parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") parser.add_argument('--distillation-alpha', default=0.7, type=float, help="") #0.5 教师模型配置 parser.add_argument('--distillation-tau', default=1.0, type=float, help="") # * Cosub params parser.add_argument('--cosub', action='store_true') # * Finetuning params parser.add_argument('--finetune', default='', help='finetune from checkpoint') parser.add_argument('--attn-only', action='store_true') # Dataset parameters parser.add_argument('--data-path', default=r'D:/ru_file/ruru_file/Desktop/cifar-10-python', type=str, help='D:/ru_file/ruru_file/Desktop/cifar-10-python') #parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], # type=str, help='Image Net dataset path') parser.add_argument('--data-set', default='CIFAR10', choices=['CIFAR10', 'CIFAR100', 'IMNET', 'INAT', 'INAT19'], type=str, help='Dataset type (CIFAR10, CIFAR100, IMNET, etc)') parser.add_argument('--inat-category', default='name', choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], type=str, help='semantic granularity') parser.add_argument('--output_dir', default='', help='path where to save, empty for no saving') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=0, type=int) parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation") parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--pin-mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', help='') parser.set_defaults(pin_mem=True) # distributed training parameters parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training') parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') # amp about parser.add_argument('--if_amp', action='store_true', help='Enable NVIDIA Apex AMP training') parser.add_argument('--no_amp', action='store_false', dest='if_amp') parser.set_defaults(if_amp=True) # if continue with inf parser.add_argument('--if_continue_inf', action='store_true') parser.add_argument('--no_continue_inf', action='store_false', dest='if_continue_inf') parser.set_defaults(if_continue_inf=False) # if use nan to num parser.add_argument('--if_nan2num', action='store_true') parser.add_argument('--no_nan2num', action='store_false', dest='if_nan2num') parser.set_defaults(if_nan2num=False) # if use random token position #parser.add_argument('--if_random_cls_token_position', action='store_true') #parser.add_argument('--no_random_cls_token_position', action='store_false', dest='if_random_cls_token_position') #parser.set_defaults(if_random_cls_token_position=False) # if use random token rank #parser.add_argument('--if_random_token_rank', action='store_true') #parser.add_argument('--no_random_token_rank', action='store_false', dest='if_random_token_rank') #parser.set_defaults(if_random_token_rank=False) parser.add_argument('--local-rank', default=0, type=int) parser.add_argument('--use-acmix', action='store_true', help='启用ACMix模块') #加入新参数 parser.add_argument('--clip-grad', type=float, default=1.0, help='梯度裁剪阈值') #加入新参数 parser.add_argument('--acmix-kernel', type=int, default=5, help='ACMix卷积核大小') #加入新参数 parser.add_argument('--d-state', type=int, default=16, help='SSM状态维度') #加入新参数,16 parser.add_argument('--expand', type=int, default=4, help='SSM扩展因子') #加入新参数,2,1 #parser.add_argument('--input-size', type=int, default=32, help='CIFAR-10 的标准输入尺寸为 32x32') parser.add_argument('--early-stop', type=int, default=10, help='Early stopping patience')#15 return parser def plot_loss_curves(train_loss_history, val_loss_history, output_dir): """绘制训练和验证损失曲线并保存图像""" plt.figure(figsize=(10, 6)) plt.plot(train_loss_history, label='Training Loss', marker='o') plt.plot(val_loss_history, label='Validation Loss', marker='x') plt.title('Training and Validation Loss Curves') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.grid(True) # 保存图像到输出目录 loss_path = Path(output_dir) / 'loss_curves.png' plt.savefig(loss_path) plt.close() # 记录到MLflow(如果已初始化) if mlflow.active_run(): mlflow.log_artifact(str(loss_path)) def main(args): utils.init_distributed_mode(args) print(args) if args.distillation_type != 'none' and args.finetune and not args.eval: raise NotImplementedError("Finetuning with distillation not yet supported") device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True # log about run_name = args.output_dir.split("/")[-1] if args.local_rank == 0 and args.gpu == 0: mlflow.start_run(run_name=run_name) for key, value in vars(args).items(): mlflow.log_param(key, value) if args.data_set in ['CIFAR10', 'CIFAR100']: args.repeated_aug = False # 强制关闭重复增强 dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) # 检查验证集的类别数是否与训练集一致 assert args.nb_classes == dataset_val.nb_classes, "类别数不一致:训练集 {} vs 验证集 {}".format(args.nb_classes,dataset_val.nb_classes) if args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) if args.ThreeAugment: data_loader_train.dataset.transform = new_data_aug_generator(args) data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, #batch_size=int(1.5 * args.batch_size * 20), batch_size=int(args.batch_size * 2), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False ) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) if args.data_set in ['CIFAR10', 'CIFAR100']: args.input_size = 32 # 强制输入尺寸为32 print(f"Creating model: {args.model}") # 在调用 create_model() 之前添加以下代码 model_kwargs = dict( img_size=args.input_size, # 确保 img_size 在此处定义 use_acmix=args.use_acmix, acmix_kernel=args.acmix_kernel, d_state=args.d_state, expand=args.expand, rms_norm=False, # 确保禁用RMSNorm num_classes=args.nb_classes, ) model = create_model( args.model, pretrained=False, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, #use_acmix=args.use_acmix, # 传递参数 #acmix_kernel=args.acmix_kernel, # 传递参数 #img_size=args.input_size ** model_kwargs # 传递所有额外参数 ) # 检查并调整分类头的输出维度 if hasattr(model, 'head') and len(model.head) > 0 and hasattr(model.head[-1], 'out_features'): if model.head[-1].out_features != args.nb_classes: print(f"Adjusting model head from {model.head[-1].out_features} to {args.nb_classes} classes") model.head[-1] = nn.Linear(model.head[-1].in_features, args.nb_classes) else: # 如果模型没有标准的分类头,直接检查输出层 if model.num_features != args.nb_classes: print(f"Adjusting model output dimension from {model.num_features} to {args.nb_classes}") model.head = nn.Linear(model.num_features, args.nb_classes) model.to(device)# 确保模型先移动到设备上 print("model_kwargs 内容检查:", model_kwargs) print(f"Model device: {next(model.parameters()).device}") # 仅当使用Vim模型时添加ACMix参数 if args.model.startswith('vim'): model_kwargs['use_acmix'] = args.use_acmix model_kwargs['acmix_kernel'] = args.acmix_kernel model = create_model(args.model, **model_kwargs) model_without_ddp = model optimizer = create_optimizer(args, model_without_ddp) # 确保 model_without_ddp 在此处赋值 # 创建模型后添加 print(f"Model output dimension: {model.head[-1].out_features}") assert model.head[-1].out_features == args.nb_classes, \ f"Output dimension mismatch! Model:{model.head[-1].out_features}, Dataset:{args.nb_classes}" # 确保模型在设备上后进行前向传播测试 #with torch.no_grad(): # test_input = torch.randn(2, 3, args.input_size, args.input_size).to(device) # print(f"Test input device: {test_input.device}") # output = model(test_input) # print(f"Test output shape: {output.shape}, should be (2, {args.nb_classes})") # 替换原有调度器 lr_scheduler = CosineAnnealingLR( optimizer, T_max=args.epochs * len(data_loader_train), eta_min=1e-6 ) if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.finetune, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.finetune, map_location='cpu') checkpoint_model = checkpoint['model'] model = model.to(device) # 再次将模型移动到 GPU state_dict = model.state_dict() for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # 移除所有分类头相关权重 #keys_to_remove = [k for k in checkpoint_model.keys() # if 'head' in k or 'cls_token' in k] #for k in keys_to_remove: # print(f"Removing key {k} from pretrained checkpoint") # del checkpoint_model[k] # 如果存在 'head' 键,删除分类头权重 #if 'head' in checkpoint_model: # del checkpoint_model['head'] # 删除分类头权重 # interpolate position embedding pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed model.load_state_dict(checkpoint_model, strict=False) if args.attn_only: for name_p,p in model.named_parameters(): if '.attn.' in name_p: p.requires_grad = True else: p.requires_grad = False try: model.head.weight.requires_grad = True model.head.bias.requires_grad = True except: model.fc.weight.requires_grad = True model.fc.bias.requires_grad = True try: model.pos_embed.requires_grad = True except: print('no position encoding') try: for p in model.patch_embed.parameters(): p.requires_grad = False except: print('no patch embed') model.to(device) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) else: model = torch.nn.DataParallel(model) model = model.to(device) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) if not args.unscale_lr: linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 args.lr = linear_scaled_lr # 加入以下代码 if args.use_acmix: for param_group in optimizer.param_groups: param_group['lr'] *= 0.8 # 将基础学习率降低20% # amp about #amp_autocast = suppress #loss_scaler = "none" if args.if_amp: amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() else: amp_autocast = suppress loss_scaler = "none" if args.data_set in ['CIFAR10', 'CIFAR100']: args.warmup_epochs = max(3, args.warmup_epochs) # CIFAR最少3个epoch预热 #lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if mixup_active: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing, num_classes=args.nb_classes) #添加num_classes else: criterion = torch.nn.CrossEntropyLoss() if args.bce_loss: criterion = torch.nn.BCEWithLogitsLoss() teacher_model = None if args.distillation_type != 'none': assert args.teacher_path, 'need to specify teacher-path when using distillation' print(f"Creating teacher model: {args.teacher_model}") teacher_model = create_model( args.teacher_model, pretrained=False, num_classes=args.nb_classes, global_pool='avg', ) if args.teacher_path.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.teacher_path, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.teacher_path, map_location='cpu') teacher_model.load_state_dict(checkpoint['model']) teacher_model.to(device) teacher_model.eval() # wrap the criterion in our custom DistillationLoss, which # just dispatches to the original criterion if args.distillation_type is 'none' criterion = DistillationLoss( criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau ) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) # add ema load if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint and args.if_amp: # change loss_scaler if not amp loss_scaler.load_state_dict(checkpoint['scaler']) elif 'scaler' in checkpoint and not args.if_amp: loss_scaler = 'none' #lr_scheduler.step(args.start_epoch) if args.eval: test_stats = evaluate(data_loader_val, model, device, amp_autocast) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") test_stats = evaluate(data_loader_val, model_ema.ema, device, amp_autocast) print(f"Accuracy of the ema network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") return # log about if args.local_rank == 0 and args.gpu == 0: mlflow.log_param("n_parameters", n_parameters) print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 no_improve_epochs = 0 # 添加在epoch循环前 # 新增:初始化损失记录列表 train_loss_history = [] val_loss_history = [] best_val_acc = 0.0 no_improve_epochs = 0 patience = 10 # 连续15轮无提升则停止,早停机制 15 for epoch in range(args.start_epoch, args.epochs): ''' if args.distributed: data_loader_train.sampler.set_epoch(epoch) if test_stats["acc1"] > max_accuracy: max_accuracy = test_stats["acc1"] no_improve_epochs = 0 else: no_improve_epochs += 1 if no_improve_epochs >= 10: # 连续10轮无提升停止 print(f"No improvement for {no_improve_epochs} epochs, early stopping!") break # 在 train_one_epoch 调用前添加梯度裁剪逻辑 if args.clip_grad is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) # 修改后的 train_one_epoch 调用 train_stats = train_one_epoch( model,criterion,data_loader_train, optimizer,device,epoch, loss_scaler=loss_scaler, amp_autocast=amp_autocast, # 传入混合精度上下文 max_norm=args.clip_grad, # 传入梯度裁剪阈值 model_ema=model_ema, mixup_fn=mixup_fn, set_training_mode=True, # 强制启用训练模式 args=args ) ''' train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler=loss_scaler, amp_autocast=amp_autocast, # 传入混合精度上下文 max_norm=args.clip_grad, # 传入梯度裁剪阈值 model_ema=model_ema, mixup_fn=mixup_fn, set_training_mode=True, # 强制启用训练模式 args=args, lr_scheduler=lr_scheduler ) # 修改后的梯度监控代码 if any(p.grad is not None for p in model.parameters()): grad_norm = torch.norm(torch.stack([ torch.norm(p.grad) for p in model.parameters() if p.grad is not None # 只处理有梯度的参数 ])) print(f"Gradient Norm: {grad_norm.item():.4f}") else: print("Warning: All gradients are None!") test_stats = evaluate(data_loader_val, model, device, amp_autocast) # 添加以下代码调试 if args.cosub: outputs = torch.split(outputs, outputs.shape[0] // 2, dim=0) # 检查每个分支的输出维度 assert outputs[0].shape[ 1] == args.nb_classes, f"Output dim mismatch! Model: {outputs[0].shape[1]}, Dataset: {args.nb_classes}" else: # 检查输出维度是否匹配 assert outputs.shape[ 1] == args.nb_classes, f"Output dim mismatch! Model: {outputs.shape[1]}, Dataset: {args.nb_classes}" # 每5个epoch检查一次模型输出维度 if epoch % 5 == 0: print(f"Epoch {epoch} - Model output dim: {model_without_ddp.head[-1].out_features}") # 动态调整mixup强度 if epoch > args.epochs // 2: mixup_fn = Mixup(mixup_alpha=0.5) # 后期减少mixup强度 # ============ 新增:记录损失值 ============ train_loss = train_stats['loss'] val_loss = test_stats['loss'] train_loss_history.append(train_loss) val_loss_history.append(val_loss) # ============ MLflow日志记录 ============ if args.local_rank == 0 and args.gpu == 0: mlflow.start_run(run_name=args.output_dir.split("/")[-1]) mlflow.log_metric("train_loss", train_loss, step=epoch) mlflow.log_metric("val_loss", val_loss, step=epoch) # 记录参数 for key, value in vars(args).items(): mlflow.log_param(key, value) # 记录文件(如损失曲线) #mlflow.log_artifact(os.path.join(args.output_dir, "loss_curves.png")) # 更新最大准确率和早停计数器 if test_stats["acc1"] > max_accuracy: max_accuracy = test_stats["acc1"] no_improve_epochs = 0 else: no_improve_epochs += 1 # 更新最佳准确率和早停计数器 current_val_acc = test_stats["acc1"] if current_val_acc > best_val_acc: best_val_acc = current_val_acc no_improve_epochs = 0 # 保存最佳模型 utils.save_on_master({...}, 'best_model.pth') else: no_improve_epochs += 1 if no_improve_epochs >= patience: print(f"早停在epoch {epoch},最佳准确率: {best_val_acc:.2f}%") break # 添加过拟合诊断 if val_loss != 0: # 避免除以零 overfit_index = train_loss / val_loss if overfit_index < 0.7: # 健康范围0.8-1.2 print(f"警告:过拟合风险高!指数={overfit_index:.2f}") # 早停判断 #if no_improve_epochs >= 15: # 给予更多收敛时间 # print("Early stopping!") # break #lr_scheduler.step() if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), # 保存调度器状态 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict() if loss_scaler != 'none' else loss_scaler, 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device, amp_autocast) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") # 添加过拟合诊断(在训练循环结束后) if val_loss != 0: # 避免除以零 overfit_index = train_loss / val_loss if overfit_index < 0.7: # 健康范围0.8-1.2 print(f"警告:过拟合风险高!指数={overfit_index:.2f}") if max_accuracy < test_stats["acc1"]: max_accuracy = test_stats["acc1"] if args.output_dir: checkpoint_paths = [output_dir / 'best_checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict() if loss_scaler != 'none' else loss_scaler, 'args': args, }, checkpoint_path) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} # log about if args.local_rank == 0 and args.gpu == 0: for key, value in log_stats.items(): mlflow.log_metric(key, value, log_stats['epoch']) if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) # 新增调用 if utils.is_main_process() and args.output_dir: plot_loss_curves(train_loss_history, val_loss_history, args.output_dir) if __name__ == '__main__': parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) args = parser.parse_args() args.gpu = None if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args) 如何修改上述 问题

def forward(self, l, ab, y, idx=None): K = int(self.params[0].item()) T = self.params[1].item() Z_l = self.params[2].item() Z_ab = self.params[3].item() momentum = self.params[4].item() batchSize = l.size(0) outputSize = self.memory_l.size(0) # the number of sample of memory bank inputSize = self.memory_l.size(1) # the feature dimensionality # score computation if idx is None: # 用 AliasMethod 为 batch 里的每个样本都采样 4096 个负样本的 idx idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1) # sample positives and negatives idx.select(1, 0).copy_(y.data) # sample weight_l = torch.index_select(self.memory_l, 0, idx.view(-1)).detach() weight_l = weight_l.view(batchSize, K + 1, inputSize) out_ab = torch.bmm(weight_l, ab.view(batchSize, inputSize, 1)) # sample weight_ab = torch.index_select(self.memory_ab, 0, idx.view(-1)).detach() weight_ab = weight_ab.view(batchSize, K + 1, inputSize) out_l = torch.bmm(weight_ab, l.view(batchSize, inputSize, 1)) if self.use_softmax: out_ab = torch.div(out_ab, T) out_l = torch.div(out_l, T) out_l = out_l.contiguous() out_ab = out_ab.contiguous() else: out_ab = torch.exp(torch.div(out_ab, T)) out_l = torch.exp(torch.div(out_l, T)) # set Z_0 if haven't been set yet, # Z_0 is used as a constant approximation of Z, to scale the probs if Z_l < 0: self.params[2] = out_l.mean() * outputSize Z_l = self.params[2].clone().detach().item() print("normalization constant Z_l is set to {:.1f}".format(Z_l)) if Z_ab < 0: self.params[3] = out_ab.mean() * outputSize Z_ab = self.params[3].clone().detach().item() print("normalization constant Z_ab is set to {:.1f}".format(Z_ab)) # compute out_l, out_ab out_l = torch.div(out_l, Z_l).contiguous() out_ab = torch.div(out_ab, Z_ab).contiguous() # # update memory with torch.no_grad(): l_pos = torch.index_select(self.memory_l, 0, y.view(-1)) l_pos.mul_(momentum) l_pos.add_(torch.mul(l, 1 - momentum)) l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) updated_l = l_pos.div(l_norm) self.memory_l.index_copy_(0, y, updated_l) ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1)) ab_pos.mul_(momentum) ab_pos.add_(torch.mul(ab, 1 - momentum)) ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) updated_ab = ab_pos.div(ab_norm) self.memory_ab.index_copy_(0, y, updated_ab) return out_l, out_ab

import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 0=INFO, 1=WARNING, 2=ERROR, 3=FATAL os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # 禁用oneDNN日志 import sys import glob import time import json import torch import pickle import shutil import argparse import datetime import torchvision import numpy as np from tqdm import tqdm from PIL import Image import torch.nn as nn from packaging import version from functools import partial import pytorch_lightning as pl from omegaconf import OmegaConf, DictConfig import torch.distributed as dist from typing import List, Dict, Any, Optional, Union, Tuple from ldm.util import instantiate_from_config from pytorch_lightning import seed_everything from pytorch_lightning.trainer import Trainer from torch.utils.data import DataLoader, Dataset from ldm.data.base import Txt2ImgIterableBaseDataset from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor from torch.cuda.amp import autocast, GradScaler # 模型路径 current_dir = os.path.dirname(os.path.abspath(__file__)) for path in ["download", "download/CLIP", "download/k-diffusion", "download/stable_diffusion", "download/taming-transformers"]: sys.path.append(os.path.join(current_dir, path)) class ConfigManager: """配置管理类,统一处理配置加载和解析""" def __init__(self, config_files: Union[str, List[str]], cli_args: Optional[List[str]] = None): # 将单个字符串路径转换为列表 if isinstance(config_files, str): config_files = [config_files] # 验证配置文件存在 self.configs = [] for cfg in config_files: if not os.path.exists(cfg): raise FileNotFoundError(f"配置文件不存在: {cfg}") self.configs.append(OmegaConf.load(cfg)) # 解析命令行参数 self.cli = OmegaConf.from_dotlist(cli_args) if cli_args else OmegaConf.create() # 合并所有配置 self.config = OmegaConf.merge(*self.configs, self.cli) def get_model_config(self) -> DictConfig: """获取模型配置""" if "model" not in self.config: raise KeyError("配置文件中缺少'model'部分") return self.config.model def get_data_config(self) -> DictConfig: """获取数据配置""" if "data" not in self.config: raise KeyError("配置文件中缺少'data'部分") return self.config.data def get_training_config(self) -> DictConfig: """获取训练配置,提供默认值""" training_config = self.config.get("training", OmegaConf.create()) # 设置默认值 defaults = { "max_epochs": 200, "gpus": torch.cuda.device_count(), "accumulate_grad_batches": 1, "learning_rate": 1e-4, "precision": 32 } for key, value in defaults.items(): if key not in training_config: training_config[key] = value return training_config def get_logging_config(self) -> DictConfig: """获取日志配置""" return self.config.get("logging", OmegaConf.create({"logdir": "logs"})) def get_callbacks_config(self) -> DictConfig: """获取回调函数配置""" return self.config.get("callbacks", OmegaConf.create()) def save_config(self, save_path: str) -> None: """保存配置到文件""" os.makedirs(os.path.dirname(save_path), exist_ok=True) OmegaConf.save(self.config, save_path) print(f"配置已保存到: {save_path}") class DataModuleFromConfig(pl.LightningDataModule): def __init__(self, batch_size, num_workers, train=None, validation=None, test=None): super().__init__() self.batch_size = batch_size self.num_workers = num_workers self.dataset_configs = dict() if train is not None: self.dataset_configs["train"] = train if validation is not None: self.dataset_configs["validation"] = validation if test is not None: self.dataset_configs["test"] = test def setup(self, stage=None): self.datasets = { k: instantiate_from_config(cfg) for k, cfg in self.dataset_configs.items() } def _get_dataloader(self, dataset_name, shuffle=False): dataset = self.datasets.get(dataset_name) if dataset is None: raise ValueError(f"数据集 {dataset_name} 未配置") return DataLoader( dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=shuffle, pin_memory=True ) def train_dataloader(self): return self._get_dataloader("train", shuffle=True) def val_dataloader(self): return self._get_dataloader("validation") def test_dataloader(self): return self._get_dataloader("test") def worker_init_fn(worker_id: int) -> None: """数据加载器工作进程初始化函数""" worker_info = torch.utils.data.get_worker_info() if worker_info is None: return dataset = worker_info.dataset worker_id = worker_info.id if isinstance(dataset, Txt2ImgIterableBaseDataset): # 对可迭代数据集进行分片 split_size = dataset.num_records // worker_info.num_workers dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] # 设置随机种子 seed = torch.initial_seed() % 2**32 + worker_id np.random.seed(seed) torch.manual_seed(seed) class EnhancedImageLogger(Callback): """增强的图像日志记录器,支持多平台日志输出""" def __init__(self, batch_frequency: int, max_images: int, clamp: bool = True, rescale: bool = True, loggers: Optional[List] = None, log_first_step: bool = False, log_images_kwargs: Optional[Dict] = None): super().__init__() self.batch_frequency = max(1, batch_frequency) self.max_images = max_images self.clamp = clamp self.rescale = rescale self.loggers = loggers or [] self.log_first_step = log_first_step self.log_images_kwargs = log_images_kwargs or {} self.log_steps = [2 ** n for n in range(6, int(np.log2(self.batch_frequency)) + 1)] if self.batch_frequency > 1 else [] def check_frequency(self, step: int) -> bool: """检查是否达到记录频率""" if step == 0 and self.log_first_step: return True if step % self.batch_frequency == 0: return True if step in self.log_steps: if len(self.log_steps) > 0: self.log_steps.pop(0) return True return False def log_images(self, pl_module: pl.LightningModule, batch: Any, step: int, split: str = "train") -> None: """记录图像并发送到所有日志记录器""" if not self.check_frequency(step) or not hasattr(pl_module, "log_images"): return is_train = pl_module.training if is_train: pl_module.eval() # 切换到评估模式 with torch.no_grad(): try: images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) except Exception as e: print(f"记录图像时出错: {e}") images = {} # 处理图像数据 for k in list(images.keys()): if not isinstance(images[k], torch.Tensor): continue N = min(images[k].shape[0], self.max_images) images[k] = images[k][:N] # 分布式环境下收集所有图像 if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: images[k] = torch.cat(all_gather(images[k])) images[k] = images[k].detach().cpu() if self.clamp: images[k] = torch.clamp(images[k], -1., 1.) if self.rescale: images[k] = (images[k] + 1.0) / 2.0 # 缩放到[0,1] # 发送到所有日志记录器 for logger in self.loggers: if hasattr(logger, 'log_images'): try: logger.log_images(images, step, split) except Exception as e: print(f"日志记录器 {type(logger).__name__} 记录图像失败: {e}") if is_train: pl_module.train() # 恢复训练模式 def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None: """训练批次结束时记录图像""" if trainer.global_step % trainer.log_every_n_steps == 0: self.log_images(pl_module, batch, pl_module.global_step, "train") def on_validation_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None: """验证批次结束时记录图像""" if batch_idx == 0: # 只记录第一个验证批次 self.log_images(pl_module, batch, pl_module.global_step, "val") class TensorBoardLogger: """TensorBoard日志记录器,完整实现PyTorch Lightning日志记录器接口""" def __init__(self, save_dir: str): from torch.utils.tensorboard import SummaryWriter os.makedirs(save_dir, exist_ok=True) self.save_dir = save_dir self.writer = SummaryWriter(save_dir) self._name = "TensorBoard" # 日志记录器名称 self._version = "1.0" # 版本信息 self._experiment = self.writer # 实验对象 print(f"TensorBoard日志保存在: {save_dir}") @property def name(self) -> str: return self._name @property def version(self) -> str: return self._version @property def experiment(self) -> Any: return self._experiment def log_hyperparams(self, params: Dict) -> None: """记录超参数到TensorBoard""" try: # 将嵌套字典展平 flat_params = {} for key, value in params.items(): if isinstance(value, dict): for sub_key, sub_value in value.items(): flat_params[f"{key}/{sub_key}"] = sub_value else: flat_params[key] = value # 记录超参数 self.writer.add_hparams( {k: v for k, v in flat_params.items() if isinstance(v, (int, float, str))}, {}, run_name="." ) print("已记录超参数到TensorBoard") except Exception as e: print(f"记录超参数失败: {e}") def log_graph(self, model: torch.nn.Module, input_array: Optional[torch.Tensor] = None) -> None: """记录模型计算图到TensorBoard""" try: # 扩散模型通常有复杂的前向传播,跳过图记录 print("跳过扩散模型的计算图记录") return except Exception as e: print(f"记录模型计算图失败: {e}") def log_metrics(self, metrics: Dict[str, float], step: int) -> None: """记录指标到TensorBoard""" for name, value in metrics.items(): try: self.writer.add_scalar(name, value, global_step=step) except Exception as e: print(f"添加标量失败: {name}, 错误: {e}") def log_images(self, images: Dict[str, torch.Tensor], step: int, split: str) -> None: """记录图像到TensorBoard""" for k, img in images.items(): if img.numel() == 0: continue try: grid = torchvision.utils.make_grid(img, nrow=min(8, img.shape[0])) self.writer.add_image(f"{split}/{k}", grid, global_step=step) except Exception as e: print(f"添加图像失败: {k}, 错误: {e}") def save(self) -> None: """保存日志(TensorBoard自动保存,这里无需额外操作)""" pass def finalize(self, status: str) -> None: """完成日志记录并关闭写入器""" self.close() def close(self) -> None: """关闭日志写入器""" if hasattr(self, 'writer') and self.writer is not None: self.writer.flush() self.writer.close() self.writer = None print(f"TensorBoard日志已关闭") class TQDMProgressBar(Callback): """使用tqdm显示训练进度,兼容不同版本的PyTorch Lightning""" def __init__(self): self.progress_bar = None self.epoch_bar = None def on_train_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """训练开始时初始化进度条""" # 兼容不同版本的步数估计 total_steps = self._get_total_steps(trainer) self.progress_bar = tqdm( total=total_steps, desc="Training Steps", position=0, leave=True, dynamic_ncols=True ) self.epoch_bar = tqdm( total=trainer.max_epochs, desc="Epochs", position=1, leave=True, dynamic_ncols=True ) def _get_total_steps(self, trainer: Trainer) -> int: """获取训练总步数,兼容不同版本的PyTorch Lightning""" # 尝试使用新版本属性 if hasattr(trainer, 'estimated_stepping_batches'): return trainer.estimated_stepping_batches # 尝试使用旧版本属性 if hasattr(trainer, 'estimated_steps'): return trainer.estimated_steps # 回退到手动计算 try: if hasattr(trainer, 'num_training_batches'): num_batches = trainer.num_training_batches else: num_batches = len(trainer.train_dataloader) if hasattr(trainer, 'accumulate_grad_batches'): accumulate = trainer.accumulate_grad_batches else: accumulate = 1 steps_per_epoch = num_batches // accumulate total_steps = trainer.max_epochs * steps_per_epoch print(f"回退计算训练总步数: {total_steps} = {trainer.max_epochs} epochs × {steps_per_epoch} steps/epoch") return total_steps except Exception as e: print(f"无法确定训练总步数: {e}, 使用默认值10000") return 10000 def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int) -> None: """每个训练批次结束时更新进度条""" if self.progress_bar: # 防止进度条超过总步数 if self.progress_bar.n < self.progress_bar.total: self.progress_bar.update(1) try: # 尝试从输出中获取损失 loss = outputs.get('loss') if loss is not None: if isinstance(loss, torch.Tensor): loss = loss.item() self.progress_bar.set_postfix({"loss": loss}) except Exception: pass def on_train_epoch_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """每个训练轮次结束时更新轮次进度条""" if self.epoch_bar: self.epoch_bar.update(1) self.epoch_bar.set_postfix({"epoch": trainer.current_epoch}) def on_train_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """训练结束时关闭进度条""" if self.progress_bar: self.progress_bar.close() if self.epoch_bar: self.epoch_bar.close() class PerformanceMonitor(Callback): """性能监控回调,记录内存使用和训练速度""" def __init__(self): self.epoch_start_time = 0 self.batch_times = [] def on_train_epoch_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """每个训练轮次开始时记录时间和重置内存统计""" self.epoch_start_time = time.time() self.batch_times = [] if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() # 修改1:添加dataloader_idx参数 def on_train_batch_start(self, trainer: Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """每个训练批次开始时记录时间""" self.batch_start_time = time.time() # 修改2:添加dataloader_idx参数 def on_train_batch_end(self, trainer: Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """每个训练批次结束时记录时间""" self.batch_times.append(time.time() - self.batch_start_time) def on_train_epoch_end(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: """每个训练轮次结束时计算并记录性能指标""" epoch_time = time.time() - self.epoch_start_time if self.batch_times: avg_batch_time = sum(self.batch_times) / len(self.batch_times) batches_per_second = 1.0 / avg_batch_time else: avg_batch_time = 0 batches_per_second = 0 memory_info = "" if torch.cuda.is_available(): max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 # MiB memory_info = f", 峰值显存: {max_memory:.2f} MiB" rank_zero_info( f"Epoch {trainer.current_epoch} | " f"耗时: {epoch_time:.2f}s | " f"Batch耗时: {avg_batch_time:.4f}s ({batches_per_second:.2f} batches/s)" f"{memory_info}" ) def get_world_size() -> int: """获取分布式训练中的总进程数""" if dist.is_initialized(): return dist.get_world_size() return 1 def all_gather(data: torch.Tensor) -> List[torch.Tensor]: """在分布式环境中收集所有进程的数据""" world_size = get_world_size() if world_size == 1: return [data] # 获取各进程的Tensor大小 local_size = torch.tensor([data.numel()], device=data.device) size_list = [torch.zeros_like(local_size) for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # 收集数据 tensor_list = [] for size in size_list: tensor_list.append(torch.empty((max_size,), dtype=data.dtype, device=data.device)) if local_size < max_size: padding = torch.zeros(max_size - local_size, dtype=data.dtype, device=data.device) data = torch.cat((data.view(-1), padding)) dist.all_gather(tensor_list, data.view(-1)) # 截断到实际大小 results = [] for tensor, size in zip(tensor_list, size_list): results.append(tensor[:size].reshape(data.shape)) return results def create_experiment_directories(logging_config: DictConfig, experiment_name: str) -> Tuple[str, str, str]: """创建实验目录结构""" now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") logdir = os.path.join(logging_config.logdir, f"{experiment_name}_{now}") ckptdir = os.path.join(logdir, "checkpoints") cfgdir = os.path.join(logdir, "configs") os.makedirs(ckptdir, exist_ok=True) os.makedirs(cfgdir, exist_ok=True) print(f"实验目录: {logdir}") print(f"检查点目录: {ckptdir}") print(f"配置目录: {cfgdir}") return logdir, ckptdir, cfgdir def setup_callbacks(config_manager: ConfigManager, ckptdir: str, tb_logger: TensorBoardLogger) -> List[Callback]: """设置训练回调函数""" callbacks = [] # 模型检查点 checkpoint_callback = ModelCheckpoint( dirpath=ckptdir, filename='{epoch}-{step}-{val_loss:.2f}', monitor='val_loss', save_top_k=3, mode='min', save_last=True, save_on_train_epoch_end=True, # 确保在epoch结束时保存完整状态 save_weights_only=False, # 明确设置为False,保存完整检查点 every_n_train_steps=1000 # 每1000步保存一次 ) callbacks.append(checkpoint_callback) # 学习率监控 lr_monitor = LearningRateMonitor(logging_interval="step") callbacks.append(lr_monitor) # 图像日志记录 image_logger_cfg = config_manager.get_callbacks_config().get("image_logger", {}) image_logger = EnhancedImageLogger( batch_frequency=image_logger_cfg.get("batch_frequency", 500), max_images=image_logger_cfg.get("max_images", 4), loggers=[tb_logger] ) callbacks.append(image_logger) # 进度条 progress_bar = TQDMProgressBar() callbacks.append(progress_bar) # 性能监控 perf_monitor = PerformanceMonitor() callbacks.append(perf_monitor) return callbacks def preprocess_checkpoint(checkpoint_path: str, model: pl.LightningModule) -> Dict[str, Any]: """预处理检查点文件,确保包含所有必要的键,并添加缺失的训练状态""" print(f"预处理检查点文件: {checkpoint_path}") # 加载检查点 try: checkpoint = torch.load(checkpoint_path, map_location="cpu") except Exception as e: print(f"加载检查点失败: {e}") raise # 强制重置训练状态 checkpoint['epoch'] = 0 checkpoint['global_step'] = 0 checkpoint['lr_schedulers'] = [] checkpoint['optimizer_states'] = [] print("已重置训练状态: epoch=0, global_step=0") # 检查是否缺少关键训练状态 required_keys = ['optimizer_states', 'lr_schedulers', 'epoch', 'global_step'] missing_keys = [k for k in required_keys if k not in checkpoint] if missing_keys: print(f"警告: 检查点缺少训练状态字段 {missing_keys},将创建伪训练状态") # 创建伪训练状态 checkpoint.setdefault('optimizer_states', []) checkpoint.setdefault('lr_schedulers', []) checkpoint.setdefault('epoch', 0) checkpoint.setdefault('global_step', 0) # 检查是否缺少 position_ids state_dict = checkpoint.get("state_dict", {}) if "cond_stage_model.transformer.text_model.embeddings.position_ids" not in state_dict: print("警告: 检查点缺少 'cond_stage_model.transformer.text_model.embeddings.position_ids' 键") # 获取模型中的 position_ids 形状 if hasattr(model, "cond_stage_model") and hasattr(model.cond_stage_model, "transformer"): try: max_position_embeddings = model.cond_stage_model.transformer.text_model.config.max_position_embeddings position_ids = torch.arange(max_position_embeddings).expand((1, -1)) state_dict["cond_stage_model.transformer.text_model.embeddings.position_ids"] = position_ids print("已添加 position_ids 到检查点") except Exception as e: print(f"无法添加 position_ids: {e}") # 确保有 state_dict if "state_dict" not in checkpoint: checkpoint["state_dict"] = state_dict return checkpoint # 正确继承原始模型类 from ldm.models.diffusion.ddpm import LatentDiffusion class CustomLatentDiffusion(LatentDiffusion): """自定义 LatentDiffusion 类,处理检查点加载问题""" def on_load_checkpoint(self, checkpoint): """在加载检查点时自动处理缺失的键""" state_dict = checkpoint["state_dict"] # 检查是否缺少 position_ids if "cond_stage_model.transformer.text_model.embeddings.position_ids" not in state_dict: print("警告: 检查点缺少 'cond_stage_model.transformer.text_model.embeddings.position_ids' 键") # 获取模型中的 position_ids 形状 max_position_embeddings = self.cond_stage_model.transformer.text_model.config.max_position_embeddings position_ids = torch.arange(max_position_embeddings).expand((1, -1)) state_dict["cond_stage_model.transformer.text_model.embeddings.position_ids"] = position_ids print("已添加 position_ids 到 state_dict") # 使用非严格模式加载 self.load_state_dict(state_dict, strict=False) print("模型权重加载完成") def filter_kwargs(cls, kwargs, log_prefix=""): # 关键参数白名单 - 这些参数必须保留 ESSENTIAL_PARAMS = { 'unet_config', 'first_stage_config', 'cond_stage_config', 'scheduler_config', 'ckpt_path', 'linear_start', 'linear_end' } # 特殊处理:允许所有包含"config"的参数 filtered_kwargs = {} for k, v in kwargs.items(): if k in ESSENTIAL_PARAMS or 'config' in k: filtered_kwargs[k] = v else: print(f"{log_prefix}过滤参数: {k}") print(f"{log_prefix}保留参数: {list(filtered_kwargs.keys())}") return filtered_kwargs def check_checkpoint_content(checkpoint_path): """打印检查点包含的键,确认是否有训练状态""" checkpoint = torch.load(checkpoint_path, map_location="cpu") print("检查点包含的键:", list(checkpoint.keys())) if "state_dict" in checkpoint: print("模型权重存在") if "optimizer_states" in checkpoint: print("优化器状态存在") if "epoch" in checkpoint: print(f"保存的epoch: {checkpoint['epoch']}") if "global_step" in checkpoint: print(f"保存的global_step: {checkpoint['global_step']}") def main() -> None: """主函数,训练和推理流程的入口点""" # 启用Tensor Core加速 torch.set_float32_matmul_precision('high') # 解析命令行参数 parser = argparse.ArgumentParser(description="扩散模型训练框架") parser.add_argument("--config", type=str, default="configs/train.yaml", help="配置文件路径") parser.add_argument("--name", type=str, default="experiment", help="实验名称") parser.add_argument("--resume", action="store_true", default=True, help="恢复训练") parser.add_argument("--debug", action="store_true", help="调试模式") parser.add_argument("--seed", type=int, default=42, help="随机种子") parser.add_argument("--scale_lr", action="store_true", help="根据GPU数量缩放学习率") parser.add_argument("--precision", type=str, default="32", choices=["16", "32", "bf16"], help="训练精度") args, unknown = parser.parse_known_args() # 设置随机种子 seed_everything(args.seed, workers=True) print(f"设置随机种子: {args.seed}") # 初始化配置管理器 try: config_manager = ConfigManager(args.config, unknown) config = config_manager.config except Exception as e: print(f"加载配置失败: {e}") sys.exit(1) # 创建日志目录 logging_config = config_manager.get_logging_config() logdir, ckptdir, cfgdir = create_experiment_directories(logging_config, args.name) # 保存配置 config_manager.save_config(os.path.join(cfgdir, "config.yaml")) # 配置日志记录器 tb_logger = TensorBoardLogger(os.path.join(logdir, "tensorboard")) # 配置回调函数 callbacks = setup_callbacks(config_manager, ckptdir, tb_logger) # 初始化数据模块 try: print("初始化数据模块...") data_config = config_manager.get_data_config() data_module = instantiate_from_config(data_config) data_module.setup() print("可用数据集:", list(data_module.datasets.keys())) except Exception as e: print(f"数据模块初始化失败: {str(e)}") return # 创建模型 try: model_config = config_manager.get_model_config() model_params = model_config.get("params", {}) # 创建模型实例 model = CustomLatentDiffusion(**model_config.get("params", {})) print("模型初始化成功") # 检查并转换预训练权重 ckpt_path = model_config.params.get("ckpt_path", "") if ckpt_path and os.path.exists(ckpt_path): print(f"加载预训练权重: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) # 查找所有与conv_in.weight相关的键 conv_in_keys = [] for key in state_dict.keys(): if "conv_in.weight" in key and "first_stage_model" in key: conv_in_keys.append(key) # 转换找到的权重 for conv_in_key in conv_in_keys: if state_dict[conv_in_key].shape[1] == 3: # 原始是3通道 print(f"转换权重: {conv_in_key} 从3通道到1通道") # 取RGB三通道的平均值作为单通道权重 rgb_weights = state_dict[conv_in_key] ir_weights = rgb_weights.mean(dim=1, keepdim=True) state_dict[conv_in_key] = ir_weights print(f"转换前形状: {rgb_weights.shape}") print(f"转换后形状: {ir_weights.shape}") print(f"模型层形状: {model.first_stage_model.encoder.conv_in.weight.shape}") # 非严格模式加载(允许其他层不匹配) missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f"权重加载完成: 缺失层 {len(missing)}, 不匹配层 {len(unexpected)}") if missing: print("缺失层:", missing) if unexpected: print("意外层:", unexpected) except Exception as e: print(f"模型初始化失败: {str(e)}") return print("VAE输入层形状:", model.first_stage_model.encoder.conv_in.weight.shape) # 权重转换 if ckpt_path and os.path.exists(ckpt_path): print(f"加载预训练权重: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) # 增强:查找所有需要转换的层(包括可能的变体) conversion_keys = [] for key in state_dict.keys(): if "conv_in" in key or "conv_out" in key or "nin_shortcut" in key: if state_dict[key].ndim == 4 and state_dict[key].shape[1] == 3: conversion_keys.append(key) print(f"找到需要转换的层: {conversion_keys}") # 转换权重 for key in conversion_keys: print(f"转换权重: {key}") print(f"原始形状: {state_dict[key].shape}") # RGB权重 [out_c, in_c=3, kH, kW] rgb_weights = state_dict[key] # 转换为单通道权重 [out_c, 1, kH, kW] if rgb_weights.shape[1] == 3: ir_weights = rgb_weights.mean(dim=1, keepdim=True) state_dict[key] = ir_weights print(f"转换后形状: {state_dict[key].shape}") # 加载转换后的权重 try: # 使用非严格模式加载 missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f"权重加载完成: 缺失层 {len(missing)}, 不匹配层 {len(unexpected)}") # 打印重要信息 if missing: print("缺失层:", missing[:5]) # 只显示前5个避免过多输出 if unexpected: print("意外层:", unexpected[:5]) # 特别检查conv_in层 if "first_stage_model.encoder.conv_in.weight" in missing: print("警告: conv_in.weight未加载,需要手动初始化") # 手动初始化单通道卷积层 with torch.no_grad(): model.first_stage_model.encoder.conv_in.weight.data.normal_(mean=0.0, std=0.02) print("已手动初始化conv_in.weight") except RuntimeError as e: print(f"加载权重时出错: {e}") print("尝试仅加载兼容的权重...") # 创建新的状态字典只包含兼容的键 model_state = model.state_dict() compatible_dict = {} for k, v in state_dict.items(): if k in model_state and v.shape == model_state[k].shape: compatible_dict[k] = v # 加载兼容的权重 model.load_state_dict(compatible_dict, strict=False) print(f"部分权重加载完成: {len(compatible_dict)}/{len(state_dict)}") # 配置学习率 training_config = config_manager.get_training_config() bs = data_config.params.batch_size base_lr = model_config.base_learning_rate ngpu = training_config.get("gpus", 1) accumulate_grad_batches = training_config.get("accumulate_grad_batches", 1) if args.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr print(f"学习率缩放至: {model.learning_rate:.2e} = {accumulate_grad_batches} × {ngpu} × {bs} × {base_lr:.2e}") else: model.learning_rate = base_lr print(f"使用基础学习率: {model.learning_rate:.2e}") # 检查是否恢复训练 resume_from_checkpoint = None if args.resume: # 优先使用自动保存的last.ckpt last_ckpt = os.path.join(ckptdir, "last.ckpt") if os.path.exists(last_ckpt): print(f"恢复训练状态: {last_ckpt}") resume_from_checkpoint = last_ckpt else: # 回退到指定检查点 fallback_ckpt = os.path.join(current_dir, "checkpoints", "M3FD.ckpt") if os.path.exists(fallback_ckpt): print(f"警告: 使用仅含权重的检查点,训练状态将重置: {fallback_ckpt}") resume_from_checkpoint = fallback_ckpt else: print("未找到可用的检查点,从头开始训练") # 如果需要恢复训练,预处理检查点 if resume_from_checkpoint and os.path.exists(resume_from_checkpoint): try: # 预处理检查点 - 添加缺失的状态 checkpoint = preprocess_checkpoint(resume_from_checkpoint, model) # 创建新的完整检查点文件 fixed_ckpt_path = os.path.join(ckptdir, "fixed_checkpoint.ckpt") torch.save(checkpoint, fixed_ckpt_path) print(f"修复后的完整检查点已保存到: {fixed_ckpt_path}") # 使用修复后的检查点 resume_from_checkpoint = fixed_ckpt_path except Exception as e: print(f"预处理检查点失败: {e}") print("将尝试使用默认方式加载检查点") # 配置日志记录器 tb_logger = TensorBoardLogger(os.path.join(logdir, "tensorboard")) # 配置回调函数 callbacks = setup_callbacks(config_manager, ckptdir, tb_logger) # 检查是否有验证集 has_validation = hasattr(data_module, 'datasets') and 'validation' in data_module.datasets # 计算训练批次数 try: train_loader = data_module.train_dataloader() num_train_batches = len(train_loader) print(f"训练批次数: {num_train_batches}") except Exception as e: print(f"计算训练批次数失败: {e}") num_train_batches = 0 # 设置训练器参数(先设置基础参数) trainer_config = { "default_root_dir": logdir, "max_epochs": training_config.max_epochs, "gpus": ngpu, "distributed_backend": "ddp" if ngpu > 1 else None, "plugins": [DDPPlugin(find_unused_parameters=False)] if ngpu > 1 else None, "precision": 16, "accumulate_grad_batches": accumulate_grad_batches, "callbacks": callbacks, "logger": tb_logger, # 添加日志记录器 "resume_from_checkpoint": resume_from_checkpoint, "fast_dev_run": args.debug, "limit_val_batches": 0 if not has_validation else 1.0, "num_sanity_val_steps": 0, # 跳过初始验证加速恢复 "log_every_n_steps": 10 # 更频繁的日志记录 } # 动态调整验证配置 if has_validation: if num_train_batches < 50: # 小数据集:使用epoch验证 trainer_config["check_val_every_n_epoch"] = 1 # 确保移除步数验证参数 if "val_check_interval" in trainer_config: del trainer_config["val_check_interval"] else: # 大数据集:使用步数验证 val_check_interval = min(2000, num_train_batches) if num_train_batches < 100: val_check_interval = max(1, num_train_batches // 4) trainer_config["val_check_interval"] = val_check_interval # 创建训练器 try: print("最终训练器配置:") for k, v in trainer_config.items(): print(f" {k}: {v}") trainer = Trainer(**trainer_config) except Exception as e: print(f"创建训练器失败: {e}") tb_logger.close() sys.exit(1) # 执行训练 try: print("开始训练...") trainer.fit(model, data_module) print("训练完成!") except KeyboardInterrupt: print("训练被用户中断") if trainer.global_rank == 0 and trainer.model is not None: trainer.save_checkpoint(os.path.join(ckptdir, "interrupted.ckpt")) except Exception as e: print(f"训练出错: {e}") if trainer.global_rank == 0 and hasattr(trainer, 'model') and trainer.model is not None: trainer.save_checkpoint(os.path.join(ckptdir, "error.ckpt")) raise finally: # 关闭日志记录器 tb_logger.close() # 打印性能分析报告 if trainer.global_rank == 0 and hasattr(trainer, 'profiler'): print("训练摘要:") print(trainer.profiler.summary()) if __name__ == "__main__": main()运行报错:模型初始化失败: Error(s) in loading state_dict for CustomLatentDiffusion: size mismatch for first_stage_model.encoder.conv_in.weight: copying a param with shape torch.Size([128, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3]).

最新推荐

recommend-type

机器人开发教程-ROS 进行 SLAM 建图和机器人运动控制

机器人开发教程——ROS 进行 SLAM 建图和机器人运动控制 环境准备 确保你的开发环境已安装 ROS Noetic 或更高版本,并且安装了适用于 ROS 的 SLAM 和 Moveit2.0 软件包。 创建工作空间和包 mkdir -p ~/catkin_ws/src cd ~/catkin_ws/src catkin_create_pkg my_slam_package std_msgs rospy roscpp cd ~/catkin_ws catkin_make source devel/setup.bash 启动 SLAM 算法 roslaunch my_slam_package slam.launch 保存地图 rosrun map_server map_saver -f my_map 读取地图 在 my_slam_package/launch 目录下创建 map.launch 文件: <launch> <node name="map_server" pkg="map_server" type="map_server" args="my_map.yaml"/> </launch> 启动地图服务器 roslaunch my_slam_package map.launch 控制机器人运动 使用 teleop_twist_keyboard 包控制机器人运动: rosrun teleop_twist_keyboard teleop_twist_keyboard.py 注意事项 确保激光雷达数据正确发布到 /scan 话题。 根据实际机器人调整 SLAM 参数。 在 RViz 中查看地图和机器人位置。
recommend-type

四川话编程语言。基于Python 3.5+,可与Python模块互相引入。.zip

四川话编程语言。基于Python 3.5+,可与Python模块互相引入。.zip
recommend-type

使用ONNXRuntime部署LSTR基于Transformer的端到端实时车道线检测,包含C++和Python两个版本的程序.zip

使用ONNXRuntime部署LSTR基于Transformer的端到端实时车道线检测,包含C++和Python两个版本的程序.zip
recommend-type

软件设计师04-17年真题及模拟卷精编解析

知识点: 1. 软考概述:软件设计师是计算机技术与软件专业技术资格(水平)考试(软考)的一种职业资格,主要针对从事软件设计的人员。通过考试的人员可以获得国家认可的专业技术资格证书。 2. 软考真题的重要性:对于准备参加软考的考生来说,真题是非常重要的复习资料。通过分析和练习历年真题,可以帮助考生熟悉考试的题型、考试的难度以及出题的规律。这不仅可以提高答题的速度和准确率,同时也能帮助考生对考试有更深入的了解。 3. 软件设计师考试的科目和结构:软件设计师考试分为两个科目,分别是上午科目(知识水平)和下午科目(应用技能)。上午科目的考试内容主要包括软件工程、数据结构、计算机网络、操作系统等基础知识。下午科目则侧重考察考生的软件设计能力,包括数据库设计、系统架构设计、算法设计等。 4. 历年真题的应用:考生可以通过历年的真题来进行自我测试,了解自己的薄弱环节,并针对这些环节进行重点复习。同时,模拟考试的环境可以帮助考生适应考试的氛围,减少考试焦虑,提高应试能力。 5. 模拟卷的作用:除了历年的真题外,模拟卷也是复习中不可或缺的一部分。模拟卷可以模拟实际考试的情况,帮助考生熟悉考试的时间安排和题量分布。通过模拟考试,考生可以检验自己的复习效果,查漏补缺,进一步巩固知识点。 6. 软考复习策略:在复习软件设计师真题时,应该制定合理的复习计划,合理分配时间,全面覆盖所有知识点。同时要注重理论与实践相结合,理解概念的同时要注重实际应用。考生还可以参加一些线上或线下的辅导班,与老师和同学进行交流,解决在复习中遇到的问题。 7. 考试报名及注意事项:考生需要在规定的时间内通过官方网站或者授权的培训机构进行报名。在报名时要注意个人信息的准确性,并在考试前了解具体的考试时间和地点。考生应提前准备考试所需的证件和文具,确保能够顺利参加考试。 8. 软考后的职业发展:通过软考获得专业技术资格证书后,对于软件设计师职业发展有一定的帮助。在求职、晋升以及职称评定等方面,具有国家认可的资格证书往往会增加一定的优势。同时,获得证书也是对个人技术能力的一种肯定,有助于提升个人在行业内的竞争力。 通过以上内容的介绍,可以看出软件设计师真题是备考软考的重要工具,它能够帮助考生全面了解考试结构和内容,检验学习效果,为成功通过软考做好充分准备。对于软件设计师这一职业来说,真题的研究与分析,不仅可以提高考试通过率,也对个人的专业成长有着重要的意义。
recommend-type

QMCA开源在容器化中的实践:Docker与Kubernetes集成秘籍

# 摘要 本文主要介绍了容器化技术在现代软件开发和部署中的重要性,特别是Docker和Kubernetes的技术细节与应用。文章首先概述了容器化技术的概况和QMCA开源项目。随后深入探讨Docker容器技术,涵盖了安装配置、镜像管理、网络和数据管理、以及高级功能的实践。接着,本文转向Kubernetes的基础知识和集群管理,包括架构组件、
recommend-type

CAD切地形剖面图

<think>好的,我现在需要帮用户解决如何在CAD中绘制或生成地形剖面图的问题。首先,用户提到了几个引用来源,我需要仔细看看这些引用里有什么有用的信息,然后结合自己的知识来整理步骤。 引用[1]提到使用【图新地球桌面端】的【提取剖面线】功能,生成高程值后用CAD命令绘制。这应该是一个专门软件辅助生成数据,再导入CAD的方法。步骤可能包括提取地形数据,生成CAD命令,然后在CAD中执行这些命令。 引用[2]说CAD绘制剖面图的步骤是先有线条,然后处理。用户可能想知道如何在CAD内部直接操作,比如画线后如何生成剖面。可能需要结合高程数据,或者使用插件。 引用[3]提到AutoCAD Civ
recommend-type

中级Java开发必学:龙果学院Java多线程并发编程教程

标题“Java多线程知识,龙果学院”与描述“Java多线程知识,龙果学院,适合中级Java开发,分小节讲解”向我们明确指出了该资料的主要内容和适用对象。本篇内容将围绕Java多线程及其并发编程展开,提供给中级Java开发者系统性的学习指导。 ### 知识点一:Java多线程基础 - **线程概念**:多线程是指从软件或者硬件上实现多个线程并发执行的技术,每个线程可以处理不同的任务,提高程序的执行效率。 - **Java中的线程**:Java通过Thread类和Runnable接口实现线程。创建线程有两种方式:继承Thread类和实现Runnable接口。 - **线程状态**:Java线程在生命周期中会经历新建(New)、就绪(Runnable)、运行(Running)、阻塞(Blocked)和死亡(Terminated)这几个状态。 - **线程方法**:包括启动线程的start()方法、中断线程的interrupt()方法、线程暂停的sleep()方法等。 ### 知识点二:线程同步机制 - **同步问题**:在多线程环境中,共享资源的安全访问需要通过线程同步来保障,否则会发生数据竞争和条件竞争等问题。 - **同步代码块**:使用synchronized关键字来定义同步代码块,确保同一时刻只有一个线程可以执行该代码块内的代码。 - **同步方法**:在方法声明中加入synchronized关键字,使得方法在调用时是同步的。 - **锁**:在Java中,每个对象都有一把锁,synchronized实质上是通过获取对象的锁来实现线程的同步。 - **死锁**:多个线程相互等待对方释放锁而导致程序无法继续运行的情况,需要通过合理设计避免。 ### 知识点三:线程间通信 - **等待/通知机制**:通过Object类中的wait()、notify()和notifyAll()方法实现线程间的协调和通信。 - **生产者-消费者问题**:是线程间通信的经典问题,涉及如何在生产者和消费者之间有效地传递数据。 - **等待集(wait set)**:当线程调用wait()方法时,它进入与之相关联对象的等待集。 - **条件变量**:Java 5引入了java.util.concurrent包中的Condition接口,提供了比Object的wait/notify更为强大的线程协作机制。 ### 知识点四:并发工具类 - **CountDownLatch**:允许一个或多个线程等待其他线程完成操作。 - **CyclicBarrier**:让一组线程到达一个屏障点后互相等待,直到所有线程都到达后才继续执行。 - **Semaphore**:信号量,用于控制同时访问特定资源的线程数量。 - **Phaser**:一种可以动态调整的同步屏障,类似于CyclicBarrier,但是更加灵活。 ### 知识点五:并发集合和原子变量 - **并发集合**:java.util.concurrent包下提供的一系列线程安全的集合类,例如ConcurrentHashMap、CopyOnWriteArrayList等。 - **原子变量**:如AtomicInteger、AtomicLong等,提供了无锁的线程安全操作,使用了CAS(Compare-And-Swap)技术。 - **锁框架**:如ReentrantLock、ReadWriteLock等,提供了比内置锁更为灵活和强大的锁机制。 ### 知识点六:线程池的使用 - **线程池概念**:线程池是一种多线程处理形式,它预先创建若干数量的线程,将线程置于一个池中管理,避免在使用线程时创建和销毁线程的开销。 - **线程池优势**:重用线程池中的线程,减少创建和销毁线程的开销;有效控制最大并发数;提供定时执行、周期性执行、单线程、并发数控制等功能。 - **线程池的参数**:核心线程数、最大线程数、存活时间、队列大小等参数决定了线程池的行为。 - **线程池的实现**:通过Executors类创建线程池,也可以通过ThreadPoolExecutor直接实例化一个线程池。 ### 知识点七:Java 8并发新特性 - **Stream API**:Java 8引入的Stream API在并行处理数据时非常有用,可以轻松将串行处理转换为并行处理。 - **CompletableFuture**:实现了Future和CompletionStage接口,用于异步编程,简化了线程操作并提供了更细粒度的控制。 - **Lambda表达式**:简化了使用匿名内部类实现事件监听器、比较器等场景,从而间接提升了并发编程的效率。 以上知识点覆盖了Java多线程和并发编程的基本概念、同步机制、线程间通信、并发工具类、原子变量、线程池的使用以及Java 8的新特性等核心内容。对于中级Java开发者而言,这些内容既全面又系统,有助于深入理解并应用Java多线程技术,设计出高效、稳定的应用程序。
recommend-type

QMCA开源版本控制指南:提升代码管理与团队协作效率的策略

# 摘要 本文全面介绍了QMCA开源版本控制系统的相关知识和应用。首先,概述了QMCA的基础知识和代码管理中的基本操作与功能。随后,重点探讨了QMCA在代码合并、分支管理、审核及问题追踪中的优势与应用。接着,分析了QMCA在团队协作中的权限管理、项目管理以
recommend-type

CAD技术标准

<think>好的,用户想了解CAD相关的技术标准规范。首先,我需要明确用户的需求,他们可能是在工程设计、制造或建筑领域工作,需要遵循特定的标准来确保图纸的兼容性和质量。 我应该先回忆常见的CAD技术标准,比如国际标准ISO和ANSI,还有国家层面的标准如中国的GB/T。这些标准通常涵盖文件格式、图层管理、尺寸标注等方面。比如,ISO 128规定了工程制图的基本规则,而GB/T 18229可能与CAD文件管理有关。 接下来,用户可能需要具体的标准编号和内容。例如,文件格式方面,DXF和DWG是常见的,由Autodesk制定,但ISO也有STEP标准(ISO 10303)。图层管理方面,不同
recommend-type

快速自定义安装RabbitMQ及Erlang脚本指南

### RabbitMQ快速安装自定义脚本知识点说明 #### 标题:rabbitmq快速安装自定义脚本 这个标题指出了脚本的主要功能——实现RabbitMQ的快速安装。脚本的开发目的是为了简化安装过程,让使用者不需要深入了解RabbitMQ安装过程中的各种配置细节。此外,标题还暗示了该脚本是为用户量身定制的,意味着它可能包含特定的安装选项或配置,以满足不同的安装需求。 #### 描述:该资源是自定义的脚本快速安装erlang19.3+rabbitmq3.6,将rabbitmq的安装配置、开机自启动等繁杂操作写成自定义脚本,rabbitmq安装位置:/usr/local/rabbitmq,erlang安装位置:/usr/local/erlang 在描述中提到了几个关键的知识点,分别是: 1. **Erlang版本**:脚本支持安装的Erlang版本为19.3。Erlang是一种编程语言,它的运行时系统被广泛用于需要高并发处理的场景,而这正是构建消息队列系统的核心需求。RabbitMQ就是用Erlang编写的,因此在安装RabbitMQ之前,必须先安装Erlang。 2. **RabbitMQ版本**:脚本同时支持安装的RabbitMQ版本为3.6。这个版本是RabbitMQ的稳定版本,提供了丰富的特性和性能改进。 3. **自定义脚本**:脚本集成了RabbitMQ的安装和配置步骤,使得用户能够快速地完成安装并且不需要手动配置系统文件。 4. **安装位置**:脚本指定了RabbitMQ和Erlang的安装路径,分别是`/usr/local/rabbitmq`和`/usr/local/erlang`。这表明脚本能够直接将RabbitMQ和Erlang安装在指定目录,简化了用户的操作。 5. **开机自启动**:脚本还负责配置RabbitMQ服务的开机自启动,这是在生产环境中非常实用的一个功能,可以确保消息队列服务在系统重启后能够自动运行。 #### 标签:rabbitmq安装 标签强调了脚本的主要用途——安装RabbitMQ。由于RabbitMQ的安装和配置过程相对复杂,尤其是涉及到依赖包的安装和环境变量的配置,所以使用标签化的安装脚本可以大大减少安装所需时间和可能出现的错误。 #### 压缩包子文件的文件名称列表:installrbmq_sh 文件名称`installrbmq_sh`表明了这个脚本的用途,即安装RabbitMQ,文件扩展名`.sh`表明这是一个shell脚本文件。在Unix-like系统中,shell脚本是自动化任务的常用方法,能够通过一行行命令的执行,实现复杂的安装和配置过程。 #### 结语 从给定的文件信息中,我们可以了解到RabbitMQ的安装不仅仅依赖于一个单一的脚本文件,还可能涉及对Erlang语言环境的依赖。RabbitMQ快速安装自定义脚本将安装、配置以及开机自启动等步骤进行了整合,使得安装过程更加简洁、高效。这个脚本将作为在生产环境中部署RabbitMQ的重要工具,通过自动化的方式大大减少了部署时间,并可能降低因手动配置不当而带来的错误风险。使用这样的自定义脚本,不仅提高了工作效率,还保证了RabbitMQ的稳定运行和可管理性。在实际工作中,为了确保安装过程的顺利进行,需要确保脚本文件的安全性,避免执行未经验证的脚本,这可能引起系统安全风险。