---
abstract: |
  `\looseness -1`{=latex} Language models with recurrent depth, also referred to as universal or looped when considering transformers, are defined by the capacity to increase their computation through the repetition of layers. Recent efforts in pretraining have demonstrated that these architectures can scale to modern language modeling tasks while exhibiting advantages in reasoning tasks. In this work, we examine the relationship between recurrent-depth models and diffusion language models. Building on their similarities, we develop a new diffusion forcing sampler for these models to accelerate generation. The sampler advances by decoding new tokens at every forward pass of the model, while the latent states of these tokens can be further refined in parallel through recurrence. Theoretically, generation with our sampler is strictly more expressive than the baseline autoregressive generation using the same time budget on modern hardware. Moreover, this sampler, based on principles from diffusion literature, can be directly applied to existing 3.5B recurrent-depth transformers without any tuning, leading to up to a **5**x speedup. Consequently, our findings not only provide an efficient mechanism for parallelizing the extra computation in recurrent-depth models at inference, but also suggest that such models can be naturally viewed as strong continuous, though causal, diffusion language models.
author:
- |
  Jonas Geiping\
  ELLIS Institute Tübingen &\
  Max-Planck Institute for Intelligent Systems,\
  Tübingen AI Center\
  `jonas@tue.ellis.eu`\
  `\And`{=latex} Xinyu Yang\
  Electrical and Computer Engineering\
  Carnegie Mellon University\
  `\AND`{=latex} Guinan Su\
  ELLIS Institute Tübingen &\
  Max-Planck Institute for Intelligent Systems,\
  Tübingen AI Center\
bibliography:
- nlp\_auto\_references\_do\_not\_edit.bib
- manual\_references.bib
title: 'Efficient Parallel Samplers for Recurrent-Depth Models and Their Connection to Diffusion Language Models'
---

