# 生成微调数据

# 准备工作

DISC-LawLLM 主要基于 DISC-Law-SFT-Pair 和 DISC-Law-SFT-Triplet 数据集进行训练,这两个数据集对应了法律专业人员助手、法律咨询服务、法律考试助手、法律专业人员助手和法律咨询服务的情景。

这次实践,使用 Self-Instruct 方法,利用 DISC-Law-SFT-Pair 和 DISC-Law-SFT-Triplet 数据集生成微调数据,以此微调 DISC-LawLLM。Self-Instruct 方法是怎么工作的?如下图:

seed tasks 是源数据,Self-Instruct 根据这些源数据生成微调数据。总共有四步:根据 seed tasks 生成指令,判断是否为分类问题,生成实例,微调准备工作。对应四个 bash 文件:

# 1. Generate instructions from the seed tasks
./scripts/generate_instructions.sh
# 2. Identify whether the instruction represents a classification task or not
./scripts/is_clf_or_not.sh
# 3. Generate instances for each instruction
./scripts/generate_instances.sh
# 4. Filtering, processing, and reformatting
./scripts/prepare_for_finetuning.sh

阅读 generate_instructions.sh 以及相关的代码,发现 Self-Instruct 生成数据的逻辑是借助大模型自身去生成数据。那么对于 DISC-Law-SFT-Pair 和 DISC-Law-SFT-Triplet 两个数据集,就是利用法律大模型 DISC-LawLLM 去生成新的法律数据。

观察 self-instruct 项目自带的 seed tasks 示例数据,以及 DISC-Law-SFT-Pair 和 DISC-Law-SFT-Triplet 两个数据集。数据集均是 jsonl 格式,seed tasks 示例数据的每一行都是类似如下的格式:

{
    "id": "seed_task_0", 
    "name": "breakfast_suggestion", 
    "instruction": "Is there anything I can eat for a breakfast that doesn't include eggs, yet includes protein, and has roughly 700-1000 calories?", 
    "instances": [
        {
        "input": "", 
        "output": "Yes, you can have 1 oatmeal banana protein shake and 4 strips of bacon. The oatmeal banana protein shake may contain 1/2 cup oatmeal, 60 grams whey protein powder, 1/2 medium banana, 1tbsp flaxseed oil and 1/2 cup watter, totalling about 550 calories. The 4 strips of bacon contains about 200 calories."
        }
    ], 
    "is_classification": false
}

DISC-Law-SFT-Pair 的数据每一行则是类似如下格式:

{
    "id": "jud_doc_sum-1", 
    "input": "请大致描述这篇文书的内容。\n\n唐鲜明与何伟华、深圳市华名威电汽车服务有限公司侵权责任纠纷一审民事判决书\n......\n没有证据或者证据不足以证明当事人的事实主张的,由负有举证责任的当事人承担不利后果。", 
    "output": "总结:原告与被告侵权责任纠纷一案。原告提出诉求:判决被告赔偿原告医疗费、住院伙食补助费、营养费、护理费、误工费、残疾赔偿金、精神损害抚慰金、后续治疗费、......根据《中华人民共和国民事诉讼法》第六十四条第一款、《最高人民法院关于民事诉讼证据的若干规定》第二条的规定,判决:驳回原告全部诉讼请求。"
}

DISC-Law-SFT-Triplet 的数据每一行则是类似如下格式:

{
    "id": "judgement_predit-1", 
    "reference": 
    [
        "《刑法》第一百一十四条:【放火罪】【决水罪】【爆炸罪】【投放危险物质罪】......尚未造成严重后果的,处三年以上十年以下有期徒刑。"
    ], 
    "input": "基于下列案件,推测可能的判决结果。\n经审理查明,2015年6月21日15时许,被告人白某某在大东区小河沿公交车站乘坐被害人张某某驾驶的133路公交车,......鉴定结论书,被告人白某某的供述与辩解等证据证实,足以认定。", 
    "output": "根据《刑法》第一百一十四条的规定,被告人白某某以危险方法危害公共安全,尚未造成严重后果。......\n需要注意的是,以上仅为根据提供的案件事实和相关法律条文进行的推理,最终的判决结果还需根据法庭审理的具体情况和法官的判断做出。"
}

因此,需要将法律数据整理成 seed tasks 的格式。观察数据,很明显,无论是 DISC-Law-SFT-Pair 还是 DISC-Law-SFT-Triplet,input 里都包含指令,也就是 seed tasks 里的 instruction,所以需要对 input 进行数据提取。

