多进程pyton代码高效进行数据统计demo

写在前面

这个事情的背景是今天工作老板交代了一个数据统计的任务, 给了一个txt文件, 数据的每一行是一个clip_adrn,组织形式大概是这样:

clip_adrn1,
clip_adrn2,
......

每个clip, 可以通过一个函数get_available_channels去获取里面的所有通道名。

老板想看下, 同时含有 ["/pilot/map/position", "/perception/frame/traffic_sign"]通道的clip的数量占比。

这个任务刚拿过来之后以为非常简单, 然后确定逻辑:

1. 读数txt文件
2. 写个for循环,遍历每个clip_adrn
3. 针对每个clip_adrn, 获取所有通道, 如果同时含有上面两个通道,cnt+=1
4. 输出最后结果

暴力破解

完成分析之后,立马写了第一版代码, 这个代码不用解释,逻辑非常之简单:

import os
import sys

from ad_cloud.adrn.data_seeker.dataclip import get_available_channels
from tqdm import tqdm
# adrn = "adr::dataclip:eng.bhd.0039::1683435549501586176:1683435551001586176"
# channels_list = get_available_channels(adrn)
# print(channels_list)

if __name__ == '__main__':
    file_path = sys.argv[1]

    if not os.path.exists(file_path):
        print(f"Error: File '{file_path}' does not exist")
        sys.exit(1)

    with open(file_path, 'r') as file:
        adrns = file.readlines()

    adrns = [adrn.strip() for adrn in adrns]
    adrns_count = len(adrns)
    print(f"Total Adrns: {adrns_count}")

    select_channels = ["/pilot/map/position", "/perception/frame/traffic_sign"]

    has_both_channels_count = 0
    has_map_channel_count = 0
    has_perception_channel_count = 0
    has_no_channel_count = 0
	
	# 遍历每个adrn, 去统计
    for i in tqdm(range(adrns_count)):
        channels_list = get_available_channels(adrns[i])
        if select_channels[0] in channels_list and select_channels[1] in channels_list:
            has_both_channels_count += 1
        elif select_channels[0] in channels_list:
            has_map_channel_count += 1
        elif select_channels[1] in channels_list:
            has_perception_channel_count += 1
        else:
            has_no_channel_count += 1
	
	# 最后输出同时有俩通道的,有其中一个通道的 以及没有这俩通道的统计结果
    print(f"Adrns with both channels: {has_both_channels_count}/{adrns_count}")
    print(f"Adrns with map channel: {has_map_channel_count}/{adrns_count}")
    print(f"Adrns with traffic_sign channel: {has_perception_channel_count}/{adrns_count}")
    print(f"Adrns with no map and traffic_sign channel: {has_no_channel_count}/{adrns_count}")

写完之后, 导入文件,一运行, 发现事情不是这么简单。

原因是, clip_adrn的数量3000000, 平均1个clip_adrn处理用1s的话, 处理完这么多数据也得需要30多天

What? 老板今天交代的任务, 难不成要说: 老板,这个数据有点多,一个月后才能给你结果。 老板可能说: 呵呵, 滚蛋。

优雅多进程

看到这个结果之后,我就发现,这个事情不能暴力破解, 得需要做一些优化的方式,最快的方式就是 并行化

于是想用多进程处理下这个任务。

多进程的话, 需要考虑几个问题:

  1. 用多少个进程去处理
  2. 这么多数据怎么去读
  3. 如果中途统计出现了中断, 应该怎么回复任务
  4. 如果中途卡死了, 应该怎么手动结束程序, 再通过断点进行恢复
  5. 多进程应该怎么写

带着这几个问题, 在LLM的帮助之下, 慢慢得到了代码。

下面基于每个问题去做分析, 给出分块代码,最后给出完整代码去完成任务。 导入包:

import os
import sys
import multiprocessing as mp
from functools import partial
from ad_cloud.adrn.data_seeker.dataclip import get_available_channels
from tqdm import tqdm
import time
import pandas as pd
import json
import signal
import hashlib
import uuid
import threading
import queue
import concurrent.futures

# 全局变量
# 用于在多线程环境中通知所有线程停止运行。当程序需要停止时,会调用 stop_event.set(),所有监听该事件的线程会收到通知并停止工作。
stop_event = threading.Event()  # 用于通知所有线程停止
# 用于在多线程环境中收集各个线程的处理结果。线程可以将结果放入队列中,主程序可以从队列中取出结果进行汇总。
result_queue = queue.Queue()  # 用于收集结果 

让程序优雅退出

这里是写了一个类,处理接收到的终止信号,比如我们常用的Ctrl+C。

主要目的是实现程序的优雅退出机制。通过全局变量和信号处理机制,程序可以在接收到终止信号时,通知所有线程停止工作,并在必要时强制退出。这有助于在多线程环境中安全地关闭程序,避免数据丢失或资源泄漏。

class GracefulKiller:
    """处理SIGINT和SIGTERM信号,允许程序优雅退出"""
    kill_now = False
    kill_count = 0

    def __init__(self):
        # 注册信号处理函数
        # SIGINT: 用户 Ctrl+C产生的信号
        signal.signal(signal.SIGINT, self.exit_gracefully)
        # SIGTERM: 通常由系统或用户通过命令行发送的终止信号
        signal.signal(signal.SIGTERM, self.exit_gracefully)

    def exit_gracefully(self, signum):
        """信号处理函数"""
        self.kill_count += 1
		
		# 如果第一次收到终止信号, 优雅停止, 但我实际操作里面这个好像终止不了
        if self.kill_count == 1:
            print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
            print("Press Ctrl+C again to force quit (may cause data loss).")
            self.kill_now = True
            stop_event.set()  # 设置停止事件,通知所有线程
        else:  # 强制停止, 不保存进度,可能会数据丢失
            print("\nForcing immediate termination...")
            # 强制退出,不保存进度
            os._exit(1)
 
