---
abstract: |
  The quadratic cost of attention in transformers motivated the development of efficient approaches: namely sparse and sliding window attention, convolutions and linear attention. Although these approaches result in impressive reductions in compute and memory, they often trade-off with quality, specifically in-context recall performance. Moreover, apriori fixing this quality-compute tradeoff in an architecture means being suboptimal from the get-go: some downstream applications require more memory for in-context recall, while others require lower latency and memory. Further, these approaches rely on heuristic choices that artificially restrict attention, or require handcrafted and complex recurrent state update rules, or they must be carefully composed with attention at specific layers to form a hybrid architecture that complicates the design process, especially at scale.

  To address above issues, we propose **C**ompress & **A**ttend **T**ransformer (`\cat`{=latex}), a conceptually simple architecture employing two simple ingredients only: dense attention and compression. `\cat `{=latex}decodes chunks of tokens by attending to compressed chunks of the sequence so far. Compression results in decoding from a reduced sequence length that yields compute and memory savings, while choosing a particular chunk size trades-off quality for efficiency. Moreover, `\cat `{=latex}can be trained with multiple chunk sizes at once, unlocking control of quality-compute trade-offs directly at test-time without any retraining, all in a single adaptive architecture.

  In exhaustive evaluations on language modeling and common-sense reasoning, in-context recall, and long-context understanding, a single adaptive `\cat `{=latex}model outperforms many existing efficient baselines, including hybrid architectures, across different compute-memory budgets. Further, a single `\cat `{=latex}matches dense transformer in language modeling across different model scales while being $1.4-3\times$ faster and requiring $2-9\times$ lower total memory usage.

  Play with `\cats `{=latex}at: [`github.com/rajesh-lab/cat-transformer`](https://github.com/rajesh-lab/cat-transformer)
author:
- |
  Jatin Prakash`\quad`{=latex} Aahlad Puli `\quad`{=latex} Rajesh Ranganath\
  New York University `\quad`{=latex}\
  `jatin.prakash@nyu.edu`
bibliography:
- iclr2026_conference.bib
title: Attention and Compression is all you need for Controllably Efficient Language Models
---

\DeclareMathOperator*{\argmax}{arg\,max}
\DeclareMathOperator*{\argmin}{arg\,min}
\newcommand{\DKL}{D_{\mathrm{KL}}}
\newcommand{\KL}{\mathrm{KL}}
\DeclareRobustCommand{\wass}[2]{\ensuremath{\mathcal{W}_1\left(#1\;\left.\right\|\;#2\right)}}
\DeclareRobustCommand{\dho}[2]{\ensuremath{\frac{\partial #1}{\partial #2}}}
\renewcommand{\mid}{~\vert~}
\newcommand{\g}{\,|\,}
\newcommand{\indep}{\,\rotatebox[origin=c]{90}{$\models$}\,}
\newcommand{\nindep}{\,\rotatebox[origin=c]{90}{$\not\models$}\,}
\newcommand{\dif}{\mathop{}\!\mathrm{d}}
\newcommand{\diag}{\textrm{diag}}
\newcommand{\supp}{\textrm{supp}}
\newcommand{\Gam}{\textrm{Gam}}
\newcommand{\InvGam}{\textrm{InvGam}}
\newcommand{\norm}[1]{\left\lVert#1\right\rVert}
\newcommand{\Func}{\mathcal{F}}
\newcommand{\Ker}{\mathcal{K}}
\newcommand{\St}{\mathcal{S}}
\newcommand{\Kernel}{\mathcal{K}_{k}}
\newcommand{\euclid}{\mathbb{R}}
\newcommand{\score}{s^{p}}
\newcommand{\DO}{\text{do}}
\newcommand{\ve}{{\mathtt{\epsilon}}}
\newcommand{\va}{\mathtt{a}}
\newcommand{\vz}{\mathtt{z}}
\newcommand{\vx}{\mathtt{x}}
\newcommand{\vy}{\mathtt{y}}
\newcommand{\vg}{\mathtt{g}}
\newcommand{\vt}{\mathtt{t}}
\newcommand{\gvt}{g(\mathtt{t})}
\newcommand{\Hb}{\mathbf{H}}
\newcommand{\MI}{\mathbf{I}}
\newcommand{\mbeps}{\mbepsilon}
\DeclareRobustCommand{\doot}[1]{\text{do}(\mbt=#1)}
\newcommand{\dinv}{d^{-1}}
\newcommand{\partialder}[2]{\frac{\partial #1}{\partial #2}}
\newcommand{\ind}{\mathbf{1}}
\newcommand{\mbxp}{{\mbx^\prime}}
\newcommand{\xp}{{x^\prime}}
\newcommand{\ince}{\mbI_{NCE}}
\newcommand{\ptr}{{p_{tr}}}
\newcommand{\pte}{{p_{te}}}
\newcommand{\pind}{{p_{\scaleto{\indep}{4pt}}}}
\newcommand{\pindprime}{{p^\prime_{\scaleto{\indep}{4pt}}}}
\newcommand{\varqa}{\sigma^2_{q_a}}
\newcommand{\varqb}{\sigma^2_{q_b}}
\newcommand{\meanqa}{\E_{q_a}}
\newcommand{\meanqb}{\E_{q_b}}
\newcommand{\var}{\text{var}}
\newcommand{\aprime}{{ ( 1 + a)}}
\newcommand{\aminus}{{ ( 1-a)}}
\newcommand{\bprime}{{(1 + b)}}
\newcommand{\bminus}{{ (1 - b)}}
\newcommand{\ruv}{r_{u,v}}
\newcommand{\pinda}{{p_{\scaleto{\indep}{4pt},1}}}
\newcommand{\pindb}{{p_{\scaleto{\indep}{4pt},2}}}
\newcommand{\perf}{\mathtt{Perf}}
\newcommand{\perfte}{\perf_\pte}
\newcommand{\sumybin}{\sum_{y\in \{0,1\}}}
\newcommand{\epsdiff}{\mbeps_2 -\mbeps_1}
\newcommand{\epssum}{\mbeps_1 + \mbeps_2}
\newcommand{\variance}{\sigma^2}
\DeclareRobustCommand{\mb}[1]{{\ensuremath{\boldsymbol{\mathbf{#1}}}}}
\newcommand{\mba}{\mathbf{a}}
\newcommand{\mbb}{\mathbf{b}}
\newcommand{\mbc}{\mathbf{c}}
\newcommand{\mbd}{\mathbf{d}}
\newcommand{\mbe}{\mathbf{e}}
\newcommand{\mbg}{\mathbf{g}}
\newcommand{\mbh}{\mathbf{h}}
\newcommand{\mbi}{\mathbf{i}}
\newcommand{\mbj}{\mathbf{j}}
\newcommand{\mbk}{\mathbf{k}}
\newcommand{\mbl}{\mathbf{l}}
\newcommand{\mbm}{\mathbf{m}}
\newcommand{\mbn}{\mathbf{n}}
\newcommand{\mbo}{\mathbf{o}}
\newcommand{\mbp}{\mathbf{p}}
\newcommand{\mbq}{\mathbf{q}}
\newcommand{\mbr}{\mathbf{r}}
\newcommand{\mbs}{\mathbf{s}}
\newcommand{\mbt}{\mathbf{t}}
\newcommand{\mbu}{\mathbf{u}}
\newcommand{\mbv}{\mathbf{v}}
\newcommand{\mbw}{\mathbf{w}}
\newcommand{\mbx}{\mathbf{x}}
\newcommand{\mby}{\mathbf{y}}
\newcommand{\mbz}{\mathbf{z}}
\newcommand{\mbA}{\mathbf{A}}
\newcommand{\mbB}{\mathbf{B}}
\newcommand{\mbC}{\mathbf{C}}
\newcommand{\mbD}{\mathbf{D}}
\newcommand{\mbE}{\mathbf{E}}
\newcommand{\mbF}{\mathbf{F}}
\newcommand{\mbG}{\mathbf{G}}
\newcommand{\mbH}{\mathbf{H}}
\newcommand{\mbI}{\mathbf{I}}
\newcommand{\mbJ}{\mathbf{J}}
\newcommand{\mbK}{\mathbf{K}}
\newcommand{\mbL}{\mathbf{L}}
\newcommand{\mbM}{\mathbf{M}}
\newcommand{\mbN}{\mathbf{N}}
\newcommand{\mbO}{\mathbf{O}}
\newcommand{\mbP}{\mathbf{P}}
\newcommand{\mbQ}{\mathbf{Q}}
\newcommand{\mbR}{\mathbf{R}}
\newcommand{\mbS}{\mathbf{S}}
\newcommand{\mbT}{\mathbf{T}}
\newcommand{\mbU}{\mathbf{U}}
\newcommand{\mbV}{\mathbf{V}}
\newcommand{\mbW}{\mathbf{W}}
\newcommand{\mbX}{\mathbf{X}}
\newcommand{\mbY}{\mathbf{Y}}
\newcommand{\mbZ}{\mathbf{Z}}
\newcommand{\mbalpha}{\mb{\alpha}}
\newcommand{\mbbeta}{\mb{\beta}}
\newcommand{\mbdelta}{\mb{\delta}}
\newcommand{\mbepsilon}{\mb{\epsilon}}
\newcommand{\mbchi}{\mb{\chi}}
\newcommand{\mbeta}{\mb{\eta}}
\newcommand{\mbgamma}{\mb{\gamma}}
\newcommand{\mbiota}{\mb{\iota}}
\newcommand{\mbkappa}{\mb{\kappa}}
\newcommand{\mblambda}{\mb{\lambda}}
\newcommand{\mbmu}{\mb{\mu}}
\newcommand{\mbnu}{\mb{\nu}}
\newcommand{\mbomega}{\mb{\omega}}
\newcommand{\mbphi}{\mb{\phi}}
\newcommand{\mbpi}{\mb{\pi}}
\newcommand{\mbpsi}{\mb{\psi}}
\newcommand{\mbrho}{\mb{\rho}}
\newcommand{\mbsigma}{\mb{\sigma}}
\newcommand{\mbtau}{\mb{\tau}}
\newcommand{\mbtheta}{\mb{\theta}}
\newcommand{\mbupsilon}{\mb{\upsilon}}
\newcommand{\mbvarepsilon}{\mb{\varepsilon}}
\newcommand{\mbvarphi}{\mb{\varphi}}
\newcommand{\mbvartheta}{\mb{\vartheta}}
\newcommand{\mbvarrho}{\mb{\varrho}}
\newcommand{\mbxi}{\mb{\xi}}
\newcommand{\mbzeta}{\mb{\zeta}}
\newcommand{\mbDelta}{\mb{\Delta}}
\newcommand{\mbGamma}{\mb{\Gamma}}
\newcommand{\mbLambda}{\mb{\Lambda}}
\newcommand{\mbOmega}{\mb{\Omega}}
\newcommand{\mbPhi}{\mb{\Phi}}
\newcommand{\mbPi}{\mb{\Pi}}
\newcommand{\mbPsi}{\mb{\Psi}}
\newcommand{\mbSigma}{\mb{\Sigma}}
\newcommand{\mbTheta}{\mb{\Theta}}
\newcommand{\mbUpsilon}{\mb{\Upsilon}}
\newcommand{\mbXi}{\mb{\Xi}}
\newcommand{\cA}{\mathcal{A}}
\newcommand{\cB}{\mathcal{B}}
\newcommand{\cD}{\mathcal{D}}
\newcommand{\cF}{\mathcal{F}}
\newcommand{\cG}{\mathcal{G}}
\newcommand{\cL}{\mathcal{L}}
\newcommand{\cN}{\mathcal{N}}
\newcommand{\cO}{\mathcal{O}}
\newcommand{\cP}{\mathcal{P}}
\newcommand{\cQ}{\mathcal{Q}}
\newcommand{\cR}{\mathcal{R}}
\newcommand{\cS}{\mathcal{S}}
\newcommand{\cT}{\mathcal{T}}
\newcommand{\cU}{\mathcal{U}}
\newcommand{\cX}{\mathcal{X}}
\newcommand{\cY}{\mathcal{Y}}
\newcommand{\E}{\mathop{\mathbb{E}}}
\newcommand{\G}{\mathbb{G}}
\newcommand{\bbH}{\mathbb{H}}
\newcommand{\bbN}{\mathbb{N}}
\newcommand{\bbQ}{\mathbb{Q}}
\newcommand{\bbR}{\mathbb{R}}
\newcommand{\bbS}{\mathbb{S}}
\newcommand{\bbZ}{\mathbb{Z}}
\newcommand{\Var}{\mathbb{V}\textrm{ar}}
\newcommand{\Cov}{\mathbb{C}\textrm{ov}}
\newcommand{\cmark}{\textcolor{green!60!black}{\ding{51}}}
\newcommand{\xmark}{\textcolor{red!70!black}{\ding{55}}}
\newcommand{\qmark}{\textcolor{orange!80!black}{\textbf{?}}}
\newcommand{\fixme}[1]{\textcolor{red}{\textbf{FIXME:} #1}}
\newcommand{\jats}[1]{{\textcolor{red}{{(#1)}}}}
\newcommand{\ud}[1]{\underline{#1}}
\newcommand{\cat}{\textsc{cat}\xspace}
\newcommand{\cats}{\textsc{cat}s\xspace}
\newcommand{\Cat}{\textsc{Cat}\xspace}
\newcommand{\Cats}{\textsc{Cat}s\xspace}
\renewcommand{\paragraph}{%
  \@startsection{paragraph}{4}%
  {\z@}{0.25ex \@plus 0ex \@minus -.3ex}{-1em}%
  {\normalfont\normalsize\bfseries}%
}
\maketitle
\definecolor{dgreen}{rgb}{0.0, 0.52, 0.34}
\newcommand{\dgreen}[1]{\color{dgreen}{#1}}
\definecolor{dblue}{rgb}{0.0, 0.32, 0.52}
\newcommand{\dblue}[1]{\color{dblue}{#1}}
\newcommand{\rr}[1]{{\dgreen{RR: #1}}}

# Introduction

<figure id="fig:pareto_frontier" data-latex-placement="h">
<img src="figures/pareto_frontier_v2.png" style="width:80.0%" />
<figcaption><strong>unlocks test-time control of quality-efficiency trade-offs</strong>, where a <strong><u>single</u> adaptive model</strong> (all <strong><span style="color: red">red</span></strong> dots come from a single model) outperforms nearly every popular efficient architecture on real-world in-context recall tasks across varying compute-memory budgets.</figcaption>
</figure>

Transformers [@vaswani2017attention] are the default architectures for large language models (LLMs) and rely on the powerful self-attention mechanism [@bahdanau2014neural]. However, the compute required for decoding with self-attention grows quadratically with the sequence length, with memory costs growing linearly, making transformers expensive to deploy.

Given the cost of self-attention, there has been interest in designing efficient alternatives: while approaches like sparse and sliding window attention [@child2019generating; @zaheer2020big; @jiang2023mistral7b] heuristically restricts the tokens being attended to; approaches like linear attention [@katharopoulos2020transformers; @arora2024simple; @dao2024transformers; @yang2025gated] rely on fixed-size recurrent states with complex state update rules, to enable constant compute and memory costs. However, restricting attention to tokens apriori or using fixed-size recurrent states, that have problems managing information over long sequences, hurt in-context recall performance [@arora2024simple; @jelassi2024repeat; @wen2024rnnstransformersyetkey]. To make these approaches performant, they require careful composition with dense attention at specific layers, making the design process cumbersome especially at scale [@waleffe2024empirical; @wang2025systematic]. Enabling efficiency by recursively compressing the sequence can avoid fixed-memory bottlenecks and heuristic restrictions [@Rae2020Compressive; @chevalier2023adapting], but sequential computations in these approaches makes the training slow and optimization difficult [@geiping2025scaling].

\begin{wrapfigure}{r}{0.55\linewidth}
     %

    \includegraphics[width=\linewidth]{figures/cats_teaser_v2.pdf}
    \caption{
    \textbf{The Compress and Attend Transformer (\textsc{cat}) architecture.}
    \cat chunks up a sequence of length $N$ into $N/C$ chunks of $C$ tokens (illustrated for $C=3$).
    Each chunk is parallelly compressed into a chunk representation.
    \cat then decodes each chunk by attending to past chunk representations.
    Compression results in a reduced sequence length enabling compute and memory savings during decoding.
    Chunk size in \cat acts as knob, offering test-time control of quality-efficiency trade-offs, where higher chunk sizes result in increased efficiency.
    }
    \label{fig:cat_figure}

\end{wrapfigure}

Additionally, existing approaches do not account for the differences in compute and memory requirements across diverse downstream tasks. For example, writing short email replies does not require strong in-context recall performance and usage of linear attention can be sufficient; but code auto-completion demands accurate recall of function names from the entire code repository in the context, where more memory and compute of dense attention may be preferred. The existing approaches for efficiency *fix* the compute-memory usage before training. This means if at test time a problem demands a higher budget for better performance, a whole new model needs to be trained. Training multiple models with different tradeoffs is one way to tackle this problem but repeating this for every downstream task can become quickly prohibitive. Even if such models were available, routing between these models based on the context requires holding all of them in memory.

To address the issues raised above, we propose a conceptually simple architecture: **C**ompress & **A**ttend **T**ransformer (`\cat`{=latex}) that employs two simple well-known ingredients, namely dense attention and compression. `\cat `{=latex}compresses chunks of tokens in parallel into a shorter sequence using a *compressor*, which a *decoder* then attends to while autoregressively modeling the tokens in the latest chunk; both compressor and decoder are simple dense transformers themselves (see `\Cref{fig:cat_figure}`{=latex}). With the compression and decoding being parallel over tokens during training, there is no recurrence along the sequence dimension, which enables end-to-end **scalable training**. Decoding from the reduced sequence length due to compression enables **compute and total memory savings**. This reduction allows the use of more parameters in `\cat`{=latex}, thereby improving model quality.[^1] Choosing a particular chunk size in `\cat `{=latex}trades-off quality for compute and memory (see `\Cref{fig:pareto_frontier}`{=latex}). At the same time, the **memory grows *gracefully***, linearly with sequence length but at a significantly slower rate, to enable in-context recall performance at long sequence lengths. Importantly, training `\cats `{=latex}across multiple chunk sizes at once **unlocks control of quality-compute trade-offs directly at test-time** without any retraining, all in a single adaptive architecture.

To summarize, this paper:

- Introduces the `\cat `{=latex}architecture to efficiently model sequences by decoding each chunk of tokens given parallelly compressed representations of the past chunks. Adjusting a single knob (chunk size) at test-time controls quality-efficiency trade-offs, allowing a single `\cat `{=latex}model to *interpolate* between the dense transformer and efficient alternatives without any retraining.

- Provides a parallel and scalable implementation for training `\cats `{=latex}(we scale from 90M to 1B parameters) and an efficient pure PyTorch implementation for generation that does not require any custom CUDA or Triton kernels, unlike most efficient baselines.

  We release code at: [`github.com/rajesh-lab/cat-transformer`](https://github.com/rajesh-lab/cat-transformer)

- Demonstrates that a **single adaptive `\cat `{=latex}model**

  - outperforms many popular efficient baselines including hybrid architectures on language modeling, common-sense reasoning, long-context understanding, in-context recall, and needle-in-haystack tasks, across different compute and memory budgets.

  - matches or outperforms the dense transformer on language modeling at multiple model scales while being $\boldsymbol{1.4-3\times}$ faster and using a $\boldsymbol{2-9\times}$ smaller total memory footprint.

  - surpasses, interestingly even the dense transformer on real-world in-context recall tasks using the least efficient setting (`\cat-4`{=latex}) while still being atleast $1.5\times$ faster and $2\times$ memory efficient, akin to MoEs [@shazeer2017outrageously].

*Brief outline of the paper:* `\Cref{sec:cats}`{=latex} first lays out `\cats`{=latex}, including the overall architecture design, training objective and implementation details regarding scalable training and efficient generation. This is followed by a small discussion with the related work, highlighting conceptual differences between `\cats `{=latex}and other efficient architectures in `\Cref{sec:related_work,tab:related_work}`{=latex}. Then, `\Cref{sec:experiments}`{=latex} exhaustively compares `\cats `{=latex}with efficient baselines on plethora of downstream tasks. Finally, `\Cref{sec:discussion}`{=latex} ends with a discussion on the practical utility of `\cats `{=latex}and some future work.

# **C**ompress and **A**ttend **T**ransformers (`\cats`{=latex}) {#sec:cats}

In this section, we describe the components of the `\cat `{=latex}architecture and how it is trained for test-time control of trade-offs between quality and compute. Next, we discuss efficient implementation for `\cats `{=latex}and the resulting compute and memory savings.

#### Compression and decoding.

Given a sequence $\mbx=(x_1, x_2, \dots, x_N)$ of $N$ tokens, we split the sequence into chunks $(\mbc_1, \mbc_2, \dots, \mbc_{N_C})$ containing $C$ tokens each, such that $\mbc_i = (x_{C \cdot i+1}, \dots, x_{C \cdot i + C})=(\mbx_{i,1}, \dots \mbx_{i,C})=\mbx_{i,:}$, where $\mbx_{i,:}$ indexes the $i$-th chunk of $C$ consecutive tokens (`numpy` array slicing).

Next, `\cat `{=latex}compresses each chunk $\mbc_i$ using the *compressor* $f_\theta$ into chunk representations. The *compressor* $f_\theta$ is a dense bidirectional transformer with hidden size $D_{f}$, followed by a linear projection to $D_g$. This leads to a *compressed* chunk representation $f_\theta(\mbc_i)\in \mathcal{R}^{D_g}$. That is: $$\mbx = \{x_1, \cdots x_N\}  \;\xrightarrow{\text{chunking}}\; \{\mbc_1, \cdots \mbc_{N_c}\} \;\xrightarrow{\text{compress}}\; \{f_\theta(\mbc_1),  \cdots f_\theta(\mbc_{N_c})\} .$$ After compression, `\cat `{=latex}decodes the original sequence $\mbx$ from the compressed chunk representations $\{f_\theta(\mbc_i)\}_{i=1}^{N_C}$ using a *decoder* $g_\theta$, which is a causal dense transformer having hidden size $D_g$, matching the linear projection from the *compressor*. `\cat `{=latex}decodes chunks autoregressively, where to decode each token $\mbx_{i,j}$ in a chunk $\mbc_i$, the decoder takes as input the previous tokens $\{\mbx_{i,<j}\}$ in chunk $\mbc_i$ and the past chunk representations $\{f_\theta(\mbc_1), \dots, f_\theta(\mbc_{i-1})\}$. Formally, the predictive distribution $p_\theta$ for the tokens in chunk $\mbc_i$ is defined as: $$\begin{align}
p_\theta(\mbc_i\mid \mbc_{i-1} \cdots \mbc_{1})
&= \prod_{j=1}^C g_\theta \big( \underbrace{\mbx_{i,j}}_{j^{\text{th}}\ \text{token in chunk}\ \mbc_i} \mid \underbrace{ \mbx_{i,j-1} , \dots \mbx_{i,1}}_{\text{previous tokens in chunk } \mbc_i}, \underbrace{f_\theta(\mbc_{i-1}) \cdots f_\theta(\mbc_{1})}_{\text{past chunk representations}} \big)
\end{align}$$ By using compressed chunk representations, `\cat `{=latex}reduces the amount of compute and memory required for decoding; the larger the chunk size the larger the reduction in memory and compute.

During training, the compression and the decoding happens in parallel for all tokens in the sequence because compression of a chunk does not depend on earlier chunks. This choice allows the entire `\cat `{=latex}model to be efficiently trained end-to-end with the standard next-token prediction loss. The end-to-end training ensures that `\cats `{=latex}*learn what to retain* in their compressed chunk representations rather than relying on fixed attention patterns, or complex state update rules.

#### Training for test-time control in compute and memory. {#sec:adaptive_cats}

Varying the chunk size in `\cats `{=latex}trades-off quality for compute and memory efficiency. Training `\cat `{=latex}with multiple chunk sizes renders a single adaptive model whose compute-memory budget can be adjusted directly at test-time without any retraining.

To build such a controllably efficient `\cat `{=latex}model, we uniformly sample a chunk size $C$ at each training iteration, and pass in a *learnable* indicator token to `\cat `{=latex}to indicate which chunk size it is currently operating at. The compressed tokens are separated from the uncompressed ones in the decoder using a marker token shared across different chunk sizes. After training, one can use the same `\cat `{=latex}model at different compute/memory budget at test-time by just changing the indicator token. `\Cref{app:adaptive_cat_train_details}`{=latex} provides further detail.

## How to implement fast and scalable `\cats`{=latex} {#sec:implement_cats}

As both components of `\cat `{=latex}are transformers, `\cat `{=latex}admits a pure PyTorch implementation for scalable training and fast generation, without requiring custom CUDA or Triton kernels. We describe the approach here.

**Fast and Parallel Compression.** Compression of chunks of tokens is efficient and can be executed in parallel, for instance by using `torch.vmap` [@functorch2021], to produce $\{f_\theta(\mbc_i)\}$ for all chunks $\mbc_i$. This costs a total of $O(\frac{N}{C}\cdot C^2)=O(NC)$ in self-attention compute, rather than $O(N^2)$.

**Naive and Slow Training.** For training the decoder, a naive implementation can lead to slower training. To compute logits for tokens in chunk $\mbc_i$, that is computing $g_\theta (\mbc_i \mid f_\theta(\mbc_1) \cdots f_\theta(\mbc_{i-1}))$ in parallel can be non-trivial. Since, for chunk $\mbc_i$, the number of past chunk varies, making shapes variable and as a result, harder to parallelize the computation of logits. One could employ a python loop and compute logits for every chunk sequentially, but that would be slow and won't scale. Padding to make shapes constant to allow parallelism would make things worse by increasing wasteful computations. In fact, even if one bypasses varying shapes problem and manages to compute logits for every chunk in parallel, the total self-attention operations in the decoder would scale as $O(\sum_{i=1}^{N_c}(i+C)^2)=O((\frac{N}{C})^3)$, that is cubic in sequence length. Thus, even the perfect parallel approach for training will not scale, despite `\cats `{=latex}being a simple architecture.

#### Fast and Scalable Training.

To make training scalable in `\cats`{=latex}, we observe that in computing logits for every chunk $\mbc_i$, one calculates exactly the same key-value vectors for the representation $f_\theta(\mbc_j)$ in the decoder transformer, where $j<i$. This means that computation is duplicated. We exploit this observation in training `\cats`{=latex}.

On a high-level, we implement this observation by modifying the original chunked sequence $\mbx=\{\mbc_1,\dots\mbc_i\dots\}$ to $\{\mbc_1,f_\theta(\mbc_1),\mbc_2,f_\theta(\mbc_2),\dots \mbc_i,f_\theta(\mbc_i)\dots\}$, that is we insert compressed representations of the chunk after the chunk of tokens itself. Now, we pass this sequence into the decoder during training, with a custom attention mask (App. Figure `\ref{fig:cat_attention_mask}`{=latex}) that allows a token in chunk $\mbc_i$ to attend to previous tokens within that chunk and *only* to previous chunk representations, which would be $f_\theta(\mbc_{i-1}), f_\theta(\mbc_{i-2})\dots f_\theta(\mbc_1)$. Any token in chunk $\mbc_i$ does not attend to raw tokens outside this chunk. This implementation allows re-use of key-values for chunk representations $f_\theta(\mbc_i)$ in decoder for computing logits of a future chunk $\mbc_j$, where $j>i$. This way of computing logits is quadratic in sequence length, in fact it is a constant times better: $O(\frac{N^2}{C})$ vs. the $O(N^2)$ complexity of the dense transformer, allowing for a potentially faster pre-training (see `\Cref{sec:discussion,app:training_details}`{=latex} for a discussion).

**Fast and Efficient Generation.** Due to compression, `\cats `{=latex}can throwaway past chunks of tokens, and only keep their compressed chunk representations in memory. This straightaway results in a big reduction of memory; the KV cache is slashed by a factor of $C$, even for a modest chunk size of 4 (see `\Cref{fig:generation_throughput}`{=latex}). This slash in memory is accompanied by reduced memory accesses the decoder makes in `\cats`{=latex}, which is the major bottleneck during generation. The decoder attends to atmost $N_c+C$ tokens during generation, reducing compute required in self-attention significantly.

Implementing generation is simpler than training and very similar to how it occurs for a dense transformer. In fact, a pure PyTorch implementation[^2] for `\cats `{=latex}is on-par with efficient architectures that utilize custom kernels. Given a sequence, `\cats `{=latex}first compute representations for each chunk in parallel and use them to prefill the decoder's KV cache. Then generation proceeds chunk by chunk: each new chunk is decoded token by token in the decoder, and once a chunk is complete, the chunk is compressed and its representation is prefilled in the KV cache for the generation of the next chunk. This loop continues until the sequence is fully generated.

The full implementation details along with a PyTorch style pseudo-code are in `\Cref{app:pseudo_code,app:generation_details}`{=latex}.

# Related Work {#sec:related_work}

We provide a brief summary of the most relevant related work in this section. `\Cref{tab:related_work}`{=latex} highlights key properties and conceptual differences between `\cats `{=latex}and other methods. For an extended related work, refer to Appendix `\ref{app:extended_related_work}`{=latex}.

`\footnotesize   `{=latex} `\setlength{\tabcolsep}{4pt}`{=latex} `\renewcommand{\arraystretch}{1.3}`{=latex}

\begin{tabular}{|p{0.22\linewidth}|p{0.12\linewidth}:p{0.11\linewidth}:p{0.11\linewidth}:p{0.18\linewidth}:p{0.12\linewidth}|}
\hline
\textbf{Method} & \textbf{Unrestricted Access to Memory?} & \textbf{Flexible memory?} & \textbf{Scalable training?} & \textbf{Both compute \& memory efficient?} & \textbf{Adaptive?} \\
\hline
\textbf{\textit{Dense Attention}}: \cite{vaswani2017attention}
  & \cmark{} %
  & \cmark
  & \cmark
  & \xmark
  & \xmark \\
\hline
\textbf{\textit{Sparse Attention}}: \cite{child2019generating}
  & \xmark{} %
  & \cmark
  & \cmark
  & \cmark
  & \xmark \\
\hline
\textbf{\textit{NSA}}: \cite{yuan2025native}
  & \cmark{} %
  & \cmark
  & \cmark
  & \xmark
  & \xmark \\
\hline
\textbf{\textit{Sliding window Attn.}}: \cite{jiang2023mistral7b}
  & \xmark
  & \xmark
  & \cmark
  & \cmark
  & \xmark \\
\hline
\textbf{\textit{Linear Attention}:} \cite{dao2024transformers}
  & \cmark{} %
  & \xmark
  & \cmark
  & \cmark
  & \xmark \\
\hline
\textbf{\textit{Recursive compression}}: \cite{chevalier2023adapting}
  & \cmark
  & \cmark
  & \xmark{} %
  & \cmark
  & \xmark \\
\hline
\textbf{\textit{MegaByte/Block Transformer}}: \cite{ho2024block, yu2023megabyte}
  & \cmark
  & \xmark
  & \cmark
  & \cmark
  & \xmark \\
\hline
\textbf{\textit{CATs}}
  & \cmark %
  & \cmark %
  & \cmark %
  & \cmark
  & \cmark \\
\hline
\end{tabular}

#### Efficient self-attention using custom masks:

These techniques include *heuristically* defined *fixed* sparse or stratified attention masks [@child2019generating; @zaheer2020big] or local sliding window masks [@jiang2023mistral7b] that artificially restricts the tokens being attended to in self-attention. The compute required (and in some attention masks, memory) for attention goes down during generation, but if the *wrong* attention mask is chosen for the task, these methods will be less performant or would require more depth [@arora2024simple]. To match quality of a dense transformer, these models either require big window sizes (making their memory costs large again) or need to be composed with dense attention again at specific layers [@arora2024simple; @agarwal2025gpt].

#### Compressing past context:

[@Rae2020Compressive; @chevalier2023adapting] explored recurrent formulations of a transformer to enable generation of longer sequences on limited compute and memory by compressing past context. But sequential training is slow and memory intensive, making these approaches hard to scale on modern hardware that favors parallel computations. Moreover, training models in a recurrent fashion has optimization challenges, back-propagation through time (BPTT) being the most important one. More recently [@geiping2025scaling] had to use very careful recipe to train a large recurrent architecture in a stable manner and prevent optimization collapse.

Alternatively, Native Sparse Attention (NSA) [@yuan2025native] reduce attention compute by attending to compressed chunks of tokens as well as to specific chunks of uncompressed tokens in the past. These past tokens are compressed in parallel in every layer. This is similar in spirit to our work, however there are no memory savings during inference since the KV cache needs to be retained for the entire past context; there are only compute savings.

#### Linear attention:

[@arora2024simple; @katharopoulos2020transformers] linearize self-attention that replace softmax-based attention with kernelized dot-product-based linear attention, that further admits a linear recurrence form. Recent enhancements incorporate data-dependent gating mechanism in the recurrence [@dao2024transformers; @yang2025gated]; all require handcrafted and complicated recurrent state update rules. These architectures show impressive reductions in compute and memory, but the fixed-size recurrent state struggles to manage information over long sequences, that hurts in-context recall performance [@arora2024simple; @jelassi2024repeat; @wen2024rnnstransformersyetkey]. To make these mixers competitive, they are usually composed with long sliding window attention at specific layers [@yang2025gated]. Performing such a composition is unclear and requires careful *trial-and-error* [@waleffe2024empirical; @qwen_blog_2025] making the design process for an efficient architecture cumbersome, especially at scale [@wang2025systematic].

#### Hierarchical transformers:

[@nawrot2021hierarchical; @nawrot2022efficient; @slagle2024spacebyte] explored downsample-then-upsample approach (*hour-glass* like structure), where the sequence is downsampled into *coarse* tokens followed by upsampling into *fine-grained* tokens before being decoded. Due to the *hour-glass* structure, there are compute savings during training; but the architecture must maintain a cache for all the past tokens leading to significant memory accesses (especially for *fine-grained* ones) which is the main bottleneck during generation.

Unlike above, [@ho2024block; @yu2023megabyte] break up the modeling of a sequence into independent chunks/patches, given a single compressed representation of the entire past. While compression helps in efficiency, the requirement to decode each chunk from a single, fixed size, compressed representation results in poor in-context recall even on simple toy tasks (Fig. `\ref{fig:block_transformer_fails}`{=latex}). Further, unlike the original encoder-decoder architectures that attend directly to past tokens [@raffel2020exploring; @vaswani2017attention], decoder in `\cat `{=latex}attends to the compressed representations of chunks of tokens in the past.

`\cats `{=latex}sidestep many limitations of existing efficient baselines described above. Firstly, `\cats `{=latex}do not require any heuristic choices for attention masks, complex state update rules, or careful composition with attention layers to have competitive performance; `\cats `{=latex}directly build on simple dense transformer abstractions. Secondly, `\cats `{=latex}alleviate the fixed memory bottleneck by having flexible but efficient memory usage: it grows *gracefully* as sequence length increases, resulting in superior in-context recall performance, despite using similar memory overall compared to fixed memory baselines (Table `\ref{tab:swde_fda_results}`{=latex}). Thirdly, `\cats `{=latex}admit scalable training where compression and decoding can happen in parallel. Finally, `\cats `{=latex}enable control of quality-compute trade-offs at test-time, allowing them to cater to diverse downstream tasks with different compute-memory budgets. This is similar in spirit to [@kusupati2022matryoshka; @devvrit2023matformer; @beyer2023flexivit]. `\Cref{tab:related_work}`{=latex} provides a brief summary.

Refer to `\Cref{app:extended_related_work}`{=latex} for an extended related work.

# Experiments {#sec:experiments}

**Baselines:** Our experiments provide a comprehensive comparison of recent architectures, including (i) attention-based baselines: standard Dense Transformer [@touvron2023llama] and Sparse Transformer [@child2019generating], (ii) Linear Transformers such as Mamba2 [@dao2024transformers] and GatedDeltaNet ([gdn]{.smallcaps}) [@yang2025gated], as well as (iii) hybrid architectures such as the hybrid variant of [gdn]{.smallcaps} having alternate layers as long sliding windows, [GDN]{.smallcaps}-Hybrid.

All baselines use $L=12$ layers with hidden size of $D=1024$, making their parameters count not more than $\sim300$M, except Sparse Transformer that uses $\sim800M$ parameters due to hidden size of $2D=2048$ for a fair comparison with `\cats `{=latex}(as we will see below). [GDN]{.smallcaps}-Hybrid employs a sliding window of $2$K, following [@yang2025gated]. Refer to Appendix `\ref{app:training_details}`{=latex} for more details regarding hyperparameters used for each baseline.

**Training setup:** All models were trained on 15B tokens of FineWeb-Edu [@penedo2024fineweb] which is $2.5\times$ the Chinchilla optimal, with a context length of 4K following [@behrouz2024titans; @yang2025gated]. We use the AdamW optimizer [@loshchilov2017decoupled] with a peak learning rate of 8e-4, weight decay of 0.1, gradient clipping of 1.0, batch-size of 0.5M tokens, employing the GPT2 tokenizer (see Appendix `\ref{app:training_details}`{=latex} for more details).

\footnotesize

`\setlength{\tabcolsep}{3pt}`{=latex}

  **Model**              **LMB**$\downarrow$   **Wiki**$\downarrow$   **FW**$\downarrow$    **HS**$\uparrow$   **PQ**$\uparrow$   **AE**$\uparrow$   **AC**$\uparrow$   **WG**$\uparrow$   **OQA**$\uparrow$   **Avg.**$\uparrow$
  --------------------- --------------------- ---------------------- --------------------- ------------------ ------------------ ------------------ ------------------ ------------------ ------------------- --------------------
  Dense                         38.7                   19.6                  17.1                 34.8               65.6               56.7               24.4               51.1               20.0                 42.1
  Sparse                        37.2                   18.5           `\ud{16.0}`{=latex}         35.6               66.8               57.3               25.4               51.1               22.8                 43.2
  Mamba2                        36.1                   19.5                  16.7                 36.1               67.0               59.2               26.5               51.9               21.6                 43.7
  [GDN]{.smallcaps}           **35.7**                 18.8                  16.3                 36.1               66.8               58.7               25.2               51.6               22.8                 43.5
  [GDN]{.smallcaps}-H           36.6                   18.5                  16.2               **36.8**             66.3               56.4               25.8               52.1               20.4                 43.0
  CAT-4                         38.0                 **18.1**         `\ud{16.0}`{=latex}         35.6               66.4               59.5             **27.1**             51.5               23.4                 43.9
  CAT-8                         37.2                 **18.1**              **15.8**               35.4               66.8               60.1             **27.4**             51.3               23.6                 44.1
  CAT-16                        36.8           `\ud{18.4}`{=latex}    `\ud{16.0}`{=latex}         35.5               67.3               60.2               27.0               52.0               23.8                 44.3
  CAT-32                        36.8                   19.1                  16.4                 35.9             **68.2**           **61.0**             27.0             **53.6**           **25.0**             **45.1**

  : Zero-shot perplexity and accuracy on language modeling and common-sense reasoning benchmarks. Note that `\cat-4`{=latex}/8/16/32 reported here is a single model. {#tab:lm_eval}

\captionsetup{font=footnotesize,labelfont=footnotesize}
\begin{wrapfigure}[18]{r}{0.6\textwidth}

        \includegraphics[width=\linewidth]{figures/benchmark.png}
        \captionsetup{font=footnotesize,labelfont=footnotesize}
        \caption{\textbf{A single \cat model generates $1.4-3.2\times$ faster than the dense transformer while showcasing upto $2.2-9.5\times$ lower memory usage}.
        Per \cref{tab:swde_fda_results}, \Cat-8 outperforms \textsc{gdn}-Hybrid in real-world recall tasks while being faster and requiring similar memory; \cat-16 outperforms Mamba2 and \textsc{gdn} and is $1.15\times$ faster but costs slightly ($\sim15\%$) more memory.
        }
        \label{fig:generation_throughput}
\end{wrapfigure}
\begin{wraptable}[4]{r}{0.38\textwidth}


        \setlength{\tabcolsep}{3pt}
        \begin{tabular}{l|ccc}
        \toprule
        \textbf{Model} & \textbf{SWDE} & \textbf{FDA} & \textbf{Avg.} \\
        \midrule
        Dense     & 43.4 & 19.7 & 32.0 \\
        Sparse    & 20.9 & 6.0 & 13.0 \\
        Mamba2    & 13.5 & 4.5 & 9.0 \\
        \textsc{gdn}       & 18.0 & 6.8 & 12.0 \\
        \textsc{gdn}-Hybrid      & \ud{44.0} & 17.8 & 31.0 \\
        \midrule
        \cat-4   & \textbf{49.1} & \textbf{45.1} & \textbf{47.1} \\
        \cat-8   & 38.2 & \ud{34.8} & \ud{36.5} \\
        \cat-16  & 27.5 & 15.4 & 21.5 \\
        \cat-32  & 13.2 & 3.2  & 8.2  \\
        \bottomrule
        \end{tabular}
        \captionsetup{font=footnotesize,labelfont=footnotesize}
        \caption{Zero-shot performance on real-world in-context recall tasks measured at $2$K sequence lengths.
        We report results on SWDE and FDA here, which have longer sequences among the datasets in the suite (others have an average length of $\leq300$ tokens~\citep{arora2024just}). \Cref{app:more_experiments} shows evaluations on all datasets.
        All \cats reported here is a single model.
        }
        \label{tab:swde_fda_results}
\end{wraptable}

**What makes `\cats `{=latex}purr?** To match dense-transformer perplexity, we empirically find a more *expressive* decoder helps: that is, decoder uses $2\times$ hidden size. *This suggests accurate decoding from compressed representations needs extra compute*, with similar observations in recent works [@ho2024block; @yu2023megabyte]. Refer to App. `\ref{app:expressive_decoder}`{=latex} for a comparison. Further, we find depth of compressor does not have major effect on perplexity (App. `\ref{app:cat_ablation}`{=latex}). Given these findings, to instantiate `\cats `{=latex}that compete with dense transformer of depth $L$ and hidden size $D$: `\cats `{=latex}use a decoder of depth $L$ and hidden size $2D$, and a compressor of depth $L/4$ and hidden size $D$. While this increases parameters, `\cats `{=latex}are still significantly faster and memory efficient (see `\Cref{fig:generation_throughput}`{=latex}) compared to the corresponding dense transformer. Thus, for `\cats `{=latex}we use $L=12$ layers, same as baselines, but a wider hidden size of $D_g=2D=2048$ for the decoder. The compressor uses $L=3$ layers and hidden size of $D_f=D=1024$. This makes the parameter count for `\cats `{=latex}close to $1B$. We train `\cats `{=latex}simultaneously on chunk sizes $C=\{4,8,16,32\}$. Note that this `\cat `{=latex}is a single model that can work with different chunk sizes at once, offering different compute-quality trade-offs at test-time.

\begin{wraptable}[15]{r}{0.58\textwidth} %
 %
\captionsetup{justification=RaggedRight,singlelinecheck=false,
              font=footnotesize,labelfont=footnotesize}

\setlength{\tabcolsep}{2pt}
\scriptsize
\begin{tabular}{l|cc|cc|cc|l}
\toprule
& \textbf{Single-doc QA}
& \textbf{Multi-doc QA}
& \textbf{Few Shot}
& \textbf{Avg.} \\
\cmidrule(lr){2-3} \cmidrule(lr){4-5} \cmidrule(lr){6-7}
\textbf{Model} & \texttt{QAS} & \texttt{MQA} & \texttt{HQA} & \texttt{2WMQ} & \texttt{TQA} & \texttt{TREC} &  \\
\midrule
Dense                       & 3.9 & 12.2 & 6.9 & \textbf{10.8} & 11.2 & 10.6 & 9.3 \\
Sparse                      & 5.1 & 11.0 & 7.0 & \ud{10.6} & 10.5 & 5.6 & 9.3 \\
Mamba2                      & 4.1 & 11.9 & \textbf{7.6} & 7.6  & 9.0  & 7.6  & 8.0 \\
\textsc{gdn}                & \textbf{8.3} & \textbf{15.5} & 6.0 & 7.9  & 7.4  & 8.3  & 8.9 \\
\textsc{gdn}-Hybrid         & 4.2 & 13.3 & 6.6 & 11.6 & 11.8 & 6.5  & 9.0 \\
\midrule
\cat-4                      & \ud{5.6} & 12.7 & \ud{7.4} & 9.9  & \ud{12.1} & \textbf{35.6} & \textbf{13.9} \\
\cat-8                      & 5.5 & 11.0 & 6.1 & 8.0  & \textbf{12.4} & \ud{29.5} & \ud{12.1} \\
\cat-16                     & 4.3 & \ud{14.1} & 6.1 & 5.6  & 10.5 & 16.6 & 9.5 \\
\cat-32                     & 4.7 & 11.0 & 7.0 & 6.6  & 10.0 & 8.3  & 7.9 \\
\bottomrule
\end{tabular}

\caption{Zero-shot evaluation of baselines on suite of tasks from LongBench \cite{bai2023longbench} up to $4$K tokens. Refer to \Cref{app:datasets}. \cat-4/8/16/32 are a single model.}
\label{tab:longbench}
\end{wraptable}

**Language modeling and understanding benchmarks:** `\Cref{tab:lm_eval}`{=latex} reports the zero-shot perplexity against LAMBADA (LMB) [@paperno2016lambada], WikiText (Wiki) [@merity2016pointer], and on a held-out test set of FineWeb-Edu (FW), and the zero-shot accuracies on key common-sense reasoning benchmarks; `\Cref{app:datasets}`{=latex} expands the acronyms in `\cref{tab:lm_eval}`{=latex}. All `\cat `{=latex}variants outperform existing efficient baselines on common-sense reasoning benchmarks on average. `\Cats-4`{=latex}/8/16 match or outperform all the baselines on the language modeling tasks except LMB. Note that `\cat-32`{=latex} outperforms other `\cat `{=latex}variants on common-sense reasoning and language understanding benchmarks since these evaluations only consider short sequences ($\leq30$ tokens on average) -- this means the decoder in case of `\cat-32`{=latex} directly consumes the raw sequence without any compression. We test language understanding on longer contexts in `\cref{tab:longbench}`{=latex} on a suite of tasks from LongBench [@bai2023longbench] where `\cats-4`{=latex}/8/16 outperform all the baselines, and we observe the expected trends between different `\cats`{=latex}.

**Real world in-context recall:** Table `\ref{tab:swde_fda_results}`{=latex} reports results on real-world in-context recall tasks from [@arora2024simple]. Linear models (Mamba2, GatedDeltaNet) lag far behind dense attention, while [gdn]{.smallcaps}-Hybrid reduces the gap. `\Cat `{=latex}surpasses nearly all efficient baselines, benefiting from the gracefully growing memory. Interestingly, `\cat `{=latex}outperforms even the dense transformer at moderate chunk sizes ($C=4,8$), while being at least $1.4\times$ faster and $2.2\times$ more memory efficient. `\Cref{fig:pareto_frontier}`{=latex} reports these results, with more details in App. `\ref{app:figure_details}`{=latex}. More importantly, `\cat `{=latex}achieves these strong results at varying compute-memory budgets using a single adaptive model.

\footnotesize
\captionsetup{font=footnotesize,labelfont=footnotesize}

                                                                             **S-NIAH-N**          **S-NIAH-U**          **BabiLong**
  ---------------------------------------------------------------------- --------------------- --------------------- --------------------- --------------------- --------------------- --------------------- --------------------- --------------------- --------------------- --------------------
  2-4 `\cmidrule`{=latex}(lr)5-7 `\cmidrule`{=latex}(lr)8-11 **Model**          **1K**                **2K**                **4K**                **1K**                **2K**                **4K**                **0K**                **1K**                **2K**                **4K**
  Dense                                                                          96.0                  92.0                  43.0                **93.6**                55.7                  19.8                **49.0**                14.0                  12.0                  1.0
  Sparse                                                                         51.2                  46.2                   5.0                  12.8                   1.4                   0.8                  29.0           `\ud{22.0}`{=latex}           6.0                  4.0
  Mamba2                                                                  `\ud{97.7}`{=latex}          81.1                  18.6                  46.7                   4.6                   1.0                  30.0                  18.0           `\ud{19.0}`{=latex}          0.0
  [gdn]{.smallcaps}                                                              84.7                  69.1                  13.6                  38.9                   2.6                   2.0           `\ud{48.0}`{=latex}        **36.0**              **31.0**              **6.0**
  [gdn]{.smallcaps}-Hybrid                                                     **99.0**              **97.0**                44.0                  50.9                   5.6                   2.6                  35.0                  10.0                   2.0                  1.0
  `\cat-4`{=latex}                                                               96.0                **97.0**              **96.0**         `\ud{79.6}`{=latex}        **59.3**         `\ud{46.5}`{=latex}          46.0           `\ud{22.0}`{=latex}           9.0                  1.0
  `\cat-8`{=latex}                                                               90.0           `\ud{93.0}`{=latex}   `\ud{91.0}`{=latex}          68.1           `\ud{57.5}`{=latex}        **47.3**                46.0                  19.0                   9.0           `\ud{5.0}`{=latex}
  `\cat-16`{=latex}                                                              76.0                  72.0                  70.0                  10.0                   6.6                   3.8                  31.0                   5.0                   8.0           `\ud{5.0}`{=latex}
  `\cat-32`{=latex}                                                              60.0                  37.0                  31.0                   0.0                   0.0                   0.0                  17.0                  10.0                   7.0           `\ud{5.0}`{=latex}

  : Accuracy on RULER [@hsieh2024ruler] and BabiLong [@kuratov2024babilong] benchmarks. All `\cats `{=latex}reported are a single model. {#tab:s_niah}

**Synthetic needle-in-haystack & State-tracking:** Table `\ref{tab:s_niah}`{=latex} reports results on RULER [@hsieh2024ruler] synthetic single-needle tasks: S-NIAH-N (recall number worth 7 tokens from the long context) and the harder variant S-NIAH-U (recall a long alpha-numeric string or UUID containing 32 tokens from the long context).

Linear recurrent models (Mamba2, [gdn]{.smallcaps}) struggle at longer contexts, and while [gdn]{.smallcaps}-Hybrid narrows the gap with dense transformers, performance still drops at longer contexts. `\Cats-4`{=latex}/8/16 outperform the efficient baselines as context length increases, showing slower degradation with length, even compared to the dense transformer. This slow degradation can possibly be attributed to reduced sequence length in `\cat `{=latex}that leads to fewer distractions for attention [@barbero2024transformers; @chiang2022overcoming; @golovneva2025multi]. Further, large-chunk `\cat `{=latex}underperforms at short contexts but interestingly surpass baselines at long ones. One reason why large-chunk `\cat `{=latex}underperforms could be due to ineffective compression -- due to larger chunks, the compressor in `\cat `{=latex}is not always able to surface the *right* information in the chunk representation for accurate retrieval. More pre-training or finetuning on specific task data alleviates this problem for large chunk `\cat `{=latex}(see App. `\ref{app:finetuning_cats}`{=latex}). That being said, note that there is an upper limit to how much information fixed sized chunk representations can practically learn to hold for large token chunks. On the BabiLong state-tracking and retrieval task (`qa1` subset), all models decline as context grows, although linear recurrent models (Mamba2, [gdn]{.smallcaps}) perform better, in accordance with observations in [@kuratov2024babilong].

**Benchmarking generation:** Figure `\ref{fig:generation_throughput}`{=latex} compares architectures as one scales the sequence length, with a fixed batch-size of 320. `\cat `{=latex}generates sequences **$1.4-3.2\times$ faster** than the dense transformer while showcasing **upto $2.2-9.5\times$ lower total memory usage** as one increases chunk sizes, despite using more parameters due to wider decoder and the additional compressor. This is not surprising since the major bottlenecks during generation are: (a) KV cache size that drives the main memory requirement during generation and not the parameter count (see discussion in `\Cref{sec:discussion}`{=latex}), (b) memory accesses required for a token, and (c) FLOPs used per token determined by the past tokens being attended to. `\cats `{=latex}reduce all the above factors significantly despite carrying more parameters overall. Refer to `\cref{app:generation_details}`{=latex} for benchmark details.

**`\cats `{=latex}scale similar to their dense counterparts:** `\Cref{fig:scaling_law}`{=latex} demonstrates that `\cats `{=latex}scale similar to their dense transformer equivalents. We evaluate against three dense transformer scales $\{31M, 92M, 260M\}$, with their `\cat `{=latex}equivalents containing parameters $\{95M, 326M, 1000M\}$. All models were trained for 15B tokens, under the identical setup in described in `\cref{sec:experiments}`{=latex}.

<figure data-latex-placement="h">
<p><img src="figures/mqar.png" alt="image" />   </p>
<p><img src="figures/scaling_law.png" alt="image" />   </p>
</figure>

**Ablations:** We investigate how different choices affect performance of `\cats `{=latex}in App. `\ref{app:cat_ablation}`{=latex}. We further provide performance of `\cats `{=latex}when trained independently for a single chunk size in App. `\ref{app:indepedent_cats}`{=latex}.

**Instantiating `\cat `{=latex}as a layer:** While `\cat `{=latex}presented in the paper is a separate architecture, one can take the core concepts and instantiate `\cat `{=latex}as a layer that can be swapped in any sequence model, replacing dense attention. This can unlock lots of interesting possibilities starting with creating hybrid as well as adaptive architectures that mixes `\cat `{=latex}layers alongside dense attention, or perhaps even linear attention. We leave this open for future work. App. `\ref{app:cat_as_layer}`{=latex} reports results when instantiating `\cat `{=latex}as a layer on MQAR task along with more details. We provide a scalable and efficient implementation for `\cat `{=latex}as a layer in the released code.

## More comparisons with baselines

**`\cats `{=latex}outperform baselines when memory matched at every sequence length:** To rule out any memory discrepancy (Fig. `\ref{fig:generation_throughput}`{=latex}), we evaluate `\cat `{=latex}and baselines on the challenging MQAR task [@arora2023zoology], matching memory budgets down to the level of bytes, and stress-test up to $1$K sequence length ($4\times$ standard); `\Cref{fig:mqar_cat_gdn_mamba}`{=latex} reports these results. Baselines are grid-searched over learning rates. Linear models show degradation at longer contexts, while `\cats `{=latex}remain near-perfect, thanks to the flexible yet efficient memory. We use the setup described in App. `\ref{app:sparse_fails_mqar}`{=latex}.

**MegaByte/Block Transformer struggle at in-context recall:** The MegaByte/Block Transformer [@ho2024block; @yu2023megabyte] has elements similar to `\cat `{=latex}but fails to solve a simple in-context recall task in `\cref{fig:block_transformer_fails}`{=latex} across different hyperparameters and architecture configurations due to the fixed memory bottleneck. In fact, the block transformer overfits on the task. `\cats `{=latex}alleviate the memory bottleneck, allowing it to solve the task, with even lower memory requirements. See `\cref{app:block_transformer_fails}`{=latex} for details.

# Discussion and Conclusion {#sec:discussion}

We introduce **C**ompress & **A**ttend **T**ransformers (`\cats`{=latex}), a simple *controllably* efficient alternative to the standard transformer architecture. On language modeling tasks, common-sense reasoning, in-context recall and long-context understanding, `\cat `{=latex}outperforms various existing efficient baselines, across different compute-memory budgets. Notably, `\cat-4`{=latex} (the least efficient setting) outperforms the dense transformer in both language modeling and recall tasks while being $1.5\times$ faster and requiring $2\times$ less memory. We discuss a possible explanation for this observation, followed by the practical utility of `\cats `{=latex}and some future directions.

#### Parameters and Efficiency.

Despite the larger parameter count in `\cats`{=latex}, working with compressed sequences ensures that `\cats `{=latex}are faster and memory efficient than their dense transformer equivalent. In spirit, `\cats `{=latex}are similar to flagship models that are all parameterized as Mixture-of-Experts (MoEs) [@shazeer2017outrageously], where increased parameter counts in MoEs (upto $10\times$ more[^3]) does not mean higher computational costs during inference. In fact, more parameters brings improvement in quality while using the same compute [@muennighoff2024olmoe]. Due to compression, `\cats `{=latex}utilize their increased model parameters *smartly*, bringing improvements in quality (in-context recall) while still being efficient.

#### Are `\cats `{=latex}adoptable?

Current implementation for training `\cats `{=latex}costs twice as much time due to inefficient compilation (using recent PyTorch FlexAttention API [@dong2024flex]) of attention mask employed in `\cats `{=latex}(see `\Cref{app:cat_training_throughput}`{=latex} for a discussion); custom kernels can be developed in future to mitigate this difference and potentially realize compute savings during pre-training discussed in `\Cref{sec:implement_cats}`{=latex}. However, training is a one time cost, and the service life of models dictates profits, making *serving costs the more important consideration*. Deploying language models at scale is often constrained *not* by model weights but by the memory footprint of their KV cache. For instance, `Qwen3-14B` at a modest batch size of 16, which is common in chat/code completion, requires an *order of magnitude more memory* for the KV cache than the model weights themselves: $28$GB for the weights vs. $\sim670$GB for the KV cache at maximum context length. In contrast, a `\cat `{=latex}variant of the same model could reduce total memory usage upto $\sim4\times$ despite having more model parameters overall[^4], and lead to higher generation throughput. As most modern GPU workloads are increasingly memory-bound rather than compute-bound, memory reductions play an even more critical role [@gholami2024ai]. The reduction in memory and increase in throughput is more pronounced at larger batch sizes for `\cats`{=latex}, which are critical for workloads such as synthetic data generation [@maini2025beyondweb] and large-scale rollouts in reinforcement learning (RL) post-training pipelines for math/code reasoning or alignment for faster RL training [@noukhovitch2024asynchronous] or better RL optimization [@zhang2025preference]. Further, note that `\cats `{=latex}serve as multiple models in one, enabling reduced compute during high traffic, longer shelf-life under smaller budgets, and deployment on cheaper hardware -- **all from a single training run**.

**Future work:** `\cats `{=latex}currently rely on dense transformer abstractions, but the architecture is general and could incorporate other sequence mixers directly; for e.g. linear attention as *compressor* with dense attention *decoders* for long-range interactions between the compressed sequence. A different direction is data-dependent adaptivity. `\Cats`{=latex}, as they stand, require users to choose a chunk size appropriate for their compute and memory budgets. Instead, one could post-train with RL to allow `\cats `{=latex}to learn to allocate budget themselves based on the context and the task. Such post-training would enable truly adaptive efficiency. Further, as illustrated before, `\cats `{=latex}could possibly be instantiated as a layer instead as a full architecture, where every layer has a separate compressor and decoder. This could enable interesting hybrid and adaptive architectures combining dense attention layers mixed with `\cat `{=latex}layers. Additionally, one can train a `\cat `{=latex}where compression only kicks in only after 1K tokens (say), resulting in parallel compression of older context only. Finally, scaling up `\cats `{=latex}to larger model scales and longer training would enable further insights and better comparisons.

# Reproducibility Statement

We release code for `\cats `{=latex}at: [`github.com/rajesh-lab/cat-transformer`](https://github.com/rajesh-lab/cat-transformer). We provide exhaustive implementation details for `\cats `{=latex}in `\Cref{sec:implement_cats}`{=latex} and pseudo-code in Appendix `\ref{app:pseudo_code}`{=latex}. Further, we provide training details and hyperparameters for baselines in Appendix `\ref{app:training_details}`{=latex}. We directly use the official code for implementing and benchmarking baselines.

# Acknowledgments

We would like to thank Neelabh Madan, Saksham Rastogi (`pogs`), Aastha Jain, Raghav Singhal, Zhixuan Lin, Mark Goldstein, Ethan Barron, Anirudh Buvanesh, Atharv Sonwane, Daman Arora and Anshuk Uppal for super useful discussions and feedback. This work was partly supported by the NIH/NHLBI Award R01HL148248, NSF Award 1922658 NRT-HDR: FUTURE Foundations, Translation, and Responsibility for Data Science, NSF CAREER Award 2145542, NSF Award 2404476, ONR N00014-23-1-2634, Optum, and Apple. We would also like to thank the support by IITP with a grant funded by the MSIT of the Republic of Korea in connection with the Global AI Frontier Lab International Collaborative Research.

\bibliographystyle{iclr2026_conference}
\appendix
\newpage
\appendix
\startcontents[appendix] %

# Table of Contents for the Appendices {#table-of-contents-for-the-appendices .unnumbered}

\printcontents[appendix]{l}{1}{\setcounter{tocdepth}{2}}
\newpage

# More experiments {#app:more_experiments}

## Recall evaluation {#app:recall_evaluation}

Here, we evaluate all baselines on all datasets from the EVAPORATE suite of tasks that tests for real-world in-context recall.

  **Model**                    **SWDE**               **FDA**              **Squad**           **TriviaQA**            **Drop**              **Avg.**
  ---------------------- --------------------- --------------------- --------------------- --------------------- --------------------- ---------------------
  Dense                          43.4                  19.7           `\ud{31.0}`{=latex}          15.0           `\ud{19.4}`{=latex}   `\ud{26.7}`{=latex}
  Sparse                         20.9                   6.0                  20.7                  15.2                  19.3                  16.4
  Mamba2                         13.5                   4.5                  24.9                  13.9                  17.8                  14.9
  [gdn]{.smallcaps}              18.0                   6.8                  25.5                **15.5**                17.2                  16.6
  [gdn]{.smallcaps}-H1           44.0                  17.8                **32.9**         `\ud{15.4}`{=latex}        **19.8**                26.0
  `\cat-4`{=latex}             **49.1**              **45.1**                28.3                  15.0                  17.9                **31.1**
  `\cat-8`{=latex}        `\ud{38.2}`{=latex}   `\ud{34.8}`{=latex}          25.9                  14.0                  18.3                  26.2
  `\cat-16`{=latex}              27.5                  15.4                  20.4                  14.8                  16.9                  18.9
  `\cat-32`{=latex}              13.2                   3.2                  15.8                  13.0                  14.3                  11.9

  : Zero-shot performance on real-world in-context recall tasks from EVAPORATE suite, measured at $2$K sequence lengths. Note that only SWDE and FDA have long token sequences among the datasets in the suite (others have an average length of $\leq300$ tokens [@arora2024just]). {#tab:all_evaporate_results}

## Comparison with MegaByte/Block Transformer {#app:block_transformer_fails}

In figure `\ref{fig:block_transformer_fails}`{=latex}, we evaluate in-context recall ability for Block Transformer architectures [@ho2024block; @yu2023megabyte], that model chunks of tokens similar to `\cats `{=latex}but with a subtle but salient difference in the architecture circuit (that we explain below). For this experiment, we test on the MQAR task (a synthetic needle-in-haystack task [@arora2023zoology]) on a modest sequence length of 256. We test the accuracy of retrieving just 4 needles. We parametrize components of Block Transformer that is: global model and local model using a transformer, the embedder is a look-up table or a transformer. We keep the patch size/chunk size as 4 -- same as `\cat`{=latex}. We keep the identical training setup for both architectures. We grid search for hyper-parameters (`lr`, `hidden_size`, and embedder parameterization), even **using more memory** than the `\cat `{=latex}baseline, in its global decoder. Even in these simple settings and added advantage, Block Transformer [@ho2024block; @yu2023megabyte] fails to solve the task (fig. `\ref{fig:block_transformer_fails}`{=latex}) -- instead the model starts to memorize the train points, as seen from train loss and train accuracy -- train metrics keep getting better, however, test metrics suffer.

`\Cats `{=latex}directly pass all the "local" patch/chunk representations directly to the decoder, unlike the block transformer that forces the history to be compressed into fixed dimensional representation. This design choice helps `\cat `{=latex}*alleviate the memory bottleneck* that [@ho2024block] suffers from where the architecture must compress everything from the past into a single \"global\" representation to generate the next chunk. Note that this different design choice in `\cats `{=latex}does not introduce any memory/compute overhead compared to Block Transformer [@ho2024block], it just changes the circuit of the architecture. In fact, `\cats `{=latex}don't utilize three different components (embedder, global decoder, local decoder) -- it only uses a compressor and a decoder, reducing the design space and (significant) parameter requirements further.

## `\cats `{=latex}trained with fixed chunk size {#app:indepedent_cats}

`\Cref{tab:indepedent_cats}`{=latex} reports results for different `\cats `{=latex}when trained with a fixed chunk size. We observe fixed chunk `\cats `{=latex}exhibit better in-context recall performance compared to the single adaptive `\cat `{=latex}model. This drop in performance in single adaptive `\cat `{=latex}model could be attributed to the decoder losing some of its capacity to model multiple chunk sizes at once. More pre-training can alleviate some of this performance drop. Nevertheless, despite this drop, single adaptive `\cat `{=latex}model still outperforms all baselines.

  **Model**                         **SWDE**             **FDA**             **Squad**           **TriviaQA**           **Drop**             **Avg.**
  --------------------------- -------------------- -------------------- -------------------- -------------------- -------------------- --------------------
  `\cat-4`{=latex}             [49.1]{.underline}          45.1          [28.3]{.underline}          15.0                 17.9                 31.1
  `\cat-8`{=latex}                    38.2                 34.8                 25.9                 14.0                 18.3                 26.2
  `\cat-16`{=latex}                   27.5                 15.4                 20.4                 14.8                 16.9                 18.9
  `\cat-32`{=latex}                   13.2                 3.2                  15.8                 13.0                 14.3                 11.9
  `\cat-4`{=latex} (fixed)          **50.9**             **55.5**             **30.2**        [15.6]{.underline}   [20.4]{.underline}        **34.5**
  `\cat-8`{=latex} (fixed)            43.3          [51.6]{.underline}          26.9                 15.4               **21.4**        [31.7]{.underline}
  `\cat-16`{=latex} (fixed)           32.7                 23.3                 22.6               **15.8**               16.4                 22.2
  `\cat-32`{=latex} (fixed)           10.7                 5.4                  15.6                 13.4                 16.6                 12.3

  : In-context recall performance comparison of single adaptive `\cat `{=latex}model and independently trained `\cats `{=latex}with fixed chunk size. First four rows are a single `\cat `{=latex}model, whereas the last four rows are separate `\cat `{=latex}models, trained with a fixed chunk size. Fixed chunk size `\cats `{=latex}outperform their adaptive versions. {#tab:indepedent_cats}

## Sparse or Sliding Window Attention struggle at recall {#app:sparse_fails_mqar}

We evaluate models on the synthetic multi-associate query recall (MQAR) task, proposed in [@arora2023zoology] and further popularized in [@arora2024simple]. All models use depth of 2 layers, and are trained and tested on sequence lengths upto 256 having varying number of key-value pairs. `\cat `{=latex}models use a 1 layer compressor, followed by a 2 layer decoder, with a chunk size of 4, both using model dimension of $D=D_g=64$ in this case. Note that the state size for `\cat `{=latex}is $\frac{N}{C}\cdot D = 4096$ for this particular sequence length and model dimension. Sparse attention uses a chunk size of $4$ (for fair comparison with `\cat`{=latex}); Sliding window uses a window size of $64$.

  **Method**            **Solves?**       **State Size**
  ----------------- ------------------- ----------------
  Dense              `\cmark `{=latex}           $16384$
  Sparse             `\xmark `{=latex}            $4096$
  Sliding Window     `\xmark `{=latex}            $4096$
  `\cat `{=latex}    `\cmark `{=latex}            $4096$

  : For each method, we report the state size at which the particular method was trained for the MQAR task. Each method was grid searched for best possible hyper-parameters. We use the state size calculations provided in [@arora2024simple; @arora2023zoology]. {#tab:sparse_fails_mqar}

In table `\ref{tab:sparse_fails_mqar}`{=latex}, `\cat `{=latex}is able to solve the MQAR task. Notably, we find the sparse attention as well as sliding window attention fail to solve the task at 2 layers, highlighting their dependence on depth.

## `\cat `{=latex}as a layer {#app:cat_as_layer}

To instantiate `\cat `{=latex}as a seperate layer in itself, we parameterize the *compressor* as a simple linear projection. We use the dense attention mechanism itself as the *decoder*. Before applying the compression and decoding from compressed chunk representations, we artificially up-project the input embeddings in the layer -- this is done following the observation in the main paper that decoding from compressed representations requires more compute. Please find the actual implementation in the released code.

`\Cref{tab:cat_as_layer}`{=latex} reports MQAR accuracy when `\cat `{=latex}is used as a layer. We use a fixed chunk size of 4 in this experiment. We use 2 layers of `\cat`{=latex}. We follow the same setup described in `\Cref{app:sparse_fails_mqar}`{=latex}.

  **Method**                   **Solves?**       **State Size**
  ------------------------ ------------------- ----------------
  Dense                     `\cmark `{=latex}           $16384$
  `\cat `{=latex}           `\cmark `{=latex}            $4096$
  `\cat `{=latex}(layer)    `\cmark `{=latex}            $4096$

  : `\cat `{=latex}instantiated as a seperate layer solves the MQAR task again. {#tab:cat_as_layer}

## Finetuning `\cats `{=latex}on S-NIAH-U {#app:finetuning_cats}

S-NIAH-U is a task where model needs to recall 32 token long UUID strings from the long context. This section reports performance of `\cats `{=latex}after task specific finetuning on samples from S-NIAH-U. We only apply the loss on tokens that appear in the answer span. `\Cref{tab:finetuning_cats}`{=latex} reports these results. This is accompanied by loss curves for different `\cats `{=latex}depending on chunk size in `\Cref{fig:finetuning_cats}`{=latex} on this task.

We observe two things: (i) after finetuning, performance goes up significantly for all chunk sizes. This signifies as chunk size increased, compressor in `\cats`{=latex}, before finetuning, was not surfacing the *right* information in the chunk representation. (ii) the loss curves during finetuning indicate the same as well, however it still does not go completely to zero, especially for `\cat-32`{=latex}. This indicates that there are limits to what information a fixed sized chunk representation can practically learn to surface, justifying its sub-par accuracy on the task.

This problem of not surfacing the *right* information in the chunk representation could be alleviated by more and longer pre-training, or choosing smaller chunk sizes for tasks that require accurate recall.

<figure id="fig:finetuning_cats" data-latex-placement="h">
<img src="figures/cat_finetuning.png" style="width:50.0%" />
<figcaption>Loss curves when finetuning different on samples from S-NIAH-U task.</figcaption>
</figure>

  **Model**            **Before**   **After**
  ------------------- ------------ -----------
  `\cat-4`{=latex}        46.3        97.1
  `\cat-8`{=latex}        47.3        97.0
  `\cat-16`{=latex}       3.8         94.2
  `\cat-32`{=latex}       0.0         64.3

  : Performance on 4K sequence length before and after finetuning for different `\cat `{=latex}variants. {#tab:finetuning_cats}

## How are `\cats `{=latex}different from Block Transformer/MegaByte? {#app:block_diagram_different}

<figure id="fig:block_transformer_fails" data-latex-placement="h">
<p><img src="figures/block_fails.png" style="width:75.0%" alt="image" /> </p>
<figcaption> Block Transformer <span class="citation" data-cites="ho2024block yu2023megabyte"></span> (across different configurations and hyperparameters) fails to solve a simple MQAR task with only 4 key-value pairs, tested on modest sequence length of 256 tokens, possibly due to fixed memory. solves the task with ease due to flexible yet efficient memory. Note that training of stops when it solves the task perfectly. </figcaption>
</figure>

Works like [@ho2024block; @yu2023megabyte] break up the modeling of a sequence into chunks/patches, where each chunk is modeled independently of each other given the previous "global" chunk embedding. An embedder first compresses each chunk independently, then these "local" chunk embeddings are passed to a "global" model where each "local" chunk embedding attends to past "local" chunk embeddings, forming a "global" chunk embedding. Each "global" chunk embedding is then passed to a decoder that is responsible for generating the next chunk.

On first glance, **`\cats `{=latex}might appear similar to above works**, specifically Block Transformer [@ho2024block]/ MegaByte [@yu2023megabyte], however the subtle but salient difference is: one directly feeds **all the previous "local" chunk/patch representations** directly to the decoder in `\cat`{=latex}, whereas in works like [@ho2024block], one feeds in just the previous "global" chunk representation outputted by a "global" model to the decoder (refer to the comparative figure in `\ref{fig:block_and_cat_different}`{=latex}).

This architectural choice of passing *all* the compressed local chunks from the past directly to the decoder allows `\cats `{=latex}to solve long-range recall tasks with ease while maintaining efficiency, whereas **Block Transformer/MegaByte is plagued by *learnability* problems** (even in toy recall tasks) due to constant size compression of history (see Figure `\ref{fig:block_transformer_fails}`{=latex}). Additionally, `\cats `{=latex}don't utilize three different components (embedder, global decoder, local decoder) -- it only uses a compressor and a decoder, reducing the design space and (significant) parameter requirements further.

\newpage

# Implementation details {#app:pseudo_code}

In this section, we discuss some implementation details regarding `\cats`{=latex}. We repeat some text presented in the main paper to be self-contained below.

## Training

#### Training:

While `\cats `{=latex}are simple and build on dense transformer abstractions, their naive PyTorch training implementation is very inefficient.

Note that compression of chunks of tokens is efficient since it can be done in parallel, specifically using `torch.vmap`($f_\theta(\mbc_i)$) for all chunks $\mbc_i$. This costs a total of $O(\frac{N}{C}\cdot C^2)=O(NC)$ in self-attention compute, which is much better than $O(N^2)$.

But, computing logits for tokens in chunk $\mbc_i$, that is computing $g_\theta (\mbc_i \mid f_\theta(\mbc_1) \cdots f_\theta(\mbc_{i-1}))$ can be non-trivial since for chunk $\mbc_i$, we have $i-1$ past chunk representations $\{f_\theta(\mbc_1),f_\theta(\mbc_2)\dots f_\theta(\mbc_{i-1})\}$. In other words, there are different number of past chunk representations for every chunk, making shapes variable and as a result, harder to parallelize computation of logits. One could employ a python loop and compute logits for every chunk sequentially, but that would be slow and won't scale. In fact, even if one manages to compute logits for every chunk in parallel, the total self-attention operations in the decoder would be $O(\sum_{i=1}^{\frac{N}{C}}(i+C)^2)=O((\frac{N}{C})^3)$, that is cubic in sequence length. Padding to make shapes constant would make things worse. Thus, naive techniques will not scale.

*With such difficulties in making the training scalable, it may not be surprising that despite the simplicity of `\cats`{=latex}, it was not attempted in the community.* Note that unlike `\cats`{=latex}, similar architectures [@ho2024block; @yu2023megabyte] do not have this problem: computing logits can be naively parallelized due to fixed shapes and self-attention operations scale quadratically due to a single compressed representation for the past.

In `\cats`{=latex}, observe that in computing logits chunks $\mbc_i,\mbc_{i+1} \dots \mbc_{\frac{N}{C}}$, one calculates the same key-values for chunk representations $f_\theta(\mbc_j)$ in the decoder, where $j<i$. This points to repeated and identical computations. To exploit this observation, we take advantage of a custom attention mask in decoder to calculate logits for all chunks in parallel, and reuse computations done for a past chunk representation to be used for a computations for logits for a future chunk. To be concrete, once we calculate all chunk representations $f_\theta(\textbf{c}_i)$ in parallel using `torch.vmap`, we insert $f_\theta(\textbf{c}_i)$s at particular positions in the original sequence: after every chunk $\mbc_i$, we attach its chunk representation. That is, sequence would look like: $\{\mbc_1,f_\theta(\mbc_1),\mbc_2,f_\theta(\mbc_2),\dots \mbc_i,f_\theta(\mbc_i)\dots\}$. Now, we pass this sequence into the decoder during training, with a custom attention mask (see Figure `\ref{fig:cat_attention_mask}`{=latex}) that allows a token in chunk $\mbc_i$ to attend to previous tokens within that chunk only as well as only to previous chunk representations, which would be $f_\theta(\mbc_{i-1}), f_\theta(\mbc_{i-2})\dots f_\theta(\mbc_1)$ only. Any token in chunk $\mbc_i$ does not attend to raw tokens outside this chunk. This implementation allows re-use of key-values for chunk representations $f_\theta(\mbc_i)$ for calculation of logits of future chunks, in parallel, making the training of `\cats `{=latex}efficient and scalable. We utilize the FlexAttention API [@dong2024flex] to automatically create a custom kernel for the custom mask (Figure `\ref{fig:cat_attention_mask}`{=latex}). Note that this way of computing logits is quadratic in sequence length but with a constant times better: concretely it is $O(\frac{N}{C}\cdot N + \frac{N}{C}\cdot C^2)=O(\frac{N^2}{C})$, **which is $C\times$ better than $O(N^2)$** (yellow dots in figure `\ref{fig:cat_attention_mask}`{=latex} provides a visual proof for this cost; number of yellow dots are significantly lower than $\frac{N^2}{2}$). Mathematically the cost of attention in `\cats `{=latex}decoder is: $\sum_{i=1}^{N}[\frac{i}{C}]+(i \bmod C)+1=O(\frac{N^2}{C})$, where $[.]$ is the floor function, and $\bmod$ is modulo operator.

For a discussion in training throughput, refer to a discussion in Appendix `\ref{app:cat_training_throughput}`{=latex}.

``` {style="pytorch" caption="Pseudocode for training step"}

def forward(input_ids, targets):

    input_ids = einops.rearrange("b (k c) -> b k c", k=num_chunks, c=chunk_size)

    # calculate f(x)
    # shape of fx: (b, k, D_d)
    fx = torch.vmap(f)(input_ids)

    output_logits = list()
    for i in range(num_chunks): # note that this loop is done in parallel with the custom attention mask presented in the appendix
        # use the previous i+1 fx to predict the current chunk
        # shape of cur_chunk_logits: (b, 1, l, V)
        cur_chunk_logits = phi(input_ids[:, i, :], fx[:, :i+1, :])
        output_logits.append(cur_chunk_logits)
    output_logits = torch.cat(output_logits, dim=1) # shape: (b, k, c, V)
    output_logits = einops.rearrange(output_logits, "b k c v -> b (k c) v") # arrange all chunks logits together (or flatten)
    return torch.nn.functional.cross_entropy(output_logits, targets) # return the loss
```

## Attention mask {#app:cat_attention_mask}

<figure id="fig:cat_attention_mask" data-latex-placement="h">
<img src="figures/cat_attention_mask.png" style="width:50.0%" />
<figcaption>Sequence length is 128, and the chunk size that we use in this particular attention mask is <span class="math inline"><em>C</em> = 16</span>.</figcaption>
</figure>

Note that attention mask in figure `\ref{fig:cat_attention_mask}`{=latex} looks very similar to the attention mask as defined in [@child2019generating], however, in `\cat`{=latex}'s case: (a) it is not heuristic choice, and (b), tokens in a particular chunk attend to the past $f_\theta(\textbf{c}_i)$ representations obtained by the compressor, rather than the past token embeddings at that position as done in [@child2019generating].

\newpage

## Generation

The decoder during generation attends to atmost $\frac{N}{C}+C$ tokens. Due to compression, `\cats `{=latex}can throwaway past chunks of tokens, and only keep their compressed chunk representations in memory. This straightaway results in a big reduction of memory; the KV cache is slashed by a factor of $C$. For even a moderate chunk size of 4, this results in big reductions in memory during generation (Figure `\ref{fig:generation_throughput}`{=latex}) compared to a dense transformer. This slash in memory is accompanied by reduced memory accesses a decoder makes in `\cats`{=latex}, which is the major bottleneck during generation.

Implementing generation is simpler than training and very similar to how it occurs for a dense transformer. In fact, a pure PyTorch implementation for `\cats `{=latex}is on-par with efficient architectures that utilize custom kernels. We inspire our implementation from: <https://github.com/meta-pytorch/gpt-fast>. Given $i$ chunks of tokens: firstly, `torch.vmap` over chunks independently to calculate $f_\theta(\mbc_i)$ in parallel. Then prefill the decoder's KV cache in parallel with the obtained $f_\theta({\mbc}_i)$s. Now generate the next chunk ${\mbc}_{i+1}$ autoregressively one token at a time. Note that this uses a simple causal mask since the previous positions are already prefilled with $f_\theta({\mbc}_i)$s, which is required to decode chunk $\mbc_{i+1}$. Once all the tokens of the chunk ${\mbc}_{i+1}$ are generated, calculate $f_\theta({\mbc}_{i+1})$ and prefill the decoder's KV cache just after the position where $f_\theta({\mbc}_{i})$ was cached. Now the KV cache is ready for generation of the next chunk ${\mbc}_{i+2}$ and this process will continue.

This simple implementation enables `\cats `{=latex}to be **$1.4-3.2\times$ faster** than the dense transformer while showcasing **upto $2.2-9.5\times$ lower total memory usage** as one increases chunk sizes.

``` {style="pytorch" caption="Pseudocode for generation"}

# https://github.com/pytorch-labs/gpt-fast/blob/7dd5661e2adf2edd6a1042a2732dcd3a94064ad8/generate.py#L154
def generate_chunk_by_chunk(
    input_ids
):
    # assume input_ids.shape == (batch_size, 1, chunk_size)

    # declare/reset static KV cache, shape: [batch_size, num_chunks + chunk_size, 2, D_d]

    input_pos = 0

    # compress the first chunk (batch_size, 1, chunk_size) -> (batch_size, 1, D_d)
    # get fx for the very first chunk
    fx = f(input_ids) # shape of fx: (batch_size, 1, D_d)
    next_token = prefill(fx, input_pos) # prefill at idx 0 with fx in g

    new_chunks = list()

    for i in range(num_chunks - 1):

        # generate entire chunk using fx that was prefilled earlier in g
        next_chunk = generate_chunk(next_token)
        new_chunks.append(next_chunk.clone())

        # get new fx
        # compress the new obtained chunk
        fx = f(next_chunk) # (batch_size, 1, chunk_size) -> (batch_size, 1, D_d)

        # prefill again at input_pos
        input_pos += 1
        next_token = prefill(fx, input_pos) # prefill fx at idx `input_pos` in g

    new_chunks = torch.cat(new_chunks)
    return new_chunks
```

## Adaptive `\cats`{=latex} {#app:adaptive_cat_train_details}

To enable training of adaptive `\cats`{=latex}, we made some choices that we now describe. In every training iteration, we sample a chunk size uniformly at random and perform loss computation. Further, due to variable size of a chunk in every training iteration, one cannot keep a single projection matrix that projects processed token embeddings in the compressor to a single chunk representation (since shapes for projection matrix would be different for different chunk size). One could tackle this by keeping an independent projection matrix for every chunk size, but we found this didn't work well empirically, possibly due to reduced updates for every chunk size's projection weights (only one chunk size's projection weights are updated per iteration; this is not the case with compressor or the decoder, they are updated every iteration). Instead, we took inspiration from [@beyer2023flexivit] where the authors declared a single projection matrix for all chunk sizes, and then linearly interpolated the matrix to the desired shape depending on the current chunk size. This means the linear interpolation is also under `torch.autograd` and is optimized so that the final linearly interpolated projection matrix gives a *good* chunk representation for every chunk size.

## Training throughput analysis {#app:cat_training_throughput}

We make use of FlexAttention API to obtain a custom self-attention kernel specifically for the masking scheme `\Cref{fig:cat_attention_mask}`{=latex}. This fused kernel gives a significant boost in training throughput in self-attention costs compared to using a naive PyTorch masked implementation.

That being said, an efficient training kernel can be developed in the future. In our experiments, using FlexAttention did not give significant boosts in training speeds compared to using Flash Attention on a dense transformer. This could be due to the fact that speeding up the attention maps (that we use in figure `\ref{fig:cat_attention_mask}`{=latex}) may require different principles than Flash Attention like optimization that Flex Attention might be using under the hood.

As a result, due to the unavailability of an efficient training kernel, theoretical speed ups due to reduction in attention FLOPs in the `\cat `{=latex}architecture don't appear in training wall-clock times. Additionally, MLPs in a transformer drive the majority of the FLOPs budget during training at smaller sequence lengths [@scaling-book]. At a sequence length of 4096, `\cats `{=latex}take $\leq2.35\times$ to train compared to a dense transformer (measured on batch size of $8$ with compressor depth of $3$, decoder depth of $6$, hidden size for compressor $D_f=1024$ and hidden size for decoder $D_g=2D=2048$ for `\cat`{=latex}, compared against dense transformer having depth of $6$ and $D=1024$, on a A100 80 GB PCIe.)

Developing an efficient attention kernel for training `\cats `{=latex}is left as future work.

\newpage

# Ablations {#app:cat_ablation}

## Ablation on hidden size of compressor

With this ablation, we show that increasing hidden size of the compressor does not help in improving perplexity. We fix $D_g=1536$ for these experiments. For this ablation, we use a smaller WikiText-103 dataset. Both compressor and decoder use the same depth $L=6$.

  **Chunk Size $C$**   **Size of $D_f$**    **Perplexity**
  -------------------- ------------------- ----------------
  16                   $768$                     17.6
                       $1536$                    17.6

  : Comparison of choices of hidden size of compressor on WikiText-103 perplexity.

There is no effect of increasing the hidden size of the compressor. The performance before and after remains the same.

## Ablation on hidden size of decoder {#app:expressive_decoder}

We ablate on different choices of $D_g$ along with different chunk sizes in `\cat`{=latex}`\space`{=latex}. In this setup, we fix $D_f$ in the compressor, and only vary $D_g$ or $C$ (chunk size). We use WikiText-103 for these experiments. In this setup, $D=768$. Both compressor and decoder use the same depth of $L=6$.

  **Chunk Size $C$**   **Size of $D_g$**    **Perplexity**
  -------------------- ------------------- ----------------
  4                    $D$                       19.8
                       $2D$                      17.4
  8                    $D$                       20.4
                       $2D$                      17.7
  16                   $D$                       20.2
                       $2D$                      17.6

  : Comparison on choices of chunk sizes and sizes of $D_g$ on WikiText-103 perplexity.

We observe that we obtain the best perplexities when we $D_g=2D$ for the particular chunk size we are using. Using this observation, we used this as our *default* configuration for the FineWeb-Edu experiments.

  **Model**            **`\shortstack[c]{$D_f$}`{=latex}**   **`\shortstack[c]{$D_g$}`{=latex}**   **Perplexity**   **Avg. recall**
  ------------------- ------------------------------------- ------------------------------------- ---------------- ----------------- --
  Dense                               $--$                                   $D$                        21.2             23.8
  [cat]{.smallcaps}                    $D$                                   $D$                        23.8             13.7
  [cat]{.smallcaps}                    $D$                                  $2D$                      **20.7**         **19.8**

  : Impact on perplexity and average recall performance of [cat]{.smallcaps} when varying $D_g$. For dense, $D_g$ implies hidden size for itself. Here, $D=1024$. $D_g=2D$ gives better perplexity and average recall. We train `\cat `{=latex}only at chunk size $C=8$ for these experiments. All models were trained for 5B tokens with 1K sequence length. Rest of the setup follows Sec. `\ref{sec:experiments}`{=latex}. {#tab:decoder_power}

## Ablation on depth of the compressor

We ablate on the depth of the compressor. For a fixed chunk-size, $D_f=768$ (compressor embedding size), $D_g=1536$ (decoder hidden size), and a fixed depth of the decoder, we vary the compressor depth.

  **Chunk Size $C$**   **Depth of Compressor**    **Perplexity**
  -------------------- ------------------------- ----------------
  8                    $6$                             17.4
                       $3$                             17.4
  16                   $6$                             17.8
                       $3$                             17.7

  : Comparison on choices of depth of the compressor across different chunk sizes $C$ on WikiText-103.

We have an interesting observation that one can reduce the depth of the compressor without sacrificing on the downstream perplexity. This could mean one can compress small chunks of tokens without a requiring high capacity. In our generation benchmarks, we observed that compressor depth play less of a role in latency as compared to the decoder depth (since we compress tokens in parallel using one transformer call). That being said, compressor depth does play a significant role in training costs (due to the MLP training costs in the compressor). Therefore, reducing compressor depth goes into overall advantage for the `\cat`{=latex}`\space `{=latex}architecture.

However, what is the limit, and can one go to even a 1 layer of compressor is an interesting question to ask. There might be some lower bound on the compressor depth to start compressing chunks of tokens, but we leave this to future work.

\newpage

# More experiment details {#app:training_details}

Here we provide more details about the experiments done in the main text.

## Baselines

  **Model**                   **Total (M)**   **Embedding (M)**   **Non-Embedding (M)**
  -------------------------- --------------- ------------------- -----------------------
  Dense                            260               50                    210
  Mamba2                           260               50                    210
  GDN                              310               50                    260
  GDN-Hybrid                       280               50                    230
  Sparse                           820               100                   720
  `\cat-4`{=latex}/8/16/32      150 + 820         50 + 100              100 + 720

  : Model parameter sizes in millions, separated into embedding and non-embedding parameters. Parameters for `\cats `{=latex}consists of cost of compressor + cost of decoder. {#tab:model_sizes}

1.  Dense transformer (or Transformer++) [@vaswani2017attention; @touvron2023llama]: We use rotary position embeddings along with the FlashAttention kernel to perform self-attention. The MLP is a SwiGLU MLP [@touvron2023llama].

2.  Sparse transformer [@child2019generating]: Follows the Dense transformer configuration, except the attention mask used. Moreover, we used $D=2\cdot1024=2048$ for this baseline for a fair comparison with `\cats`{=latex}. We used FlexAttention API to create optimized Flash Attention like kernel for this.

3.  [mamba2]{.smallcaps} [@dao2024transformers]: The model uses 2 Mamba mixer per layer. All layers use the [mamba2]{.smallcaps} block without any mixing any attention. The `expand` is set to 2, $d_{state}=128$, and convolution $k=4$. Activations used are `SiLU`. We use the official codebase for [mamba2]{.smallcaps} generation throughput and memory benchmarking: <https://github.com/state-spaces/mamba> and code from: <https://github.com/fla-org/flash-linear-attention> for training.

4.  Gated Delta Net [@yang2025gated]: We use the implementation provided at <https://github.com/fla-org/flash-linear-attention> for training. We use `head_dim` as 128 and `num_heads` as 8 (same as [mamba2]{.smallcaps} above). For the hybrid version, we use sliding window layers at every other layer with a sliding window size of $2048$.

## Datasets {#app:datasets}

Following common practices done in [@gu2023mamba; @dao2024transformers; @arora2024simple; @yang2025gated], we evaluate all models on multiple common sense reasoning benchmarks: PIQA [@bisk2020piqa], HellaSwag [@zellers2019hellaswag], ARC-challenge [@arc_challenge], WinoGrande [@sakaguchi2021winogrande] and measure perplexity on WikiText-103 [@merity2016pointer]and LAMBADA [@paperno2016lambada]. In Table `\ref{tab:lm_eval}`{=latex}, HS denotes HellaSwag, PQ denotes PIQA, AE denotes ARC-Easy, AC denotes ARC-Challenge, WG denotes Winogrande, OQA denotes OpenBookQA, LMB denotes LAMBADA, Wiki denotes WikiText, and FW denotes FineWeb-Edu.

We evaluate on tasks from LongBench [@bai2023longbench] where each abbrevation in table `\ref{tab:longbench}`{=latex} stands for: QAS: `qasper`, MQA: `multifieldqa_en`, HQA: `hotpotqa`, 2WMQ: `2wikimqa`, TQA: `triviaqa`, TREC: `trec` split of LongBench.

To measure real-world recall accuracy, we use datasets used in [@arora2024simple; @arora2024just]. Namely these consists of SWDE [@lockard2019openceres] for structured HTML relation extraction and several question answering datasets including SQuAD [@rajpurkar2018know], TriviQA [@joshi2017triviaqa], DROP [@dua2019drop] and FDA [@arora2023language]. Since our pretrained models are small, we use the Cloze Completion Formatting prompts provided by [@arora2024just].

We evaluate on tasks from the needle-in-haystack benchmark RULER [@hsieh2024ruler].

Additionally, we evaluate on datasets from the LongBench benchmark [@bai2023longbench] to evaluate long-context understanding.

Finally, to evaluate baselines on state-tracking tasks, we used the BabiLong benchmark [@kuratov2024babilong]. Due to relatively small scale of our setup, we were only able to evaluate on `qa1` subset, since for other complex subsets, all baselines failed.

## Generation {#app:generation_details}

Both dense transformer and `\cat`{=latex}`\space `{=latex}use FlexAttention API causal dot product kernels. We use the script provided in [@dao2024transformers] to benchmark[^5] Mamba2, GatedDeltaNet and GatedDeltaNet-Hybrid. All benchmarks used a prefill of $8$ tokens. All benchmarks were run using a single NVIDIA A100 80GB PCIe, and use CUDA cache graphs for the next-token prediction.

## Main figure details {#app:figure_details}

`\Cref{fig:pareto_frontier}`{=latex} reports memory usage at 2K sequence length since both SWDE and FDA datasets have queries with context length $\leq$ 2K. The latencies are reported at maximum sequence length of 4K.

\newpage

# Extended Related Work {#app:extended_related_work}

#### Reducing self-attention costs:

Reducing the cost of self-attention enables scaling transformers to large contexts and has been the focus of much work @child2019generating [@parmar2018image; @beltagy2020longformer; @jiang2023mistral7b]. Common techniques include *heuristically* defined sparse attention maps [@child2019generating; @zaheer2020big] or a sliding window [@jiang2023mistral7b] in order to reduce the tokens being attended to. The compute required (and in some cases, memory) for attention go down, however, compromising with the expressivity of the model. In turn, to achieve performance similar to that of full-attention, efficient models either require big window sizes (making their memory costs large again)  [@arora2024simple] or more layers (in case of sparse or sliding window attention, see App. `\ref{app:sparse_fails_mqar}`{=latex} and Tab. `\ref{tab:swde_fda_results}`{=latex}).

[@shazeer2019fast] proposes use of single or reduced key and value heads in the self-attention block, more commonly known as Grouped Query Attention (only one key/value head) or Multi Query Attention (reduced key/value heads). This results in reduction of memory with seemingly no loss in downstream performance, making this a popular choice in latest model releases [@yang2025qwen3]. That being said, one could use the same technique inside `\cat`{=latex}'s decoder (and compressor) self-attention block, making it complimentary.

Concurrent works like [@yuan2025native] reduce attention compute by attending to compressed past tokens as well as to specific blocks of uncompressed tokens in the past. This is similar in spirit to our work, however, in the case of [@yuan2025native], there are no memory savings during inference.

Some works [@Rae2020Compressive; @chevalier2023adapting] explored recurrent formulations of a transformer to enable processing of longer sequences on limited compute by compressing past context. However, training sequence models in a recurrent fashion has its own challenges, back-propagation through time (BPTT) being the most important one. More recently [@geiping2025scaling] had to use very careful weight initialization, truncated gradients, small learning rates and careful placement and tuning of norms to train a large-scale recurrent architecture in a stable manner and prevent optimization collapse. Nevertheless, these techniques are complementary to `\cat`{=latex}.

Alternatively, one can optimize the computation of full-attention to directly reduce wall-clock time and memory by leveraging hardware advancements. For example, @dao2022flashattention compute attention in blockwise manner and exploit the nature of online softmax [@milakov2018online] which removes the need to instantiate the entire $QK^T$ matrix and reduce calls to slow-read part of the GPU memory. As we utilize the attention mechanism as is, any reductions in cost due to hardware optimization that apply to the attention mechanism also proportionally reduce the cost of `\cat `{=latex}models.

Finally, plethora of works have tackled reducing compute and memory requirements of a transformer in a *post-hoc* manner i.e. after it has been trained using full-attention (also called *training-free* sparse attention [@nawrot2025sparse]). Common techniques include prefill-time sparsification (vertical/slash/block; adaptive) and decode-time KV-cache selection/eviction (e.g. [@li2024snapkv; @tang2024quest]). Further, [@lee2024training; @lee2025infinitehipextendinglanguagemodel] explored offloading of key-value cache on the CPU along with sparsifying attention masks. However, because models are trained dense but run sparse, train-test mismatch can hurt downstream performance [@nawrot2025sparse]. Still, these works can be directly applied to the decoder in `\cat`{=latex}, or any other efficient alternatives including hybrid architectures that make use of dense attention, making them complementary. Concurrently, [@lancucki2025inference] explored delayed token eviction strategies to enable effective KV cache compression [@nawrot2024dynamic].

#### Linear attention and state-space models:

A different line of work reduces the generation cost of transformers by limiting the recurrent state, which is the vector required to decode each token. Self-attention keeps track of the entire context (or the KV cache) meaning that the recurrent state increases in size with each decoded token. Works like [@arora2024simple; @katharopoulos2020transformers] linearize attention to make a fixed-size recurrent state that can be updated via simple averaging; the technique is to approximate self-attention with linear operations of query, key, and value vectors transformed through a feature map. The choice of the feature map falls to the user and approximating attention well requires the feature map to be large in size, which can counteract the gains in computational costs achieved by the linearization.

Alternatively, one can replace attention with linear or pseudo-linear sequence mixers such as state-space models (SSMs) [@gu2021efficiently; @sun2023retentive], gated convolutions [@fu2022hungry; @poli2023hyena] and input-dependent recurrent [@peng2023rwkv; @gu2023mamba] and more recently [@yang2025gated].

Typical implementations of linear attention and state-space models do achieve impressive reductions in generation costs and memory, but restrict the expressivity to the extent that these models do not solve in-context recall tasks without large recurrent state sizes [@arora2024simple; @arora2023zoology], or without composing with other sequence mixers, such as local sliding window attention [@arora2024simple; @yang2025gated]. Choosing such a composition again falls back to the user, complicating the design process. Additionally, this process trades-off computation costs for performance because the attention layers that improve recall performance also come with larger time and memory costs.

Unlike the works discussed above, `\cats `{=latex}require no complicated changes to the attention mechanism itself. Instead of relying manual approximations of history or utilizing any heuristic choice for feature maps, we let the model and optimization decide what the history should be using learned compression. Moreover, its unclear how much memory and compute a downstream task requires, making the adaptive property of `\cats `{=latex}much desirable, which no other baselines provide.

#### Hierarchical transformers:

Many previous works [@pappagari2019hierarchical; @han2021transformer; @dai2020funnel] have explored employing hierarchy in transformers for creating representations representations for documents/images, where a *local* encoder transformer processed parts of the document/image independently. Later works [@nawrot2021hierarchical; @nawrot2022efficient; @slagle2024spacebyte] explored downsample-then-upsample approach (*hour-glass* like structure), where the sequence is downsampled into *coarse* tokens followed by upsampling into *fine-grained* tokens before being decoded. Due to the *hour-glass* structure, there are compute savings during training, but during generation, the architecture must maintain a cache for all the past tokens, leading to significant memory accesses. Concurrently, [@hwang2025dynamic] explored a dynamic and end-to-end learned strategy for chunking in *hour-glass* like architectures.

Different from above, works like [@ho2024block; @yu2023megabyte] break up the modeling of a sequence into chunks/patches, where each chunk is modeled independently of each other given the previous "global" chunk embedding. An embedder first compresses each chunk independently, then these "local" chunk embeddings are passed to a "global" model where each "local" chunk embedding attends to past "local" chunk embeddings, forming a "global" chunk embedding. Each "global" chunk embedding is then passed to a decoder that is responsible for generating the next chunk.

On first glance, `\cats `{=latex}might appear similar to above works, specifically [@ho2024block; @yu2023megabyte], however the subtle but salient difference is: one directly feeds **all the previous "local" chunk/patch representations** directly to the decoder in `\cat`{=latex}, whereas in works like [@ho2024block], one feeds in just the previous "global" chunk representation outputted by a "global" model to the decoder. This architectural choice of passing *all* the compressed local chunks from the past directly to the decoder allows `\cats `{=latex}to solve long-range recall tasks with ease while maintaining efficiency, whereas [@ho2024block] is plagued by *learnability* problems (even in toy recall tasks) due to constant size compression of history. Additionally, `\cats `{=latex}don't utilize three different components (embedder, global decoder, local decoder) -- it only uses a compressor and a decoder, reducing the design space and (significant) parameter requirements further.

Additionally, [@yen2024long] extend the cache by using a modified encoder-decoder architecture, where decoder attends directly to final activations of a smaller frozen encoder, without any compression.

Finally,  [@barrault2024large] suggest learning \`\`concepts" instead of tokens by modeling the latent representation of language produced by pushing the token sequence through a large sentence embedder. The focus of this work is to decouple the modeling of the low-level details in each language, like tense and grammar, from the larger concept space that is shared across languages. In contrast, the goal with `\cat `{=latex}is to reduce the cost of modeling sequences and can be used as a plug-and-play replacement to the latent concept model. Moreover, the encoder in [@barrault2024large] is an auto-encoder, that might keep irrelevant information in the chunk representation. Compressor in `\cats `{=latex}only keeps information that is predictive of the future chunks.

#### Adaptive architectures:

[@kusupati2022matryoshka; @devvrit2023matformer] learns representations during training time that can work at different granularity during test-time, yielding adaptivity to the learned architecture. However, [@devvrit2023matformer] applies the *Matryoshka* technique on the feed-forward layer only, and not on the attention weights, that yields compute savings, but not memory savings. That being said, one could apply similar approaches to `\cats `{=latex}making them complimentary. `\cats `{=latex}use the same high-level approach described in [@beyer2023flexivit]: learn a single model that can work for various patch sizes at once depending on the downstream use-case at test-time. However, [@beyer2023flexivit] worked with image tasks; `\cats `{=latex}deal with language modeling and generation.

[^1]: similar in vein to flagship models [@yang2025qwen3; @agarwal2025gpt] that are all parameterized as Mixture-of-Experts (MoEs) [@shazeer2017outrageously], where additional model parameters does not mean more compute, and brings improvements in quality.

[^2]: Our implementation is inspired from: [github.com/meta-pytorch/gpt-fast](https://github.com/meta-pytorch/gpt-fast)

[^3]: [`Qwen3-30B-A3B`](https://huggingface.co/Qwen/Qwen3-30B-A3B): only 3B parameters are *active* during inference out of 30B parameters in total

[^4]: Total memory usage for `\cats`{=latex}: $28\cdot(4+\frac{1}{4})+\frac{670\cdot2}{32}=160$GB, which is $\sim4\times$ better at chunk size $C=32$

[^5]: [github.com/state-spaces/mamba](https://github.com/state-spaces/mamba/blob/main/benchmarks/benchmark_generation_mamba_simple.py)