```{=latex}
\newcommand{\figleft}{{\em (Left)}}
```
```{=latex}
\newcommand{\figcenter}{{\em (Center)}}
```
```{=latex}
\newcommand{\figright}{{\em (Right)}}
```
```{=latex}
\newcommand{\figtop}{{\em (Top)}}
```
```{=latex}
\newcommand{\figbottom}{{\em (Bottom)}}
```
```{=latex}
\newcommand{\captiona}{{\em (a)}}
```
```{=latex}
\newcommand{\captionb}{{\em (b)}}
```
```{=latex}
\newcommand{\captionc}{{\em (c)}}
```
```{=latex}
\newcommand{\captiond}{{\em (d)}}
```
```{=latex}
\newcommand{\newterm}[1]{{\bf #1}}
```
```{=latex}
\def\figref#1{figure~\ref{#1}}
```
```{=latex}
\def\Figref#1{Figure~\ref{#1}}
```
```{=latex}
\def\twofigref#1#2{figures \ref{#1} and \ref{#2}}
```
```{=latex}
\def\quadfigref#1#2#3#4{figures \ref{#1}, \ref{#2}, \ref{#3} and \ref{#4}}
```
```{=latex}
\def\secref#1{section~\ref{#1}}
```
```{=latex}
\def\Secref#1{Section~\ref{#1}}
```
```{=latex}
\def\twosecrefs#1#2{sections \ref{#1} and \ref{#2}}
```
```{=latex}
\def\secrefs#1#2#3{sections \ref{#1}, \ref{#2} and \ref{#3}}
```
```{=latex}
\def\eqref#1{equation~\ref{#1}}
```
```{=latex}
\def\Eqref#1{Equation~\ref{#1}}
```
```{=latex}
\def\plaineqref#1{\ref{#1}}
```
```{=latex}
\def\chapref#1{chapter~\ref{#1}}
```
```{=latex}
\def\Chapref#1{Chapter~\ref{#1}}
```
```{=latex}
\def\rangechapref#1#2{chapters\ref{#1}--\ref{#2}}
```
```{=latex}
\def\algref#1{algorithm~\ref{#1}}
```
```{=latex}
\def\Algref#1{Algorithm~\ref{#1}}
```
```{=latex}
\def\twoalgref#1#2{algorithms \ref{#1} and \ref{#2}}
```
```{=latex}
\def\Twoalgref#1#2{Algorithms \ref{#1} and \ref{#2}}
```
```{=latex}
\def\partref#1{part~\ref{#1}}
```
```{=latex}
\def\Partref#1{Part~\ref{#1}}
```
```{=latex}
\def\twopartref#1#2{parts \ref{#1} and \ref{#2}}
```
```{=latex}
\def\ceil#1{\lceil #1 \rceil}
```
```{=latex}
\def\floor#1{\lfloor #1 \rfloor}
```
```{=latex}
\def\1{\bm{1}}
```
```{=latex}
\newcommand{\train}{\mathcal{D}}
```
```{=latex}
\newcommand{\valid}{\mathcal{D_{\mathrm{valid}}}}
```
```{=latex}
\newcommand{\test}{\mathcal{D_{\mathrm{test}}}}
```
```{=latex}
\def\eps{{\epsilon}}
```
```{=latex}
\def\reta{{\textnormal{$\eta$}}}
```
```{=latex}
\def\ra{{\textnormal{a}}}
```
```{=latex}
\def\rb{{\textnormal{b}}}
```
```{=latex}
\def\rc{{\textnormal{c}}}
```
```{=latex}
\def\rd{{\textnormal{d}}}
```
```{=latex}
\def\re{{\textnormal{e}}}
```
```{=latex}
\def\rf{{\textnormal{f}}}
```
```{=latex}
\def\rg{{\textnormal{g}}}
```
```{=latex}
\def\rh{{\textnormal{h}}}
```
```{=latex}
\def\ri{{\textnormal{i}}}
```
```{=latex}
\def\rj{{\textnormal{j}}}
```
```{=latex}
\def\rk{{\textnormal{k}}}
```
```{=latex}
\def\rl{{\textnormal{l}}}
```
```{=latex}
\def\rn{{\textnormal{n}}}
```
```{=latex}
\def\ro{{\textnormal{o}}}
```
```{=latex}
\def\rp{{\textnormal{p}}}
```
```{=latex}
\def\rq{{\textnormal{q}}}
```
```{=latex}
\def\rr{{\textnormal{r}}}
```
```{=latex}
\def\rs{{\textnormal{s}}}
```
```{=latex}
\def\rt{{\textnormal{t}}}
```
```{=latex}
\def\ru{{\textnormal{u}}}
```
```{=latex}
\def\rv{{\textnormal{v}}}
```
```{=latex}
\def\rw{{\textnormal{w}}}
```
```{=latex}
\def\rx{{\textnormal{x}}}
```
```{=latex}
\def\ry{{\textnormal{y}}}
```
```{=latex}
\def\rz{{\textnormal{z}}}
```
```{=latex}
\def\rvepsilon{{\mathbf{\epsilon}}}
```
```{=latex}
\def\rvtheta{{\mathbf{\theta}}}
```
```{=latex}
\def\rva{{\mathbf{a}}}
```
```{=latex}
\def\rvb{{\mathbf{b}}}
```
```{=latex}
\def\rvc{{\mathbf{c}}}
```
```{=latex}
\def\rvd{{\mathbf{d}}}
```
```{=latex}
\def\rve{{\mathbf{e}}}
```
```{=latex}
\def\rvf{{\mathbf{f}}}
```
```{=latex}
\def\rvg{{\mathbf{g}}}
```
```{=latex}
\def\rvh{{\mathbf{h}}}
```
```{=latex}
\def\rvu{{\mathbf{i}}}
```
```{=latex}
\def\rvj{{\mathbf{j}}}
```
```{=latex}
\def\rvk{{\mathbf{k}}}
```
```{=latex}
\def\rvl{{\mathbf{l}}}
```
```{=latex}
\def\rvm{{\mathbf{m}}}
```
```{=latex}
\def\rvn{{\mathbf{n}}}
```
```{=latex}
\def\rvo{{\mathbf{o}}}
```
```{=latex}
\def\rvp{{\mathbf{p}}}
```
```{=latex}
\def\rvq{{\mathbf{q}}}
```
```{=latex}
\def\rvr{{\mathbf{r}}}
```
```{=latex}
\def\rvs{{\mathbf{s}}}
```
```{=latex}
\def\rvt{{\mathbf{t}}}
```
```{=latex}
\def\rvu{{\mathbf{u}}}
```
```{=latex}
\def\rvv{{\mathbf{v}}}
```
```{=latex}
\def\rvw{{\mathbf{w}}}
```
```{=latex}
\def\rvx{{\mathbf{x}}}
```
```{=latex}
\def\rvy{{\mathbf{y}}}
```
```{=latex}
\def\rvz{{\mathbf{z}}}
```
```{=latex}
\def\erva{{\textnormal{a}}}
```
```{=latex}
\def\ervb{{\textnormal{b}}}
```
```{=latex}
\def\ervc{{\textnormal{c}}}
```
```{=latex}
\def\ervd{{\textnormal{d}}}
```
```{=latex}
\def\erve{{\textnormal{e}}}
```
```{=latex}
\def\ervf{{\textnormal{f}}}
```
```{=latex}
\def\ervg{{\textnormal{g}}}
```
```{=latex}
\def\ervh{{\textnormal{h}}}
```
```{=latex}
\def\ervi{{\textnormal{i}}}
```
```{=latex}
\def\ervj{{\textnormal{j}}}
```
```{=latex}
\def\ervk{{\textnormal{k}}}
```
```{=latex}
\def\ervl{{\textnormal{l}}}
```
```{=latex}
\def\ervm{{\textnormal{m}}}
```
```{=latex}
\def\ervn{{\textnormal{n}}}
```
```{=latex}
\def\ervo{{\textnormal{o}}}
```
```{=latex}
\def\ervp{{\textnormal{p}}}
```
```{=latex}
\def\ervq{{\textnormal{q}}}
```
```{=latex}
\def\ervr{{\textnormal{r}}}
```
```{=latex}
\def\ervs{{\textnormal{s}}}
```
```{=latex}
\def\ervt{{\textnormal{t}}}
```
```{=latex}
\def\ervu{{\textnormal{u}}}
```
```{=latex}
\def\ervv{{\textnormal{v}}}
```
```{=latex}
\def\ervw{{\textnormal{w}}}
```
```{=latex}
\def\ervx{{\textnormal{x}}}
```
```{=latex}
\def\ervy{{\textnormal{y}}}
```
```{=latex}
\def\ervz{{\textnormal{z}}}
```
```{=latex}
\def\rmA{{\mathbf{A}}}
```
```{=latex}
\def\rmB{{\mathbf{B}}}
```
```{=latex}
\def\rmC{{\mathbf{C}}}
```
```{=latex}
\def\rmD{{\mathbf{D}}}
```
```{=latex}
\def\rmE{{\mathbf{E}}}
```
```{=latex}
\def\rmF{{\mathbf{F}}}
```
```{=latex}
\def\rmG{{\mathbf{G}}}
```
```{=latex}
\def\rmH{{\mathbf{H}}}
```
```{=latex}
\def\rmI{{\mathbf{I}}}
```
```{=latex}
\def\rmJ{{\mathbf{J}}}
```
```{=latex}
\def\rmK{{\mathbf{K}}}
```
```{=latex}
\def\rmL{{\mathbf{L}}}
```
```{=latex}
\def\rmM{{\mathbf{M}}}
```
```{=latex}
\def\rmN{{\mathbf{N}}}
```
```{=latex}
\def\rmO{{\mathbf{O}}}
```
```{=latex}
\def\rmP{{\mathbf{P}}}
```
```{=latex}
\def\rmQ{{\mathbf{Q}}}
```
```{=latex}
\def\rmR{{\mathbf{R}}}
```
```{=latex}
\def\rmS{{\mathbf{S}}}
```
```{=latex}
\def\rmT{{\mathbf{T}}}
```
```{=latex}
\def\rmU{{\mathbf{U}}}
```
```{=latex}
\def\rmV{{\mathbf{V}}}
```
```{=latex}
\def\rmW{{\mathbf{W}}}
```
```{=latex}
\def\rmX{{\mathbf{X}}}
```
```{=latex}
\def\rmY{{\mathbf{Y}}}
```
```{=latex}
\def\rmZ{{\mathbf{Z}}}
```
```{=latex}
\def\ermA{{\textnormal{A}}}
```
```{=latex}
\def\ermB{{\textnormal{B}}}
```
```{=latex}
\def\ermC{{\textnormal{C}}}
```
```{=latex}
\def\ermD{{\textnormal{D}}}
```
```{=latex}
\def\ermE{{\textnormal{E}}}
```
```{=latex}
\def\ermF{{\textnormal{F}}}
```
```{=latex}
\def\ermG{{\textnormal{G}}}
```
```{=latex}
\def\ermH{{\textnormal{H}}}
```
```{=latex}
\def\ermI{{\textnormal{I}}}
```
```{=latex}
\def\ermJ{{\textnormal{J}}}
```
```{=latex}
\def\ermK{{\textnormal{K}}}
```
```{=latex}
\def\ermL{{\textnormal{L}}}
```
```{=latex}
\def\ermM{{\textnormal{M}}}
```
```{=latex}
\def\ermN{{\textnormal{N}}}
```
```{=latex}
\def\ermO{{\textnormal{O}}}
```
```{=latex}
\def\ermP{{\textnormal{P}}}
```
```{=latex}
\def\ermQ{{\textnormal{Q}}}
```
```{=latex}
\def\ermR{{\textnormal{R}}}
```
```{=latex}
\def\ermS{{\textnormal{S}}}
```
```{=latex}
\def\ermT{{\textnormal{T}}}
```
```{=latex}
\def\ermU{{\textnormal{U}}}
```
```{=latex}
\def\ermV{{\textnormal{V}}}
```
```{=latex}
\def\ermW{{\textnormal{W}}}
```
```{=latex}
\def\ermX{{\textnormal{X}}}
```
```{=latex}
\def\ermY{{\textnormal{Y}}}
```
```{=latex}
\def\ermZ{{\textnormal{Z}}}
```
```{=latex}
\def\vzero{{\bm{0}}}
```
```{=latex}
\def\vone{{\bm{1}}}
```
```{=latex}
\def\vmu{{\bm{\mu}}}
```
```{=latex}
\def\vtheta{{\bm{\theta}}}
```
```{=latex}
\def\va{{\bm{a}}}
```
```{=latex}
\def\vb{{\bm{b}}}
```
```{=latex}
\def\vc{{\bm{c}}}
```
```{=latex}
\def\vd{{\bm{d}}}
```
```{=latex}
\def\ve{{\bm{e}}}
```
```{=latex}
\def\vf{{\bm{f}}}
```
```{=latex}
\def\vg{{\bm{g}}}
```
```{=latex}
\def\vh{{\bm{h}}}
```
```{=latex}
\def\vi{{\bm{i}}}
```
```{=latex}
\def\vj{{\bm{j}}}
```
```{=latex}
\def\vk{{\bm{k}}}
```
```{=latex}
\def\vl{{\bm{l}}}
```
```{=latex}
\def\vm{{\bm{m}}}
```
```{=latex}
\def\vn{{\bm{n}}}
```
```{=latex}
\def\vo{{\bm{o}}}
```
```{=latex}
\def\vp{{\bm{p}}}
```
```{=latex}
\def\vq{{\bm{q}}}
```
```{=latex}
\def\vr{{\bm{r}}}
```
```{=latex}
\def\vs{{\bm{s}}}
```
```{=latex}
\def\vt{{\bm{t}}}
```
```{=latex}
\def\vu{{\bm{u}}}
```
```{=latex}
\def\vv{{\bm{v}}}
```
```{=latex}
\def\vw{{\bm{w}}}
```
```{=latex}
\def\vx{{\bm{x}}}
```
```{=latex}
\def\vy{{\bm{y}}}
```
```{=latex}
\def\vz{{\bm{z}}}
```
```{=latex}
\def\evalpha{{\alpha}}
```
```{=latex}
\def\evbeta{{\beta}}
```
```{=latex}
\def\evepsilon{{\epsilon}}
```
```{=latex}
\def\evlambda{{\lambda}}
```
```{=latex}
\def\evomega{{\omega}}
```
```{=latex}
\def\evmu{{\mu}}
```
```{=latex}
\def\evpsi{{\psi}}
```
```{=latex}
\def\evsigma{{\sigma}}
```
```{=latex}
\def\evtheta{{\theta}}
```
```{=latex}
\def\eva{{a}}
```
```{=latex}
\def\evb{{b}}
```
```{=latex}
\def\evc{{c}}
```
```{=latex}
\def\evd{{d}}
```
```{=latex}
\def\eve{{e}}
```
```{=latex}
\def\evf{{f}}
```
```{=latex}
\def\evg{{g}}
```
```{=latex}
\def\evh{{h}}
```
```{=latex}
\def\evi{{i}}
```
```{=latex}
\def\evj{{j}}
```
```{=latex}
\def\evk{{k}}
```
```{=latex}
\def\evl{{l}}
```
```{=latex}
\def\evm{{m}}
```
```{=latex}
\def\evn{{n}}
```
```{=latex}
\def\evo{{o}}
```
```{=latex}
\def\evp{{p}}
```
```{=latex}
\def\evq{{q}}
```
```{=latex}
\def\evr{{r}}
```
```{=latex}
\def\evs{{s}}
```
```{=latex}
\def\evt{{t}}
```
```{=latex}
\def\evu{{u}}
```
```{=latex}
\def\evv{{v}}
```
```{=latex}
\def\evw{{w}}
```
```{=latex}
\def\evx{{x}}
```
```{=latex}
\def\evy{{y}}
```
```{=latex}
\def\evz{{z}}
```
```{=latex}
\def\mA{{\bm{A}}}
```
```{=latex}
\def\mB{{\bm{B}}}
```
```{=latex}
\def\mC{{\bm{C}}}
```
```{=latex}
\def\mD{{\bm{D}}}
```
```{=latex}
\def\mE{{\bm{E}}}
```
```{=latex}
\def\mF{{\bm{F}}}
```
```{=latex}
\def\mG{{\bm{G}}}
```
```{=latex}
\def\mH{{\bm{H}}}
```
```{=latex}
\def\mI{{\bm{I}}}
```
```{=latex}
\def\mJ{{\bm{J}}}
```
```{=latex}
\def\mK{{\bm{K}}}
```
```{=latex}
\def\mL{{\bm{L}}}
```
```{=latex}
\def\mM{{\bm{M}}}
```
```{=latex}
\def\mN{{\bm{N}}}
```
```{=latex}
\def\mO{{\bm{O}}}
```
```{=latex}
\def\mP{{\bm{P}}}
```
```{=latex}
\def\mQ{{\bm{Q}}}
```
```{=latex}
\def\mR{{\bm{R}}}
```
```{=latex}
\def\mS{{\bm{S}}}
```
```{=latex}
\def\mT{{\bm{T}}}
```
```{=latex}
\def\mU{{\bm{U}}}
```
```{=latex}
\def\mV{{\bm{V}}}
```
```{=latex}
\def\mW{{\bm{W}}}
```
```{=latex}
\def\mX{{\bm{X}}}
```
```{=latex}
\def\mY{{\bm{Y}}}
```
```{=latex}
\def\mZ{{\bm{Z}}}
```
```{=latex}
\def\mBeta{{\bm{\beta}}}
```
```{=latex}
\def\mPhi{{\bm{\Phi}}}
```
```{=latex}
\def\mLambda{{\bm{\Lambda}}}
```
```{=latex}
\def\mSigma{{\bm{\Sigma}}}
```
```{=latex}
\newcommand{\tens}[1]{\bm{\mathsfit{#1}}}
```
```{=latex}
\def\tA{{\tens{A}}}
```
```{=latex}
\def\tB{{\tens{B}}}
```
```{=latex}
\def\tC{{\tens{C}}}
```
```{=latex}
\def\tD{{\tens{D}}}
```
```{=latex}
\def\tE{{\tens{E}}}
```
```{=latex}
\def\tF{{\tens{F}}}
```
```{=latex}
\def\tG{{\tens{G}}}
```
```{=latex}
\def\tH{{\tens{H}}}
```
```{=latex}
\def\tI{{\tens{I}}}
```
```{=latex}
\def\tJ{{\tens{J}}}
```
```{=latex}
\def\tK{{\tens{K}}}
```
```{=latex}
\def\tL{{\tens{L}}}
```
```{=latex}
\def\tM{{\tens{M}}}
```
```{=latex}
\def\tN{{\tens{N}}}
```
```{=latex}
\def\tO{{\tens{O}}}
```
```{=latex}
\def\tP{{\tens{P}}}
```
```{=latex}
\def\tQ{{\tens{Q}}}
```
```{=latex}
\def\tR{{\tens{R}}}
```
```{=latex}
\def\tS{{\tens{S}}}
```
```{=latex}
\def\tT{{\tens{T}}}
```
```{=latex}
\def\tU{{\tens{U}}}
```
```{=latex}
\def\tV{{\tens{V}}}
```
```{=latex}
\def\tW{{\tens{W}}}
```
```{=latex}
\def\tX{{\tens{X}}}
```
```{=latex}
\def\tY{{\tens{Y}}}
```
```{=latex}
\def\tZ{{\tens{Z}}}
```
```{=latex}
\def\gA{{\mathcal{A}}}
```
```{=latex}
\def\gB{{\mathcal{B}}}
```
```{=latex}
\def\gC{{\mathcal{C}}}
```
```{=latex}
\def\gD{{\mathcal{D}}}
```
```{=latex}
\def\gE{{\mathcal{E}}}
```
```{=latex}
\def\gF{{\mathcal{F}}}
```
```{=latex}
\def\gG{{\mathcal{G}}}
```
```{=latex}
\def\gH{{\mathcal{H}}}
```
```{=latex}
\def\gI{{\mathcal{I}}}
```
```{=latex}
\def\gJ{{\mathcal{J}}}
```
```{=latex}
\def\gK{{\mathcal{K}}}
```
```{=latex}
\def\gL{{\mathcal{L}}}
```
```{=latex}
\def\gM{{\mathcal{M}}}
```
```{=latex}
\def\gN{{\mathcal{N}}}
```
```{=latex}
\def\gO{{\mathcal{O}}}
```
```{=latex}
\def\gP{{\mathcal{P}}}
```
```{=latex}
\def\gQ{{\mathcal{Q}}}
```
```{=latex}
\def\gR{{\mathcal{R}}}
```
```{=latex}
\def\gS{{\mathcal{S}}}
```
```{=latex}
\def\gT{{\mathcal{T}}}
```
```{=latex}
\def\gU{{\mathcal{U}}}
```
```{=latex}
\def\gV{{\mathcal{V}}}
```
```{=latex}
\def\gW{{\mathcal{W}}}
```
```{=latex}
\def\gX{{\mathcal{X}}}
```
```{=latex}
\def\gY{{\mathcal{Y}}}
```
```{=latex}
\def\gZ{{\mathcal{Z}}}
```
```{=latex}
\def\sA{{\mathbb{A}}}
```
```{=latex}
\def\sB{{\mathbb{B}}}
```
```{=latex}
\def\sC{{\mathbb{C}}}
```
```{=latex}
\def\sD{{\mathbb{D}}}
```
```{=latex}
\def\sF{{\mathbb{F}}}
```
```{=latex}
\def\sG{{\mathbb{G}}}
```
```{=latex}
\def\sH{{\mathbb{H}}}
```
```{=latex}
\def\sI{{\mathbb{I}}}
```
```{=latex}
\def\sJ{{\mathbb{J}}}
```
```{=latex}
\def\sK{{\mathbb{K}}}
```
```{=latex}
\def\sL{{\mathbb{L}}}
```
```{=latex}
\def\sM{{\mathbb{M}}}
```
```{=latex}
\def\sN{{\mathbb{N}}}
```
```{=latex}
\def\sO{{\mathbb{O}}}
```
```{=latex}
\def\sP{{\mathbb{P}}}
```
```{=latex}
\def\sQ{{\mathbb{Q}}}
```
```{=latex}
\def\sR{{\mathbb{R}}}
```
```{=latex}
\def\sS{{\mathbb{S}}}
```
```{=latex}
\def\sT{{\mathbb{T}}}
```
```{=latex}
\def\sU{{\mathbb{U}}}
```
```{=latex}
\def\sV{{\mathbb{V}}}
```
```{=latex}
\def\sW{{\mathbb{W}}}
```
```{=latex}
\def\sX{{\mathbb{X}}}
```
```{=latex}
\def\sY{{\mathbb{Y}}}
```
```{=latex}
\def\sZ{{\mathbb{Z}}}
```
```{=latex}
\def\emLambda{{\Lambda}}
```
```{=latex}
\def\emA{{A}}
```
```{=latex}
\def\emB{{B}}
```
```{=latex}
\def\emC{{C}}
```
```{=latex}
\def\emD{{D}}
```
```{=latex}
\def\emE{{E}}
```
```{=latex}
\def\emF{{F}}
```
```{=latex}
\def\emG{{G}}
```
```{=latex}
\def\emH{{H}}
```
```{=latex}
\def\emI{{I}}
```
```{=latex}
\def\emJ{{J}}
```
```{=latex}
\def\emK{{K}}
```
```{=latex}
\def\emL{{L}}
```
```{=latex}
\def\emM{{M}}
```
```{=latex}
\def\emN{{N}}
```
```{=latex}
\def\emO{{O}}
```
```{=latex}
\def\emP{{P}}
```
```{=latex}
\def\emQ{{Q}}
```
```{=latex}
\def\emR{{R}}
```
```{=latex}
\def\emS{{S}}
```
```{=latex}
\def\emT{{T}}
```
```{=latex}
\def\emU{{U}}
```
```{=latex}
\def\emV{{V}}
```
```{=latex}
\def\emW{{W}}
```
```{=latex}
\def\emX{{X}}
```
```{=latex}
\def\emY{{Y}}
```
```{=latex}
\def\emZ{{Z}}
```
```{=latex}
\def\emSigma{{\Sigma}}
```
```{=latex}
\newcommand{\etens}[1]{\mathsfit{#1}}
```
```{=latex}
\def\etLambda{{\etens{\Lambda}}}
```
```{=latex}
\def\etA{{\etens{A}}}
```
```{=latex}
\def\etB{{\etens{B}}}
```
```{=latex}
\def\etC{{\etens{C}}}
```
```{=latex}
\def\etD{{\etens{D}}}
```
```{=latex}
\def\etE{{\etens{E}}}
```
```{=latex}
\def\etF{{\etens{F}}}
```
```{=latex}
\def\etG{{\etens{G}}}
```
```{=latex}
\def\etH{{\etens{H}}}
```
```{=latex}
\def\etI{{\etens{I}}}
```
```{=latex}
\def\etJ{{\etens{J}}}
```
```{=latex}
\def\etK{{\etens{K}}}
```
```{=latex}
\def\etL{{\etens{L}}}
```
```{=latex}
\def\etM{{\etens{M}}}
```
```{=latex}
\def\etN{{\etens{N}}}
```
```{=latex}
\def\etO{{\etens{O}}}
```
```{=latex}
\def\etP{{\etens{P}}}
```
```{=latex}
\def\etQ{{\etens{Q}}}
```
```{=latex}
\def\etR{{\etens{R}}}
```
```{=latex}
\def\etS{{\etens{S}}}
```
```{=latex}
\def\etT{{\etens{T}}}
```
```{=latex}
\def\etU{{\etens{U}}}
```
```{=latex}
\def\etV{{\etens{V}}}
```
```{=latex}
\def\etW{{\etens{W}}}
```
```{=latex}
\def\etX{{\etens{X}}}
```
```{=latex}
\def\etY{{\etens{Y}}}
```
```{=latex}
\def\etZ{{\etens{Z}}}
```
```{=latex}
\newcommand{\pdata}{p_{\rm{data}}}
```
```{=latex}
\newcommand{\ptrain}{\hat{p}_{\rm{data}}}
```
```{=latex}
\newcommand{\Ptrain}{\hat{P}_{\rm{data}}}
```
```{=latex}
\newcommand{\pmodel}{p_{\rm{model}}}
```
```{=latex}
\newcommand{\Pmodel}{P_{\rm{model}}}
```
```{=latex}
\newcommand{\ptildemodel}{\tilde{p}_{\rm{model}}}
```
```{=latex}
\newcommand{\pencode}{p_{\rm{encoder}}}
```
```{=latex}
\newcommand{\pdecode}{p_{\rm{decoder}}}
```
```{=latex}
\newcommand{\precons}{p_{\rm{reconstruct}}}
```
```{=latex}
\newcommand{\laplace}{\mathrm{Laplace}}
```
```{=latex}
\newcommand{\E}{\mathbb{E}}
```
```{=latex}
\newcommand{\Ls}{\mathcal{L}}
```
```{=latex}
\newcommand{\R}{\mathbb{R}}
```
```{=latex}
\newcommand{\emp}{\tilde{p}}
```
```{=latex}
\newcommand{\lr}{\alpha}
```
```{=latex}
\newcommand{\reg}{\lambda}
```
```{=latex}
\newcommand{\rect}{\mathrm{rectifier}}
```
```{=latex}
\newcommand{\softmax}{\mathrm{softmax}}
```
```{=latex}
\newcommand{\sigmoid}{\sigma}
```
```{=latex}
\newcommand{\softplus}{\zeta}
```
```{=latex}
\newcommand{\KL}{D_{\mathrm{KL}}}
```
```{=latex}
\newcommand{\Var}{\mathrm{Var}}
```
```{=latex}
\newcommand{\standarderror}{\mathrm{SE}}
```
```{=latex}
\newcommand{\Cov}{\mathrm{Cov}}
```
```{=latex}
\newcommand{\normlzero}{L^0}
```
```{=latex}
\newcommand{\normlone}{L^1}
```
```{=latex}
\newcommand{\normltwo}{L^2}
```
```{=latex}
\newcommand{\normlp}{L^p}
```
```{=latex}
\newcommand{\normmax}{L^\infty}
```
```{=latex}
\newcommand{\parents}{Pa}
```
```{=latex}
\DeclareMathOperator*{\argmax}{arg\,max}
```
```{=latex}
\DeclareMathOperator*{\argmin}{arg\,min}
```
```{=latex}
\DeclareMathOperator{\sign}{sign}
```
```{=latex}
\DeclareMathOperator{\Tr}{Tr}
```
```{=latex}
\let\ab\allowbreak
```
```{=latex}
\newcommand{\cmark}{\ding{51}}
```
```{=latex}
\newcommand{\xmark}{\ding{55}}
```
```{=latex}
\DeclareMathOperator{\Embed}{Embed}
```
```{=latex}
\DeclareMathOperator{\Sample}{Sample}
```
```{=latex}
\DeclareMathOperator{\InitState}{InitState}
```
```{=latex}
\DeclareMathOperator{\Iterate}{Iterate}
```
```{=latex}
\DeclareMathOperator{\Predict}{Predict}
```
```{=latex}
\newcommand{\jog}[1]{{\textcolor{teal}{Jonas: \footnotesize\sf[#1]}}}
```
```{=latex}
\newcommand{\fix}{\marginpar{FIX}}
```
```{=latex}
\newcommand{\new}{\marginpar{NEW}}
```
```{=latex}
\maketitle
```
Introduction
============