def worker_init():
    """工作进程初始化函数
    设置子进程忽略SIGINT信号,由主进程处理
    """
    signal.signal(signal.SIGINT, signal.SIG_IGN)

打断点

300w的数据量, 处理起来一定耗时很长,和模型训练一样,如果中途担心怕断了重新训练,需要打断点处理, 这样中途断了也能续上继续跑。

这里的实现思路是:多进程的话,文件要做分块处理, 每一块使用多进程快速完成。 所以处理完一块之后,要打上断点,记录一些中间的关键结果到公共文件,这样中途即使断掉, 也没有问题了。

def get_job_id(file_path):
    """生成基于输入文件的唯一作业ID"""
    return hashlib.md5(file_path.encode()).hexdigest()[:8]

def get_progress_dir():
    """获取进度文件目录,优先使用环境变量指定的目录"""
    progress_dir = os.environ.get('PROGRESS_DIR', '/mnt/output-data-storage/zhongqiang/adrn_progress')
    os.makedirs(progress_dir, exist_ok=True)
    return progress_dir

def get_progress_file(job_id):
    """获取进度文件路径"""
    return os.path.join(get_progress_dir(), f"{job_id}_progress.json")

def save_progress(job_id, current_chunk, processed_count, counters):
    """保存当前进度到文件"""
    progress = {
        'current_chunk': current_chunk,
        'processed_count': processed_count,
        'counters': counters,
        'timestamp': time.time()
    }
    
    progress_file = get_progress_file(job_id)
    with open(progress_file, 'w') as f:
        json.dump(progress, f)
    print(f"Progress saved to {progress_file}")

def load_progress(job_id):
    """从文件加载进度"""
    progress_file = get_progress_file(job_id)
    
    if not os.path.exists(progress_file):
        return None
    
    try:
        with open(progress_file, 'r') as f:
            progress = json.load(f)
        print(f"Loaded progress from {progress_file}")
        return progress
    except Exception as e:
        print(f"Error loading progress: {e}")
        return None

针对输入的数据文件名,生成一个唯一的job_id, 然后根据这个job_id生成一个{job_id}_progress.json保存断点信息, 这里的断点信息是:

progress = {
        'current_chunk': current_chunk,
        'processed_count': processed_count,
        'counters': counters,
        'timestamp': time.time()
    }

当前的current_chunk, 已经处理的chunk 数量, 统计结果和时间。 这样再次恢复的时候,就能从current_chunk+1继续往后统计了。

启动程序的时候,如果有这样的断点文件,读入断点文件,往后继续,如果没有,那就从头开始。

多进程

那么多进程代码应该怎么写呢?

先写一个处理每个adrn的函数, 给定一个adrn, 需要统计通道。

def process_adrn(adrn, select_channels):
    """处理单个ADRN并返回统计结果
    返回格式: (both_channels, map_only, traffic_only, neither)
    """
    try:
        channels_list = get_available_channels(adrn)
        if select_channels[0] in channels_list and select_channels[1] in channels_list:
            return (1, 0, 0, 0)  # 两个通道都有
        elif select_channels[0] in channels_list:
            return (0, 1, 0, 0)  # 只有地图通道
        elif select_channels[1] in channels_list:
            return (0, 0, 1, 0)  # 只有交通标志通道
        else:
            return (0, 0, 0, 1)  # 两个通道都没有
    except Exception as e:
        print(f"Error processing {adrn}: {e}")
        return (0, 0, 0, 0)  # 出错情况不计入统计

再写一个处理一个chunk的函数, 即多个clip adrn,我们用多进程处理。

def process_chunk(pool, chunk, process_func, chunk_index, total_chunks, killer):
    """处理单个数据块,支持优雅终止
	pool: 多近程池对象,用于并行处理数据
	chunk: 当前要处理的数据快, ADRN列表的一部分
	process_func: 处理单个ADRN函数
	chunk_index: 当前块的索引
	total_chunks: 总块数
	killer: 用于检测终止信号
	
	"""
    print(f"\nProcessing chunk {chunk_index+1}/{total_chunks}")
    
    # 创建进度条
    pbar = tqdm(total=len(chunk), desc=f"Chunk {chunk_index+1}/{total_chunks}")
    
    # 使用线程池处理结果,避免阻塞主进程
    results = []
    try:
        # 使用imap_unordered获取结果,同时支持中断
        # imap_unordered: 将一个可迭代对象(如列表、元组等)中的元素作为参数,分发给多个进程,然后将这些进程的返回结果收集起来。不过,与imap方法不同的是,imap_unordered返回的结果顺序是不确定的
        # 例如,假设你有一个函数process_func,它接受一个参数并进行处理。你有一个元素列表chunk,你想对chunk中的每个元素应用process_func。使用pool.imap_unordered(process_func,chunk)后,多个进程会同时对chunk中的元素进行处理,但返回的结果顺序可能和chunk中元素的顺序不同。这在某些情况下是有优势的,比如当你只关心每个元素的处理结果,而不关心结果的顺序时,使用imap_unordered可以提高效率,因为它可以更快地收集到已经完成的任务的结果,而不需要等待所有任务按照顺序完成。
        for result in pool.imap_unordered(process_func, chunk):
            # 检查是否需要停止
            if killer.kill_now:
                pbar.close()
                return results, False
                
            if result is not None:
                results.append(result)
                result_queue.put(result)  # 将结果放入队列
                
            pbar.update(1)
    except Exception as e:
        print(f"Error processing chunk {chunk_index+1}: {e}")
    
    pbar.close()
    return results, True

主函数

