参考:论文《Scaling Relationship on Learning Mathematical Reasoning with Large Language Models》、https://www.zhihu.com/tardis/bd/art/703848627https://zhuanlan.zhihu.com/p/507830576

  RFT(Rejection sampling Fine-Tuning,拒绝采样微调)的整体思路是使用多个模型生成推理路径,经过质量筛选和多样性筛选之后,获得增强的数据集,其中每一个问题都对应了多种解析,用作训练集训练模型。RFT 的步骤如下:

1、训练一轮小模型和大模型:首先利用数据集(如 GSM8K)D={qi,ri,ai}iD=\{q_i,r_i,a_i\}_{i} 将预训练大语言模型 ρ\rho 通过监督微调获得 SFT 模型π\piqiq_i 其中表示问题,rir_i 表示推理,aia_i 表示答案;

2、选择推理路径:对于每一个问题qiq_i,使用 SFT 模型 π\pi 来生成kk 个候选推理路径rr 和答案aa,过滤掉模型生成答案和标准答案不一致的推理路径,并删除具有相同方程列表的其他推理路径以降低推理路径的重复,增加数据集的多样性;

3、获得增强数据集定义D=D{qi,ri,j,ai}i,jD'=D\cup\{qi,r_{i,j},a_i\}_{i,j} 作为增强数据集。使用数据集DD' 在预训练模型ρ\rho 上进行微调得到πRFT\pi_{RFT},得到 RFT 模型。
使用新的数据集微调模型:使用新的推理数据集 R^s 微调一轮小模型,并将其用于微调右边的 llama2-70B;

  从单个 SFT 模型中采样的推理路径可能在逻辑上是非多样化的。因此,可以利用从不同模型聚合的拒绝采样推理路径来进一步提高数学推理性能,实现多模型拒绝采样。

  RFT 的关键点是步骤 2 中选择推理路径。具体做法如下图所示:

  面对重复路径时,利用 Levenstein Distance(莱文斯坦距离算法)获得与所有已经保存路径最不相似的推理路径,从而增加增强数据集的多样性。

  莱文斯坦距离指的是将一个字符串变为另一个字符串需要进行编辑操作最少的次数。其中,允许的编辑操作有以下三种:

  • 「替换」:将一个字符替换成另一个字符
  • 「插入」:插入一个字符
  • 「删除」:删除一个字符

  莱文斯坦距离用于衡量两个字符串之间的差异,被广泛应用于拼写纠错检查、DNA 分析、语音识别等领域。

更新于 阅读次数