大模型微调-损失收敛&动态调整 Epochs(THS)

一、数据集准备

1.1 数据集介绍

本次微调采用 Hugging Face 平台上的 fahmiaziz/text2sql-dataset (https://huggingface.co/datasets/fahmiaziz/text2sql-dataset)数据集。该数据集包含 126,642 条样本,每条样本包含三列:

  • prompt: 用户用自然语言描述的查询需求。
  • context: 查询目标数据库的表结构语句(DDL)。
  • answer: 对应的标准 SQL 查询语句。

1.2 数据集多样性分析

为评估数据集任务覆盖范围,分析统计了以下五类 SQL 任务的出现频率:

  • 聚合统计 (Aggregation): 查询包含 SUMAVGCOUNTMINMAX 等聚合函数。
  • 多表关联 (Multi-table)FROM 子句后包含多个表名(通过识别表名数量 >= 2 判断)。
  • 子查询 (Subquery): 查询包含子查询(通过检测 FROM (SELECT ...) 或 WITH 子句判断)。
  • 时间范围筛选 (Time Filter): 查询包含 DATEMONTHYEARBETWEEN 等与时间筛选相关的关键词。
  • 唯一值提取 (Distinct): 查询包含 SELECT DISTINCT ... 。

分析方法: 编写 Python 脚本(关键逻辑见下方代码框)遍历数据集中的每条 answer (SQL 语句),利用正则表达式识别上述特征并计数。

import pandas as pd

import re

import matplotlib.pyplot as plt

# 读取数据

df = pd.read_parquet('text2sql-dataset/data/train-00000-of-00001.parquet', engine='pyarrow')

# 定义检测函数

def detect_sql_features(sql_query):

    sql_query = sql_query.lower()

    # 聚合统计:SUM, AVG, COUNT, MIN, MAX

    aggregation = bool(re.search(r'\b(sum|avg|count|min|max)\b', sql_query))

    # 多表关联:通过识别 FROM 后面出现多个表名

    multi_table = False

    from_match = re.search(r'from\s+([a-zA-Z0-9_,\s]+)', sql_query)

    if from_match:

        tables_str = from_match.group(1)

        tables = re.split(r'[,\s]+', tables_str.strip())

        tables = [t for in tables if and not t.startswith('on'and not t.isdigit()]

        if len(set(tables)) >= 2:

            multi_table = True

    # 子查询:SELECT 在括号内或 WITH

    subquery = bool(re.search(r'from\s+$$select|with', sql_query, re.IGNORECASE))

    # 时间筛选:DATE、MONTH、YEAR、BETWEEN 等关键词

    time_filter = bool(re.search(r'\b(date|month|year|between|where.*\d{4}-\d{2}-\d{2})\b', sql_query))

    # 唯一值提取:DISTINCT 或 UNIQUE

    distinct = bool(re.search(r'\b(distinct|unique)\b', sql_query))

    return {

        'aggregation': aggregation,

        'multi_table': multi_table,

        'subquery': subquery,

        'time_filter': time_filter,

        'distinct': distinct

    }

# 初始化计数器

task_counter = {'aggregation'0'multi_table'0'subquery'0'time_filter'0'distinct'0}

# 遍历每一行的 SQL 查询并更新计数器

for index, row in df.iterrows():

    sql_query = row['answer']  # 使用 answer 字段中的 SQL 查询

    features = detect_sql_features(sql_query)

    for key in task_counter:

        if features[key]:

            task_counter[key] += 1

# 可视化

task_names = ['Aggregation''Multi-table''Subquery''Time Filter''Distinct']

counts = [

    task_counter['aggregation'],

    task_counter['multi_table'],

    task_counter['subquery'],

    task_counter['time_filter'],

    task_counter['distinct']

]

plt.figure(figsize=(106))

bars = plt.bar(task_names, counts, color='skyblue')

plt.title('Frequency of SQL Task Types in Answer Data', fontsize=16)

plt.xlabel('Task Type', fontsize=14)

plt.ylabel('Frequency', fontsize=14)

plt.grid(axis='y', linestyle='--', alpha=0.7)

# 显示数值标签

for bar in bars:

    yval = bar.get_height()

    plt.text(bar.get_x() + bar.get_width() / 2, yval + 0.5int(yval), ha='center', va='bottom')

plt.tight_layout()

plt.show()

分析结果:

  • 聚合统计 (Aggregation): 68,533 条
  • 多表关联 (Multi-table): 119,092 条
  • 子查询 (Subquery): 370 条
  • 时间范围筛选 (Time Filter): 22,953 条
  • 唯一值提取 (Distinct): 3,997 条

结论: 数据集在聚合、多表关联和时间筛选任务上分布广泛。子查询样本量显著偏少,推测可能由于多数子查询逻辑可通过多表连接实现转换。总体而言,该数据集具备任务多样性特征,适合用于 Text2SQL 模型微调。

1.3 数据预处理(适配百炼平台格式)

百炼平台微调要求输入数据为两列 Excel 文件:

  1. Prompt: 输入提示词。
  2. Completion: 期望的模型输出(目标 SQL)。

原始数据集包含 promptcontextanswer 三列,需进行转换:

Step1:提示词拼接: 将 prompt (用户问题) 和 context (表结构) 按以下模板组合成完整的 Prompt

"Based on the following table structure:

{context}

Please convert the following natural language question into an SQL query:

{prompt}

"

Step2:列重命名与导出: 保留拼接后的 Prompt 和原始 answer 列,将 answer 重命名为 Completion,导出为 Excel 文件。

完整代码如下:

import pandas as pd

df = pd.read_parquet('text2sql-dataset/data/train-00000-of-00001.parquet', engine='pyarrow')

# 定义拼接模板

def build_prompt_template(prompt, context):

    return f"""Based on the following table structure:

{context}

Please convert the following natural language question into an SQL query:

{prompt}

"""

# 拼接 Prompt 字段

df['Prompt'= df.apply(lambda row: build_prompt_template(row['prompt'], row['context']), axis=1)

# 准备导出的 DataFrame

export_df = df[['Prompt''answer']].rename(columns={'answer''Completion'})

# 导出到 Excel

output_path = 'text2sql-dataset/data/train_text2sql_prompt_with_context_en.xlsx'

export_df.to_excel(output_path, index=False)

print(f"Data has been successfully exported to {output_path}")

预处理后的数据格式符合百炼平台要求。

1.4 导入数据集

将生成的 train_text2sql_prompt_with_context_en.xlsx 文件导入百炼平台数据集管理模块,并完成发布操作,使其可用于模型训练。

二、模型微调任务配置

2.1 模型选择

  • 基础模型: Qwen2.5-7B
  • 训练方式: 高效训练

2.2 建立验证集

采用百炼平台提供的自动切分功能,由平台在训练时自动从导入的数据集中划分出一部分作为验证集。

2.3 参数调整

  • 关键调整: 将默认的训练轮数 (Epochs) 由 3 调整为 1

调整依据: 在首次尝试训练(Epochs=3)过程中,观察到训练损失 (Training Loss) 在早期即迅速收敛。为避免后续轮次带来的不必要计算开销和成本,手动终止了该次训练任务,并在新任务中减少了 Epochs。

三、训练过程与结果

3.1 训练资源消耗

  • 训练时长: 15 小时 02 分 55 秒
  • Token 消耗: 21,292,839 Token
  • 费用估算: 按平台单价 0.006 元/千 Token 计算,费用约为 127.76 元 (21,292,839 / 1000 * 0.006 ≈ 127.76)。

3.2 训练效果评估

  • 训练损失 (Training Loss): 如损失曲线图所示,损失值在训练早期即快速下降并趋于稳定,最终收敛至 0.09左右。这表明模型较快地学习到了数据中的模式。
    • 注:平台未提供手动停止训练功能,导致训练持续完成整个设定的轮数(1 Epoch),尽管损失已提前收敛。

    

  • 验证损失 (Validation Loss): 损失曲线显示验证损失持续稳定下降,最终达到 0.093 左右,表明模型在未见数据上具备良好的泛化能力
  • 验证集准确率 (Validation Accuracy): 达到 0.96,表明模型在验证集上生成正确 SQL 的比例很高。

       

四、心得体会

  1. 监控损失收敛: 在模型微调过程中,密切监控训练损失曲线的收敛情况至关重要。本次训练在早期轮次即表现出显著收敛迹象。
  2. 动态调整 Epochs: 基于实时的损失收敛趋势,适时调整训练轮数 (Epochs) 是优化成本效益的关键策略。本次通过将 Epochs 从默认值 3 降至 1,有效避免了资源浪费(尽管平台未支持提前停止导致首个任务部分轮次浪费)。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值