`\looseness -1`{=latex} Conventional large language models (LLMs) are constructed as fixed-depth neural networks with a predetermined number of layers (often merely a two-digit count), a property that not only allows these models to be trained efficiently, but in practice appears sufficient for many tasks [@radford_language_2019]. However, more challenging tasks in mathematics and programming often require conceptual leaps over multiple steps in a logical chain that are hard for these models to learn robustly. More formally, fixed-depth transformers fall within the complexity class $\mathsf{TC}^{\mathsf{0}}$ [@merrill_parallelism_2023]. To resolve this, recent efforts have focused on training models to \`\`verbalize" their internal reasoning into chains-of-thought composed of small sub-steps, each of which the model is capable of learning.

An alternative to fixed-depth are models with *recurrent depth* [@dehghani_universal_2019; @schwarzschild_can_2021], which can repeat layers. Consequently, these models are also referred to as *looped* transformers [@giannou_looped_2023], or, as *universal* transformers, [@dehghani_universal_2019] when highlighting the motivation for these systems to represent universal Turing machines [@graves_neural_2014; @graves_adaptive_2017]. `\Citet{merrill_little_2025}`{=latex} showcase that, in contrast to fixed-depth models, models with arbitrary recurrence are indeed capable of representing a larger complexity class.

However, generation with autoregressive recurrent-depth models is typically slow, given that every repetition of the model layers must be executed sequentially before the next token can be produced.

![Different generation schemes for autoregressive, recurrent-depth models. **Left:** Standard sequential generation, which proceeds one token and step of the recurrence at a time (time steps denoted by integers). **Right:** A diffusion forcing sampler used for the same model can parallelize generation \`\`diagonally", by computing one step of the recurrence per token position, iteratively refining its estimate of the generated sequence. ](plots/neurips_teaser.png){#fig:teaser width="0.75\\linewidth"}

In this work, we discuss how generation from recurrent-depth models can be efficiently parallelized by connecting this architecture to diffusion model architectures. Both architectures \`\`recur" in a related sense, and even though both are trained with different objectives, we show that samplers adapted from diffusion literature, namely, *diffusion forcing* [@chen_diffusion_2024-1], can be directly applied to parallelize the generation of already existing recurrent-depth models from [@geiping_scaling_2025].

We discuss how to adapt diffusion forcing sampling to recurrent-depth models, identifying the essential architectural components and strategies required to ensure both stability of the iterates and bounded memory usage. As illustrated in `\cref{fig:teaser}`{=latex}, rather than waiting for the recurrence at sequence position $n$ to fully converge before generating the next token, our sampler immediately produces token drafts from intermediate iterates. It then advances to position $n+1$, where the subsequent forward pass simultaneously refines the drafts for steps $n$ and $n+1$, while also decoding an initial draft for $n+2$. In this way, the sampler achieves parallelism along the sequence dimension, akin to speculative decoding. Importantly, because the underlying model is trained as a causal language model, information still propagates strictly from left to right, and the output sequence is iteratively refined across recurrences. While this approach does not reduce FLOPs, it effectively exploits modern GPU architectures by unlocking additional opportunities for parallelization. Overall, in this work, we

-   Clarify the connection between recurrent-depth models and diffusion models via diffusion forcing and block or wave-based inference strategies for sequence-based diffusion models.

-   Describe how to apply principles from diffusion forcing to efficiently parallelize the inference of models with recurrent depth.

-   Verify that recurrent-depth models equipped with diffusion-forcing samplers achieve the strongest balance between practical efficiency and theoretical expressiveness in both prefilling and decoding.

-   `\looseness -1`{=latex} Show that diffusion forcing sampling outperforms even well-tuned speculative decoding baselines for the same model with speed gains that can be smoothly traded off against accuracy.

Related Work
============

We briefly introduce both recurrent models and diffusion models, focusing on language applications.

**Recurrent Models.** Models with recurrent computations have long been central to machine learning [@amari_learning_1972; @hopfield_neural_1982; @braitenberg_vehicles_1986; @gers_recurrent_2000; @sutskever_recurrent_2008], not only due to significant inspiration from recurrent firing patterns found in neuroscience [@hopfield_neural_1982; @lamme_distinct_2000; @douglas_neuronal_2004], and early successes in language modeling centered on recurrent neural networks [@mikolov_recurrent_2010; @sutskever_generating_2011]. With the advent of transformer models, these architectures were considered less scalable, yet recurrence, now as *recurrence in depth*, was swiftly re-introduced as *universal transformers*, @dehghani_universal_2019, motivating that these models could be capable of modeling universal Turing machines [@graves_neural_2014]. Other work showed that recurrent models were capable of learning algorithms [@schwarzschild_can_2021; @bansal_end--end_2022; @bear_rethinking_2024]. That recurrence was capable of representing universal computation was explicitly constructed for transformer models in @giannou_looped_2023, and following work on *looped* transformers has shown that these models are capable learners [@giannou_looped_2023; @gatmiry_can_2024; @yang_looped_2024; @mcleish_transformers_2024; @fan_looped_2025]. These findings have led to a wave of work training larger, general-purpose recurrent-depth models of language [@tan_sparse_2023; @abnar_adaptivity_2023; @mathur_mind_2024; @csordas_moeut_2024; @geiping_scaling_2025], as well as work retro-fitting recurrence into trained models [@li_deep_2020; @bae_relaxed_2024; @hay_dynamic_2023; @liu_mobilellm_2024]. Several of these works also highlight the possibility of implementing *latent reasoning* via recurrence, that is to complement or replace verbalized chains-of-thought, with recurrence. Examples for this line of thinking are *Coconut* [@hao_training_2024], as well as @liu_deliberation_2024 [@cheng_compressed_2024].

![An example of a text sequence being generated with the proposed diffusion forcing sampler from a depth-recurrent model. While the original recurrent-depth model requires 32 recurrence steps to produce a single token (the default for this model), the diffusion sampler has already produced and committed 8 new tokens (green). As described, the sampler advances by at least one token per step of the recurrence. Decoded candidate tokens are initial spell out incoherent text, but map into the right concepts, and quickly improve with more steps. Note that the \`\`freeze" decision is dynamic, based on distance to the previous state in latent space (not pictured). ](plots/text_crystallization_long.png){#fig:illustration width="0.8\\linewidth"}

In this work, we propose a generic sampling algorithm for depth-recurrent models, which we test with the models developed in @geiping_scaling_2025, which are trained for general language understanding and reasoning on 800B tokens, with 3.5B parameters, and openly accessible.

`\looseness -1`{=latex} **Diffusion Language Models.** Diffusion models are general-purpose generative models, with early applications focusing on continuous domains, such as images [@song_generative_2019; @rombach_high-resolution_2022; @peebles_scalable_2023], which lead to substantial interest in extending diffusion also to discrete domains, such as text [@austin_structured_2021; @hoogeboom_argmax_2021]. Approaches to language diffusion are split on whether to incorporate diffusion processes on a continuous variable (that is then projected into discrete space) [@chen_analog_2022; @dieleman_continuous_2022; @han_ssd-lm_2023; @karimi_mahabadi_tess_2024; @jo_continuous_2025; @graves_bayesian_2025], or diffusion processes that directly act on discrete variables.[@lou_discrete_2024-1; @richemond_categorical_2023]. The latter though, especially using *masking* as the discrete forward diffusion step, is currently the most scalable approach, employed in large-scale efforts to train language diffusion models, competitive with autoregressive models [@gong_scaling_2025; @gong_diffucoder_2025; @google_deepmind_gemini_2025; @nie_large_2025; @wang_diffusion_2025; @xie_dream-coder_2025; @ye_dream_2025].

**Inference Strategies for Diffusion Language Models.** To make diffusion tractable for arbitrarily long sequences requires techniques such as block diffusion [@arriola_block_2025], where chunks of text are being modified by the diffusion model, and then frozen and their KV entries cached, with the sampler moving to the next chunk. A more free-form approach to handle sequence-based diffusion is to use *diffusion forcing* [@chen_diffusion_2024-1], a hybrid model, where noise is added to future tokens in a sequence relative to the position of the current token, allowing the sampler to move both on both the sequence dimension and the diffusion time dimension.

**Inference Acceleration for Fixed-Depth Transformers.** Inference in transformers, in particular in small-batch settings is memory-bound, meaning that the transfer of data (or, in the default case, model parameters) to and from the L1 cache of the accelerator, is the dominating cost during inference, allowing algorithms such as speculative decoding [@leviathan_fast_2023] and follow-ups [@cai_medusa_2024; @miao2024specinfer; @chen2024sequoia] to improve inference speed through speculative parallelization. Using smaller draft models, these algorithms draft text several tokens in the future, which can then be verified using the original model, as verification of the entire text sequence is compute-bound and hence, fast.

Applying Diffusion Forcing to Recurrent-Depth Models
====================================================

In this section, we present our diffusion forcing sampler for recurrent-depth models, which accelerates text generation by advancing at least one token in each recurrence step, as illustrated in `\cref{fig:illustration}`{=latex}.

Background on Recurrent-Depth Models
------------------------------------

`\looseness -1`{=latex} Before detailing the diffusion forcing sampler, we briefly describe the particular recurrent-depth architecture proposed by @geiping_scaling_2025, emphasizing features of the model that are pertinent to the sampler's functionality. We will use the checkpoint name *Huginn-0125* when referring to the trained model. The architecture of this model contains three main blocks, each composed of multiple transformer layers: (i) a prelude block $P$, projecting the embedded input tokens into a latent space; (ii) a recurrent block $R$, iterating $r$ times in this latent space by refining a state vector $\textbf{s}$, and (iii) a coda block $C$ that processes the latent state and produces the model's probabilities for the next token, formally $$\begin{aligned}
    \mathbf{e} &= P(\mathbf{x}) \\
    \mathbf{s}_0 &\sim \mathcal{N}(\mathbf{0}, \sigma^2 I) \\
    \mathbf{s}_i &= R(\mathbf{e}, \mathbf{s}_{i-1}) \qquad \textnormal{for} \quad i \in \lbrace 1, \dots, r \rbrace \\
    \mathbf{p} &= C(\mathbf{s}_r).\end{aligned}$$ Notably, while this architecture is derived from looping the middle layers of fixed-depth transformer models [@skean_does_2024; @sun_transformer_2024; @kaplan_tokens_2024], with features such as input injection and random state initialization from the literature of recurrent-depth models [@bansal_end--end_2022; @anil_path_2022], it can also be interpreted as a *latent-space diffusion model* following the formulation of @rombach_high-resolution_2022: Starting from an initial random state $s_0$, the model iteratively refines this state conditioned on the embedded input sequence $e$, until we assume the state to be completely denoised at the end of the process, at which point it will be decoded into the next token using $C$.

In @geiping_scaling_2025, this model is trained using randomized unrolling with truncated backpropagation, i.e. a random number of iterates $r$ is sampled (from a Poisson-lognormal distribution), and then the entire current batch of training sequences is iterated up to $r$, which is not directly related to diffusion language modeling, which most effectively trains by randomized masking and adaptation from autoregressive models [@nie_large_2025; @xie_dream-coder_2025; @ye_dream_2025; @gong_scaling_2025].

The Ingredients for Diffusion Forcing Sampling
----------------------------------------------

While we will describe experiments using this particular recurrent-depth model, the sampler can be applied to all recurrent-depth models that fulfill the following requirements.

**Input Injection.** The first necessary component, aside from the recurrence over layers itself, is the input injection, i.e., the conditioning of the recurrence on $e$. This will allow the sampler to \`\`course-correct" if conditioning changes without having to jettison a partially computed state $s$. The other component that may improve the connection to diffusion modeling is the initialization of random states, but while we speculate that this is beneficial, it is not architecturally necessary. As such, recurrent-depth models trained in @csordas_moeut_2024 [@schone_implicit_2025; @mohtashami_cotformer_2024] or @wang_hierarchical_2025 could also benefit from this sampler. However, looped architectures such as *Coconut* [@hao_training_2024], which train to feed the outputs of a transformer back in as inputs, are not immediately supported and require retraining to incorporate input injection, separating their recurrent state from their input data.

**Robust Recurrence.** The second necessary property is that the intermediate state at every step of the recurrence must be decodable to approximately correct solutions. While this property is generally satisfied, it may fail in models trained exclusively with a fixed number of recurrences $r$, where decoding from earlier steps can yield nonsensical outputs rather than approximate versions of the intended result.

```{=latex}
\begin{wrapfigure}[18]{r}{0.5\textwidth}     % 
    
    
    \includegraphics[width=0.48\textwidth]{plots/kv_cache_sharing_gsm8k.pdf}
    
    \caption{The \textit{Huginn-0125} recurrent-depth model can match the baseline performance on the GSM8k dataset when enabling KV cache sharing (with a minimal cache size of 1), using $r$-times less memory for KV states.}
    \label{fig:cache_sharing}
    % % latex is overrating the gap
