---
abstract:
- |
  Large language models have shown remarkable reasoning abilities and scaling laws suggest that large parameter count, especially along the depth axis, is the primary driver. In this work, we make a stronger claim --- many reasoning problems require a large depth but not necessarily many parameters. This unlocks a novel application of *looped models for reasoning*. Firstly, we show that for many synthetic reasoning problems like addition, $p$-hop induction, and math problems, a $k$-layer transformer looped $L$ times nearly matches the performance of a $kL$-layer non-looped model, and is significantly better than a $k$-layer model. This is further corroborated by theoretical results showing that many such reasoning problems can be solved via iterative algorithms, and thus, can be solved effectively using looped models with nearly optimal depth. Perhaps surprisingly, these benefits also translate to practical settings of language modeling --- on many downstream reasoning tasks, a language model with $k$-layers looped $L$ times can be competitive to, if not better than, a $kL$-layer language model. In fact, our empirical analysis reveals an intriguing phenomenon: looped and non-looped models exhibit scaling behavior that depends on their effective depth, akin to the inference-time scaling of chain-of-thought (CoT) reasoning. We further elucidate the connection to CoT reasoning by proving that looped models implicitly generate *latent thoughts* and can simulate $T$ steps of CoT with $T$ loops. Inspired by these findings, we also present an interesting dichotomy between reasoning and memorization, and design a looping-based regularization that is effective on both fronts.
