---
abstract: |
  Training recurrent neural networks (RNNs) requires assigning credit across long sequences of computations. Standard backpropagation through time (BPTT) addresses this problem poorly: it is sequential in time, limiting parallelism, and suffers from vanishing or exploding gradients, making long-range associations difficult to learn. We propose *Supervised Memory Training* (SMT), a method for training nonlinear RNNs that sidesteps recurrent credit propagation entirely by reducing RNN training to supervised learning on one-step memory transition labels $(m_t, x_{t+1}) \rightarrow m_{t+1}$. SMT acquires these memory labels by training a Transformer-based encoder on a predictive state objective---retaining only information from the past necessary to predict the future. By decoupling *what to remember* from *how to update memory*, SMT enables time-parallel RNN training with a stable $\mathcal{O}(1)$ length gradient path between any two tokens---without ever unrolling the RNN. We find that SMT outperforms BPTT when pretraining various RNN architectures on tasks like language modeling and pixel sequence modeling. SMT enables nonlinear RNNs to better capture long-range dependencies and train in parallel, potentially unlocking the scaling of models that build temporal abstractions of past experience.

  ```{=latex}
  \begin{tcolorbox}[
      enhanced,
      boxrule=0pt,
      frame hidden,
      left=5pt,
      right=5pt,
      top=4pt,
      bottom=4pt
  ]\textsf{\textbf{Project Page:}} \href{https://akarshkumar.com/smt}{\texttt{akarshkumar.com/smt}} \\
  \textsf{\textbf{Source Code:}} \href{https://github.com/akarshkumar0101/smt}{\texttt{github.com/akarshkumar0101/smt}}
  \end{tcolorbox}
  ```
author:
- |
  Akarsh Kumar `\qquad `{=latex}`\qquad`{=latex} Phillip Isola\
  \
  `\quad `{=latex}`\quad `{=latex}MIT
bibliography:
- references.bib
title: Pretraining Recurrent Networks without Recurrence
---