def main():
    """主函数:程序入口点"""
    # 参数检查和初始化
    if len(sys.argv) < 2:
        print("Usage: python script.py <file_path> [num_processes] [chunk_size]")
        sys.exit(1)
    
    # 这种添加参数的方法也可以参考下
    file_path = sys.argv[1]
    # 进程这里由于我这里的机器是32核, 这里留了4核给系统, 用了28核
    num_processes = int(sys.argv[2]) if len(sys.argv) > 2 else 28
    chunk_size = int(sys.argv[3]) if len(sys.argv) > 3 else 5000
    
    # 生成基于输入文件的唯一作业ID
    job_id = get_job_id(file_path)
    print(f"Job ID: {job_id}")
    
    # 初始化信号处理器,允许优雅退出
    killer = GracefulKiller()
    
    # 检查输入文件是否存在
    if not os.path.exists(file_path):
        print(f"Error: File '{file_path}' does not exist")
        sys.exit(1)

    # 读取输入文件
    # 这里其实是都读入进来了
    print(f"Reading input file: {file_path}")
    with open(file_path, 'r') as file:
        adrns = file.readlines()

    adrns = [adrn.strip() for adrn in adrns]
    adrns_count = len(adrns)
    print(f"Total Adrns: {adrns_count}")
    print(f"Using {num_processes} processes")
    print(f"Chunk size: {chunk_size}")
    print(f"Progress directory: {get_progress_dir()}")

    select_channels = ["/pilot/map/position", "/perception/frame/traffic_sign"]
    
    # 将ADRN列表分割成固定大小的数据块
    # 读入进来之后才分的chunk
    chunks = [adrns[i:i+chunk_size] for i in range(0, adrns_count, chunk_size)]
    
    # 初始化统计计数器
    has_both_channels_count = 0
    has_map_channel_count = 0
    has_perception_channel_count = 0
    has_no_channel_count = 0
    
    # 检查是否有保存的进度
    progress = load_progress(job_id)
    # 如果已经有断点信息
    if progress:
        start_chunk = progress['current_chunk']
        processed_count = progress['processed_count']
        counters = progress['counters']
        
        # 结果用保存好的结果
        has_both_channels_count = counters[0]
        has_map_channel_count = counters[1]
        has_perception_channel_count = counters[2]
        has_no_channel_count = counters[3]
        
        print(f"Resuming from chunk {start_chunk+1}/{len(chunks)}")
        print(f"Already processed {processed_count}/{adrns_count} ADRNs")
        print(f"Current counters: both={has_both_channels_count}, map={has_map_channel_count}, "
              f"traffic={has_perception_channel_count}, neither={has_no_channel_count}")
    else:
        start_chunk = 0
        processed_count = 0
        print("Starting from the beginning")
    
    start_time = time.time()
    
    # 创建偏函数函数,固定select_channels参数
    process_func = partial(process_adrn, select_channels=select_channels)
    
    # 创建进程池
    pool = mp.Pool(processes=num_processes, initializer=worker_init)
    
    # 从上次中断的块开始处理
    for i in range(start_chunk, len(chunks)):
        chunk = chunks[i]
        
        # 处理当前块
        results, success = process_chunk(
            pool, chunk, process_func, i, len(chunks), killer
        )
        
        # 如果处理成功,汇总结果
        if success:
            # 汇总当前块的统计结果
            chunk_both = sum(r[0] for r in results)
            chunk_map = sum(r[1] for r in results)
            chunk_traffic = sum(r[2] for r in results)
            chunk_neither = sum(r[3] for r in results)
            
            # 更新全局计数器
            has_both_channels_count += chunk_both
            has_map_channel_count += chunk_map
            has_perception_channel_count += chunk_traffic
            has_no_channel_count += chunk_neither
            
            # 更新已处理ADRN数量
            processed_count += len(chunk)
            if processed_count > adrns_count:
                processed_count = adrns_count
                
            # 保存当前进度
            save_progress(
                job_id, 
                i + 1, 
                processed_count, 
                [has_both_channels_count, has_map_channel_count, has_perception_channel_count, has_no_channel_count]
            )
            
            # 计算已用时间和预计剩余时间
            elapsed = time.time() - start_time
            if processed_count > 0:
                rate = processed_count / elapsed
                remaining = (adrns_count - processed_count) / rate
                print(f"Processed: {processed_count}/{adrns_count} ({processed_count/adrns_count*100:.2f}%)")
                print(f"Elapsed: {elapsed/60:.2f} minutes, Remaining: {remaining/60:.2f} minutes")
                print(f"Rate: {rate:.2f} ADRNs/second")
                
                # 每处理完10个块保存一次中间结果
                if (i + 1) % 10 == 0:
                    save_intermediate_results(
                        job_id, i + 1, processed_count, adrns_count,
                        has_both_channels_count, has_map_channel_count,
                        has_perception_channel_count, has_no_channel_count,
                        elapsed
                    )
        
        # 检查是否需要停止
        if killer.kill_now:
            print("Gracefully shutting down...")
            
            # 保存当前进度
            save_progress(
                job_id, 
                i,  # 保存当前块的索引,而不是下一个块
                processed_count, 
                [has_both_channels_count, has_map_channel_count, has_perception_channel_count, has_no_channel_count]
            )
            
            # 清理进程池
            pool.terminate()
            pool.join()
            
            print("Shutdown complete. Progress saved.")
            sys.exit(0)

    # 处理完成后关闭进程池
    pool.close()
    pool.join()
    
    # 删除进度文件
    progress_file = get_progress_file(job_id)
    if os.path.exists(progress_file):
        os.remove(progress_file)
        print(f"Progress file {progress_file} deleted")

    # 打印最终统计结果
    print("\n===== Final Results =====")
    print(f"Adrns with both channels: {has_both_channels_count}/{adrns_count} ({has_both_channels_count/adrns_count*100:.2f}%)")
    print(f"Adrns with map channel: {has_map_channel_count}/{adrns_count} ({has_map_channel_count/adrns_count*100:.2f}%)")
    print(f"Adrns with traffic_sign channel: {has_perception_channel_count}/{adrns_count} ({has_perception_channel_count/adrns_count*100:.2f}%)")
    print(f"Adrns with no map and traffic_sign channel: {has_no_channel_count}/{adrns_count} ({has_no_channel_count/adrns_count*100:.2f}%)")
    
    # 计算并打印总耗时
    total_time = time.time() - start_time
    print(f"\nTotal processing time: {total_time/60:.2f} minutes")
    
    # 保存最终结果到CSV
    save_final_results(
        job_id, adrns_count, has_both_channels_count, has_map_channel_count,
        has_perception_channel_count, has_no_channel_count, total_time
    )

