本项目基于 Google 的 PaliGemma2 多模态大语言模型,使用 LoRA 技术对xx场景的结构化视觉问答任务进行微调。
项目背景
任务描述
在xx场景中,准确识别和描述相关事件对于xx质量评估、教学培训和安全监控具有重要意义。本项目旨在构建一个能够理解xx图像并回答结构化问题的多模态AI系统。
技术挑战
- 多模态理解:需要同时处理视觉信息和文本问题
- 结构化输出:要求模型输出特定格式的答案
- 领域专业性:xx场景具有高度的专业性和复杂性
- 实时性要求:在实际应用中需要快速响应
解决方案
- 使用 PaliGemma2 作为基础模型,具备强大的多模态理解能力
- 采用 LoRA 技术进行高效微调,降低计算资源需求
- 构建专业的xx场景数据集进行针对性训练
项目结构
PaliGemma2/
├── paligemma_finetune.py # 模型微调训练脚本
├── paligemma_infer.py # 模型推理脚本
├── paligemma_video_infer.py # 视频推理脚本
├── label_gui.py # 数据标注工具
├── paligemma_xxx_structured_vqa/ # 训练输出目录
│ ├── checkpoint-2400/ # 训练检查点
│ └── runs/ # TensorBoard日志
└── README.md # 项目说明文档
数据集构建
数据格式
数据集采用 JSONL 格式,每条记录包含:
{
"image": "images/merge_000063.jpg",
"question": "answer en Is any of the following xxx events occurring in the current image? Please indicate the phase for each event. You can choose from the following event types: A, B, C, D, E. Phase types: 1, 2, 3, 4.",
"answer": "Based on the image, I can see A in the 2 phase, B in the 3 phase, and no other events are visible."
}
标注工具
使用 label_gui.py
进行数据标注:
- 图形化界面,支持批量标注
- 预设问题模板,确保标注一致性
- 支持多种答案格式
数据集特点
- 手术事件类型:A, B, C, D, E
- 事件阶段:1, 2, 3, 4
- 结构化输出:要求模型输出特定格式的答案
模型训练
训练配置
# 主要配置参数
USE_LORA = True # 使用 LoRA 微调
USE_QLORA = False # 是否使用 QLoRA (4bit量化)
FREEZE_VISION = False # 是否冻结视觉编码器
MODEL_ID = "google/paligemma2-3b-pt-448"
LoRA 配置
lora_config = LoraConfig(
r=8, # LoRA 秩
target_modules=["q_proj", "o_proj", "k_proj", "v_proj",
"gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
训练参数
- 学习率:2e-5
- 批次大小:4 (per device)
- 梯度累积:2
- 训练轮数:10
- 优化器:AdamW
- 混合精度:BF16
开始训练
python paligemma_finetune.py
训练脚本
import os
from PIL import Image
import torch
from datasets import load_dataset
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
Trainer,
TrainingArguments,
BitsAndBytesConfig
)
from peft import get_peft_model, LoraConfig
# ==== 设置 ====
USE_LORA = True
USE_QLORA = False,
FREEZE_VISION = False
DATA_PATH = ""
MODEL_ID = "google/paligemma2-3b-pt-448"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ==== 载入 Processor ====
processor = PaliGemmaProcessor.from_pretrained(MODEL_ID)
# ==== 加载数据 ====
dataset = load_dataset("json", data_files=f"{DATA_PATH}/labels.jsonl", split="train")
dataset = dataset.train_test_split(test_size=0.2)
train_ds = dataset["train"]
eval_ds = dataset["test"]
# ==== 模型加载 ====
if USE_LORA or USE_QLORA:
lora_config = LoraConfig(
r=8,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
if USE_QLORA:
bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.bfloat16)
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID,device_map="auto",quantization_config=bnb_config)
else:
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID,device_map="auto")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
else:
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(DEVICE)
if FREEZE_VISION:
for param in model.vision_tower.parameters():
param.requires_grad = False
for param in model.multi_modal_projector.parameters():
param.requires_grad = False
DTYPE = model.dtype
# ==== Collate 函数 ====
def collate_fn(examples):
texts = ["<image>answer en " + example["question"] for example in examples]
labels = [example["answer"] for example in examples]
images = [Image.open(os.path.join(DATA_PATH, example["image"])).convert("RGB") for example in examples]
tokens = processor(text=texts, images=images, suffix=labels,
return_tensors="pt", padding="longest")
tokens = tokens.to(DTYPE).to(DEVICE)
return tokens
# ==== 训练参数 ====
args = TrainingArguments(
output_dir="paligemma_surgical_structured_vqa",
num_train_epochs=10,
remove_unused_columns=False,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
warmup_steps=2,
learning_rate=2e-5,
weight_decay=1e-6,
adam_beta2=0.999,
logging_steps=100,
optim="adamw_torch", # you can use paged optimizers like paged_adamw_8bit for QLoRA
save_strategy="steps",
save_steps=1000,
save_total_limit=1,
bf16=True,
report_to=["tensorboard"],
dataloader_pin_memory=False
)
# ==== Trainer ====
trainer = Trainer(
model=model,
train_dataset=train_ds,
eval_dataset=eval_ds,
data_collator=collate_fn,
args=args
)
# ==== 开始训练 ====
trainer.train()
模型推理
单张图片推理
python paligemma_infer.py \
--image_path /path/to/image.jpg \
--lora_path paligemma_xxx_structured_vqa/checkpoint-2400
推理脚本功能
- 自动加载 LoRA 权重
- 支持自定义提示词
- 批量处理能力
- 结果可视化
输出示例
>> Model Response:
Based on the image, I can see A in the 2 phase,
c in the 3 phase, and no other events are visible.