- |
  Large language models have shown promising abilities in reasoning problems and scaling laws suggest that parameter count is a key driver. Recent works [@chen2024can; @ye2024physics] argue that, for reasoning, depth plays a very important role in addition to parameter count. In this work, we make a more fine-grained claim --- many reasoning problems require large depth but not necessarily many parameters, in the sense that they can be solved via *looped models*. This unlocks a novel application of looped models for reasoning. For various synthetic reasoning problems like addition, $p$-hop induction, and math problems, we find that a $k$-layer transformer looped $L$ times nearly matches the quality of a $kL$-layer non-looped model, and is much better than a $k$-layer model. Thus, using a small model and providing depth via looping can suffice for such reasoning problems. This is corroborated by theoretical results that show many such reasoning problems can be solved via iterative algorithms, and thus, can be solved with looped models. We further elucidate connections to chain-of-thought (CoT) reasoning by proving that looped models implicitly generate \`\`latent thoughts" and can simulate $T$ steps of CoT with $T$ loops. Motivated by these findings, we train autoregressive models on general language modeling datasets with looping and compare a $k$-layer model looped $L$ times to a $kL$-layer model. While the looped model is understandably worse on perplexity and memorization tasks, it surprisingly does very well on tasks that require reasoning, like open book QA, math word problems and reasoning primitives. Despite having significantly fewer parameters, it can even match or outperform the non-looped $kL$-layer model on some of these tasks and have a very robust scaling behavior with respect to depth. These results suggest a novel inductive bias of looped models towards enhanced reasoning. We provide further evidence for this inductive bias by visualizing perplexity vs downstream isoplots, and design a looping-inspired regularization that solidifies this hypothesis.
author:
- |
  Nikunj Saunshi$^{1}$, Nishanth Dikkala$^{1}$, Zhiyuan Li$^{1,2}$, Sanjiv Kumar$^{1}$, Sashank J. Reddi$^{1}$\
  `{nsaunshi, nishanthd, lizhiyuan, sanjivk, sashank}@google.com`\
  $^{1}$Google Research, $^{2}$Toyota Technological Institute at Chicago\
  \
  Submitted on Oct. 1, 2024
bibliography:
- references.bib
title: 'Reasoning with Latent Thoughts: On the Power of Looped Transformers'
---

```{=latex}
\newcommand{\sashank}[1]{{\color{red} [SR: #1]}}
```
```{=latex}
\renewcommand{\algorithmicrequire}{\textbf{Input:}}
```
```{=latex}
\renewcommand{\algorithmicensure}{\textbf{Output:}}
```
```{=latex}
\newcommand{\RR}{\mathbb{R}}
```
```{=latex}
\newcommand{\method}{\textsc{LoOPT}\xspace}
```
```{=latex}
\newcommand{\ns}[1]{{\color{orange}NS: #1}}
```
```{=latex}
\newcommand{\highlight}[1]{{\color{purple}#1}}
```
```{=latex}
\newcommand{\sjr}[1]{{\color{red}SJ: #1}}
```
```{=latex}
\newcommand{\nishanth}[1]{{\color{teal} [Nishanth: #1]}}
```
```{=latex}
\newcommand{\zhiyuan}[1]{{\color{green}ZL: #1}}
```
```{=latex}
\newcommand{\new}[1]{#1}
```
```{=latex}
\newcommand{\loopy}[2]{$(#1\otimes #2)$}
```
```{=latex}
\newcommand*{\MinNumber}{0.0}
```
```{=latex}
\newcommand*{\MidNumber}{60}
```
```{=latex}
\newcommand*{\MaxNumber}{150}
```
```{=latex}
\newcommand{\ApplyGradient}[1]{%
        \ifdim #1 pt > \MidNumber pt
            \pgfmathsetmacro{\PercentColor}{max(min(100.0*(min(#1, \MaxNumber) - \MidNumber)/(\MaxNumber-\MidNumber),100.0),0.00)} %
            
            % {\pgfmathsetmacro{#1}{min(max(#1,\MinNumber),\MaxNumber)}
            \colorbox{green!\PercentColor!yellow}{#1}
        \else
            \pgfmathsetmacro{\PercentColor}{max(min(100.0*(\MidNumber - #1)/(\MidNumber-\MinNumber),100.0),0.00)} %
            \colorbox{red!\PercentColor!yellow}{#1}
        \fi
}
```
```{=latex}
\newcommand{\lreg}{\lambda_{\text{reg}}}
```
```{=latex}
\newcommand{\figleft}{{\em (Left)}}
```
```{=latex}
\newcommand{\figcenter}{{\em (Center)}}
```
```{=latex}
\newcommand{\figright}{{\em (Right)}}
```
```{=latex}
\newcommand{\figtop}{{\em (Top)}}
```
```{=latex}
\newcommand{\figbottom}{{\em (Bottom)}}
```
```{=latex}
\newcommand{\captiona}{{\em (a)}}
```
```{=latex}
\newcommand{\captionb}{{\em (b)}}
```
```{=latex}
\newcommand{\captionc}{{\em (c)}}
```
```{=latex}
\newcommand{\captiond}{{\em (d)}}
```
```{=latex}
\newcommand{\newterm}[1]{{\bf #1}}
```
```{=latex}
\def\figref#1{figure~\ref{#1}}
```
```{=latex}
\def\Figref#1{Figure~\ref{#1}}
```
```{=latex}
\def\twofigref#1#2{figures \ref{#1} and \ref{#2}}
```
```{=latex}
\def\quadfigref#1#2#3#4{figures \ref{#1}, \ref{#2}, \ref{#3} and \ref{#4}}
```
```{=latex}
\def\secref#1{section~\ref{#1}}
```
```{=latex}
\def\Secref#1{Section~\ref{#1}}
```
```{=latex}
\def\twosecrefs#1#2{sections \ref{#1} and \ref{#2}}
```
```{=latex}
\def\secrefs#1#2#3{sections \ref{#1}, \ref{#2} and \ref{#3}}
```
```{=latex}
\def\eqref#1{equation~\ref{#1}}
```
```{=latex}
\def\Eqref#1{Equation~\ref{#1}}
```
```{=latex}
\def\plaineqref#1{\ref{#1}}
```
```{=latex}
\def\chapref#1{chapter~\ref{#1}}
```
```{=latex}
\def\Chapref#1{Chapter~\ref{#1}}
```
```{=latex}
\def\rangechapref#1#2{chapters\ref{#1}--\ref{#2}}
```
```{=latex}
\def\algref#1{algorithm~\ref{#1}}
```
```{=latex}
\def\Algref#1{Algorithm~\ref{#1}}
```
```{=latex}
\def\twoalgref#1#2{algorithms \ref{#1} and \ref{#2}}
```
```{=latex}
\def\Twoalgref#1#2{Algorithms \ref{#1} and \ref{#2}}
```
```{=latex}
\def\partref#1{part~\ref{#1}}
```
```{=latex}
\def\Partref#1{Part~\ref{#1}}
```
```{=latex}
\def\twopartref#1#2{parts \ref{#1} and \ref{#2}}
```
```{=latex}
\def\ceil#1{\lceil #1 \rceil}
```
```{=latex}
\def\floor#1{\lfloor #1 \rfloor}
```
```{=latex}
\def\1{\bm{1}}
```
```{=latex}
\newcommand{\train}{\mathcal{D}}
```
```{=latex}
\newcommand{\valid}{\mathcal{D_{\mathrm{valid}}}}
```
```{=latex}
\newcommand{\test}{\mathcal{D_{\mathrm{test}}}}
```
```{=latex}
\def\eps{{\epsilon}}
```
```{=latex}
\def\reta{{\textnormal{$\eta$}}}
```
```{=latex}
\def\ra{{\textnormal{a}}}
```
```{=latex}
\def\rb{{\textnormal{b}}}
```
```{=latex}
\def\rc{{\textnormal{c}}}
```
```{=latex}
\def\rd{{\textnormal{d}}}
```
```{=latex}
\def\re{{\textnormal{e}}}
```
```{=latex}
\def\rf{{\textnormal{f}}}
```
```{=latex}
\def\rg{{\textnormal{g}}}
```
```{=latex}
\def\rh{{\textnormal{h}}}
```
```{=latex}
\def\ri{{\textnormal{i}}}
```
```{=latex}
\def\rj{{\textnormal{j}}}
```
```{=latex}
\def\rk{{\textnormal{k}}}
```
```{=latex}
\def\rl{{\textnormal{l}}}
```
```{=latex}
\def\rn{{\textnormal{n}}}
```
```{=latex}
\def\ro{{\textnormal{o}}}
```
```{=latex}
\def\rp{{\textnormal{p}}}
```
```{=latex}
\def\rq{{\textnormal{q}}}
```
```{=latex}
\def\rr{{\textnormal{r}}}
```
```{=latex}
\def\rs{{\textnormal{s}}}
```
```{=latex}
\def\rt{{\textnormal{t}}}
```
```{=latex}
\def\ru{{\textnormal{u}}}
```
```{=latex}
\def\rv{{\textnormal{v}}}
```
```{=latex}
\def\rw{{\textnormal{w}}}
```
```{=latex}
\def\rx{{\textnormal{x}}}
```
```{=latex}
\def\ry{{\textnormal{y}}}
```
```{=latex}
\def\rz{{\textnormal{z}}}
```
```{=latex}
\def\rvepsilon{{\mathbf{\epsilon}}}
```
```{=latex}
\def\rvtheta{{\mathbf{\theta}}}
```
```{=latex}
\def\rva{{\mathbf{a}}}
```
```{=latex}
\def\rvb{{\mathbf{b}}}
```
```{=latex}
\def\rvc{{\mathbf{c}}}
```
```{=latex}
\def\rvd{{\mathbf{d}}}
```
```{=latex}
\def\rve{{\mathbf{e}}}
```
```{=latex}
\def\rvf{{\mathbf{f}}}
```
```{=latex}
\def\rvg{{\mathbf{g}}}
```
```{=latex}
\def\rvh{{\mathbf{h}}}
```
```{=latex}
\def\rvu{{\mathbf{i}}}
```
```{=latex}
\def\rvj{{\mathbf{j}}}
```
```{=latex}
\def\rvk{{\mathbf{k}}}
```
```{=latex}
\def\rvl{{\mathbf{l}}}
```
```{=latex}
\def\rvm{{\mathbf{m}}}
```
```{=latex}
\def\rvn{{\mathbf{n}}}
```
```{=latex}
\def\rvo{{\mathbf{o}}}
```
```{=latex}
\def\rvp{{\mathbf{p}}}
```
```{=latex}
\def\rvq{{\mathbf{q}}}
```
```{=latex}
\def\rvr{{\mathbf{r}}}
```
```{=latex}
\def\rvs{{\mathbf{s}}}
```
```{=latex}
\def\rvt{{\mathbf{t}}}
```
```{=latex}
\def\rvu{{\mathbf{u}}}
```
```{=latex}
\def\rvv{{\mathbf{v}}}
```
```{=latex}
\def\rvw{{\mathbf{w}}}
```
```{=latex}
\def\rvx{{\mathbf{x}}}
```
```{=latex}
\def\rvy{{\mathbf{y}}}
```
```{=latex}
\def\rvz{{\mathbf{z}}}
```
```{=latex}
\def\erva{{\textnormal{a}}}
```
```{=latex}
\def\ervb{{\textnormal{b}}}
```
```{=latex}
\def\ervc{{\textnormal{c}}}
```
```{=latex}
\def\ervd{{\textnormal{d}}}
```
```{=latex}
\def\erve{{\textnormal{e}}}
```
```{=latex}
\def\ervf{{\textnormal{f}}}
```
```{=latex}
\def\ervg{{\textnormal{g}}}
```
```{=latex}
\def\ervh{{\textnormal{h}}}
```
```{=latex}
\def\ervi{{\textnormal{i}}}
```
```{=latex}
\def\ervj{{\textnormal{j}}}
```
```{=latex}
\def\ervk{{\textnormal{k}}}
```
```{=latex}
\def\ervl{{\textnormal{l}}}
```
```{=latex}
\def\ervm{{\textnormal{m}}}
```
```{=latex}
\def\ervn{{\textnormal{n}}}
```
```{=latex}
\def\ervo{{\textnormal{o}}}
```
```{=latex}
\def\ervp{{\textnormal{p}}}
```
```{=latex}
\def\ervq{{\textnormal{q}}}
```
```{=latex}
\def\ervr{{\textnormal{r}}}
```
```{=latex}
\def\ervs{{\textnormal{s}}}
```
```{=latex}
\def\ervt{{\textnormal{t}}}
```
```{=latex}
\def\ervu{{\textnormal{u}}}
```
```{=latex}
\def\ervv{{\textnormal{v}}}
```
```{=latex}
\def\ervw{{\textnormal{w}}}
```
```{=latex}
\def\ervx{{\textnormal{x}}}
```
```{=latex}
\def\ervy{{\textnormal{y}}}
```
```{=latex}
\def\ervz{{\textnormal{z}}}
```
```{=latex}
\def\rmA{{\mathbf{A}}}
```
```{=latex}
\def\rmB{{\mathbf{B}}}
```
```{=latex}
\def\rmC{{\mathbf{C}}}
```
```{=latex}
\def\rmD{{\mathbf{D}}}
```
```{=latex}
\def\rmE{{\mathbf{E}}}
```
```{=latex}
\def\rmF{{\mathbf{F}}}
```
```{=latex}
\def\rmG{{\mathbf{G}}}
```
```{=latex}
\def\rmH{{\mathbf{H}}}
```
```{=latex}
\def\rmI{{\mathbf{I}}}
```
```{=latex}
\def\rmJ{{\mathbf{J}}}
```
```{=latex}
\def\rmK{{\mathbf{K}}}
```
```{=latex}
\def\rmL{{\mathbf{L}}}
```
```{=latex}
\def\rmM{{\mathbf{M}}}
```
```{=latex}
\def\rmN{{\mathbf{N}}}
```
```{=latex}
\def\rmO{{\mathbf{O}}}
```
```{=latex}
\def\rmP{{\mathbf{P}}}
```
```{=latex}
\def\rmQ{{\mathbf{Q}}}
```
```{=latex}
\def\rmR{{\mathbf{R}}}
```
```{=latex}
\def\rmS{{\mathbf{S}}}
```
```{=latex}
\def\rmT{{\mathbf{T}}}
```
```{=latex}
\def\rmU{{\mathbf{U}}}
```
```{=latex}
\def\rmV{{\mathbf{V}}}
```
```{=latex}
\def\rmW{{\mathbf{W}}}
```
```{=latex}
\def\rmX{{\mathbf{X}}}
```
```{=latex}
\def\rmY{{\mathbf{Y}}}
```
```{=latex}
\def\rmZ{{\mathbf{Z}}}
```
```{=latex}
\def\ermA{{\textnormal{A}}}
```
```{=latex}
\def\ermB{{\textnormal{B}}}
```
```{=latex}
\def\ermC{{\textnormal{C}}}
```
```{=latex}
\def\ermD{{\textnormal{D}}}
```
```{=latex}
\def\ermE{{\textnormal{E}}}
```
```{=latex}
\def\ermF{{\textnormal{F}}}
```
```{=latex}
\def\ermG{{\textnormal{G}}}
```
```{=latex}
\def\ermH{{\textnormal{H}}}
```
```{=latex}
\def\ermI{{\textnormal{I}}}
```
```{=latex}
\def\ermJ{{\textnormal{J}}}
```
```{=latex}
\def\ermK{{\textnormal{K}}}
```
```{=latex}
\def\ermL{{\textnormal{L}}}
```
```{=latex}
\def\ermM{{\textnormal{M}}}
```
```{=latex}
\def\ermN{{\textnormal{N}}}
```
```{=latex}
\def\ermO{{\textnormal{O}}}
```
```{=latex}
\def\ermP{{\textnormal{P}}}
```
```{=latex}
\def\ermQ{{\textnormal{Q}}}
```
```{=latex}
\def\ermR{{\textnormal{R}}}
```
```{=latex}
\def\ermS{{\textnormal{S}}}
```
```{=latex}
\def\ermT{{\textnormal{T}}}
```
```{=latex}
\def\ermU{{\textnormal{U}}}
```
```{=latex}
\def\ermV{{\textnormal{V}}}
```
```{=latex}
\def\ermW{{\textnormal{W}}}
```
```{=latex}
\def\ermX{{\textnormal{X}}}
```
```{=latex}
\def\ermY{{\textnormal{Y}}}
```
```{=latex}
\def\ermZ{{\textnormal{Z}}}
```
```{=latex}
\def\vzero{{\bm{0}}}
```
```{=latex}
\def\vone{{\bm{1}}}
```
```{=latex}
\def\vmu{{\bm{\mu}}}
```
```{=latex}
\def\vtheta{{\bm{\theta}}}
```
```{=latex}
\def\va{{\bm{a}}}
```
```{=latex}
\def\vb{{\bm{b}}}
```
```{=latex}
\def\vc{{\bm{c}}}
```
```{=latex}
\def\vd{{\bm{d}}}
```
```{=latex}
\def\ve{{\bm{e}}}
```
```{=latex}
\def\vf{{\bm{f}}}
```
```{=latex}
\def\vg{{\bm{g}}}
```
```{=latex}
\def\vh{{\bm{h}}}
```
```{=latex}
\def\vi{{\bm{i}}}
```
```{=latex}
\def\vj{{\bm{j}}}
```
```{=latex}
\def\vk{{\bm{k}}}
```
```{=latex}
\def\vl{{\bm{l}}}
```
```{=latex}
\def\vm{{\bm{m}}}
```
```{=latex}
\def\vn{{\bm{n}}}
```
```{=latex}
\def\vo{{\bm{o}}}
```
```{=latex}
\def\vp{{\bm{p}}}
```
```{=latex}
\def\vq{{\bm{q}}}
```
```{=latex}
\def\vr{{\bm{r}}}
```
```{=latex}
\def\vs{{\bm{s}}}
```
```{=latex}
\def\vt{{\bm{t}}}
```
```{=latex}
\def\vu{{\bm{u}}}
```
```{=latex}
\def\vv{{\bm{v}}}
```
```{=latex}
\def\vw{{\bm{w}}}
```
```{=latex}
\def\vx{{\bm{x}}}
```
```{=latex}
\def\vy{{\bm{y}}}
```
```{=latex}
\def\vz{{\bm{z}}}
```
```{=latex}
\def\evalpha{{\alpha}}
```
```{=latex}
\def\evbeta{{\beta}}
```
```{=latex}
\def\evepsilon{{\epsilon}}
```
```{=latex}
\def\evlambda{{\lambda}}
```
```{=latex}
\def\evomega{{\omega}}
```
```{=latex}
\def\evmu{{\mu}}
```
```{=latex}
\def\evpsi{{\psi}}
```
```{=latex}
\def\evsigma{{\sigma}}
```
```{=latex}
\def\evtheta{{\theta}}
```
```{=latex}
\def\eva{{a}}
```
```{=latex}
\def\evb{{b}}
```
```{=latex}
\def\evc{{c}}
```
```{=latex}
\def\evd{{d}}
```
```{=latex}
\def\eve{{e}}
```
```{=latex}
\def\evf{{f}}
```
```{=latex}
\def\evg{{g}}
```
```{=latex}
\def\evh{{h}}
```
```{=latex}
\def\evi{{i}}
```
```{=latex}
\def\evj{{j}}
```
```{=latex}
\def\evk{{k}}
```
```{=latex}
\def\evl{{l}}
```
```{=latex}
\def\evm{{m}}
```
```{=latex}
\def\evn{{n}}
```
```{=latex}
\def\evo{{o}}
```
```{=latex}
\def\evp{{p}}
```
```{=latex}
\def\evq{{q}}
```
```{=latex}
\def\evr{{r}}
```
```{=latex}
\def\evs{{s}}
```
```{=latex}
\def\evt{{t}}
```
```{=latex}
\def\evu{{u}}
```
```{=latex}
\def\evv{{v}}
```
```{=latex}
\def\evw{{w}}
```
```{=latex}
\def\evx{{x}}
```
```{=latex}
\def\evy{{y}}
```
```{=latex}
\def\evz{{z}}
```
```{=latex}
\def\mA{{\bm{A}}}
```
```{=latex}
\def\mB{{\bm{B}}}
```
```{=latex}
\def\mC{{\bm{C}}}
```
```{=latex}
\def\mD{{\bm{D}}}
```
```{=latex}
\def\mE{{\bm{E}}}
```
```{=latex}
\def\mF{{\bm{F}}}
```
```{=latex}
\def\mG{{\bm{G}}}
```
```{=latex}
\def\mH{{\bm{H}}}
```
```{=latex}
\def\mI{{\bm{I}}}
```
```{=latex}
\def\mJ{{\bm{J}}}
```
```{=latex}
\def\mK{{\bm{K}}}
```
```{=latex}
\def\mL{{\bm{L}}}
```
```{=latex}
\def\mM{{\bm{M}}}
```
```{=latex}
\def\mN{{\bm{N}}}
```
```{=latex}
\def\mO{{\bm{O}}}
```
```{=latex}
\def\mP{{\bm{P}}}
```
```{=latex}
\def\mQ{{\bm{Q}}}
```
```{=latex}
\def\mR{{\bm{R}}}
```
```{=latex}
\def\mS{{\bm{S}}}
```
```{=latex}
\def\mT{{\bm{T}}}
```
```{=latex}
\def\mU{{\bm{U}}}
```
```{=latex}
\def\mV{{\bm{V}}}
```
```{=latex}
\def\mW{{\bm{W}}}
```
```{=latex}
\def\mX{{\bm{X}}}
```
```{=latex}
\def\mY{{\bm{Y}}}
```
```{=latex}
\def\mZ{{\bm{Z}}}
```
```{=latex}
\def\mBeta{{\bm{\beta}}}
```
```{=latex}
\def\mPhi{{\bm{\Phi}}}
```
```{=latex}
\def\mLambda{{\bm{\Lambda}}}
```
```{=latex}
\def\mSigma{{\bm{\Sigma}}}
```
```{=latex}
\newcommand{\tens}[1]{\bm{\mathsfit{#1}}}
```
```{=latex}
\def\tA{{\tens{A}}}
```
```{=latex}
\def\tB{{\tens{B}}}
```
```{=latex}
\def\tC{{\tens{C}}}
```
```{=latex}
\def\tD{{\tens{D}}}
```
```{=latex}
\def\tE{{\tens{E}}}
```
```{=latex}
\def\tF{{\tens{F}}}
```
```{=latex}
\def\tG{{\tens{G}}}
```
```{=latex}
\def\tH{{\tens{H}}}
```
```{=latex}
\def\tI{{\tens{I}}}
```
```{=latex}
\def\tJ{{\tens{J}}}
```
```{=latex}
\def\tK{{\tens{K}}}
```
```{=latex}
\def\tL{{\tens{L}}}
```
```{=latex}
\def\tM{{\tens{M}}}
```
```{=latex}
\def\tN{{\tens{N}}}
```
```{=latex}
\def\tO{{\tens{O}}}
```
```{=latex}
\def\tP{{\tens{P}}}
```
```{=latex}
\def\tQ{{\tens{Q}}}
```
```{=latex}
\def\tR{{\tens{R}}}
```
```{=latex}
\def\tS{{\tens{S}}}
```
```{=latex}
\def\tT{{\tens{T}}}
```
```{=latex}
\def\tU{{\tens{U}}}
```
```{=latex}
\def\tV{{\tens{V}}}
```
```{=latex}
\def\tW{{\tens{W}}}
```
```{=latex}
\def\tX{{\tens{X}}}
```
```{=latex}
\def\tY{{\tens{Y}}}
```
```{=latex}
\def\tZ{{\tens{Z}}}
```
```{=latex}
\def\gA{{\mathcal{A}}}
```
```{=latex}
\def\gB{{\mathcal{B}}}
```
```{=latex}
\def\gC{{\mathcal{C}}}
```
```{=latex}
\def\gD{{\mathcal{D}}}
```
```{=latex}
\def\gE{{\mathcal{E}}}
```
```{=latex}
\def\gF{{\mathcal{F}}}
```
```{=latex}
\def\gG{{\mathcal{G}}}
```
```{=latex}
\def\gH{{\mathcal{H}}}
```
```{=latex}
\def\gI{{\mathcal{I}}}
```
```{=latex}
\def\gJ{{\mathcal{J}}}
```
```{=latex}
\def\gK{{\mathcal{K}}}
```
```{=latex}
\def\gL{{\mathcal{L}}}
```
```{=latex}
\def\gM{{\mathcal{M}}}
```
```{=latex}
\def\gN{{\mathcal{N}}}
```
```{=latex}
\def\gO{{\mathcal{O}}}
```
```{=latex}
\def\gP{{\mathcal{P}}}
```
```{=latex}
\def\gQ{{\mathcal{Q}}}
```
```{=latex}
\def\gR{{\mathcal{R}}}
```
```{=latex}
\def\gS{{\mathcal{S}}}
```
```{=latex}
\def\gT{{\mathcal{T}}}
```
```{=latex}
\def\gU{{\mathcal{U}}}
```
```{=latex}
\def\gV{{\mathcal{V}}}
```
```{=latex}
\def\gW{{\mathcal{W}}}
```
```{=latex}
\def\gX{{\mathcal{X}}}
```
```{=latex}
\def\gY{{\mathcal{Y}}}
```
```{=latex}
\def\gZ{{\mathcal{Z}}}
```
```{=latex}
\def\sA{{\mathbb{A}}}
```
```{=latex}
\def\sB{{\mathbb{B}}}
```
```{=latex}
\def\sC{{\mathbb{C}}}
```
```{=latex}
\def\sD{{\mathbb{D}}}
```
```{=latex}
\def\sF{{\mathbb{F}}}
```
```{=latex}
\def\sG{{\mathbb{G}}}
```
```{=latex}
\def\sH{{\mathbb{H}}}
```
```{=latex}
\def\sI{{\mathbb{I}}}
```
```{=latex}
\def\sJ{{\mathbb{J}}}
```
```{=latex}
\def\sK{{\mathbb{K}}}
```
```{=latex}
\def\sL{{\mathbb{L}}}
```
```{=latex}
\def\sM{{\mathbb{M}}}
```
```{=latex}
\def\sN{{\mathbb{N}}}
```
```{=latex}
\def\sO{{\mathbb{O}}}
```
```{=latex}
\def\sP{{\mathbb{P}}}
```
```{=latex}
\def\sQ{{\mathbb{Q}}}
```
```{=latex}
\def\sR{{\mathbb{R}}}
```
```{=latex}
\def\sS{{\mathbb{S}}}
```
```{=latex}
\def\sT{{\mathbb{T}}}
```
```{=latex}
\def\sU{{\mathbb{U}}}
```
```{=latex}
\def\sV{{\mathbb{V}}}
```
```{=latex}
\def\sW{{\mathbb{W}}}
```
```{=latex}
\def\sX{{\mathbb{X}}}
```
```{=latex}
\def\sY{{\mathbb{Y}}}
```
```{=latex}
\def\sZ{{\mathbb{Z}}}
```
```{=latex}
\def\emLambda{{\Lambda}}
```
```{=latex}
\def\emA{{A}}
```
```{=latex}
\def\emB{{B}}
```
```{=latex}
\def\emC{{C}}
```
```{=latex}
\def\emD{{D}}
```
```{=latex}
\def\emE{{E}}
```
```{=latex}
\def\emF{{F}}
```
```{=latex}
\def\emG{{G}}
```
```{=latex}
\def\emH{{H}}
```
```{=latex}
\def\emI{{I}}
```
```{=latex}
\def\emJ{{J}}
```
```{=latex}
\def\emK{{K}}
```
```{=latex}
\def\emL{{L}}
```
```{=latex}
\def\emM{{M}}
```
```{=latex}
\def\emN{{N}}
```
```{=latex}
\def\emO{{O}}
```
```{=latex}
\def\emP{{P}}
```
```{=latex}
\def\emQ{{Q}}
```
```{=latex}
\def\emR{{R}}
```
```{=latex}
\def\emS{{S}}
```
```{=latex}
\def\emT{{T}}
```
```{=latex}
\def\emU{{U}}
```
```{=latex}
\def\emV{{V}}
```
```{=latex}
\def\emW{{W}}
```
```{=latex}
\def\emX{{X}}
```
```{=latex}
\def\emY{{Y}}
```
```{=latex}
\def\emZ{{Z}}
```
```{=latex}
\def\emSigma{{\Sigma}}
```
```{=latex}
\newcommand{\etens}[1]{\mathsfit{#1}}
```
```{=latex}
\def\etLambda{{\etens{\Lambda}}}
```
```{=latex}
\def\etA{{\etens{A}}}
```
```{=latex}
\def\etB{{\etens{B}}}
```
```{=latex}
\def\etC{{\etens{C}}}
```
```{=latex}
\def\etD{{\etens{D}}}
```
```{=latex}
\def\etE{{\etens{E}}}
```
```{=latex}
\def\etF{{\etens{F}}}
```
```{=latex}
\def\etG{{\etens{G}}}
```
```{=latex}
\def\etH{{\etens{H}}}
```
```{=latex}
\def\etI{{\etens{I}}}
```
```{=latex}
\def\etJ{{\etens{J}}}
```
```{=latex}
\def\etK{{\etens{K}}}
```
```{=latex}
\def\etL{{\etens{L}}}
```
```{=latex}
\def\etM{{\etens{M}}}
```
```{=latex}
\def\etN{{\etens{N}}}
```
```{=latex}
\def\etO{{\etens{O}}}
```
```{=latex}
\def\etP{{\etens{P}}}
```
```{=latex}
\def\etQ{{\etens{Q}}}
```
```{=latex}
\def\etR{{\etens{R}}}
```
```{=latex}
\def\etS{{\etens{S}}}
```
```{=latex}
\def\etT{{\etens{T}}}
```
```{=latex}
\def\etU{{\etens{U}}}
```
```{=latex}
\def\etV{{\etens{V}}}
```
```{=latex}
\def\etW{{\etens{W}}}
```
```{=latex}
\def\etX{{\etens{X}}}
```
```{=latex}
\def\etY{{\etens{Y}}}
```
```{=latex}
\def\etZ{{\etens{Z}}}
```
```{=latex}
\newcommand{\pdata}{p_{\rm{data}}}
```
```{=latex}
\newcommand{\ptrain}{\hat{p}_{\rm{data}}}
```
```{=latex}
\newcommand{\Ptrain}{\hat{P}_{\rm{data}}}
```
```{=latex}
\newcommand{\pmodel}{p_{\rm{model}}}
```
```{=latex}
\newcommand{\Pmodel}{P_{\rm{model}}}
```
```{=latex}
\newcommand{\ptildemodel}{\tilde{p}_{\rm{model}}}
```
```{=latex}
\newcommand{\pencode}{p_{\rm{encoder}}}
```
```{=latex}
\newcommand{\pdecode}{p_{\rm{decoder}}}
```
```{=latex}
\newcommand{\precons}{p_{\rm{reconstruct}}}
```
```{=latex}
\newcommand{\laplace}{\mathrm{Laplace}}
```
```{=latex}
\newcommand{\E}{\mathbb{E}}
```
```{=latex}
\newcommand{\Ls}{\mathcal{L}}
```
```{=latex}
\newcommand{\R}{\mathbb{R}}
```
```{=latex}
\newcommand{\emp}{\tilde{p}}
```
```{=latex}
\newcommand{\lr}{\alpha}
```
```{=latex}
\newcommand{\reg}{\lambda}
```
```{=latex}
\newcommand{\rect}{\mathrm{rectifier}}
```
```{=latex}
\newcommand{\softmax}{\mathrm{softmax}}
```
```{=latex}
\newcommand{\sigmoid}{\sigma}
```
```{=latex}
\newcommand{\softplus}{\zeta}
```
```{=latex}
\newcommand{\KL}{D_{\mathrm{KL}}}
```
```{=latex}
\newcommand{\Var}{\mathrm{Var}}
```
```{=latex}
\newcommand{\standarderror}{\mathrm{SE}}
```
```{=latex}
\newcommand{\Cov}{\mathrm{Cov}}
```
```{=latex}
\newcommand{\inner}[2]{\left\langle #1,#2 \right\rangle}
```
```{=latex}
\newcommand{\normlzero}{L^0}
```
```{=latex}
\newcommand{\normlone}{L^1}
```
```{=latex}
\newcommand{\normltwo}{L^2}
```
```{=latex}
\newcommand{\normlp}{L^p}
```
```{=latex}
\newcommand{\normmax}{L^\infty}
```
```{=latex}
\newcommand{\parents}{Pa}
```
```{=latex}
\DeclareMathOperator*{\argmax}{arg\,max}
```
```{=latex}
\DeclareMathOperator*{\argmin}{arg\,min}
```
```{=latex}
\DeclareMathOperator{\sign}{sign}
```
```{=latex}
\DeclareMathOperator{\Tr}{Tr}
```
```{=latex}
\let\ab\allowbreak
```
```{=latex}
\renewcommand{\P}{\mathsf{P}}
```
```{=latex}
\newcommand{\NC}{\mathsf{NC}}
```
```{=latex}
\newcommand{\TC}{\mathsf{TC}}
```
```{=latex}
\newcommand{\round}{\mathsf{round}}
```
```{=latex}
\newcommand{\relu}{\mathsf{relu}}
```
```{=latex}
\newcommand{\hop}{\mathsf{hop}}
```
```{=latex}
\newcommand{\abs}[1]{|#1|}
```
```{=latex}
\newcommand{\find}{\mathsf{find}}
```
```{=latex}
\newcommand{\id}{\mathsf{id}}
```
```{=latex}
\newcommand{\transformer}{\mathsf{TF}}
```
```{=latex}
\newcommand{\tfblock}{\mathsf{TB}}
```
```{=latex}
\newcommand{\posencoding}{\mathsf{PE}}
```
```{=latex}
\newcommand{\tokenembedding}{\mathsf{TE}}
```
```{=latex}
\newcommand{\transoutput}{\mathsf{OUTPUT}}
```
```{=latex}
\newcommand{\embed}{\mathsf{EMBED}}
```
```{=latex}
\newcommand{\decoder}{\mathsf{DECODER}}
```
```{=latex}
\newcommand{\loopedtransformer}{\mathsf{LTF}}
```
```{=latex}
\newcommand{\stransformer}{{\overrightarrow{\mathsf{Trans}}}}
```
```{=latex}
\newcommand{\attn}{\mathsf{ATTN}}
```
```{=latex}
\newcommand{\shift}{\mathsf{SHIFT}}
```
```{=latex}
\newcommand{\mha}{\mathsf{MHA}}
```
```{=latex}
\newcommand{\ff}{\mathsf{FF}}
```
```{=latex}
\newcommand{\NOT}{\mathsf{NOT}}
```
```{=latex}
\newcommand{\AND}{\mathsf{AND}}
```
```{=latex}
\newcommand{\OR}{\mathsf{OR}}
```
```{=latex}
\newcommand{\MAJORITY}{\mathsf{MAJORITY}}
```
```{=latex}
\newcommand{\bin}{\mathsf{bin}}
```
```{=latex}
\newcommand{\sbin}{\mathsf{sbin}}
```
```{=latex}
\newcommand{\pos}{\mathsf{pos}}
```
```{=latex}
\newcommand{\inp}{\mathsf{input}}
```
```{=latex}
\newcommand{\Adam}{\mathsf{Adam}}
```
```{=latex}
\newcommand{\Hop}{\mathsf{Hop}}
```
```{=latex}
\newcommand{\base}{{\sf{base}}}
```
```{=latex}
\newcommand{\hint}{{\sf{hint}}}
```
```{=latex}
\renewcommand{\cot}{{\sf{cot}}}
```
```{=latex}
\newcommand{\lossce}{\ell_{\mathsf{ce}}}
```
```{=latex}
\newcommand{\indct}[1]{\bm{1}\left[#1\right]}
```
```{=latex}
\newcommand{\True}{\mathsf{TRUE}}
```
```{=latex}
\newcommand{\False}{\mathsf{FALSE}}
```
```{=latex}
\newcommand{\rds}[1]{\left[#1\right]_s}
```
```{=latex}
\newcommand{\rdes}[1]{\left[#1\right]_{e,s}}
```
```{=latex}
\newcommand{\float}{\mathsf{float}}
```
```{=latex}
\newcommand{\sums}{\mathsf{sum}_s}
```
```{=latex}
\newcommand{\sumes}{\mathsf{sum}_{e,s}}
```
```{=latex}
\newcommand{\significand}{\mathrm{significand}}
```
```{=latex}
\newcommand{\exponent}{\mathrm{exponent}}
```
```{=latex}
\newcommand{\Lbar}{\bar{\mathcal L}}
```
```{=latex}
\newcommand{\wematrix}{M(\mathbf W)}
```
```{=latex}
\newcommand{\sharpness}{\mathcal{S}}
```
```{=latex}
\newcommand{\rada}{\mathcal{R}}
```
```{=latex}
\newcommand{\model}{f}
```
```{=latex}
\newcommand{\para}{\theta}
```
```{=latex}
\newcommand{\dataset}{\mathbf{D}}
```
```{=latex}
\newcommand{\distr}{\mathcal{D}}
```
```{=latex}
\newcommand{\feature}{x}
```
```{=latex}
\newcommand{\target}{y}
```
```{=latex}
\newcommand{\loss}{\ell}
```
```{=latex}
\newcommand{\Loss}{\mathcal{L}}
```
```{=latex}
\newcommand{\Trace}{\mathrm{Tr}}
```
```{=latex}
\newcommand{\ReLU}{\mathrm{ReLU}}
```
```{=latex}
\newcommand{\mseloss}{\ell_{\mathrm{mse}}}
```
```{=latex}
\newcommand{\logistic}{\ell_{\mathrm{logistic,p}}}
```
```{=latex}
\newcommand{\zeroone}{\ell_{\mathrm{01}}}
```
```{=latex}
\newcommand{\jacobian}{\mathcal{J}}
```
```{=latex}
\newcommand{\bn}{\mathrm{BN}}
```
```{=latex}
\newcommand{\simplebn}{\mathrm{BN}_{\mathrm{simple}}}
```
```{=latex}
\newcommand{\lan}{\mathrm{LN}}
```
```{=latex}
\newcommand{\simpleln}{\mathrm{LN}_{\mathrm{simple}}}
```
```{=latex}
\newcommand{\affine}{\mathcal{A}}
```
```{=latex}
\newcommand{\weight}{\mathcal{W}}
```
```{=latex}
\newcommand{\bias}{\mathbf{b}}
```
```{=latex}
\newcommand{\dimension}{\mathrm{D}}
```
```{=latex}
\newcommand{\Para}{\Theta}
```
```{=latex}
\newcommand{\sgn}{\mathrm{sign}}
```
```{=latex}
\newcommand{\Floating}{\mathbb{F}}
```
```{=latex}
\newcommand{\interleave}[2]{{#1}^\frown{#2}}
```
```{=latex}
\newcommand{\fix}{\marginpar{FIX}}
```
```{=latex}
\maketitle
```
```{=latex}
\iffalse 
```
```{=latex}
\fi
```
Introduction {#sec:intro}
============

```{=latex}
\iffalse
```
Outline

-   Scaling laws suggest that more parameters are needed to solve more and harder problems. Recent work (igsm) argues that not just params, but depth is very important for reasoning

-   In this work we argue further -- reasoning needs depth, but not necessarily many parameters. Intuitive you want to learn the right \`\`algorithm" to solve the problem, which can be concise

-   Show many synthetic examples highlighting this point. 3L-4loop is much closer to 12L than to 3L on reasoning problems.

-   Some theory? (b) parameter sharing can still solve reasoning problems of interesting: e.g. k-hop (and other problems like graph algorithms), (c) general simulation result with mild width expansion, (d) something about graph algorithms, RASP, (e) separation loop(n+1) vs loop(n), (a) CoT is one way to increase depth in autoregressive models. Loop can simulate CoT, so we inherit its expressivity.

-   Inspired by this, we pretrain LLMs with block looping. Performance is lagging on most tasks compared to non-looped, especially on perplexity, but the most improvements are on reasoning -- \`\`inductive bias"

-   Is there a way to utilize this inductive bias? Train with regularization on layers. better on everything, especially reasoning. Bonus: make it looped post-hoc

TODOs

-   Synthetic expts on **var-assignment, k-hop, addition, i-gsm**

-   Theory for (a), (b) and (c) at least

-   Real experiments

    -   (Running) Get all evals for current looping and regularization experiments and include in the paper

    -   (Running) Rerun experiments on PILE

```{=latex}
\fi
```
`\looseness`{=latex}-1Language models have shown a lot of promise in solving problems that require strong reasoning abilities like math, coding, common sense reasoning and logical puzzles [@brown2020language; @team2023gemini]. This has sparked interest in developing techniques to improve reasoning on harder problems [@wei2022chain] and has inspired theoretical studies on how Transformers are able to perform reasoning [@feng2024towards; @sanford2024understanding]. Reasoning abilities are often emergent in larger language models [@wei2022emergent] -- this aligns with various scaling laws [@kaplan2020scaling; @hoffmann2022training; @allen2024physics] that show that the performance of language models is very strongly dependent on the model size (i.e., number of parameters) and much lesser on other architectural design choices. However, recent works have started to question this view. @ye2024physics argue that scaling laws for reasoning are more subtle, and *depth is very important* in addition to parameter count -- at the same parameter count, deeper but shallower models are better. This is a deviation from the conventional scaling law wisdom, but it intuitively makes sense because reasoning problems often requires multi-step compositional thinking, and thus, depth can play a crucial role.

`\looseness`{=latex}-1In this work, we make a stronger claim -- while depth is important, many reasoning problems do not necessarily require a lot of parameters. How does one solve reasoning problems with large depth but few parameters? We argue that *looped models* are perfectly suited for this, where the same function, parameterized with few parameters, is iteratively applied on the input. This leads us to our first important claim:

*Claim 1: Many reasoning problems require depth but not necessarily parameters. That is, they can be solved via looped models*

Looped models have been studied in the literature for parameter efficiency [@lan2019albert], adaptive compute [@dehghani2018universal], equilibrium models [@bai2019deep] and for in-context learning [@yang2023looped; @gatmiry2024can]. In this work, we **initiate the study of looped models in the context of reasoning**. Admittedly, reasoning is not very well-defined and can be of various forms [@sun2023survey]. Acknowledging this hurdle, we focus on a non-exhaustive list of problems that intuitively require reasoning and that are inspired by reasoning benchmarks. Throughout the paper, we use the notation `\loopy{k}{L}`{=latex} to denote a $k$-layer model looped $L$ times (precise definition in `\Cref{sec:synthetic_reasoning}`{=latex}), which has the same number of parameters as a `\loopy{k}{1}`{=latex} model and same flops as a `\loopy{kL}{1}`{=latex} non-looped model (see `\Cref{fig:looping_illustration}`{=latex}). As a first step towards connecting looped models and reasoning, we empirically evaluate looped models on several simple reasoning tasks in the literature (e.g. `\Cref{sec:synthetic_reasoning}`{=latex}). Perhaps surprisingly, we find that a `\loopy{k}{L}`{=latex} looped models does almost as well as, if not better than, a non-looped model `\loopy{kL}{1}`{=latex} that has the same effective depth but $L$ times more parameters on these reasoning tasks. The looped model is also significantly better than a `\loopy{k}{1}`{=latex} model which has the same number of parameters. Our theoretical results on the expressiveness of looped models in representing iterative algorithms with short description further corroborate these empirical findings and provide strong support for our claim. This naturally raises an important question: do looped models benefit language modeling in a similar manner?

![`\looseness-1`{=latex}`\new{Illustration of the simple and architecture agnostic looping mechanism that we consider. A $k$-layer block looped $L$ times ({\bf middle}) is denoted by \loopy{k}{L}, which can essentially be viewed as a weighted shared model. The iso-param baseline, \loopy{k}{1}, is a $k$-layer model with the same number of \emph{distinct} parameters. The iso-FLOP baseline, \loopy{kL}{1}, is a $kL$-layer model with the same depth but $L$ times more parameters. Middle looping is a strategy that is inspired from prior works on model stacking (e.g. \citep{saunshi2024inductive}).}`{=latex}](Media/looping_illustration2.png){#fig:looping_illustration width="80%"}

```{=latex}
\hfill
```
*Claim 2: For language modeling, looped models have an inductive bias towards good reasoning despite having worse perplexity and memorization to an iso-flop non-looped model*

`\looseness`{=latex}-1For the above claim, we again train a `\loopy{k}{L}`{=latex} looped model on causal language modeling and compare it to the iso-param `\loopy{k}{1}`{=latex} and iso-flop `\loopy{kL}{1}`{=latex} non-looped baselines. While the looped model improves over the iso-param baseline, perhaps unsurprisingly, it ends up with worse perplexity than iso-flop baseline, since perplexity depends strongly on number of parameters. However, the downstream evaluations reveal an intriguing trend: looped models have a tendency to improve tasks that require reasoning a lot more than memorization tasks. Specifically, the looped model has reasoning performance much closer to the iso-flop baseline, sometimes even exceeding it despite having $L$ times fewer parameters and worse perplexity. This contrasting behavior between the pretraining and downstream metrics has been a subject of study lately [@saunshi22understanding; @liu2023same] and is attributed to the *inductive biases* introduced due to different architectures and training algorithms. Our empirical analysis also uncovers an interesting phenomenon: accuracy on downstream tasks scale as logarithm of the effective depth. In particular, more loops enhances performance, and the relative benefit of loops is higher for tasks that require more reasoning. This is conceptually similar to inference time scaling discovered for CoT, but with looping as a central component. To further elucidate this interesting relationship of looped models with CoT, we present the following claim.

*Claim 3: Looped models generate latent thoughts and can, in theory, simulate CoT reasoning*

Note that CoT reasoning gives the model more time and compute by generating multiple thought tokens before the answer, and it has powered the recent paradigm of inference-time scaling for \`\`thinking" models like O1 and DeepSeek's R1 [@guo2025deepseek]. We make an observation about CoT reasoning -- it is essentially a looped model that generates 1 thought token in each iteration. However, looped models seem to be much more powerful, since they can generate multiple *latent thoughts* in each iteration. We translate this intuition into a theoretical result about how looped models can simulate CoT reasoning.

```{=latex}
\iffalse 
```
As a first step towards connecting looped models and reasoning, we empirically evaluate looped models on simple (non-exhaustive) set of problems that have been used to study reasoning in the literature -- addition [@nye2021show; @nogueira2021investigating; @li2024chain], $p$-hop induction [@sanford2024transformers], variable assignment [@saunshi2024inductive] and math word problems [@wei2022chain; @ye2024physics]. For each of these reasoning problems, rather surprisingly, we find that a `\loopy{k}{L}`{=latex} looped models does almost as well as (or sometimes even better than) a non-looped model `\loopy{kL}{1}`{=latex} that has the same effective depth but $L$ times more parameters. Additionally, the looped model is significantly better than a `\loopy{k}{1}`{=latex} model which has the same number of parameters. This suggests that these problems require depth but not necessarily more parameters, and that looped models are very good for independently solving such simple reasoning problems. We support these findings with theoretical results on the expressiveness of looped models -- many such reasoning problems can be solved by iterative algorithms with a short description, and that such algorithms are implementable by looped models with few parameters.

These results are surprising given the prevalent notion that model size is important for good reasoning. `\new{What about problems that do not obviously have an iterative algorithm? We broaden the scope of looped models by establishing a connection to chain-of-thought (CoT) reasoning \citep{wei2022chain}.}`{=latex}

![`\looseness-1`{=latex}`\new{Illustration of the simple and architecture agnostic looping mechanism that we consider. A $k$-layer block looped $L$ times ({\bf middle}) is denoted by \loopy{k}{L}, which can essentially be viewed as a weighted shared model. The iso-param baseline, \loopy{k}{1}, is a $k$-layer model with the same number of \emph{distinct} parameters. The iso-FLOP baseline, \loopy{kL}{1}, is a $kL$-layer model with the same depth but $L$ times more parameters. Middle looping is a strategy that is inspired from prior works on model stacking (e.g. \citep{saunshi2024inductive}).}`{=latex}](Media/looping_illustration2.png){#fig:looping_illustration width="80%"}

```{=latex}
\hfill
```
*Claim 2: Looped models generate latent thoughts and can, in theory, simulate CoT reasoning*

```{=latex}
\new{CoT reasoning gives the model more time and compute by generating multiple thought tokens before the answer, and it has powered the recent paradigm of inference-time scaling for ``thinking'' models like O1 and DeepSeek's R1 \citep{guo2025deepseek}. We make an observation about CoT reasoning -- it is essentially a looped model that generates 1 thought token in each iteration. However, looped models seem to be much more powerful, since they can generate multiple ``latent thoughts'' in each iteration. 
We translate this intuition into a theoretical result about how looped models can simulate CoT reasoning.
Given this backdrop, a natural question arises: are looped models relevant for language modeling?}
```
*Claim 3: For language modeling, looped models have an inductive bias towards good reasoning despite having worse perplexity*

For the above claim, we again train a `\loopy{k}{L}`{=latex} looped model on causal language modeling and compare it to the iso-param `\loopy{k}{1}`{=latex} and iso-flop `\loopy{kL}{1}`{=latex} non-looped baselines. While the looped model improves over the iso-param baseline, perhaps unsurprisingly, it ends up with worse perplexity than iso-flop baseline, since perplexity depends strongly on number of parameters. However, the downstream evaluations reveal an intriguing trend: looped models have a tendency to improve tasks that require reasoning a lot more than memorization tasks. Specifically, the looped model has reasoning performance much closer to the iso-flop baseline, sometimes even exceeding it despite having $L$ times fewer parameters and worse perplexity. This contrasting behavior between the pretraining and downstream metrics has been a subject of study lately [@saunshi22understanding; @liu2023same] and is attributed to the *inductive biases* introduced due to different architectures and training algorithms. `\ns{Add something about scaling behavior}`{=latex}

```{=latex}
\fi
```
Motivated by these findings, we propose a regularization scheme that aims to tap into the inductive bias of looped models towards reasoning. This leads us to our final claim:

*Claim 4: Looping-inspired regularization can leverage this inductive bias towards better reasoning*

With the backdrop of these claims, we concretely present the contributions of the paper below:

-   In this paper we study looped models -- multilayer models with weight sharing -- and their role in reasoning. In particular, we compare a $k$-layer model looped $L$ times, denoted by `\loopy{k}{L}`{=latex}, with an *iso-param* `\loopy{k}{1}`{=latex} non-looped model with $k$ layers and an *iso-flop* `\loopy{kL}{1}`{=latex} model with $kL$ layers and $L$ times more parameters.

-   We conduct experiments on synthetic reasoning tasks like addition, $p$-hop induction and GSM-style math word problems in `\Cref{sec:synthetic_reasoning}`{=latex}. For these tasks, we surprisingly find that iso-flop looped models, despite having way fewer parameters, can nearly match or outperform a non-looped model. Supporting these experiments, in `\Cref{sec:theory}`{=latex} we present theoretical results for why looped models can solve such problems with almost optimal depth.

-   In `\Cref{sec:language_modeling}`{=latex}, we train looped models on causal language modeling at 1B parameter scale. Here, we show that looped models have an inductive bias towards doing well on reasoning benchmarks, despite having much worse perplexity. This finding is novel, since most prior work on looping focused more on perplexity metrics rather than downstream reasoning tasks. We validate this inductive bias by visualizing the perplexity vs downstream performance plots as training process. `\new{Additionally, we show that looped models demonstrate good scaling behavior on various benchmarks as the number of loops are increased, akin to CoT reasoning.  
        Finally, we show that looped models, along with a scratchpad, can simulate chain-of-thought reasoning.}`{=latex}

-   Inspired by this inductive bias, in `\Cref{sec:regularization}`{=latex}, we propose a regularization that encourages layers to be more similar to each other. We find that training with such a regularization inherits the inductive bias of looped models towards reasoning without affecting perplexity.

Looped models on simple reasoning tasks {#sec:synthetic_reasoning}
=======================================

We first explore our hypothesis of looped models helping reasoning tasks on a set of tasks constructed in a procedural manner. The illustrative reasoning tasks we consider are: $n$-ary addition, $p$-hop induction head that tests the model's ability to track back for $p$ steps, and i-GSM which consists of synthetically constructed grade-school math problems. While these obviously do not cover the whole spectrum of reasoning problems, they provide useful insights into looped models and provide a basis for the theoretical results in `\Cref{sec:theory}`{=latex}.

**Looped models.** While many variants of looped model have been proposed [@lan2019albert; @dehghani2018universal; @giannou2023looped; @yang2023looped; @mohtashami2023cotformer], we use the vanilla version for simplicity of our exploration. For any sequence-to-sequence function $f$, we denote $f^{(L)} = f \circ f \dots \circ f$ to be the function that is $f$ looped $L$ times. In general, the looping mechanism is independent of architectural choices for $f$. For the rest of the paper, we typically use $f$ to denote a $k$-layer Transformer backbone of a model. Thus, $f$ looped $L$ times is the same as a $kL$ layer model with weight sharing between all $L$ blocks of $k$ consecutive layers. We denote such looped models with the notation `\loopy{k}{L}`{=latex}. Please refer to `\Cref{fig:looping_illustration}`{=latex} for a succinct illustration of the looping mechanism. `\Cref{subsec:prelim}`{=latex} provides a more formal definition of looped transformers that is used for theoretical analysis.

Experiments with simple reasoning problems {#sec:addition}
------------------------------------------

**$n$-ary addition.** `\label{sec:addition}`{=latex} We consider the problem of adding $n$ numbers with 3 digits each. Addition is popular in the literature to study aspects of reasoning with Transformers, such as use of scratchpad [@nye2021show], chain of thought reasoning [@lee2024teaching; @li2024chain] and length generalization [@cho2024position]. One reason for its popularity is that addition can have algorithmic solutions, which is a feature of many reasoning problems. For our experiments, we train on a uniform mixture on numbers of operands $n\in\{2, 4, 8, 16, 32\}$ and sample each 3-digit operand uniformly at random between $[0, 999]$. We train all models directly on the input-output pair, without any chain-of-thought steps. Following is an example for $n=4$:

Input: \`\`$315 + 120 + 045 + 824 =$" ; Output = \`\`$1304$".

We train a standard Transformer-based baseline `\loopy{12}{1}`{=latex} model with 12 layers. Please refer to `\Cref{sec:apx_addition}`{=latex} for details on the training setup. We also train `\loopy{k}{12/k}`{=latex} looped model and an iso-param `\loopy{k}{1}`{=latex} baseline models for comparison, and vary $k \in \{2, 3, 4, 6\}$. All trained models are finally evaluated separately on each of $n\in\{8, 16, 24, 32\}$ to measure accuracy on increasingly difficult problems. Results are presented in `\Cref{table:addition_phop_results}`{=latex}.

```{=latex}
\scalebox{0.77}{
\begin{tabular}{lc|cccc}
\toprule
\textbf{Addition of $n$ numbers}\\
\midrule
 & Params /  & $n=8$ & $n=16$ & $n=24$ & $n=32$\\
 & FLOPs   &  &  & & \\
\midrule
\midrule
Base \loopy{12}{1} & 12x / 12x & 100.0 & 100.0 & 100.0 & 100.0\\
\hline
\hline
\rowcolor{lightgray}
1 layer model  \\
\hline
Base \loopy{1}{1} & 1x / 1x & 0.1 & 0.1 & 0.1 & 0.0 \\
Loop \loopy{1}{12} & 1x / 12x & 99.9 & 100.0 & 99.9 & 99.6\\
\hline
\hline
\rowcolor{lightgray}
2 layer model  \\
\hline
Base \loopy{2}{1} & 2x / 2x & 85.8 & 71.5 & 49.3 & 38.8 \\
Loop \loopy{2}{6} & 2x / 12x & 100.0 & 99.8 & 99.7 & 99.5\\
\hline
\hline
\rowcolor{lightgray}
3 layer model  \\
\hline
Base \loopy{3}{1} & 3x / 3x & 97.2 & 78.5 & 69.2 & 60.7 \\
Loop \loopy{3}{4} & 3x / 12x & 100.0 & 99.1 & 97.0 & 96.6 \\
\hline
\end{tabular}
}
```
```{=latex}
\scalebox{0.77}{
\begin{tabular}{lc|cc}
\toprule
\textbf{$p$-hop with $n$ tokens}\\
\midrule
 & Params / & $p=16$ & $p=32$ \\
  & FLOPs & $n=256$ & $n=256$ \\
\midrule
\midrule
Base \loopy{6}{1} & 6x / 6x & \textbf{99.9} & \textbf{99.6} \\
\hline
\hline
\rowcolor{lightgray}
1 layer model  \\
\hline
Base \loopy{1}{1} & 1x / 1x &  48.9 & 49.0\\
Loop \loopy{1}{6} & 1x / 6x & \textbf{99.9} & \textbf{99.5}\\
\hline
\hline
\rowcolor{lightgray}
2 layer model  \\
\hline
Base \loopy{2}{1} & 2x / 2x & 68.8 & 59.4\\
Loop \loopy{2}{3} & 2x / 6x & \textbf{99.9} & \textbf{99.8}\\
\hline
\hline
\rowcolor{lightgray}
3 layer model  \\
\hline
Base \loopy{3}{1} & 3x / 3x & 97.2 & 73.0\\
Loop \loopy{3}{2} & 3x / 6x & \textbf{99.9} & \textbf{99.5}\\
\hline
\end{tabular}
}
```
We find that, while the shallower baselines `\loopy{k}{1}`{=latex} degrade with lower $k$, the looped model `\loopy{k}{12/k}`{=latex} performs very well, and nearly matches the iso-flop `\loopy{12}{1}`{=latex} baseline. In fact, even a 1-layer network looped 12 times is able to solve this, despite using merely $1/12^{th}$ of the parameters of the baseline. This suggests the addition problem primarily requires depth, but not necessarily more parameters.

#### $p$-hop induction. {#sec:khop_desc}

`\looseness`{=latex}-1The $p$-hop problem is a synthetic induction task studied in [@sanford2024transformers], who were inspired by the analysis of induction heads from [@elhage2021mathematical]. Specifically, given a sequence of letters $\vv = (v_1 \ldots v_n)$ from an alphabet $\Sigma$, an induction head tries to find the penultimate occurrence of $v_n$ (from left to right) in the sequence and output the character immediately succeeding it. The $p$-hop problem generalizes this idea to sequentially hop $p$ times. Intuitively, the $p$-hop problem tests a model's ability to recursively backtrack and retrieve the answer for a given query. This is reminiscent of the reasoning abilities required to solve reading comprehension kind of problems. We present the formal definition of the $p$-hop problem in `\Cref{defi:khop}`{=latex}. We perform experiments with looped and non-looped models on the $p$-hop problem, with alphabet size set to 4 and sequences of length $256$, and vary $p$ between 16 and 32 to control the problem difficulty. Our observations are presented in Table `\ref{table:addition_phop_results}`{=latex}. Similarly to our findings on the addition task, reasonably deep looped models perform as well as the baseline using much fewer parameters.

```{=latex}
\iffalse
```
::: {#table:khop-results-old}
            **Setting**            **Accuracy (%)**
  ------------------------------- ------------------
              2 layer                   58.14
             10 layer                   99.69
   Looping with 2 layer, 5 loops      **99.94**

  : Comparison of looped vs non-looped models for the $p$-hop problem.
:::

```{=latex}
\fi
```
#### i-GSM (Synthetic Grade School Math Problems).

Inspired by  [@ye2024physics], we built our own version of grade-school level math word problems. While we follow many of the design guidelines of [@ye2024physics], we make a few simplifying changes. We generate the math problem as a DAG of arithmetic computations modulo 7, and restrict the depth of the graph to 4. For simplicity, we retain the problems in the symbolic form and do not map them to English (e.g., see Table `\ref{fig:sample-eigsm-problem}`{=latex}) [^1] We train models of depths 1, 2, 4 and 8 and compare them with different looped variants in Table `\ref{table:igsm-results}`{=latex}. The answer is computed modulo 7. Hence, a random guessing baseline would get at least 14% accuracy. `\Cref{sec:apx_igsm}`{=latex} has more details. Remarkably, we again observe that a depth $k$ model looped $L$ times matches or outperforms a depth $kL$ model and far outperforms a depth $k$ non-looped model. This suggests that even a more complex and realistic looking reasoning problem does not need too many parameters and can benefit from depth via looping.

```{=latex}
\scalebox{0.8}{
    
        
        \small
        \parbox{\linewidth}{
        \paragraph{Question.} \textit{E\#I := 4. E\#J := E\#I. K\#N := I\#N + J\#O + F\#K. F\#K := E\#J. J\#O := F\#K + K\#O + E\#J. H\#J := E\#J + F\#K. I\#P := L\#M + I\#N + K\#O. I\#M := J\#O + J\#P + F\#K. J\#P := H\#J - F\#K. L\#M := I\#N + J\#P + F\#K. I\#N := 2 * J\#P + H\#J + E\#I. K\#O := J\#P + I\#N + E\#J. I\#P?}\\
        \paragraph{Answer with CoT.} \textit{E\#I = 4. $\implies$ E\#I = 4. E\#J = E\#I. $\implies$ E\#J = 4. F\#K = E\#J. $\implies$ F\#K = 4. H\#J = E\#J+F\#K. $\implies$ H\#J = 1. J\#P = H\#J-F\#K. $\implies$ J\#P = 4. I\#N = 2J\#P+2H\#J+2E\#I. $\implies$ I\#N = 4. L\#M = I\#N+J\#P+F\#K. $\implies$ L\#M = 5. K\#O = J\#P+I\#N+E\#J. $\implies$ K\#O = 5. I\#P = L\#M+I\#N+K\#O. $\implies$ I\#P = 0.}
        }
        % \normalsize
        % \caption{Symbolic i-GSM problem and its solution.}
        \label{fig:sample-eigsm-problem}
    %
    }
```
```{=latex}
\scalebox{0.8}{
\begin{tabular}{lc|c}
\toprule
 & Params / FLOPs & Accuracy \\
\midrule
\midrule
Base \loopy{8}{1} & 8x / 8x & \textbf{73.2} \\
\hline
\hline
\rowcolor{lightgray}
1 layer model  \\
\hline
Base \loopy{1}{1} & 1x / 1x &  24.5 \\
Loop \loopy{1}{2} & 1x / 2x & 52.3\\
Loop \loopy{1}{4} & 1x / 4x & 69.9\\
Loop \loopy{1}{8} & 1x / 8x & \textbf{73.2}\\
\hline
\hline
\rowcolor{lightgray}
2 layer model  \\
\hline
Base \loopy{2}{1} & 2x / 2x & 54.0 \\
Loop \loopy{2}{2} & 2x / 4x & 66.9 \\
Loop \loopy{2}{4} & 2x / 8x & \textbf{73.6} \\
\hline
\hline
\rowcolor{lightgray}
4 layer model  \\
\hline
Base \loopy{4}{1} & 4x / 4x & 71.3 \\
Loop \loopy{4}{2} & 4x / 8x & \textbf{71.6} \\
\hline
\end{tabular}
}
```
`\label{table:igsm-results}`{=latex}

Language modeling with looped models {#sec:language_modeling}
====================================

In this section, we pretrain and evaluate looped models for causal language models. We train models on 250B tokens of the Pile dataset [@gao2020pile] and use a 24-layer 1B parameter model for most experiments, motivated by the setting in @tay2022ul2 (refer to `\Cref{sec:apx_language_modeling}`{=latex} for more details).

Experiments with 1B language modeling {#sec:expt_1B}
-------------------------------------

For causal language modeling, we pretrain various looped models on the standard GPT2-style next token prediction objective [@radford2019language]. We train models with different parameter budgets to make sure that the findings are robust. We remind the reader that the notation `\loopy{k}{L}`{=latex} corresponds to a $k$ layer model looped $L$ times. For each setting, we compare 3 models: **(a)** `\loopy{24}{1}`{=latex}: 24-layer 1B model, **(b)** `\loopy{k}{1}`{=latex}: $k$-layer model with the same configuration as the 24-layer model for other dimensions, **(c)** `\loopy{k}{24/k}`{=latex}: $k$-layer model looped $24/k$ times to match the parameter count of (b) and match the effective depth/FLOPs of (a). We run experiments for $k\in\{4, 6, 8, 12\}$ to ensure that the findings are robust. After pretraining on Pile, we evaluate the models on validation perplexity and on downstream benchmarks using $k$-shot evaluations. Results are summarized in `\Cref{table:language_modeling_results}`{=latex}

```{=latex}
\scalebox{0.7}{
% \begin{tabular}{lll|r|rrr|r}
\begin{tabular}{lcc|cccc|c}
\toprule
 & Params / & Perplexity ($\downarrow$) & Closed & Open   & Math Word & All Tasks & Reasoning  \\
  & FLOPs & \tiny{(validation)} & Book QA ($\uparrow$) & Book QA ($\uparrow$)  &  Problems ($\uparrow$) & Average ($\uparrow$) & Primitives ($\uparrow$) \\
  & &  & \tiny{(4 tasks)} & \tiny{(5 tasks)}  &  \tiny{(6 tasks)} & \tiny{(15 tasks)}  & \tiny{(4 tasks)}  \\
\midrule
\hline
\rowcolor{lightgray}
24 layers  \\
\hline
Baseline & 24x / 24x & 7.40 & 11.2 & 33.9  & 29.3 & 26.0 & 47.5 \\
\hline
\hline
\rowcolor{lightgray}
12 layers  \\
\hline
Base \loopy{12}{1} & 12x / 12x & 8.16 & 8.2 & 26.9  & 26.7 & 21.8 & 35.7 \\
Loop \loopy{12}{2} & 12x / 24x & 7.90 & 9.3 & 30.8  & 34.3 & 26.5 & 51.2\\
% \cline{1-3}
\hline
\% Gap &  & \ApplyGradient{34}\% & \ApplyGradient{37}\% & \ApplyGradient{56}\%  & \ApplyGradient{282}\% & \ApplyGradient{110}\% & \ApplyGradient{131}\%\\
%%
%%
\hline
% \new{
Middle Loop & 12x / 24x & 7.81 & 11.0 & 32.3  & 28.3 & 25.0 & 56.5\\
\loopy{4}{1,4,1} & &  &  &   &  &  & \\

% \cline{1-3}
\hline
\% Gap &  & \ApplyGradient{46}\% & \ApplyGradient{94}\% & \ApplyGradient{78}\%  & \ApplyGradient{62}\% & \ApplyGradient{95}\% & \ApplyGradient{176}\%\\
%%
%%
% \% Gap &  & 34\% & 37\% & 56\%  & 282\% & 110\% & 131\%\\
% }
\hline
\hline
\rowcolor{lightgray}
8 layers  \\
\hline
Base \loopy{8}{1} & 8x / 8x & 8.75 & 6.3 & 22.7 & 17.1 & 16.1 & 33.0 \\
Loop \loopy{8}{3} & 8x / 24x & 8.19 & 8.5 & 30.8 & 28.4 & 23.9 & 55.3 \\
\% Gap &  & \ApplyGradient{41}\% & \ApplyGradient{44}\% & \ApplyGradient{72}\% & \ApplyGradient{92}\% & \ApplyGradient{78}\% & \ApplyGradient{153}\% \\
% \% Gap &  & 41\% & 44\% & 72\% & 92\% & 78\% & 153\% \\
\hline
\hline
\rowcolor{lightgray}
6 layers  \\
\hline
Base \loopy{6}{1} & 6x / 6x & 9.25 & 4.0 & 19.3 & 17.7 & 14.6 & 24.1 \\
Loop \loopy{6}{4} & 6x / 24x & 8.42 & 8.2 & 28.7 & 29.8 & 23.7 & 56.1 \\
\% Gap &  & \ApplyGradient{44}\% & \ApplyGradient{58}\% & \ApplyGradient{64}\% & \ApplyGradient{104}\% & \ApplyGradient{80}\% & \ApplyGradient{136}\% \\
% \% Gap &  & 44\% & 58\% & 64\% & 104\% & 80\% & 136\% \\
\hline
\hline
\rowcolor{lightgray}
4 layers  \\
\hline
Base \loopy{4}{1} & 4x / 4x & 10.12 & 1.8 & 13.8 & 9.7 & 9.0 & 19.4 \\
Loop \loopy{4}{6} & 4x / 24x & 8.79 & 6.7 & 26.2 & 24.8 & 20.4 & 56.9 \\
\% Gap &  & \ApplyGradient{48}\% & \ApplyGradient{52}\% & \ApplyGradient{61}\% & \ApplyGradient{77}\% & \ApplyGradient{67}\% & \ApplyGradient{133}\% \\
% \% Gap &  & 48\% & 52\% & 61\% & 77\% & 67\% & 133\% \\
\hline
\end{tabular}
}
```
`\looseness-1`{=latex}**Evaluation metrics.** We evaluate the models on perplexity metric after training is completed. Since there is growing evidence that perplexity, although very useful for training, is a narrow measure of model quality, we also track more holistic downstream evaluations [@liang2023holistic]. Thus, we evaluate the model on 4 important slices: closed book QA, open book QA, math word problems and reasoning primitives. These comprise of **19 different tasks** in total.

-   **Closed book QA**: This includes tasks like TriviaQA [@joshi2017triviaqa], TydiQA-NoContext [@clark2020tydi], Natural Questions [@kwiatkowski2019natural] and Web Questions [@talmor2018web] that test the model's ability to answer questions without any context, and thus, primarily measure the memorization abilities of language models.

-   **Open book QA**: This includes tasks like TydiQA-GoldP [@clark2020tydi], SquadV2 [@rajpurkar2018know], Drop [@dua2019drop], QuAC [@choi2018quac], CoQA [@reddy2019coqa] that evaluate the model's ability to infer the answer to a question from the extra context that is provided, akin to reading comprehension.

-   **Math word problems**: To evaluate the model's ability to reason, we test them on math word problems considered in [@wei2022chain]. This includes tasks like SVAMP [@patel2021nlp], ASDiv [@miao2020diverse], the MAWPS benchmark [@koncel2016mawps]. We report 5-shot evaluation for the pretrained model on these tasks.

-   **Reasoning primitives**: @saunshi2024inductive introduced these datasets to study the inductive bias of stacking towards improving reasoning, by isolating simple reasoning abilities. One such primitive is depth-$k$ variable assignment that requires the model to resolve a chain of assignments of length $k$. An example of depth-0 var-assign is *a=1, b=2, c=6, b=?*, and example for depth-1 var-assign is *a=1, b=2, c=a, d=b, d=?*. We evaluate on the math and coding variants of the depth-0 and depth-1 problems using 5-shot evaluation.

For each task group $G$ from above, in `\Cref{table:language_modeling_results}`{=latex} we report the average accuracy for that task group, denoted by $\text{Avg}_{G}$. Furthermore, for each layer budget $k$, we report the **% gap** between the iso-param and iso-flop models that is covered by the looped model. More specifically $$\begin{aligned}
    \text{\% Gap} = \frac{\text{Avg}_{G}\text{\loopy{k}{24/k}} - \text{Avg}_{G}\text{\loopy{k}{1}}}{\text{Avg}_{G}\text{\loopy{24}{1}} - \text{Avg}_{G}\text{\loopy{k}{1}}}\label{eq:gap_pct}.\end{aligned}$$ `\looseness`{=latex}-1This measures how effectively looping can bridge the gap between iso-param and iso-flops baselines. Implicitly, it measures how different task groups behave when a model only has a few parameters but is given depth through looping. `\text{\% Gap}`{=latex} being closer to 0% means that providing depth via looping does not benefit the task group, and number of parameters is the most important factor. On the other hand, `\text{\% Gap}`{=latex} closer to 100% means parameter count matters much less for the task group, and that depth via looping is more essential.

**Perplexity results.** Firstly we notice that all `\loopy{k}{24/k}`{=latex} looped models have better perplexity compared to the iso-param `\loopy{k}{1}`{=latex} baseline, but worse perplexity compared to the non-looped 24-layer baseline. The looped models only covers up roughly $34-50\%$ of the perplexity gap between the iso-param and iso-flop baselines for various values of $k$. This perplexity gap is not too surprising since the looped model has $24/k$ times fewer parameters, and thus, lower capacity than the 24-layer baseline. This was also been observed in prior works [@lan2019albert; @mohtashami2023cotformer] and is the primary reason looped models have been ignored for language modeling. However, as we shall see shortly, the downstream metrics paint a more interesting and favorable picture.

**Results on QA tasks.** We first consider closed book and open book QA categories in `\Cref{table:language_modeling_results}`{=latex}. Closed book QA tasks are primarily testing the model's memorization abilities. Open book QA on the other hand tests the model's ability to infer the answer from the additional context that is provided. Thus, intuitively, open book QA tasks require more reasoning. Firstly we notice that the % Gap for closed book QA (memorization) is very similar to % Gap for perplexity. The % Gap for open book QA, on the other hand, is much higher for all parameter budgets. This suggests that looped models relatively improve contextual QA much more than memorization based QA.

```{=latex}
\begin{figure*}[!tbp]

    \begin{subfigure}{0.23\textwidth}
        \includegraphics[width=\textwidth]{Media/12L_closedQA.png}
    % \caption{}
    \label{fig:12L_closedQA}
    \end{subfigure}\hfill
    
    \begin{subfigure}{0.23\textwidth}
    
    \includegraphics[width=\textwidth]{Media/12L_openQA.png}
    % \caption{}
    \label{fig:12L_openQA}
    \end{subfigure}\hfill

    \begin{subfigure}{0.23\textwidth}
        \includegraphics[width=\textwidth]{Media/12L_mwp.png}
    % \caption{}
    \label{fig:12L_mwp}
    \end{subfigure}\hfill

    \begin{subfigure}{0.23\textwidth}
        \includegraphics[width=\textwidth]{Media/12L_primitives.png}
    % \caption{}
    \label{fig:12L_primitives}
    \end{subfigure}\hfill
    
    \begin{subfigure}{0.24\textwidth}
        \includegraphics[width=\textwidth]{Media/24L_closedQA.png}
    % \caption{}
    \label{fig:24L_closedQA}
    \end{subfigure}\hfill
    
    \begin{subfigure}{0.23\textwidth}
    
    \includegraphics[width=\textwidth]{Media/24L_openQA.png}
    % \caption{}
    \label{fig:24L_openQA}
    \end{subfigure}\hfill

    \begin{subfigure}{0.23\textwidth}
        \includegraphics[width=\textwidth]{Media/24L_mwp.png}
    % \caption{}
    \label{fig:24L_mwp}
    \end{subfigure}\hfill

    \begin{subfigure}{0.23\textwidth}
        \includegraphics[width=\textwidth]{Media/24L_primitives.png}
    % \caption{}
    \label{fig:24L_primitives}
    \end{subfigure}
    
    \caption{\looseness-1Downstream evaluation for various task groups on the x-axis, vs validation log perplexity on the y-axis (reversed), as training proceeds. The top plots compare a 12-layer baseline model \loopy{12}{1} and the looped model \loopy{12}{2}.
    The second row compares the baseline 24-layer model and the 24-layer model trained with regularization using block size $k=4$ and $\lreg=10$ (See \Cref{eq:regularization}). For both comparisons we have similar observations. For closed book QA (memorization) tasks looping has very similar trends to baseline. For open book QA tasks and math word problems, looping has much better downstream performance at an equivalent log perplexity. This verifies the inductive bias of looping and regularization towards better reasoning abilities.}
    \label{fig:iso_plots_taskgroups_12L}
\end{figure*}
```
**Math problems and reasoning primitves.** We also present the % Gap for the math word problem in `\Cref{table:language_modeling_results}`{=latex}. Surprisingly, we find that `\loopy{k}{24/k}`{=latex} looped model can almost match the baseline `\loopy{24}{1}`{=latex} model for $k\ge6$, despite having $k$ times fewer parameters. In fact, the 12 layer model looped twice is even signficantly better ($34.3$) than the 24 layer baseline ($29.3$), despite having 50% of parameters and worse perplexity; suggesting that looping disproportionately improves mathematical reasoning.

To better understand the effect on reasoning, we direct the readers attention to the evaluations for reasoning primitives in `\Cref{table:language_modeling_results}`{=latex}. The results are quite remarkable: **`\loopy{k}{24/k}`{=latex} looped models are better than the iso-flop baseline `\loopy{24}{1}`{=latex} at reasoning primitives, for all values of $k$**. This is a priori very surprising, since these are synthetic generated tasks and have nothing to do with the pretraining data or the model architecture. Thus, solving these tasks necessarily requires reasoning from context, and memorization abilities will not help here. These results clearly suggest that looped models have a bias towards improving reasoning, despite having worse perplexity and memorization abilities. Next, we formalize the *inductive bias* towards reasoning via isoplots.

Inductive bias towards reasoning {#sec:inductive_bias}
--------------------------------

In this section, we formalize the inductive bias by plotting the perplexity vs downstream metric iso-plots, as introduced in @saunshi22understanding. `\Cref{sec:expt_1B}`{=latex} showed that looped models have *higher than expected* performance on reasoning problems. However, since looped models are worse on perplexity, it is hard to make a direct comparison between various models. One way to bring parity between models is to look at their downstream performances at the same validation pretraining loss [@liu2023same]. @saunshi2024inductive proposed plotting pretraining loss vs downstream metrics as training proceeds, as a way to study the inductive bias of various methods. For each model, we evaluate the log perplexity and downstream metrics at every 20k steps, starting from 120k steps. We plot these values in a scatter plot and fit a linear function with log perplexity and the corresponding downstream metric being input and output respectively. Please refer to `\Cref{fig:iso_plots_taskgroups_12L,fig:iso_plots_taskgroups_8L}`{=latex} for two sets of isoplots.

**Findings.** For all values of $k$, we observe the following:

-   The isoplots for `\loopy{k}{L}`{=latex} looped model and `\loopy{k}{1}`{=latex} baseline are very aligned for closed book QA tasks (if extrapolated). This suggests that log perplexity is a very strong indicator of downstream performance on memorization based tasks.

-   For open book QA and math word problems, the isoplot line for the looped model is always higher than the baseline model. This suggests that at the same log perplexity, looped models will tend to have higher evaluation on these tasks that require more reasoning.

-   For reasoning primitives, there is a stark difference between looped and baseline models. The looped model seems to have good performance at most points in training.

Overall this suggests a strong inductive bias of looped models towards improving reasoning.

```{=latex}
\new{

\subsection{Middle looping variant and relationship with Gradual stacking}
\label{sec:stacking_middle}

\looseness-1Recently \citet{saunshi2024inductive} introduced a gradual stacking \citep{gong2019efficient,reddi2023efficient} approach for training language models called MidAS. This approach gradually grows the model depth as training proceeds by duplicating certain layers of the model in each stacking operation. Surprisingly, they found that MidAS not only speeds up pretraining, but also improves reasoning in the same sense as \Cref{fig:iso_plots_taskgroups_12L} -- better reasoning at the same perplexity.
Furthermore, the paper established a strong connection between stacking via MidAS and looped models, owing to the layer duplication operation, and conjectured that this is the reason for such an inductive bias.
Our results from the previous section provides a compelling evidence for this conjecture by showing that looped models also show a very similar inductive bias, thus, further strengthening the connection between stacking and looped models.
Why such an inductive bias occurs is still an open question, and we believe that understanding this is an important future direction.


Furthermore, inspired by their findings, we explore {\bf middle looping} (see \Cref{fig:looping_illustration} for an illustration) --- a variant of looping which maintains independent layers at the start and the end of the network, and perform looping on the middle block of layers. 
The high-level intuition from \citet{saunshi2024inductive} is that the first and last layers play a special role in the model and thus, should be treated differently from the middle layers.
In \Cref{table:language_modeling_results}, we report results for a version of middle looping that is iso-param with a \loopy{12}{1} baseline and iso-flop with a \loopy{24}{1} baseline, just like the \loopy{12}{2} model.
Overall, we find that middle looping has better perplexity and more uniform improvements than the default looping of \loopy{12}{2} (except for math word problems), and thus, might be a more practical looping approach.
We leave the exploration of the best looping strategies for future work.

% \ns{Present basic results}



\subsection{Scaling behavior of looping}
\label{sec:depth_scaling}



\begin{figure*}[!tbp]% [H] is so declass\'e!

    \begin{subfigure}{0.33\textwidth}
        \includegraphics[width=0.85\textwidth]{Media/depth_scaling_eval_perplexity.png}
    % \caption{}
    \label{fig:depth_scaling_eval_perplexity}
    \end{subfigure}\hfill
    
    \begin{subfigure}{0.33\textwidth}
        \includegraphics[width=0.85\textwidth]{Media/depth_scaling_closedQA.png}
    % \caption{}
    \label{fig:depth_scaling_closedQA}
    \end{subfigure}\hfill
    
    \begin{subfigure}{0.33\textwidth}
    
    \includegraphics[width=0.85\textwidth]{Media/depth_scaling_openQA.png}
    % \caption{}
    \label{fig:depth_scaling_openQA}
    \end{subfigure}\hfill

    \begin{subfigure}{0.33\textwidth}
        \includegraphics[width=0.85\textwidth]{Media/depth_scaling_eval_token_acc.png}
    % \caption{}
    \label{fig:depth_scaling_eval_token_acc}
    \end{subfigure}\hfill

    \begin{subfigure}{0.33\textwidth}
        \includegraphics[width=0.9\textwidth]{Media/depth_scaling_mwp.png}
    % \caption{}
    \label{fig:depth_scaling_mwp}
    \end{subfigure}\hfill

    \begin{subfigure}{0.33\textwidth}
        \includegraphics[width=0.9\textwidth]{Media/depth_scaling_primitives.png}
    % \caption{}
    \label{fig:depth_scaling_primitives}
    \end{subfigure}
    
    \caption{\looseness-1\new{Scaling behavior for various task group as the effective depth increases. The blue curve shows how performance scales as the number of loops increases, without increasing parameters, using models of the form \loopy{4}{D/4} for various values of $D$. The orange curve visualizes the scaling behavior of \loopy{D}{1} which increases the depth by adding fresh parameters. For reasoning primitives, the looped model scales as well, or even better, than the baseline despite having $D/4$ fewer parameters.}}
    \label{fig:depth_scaling}
\end{figure*}



In this section, we discuss an intriguing scaling behavior of looping, specifically the impact of the number of loops on various evaluation metrics. In particular, we are interested in scaling behavior of: (a) accuracy as a function of number of loops and (b) comparison of looped and non-looped baseline of the same effective depth.
To this end, we pretrain various looped language models of the form \loopy{4}{L}, i.e., 4-layer model looped $L$ times, for $L \in \{1, 2, 3, 6, 9, 12\}$.
To enable iso-FLOPs comparison, we also train baseline models of the form \loopy{4L}{1}, which has $L$ times more parameters.
For each task group, we plot the average accuracy as a function of the \emph{effective depth}, i.e. $D = 4L$.
From the results presented in \Cref{fig:depth_scaling}, we observe the following.
% For both the looped and baseline models, we find that the accuracy can be well approximated by a linear function of the log of the depth.
\begin{enumerate}[leftmargin=.15in, itemsep=1ex]
    \item In both cases, we find that the accuracies for all task groups continue to increase with more loops/depth, although, unsurprisingly, the returns are diminishing with depth for both looped and non-looped models. Interestingly, for both looped and non-looped models, we found that one can fit a simple scaling law of the following form:
\begin{align}
    \text{Acc} = \alpha \log(D) + \beta,
    \label{eq:depth_scaling_law}
\end{align}
where $D$ is the effective depth and $\alpha$ measures the impact of depth on downstream performance.
\item Furthermore, for each task group, we compute $\alpha_{loop}/\alpha_{base}$ to assess the relative impact of ``depth via looping'' compared to ``depth via additional parameters''.
We find that more loops continue to help, and the relative benefit of loops is higher for reasoning tasks like open book QA and math problems.
Remarkably, the impact of loops is even higher (1.19x) than impact of depth for reasoning primitives, which further consolidates the benefit of looped models for reasoning.

\end{enumerate}

% A natural question to ask is -- why do looped models show a nice linear scaling?


\paragraph{Latent thoughts and connections to CoT reasoning.}
\label{sec:latent_thoughts}
We end this discussion with an important question: \emph{why should we expect this interesting scaling behavior of looped models?} We argue that looped models have a strong connection to chain-of-thought (CoT) reasoning. Such a connection is insightful because recent works on thinking have shown CoT can demonstrate inference-time scaling behavior, where the accuracy on reasoning datasets continues to improve as the length of the models chain of thought increases; i.e., generating more thoughts leads to better reasoning. 

To establish this connection, we make a simple observation about CoT reasoning -- it is essentially a looped model that generates a single thought token in each iteration.
However, looped models can be more powerful since they can generate multiple ``latent thoughts'' in each iteration. 
This can be visualized in \Cref{fig:latent_thoughts}.
We further translate this intuition into a theoretical result (see \Cref{sec:looped_simulate_cot}) about how looped models can in fact simulate CoT reasoning.
Given this connection to CoT reasoning, it is believable that looping can scale well for harder reasoning.
We believe that leveraging looping explicitly for inference-time scaling is a very promising future direction.
}
```
Looping-inspired regularization {#sec:regularization}
===============================

In the previous section, we observed the looped models can improve reasoning with worse perplexity. Can we leverage this observation to improve reasoning without affecting perplexity? Here, we propose a simple approach: regularize the weights of the model to encourage them to be close to a looped model. This could have two advantages, (a) the model still has free parameters to improve perplexity, (b) the closeness to looped model can inherit the desirable inductive bias. In particular, if an $L$-layer model is denoted as $f_{0}\circ f_{1} \dots, \circ f_{L/k-1}$, where each $f_{i}$ is a block of $k$ layers, we add a regularization term that makes all the weights of the block $f_{i}$ close to $f_{i+1}$ in terms of cosine similarity. For a parameter group $G$ (e.g. first feed-forward layer, or query matrix in Transformer), we use $\theta_{G}^{(0)}, \theta_{G}^{(1)}, \dots, \theta_{G}^{(L-1)}$ to denotes the weights in all layers. Then the regularization term is $$\begin{aligned}
    \mathcal{R}_{G}(k) = \frac{1}{L-k}\sum_{i=0}^{\frac{L}{k}-2} \sum_{j=0}^{k-1} \text{Cosine}\left(\theta_{G}^{(i k + j)}, \theta_{G}^{((i+1) k + j)}\right)\end{aligned}$$ The final loss function is a sum of the standard cross-entropy loss and the regularization term averaged over all groups, multiplied by a scalar hyperparameter. Let $\mathcal{G}$ denote the set of all parameter groups; $\mathcal{G} = \left\{\text{Attn-Q}, \text{Attn-K}, \dots, \text{FFN-W2}\right\}$ $$\begin{aligned}
    \mathcal{L} = \mathcal{L_{\text{xent}}} + \lreg|\mathcal{G}|^{-1}  \sum_{G \in \mathcal{G}} \mathcal{R}_{G}(k)
    \label{eq:regularization}\end{aligned}$$ `\looseness`{=latex}-1In the above formulation, $\lreg=0$ would recover standard training and $\lreg \rightarrow \infty$ would converge to a fully looped model. Intermediate values of $\lreg$ will lead to \`\`approximately" looped models. For instance, to emulate the `\loopy{4}{6}`{=latex} looped model setting, we use pick $k=4$, $L=24$ and a large regularization strength like $\lreg=10$. All other hyperparameters are kept the same as baseline training for a fair comparison. We tried options other than cosine similarity, like $\ell_2$ norm, to bring the weights closer but found that cosine was more robust and effective.

**Cosine similarities.** Firstly, we check if the regularization had the right effect by measuring the cosine similarities between the successive blocks of $k$ layers at the end of training. We, indeed, find that for all parameter groups, the cosine similarity around 0.98 or higher (see `\Cref{fig:cosines_cosreg}`{=latex}).

```{=latex}
\iffalse
```
**Results.** `\Cref{table:regularization_results}`{=latex} summarizes the results with regularization. Overall, we find that regularized models improve over standard training. In particular, for block size of $k=4$, we find that higher regularization strength of $\lreg=10$ was more effective than $\lreg=1$. Thus, we use $\lreg=10$ for all other experiments. Not only does regularization slightly improve perplexity ($7.40 \rightarrow 7.38$) and closed book QA performance ($11.2 \rightarrow 12.5$), it significantly improves open book QA ($33.9 \rightarrow 36.2$), math word problems ($29.3 \rightarrow 36.4$) and reasoning primitives ($47.5 \rightarrow 57.5$). Thus, we are able to achieve the best of both worlds -- better memorization than baseline and better reasoning than looped models. `\fi`{=latex}

**Inductive bias.** To confirm the inductive bias of the regularizer, we visualize the log perplexity vs downstream isoplots for $\lreg=10$ and baseline models in `\Cref{fig:iso_plots_taskgroups_12L}`{=latex}. While the plots are similar for closed book QA, a strong inductive bias shows up for open book QA and reasoning problems. Crucially, the regularized model does well on reasoning without hurting perplexity (see `\Cref{table:regularization_results}`{=latex}).

```{=latex}
\scalebox{0.75}{
% \begin{tabular}{lll|r|rrr|r}
\begin{tabular}{lc|cccc|c}
\toprule
 &  Perplexity ($\downarrow$) & Closed & Open   & Math Word & All Tasks & Reasoning  \\
 % &  & Perplexity ($\downarrow$) & Closed & Open   & Math Word & All Tasks & Reasoning  \\
  & \tiny{(validation)} & Book QA ($\uparrow$) & Book QA ($\uparrow$)  &  Problems ($\uparrow$) & Average ($\uparrow$) & Primitives ($\uparrow$) \\
  % & Params / FLOPs & \tiny{(validation)} & Book QA ($\uparrow$) & Book QA ($\uparrow$)  &  Problems ($\uparrow$) & Average ($\uparrow$) & Primitives ($\uparrow$) \\
  &  & \tiny{(4 tasks)} & \tiny{(5 tasks)}  &  \tiny{(6 tasks)} & \tiny{(15 tasks)}  & \tiny{(4 tasks)}  \\
  % &  &  & \tiny{(4 tasks)} & \tiny{(5 tasks)}  &  \tiny{(6 tasks)} & \tiny{(15 tasks)}  & \tiny{(4 tasks)}  \\
\midrule
\hline
Baseline & 7.40 & 11.2 & 33.9  & 29.3 & 26.0 & 47.5 \\
% Baseline & 24x / 24x & 7.40 & 11.2 & 33.9  & 29.3 & 26.0 & 47.5 \\
% \hline
\hline
% \rowcolor{lightgray}
Regularized ($k=4,\lreg=1$) & 7.41 & 11.2 & 34.8  & 31.6 & 27.2 & 42.5 \\
Regularized ($k=4,\lreg=10$) & 7.38 & 12.5 & 36.2  & 36.4 & 30.0 & 57.2\\
Regularized ($k=6,\lreg=10$) & 7.40 & 12.0 & 35.8  & 31.0 & 27.5 & 55.8\\
Regularized ($k=8,\lreg=10$) & 7.43 & 11.3 & 34.4 & 32.8 & 27.6 & 56.3\\
Regularized ($k=12,\lreg=10$) & 7.51 & 10.1 & 34.1 & 32.3 & 27.0 & 50.7\\
% \hline
% \hline
% \rowcolor{lightgray}
% 8 layer model  \\
\hline
\end{tabular}
}
```
Theoretical analysis for looped models {#sec:theory}
======================================

`\looseness`{=latex}-1In this section, we present theoretical results to understand the phenomenon from the previous sections -- *why can looped model with few parameters match an iso-flops non-looped baseline on reasoning problems?* While a complete theory is challenging, since \`\`reasoning" is a very broad concept, the goal is to provide some intuition and formalization for the expressive power of looped models. First, we show that looped Transformers can effectively solve group composition (a generalization of the addition problem). Then we show a very general result on how a non-looped model with very few distinct layers can be simulated by a looped model with a small blowup in model size. This result is then used to solve the $p$-hop problem using a one-layer looped transformer. Our construction for group composition and $p$-hop problem are nearly optimal in terms of depth and much more efficient in terms of parameters compared to non-looped models.

Preliminaries and Notations {#subsec:prelim}
---------------------------

We first define the standard transformer architecture. Throughout the paper we will fix the dimension of the embedding to be $d\in\mathbb{N}^+$, the vocabulary to be $\mathcal{V}$ and maximum sequence length to be $n_{\max}$. We will use $\id$ to denote the identity mapping. Here, we describe the high-level notation. Please refer to `\Cref{sec:apx_detailed_notation}`{=latex} for detailed notations. We use $\ff$ and $\mha$ to denote the feed-forward and attention layers respectively, and $\theta_{\mha}$ and $\theta_{\ff}$ denote the parameters in these layers.

```{=latex}
\begin{definition}[Transformer Block]\label{defi:transformer_block}
Given number of layers $L\in\mathbb{N}^+$ and parameter $\theta_\tfblock = (\theta^{(l)}_\mha,\theta^{(l)}_\ff )_{l=0}^{L-1}$, $L$-layer transformer block $\tfblock_{\theta_\tfblock}:(\mathbb{R}^d)^n\to (\mathbb{R}^d)^n$ for any $n\in\mathbb{N}^+$ is defined as \begin{equation}
    \tfblock_{\theta_\tfblock} \triangleq (\id+ \ff_{\theta^{(L-1)}_\ff})\circ (\id+ \mha_{\theta^{(L-1)}_\mha})\circ \cdots (\id+ \ff_{\theta^{(0)}_\ff})\circ (\id+ \mha_{\theta^{(0)}_\mha}),
\end{equation}
\end{definition}
```
We also denote $\embed$ and $\transoutput$ to be the input embedding and output softmax layers respectively. Please refer to `\Cref{sec:apx_detailed_notation}`{=latex} for precise definitions. Finally, we define the entire transformer model that maps a sequence of tokens to a distribution over tokens: $p_{\theta}: \cup_{n\le n_{\max}} \mathcal{V}^n\to \Delta^{|\mathcal{V}|-1}$. $$\begin{aligned}
p_{\theta}\triangleq \transoutput_{\theta_\transoutput}\circ \tfblock_{\theta_\tfblock}\circ \embed_{\theta_\tokenembedding,\theta_\posencoding}\end{aligned}$$ where $\theta = (\theta_\tfblock,\theta_\tokenembedding,\theta_\posencoding,\theta_\transoutput)$ denote all the transformer parameter. In particular, we use $\transformer_{\theta}(v_1,\ldots,v_n) \triangleq \argmax_{v\in\mathcal{V}} p_{\theta}(v|v_1,\ldots,v_n)$ to denote the deterministic version of the transformer model. We now define a looped Transformer model that also subsumes a non-looped model.

```{=latex}
\begin{definition}[\loopy{L}{T} Looped Transformer]\label{defi:looped_transformer}
Given the number of loops $T\in\mathbb{N}^+$,  parameters $\theta = (\theta_{\tfblock},\theta_\tokenembedding,\theta_\posencoding,\theta_\transoutput)$, where $\theta_{\transformer} = (\theta^{(l)}_\mha,\theta^{(l)}_\ff )_{l=0}^{L-1}$, we define a \loopy{L}{T} \emph{looped Transformer} as  $p_{\theta,T}\triangleq \transoutput_{\theta_\transoutput}\circ \left(\tfblock_{\theta_\tfblock}\right)^T\circ \embed_{\theta_\tokenembedding,\theta_\posencoding}$.
\end{definition}
```
Group composition problem
-------------------------

We consider the problem of composing $n$ elements from a group, and prove that a standard 1-layer transformer looped $\mathcal{O}(\log(n))$ times can solve this problem. This is a generalization of the modular addition problem and has a long history (see `\Cref{sec:apx_group_composition}`{=latex}). Recently, @liu2022transformers show that transformers with $\log_2 n$ depth can compute composition over $n$ group elements, regardless of whether the group is solvable or not. However, the construction in @liu2022transformers uses different attention parameter for each layer. Here, we provide a more parameter efficient construction where we solve this problem by looping a one-layer transformer $\log(n)$ times.

```{=latex}
\begin{theorem}\label{thm:group_composition_log_depth}
    For any finite group $G$ and every $n\in\mathbb{N}^+$, there exists a constant-precision looped transformer $\transformer_{\theta,T}$ computing the composition of $n$ elements from $G$ with a $1$-layer transformer block, $T=\lceil\log _2 n\rceil$ loops, $G\cup\{\#\}$ being the vocabulary, $d= 3\left(\lceil\log_2 |G|\rceil +\lceil\log_2 n+1\rceil\right)$ embedding size, $d_\ff=|G|^2 + 6\lceil\log_2 |G|\rceil$ hidden dimension in MLP, $d_\attn=\lceil\log_2 n\rceil$ hidden attention dimension, and $2$ attention heads. More specifically, for any $g_1,\ldots,g_n\in G$, $\transformer_{\theta}(\#,g_1,\ldots, g_n) = g_1\circ\cdots \circ g_n$.
\end{theorem}
```
The above result matches the depth upper bound shown for non-looped models by @liu2022transformers, and is very close to the super constant lower bound shown there. Thus looped models can solve the problem with best known depth.

Looped models can simulate non-looped models
--------------------------------------------

Our second theoretical result shows that a non-looped transformer with repeated layers can be simulated by a looped transformer with fewer parameters and same depth.

```{=latex}
\begin{theorem}\label{thm:main}
    For any transformer $p_{\theta}$ with $L$ layers, $d$ embedding size, $d_\ff$ hidden dimension for MLP, $H$ attention heads with $d_\attn$ hidden dimension, at most $R$ distinct transformer layers and bounded activation value, there is a looped transformer $p_{\theta',L}$ working on the same vocabulary $\mathcal{V}$ plus a dummy token $\#$, 
which loops a 1-layer transformer block for $L$ times, with $d+R+2$ embedding size, $R h_\ff +O(L)$ hidden dimension for MLP, $RH$ attention heads with $d_\attn$ hidden dimension, such that for any string $v_1,\ldots, v_n\in\mathcal{V}$, $p_{\theta}(v_1,\ldots,v_n) = p_{\theta',L}(\#, v_1,\ldots,v_n)$.
\end{theorem}
```
We can also use the above theorem to convert the $O(\log_2 n)$ depth transformer that simulates group composition into a 1-layer looped transformer, although it is less parameter efficient than the result from the previous section. Please refer to `\Cref{sec:apx_simulate_nonlooped}`{=latex} for the full proof.

**Theory for $p$-hop.** The experiments on $p$-hop induction from `\Cref{sec:khop_desc}`{=latex} surprisingly show that a small model looped multiple times can solve it very effectively. We establish a theoretical basis for this finding. More specifically, we show that a constant layer transformer with $\log(p)$ loops suffices to solve $p$-hop induction problem. This result, in fact, matches the lower bound for layers required for non-looped models proved in @sanford2024transformers. The result follows from `\Cref{thm:main}`{=latex} and the construction for non-looped models from @sanford2024transformers (restated in `\Cref{thm:khop_log_depth}`{=latex}).

```{=latex}
\begin{corollary}
    $p$-hop problem (\Cref{defi:khop}) can be solved by looping a 1-layer transformer $\lfloor\log_2 p\rfloor+2$ times, which has $O(\log n)$ bits of precision, $d=d_{\ff}=d_{\attn}=O(1)$ embedding size, hidden dimension for MLP and attention, and at most $3$ attention heads.
\end{corollary}
```
```{=latex}
\new{

\subsection{Looped models can simulate CoT reasoning}
\label{sec:looped_simulate_cot}

% \ns{Briefly describe how looping is providing latent thoughts and it differs from CoT reasoning.}

In \Cref{sec:latent_thoughts}, we discussed how looped models can be viewed as generating multiple latent thoughts in each iteration. 
The following theorem shows that one can use looped transformer with $L$ loops to simulate $L$ steps CoT of another transformer with similar sizes. 
\begin{theorem}[Looped transformer simulates CoT]\label{thm:cot_informal}
    For any $L$-layer non-looped transformer $\transformer_\theta$ with fixed input length $n$ and the number of CoT steps $m$, there exists a looped transformer $\transformer_{\theta'}$ with $L+\mathcal{O}(1)$ layers, $\Omega(\log (n+m))$, more  embedding dimension and constantly many more attention heads, such that for any input $\vv= (v_i)_{i=1}^n$, the output of non-looped transformer after $m$ steps of CoT, $\transformer^m_\theta(\vv)$, is the same as that of the looped transformer on input $x$ concatenated by $m$ dummy tokens with $m$ loops, $\transformer_{\theta',m}(\vv,\#^m)$.
\end{theorem}

Below we sketch the high-level idea behind \Cref{thm:cot_informal}; the full proof can be found in \Cref{sec:apx_cot_connection}. 
\begin{itemize}
    \item (Masking all-but-one) We maintain a counter $t$ (\Cref{lem:mlp_addition}) at each position for the current number of loop and for the $i$th position, we will only "update" the embedding if $i-n\ge t$ and reset the embedding to the input to the looped transformer layer otherwise. This can be implemented by MLP. So similar to CoT, the first $n+i-1$ embedding won't be changed in the $i$th loop.
    \item (Shift by one) We can use attention to obtain the output of the current loop at the previous position and use that as the input of the next loop at current position. (\Cref{lem:shifting_layer})
    \item (Token decoding) We show that we can use MLP with ReLU activation to simulate the encoding and decoding process of CoT. (\Cref{lem:simulating_decoding_embedding})
\end{itemize}


}
```
```{=latex}
\begin{figure*}[!tbp]

    \begin{subfigure}{0.5\textwidth}
        \includegraphics[width=0.55\textwidth]{Media/cot_iteration.png}
    % \caption{}
    \label{fig:cot_iteration}
    \end{subfigure}\hfill
    
    \begin{subfigure}{0.5\textwidth}
        \includegraphics[width=0.55\textwidth]{Media/latent_thoughts_iteration.png}
    % \caption{}
    \label{fig:latent_thoughts_iteration}
    \end{subfigure}
    
    \caption{\looseness-1\new{{\bf Left.} Chain-of-thought reasoning can be viewed as a looped model, where each iteration produces one new thoughts token. The new tokens are highlighted in red. {\bf Right.} A looped model can instead generate multiple {\em latent thoughts} in parallel and, in theory, can simulate CoT reasoning my masking the updates appropriately (see \Cref{thm:cot_informal})}}
    \label{fig:latent_thoughts}
\end{figure*}
```
Related work {#sec:related}
============

`\looseness`{=latex}-1Reasoning is recognized as a core ability for intelligent and robustly model and has thus seen significant focus over the last few years. The synthetic reasoning problems we consider in this work have all been used in the prior works of [@sanford2024transformers; @ye2024physics; @sanford2024understanding; @nogueira2021investigating] to theoretically analyze the strengths and limitations of Transformers. There is also interest in the representation power of Transformers for computational problems [@liu2022transformers; @strobl2023transformers] and for chain-of-thought reasoning [@merrill2023expresssive; @feng2023towards; @li2024chain]. The necessity of model depth for performance has been remarked upon, for small models [@liu2024mobilellm] and reasoning [@chen2024can; @ye2024physics; @petty2023impact]. In this work, we make a finer observation that albeit larger depth is necessary, this can be achieved with a limited parameter budget via looping.

Looping in transformer models has been studied since the works [@dehghani2018universal; @lan2019albert] where they showed the benefits of looping for supervised learning tasks and BERT pretraining respectively. Looping also appears in [@schwarzschild2021can; @bansal2022end] that study the extrapolation properties of looping for certain algorithmic tasks. More recently @giannou2023looped [@de2024simulation] have studied the theoretical properties of looped decoder models and show that looping can simulate arbitrary Turing machines. In addition [@yang2023looped; @gao2024expressive; @gatmiry2024can; @gatmiry2024role] study looped models as a way to simulate iterative algorithms for in-context learning. Recently, @mohtashami2023cotformer introduced CoTFormer, which tries to improve the perplexity of looped language models and [@liu2024mobilellm] explore latency-efficiency parameter sharing for on-device LLMs. In contrast, our work focuses on the surprising inductive bias of looping to improve downstream reasoning tasks, and goes beyond algorithmic and in-context learning tasks.

`\looseness`{=latex}-1Different training algorithms (e.g. gradient descent [@soudry2018implicit]) and architectural choices (e.g. attention [@edelman2022inductive]) have been shown to have certain implicit biases. There is increasing interest in such inductive biases during pretraining [@saunshi22understanding; @liu2023same]. More recently, @saunshi2024inductive showed an inductive bias of stacking [@reddi2023efficient] towards improving reasoning and hypothesize that a connection of stacking to looped models could be responsible for this. Our results provide further verification for this hypothesis.

Conclusions, limitations and future work
========================================

`\looseness`{=latex}-1This work explores a new direction of \`\`looped models for reasoning". Not only are looped models able to solve many reasoning problems with very fewer parameters, they also have an inductive bias towards disproportionately improving the reasoning performance of language models, over memorization. The theoretical results on the expressivity of looped models start to provide some hints into their depth-optimality. While we test looped models on a subset of reasoning problems, a natural question is whether the results hold for many other forms of reasoning (e.g. multimodal and common-sense reasoning). In particular, a succinct formalization of reasoning problems itself is an interesting future direction. Furthermore, the inductive bias towards improved reasoning performance at the same perplexity is very intriguing and deserves further exploration. We find the scaling behavior of looped models very fascinating, and the connections to *latent thoughts* and CoT reasoning start to provide hints into this behavior. We hope this inspires future exploration on using looped models for more efficient inference-time scaling to aid with deeper reasoning.

```{=latex}
\bibliographystyle{iclr2025_conference}
```
```{=latex}
\appendix
```
Experiments
===========

Simple reasoning setup details
------------------------------

#### $n$-ary addition. {#sec:apx_addition}

All experiments are run on a standard Transformer architecture with input dimension of 256, 8 heads and 1024 hidden dimension in the feed-forward layers. We train using Adafactor [@shazeer2018adafactor] employing a linear warmup coupled with a cosine decay schedule for the learning rate. All runs use a batch size of 1024, learning rate of 0.005 and run for 200k steps. This corresponds to 200M examples, which is insignificant compared to the total possible examples ($> 10^{320}$). Thus memorization of the answer is not an issue. Since training is a bit noisy, for each setting, we run 3 different random seeds and pick the run with the maximum average accuracy. We pick maximum instead of average because we care about expressivity power of these models.

#### $p$-hop induction. {#sec:apx_khop}

We formally define the $p$-hop problem below:

```{=latex}
\begin{definition}[$p$-$\hop$,~\cite{sanford2024transformers}]\label{defi:khop}
For a finite alphabet $\Sigma$, define the map $\hop_p: \Sigma^n \to \Sigma \cup \{\perp\}$ as $\hop_p(\vv) = v_{\find_p(\vv, n)}$ if $\find_p(\vv, n) \neq 0$ else $\perp$, where
\begin{align*}
    \find_1(\vv, i) &= \max\left(\{0\} \cup \{j \le i, v_{j-1} = v_i \}\right) \\
    \find_p(\vv, i) &= \find_1(\vv, \find_{p-1}(\vv, i)) \text{ for } p \ge 2.
\end{align*}
\end{definition}
```
```{=latex}
\begin{figure*}[tbp]

    
    \begin{subfigure}{0.33\textwidth}
        \includegraphics[scale=0.36]{Media/Heatmap.Model_Pile_CosReg10_blk4.VariableKind_self_attention.post.w}
    \caption{Attn:Q}
    \label{fig:cosines_cosreg_attn_q}
    \end{subfigure}\hfill

    \begin{subfigure}{0.33\textwidth}
        \includegraphics[scale=0.36]{Media/Heatmap.Model_Pile_CosReg10_blk4.VariableKind_ff_layer.ffn_layer1.linear.w.png}
    \caption{FFN:W1}
    \label{fig:cosines_cosreg_ffn_w1}
    \end{subfigure}\hfill

    \begin{subfigure}{0.33\textwidth}
       
    % width=\textwidth
     \includegraphics[scale=0.36]{Media/Heatmap.Model_Pile_CosReg10_blk4.VariableKind_self_attention.post.w}
    \caption{Attn:PostNorm}
    \label{fig:cosines_cosreg_postnorm}
    \end{subfigure}\hfill
    \caption{\looseness-1Cosine similarities for different layers in the model trained with the regularization strength $\lreg=10$ for block size $k=4$ (see \Cref{sec:regularization} for details). The 24 layer model will have 6 such blocks of size 4. The heatmap above shows the cosine similarities between weights for all pairs of blocks. Overall we find the final cosine similarity to be very high, thus suggesting a strong connection to looped models.}
    \label{fig:cosines_cosreg}
\end{figure*}
```
```{=latex}
\begin{figure*}[tbp]

    
    \begin{subfigure}{0.33\textwidth}
        \includegraphics[scale=0.36]{Media/Heatmap.Model_Pile_Baseline_blk4.VariableKind_self_attention.post.w.png}
    \caption{Attn:Q}
    \label{fig:cosines_baseline_attn_q}
    \end{subfigure}\hfill

    \begin{subfigure}{0.33\textwidth}
        \includegraphics[scale=0.36]{Media/Heatmap.Model_Pile_Baseline_blk4.VariableKind_ff_layer.ffn_layer1.linear.w.png}
    \caption{FFN:W1}
    \label{fig:cosines_baseline_ffn_w1}
    \end{subfigure}\hfill

    \begin{subfigure}{0.33\textwidth}
       
    % width=\textwidth
     \includegraphics[scale=0.36]{Media/Heatmap.Model_Pile_Baseline_blk4.VariableKind_self_attention.post.w.png}
    \caption{Attn:PostNorm}
    \label{fig:cosines_baseline_postnorm}
    \end{subfigure}\hfill
    \caption{\looseness-1Cosine similarities for different layers in the baseline model trained without any regularization strength. Overall the cosine similarities are very low for large matrices, as expected for high dimensions matrices.}
    \label{fig:cosine_similarity}
\end{figure*}
```
For the $p$-hop problem we sample instances randomly while enforcing that there always exists $p$-hops present in the input sequence. We do this by first picking the sequence of $p$-hops randomly and then shuffling them around in a sequence with filler tokens to be filled by the remaining characters. After the shuffle, we sample the remaining characters to occur in place of the filler tokens while respecting the $p$-hop order. Our train set consists of 4M examples and our test and validation sets consist of 262k examples each. For all models we train on this dataset, the model dimension is 128, hidden dimension is 512 and 8 attention heads are used. Rotary positional encodings are used as well. We train using Adafactor for 200,000 steps with a batch size of 256 using a base learning rate of $10^{-3}$ and use a linear warmup coupled with a cosine decay schedule for the learning rate.

#### i-GSM. {#sec:apx_igsm}

We describe how the i-GSM dataset is generated in more detail here. We start with a hierarchy of entities of depth 4 from which we build a randomly sampled structure graph where directed edges connect entities in level $i$ to those in level $i+1$. Each edge in the structure graph defines a instance parameter which is an integer (for e.g. an edge between city center and car parks denotes the number of car parks in city center). We then construct a randomly sampled mathematical dependency graph which is a DAG over all the instance parameters by relating each to the others. Finally we pick one of the nodes of the dependency graph to query and the goal is to compute the numerical value of this node modulo some prime number $P$. For more details on the sampling process for the structure and dependency graphs, we refer the reader to [@ye2024physics]. We make 3 simplifications compared to [@ye2024physics]: we phrase our problems in a symbolic language without the English construction of sentences (see `\Cref{fig:sample-eigsm-problem}`{=latex}); we do not allow abstract parameters in our problems; we perform arithmetic modulo $7$ as opposed to $23$. Our train dataset consists of around 4 million examples and we test on around 50k examples. Given the significantly larger number of unique solution templates possible, train-test overlap in the problem template space is going to be limited with high probability. For all models we train on this dataset, the model dimension is 128, hidden dimension is 512 and 8 attention heads are used. Rotary positional encodings are used as well. We train using Adafactor for 200,000 steps with a batch size of 256 using a base learning rate of $10^{-3}$ and use a linear warmup coupled with a cosine decay schedule for the learning rate.

```{=latex}
\begin{figure*}[!tbp]

    \begin{subfigure}{0.24\textwidth}
        \includegraphics[width=\textwidth]{Media/8L_closedQA.png}
    % \caption{}
    \label{fig:8L_closedQA}
    \end{subfigure}\hfill
    
    \begin{subfigure}{0.23\textwidth}
    
    \includegraphics[width=\textwidth]{Media/8L_openQA.png}
    % \caption{}
    \label{fig:8L_openQA}
    \end{subfigure}\hfill

    \begin{subfigure}{0.23\textwidth}
        \includegraphics[width=\textwidth]{Media/8L_mwp.png}
    % \caption{}
    \label{fig:8L_mwp}
    \end{subfigure}\hfill

    \begin{subfigure}{0.23\textwidth}
        \includegraphics[width=\textwidth]{Media/8L_primitives.png}
    % \caption{}
    \label{fig:8L_primitives}
    \end{subfigure}\hfill
    \caption{\looseness-1Downstream evaluation for various task groups on the x-axis, vs validation log perplexity on the y-axis (reversed), as training proceeds. The top plots compare a 8-layer baseline model \loopy{8}{1} and the looped model \loopy{8}{3}.
    Similarly to \Cref{fig:iso_plots_taskgroups_12L}, for closed book QA (memorization) tasks looping has very similar trends to baseline. For open book QA tasks and math word problems, looping has much better downstream performance at an equivalent log perplexity.}
    \label{fig:iso_plots_taskgroups_8L}
\end{figure*}
```
Language modeling setup {#sec:apx_language_modeling}
-----------------------

We train on the Pile data using causal language modeling objective. The dataset is pre-processed and cached to ensure that all models are trained on exactly the same tokens in the same order. For all experiments, we use a batch size of 512 and sequence length of 1280. We use a cosine learning rate schedule decaying over 400k steps with a peak learning rate of 0.01, tuned based on the baseline model. The base model is a 1.5B parameter decoder only Transformer model, with 24 layers, model dimensions of 2048, hidden dimension 5120 and 32 heads. For the shallower baselines and looped models, we only change the number of layers and keep all other hyperparameters the same.

Results for each task group
---------------------------

In `\Cref{sec:language_modeling}`{=latex} we discussed results for various task groups like closed book QA, math word problems etc. Here we present results for each individual task for completeness. Detailed results for closed book QA are in `\Cref{table:main_closedqa_results}`{=latex}, open book QA in `\Cref{table:main_openqa_results}`{=latex}, math word problems in `\Cref{table:main_mwp_results}`{=latex} and reasoning primitives in `\Cref{table:main_primitives_results}`{=latex}. These tables include, both, looped models from `\Cref{table:language_modeling_results}`{=latex} and the models trained with regularization from `\Cref{table:regularization_results}`{=latex}.

```{=latex}
\scalebox{0.75}{
\begin{tabular}{lc|cccc|c}
\toprule
 & Params / FLOPs & TriviaQA & TydiQA- & Natural & Web & Average\\
 &  &  & NoContext & Questions & Questions & Average\\
\midrule
\hline
Baseline & 24x / 24x & 24.5 & 10.6 & 4.3 & 5.6 & 11.2 \\
\hline
Base \loopy{12}{1} & 12x / 12x & 15.9 & 9.3 & 2.4 & 5.2 & 8.2 \\
Loop \loopy{12}{2} & 12x / 24x & 18.6 & 9.6 & 3.4 & 5.8 & 9.3 \\
Regularized $(k=12)$ & 24x / 24x & 20.9 & 10.9 & 3.6 & 5.0 & 10.1 \\
\hline
Base \loopy{8}{1} & 8x / 8x & 12.3 & 7.4 & 1.8 & 3.9 & 6.3 \\
Loop \loopy{8}{3} & 8x / 24x & 17.3 & 8.8 & 2.5 & 5.2 & 8.5 \\
Regularized $(k=8)$ & 24x / 24x & 23.6 & 10.6 & 4.1 & 6.7 & 11.3 \\
\hline
Base \loopy{6}{1} & 6x / 6x & 7.4 & 4.8 & 1.1 & 2.7 & 4.0 \\
Loop \loopy{6}{4} & 6x / 24x & 16.0 & 9.3 & 2.7 & 4.9 & 8.2 \\
Regularized $(k=6)$ & 24x / 24x & 25.8 & 12.2 & 4.5 & 5.6 & 12.0 \\
\hline
Base \loopy{4}{1} & 4x / 4x & 3.3 & 1.9 & 0.6 & 1.5 & 1.8 \\
Loop \loopy{4}{6} & 4x / 24x & 11.8 & 8.8 & 2.4 & 3.5 & 6.7 \\
Regularized $(k=4)$ & 24x / 24x & 26.1 & 13.3 & 4.5 & 6.2 & 12.5 \\
\end{tabular}
}
```
```{=latex}
\scalebox{0.75}{
\begin{tabular}{lc|ccccc|c}
\toprule
 & Params / FLOPs & TydiQA-NoContext & SquadV2 & DROP & QuAC & CoQA & Average\\
\midrule
\hline
Baseline & 24x / 24x & 33.4 & 41.3 & 23.4 & 18.6 & 52.7 & 33.9 \\
\hline
Base \loopy{12}{1} & 12x / 12x & 21.8 & 34.6 & 20.0 & 17.3 & 41.1 & 26.9 \\
Loop \loopy{12}{2} & 12x / 24x & 27.7 & 39.5 & 22.3 & 17.6 & 46.8 & 30.8 \\
Regularized $(k=12)$ & 24x / 24x & 33.0 & 44.3 & 23.8 & 17.7 & 51.9 & 34.1 \\
\hline
Base \loopy{8}{1} & 8x / 8x & 12.7 & 33.8 & 16.0 & 15.3 & 35.9 & 22.7 \\
Loop \loopy{8}{3} & 8x / 24x & 29.8 & 38.1 & 21.6 & 17.6 & 46.6 & 30.8 \\
Regularized $(k=8)$ & 24x / 24x & 33.0 & 41.7 & 25.2 & 19.3 & 52.9 & 34.4 \\
\hline
Base \loopy{6}{1} & 6x / 6x & 8.2 & 26.8 & 14.8 & 14.4 & 32.5 & 19.3 \\
Loop \loopy{6}{4} & 6x / 24x & 26.6 & 34.6 & 20.8 & 18.1 & 43.7 & 28.7 \\
Regularized $(k=6)$ & 24x / 24x & 34.5 & 46.1 & 26.2 & 18.6 & 53.4 & 35.8 \\
\hline
Base \loopy{4}{1} & 4x / 4x & 3.4 & 19.0 & 11.1 & 13.1 & 22.3 & 13.8 \\
Loop \loopy{4}{6} & 4x / 24x & 22.0 & 32.4 & 20.1 & 15.8 & 40.7 & 26.2 \\
Regularized $(k=4)$ & 24x / 24x & 33.6 & 49.4 & 24.1 & 19.8 & 54.0 & 36.2 \\

\end{tabular}
}
```
```{=latex}
\scalebox{0.75}{
\begin{tabular}{lc|cccccc|c}
\toprule
 & Params / FLOPs & ASDiv & MAWPS- & MAWPS- & MAWPS- & MAWPS- & SVAMP & Average\\
 &  &  & AddSub & MultiArith & SingleEq & SingleOp &  & \\
\midrule
\hline
Baseline & 24x / 24x & 26.9 & 49.9 & 2.7 & 38.6 & 40.2 & 17.9 & 29.3 \\
\hline
Base \loopy{12}{1} & 12x / 12x & 23.1 & 44.8 & 2.0 & 36.8 & 37.9 & 15.3 & 26.7 \\
Loop \loopy{12}{2} & 12x / 24x & 31.5 & 55.9 & 2.0 & 47.4 & 50.2 & 18.5 & 34.3 \\
Regularized $(k=12)$ & 24x / 24x & 27.4 & 54.4 & 2.7 & 43.7 & 45.9 & 19.8 & 32.3 \\
\hline
Base \loopy{8}{1} & 8x / 8x & 12.9 & 32.7 & 2.0 & 19.9 & 21.9 & 13.4 & 17.1 \\
Loop \loopy{8}{3} & 8x / 24x & 23.2 & 54.9 & 2.3 & 36.2 & 38.3 & 15.6 & 28.4 \\
Regularized $(k=8)$ & 24x / 24x & 31.0 & 56.7 & 3.5 & 40.7 & 47.5 & 17.3 & 32.8 \\
\hline
Base \loopy{6}{1} & 6x / 6x & 15.4 & 31.1 & 1.3 & 21.3 & 26.3 & 10.9 & 17.7 \\
Loop \loopy{6}{4} & 6x / 24x & 24.5 & 51.6 & 0.7 & 38.0 & 47.7 & 16.3 & 29.8 \\
Regularized $(k=6)$ & 24x / 24x & 32.0 & 47.1 & 3.5 & 44.9 & 42.0 & 16.8 & 31.0 \\
\hline
Base \loopy{4}{1} & 4x / 4x & 7.9 & 10.4 & 1.5 & 11.0 & 19.0 & 8.3 & 9.7 \\
Loop \loopy{4}{6} & 4x / 24x & 19.9 & 49.1 & 1.7 & 29.3 & 35.9 & 12.6 & 24.8 \\
Regularized $(k=4)$ & 24x / 24x & 35.0 & 53.9 & 3.2 & 52.2 & 50.9 & 23.0 & 36.4 \\

\end{tabular}
}
```
```{=latex}
\scalebox{0.75}{
\begin{tabular}{lc|cccc|c}
\toprule
 & Params / FLOPs & Code & Math & Code & Math & Average\\
 &  &  Depth-0 & Depth-0 & Depth-1 & Depth-1  & \\
\midrule
\hline
Baseline & 24x / 24x & 71.1 & 72.5 & 24.2 & 22.3 & 47.5 \\
\hline
Base \loopy{12}{1} & 12x / 12x & 52.2 & 51.6 & 20.7 & 18.3 & 35.7 \\
Loop \loopy{12}{2} & 12x / 24x & 73.5 & 86.5 & 21.3 & 23.6 & 51.2 \\
Regularized $(k=12)$ & 24x / 24x & 75.1 & 74.9 & 27.0 & 25.7 & 50.7 \\
\hline
Base \loopy{8}{1} & 8x / 8x & 48.7 & 42.0 & 21.4 & 19.8 & 33.0 \\
Loop \loopy{8}{3} & 8x / 24x & 88.4 & 86.9 & 23.1 & 22.9 & 55.3 \\
Regularized $(k=8)$ & 24x / 24x & 92.2 & 86.7 & 23.1 & 23.3 & 56.3 \\
\hline
Base \loopy{6}{1} & 6x / 6x & 26.3 & 29.7 & 20.2 & 20.1 & 24.1 \\
Loop \loopy{6}{4} & 6x / 24x & 90.6 & 88.1 & 24.1 & 21.7 & 56.1 \\
Regularized $(k=6)$ & 24x / 24x & 84.0 & 79.7 & 31.4 & 28.3 & 55.8 \\
\hline
Base \loopy{4}{1} & 4x / 4x & 19.3 & 23.3 & 17.3 & 17.6 & 19.4 \\
Loop \loopy{4}{6} & 4x / 24x & 87.9 & 90.0 & 24.8 & 24.8 & 56.9 \\
Regularized $(k=4)$ & 24x / 24x & 88.0 & 86.9 & 27.9 & 25.8 & 57.2 \\




\end{tabular}
}
```
Theoretical results {#sec:appendix_proof}
===================

Detailed notations {#sec:apx_detailed_notation}
------------------

```{=latex}
\begin{definition}[Embedding Layer]\label{defi:embedding_layer}
Given a finite vocabulary $\mathcal{V}$, embedding dimension $d\in\mathbb{N}^+$, token embedding parameter $\theta_{\tokenembedding}\in\mathbb{R}^{d\times \abs{\mathcal{V}}}$ and position embedding parameter $\theta_{\posencoding}\in \mathbb{R}^{d\times n_{\max}}$, we define the \emph{embedding layer} as a sequence-to-sequence map, denoted by $\embed_{\theta_\tokenembedding,\theta_\posencoding}:\mathcal{V}^n\to (\mathbb{R}^d)^n$ for any $1\le n\le n_{\max}$, where
\begin{equation}
    \embed_{\theta_\tokenembedding,\theta_\posencoding}(v_1,\ldots,v_n) = \left(\theta_\tokenembedding(v_1)+ \theta_\posencoding(1), \ldots, \theta_\tokenembedding(v_n)+ \theta_\posencoding(n)  \right). 
\end{equation}
\end{definition}
```
#### Multi-Head Self-Attention Mechanism:

Given attention parameters $\theta_\attn = \{W_Q, W_K, W_V, W_O\}$, where each $W_Q^m, W_K^m, W_V^m, W_O^m \in \mathbb{R}^{d_{\attn} \times d }$, we define the *Self-Attention* layer with a causal mask for a decoder-only transformer in `\Cref{alg:defi_attn}`{=latex}. We also define a *Multi-Head Attention* layer as a collection of self-attention layer with non-shared parameters $\theta_\mha = \{\theta_{\attn}^{(h)}\}_{h=1}^H$, and its output is the sum of the outputs from each head. That is, $\mha_{\theta_\mha} = \sum_{h=1}^H \attn_{\theta_\attn^{(h)}}$.[^2]

```{=latex}
\begin{algorithm}
    \caption{Causal Self-Attention, $\attn$}\label{alg:defi_attn}
    \begin{algorithmic}[1]
    \Require  Parameter $\theta_\attn = (W_Q,W_K,W_V,W_O)$, Input embedding $ x_1,\ldots, x_n)\in \left(\mathbb{R}^{d}\right)^n$.
    \Ensure Output embedding $x'= (x'_1,\ldots, x'_n) \triangleq \attn_{\theta_\attn}(x_1,\ldots, x_n)$.
    \State $q_i \triangleq W_Q  x_i, k_i \triangleq W_K x_i, v_i \triangleq W_V x_i, \forall i\in[n]$
    \State $s_i \triangleq \softmax(\inner{q_i}{k_1},\ldots,\inner{q_i}{k_i}) \| (0,\ldots, 0) $. 
    \State $h'_i \triangleq W_O^\top \sum_{j=1}^n (s_i)_j v_j$.
    \end{algorithmic}
    \end{algorithm}
```
#### Feed-Forward Network:

Given the parameters of the fully-connected feedforward network layer $\theta_\ff = (W_1, b_1, W_2, b_2) \in \mathbb{R}^{x_{\ff} \times d} \times \mathbb{R}^{d_{\ff}} \times \mathbb{R}^{d \times d_{\ff}} \times \mathbb{R}^{d_{\ff}}$, we define the feedforward layer $\ff_{\theta_\ff}: \mathbb{R}^{d} \to \mathbb{R}^{d}$ as $\ff_{\theta_{\ff}}(h) \triangleq W_2, \relu(W_1 h + b_1) + b_2$.

```{=latex}
\iffalse
```
```{=latex}
\begin{definition}[Transformer Block]\label{defi:transformer_block}
Given number of layers $L\in\mathbb{N}^+$ and parameter $\theta_\tfblock = (\theta^{(l)}_\mha,\theta^{(l)}_\ff )_{l=0}^{L-1}$, we define the $L$-layer transformer block $\tfblock_{\theta_\tfblock}:(\mathbb{R}^d)^n\to (\mathbb{R}^d)^n$ for any $n\in\mathbb{N}^+$ as \begin{equation}
    \tfblock_{\theta_\tfblock} \triangleq (\id+ \ff_{\theta^{(L-1)}_\ff})\circ (\id+ \mha_{\theta^{(L-1)}_\mha})\circ \cdots (\id+ \ff_{\theta^{(0)}_\ff})\circ (\id+ \mha_{\theta^{(0)}_\mha})
\end{equation}
\end{definition}
```
```{=latex}
\fi
```
```{=latex}
\begin{definition}[Output Layer]\label{defi:output_layer}
    Given parameter $\theta_\transoutput \in\mathbb{R}^{d\times |\mathcal{V}|}$, we denote the output layer as $\transoutput_{\theta_\transoutput}:(\mathbb{R}^d)^n\to \Delta^{|\mathcal{V}|-1}$, where 
    \begin{align}
        \transoutput_{\theta_\transoutput}(x_1,\ldots,x_n) \triangleq \softmax(x_n^\top \theta_{\transoutput})
    \end{align}
\end{definition}
```
Finally, we define the entire transformer model $p_{\theta}: \cup_{n\le n_{\max}} \mathcal{V}^n\to \Delta^{|\mathcal{V}|-1}$ as $$\begin{aligned}
p_{\theta}\triangleq \transoutput_{\theta_\transoutput}\circ \tfblock_{\theta_\tfblock}\circ \embed_{\theta_\tokenembedding,\theta_\posencoding}\end{aligned}$$ for any $\theta = (\theta_\tfblock,\theta_\tokenembedding,\theta_\posencoding,\theta_\transoutput)$.For convenience, we also write $\left[p_{\theta}(v_1,\ldots,v_n)\right](v)$ as $p_{\theta}(v\mid v_1,\ldots,v_n)$. In particular, we use $\transformer_{\theta}(v_1,\ldots,v_n) \triangleq \argmax_{v\in\mathcal{V}} p_{\theta}(v|v_1,\ldots,v_n)$ to denote the deterministic version of the transformer model.

```{=latex}
\iffalse
```
```{=latex}
\begin{definition}[Looped Transformer]\label{defi:looped_transformer}
Given number of loops $T\in\mathbb{N}^+$,  parameters $\theta = (\theta_{\tfblock},\theta_\tokenembedding,\theta_\posencoding,\theta_\transoutput)$, where $\theta_{\transformer} = (\theta^{(l)}_\mha,\theta^{(l)}_\ff )_{l=0}^{L-1}$, we define the \emph{looped transformer} as  $p_{\theta,T}\triangleq \transoutput_{\theta_\transoutput}\circ \left(\tfblock_{\theta_\tfblock}\right)^T\circ \embed_{\theta_\tokenembedding,\theta_\posencoding}$ and the corresponding deterministic version as $\transformer_{\theta,T}(v_1,\ldots,v_n) \triangleq \argmax_{v\in\mathcal{V}} p_{\theta,T}(v|v_1,\ldots,v_n)$.
\end{definition}
```
```{=latex}
\fi
```
#### Finite-precision Modeling:

In this paper we assume the transformer is of finite precision. More specifically, we follow the setting in @li2024chain and use the shorthand $\Floating_{s}\triangleq \{c\cdot k\cdot 2^{-s}\mid c\in \{-1,1\}, 0\le k\le 2^{2s}-1, k\in\mathbb{N}\}$ to denote fixed-point numbers of constant precision $s$ and rounding operation $\rds{\cdot}:\mathbb{R}\to \Floating_s$ to denote the correcting rounding, namely the mapping from $\mathbb{R}$ to the closest representable number in $\Floating_s$. (We break the tie by picking the number with smaller absolute value). We assume that (1). all the parameters of the transformer are in $\Floating_{s}$ and (2). correct rounding is performed after every binary operation in the forward pass of the transformer. We will refer the readers to @li2024chain for detailed discussion on such finite-precision modeling and only list important notations and lemmas that will be used in this paper below.

We use $1_s$ to denote all-one vectors of length $s$. Similarly we define $\inner{\cdot}{\cdot}_s$, $\times_s$, and $\softmax_s$. We recall that for any $s\in \mathbb{N}^+$ and integer $0\le x\le 2^s-1$, we use $\bin_s(x)\in\{0,1\}^s$ to denote the usual binary encoding of integer $x$ using $s$ binary bits in the sense that $x = \sum_{i=1}^s 2^i (\bin_s(x))_i$ and $\sbin_s(x)\in\{-1,1\}^s$ to denote the signed binary encoding, which is $2\bin_s(x)-(1,\ldots,1)$. Finally we define $B_s = \max \Floating_s = 2^s-2^{-s}$.

```{=latex}
\begin{lemma}\label{lem:exp_rounding}[Lemma E.1, \citep{li2024chain}]
    For any $s\in\mathbb{N}^+$, it holds that $\rds{\exp(-B_s)} = 0$.
\end{lemma}
```
```{=latex}
\begin{lemma}\label{lem:exp_rounding_up}[Lemma E.2, \citep{li2024chain}]
    For any $s\in\mathbb{N}^+$, it holds that $\rds{\exp(B_s)} = B_s$.
\end{lemma}
```
```{=latex}
\begin{lemma}\label{lem:FB_simulating_boolean_gates}[Lemma E.5, \citep{li2024chain}]
Unlimited-fanin $\AND,\OR$ (resp. $\MAJORITY) :\{0,1\}^n\to \{0,1\}$ can be simulated by some 2-layer feedforward ReLU network with constant (resp. $\log n$) bits of precision constant hidden dimension and additional $n$ constant inputs of value 1. 

Mathematically, let $\ff[s(n)]$ be the set of functions $C:\{0,1\}^n\to\{0,1\}$ which can be a two-layer feedforward ReLU network with at most $s(n)$ bits of precision and constant hidden dimension $\ff_{\theta}:\{0,1\}^{2n}\to \{0,1\}, \ff_{\theta}(x') = W_2\times_s\relu(\rds{W_1\times_s x'+b_1})$, where $\theta = (W_2,W_1,b_1)$, such that for any $x\in\{0,1\}^n$, 
\begin{align}
\ff_\theta(x_1,1,x_2,1,\ldots,x_n,1) = C(x).    
\end{align}
We have unlimited-fanin $\AND,\OR\in \ff[1]$ and $\MAJORITY\in \ff[\log n]$.  
\end{lemma}
```
Given two vectors $x,y$ of the same length $s$, we use $\interleave{x}{y}$ to denote their interleaving, that is, $(\interleave{x}{y})_{2i-1} = x_i, (\interleave{x}{y})_{2i} = y_i$ for all $i\in [e]$.

```{=latex}
\begin{lemma}\label{lem:attention_rounding}[Lemma E.3, \citep{li2024chain}]
    For any $s\in\mathbb{N}^+$, let $q_i = \interleave{\sbin_s(i)}{1_s}$ and $k_i = B_s\cdot (\interleave{\sbin_s(i)}{(-1_s)})$ for all $i\in [2^s-1]$, it holds that $\rds{\exp(\inner{q_i}{k_j}_s)}=\indct{i=j}$ for all $i,j\in [2^s-1]$.
\end{lemma}
```
Proofs
------

### Looped models can simulate non-looped models {#sec:apx_simulate_nonlooped}

```{=latex}
\begin{proof}[Proof of \Cref{thm:main}]

We start by introduce some more notations. We will proof the theorem for any fixed sequence of $\vv = (v_1,\ldots, v_n)$. We use $\vx^{(l)}=(x_{i}^{(l)})_{i=1}^n$ to denote the intermediate embedding of $p_{\theta}$ in the $l$th layer. More specifically, we define 

\begin{align}
    \vx^{l} = (\id+ \ff_{\theta^{(l-1)}_\ff})\circ (\id+ \mha_{\theta^{(l-1)}_\mha})\circ \cdots (\id+ \ff_{\theta^{(0)}_\ff})\circ (\id+ \mha_{\theta^{(0)}_\mha})\circ \embed_{\theta_{\embed}}(\vv).
\end{align}

We also use $\vx^{(l+0.5)}=(x_{i}^{(l+0.5)})_{i=1}^n$ to denote the intermediate embedding of $p_{\theta}$ in the $l$th layer after the attention layer.
\begin{align}
    \vx^{l+0.5} =  (\id+ \mha_{\theta^{(l-1)}_\mha})(\vx^{l}).
\end{align}

Similarly, for the constructed looped transformer $p_{\theta,T}$, we use $\vv'=(\#,\vv_1,\ldots,\vv_n)$ to denote its input. For simplicity, we use the convention that $\#$ is at position $0$. The proof still works if $\#$ starts at position $1$ because we can just transfer the later tokens by $1$ position. We define $\vx'^{(l)}=(x'^{(l)}_0,x'^{(l)}_1,\ldots,x'^{(l)}_n)$ as the intermediate embedding of $p_{\theta}$ in the $l$th layer and $\vx'^{(l+0.5)}=(x'^{(l+0.5)}_0,x'^{(l+0.5)}_1,\ldots,x'^{(l+0.5)}_n)$ as the intermediate embedding of $p_{\theta}$ in the $l$th layer.

Below we first state the key properties that our construction will satisfy, which imply the correctness of the theorem and then we state our construction of $p_{\theta',T}$ and show the properties are actually satisfied:
\begin{itemize}
    \item $x'^{(l)}_{i} = (x^{(l)}_i, \bm{1}_{R} - e_{r(l)},l, \indct{i=0})$.
    \item $x'^{(l+0.5)}_{i} = (x^{(l+0.5)}_i, \bm{1}_{R} - e_{r(l)},l, \indct{i=0})$.
    \item $x^{(l)}_0 = x^{(l+0.5)}_0 = \bm{0}$.\footnote{Here we abuse the notation for simplicity of presentation. $x^{(l)}_0 = x^{(l+0.5)}_0$ are not defined in the original non-looped transformer. The key point here is that they are 0 vectors throughout the forward pass.}
\end{itemize}

% \begin{itemize}
%     \item $\theta'_{\ff} $ is essentially concatenation of $\theta^{(i)}_\ff$ for $i\in[r]$ in the hidden dimension of $\ff$. The goal is to make $\ff_{\theta'_{\ff}} = \sum_{i=1}^r \ff_{\theta^{(i)}_\ff}\indct{r(l)=i}$ at $l$th layer. 
%     \item  $\theta'_{\mha} $ is essentially concatenation of $\theta^{(i)}_\mha$ for $i\in[r]$ in the hidden dimension of $\mha$. The goal is to make $\mha_{\theta'_{\mha}} = \sum_{i=1}^r \mha_{\theta^{(i)}_\ff}\indct{r(l)=i}$ at $l$th layer.
%     \item Use extra dimension to count the depth $l$ and maintain $\bm{1} - e_{r(l)}$ using $R+1$ dimension on top of the original embedding size $d$ and use it to do switch. Such piecewise linear function can be approximated by a neural network with $O(L)$ hidden dimension.
%     \item use $\bm{1} - e_{r(l)}$ together with ReLU activation to do the select in MLP. Use $\indct - e_{r(l)}$ in attention to force all attention heads to attend to the dummy token $\#$ in the beginning of the sentence.
% \end{itemize}

To get the last coordinate $l$, which is a depth counter, we just need to add $1$ more hidden dimension in MLP. 

Next we show we can use two-layer with $L+2$ MLP to get the mapping from $\ell\mapsto \bm{1}_{R} - e_{r(l)}$. Let $(\theta^{(i)}_\ff,\theta^{(i)}_\mha)_{i=1}^R$ be the parameters of the $R$ distinct layers in $\theta$. We assume in the $l$th layer, $r(l)$'s parameters are used. This is because $e_{r(l)} = \sum_{i=1}^L e_{r(i)}0.5*( [l-i+1 ]_+-2[l-i]_+ + [l-i-1]_+)$.

Now we explain how to deactivate the undesired MLP neurons. In other words, our construction of $\theta'_{\ff} $ is essentially concatenation of $\theta^{(i)}_\ff$ for $i\in[r]$ in the hidden dimension of $\ff$, with the additional control that  $\ff_{\theta'_{\ff}}((x^{(l)}_i, \bm{1}_{R} - e_{r(l)},l),\indct{i=0}) = \sum_{i=1}^R \ff_{\theta^{(i)}_\ff}(x^{(l)}_i)\indct{r(l)=i,i\neq 0}$ at $l$th layer. This control can be done by subtracting $\bm{1} - e_{r(l)} + \indct{i=0}$ by a constant which is larger than the maximum pre-activation in the hidden layer.

Finally we explain how to deactivate the undesired attention. We will only use attention to update the first part of the embedding, which is $x^{(l+0.5)}_i$. A crucial step here is that we set the token embedding of $\#$ as $0$
We construct keys and queries as follows: 
\begin{enumerate}
    \item ${W'}_Q^{(r')}(x'^{(l)}_{i}) = (W_Q^{(r')}x^{(l)}_{i}, 1- \indct{r'=r(l)} $ for $r' \in [R]$ and $i=0,\ldots,n$
    \item ${W'}_K^{(r')}(x'^{(l)}_{i}) = (W_K^{(r')}x^{(l)}_{i}, -B\indct{i=0}) $ for $r' \in [R]$ and $i=0,\ldots,n$, where $B$ is some constant larger than the maximum previous inner product in attention, $\max_{l\in[L],i,j} \inner{(W_K x^{(l)}_{i}}{(W_Qx^{(l)}_{i}}$. 
    \item ${W'}_O^{(r')}{W'}_V^{(r')}(x'^{(l)}_{i}) = (W_O^{(r')}W_V^{(r')}x^{(l)}_{i}, \bm{0}, 0,0)$.
\end{enumerate}
This construction works because only the `desired' attention head $r=r(l)$ will be activated and behave as in the non-looped case, because otherwise all position in that attention head will be completely attended to position $0$ and returns a zero vector. (We can choose $B$ to be large enough and distribution calculated by the attention score is delta distribution) at position $0$, which yields a zero vector as its value. This completes the proof.
\end{proof}
```
### Group composition. {#sec:apx_group_composition}

```{=latex}
\begin{algorithm}[t]\caption{Group Composition}\label{alg:group_composition}
\begin{algorithmic}[1]
\Require  Group elements $g_0,g_1,\ldots,g_n\in G$, where $g_0=e$.
\Ensure $g_0\circ g_1\circ \ldots g_n$.
\State $g^{(0)}_i = g_i$, $\forall0\le i\le n$.
\For{$l=1\to \lceil \log_2 n\rceil$}
\State $a^{(l)}_i = g^{(l-1)}_{[2i-n-1]_+},b^{(l)}_i = g^{(l-1)}_{[2i-n]_+}$  $\forall0\le i\le n$.
\State $g^{(l)}_i = a^{(l)}_i\circ b^{(l)}_i$, $\forall0\le i\le n$.
\EndFor
\State \Return $g^{(\lceil\log _2 n\rceil)}_{n}$.
\end{algorithmic}
\end{algorithm}
```
The landmark result in automata theory, Krohn-Rhodes Decompotision Theorem [@krohn1965algebraic], shows that all semi-automaton with solvable transformation group (which includes composition problem of solvable groups) can be simulated by a cascade of permutation-reset automata, which can be simulated by $\TC^0$ circuits. [@liu2022transformers] further showed that such automaton with solvable transformation group can be continuously simulated by constant-depth transformers. However, it is also shown [@barrington1986bounded] that the composition problem of unsolvable groups are $\NC^1$-complete, for example, the composition of permutation group over $5$ elements, $S_5$. Under the common hardness assumption that $\NC^1\neq \TC^0)$, constant depth transformer cannot solve composition of $S_5$ using a single forward pass [@merrill2023parallelism; @liu2022transformers; @li2024chain]. But with CoT, very shallow transformers (depth equal to one or two) can simulate the composition problem of any group[@li2024chain; @merrill2023expresssive].

```{=latex}
\begin{proof}[Proof of \Cref{thm:group_composition_log_depth}]We will set the token embedding of $\#$ the same as that of $e$, which is the identity of $G$. In the following proof, we will just treat $\#$ as $e$. We will construct the transformer simulating group composition following the following algorithm~\Cref{alg:group_composition}, which gives the high-level idea of the construction. The correctness of \Cref{alg:group_composition} follows from the associativity of group composition. More concretely, we can verify by induction that $g^{l}_0\circ g^{l}_1\circ\ldots g^{l}_n$ is the same for all $l=0,\ldots, \lceil \log_2 n\rceil$ and in the final round, i.e., when $l=\lceil \log_2 n\rceil$, $g^{(l)}_i=e$  for all $i<n$. 



Below we show how to construct a transformer of the given sizes to simulate the above \Cref{alg:group_composition}.
We will embed each $g\in G$ as a different vector  $\overline{g} \in \{-1,1\}^{\lceil\log_2 |G|\rceil}$ and each position $0\le i\le n$ as its binary representation in $\overline{i} \in\{-1,1\}^{\lceil\log_2 n+1\rceil}$, which is a shorthand for $\sbin_s(i)$ with $s=\lceil\log_2 n+1\rceil$.
We concatenate them to get $\{x^{(0)}_i\}_{i=0}^n$, that is, $x^{(0)}_i = (\overline{g_i},\overline{i}, \overline{[2i-n-1]_+}, \overline{[2i-n-1]_+},0^{\lceil \log_2 |G|\rceil},0^{\lceil \log_2 |G|\rceil})$.
For convenience, we will drop the 0's in the end (also in the other proofs of the paper) and write it as $x^{(0)}_i = (\overline{g_i},\overline{i} , \overline{[2i-n-1]_+}, \overline{[2i-n-1]_+})$.
Below we show we can construct 1-layer transformer block with parameter $(\theta_\mha,\theta_\ff)$ satisfying that
\begin{enumerate}
    \item $\left[ \mha_{\theta_\mha}\left((\overline{g_i},\overline{i},\overline{[2i-n-1]_+}, \overline{[2i-n-1]_+})_{i=0}^n\right)\right]_k = (0^{\lceil\log_2 |G|\rceil +3\lceil\log_2 n+1\rceil}, \overline{g_{[2k-n-1]_+}}, \overline{g_{[2k-n]_+}})$ \\for all $g_0=e,g_i\in G \forall i\in[n]$, $k=0,\ldots,n$;
    \item $\ff_{\theta_\ff}(\overline{g},\overline{i}, \overline{j}, \overline{k},\overline{g'}, \overline{g''}) = (\overline{g'\circ g''}-\overline{g}, 0^{3\lceil\log_2 n+1\rceil},-\overline{g'}, -\overline{g''})$, for all $i,j,k=0,\ldots,n$, $g,g',g''\in G$. 
\end{enumerate}

The first claim is because we can use two attention heads to retrieve $ \overline{g_{[2k-n-1]_+}}$ and $ \overline{g_{[2k-n]_+}}$ respectively, where both of them use $\overline{k}$ as the key and use $-\overline{[2k-n-1]_+}$ and $-\overline{[2k-n]_+}$ as queries respectively. This is possible because all the required information are already in $x_i$. We further make attention temperature low enough so the probability returned by attention is a one-hot distribution at the position whose key is equal to the negative query after rounding.

Now we turn to the second claim about MLP. We will use $|G|^2$ neurons with ReLU activation and bias to simulate the product of $g'$ and $g''$. We can index each neuron 
by $(h,h')$ for $h,h'\in G$ and set its incoming weight $[W_1]_{(h,h'),:} = (\overline{h},\overline{h'})$ and set bias $(b_1)_{(h,h')} = - 2\lceil\log_2 |G|\rceil+1$, which ensures that the activation of neuron $(h,h')$ will only be $1$ when $g'=h,g''=h'$ and be $0$ otherwise. Then setting the outgoing weight of neuron $(h,h')$ as $\overline{h\circ h'}$ and the bias in the second layer to be $0$ finishes the construction for simulating the group composition. Finally we use the remaining $6\lceil\log_2 |G|\rceil$ to simulate negative identity mapping $x\to-x$ for the remaining $3\lceil\log_2 |G|\rceil$ embedding dimension. This completes the proof.
\end{proof}
```
Connection to chain-of-thought reasoning {#sec:apx_cot_connection}
----------------------------------------

In this section, we establish a connection betwee looped models and CoT reasoning. We first define the recursion for CoT reasoning as follows: $$\transformer^{i}_\theta(v_1,\ldots,v_n)\triangleq \transformer^{i-1}_\theta(v_1,\ldots,v_n, \transformer_\theta(v_1,\ldots,v_n)),$$ for $i, n \in \mathbb{N}^+$ satisfying $i+n\le n_{\max}-1$ along with the base case of $\transformer^{1}_\theta(v_1,\ldots,v_n)\triangleq\transformer_\theta(v_1,\ldots,v_n)$. For all $0\le i\le n_{\max} - n-1$, the output with $i$ steps of CoT is $v_{n+i+1}  = \transformer^{i+1}_\theta(v_1,\ldots,v_n) = \transformer_\theta(v_1,\ldots,v_n,v_{n+1},\ldots,v_{n+i})$.

```{=latex}
\iffalse
```
```{=latex}
\begin{theorem}[Informal]\label{thm:cot_informal}
    For any $L$-layer non-looped transformer $\transformer_\theta$, there exists a looped transformer with parameter $\theta'$ and $L+\mathcal{O}(1)$ layers, such that for any input $\vx$ and integer $m$,  the output of non-looped transformer after $m$ steps of CoT is the same as that of the looped transformer on input $x$ concatenated by $m$ dummy tokens with $m$ loops.
\end{theorem}
```
```{=latex}
\fi
```
We first give the formal statement below.

```{=latex}
\begin{theorem}[Looped transformer simulates CoT]\label{thm:cot_formal}
    For any $L$-layer non-looped transformer $\transformer_\theta$, there exists a looped transformer $\transformer_{\theta'}$ with $L+\mathcal{O}(1)$ layers, constantly many more dimensions in embedding, MLP and attention layers and constantly many more attention heads, such that for any input $\vv= (v_i)_{i=1}^n$ and integer $m$, the output of non-looped transformer after $m$ steps of CoT, $\transformer^m_\theta(\vv)$, is the same as that of the looped transformer on input $x$ concatenated by $m$ dummy tokens with $m$ loops, $\transformer_{\theta',m}(\vv,\#^m)$.
\end{theorem}
```
Below are some helping lemmas towards showing `\Cref{thm:cot_formal}`{=latex} is at least as powerful as CoT.

```{=latex}
\begin{lemma}[Simulating $\argmax$ using MLP]\label{lem:simulating_hard_argmax}
    For every $d\in \mathbb{N}$ and precision $s\in\mathbb{N}^+$, there exists a 3-layer network with $\relu$ activation and $d^2$ width $f$ with $s$ precision, such that for any $x\in\Floating_s^d$, if there is $k\in [d]$, such that $x_k >\max_{j\neq k,j\in [d]}x_j$, $f(x) = e_k$. 
\end{lemma}
```
```{=latex}
\begin{proof}[Proof of \Cref{lem:simulating_hard_argmax}]Define $g_i = 2^s\cdot \relu(2^{-s} - \sum_{j\neq i}\relu(x_j-x_i))$ for each $i\in [n]$. We claim that if there is $k\in [d]$, such that $x_k -\max_{j\neq k,j\in [d]}x_j\ge 2^{-s}$, $g_i = 1$ iff $i=k$ for all $i\in [d]$. First $g_k =2^s\cdot \relu(2^{-s}) =1 $. Next for $i\neq k$, it clearly holds that  $\sum_{j\neq i}\relu(x_j-x_i)\ge 2^{-s}$ and thus $g_i\le 0$. 
    This construction can clearly be implemented by a 3-layer $\relu$ network with $s$ precision.
\end{proof}
```
```{=latex}
\begin{lemma}[Simulating Decoding and Embedding using MLP]\label{lem:simulating_decoding_embedding}
    Given any $s$-precision $\theta_{\tokenembedding}\in\mathbb{R}^{d\times \Sigma}$ and $\theta_{\transoutput}$, there is a 5-layer network $f:\mathbb{R}^d\to \mathbb{R}^d$ with $\relu$ activation and $\max(|\Sigma|^2)$ width with $s$-precision, such that for all $s$-precision $x\in\mathbb{R}^d$ which admits unique $\argmax$ for $v\triangleq \argmax_{o\in\Sigma} (x^\top\theta_{\transoutput})(o)$, it holds that
    \begin{align*}
        f(x) =  \theta_{\tokenembedding}(v).
    \end{align*}
\end{lemma}
```
```{=latex}
\begin{proof}[Proof of \Cref{lem:simulating_decoding_embedding}]This is a simple application of \Cref{lem:simulating_hard_argmax}.
\end{proof}
```
```{=latex}
\begin{lemma}[Control Gate]\label{lem:control_gate}
    A 2-layer $\relu$ network with precision $s$ can implement $F:\Floating_s\times \Floating_s\times \{0,1\}, F(x,y,M) = Mx + (1-M)y $. 
\end{lemma}
```
```{=latex}
\begin{proof}[Proof of \Cref{lem:control_gate}]Note that $F(x,y,M) = \relu(x-2^s\cdot (1-M)) - \relu(-x-2^s\cdot (1-M)) + \relu(y-2^s\cdot M) - \relu(-y-2^s\cdot M)$. The proof is completed.
\end{proof}
```
```{=latex}
\begin{definition}[Transformer Block with Mask]\label{defi:embedding_layer_w_mask}
Given number of layers $L\in\mathbb{N}^+$, parameter $\theta_\tfblock = (\theta^{(l)}_\mha,\theta^{(l)}_\ff )_{l=0}^{L-1}$, and mask function $M:\mathbb{N}\to\{0,1\}$, we define the $L$-layer \emph{transformer block with mask}$\tfblock_{\theta_\tfblock,M}:(\mathbb{R}^d)^n\to (\mathbb{R}^d)^n$ for any $n\in\mathbb{N}^+$ as 
\begin{equation}
    [\tfblock_{\theta_\tfblock,M}(\vx)]_i \triangleq (1-M(i))x_i + M(i)[\tfblock_{\theta_\tfblock}(\vx)]_i
\end{equation}
\end{definition}
```
```{=latex}
\begin{definition}[Looped Transformer with Mask]\label{defi:looped_transformer_w_mask}
Given number of loops $T\in\mathbb{N}^+$,  parameters $\theta = (\theta_{\tfblock},\theta_\tokenembedding,\theta_\posencoding,\theta_\transoutput)$, and mask functions $\{M^t\}_{t=1}^T$, where $\theta_{\transformer} = (\theta^{(l)}_\mha,\theta^{(l)}_\ff )_{l=0}^{L-1}$, we define the \emph{looped transformer with mask} as  $p_{\theta,T,M}\triangleq \transoutput_{\theta_\transoutput}\circ \tfblock_{\theta_\tfblock,M^T}\circ \cdots \tfblock_{\theta_\tfblock,M^1}\circ \embed_{\theta_\tokenembedding,\theta_\posencoding}$ and the corresponding deterministic version as $\transformer_{\theta,T,M}(v_1,\ldots,v_n) \triangleq \argmax_{v\in\mathcal{V}} p_{\theta,T,M}(v|v_1,\ldots,v_n)$.
\end{definition}
```
```{=latex}
\begin{definition}[Shifting Layer]\label{defi:shifting_layer}
    We define the \emph{shifting layer} $\shift:(\mathbb{R}^d)^n\to (\mathbb{R}^d)^n$ as the following for any $d,n\in\mathbb{N}^+$ and $x_1,\ldots, x_n\in\mathbb{R}^{d}$:
    \begin{align}
        \shift(x_1,x_2,x_3, \ldots,x_n) = (x_1,x_1,x_2,x_3,\ldots,x_{n-1}).
    \end{align}
\end{definition}
```
```{=latex}
\begin{lemma}\label{lem:shifting_layer}
    For input sequence length up to some integer $n$, $\shift$ could be implemented by a attention layer by concatenating each embedding $x_i$ with $(\sbin_s(i), \sbin_s(f(i)))$, where $n= \lceil \log_2 n +1\rceil$.
\end{lemma}
```
```{=latex}
\begin{proof}[Proof of \Cref{lem:shifting_layer}]It is equivalent to show we can construct an attention heads which computes $x_{f(i)}$ at each position $i$. 

    To do so, we just need to invoke \Cref{lem:simulating_hard_argmax} and use that to set key and query, so position $i$ attends to position $f(i)$. We set value at position $i$ to be $x_i$. This completes the proof. 
\end{proof}
```
```{=latex}
\begin{lemma}\label{lem:mlp_addition}
    For any positive integer $s>0$, there is a constant-depth MLP $F$ with $O(s)$ hidden neurons per layer and parameters in $\Floating_s$, such that for any input $\bin(x)\in\{-1,+1\}^s$, where $0\le x\le 2^s-1$, it holds that  
    \begin{align*}
        F(x) = \bin(x+1).
    \end{align*}
\end{lemma}
```
```{=latex}
\begin{proof}[Proof of \Cref{lem:mlp_addition}]By \Cref{lem:FB_simulating_boolean_gates}, it suffices to show that we can simulate $\bin(x)\mapsto \bin(x)+1$ using $O(s)$ wide, constant-depth boolean circuits with $\AND,\OR,\NOT$ gates with unlimited fan-ins. This is immediate by noting that 
    \begin{align}
        [\bin(x+1)]_i = [\bin(x+1)]_i \oplus \bigwedge_{j=1}^{i-1}[\bin(x+1)]_j
    \end{align}
\end{proof}
```
```{=latex}
\begin{lemma}\label{lem:mlp_comparison}
    For any positive integer $s>0$, there is a constant-depth MLP $F$ with $O(s)$ hidden neurons per layer and parameters in $\Floating_s$, such that for any input $(\bin(x),\bin(y))\in\{-1,+1\}^s$, where $0\le x,y\le 2^s-1$, it holds that  
    \begin{align*}
        F(x,y) = \indct{x > y}.
    \end{align*}
\end{lemma}
```
```{=latex}
\begin{proof}[Proof of \Cref{lem:mlp_comparison}]By \Cref{lem:FB_simulating_boolean_gates}, it suffices to show that we can simulate $\bin(x),\bin(y)\mapsto \indct{x> y}$ using $O(s)$ wide, constant-depth boolean circuits with $\AND,\OR,\NOT$ gates with unlimited fan-ins. This is immediate by noting that 
    \begin{align}
\indct{x > y} = \bigvee\limits_{i\in[s]} \left( ([\bin(x)]_i = 1) \wedge ([\bin(y)]_i = 0) \wedge \bigwedge\limits_{1\le j< i} \left([\bin(x)]_j = [\bin(y)]_j\right) \right).
    \end{align}
\end{proof}
```
```{=latex}
\begin{proof}[Proof of \Cref{thm:cot_formal}]

We consider mask $M^t(i) = \indct{i-t\ge n}$, which we call it CoT masking. Let $s = \lfloor \log (n+m) +1 \rfloor$, we use $\overline i$ to denote $\sbin_s(i)$ for $1\le i\le n+m$ for convenience. The embedding size of our newly constructed transformer is larger than the target transformer to be simulated by an additive constant of $3s$. We denote the new embedding at position $i$ after $t$th loop by $(x_i^{(t)}, p_i^{(t)})$, where $x_i^{(t)}$ will be the original embedding of the transformer to be simulated, and $p_i^{(t)}$ is of dimension $3s$ and only depends on $i$ and $t$. In particular, we can show that we can set $p_i^{(t)} \triangleq (\overline i,\overline{[i-2]_+ +1}, \overline{n+t})$ to save information about position and the number of loops --- $p_i^{(0)}$ is from the positional embedding and the update is due to \Cref{lem:mlp_addition}. The size of hidden dimension of MLP and attention (key, query, value) will also be increased by $O(s)$.

The proof contains two steps:
\begin{enumerate}
    \item To show that there is a transformer with CoT masks simulating the target transformer with $m$ steps of CoT and $L$ layers by looping its own $L+O(1)$ layer block $m$ times.
    \item To show that the above looped transformer with CoT masks can be simulated by a standard looped transformer without mask and with constantly many more layers.
\end{enumerate}

For the first claim, starting from the same parameter of the transformers with CoT $\theta$, we build a new looped model with parameter $\theta'$ with constantly many more layers in each transformer block and at most constantly many more heads per attention layer. First, we can add constantly many more layers to use MLP to simulate the decoding-encoding process using \Cref{lem:simulating_decoding_embedding}.
Next, we can add one more transformer layer in each block and use the attention layer to simulate the shifting layer by~\Cref{lem:shifting_layer}, since we have the additional position embedding $p_i^(t)$.
In particular, the embedding we get at position $n+t$ after $t$ loops, $x_{n+t}^{(t)}$,  now simulates the token embedding of $n+t$ of the CoT transformer.
By the way we define CoT mask $M$, for every $t\ge -n+1$, the embedding $\hat{x}_{n+t}^{(t')}$ will keep the same for all $t'\ge \max(t,0)$. 
In $t$th loop, the only embedding update that matters happens at $n+t$th position, because no updates happen at earlier positions, and updates at later positions $n+t'$ for some $t'>t$ will be overwritten eventually in the future loops $t'$, by some value which is independent of their value at the current loop $t$. In the end, we know the embedding $x_i^{(T)}$ in \Cref{defi:looped_transformer_w_mask} is exactly equal to that in CoT transformer, and so does the final output.

For the second claim, because CoT mask can be computed by a $O(\log(n+m))$ wide, constant-depth MLP~(\Cref{lem:mlp_comparison}), together with \Cref{lem:control_gate}, we know it suffices to increase the number of layers per transformer block and embedding size and hidden dimensions by constant to simulate the transformer with mask by a transformer without mask. 
\end{proof}
```
```{=latex}
\begin{theorem}[Restatement of Theorem 4.2, \citep{sanford2024transformers}]\label{thm:khop_log_depth}
    $p$-hop problem (\Cref{defi:khop}) can be solved by $\lfloor\log_2 p\rfloor+2$-layer non-looped transformer with $\log n$ bits of precision, at most $3$ different layers, $d=d_{\ff}=d_{\attn}=O(1)$ embedding size, hidden dimension for MLP and attention, $1$ attention head.
\end{theorem}
```

[^1]: Similar to [@ye2024physics], the simplified setting still allows for over 7 billion unique solution templates.

[^2]: Though in this paper we focus on attention with casual mask, our definition of looped transformer generalizes to the cases with other attention masks.
