DB-GPT Text2SQL微调方法
编辑DB-GPT Text2SQL 微调实践
1.环境准备
centos 7.9
CUDA=11.7
DB-GPT-HUB 版本为 Feb 19, 2024
这个commit,commitID为 c329ba11c1aae9ad2f59318864896a30e7132a7c
git clone https://github.com/eosphoros-ai/DB-GPT-Hub.git
cd DB-GPT-Hub
conda create -n dbgpt_hub python=3.10
conda activate dbgpt_hub
# 优先安装cuda11.7对应版本的torch
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install poetry
因CUDA版本相对较低问题,如果直接执行 poetry install 将会安装cuda12等包 会导致没法微调 因此这边需要进行处理。
# 使用以下命令进行生成requirements.txt 文件
poetry export --without-hashes -f requirements.txt | tee requirements.txt
之后采用之前的方法,注释或者删除一切有关cuda12或者torch相关的依赖行
pip install -r requirements.txt
至此环境准备结束
2.数据处理
2.1 数据获取
参考官方文档,以 Spider 数据集为示例 :
简介:Spider 数据集是一个跨域的复杂 text2sql 数据集,包含了自然语言问句和分布在 200 个独立数据库中的多条 SQL,内容覆盖了 138 个不同的领域。
下载:下载数据集,并移动到
dbgpt_hub/data
目录下 也就是dbgpt_hub/data/spider
。
2.2 划分数据
DB-GPT-HUB项目使用的是信息匹配生成法进行数据准备,即结合表信息的 SQL + Repository 生成方式,这种方式结合了数据表信息,能够更好地理解数据表的结构和关系,适用于生成符合需求的 SQL 语句。
运行一键脚本,即可在 dbgpt_hub/data/
目录中将得到生成的各类数据集 example_text2sql_train.json
、example_text2sql_dev.json
、 example_text2sql_dev_one_shot.json
、 example_text2sql_train_one_shot.json
## 生成train数据 和dev(eval)数据,
poetry run sh dbgpt_hub/scripts/gen_train_eval_data.sh
其中训练集中为 8659 条,评估集为 1034 条。生成的训练集中数据格式形如:
{
"db_id": "department_management",
"instruction": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n\n",
"input": "###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:",
"output": "SELECT count(*) FROM head WHERE age > 56",
"history": []
}
在 dbgpt_hub/data/dataset_info.json
中配置训练的数据文件,json文件中对应的 key 的值默认为 example_text2sql
,此值即在后续训练脚本 train_sft
中参数 --dataset
需要传入的值, json中的file_name 的值为训练集的文件名字。
2.3 构造数据逻辑
数据处理的核心代码主要在 dbgpt_hub/data_process/sql_data_process.py
中,核心处理 class 是 ProcessSqlData()
,核心处理函数是 decode_json_file()
。
decode_json_file()
首先将 Spider 数据中的 table 信息处理成为字典格式,key 和 value 分别是 db_id 和该 db_id 对应的 table、columns 信息处理成所需的格式,例如:
{
"department_management": department_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.
}
然后将上述文本填充于 config 文件中 INSTRUCTION_PROMPT 的 {} 部分,形成最终的 instruction, INSTRUCTION_PROMPT 如下所示:
INSTRUCTION_PROMPT = "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n ##Instruction:\n{}\n"
最后将训练集和验证集中每一个 db_id 对应的 question 和 query 处理成模型 SFT 训练所需的格式,即上面数据处理代码执行部分所示的数据格式。
3.模型微调
因暂时只训练小模型,只是用单卡就行,所以暂时只写单卡训练
poetry run sh dbgpt_hub/scripts/train_sft.sh
需要注意以下信息
# 确保是 dbgpt_hub/data/dataset_info.json 下的字典的头部key
dataset="example_text2sql_train"
model_name_or_path="填写微调模型的对应path"
# 输出权重文件路径
output_dir="dbgpt_hub/output/adapter/Qwen-7B-Chat-sql-lora"
微调参数
CUDA_VISIBLE_DEVICES=3 python dbgpt_hub/train/sft_train.py \
# 既基座模型文件路径
--model_name_or_path /DB-GPT/DB-GPT/models/Qwen-7B-Chat \
--do_train \
# 取值为训练数据集的配置名字,对应在dbgpt_hub/data/dataset_info.json 中外层key值,如example_text2sql。
--dataset example_text2sql_train \
# 输入模型的文本长度,如果计算资源支持,可以尽能设大,如1024或者2048
--max_source_length 2048 \
# 输出模型的sql内容长度,设置为512一般足够
--max_target_length 512 \
# 微调类型,取值为 [ ptuning、lora、freeze、full ] 等
--finetuning_type lora \
# LoRA 微调时的网络参数更改部分。
--lora_target c_attn \
# 项目设置的不同模型微调的 lora 部分,对于 Llama2 系列的模型均设置为 llama2。
--template chatml \
# LoRA 微调中的秩大小。
--lora_rank 64 \
# LoRA 微调中的缩放系数。
--lora_alpha 32 \
# SFT微调时Peft模块输出的路径,默认设置在dbgpt_hub/output/adapter/路径下。
--output_dir dbgpt_hub/output/adapter/Qwen-7B-Chat-2048_epoch8_lora \
--overwrite_cache \
--overwrite_output_dir \
# batch的大小,如果计算资源支持,可以设置为更大,默认为1。
--per_device_train_batch_size 1 \
# 梯度更新的累计steps值 save_steps : 模型保存的ckpt的steps大小值,默认可以设置为100。
--gradient_accumulation_steps 16 \
# 学习率类型。
--lr_scheduler_type cosine_with_restarts \
# 日志保存的 steps 间隔。
--logging_steps 50 \
# 模型保存的 ckpt 的 steps 大小值。
--save_steps 2000 \
# 学习率,推荐的学习率为 2e-4。
--learning_rate 2e-4 \
# 训练数据的epoch数
--num_train_epochs 8 \
--plot_loss \
--bf16
train_sft.sh 中关键参数与含义介绍(来自DBGPT官方文档):
model_name_or_path :所用 LLM 模型的路径。
dataset :取值为训练数据集的配置名字,对应在 dbgpt_hub/data/dataset_info.json 中外层 key 值,如 example_text2sql。
max_source_length :输入模型的文本长度,本教程的效果参数为 2048,为多次实验与分析后的最佳长度。
max_target_length :输出模型的 sql 内容长度,设置为 512。
template:项目设置的不同模型微调的 lora 部分,对于 Llama2 系列的模型均设置为 llama2。
lora_target :LoRA 微调时的网络参数更改部分。
finetuning_type : 微调类型,取值为 [ ptuning、lora、freeze、full ] 等。
lora_rank : LoRA 微调中的秩大小。
loran_alpha: LoRA 微调中的缩放系数。
output_dir :SFT 微调时 Peft 模块输出的路径,默认设置在 dbgpt_hub/output/adapter/路径下 。
per_device_train_batch_size :每张 gpu 上训练样本的批次,如果计算资源支持,可以设置为更大,默认为 1。
gradient_accumulation_steps :梯度更新的累计steps值。
lr_scheduler_type :学习率类型。
logging_steps :日志保存的 steps 间隔。
save_steps :模型保存的 ckpt 的 steps 大小值。
num_train_epochs :训练数据的 epoch 数。
learning_rate : 学习率,推荐的学习率为 2e-4。
脚本中微调时不同模型对应的关键参数lora_target 和 template,如下表(来自DBGPT-Hub):
4.模型预测
项目目录下./dbgpt_hub/
下的output/pred/
,此文件路径为关于模型预测结果默认输出的位置(如果没有则建上)。
预测运行命令:
poetry run sh ./dbgpt_hub/scripts/predict_sft.sh
需要注意以下几点:
CUDA_VISIBLE_DEVICES=0,1 python dbgpt_hub/predict/predict.py \
# LLMs模型目录目录
--model_name_or_path /DB-GPT/DB-GPT/models/Qwen-7B-Chat \
# 这边需要与上表对应(也就是模型微调时填写的
--template chatml \
--finetuning_type lora \
# 测试集目录
--predicted_input_filename dbgpt_hub/data/example_text2sql_dev.json \
# checkpoint文件路径
--checkpoint_dir dbgpt_hub/output/adapter/Qwen-7B-Chat-2048_epoch8_lora \
# 通过微调模型进行预测后的输出结果存放路径
--predicted_out_filename dbgpt_hub/output/pred/pred_codellama13b.sql >> ${pred_log}
5.模型评估
对于模型在数据集上的效果评估,默认为在spider
数据集上。 运行以下命令来:
# 最后的路径就是 通过微调模型进行预测后的输出结果存放路径 比如:dbgpt_hub/output/pred/pred_codellama13b.sql
poetry run python dbgpt_hub/eval/evaluation.py --plug_value --input Your_model_pred_file
于大模型生成的结果具有一定的随机性,和 temperature 等参数密切相关(可以在 /dbgpt_hub/configs/model_args.py
中的 GeneratingArguments 中进行调整)
5.1 NLTK报错解决
错误信息(例子):
LookupError:
Resource punkt not found.
Please use the NLTK Downloader to obtain the resource:
import nltk
nltk.download('punkt')
For more information see: https://www.nltk.org/data.html
Attempted to load tokenizers/punkt/PY3/english.pickle
Searched in:
- '/root/nltk_data'
- '/home/anaconda/envs/dbgpt_hub/nltk_data'
- '/home/anaconda/envs/dbgpt_hub/share/nltk_data'
- '/home/anaconda/envs/dbgpt_hub/lib/nltk_data'
- '/usr/share/nltk_data'
- '/usr/local/share/nltk_data'
- '/usr/lib/nltk_data'
- '/usr/local/lib/nltk_data'
- ''`
解决方案:
使用本地可以使用访问到NLTK官网的环境进行下载对应依赖
先安装NLTK
pip install NLTK
import nltk
nltk.download('punkt')
然后再将下载后的punkt文件放到指定路径 一般是在conda创建的环境中,如何找到环境对应目录
conda info --envs
# 找到对应环境目录就行
以 /home/anaconda/envs/dbgpt_hub/
为例, 需要在这下面继续创建 nltk_data/tokenizers
这两文件,然后将punkt放入到 tokenizers
下面即可
6.模型权重合并
将训练的基础模型和微调的Peft模块的权重合并,导出一个完整的模型。
poetry run sh ./dbgpt_hub/scripts/export_merge.sh
注意将脚本中的相关参数路径值替换为自己的项目所对应的路径。
# 以当前环境为例
python dbgpt_hub/train/export_model.py \
# 原始模型路径
--model_name_or_path /DB-GPT/DB-GPT/models/Qwen-7B-Chat \
# 参考模型微调的表格
--template chatml \
--finetuning_type lora \
# 模型权重路径
--checkpoint_dir dbgpt_hub/output/adapter/Qwen-7B-Chat-2048_epoch8_lora \
# 合并权重后的模型输出目录
--output_dir dbgpt_hub/output/Qwen-7B-Chat-sql-sft \
--fp16
- 0
- 0
-
分享