当数据量相当大的时候,毫无疑问不能人工处理,我们需要用一个部署好的中文大模型进行数据提取,并处理成 seed tasks 的格式。

# 数据提取与预处理

import json
import asyncio
import os
from openai import AsyncOpenAI, APIConnectionError
from tqdm.asyncio import tqdm
# ================= 配置区域 =================
INPUT_FILE = "data/law/DISC-Law-SFT-Triplet-released.jsonl"            # 原始文件路径
OUTPUT_FILE = "data/law/DISC-Law-SFT-Triplet-formatted_data.jsonl" # 结果文件路径
BASE_URL = "http://localhost:8000/v1" # 本地 vLLM 地址
API_KEY = "EMPTY"                     # 本地不需要 Key
MODEL_NAME = "qwen"            # vLLM 启动时的 served-model-name
MAX_CONCURRENCY = 100                 # 并发数
API_TIMEOUT = 60                      # 单次请求超时时间 (秒)
# ===========================================
# 初始化客户端,设置超时
client = AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL, timeout=API_TIMEOUT)
SYSTEM_PROMPT = """
你是一个文本提取助手。
任务:从用户提供的文本中,精准提取出“指令”(Instruction)。
指令通常位于文本的【开头】或【结尾】,是类似“请总结文书”、“请描述内容”、“请根据案件中的事实,推理出可能的判决”的要求。
约束:
1. 只输出指令原文,不要加任何标点、前缀或解释。
2. 必须是从原文中摘录的连续片段,不要改写,以便我后续进行文本匹配删除。
"""
async def process_line(sem, line, index):
    """处理单行数据:提取指令 -> 删除指令 -> 格式化"""
    async with sem:
        line_id = index
        try:
            record = json.loads(line)
            original_input = record.get("input", "")
            original_output = record.get("output", "")
            # 使用原有的 id,如果没有则生成一个
            record_id = record.get("id", f"task_{line_id}")
            if not original_input:
                return None
            # --- 1. 预处理:截断文本 (优化速度) ---
            # 只取前 800 和后 800 字符,跳过中间无关内容
            if len(original_input) > 2000:
                prompt_input = original_input[:800] + "\n...[OMITTED]...\n" + original_input[-800:]
            else:
                prompt_input = original_input
            # --- 2. 调用 LLM ---
            response = await client.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": prompt_input}
                ],
                temperature=0.0, # 确保输出稳定
                max_tokens=200
            )
            
            extracted_instruction = response.choices[0].message.content.strip()
            # --- 3. 后处理:在原文中删除提取的指令 ---
            clean_input = original_input
            
            # 尝试查找并删除(仅替换第一次出现)
            if extracted_instruction in original_input:
                clean_input = original_input.replace(extracted_instruction, "", 1).strip()
            
            # --- 4. 构造新格式 ---
            new_record = {
                "id": record_id,
                "name": "doc_processing",
                "instruction": extracted_instruction,
                "instances": [
                    {
                        "input": clean_input,
                        "output": original_output
                    }
                ],
                "is_classification": False
            }
            return json.dumps(new_record, ensure_ascii=False)
        except APIConnectionError:
            # 连接错误通常意味着 vLLM 没开或者挂了
            print(f"\n[Error] Line {line_id}: 无法连接到本地模型 API。请检查 vLLM 是否运行在端口 8000。")
            return None
        except Exception as e:
            # 其他错误(格式错误、超长等)
            # print (f"\n [Error] Line {line_id}: {e}") # 调试时可打开
            return None
        