```{=latex}
\DeclareMathOperator*{\argmax}{arg\,max}
```
```{=latex}
\DeclareMathOperator*{\argmin}{arg\,min}
```
```{=latex}
\maketitle
```
![ **BPTT vs SMT.** **Left:** BPTT trains an RNN by recurrently unrolling the \`\`updater" network in time, and backpropagating gradients through the entire graph. **Right:** Supervised Memory Training (SMT) trains an RNN with supervised learning on one-step memory transition labels, which are generated by a Transformer encoder-decoder model pair trained to produce *predictive states*. SMT is fully time-parallel. In SMT, the longest gradient path between tokens is $\mathcal O(1)$ (compared to $\mathcal O(T)$ in BPTT), which stabilizes gradients, making learning long-range dependencies qualitatively easier. ](figures/teaser_methods.png){#fig:teaser width="1.0\\linewidth"}

```{=latex}
\newpage
```
# Introduction {#sec:intro}

Recurrent neural networks (RNNs) store information about the past that will only become useful in the future. The core training challenge is that the utility of a memory may be delayed: many intermediate computations intervene between writing information and eventually using it. These intervening steps confound learning the correct associations, a problem known as *credit assignment* [@minsky1961steps].

The standard approach, backpropagation through time (BPTT), assigns credit across a sequence by unrolling the RNN in time and propagating gradients backward through the resulting computation graph [@rumelhart1986learning; @werbos1990backpropagation]. Although conceptually well-motivated, BPTT is sequential in time and suffers from unstable high variance gradients that may vanish or explode [@pascanu2013difficulty]. The lack of time-parallelism makes BPTT scale poorly, while its gradient instability makes learning long-range associations difficult, as credit must propagate across up to $\mathcal{O}(T)$ steps [@bengio1994learning]. Is recurrent credit propagation unavoidable?

In this paper, we propose *Supervised Memory Training* (SMT), a method to train nonlinear RNNs that sidesteps recurrent credit propagation by reducing the problem to supervised learning. Suppose we had access to the optimal memory state at each timestep, $m^*_t$. Then, RNN training reduces to learning the one-step update $(m^*_t, x_{t+1})\rightarrow m^*_{t+1}$ using standard supervised objectives.

The challenge, of course, is how to actually obtain such memory labels. In this paper, we assert that an effective memory is a sufficient statistic of the past for predicting the future, i.e., a predictive state [@littman2001predictive]. The past is typically viewed as a sequence, suggesting that memory must be computed sequentially over time. Our key insight is that, by augmenting each observation with its timestamp, the past can instead be losslessly represented as a *set* of timestamped events, rather than a sequence. Under this reparameterization, the optimal memory becomes a permutation-invariant function of this set, and can therefore be estimated using models that operate in parallel over time. This reframing allows us to train memory representations *without recurrently propagating credit through time*.

In practice, we train a Transformer encoder model to embed the past context into a memory that a separate decoder can use to predict the future. This objective operationalizes the notion of a predictive state: a representation of the past that retains only the information needed to predict the future and nothing more. Once this teacher encoder has learned to construct such memory representations, the RNN can then focus on learning the now much simpler task of updating that memory over time.

In essence, SMT decouples learning *what to remember* (memory representation), which is a non-sequential problem, from learning *how to update memory* (memory dynamics), which is a sequential process but can be supervised one-step at a time. This decoupling enables time-parallel training of nonlinear RNNs *without unrolling*, and creates a stable $\mathcal{O}(1)$ gradient path for long-range associations.

Indeed, Transformers solved time-parallelism and credit assignment in the same way [@vaswani2017attention], and have since revolutionized sequence modeling [@brown2020language]. However, Transformers do not possess a compressed memory of the past in the way RNNs or human brains do [@gu2025tradeoffs; @kandel2000principles]. Instead, Transformers store the *entire history* of past token representations and attend to *all of them* when processing each new token. As a result, their memory size grows with sequence length, leading to prohibitive computational costs for unbounded sequences, such as a human lifetime of experience [@tay2022efficient]. Sliding-window transformers mitigate this issue by storing only the most recent tokens, but have the severe drawback that they lose access to information before the context window [@dai2019transformer]. In contrast, no known biological intelligence operates in this manner---accessing its entire experiential history for every new decision---but instead constructs a temporally compressed abstraction of past experience, like an RNN [@bennett2023brief].

Linear attention RNN models also exhibit time-parallel training and relatively stable credit assignment, while maintaining a fixed memory size [@katharopoulos2020transformers; @gu2021efficiently; @gu2023mamba; @dao2024transformers]. But, because their transition function is linear, the class of functions they can represent is fundamentally constrained [@merrill2026linear], which can lead to failure on important sequential tasks such as state tracking [@merrill2024illusion; @liu2025serial].

SMT aims to combine the best of all worlds: time-parallel training, stable $\mathcal O(1)$ long-range credit assignment, fixed-memory inference, and maximal expressivity via nonlinear dynamics. Our results confirm that, on language modeling and pixel sequence modeling tasks, SMT outperforms BPTT in learning long-range dependencies while requiring less sequential computation. SMT should primarily be used for pretraining RNNs, followed by some lightweight post-training to mitigate drift from the teacher memory trajectories and adapt to specific downstream tasks. In fact, post-training is necessary to go beyond the limitations of the teacher encoder [@merrill2023parallelism]. Beyond its role as a training approach for RNNs, SMT can also be seen as a new method for learning representations (mappings from data to latent variables) and for learning world models (transitions from state at time $t$ to state at time $t+1$).

# Methods {#sec:methods}

## Background {#sec:methods_background}

#### Causal Conditional Sequence Modeling

Let $\mathbf{x} = [x_0, \dots, x_T]$ and $\mathbf{y} = [y_0, \dots, y_T]$ denote input and output sequences. The objective is to learn a model of the conditional distribution $p(\mathbf{y} \mid \mathbf{x})$. We assume each output $y_t$ depends only on $x_0,\dots,x_t$. Formally we model this distribution with $\prod_{t=0}^{T} p_\theta(y_t \mid \mathbf{x}_{\leq t})$. Autoregressive sequence modeling is a special case when $x_t = y_{t-1}$.

#### Recurrent Neural Networks (RNNs)

An RNN models this problem using a fixed-size latent state, $m_t$, that summarizes past inputs. At each timestep, this state is updated according to: $$\label{eq:rnn_update}
m_{t+1} = f_\theta(m_t, x_{t+1})$$ where $f_\theta$ is the transition function. The predicted output token distribution is then $p_\theta(y_{t} \mid \mathbf{x}_{\leq t}) = \text{softmax}(g_\theta(m_t))$, where $g_\theta$ is the readout function. Ideally, $m_t$ \`\`remembers" important information from the past and intentionally \`\`forgets" unimportant information, i.e., $m_t$ is a *memory*.

#### Backpropagation Through Time (BPTT)

Traditionally, RNNs are trained with BPTT [@rumelhart1986learning; @werbos1990backpropagation]. In the forward pass, $f_\theta$ is recurrently unrolled over the sequence. The input sequence $x_t$ is provided via teacher forcing, while the memory sequence $m_t$ is generated by the RNN's transition and is used to compute the output predictions. Conceptually, the computation graph takes the form: $$m_t = f_\theta(\ldots f_\theta(f_\theta(m_{\emptyset}, x_0), x_1), \ldots , x_t) \quad\text{with}\quad m_{\emptyset}=\mathbf{0}$$ Gradients are then computed end-to-end on this unrolled computation graph, propagating from the output prediction losses backward through the trajectory of the nonlinear dynamical system. Thus, this gradient credit assignment signal may have to travel for a path length of up to $\mathcal{O}(T)$ steps. Depending on the singular values of the Jacobian of $f_\theta$, gradients may vanish or explode in time.

BPTT has two well-known limitations:

1.  Equation `\ref{eq:rnn_update}`{=latex} is usually implemented with a recurrent for-loop, preventing parallelization [@pascanu2013construct].

2.  BPTT often produces unstable high variance gradients [@bengio1994learning]. When gradients vanish, the RNN experiences a recency bias, hindering the learning of long-range associations [@ravfogel2019studying]. When gradients explode, the induced dynamical system is chaotic, causing training instability [@pascanu2013difficulty].

## Supervised Memory Training (SMT) {#sec:smt}

We propose *Supervised Memory Training* (SMT) for pretraining nonlinear RNNs without BPTT. The core idea is to decouple the learning of memory representation from memory dynamics.

#### Motivation

Consider a hypothetical oracle memory-encoding model $\mathcal{Q}$, that takes as input the sequence of tokens up to timestep $t$, $\mathbf{x}_t^\text{ctx}=[x_0,x_1,\dots,x_t]$, and outputs an effective compressed memory for that timestep $m_t^*=\mathcal{Q}(\mathbf{x}_t^\text{ctx})$. This memory retains all information from the past input that is relevant for predicting the future output, $\mathbf{y}_t^\text{fut}=[y_{t},\dots,y_T]$, while deliberately discarding unimportant details. For example, the oracle would remember the personalities of characters in a story, but discard details of what they were wearing on a specific day, just as humans do. Running $\mathcal Q$ at different points along the sequence produces a corresponding sequence of memory labels $[m_0^*, m_1^*, \dots, m_T^*]$. With $\mathcal Q$, the RNN's problem of learning a temporal update collapses to standard supervised learning on oracle memory transitions labels $(m_t^*, x_{t+1})\rightarrow m_{t+1}^*$. Our key insight is that $\mathcal{Q}$ does *not* have to be a recurrent function over $[x_0, x_1, \dots, x_t]$, but can instead be represented as a permutation-invariant function over the *set* $\{(x_0, 0), (x_1, 1), \dots, (x_t, t)\}$ (details in Appendix `\ref{sec:set_reframing}`{=latex}).

In practice, SMT approximates $\mathcal Q$ by training a time-parallel model (e.g. a Transformer) to compress the past input into a memory representation that a separate decoder model can use to predict the future output. This future-predicting objective operationalizes the notion of a predictive state [@littman2001predictive].

#### Formulation

Formally, we have the RNN $f_\theta$, bidirectional encoder $\mathcal{E}_\phi$, and causal decoder $\mathcal{D}_\psi$. $\mathcal E_\phi$ and $\mathcal D_\psi$ are time-parallel Transformer architectures. Given $\mathbf{x} = [x_0, x_1, \dots, x_T]$ and $\mathbf{y} = [y_0, y_1, \dots, y_T]$, we consider, for each timestep $t$, a decomposition into the past and future: $$\mathbf{x}_t^{\text{ctx}} = [x_0, \dots, x_t] \qquad
\mathbf{x}_t^{\text{fut}} = [x_{t+1}, \dots, x_T] \qquad
\mathbf{y}_t^{\text{fut}} = [y_{t}, \dots, y_T]$$ The encoder maps each context to a memory state with $m_t = \mathcal E_\phi(\mathbf{x}_t^{\text{ctx}})$. Then, the decoder predicts the future output distribution using the memory of the past and teacher forced *future inputs*: $$p_{\phi,\psi}(\mathbf{y}_t^\text{fut} \mid \mathbf{x}_t^\text{ctx}, \mathbf{x}_t^\text{fut})
= \prod_{\tau=t}^{T} p_\psi(y_{\tau} \mid m_t, \mathbf{x}_{t+1:\tau})
= \mathcal D_\psi(m_t, \mathbf{x}_t^\text{fut})$$ The future decoding loss for timestep $t$ is ($\mathrm{CE}$ denotes the sequence level cross-entropy loss): $$\label{eq:loss_dec}
\mathcal{L}_t^\text{dec} =
\mathrm{CE}\left(
\mathbf{y}_t^\text{fut},
p_{\phi,\psi}(\mathbf{y}_t^\text{fut} \mid \mathbf{x}_t^\text{ctx}, \mathbf{x}_t^\text{fut})
\right)$$

We have the RNN predict the next memory given the current memory and the next input with $\hat{m}_{t+1} = f_\theta(m_t, x_{t+1})$. This prediction is supervised with the next timestep's memory: $$\label{eq:loss_dyn}
\mathcal{L}^\text{dyn}_t = \mathrm{MSE} (\hat{m}_{t+1}, m_{t+1})$$ This dynamics loss has two distinct purposes: 1) to train the RNN and 2) to explicitly shape the encoder memory representations to be *Markovian* (i.e. $m_{t+1}$ is predictable solely from $(m_t, x_{t+1})$).

We add a uniformity loss [@wang2020understanding] to prevent the memory space from collapsing: $$\mathcal{L}^\text{unif} = \log \mathbb{E}_{t_a, t_b \sim [0,\dots,T]} \exp(-2\| m_{t_a} - m_{t_b}\|_2^2)$$

The full objective is a weighted sum of all three losses: $$\label{eq:smt}
\mathcal{L}^\text{smt} =
\lambda_\text{dec} \mathbb{E}_t\left[ \mathcal{L}^{\text{dec}}_t \right]
+ \lambda_\text{dyn} \mathbb{E}_t\left[ \mathcal{L}^{\text{dyn}}_t \right]
+ \lambda_\text{unif} \mathcal{L}^\text{unif}$$ where the $\lambda$ terms control the trade-off between memory representation, dynamics, and collapse.

#### Practice

Theoretically, it should be enough to train $\mathcal{E}_\phi$ and $\mathcal D_\psi$ with only $\mathcal{L}^{\text{dec}}$, and separately train $f_\theta$ with only $\mathcal{L}^{\text{dyn}}$ (proof in Appendix `\ref{sec:markovian_encoder_proof}`{=latex}). However, in practice we find it beneficial to jointly train all models in one stage with $\mathcal{L}^{\text{smt}}$, since that explicitly optimizes $m_t$ to be Markovian, and provides additional temporal credit propagation benefits described in Section `\ref{sec:detach_rnn_experiments}`{=latex}.

For experiments, we truncate $\mathbf{x}_t^\text{ctx}$ to a context length $T_c$ and $\mathbf{y}_t^\text{fut}$ to a future length $T_f$. For computational efficiency, we estimate the expectation in $\mathcal{L}^{\text{smt}}$ by randomly sampling a single timestep $t$, rather than computing all timesteps in the sequence. This yields SMT a smaller training memory footprint than BPTT: $\mathcal{O}(M+T)$ instead of $\mathcal{O}(MT)$, where $M$ is the memory size.

#### Properties of SMT

In SMT, the encoder model constructs appropriate memory representations of the past, while the RNN is responsible for learning the now much simpler task of updating that memory in one-step, thereby decoupling memory representation from memory dynamics. In contrast, under BPTT training, the RNN must learn both tasks simultaneously. Since the memory labels are acquired with a \`\`teacher" encoder-decoder pair, SMT inherits all of its properties, such as time-parallelism, $\mathcal{O}(1)$ credit path for long-range associations, and gradient stability.

```{=latex}
\begin{wrapfigure}{r}{0.45\textwidth}


    \includegraphics[width=0.44\textwidth]{figures/smt_dmt_path.pdf}
    \caption{
    \label{fig:dmt}
    \textbf{SMT vs DMT.}
    SMT trains the RNN with behavior cloning on the encoder-generated memory states (off-policy imitation learning).
    DMT unrolls the RNN with its own memory states and then imitates the encoder trajectory (on-policy imitation learning).
    Figure design inspired by~\citet{jacobs2025block}.
    }
\end{wrapfigure}
```
## DAgger Memory Training (DMT) {#sec:dmt}

After SMT, the RNN achieves low one-step error in predicting $(m_t, x_{t+1})\rightarrow m_{t+1}$ when $m_t$ comes from the encoder. However, at evaluation time, the model is unrolled autoregressively, using its own predicted memories rather than the encoder memories as input. This train--test mismatch causes small prediction errors to accumulate over time, leading to a growing drift between the RNN-generated memory trajectory $[\hat{m}_0, \dots, \hat{m}_T]$ and the encoder trajectory $[m_0, \dots, m_T]$, even with teacher forced input tokens. This drift is quantified as $\delta_t = \text{MSE}(\hat m_t, m_t)$.

We introduce *DAgger Memory Training* (DMT), a finetuning phase that corrects this drift via *on-policy imitation learning* [@ross2011reductionimitationlearningstructured]. By exposing the RNN to its own induced memory state distribution, DMT trains the RNN to autocorrect its errors to stay aligned with the encoder trajectory (Figure `\ref{fig:dmt}`{=latex}).

Concretely, given $\mathbf{x}$, we first compute the encoder trajectory $[m_0, \dots, m_T]$ using only $\mathcal E_\phi$ and then the RNN trajectory $[\hat m_0, \dots, \hat m_T]$ using $f_\theta$. Instead of training on SMT labels $(m_t, x_{t+1}) \rightarrow m_{t+1}$, we train on DMT labels $(\hat m_t, x_{t+1}) \rightarrow m_{t+1}$. Equivalently, the training loss is: $$\mathcal{L}^\text{dmt}=\mathbb{E}_t \left[ \mathrm{MSE}(\hat{m}_t, m_t) \right]$$

During DMT, we freeze the encoder and decoder and only train the RNN with a small learning rate. Note that DMT unrolls the RNN memories, but still uses teacher forced $x_t$ inputs. Although DMT unrolls the RNN and gradients may optionally propagate through time, its objective is fundamentally different than standard BPTT, since long-range credit is already assigned in the encoder memory labels, $m_t$. DMT is not time-parallel. That said, DMT should primarily be viewed as a lightweight fine-tuning phase following SMT. Table `\ref{tab:resources_bigo}`{=latex} shows the resource requirements for the different methods.

```{=latex}
\resizebox{\textwidth}{!}{%
\begin{tabular}{lccccc|cc|c}
\toprule
& \shortstack{Training ($T$-Length Sequence)} & & \shortstack{Inference (One-Step)} & \shortstack{Complexity Class} \\
\cmidrule(lr){2-5} \cmidrule(lr){7-8} \cmidrule(lr){9-9}
Method & Memory & Compute & \shortstack{\textbf{Sequential} \\ \textbf{Operations}} & \shortstack{\textbf{Credit Path} \\ \textbf{Length}} & & Memory & Compute & \\
\midrule
Transformer      & $\mathcal{O}(T)$      & $\mathcal{O}(T^2)$    & $\mathcal{O}(1)$ & $\mathcal{O}(1)$ & & $\mathcal{O}(T)$ & $\mathcal{O}(T)$ & $\mathrm{TC}^0$ \\
LA (parallel) & $\mathcal{O}(T)$       & $\mathcal{O}(T^2)$       & $\mathcal{O}(1)$ & $\mathcal{O}(1)$ & & $\mathcal{O}(M)$ & $\mathcal{O}(1)$ & $\mathrm{PNC}^1$ \\
LA (recurrent) & $\mathcal{O}(M)$       & $\mathcal{O}(T)$       & $\mathcal{O}(1)$ & $\mathcal{O}(1)$ & & $\mathcal{O}(M)$ & $\mathcal{O}(1)$ & $\mathrm{PNC}^1$ \\
BPTT (RNN)       & $\mathcal{O}(MT)$     & $\mathcal{O}(T)$      & $\mathcal{O}(T)$ & $\mathcal{O}(T)$ & & $\mathcal{O}(M)$ & $\mathcal{O}(1)$ & $\mathrm{L/P}$ \\
\textbf{SMT} (RNN)        & $\mathcal{O}(M+T_c)$ & $\mathcal{O}(T_c^2)$ & $\mathcal{O}(1)$ & $\mathcal{O}(1)$ & & $\mathcal{O}(M)$ & $\mathcal{O}(1)$ & $\mathrm{L/P}$ \\
\textbf{DMT} (RNN)        & $\mathcal{O}(MT)$  & $\mathcal{O}(T_c^2T)$  & $\mathcal{O}(T)$  & $\mathcal{O}(1)$ & & $\mathcal{O}(M)$ & $\mathcal{O}(1)$ & $\mathrm{L/P}$ \\
\bottomrule
\end{tabular}%
}
```
# Experiments {#sec:experiments}

We study the properties of SMT and compare against BPTT, the standard RNN training algorithm. We restrict our analysis to nonlinear RNNs, the primarily setting BPTT is applied. Transformers and linear RNNs are excluded as they are qualitatively distinct model classes [@merrill2026linear; @gu2025tradeoffs].

\`\`BPTT RNN" denotes the BPTT baseline. \`\`SMT Encoder$^*$" generates memories $m_t$ with the SMT-trained encoder and predicts next tokens using the decoder. This method is essentially a Transformer baseline with the same memory bottleneck as our RNNs. Since it serves as the teacher during SMT and DMT, it provides a reference upper bound on RNN performance. \`\`SMT$\rightarrow$DMT RNN" denotes the RNN pretrained with SMT and finetuned with DMT, which constitutes our full method.

#### Architectures

We use RNN architectures based on a Transformer, MLP, and GRU [@cho2014properties] backbone.

#### Datasets

We consider character-level language modeling on TinyStories [@eldan2023tinystories] as a naturalistic task requiring long-range memory [@sutskever2011generating]. As a more challenging problem, we test our method on raster-scan order pixel sequence modeling of sparse images from MNIST [@lecun1998mnist] and Sketchy [@sangkloy2016sketchy]. This is a hard problem for RNNs [@lamb2016professor; @van2016pixel]. Imagine you are an ant traversing an image pixel by pixel, row by row. When you see a new white pixel, in order to recognize the shape and slope of the stroke it belongs to, you must remember the white pixels you saw in the previous rows, which may be hundreds of timesteps ago, buried among black pixels. RNNs must achieve this with finite memory, meaning no direct attention to earlier pixels, and thus forcing long-range memory to emerge. We term this \`\`Attneave's task", based on classic work from perceptual psychology [@attneave1954some].

More details on architectures, datasets, and experiments are in Appendix `\ref{sec:appendix_experiment_details}`{=latex}.

## Synthetic Task Experiments {#sec:synthetic_experiments}

![ **Synthetic Task Experiments.** We evaluate BPTT, SMT, and SMT$\rightarrow$DMT using five synthetic tasks with various settings to probe different properties of the algorithms. $^*$ signifies that the SMT Encoder is the teacher Transformer (*not* an RNN) and is used only as a reference. Across all tasks and task settings, SMT$\rightarrow$DMT outperforms BPTT, signaling that SMT has better gradient properties, memory utilization, state tracking, associative recall, and in-context learning than BPTT. ](figures/synthetic.png){#fig:synthetic_results width="1.0\\linewidth"}

We first evaluate BPTT and SMT on synthetic tasks designed to isolate and probe specific properties of the training algorithm. The RNN architecture with the Transformer backbone is used for these experiments. For these synthetic experiments, we set $T_c=T_f=T$ and train all timesteps in the $\mathcal{L}^\text{smt}$ expected value. Our tasks include the following (details of tasks are in Appendix `\ref{sec:synth_task_desc}`{=latex}):

1.  Retrieval to test **Gradient Stability** (sweep sequence length and noise level).

2.  String Copy to test **Memory Capacity** (sweep sequence length and memory state size).

3.  Stack Operations to test **State Tracking** (sweep sequence length and state complexity).

4.  Keys-Values to test **Associative Recall** (sweep number of and complexity of associations).

5.  Modular Arithmetic to test **In-Context Learning** (sweep difficulty and number of examples).

Figure `\ref{fig:synthetic_results}`{=latex} shows that SMT$\rightarrow$DMT outperforms BPTT in all settings of all tasks. BPTT struggles to learn as sequences get longer, even when the task is simple, e.g. retrieval. It also struggles to utilize memory capacity fully, do associative recall, and perform in-context learning, all of which require solving long-range credit assignment. In contrast, SMT seems agnostic to the sequence length, and is able to solve all of the harder credit assignment problems except associative recall. We attribute these differences to BPTT's $\mathcal{O}(T)$ credit path length, compared to SMT's $\mathcal{O}(1)$. Further analysis in Section `\ref{sec:vanishing_grads}`{=latex} confirms the difference in gradient stability in both methods.

```{=latex}
\newpage
```
## Attneave's Pixel Sequence Modeling {#sec:attneaves_experiments}

```{=latex}
\begin{wrapfigure}{r}{0.5\textwidth}




    \includegraphics[width=\linewidth]{figures/mnist_generations.pdf}
    \captionof{figure}{%
      \textbf{Attneave's MNIST Generation.}
      BPTT fails to effectively capture the long-range dependencies required for pixel sequence modeling, even with a GRU.
      SMT$\rightarrow$DMT captures these dependencies with a non-gated RNN architecture.
      More samples are in Appendix Figure~\ref{fig:mnist_generations_extra}.
    }
    \label{fig:mnist_generations}






    \includegraphics[width=\linewidth]{figures/sketchy_generations.pdf}
    \captionof{figure}{%
      \textbf{Attneave's Sketchy Generation.}
      SMT$\rightarrow$DMT captures the stroke structure of human-drawn sketches through only pixel sequence modeling on sparse images.
      More samples are in Appendix Figure~\ref{fig:sketchy_generations_extra}.
    }
    \label{fig:sketchy_generations}


\end{wrapfigure}
```
We now evaluate on Attneave's tasks. Figure `\ref{fig:mnist_generations}`{=latex} shows the stark difference between MNIST samples generated by RNNs trained with BPTT and SMT$\rightarrow$DMT. Figure `\ref{fig:sketchy_generations}`{=latex} shows images generated by an SMT$\rightarrow$DMT RNN trained on Sketchy.

Along with the synthetic experiments, these results confirm that SMT doesn't suffer from a recency bias like BPTT, allowing it to properly attribute credit across long sequences.

## Sequential Compute and Data {#sec:seqflops_experiments}

We now evaluate BPTT and SMT across real domains and various RNN architectures. Each method is allowed $N$ optimization steps on token batches of shape $B\times T$ (number of sequences $\times$ sequence length). We sweep $N$, $B$, and $T$ for each method to profile how much sequential compute and data each method uses to achieve a target performance. Sequential compute, measured in sequential FLOPs, is a metric proportional to the amount of inherently serial steps required to do the computation ($\sim$ time it would take on an infinitely parallel computer). Sequential compute is a useful quantity because modern hardware is highly parallel, making it the primary constraint in large-scale model training [@hooker2021hardware]. Data is measured by the number of tokens processed by the model during training. We elaborate on how sequential FLOPs and data is calculated for each method in Appendix `\ref{sec:appendix_definitions}`{=latex}.

Figure `\ref{fig:seqflops}`{=latex} shows the results. In **sequential compute**, SMT Encoder and SMT$\rightarrow$DMT RNN are significantly more efficient than BPTT with the Transformer and MLP backbones. In **data**, SMT Encoder and SMT$\rightarrow$DMT RNN has approximately the same data efficiency as BPTT with the Transformer and MLP backbones on TinyStories. However on MNIST, SMT Encoder and SMT$\rightarrow$DMT RNN shows significantly better data efficiency. This result is explained by the short vs long range memory information requirements of natural language [@fang2025wrongperplexitylongcontextlanguage] vs pixel sequence modeling [@tay2020longrangearenabenchmark]. SMT$\rightarrow$DMT is unable to train GRU RNNs, because the GRU architecture induces memory space collapse during SMT training, degrading RNN rollout.

![ **Sequential Compute and Data Efficiency.** We sweep training hyperparameters for BPTT, SMT, and SMT$\rightarrow$DMT and plot the resulting runs' performance along sequential compute (SeqFLOPs) used and data processed (Tokens), across different RNN architectures and datasets. Runs are capped at one day on an H200 GPU. $^*$ signifies that the SMT Encoder is the teacher Transformer (*not* an RNN) and is used only as a reference. Generally, SMT and SMT$\rightarrow$DMT are more efficient than BPTT in sequential compute, and around the same or better efficiency in data. ](figures/seqflops.png){#fig:seqflops width="1.0\\linewidth"}

## Scaling Laws {#sec:scalinglaws_experiments}

<figure id="fig:model_scaling">
<img src="figures/memory_Tc_scaling.png" />
<img src="figures/model_scaling.png" />
<figcaption> <strong>Scaling Model Size.</strong> Sweeping the width and depth of the RNN and teacher shows smooth performance improvements in TinyStories. The RNN imitates the teacher performance better at larger scale. </figcaption>
</figure>

We evaluate the scaling behavior of SMT$\rightarrow$DMT along three axes: context length, memory state size, and model parameter count. For the first two, we logarithmically sweep $T_c$ and the number of memory tokens in the Transformer-based RNN (Figure `\ref{fig:architecture}`{=latex}). For model scaling, we vary the width and depth of the RNN, encoder, and decoder. We use the TinyStories domain for these experiments.

Figure `\ref{fig:memory_Tc_scaling}`{=latex} shows that SMT$\rightarrow$DMT exhibits smooth, predictable performance improvements with larger context length and bigger memory state size. Together with the previous experiments, these results reaffirm that SMT effectively leverages long contexts and large memory states. Figure `\ref{fig:model_scaling}`{=latex} presents the parameter scaling results. The SMT encoder follows a standard power-law-like scaling trend. The SMT$\rightarrow$DMT RNN also improves smoothly with scale, albeit with a differently shaped scaling curve. Interestingly, the RNN appears to more closely match the encoder's performance at larger scales.

## Compression as a Scaling Axis {#sec:compression_experiments}

```{=latex}
\begin{wrapfigure}{r}{0.45\textwidth}


    \includegraphics[width=1.0\linewidth]{figures/compression_scaling.pdf}

    \caption{
    \label{fig:compression_scaling}
    \textbf{Scaling Laws for Compression.}
    We plot iso-loss contours for SMT-trained encoder models across a range of memory state sizes and training compute budgets.
    For a fixed target performance, SMT can achieve higher compression (smaller memory size) using additional compute.
    This result suggests a new property to scale when given more training compute: memory state compression.
    }

\end{wrapfigure}
```
Neural scaling laws predict the relationship between a resource (e.g. compute, data) and a desired property (e.g. validation loss, benchmark accuracy) [@kaplan2020scaling; @hoffmann2022training]. Can the desired property instead be *compression* [@hutter2005universal]? For RNNs, compression can be interpreted as achieving the same performance with a smaller memory state size. Thus, to answer this question, we train a set of SMT models on TinyStories across a sweep of memory state sizes and training compute budgets.

Figure `\ref{fig:compression_scaling}`{=latex} shows the scaling curve, confirming that SMT can achieve more compression when allocated more compute. Since compression is often speculated as being a core property of intelligent systems [@solomonoff1964formal; @kolmogorov1968three], scaling along this compression axis may be a desired direction forward for future sequence models. Notably, Transformers perform no compression of the past [@gu2025tradeoffs], which may explain their training efficiency.

## Ablations {#sec:ablations_experiments}

#### Predictive State and Detached RNN {#sec:detach_rnn_experiments}

The impact of the predictive state objective (Equation `\ref{eq:loss_dec}`{=latex}) is evaluated by sweeping the future length $T_f$, while keeping $T_c$ large enough to see the whole sequence. The impact of the dynamics objective (Equation `\ref{eq:loss_dyn}`{=latex}) on memory representation is tested by detaching the model computation graph with stop grads at two locations such that the gradients from $\mathcal{L}_t^\text{dyn}$ flow to the RNN, but *not* the encoder (**detached**); the non-detached SMT baseline is referred to as **joint**. This ablation isolates the contribution of explicitly training $m_t$ to be a Markovian representation.

Figure `\ref{fig:needle_tf_sweep}`{=latex} shows the results on the needle retrieval task. To solve the task, and thus have proper credit assignment, SMT requires either large enough $T_f$ *or* joint training. When $T_f$ is large enough, there is a $\mathcal O(1)$ credit path length between the needle and the answer at all timesteps. Interestingly, when $T_f$ is small, there exists no credit path to learn early timestep memories, yet joint training still learns effectively, even when $T_f=1$. Credit must be propagating through the RNN dynamics from $m_{T}$ to $m_{T-1}$, and so on, to $m_0$. But because the RNN is never unrolled, there is no computation graph for credit to propagate directly. The only explanation is that **credit is being amortized into gradient optimization steps**. Each optimization step sends information from $m_t$ to $m_{t-1}$ through $f_\theta$; $T$ such gradients steps sends information $T$ steps back in the sequence. This implies that solving $T$ sequence length credit assignment task when $T_f=1$, requires at least $T$ *gradient optimization steps*. This credit amortization phenomenon is reminiscent of value bootstrapping in RL [@sutton1998reinforcement].

<figure id="fig:gradient_vanishing">
<img src="figures/needle_Tf_sweep.png" />
<img src="figures/vanishing_gradients.png" />
<figcaption> <strong>Gradient Properties of BPTT and SMT.</strong> In the needle retrieval task, the loss is applied at the last timestep. BPTT propagates gradients backward through all timesteps, risking vanishing/exploding gradients for each <span class="math inline"><em>m</em><sub><em>t</em></sub></span>, depending on the weight initialization. SMT is non-recurrent and has a <span class="math inline">𝒪(1)</span> credit path length, making its gradients agnostic to initialization and time-horizon. </figcaption>
</figure>

#### $\lambda$ Coefficients

The values of $\lambda_\text{dyn}$ and $\lambda_\text{unif}$ are swept here to check their effects. Figure `\ref{fig:lambda_ablation}`{=latex} shows the results. The best RNNs require $\lambda_\text{dyn}= 0.1$, and $\lambda_\text{unif}=0.001$. When $\lambda_\text{unif}=0$, although the RNN performance is preserved, the memory space is collapsed, as indicated by $\mathcal L^\text{unif}$.

#### Drift and DMT

As described in Section `\ref{sec:dmt}`{=latex}, RNN suffers from drift post-SMT. Figure `\ref{fig:dmt_analysis}`{=latex} shows an analysis of drift and DMT's mitigation of it. From a dynamical systems perspective, DMT seems to discover RNNs which have an initially higher drift, but which *plateau* at a much lower equilibrium drift. Interestingly, this equilibrium drift value is not fully predicted by the one-step drift, inviting future investigations into predicting and mitigating rollout drift in one-step during SMT.

![ **Impact of DMT** across many runs with different SMT $\lambda_\text{dec}$ and $\lambda_\text{dyn}$ hyperparameters. **Left:** Applying DMT reduces the drift of the RNN rollout (measured with $1-R^2$ of RNN memory prediction $\hat m_t$ of encoder ground truth $m_t$). **Middle:** DMT significantly improves RNN performance across settings. **Right:** The one-step drift of the RNN only partially correlates with the rollout drift. ](figures/drift_analysis.png){#fig:dmt_analysis width="1.0\\linewidth"}

## Analysis {#sec:analysis_experiments}

#### Gradient Properties of BPTT and SMT {#sec:vanishing_grads}

The fundamental difference between BPTT and SMT in long-range credit assignment is dictated by their gradients. Figure `\ref{fig:gradient_vanishing}`{=latex} shows the gradient magnitude of $m_t$, $\|\frac{\partial L}{\partial m_t}\|$, at different $t$ for both methods with different model weight initializations on the needle retrieval task. In BPTT, gradients vanish or explode over time, due to BPTT's gradient propagation through recurrent modules. In SMT, gradient magnitudes are independent of $t$, because the credit path length between tokens is independent of the sequence length. This result explains why SMT does not suffer a recency bias and is able to do stably perform long-horizon credit assignment.

```{=latex}
\begin{wrapfigure}{r}{0.45\textwidth}


    \includegraphics[width=1.0\linewidth]{figures/seqlen_generalization.pdf}
    \caption{
    \label{fig:seqlen_generalization}
    \textbf{Sequence Length Generalization.}
    An SMT$\rightarrow$DMT trained RNN generalizes better than its Transformer teacher when evaluated on sequence lengths longer than training.
    The task is synthetic state tracking.
    }

\end{wrapfigure}
```
#### Benefit of RNNs over Transformers

SMT trains an RNN to mimic a Transformer encoder model, raising the question of why an RNN is needed at all, given the Transformer. RNNs are qualitatively more efficient than Transformers at inference, requiring $O(1)$ rather than $O(T)$ memory and compute per generated token (Table `\ref{tab:resources_bigo}`{=latex}). RNNs also constitute a more expressive class of models [@merrill2026linear; @liu2025serial].

Here, we compare an SMT$\rightarrow$DMT RNN against a Transformer on the synthetic stack state tracking task. For a fair comparison, we use the SMT encoder as the Transformer baseline, since it imposes the same memory-information bottleneck as the RNN.

Figure `\ref{fig:seqlen_generalization}`{=latex} shows the Transformer outperforms the RNN on training sequence lengths, but significantly underperforms the RNN on sequence lengths longer than training. Prior work on length generalization reports similar findings [@press2021train]. This result reflects the distinct inductive biases of the architectures: Transformers behave like growing lookup tables in context, while RNNs update finite states [@gu2025tradeoffs]. The latter is a better inductive bias for generalization.

#### Memory Space

To better understand what SMT is learning, we train smaller SMT models that have a 2D memory state and directly visualize their memory space across three synthetic tasks in Figure `\ref{fig:memory_space}`{=latex}. In the retrieval tasks, SMT learns to collapse many sequence states into only a few effective memory states: an initial state, a state indicating the next token is the needle, and states corresponding to the needle value. Then, the RNN learns finite-state machine behavior to transition between these states. In contrast, string copying requires lossless sequence compression and thus SMT cannot alias distinct memory states together. It learns to create a tree-like memory geometry to store all possible sequences, matching the tree structure of all possible strings. Figure `\ref{fig:mnist_memory_pca}`{=latex} and Figure `\ref{fig:mnist_memory_tsne}`{=latex} show memory visualizations for models trained on MNIST. These results indicate SMT memories form effective temporal *abstractions* of the past depending on what the future requires.

<figure id="fig:memory_space">
<img src="figures/memory_space.png" />
<figcaption> <strong>Memory Space Visualization.</strong> The encoder learns different memory geometries for different tasks. In Retrieval, the encoder collapses many sequence states into a few memory states, creating finite-state machine like behavior. In String Copy, the encoder constructs a tree-like memory geometry to compress all possible sequences. Some geometries induce more complex RNN transition fields. </figcaption>
</figure>

# Related Works {#sec:related_works}

#### Recurrent Neural Networks (RNNs)

RNNs were studied extensively early in AI because their recurrence mechanism resembles biological brains [@mcculloch1943logical] and can be applied to any sequential task [@elman1990finding]. Many different algorithms were proposed for learning, including random guessing [@schmidhuber2001evaluating], evolutionary algorithms [@miller1989designing; @angeline1994evolutionary; @stanley2002evolving; @salimans2017evolution; @sarkar2025evolution], hebbian learning [@hebb1949organization; @hopfield1982neural; @miconi2018differentiableplasticitytrainingplastic; @najarro2022metalearninghebbianplasticityrandom], real-time recurrent learning [@williams1989learning], and other algorithms [@ollivier2015training; @bengio2015scheduledsamplingsequenceprediction; @lamb2016professor; @bai2019deep; @kag2021training]. BPTT is the only widely adopted algorithm [@werbos1990backpropagation].

However, it has repeatedly been shown that BPTT produces unstable gradients that vanish, explode, or exhibit high variance [@bengio1994learning; @hochreiter2001gradient; @pascanu2013difficulty; @bai2018empirical]. Several directions address this issue. One direction focuses on architectural modifications, including residual connections [@srivastava2015highway; @he2016deep] and gating mechanisms [@chung2014empirical], culminating in the development of the LSTM [@hochreiter1997long] and GRU [@cho2014properties]. A parallel direction addressed gradient instability through orthogonal weight parameterizations to prevent exponential growth or decay across time [@saxe2013exact; @arjovsky2016unitary; @wisdom2016full; @mhammedi2017efficient; @vorontsov2017orthogonality; @helfrich2018orthogonal]. Others explored external memory [@peshkin2001learning; @graves2014neural; @hansen2017longtimescalecreditassignment], hierarchical modeling [@hihi1995hierarchical; @chung2016hierarchical; @wang2025hierarchicalreasoningmodel; @jolicoeurmartineau2025morerecursivereasoningtiny], other unique directions [@jaeger2001echo; @lukovsevivcius2009reservoir; @miller2018stable].

Recently, there has been renewed interest in RNNs in the form of linear state space models [@gu2021efficiently; @smith2022simplified; @gu2023mamba], linear attention models [@katharopoulos2020transformers; @sun2023retentive; @dao2024transformers; @yang2025parallelizinglineartransformersdelta; @yang2025gateddeltanetworksimproving], and even nonlinear RNN models [@beck2024xlstm; @bulatov2022recurrent; @mishra2026m; @oncescu2026recurrent]. Recurrent computation more generally has been reappearing across paradigms including in diffusion [@ho2020denoising], looped Transformers [@giannou2023looped], and reasoning [@hao2025traininglargelanguagemodels; @geiping2025scaling].

#### Time-Parallel Training

Transformers revolutionized sequence modeling [@brown2020language] largely because they have time-parallel training [@vaswani2017attention] (unlike prior attention methods [@bahdanau2016neuralmachinetranslationjointly]), which is crucial for leveraging modern hardware [@hillis1986data; @hooker2021hardware] to scale performance [@kaplan2020scaling; @hoffmann2022training]. Linear RNNs gained popularity [@katharopoulos2020transformers; @dao2024transformers; @yang2025parallelizinglineartransformersdelta; @feng2024were] after it was realized they can be parallelized with the associative scan algorithm [@blelloch1990prefix; @martin2018parallelizinglinearrecurrentneural; @wang2020bppsascalingbackpropagationparallel].

A recent line of work attempts to parallelize nonlinear RNNs as well [@lim2024parallelizingnonlinearsequentialmodels; @danieli2023deeppcr]. Rather than computing $[m_0,\dots,m_T]$ with $m_{t+1}=f_\theta(m_t)$, they formulate the forward pass as an iterative optimization procedure. Starting with an initial guess $[m_0^0,\dots,m_T^0]$, they construct a system of $T$ equations, $\{m_{t+1} - f_\theta(m_t)= 0\}_{t=0}^T$, and solve this system with Newton's method [@ortega2000iterative]. Many works have further built on this approach [@danieli2025pararnnunlockingparalleltraining; @gonzalez2024towards]. Although appealing, this approach approximates BPTT and hence will suffer from its $\mathcal{O}(T)$ credit path length and corresponding gradient instability, along with the added convergence worries of Newton's method [@gonzalez2025predictability]. In contrast, SMT uses an encoder to train $m_t$ to be a predictive state while satisfying $m_{t+1}\approx f_\theta(m_t)$ and providing an $\mathcal{O}(1)$ credit path length.

#### Computation Complexity Class of Models

A model's architecture determines the problems it can theoretically solve [@hornik1989multilayer]. Some tasks are inherently sequential and cannot be efficiently parallelized [@amdahl1967validity]; the **circuit depth** of a task is the minimum number of sequential steps required to solve it on an infinitely parallel computer [@ruzzo1981uniform; @cook1985taxonomy]. Every neural network has a corresponding **sequential depth**---the longest nonlinear computation path from input to output---which bounds the class of problems it can solve [@merrill2024expressivepowertransformerschain]. Models with constant or logarithmic sequential depth per layer, such as Transformers and linear RNNs, are provably limited to tasks with equivalently low circuit depth [@merrill2023parallelism; @merrill2024expressivepowertransformerschain; @yau2026sequentialparalleldualityprefixscannable]. While such models succeed on tasks amenable to parallelization (e.g. parity tracking via associative scan [@liu2022transformers; @li2025howlanguagemodelstrack]), they systematically fail on tasks requiring deep sequential computation (e.g. tracking a chess board [@merrill2024illusion]). Interestingly, the aspect that makes models parallelizable, limits their performance on harder problems [@merrill2023parallelism]. Nonlinear RNNs are one of the few classes of models where its sequential depth grows with the input sequence length [@merrill2026linear]. Although these constraints were seen as theoretical, there is growing evidence they affect models in practice as well [@liu2025serial].

In SMT, we train a nonlinear RNN (which is fully expressive), using a time-parallel teacher Transformer (which has limits). We note this limitation but argue that SMT is a pretraining algorithm, which should be used with a lightweight post-training algorithm to solve downstream tasks [@gan2026neural].

#### Predictive State Representations (PSRs)

A PSR is a way of modeling a partially observed dynamical system by representing its state only in terms of predictions about future observations [@littman2001predictive; @singh2012predictive], a representation that is sufficient for optimal decision making [@singh2003learning]. Early works interpreted PSRs as a literal vector of probabilities of future events, but have since been generalized [@downey2017practical]. Belief states are a related concept, which also defines a sufficient statistic of the past [@kaelbling1998planning].

PSRs have been previously incorporated into RNNs [@downey2017predictivestaterecurrentneural; @hefny2018recurrentpredictivestatepolicy]. @venkatraman2017predictivestatedecodersencodingfuture introduce an auxiliary objective for RNNs that trains hidden states to predict statistics of future observations using a decoder. However, these works still unroll the RNN and use BPTT, and thus are not time-parallelizable and have a $\mathcal{O}(T)$ credit path.

#### Other Related Work

Our work is related to the literature on cross-architecture teacher-student distillation [@kasai2021finetuningpretrainedtransformersrnns; @wang2025mamballamadistillingaccelerating; @hauzenberger2026effectivedistillationhybridxlstm; @chen2026hybridlinearattentionright; @moudgil2026attention], but these works do not address the challenges of training nonlinear RNNs. In concurrent work, @teoh2025nextlatentpredictiontransformerslearn introduced Next-Latent Prediction (NextLat), which trains an RNN with memory state supervision from a Transformer. With a particular setting of the hyperparameters, SMT and NextLat are equivalent. However, @teoh2025nextlatentpredictiontransformerslearn focuses its experiments on the case where the latent representations are still trained with BPTT, which can be optionally truncated to $T=1$, whereas we focus primarily on the $T=1$ case. Additionally, NextLat's goal is to regularize the Transformer to learn compact world models, rather than training the RNN without BPTT. The Recurrent Transformer is an RNN architecture that attends to all past hidden states, creating an $\mathcal{O}(1)$ gradient path that stabilizes credit assignment [@oncescu2026recurrent]. However, because it retains all past hidden states, its memory grows unboundedly during inference---making it more akin to a Transformer than a fixed-memory RNN. Crucially, training still requires sequential unrolling and BPTT. SMT, by contrast, replaces BPTT and supports arbitrary fixed-memory RNN architectures and enables time-parallel training by never unrolling the RNN. Other works similarly combine Transformers with recurrent processing but also train with sequential unrolling and BPTT [@fan2020addressing; @bulatov2022recurrent; @wu2022memformer]. A new line of work uses principles from diffusion models to train blocks of a feed-forward network in parallel, avoiding global backpropagation [@li2025noprop; @shing2025diffusionblocks].

# Discussion {#sec:discussion}

In SMT, the teacher model is time-parallel, and is thus constrained in expressivity [@merrill2023parallelism], implying that SMT-trained RNNs may suffer the same problem. Therefore, BPTT finetuning might be required to achieve expressivity beyond the teacher. Additionally, while SMT is useful for learning how to encode sequences, it is not necessarily to be used for learning reasoning since intermediate steps are not supervised. The same limitation applies to Transformers yet post-training allows them to effectively solve longer-horizon tasks than the training horizon; the same might be true for SMT-trained RNNs.

The current SMT variant computes and trains only a single $m_t$ within a sequence. We found that training on all memories $[m_0, \dots, m_T]$ offered no improvement in our settings, but this may not hold at larger scales. After SMT, the RNN experiences drift away from the teacher memory trajectory. DMT provides one solution but is not time-parallel; however, it may be parallelized via DEER [@lim2024parallelizingnonlinearsequentialmodels].

RNNs have the promise of solving problems that extend over unbounded horizons, such as the entire lifetime of an agent. However, training methods for RNN have been hindered by the inability of BPTT to assign credit effectively over such a long horizon. Our method circumvents the credit assignment issue with an $\mathcal O(1)$ connection path. In the regimes we studied, this effectively allows for learning memories that are only useful many steps later, an ability that is crucial for lifelong learning.

```{=latex}
\newpage
```
```{=latex}
\begin{ack}

This work was supported by an NSF GRFP Fellowship to A.K., a Packard Fellowship and Sloan Research Fellowship to P.I., and ONR MURI grant N00014-22-1-2740.
This work was also supported under project ID 43 as part of the Swiss AI Initiative, through a grant from the ETH Domain and computational resources provided by the Swiss National Supercomputing Centre (CSCS) under the Alps infrastructure.
\\
We thank Alyosha Efros for suggesting the Attneave framing for pixel sequence modeling and recommending the Sketchy dataset.
We thank Alexander Huth for initially motivating A.K. to work on memory many years ago.
We thank Assaf Ben-Kish for reviewing an earlier draft of this manuscript.
We thank Han Guo and Oliver Sieberling for technical advice on algorithmic complexity.
\end{ack}
```
```{=latex}
\bibliographystyle{plainnat}
```
```{=latex}
\appendix
```
```{=latex}
\newpage
```
# Definitions {#sec:appendix_definitions}

#### Credit Assignment Path Length

For any differentiable computational graph, backpropagation propagates gradients from the scalar loss backward through the graph to each leaf node (typically model weights). We define the **credit assignment path length** as the maximum distance between any two nodes (e.g. tokens) in the computation graph. Distance is measured as the number of intervening non-identity operations that modify gradients (e.g. matrix multiplications or nonlinearities). The longer this path, the less effective backpropagation is for properly learning associations between distant nodes and assigning credit [@bengio1994learning]. Under this definition, BPTT has $\mathcal O(T)$ credit assignment path length, whereas Transformers and SMT have $\mathcal O(1)$ path length between any two tokens.

#### Sequential Computation (measured in SeqFLOPs)

Sequential computation is the amount of serial (non-parallelizable) work required to complete a computation. Some computations may require substantial total work but little sequential work (e.g. matrix multiplication). As parallel hardware such as GPUs continues to scale, total work matters less than the amount of inherently sequential work required [@hooker2021hardware].

Sequential compute is measured by analyzing the computation graph required to execute an algorithm and computing the graph's *critical path*: the number of floating point operations that must be executed sequentially on an infinitely parallel computer. We refer to this quantity as **sequential FLOPs** (SeqFLOPs).

For simplicity, we estimate SeqFLOPs by counting the number of sequential atomic deep learning operations (e.g., Linear, LayerNorm) executed over the course of the algorithm. The true SeqFLOPs, measured in floating-point operations, is proportional to this estimate.

We compute SeqFLOPs for BPTT, SMT, and DMT. SMT is fully parallelizable in time, incurring $\mathcal{O}(1)$ SeqFLOPs per optimization step, independent of the sequence length $T$. In contrast, BPTT and DMT require unrolling the RNN, which increases SeqFLOPs to $\mathcal{O}(T)$ per optimization step.

#### Data Processed (measured in Tokens)

Data Processed is defined as the total number of tokens processed during training, including repeated tokens in multi-epoch settings.

# Experiment Details {#sec:appendix_experiment_details}

## Architectures

Our primary architecture used for most experiments is shown in Figure `\ref{fig:architecture}`{=latex}.

![ **Model Architecture for SMT.** **Left:** The encoder reads the input context tokens and a set of learned register tokens, and outputs the memory, $m_t$, which is a set of memory tokens. The decoder takes in this memory and the *future input* tokens and predicts the *future output* tokens, using a causal mask. This setup forces information from the context to be compressed into a memory that is useful for predicting the future outputs, given future inputs. **Middle:** Our RNN maps $(m_t, x_{t+1})$ to $m_{t+1}$ using a Transformer-backbone. Since the memory is a list of tokens and the input is a token, we simply use a full attention Transformer to transform the current memory into the next timestep's memory. **Right:** Readout is performed by a full attention Transformer over the memory tokens. ](figures/architecture.png){#fig:architecture width="1.0\\linewidth"}

#### Encoder Architecture

We use the same encoder model architecture across all experiments. The model begins with an embedding layer to embed input tokens, and a list of learned memory token registers. The input and register tokens are concatenated and processed by a stack of *bidirectional* (full attention mask) Transformer blocks. We use bidirectional model because the goal is to create a holistic representation of the entire input sequence. The register tokens are then interpreted as memory tokens at the output. Note that a single memory consists of a list of memory tokens, $m_t=[m_t^1,\dots,m_t^M]$. Within a Transformer block, we use rotary position encodings [@su2024roformer] and RMSNorm instead of LayerNorm. We perform RMSNorm on the output memory tokens for stability. Figure `\ref{fig:architecture}`{=latex} shows the encoder architecture.

#### Decoder Architecture

We use the same decoder model architecture across all experiments. The decoder has an embedding layer to embed input tokens, which is weight shared with the encoder model. The memory tokens from the encoder and the embedded future input tokens are concatenated, and then processed by a stack of *causally masked* Transformer blocks. We use a causal mask because the goal is to learn a generative model of the output sequence. Within the each Transformer block, we use rotary position encodings [@su2024roformer] and RMSNorm instead of LayerNorm. The output predictions are read out at token positions such that $\hat y_{t+k}$ is a function of only $m_t$ and $x_{t+1},\dots,x_{t+k}$. Figure `\ref{fig:architecture}`{=latex} shows the decoder architecture.

#### Transformer-based RNN Architecture

The Transformer-based RNN is our primary RNN architecture and is used for most experiments. It begins with an embedding layer for the current timestep's input token. The memory tokens are concatenated with the input token and then processed by a stack of *bidirectional* (full attention mask) Transformer blocks to produce the output memory tokens. We perform RMSNorm on the output memory tokens.

#### MLP-based RNN Architecture

The MLP-based RNN flattens the list of memory tokens into a single vector, concatenates it with the input token embedding, and passes them through an MLP. At the output, the model then unflattens them to be a list of memory tokens again. We perform RMSNorm on the output memory tokens.

#### GRU-based RNN Architecture

The GRU-based RNN processes the $M$ memory tokens with a $M$ layer stacked GRU. Layer $l$ reads a single memory token, $m_t^l$, and outputs a single memory token for the next timestep, $m_t^{l+1}$. We do *not* RMSNorm on the output memory tokens, since that would undermine the GRU's residual structure.

#### RNN Readout Architecture

We use the same readout model architecture for all RNNs. The readout architecture takes in the memory tokens and processes them through a stack of *bidirectional* (full attention mask) Transformer blocks Figure `\ref{fig:architecture}`{=latex} shows the readout architecture.

## Datasets

### Synthetic Tasks {#sec:synth_task_desc}

#### Retrieval to test *Gradient Stability*

The retrieval task requires the model to remember and reproduce the token immediately following a designated identifier (token $0$). For example, given $\mathbf{x}=[3,4,0,2,1,3,1,0]$ the target is $\mathbf{y}=[\emptyset,\emptyset,\emptyset,\emptyset,\emptyset,\emptyset,\emptyset,2]$, where $\emptyset$ denotes no prediction target. With probability $p$, the label is corrupted to a random token. By varying **sequence length** and **noise level**, this task probes the algorithm's capacity for stable gradient credit assignment.

#### String Copy to test *Memory Capacity*

The string copy task requires the model to reproduce a sequence in reverse order after a delimiter (token $0$). For example, given $\mathbf{x}=[3,4,1,2,0,0,0,0]$, the target is $\mathbf{y}=[\emptyset,\emptyset,\emptyset,\emptyset,2,1,4,3]$. By varying the **sequence length** and the **memory state size**, this task measures the algorithm's ability to leverage the RNN's memory capacity for memorization.

#### Stack Operations to test *State Tracking*

The stack operations task requires the model to track the top element of a stack through a sequence of push and pop operations (denoted by token $0$). For example, given $\mathbf{x}=[1,0,2,3,0,1,0,0]$, the target is $\mathbf{y}=[\emptyset,1,\emptyset,\emptyset,3,\emptyset,1,2]$. By varying the **sequence length** and **state complexity** (maximum stack depth), this task evaluates the algorithm's capacity for state tracking.

#### Keys and Values to test *Associative Recall*

The keys and values task requires the model to store and retrieve associations between keys and values, then recall the value corresponding to a queried key [@willshaw1969non; @hinton2014parallel]. For example, given $\mathbf{x}=[b,1,a,3,d,2,4,a]$, the target is $\mathbf{y}=[\emptyset,\emptyset,\emptyset,\emptyset,\emptyset,\emptyset,\emptyset,3]$. By varying the **number of associations** and **association complexity** (string length of keys and values), this task evaluates the algorithm's capacity for associative recall.

#### Modular Arithmetic to test *In-Context Learning*

The modular arithmetic task requires the model to infer a latent linear rule from in-context examples and apply it to novel inputs [@akyurek2022learning; @kirsch2022general]. For each sequence, parameters $a$ and $b$ are sampled, then the sequence is presented as $\mathbf{x}=[x_0,y_0,x_1,y_1,x_2,y_2,x_3,y_3,]$, where $y_i=(a x_i + b) \mod V$, where $V$ is the vocabulary size. Then, the target is $\mathbf{y}=[y_0,\emptyset,y_1,\emptyset,y_2,\emptyset,y_3,\emptyset]$. By varying the **difficulty** (range of values $a$, $b$ can take on) and the number of **in-context examples**, this tasks evaluates the algorithm's ability to induce in-context learning.

### Natural Tasks

#### TinyStories

TinyStories is a curated dataset of short stories generated by OpenAI's GPT-4 [@eldan2023tinystories]. We use ASCII character-level tokenization, yielding a vocabulary of 256 tokens. Under this tokenization, the training and test sets contain 1.9B and 19.2M tokens, respectively.

#### MNIST

MNIST is a classic image dataset consisting of handwritten digits [@lecun1998mnist]. Rather than performing classification or 2D image generation, we consider the problem of 1D pixel-sequence modeling. The original 28$\times$`<!-- -->`{=html}28 images are flattened into sequences of length 784 using raster-scan ordering. Each image is represented as a sequence of raw grayscale pixel intensities (0--255), yielding a vocabulary of 256 tokens. The training and test sets contain 47M and 7.8M tokens, respectively.

#### Sketchy

Sketchy is an image dataset of human-drawn sketches [@sangkloy2016sketchy]. Rather than performing classification or 2D image generation, we consider the problem of 1D pixel-sequence modeling. The original images are resized to 64$\times$`<!-- -->`{=html}64 using Lanczos resampling, and the pixels are binarized. Non-overlapping 2$\times$`<!-- -->`{=html}2 patches are tokenized, yielding a vocabulary of $2^{2\times2}=16$ tokens. The resulting image is flattened in raster-scan order to form sequences of length 1024. The training and test sets contain 69.5M and 7.7M tokens, respectively.

## Algorithms {#sec:appendix_algorithm_details}

For all experiments, we use the AdamW optimizer with a weight decay of 0.01 and learning rates tuned separately for each algorithm. For all methods, gradients are clipped to a maximum global norm of $1$. Gradient clipping is expected to be particularly beneficial for BPTT.

After SMT, we transfer the decoder weights to the RNN readout module. During DMT, this readout head is further finetuned to optimize the next-token prediction loss using RNN-generated memory states. Importantly, this task loss updates only the readout head and not the RNN dynamics function, and therefore does not constitute temporal credit assignment for the RNN itself. Instead, finetuning serves to adapt the readout head to imperfections in the memory states generated by the RNN.

#### Synthetic Experiments

We set $T_c = T_f = T$. To evaluate the expectation in $\mathcal{L}^\text{smt}$, we compute the loss terms at all timesteps $t\in[0, \dots, T]$. For earlier timesteps, where the available past context is shorter than the required context length, we pad the sequence and modify the attention mask so that padding tokens are ignored. The same procedure is applied to the future context. In these synthetic experiments, the prediction loss is applied only at positions where target output tokens are defined (e.g. the answer token in the needle task). We use batch sizes of 32 sequences during optimization, but this gets expanded to $32\times T$ input contexts to the encoder.

#### Other Experiments

To evaluate the expectation in $\mathcal{L}^\text{smt}$, we compute the loss term at a single timestep sampled uniformly from $t \sim \mathcal U[0,\dots,T]$. The dataset is represented as one long sequence, meaning padding is not required, as both the past and future contexts extend indefinitely. We compute the $\mathcal{L}^\text{unif}$ over batches of memories from different sequences.

By default for SMT, we set $T_c = 256$, $T_f = 64$, $\lambda_\text{dec}=1.0$, $\lambda_\text{dyn}=0.1$, $\lambda_\text{unif}=0.001$ and train for $150000$ SGD iterations. Unless otherwise specified, models use a hidden dimension of 256 with 16 memory tokens. The encoder is 8 layers deep, while the decoder is 4 layers deep. The RNN is also 8 layers deep, and its readout function is 4 layers deep. We use a batch size of 128 sequences.

# Additional Experiments

Figure `\ref{fig:lambda_ablation}`{=latex} shows the results of ablating the $\lambda_\text{dyn}$ and $\lambda_\text{unif}$. Results show the optimal RNN performance requires a moderate dynamics loss, paired with a very low uniformity loss. However, a little uniformity is critical for avoiding memory space collapse.

Figure `\ref{fig:mnist_generations_extra}`{=latex} shows more samples of generations from Figure `\ref{fig:mnist_generations}`{=latex}. Samples generated by the BPTT RNN (Transformer backbone) seem to only pick up on short range context and act accordingly: either output large streaks of white or black based on the current row. BPTT RNN (GRU backbone) improves this significantly, but still fails to capture the nuanced structure of digits. SMT$\rightarrow$DMT RNN (Transformer backbone) is able to capture this structure quite well.

Figure `\ref{fig:sketchy_generations_extra}`{=latex} shows more samples of generations from Figure `\ref{fig:sketchy_generations}`{=latex}. These generations are often not fully interpretable, but do capture the stroke structure of human-drawn sketches. Capturing this stroke structure is itself a difficult problem, given the long-horizon nature of pixel sequence modeling.

Figure `\ref{fig:cat_read}`{=latex} provides an analysis of our Sketchy RNN as it \`\`reads" a sequence corresponding to the classic Attneave's cat image [@attneave1954some]. The memory sequence does not seem to be fully interpretable, but does show significant structure. Figure `\ref{fig:cat_write}`{=latex} provides samples of generations when conditioned on partial context of Attneave's cat image. Figure `\ref{fig:mnist_memory_pca}`{=latex}, `\ref{fig:mnist_memory_tsne}`{=latex}, `\ref{fig:sketchy_memory_tsne}`{=latex}, show additional analysis of the RNNs on MNIST and Sketchy data.

# Compute Resources Used {#sec:appendix_compute_resources}

All individual training runs were conducted on one H200 GPU within 48 hours. The synthetic experiments comprised more than 375 small-scale training runs, while the real-data experiments required 144 large-scale runs.

![ **Sweep of $\lambda_\text{dyn}$ and $\lambda_\text{unif}$.** Cell color indicates the RNN test loss for each setting. Top number in each cell is the RNN test loss. Bottom number in each cell shows the $\mathcal L^\text{unif}$. $\mathcal L^\text{unif}$ varies from $0$ (collapsed latent space) to $-4$ (fully uniform latent space). ](figures/lambda_ablation.png){#fig:lambda_ablation width="1.0\\linewidth"}

<figure id="fig:mnist_generations_extra">
<p><img src="figures/mnist_generations_all_dataset.png" /><br />
<img src="figures/mnist_generations_all_bptt.png" /><br />
<img src="figures/mnist_generations_all_bptt_gru.png" /><br />
<img src="figures/mnist_generations_all_smtdmt.png" /></p>
<figcaption> <strong>Additional MNIST Samples.</strong> Here we give more examples of samples of MNIST images generated by the various methods. SMT<span class="math inline">→</span>DMT RNN outperforms BPTT, even when BPTT is applied on a GRU architecture, in processing long-horizon information, which is required for pixel modeling. </figcaption>
</figure>

<figure id="fig:sketchy_generations_extra">
<p><img src="figures/sketchy_generations_all_dataset.png" /><br />
<img src="figures/sketchy_generations_all_smtdmt.png" /></p>
<figcaption> <strong>Additional Sketchy Samples.</strong> Here we give more examples of samples of Sketchy images from the dataset and generated by SMT<span class="math inline">→</span>DMT. Even in this hard sparse domain, SMT<span class="math inline">→</span>DMT can capture the overall stroke structure, which requires integrating information over hundreds of pixels. </figcaption>
</figure>

![ **Analysis on Attneave's Cat.** We apply the SMT$\rightarrow$DMT-trained RNN on Sketchy and evaluate it on the classic image of Attneave's cat. The RNN reads the image pixel-by-pixel in raster scan order. **Top Left:** Input image presented in its original 2D form. **Top Middle:** 3D t-SNE projection of the RNN memory state, visualized as RGB values over time, showing the evolution of memory throughout sequence processing. **Top Right:** 2D t-SNE projection of the memory state trajectory over time. **Middle:** The same image presented as a flat token sequence. From the RNN's perspective, the task resembles modeling a barcode-like sequence, requiring long-range associations between distant tokens and highlighting the difficulty of pixel sequence modeling. **Bottom:** 3D t-SNE projection of the memory state visualized along the flattened sequence. ](figures/cat_read.png){#fig:cat_read width="1.0\\linewidth"}

![ **Generations of Attneave's Cat.** We apply the SMT$\rightarrow$DMT-trained RNN on Sketchy and apply it to generate part of the image of Attneave's cat. Given more of the image context, the RNN seems to understand the image better and make somewhat more plausible predictions. ](figures/cat_write.png){#fig:cat_write width="1.0\\linewidth"}

```{=latex}
\newpage
```
![ **RNN Memory Evolution on MNIST (PCA)**. We analyze the memory evolution of our SMT$\rightarrow$DMT MNIST RNN. **Left**: Input image presented in its original 2D form. **Middle**: 3D PCA projection of the RNN memory state, visualized as RGB values over time, showing the evolution of memory during processing. **Right**: 2D PCA projection of the memory state trajectory over time. ](figures/mnist_memory_pca.png){#fig:mnist_memory_pca width="0.7\\linewidth"}

![ **RNN Memory Evolution on MNIST (t-SNE)**. We analyze the memory evolution of our SMT$\rightarrow$DMT MNIST RNN. **Left**: Input image presented in its original 2D form. **Middle**: 3D t-SNE projection of the RNN memory state, visualized as RGB values over time, showing the evolution of memory during processing. **Right**: 2D t-SNE projection of the memory state trajectory over time. ](figures/mnist_memory_tsne.png){#fig:mnist_memory_tsne width="0.7\\linewidth"}

![ **RNN Memory Evolution on Sketchy (t-SNE)**. We analyze the memory evolution of our SMT$\rightarrow$DMT Sketchy RNN. **Left**: Input image presented in its original 2D form. **Middle**: 3D t-SNE projection of the RNN memory state, visualized as RGB values over time, showing the evolution of memory during processing. **Right**: 2D t-SNE projection of the memory state trajectory over time. ](figures/sketchy_memory_tsne.png){#fig:sketchy_memory_tsne width="1.0\\linewidth"}

```{=latex}
\newpage
```
# Sequence to Set Reframing {#sec:set_reframing}

As described in Section `\ref{sec:smt}`{=latex}, consider a hypothetical oracle memory-encoding model $\mathcal{Q}$ that takes as input the sequence of tokens and outputs an effective compressed memory. Here we show that $\mathcal{Q}$ does *not* have to be a recurrent function over the sequence of tokens, but can instead be represented as a permutation-invariant function over a *set* of timestamped tokens.

#### Claim

Let $\mathbf{x}_\text{seq}=[x_0, x_1, \dots, x_t]$ be the original sequence of tokens. We define the set $\mathbf{x}_\text{set}=\{(x_0, 0), (x_1, 1), \dots, (x_t, t)\}$. Assume that $\mathcal{Q}$ is a recurrent function over $\mathbf{x}_\text{seq}$. In other words, $$m = \mathcal{Q}(\mathbf{x}_\text{seq})= f(\ldots f(f(m_{\emptyset}, x_0), x_1), \ldots , x_t)$$ with $m_{\emptyset}=\mathbf{0}$ for some function $f$. For any such $\mathcal Q$, $\exists$ a function $g$ such that $g(\mathbf{x}_\text{set}) = \mathcal{Q}(\mathbf{x}_\text{seq})$.

#### Proof

We construct $g$ explicitly. Define $g(\mathbf{x}_\text{set})$ as follows: given the input set $\mathbf{x}_\text{set} = \{(x_0, 0), (x_1, 1), \dots, (x_t, t)\}$, sort the elements in ascending order of their timestamp to recover the sequence $\mathbf{x}_\text{seq} = [x_0, x_1, \dots, x_t]$, then apply $\mathcal{Q}$ to this sequence.

This is well-defined because the timestamps $\{0, 1, \dots, t\}$ are distinct integers, so the sort order is unique. The resulting sequence is identical to the original $\mathbf{x}_\text{seq}$, and therefore $g(\mathbf{x}_\text{set}) = \mathcal{Q}(\mathbf{x}_\text{seq}) = m$.

Moreover, $g$ is permutation-invariant: any permutation of the elements of $\mathbf{x}_\text{set}$ yields the same sorted sequence and thus the same output.

Since $\mathcal{Q}$ was arbitrary, this construction applies to every recurrent $\mathcal{Q}$, completing the proof. $\blacksquare$

#### Implication

This result implies that any sufficiently expressive permutation-invariant set model can in principle exactly model a recurrent memory function. Because sets are unordered, time-parallel processing naturally follows. In particular, Transformer-based architectures can be interpreted as operating over sets of timestamped tokens rather than strictly ordered sequences.

Notably, the proof is constructive: $g$ recovers the sequential computation by sorting the timestamps and implicitly applying the recurrent update rule $f$ up to $t$ times. Consequently, when implemented with bounded-depth architectures such as Transformers, the required depth may need to scale with sequence length, consistent with prior work on sequential depth and time-parallel training discussed in Section `\ref{sec:related_works}`{=latex}. Scaling depth with sequence length seems to present a major theoretical limitation.

However, our empirical results suggest that even relatively shallow Transformer encoders can learn highly effective memory representations for both synthetic and natural tasks. Thus, despite lacking full theoretical expressivity, this sequence-to-set reframing may still provide a practical strategy for memory *pretraining*. For full expressivity, some light-weight post-training may be required.

```{=latex}
\newpage
```
# Encoder Markovian Training {#sec:markovian_encoder_proof}

SMT consists of two primary objectives: future predicting with $\mathcal{L}^{\text{dec}}$ and dynamics modeling with $\mathcal{L}^{\text{dyn}}$. The dynamics objective serves two purposes: (1) training the RNN to predict the next memory state from the current one, and (2) encouraging the encoder to produce memory states that are predictable from one another, i.e. approximately Markovian. In this section, we show that the predictive state objective $\mathcal{L}^{\text{dec}}$ alone is sufficient for learning Markovian memories, implying that $\mathcal{L}^{\text{dyn}}$ is theoretically unnecessary, though still practically useful.

#### Claim

Let $\mathbf{x} = [x_0, x_1, \dots, x_T]$ and $\mathbf{y} = [y_0, y_1, \dots, y_T]$ be input and output sequences. For each timestep $t$, define $$\mathbf{x}_t^{\text{ctx}} = [x_0, \dots, x_t], \qquad
  \mathbf{x}_t^{\text{fut}} = [x_{t+1}, \dots, x_T], \qquad
  \mathbf{y}_t^{\text{fut}} = [y_t, \dots, y_T],$$ with memory state $m_t = \mathcal{E}_\phi(\mathbf{x}_t^{\text{ctx}})$ and reconstructed future $\hat{\mathbf{y}}_t^{\text{fut}} = \mathcal{D}_\psi(m_t, \mathbf{x}_t^{\text{fut}})$. If $m_t$ is an *optimal* minimal sufficient statistic of $\mathbf{x}_t^{\text{ctx}}$ for predicting $\mathbf{y}_t^{\text{fut}}$ given $\mathbf{x}_t^{\text{fut}}$ at *every* $t$, then the memory sequence $(m_t)$ is **Markovian**: $$m_{t+1} \perp\!\!\!\perp \mathbf{x}_t^{\text{ctx}} \;\Big|\; m_t,\, x_{t+1}.$$

#### Proof

By optimality of $m_t$, it is a minimal sufficient statistic of $\mathbf{x}_t^{\text{ctx}}$ for predicting $\mathbf{y}_t^{\text{fut}}$ given $\mathbf{x}_t^{\text{fut}}$: $$H\!\bigl(\mathbf{y}_t^{\text{fut}} \mid m_t, \mathbf{x}_t^{\text{fut}}\bigr)
  = H\!\bigl(\mathbf{y}_t^{\text{fut}} \mid \mathbf{x}_t^{\text{ctx}}, \mathbf{x}_t^{\text{fut}}\bigr).$$ Note that $\mathbf{y}_{t+1}^{\text{fut}} \subseteq \mathbf{y}_t^{\text{fut}}$ and $\mathbf{x}_{t+1}^{\text{fut}} \subseteq \mathbf{x}_t^{\text{fut}}$, so optimality of $m_t$ at time $t$ implies it is also sufficient for $\mathbf{y}_{t+1}^{\text{fut}}$ given $\mathbf{x}_{t+1}^{\text{fut}}$: $$H\!\bigl(\mathbf{y}_{t+1}^{\text{fut}} \mid m_t, x_{t+1}, \mathbf{x}_{t+1}^{\text{fut}}\bigr)
  = H\!\bigl(\mathbf{y}_{t+1}^{\text{fut}} \mid \mathbf{x}_t^{\text{ctx}}, x_{t+1}, \mathbf{x}_{t+1}^{\text{fut}}\bigr)
  = H\!\bigl(\mathbf{y}_{t+1}^{\text{fut}} \mid \mathbf{x}_{t+1}^{\text{ctx}}, \mathbf{x}_{t+1}^{\text{fut}}\bigr).$$ By optimality of $m_{t+1}$, it is a minimal sufficient statistic of $\mathbf{x}_{t+1}^{\text{ctx}}$ for predicting $\mathbf{y}_{t+1}^{\text{fut}}$ given $\mathbf{x}_{t+1}^{\text{fut}}$. Minimality means $m_{t+1}$ retains no information from $\mathbf{x}_{t+1}^{\text{ctx}}$ beyond what is predictively necessary. Since $(m_t, x_{t+1})$ already constitutes a sufficient statistic for this same prediction task---as shown above---minimality of $m_{t+1}$ forces it to be a function of $(m_t, x_{t+1})$: $$m_{t+1} = f(m_t,\, x_{t+1})$$ for some measurable $f$. Therefore $m_{t+1}$ is determined entirely by $(m_t, x_{t+1})$, and conditioning on these renders it independent of all earlier context: $$H(m_{t+1} \mid m_t,\, x_{t+1}) = 0,$$ which is equivalent to $m_{t+1} \perp\!\!\!\perp \mathbf{x}_t^{\text{ctx}} \mid m_t, x_{t+1}$. Hence $(m_t)$ is Markovian. $\blacksquare$

#### Implication

This result establishes that under ideal conditions---sufficient encoder and decoder capacity, infinite future horizon, and exact optimization---the memory states learned by $\mathcal{E}_\phi$ form a Markov chain driven only by the previous state and the incoming token. In other words, the encoder *implicitly* learns a Markovian memory representation: $m_{t+1}$ can be predicted from only $(m_t, x_{t+1})$.

In practice, finite capacity and approximate optimization relax this property, leaving $m_{t+1}$ with residual dependence on $\mathbf{x}_t^{\text{ctx}}$ beyond $(m_t, x_{t+1})$. This gap motivates jointly training the dynamics loss $\mathcal{L}^{\text{dyn}}$ alongside $\mathcal{L}^{\text{dec}}$ to explicitly encourage Markovian structure in the learned memory sequence.

Relatedly, @teoh2025nextlatentpredictiontransformerslearn provide a proof that one-step RNN dynamics with a $T_f=1$ encoder also induce a predictive-state memory representation.