多进程最终代码

最后附上一个能跑的最终代码

import os
import sys
import multiprocessing as mp
from functools import partial
from ad_cloud.adrn.data_seeker.dataclip import get_available_channels
from tqdm import tqdm
import time
import pandas as pd
import json
import signal
import hashlib
import uuid
import threading
import queue
import concurrent.futures

# 全局变量
stop_event = threading.Event()  # 用于通知所有线程停止
result_queue = queue.Queue()  # 用于收集结果

class GracefulKiller:
    """处理SIGINT和SIGTERM信号,允许程序优雅退出"""
    kill_now = False
    kill_count = 0
    
    def __init__(self):
        # 注册信号处理函数
        signal.signal(signal.SIGINT, self.exit_gracefully)
        signal.signal(signal.SIGTERM, self.exit_gracefully)

    def exit_gracefully(self, signum, frame):
        """信号处理函数"""
        self.kill_count += 1
        
        if self.kill_count == 1:
            print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
            print("Press Ctrl+C again to force quit (may cause data loss).")
            self.kill_now = True
            stop_event.set()  # 设置停止事件,通知所有线程
        else:
            print("\nForcing immediate termination...")
            # 强制退出,不保存进度
            os._exit(1)

def process_adrn(adrn, select_channels):
    """处理单个ADRN并返回统计结果
    返回格式: (both_channels, map_only, traffic_only, neither)
    """
    try:
        channels_list = get_available_channels(adrn)
        if select_channels[0] in channels_list and select_channels[1] in channels_list:
            return (1, 0, 0, 0)  # 两个通道都有
        elif select_channels[0] in channels_list:
            return (0, 1, 0, 0)  # 只有地图通道
        elif select_channels[1] in channels_list:
            return (0, 0, 1, 0)  # 只有交通标志通道
        else:
            return (0, 0, 0, 1)  # 两个通道都没有
    except Exception as e:
        print(f"Error processing {adrn}: {e}")
        return (0, 0, 0, 0)  # 出错情况不计入统计

def worker_init():
    """工作进程初始化函数
    设置子进程忽略SIGINT信号,由主进程处理
    """
    signal.signal(signal.SIGINT, signal.SIG_IGN)

def get_job_id(file_path):
    """生成基于输入文件的唯一作业ID"""
    return hashlib.md5(file_path.encode()).hexdigest()[:8]

def get_progress_dir():
    """获取进度文件目录,优先使用环境变量指定的目录"""
    progress_dir = os.environ.get('PROGRESS_DIR', '/tmp/adrn_progress')
    os.makedirs(progress_dir, exist_ok=True)
    return progress_dir

def get_progress_file(job_id):
    """获取进度文件路径"""
    return os.path.join(get_progress_dir(), f"{job_id}_progress.json")

def save_progress(job_id, current_chunk, processed_count, counters):
    """保存当前进度到文件"""
    progress = {
        'current_chunk': current_chunk,
        'processed_count': processed_count,
        'counters': counters,
        'timestamp': time.time()
    }
    
    progress_file = get_progress_file(job_id)
    with open(progress_file, 'w') as f:
        json.dump(progress, f)
    print(f"Progress saved to {progress_file}")

def load_progress(job_id):
    """从文件加载进度"""
    progress_file = get_progress_file(job_id)
    
    if not os.path.exists(progress_file):
        return None
    
    try:
        with open(progress_file, 'r') as f:
            progress = json.load(f)
        print(f"Loaded progress from {progress_file}")
        return progress
    except Exception as e:
        print(f"Error loading progress: {e}")
        return None

def save_intermediate_results(job_id, chunk_num, processed, total, both, map_only, traffic_only, neither, elapsed):
    """保存中间处理结果"""
    results = {
        'chunk_number': [chunk_num],
        'processed_adrns': [processed],
        'total_adrns': [total],
        'both_channels': [both],
        'map_only': [map_only],
        'traffic_only': [traffic_only],
        'neither': [neither],
        'elapsed_minutes': [elapsed/60],
        'rate_adrns_per_second': [processed/elapsed]
    }
    df = pd.DataFrame(results)
    
    csv_file = os.path.join(get_progress_dir(), f"{job_id}_intermediate.csv")
    
    if chunk_num == 10:
        df.to_csv(csv_file, index=False)
    else:
        df.to_csv(csv_file, index=False, mode='a', header=False)

def save_final_results(job_id, total, both, map_only, traffic_only, neither, total_time):
    """保存最终结果到CSV"""
    results = {
        'total_adrns': [total],
        'both_channels': [both],
        'both_channels_percent': [both/total*100],
        'map_only': [map_only],
        'map_only_percent': [map_only/total*100],
        'traffic_only': [traffic_only],
        'traffic_only_percent': [traffic_only/total*100],
        'neither': [neither],
        'neither_percent': [neither/total*100],
        'total_time_minutes': [total_time/60],
        'rate_adrns_per_second': [total/total_time]
    }
    df = pd.DataFrame(results)
    
    csv_file = os.path.join(get_progress_dir(), f"{job_id}_final_results.csv")
    df.to_csv(csv_file, index=False)
    print(f"Final results saved to {csv_file}")