async def main():
    # 1. 确保输出目录存在
    output_dir = os.path.dirname(OUTPUT_FILE)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"已创建输出目录: {output_dir}")
    # 2. 断点续传逻辑
    processed_count = 0
    if os.path.exists(OUTPUT_FILE):
        with open(OUTPUT_FILE, 'r', encoding='utf-8') as f:
            try:
                processed_count = sum(1 for _ in f)
            except Exception:
                pass
        print(f"发现断点,已处理 {processed_count} 行,将从此处继续...")
    # 3. 读取输入文件
    if not os.path.exists(INPUT_FILE):
        print(f"错误:找不到输入文件 {INPUT_FILE}")
        return
    with open(INPUT_FILE, 'r', encoding='utf-8') as f:
        all_lines = f.readlines()
    
    lines_to_process = all_lines[processed_count:]
    if not lines_to_process:
        print("所有数据已处理完毕。")
        return
    # 4. 准备任务
    sem = asyncio.Semaphore(MAX_CONCURRENCY)
    tasks = []
    
    # 创建任务列表
    for i, line in enumerate(lines_to_process):
        real_index = i + processed_count
        tasks.append(process_line(sem, line, real_index))
    print(f"开始处理剩余 {len(tasks)} 行数据...")
    print(f"配置: 并发={MAX_CONCURRENCY}, 超时={API_TIMEOUT}秒")
    print("提示: 你可以随时按 Ctrl+C 中止,数据会自动保存。")
    # 5. 执行循环 (关键修改部分)
    # 使用 as_completed 确保完成一个写入一个
    
    completed_in_session = 0
    
    with open(OUTPUT_FILE, 'a', encoding='utf-8') as f_out:
        # 将 tasks 转换为一个即时返回的迭代器
        futures = asyncio.as_completed(tasks)
        
        try:
            # 使用 tqdm 包装 futures 迭代器来显示进度
            for future in tqdm(futures, total=len(tasks), unit="lines"):
                result = await future
                if result:
                    f_out.write(result + "\n")
                    f_out.flush() # 关键:每写一行立即刷入硬盘
                completed_in_session += 1
                
        except KeyboardInterrupt:
            print("\n" + "="*50)
            print(f"检测到 Ctrl+C 中断信号!")
            print(f"正在停止... (本轮已安全写入 {completed_in_session} 条数据)")
            print(f"你可以随时重新运行脚本,程序会自动从第 {processed_count + completed_in_session} 行继续。")
            print("="*50)
            # 这里不需要做额外操作,with open 块结束会自动关闭文件句柄
            return
    print(f"\n处理完成!总共新增处理 {completed_in_session} 行。")
if __name__ == "__main__":
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        pass

这份代码,将 DISC-Law-SFT-Pair 和 DISC-Law-SFT-Triplet 处理成 seed tasks 类似的格式,同时提取 input 中的指令内容,有这些内容需要注意:

  • 文件读取与断点续传是必要的,当程序被中断后继续执行,可以从上次处理位置继续处理。
  • 对数据进行观察,发现指令大多出现在 input 字段的首尾,那么可以考虑只截取 input 的前后若干长度段。
  • 使用 openai 的 AsyncOpenAI 对大模型进行异步访问,其 chat.completions.create 方法是用大模型进行推理的核心,model 参数根据模型名找到大模型接口进行访问,messages 为系统(大模型)和用户指定提示,一般 system 指定为提示,用户 user 指定为给大模型输入的文本和任务。

那么怎么使用这份代码呢?首先你需要下载一个中文大模型,这里我用 Qwen2.5-7B。我选择在魔塔社区下载 Qwen2.5-B,毕竟能在国内下载肯定优于到国外 huggingface 下载,速度肯定是更快的。借助魔塔社区 Qwen2.5-7B 页面的方法,执行下面的 python 文件即可下载,下载路径可以在 cache_dir 修改:

pip install modelscope
from modelscope import snapshot_download
# 下载到当前目录下的 qwen_models 文件夹
model_dir = snapshot_download('Qwen/Qwen2.5-7B-Instruct', cache_dir='./qwen_models')
print(f"模型已下载路径: {model_dir}")

第二步是安装 vllm,在自己的 anaconda 环境下执行:

pip install vllm

第三步,打开 Qwen 大模型的推理接口。新建一个 bash 文件,内容如下:

python -m vllm.entrypoints.openai.api_server --model qwen_models/Qwen/Qwen2___5-7B-Instruct --served-model-name qwen --max-model-len 8192 --gpu-memory-utilization 0.8 --port 8000

打开一个终端窗口,执行这个 bash 文件即可打开推理接口,之后使用 Qwen 需要保持终端窗口的打开(实际是保证 vllm 程序的运行)。model 参数是本地的大模型文件夹路径,其文件夹下的内容大致类似如下:

served-model-name 参数将大模型赋予一个 id,后面的访问根据 id 找到对应模型。port 参数指定了大模型接口的端口号,后面会根据端口号访问大模型。

最后,在代码中找到如下区域,修改 INPUT_FILE 为源数据路径,OUTPUT_FILE 为处理结果的数据路径。BASE_URL 格式一般为 http://localhost:端口号/v1 。API_KEY 由于是本地部署,所以随意指定即可。MODEL_NAME 是模型 id。

# ================= 配置区域 =================
				 ..........
# ===========================================

也可以使用 DISC-LawLLM 替代 Qwen2.5-7B 去做数据提取和预处理,说不定效果会更好。

