Description
Hi everyone,
I have the following scenario.
I have a machine with 2-GPUs and a running service that keep has two pipelines loaded to their corresponding devices. Also I have a list of LoRAs (say 10). On each request I split the batch into 2 parts (request also has the corresponding information about LoRA), load LoRAs and run the forward pass.
The problem I encounter is that whatever parallelization method I have tried (threading, multi-processing), the maximum I have achieved is pre-loading LoRAs on the cpu and then, moving them to GPU and only after that load_lora_weights
from the state_dict.
Even if I attempt to achieve parallelization in by calling the chunk where I load in parallel in threads, the pipe starts to produce either a complete noise or a black image.
Where I would appreciate a lot the help is:
- To get an advice of elegantly loading multiple LoRAs at once into one pipe (all examples in the documentation indicate that one needs to do it 1 by 1)
- If I have 2 pipes on 2 different devices, how to parallelize the process of loading 1 LoRA to pipes on their corresponding devices.
def apply_multiple_loras_from_cache(pipes, adapter_names, lora_cache, lora_names, lora_strengths, devices):
for device_index, pipe in enumerate(pipes):
logger.info(f"Starting setup for device {devices[device_index]}")
# Step 1: Unload LoRAs
start = time.time()
pipe.unload_lora_weights(reset_to_overwritten_params=False)
logger.info(f"[Device {device_index}] Unload time: {time.time() - start:.3f}s")
# Step 2: Parallelize CPU → GPU state_dict move
def move_to_device(name):
return name, {
k: v.to(devices[device_index], non_blocking=True).to(pipe.dtype)
for k, v in lora_cache[name]['state_dict'].items()
}
start = time.time()
with ThreadPoolExecutor() as executor:
future_to_name = {executor.submit(move_to_device, name): name for name in adapter_names}
results = [future.result() for future in as_completed(future_to_name)]
logger.info(f"[Device {device_index}] State dict move + dtype conversion time: {time.time() - start:.3f}s")
# Step 3: Load adapters
start = time.time()
for adapter_name, state_dict in results:
pipe.load_lora_weights(
pretrained_model_name_or_path_or_dict=state_dict,
adapter_name=adapter_name
)
logger.info(f"[Device {device_index}] Load adapter weights time: {time.time() - start:.3f}s")
# Step 4: Set adapter weights
start = time.time()
pipe.set_adapters(lora_names, adapter_weights=lora_strengths)
logger.info(f"[Device {device_index}] Set adapter weights time: {time.time() - start:.3f}s")
torch.cuda.empty_cache()
logger.info("All LoRAs applied and GPU cache cleared.")