卡尔曼滤波实现车辆精准跟踪

import cv2
import numpy as np
import os
import glob
from collections import defaultdict
import random


class KalmanTracker:
    """基于卡尔曼滤波的单车辆跟踪器"""

    def __init__(self, initial_bbox, track_id):
        # 初始化卡尔曼滤波器
        self.kalman = cv2.KalmanFilter(8, 4)  # 8个状态变量,4个观测变量
        self.track_id = track_id  # 唯一的跟踪ID

        # 状态转移矩阵 (x, y, w, h, dx, dy, dw, dh)
        self.kalman.transitionMatrix = np.array([
            [1, 0, 0, 0, 1, 0, 0, 0],
            [0, 1, 0, 0, 0, 1, 0, 0],
            [0, 0, 1, 0, 0, 0, 1, 0],
            [0, 0, 0, 1, 0, 0, 0, 1],
            [0, 0, 0, 0, 1, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 0, 0],
            [0, 0, 0, 0, 0, 0, 1, 0],
            [0, 0, 0, 0, 0, 0, 0, 1]
        ], dtype=np.float32)

        # 观测矩阵
        self.kalman.measurementMatrix = np.array([
            [1, 0, 0, 0, 0, 0, 0, 0],
            [0, 1, 0, 0, 0, 0, 0, 0],
            [0, 0, 1, 0, 0, 0, 0, 0],
            [0, 0, 0, 1, 0, 0, 0, 0]
        ], dtype=np.float32)

        # 过程噪声协方差矩阵 - 针对车辆运动优化
        self.kalman.processNoiseCov = np.eye(8, dtype=np.float32) * 0.03

        # 测量噪声协方差矩阵 - 降低噪声,提高平滑度
        self.kalman.measurementNoiseCov = np.eye(4, dtype=np.float32) * 0.05

        # 初始状态
        x, y, w, h = initial_bbox
        self.kalman.statePre = np.array([x, y, w, h, 0, 0, 0, 0], dtype=np.float32)
        self.kalman.statePost = np.array([x, y, w, h, 0, 0, 0, 0], dtype=np.float32)

        # 初始协方差矩阵
        self.kalman.errorCovPre = np.eye(8, dtype=np.float32)
        self.kalman.errorCovPost = np.eye(8, dtype=np.float32)

        self.age = 0
        self.total_visible_count = 0
        self.consecutive_invisible_count = 0

    def predict(self):
        """预测下一帧的位置"""
        prediction = self.kalman.predict()
        return prediction[:4]  # 返回 (x, y, w, h)

    def update(self, bbox):
        """使用观测值更新卡尔曼滤波器"""
        measurement = np.array(bbox, dtype=np.float32)
        self.kalman.correct(measurement)
        self.age += 1
        self.total_visible_count += 1
        self.consecutive_invisible_count = 0

    def mark_missed(self):
        """标记未检测到目标"""
        self.age += 1
        self.consecutive_invisible_count += 1