首先执行上面那个 bash 文件,保持大模型接口的打开,然后再进行数据处理。最后,我对所有的法律数据集都进行了处理,汇总得到了一份格式与 seed tasks 匹配的、样本量足够多的人工标注数据。接下来可以开始用 self-instruct 方法生成数据了。

# 得到合成数据

得到正确格式的人工标注数据后,就可以按照四个步骤,一步步执行来得到合成数据了。观察其实不难发现,四个步骤的脚本其实十分相似,以第一步 —— 根据人工标注数据生成指令的代码为例:

import os
import json
import random
import re
import string
import tqdm
import argparse
import numpy as np
import pandas as pd
from multiprocessing import Pool
from functools import partial
from rouge_score import rouge_scorer
from gpt3_api import make_requests as make_gpt3_requests
from openai import AsyncOpenAI, APIConnectionError
# ================= 配置区域 =================
BASE_URL = "http://localhost:8000/v1" # 本地 vLLM 地址
API_KEY = "EMPTY"                     # 本地不需要 Key
MODEL_NAME = "lawllm"            # vLLM 启动时的 served-model-name
MAX_CONCURRENCY = 100                 # 并发数
API_TIMEOUT = 180                      # 单次请求超时时间 (秒)
# ===========================================
# 初始化客户端,设置超时
client = AsyncOpenAI(api_key=API_KEY, base_url=BASE_URL, timeout=API_TIMEOUT)
random.seed(42)
def encode_prompt(prompt_instructions, classification=False):
    if classification:
        prompt = "设计一系列分类任务。尽可能明确可能的输出标签。\n"
    else:
        prompt = "制定一系列任务:\n"
    for idx, instruction in enumerate(prompt_instructions):
        instruction = re.sub(r"\s+", " ", instruction).strip().rstrip(":")
        prompt += f"{idx+1}. {instruction}\n"
    prompt += f"{len(prompt_instructions) + 1}."
    return prompt
def sample_machine_instructions(machine_instructions, similarities, n):
    return random.sample(machine_instructions, min(n, len(machine_instructions)))
def post_process_gpt3_response(response):
    if response is None or response["choices"][0]["finish_reason"] == "length":
        return []
    raw_instructions = re.split(r"\n\d+\s?\. ", response["choices"][0]["text"])
    instructions = []
    for inst in raw_instructions:
        inst = re.sub(r"\s+", " ", inst).strip()
        # inst = inst.strip().capitalize()
        inst = inst.strip()
        
        if inst == "":
            continue
        if inst and inst[0] in string.punctuation:
             continue
        
        instructions.append(inst)
    return instructions
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--batch_dir",
        type=str,
        required=True,
        default="data/gpt3_generations/",
        help="The directory where the batch is stored.",
    )
    parser.add_argument(
        "--seed_tasks_path",
        type=str,
        required=True,
        default="data/seed_tasks.jsonl",
        help="The path to the human written data.",
    )
    parser.add_argument(
        "--num_instructions_to_generate",
        type=int,
        default=100,
        help="th",
    )
    parser.add_argument(
        "--use_clf_seed_tasks_only",
        action="store_true",
        help="If specified, we will only use the classification seed tasks to prompt new instructions. This will lead to more classification instructions.",
    )
    parser.add_argument(
        "--engine",
        type=str,
        default="davinci",
        help="The engine to use."
    )
    parser.add_argument(
        "--num_prompt_instructions",
        type=int,
        default=8,
        help="The number of instructions to use in the prompt."
    )
    parser.add_argument(
        "--request_batch_size",
        type=int,
        default=5,
        help="The number of requests to send to GPT3 at a time."
    )
    parser.add_argument(
        "--api_key",
        type=str,
        help="The API key to use. If not specified, the key will be read from the environment variable OPENAI_API_KEY."
    )
    parser.add_argument(
        "--organization",
        type=str,
        help="The organization to use. If not specified, the default organization id will be used."
    )
    return parser.parse_args()
