yyhhyy's blog

yyhhyy

DB-GPT Text2SQL微调方法

656
2024-03-07

DB-GPT Text2SQL 微调实践

1.环境准备

centos 7.9

CUDA=11.7

DB-GPT-HUB 版本为 Feb 19, 2024 这个commit,commitID为 c329ba11c1aae9ad2f59318864896a30e7132a7c

git clone https://github.com/eosphoros-ai/DB-GPT-Hub.git
cd DB-GPT-Hub
conda create -n dbgpt_hub python=3.10 
conda activate dbgpt_hub
# 优先安装cuda11.7对应版本的torch
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install poetry

因CUDA版本相对较低问题,如果直接执行 poetry install 将会安装cuda12等包 会导致没法微调 因此这边需要进行处理。

# 使用以下命令进行生成requirements.txt 文件
poetry export --without-hashes -f requirements.txt | tee requirements.txt

之后采用之前的方法,注释或者删除一切有关cuda12或者torch相关的依赖行

pip install -r requirements.txt

至此环境准备结束

2.数据处理

2.1 数据获取

参考官方文档,以 Spider 数据集为示例 :

  • 简介:Spider 数据集是一个跨域的复杂 text2sql 数据集,包含了自然语言问句和分布在 200 个独立数据库中的多条 SQL,内容覆盖了 138 个不同的领域。

  • 下载:下载数据集,并移动到dbgpt_hub/data目录下 也就是dbgpt_hub/data/spider

2.2 划分数据

DB-GPT-HUB项目使用的是信息匹配生成法进行数据准备,即结合表信息的 SQL + Repository 生成方式,这种方式结合了数据表信息,能够更好地理解数据表的结构和关系,适用于生成符合需求的 SQL 语句。

运行一键脚本,即可在 dbgpt_hub/data/目录中将得到生成的各类数据集 example_text2sql_train.jsonexample_text2sql_dev.jsonexample_text2sql_dev_one_shot.jsonexample_text2sql_train_one_shot.json

## 生成train数据 和dev(eval)数据,
poetry run sh dbgpt_hub/scripts/gen_train_eval_data.sh

其中训练集中为 8659 条,评估集为 1034 条。生成的训练集中数据格式形如:

{
  "db_id": "department_management",
  "instruction": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n\n",
  "input": "###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:",
  "output": "SELECT count(*) FROM head WHERE age  >  56",
  "history": []
}

dbgpt_hub/data/dataset_info.json 中配置训练的数据文件,json文件中对应的 key 的值默认为 example_text2sql,此值即在后续训练脚本 train_sft 中参数 --dataset 需要传入的值, json中的file_name 的值为训练集的文件名字。

2.3 构造数据逻辑

数据处理的核心代码主要在 dbgpt_hub/data_process/sql_data_process.py 中,核心处理 class 是 ProcessSqlData(),核心处理函数是 decode_json_file()

decode_json_file()首先将 Spider 数据中的 table 信息处理成为字典格式,key 和 value 分别是 db_id 和该 db_id 对应的 table、columns 信息处理成所需的格式,例如:


{
  "department_management": department_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.
}

然后将上述文本填充于 config 文件中 INSTRUCTION_PROMPT 的 {} 部分,形成最终的 instruction, INSTRUCTION_PROMPT 如下所示:

INSTRUCTION_PROMPT = "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n ##Instruction:\n{}\n"

最后将训练集和验证集中每一个 db_id 对应的 question 和 query 处理成模型 SFT 训练所需的格式,即上面数据处理代码执行部分所示的数据格式。

3.模型微调

因暂时只训练小模型,只是用单卡就行,所以暂时只写单卡训练

poetry run sh dbgpt_hub/scripts/train_sft.sh

需要注意以下信息

# 确保是 dbgpt_hub/data/dataset_info.json 下的字典的头部key
dataset="example_text2sql_train"
model_name_or_path="填写微调模型的对应path"
# 输出权重文件路径
output_dir="dbgpt_hub/output/adapter/Qwen-7B-Chat-sql-lora"

微调参数

