对代码进行理解、排查错误问题,特别关注是强化学习部分,根据你的推理逻辑给出正确且合理的方案步骤(文字描述),并给出优化后正确逻辑的完整代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import random
import argparse
from collections import deque
from torch.distributions import Normal, Categorical
from torch.nn.parallel import DistributedDataParallel as DDP
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from mmengine.registry import MODELS, DATASETS
from mmengine.config import Config
from rl_seg.datasets.build_dataloader import init_dist_pytorch, build_dataloader
from rl_seg.datasets import load_data_to_gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")
# PPO 代理(Actor-Critic 网络)
class PPOAgent(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(PPOAgent, self).__init__()
self.state_dim = state_dim
self.action_dim = action_dim
# 共享特征提取层
self.shared_layers = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
# nn.ReLU(),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
# nn.ReLU()
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
# Actor 网络 (策略)
self.actor = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
# nn.ReLU(),
nn.GELU(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh() # 输出在[-1,1]范围内
)
# Critic 网络 (值函数)
self.critic = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
# nn.ReLU(),
nn.GELU(),
nn.Linear(hidden_dim, 1)
)
# 动作标准差 (可学习参数)
self.log_std = nn.Parameter(torch.zeros(1, action_dim))
# 初始化权重
self.apply(self._init_weights)
def _init_weights(self, module):
"""初始化网络权重"""
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=0.01)
nn.init.constant_(module.bias, 0.0)
def forward(self, state):
features = self.shared_layers(state)
action_mean = self.actor(features)
value = self.critic(features)
return action_mean, value
def act(self, state):
"""与环境交互时选择动作"""
state = torch.FloatTensor(state).unsqueeze(0).to(device) # 确保是 [1, state_dim]
print(state.shape)
with torch.no_grad():
action_mean, value = self.forward(state)
# 创建动作分布 (添加最小标准差确保稳定性)
action_std = torch.clamp(self.log_std.exp(), min=0.01, max=1.0)
dist = Normal(action_mean, action_std)
# 采样动作
action = dist.sample()
log_prob = dist.log_prob(action).sum(-1)
return action, log_prob, value
def evaluate(self, state, action):
"""评估动作的概率和值"""
# 添加维度检查
if len(state.shape) == 1:
state = state.unsqueeze(0)
if len(action.shape) == 1:
action = action.unsqueeze(0)
action_mean, value = self.forward(state)
# 创建动作分布
action_std = torch.clamp(self.log_std.exp(), min=0.01, max=1.0)
dist = Normal(action_mean, action_std)
# 计算对数概率和熵
log_prob = dist.log_prob(action).sum(-1)
entropy = dist.entropy().sum(-1)
return log_prob, entropy, value
# 强化学习优化器
class PPOTrainer:
"""PPO训练器,整合了策略优化和模型微调"""
def __init__(self, seg_net, agent, cfg):
"""
Args:
seg_net: 预训练的分割网络
agent: PPO智能体
cfg: 配置对象,包含以下属性:
- lr: 学习率
- clip_param: PPO裁剪参数
- ppo_epochs: PPO更新轮数
- gamma: 折扣因子
- tau: GAE参数
- value_coef: 值函数损失权重
- entropy_coef: 熵正则化权重
- max_grad_norm: 梯度裁剪阈值
"""
self.seg_net = seg_net
self._base_seg_net = seg_net.module if isinstance(seg_net, DDP) else seg_net
self._base_seg_net.device = self.seg_net.device
self.agent = agent
self.cfg = cfg
self.writer = SummaryWriter(log_dir='runs/ppo_trainer')
# 使用分离的优化器
self.optimizer_seg = optim.AdamW(
self.seg_net.parameters(),
lr=cfg.lr,
weight_decay=1e-4
)
self.optimizer_agent = optim.AdamW(
self.agent.parameters(),
lr=cfg.lr,
weight_decay=1e-4
)
# 训练记录
self.best_miou = 0.0
self.metrics = {
'loss': [],
'reward': [],
'miou': [],
'class_ious': [],
'lr': []
}
def compute_state(self, features, pred, gt_seg):
"""
计算强化学习状态向量
Args:
features: 从extract_features获取的字典包含:
- spatial_features: [B, C1, H, W]
- bev_features: [B, C2, H, W]
- neck_features: [B, C3, H, W]
pred: 网络预测的分割结果 [B, num_classes, H, W]
gt_seg: 真实分割标签 [B, H, W]
Returns:
state: 状态向量 [state_dim]
"""
# 主要使用neck_features作为代表特征 torch.Size([4, 64, 496, 432])
feats = features["neck_features"] # [B, C, H, W]
print(feats.shape)
B, C, H, W = feats.shape
# 初始化状态列表
states = []
# 为批次中每个样本单独计算状态
for i in range(B):
# 特征统计
feat_mean = feats[i].mean(dim=(1, 2)) # [C]
feat_std = feats[i].std(dim=(1, 2)) # [C]
# 预测类别分布
pred_classes = pred[i].argmax(dim=0) # [H, W]
class_dist = torch.bincount(
pred_classes.flatten(),
minlength=21
).float() / (H * W)
# 各类IoU (需实现单样本IoU计算)
sample_miou, sample_cls_iou = self.compute_sample_iou(
pred[i:i+1],
{k: v[i:i+1] for k, v in gt_seg.items()}
)
sample_cls_iou = torch.FloatTensor(sample_cls_iou).to(feats.device)
# 组合状态
state = torch.cat([
feat_mean,
feat_std,
class_dist,
sample_cls_iou
])
states.append(state)
return torch.stack(states)
# 特征统计 (均值、标准差)
feat_mean = feats.mean(dim=(2, 3)).flatten() # [B*C]
feat_std = feats.std(dim=(2, 3)).flatten() # [B*C]
# 预测类别分布
pred_classes = pred.argmax(dim=1)
# class_dist = torch.bincount(pred_classes.flatten(), minlength=21).float() / pred_classes.numel()
class_dist = torch.bincount(
pred_classes.flatten(),
minlength=21
).float() / (B * H * W)
# 各类IoU
batch_miou, cls_iou = get_miou(pred, gt_seg, classes=range(21))
cls_iou = torch.FloatTensor(cls_iou).to(feats.device)
# 组合状态
state = torch.cat([feat_mean, feat_std, class_dist, cls_iou])
print(feat_mean.shape, feat_std.shape, class_dist.shape, cls_iou.shape)
print(state.shape)
# 必须与PPOAgent的state_dim完全一致
assert len(state) == self.agent.state_dim, \
f"State dim mismatch: {len(state)} != {self.agent.state_dim}"
return state
def compute_reward(self, miou, prev_miou, class_ious, prev_class_ious):
"""
计算复合奖励函数
Args:
miou: 当前mIoU
prev_miou: 前一次mIoU
class_ious: 当前各类IoU [num_classes]
prev_class_ious: 前一次各类IoU [num_classes]
Returns:
reward: 综合奖励值
"""
# 基础奖励: mIoU提升
miou_reward = 10.0 * (miou - prev_miou)
# 类别平衡奖励: 鼓励所有类别均衡提升
class_reward = 0.0
for cls, (iou, prev_iou) in enumerate(zip(class_ious, prev_class_ious)):
if iou > prev_iou:
# 对稀有类别给予更高奖励
weight = 1.0 + (1.0 - prev_iou) # 性能越差的类权重越高
class_reward += weight * (iou - prev_iou)
# 惩罚项: 防止某些类别性能严重下降
penalty = 0.0
# for cls in range(21):
# if class_ious[cls] < prev_class_ious[cls] * 0.8:
# penalty += 5.0 * (prev_class_ious[cls] - class_ious[cls])
for cls, (iou, prev_iou) in enumerate(zip(class_ious, prev_class_ious)):
if iou < prev_iou * 0.9: # 性能下降超过10%
penalty += 5.0 * (prev_iou - iou)
total_reward = miou_reward + class_reward - penalty
return np.clip(total_reward, -5.0, 10.0) # 限制奖励范围
def apply_action(self, action):
"""
应用智能体动作调整模型参数
Args:
action: [6] 连续动作向量,范围[-1, 1]
"""
# 动作0-1: 调整学习率
lr_scale = 0.1 + 0.9 * (action[0] + 1) / 2 # 映射到[0.1, 1.0]
for param_group in self.optimizer.param_groups:
param_group['lr'] *= lr_scale
# 动作2-3: 调整特征提取层权重 (范围[0.8, 1.2])
backbone_scale = 0.8 + 0.2 * (action[2] + 1) / 2
with torch.no_grad():
for param in self.seg_net.module.backbone_2d.parameters():
param.data *= backbone_scale # (0.9 + 0.1 * action[2]) # 调整范围[0.9,1.1]
# 动作4-5: 调整分类头权重
head_scale = 0.8 + 0.2 * (action[4] + 1) / 2
with torch.no_grad():
for param in self.seg_net.module.at_seg_head.parameters():
param.data *= head_scale # (0.9 + 0.1 * action[4]) # 调整范围[0.9,1.1]
def train_epoch(self, train_loader, epoch):
"""执行一个训练周期"""
epoch_metrics = {
'seg_loss': 0.0,
'reward': 0.0,
'miou': 0.0,
'class_ious': np.zeros(21),
'policy_loss': 0.0,
'value_loss': 0.0,
'entropy_loss': 0.0,
'batch_count': 0
}
self.seg_net.train()
self.agent.train()
for data_dicts in tqdm(train_loader, desc=f"RL Epoch {epoch+1}/{self.cfg.num_epochs_rl}"):
load_data_to_gpu(data_dicts)
# 初始预测和特征
with torch.no_grad():
initial_pred = self.seg_net(data_dicts)
initial_miou, initial_class_ious = get_miou(
initial_pred,
data_dicts,
classes=range(21)
)
features = self.seg_net.module.extract_features(data_dicts) # DDP包装了
# features = self._base_seg_net.extract_features(data_dicts)
# 计算初始状态
states = self.compute_state(features, initial_pred, data_dicts)
# 为批次中每个样本选择动作
actions, log_probs, values = [], [], []
for state in states:
action, log_prob, value = self.agent.act(state.cpu().numpy())
actions.append(action)
log_probs.append(log_prob)
values.append(value)
# 应用第一个样本的动作 (简化处理)
self.apply_action(actions[0])
# 调整后的预测
adjusted_pred = self.seg_net(data_dicts)
adjusted_miou, adjusted_class_ious = get_miou(
adjusted_pred,
data_dicts,
classes=range(21)
)
# 计算奖励 (使用整个批次的平均改进)
reward = self.compute_reward(
adjusted_miou,
initial_miou,
adjusted_class_ious,
initial_class_ious
)
# 计算优势 (修正为单步优势)
advantages = [reward - v for v in values]
# 存储经验
experience = {
'states': states.cpu().numpy(),
'actions': actions,
'rewards': [reward] * len(actions),
'old_log_probs': log_probs,
'old_values': values,
'advantages': advantages,
}
# PPO优化
policy_loss, value_loss, entropy_loss = self.ppo_update(experience)
# 分割网络损失
seg_loss = self.seg_net.module.at_seg_head.get_loss(
adjusted_pred,
data_dicts
)
# 分割网络更新 (使用单独优化器)
self.optimizer_seg.zero_grad()
seg_loss.backward()
torch.nn.utils.clip_grad_norm_(
self.seg_net.parameters(),
self.cfg.max_grad_norm
)
self.optimizer_seg.step()
# 记录指标
epoch_metrics['seg_loss'] += seg_loss.item()
epoch_metrics['reward'] += reward
epoch_metrics['miou'] += adjusted_miou
epoch_metrics['class_ious'] += adjusted_class_ious
epoch_metrics['policy_loss'] += policy_loss
epoch_metrics['value_loss'] += value_loss
epoch_metrics['entropy_loss'] += entropy_loss
epoch_metrics['batch_count'] += 1
# 计算平均指标
avg_metrics = {}
for k in epoch_metrics:
if k != 'batch_count':
avg_metrics[k] = epoch_metrics[k] / epoch_metrics['batch_count']
# 记录到TensorBoard
self.writer.add_scalar('Loss/seg_loss', avg_metrics['seg_loss'], epoch)
self.writer.add_scalar('Reward/total', avg_metrics['reward'], epoch)
self.writer.add_scalar('mIoU/train', avg_metrics['miou'], epoch)
self.writer.add_scalar('Loss/policy', avg_metrics['policy_loss'], epoch)
self.writer.add_scalar('Loss/value', avg_metrics['value_loss'], epoch)
self.writer.add_scalar('Loss/entropy', avg_metrics['entropy_loss'], epoch)
return avg_metrics
def ppo_update(self, experience):
"""
PPO策略优化步骤
Args:
batch: 包含以下键的字典:
- states: [batch_size, state_dim]
- actions: [batch_size, action_dim]
- old_log_probs: [batch_size]
- old_values: [batch_size]
- rewards: [batch_size]
- advantages: [batch_size]
Returns:
policy_loss: 策略损失值
value_loss: 值函数损失值
entropy_loss: 熵损失值
"""
states = torch.FloatTensor(experience['states']).unsqueeze(0).to(device)
actions = torch.FloatTensor(experience['actions']).unsqueeze(0).to(device)
old_log_probs = torch.FloatTensor([experience['old_log_probs']]).to(device)
old_values = torch.FloatTensor([experience['old_values']]).to(device)
rewards = torch.FloatTensor([experience['rewards']]).to(device)
advantages = torch.FloatTensor(experience['advantages']).to(device) # GAE优势 优势估计使用GAE(广义优势估计)
policy_losses, value_losses, entropy_losses = [], [], []
for _ in range(self.cfg.ppo_epochs):
# 评估当前策略
log_probs, entropy, values = self.agent.evaluate(states, actions)
# 比率
ratios = torch.exp(log_probs - old_log_probs)
# 裁剪目标
surr1 = ratios * advantages
surr2 = torch.clamp(ratios,
1.0 - self.cfg.clip_param,
1.0 + self.cfg.clip_param) * advantages
# 策略损失
policy_loss = -torch.min(surr1, surr2).mean()
# 值函数损失
value_loss = 0.5 * (values - rewards).pow(2).mean()
# 熵损失
entropy_loss = -entropy.mean()
# 总损失
loss = policy_loss + self.cfg.value_coef * value_loss + self.cfg.entropy_coef * entropy_loss
# 智能体参数更新
self.optimizer_agent.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.agent.parameters(),
self.cfg.max_grad_norm
)
self.optimizer_agent.step()
policy_losses.append(policy_loss.item())
value_losses.append(value_loss.item())
entropy_losses.append(entropy_loss.item())
return (
np.mean(policy_losses),
np.mean(value_losses),
np.mean(entropy_losses)
)
def close(self):
"""关闭资源"""
self.writer.close()
# 监督学习预训练
def supervised_pretrain(cfg):
seg_net = MODELS.build(cfg.model).to('cuda')
seg_head = MODELS.build(cfg.model.at_seg_head).to('cuda')
if cfg.pretrained_path:
ckpt = torch.load(cfg.pretrained_path)
print(ckpt.keys())
seg_net.load_state_dict(ckpt['state_dict'])
print(f'Load pretrained ckpt: {cfg.pretrained_path}')
seg_net = DDP(seg_net, device_ids=[cfg.local_rank])
print(seg_net)
return seg_net
optimizer = optim.Adam(seg_net.parameters(), lr=cfg.lr)
writer = SummaryWriter(log_dir='runs/pretrain')
train_losses = []
train_mious = []
train_class_ious = [] # 存储每个epoch的各类IoU
for epoch in range(cfg.num_epochs):
cfg.sampler.set_epoch(epoch)
epoch_loss = 0.0
epoch_miou = 0.0
epoch_class_ious = np.zeros(21) # 初始化各类IoU累加器
batch_count = 0
seg_net.train()
for data_dicts in tqdm(cfg.train_loader, desc=f"Pretrain Epoch {epoch+1}/{cfg.num_epochs}"):
optimizer.zero_grad()
pred = seg_net(data_dicts)
device = pred.device
seg_head = seg_head.to(device)
loss = seg_head.get_loss(pred, data_dicts["gt_seg"].to(device))
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# import pdb;pdb.set_trace()
# 计算mIoU
class_ious = []
batch_miou, cls_iou = get_miou(pred, data_dicts, classes=[i for i in range(21)])
# for cls in range(5):
# pred_mask = (pred.argmax(dim=1) == cls)
# true_mask = (labels == cls)
# intersection = (pred_mask & true_mask).sum().float()
# union = (pred_mask | true_mask).sum().float()
# iou = intersection / (union + 1e-8)
# class_ious.append(iou.item())
epoch_miou += batch_miou
epoch_class_ious += np.array(cls_iou) # 累加各类IoU
batch_count += 1
# avg_loss = epoch_loss / len(cfg.dataloader)
# avg_miou = epoch_miou / len(cfg.dataloader)
# 计算epoch平均指标
avg_loss = epoch_loss / batch_count if batch_count > 0 else 0.0
avg_miou = epoch_miou / batch_count if batch_count > 0 else 0.0
avg_class_ious = epoch_class_ious / batch_count if batch_count > 0 else np.zeros(21)
train_losses.append(avg_loss)
train_mious.append(avg_miou)
train_class_ious.append(avg_class_ious) # 存储各类IoU
# 记录到TensorBoard
writer.add_scalar('Loss/train', avg_loss, epoch)
writer.add_scalar('mIoU/train', avg_miou, epoch)
for cls, iou in enumerate(avg_class_ious):
writer.add_scalar(f'IoU/class_{cls}', iou, epoch)
print(f"Epoch {epoch+1}/{cfg.num_epochs} - Loss: {avg_loss:.3f}, mIoU: {avg_miou*100:.3f}")
print("Class IoUs:")
for cls, iou in enumerate(avg_class_ious):
print(f" {cfg.class_names[cls]}: {iou*100:.3f}")
# # 保存预训练模型
torch.save(seg_net.state_dict(), "polarnet_pretrained.pth")
writer.close()
# 绘制训练曲线
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title("Supervised Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.subplot(1, 2, 2)
plt.plot(train_mious)
plt.title("Supervised Training mIoU")
plt.xlabel("Epoch")
plt.ylabel("mIoU")
plt.tight_layout()
plt.savefig("supervised_training.png")
return seg_net
# 强化学习微调
def rl_finetune(cfg):
# 状态维度 = 特征统计(1024*2) + 类别分布(5) + 各类IoU(5)
state_dim = 256*2 + 21 + 21
action_dim = 6 # 6个连续动作;动作0调整学习率,动作1调整特征提取层权重,动作2调整分类头权重
# 初始化PPO智能体
agent = PPOAgent(state_dim, action_dim).to(device)
if cfg.agent_path:
agent.load_state_dict(torch.load(cfg.agent_path))
trainer = PPOTrainer(cfg.seg_net, agent, cfg)
train_losses = []
train_rewards = []
train_mious = []
# 训练循环
for epoch in range(cfg.num_epochs_rl):
avg_metrics = trainer.train_epoch(cfg.train_loader, epoch)
# 记录指标
train_losses.append(avg_metrics['seg_loss'])
train_rewards.append(avg_metrics['reward'])
train_mious.append(avg_metrics['miou'])
# trainer.metrics['loss'].append(avg_metrics['seg_loss'])
# trainer.metrics['reward'].append(avg_metrics['reward'])
# trainer.metrics['miou'].append(avg_metrics['miou'])
# trainer.metrics['class_ious'].append(avg_metrics['class_ious'])
# trainer.metrics['lr'].append(trainer.optimizer.param_groups[0]['lr'])
# 保存最佳模型
if avg_metrics['miou'] > trainer.best_miou:
trainer.best_miou = avg_metrics['miou']
torch.save(cfg.seg_net.state_dict(), "polarnet_rl_best.pth")
torch.save(agent.state_dict(), "ppo_agent_best.pth")
np.savetxt("best_class_ious.txt", avg_metrics['class_ious'])
# 打印日志
print(f"\nRL Epoch {epoch+1}/{cfg.num_epochs_rl} Results:")
print(f" Seg Loss: {avg_metrics['seg_loss']:.4f}")
print(f" Reward: {avg_metrics['reward']:.4f}")
print(f" mIoU: {avg_metrics['miou']*100:.3f} (Best: {trainer.best_miou*100:.3f})")
print(f" Policy Loss: {avg_metrics['policy_loss']:.4f}")
print(f" Value Loss: {avg_metrics['value_loss']:.4f}")
print(f" Entropy Loss: {avg_metrics['entropy_loss']:.4f}")
print(f" Learning Rate: {trainer.optimizer.param_groups[0]['lr']:.2e}")
print(" Class IoUs:")
for cls, iou in enumerate(avg_metrics['class_ious']):
print(f" {cfg.class_names[cls]}: {iou:.4f}")
# 保存最终模型和训练记录
torch.save(cfg.seg_net.state_dict(), "polarnet_rl_final.pth")
torch.save(agent.state_dict(), "ppo_agent_final.pth")
np.savetxt("training_metrics.txt", **trainer.metrics)
print(f"\nTraining completed. Best mIoU: {trainer.best_miou:.4f}")
trainer.close()
# 绘制训练曲线
plt.figure(figsize=(15, 10))
plt.subplot(2, 2, 1)
plt.plot(train_losses)
plt.title("RL Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.subplot(2, 2, 2)
plt.plot(train_rewards)
plt.title("Average Reward")
plt.xlabel("Epoch")
plt.ylabel("Reward")
plt.subplot(2, 2, 3)
plt.plot(train_mious)
plt.title("RL Training mIoU")
plt.xlabel("Epoch")
plt.ylabel("mIoU")
plt.subplot(2, 2, 4)
plt.plot(train_losses, label='Loss')
plt.plot(train_mious, label='mIoU')
plt.title("Loss vs mIoU")
plt.xlabel("Epoch")
plt.legend()
plt.tight_layout()
plt.savefig("rl_training.png")
return cfg.seg_net, agent
# 模型评估
def evaluate_model(cfg):
cfg.seg_net.eval()
avg_miou = 0.0
total_miou = 0.0
class_ious = np.zeros(21)
batch_count = 0 # 记录实际处理的batch数量
return avg_miou, class_ious
with torch.no_grad():
for data_dicts in tqdm(cfg.val_loader, desc="Evaluating"):
pred = cfg.seg_net(data_dicts)
batch_miou, cls_iou = get_miou(pred, data_dicts, classes=[i for i in range(21)])
total_miou += batch_miou
class_ious += cls_iou
batch_count += 1
# avg_miou = total_miou / len(cfg.dataloader)
# class_ious /= len(cfg.dataloader)
# 计算平均值
avg_miou = total_miou / batch_count if batch_count > 0 else 0.0
class_ious = class_ious / batch_count if batch_count > 0 else np.zeros(21)
print("\nEvaluation Results:")
print(f"Overall mIoU: {avg_miou*100:.3f}")
for cls, iou in enumerate(class_ious):
print(f" {cfg.class_names[cls]}: {iou*100:.3f}")
return avg_miou, class_ious
def fast_hist(pred, label, n):
k = (label >= 0) & (label < n)
bin_count = np.bincount(n * label[k].astype(int) + pred[k], minlength=n**2)
return bin_count[: n**2].reshape(n, n)
def fast_hist_crop(output, target, unique_label):
hist = fast_hist(
output.flatten(), target.flatten(), np.max(unique_label) + 1
)
hist = hist[unique_label, :]
hist = hist[:, unique_label]
return hist
def compute_miou_test(y_true, y_pred):
from sklearn.metrics import confusion_matrix
current = confusion_matrix(y_true, y_pred)
intersection = np.diag(current)
gt = current.sum(axis=1)
pred = current.sum(axis=0)
union = gt + pred - intersection
iou_list = intersection / union.astype(np.float32) + 1e-8
return np.mean(iou_list), iou_list
def get_miou(pred, target, classes=[i for i in range(21)]):
# import pdb;pdb.set_trace()
gt_val_grid_ind = target["grid_ind"]
gt_val_pt_labs = target["labels_ori"]
pred_labels = torch.argmax(pred, dim=1).cpu().detach().numpy()
metric_data = []
miou_list = []
for bs, i_val_grid in enumerate(gt_val_grid_ind):
val_grid_idx = pred_labels[
bs, i_val_grid[:, 1], i_val_grid[:, 0], i_val_grid[:, 2]
] # (N,)
gt_val_pt_lab_idx = gt_val_pt_labs[bs] #(N,1)
hist = fast_hist_crop(
val_grid_idx, gt_val_pt_lab_idx, classes
) # (21, 21)
hist_tensor = torch.from_numpy(hist).to(pred.device)
metric_data.append(hist_tensor)
# miou, iou_dict = compute_miou_test(gt_val_pt_lab_idx, val_grid_idx)
# miou_list.append(miou)
hist = sum(metric_data).cpu().numpy()
iou_overall = np.diag(hist) / ((hist.sum(1) + hist.sum(0) - np.diag(hist)) + 1e-6)
miou = np.nanmean(iou_overall)
# print(metric_data)
# print(iou_overall)
# print(miou)
# print(miou_list, np.nanmean(miou_list))
# import pdb;pdb.set_trace()
return miou, iou_overall
# 主函数
def main(args):
# 第一阶段:监督学习预训练
print("="*50)
print("Starting Supervised Pretraining...")
print("="*50)
cfg_file = "rl_seg/configs/rl_seg_leap.py"
cfg = Config.fromfile(cfg_file)
print('aaaaaaaa ',cfg.keys())
total_gpus, LOCAL_RANK = init_dist_pytorch(
tcp_port=18888, local_rank=0, backend='nccl'
)
cfg.local_rank = LOCAL_RANK
dist_train = True
train_dataset, train_dataloader, sampler = build_dataloader(dataset_cfg=cfg,
data_path=cfg.train_data_path,
workers=cfg.num_workers,
samples_per_gpu=cfg.batch_size,
num_gpus=cfg.num_gpus,
dist=dist_train,
pipeline=cfg.train_pipeline,
training=True)
cfg.train_loader = train_dataloader
cfg.sampler = sampler
seg_net = supervised_pretrain(cfg)
val_dataset, val_dataloader, sampler = build_dataloader(dataset_cfg=cfg,
data_path=cfg.val_data_path,
workers=cfg.num_workers,
samples_per_gpu=cfg.batch_size,
num_gpus=cfg.num_gpus,
dist=True,
pipeline=cfg.val_pipeline,
training=False)
cfg.val_loader = val_dataloader
cfg.sampler = sampler
cfg.seg_net = seg_net
# 评估预训练模型
print("\nEvaluating Pretrained Model...")
pretrain_miou, pretrain_class_ious = evaluate_model(cfg)
# 第二阶段:强化学习微调
print("\n" + "="*50)
print("Starting RL Finetuning...")
print("="*50)
seg_net, ppo_agent = rl_finetune(cfg)
# 评估强化学习优化后的模型
print("\nEvaluating RL Optimized Model...")
rl_miou, rl_class_ious = evaluate_model(cfg)
# 结果对比
print("\nPerformance Comparison:")
print(f"Pretrained mIoU: {pretrain_miou*100:.3f}")
print(f"RL Optimized mIoU: {rl_miou*100:.3f}")
print(f"Improvement: {(rl_miou - pretrain_miou)*100:.3f} ({((rl_miou - pretrain_miou)/pretrain_miou)*100:.2f}%)")
# 绘制各类别IoU对比
plt.figure(figsize=(10, 6))
x = np.arange(5)
width = 0.35
plt.bar(x - width/2, pretrain_class_ious, width, label='Pretrained')
plt.bar(x + width/2, rl_class_ious, width, label='RL Optimized')
plt.xticks(x, cfg.class_names)
plt.ylabel("IoU")
plt.title("Per-Class IoU Comparison")
plt.legend()
plt.tight_layout()
plt.savefig("class_iou_comparison.png")
print("\nTraining completed successfully!")
if __name__ == "__main__":
def args_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--cfg_file', type=str, default="rl_seg/configs/rl_seg_leap.py",
help='specify the config for training')
parser.add_argument('--batch_size', type=int, default=16, required=False, help='batch size for training')
parser.add_argument('--epochs', type=int, default=20, required=False, help='number of epochs to train for')
parser.add_argument('--workers', type=int, default=10, help='number of workers for dataloader')
parser.add_argument('--extra_tag', type=str, default='default', help='extra tag for this experiment')
parser.add_argument('--ckpt', type=str, default=None, help='checkpoint to start from')
parser.add_argument('--pretrained_model', type=str, default=None, help='pretrained_model')
return parser.parse_args()
args = args_config()
main(args)
最新发布