async def main():
    args = parse_args()
    seed_tasks = [json.loads(l) for l in open(args.seed_tasks_path, "r")]
    if args.use_clf_seed_tasks_only:
        seed_tasks = [t for t in seed_tasks if t["is_classification"]]
    seed_instructions = [t["instruction"] for t in seed_tasks]
    print(f"Loaded {len(seed_instructions)} human-written seed instructions")
    
    os.makedirs(args.batch_dir, exist_ok=True)
    request_idx = 0
    # load the LM-generated instructions
    machine_instructions = []
    if os.path.exists(os.path.join(args.batch_dir, "machine_generated_instructions.jsonl")):
        with open(os.path.join(args.batch_dir, "machine_generated_instructions.jsonl"), "r") as fin:
            for line in fin:
                instruction_info = json.loads(line)
                machine_instructions.append(instruction_info["instruction"])
                request_idx = instruction_info["request_idx"] + 1
        print(f"Loaded {len(machine_instructions)} machine-generated instructions")
    # similarities = {}
    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False)
    
    # now let's generate new instructions!
    progress_bar = tqdm.tqdm(total=args.num_instructions_to_generate)
    if machine_instructions:
        progress_bar.update(len(machine_instructions))
    with open(os.path.join(args.batch_dir, "machine_generated_instructions.jsonl"), "a") as fout:
        while len(machine_instructions) < args.num_instructions_to_generate:
            batch_inputs = []
            for _ in range(args.request_batch_size):
                # sample machine instructions from the pool
                prompt_instructions = sample_machine_instructions(
                    machine_instructions, 
                    similarities=None,
                    n=2)
                # sample human instructions from the pool
                prompt_instructions += random.sample(seed_instructions, args.num_prompt_instructions - len(prompt_instructions))
                random.shuffle(prompt_instructions)
                prompt = encode_prompt(prompt_instructions, classification=args.use_clf_seed_tasks_only)
                batch_inputs.append(prompt)
            
            tasks = []
            for prompt_text in batch_inputs:
                task = client.chat.completions.create(
                    model=MODEL_NAME,
                    messages=[
                        {"role": "system", "content": "你是LawLLM,一个由复旦大学DISC实验室创造的法律助手。"},
                        {"role": "user", "content": prompt_text}
                    ],
                    # 您可以根据需要调整这些参数
                    temperature=0.1, 
                    top_p=0.9,
                    max_tokens=256,
                    stop=["\n\n", "\n16", "16.", "16 ."],
                )
                tasks.append(task)
                
            # --- 并发执行所有请求 ---
            try:
                responses = await asyncio.gather(*tasks, return_exceptions=True)
            except APIConnectionError as e:
                print(f"API Connection Error: {e}")
                continue # 跳过当前批次
            except Exception as e:
                print(f"An unexpected error occurred: {e}")
                continue
            
            # --- 格式化结果以兼容 post_process_gpt3_response ---
            results = []
            for response in responses:
                if isinstance(response, Exception):
                    # 处理单个请求的异常(如超时、API 错误等)
                    print(f"Individual Request Error: {response}")
                    results.append(None)
                    continue
                
                # 将 Chat Completions API 的响应转换为 GPT-3 Completions API 的旧格式,以兼容 post_process_gpt3_response
                if not response.choices:
                    results.append(None)
                    continue
                message_content = response.choices[0].message.content
                finish_reason = response.choices[0].finish_reason
                
                results.append({
                    "response": {
                        "choices": [{
                            "text": message_content,
                            "finish_reason": finish_reason
                        }]
                    }
                })
                
            instructions = []
            all_metadata = []
            for result in results:
                if result is None:
                    continue
                new_instructions = post_process_gpt3_response(result["response"])
                instructions += new_instructions
                all_metadata += [result] * len(new_instructions)
            for inst, metadata in zip(instructions, all_metadata):
                most_similar_instructions = {}
                rouge_scores = []
                machine_instructions.append(inst)
                fout.write(json.dumps({
                    "instruction": inst,
                    "metadata": metadata,
                    "request_idx": request_idx
                }, ensure_ascii=False) + "\n")
                progress_bar.update(1)
            request_idx += 1
if __name__ == "__main__":
    import asyncio
    asyncio.run(main())

因为四个步骤使用的是 gpt3.5 做的数据生成,而当前的数据是中文文本,为了更好的适配,我们要使用中文大模型,因此需要做点代码上的调整,上面的代码已经是调整后的结果。

保持大模型接口的打开,然后运行脚本处理数据,经过四步处理得到合成数据。

四个步骤的脚本有相似的地方,需要关注的点有:

  • 不同步骤,给大模型的提示不同,需要对应修改提示。
  • 输出文件的路径写在了 main 函数中,处理完后可以对应找到输出文件。
  • 用一个列表 tasks 接收大模型输出的结果,然后最后用 await 与 gather 收集异步的每个结果,这种做法能比串行访问大模型更快地处理完得到结果。
  • client.chat.completions.create 方法的输出结果不是单纯的模型输出文本,只有用.choices [0].message.content 的方式可以获取输出文本,这值得注意。