CUDA_VISIBLE_DEVICES=3 python dbgpt_hub/train/sft_train.py \
    # 既基座模型文件路径
    --model_name_or_path /DB-GPT/DB-GPT/models/Qwen-7B-Chat \
    --do_train \
    # 取值为训练数据集的配置名字,对应在dbgpt_hub/data/dataset_info.json 中外层key值,如example_text2sql。
    --dataset example_text2sql_train \
    #  输入模型的文本长度,如果计算资源支持,可以尽能设大,如1024或者2048
    --max_source_length 2048 \
    # 输出模型的sql内容长度,设置为512一般足够
    --max_target_length 512 \
    #  微调类型,取值为 [ ptuning、lora、freeze、full ] 等
    --finetuning_type lora \
    # LoRA 微调时的网络参数更改部分。
    --lora_target c_attn \
    # 项目设置的不同模型微调的 lora 部分,对于 Llama2 系列的模型均设置为 llama2。
    --template chatml \
    # LoRA 微调中的秩大小。
    --lora_rank 64 \
    # LoRA 微调中的缩放系数。
    --lora_alpha 32 \
    # SFT微调时Peft模块输出的路径,默认设置在dbgpt_hub/output/adapter/路径下。
    --output_dir dbgpt_hub/output/adapter/Qwen-7B-Chat-2048_epoch8_lora \
    --overwrite_cache \
    --overwrite_output_dir \
    # batch的大小,如果计算资源支持,可以设置为更大,默认为1。
    --per_device_train_batch_size 1 \
    # 梯度更新的累计steps值 save_steps : 模型保存的ckpt的steps大小值,默认可以设置为100。
    --gradient_accumulation_steps 16 \
    # 学习率类型。
    --lr_scheduler_type cosine_with_restarts \
    # 日志保存的 steps 间隔。
    --logging_steps 50 \
    # 模型保存的 ckpt 的 steps 大小值。
    --save_steps 2000 \
    # 学习率,推荐的学习率为 2e-4。
    --learning_rate 2e-4 \
    # 训练数据的epoch数
    --num_train_epochs 8 \
    --plot_loss \
    --bf16

train_sft.sh 中关键参数与含义介绍(来自DBGPT官方文档):

  • model_name_or_path :所用 LLM 模型的路径。

  • dataset :取值为训练数据集的配置名字,对应在 dbgpt_hub/data/dataset_info.json 中外层 key 值,如 example_text2sql。

  • max_source_length :输入模型的文本长度,本教程的效果参数为 2048,为多次实验与分析后的最佳长度。

  • max_target_length :输出模型的 sql 内容长度,设置为 512。

  • template:项目设置的不同模型微调的 lora 部分,对于 Llama2 系列的模型均设置为 llama2。

  • lora_target :LoRA 微调时的网络参数更改部分。

  • finetuning_type : 微调类型,取值为 [ ptuning、lora、freeze、full ] 等。

  • lora_rank : LoRA 微调中的秩大小。

  • loran_alpha: LoRA 微调中的缩放系数。

  • output_dir :SFT 微调时 Peft 模块输出的路径,默认设置在 dbgpt_hub/output/adapter/路径下 。

  • per_device_train_batch_size :每张 gpu 上训练样本的批次,如果计算资源支持,可以设置为更大,默认为 1。

  • gradient_accumulation_steps :梯度更新的累计steps值。

  • lr_scheduler_type :学习率类型。

  • logging_steps :日志保存的 steps 间隔。

  • save_steps :模型保存的 ckpt 的 steps 大小值。

  • num_train_epochs :训练数据的 epoch 数。

  • learning_rate : 学习率,推荐的学习率为 2e-4。

脚本中微调时不同模型对应的关键参数lora_target 和 template,如下表(来自DBGPT-Hub):

模型名

lora_target

template

LLaMA-2

q_proj,v_proj

llama2

CodeLlama-2

q_proj,v_proj

llama2

Baichuan2

W_pack

baichuan2

Qwen

c_attn

chatml

sqlcoder-7b

q_proj,v_proj

mistral

sqlcoder2-15b

c_attn