class VehicleTracker:
    """车辆多目标跟踪器"""

    def __init__(self):
        self.trackers = {}  # 使用字典存储,key为track_id
        self.next_id = 1  # 从1开始分配ID
        self.trajectories = {}  # 每个车辆的独立轨迹,key为track_id
        self.colors = {}  # 每个车辆的颜色,key为track_id
        self.max_disappeared = 30  # 最大消失帧数
        self.min_hits = 3  # 最小命中次数

    def generate_color(self):
        """生成随机颜色"""
        return (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255))

    def is_valid_vehicle_movement(self, old_center, new_center, max_speed=100):
        """验证车辆移动是否合理,过滤异常跳跃"""
        if old_center is None:
            return True

        distance = np.sqrt((new_center[0] - old_center[0]) ** 2 +
                           (new_center[1] - old_center[1]) ** 2)
        return distance <= max_speed

    def calculate_iou(self, box1, box2):
        """计算两个边界框的IoU"""
        x1, y1, w1, h1 = box1
        x2, y2, w2, h2 = box2

        # 计算交集
        xi1 = max(x1, x2)
        yi1 = max(y1, y2)
        xi2 = min(x1 + w1, x2 + w2)
        yi2 = min(y1 + h1, y2 + h2)

        if xi2 <= xi1 or yi2 <= yi1:
            return 0

        inter_area = (xi2 - xi1) * (yi2 - yi1)
        box1_area = w1 * h1
        box2_area = w2 * h2
        union_area = box1_area + box2_area - inter_area

        return inter_area / union_area if union_area > 0 else 0

    def update(self, detections):
        """更新跟踪器"""
        # 预测所有跟踪器的位置
        predictions = {}
        for track_id, tracker in self.trackers.items():
            pred = tracker.predict()
            predictions[track_id] = pred

        # 计算IoU矩阵
        track_ids = list(predictions.keys())
        iou_matrix = np.zeros((len(track_ids), len(detections)))

        for i, track_id in enumerate(track_ids):
            for j, det in enumerate(detections):
                iou_matrix[i, j] = self.calculate_iou(predictions[track_id], det)

        # 简单的贪婪匹配
        matched_pairs = []
        iou_threshold = 0.3
        used_detections = set()
        used_trackers = set()

        # 按IoU从大到小排序进行匹配
        matches = []
        for i in range(len(track_ids)):
            for j in range(len(detections)):
                if iou_matrix[i, j] > iou_threshold:
                    matches.append((iou_matrix[i, j], i, j))

        matches.sort(reverse=True)  # 按IoU降序排列

        for iou, i, j in matches:
            if i not in used_trackers and j not in used_detections:
                track_id = track_ids[i]
                matched_pairs.append((track_id, j))
                used_trackers.add(i)
                used_detections.add(j)

        # 更新匹配的跟踪器
        for track_id, detection_idx in matched_pairs:
            self.trackers[track_id].update(detections[detection_idx])

            # 计算车辆中心点
            center_x = detections[detection_idx][0] + detections[detection_idx][2] // 2
            center_y = detections[detection_idx][1] + detections[detection_idx][3] // 2
            new_center = (center_x, center_y)

            # 确保该车辆有独立的轨迹列表
            if track_id not in self.trajectories:
                self.trajectories[track_id] = []

            # 验证移动是否合理(避免异常跳跃)
            old_center = self.trajectories[track_id][-1] if self.trajectories[track_id] else None
            if self.is_valid_vehicle_movement(old_center, new_center):
                self.trajectories[track_id].append(new_center)
            else:
                # 如果移动异常,使用预测位置
                pred_bbox = self.trackers[track_id].predict()
                pred_center = (int(pred_bbox[0] + pred_bbox[2] // 2),
                               int(pred_bbox[1] + pred_bbox[3] // 2))
                self.trajectories[track_id].append(pred_center)

            # 限制轨迹长度(保留最近60个点,足够显示流畅轨迹)
            if len(self.trajectories[track_id]) > 60:
                self.trajectories[track_id] = self.trajectories[track_id][-60:]

        # 标记未匹配的跟踪器
        for i, track_id in enumerate(track_ids):
            if i not in used_trackers:
                self.trackers[track_id].mark_missed()

        # 为未匹配的检测创建新的跟踪器
        for j in range(len(detections)):
            if j not in used_detections:
                # 创建新的车辆跟踪器
                new_id = self.next_id
                self.next_id += 1

                new_tracker = KalmanTracker(detections[j], new_id)
                self.trackers[new_id] = new_tracker

                # 为新车辆分配颜色
                self.colors[new_id] = self.generate_color()

                # 初始化新车辆的轨迹
                center_x = detections[j][0] + detections[j][2] // 2
                center_y = detections[j][1] + detections[j][3] // 2
                self.trajectories[new_id] = [(center_x, center_y)]

        # 删除长时间未检测到的跟踪器
        trackers_to_remove = []
        for track_id, tracker in self.trackers.items():
            if tracker.consecutive_invisible_count >= self.max_disappeared:
                trackers_to_remove.append(track_id)

        # 清理失效的跟踪器及其相关数据
        for track_id in trackers_to_remove:
            del self.trackers[track_id]
            if track_id in self.trajectories:
                del self.trajectories[track_id]
            if track_id in self.colors:
                del self.colors[track_id]

    def get_active_tracks(self):
        """获取当前活跃的跟踪结果"""
        active_tracks = []
        for track_id, tracker in self.trackers.items():
            if (tracker.total_visible_count >= self.min_hits and
                    tracker.consecutive_invisible_count <= 1):
                pred = tracker.predict()
                active_tracks.append((track_id, pred))
        return active_tracks

    def smooth_trajectory(self, trajectory, window_size=5):
        """对轨迹进行平滑处理,使车辆轨迹更流畅"""
        if len(trajectory) < window_size:
            return trajectory

        smoothed = []
        for i in range(len(trajectory)):
            # 计算滑动窗口的平均位置
            start_idx = max(0, i - window_size // 2)
            end_idx = min(len(trajectory), i + window_size // 2 + 1)

            window_points = trajectory[start_idx:end_idx]
            avg_x = sum(p[0] for p in window_points) / len(window_points)
            avg_y = sum(p[1] for p in window_points) / len(window_points)

            smoothed.append((int(avg_x), int(avg_y)))

        return smoothed


def detect_vehicles(frame, background_subtractor):
    """检测车辆,返回检测结果和前景掩码"""
    # 高斯模糊预处理,减少噪声
    blurred = cv2.GaussianBlur(frame, (5, 5), 0)

    # 背景减除
    fg_mask = background_subtractor.apply(blurred)

    # 去除阴影
    fg_mask[fg_mask == 127] = 0  # 127是阴影像素值

    # 形态学操作去噪 - 针对车辆优化
    kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_OPEN, kernel_open)

    kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
    fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_CLOSE, kernel_close)

    # 填充空洞
    kernel_fill = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_CLOSE, kernel_fill)

    # 查找轮廓
    contours, _ = cv2.findContours(fg_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    detections = []
    min_area = 1000  # 车辆最小面积
    max_area = 40000  # 车辆最大面积
    min_aspect_ratio = 0.4  # 最小宽高比
    max_aspect_ratio = 3.5  # 最大宽高比

    for contour in contours:
        area = cv2.contourArea(contour)
        if min_area < area < max_area:
            x, y, w, h = cv2.boundingRect(contour)
            aspect_ratio = w / h if h > 0 else 0

            # 过滤掉不像车辆的检测
            if min_aspect_ratio < aspect_ratio < max_aspect_ratio:
                # 添加边界检查,避免边缘检测
                if x > 10 and y > 10 and x + w < frame.shape[1] - 10 and y + h < frame.shape[0] - 10:
                    detections.append((x, y, w, h))

    return detections, fg_mask


def draw_vehicle_info(frame, track_id, bbox, color, trajectory, vehicle_tracker):
    """绘制单个车辆的信息"""
    x, y, w, h = [int(v) for v in bbox]

    # 画边界框
    cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)

    # 画车辆ID
    label = f'Car {track_id}'
    label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
    cv2.rectangle(frame, (x, y - label_size[1] - 10),
                  (x + label_size[0], y), color, -1)
    cv2.putText(frame, label, (x, y - 5),
                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

    # 绘制该车辆的平滑轨迹
    if len(trajectory) > 1:
        # 对轨迹进行平滑处理
        smoothed_trajectory = vehicle_tracker.smooth_trajectory(trajectory)

        # 绘制平滑轨迹线 - 只连接该车辆自己的轨迹点
        for i in range(1, len(smoothed_trajectory)):
            # 根据轨迹点的新旧程度调整线条粗细
            line_thickness = max(1, int(3 * (i / len(smoothed_trajectory))))
            cv2.line(frame, smoothed_trajectory[i - 1], smoothed_trajectory[i], color, line_thickness)

        # 绘制关键轨迹点
        key_points = smoothed_trajectory[-20::3]  # 每3个点取一个,显示最近20个点
        for i, point in enumerate(key_points):
            alpha = 0.4 + 0.6 * (i / len(key_points))
            point_color = tuple(int(c * alpha) for c in color)
            radius = 2 + int(2 * alpha)  # 越新的点越大
            cv2.circle(frame, point, radius, point_color, -1)


def main():
    # 视频文件路径
    video_dir = r"D:\BaiduNetdiskDownload"

    # 查找视频文件
    video_extensions = ['*.mp4', '*.avi', '*.mov', '*.mkv', '*.flv', '*.wmv']
    video_files = []

    for ext in video_extensions:
        video_files.extend(glob.glob(os.path.join(video_dir, ext)))

    if not video_files:
        print(f"在 {video_dir} 中未找到视频文件")
        print("支持的格式: mp4, avi, mov, mkv, flv, wmv")
        return

    print("找到的视频文件:")
    for i, video_file in enumerate(video_files):
        print(f"{i}: {os.path.basename(video_file)}")

    # 选择视频文件
    if len(video_files) == 1:
        selected_video = video_files[0]
    else:
        try:
            choice = int(input("请选择视频文件编号: "))
            selected_video = video_files[choice]
        except (ValueError, IndexError):
            print("无效选择,使用第一个视频文件")
            selected_video = video_files[0]

    print(f"正在处理车流视频: {os.path.basename(selected_video)}")

    # 打开视频
    cap = cv2.VideoCapture(selected_video)
    if not cap.isOpened():
        print("无法打开视频文件")
        return

    # 初始化背景减除器和车辆跟踪器
    background_subtractor = cv2.createBackgroundSubtractorMOG2(
        detectShadows=True, varThreshold=50, history=500)
    vehicle_tracker = VehicleTracker()

    # 获取视频信息
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # 设置输出视频
    output_path = os.path.join(video_dir, 'vehicle_tracking_output.mp4')
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    frame_count = 0
    print("开始处理车流视频... 按 'q' 退出预览")
    print(f"总帧数: {total_frames}, FPS: {fps}")

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frame_count += 1

        # 检测车辆
        detections, fg_mask = detect_vehicles(frame, background_subtractor)

        # 更新车辆跟踪器
        vehicle_tracker.update(detections)

        # 获取活跃的车辆跟踪结果
        active_tracks = vehicle_tracker.get_active_tracks()

        # 为每辆车绘制独立的跟踪信息
        for track_id, bbox in active_tracks:
            color = vehicle_tracker.colors.get(track_id, (0, 255, 0))
            trajectory = vehicle_tracker.trajectories.get(track_id, [])
            draw_vehicle_info(frame, track_id, bbox, color, trajectory, vehicle_tracker)

        # 显示统计信息
        info_text = [
            f'Frame: {frame_count}/{total_frames}',
            f'Active Cars: {len(active_tracks)}',
            f'Total Detections: {len(detections)}',
            f'Progress: {frame_count / total_frames * 100:.1f}%'
        ]

        for i, text in enumerate(info_text):
            cv2.putText(frame, text, (10, 30 + i * 25),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

        # 写入输出视频
        out.write(frame)

        # 显示结果窗口
        display_frame = cv2.resize(frame, (1000, 700))
        display_mask = cv2.resize(fg_mask, (400, 280))

        # 将前景掩码转换为3通道图像以便显示
        fg_mask_colored = cv2.cvtColor(display_mask, cv2.COLOR_GRAY2BGR)

        # 在前景掩码上绘制检测框
        mask_scale_x = display_mask.shape[1] / width
        mask_scale_y = display_mask.shape[0] / height

        for detection in detections:
            x, y, w, h = detection
            scaled_x = int(x * mask_scale_x)
            scaled_y = int(y * mask_scale_y)
            scaled_w = int(w * mask_scale_x)
            scaled_h = int(h * mask_scale_y)
            cv2.rectangle(fg_mask_colored, (scaled_x, scaled_y),
                          (scaled_x + scaled_w, scaled_y + scaled_h), (0, 255, 0), 1)

        cv2.imshow('Vehicle Tracking - Main View', display_frame)
        cv2.imshow('Vehicle Detection - Foreground Mask', fg_mask_colored)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

        # 显示进度
        if frame_count % 100 == 0:
            print(f"已处理 {frame_count}/{total_frames} 帧 ({frame_count / total_frames * 100:.1f}%)")

    # 清理
    cap.release()
    out.release()
    cv2.destroyAllWindows()

    print(f"\n车辆跟踪完成!")
    print(f"输出视频保存在: {output_path}")
    print(f"总共处理了 {frame_count} 帧")
    print(f"检测到的最大车辆数: {len(vehicle_tracker.trajectories)}")

    # 统计信息
    if vehicle_tracker.trajectories:
        avg_trajectory_length = sum(len(traj) for traj in vehicle_tracker.trajectories.values()) / len(
            vehicle_tracker.trajectories)
        print(f"平均轨迹长度: {avg_trajectory_length:.1f} 个点")


if __name__ == "__main__":
    main()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值