最后一个步骤做完后,对应得到了用于微调的合成数据,可以发现这个数据比原来少了很多,因为最后一个步骤会对合成数据进行去重和过滤。如果嫌数据量不够,可以考虑再生成多些数据。

得到了合成数据,数据格式如下。尽管这是由 DISC-LawLLM 自己生成的,质量可能不怎么样,但可以尝试微调了。

{
	"instruction": "请为下面的法律文书的案情描述片段贴上标签", 
	"input": "文书:本院在审理中还查明,原、被告于婚姻关系存续期间有夫妻共同债务。", 
	"output": "有夫妻共同债务"
}

# 用 LLama-Factory 进行微调

下一步需要下载模型权重以及微调工具 LLama-Factory。我使用的配置是 4090D(24G)进行微调,因此选择 7B 模型会合适稳妥。后面实测,7B 模型在 batch 为 4 的时候就已经跑到了 20G,在 batch 为 8 的时候就爆显存了。

# 准备工作

首先下载 LLama-Factory,并做好环境配置:

git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[torch,metrics]" --no-build-isolation

然后下载 DISC-LawLLM-7B,可以考虑到魔塔社区下载。确保安装了 modelscope:

pip install modelscope

然后在终端虚拟环境执行:

modelscope download --model luohuan02/ShengbinYue-LawLLM-7B --local_dir 'lawllm'

--local_dir 指定模型的本地目录。

# 数据与配置文件

将微调数据移动到 LLaMA-Factory/data 中,修改名字为 lawllm_data.jsonl,然后打开 dataset_info.json 文件,加入字段:

"lawllm": {
    "file_name": "lawllm_data.jsonl",
    "columns": {
      "prompt": "instruction",
      "query": "input",
      "response": "output"
    }
},

DISC-LawLLM 的项目里,提供了用 Lora 微调的配置文件 lawllm_lora_sft.yaml,将这个 yaml 文件移动到 LLaMA-Factory/examples/train_lora 中,修改里面的 model_name_or_path(模型下载目录),output_dir(训练内容输出路径),per_device_train_batch_size(批次)。

# 微调

在终端执行下面命令即可开始微调训练:

llamafactory-cli train examples/train_lora/lawllm_lora_sft.yaml

用下面命令监控 GPU 情况:

watch nvidia-smi

一开始发现 GPU 占用率极低,大概在 10%~30% 左右波动。

GPU 占用率低会导致训练效率很低,因为 GPU 是并行地训练,所以需要花更多的时间才能结束训练。

查找原因,发现是 deepspeed 的问题。我们看到 lawllm_lora_sft.yaml,发现在 deepspeed 有三个选择:

ds_z0_config.json, ds_z2_config.json, ds_z3_config.json

而 yaml 文件中默认是用 ds_z3_config.json。大模型从零开始的训练,一般需要多机多卡并行,这个时候就需要机器之间有足够强的通信,因此 z3 能提供很强的机器通信,但是同样需要消耗很多 CPU 资源。因此在我用 z3 进行单机微调时,看到 GPU 的利用率在 10% 波动一会后,突升到 30% 左右然后又快速降到 10% 波动。

关于ZeRO-3

DeepSpeed ZeRO-3 的设计初衷是为了在多卡甚至多机环境下,训练超大模型(比如显存放不下的模型)。

  • ZeRO-3 做了什么? 它会将模型参数、梯度、优化器状态全部切片,分散存储。如果显存不够,它还会开启 CPU Offload(将参数卸载到 CPU 内存中)。
  • 在单卡上发生的事情:当使用 ds_z3_config.json(通常默认开启了 CPU Offload)时,GPU 在计算每一层时,都需要从 CPU 内存中通过 PCIe 通道把参数 “拉” 过来,计算完再 “扔” 掉。如果 PCIe 的传输速度(特别是如果是 PCIe 3.0 或 4.0 x8 等情况)低于 GPU 内部显存的读写速度,那么 GPU 绝大部分时间都在等待数据传输(IO Bound),只有极少时间在做真正的矩阵乘法。这就是为什么看到利用率只有 10%-30%。

因此,我修改为使用 ds_z2_config.json 进行微调(ds_z0_config.json 为纯 DPO),一下子占用率高了很多,训练所需的时间一下子缩小了很多。

除了 deepspeed 能影响 GPU 占用率,batch、preprocessing_num_workers、dataloader_num_workers 也会影响占用率,具体需要根据实际排查。