Beyond Two-Stage Training: Cooperative SFT and RL for LLM Reasoning
Beyond Two-Stage Training: Cooperative SFT and RL for LLM Reasoning
基本情報
- arXiv ID: 2509.06948v1 (https://arxiv.org/abs/2509.06948v1)
- 著者: Liang Chen, Xueting Han, Li Shen他
- 所属: The Chinese University of Hong Kong, Microsoft Research他
- 投稿日: 2025年09月10日
- カテゴリ: cs.CL, cs.LG, cs.AI
簡単に説明すると
この論文は、大規模言語モデル(LLM)の推論能力を向上させるために、従来の二段階訓練(SFT→RL)を超える新しい協調的訓練フレームワーク「BRIDGE」を提案しています。従来手法では教師あり微調整(SFT)と強化学習(RL)が独立して実行されるため、相互の情報交換が限定的で、破滅的忘却や非効率な探索といった問題が発生していました。
BRIDGEは二層最適化を用いて、SFTを上位レベル問題、RLを下位レベル問題として定式化します。これにより、SFTがRLの最適化プロセスを誘導する方法をメタ学習し、両手法間の緊密な協調を実現します。数学推論ベンチマークでの実験では、従来手法に比べて一貫した性能向上と訓練効率の改善を示しています。
関連リンクについて、以下の情報が公開予定です。
1. 研究概要
1.1 背景と動機
OpenAIのo1モデルやDeepSeek-R1の登場により、大規模言語モデル(LLM)の推論能力は显著に向上しました。これらの進歩の鍵となる技術は大規模なルールベース強化学習です。しかし、RLの本質的な試行錯誤の性質により、訓練プロセスは非常に非効率的です。代替アプローチとして、厳選された長い思考連鎖(CoT)データセットでの教師あり微調整(SFT)があり、模倣学習により効果的な推論パターンを迅速に習得できます。
サンプル効率は高いものの、SFTは通常RLに比べて性能と汎化能力で劣ります。実用的な生産規模の訓練パイプラインでは、RLの前の準備段階としてSFTを使用する多段階パラダイムが採用されることが多いです。しかし、この完全に分離された二段階設定は破滅的忘却と非効率な探索に悩まされ、両手法の強みを十分に活用できません。
1.2 主要な貢献
本研究では、SFTとRLの意味のある情報交換を可能にする訓練フレームワークを設計し、両者の協調が単独のRLを上回る性能を確保することを目指しています。主要な技術的貢献について、以下のような要素が含まれています。
第一に、推論訓練パラダイムの比較分析です。大規模推論モデルの訓練に用いられる3つの主要戦略を体系的に分析し、二段階パイプラインの相互作用不足が破滅的忘却と非効率な探索につながることを明らかにしました。これらの知見に基づき、優れた性能を達成するシンプルな交代ベースラインを導入しています。
第二に、SFTとRLを統合する二層最適化フレームワークです。SFTを上位レベル、RLを下位レベル問題として定式化するBRIDGE手法を提案しています。拡張モデルアーキテクチャ上に構築され、ペナルティベース緩和により解かれ、協調利得を明示的に最大化して共同訓練が単独RLを上回ることを保証しています。
第三に、数学推論ベンチマークでの実証検証です。3つのLLMと5つの数学推論ベンチマークで広範な実験を実施し、BRIDGEが精度と訓練効率の両面で5つのベースラインを一貫して上回り、緊密に統合されたSFT-RL最適化の実用的利益を実証しています。
2. 提案手法
2.1 手法の概要
BRIDGEは、SFTとRLを協調的メタ学習アプローチで緊密に結合するフレームワークです。SFTデータセットとRLデータセットを前提として、政策最適化と教師あり学習を統合する協調的メタ学習定式化を提案しています。
手法の核心は二層の最適化構造にあります。SFTが上位レベル(教師)として機能し、RLの下位レベル(生徒)の最適応答にアクセスして的確な指導を提供します。一方、RLはSFTからのwパラメータによる補助サポートを受けてベースパラメータθを最適化します。この構造により、従来の二段階アプローチの一方向情報フローとは対照的に、双方向情報フローが可能になります。
2.2 技術的詳細
拡張モデルアーキテクチャは、ベースモデルパラメータθとLoRAパラメータwから構成されています。ベースモデルは下位レベルのRL目的で最適化され、LoRAパラメータは上位レベルの教師あり目的で更新されます。
二層の最適化問題を解くために、高価な二階微分計算を避けるペナルティベース手法を採用しています。下位レベル問題の次善性を測定するペナルティ関数を定義し、ペナルティ重みλを用いてペナルティ化された再定式化を得ます。λは小さな値から開始して徐々に増加するアニーリングスケジュールに従い、教師ありデータでの準備段階から二層制約のより厳密な強制へと移行します。
実際の実装では、ダンスキン定理を適用してwに関する勾配を計算し、θ*(w)をRL目的に関する単一勾配上昇ステップで近似します。この近似により、効率的な一階最適化が可能になり、協調利得(共同SFT-RL訓練が単独RLを上回る性能優位性)を明示的に最大化します。
2.3 新規性
本提案手法の新規性は、従来の分離された二段階アプローチの根本的限界を克服する統合フレームワークにあります。既存手法では、SFTとRLが独立して実行されるため、情報交換が制限され、破滅的忘却と非効率な探索が発生していました。
技術的な差別化要因として、二層最適化による協調的メタ学習の定式化が挙げられます。SFTがRLの最適化プロセスを誘導する方法をメタ学習することで、適応的に最も有益な情報をSFTからRLに転移します。この設計により、すべてのSFT更新がRL最適化に有益とは限らないという問題を解決しています。
拡張モデルアーキテクチャも重要な革新です。ベースモデルとLoRAモジュールの分離により、上位・下位レベル目的の共適応が可能になります。この分離がなければ、定式化はMAMLスタイルの設定に縮退し、RLの学習が無効化されて協調が失われてしまいます。
3. 実験結果
3.1 実験設定
実験では、LIMR(1.3k問題)とMATH(8.5k問題)の2つのデータセットをRL訓練に使用しています。SFTデータセットについては、LIMRとMATHのクエリをDeepSeekMath-103kから抽出した対応する中間推論トレースとペアリングしています。評価は5つの数学推論ベンチマーク(MATH500、Minerva Math、OlympiadBench、AIME 2024、AMC 2023)で実施されています。
3つのLLM(Qwen2.5-3B、Llama-3.2-3B-Instruct、Qwen2-8B-Base)で手法の汎用性を実証しています。報酬関数は解答の正確性に基づく二値報酬(正解で+1、不正解で0)を採用し、形式ベースの報酬は除外しています。実装はVERLフレームワークを使用し、プロンプトバッチサイズ64、ミニバッチサイズ64、学習率5×10^-7で設定されています。
3.2 主要な結果
BRIDGEは五つの多様な数学推論ベンチマークで一貫した性能向上を達成しています。全体的に、BRIDGEはRL-zeroとCold-startに対して平均11.8%の改善を示し、様々な難易度のタスクにわたる効果と頑健性を実証しています。
ベースライン手法は比較的簡単なベンチマークでより大きな改善を示す傾向がありますが、より複雑な推論タスクでは汎化性能が劣ります。例えば、Cold-start手法はMinerva Math、Olympiad Bench、AMC23でRL-zeroを下回り、事前SFT段階での過適合が原因と考えられます。対照的に、BRIDGEはより困難なベンチマークで一貫した大幅な改善を達成しています。
訓練動態の分析では、BRIDGEが継続的なSFT指導により急速な報酬成長を実現し、Cold-startを上回って優れた収束を達成することが示されています。Cold-startは初期の長い応答による訓練の非効率性と、その後のRL段階での適切な指導不足により、RL-zeroと同様の収束に留まっています。
3.3 既存手法との比較
包括的な比較実験では、BRIDGEが5つのベースライン手法(Base/Instruct、SFT、RL-zero、Cold-start、Naive Alternating)をすべて上回る性能を示しています。Qwen3-8B-BaseではRL-zeroに対して16.3%、Cold-startに対して9.7%の改善を達成し、Llama3.2-3B-InstructではRL-zeroに対して13.5%、Cold-startに対して30.9%のより顕著な向上を実現しています。
コストパフォーマンス分析では、Cold-startがRL-zeroの約2倍の訓練時間を要するのに対し、BRIDGEは3Bモデルで44%、8Bモデルで14%の時間削減を達成しています。大型モデルでメモリ使用量が11%増加するものの、BRIDGEは一貫して優れた性能改善(3Bで13%、8Bで9.7%)を実現し、実用展開において有利なコストパフォーマンストレードオフを実証しています。
4. 実用性評価
4.1 実装の容易性
BRIDGEの実装は、既存のフレームワークと技術を活用することで実現されています。VERLフレームワークを基盤として、標準的なハイパーパラメータ設定(バッチサイズ64、学習率5×10^-7)で訓練が可能です。LoRAの設定(ランクとα値16)も標準的で、既存の微調整パイプラインに容易に統合できます。
二層最適化の実装は、ペナルティベース緩和により一階最適化に簡略化されており、高価な二階微分計算が不要です。ペナルティ重みのアニーリングスケジュール(0.5に設定)も実装が簡単で、訓練の安定性を確保しています。拡張モデルアーキテクチャは、既存のベースモデルにLoRAモジュールを追加するだけで実現できるため、実装の複雑性は最小限です。
4.2 計算効率
BRIDGEの計算効率は、実用的な展開要件を満たすよう最適化されています。ペナルティベース手法により、計算量の多い二階微分を回避し、効率的な一階最適化を実現しています。メモリ使用量の増加は、大型モデル(8B)で11%程度と適度な範囲に収まっています。
訓練時間の観点では、BRIDGEはCold-startに比べて大幅な時間短縮を実現し、3Bモデルで44%、8Bモデルで14%の改善を達成しています。GPU使用効率も良好で、3Bモデルで4×A100-80GB、8Bモデルで8×MI300-192GBの設定で安定した訓練が可能です。この効率性により、研究機関や企業での実用的な採用が促進されます。
4.3 応用可能性
BRIDGEの応用可能性は、多様なLLMと数学推論タスクでの実証により検証されています。3つの異なるモデルファミリー(Qwen、Llama)で一貫した性能向上を示し、手法の汎用性を実証しています。数学推論以外の領域への拡張も期待され、プログラム合成、定理証明、科学推論などの分野での応用可能性があります。
実用的展開の観点では、既存の訓練パイプラインとの統合が容易で、段階的な導入が可能です。コストパフォーマンスの改善により、限られた計算資源での効果的な推論モデル訓練が実現できます。オープンソース化予定のコード公開により、研究コミュニティでの採用と改良が促進される見込みです。
5. まとめと所感
5.1 論文の意義
本論文は、LLMの推論能力向上における従来の二段階訓練の限界を克服する重要な貢献を提供しています。SFTとRLの協調的統合により、破滅的忘却と非効率な探索という根本的問題を解決し、両手法の相乗効果を実現しています。二層最適化による定式化は理論的にも実用的にも優れたアプローチです。
技術的革新性も高く評価されます。協調的メタ学習フレームワークにより、SFTがRLの最適化プロセスを誘導する方法を学習し、適応的で効果的な知識転移を実現しています。拡張モデルアーキテクチャと一階最適化による実装は、実用性と理論的厳密性の両立を達成しています。
実験的検証の包括性も特筆すべきです。複数のLLM、多様なベンチマーク、詳細な分析により、手法の有効性と一般化能力が十分に実証されています。特に、コストパフォーマンス分析は実用展開への重要な示唆を提供しています。
5.2 今後の展望
本研究の発展方向として、より大規模なモデルと広範な領域への拡張が期待されます。現在の評価は限定的な応用シナリオに留まっているため、プログラム合成、定理証明、科学推論などの分野での検証が重要です。これにより手法の汎用性がさらに確認されることになります。
SFTデータセットの品質改善も重要な課題です。大規模推論モデルから蒸留された推論トレースにはノイズが含まれる可能性があり、自動検出ツールによるフィルタリングが性能向上に寄与する可能性があります。堅牢性の観点では、推論能力の頑健性不足を緩和する堅牢微調整技術の統合が有望です。
理論的発展として、二層最適化の収束保証や最適性理論の整備が挙げられます。また、より効率的な最適化アルゴリズムの開発により、さらなる計算効率の改善が期待されます。実用化に向けては、産業規模での展開とより多様なタスクファミリーでの検証が必要です。