\end{wrapfigure}
```
`\looseness -1`{=latex} **KV Cache Sharing.** The third property, while not strictly required but highly beneficial for diffusion forcing samplers, is the ability of different recurrent depths to share their KV cache across iterations during generation. Without fungible KV states, all KV states from previous recurrences and tokens must be retained in memory, causing the cache to grow with both sequence length and recurrence depth. As shown in `\cref{fig:cache_sharing}`{=latex}, the trained *Huginn-0125* model inherently supports KV cache sharing, allowing us to store only the KV state of the most recent recurrence for each token position[^1].

A Simplified Version of the Sampling Algorithm
----------------------------------------------

Next, we present the algorithm for our sampler. Given a prompt $x$, `\cref{alg:simple_diffusion}`{=latex} describes a simplified version that directly adapts diffusion forcing principles to parallelize generation across the sequence dimension. This approach yields improvements in tokens/second while maintaining equivalent total FLOP requirements. An example of the sampler's behavior is illustrated in Figure `\ref{fig:illustration}`{=latex}.

We emphasize several important aspects. First, the number of inner recurrences $r'$ may be chosen to exceed one. These additional iterations are relatively inexpensive, since the broader logic of the sampler is not yet invoked. More importantly, they serve to stabilize the recurrence. Because the conditioning on the input embedding $\mathbf{e}$ may vary across successive steps of the sampler, the model risks becoming trapped in oscillatory behavior unless sufficient steps are allowed to adapt the current state to the evolving conditioning. This mechanism closely parallels practices in the diffusion literature, such as the use of supplementary diffusion steps in @bansal_universal_2023-1 to incorporate complex guidance signals into image diffusion models.

Second, we naturally employ this sampler only during the generation phase, as the prefill phase is already parallelizable in the sequence dimension, as the recurrence can be computed on all token positions of the prompt simultaneously.

Further, in terms of efficiency, we note that we do not actually want to keep the state for all tokens changing indefinitely, as doing so would slow down generation again, as well as increase memory usage dramatically. As such, similar to block-diffusion samplers [@arriola_block_2025], we look for rules that decide when each position is \`\`finished". In the simplified version of the sampler, we freeze the last token once we reach a predetermined number of recurrence steps at this position -- which naturally happens $r$ positions behind the current maximal extent of the sequence. Frozen tokens are removed from the state vector and their KV states are added to the cache, so that, as in block diffusion models [@arriola_block_2025], at each point in time, only a small subset of tokens in being modified and the full generation runs like a wave over the generating sequence. Finally, note that with this simplified exit rule, $r'=r$ exactly recovers the original autoregressive sampler.

Stabilizing components based on Diffusion Principles
----------------------------------------------------

Further, we also experiment with adding momentum to the input conditioning $\mathbf{e}$, setting $$\mathbf{e} = \eta~\mathbf{e}_{\text{prev}} + (1-\eta) \mathcal{P}(y_\text{current}),$$ which we find can stabilize the recurrence in challenging sequences, providing a small, but robust gain on average.

Secondly, surprisingly, we find that even though these models are never trained with noise injected into intermediate states, that artificially adding noise to the state in each step of the sampler, in analogy to sampling from continuous diffusion models, i.e. $$\mathbf{z'} =  (1-\beta_t)\mathbf{z} + \beta_t\,\mathbf{z}_{\text{noise}} \qquad \qquad \textnormal{where}  \quad \mathbf{z}_{\text{noise}}=\InitState(1,\alpha),$$ can stabilize the iterative process, leading to gains in both accuracy and throughput if $r'$ is small. In practice, we schedule $\beta_t$ linearly as a function of steps $t$ at each position, so that the latter steps are naturally less noisy [@chen_diffusion_2024], which we find to outperform either scheduling $\beta_t$ scaled by the square root of the number of recurrences at each position or keeping it constant. However, the optimal value of $\beta_t$ depends on $r'$.

```{=latex}
\begin{algorithm}[t]\caption{Diffusion-forcing-style generation, simplified version (Full Version in \cref{alg:latent_diff_freeze})}
\label{alg:simple_diffusion}
\begin{algorithmic}[1]
\small
\Require current text context $\mathbf{x}$, max new tokens $N$, inner recurrence $r'$, total recurrences per token $r$, diffusion steps $T$, init scale $\alpha$
% \Ensure generated sequence $\mathbf{y}$
\State \(\mathbf{y}_{\mathrm{frozen}} \leftarrow \mathbf{x}\)
\State \(\mathbf{y}_{\mathrm{current}} \leftarrow \mathbf{x}\)
\State \(\mathbf{z} \leftarrow \InitState(1,\alpha)\)
%\State \(\mathbf{z} \leftarrow \InitState(\,|\mathbf{x}|,\,\text{scale}=\alpha)\) 
\For{step $t=1,\dots,T$}
    \State \(\mathbf{e}\leftarrow \mathcal{P}(\mathbf{y}_{\mathrm{current}})\) 
    \State \(\mathbf{z}_{\text{noise}} \leftarrow \InitState(1,\alpha)\)
    \State \(\mathbf{z}\leftarrow (1-\beta_t)\mathbf{z} + \beta_t\,\mathbf{z}_{\text{noise}}\)
    \For{$j=1,\dots,r'$} 
        \State \(\mathbf{z}\leftarrow \mathcal{R}(\mathbf{z},\mathbf{e})\) \Comment{Inner recurrence}
    \EndFor
    \State \(\mathbf{p}\leftarrow \mathcal{C}(\mathbf{z})\) \Comment{project latent states to logits}
    \State \(\hat{\mathbf{y}}\leftarrow \Sample(\mathbf{p})\) 
    \State $\mathbf{y}_{\mathrm{current}} \leftarrow [\mathbf{y}_{\mathrm{frozen}}, \hat{\mathbf{y}}]$
    \State \(\mathbf{y}_{\mathrm{frozen}} \leftarrow \) Assign $\mathbf{y}_{\mathrm{current}}$ up to the last $\lceil{\frac{r}{r'}}\rceil$ entries \Comment{Freeze completed tokens}
    \If{$|\mathbf{y}_{\mathrm{frozen}}|-|\mathbf{x}|\ge N$} \textbf{break} \EndIf
    \State \(\mathbf{z}\leftarrow [\mathbf{z},\,\InitState(1,\alpha)]\) \Comment{Append a new latent state for the next position}
\EndFor
\State \Return \(\mathbf{y}_{\mathrm{frozen}}\) % actually y_frozen, or y
\end{algorithmic}
\end{algorithm}
```
Adaptive Exits
--------------

`\looseness -1`{=latex} However, the fixed exit scheme of the simplified sampler can run into issues. The recurrent-depth model is causal and how quickly states converge depends on the complexity of the query. This can lead to situations where either, compute is wasted because the states at certain positions have already converged quicker than $r$, or, more problematically, states where, due to a late change in the conditioning of prior tokens, the states have not converged in time. Freezing these unfinished states would worsen generation, in the worst case leading to a spiral where each token that is frozen incorrectly slows down convergence further, leading to a response that becomes more incorrect with each token.

However, we can remedy both cases through adaptive compute. We pick the simplest adaptive exit criterion, the normalized distance in latent space, and compute this quantity for each position and freeze up to all positions where this distance $\delta_i$ is smaller than a threshold $\varepsilon$. $$\begin{aligned}
\delta_i = \frac{\Vert\mathbf{z}_i - \mathbf{z}_{\text{prev},i}\Vert_2}{\Vert\mathbf{z}_i\Vert_2}, \label{eq:conv} \qquad k^* = \max\{k : \delta_j < \varepsilon \text{ for all } j \leq k\}\end{aligned}$$ We combine this with a limiter on the maximum length of the wavefront of the algorithm to guarantee that both 1) the number of states currently being modified, so the maximum memory footprint, is bounded and 2) only positions with converged states are frozen. The full algorithm is described in Appendix `\cref{alg:latent_diff_freeze}`{=latex}. With these rules in place, we note that setting the wavefront to 1 token, we exactly recover the token-per-token adaptive compute sampler from [@geiping_scaling_2025].

We show the practical outcome of this sampler for a challenging input sequence from GSM8k in a series of heatmaps in the appendix, see `\cref{fig:token_stability}`{=latex}. The heatmap shows the development of the sequence as a function of generation steps and tokens. We see that the wave first advances quickly, but then halts for a short amount of steps, before resuming the advance.

![**Examples of adaptive sampler behavior**. Each color represents a token id in the vocabulary of the model, showing the development of the generated sequence (running left to right) as a function of sampler steps (running top to bottom) for *different hyperparameter choices*. The leftmost example is $r'=4$, and tokens are frozen quickly, whereas middle and right show sequences with $r<4$ require more adaptive computation, and in both cases the sampler stalls after hitting the maximal length of the wavefront (here 32 to visualize), before resolving the sequence and advancing again.](plots/gsm8k_claire_example_3.png "fig:"){#fig:qualitative_examples width="0.32\\linewidth"} ![**Examples of adaptive sampler behavior**. Each color represents a token id in the vocabulary of the model, showing the development of the generated sequence (running left to right) as a function of sampler steps (running top to bottom) for *different hyperparameter choices*. The leftmost example is $r'=4$, and tokens are frozen quickly, whereas middle and right show sequences with $r<4$ require more adaptive computation, and in both cases the sampler stalls after hitting the maximal length of the wavefront (here 32 to visualize), before resolving the sequence and advancing again.](plots/gsm8k_claire_example_4.png "fig:"){#fig:qualitative_examples width="0.32\\linewidth"} ![**Examples of adaptive sampler behavior**. Each color represents a token id in the vocabulary of the model, showing the development of the generated sequence (running left to right) as a function of sampler steps (running top to bottom) for *different hyperparameter choices*. The leftmost example is $r'=4$, and tokens are frozen quickly, whereas middle and right show sequences with $r<4$ require more adaptive computation, and in both cases the sampler stalls after hitting the maximal length of the wavefront (here 32 to visualize), before resolving the sequence and advancing again.](plots/gsm8k_claire_example.png "fig:"){#fig:qualitative_examples width="0.32\\linewidth"}

```{=latex}
\begin{remark}[Convergence of the Adaptive Diffusion Sampler]With this algorithm, we can, in principle guarantee convergence to the same solution as when sampling autoregressively, if we assume that the recurrent block $R$ is a contraction. Then,  convergence of iterates, i.e. \cref{eq:conv}, implies convergence to the fixed point of the operator. Second, because the model is causal, convergence of the first token position does not depend others and will converge at some step $t$. At this step, the conditioning of the subsequent token is frozen, so it will also converge, proving convergence of the full sequence to the autoregressive solution by induction. However, in practice, large-scale recurrent-depth models are not easily proven to be contractive, even if models are approximately path-independent \citep{anil_path_2022}.
\end{remark}
```
Finally, we remark on practical back-of-the-envelope estimates of runtime cost.

```{=latex}
\begin{remark}[Computational Cost]In comparison to the baseline autoregressive sampling algorithm where the recurrence is computed one token at a time, there are two additional sources of computational cost, the cost to encode and decode latent states using $\mathcal{P}$ and $\mathcal{C}$, and the potential cost incurred if convergence is slower than in baseline due to cascading effects of tokens changing late, as seen in \cref{fig:qualitative_examples} if the adaptive version is used. The first cost depends on the size of the recurrent block $\mathcal{R}$, relative to prelude and coda. For the model we study in this work this is  disadvantageous as the FLOP costs for prelude and coda equal one pass through the recurrent block. We define the FLOP costs of one pass through $\mathcal{R}$ as $f$, ignoring attention, so that the FLOP costs of one iteration of the sampler is roughly $(r'+1)f$. Then, the total FLOP costs of running the baseline algorithm for $w$ tokens are  $(r+1)f w$, compared to  $(r+\frac{r}{r'})f w$ for the non-adaptive diffusion sampler. However, as we will see, this FLOP inefficiency is counteracted in practice by the  parallelization gains obtained from the sampler.
\end{remark}
```
Theoretical Analysis
====================

This section develops a theoretical framework to justify the optimality of our design in balancing efficiency and expressiveness with two research questions (RQs): **(i)** Why should models prioritize recurrence, i.e. *depth scaling*, during prefilling? and **(ii)** Why should models prioritize parallelizing decoding from a larger wavefront of tokens using the sampler described in the previous section, i.e. *width scaling* during decoding?

Problem Formulation
-------------------

Before answering these RQs, we formalize the notions of depth and width within our framework, which limits our analysis to Transformer-based autoregressive LLMs. In particular, we focus exclusively on the comparison between depth and width, without considering length (i.e., CoT) scaling.

```{=latex}
\begin{definition}[Depth and Width in Recurrent-Depth Models, informal]For recurrent-depth models, we define \emph{depth} $d_t$ and \emph{width} $w_t$ at each time step $t \in \mathbb{N}$, with initial conditions $d_0 = 0$ and $w_0 = L_0$ (where $L_0$ denotes the input sequence length). The corresponding update rules are given as follows:
\begin{enumerate}[itemsep=1pt,topsep=0pt,leftmargin=*]
\item \textbf{Depth Update:} At each step $t$, $d_{t+1} = d_t + 1$ with $d_0 = 0$, therefore $d_t = t$ for all $t \in \mathbb{N}$.
\item \textbf{Width Update:} At each step $t$, width changes only through token exits and token entries:
\[
\delta^{(t)} =
\begin{cases}
-1, & \text{if a hidden state decodes from the model (exit event)},\\
+1, & \text{if a latest token encodes into the model (entry event)}.
\end{cases}
\]
\end{enumerate}
\end{definition}
```
LLMs should prioritize depth scaling during prefilling.
-------------------------------------------------------

To establish this, we first define a width scaling architecture without increasing model parameters following [@wu2025efficient]. Concretely, we repeat each token along the sequence dimension. Note that during prefilling, increasing the number of such repeated tokens is equivalent to width scaling under our definition, since this expands the input sequence length. Here, we introduce two variants:

-   Width Scaling without KV Sharing (Width-NoShare): For the $j$-th copy of token $i$, attention is allowed to all copies of tokens $0,\dots,i-1$, as well as the first $j-1$ copies of token $i$.

-   Width Scaling with KV Sharing (Width-KVShare): For the $j$-th copy of token $i$, attention is limited to (i) the last copy of tokens $0,\dots,i-1$, and (ii) the first $j-1$ copies of token $i$.

Based on the above definition, we state the importance of depth scaling during prefilling stage.

```{=latex}
\begin{theorem}[Depth vs. Width Scaling in Prefilling, informal]Given the width-scaling architecture above and our recurrent-depth model with the same scaling factor $s$. Then the following hold:
\begin{enumerate}[itemsep=2pt,topsep=0pt,leftmargin=*]
\item \textbf{Expressiveness.} Under equal scaling factors,  depth scaling is more expressive than width scaling.
\item \textbf{Complexity.} For asymptotic prefill cost (including both attention and linear layers), we have
\[
E_{\textnormal{Depth}} \;\leq\; E_{\textnormal{Width\!-\!KVShare}} \;<\; E_{\textnormal{Width\!-\!NoShare}}.
\]
\item \textbf{Parallelism.} There exists a threshold $L_\star$ such that for $L<L_\star$, width scaling provides $s^{2}$ times the parallelism of depth scaling, while for $L\geq L_\star$ both saturate with similar parallelism. 
\end{enumerate}
\end{theorem}
```
```{=latex}
\begin{remark}
Let $L$ be a random variable for prompt length with distribution $\mathcal{D}$. Then the probability that depth scaling is more efficient than width scaling equals $\Pr_{L \sim \mathcal{D}}[L \geq L_\star]$. Since $L_\star$ on modern GPUs typically lies between a few hundred and a few thousand tokens while empirical input length distributions place substantial mass above this range, the probability is indeed close to $1$ in practice.
\end{remark}
```
LLMs should prioritize width scaling during decoding.
-----------------------------------------------------

Next, we prove that recurrent-depth models should use diffusion forcing samplers during decoding.

```{=latex}
\begin{theorem}[Depth vs. Width Scaling in Decoding, informal]For recurrent-depth models with $r > 1$ inner recurrences, if diffusion forcing sampling and KV-cache sharing are employed with wavefront size $W \leq L_\star$, then diffusion forcing decoding achieves equal depth and strictly greater width compared to standard autoregressive decoding under the same runtime constraints. Mathematically, this relationship can be expressed as: $$d_{\text{DF}}(T) = d_{\text{AR}}(T) \quad \text{and} \quad w_{\text{DF}}(T) > w_{\text{AR}}(T),$$ where $T$ is the runtime budget, and $\text{DF}$ and $\text{AR}$ denote diffusion forcing and autoregressive decoding.
\end{theorem}
```
```{=latex}
\begin{remark}
Since model parameters and KV states are shared, the I/O cost of processing multiple tokens is asymptotically equivalent to processing a single token, enabling increased token generation within identical runtime constraints. At each decoding step, an expanded wavefront enables greater width scaling, providing superior expressiveness compared to autoregressive decoding. Empirically, since maximum recurrence depth rarely exceeds $r \approx 100$, the condition $W \leq L_\star$ typically holds.
\end{remark}
```
Experimental Evaluation {#sec:experiments}
=======================

To assess whether our method really accelerates generation, we compare our sampler against an equally optimized implementation of standard autoregressive sampling, both evaluated with a batch size of 1. Extensions to larger batch sizes are conceivable but fall outside the scope of this study, see additional discussion in `\cref{app:batch_engine}`{=latex}

```{=latex}
\small 
```
```{=latex}
\resizebox{\linewidth}{!}{\setlength{\tabcolsep}{1mm}{
\begin{tabular}{l*{8}{c}}
\toprule
\small 
Sampler & GSM8K & MATH500 & HumanEval & MBPP  \\
\cmidrule(lr){2-3} \cmidrule(lr){4-5} \cmidrule(lr){6-7} \cmidrule(lr){8-9}
& Acc & t/s & Acc & t/s & Acc & t/s & Acc & t/s \\
\midrule
Static AR ({\tiny$r=32$})  & 41.77\% & 36.1 & 17.60\% & 6.4 & 22.56\% & 13.5 & 31.60\% & 15.3 \\ % sampler:Base_32
\midrule
Static AR ({\tiny$r=4$})  & 1.59\% & 312.9 & 3.20\% & 18.6 & 0.61\% & 244.1 & 1.40\% & 49.6 \\ % cache_share:False_sampler:Base_4
Static AR ({\tiny$r=8$})  & 31.61\% & 137.5 & 14.80\% & 23.1 & 21.34\% & 61.7 & 27.40\% & 57.2 \\ % sampler:Base_8
Static AR ({\tiny$r=64$})  & 42.15\% & 18.2 & 18.60\% & 3.4 & 22.56\% & 7.3 & 30.20\% & 7.6 \\ % sampler:Base_64
\midrule
Adaptive Compute AR & 42.23\% & 66.9 & 18.20\% & 12.2 & 21.95\% & 26.1 & 30.20\% & 29.5 \\ % criterion:latent-diff_exit_threshold:0.03_sampler:Adaptive Comp.
Speculative Decoding AR  & 42.76\% & 69.5 & 17.80\% & 13.4 & 20.12\% & 27.5 & 30.60\% & 31.6 \\ % draft_steps:4_lookahead_for_draft:4_sampler:Spec. Decoding_1.0_verification_threshold:1.0
\midrule 
Diff. Sampler ({\tiny$r'=2,\beta_t=0.5$})  & 40.71\% & 182.2 & 17.60\% & 35.9  & 20.12\% & 67.4 & 27.80\% & 92.3 \\ % ema_embeds:0.1_exit_t:0.03_freeze_adaptive:latent-diff_headway:1_inner_recurrence:2_sampler:Diff. Sampler_2_0.5_state_noise_mixing:0.5

Diff. Sampler ({\tiny$r'=4,\beta_t=0$}) & 42.08\% & 157.3 & 18.00\% & 30.3 & 20.12\% & 64.9 & 31.00\% & 70.2 \\ % ema_embeds:0.1_exit_t:0.03_freeze_adaptive:latent-diff_headway:1_inner_recurrence:4_sampler:Diff. Sampler_4_0.0_state_noise_mixing:0.0
\midrule
Relative Diff to AR ({\tiny$r=32$})& +0.31 & \textbf{4.36x} & +0.40 & \textbf{4.73x} & -2.44 & \textbf{4.81x} & -0.60 & \textbf{4.59x} \\
\bottomrule
\end{tabular}}}
```
![Trade-off between accuracy and speed on GSM8k under different hyperparameter choices. **Left:** Effect of increasing inner recurrence $r'$. Inner recurrence stabilizes the sampling, increasing accuracy at the cost of throughput. **Right:** Effect of varying the exit threshold $\varepsilon$. Modulating the exit threshold most directly trades off throughput and accuracy.](ablations/ablation_dual_axis_inner_recurrenceno_noise_dual.png "fig:"){#fig:hyperparam_sweeps1 width="0.49\\linewidth"} ![Trade-off between accuracy and speed on GSM8k under different hyperparameter choices. **Left:** Effect of increasing inner recurrence $r'$. Inner recurrence stabilizes the sampling, increasing accuracy at the cost of throughput. **Right:** Effect of varying the exit threshold $\varepsilon$. Modulating the exit threshold most directly trades off throughput and accuracy.](ablations/ablation_dual_axis_exit_tno_noise_dual.png "fig:"){#fig:hyperparam_sweeps1 width="0.49\\linewidth"}

We evaluate the 4 generative benchmarks (GSM8K, MATH500, HumanEval and MBPP) also evaluated in [@geiping_scaling_2025], which we rerun using our sampler and compare against a number of baselines. Aside from the **static, autoregressive baseline** (static AR), at different recurrence steps, we also compare against the **adaptive compute** sampler of the original work, which still samples token-by-token, but exits the recurrence at every token, once the difference in latent space is small enough. We tune this sampler, finding that its hyperparameter, the threshold $\varepsilon$ is similar to the diffusion sampler.

Finally, we also compare against a heavily tuned **self-speculative decoding baseline**. It was observed in @geiping_scaling_2025 that recurrent-depth models can be natively used as their own draft models, using fewer steps to draft. We find that drafting 4 tokens into the future, each with 4 draft steps is optimal for the *Huginn-0125* checkpoint on GSM8k.

We implement all samplers in comparable Hugging Face `transformers` implementations with dynamic KV caching and we measure mean accuracy and median tokens per second, computed over queries from each benchmark. All timings are obtained from CUDA event measurements on sandboxed A100-40GB GPUs. If not otherwise mentioned, we default to conservative settings for the sampler, always setting an exit threshold of $\varepsilon=0.03$, $\beta_t=0$, $\eta=0.1$ and $r'=4$, for a maximum wavefront size of $128$.

Benchmark Results.
------------------

We summarize our findings in `\cref{tab:sampler_comparison}`{=latex}. We find that on all benchmarks, executing the parallelized sampler leads to significant speedups of around 5x, with only minor trade-offs in generation quality of around 1%, depending on the task, owing to the trade-off set by our default hyperparameters. In `\cref{tab:sampler_comparison_2}`{=latex} we repeat all benchmarks for two additional model checkpoints, the SWA model also released in @geiping_scaling_2025, and a math variant, that we finetuned on the MetaMath dataset [@yu_metamath_2023]. Even though these model variants differ noticeably in their benchmark scores, they show similar gains and trade-offfs when using the diffusion sampler.

```{=latex}
\small 
```
::: {#tab:sampler_comparison_2}
  ------------------------------------------------------------------------------------------------------------- -------- -------------- ----------- ------- -------- ------ -------- ------
  Sampler                                                                                                        GSM8K    Minerva Math   HumanEval   MBPP                            
  `\cmidrule`{=latex}(lr)2-3 `\cmidrule`{=latex}(lr)4-5 `\cmidrule`{=latex}(lr)6-7 `\cmidrule`{=latex}(lr)8-9     Acc         Time          Acc      Time     Acc     Time    Acc     Time
  Huginn-0125                                                                                                                                                                        
  Static AR (`\tiny`{=latex}$r=32$)                                                                              41.77%       36.1        12.98%     21.0    22.56%   13.5   31.60%   15.3
  Diff. Sampler (`\tiny`{=latex}$r'=4,\beta_t=0$)                                                                42.08%      157.3        13.06%     96.0    20.12%   64.9   31.00%   70.2
  SWA Model Variant                                                                                                                                                                  
  Static AR (`\tiny`{=latex}$r=32$)                                                                              47.99%       36.2        14.86%     22.1    23.78%   14.9   31.20%   11.8
  Diff. Sampler (`\tiny`{=latex}$r'=4,\beta_t=0$)                                                                47.08%      143.1        14.52%     101.4   23.78%   71.2   29.20%   59.7
  Math-Finetuned Model                                                                                                                                                               
  Static AR (`\tiny`{=latex}$r=32$)                                                                              58.91%       29.8        22.20%      7.9    17.07%   11.5   28.80%   11.2
  Diff. Sampler (`\tiny`{=latex}$r'=4,\beta_t=0$)                                                                58.45%      144.1        21.40%     39.8    15.24%   47.9   27.60%   57.1
  ------------------------------------------------------------------------------------------------------------- -------- -------------- ----------- ------- -------- ------ -------- ------

  : Hyperparameters remain stable across different model variants. For example, both the weight-averaged checkpoint from the original work and the model finetuned on MetaMath for this study exhibit consistent speed gains in the range of 4--5× and accuracy deviations within 0.5--1%, even when baseline values change.
:::

![**Left:** Scaling the amount of momentum $\eta$ in the conditioning., showing that small, but non-zero $\eta$ values are optimal. **Right:** Scaling the amount of noise added during inference for $r'=4$, scheduled linearly in the number of recurrence steps, also measured on GSM8k. At $r'=4$, adding noise is not optimal. We plot the full spectrum of $r'$ to $\beta_t$ in `\cref{fig:pareto}`{=latex}.](ablations/ablation_dual_axis_ema_embeds_dual.png "fig:"){#fig:hyperparam_sweeps3 width="0.49\\linewidth"} ![**Left:** Scaling the amount of momentum $\eta$ in the conditioning., showing that small, but non-zero $\eta$ values are optimal. **Right:** Scaling the amount of noise added during inference for $r'=4$, scheduled linearly in the number of recurrence steps, also measured on GSM8k. At $r'=4$, adding noise is not optimal. We plot the full spectrum of $r'$ to $\beta_t$ in `\cref{fig:pareto}`{=latex}.](ablations/ablation_dual_axis_state_noise_mixing_dual.png "fig:"){#fig:hyperparam_sweeps3 width="0.49\\linewidth"}

Variants and Hyperparameters
----------------------------

**Hyperparameter Choices.** We show the trade-off curves arising when varying the inner recurrence $r'$ and the exit threshold $\varepsilon$ in `\cref{fig:hyperparam_sweeps1}`{=latex} for two settings of noise $\beta_t$, finding that we can effectively trade-off additional generation speed against minor losses in accuracy. We further vary the embedding EMA $\eta$ and the noise schedule in `\cref{fig:hyperparam_sweeps3}`{=latex}, showing that the sampler is robust to a broad range of settings for both options, although upsides are also limited.

In `\cref{fig:pareto}`{=latex}, we sweep a range of values for $r'$ and $\beta_t$, showing that, on average, more noise is helpful if the model takes fewer inner recurrence steps. In `\cref{fig:hyperparam_sweeps5}`{=latex} (left), we confirm that larger maximum wavefront sizes (i.e. the number of tokens that is modified at once in the adaptive sampler) allow for better parallelization. For the tested A100 GPU, the optimal maximal wavefront size is between 64 and 128, although this is likely accelerator-specific.

#### Moving Forward Multiple Steps.

In principle, there is no limitation of only advancing one token at a time, and so we can consider *headways* greater than 1, however, for these, we have no prior position to decode from, so we can only fill these positions with random tokens, or a particular padding token. And, given that the model is still causal, it will take several steps for sequential dependencies to be resolved, even if we sample a large headway in every step. We experiment with headways greater than one, but while interestingly stable, this accelerates the speed of the sampler only marginally at a cost to accuracy, see `\cref{fig:hyperparam_sweeps5}`{=latex}, right.

![The Pareto Curve of Accuracy and Throughput on GSM8k spanned by varying inner recurrence and noise hyperparameter pairs $(r', \beta_t)$. Adding moderate amounts of noise, e.g. $\beta_t=0.2$ is dominating runs with no noise added. Note also the scale of y-axis, as even at the rightmost part of the frontier, we are observing accuracy losses of only $2\%$.](ablations/pareto_chart_final.png){#fig:pareto width="\\linewidth"}

![Impact of Additional Hyperparameter Choices on GSM8k. **Left:** Size of the wavefront. Increasing wavefront size up to a value around 64-128 appears optimal. We note that the optimal wavefront size is also likely to be accelerator-specific. **Right:** Amount of headway. Larger amounts of headway than 1, i.e. advancing the sampler more than 1 token per step, do not seem to materialize practical speedups for the studied model. ](ablations/appendix_ablation_dual_axis_max_wavefront_dual.png "fig:"){#fig:hyperparam_sweeps5 width="0.49\\linewidth"} ![Impact of Additional Hyperparameter Choices on GSM8k. **Left:** Size of the wavefront. Increasing wavefront size up to a value around 64-128 appears optimal. We note that the optimal wavefront size is also likely to be accelerator-specific. **Right:** Amount of headway. Larger amounts of headway than 1, i.e. advancing the sampler more than 1 token per step, do not seem to materialize practical speedups for the studied model. ](ablations/appendix_ablation_dual_axis_headway_dual.png "fig:"){#fig:hyperparam_sweeps5 width="0.49\\linewidth"}

Conclusions: Are Recurrent-depth Transformers secretly continuous language diffusion models?
============================================================================================

We have shown that, surprisingly, diffusion forcing samplers can be directly applied to parallelize the inference of existing recurrent-depth language models, which we justify theoretically, and implement in practice, leading to five times faster single-sequence inference, even on reasoning and coding benchmark questions. Interestingly, we could also interpret this relationship in the opposite direction, namely that the recurrent-depth models of @geiping_scaling_2025 *are* effectively continuous latent language diffusion models, just trained with an unusual objective, namely truncated unrolling. This would imply that unrolling objectives could be competitive objectives for future language diffusion models. However, while this comparison is possible, the recurrent models like *Huginn-0125* are still causal, at least without further training, and so this advantage of diffusion modeling remains elusive.

### Acknowledgments {#acknowledgments .unnumbered}

JG acknowledges the support of the Hector foundation and the Max Planck Computing and Data Facility (MPCDF), especially the compute cluster Raven. We are especially thankful that the MPCDF team was able to address the overheating issues that coincided with the large-scale deployment of the evaluation of this sampling algorithm to the Raven compute cluster. GS acknowledges the support of the International Max Planck Research School for Intelligent Systems (IMPRS-IS).

### Reproducibility Statement {#reproducibility-statement .unnumbered}

We provide the complete sampling algorithm we describe, including all options at `\small `{=latex}<https://github.com/seal-rg/recurrent-pretraining>. We provide experimental details in `\cref{sec:experiments}`{=latex} and provide further ablations and variants in the appendix. If not otherwise mentioned, all measured values are based on at least 5 repeated experiments. All timing are measured using CUDA events on GPUs of equal power, and are comparable to timings in the same table or figure.

```{=latex}
\small
```
```{=latex}
\bibliographystyle{acl_natbib}
```
```{=latex}
\appendix
```
Appendix
========

AdditionaL Algorithm Details
----------------------------

We provide the full algorithm, including adaptive exiting in `\cref{alg:latent_diff_freeze}`{=latex}.

```{=latex}
\begin{algorithm}[h]\caption{Diffusion-style generation with latent-diference-based freezing}
\label{alg:latent_diff_freeze}
\begin{algorithmic}[1]
\Require prompt $\mathbf{x}$, max new tokens $N$, inner recurrence $r$, diffusion steps $T$, init scale $\alpha$, exit threshold $\varepsilon$
% \Ensure generated frozen sequence $\mathbf{y}_{\mathrm{frozen}}$
\State \(\mathbf{y}_{\mathrm{frozen}}\leftarrow \mathbf{x}\), \(\mathbf{y}_{\mathrm{current}}\leftarrow \mathbf{x}\)
\State \(\mathbf{z}\leftarrow \InitState(|\mathbf{x}|,\alpha)\)
\State \(\mathbf{z}_{\mathrm{prev}}\leftarrow \mathbf{z}\) 
\For{step $t=1,\dots,T$}
    \State \(\mathbf{e}\leftarrow \mathcal{P}(\mathbf{y}_{\mathrm{current}})\)
    \State \(\mathbf{z}_{\text{noise}}\sim\mathcal{N}(0,\sigma^2 I)\)
    \State \(\mathbf{z}\leftarrow (1-\beta_r)\mathbf{z}+\beta_r\mathbf{z}_{\text{noise}}\)
    \For{$j=1,\dots,r$}
        \State \(\mathbf{z}\leftarrow \mathcal{R}(\mathbf{z},\mathbf{e})\)
    \EndFor
    \State \(\mathbf{p}\leftarrow \mathcal{C}(\mathbf{z})\)
    \State \(\hat{\mathbf{y}}\leftarrow \Sample(\mathbf{p})\)
    \State $\mathbf{y}_{\mathrm{current}}\leftarrow [\mathbf{y}_{\mathrm{frozen}},\hat{\mathbf{y}}]$
    % \State \(\mathbf{z}_{\mathrm{match}} \leftarrow \mathbf{z}_{1:|\mathbf{z}_{\mathrm{prev}}|}\)
    \State \(\delta_i \leftarrow ||\mathbf{z}_i-\mathbf{z}_{\mathrm{prev},i}||_2 / ||{\mathbf{z}_{i}||_2}\)
    \Comment{Compute relative changes in latents at each position.}
    \If{exists position $i$ with $\delta_i < \varepsilon$}
        \State let $k^* \leftarrow$ index of the last such freezable position where $\delta_i < \varepsilon$ \Comment{freeze up to $k^*$}
        \State \(\mathbf{y}_{\mathrm{frozen}}\leftarrow \mathbf{y}_{\mathrm{current}}[1{:}k^*]\)
        \State keep only unfrozen tail of latents: \(\mathbf{z}\leftarrow \mathbf{z}[k^* - \ell{:}]\)
    \Else
        \State no tokens frozen this step
    \EndIf
    \If{$|\mathbf{y}_{\mathrm{frozen}}|-|\mathbf{x}|\ge N$} \textbf{break} \EndIf
    \State \(\mathbf{z}\leftarrow [\mathbf{z},\,\InitState(1,\alpha)]\) \Comment{Append a new latent state for the next position}
    \State \(\mathbf{z}_{\mathrm{prev}}\leftarrow \mathbf{z}\)

\EndFor
\State \Return \(\mathbf{y}_{\mathrm{frozen}}\)
\end{algorithmic}
\end{algorithm}
```
Additional Variants {#app:batch_engine}
-------------------

#### Larger Batch Sizes.

The sampler discussed in this work could, in principle, also be deployed in batched or continuously-batched inference settings. In that scenario, similar to a paged KV cache, the sampler would reserve a number of slots for hidden states up to an occupancy multiplier of the maximum wavefront size, and would be capable of scheduling recurrent updates in tandem with sequence updates. For larger models, this would, if implemented efficiently, actually simplify deployment, as recurrent states are fungible, and e.g. states could be evicted from one device, and then bundled into the next forward call of the model on a different device, as the slots of the model's hidden states do not have to correspond to contiguous sequences in either the sequence or the recurrence dimension. However, due to to the imminent complexity of such an inference engine, we refrained from engaging with this direction in this work, and focus only on properly bringing the general idea of diffusion sampling to recurrent-depth models, and leave a batched inference engine as a limitation, potentially motivating future work.

Additional Information.
-----------------------

#### Finetuned Math Model:

To verify that our findings are not limited to the particular model checkpoint we evaluate, and its capabilities, we finetune the original checkpoint for one epoch with a trapezoidal learning rate schedule with a peak learning rate of $5 \times 10^{-7}$ using the MetaMath dataset [@yu_metamath_2023]. As suggested in the original work, we train the model with randomized unrolling, we set a mean of $r=32$ and sample $r$ from an Exponential distribution. As a sidenote, we remark that while we do train the full model, most of the gains can also be achieved by just finetuning the adapter component of the model that maps inputs and states into the recurrent block.

#### Dataset Details.

When evaluating GSM8k, we always refer to the CoT version of the dataset, which we provide to the model with the 8 few-shot examples associated with this variant as in @touvron_llama_2023-1. We always score GSM8k using the flexible-extract metric, i.e. by matching the last number in the model response against the reference answer. For MATH500, we follow the format of @deepseek-ai_deepseek-r1_2025, while for Minerva Math, we follow the updated format established in the lm-eval harness. For both, we grade answers using *math-verify*. For MBPP and HumanEval, we grade these benchmarks as normal. During inference we sample with a temperature of 0.2 and top-p factor of $0.95$ as in @geiping_scaling_2025.

![Impact of Additional Hyperparameter Choices, also on GSM8k. **Left** Initialization Scale of new states, which has only a minor effect of the result. **Right:** Continuous Compute, i.e. choosing to initialize new states with previously computed states (We initialize new states with the latest state from the position one step to the left). This is less effective for our sampler, given that the position one step to the left is only the result of $r'$ recurrences.](ablations/appendix_ablation_dual_axis_init_scale_dual.png "fig:"){#fig:hyperparam_sweeps2 width="0.49\\linewidth"} ![Impact of Additional Hyperparameter Choices, also on GSM8k. **Left** Initialization Scale of new states, which has only a minor effect of the result. **Right:** Continuous Compute, i.e. choosing to initialize new states with previously computed states (We initialize new states with the latest state from the position one step to the left). This is less effective for our sampler, given that the position one step to the left is only the result of $r'$ recurrences.](ablations/appendix_ablation_dual_axis_continuous_compute_dual.png "fig:"){#fig:hyperparam_sweeps2 width="0.49\\linewidth"}

![A heatmap of accuracy and throughput measurements spanned by varying noise and inner recurrence.](ablations/sweep_dual_metric_combined_heatmap.png){#fig:heatmap width="\\linewidth"}

![Additional visualizations of the trade-off of noise and inner recurrence in `\cref{fig:pareto}`{=latex}.](ablations/trade_off_chart.png){#fig:tradeoff width="\\linewidth"}

Qualitative Evaluation
----------------------

To visualize the progress (or temporary lack thereof) of the sampler on a challenging sequence from the GSM8k validation set, we provide a few additional visualizations in `\cref{fig:token_stability}`{=latex}.

![A full example of a sampler hyperparameter failure. As in `\cref{fig:qualitative_examples}`{=latex}, this figure shows the token ids on the left, as they change during successive steps of the sampler (running from top to bottom) over the sequence dimension (running left to right). We see that the model tries various configurations for the current tokens, before they are gradually frozen as their latent states converge. Due to a few hard decisions (from the perspective of the model), as seen on the stability charts on the right, early in the sequence, progress stalls until these tokens are decided, but then picks up speed again. However, large points of the wavefront all decode into the whitespace token (dark blue color), so that no useful states information is computed until the earlier tokens are resolved.](plots/modeling01_token_stability_heatmaps.png){#fig:token_stability width="90%"}

![Hyperparameter Robustness for the finetuned math model on GSM8k. These figure repeat the ablation study from the main body concerning hyperparameter robustness also for the finetuned math model, showing that behaviors are largely similar, even though the model's capability has noticeably changed.](ablations/metamath_model_ablation_dual_axis_ema_embeds_dual.png "fig:"){#fig:metamath_ablations width="0.49\\linewidth"} ![Hyperparameter Robustness for the finetuned math model on GSM8k. These figure repeat the ablation study from the main body concerning hyperparameter robustness also for the finetuned math model, showing that behaviors are largely similar, even though the model's capability has noticeably changed.](ablations/metamath_model_ablation_dual_axis_exit_t_dual.png "fig:"){#fig:metamath_ablations width="0.49\\linewidth"} ![Hyperparameter Robustness for the finetuned math model on GSM8k. These figure repeat the ablation study from the main body concerning hyperparameter robustness also for the finetuned math model, showing that behaviors are largely similar, even though the model's capability has noticeably changed.](ablations/metamath_model_ablation_dual_axis_inner_recurrence_dual.png "fig:"){#fig:metamath_ablations width="0.49\\linewidth"} ![Hyperparameter Robustness for the finetuned math model on GSM8k. These figure repeat the ablation study from the main body concerning hyperparameter robustness also for the finetuned math model, showing that behaviors are largely similar, even though the model's capability has noticeably changed.](ablations/metamath_model_ablation_dual_axis_inner_recurrenceno_noise_dual.png "fig:"){#fig:metamath_ablations width="0.49\\linewidth"} ![Hyperparameter Robustness for the finetuned math model on GSM8k. These figure repeat the ablation study from the main body concerning hyperparameter robustness also for the finetuned math model, showing that behaviors are largely similar, even though the model's capability has noticeably changed.](ablations/metamath_model_ablation_dual_axis_state_noise_mixing_dual.png "fig:"){#fig:metamath_ablations width="0.49\\linewidth"}

Theoretical Analysis
====================

Problem Formulations
--------------------

```{=latex}
\begin{definition}[Depth and Width in Recurrent-Depth Models]Consider a recurrent-depth model $\mathcal{M}_{d}$ that processes an input sequence $\mathbf{x} \in \mathbb{R}^{L_0 \times h}$, where $L_0 \in \mathbb{N}$ is the sequence length and $h \in \mathbb{N}$ is the hidden dimension. At each generation step $t \in \mathbb{N}$, we define a \emph{hidden state} as the $h$-dimensional output vector produced by a Transformer block for an input token. Let $\mathbf{H}_t \in \mathbb{R}^{w_t \times h}$ denote the 2D-matrix containing all hidden states at step $t$. We define the following two associated quantities:

\begin{itemize}[itemsep=1pt,topsep=0pt,leftmargin=*]
\item the \emph{depth} $d_t \in \mathbb{N}$, defined as the number of \emph{serial} 
Transformer block forward passes used to obtain $\mathbf{H}_t$ from the initial $L_0$ input tokens 
(i.e., the generation step), while ignoring any discretization;
\item the \emph{width} $w_t \in \mathbb{N}$, defined as the cardinality of the active hidden-state set 
$\mathbf{H}_t$ ( i.e., the number of $h$-dimensional hidden states that are processed in \emph{parallel}
at generation step $t$).
\end{itemize}
These quantities evolve according to the following rules:
\begin{enumerate}[itemsep=1pt,topsep=0pt,leftmargin=*]
\item \textbf{Initialization.} At time $t=0$, we set
\[
d_0 = 0, \qquad w_0 = L_0.
\]
\item \textbf{Depth update.} At each step $t \ge 0$, one additional Transformer block is applied, hence
\[
d_{t+1} = d_t + 1,
\]
so that $d_t = t$ for all $t \in \mathbb{N}$.
\item \textbf{Width update.} 
At each step $t \ge 0$, the width changes only due to two types of events:
\begin{itemize}[itemsep=0pt,topsep=0pt,leftmargin=*]
  \item \emph{Token entry:} let $e^{(t)} \in \mathbb{N}_0$ denote the number of new tokens encoded 
  into the model at step $t$, each contributing a new hidden state;
  \item \emph{Hidden-state exit:} let $x^{(t)} \in \mathbb{N}_0$ denote the number of hidden states 
  removed from the model at step $t$ due to decoding.
\end{itemize}
Then the width evolves as
\[
w_{t+1} = w_t + e^{(t)} - x^{(t)}.
\]
Equivalently, the net change can be written as $\delta^{(t)} = e^{(t)} - x^{(t)}$, 
so that $\delta^{(t)} > 0$ corresponds to entries (more tokens encoded), 
and $\delta^{(t)} < 0$ corresponds to exits (more hidden states decoded).
\end{enumerate}
\end{definition}
```
```{=latex}
\begin{remark}
At any generation step $t$, all hidden states in $H_t$ share the same depth $d_t$, 
since each step corresponds to one additional serial forward pass through the Transformer block.
\end{remark}
```
LLMs should prioritize depth scaling during prefilling.
-------------------------------------------------------

```{=latex}
\begin{definition}[Width Scaling Variants]Fix a width scaling factor $s \in \mathbb{N}$. 
Given an input sequence of length $L$, for each token $i \in \{1,\dots,L\}$ 
we create $s$ copies indexed by $j \in \{1,\dots,s\}$. 
The replicated sequence therefore has length $L\cdot s$, with elements denoted by $(i,j)$, 
the $j$-th copy of token $i$. 
The width-scaling model is obtained by applying a Transformer block 
(with parameters unchanged) to this replicated sequence under a customized attention mask, 
followed by a reduction step that maps the $L\cdot s$ outputs back to length $L$ 
(e.g., by selecting the last copy or averaging over copies). 

We define two variants according to how each copy may attend:
\begin{itemize}[itemsep=0pt,topsep=0pt,leftmargin=*]
\item \textbf{Width-NoShare.} The $j$-th copy of token $i$ may attend to all copies of tokens $0,\dots,i-1$, 
as well as the first $j-1$ copies of token $i$.  

\item \textbf{Width-KVShare.} The $j$-th copy of token $i$ may attend only to the last copy of tokens $0,\dots,i-1$, 
together with the first $j-1$ copies of token $i$.
\end{itemize}
\end{definition}
```
```{=latex}
\begin{proposition}
During prefilling, both \textsf{Width-NoShare} and \textsf{Width-KVShare} are valid 
width-scaling architectures with factor $s$.
\end{proposition}
```
```{=latex}
\begin{proof}
\textbf{Depth.}  
At any generation step, each variant performs exactly one Transformer block forward pass 
on the replicated sequence. 
Therefore the number of serial block forward passes needed to produce the hidden states is unchanged, 
so the depth satisfies $\tilde d_t = d_t$.

\textbf{Width.}  
By definition, the width $w_t$ is the number of hidden states produced in parallel at step $t$.  
In the original model, prefilling a sequence of length $L$ produces $L$ hidden states per step.  
In both variants, we replicate each token $s$ times, so the block computes hidden states for 
all pairs $(i,j)$ with $i \in \{1,\dots,L\}$ and $j \in \{1,\dots,s\}$.  
Hence the total number of hidden states produced in that step is
\[
\tilde w_t = Ls = s \cdot w_t.
\]
The difference between \textsf{NoShare} and \textsf{KVShare} lies only in the 
attention pattern (which copies each query may attend to). 
This affects information flow but not the number of hidden states computed.  
The optional reduction back to length $L$ occurs \emph{after} the parallel computation 
and thus does not change the measured width.

\textbf{Conclusion.}  
Both variants keep serial depth fixed and enlarge width by a factor of $s$, 
which is precisely our notion of width scaling.
\end{proof}
```

[^1]: With this form of KV sharing, the cache requires no more memory than that of a parameter-matched fixed-depth transformer.