def process_chunk(pool, chunk, process_func, chunk_index, total_chunks, killer):
    """处理单个数据块,支持优雅终止"""
    print(f"\nProcessing chunk {chunk_index+1}/{total_chunks}")
    
    # 创建进度条
    pbar = tqdm(total=len(chunk), desc=f"Chunk {chunk_index+1}/{total_chunks}")
    
    # 使用线程池处理结果,避免阻塞主进程
    results = []
    try:
        # 使用imap_unordered获取结果,同时支持中断
        for result in pool.imap_unordered(process_func, chunk):
            # 检查是否需要停止
            if killer.kill_now:
                pbar.close()
                return results, False
                
            if result is not None:
                results.append(result)
                result_queue.put(result)  # 将结果放入队列
                
            pbar.update(1)
    except Exception as e:
        print(f"Error processing chunk {chunk_index+1}: {e}")
    
    pbar.close()
    return results, True

def main():
    """主函数:程序入口点"""
    # 参数检查和初始化
    if len(sys.argv) < 2:
        print("Usage: python script.py <file_path> [num_processes] [chunk_size]")
        sys.exit(1)
        
    file_path = sys.argv[1]
    num_processes = int(sys.argv[2]) if len(sys.argv) > 2 else 28
    chunk_size = int(sys.argv[3]) if len(sys.argv) > 3 else 5000
    
    # 生成基于输入文件的唯一作业ID
    job_id = get_job_id(file_path)
    print(f"Job ID: {job_id}")
    
    # 初始化信号处理器,允许优雅退出
    killer = GracefulKiller()
    
    # 检查输入文件是否存在
    if not os.path.exists(file_path):
        print(f"Error: File '{file_path}' does not exist")
        sys.exit(1)

    # 读取输入文件
    print(f"Reading input file: {file_path}")
    with open(file_path, 'r') as file:
        adrns = file.readlines()

    adrns = [adrn.strip() for adrn in adrns]
    adrns_count = len(adrns)
    print(f"Total Adrns: {adrns_count}")
    print(f"Using {num_processes} processes")
    print(f"Chunk size: {chunk_size}")
    print(f"Progress directory: {get_progress_dir()}")

    select_channels = ["/pilot/map/position", "/perception/frame/traffic_sign"]
    
    # 将ADRN列表分割成固定大小的数据块
    chunks = [adrns[i:i+chunk_size] for i in range(0, adrns_count, chunk_size)]
    
    # 初始化统计计数器
    has_both_channels_count = 0
    has_map_channel_count = 0
    has_perception_channel_count = 0
    has_no_channel_count = 0
    
    # 检查是否有保存的进度
    progress = load_progress(job_id)
    if progress:
        start_chunk = progress['current_chunk']
        processed_count = progress['processed_count']
        counters = progress['counters']
        
        has_both_channels_count = counters[0]
        has_map_channel_count = counters[1]
        has_perception_channel_count = counters[2]
        has_no_channel_count = counters[3]
        
        print(f"Resuming from chunk {start_chunk+1}/{len(chunks)}")
        print(f"Already processed {processed_count}/{adrns_count} ADRNs")
        print(f"Current counters: both={has_both_channels_count}, map={has_map_channel_count}, "
              f"traffic={has_perception_channel_count}, neither={has_no_channel_count}")
    else:
        start_chunk = 0
        processed_count = 0
        print("Starting from the beginning")
    
    start_time = time.time()
    
    # 创建部分应用函数,固定select_channels参数
    process_func = partial(process_adrn, select_channels=select_channels)
    
    # 创建进程池
    pool = mp.Pool(processes=num_processes, initializer=worker_init)
    
    # 从上次中断的块开始处理
    for i in range(start_chunk, len(chunks)):
        chunk = chunks[i]
        
        # 处理当前块
        results, success = process_chunk(
            pool, chunk, process_func, i, len(chunks), killer
        )
        
        # 如果处理成功,汇总结果
        if success:
            # 汇总当前块的统计结果
            chunk_both = sum(r[0] for r in results)
            chunk_map = sum(r[1] for r in results)
            chunk_traffic = sum(r[2] for r in results)
            chunk_neither = sum(r[3] for r in results)
            
            # 更新全局计数器
            has_both_channels_count += chunk_both
            has_map_channel_count += chunk_map
            has_perception_channel_count += chunk_traffic
            has_no_channel_count += chunk_neither
            
            # 更新已处理ADRN数量
            processed_count += len(chunk)
            if processed_count > adrns_count:
                processed_count = adrns_count
                
            # 保存当前进度
            save_progress(
                job_id, 
                i + 1, 
                processed_count, 
                [has_both_channels_count, has_map_channel_count, has_perception_channel_count, has_no_channel_count]
            )
            
            # 计算已用时间和预计剩余时间
            elapsed = time.time() - start_time
            if processed_count > 0:
                rate = processed_count / elapsed
                remaining = (adrns_count - processed_count) / rate
                print(f"Processed: {processed_count}/{adrns_count} ({processed_count/adrns_count*100:.2f}%)")
                print(f"Elapsed: {elapsed/60:.2f} minutes, Remaining: {remaining/60:.2f} minutes")
                print(f"Rate: {rate:.2f} ADRNs/second")
                
                # 每处理完10个块保存一次中间结果
                if (i + 1) % 10 == 0:
                    save_intermediate_results(
                        job_id, i + 1, processed_count, adrns_count,
                        has_both_channels_count, has_map_channel_count,
                        has_perception_channel_count, has_no_channel_count,
                        elapsed
                    )
        
        # 检查是否需要停止
        if killer.kill_now:
            print("Gracefully shutting down...")
            
            # 保存当前进度
            save_progress(
                job_id, 
                i,  # 保存当前块的索引,而不是下一个块
                processed_count, 
                [has_both_channels_count, has_map_channel_count, has_perception_channel_count, has_no_channel_count]
            )
            
            # 清理进程池
            pool.terminate()
            pool.join()
            
            print("Shutdown complete. Progress saved.")
            sys.exit(0)

    # 处理完成后关闭进程池
    pool.close()
    # 它的作用是阻塞主进程,直到进程池中的所有子进程都执行完毕。也就是说,当程序执行到“PROCESS_POOL.join()”时,主进程会暂停,等待进程池中的所有进程完成它们的任务。这很重要,因为如果不等待子进程完成,主进程可能会提前退出,从而导致一些子进程可能没有完成任务就被强制终止,或者导致程序的逻辑混乱。
    pool.join()
    
    # 删除进度文件
    progress_file = get_progress_file(job_id)
    if os.path.exists(progress_file):
        os.remove(progress_file)
        print(f"Progress file {progress_file} deleted")

    # 打印最终统计结果
    print("\n===== Final Results =====")
    print(f"Adrns with both channels: {has_both_channels_count}/{adrns_count} ({has_both_channels_count/adrns_count*100:.2f}%)")
    print(f"Adrns with map channel: {has_map_channel_count}/{adrns_count} ({has_map_channel_count/adrns_count*100:.2f}%)")
    print(f"Adrns with traffic_sign channel: {has_perception_channel_count}/{adrns_count} ({has_perception_channel_count/adrns_count*100:.2f}%)")
    print(f"Adrns with no map and traffic_sign channel: {has_no_channel_count}/{adrns_count} ({has_no_channel_count/adrns_count*100:.2f}%)")
    
    # 计算并打印总耗时
    total_time = time.time() - start_time
    print(f"\nTotal processing time: {total_time/60:.2f} minutes")
    
    # 保存最终结果到CSV
    save_final_results(
        job_id, adrns_count, has_both_channels_count, has_map_channel_count,
        has_perception_channel_count, has_no_channel_count, total_time
    )

