基于自定义VQA数据集LoRA微调paligemma2

  本项目基于 Google 的 PaliGemma2 多模态大语言模型,使用 LoRA 技术对xx场景的结构化视觉问答任务进行微调。

项目背景

任务描述

  在xx场景中,准确识别和描述相关事件对于xx质量评估、教学培训和安全监控具有重要意义。本项目旨在构建一个能够理解xx图像并回答结构化问题的多模态AI系统。

技术挑战

  • 多模态理解:需要同时处理视觉信息和文本问题
  • 结构化输出:要求模型输出特定格式的答案
  • 领域专业性:xx场景具有高度的专业性和复杂性
  • 实时性要求:在实际应用中需要快速响应

解决方案

  • 使用 PaliGemma2 作为基础模型,具备强大的多模态理解能力
  • 采用 LoRA 技术进行高效微调,降低计算资源需求
  • 构建专业的xx场景数据集进行针对性训练

项目结构

PaliGemma2/
├── paligemma_finetune.py      # 模型微调训练脚本
├── paligemma_infer.py         # 模型推理脚本
├── paligemma_video_infer.py   # 视频推理脚本
├── label_gui.py              # 数据标注工具
├── paligemma_xxx_structured_vqa/  # 训练输出目录
│   ├── checkpoint-2400/       # 训练检查点
│   └── runs/                  # TensorBoard日志
└── README.md                 # 项目说明文档

数据集构建

数据格式

数据集采用 JSONL 格式,每条记录包含:

{
    "image": "images/merge_000063.jpg",
    "question": "answer en Is any of the following xxx events occurring in the current image? Please indicate the phase for each event. You can choose from the following event types: A, B, C, D, E. Phase types: 1, 2, 3, 4.",
    "answer": "Based on the image, I can see A in the 2 phase, B in the 3 phase, and no other events are visible."
}

标注工具

使用 label_gui.py 进行数据标注:

  • 图形化界面,支持批量标注
  • 预设问题模板,确保标注一致性
  • 支持多种答案格式

数据集特点

  • 手术事件类型:A, B, C, D, E
  • 事件阶段:1, 2, 3, 4
  • 结构化输出:要求模型输出特定格式的答案

模型训练

训练配置

# 主要配置参数
USE_LORA = True          # 使用 LoRA 微调
USE_QLORA = False        # 是否使用 QLoRA (4bit量化)
FREEZE_VISION = False    # 是否冻结视觉编码器
MODEL_ID = "google/paligemma2-3b-pt-448"

LoRA 配置

lora_config = LoraConfig(
    r=8,  # LoRA 秩
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", 
                   "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

训练参数

  • 学习率:2e-5
  • 批次大小:4 (per device)
  • 梯度累积:2
  • 训练轮数:10
  • 优化器:AdamW
  • 混合精度:BF16

开始训练

python paligemma_finetune.py

训练脚本

import os
from PIL import Image
import torch
from datasets import load_dataset
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
    Trainer,
    TrainingArguments,
    BitsAndBytesConfig
)
from peft import get_peft_model, LoraConfig

# ==== 设置 ====
USE_LORA = True
USE_QLORA = False,
FREEZE_VISION = False
DATA_PATH = ""
MODEL_ID = "google/paligemma2-3b-pt-448"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ==== 载入 Processor ====
processor = PaliGemmaProcessor.from_pretrained(MODEL_ID)

# ==== 加载数据 ====
dataset = load_dataset("json", data_files=f"{DATA_PATH}/labels.jsonl", split="train")
dataset = dataset.train_test_split(test_size=0.2)
train_ds = dataset["train"]
eval_ds = dataset["test"]

# ==== 模型加载 ====
if USE_LORA or USE_QLORA:
    lora_config = LoraConfig(
        r=8,
        target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
        task_type="CAUSAL_LM",
    )
    
    if USE_QLORA:
        bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.bfloat16)
        model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID,device_map="auto",quantization_config=bnb_config)
    else:
        model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID,device_map="auto")
    
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
else:
    model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(DEVICE)
    if FREEZE_VISION:
        for param in model.vision_tower.parameters():
            param.requires_grad = False
        for param in model.multi_modal_projector.parameters():
            param.requires_grad = False
DTYPE = model.dtype
# ==== Collate 函数 ====
def collate_fn(examples):
    texts = ["<image>answer en " + example["question"] for example in examples]
    labels = [example["answer"] for example in examples]
    images = [Image.open(os.path.join(DATA_PATH, example["image"])).convert("RGB") for example in examples]
    tokens = processor(text=texts, images=images, suffix=labels,
                      return_tensors="pt", padding="longest")
    tokens = tokens.to(DTYPE).to(DEVICE)
    return tokens



# ==== 训练参数 ====
args = TrainingArguments(
    output_dir="paligemma_surgical_structured_vqa",
    num_train_epochs=10,
    remove_unused_columns=False,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    warmup_steps=2,
    learning_rate=2e-5,
    weight_decay=1e-6,
    adam_beta2=0.999,
    logging_steps=100,
    optim="adamw_torch", # you can use paged optimizers like paged_adamw_8bit for QLoRA
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=1,
    bf16=True,
    report_to=["tensorboard"],
    dataloader_pin_memory=False
)

# ==== Trainer ====
trainer = Trainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=collate_fn,
    args=args
)

# ==== 开始训练 ====
trainer.train()

模型推理

单张图片推理

python paligemma_infer.py \
    --image_path /path/to/image.jpg \
    --lora_path paligemma_xxx_structured_vqa/checkpoint-2400

推理脚本功能

  • 自动加载 LoRA 权重
  • 支持自定义提示词
  • 批量处理能力
  • 结果可视化

输出示例

>> Model Response:
Based on the image, I can see A in the 2 phase, 
c in the 3 phase, and no other events are visible.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值