KaVa: Latent Reasoning via Compressed KV-Cache Distillation

著者 Anna Kuzina, Maciej Pioro, Paul N. Whatmough, Babak Ehteshami Bejnordi
所属 Qualcomm AI Research, IDEAS NCBR / IPPT PAN
投稿日 2025年10月03日
カテゴリ cs.AI, cs.LG

KaVa: Latent Reasoning via Compressed KV-Cache Distillation

基本情報

  • arXiv ID: 2510.02312v1 (https://arxiv.org/abs/2510.02312)
  • 著者: Anna Kuzina, Maciej Pioro, Paul N. Whatmough, Babak Ehteshami Bejnordi
  • 所属: Qualcomm AI Research, IDEAS NCBR / IPPT PAN
  • 投稿日: 2025年10月03日
  • カテゴリ: cs.AI, cs.LG

簡単に説明すると

大規模言語モデル(LLM)の推論において、従来の明示的なChain-of-Thought(CoT)手法は高精度を実現します。しかし冗長な推論トレースが計算コストとメモリ使用量を著しく増加させる問題がありました。この研究では、推論過程を内部化する「潜在推論」を効率的に学習させるため、教師モデルの圧縮されたKV-cacheから知識を蒸留する新しいフレームワーク「KaVa」を提案している。

KaVaは、教師モデルが生成する完全なCoTの圧縮されたKV-cacheを監督信号として活用します。学生モデルの連続的な潜在トークンの軌道を直接監督することで、トークンレベルの対応を失った圧縮キャッシュからでも効果的な知識蒸留を実現します。これにより、自然言語による推論トレースでの性能劣化を効果的に抑制し、推論時の効率性を保持しながら高精度な潜在推論モデルの訓練を可能にしています。

1. 研究概要

1.1 背景と動機

大規模言語モデルの推論能力向上において、Chain-of-Thought(CoT)は数学や科学、コード生成などの複雑な問題解決において重要な役割を果たしている。
DeepSeek-R1のような最新モデルは、中間ステップを明示的に生成することで長期的な推論問題の精度を向上させることを実証している。

しかし、明示的なCoTには重大な課題が存在する。
まず、冗長で詳細な推論トレースが推論時のコストを著しく増加させます。さらに、メモリやコンピューティングリソースが制約されたデバイスでの展開を困難にしています。
さらに、大規模モデルから蒸留されたCoTトレースには、バイアスや偽の論理が含まれる可能性があり、信頼性の問題が存在する。

近年の研究では、CoTの基盤となるKV-cacheが高度に冗長であり、精度をほとんど失うことなく圧縮可能であることが示されている。
R-KVやKeyDiffなどの手法により、推論の本質的な動的構造は圧縮可能な構造に存在し、不可欠なテキストではないことが明らかになっている。
この観察により、モデルを推論時において冗長なトレースなしで動的構造を内部化するよう訓練できることが示されています。

1.2 主要な貢献

本研究は、圧縮されたKV-cacheからの知識蒸留という革新的なアプローチを通じて、潜在推論における監督の欠如という根本的な課題に取り組んでいる。

主要な技術的貢献として、層ごと・ヘッドごとの独立した刈り込み決定により、トークンとの直接的な対応を失った圧縮KV-cacheから自己蒸留を通じて知識を抽出することを初めて実証しています。
この技術革新は、従来のトークンレベルの蒸留手法では対処できない問題を解決している。

さらに、圧縮されたKV-cacheを段階的な監督信号として活用することで、従来手法が困難としていた自然言語推論トレースから効果的に学習する潜在推論器の訓練を実現している。
従来の潜在推論手法は、テンプレート的な短いトレースでは成功を収めるものの、実世界の推論ワークロードをより適切に反映する長い自然言語推論シーケンスでは性能が劣化する問題があった。

実証的な評価においては、自然言語設定での強力な潜在ベースラインに対する一貫した優位性を示し、方程式のみから自然言語トレースへの移行における性能劣化の削減を実現し、より大規模なバックボーンへのスケーラビリティを実証している。

2. 提案手法

2.1 手法の概要

KaVaフレームワークは、教師・学生両方の役割を持つ単一モデルによる自己監督学習を採用している。
システムは3つの主要コンポーネントから構成されます。

第一に、教師モードと学生モードを交互に実行するバックボーンモデルが存在する。
教師モードでは、完全なCoTを消費して層ごと・ヘッドごとのKV-cacheを構築する。
学生モードでは、連続的な潜在思考を生成し、これらの潜在トークンは同じ自己回帰モデルによって生成されるが、埋め込みのハードトークンへのマッピングをバイパスし、訓練可能な投影層を通じて次のトークン予測に使用される。

第二のコンポーネントは、冗長性と重要度を考慮した排除モジュールで、教師キャッシュを潜在バジェットに圧縮する。
この圧縮プロセスでは、各層・各ヘッドで独立して重要でないKVペアを排除し、結果として得られる圧縮キャッシュは元のトークン対応を失うものの、推論の本質的な構造を保持している。

第三のコンポーネントは、KVマッチング損失で、学生の段階的な潜在KとVを、スタック全体を通じて圧縮されたターゲットに整合させる。
これにより、学生が明示的な推論の圧縮キャッシュのように「思考」することを学習させる強力な段階的内部監督信号を提供している。

2.2 技術的詳細

潜在推論の訓練目的において、潜在推論は観測されない中間ステップ Z = {z_i}_{i=1}^M を導入し、これが明示的推論トレース C の代替として機能する。
潜在推論シーケンスは特別なトークン で始まり、M個の連続トークンが続き、 で推論段階の終了を示す。

KV-cache圧縮には、重要度スコアリング機能が組み込まれている。
各レイヤーlとヘッドhにおいて、キーと値のペア (K^{l,h}, V^{l,h}) は重要度スコア s^{l,h}_i によってランク付けされ、上位k個のペアのみが保持される。
この圧縮プロセスは層間で独立して実行されるため、異なる層で異なるトークンインデックスが保持される可能性がある。

蒸留損失は、学生の潜在表現と教師の圧縮KV-cacheの間のL2距離として定義される。
具体的には、各ステップt、レイヤーl、ヘッドhにおいて、学生の潜在キー・値表現 (K^{l,h}{student,t}, V^{l,h}{student,t}) を、対応する圧縮された教師の表現 (K^{l,h}{teacher,compressed,t}, V^{l,h}{teacher,compressed,t}) に整合させる損失が計算される。

2.3 新規性

従来の知識蒸留手法は、主にトークンレベルの活性化や層レベルの隠れ状態のマッチングに依存していたが、KaVaは圧縮KV-cacheの抽象的で非構造化された知識を直接監督信号として活用する点で革新的である。

既存の潜在推論手法であるCODIやPCCoTと比較して、KaVaは終点レベルやトークンレベルの監督ではなく、圧縮されたKV軌道の段階的監督を提供する点で差別化されている。
CODIは終点の監督に焦点を当て、PCCoTは並列的な潜在更新を採用するが、KaVaは内部KV空間での段階的な軌道整合により、より豊富な監督信号を提供している。

また、圧縮KV-cacheのトークン対応の喪失という課題に対し、連続潜在トークンの表現力の柔軟性を活用して段階的KV軌道を整合させるアプローチは、従来手法では対処できない問題を解決している。

3. 実験結果

3.1 実験設定

実験は複数のモデルサイズと推論データセットで実施されている。
ベースラインモデルとして、Qwen2.5-0.5B-Instruct、Llama3.2-1B-Instruct、Llama3.2-3B-Instructが使用されている。

評価データセットには、GSM8k、GSM8k-Hard、SVAMPの数学的推論タスクが含まれ、さらにGSM8k-AUGとGSM8k-AUG-NLの二つの設定で評価されている。
GSM8k-AUGは方程式スタイルの短い推論トレースを含み、GSM8k-AUG-NLは自然言語スタイルの長い推論トレースを含んでいる。

比較手法として、Full CoT(上限性能)、No-CoT(ベースライン)、既存の潜在推論手法であるCODI、PCCoT、iCoT、Coconutが使用されている。
評価指標はテスト精度で、分布内テストデータセットでの性能と、分布外データセットでのゼロショット評価の両方が実施されている。

3.2 主要な結果

Qwen2.5-0.5Bモデルでの結果では、KaVaはGSM8k-AUGにおいてGSM8kで46.9%(±1.4%)の精度を達成し、CODIの37.5%を9.4ポイント上回っている。
GSM8k-Hardでは10.8%(±0.1%)、SVAMPでは50.6%(±0.4%)と、全ての評価において最良の性能を示している。

特に注目すべきは、GSM8k-AUG-NLにおける性能である。
自然言語推論トレースの設定では、KaVaはGSM8kで44.4%(±1.8%)を達成し、CODIの20.2%を24.2ポイント上回っている。
この結果は、従来の潜在推論手法が自然言語トレースで大幅な性能劣化を示すのに対し、KaVaの圧縮KV-cache蒸留が効果的に機能していることを示している。

Llama3.2-1BおよびLlama3.2-3Bでも一貫した改善が観察されており、スケーラビリティが確認されている。
Llama3.2-1BでのGSM8k性能は56.5%(±0.4%)に達し、CODIの55.6%を上回り、自然言語設定では55.7%(±0.4%)とCODIやPCCoTを大きく上回っている。

3.3 既存手法との比較

従来の潜在推論手法との比較において、KaVaは特に自然言語推論トレースでの優位性が顕著である。
CODIは方程式スタイルのトレースでは良好な性能を示すが、自然言語トレースでは大幅な性能低下を示している(GSM8kで37.5%から20.2%へ)。

PCCoTも同様の傾向を示し、方程式トレースから自然言語トレースへの移行で性能が劣化している。
一方、KaVaは方程式スタイル(46.9%)から自然言語スタイル(44.4%)への移行での性能低下が2.5ポイントと最小限に抑えられている。

効率性の観点では、KaVaは潜在推論の利点を保持しながら、Full CoTに近い性能を達成している。
推論時のトークン生成数とKV-cacheフットプリントは潜在推論により削減され、計算効率とメモリ効率の両方で利益を提供している。

4. 実用性評価

4.1 実装の容易性

KaVaの実装は既存の自己回帰言語モデルのフレームワーク上に構築可能で、大幅なアーキテクチャ変更を必要としない。
教師・学生の切り替え機構と圧縮モジュールは、標準的な深層学習フレームワークで実装可能であり、実装の障壁は比較的低いと評価される。

ただし、KV-cache圧縮の重要度スコアリングと排除戦略の最適化には、特定のドメインや推論タスクに応じた調整が必要となる可能性がある。
また、教師と学生の訓練を適切にバランスさせるためのハイパーパラメータ調整が重要となる。

4.2 計算効率

推論時においては、潜在推論によりトークン生成数が削減され、KV-cacheのメモリフットプリントも減少する。
これにより、メモリ制約のあるデバイスでの展開が容易になり、推論スループットの向上が期待される。

しかし、訓練時には教師モードと学生モードの両方を実行する必要があり、訓練コストは単純な潜在推論手法よりも高くなる可能性がある。
KV-cache圧縮の計算オーバーヘッドも考慮する必要があるが、これは推論時の効率向上によって相殺されると考えられる。

4.3 応用可能性

数学的推論以外のドメインへの拡張可能性は高く、科学的推論、論理的推理、複雑な問題解決タスクなどへの適用が期待される。
特に、長い推論チェーンが必要なタスクや、推論過程の効率化が重要なリアルタイムアプリケーションでの価値が高い。

また、エッジデバイスやモバイルデバイスでの大規模言語モデルの展開において、メモリとコンピューティングリソースの制約を克服する技術として重要な役割を果たす可能性がある。
ただし、異なるモデルアーキテクチャやタスクドメインでの性能検証がさらに必要である。

5. まとめと所感

5.1 論文の意義

本研究は、潜在推論における監督の欠如という根本的な課題に対する革新的な解決策を提供している。
圧縮KV-cacheからの知識蒸留という新しいパラダイムは、効率性と精度のトレードオフを改善し、実用的な潜在推論システムの実現に向けた重要な一歩である。

特に、自然言語推論トレースでの性能劣化の問題を解決したことは、実世界のアプリケーションにおける潜在推論の適用可能性を大幅に向上させている。
従来手法では困難だった長い自然言語推論シーケンスからの効果的な学習を実現し、潜在推論の実用性を実証している。

技術的な観点では、トークン対応を失った圧縮表現からの知識蒸留という困難な問題に対し、連続潜在表現の柔軟性を活用した解決策を提示したことは、知識蒸留技術の新しい方向性を示している。

5.2 今後の展望

今後の研究方向として、より多様なドメインでの評価と最適化が必要である。
現在の評価は主に数学的推論に焦点を当てているが、科学的推理、コード生成、複雑な論理的推論などでの性能検証が重要となる。

また、より大規模なモデルでのスケーラビリティの検証と、異なるアーキテクチャ(Transformer以外)での適用可能性の探索も価値がある。
KV-cache圧縮戦略のさらなる最適化により、効率性と性能のバランスをさらに改善できる可能性がある。

長期的には、この技術がエッジコンピューティングやリアルタイム推論アプリケーションでの大規模言語モデルの実用的展開を促進し、AI技術のアクセシビリティ向上に貢献することが期待される。