if __name__ == '__main__':
    main()

上面有个问题就是读取数据的时候,还是一下子把所有数据都读入了进来,然后进行分块。

如果遇到一下子把所有数据读入进来内存爆内存的情况呢? 此时可能得需要流式读取文件了, 也写一个参考代码吧。

import os
import sys
import multiprocessing as mp
from functools import partial
from ad_cloud.adrn.data_seeker.dataclip import get_available_channels
from tqdm import tqdm
import time
import pandas as pd
import json
import signal
import hashlib
import uuid
import threading
import queue
import mmap
import gc

# 全局变量
stop_event = threading.Event()  # 用于通知所有线程停止
result_queue = queue.Queue()  # 用于收集结果
PROCESS_POOL = None  # 全局进程池

class GracefulKiller:
    """处理SIGINT和SIGTERM信号,允许程序优雅退出"""
    kill_now = False
    kill_count = 0
    
    def __init__(self):
        # 注册信号处理函数
        signal.signal(signal.SIGINT, self.exit_gracefully)
        signal.signal(signal.SIGTERM, self.exit_gracefully)

    def exit_gracefully(self, signum, frame):
        """信号处理函数"""
        self.kill_count += 1
        
        if self.kill_count == 1:
            print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
            print("Press Ctrl+C again to force quit (may cause data loss).")
            self.kill_now = True
            stop_event.set()  # 设置停止事件,通知所有线程
        else:
            print("\nForcing immediate termination...")
            # 强制退出,不保存进度
            if PROCESS_POOL:
                PROCESS_POOL.terminate()
                PROCESS_POOL.join()
            os._exit(1)

def process_adrn(adrn, select_channels):
    """处理单个ADRN并返回统计结果"""
    try:
        channels_list = get_available_channels(adrn)
        if select_channels[0] in channels_list and select_channels[1] in channels_list:
            return (1, 0, 0, 0)  # 两个通道都有
        elif select_channels[0] in channels_list:
            return (0, 1, 0, 0)  # 只有地图通道
        elif select_channels[1] in channels_list:
            return (0, 0, 1, 0)  # 只有交通标志通道
        else:
            return (0, 0, 0, 1)  # 两个通道都没有
    except Exception as e:
        print(f"Error processing {adrn}: {e}")
        return (0, 0, 0, 0)  # 出错情况不计入统计

def worker_init():
    """工作进程初始化函数"""
    signal.signal(signal.SIGINT, signal.SIG_IGN)

def get_job_id(file_path):
    """生成基于输入文件的唯一作业ID"""
    return hashlib.md5(file_path.encode()).hexdigest()[:8]

def get_progress_dir():
    """获取进度文件目录"""
    progress_dir = os.environ.get('PROGRESS_DIR', '/tmp/adrn_progress')
    os.makedirs(progress_dir, exist_ok=True)
    return progress_dir

def get_progress_file(job_id):
    """获取进度文件路径"""
    return os.path.join(get_progress_dir(), f"{job_id}_progress.json")

def save_progress(job_id, current_position, processed_count, counters):
    """保存当前进度到文件"""
    progress = {
        'current_position': current_position,
        'processed_count': processed_count,
        'counters': counters,
        'timestamp': time.time()
    }
    
    progress_file = get_progress_file(job_id)
    with open(progress_file, 'w') as f:
        json.dump(progress, f)
    print(f"Progress saved to {progress_file}")

def load_progress(job_id):
    """从文件加载进度"""
    progress_file = get_progress_file(job_id)
    
    if not os.path.exists(progress_file):
        return None
    
    try:
        with open(progress_file, 'r') as f:
            progress = json.load(f)
        print(f"Loaded progress from {progress_file}")
        return progress
    except Exception as e:
        print(f"Error loading progress: {e}")
        return None