default

InternLM

q_proj,v_proj

intern

XVERSE

q_proj,v_proj

xverse

ChatGLM2

query_key_value

chatglm2

LLaMA

q_proj,v_proj

-

BLOOM

query_key_value

-

BLOOMZ

query_key_value

-

Baichuan

W_pack

baichuan

Falcon

query_key_value

-

4.模型预测

项目目录下./dbgpt_hub/下的output/pred/,此文件路径为关于模型预测结果默认输出的位置(如果没有则建上)。
预测运行命令:

poetry run sh ./dbgpt_hub/scripts/predict_sft.sh

需要注意以下几点:

CUDA_VISIBLE_DEVICES=0,1  python dbgpt_hub/predict/predict.py \
    # LLMs模型目录目录
    --model_name_or_path /DB-GPT/DB-GPT/models/Qwen-7B-Chat \
    # 这边需要与上表对应(也就是模型微调时填写的
    --template chatml \
    --finetuning_type lora \
    # 测试集目录
    --predicted_input_filename dbgpt_hub/data/example_text2sql_dev.json \
    # checkpoint文件路径
    --checkpoint_dir dbgpt_hub/output/adapter/Qwen-7B-Chat-2048_epoch8_lora \
    # 通过微调模型进行预测后的输出结果存放路径
    --predicted_out_filename dbgpt_hub/output/pred/pred_codellama13b.sql >> ${pred_log}

5.模型评估

对于模型在数据集上的效果评估,默认为在spider数据集上。 运行以下命令来:

# 最后的路径就是 通过微调模型进行预测后的输出结果存放路径 比如:dbgpt_hub/output/pred/pred_codellama13b.sql
poetry run python dbgpt_hub/eval/evaluation.py --plug_value --input  Your_model_pred_file

于大模型生成的结果具有一定的随机性,和 temperature 等参数密切相关(可以在 /dbgpt_hub/configs/model_args.py 中的 GeneratingArguments 中进行调整)

5.1 NLTK报错解决

错误信息(例子):

LookupError:

Resource punkt not found.
Please use the NLTK Downloader to obtain the resource:

import nltk
nltk.download('punkt')

For more information see: https://www.nltk.org/data.html

Attempted to load tokenizers/punkt/PY3/english.pickle

Searched in:
- '/root/nltk_data'
- '/home/anaconda/envs/dbgpt_hub/nltk_data'
- '/home/anaconda/envs/dbgpt_hub/share/nltk_data'
- '/home/anaconda/envs/dbgpt_hub/lib/nltk_data'
- '/usr/share/nltk_data'
- '/usr/local/share/nltk_data'
- '/usr/lib/nltk_data'
- '/usr/local/lib/nltk_data'
- ''`

解决方案:

使用本地可以使用访问到NLTK官网的环境进行下载对应依赖

先安装NLTK

pip install NLTK
import nltk
nltk.download('punkt')

然后再将下载后的punkt文件放到指定路径 一般是在conda创建的环境中,如何找到环境对应目录

conda info --envs
# 找到对应环境目录就行

/home/anaconda/envs/dbgpt_hub/ 为例, 需要在这下面继续创建 nltk_data/tokenizers 这两文件,然后将punkt放入到 tokenizers 下面即可

6.模型权重合并

将训练的基础模型和微调的Peft模块的权重合并,导出一个完整的模型。

poetry run sh ./dbgpt_hub/scripts/export_merge.sh

注意将脚本中的相关参数路径值替换为自己的项目所对应的路径。

# 以当前环境为例
python dbgpt_hub/train/export_model.py \
    # 原始模型路径
    --model_name_or_path /DB-GPT/DB-GPT/models/Qwen-7B-Chat \
    # 参考模型微调的表格
    --template chatml \
    --finetuning_type lora \
    # 模型权重路径
    --checkpoint_dir dbgpt_hub/output/adapter/Qwen-7B-Chat-2048_epoch8_lora \
    # 合并权重后的模型输出目录
    --output_dir dbgpt_hub/output/Qwen-7B-Chat-sql-sft \
    --fp16