写在前面
这个事情的背景是今天工作老板交代了一个数据统计的任务, 给了一个txt文件, 数据的每一行是一个clip_adrn,组织形式大概是这样:
clip_adrn1,
clip_adrn2,
......
每个clip, 可以通过一个函数get_available_channels
去获取里面的所有通道名。
老板想看下, 同时含有 ["/pilot/map/position", "/perception/frame/traffic_sign"]
通道的clip的数量占比。
这个任务刚拿过来之后以为非常简单, 然后确定逻辑:
1. 读数txt文件
2. 写个for循环,遍历每个clip_adrn
3. 针对每个clip_adrn, 获取所有通道, 如果同时含有上面两个通道,cnt+=1
4. 输出最后结果
暴力破解
完成分析之后,立马写了第一版代码, 这个代码不用解释,逻辑非常之简单:
import os
import sys
from ad_cloud.adrn.data_seeker.dataclip import get_available_channels
from tqdm import tqdm
# adrn = "adr::dataclip:eng.bhd.0039::1683435549501586176:1683435551001586176"
# channels_list = get_available_channels(adrn)
# print(channels_list)
if __name__ == '__main__':
file_path = sys.argv[1]
if not os.path.exists(file_path):
print(f"Error: File '{file_path}' does not exist")
sys.exit(1)
with open(file_path, 'r') as file:
adrns = file.readlines()
adrns = [adrn.strip() for adrn in adrns]
adrns_count = len(adrns)
print(f"Total Adrns: {adrns_count}")
select_channels = ["/pilot/map/position", "/perception/frame/traffic_sign"]
has_both_channels_count = 0
has_map_channel_count = 0
has_perception_channel_count = 0
has_no_channel_count = 0
# 遍历每个adrn, 去统计
for i in tqdm(range(adrns_count)):
channels_list = get_available_channels(adrns[i])
if select_channels[0] in channels_list and select_channels[1] in channels_list:
has_both_channels_count += 1
elif select_channels[0] in channels_list:
has_map_channel_count += 1
elif select_channels[1] in channels_list:
has_perception_channel_count += 1
else:
has_no_channel_count += 1
# 最后输出同时有俩通道的,有其中一个通道的 以及没有这俩通道的统计结果
print(f"Adrns with both channels: {has_both_channels_count}/{adrns_count}")
print(f"Adrns with map channel: {has_map_channel_count}/{adrns_count}")
print(f"Adrns with traffic_sign channel: {has_perception_channel_count}/{adrns_count}")
print(f"Adrns with no map and traffic_sign channel: {has_no_channel_count}/{adrns_count}")
写完之后, 导入文件,一运行, 发现事情不是这么简单。
原因是, clip_adrn的数量3000000, 平均1个clip_adrn处理用1s的话, 处理完这么多数据也得需要30多天。
What? 老板今天交代的任务, 难不成要说: 老板,这个数据有点多,一个月后才能给你结果。 老板可能说: 呵呵, 滚蛋。
优雅多进程
看到这个结果之后,我就发现,这个事情不能暴力破解, 得需要做一些优化的方式,最快的方式就是 并行化。
于是想用多进程处理下这个任务。
多进程的话, 需要考虑几个问题:
- 用多少个进程去处理
- 这么多数据怎么去读
- 如果中途统计出现了中断, 应该怎么回复任务
- 如果中途卡死了, 应该怎么手动结束程序, 再通过断点进行恢复
- 多进程应该怎么写
带着这几个问题, 在LLM的帮助之下, 慢慢得到了代码。
下面基于每个问题去做分析, 给出分块代码,最后给出完整代码去完成任务。 导入包:
import os
import sys
import multiprocessing as mp
from functools import partial
from ad_cloud.adrn.data_seeker.dataclip import get_available_channels
from tqdm import tqdm
import time
import pandas as pd
import json
import signal
import hashlib
import uuid
import threading
import queue
import concurrent.futures
# 全局变量
# 用于在多线程环境中通知所有线程停止运行。当程序需要停止时,会调用 stop_event.set(),所有监听该事件的线程会收到通知并停止工作。
stop_event = threading.Event() # 用于通知所有线程停止
# 用于在多线程环境中收集各个线程的处理结果。线程可以将结果放入队列中,主程序可以从队列中取出结果进行汇总。
result_queue = queue.Queue() # 用于收集结果
让程序优雅退出
这里是写了一个类,处理接收到的终止信号,比如我们常用的Ctrl+C。
主要目的是实现程序的优雅退出机制。通过全局变量和信号处理机制,程序可以在接收到终止信号时,通知所有线程停止工作,并在必要时强制退出。这有助于在多线程环境中安全地关闭程序,避免数据丢失或资源泄漏。
class GracefulKiller:
"""处理SIGINT和SIGTERM信号,允许程序优雅退出"""
kill_now = False
kill_count = 0
def __init__(self):
# 注册信号处理函数
# SIGINT: 用户 Ctrl+C产生的信号
signal.signal(signal.SIGINT, self.exit_gracefully)
# SIGTERM: 通常由系统或用户通过命令行发送的终止信号
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, signum):
"""信号处理函数"""
self.kill_count += 1
# 如果第一次收到终止信号, 优雅停止, 但我实际操作里面这个好像终止不了
if self.kill_count == 1:
print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
print("Press Ctrl+C again to force quit (may cause data loss).")
self.kill_now = True
stop_event.set() # 设置停止事件,通知所有线程
else: # 强制停止, 不保存进度,可能会数据丢失
print("\nForcing immediate termination...")
# 强制退出,不保存进度
os._exit(1)
def worker_init():
"""工作进程初始化函数
设置子进程忽略SIGINT信号,由主进程处理
"""
signal.signal(signal.SIGINT, signal.SIG_IGN)
打断点
300w的数据量, 处理起来一定耗时很长,和模型训练一样,如果中途担心怕断了重新训练,需要打断点处理, 这样中途断了也能续上继续跑。
这里的实现思路是:多进程的话,文件要做分块处理, 每一块使用多进程快速完成。 所以处理完一块之后,要打上断点,记录一些中间的关键结果到公共文件,这样中途即使断掉, 也没有问题了。
def get_job_id(file_path):
"""生成基于输入文件的唯一作业ID"""
return hashlib.md5(file_path.encode()).hexdigest()[:8]
def get_progress_dir():
"""获取进度文件目录,优先使用环境变量指定的目录"""
progress_dir = os.environ.get('PROGRESS_DIR', '/mnt/output-data-storage/zhongqiang/adrn_progress')
os.makedirs(progress_dir, exist_ok=True)
return progress_dir
def get_progress_file(job_id):
"""获取进度文件路径"""
return os.path.join(get_progress_dir(), f"{job_id}_progress.json")
def save_progress(job_id, current_chunk, processed_count, counters):
"""保存当前进度到文件"""
progress = {
'current_chunk': current_chunk,
'processed_count': processed_count,
'counters': counters,
'timestamp': time.time()
}
progress_file = get_progress_file(job_id)
with open(progress_file, 'w') as f:
json.dump(progress, f)
print(f"Progress saved to {progress_file}")
def load_progress(job_id):
"""从文件加载进度"""
progress_file = get_progress_file(job_id)
if not os.path.exists(progress_file):
return None
try:
with open(progress_file, 'r') as f:
progress = json.load(f)
print(f"Loaded progress from {progress_file}")
return progress
except Exception as e:
print(f"Error loading progress: {e}")
return None
针对输入的数据文件名,生成一个唯一的job_id, 然后根据这个job_id生成一个{job_id}_progress.json
保存断点信息, 这里的断点信息是:
progress = {
'current_chunk': current_chunk,
'processed_count': processed_count,
'counters': counters,
'timestamp': time.time()
}
当前的current_chunk, 已经处理的chunk 数量, 统计结果和时间。 这样再次恢复的时候,就能从current_chunk+1继续往后统计了。
启动程序的时候,如果有这样的断点文件,读入断点文件,往后继续,如果没有,那就从头开始。
多进程
那么多进程代码应该怎么写呢?
先写一个处理每个adrn的函数, 给定一个adrn, 需要统计通道。
def process_adrn(adrn, select_channels):
"""处理单个ADRN并返回统计结果
返回格式: (both_channels, map_only, traffic_only, neither)
"""
try:
channels_list = get_available_channels(adrn)
if select_channels[0] in channels_list and select_channels[1] in channels_list:
return (1, 0, 0, 0) # 两个通道都有
elif select_channels[0] in channels_list:
return (0, 1, 0, 0) # 只有地图通道
elif select_channels[1] in channels_list:
return (0, 0, 1, 0) # 只有交通标志通道
else:
return (0, 0, 0, 1) # 两个通道都没有
except Exception as e:
print(f"Error processing {adrn}: {e}")
return (0, 0, 0, 0) # 出错情况不计入统计
再写一个处理一个chunk的函数, 即多个clip adrn,我们用多进程处理。
def process_chunk(pool, chunk, process_func, chunk_index, total_chunks, killer):
"""处理单个数据块,支持优雅终止
pool: 多近程池对象,用于并行处理数据
chunk: 当前要处理的数据快, ADRN列表的一部分
process_func: 处理单个ADRN函数
chunk_index: 当前块的索引
total_chunks: 总块数
killer: 用于检测终止信号
"""
print(f"\nProcessing chunk {chunk_index+1}/{total_chunks}")
# 创建进度条
pbar = tqdm(total=len(chunk), desc=f"Chunk {chunk_index+1}/{total_chunks}")
# 使用线程池处理结果,避免阻塞主进程
results = []
try:
# 使用imap_unordered获取结果,同时支持中断
# imap_unordered: 将一个可迭代对象(如列表、元组等)中的元素作为参数,分发给多个进程,然后将这些进程的返回结果收集起来。不过,与imap方法不同的是,imap_unordered返回的结果顺序是不确定的
# 例如,假设你有一个函数process_func,它接受一个参数并进行处理。你有一个元素列表chunk,你想对chunk中的每个元素应用process_func。使用pool.imap_unordered(process_func,chunk)后,多个进程会同时对chunk中的元素进行处理,但返回的结果顺序可能和chunk中元素的顺序不同。这在某些情况下是有优势的,比如当你只关心每个元素的处理结果,而不关心结果的顺序时,使用imap_unordered可以提高效率,因为它可以更快地收集到已经完成的任务的结果,而不需要等待所有任务按照顺序完成。
for result in pool.imap_unordered(process_func, chunk):
# 检查是否需要停止
if killer.kill_now:
pbar.close()
return results, False
if result is not None:
results.append(result)
result_queue.put(result) # 将结果放入队列
pbar.update(1)
except Exception as e:
print(f"Error processing chunk {chunk_index+1}: {e}")
pbar.close()
return results, True
主函数
def main():
"""主函数:程序入口点"""
# 参数检查和初始化
if len(sys.argv) < 2:
print("Usage: python script.py <file_path> [num_processes] [chunk_size]")
sys.exit(1)
# 这种添加参数的方法也可以参考下
file_path = sys.argv[1]
# 进程这里由于我这里的机器是32核, 这里留了4核给系统, 用了28核
num_processes = int(sys.argv[2]) if len(sys.argv) > 2 else 28
chunk_size = int(sys.argv[3]) if len(sys.argv) > 3 else 5000
# 生成基于输入文件的唯一作业ID
job_id = get_job_id(file_path)
print(f"Job ID: {job_id}")
# 初始化信号处理器,允许优雅退出
killer = GracefulKiller()
# 检查输入文件是否存在
if not os.path.exists(file_path):
print(f"Error: File '{file_path}' does not exist")
sys.exit(1)
# 读取输入文件
# 这里其实是都读入进来了
print(f"Reading input file: {file_path}")
with open(file_path, 'r') as file:
adrns = file.readlines()
adrns = [adrn.strip() for adrn in adrns]
adrns_count = len(adrns)
print(f"Total Adrns: {adrns_count}")
print(f"Using {num_processes} processes")
print(f"Chunk size: {chunk_size}")
print(f"Progress directory: {get_progress_dir()}")
select_channels = ["/pilot/map/position", "/perception/frame/traffic_sign"]
# 将ADRN列表分割成固定大小的数据块
# 读入进来之后才分的chunk
chunks = [adrns[i:i+chunk_size] for i in range(0, adrns_count, chunk_size)]
# 初始化统计计数器
has_both_channels_count = 0
has_map_channel_count = 0
has_perception_channel_count = 0
has_no_channel_count = 0
# 检查是否有保存的进度
progress = load_progress(job_id)
# 如果已经有断点信息
if progress:
start_chunk = progress['current_chunk']
processed_count = progress['processed_count']
counters = progress['counters']
# 结果用保存好的结果
has_both_channels_count = counters[0]
has_map_channel_count = counters[1]
has_perception_channel_count = counters[2]
has_no_channel_count = counters[3]
print(f"Resuming from chunk {start_chunk+1}/{len(chunks)}")
print(f"Already processed {processed_count}/{adrns_count} ADRNs")
print(f"Current counters: both={has_both_channels_count}, map={has_map_channel_count}, "
f"traffic={has_perception_channel_count}, neither={has_no_channel_count}")
else:
start_chunk = 0
processed_count = 0
print("Starting from the beginning")
start_time = time.time()
# 创建偏函数函数,固定select_channels参数
process_func = partial(process_adrn, select_channels=select_channels)
# 创建进程池
pool = mp.Pool(processes=num_processes, initializer=worker_init)
# 从上次中断的块开始处理
for i in range(start_chunk, len(chunks)):
chunk = chunks[i]
# 处理当前块
results, success = process_chunk(
pool, chunk, process_func, i, len(chunks), killer
)
# 如果处理成功,汇总结果
if success:
# 汇总当前块的统计结果
chunk_both = sum(r[0] for r in results)
chunk_map = sum(r[1] for r in results)
chunk_traffic = sum(r[2] for r in results)
chunk_neither = sum(r[3] for r in results)
# 更新全局计数器
has_both_channels_count += chunk_both
has_map_channel_count += chunk_map
has_perception_channel_count += chunk_traffic
has_no_channel_count += chunk_neither
# 更新已处理ADRN数量
processed_count += len(chunk)
if processed_count > adrns_count:
processed_count = adrns_count
# 保存当前进度
save_progress(
job_id,
i + 1,
processed_count,
[has_both_channels_count, has_map_channel_count, has_perception_channel_count, has_no_channel_count]
)
# 计算已用时间和预计剩余时间
elapsed = time.time() - start_time
if processed_count > 0:
rate = processed_count / elapsed
remaining = (adrns_count - processed_count) / rate
print(f"Processed: {processed_count}/{adrns_count} ({processed_count/adrns_count*100:.2f}%)")
print(f"Elapsed: {elapsed/60:.2f} minutes, Remaining: {remaining/60:.2f} minutes")
print(f"Rate: {rate:.2f} ADRNs/second")
# 每处理完10个块保存一次中间结果
if (i + 1) % 10 == 0:
save_intermediate_results(
job_id, i + 1, processed_count, adrns_count,
has_both_channels_count, has_map_channel_count,
has_perception_channel_count, has_no_channel_count,
elapsed
)
# 检查是否需要停止
if killer.kill_now:
print("Gracefully shutting down...")
# 保存当前进度
save_progress(
job_id,
i, # 保存当前块的索引,而不是下一个块
processed_count,
[has_both_channels_count, has_map_channel_count, has_perception_channel_count, has_no_channel_count]
)
# 清理进程池
pool.terminate()
pool.join()
print("Shutdown complete. Progress saved.")
sys.exit(0)
# 处理完成后关闭进程池
pool.close()
pool.join()
# 删除进度文件
progress_file = get_progress_file(job_id)
if os.path.exists(progress_file):
os.remove(progress_file)
print(f"Progress file {progress_file} deleted")
# 打印最终统计结果
print("\n===== Final Results =====")
print(f"Adrns with both channels: {has_both_channels_count}/{adrns_count} ({has_both_channels_count/adrns_count*100:.2f}%)")
print(f"Adrns with map channel: {has_map_channel_count}/{adrns_count} ({has_map_channel_count/adrns_count*100:.2f}%)")
print(f"Adrns with traffic_sign channel: {has_perception_channel_count}/{adrns_count} ({has_perception_channel_count/adrns_count*100:.2f}%)")
print(f"Adrns with no map and traffic_sign channel: {has_no_channel_count}/{adrns_count} ({has_no_channel_count/adrns_count*100:.2f}%)")
# 计算并打印总耗时
total_time = time.time() - start_time
print(f"\nTotal processing time: {total_time/60:.2f} minutes")
# 保存最终结果到CSV
save_final_results(
job_id, adrns_count, has_both_channels_count, has_map_channel_count,
has_perception_channel_count, has_no_channel_count, total_time
)
多进程最终代码
最后附上一个能跑的最终代码
import os
import sys
import multiprocessing as mp
from functools import partial
from ad_cloud.adrn.data_seeker.dataclip import get_available_channels
from tqdm import tqdm
import time
import pandas as pd
import json
import signal
import hashlib
import uuid
import threading
import queue
import concurrent.futures
# 全局变量
stop_event = threading.Event() # 用于通知所有线程停止
result_queue = queue.Queue() # 用于收集结果
class GracefulKiller:
"""处理SIGINT和SIGTERM信号,允许程序优雅退出"""
kill_now = False
kill_count = 0
def __init__(self):
# 注册信号处理函数
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, signum, frame):
"""信号处理函数"""
self.kill_count += 1
if self.kill_count == 1:
print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
print("Press Ctrl+C again to force quit (may cause data loss).")
self.kill_now = True
stop_event.set() # 设置停止事件,通知所有线程
else:
print("\nForcing immediate termination...")
# 强制退出,不保存进度
os._exit(1)
def process_adrn(adrn, select_channels):
"""处理单个ADRN并返回统计结果
返回格式: (both_channels, map_only, traffic_only, neither)
"""
try:
channels_list = get_available_channels(adrn)
if select_channels[0] in channels_list and select_channels[1] in channels_list:
return (1, 0, 0, 0) # 两个通道都有
elif select_channels[0] in channels_list:
return (0, 1, 0, 0) # 只有地图通道
elif select_channels[1] in channels_list:
return (0, 0, 1, 0) # 只有交通标志通道
else:
return (0, 0, 0, 1) # 两个通道都没有
except Exception as e:
print(f"Error processing {adrn}: {e}")
return (0, 0, 0, 0) # 出错情况不计入统计
def worker_init():
"""工作进程初始化函数
设置子进程忽略SIGINT信号,由主进程处理
"""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def get_job_id(file_path):
"""生成基于输入文件的唯一作业ID"""
return hashlib.md5(file_path.encode()).hexdigest()[:8]
def get_progress_dir():
"""获取进度文件目录,优先使用环境变量指定的目录"""
progress_dir = os.environ.get('PROGRESS_DIR', '/tmp/adrn_progress')
os.makedirs(progress_dir, exist_ok=True)
return progress_dir
def get_progress_file(job_id):
"""获取进度文件路径"""
return os.path.join(get_progress_dir(), f"{job_id}_progress.json")
def save_progress(job_id, current_chunk, processed_count, counters):
"""保存当前进度到文件"""
progress = {
'current_chunk': current_chunk,
'processed_count': processed_count,
'counters': counters,
'timestamp': time.time()
}
progress_file = get_progress_file(job_id)
with open(progress_file, 'w') as f:
json.dump(progress, f)
print(f"Progress saved to {progress_file}")
def load_progress(job_id):
"""从文件加载进度"""
progress_file = get_progress_file(job_id)
if not os.path.exists(progress_file):
return None
try:
with open(progress_file, 'r') as f:
progress = json.load(f)
print(f"Loaded progress from {progress_file}")
return progress
except Exception as e:
print(f"Error loading progress: {e}")
return None
def save_intermediate_results(job_id, chunk_num, processed, total, both, map_only, traffic_only, neither, elapsed):
"""保存中间处理结果"""
results = {
'chunk_number': [chunk_num],
'processed_adrns': [processed],
'total_adrns': [total],
'both_channels': [both],
'map_only': [map_only],
'traffic_only': [traffic_only],
'neither': [neither],
'elapsed_minutes': [elapsed/60],
'rate_adrns_per_second': [processed/elapsed]
}
df = pd.DataFrame(results)
csv_file = os.path.join(get_progress_dir(), f"{job_id}_intermediate.csv")
if chunk_num == 10:
df.to_csv(csv_file, index=False)
else:
df.to_csv(csv_file, index=False, mode='a', header=False)
def save_final_results(job_id, total, both, map_only, traffic_only, neither, total_time):
"""保存最终结果到CSV"""
results = {
'total_adrns': [total],
'both_channels': [both],
'both_channels_percent': [both/total*100],
'map_only': [map_only],
'map_only_percent': [map_only/total*100],
'traffic_only': [traffic_only],
'traffic_only_percent': [traffic_only/total*100],
'neither': [neither],
'neither_percent': [neither/total*100],
'total_time_minutes': [total_time/60],
'rate_adrns_per_second': [total/total_time]
}
df = pd.DataFrame(results)
csv_file = os.path.join(get_progress_dir(), f"{job_id}_final_results.csv")
df.to_csv(csv_file, index=False)
print(f"Final results saved to {csv_file}")
def process_chunk(pool, chunk, process_func, chunk_index, total_chunks, killer):
"""处理单个数据块,支持优雅终止"""
print(f"\nProcessing chunk {chunk_index+1}/{total_chunks}")
# 创建进度条
pbar = tqdm(total=len(chunk), desc=f"Chunk {chunk_index+1}/{total_chunks}")
# 使用线程池处理结果,避免阻塞主进程
results = []
try:
# 使用imap_unordered获取结果,同时支持中断
for result in pool.imap_unordered(process_func, chunk):
# 检查是否需要停止
if killer.kill_now:
pbar.close()
return results, False
if result is not None:
results.append(result)
result_queue.put(result) # 将结果放入队列
pbar.update(1)
except Exception as e:
print(f"Error processing chunk {chunk_index+1}: {e}")
pbar.close()
return results, True
def main():
"""主函数:程序入口点"""
# 参数检查和初始化
if len(sys.argv) < 2:
print("Usage: python script.py <file_path> [num_processes] [chunk_size]")
sys.exit(1)
file_path = sys.argv[1]
num_processes = int(sys.argv[2]) if len(sys.argv) > 2 else 28
chunk_size = int(sys.argv[3]) if len(sys.argv) > 3 else 5000
# 生成基于输入文件的唯一作业ID
job_id = get_job_id(file_path)
print(f"Job ID: {job_id}")
# 初始化信号处理器,允许优雅退出
killer = GracefulKiller()
# 检查输入文件是否存在
if not os.path.exists(file_path):
print(f"Error: File '{file_path}' does not exist")
sys.exit(1)
# 读取输入文件
print(f"Reading input file: {file_path}")
with open(file_path, 'r') as file:
adrns = file.readlines()
adrns = [adrn.strip() for adrn in adrns]
adrns_count = len(adrns)
print(f"Total Adrns: {adrns_count}")
print(f"Using {num_processes} processes")
print(f"Chunk size: {chunk_size}")
print(f"Progress directory: {get_progress_dir()}")
select_channels = ["/pilot/map/position", "/perception/frame/traffic_sign"]
# 将ADRN列表分割成固定大小的数据块
chunks = [adrns[i:i+chunk_size] for i in range(0, adrns_count, chunk_size)]
# 初始化统计计数器
has_both_channels_count = 0
has_map_channel_count = 0
has_perception_channel_count = 0
has_no_channel_count = 0
# 检查是否有保存的进度
progress = load_progress(job_id)
if progress:
start_chunk = progress['current_chunk']
processed_count = progress['processed_count']
counters = progress['counters']
has_both_channels_count = counters[0]
has_map_channel_count = counters[1]
has_perception_channel_count = counters[2]
has_no_channel_count = counters[3]
print(f"Resuming from chunk {start_chunk+1}/{len(chunks)}")
print(f"Already processed {processed_count}/{adrns_count} ADRNs")
print(f"Current counters: both={has_both_channels_count}, map={has_map_channel_count}, "
f"traffic={has_perception_channel_count}, neither={has_no_channel_count}")
else:
start_chunk = 0
processed_count = 0
print("Starting from the beginning")
start_time = time.time()
# 创建部分应用函数,固定select_channels参数
process_func = partial(process_adrn, select_channels=select_channels)
# 创建进程池
pool = mp.Pool(processes=num_processes, initializer=worker_init)
# 从上次中断的块开始处理
for i in range(start_chunk, len(chunks)):
chunk = chunks[i]
# 处理当前块
results, success = process_chunk(
pool, chunk, process_func, i, len(chunks), killer
)
# 如果处理成功,汇总结果
if success:
# 汇总当前块的统计结果
chunk_both = sum(r[0] for r in results)
chunk_map = sum(r[1] for r in results)
chunk_traffic = sum(r[2] for r in results)
chunk_neither = sum(r[3] for r in results)
# 更新全局计数器
has_both_channels_count += chunk_both
has_map_channel_count += chunk_map
has_perception_channel_count += chunk_traffic
has_no_channel_count += chunk_neither
# 更新已处理ADRN数量
processed_count += len(chunk)
if processed_count > adrns_count:
processed_count = adrns_count
# 保存当前进度
save_progress(
job_id,
i + 1,
processed_count,
[has_both_channels_count, has_map_channel_count, has_perception_channel_count, has_no_channel_count]
)
# 计算已用时间和预计剩余时间
elapsed = time.time() - start_time
if processed_count > 0:
rate = processed_count / elapsed
remaining = (adrns_count - processed_count) / rate
print(f"Processed: {processed_count}/{adrns_count} ({processed_count/adrns_count*100:.2f}%)")
print(f"Elapsed: {elapsed/60:.2f} minutes, Remaining: {remaining/60:.2f} minutes")
print(f"Rate: {rate:.2f} ADRNs/second")
# 每处理完10个块保存一次中间结果
if (i + 1) % 10 == 0:
save_intermediate_results(
job_id, i + 1, processed_count, adrns_count,
has_both_channels_count, has_map_channel_count,
has_perception_channel_count, has_no_channel_count,
elapsed
)
# 检查是否需要停止
if killer.kill_now:
print("Gracefully shutting down...")
# 保存当前进度
save_progress(
job_id,
i, # 保存当前块的索引,而不是下一个块
processed_count,
[has_both_channels_count, has_map_channel_count, has_perception_channel_count, has_no_channel_count]
)
# 清理进程池
pool.terminate()
pool.join()
print("Shutdown complete. Progress saved.")
sys.exit(0)
# 处理完成后关闭进程池
pool.close()
# 它的作用是阻塞主进程,直到进程池中的所有子进程都执行完毕。也就是说,当程序执行到“PROCESS_POOL.join()”时,主进程会暂停,等待进程池中的所有进程完成它们的任务。这很重要,因为如果不等待子进程完成,主进程可能会提前退出,从而导致一些子进程可能没有完成任务就被强制终止,或者导致程序的逻辑混乱。
pool.join()
# 删除进度文件
progress_file = get_progress_file(job_id)
if os.path.exists(progress_file):
os.remove(progress_file)
print(f"Progress file {progress_file} deleted")
# 打印最终统计结果
print("\n===== Final Results =====")
print(f"Adrns with both channels: {has_both_channels_count}/{adrns_count} ({has_both_channels_count/adrns_count*100:.2f}%)")
print(f"Adrns with map channel: {has_map_channel_count}/{adrns_count} ({has_map_channel_count/adrns_count*100:.2f}%)")
print(f"Adrns with traffic_sign channel: {has_perception_channel_count}/{adrns_count} ({has_perception_channel_count/adrns_count*100:.2f}%)")
print(f"Adrns with no map and traffic_sign channel: {has_no_channel_count}/{adrns_count} ({has_no_channel_count/adrns_count*100:.2f}%)")
# 计算并打印总耗时
total_time = time.time() - start_time
print(f"\nTotal processing time: {total_time/60:.2f} minutes")
# 保存最终结果到CSV
save_final_results(
job_id, adrns_count, has_both_channels_count, has_map_channel_count,
has_perception_channel_count, has_no_channel_count, total_time
)
if __name__ == '__main__':
main()
上面有个问题就是读取数据的时候,还是一下子把所有数据都读入了进来,然后进行分块。
如果遇到一下子把所有数据读入进来内存爆内存的情况呢? 此时可能得需要流式读取文件了, 也写一个参考代码吧。
import os
import sys
import multiprocessing as mp
from functools import partial
from ad_cloud.adrn.data_seeker.dataclip import get_available_channels
from tqdm import tqdm
import time
import pandas as pd
import json
import signal
import hashlib
import uuid
import threading
import queue
import mmap
import gc
# 全局变量
stop_event = threading.Event() # 用于通知所有线程停止
result_queue = queue.Queue() # 用于收集结果
PROCESS_POOL = None # 全局进程池
class GracefulKiller:
"""处理SIGINT和SIGTERM信号,允许程序优雅退出"""
kill_now = False
kill_count = 0
def __init__(self):
# 注册信号处理函数
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, signum, frame):
"""信号处理函数"""
self.kill_count += 1
if self.kill_count == 1:
print(f"\nReceived signal {signum}. Initiating graceful shutdown...")
print("Press Ctrl+C again to force quit (may cause data loss).")
self.kill_now = True
stop_event.set() # 设置停止事件,通知所有线程
else:
print("\nForcing immediate termination...")
# 强制退出,不保存进度
if PROCESS_POOL:
PROCESS_POOL.terminate()
PROCESS_POOL.join()
os._exit(1)
def process_adrn(adrn, select_channels):
"""处理单个ADRN并返回统计结果"""
try:
channels_list = get_available_channels(adrn)
if select_channels[0] in channels_list and select_channels[1] in channels_list:
return (1, 0, 0, 0) # 两个通道都有
elif select_channels[0] in channels_list:
return (0, 1, 0, 0) # 只有地图通道
elif select_channels[1] in channels_list:
return (0, 0, 1, 0) # 只有交通标志通道
else:
return (0, 0, 0, 1) # 两个通道都没有
except Exception as e:
print(f"Error processing {adrn}: {e}")
return (0, 0, 0, 0) # 出错情况不计入统计
def worker_init():
"""工作进程初始化函数"""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def get_job_id(file_path):
"""生成基于输入文件的唯一作业ID"""
return hashlib.md5(file_path.encode()).hexdigest()[:8]
def get_progress_dir():
"""获取进度文件目录"""
progress_dir = os.environ.get('PROGRESS_DIR', '/tmp/adrn_progress')
os.makedirs(progress_dir, exist_ok=True)
return progress_dir
def get_progress_file(job_id):
"""获取进度文件路径"""
return os.path.join(get_progress_dir(), f"{job_id}_progress.json")
def save_progress(job_id, current_position, processed_count, counters):
"""保存当前进度到文件"""
progress = {
'current_position': current_position,
'processed_count': processed_count,
'counters': counters,
'timestamp': time.time()
}
progress_file = get_progress_file(job_id)
with open(progress_file, 'w') as f:
json.dump(progress, f)
print(f"Progress saved to {progress_file}")
def load_progress(job_id):
"""从文件加载进度"""
progress_file = get_progress_file(job_id)
if not os.path.exists(progress_file):
return None
try:
with open(progress_file, 'r') as f:
progress = json.load(f)
print(f"Loaded progress from {progress_file}")
return progress
except Exception as e:
print(f"Error loading progress: {e}")
return None
def save_intermediate_results(job_id, processed, total, both, map_only, traffic_only, neither, elapsed):
"""保存中间处理结果"""
results = {
'processed_adrns': [processed],
'total_adrns': [total],
'both_channels': [both],
'map_only': [map_only],
'traffic_only': [traffic_only],
'neither': [neither],
'elapsed_minutes': [elapsed/60],
'rate_adrns_per_second': [processed/elapsed]
}
df = pd.DataFrame(results)
csv_file = os.path.join(get_progress_dir(), f"{job_id}_intermediate.csv")
# 检查文件是否存在,不存在则创建并写入表头
if not os.path.exists(csv_file):
df.to_csv(csv_file, index=False)
else:
df.to_csv(csv_file, index=False, mode='a', header=False)
def save_final_results(job_id, total, both, map_only, traffic_only, neither, total_time):
"""保存最终结果到CSV"""
results = {
'total_adrns': [total],
'both_channels': [both],
'both_channels_percent': [both/total*100],
'map_only': [map_only],
'map_only_percent': [map_only/total*100],
'traffic_only': [traffic_only],
'traffic_only_percent': [traffic_only/total*100],
'neither': [neither],
'neither_percent': [neither/total*100],
'total_time_minutes': [total_time/60],
'rate_adrns_per_second': [total/total_time]
}
df = pd.DataFrame(results)
csv_file = os.path.join(get_progress_dir(), f"{job_id}_final_results.csv")
df.to_csv(csv_file, index=False)
print(f"Final results saved to {csv_file}")
def count_total_lines(file_path):
"""高效计算文件总行数"""
print("Counting total lines in file...")
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
try:
# 使用mmap提高大文件行数统计效率
buf = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
lines = 0
read_size = 1024 * 1024 # 1MB块大小
buf_size = len(buf)
for i in range(0, buf_size, read_size):
lines += buf[i:i+read_size].count(b'\n')
if stop_event.is_set():
break
return lines
except:
# 备选方案:较慢但兼容性更好
print("Falling back to slower line count method...")
return sum(1 for _ in f)
def process_chunk(chunk, process_func, killer, pbar=None):
"""处理单个数据块"""
results = []
try:
for result in process_func(chunk):
if killer.kill_now:
return results, False
if result is not None:
results.append(result)
if pbar:
pbar.update(1)
except Exception as e:
print(f"Error processing chunk: {e}")
return results, True
def stream_process_file(file_path, num_processes, chunk_size, job_id, killer):
"""流式处理文件,避免一次性加载"""
global PROCESS_POOL
select_channels = ["/pilot/map/position", "/perception/frame/traffic_sign"]
process_func = partial(process_adrn, select_channels=select_channels)
# 计算总行数
total_lines = count_total_lines(file_path)
print(f"Total lines in file: {total_lines}")
# 初始化计数器
has_both_channels_count = 0
has_map_channel_count = 0
has_perception_channel_count = 0
has_no_channel_count = 0
processed_count = 0
current_position = 0
# 检查进度
progress = load_progress(job_id)
if progress:
current_position = progress['current_position']
processed_count = progress['processed_count']
counters = progress['counters']
has_both_channels_count = counters[0]
has_map_channel_count = counters[1]
has_perception_channel_count = counters[2]
has_no_channel_count = counters[3]
print(f"Resuming from position {current_position}")
print(f"Already processed {processed_count}/{total_lines} ADRNs")
start_time = time.time()
# 创建进程池
PROCESS_POOL = mp.Pool(processes=num_processes, initializer=worker_init)
try:
chunk = []
line_count = 0
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
# 定位到上次处理的位置
if current_position > 0:
f.seek(current_position)
# 创建进度条
pbar = tqdm(total=total_lines, initial=processed_count, desc="Processing ADRNs")
# 逐行读取并处理
for line in f:
if killer.kill_now:
break
adrn = line.strip()
if adrn: # 跳过空行
chunk.append(adrn)
line_count += 1
# 达到块大小时处理
if len(chunk) >= chunk_size:
# 并行处理块
pool_process = lambda chunk: PROCESS_POOL.imap_unordered(process_func, chunk)
results, success = process_chunk(pool_process, chunk, killer, pbar)
if not success:
break
# 汇总结果
chunk_both = sum(r[0] for r in results)
chunk_map = sum(r[1] for r in results)
chunk_traffic = sum(r[2] for r in results)
chunk_neither = sum(r[3] for r in results)
has_both_channels_count += chunk_both
has_map_channel_count += chunk_map
has_perception_channel_count += chunk_traffic
has_no_channel_count += chunk_neither
processed_count += len(chunk)
current_position = f.tell() # 保存当前文件位置
# 保存进度
save_progress(
job_id,
current_position,
processed_count,
[has_both_channels_count, has_map_channel_count, has_perception_channel_count, has_no_channel_count]
)
# 计算状态
elapsed = time.time() - start_time
if processed_count > 0:
rate = processed_count / elapsed
remaining = (total_lines - processed_count) / rate
print(f"\nProcessed: {processed_count}/{total_lines} ({processed_count/total_lines*100:.2f}%)")
print(f"Elapsed: {elapsed/60:.2f} minutes, Remaining: {remaining/60:.2f} minutes")
print(f"Rate: {rate:.2f} ADRNs/second")
# 保存中间结果
if processed_count % (10 * chunk_size) == 0:
save_intermediate_results(
job_id, processed_count, total_lines,
has_both_channels_count, has_map_channel_count,
has_perception_channel_count, has_no_channel_count,
elapsed
)
# 清空块并触发垃圾回收
chunk = []
gc.collect()
# 处理最后剩余的数据
if chunk and not killer.kill_now:
pool_process = lambda chunk: PROCESS_POOL.imap_unordered(process_func, chunk)
results, _ = process_chunk(pool_process, chunk, killer, pbar)
chunk_both = sum(r[0] for r in results)
chunk_map = sum(r[1] for r in results)
chunk_traffic = sum(r[2] for r in results)
chunk_neither = sum(r[3] for r in results)
has_both_channels_count += chunk_both
has_map_channel_count += chunk_map
has_perception_channel_count += chunk_traffic
has_no_channel_count += chunk_neither
processed_count += len(chunk)
current_position = f.tell()
save_progress(
job_id,
current_position,
processed_count,
[has_both_channels_count, has_map_channel_count, has_perception_channel_count, has_no_channel_count]
)
pbar.close()
finally:
# 清理进程池
if PROCESS_POOL:
PROCESS_POOL.close()
PROCESS_POOL.join()
PROCESS_POOL = None
# 检查是否被中断
if killer.kill_now:
print("Shutdown complete. Progress saved.")
sys.exit(0)
# 删除进度文件
progress_file = get_progress_file(job_id)
if os.path.exists(progress_file):
os.remove(progress_file)
print(f"Progress file {progress_file} deleted")
# 打印结果
print("\n===== Final Results =====")
print(f"Adrns with both channels: {has_both_channels_count}/{total_lines} ({has_both_channels_count/total_lines*100:.2f}%)")
print(f"Adrns with map channel: {has_map_channel_count}/{total_lines} ({has_map_channel_count/total_lines*100:.2f}%)")
print(f"Adrns with traffic_sign channel: {has_perception_channel_count}/{total_lines} ({has_perception_channel_count/total_lines*100:.2f}%)")
print(f"Adrns with no map and traffic_sign channel: {has_no_channel_count}/{total_lines} ({has_no_channel_count/total_lines*100:.2f}%)")
# 计算耗时
total_time = time.time() - start_time
print(f"\nTotal processing time: {total_time/60:.2f} minutes")
# 保存最终结果
save_final_results(
job_id, total_lines, has_both_channels_count, has_map_channel_count,
has_perception_channel_count, has_no_channel_count, total_time
)
def main():
"""主函数"""
if len(sys.argv) < 2:
print("Usage: python script.py <file_path> [num_processes] [chunk_size]")
sys.exit(1)
file_path = sys.argv[1]
num_processes = int(sys.argv[2]) if len(sys.argv) > 2 else 28
chunk_size = int(sys.argv[3]) if len(sys.argv) > 3 else 10000 # 默认每块10000行
job_id = get_job_id(file_path)
print(f"Job ID: {job_id}")
killer = GracefulKiller()
if not os.path.exists(file_path):
print(f"Error: File '{file_path}' does not exist")
sys.exit(1)
print(f"Reading file in streaming mode, chunk size: {chunk_size}")
print(f"Using {num_processes} processes")
stream_process_file(file_path, num_processes, chunk_size, job_id, killer)
if __name__ == '__main__':
main()
小总
通过这次任务, 我觉得在工作中利用好多进程还是非常能提效的, 最终从30多天, 降到了10个小时完成任务,这是一个非常大的提升。
上面这个代码的设计我觉得非常值的学习, 从终止信号处理,打断点, 大数据的分块, 多进程等都非常值的借鉴, 学习到了很多。
上面这个也可以提炼出一个框架来,去完成不同的任务。