def save_intermediate_results(job_id, processed, total, both, map_only, traffic_only, neither, elapsed):
    """保存中间处理结果"""
    results = {
        'processed_adrns': [processed],
        'total_adrns': [total],
        'both_channels': [both],
        'map_only': [map_only],
        'traffic_only': [traffic_only],
        'neither': [neither],
        'elapsed_minutes': [elapsed/60],
        'rate_adrns_per_second': [processed/elapsed]
    }
    df = pd.DataFrame(results)
    
    csv_file = os.path.join(get_progress_dir(), f"{job_id}_intermediate.csv")
    
    # 检查文件是否存在,不存在则创建并写入表头
    if not os.path.exists(csv_file):
        df.to_csv(csv_file, index=False)
    else:
        df.to_csv(csv_file, index=False, mode='a', header=False)

def save_final_results(job_id, total, both, map_only, traffic_only, neither, total_time):
    """保存最终结果到CSV"""
    results = {
        'total_adrns': [total],
        'both_channels': [both],
        'both_channels_percent': [both/total*100],
        'map_only': [map_only],
        'map_only_percent': [map_only/total*100],
        'traffic_only': [traffic_only],
        'traffic_only_percent': [traffic_only/total*100],
        'neither': [neither],
        'neither_percent': [neither/total*100],
        'total_time_minutes': [total_time/60],
        'rate_adrns_per_second': [total/total_time]
    }
    df = pd.DataFrame(results)
    
    csv_file = os.path.join(get_progress_dir(), f"{job_id}_final_results.csv")
    df.to_csv(csv_file, index=False)
    print(f"Final results saved to {csv_file}")

def count_total_lines(file_path):
    """高效计算文件总行数"""
    print("Counting total lines in file...")
    with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
        try:
            # 使用mmap提高大文件行数统计效率
            buf = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
            lines = 0
            read_size = 1024 * 1024  # 1MB块大小
            buf_size = len(buf)
            
            for i in range(0, buf_size, read_size):
                lines += buf[i:i+read_size].count(b'\n')
                if stop_event.is_set():
                    break
            return lines
        except:
            # 备选方案:较慢但兼容性更好
            print("Falling back to slower line count method...")
            return sum(1 for _ in f)

def process_chunk(chunk, process_func, killer, pbar=None):
    """处理单个数据块"""
    results = []
    
    try:
        for result in process_func(chunk):
            if killer.kill_now:
                return results, False
                
            if result is not None:
                results.append(result)
            
            if pbar:
                pbar.update(1)
    except Exception as e:
        print(f"Error processing chunk: {e}")
    
    return results, True

def stream_process_file(file_path, num_processes, chunk_size, job_id, killer):
    """流式处理文件,避免一次性加载"""
    global PROCESS_POOL
    
    select_channels = ["/pilot/map/position", "/perception/frame/traffic_sign"]
    process_func = partial(process_adrn, select_channels=select_channels)
    
    # 计算总行数
    total_lines = count_total_lines(file_path)
    print(f"Total lines in file: {total_lines}")
    
    # 初始化计数器
    has_both_channels_count = 0
    has_map_channel_count = 0
    has_perception_channel_count = 0
    has_no_channel_count = 0
    processed_count = 0
    current_position = 0
    
    # 检查进度
    progress = load_progress(job_id)
    if progress:
        current_position = progress['current_position']
        processed_count = progress['processed_count']
        counters = progress['counters']
        
        has_both_channels_count = counters[0]
        has_map_channel_count = counters[1]
        has_perception_channel_count = counters[2]
        has_no_channel_count = counters[3]
        
        print(f"Resuming from position {current_position}")
        print(f"Already processed {processed_count}/{total_lines} ADRNs")
    
    start_time = time.time()
    
    # 创建进程池
    PROCESS_POOL = mp.Pool(processes=num_processes, initializer=worker_init)
    
    try:
        chunk = []
        line_count = 0
        
        with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
            # 定位到上次处理的位置
            if current_position > 0:
                f.seek(current_position)
            
            # 创建进度条
            pbar = tqdm(total=total_lines, initial=processed_count, desc="Processing ADRNs")
            
            # 逐行读取并处理
            for line in f:
                if killer.kill_now:
                    break
                    
                adrn = line.strip()
                if adrn:  # 跳过空行
                    chunk.append(adrn)
                    line_count += 1
                    
                    # 达到块大小时处理
                    if len(chunk) >= chunk_size:
                        # 并行处理块
                        pool_process = lambda chunk: PROCESS_POOL.imap_unordered(process_func, chunk)
                        results, success = process_chunk(pool_process, chunk, killer, pbar)
                        
                        if not success:
                            break
                        
                        # 汇总结果
                        chunk_both = sum(r[0] for r in results)
                        chunk_map = sum(r[1] for r in results)
                        chunk_traffic = sum(r[2] for r in results)
                        chunk_neither = sum(r[3] for r in results)
                        
                        has_both_channels_count += chunk_both
                        has_map_channel_count += chunk_map
                        has_perception_channel_count += chunk_traffic
                        has_no_channel_count += chunk_neither
                        
                        processed_count += len(chunk)
                        current_position = f.tell()  # 保存当前文件位置
                        
                        # 保存进度
                        save_progress(
                            job_id, 
                            current_position, 
                            processed_count, 
                            [has_both_channels_count, has_map_channel_count, has_perception_channel_count, has_no_channel_count]
                        )
                        
                        # 计算状态
                        elapsed = time.time() - start_time
                        if processed_count > 0:
                            rate = processed_count / elapsed
                            remaining = (total_lines - processed_count) / rate
                            print(f"\nProcessed: {processed_count}/{total_lines} ({processed_count/total_lines*100:.2f}%)")
                            print(f"Elapsed: {elapsed/60:.2f} minutes, Remaining: {remaining/60:.2f} minutes")
                            print(f"Rate: {rate:.2f} ADRNs/second")
                            
                            # 保存中间结果
                            if processed_count % (10 * chunk_size) == 0:
                                save_intermediate_results(
                                    job_id, processed_count, total_lines,
                                    has_both_channels_count, has_map_channel_count,
                                    has_perception_channel_count, has_no_channel_count,
                                    elapsed
                                )
                        
                        # 清空块并触发垃圾回收
                        chunk = []
                        gc.collect()
            
            # 处理最后剩余的数据
            if chunk and not killer.kill_now:
                pool_process = lambda chunk: PROCESS_POOL.imap_unordered(process_func, chunk)
                results, _ = process_chunk(pool_process, chunk, killer, pbar)
                
                chunk_both = sum(r[0] for r in results)
                chunk_map = sum(r[1] for r in results)
                chunk_traffic = sum(r[2] for r in results)
                chunk_neither = sum(r[3] for r in results)
                
                has_both_channels_count += chunk_both
                has_map_channel_count += chunk_map
                has_perception_channel_count += chunk_traffic
                has_no_channel_count += chunk_neither
                
                processed_count += len(chunk)
                current_position = f.tell()
                
                save_progress(
                    job_id, 
                    current_position, 
                    processed_count, 
                    [has_both_channels_count, has_map_channel_count, has_perception_channel_count, has_no_channel_count]
                )
            
            pbar.close()
    
    finally:
        # 清理进程池
        if PROCESS_POOL:
            PROCESS_POOL.close()
            PROCESS_POOL.join()
            PROCESS_POOL = None
    
    # 检查是否被中断
    if killer.kill_now:
        print("Shutdown complete. Progress saved.")
        sys.exit(0)
    
    # 删除进度文件
    progress_file = get_progress_file(job_id)
    if os.path.exists(progress_file):
        os.remove(progress_file)
        print(f"Progress file {progress_file} deleted")
    
    # 打印结果
    print("\n===== Final Results =====")
    print(f"Adrns with both channels: {has_both_channels_count}/{total_lines} ({has_both_channels_count/total_lines*100:.2f}%)")
    print(f"Adrns with map channel: {has_map_channel_count}/{total_lines} ({has_map_channel_count/total_lines*100:.2f}%)")
    print(f"Adrns with traffic_sign channel: {has_perception_channel_count}/{total_lines} ({has_perception_channel_count/total_lines*100:.2f}%)")
    print(f"Adrns with no map and traffic_sign channel: {has_no_channel_count}/{total_lines} ({has_no_channel_count/total_lines*100:.2f}%)")
    
    # 计算耗时
    total_time = time.time() - start_time
    print(f"\nTotal processing time: {total_time/60:.2f} minutes")
    
    # 保存最终结果
    save_final_results(
        job_id, total_lines, has_both_channels_count, has_map_channel_count,
        has_perception_channel_count, has_no_channel_count, total_time
    )

def main():
    """主函数"""
    if len(sys.argv) < 2:
        print("Usage: python script.py <file_path> [num_processes] [chunk_size]")
        sys.exit(1)
        
    file_path = sys.argv[1]
    num_processes = int(sys.argv[2]) if len(sys.argv) > 2 else 28
    chunk_size = int(sys.argv[3]) if len(sys.argv) > 3 else 10000  # 默认每块10000行
    
    job_id = get_job_id(file_path)
    print(f"Job ID: {job_id}")
    
    killer = GracefulKiller()
    
    if not os.path.exists(file_path):
        print(f"Error: File '{file_path}' does not exist")
        sys.exit(1)
    
    print(f"Reading file in streaming mode, chunk size: {chunk_size}")
    print(f"Using {num_processes} processes")
    
    stream_process_file(file_path, num_processes, chunk_size, job_id, killer)

if __name__ == '__main__':
    main()

小总

通过这次任务, 我觉得在工作中利用好多进程还是非常能提效的, 最终从30多天, 降到了10个小时完成任务,这是一个非常大的提升。

上面这个代码的设计我觉得非常值的学习, 从终止信号处理,打断点, 大数据的分块, 多进程等都非常值的借鉴, 学习到了很多。

上面这个也可以提炼出一个框架来,去完成不同的任务。

### 编写 Python 代码的基础环境 在 Ubuntu 系统中编写 Python 代码通常需要先安装并配置好开发环境。Python 是一种解释型编程语言,在大多数现代 Linux 发行版(包括 Ubuntu)中默认已预装 Python 解释器[^1]。 如果尚未安装 Python,则可以通过以下命令来安装最新版本的 Python: ```bash sudo apt update sudo apt install python3 ``` 验证安装是否成功可以运行如下命令查看当前系统的 Python 版本号: ```bash python3 --version ``` ### 使用文本编辑器或集成开发环境 (IDE) 为了方便地创建和修改 Python 脚本文件,可以选择合适的工具来进行编码工作。常见的选择有简单的文本编辑器或者功能强大的 IDE。 #### 推荐使用的轻量级编辑器 对于初学者来说,推荐使用像 `nano` 或者更高级一点的 `vim`/`emacs` 这样的纯文本编辑器。这些都可以通过终端直接调用打开新的脚本文件。例如利用 nano 创建名为 script.py 的新文件可执行下面这条指令: ```bash nano script.py ``` #### 更专业的解决方案 - 安装 PyCharm 或 VSCode 当然还有许多专门针对 Python 开发设计的强大软件可供选用,比如 JetBrains 出品的商业产品PyCharm以及微软推出的免费开源项目 Visual Studio Code(VSCode)[^2]。它们提供了诸如语法高亮显示、错误检测等功能从而极大地提高了工作效率。 完成上述任意方式之一之后就可以开始书写自己的第一个程序啦! ```python print("Hello, world!") ``` 保存此段落到之前建立的那个 .py 文件里去吧!接着回到 shell 中再次输入 python3 加上刚刚那个文档的名字就能看到效果咯~ ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值