Gated KalmaNet: A Fading Memory Layer Through Test-Time Ridge Regression

Liangzu Peng1,  Aditya Chattopadhyay2,  Luca Zancato2  Elvis Nunez2  Wei Xia2  Stefano Soatto2
University of Pennsylvania1  AWS Agentic AI2
[email protected]
{achatto,zancato,elvisnun,wxia,soattos}@amazon.com
Work done during an internship at AWS Agentic AI.Correspondence to [email protected]
Abstract

As efficient alternatives to softmax Attention, linear state-space models (SSMs) achieve constant memory and linear compute, but maintain only a lossy, fading summary of the past, often leading to inferior performance in recall oriented tasks. We propose Gated KalmaNet (GKA), a layer that reduces this gap by accounting for the full past when predicting the next token, while maintaining SSM-style efficiency. GKA achieves this by solving an online ridge regression problem at test time, with constant memory and linear compute cost in the sequence length. Drawing inspiration from the Kalman Filter, we iteratively solve the online ridge regression problem. However, a critical insight is that standard Kalman filter equations are numerically unstable in low-precision environments (like bfloat16) and difficult to parallelize in modern hardware. We address both challenges through two key innovations: (1) an adaptive regularization strategy with input-dependent gating that controls the condition number of the ridge regression, ensuring numerical stability while balancing memory retention. And (2) the use of Chebyshev Iteration instead of other conventional iterative solvers, which we demonstrate to be more stable in low-precision settings. To further improve scalability, we develop a hardware-aware chunk-wise implementation of Chebyshev Iteration along with custom kernels for backpropagating through our adaptive regularization and gating mechanisms. Empirically, GKA shows strong language understanding capabilites on short-context tasks outperforming existing SSM layers (like Mamba2, GLA and Gated DeltaNet). On long-context, GKA excels at real-world RAG and LongQA tasks up to 128k tokens, achieving more than 1010% relative improvement over other fading memory baselines.

1 Introduction

Large Language Models (LLMs) powered by (softmax) Attention mechanisms [Vaswani-NeurIPS2017] have revolutionized sequence modeling through their ability to form rich associations within their context window. However, a fundamental challenge that LLMs face is that their time complexity scales quadratically and storage grows linearly with their input length.

Recent years have seen intense efforts to develop Attention alternatives. Among them, memory layers based on linear State-Space models (SSMs) have grown popular for their linear-time computation and constant storage cost in the sequence length [Dao-ICML2024-mamba2, Yang-ICML2024-gla]. These SSMs find inspirations from classic techniques in adaptive signal processing, and integrating them into modern SSMs leads to principled layer design and enhanced performance [Liu-ICLR2025, Yang-NeurIPS2024, Zancato-NeurIPS2024-BMOJO]. However, pure SSM models still underperform Attention in many settings, especially on long-context tasks. This gap is a consequence of their different memory mechanisms: SSMs have a fading fixed dimensional lossy state of the past, while Attention has an eidetic ever increasing KV-cache state [Zancato-NeurIPS2024-BMOJO].

To bridge this gap, we aim at designing a memory layer that enjoys the efficiency of linear SSMs while performing computation conditioned on the exact past. Towards this goal, we first draw insights from the Kalman filter (KF) [Kalman-1960]. In signal processing terms, KF computes the most recent state conditioned on all data seen thus far, and, under mild assumptions, KF is optimal in the Maximum A-Posteriori (MAP) sense. In the LLM context, we use KF to update the state of an SSM layer and predict its output based on all past inputs. However, integrating KF into such a layer is non-trivial and faces two challenges:

  • Parallelizable Training. KF is an online algorithm and needs to be parallelized to fully utilize modern hardware that is highly optimized for large-scale LLM training.

  • Numerical Stability. KF involves matrix inversion, which can be numerically unstable in low precision arithmetic.

In this work, we propose Gated KalmaNet (GKA), a memory layer that incorporates KF into its design and is both numerically stable and trainable on highly parallelizable hardware. We start by observing that the KF recursion solves a test-time ridge regression problem. Then, to solve such a regularized problem stably, we make the following choices:

  • At the modeling level, we adaptively choose the regularization strength of our test-time objective function based on the Frobenius norm of the regularized data covariance. With this choice we can easily upper bound the condition number of the optimization problem.

  • At the algorithmic level, we note that exact solvers (e.g., torch.linalg.solve) are hard to parallelize (in a chunk-wise manner), so we resort to the classic Chebyshev Iteration (CH), which we show has high numerical accuracy and fast convergence compared with alternatives such as (accelerated) gradient descent and conjugate gradient.

To make GKA scalable and efficient, we implement CH with adaptive regularization in Triton in a hardware-aware, chunk-wise manner. Our technical novelty here includes deriving a chunk-wise implementation that back-propagates through the Frobenius norm, for which the difficulty is the presence of a nested recurrence. Furthermore, we combine CH with a gating mechanism that decides the regression residual weights in an input-aware and time-varying fashion, enhancing the contribution of recent inputs and smoothly fading out distant contexts. Overall, to the best of our knowledge, this is a first adoption of the CH method for training sequence modeling layers in LLMs stably at scale.

Finally, we demonstrate the efficacy of GKA on numerous LLM benchmarks. For example, on synthetic recall tasks (MQAR) [arora2023zoology], our method achieves the highest recall accuracy among other state-of-the-art linear SSMs including Mamba2 [Dao-ICML2024-mamba2] and (Gated) DeltaNet [Yang-NeurIPS2024, Yang-ICLR2025]. Also, GKA outperforms existing SSMs on several (short-context tasks from LM-Harness [eval-harness]) and long-context tasks (from RULER [hsieh2024ruler] and HELMET [yen2025helmet]). Specifically, GKA improves upon SSM baselines by at least 10% on real-world long-context tasks like Retrieval-Augmented Generation and Long Question-Answering tasks up to 128k tokens.

2 Prior Work and Preliminaries

In this section we briefly review prior work and preliminaries that will set the stage for motivating our choice of designing an SSM layer based on the Kalman Filter. For a more detailed exposition of related work refer Appendix˜A.

(Softmax) Attention. At each time tt, Attention [Vaswani-NeurIPS2017] linearly projects the tt-th input token to obtain three vectors, named query qtq_{t}, key ktk_{t}, value vtv_{t} respectively. Then, it outputs a vector ytDy_{t}\in\mathbb{R}^{D} as a convex combination of all values seen so far, with coefficients c1,,ctc_{1},\dots,c_{t} given by inner products of the current query qtq_{t} with all seen keys and a softmax mapping:

yt=i=1tcivi,ci:=exp(kiqtD)i=1texp(kiqtD).\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:attn}}{e}q:attn}y_{t}=\sum_{i=1}^{t}c_{i}v_{i},\quad\quad c_{i}:=\frac{\exp(\frac{k_{i}^{\top}q_{t}}{\sqrt{D}})}{\sum_{i=1}^{t}\exp(\frac{k_{i}^{\top}q_{t}}{\sqrt{D}})}. (Attn)

From an optimization perspective, Eq.˜Attn can be viewed as solving the following regression objective111Concretely, for keys and queries of unit norm, Attention is precisely the Nadaraya-Watson estimator [Nadaraya-1964, Watson-1964] with the Gaussian kernel to approximate the conditional expectation of the value given a query; cf. [Chaudhari-TIST2021, Vidal-2022].,

yt=argminvi=1texp(kiqtD)vvi22.\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq: attn as regression}}{e}q:attnasregression}y_{t}=\mathop{\rm argmin}_{v}\sum_{i=1}^{t}\exp\left(\frac{k_{i}^{\top}q_{t}}{\sqrt{D}}\right)\cdot\|v-v_{i}\|_{2}^{2}. (1)

The success of Eq.˜Attn is often attributed to its ability to perform verbatim retrieval of relevant context from the entire past. Here, the past refers to the entire key-value pairs observed thus far, also known as the KV-cache, which grows linearly with time tt. Moreover, the computation is also linear at each time tt, and doing so for all tt results in a quadratic time complexity. This high computation and storage cost of Attention makes its use prohibitive in long context scenarios.

Linear State-Space Models (SSMs). The high computation cost of Eq.˜Attn has motivated a flurry of work developing new LLM layers, like SSMs, with linear rather than quadratic cost. Most SSMs maintain a state matrix StD×DS_{t}\in\mathbb{R}^{D\times D} and update it at each time step via a linear recursion of the form

St=γtSt1+βtvtkt,yt=Stqt,\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:linear-attn}}{e}q:linear-attn}S_{t}=\gamma_{t}\cdot S_{t-1}+\beta_{t}\cdot v_{t}k_{t}^{\top},\quad y_{t}=S_{t}q_{t}, (Linear-SSM)

where γt,βt\gamma_{t},\beta_{t} are typically in [0,1][0,1]. Unlike the verbatim lookup of Eq.˜Attn, here Eq.˜Linear-SSM essentially compresses the entire KV-cache into a fixed-dimensional representation StS_{t}. Subsequent computation of the output yty_{t} relies on StS_{t} and no longer on the exact past. This results in a constant cost of storage and computation at every timestep.

In many linear SSMs (e.g., RetNet [Sun-arXiv2023RetNet], Mamba2 [Dao-ICML2024-mamba2]), the use of γt\gamma_{t} and βt\beta_{t} is often heuristic and finds inspirations from nonlinear recurrent neural networks [Hochreiter-NC1997]; in that light, γt\gamma_{t} and βt\beta_{t} are called forgetting and input gates, respectively. This basic form of Eq.˜Linear-SSM has been generalized by replacing γt\gamma_{t} with a diagonal matrix (GLA [Yang-ICML2024-gla], RWKV-6 [Peng-CoLM2024], Longhorn [Liu-ICLR2025]) or low-rank-plus-identity matrix (Gated DeltaNet [schlag2021linear, Yang-NeurIPS2024, Yang-ICLR2025], DeltaProduct [Siems-arXiv2025], RWKV-7 [Peng-arXiv2025rwkv]).

Similarly to that of Eq.˜Attn the case with low-rank-plus-identity matrices can often be justified from an optimization perspective. For example, Gated DeltaNet [schlag2021linear, Yang-NeurIPS2024] updates the state via (II is the D×DD\times D identity matrix)

St=γtSt1(Iβtktkt)+βtvtkt,\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:GDN}}{e}q:GDN}S_{t}=\gamma_{t}\cdot S_{t-1}\left(I-\beta_{t}k_{t}k_{t}^{\top}\right)+\beta_{t}\cdot v_{t}k_{t}^{\top}, (GDN)

which can be viewed as applying one gradient descent step with stepsize βt\beta_{t} and initialization γtSt1\gamma_{t}S_{t-1} to the objective

minSSktvt22.\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:GDN-obj}}{e}q:GDN-obj}\min_{S}\|Sk_{t}-v_{t}\|_{2}^{2}. (2)

The objectives of Eq.˜GDN Eq.˜2 and Eq.˜Attn Eq.˜1 are prime examples that expose a general distinction between linear SSMs and Eq.˜Attn: The former updates its state based on a regression objective that considers only the previous lossy state and the current time step, whereas the latter uses the entire, exact KV-cache to solve its regression objective Eq.˜1.

We hypothesize this myopic view of SSM objectives results in their lower performance and limited long-context abilities. We then ask: What is an objective or, equivalently, a recursion that considers the entire past as Eq.˜Attn while still being solvable in linear time as in Eq.˜Linear-SSM?

3 A Linear SSM Inspired by the Kalman Filter

In Section˜3.1 we show how the Kalman Filter (KF) gives insights into a new linear SSM layer that takes all past time instants into account. In Section˜3.2 we explain the numerical and efficiency challenges of building such a layer.

3.1 Motivation from Kalman Filter

KF is an established online approach that takes the exact past into account to optimally solve a weighted ridge regression objective (e.g., see [Peng-MoCL2025, Proposition 2 & Lemma 3]). In our context, this means that the optimal state

St=argminSD×DλSF2+i=1tηiSkivi22\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:rls}}{e}q:rls}\begin{split}S_{t}&=\mathop{\rm argmin}_{S\in\mathbb{R}^{D\times D}}\lambda\cdot\|S\|_{\text{F}}^{2}+\sum_{i=1}^{t}\eta_{i}\cdot\|Sk_{i}-v_{i}\|_{2}^{2}\end{split} (3)

can be computed by the KF recursion

St=St1(St1ktvt)ktΦt11/ηt+ktΦt1kt,\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:kf-update}}{e}q:kf-update}S_{t}=S_{t-1}-\frac{(S_{t-1}k_{t}-v_{t})k_{t}^{\top}\Phi_{{t}-1}}{1/\eta_{t}+k_{t}^{\top}\Phi_{{t}-1}k_{t}}, (KF)

where ηt\eta_{t} is the weight for the tt-th key-value pair, and Φt1\Phi_{t-1} is the Hessian inverse of Eq.˜3 at time t1t-1 (Φt1\Phi_{t-1} itself can be continually updated via the Woodbury matrix identity). It is now clear that objective Eq.˜3 takes the entire KV-cache into account, similarly to Eq.˜Attn. It is also clear that Eq.˜KF is an efficient update scheme similarly to Eq.˜Linear-SSM; indeed, Eq.˜KF is also a low-rank-plus-identity form (cf. Eq.˜GDN).

A key difference from Eq.˜Linear-SSM is that Eq.˜KF leverages second-order information from Φt1\Phi_{t-1} to solve Eq.˜3 optimally, whereas Eq.˜Linear-SSM relies on instantaneous objectives akin to Eq.˜2 (cf. [Yang-NeurIPS2024, Table 2]). It is in this sense that we say Eq.˜KF is more expressive than other Eq.˜Linear-SSM or Eq.˜GDN. We now detail the differences in the objectives of Eq.˜KF and Eq.˜Attn:

  • Eq.˜KF computes a parametric linear estimator that enables a constant-sized memory, while Eq.˜1 computes a non-parametric point estimate that entails storing the full cache.

  • In Eq.˜1, the weights of the same residual vary over time as the queries differ at each time, while in Eq.˜3 ii-th weight ηi\eta_{i} is constant once observed at time ii. The former results in quadratically many weights—thus a quadratic time complexity—and the latter linearly many.

  • In Eq.˜3, the regularizer λSF2\lambda\cdot\|S\|_{\text{F}}^{2} prevents overfitting our state to key-value pairs, as only a finite amount of “information" can be stored in a constant-sized memory beyond which will result in “fuzzy" recall. In this light, λ\lambda can be thought of as controlling the memorization “capacity" of the state.

3.2 Hurdles Towards Scalable Kalman Filter SSMs

Despite its optimality and (sequential) computational efficiency the Eq.˜KF recursion lacks a hardware-aware implementation that leverages parallelism in modern Tensors Cores. Moreover, for long sequences it can lose numerical precision due to division (and more significantly due to how the Hessian inverse Φt\Phi_{t} is updated). The final hurdle is conceptual: Fixing weights ηi\eta_{i} and regularization λ\lambda over time as in Eq.˜3 might make a layer less expressive.

We are aware of the use of Eq.˜KF or Eq.˜3 in neural network training three decades ago [Shah-1992] or in deep continual learning recently [Zeng-NMI2019, Mcdonnell-NeurIPS2023, Peng-ICLR2025]. We are also aware of the recent mentioning of Eq.˜3 or efforts towards solving it, which go by the name test-time optimization [Wang-arXiv2025v3, Von-arXiv2025-mesa]. However, to the best of our knowledge, none of the prior work has fully addressed the above hurdles that need to be solved to design an SSM layer that is trainable in parallel, numerical well-behaved, and sufficiently expressive. In particular, both [Von-arXiv2025-mesa] and [Wang-arXiv2025v3] have overlooked a basic numerical concern: The worst-case numerical error in solving Eq.˜3 can be ϵκ\epsilon\cdot\kappa [Golub-2013], where κ\kappa is the condition number of the Hessian in Eq.˜3 and ϵ\epsilon the machine precision; since ϵ0.007\epsilon\approx 0.007 (bf16), Eq.˜3 has to be regularized strongly for κ\kappa and the worst-case error to be small, regardless of algorithmic choices to solve Eq.˜3. Indeed, the regularization enforced in [Von-arXiv2025-mesa] sets λ\lambda to be lower bounded by 0.250.25, but this is not sufficient: Their κ\kappa is as large as 500500 [Von-arXiv2025-mesa, Fig. 13], implying a worst-case error of 3.53.5 (The implementation of [Von-arXiv2025-mesa] available on GitHub is numerically vulnerable; we failed to train it without NaNs in various settings.). Also, the regression objective in [Wang-arXiv2025v3] has no regularization, which makes it numerically ill-posed for low-precision training.

4 Gated KalmaNet (GKA)

We propose Gated KalmaNet (GKA) to address the above hurdles: We enhance numerical stability via adaptive regularization and the classic Chebyshev Iteration (CH), increase expressivity of KF via a standard gating mechanism, and improve parallelism via a hardware-friendly implementation.

4.1 CH with Adaptive Regularization & Weighting

Motivation. As alluded earlier, solving Eq.˜3 via Eq.˜KF is sequential in nature, and here we consider alternatives amenable to parallelizable training. Our first step towards this is to write down a closed form solution to Eq.˜3 and compute the output

yt=Stqt=(i=1tηiviki)(i=1tηikiki+λI)1qt.\displaystyle y_{t}=S_{t}q_{t}=\left(\sum_{i=1}^{t}\eta_{i}v_{i}k_{i}^{\top}\right)\left(\sum_{i=1}^{t}\eta_{i}k_{i}k_{i}^{\top}+\lambda I\right)^{-1}q_{t}.

With the weighted covariances Ut:=i=1tηivikiU_{t}:=\sum_{i=1}^{t}\eta_{i}v_{i}k_{i}^{\top} and Ht:=i=1tηikikiH_{t}:=\sum_{i=1}^{t}\eta_{i}k_{i}k_{i}^{\top}, we note that yty_{t} can be computed via first solving (Ht+λI)x=qt(H_{t}+\lambda I)x=q_{t} for xx and then left-multiplying UtU_{t}. An exact solver (e.g., torch.linalg.solve) can do so with high accuracy, by parallelizing over the batch dimension. However, it is inefficient here for two reasons. First, it takes O(D3)O(D^{3}) time for every tt. Second, it requires explicitly forming and materializing all HtH_{t}’s, which would entail a large I/O cost. In light of this, we resort to first-order iterative methods that admit chunk-wise implementation without materializing all HtH_{t}’s, enabling parallelism over chunks and batches. Furthermore, they often take O(D2)O(D^{2}) time complexity per iteration and can converge quickly in a few iterations. The iterative method we choose is the Chebyshev Iteration (CH); we proceed to describe its basic idea, with a justification of using CH deferred to Section˜4.2.3.

Chebyshev Iteration (CH). CH can be seen as an accelerated gradient descent method (AGD) that applies Eq.˜grad descent and Eq.˜momentum to the strongly convex objective 12ξHξξq\frac{1}{2}\xi^{\top}H\xi-\xi^{\top}q, that is to solve the optimality condition Hξ=qH\xi=q (Algorithm˜1). Different from AGD, CH incorporates a Eq.˜weight schedule and makes specific choices of different parameters; these choices makes CH optimal with the fastest convergence among first-order methods [Pedregosa-Chebyshev2021].

We now replace the above exact solver with CH:

x^t=CH(Ht+λI,qt,r),yt=Utx^t.\displaystyle\hat{x}_{t}=\text{CH}(H_{t}+\lambda I,q_{t},r),\quad\quad y_{t}=U_{t}\hat{x}_{t}.

Here, CH(Ht+λI,qt,r)\text{CH}(H_{t}+\lambda I,q_{t},r) means rr iterations of CH to approximately solve (Ht+λtI)x=qt(H_{t}+\lambda_{t}I)x=q_{t}. To improve stability and expressivity, next we allow regularization λ\lambda and weight ηi\eta_{i} to be time-varying and chosen adaptively. We write λt\lambda_{t} and ηt,i\eta_{t,i} to make their dependency in time tt explicit, with ηt,i\eta_{t,i} being the weight of the ii-th token at time tt.

1Input: HD×D,qDH\in\mathbb{R}^{D\times D},q\in\mathbb{R}^{D}, eigenvalue bounds L,μL,\mu with Lμ>0L\geq\mu>0, number of iterations rr;
2Initialize: ρLμL+μ\rho\leftarrow\frac{L-\mu}{L+\mu}; ω0=0\omega_{0}=0; the first two iterates ξ10,ξ02qL+μ\xi_{-1}\leftarrow 0,\xi_{0}\leftarrow\frac{2q}{L+\mu};
3For Loop (i=1,,ri=1,\dots,r):
4-0.2cm
ωi\displaystyle\omega_{i} 44ρ2ωi1\displaystyle\leftarrow\frac{4}{4-\rho^{2}\omega_{i-1}}\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:weight-update}}{e}q:weight-update} (weight schedule)
ξi\displaystyle\xi_{i} ξi12ωiL+μ(Hξi1q)\displaystyle\leftarrow\xi_{i-1}-\frac{2\cdot\omega_{i}}{L+\mu}(H\xi_{i-1}-q)\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:ch-gd}}{e}q:ch-gd} (grad descent)
ξi\displaystyle\xi_{i} ξi+(ωi1)(ξi1ξi2)\displaystyle\leftarrow\xi_{i}+(\omega_{i}-1)(\xi_{i-1}-\xi_{i-2})\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:ch-momentum}}{e}q:ch-momentum} (momentum)
5Output: ξr\xi_{r}
Algorithm 1 Chebyshev Iteration to solve Hξ=qH\xi=q

Adaptive Regularization. As mentioned, the condition number κt\kappa_{t} of Ht+λtIH_{t}+\lambda_{t}I has to be controlled for any method to be numerically stable. We choose λt\lambda_{t} to be proportional to the Frobenius norm HtF\|H_{t}\|_{\text{F}}, that is to set λt=aHtF\lambda_{t}=a\cdot\|H_{t}\|_{\text{F}} for some constant a>0a>0. An upper bound on κt\kappa_{t} now ensures:

κt=λmax(Ht)+λtλmin(Ht)+λtHtF+λtλt=a+1a.\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:kappa-ub}}{e}q:kappa-ub}\kappa_{t}=\frac{\lambda_{\text{max}}(H_{t})+\lambda_{t}}{\lambda_{\text{min}}(H_{t})+\lambda_{t}}\leq\frac{\|H_{t}\|_{\text{F}}+\lambda_{t}}{\lambda_{t}}=\frac{a+1}{a}. (4)

Here λmax(Ht),λmin(Ht)\lambda_{\text{max}}(H_{t}),\lambda_{\text{min}}(H_{t}) are the maximum and minimum eigenvalues of HtH_{t}, respectively. Given this choice of λt\lambda_{t}, we set L=HtF+λtL=\|H_{t}\|_{\text{F}}+\lambda_{t} and μ=λt\mu=\lambda_{t} for Algorithm˜1.

Adaptive Weighting (Gating). We use weights ηt,i\eta_{t,i} that are exponentially decaying in time: For all tit\geq i, we parameterize ηt,i=j=i+1tγj\eta_{t,i}=\prod_{j=i+1}^{t}\gamma_{j}, with each γj[0,1]\gamma_{j}\in[0,1] learnable. The fading weights encode the “prior" of recency bias that has been shown to exist in LLMs [fang2025large] without even explicitly computing the weights from the query-key dot products as in Eq.˜Attn. Similarly to Eq.˜Attn, the weights on the residuals are now time-varying, but differently, the exponentially decay parameterization allows for linear-time implementation.

Forward Recurrence. We now summarize our recurrence which arms CH with adaptive regularization and weighting:

Ht=γtHt1+ktkt,Ut=γtUt1+vtkt,yt=Utx^t,x^t=CH(Ht+λtI,qt,r).\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:ch-forward}}{e}q:ch-forward}\begin{split}H_{t}&=\gamma_{t}\cdot H_{t-1}+k_{t}k_{t}^{\top},\ \ U_{t}=\gamma_{t}\cdot U_{t-1}+v_{t}k_{t}^{\top},\\ y_{t}&=U_{t}\hat{x}_{t},\quad\hat{x}_{t}=\text{CH}(H_{t}+\lambda_{t}I,q_{t},r).\end{split} (CH)

4.2 Chunk-wise Implementation

In this subsection, we describe our hardware-aware implementation for the forward + backward passes for Eq.˜CH. More details can be found in Appendix˜B.

4.2.1 Forward Pass

Similarly to prior work [Dao-ICML2024-mamba2, Yang-ICML2024-gla, Yang-NeurIPS2024], we now describe a chunk-wise implementation for Eq.˜CH. In Eq.˜CH, given UtU_{t} and x^t\hat{x}_{t}, computing yt=Utx^ty_{t}=U_{t}\hat{x}_{t} in a chunk-wise fashion is similar to that of Eq.˜Linear-SSM; also similar is the calculation of Htξi1H_{t}\xi_{i-1} as needed in Eq.˜grad descent. For these we refer the reader to [Yang-ICML2024-gla, Yang-NeurIPS2024] for details. Our algorithmic novelty here is a chunk-wise computational formula for HtF\|H_{t}\|_{\text{F}}, presented next.

Let TT be the sequence length and CC the chunk size such that N:=T/CN:=T/C is an integer. For t=0,,N1t=0,\dots,N-1, write [t]:=tC[t]:=tC. The core idea of a chunk-wise implementation is as follows. First, we compute and store the initial state H[t]H_{[t]} of every chunk. This gives us implicit access to H[t]+cH_{[t]+c} via unrolling the recurrence of HtH_{t} for cc steps and therefore allows us to carry out computation with H[t]+cH_{[t]+c}; for example, we can compute the matrix-vector product H[t]+1ξH_{[t]+1}\xi via H[t]ξ+γ1k[t]+1k[t]+1ξH_{[t]}\xi+\gamma_{1}k_{[t]+1}k_{[t]+1}^{\top}\xi. This is without forming H[t]+1H_{[t]+1} explicitly, thereby reducing the number of states to materialize on chip. To implement such a scheme, we need to precompute all H[t]H_{[t]}’s sequentially, and then do the computation with parallelism over chunks and within each chunk.

We now make this idea precise for computing all HtF\|H_{t}\|_{\text{F}}’s within a chunk. Since the computation of each chunk is the same, we simplify by working with the first one where we have access to initial state H0H_{0}, gates γ1,,γC\gamma_{1},\dots,\gamma_{C}, keys KC=[k1,,kC]D×CK_{C}=[k_{1},\dots,k_{C}]\in\mathbb{R}^{D\times C},and we aim to compute H1F,H2F,,HCF\|H_{1}\|_{\text{F}},\|H_{2}\|_{\text{F}},\dots,\|H_{C}\|_{\text{F}}. With these notations, we first compute the CC-dimensional vector ζ=[ζ1,,ζC]\zeta=[\zeta_{1},\dots,\zeta_{C}]^{\top} of cumulative products of γi\gamma_{i}’s, with ζc=i=1cγi\zeta_{c}=\prod_{i=1}^{c}\gamma_{i}. Then, form the C×CC\times C upper triangular matrix MM whose (i,ji,j)-th entry Mj,cM_{j,c} is ζc/ζj\zeta_{c}/\zeta_{j} (cj\forall c\geq j). Now, unroll the recurrence of HcH_{c}:

Hc\displaystyle H_{c} =ζcH0+j=1cMj,ckjkj=ζcH0+j=1CMj,ckjkj.\displaystyle=\zeta_{c}H_{0}+\sum_{j=1}^{c}M_{j,c}k_{j}k_{j}^{\top}=\zeta_{c}H_{0}+\sum_{j=1}^{C}M_{j,c}k_{j}k_{j}^{\top}.

Expanding HcF2\|H_{c}\|_{\text{F}}^{2} gives the following sum of three terms:

ζc2H0F2+2ζcj=1CMj,ckjH0kj+j=1CMj,ckjkjF2.\displaystyle\zeta_{c}^{2}\|H_{0}\|_{\text{F}}^{2}+2\zeta_{c}\sum_{j=1}^{C}M_{j,c}k_{j}^{\top}H_{0}k_{j}+\Big\|\sum_{j=1}^{C}M_{j,c}k_{j}k_{j}^{\top}\Big\|_{\text{F}}^{2}.

With ζ\zeta, the first term ζc2H0F2\zeta_{c}^{2}\cdot\|H_{0}\|_{\text{F}}^{2} is easily computed in parallel for all cc. For the second term, we first compute the vector of quadratic forms kjH0kjk_{j}^{\top}H_{0}k_{j} for all jj in parallel, broadcast it and multiply it with MM element-wise, sum over each column, and multiply the result with 2ζ2\zeta element-wise. Finally, with Gram matrix GC:=KCKCG_{C}:=K_{C}^{\top}K_{C}, one verifies the third term can be computed in parallel for all cc via the following pseudocode:

column-sum(((GCGC)M)M).\displaystyle\text{column-sum}(((G_{C}\odot G_{C})M)\odot M). (5)

Here \odot denotes element-wise multiplication and the sum is over each column. Summing the three terms and taking the square root, we obtain H1F,,HCF\|H_{1}\|_{\text{F}},\dots,\|H_{C}\|_{\text{F}}, as desired.

4.2.2 Backward Pass

Motivation. Typically, the backward pass is done automatically via torch.autograd. However, for iterative methods such as CH (Algorithm˜1), torch.autograd would store some activations or intermediate iterates, entailing large storage cost. While in principle we can back-propagate through CH without storing any intermediate activations or iterates (by our trick of reverting the CH iterations, cf. Section˜C.1), under this trick it is difficult to compute all the gradients in a chunk-wise fashion. Therefore, we resort to the implicit differentiation trick, which is practically efficient and chunk-wise implementable, for backpropagation through the linear equations that CH approximately solves.

Implicit Differentiation. We derive the backward pass for our method with the standard implicit differentiation trick. It assumes we find an exact solution xtx^{*}_{t} to the equations (Ht+λtI)x=qt(H_{t}+\lambda_{t}I)x=q_{t}. In the backward pass, we are given the gradient dxt:=ddxtdx^{*}_{t}:=\frac{d\mathcal{L}}{dx^{*}_{t}} of some loss function \mathcal{L}, and need to compute the corresponding gradients at qt,kt,γtq_{t},k_{t},\gamma_{t}. For example, via the chain rule we obtain dqt:=ddqtdq_{t}:=\frac{d\mathcal{L}}{dq_{t}} via

dqt=(Ht+λtI)1dxt,\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:exact-dq-ridge}}{e}q:exact-dq-ridge}dq_{t}=(H_{t}+\lambda_{t}I)^{-1}dx^{*}_{t}, (6)

that is to solve linear equations similarly to the forward pass. Since the forward pass computes an approximate solution x^t\hat{x}_{t} via CH, we receive an approximate up stream gradient dx^td\hat{x}_{t} (not exactly dxtdx_{t}^{*}). Thus we employ CH to obtain an approximate gradient dq^t=CH(Ht+λtI,dx^t,r)d\hat{q}_{t}=\text{CH}(H_{t}+\lambda_{t}I,d\hat{x}_{t},r); cf. Table˜1.

Backward Recurrence. Besides dqtdq_{t}, we need to compute dHtdH_{t} from which we obtain dktdk_{t} and dγtd\gamma_{t} via the chain rule. We describe dγtd\gamma_{t} in the Appendix. Here we analyze dktdk_{t}:

Lemma 1.

With λt=aHtF\lambda_{t}=a\|H_{t}\|_{\text{F}}, wt=2a(xt)dqtHtFw_{t}=\frac{2a\cdot(x_{t}^{*})^{\top}dq_{t}}{\|H_{t}\|_{\text{F}}}, we have

dki\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:dk_i-main}}{e}q:dk_{i}-main}dk_{i} =tiMi,t(dqt(xt)xtdqtwtHt)ki.\displaystyle=\sum_{t\geq i}M_{i,t}\left(-dq_{t}(x_{t}^{*})^{\top}-x_{t}^{*}dq_{t}^{\top}-w_{t}H_{t}\right)k_{i}. (7)

With Ai:=tiMi,tdqt(xt)A_{i}:=\sum_{t\geq i}M_{i,t}\cdot dq_{t}(x_{t}^{*})^{\top}, we can compute the first two terms Aiki-A_{i}k_{i} and Aiki-A_{i}^{\top}k_{i} in Eq.˜7, similarly to Eq.˜Linear-SSM. Specifically, AiA_{i} satisfies the recursion

Ai=γi+1Ai+1+dqi(xi),\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:Ai}}{e}q:Ai}A_{i}=\gamma_{i+1}A_{i+1}+dq_{i}(x_{i}^{*})^{\top}, (8)

thus calculating AikiA_{i}k_{i} amounts to calculating UtqtU_{t}q_{t} in Eq.˜Linear-SSM; a difference is that the recursion here runs backwards.

Similarly, with Bi=tiMi,twtHtB_{i}=\sum_{t\geq i}M_{i,t}w_{t}H_{t}, the third term in Eq.˜7 can be written recursively as

Bi=γi+1Bi+1+wiHi,oi=Biki.\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:Bi}}{e}q:Bi}B_{i}=\gamma_{i+1}B_{i+1}+w_{i}H_{i},\quad o_{i}=B_{i}k_{i}. (9)

Chunk-wise Recurrence. As indicated, a chunk-wise implementation for computing AikiA_{i}k_{i} is known. On the other hand, computing BikiB_{i}k_{i} is more challenging than AikiA_{i}k_{i}, as the additive term wiHiw_{i}H_{i} in the backward recursion Eq.˜9 is not necessarily rank-11 ; rather, HiH_{i} itself is defined via the forward recursion in Eq.˜CH. Our contribution here is a derivation for computing BikiB_{i}k_{i} efficiently in a chunk-wise manner.

We begin by unrolling BiB_{i} to BC+1B_{C+1}:

Bi=Mi,CγC+1BC+1+BiintraBiintra:=c=iCMi,cwcHc\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:Bi=intra+inter}}{e}q:Bi=intra+inter}\begin{split}B_{i}=&M_{i,C}\cdot\gamma_{C+1}B_{C+1}+B_{i}^{\text{intra}}\\ B_{i}^{\text{intra}}:=&\sum_{c=i}^{C}M_{i,c}\cdot w_{c}H_{c}\end{split} (10)

We next discuss the intra-chunk term BiintrakiB_{i}^{\text{intra}}k_{i} and cross-chunk term Mi,CγC+1BC+1M_{i,C}\cdot\gamma_{C+1}B_{C+1} in succession.

Intra-chunk Computation. We now unroll HcH_{c} and obtain an expression BiintraB_{i}^{\text{intra}} more amenable to parallelism:

Biintra\displaystyle B_{i}^{\text{intra}} =c=iCMi,cwcHc\displaystyle=\sum_{c=i}^{C}M_{i,c}\cdot w_{c}H_{c}
=c=1CMi,cwc(ζcH0+j=1CMj,ckjkj)\displaystyle=\sum_{c=1}^{C}M_{i,c}w_{c}\Big(\zeta_{c}H_{0}+\sum_{j=1}^{C}M_{j,c}k_{j}k_{j}^{\top}\Big)
=H0c=1CMi,cwcζc+j=1Ckjkjc=1CMi,cwcMj,c.\displaystyle=H_{0}\sum_{c=1}^{C}M_{i,c}w_{c}\zeta_{c}+\sum_{j=1}^{C}k_{j}k_{j}^{\top}\sum_{c=1}^{C}M_{i,c}w_{c}M_{j,c}.

The coefficients of H0H_{0}, written as bib_{i}, are easily computed in parallel for all ii via element-wise operations, broadcasting, and summing. The coefficient of kjkjk_{j}k_{j}^{\top} is precisely the (i,j)(i,j)-th entry of the matrix Mw:=Mdiag(w1,,wC)MM_{w}:=M\operatorname{diag}(w_{1},\dots,w_{C})M^{\top}. Thus [B1intrak1,,BCintrakC][B_{1}^{\text{intra}}k_{1},\dots,B_{C}^{\text{intra}}k_{C}] is equal to

H0(diag(b1,,bC)KC)+KC((KCKC)Mw).\displaystyle H_{0}(\operatorname{diag}(b_{1},\dots,b_{C})K_{C})+K_{C}\left((K_{C}^{\top}K_{C})\odot M_{w}\right).

Here the mask MwM_{w} is in general a full matrix with no zero entries, as opposed to the triangular matrix in the case of Eq.˜Linear-SSM. While the triangular mask in the backward pass allows the error feedback from future tokens to be leveraged for learning past tokens, here our full mask MwM_{w} allows all tokens to interact with all other tokens in the backward pass, which facilitates the information flow and learning.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption

(a) Empirical convergence (b) CG as a single layer (c) CG in 5-layer LLAMA (d) CH as a single layer (e) CH in 5-layer LLAMA

Figure 1: CH converges with smaller errors than CG and is more numerically stable. Convergence of different methods in residual norms during the forward pass with batch size 88, sequence length 2048, 8 heads, head dimension 128 (a), and relative gradient differences from the exact solver (torch.linalg.solve) to CG (b, c) or CH (d, e). The backward pass is via implicit differentiation (impl) or torch.autograd (auto); cf. Table˜1. In (b, d) the gradients are those of [qt,kt][q_{t},k_{t}]; in (c, e) the gradients are those of network weights.

Cross-chunk Computation. In Eq.˜10, both γC+1\gamma_{C+1} and BC+1B_{C+1} are from the future chunks, thus we revise Eq.˜10 into the cross chunk recursion of B~C+1:=γC+1BC+1\widetilde{B}_{C+1}:=\gamma_{C+1}B_{C+1} which allows us to maintain a single term B~C+1\widetilde{B}_{C+1} from the future:

B~1=ζCB~C+1+B~1intra,B~1intra:=c=iCζcwcHc.\displaystyle\widetilde{B}_{1}=\zeta_{C}\cdot\widetilde{B}_{C+1}+\widetilde{B}^{\text{intra}}_{1},\ \ \ \widetilde{B}^{\text{intra}}_{1}:=\sum_{c=i}^{C}\zeta_{c}\cdot w_{c}H_{c}.

In our intra-chunk computation, we store the intra-chunk term B~1intra\widetilde{B}^{\text{intra}}_{1} of all chunks, implement the above with a simple for loop, and collect the terms ζCB~C+1ki\zeta_{C}\cdot\widetilde{B}_{C+1}k_{i}.

4.2.3 Comparison to Other Iterative Solvers

Here we validate our choice of Chebyshev Iteration (CH) by benchmarking it against other iterative methods.

Convergence in the Forward Pass. We generate random regression problems, which we solve via CH and 3 other baselines: gradient descent (GD), accelerated GD with Nesterov’s momentum (AGD), conjugate gradient (CG). GD and AGD are run with stepsizes that are optimal for regression problems. Fig.˜1a shows CG converges the fastest within a few iterations, while CH reaches the same accuracy as CG at iteration 10 and eventually attains the smallest errors.

Table 1: Implicit differentiation for computing dqtdq_{t}.
forward pass backward pass
exact xt=(Ht+λtI)1qtx^{*}_{t}=(H_{t}+\lambda_{t}I)^{-1}q_{t} dqt=(Ht+λtI)1dxtdq_{t}^{*}=(H_{t}+\lambda_{t}I)^{-1}dx^{*}_{t}
CG x^t=CG(Ht+λtI,qt,r)\hat{x}_{t}=\text{CG}(H_{t}+\lambda_{t}I,q_{t},r) dq^t=CG(Ht+λtI,dx^t,r)d\hat{q}_{t}=\text{CG}(H_{t}+\lambda_{t}I,d\hat{x}_{t},r)
CH x^t=CH(Ht+λtI,qt,r)\hat{x}_{t}=\text{CH}(H_{t}+\lambda_{t}I,q_{t},r) dq^t=CH(Ht+λtI,dx^t,r)d\hat{q}_{t}=\text{CH}(H_{t}+\lambda_{t}I,d\hat{x}_{t},r)

Stability of the Backward Pass. We then proceed and measure the gradient stability of CG and CH, whose backward passes are implemented either via implicit differentiation as per Table˜1 (impl), or via torch.autograd (auto).

In Fig.˜1b, CG (impl) as a standalone layer has its gradient close to that of the exact solver up to a 10310^{-3} relative difference. In Fig.˜1c, this difference is amplified to almost 11 in a 5-layer LLAMA where Eq.˜Attn is replaced with Eq.˜3. This indicates CG (impl) completely deviates from the reference gradient (exact), defeating its purpose of training the network from the regression feedback. In contrast, the gradients of CH (impl) and CH (auto) are eventually close to that of the exact solver either as a single layer (Fig.˜1c) or within multiple layers (Fig.˜1d), up to a 10610^{-6} difference. Moreover, the curves for CH (impl) and CH (auto) nearly overlap, suggesting that their gradients may be close. The following lemma confirms and formalizes this intuition (see Appendix˜C for a proof), thereby justifying our choice of CH over the alternatives:

Lemma 2.

Let dqtdq_{t} be the exact gradient of qtq_{t} for CH, e.g., computed by CH (auto). Let dq^td\hat{q}_{t} be the gradient of CH (impl), computed as per Table˜1. We have dqt=dq^tdq_{t}=d\hat{q}_{t}.

Refer to caption
Figure 2: Our GKA block. Blue refers to established practices in the literature with the solid circles denote 2\ell_{2} normalization. Green components (CH and α\alpha-connection) are our proposals.
Refer to caption
Refer to caption

(a) Accuracy vs. model dimension for different fading memory layers on MQAR.     (b) Runtime of a single memory layer

Figure 3: MQAR results (a) Each plot corresponds to a particular sequence length and number of key-value pairs for the model to memorize. Runtime (b) Runtimes are for a single forward + backward pass (8 heads, head dim 128128, batch size 44, averaged over 20 runs).

4.3 Architectural Consideration

Our GKA layer in Fig.˜2 includes two components (in green) on top of established practices (in blue). The CH component is described in Section˜4.1, thus here we introduce the α\alpha-connection. First, the sigmoid activation ensures αt[0,1]\alpha_{t}\in[0,1], so the output of the α\alpha-connection is a convex combination of the original query qtq_{t} and the output x^t\hat{x}_{t} of CH. Second, it plays a similar role to residual connection, which establishes a direct path that facilitates the gradient flow and improves training; we show this is indeed the case in Section˜F.3. Finally, the full architecture for GKA is the standard Transformer, with its attention layer replaced by the GKA layer.

5 Experiments

In this section, we empirically validate the efficacy of our approach. We first evaluate memorization ability on synthetic associative recall tasks (Section˜5.1). We then report training throughput of GKA (Section˜5.2). Finally, we examine performance on short-context language understanding benchmarks such as commonsense reasoning and long-context modeling abilities in Section˜5.3. The Appendix details our experimental settings (Appendix˜D) and ablations of various modeling choices (Appendix˜F, Appendix˜G).

Baselines. All experiments consider the following state-of-the-art linear SSM-based fading memory layers as baselines: Mamba2 [Dao-ICML2024-mamba2], DeltaNet [Yang-NeurIPS2024], Gated DeltaNet (GDN) [Yang-ICLR2025], and Gated Linear Attention (GLA) [Yang-ICML2024-gla]. Each of these layers rely on instantaneous objectives that depend on the previous lossy state and current tokens (e.g., Eq.˜2), as opposed to the entire history of tokens observed so far as in GKA. Finally, we contrast our results with (Softmax) Attention, which serves as our paragon. For our Attention-based model, we adopt the architecture proposed in Qwen3 models [yang2025qwen3].

5.1 GKA on Synthetic Associative Recall Tasks

We first assess the capability of our models to recall information on the multi-Query Associative Recall (MQAR) task, a synthetic task introduced by Arora et al. [arora2023zoology]. This task presents the model with a sequence of key-value pairs to memorize, followed by a sequence of queries. For each query, the model must retrieve the corresponding key from memory and accurately recall its associated value. Attention based layers perform the best in this task, while SSM-based memory layers are known to struggle as their memory fades away as the context length grows.

We compare GKA with Attention and other linear SSM baselines on this task. For each memory layer type, we train 2-layer models on MQAR training data and evaluate on a held-out test set. We repeat this experiment for 44 different learning rates spanning from 10410^{-4} to 10210^{-2}. As shown in Fig.˜3a, GKA improves upon every other linear SSM baseline at all sequence lengths and model dimensions considered. Note, the complexity of the task increases with increasing sequence length and number of key-value pairs, while larger model dimensions improve memorization capacity through increased state size. The success of our layer can be attributed to our modeling choice: unlike other fading memory designs (like GDN or Mamba2), we construct states based on the optimal MAP estimate conditioned on the entire history, enabling better retention of remote information.

5.2 Training Throughput of GKA

In Fig.˜3b we measure the running time (forward + backward) of a single GKA layer and compare it with FlashAttention [golden2024flash], DeltaNet, and GDN. Our layer achieves comparable running time to GDN, a state-of-the-art SSM layer, despite having a more computationally expensive state update equation Eq.˜CH than Eq.˜GDN. This demonstrates that our chunk-wise parallelization strategy effectively compensates for the additional computational cost.

Table 2: On average GKA improves upon all fading memory baselines across all tasks. We report results for zero-shot evaluation of 2.8B language models for short-context tasks. For each task, bold indicates highest value followed by underlined.
Model
ARC-C
ARC-E
BoolQ COPA HellaSWAG PIQA SciQ Winogrande FDA SWDE Avg
acc_n \uparrow acc_n \uparrow acc \uparrow acc \uparrow acc_n \uparrow acc_n \uparrow acc_n \uparrow acc \uparrow contains \uparrow contains \uparrow
Transformer 32.25 56.10 64.28 80.00 60.96 73.56 79.50 61.72 58.53 72.28 63.92
Gated Linear Attention 27.82 50.80 52.57 78.00 48.83 70.13 69.60 54.54 2.81 20.43 47.55
DeltaNet 32.85 58.16 42.51 81.00 61.13 73.78 43.90 61.72 11.80 46.08 51.29
Mamba2 32.24 59.64 58.72 82.00 62.23 73.78 79.80 62.19 7.71 41.13 55.94
Gated DeltaNet 32.59 60.02 62.75 82.00 62.80 74.32 80.60 62.35 8.26 44.28 57.00
Gated KalmaNet (Ours) 32.51 59.89 61.68 85.00 63.84 74.81 83.20 64.17 12.89 50.95 58.89
Refer to caption
Figure 4: Long Context Performance up to 128k tokens. GKA achieves strong RAG and LongQA capabilities, outperforming all baselines by 10% in relative improvement. Interestingly, we observe that there is no clear winner Synthetic Recall. All models struggle to perform better than random chance on ICL.

5.3 GKA on Language Modeling

5.3.1 Short-context Tasks

Setup. For this set of experiments, we construct 2.8B LLM models for each memory layers (GKA and baselines described in Section˜5) by cascading blocks of mem + Multi-Layer Perceptron (MLP) blocks.222For Mamba2 baseline, we consider cascading blocks of Mamba2 layer alone since a single Mamba2 layer has the Mamba2 SSM and MLP. Hereby, we refer to the 2.8B models with the same name as the layer used to construct them. We then train each model on DCLM [li2024datacomp], a generic pre-training dataset for 100100B tokens at 44K context length using the AdamW optimizer with a peak Learning Rate (LR) of 10310^{-3} and gradient clipping of 1.01.0. We used the cosine LR scheduler with a warmup period of 55B tokens with a global batch size of 22M tokens. All models employ the GPT2 tokenizer with a vocabulary size of 5050K tokens.

Tasks. Following prior works [Zancato-NeurIPS2024-BMOJO, Yang-ICML2024-gla, Yang-ICLR2025], to consider language modeling capabilities of our model we perform zero-shot evaluation on the following eight common-sense reasoning tasks from LM-Harness [eval-harness]: Arc-E, Arc-C, BoolQ, COPA, HellaSWAG, PIQA, SciQ, Winogrande. We also evaluate models on FDA and SWDE, real-world recall-intensive tasks which focus on extracting structured information like tagged content from raw text (for example, HTML files). All these tasks are relatively short (<2<2K tokens).

Results. We report our results in Table˜2. GKA outperforms all fading memory baselines on average across all tasks owing to its ability to better manage its state via solving Eq.˜3. In particular, GKA outperforms both GDN and Mamba2 on recall-intensive tasks (FDA and SWDE) by about 10%10\% (rel. improvement). We note that although GKA improves upon existing SSM layers there is still a gap with Attention-based Transformer especially on recall-tasks owing to the eidetic capabilities of Attention. Nevertheless, as discussed in Section˜1 this improvement comes at a quadratic cost at training time, whereas our layer’s computational complexity is still comparable to existing SSM layers (cf. Section˜5.2). In Appendix˜I we extend our results to Hybrid models (stack of SSM and Attention layers) and show that the gap with full Transformer models becomes negligible (while still benefiting the SSM’s computational advantages). Finally, in Appendix˜E we show that GKA exhibits stronger scaling with compute than other SSM baseline models.

5.3.2 Long-context Tasks

Setup. To enable long-context capabilities of our models, as is common practice, we perform continued pre-training of our 2.8B models obtained in Section˜5.3.1 on 2525B tokens of long documents at 128128K context length (cf. Appendix). To the best of our knowledge we are the first to train and evaluate SSM models up to 128128K context (e.g., previous work [Yang-ICLR2025] only considered up to 44K/88K context).

Tasks. For long-context, we refrain from using perplexity as it is known to have limitations at assessing long-context performance of LLMs [nunez2024expansion, fang2024wrong, gao2025train]. Instead, we turn to recently proposed benchmarks that mix synthetic and real datasets comprising several long-context tasks: synthetic Recall, Retrieval-Augmented Generation (RAG), Many shot In-Context Learning (ICL) and Long Question-Answering (LongQA). For synthetic Recall and LongQA we consider tasks from the RULER benchmark [hsieh2024ruler]. For RAG and ICL we consider tasks from HELMET [yen2025helmet].

Results. Fig.˜4 reports our results. GKA shows strong RAG and LongQA capabilities, outperforming all fading memory baselines by at least 1010% (rel. improvement). Interestingly, on synthetic Recall tasks from RULER, GKA is competitive only at 44K context length and starts to fall behind afterwards. We attribute this divergence to the fundamental differences between these task types. While both RAG and LongQA can be thought of as finding relevant information in long streams of text, they involve more realistic linguistic patterns and semantic relationships that align with natural text distributions seen during pretraining. In contrast, synthetic Recall tasks require models to find specific words, numbers, or UUIDs verbatim from long contexts filled with random distractors. This artificial setting does not reflect natural text distributions. Since GKA computes MAP estimates of the latent state based on learned representations of observed tokens, it relies on its pretrained weights to determine which information is important to retain. The synthetic and unnatural structure of Recall tasks makes it difficult for the model to identify what should be retained, as pretrained knowledge provides little signal about importance in these artificial contexts. This suggests that our approach excels in realistic scenarios where pretrained knowledge about natural language structure can guide information selection, but struggles when the signal-to-noise distinction is purely artificial.

6 Kalman Filter for Optimally Modelling Fading Memory

In this section, we show how the Kalman Filter (KF) provides a principled solution for constructing an optimal fading memory that accounts for the entire history. We begin by describing the standard Kalman Filter recurrence in the context of memory modeling. However, the KF has a fundamental limitation: its inherently sequential nature makes it impractical for large-scale training on modern hardware accelerators Section˜3.2. To address this, we make simplifying assumptions that makes KF amenable to parallelization on modern hardware accelerators. We then demonstrate that several recent state-space models (DeltaNet, Gated DeltaNet, and Kimi Delta Attention) can be viewed as approximations to the KF recurrence. Specifically, these methods approximate the “optimal" Kalman gain matrix while ignoring dependencies on the past. In contrast, GKA computes the exact Kalman gain by considering the full history. This theoretical advantage translates to improved empirical performance, as we demonstrate in Section˜5.

6.1 A Dynamical System for Fading Memory

The Kalman filter is a classical algorithm for online optimal inference in linear Gaussian State-Space Models. It gives a principled way to maintain and update a compact state as new noisy observations arrive. The latent state serves as a compressed "memory" of the past. More formally, it is a minimal sufficient statistic that makes past observations conditionally independent of future ones given the state.

We begin by describing a linear Gaussian model for fading memory.

st\displaystyle s_{t} =Atst1+Btut+wt,\displaystyle=A_{t}s_{t-1}+B_{t}u_{t}+w_{t},\quad wt𝒩(0,Qt)\displaystyle w_{t}\sim\mathcal{N}(0,Q_{t}) (LGM)
vt\displaystyle v_{t} =ktst+μt,\displaystyle=k_{t}^{\top}s_{t}+\mu_{t},\quad μt𝒩(0,rt),\displaystyle\mu_{t}\sim\mathcal{N}(0,r_{t}),

where stns_{t}\in\mathbb{R}^{n} is a latent state that summarizes the past, utnu_{t}\in\mathbb{R}^{n} is the control input that updates the state and vtv_{t} is the scalar measurement observed at time tt. At,Btn×nA_{t},B_{t}\in\mathbb{R}^{n\times n} are the state transition and input selection matrices, and ktnk_{t}\in\mathbb{R}^{n} is the emission (readout) vector. Finally, wtw_{t} and μt\mu_{t} are Gaussian process and measurement noise, respectively.

Parameter interpretation. AtA_{t} and BtB_{t} control the forgetting (fading of the remote past) and input selectivity rates respectively, determining how the state evolves over time. The measurement noise μt\mu_{t} naturally gives rise to gating mechanisms commonly used in modern SSM layers, as we will show in Section˜6.4.

Extension to multi-channel measurements. In attention mechanisms, the memory consists of verbatim key-value pairs that can be queried to retrieve past information [Zancato-NeurIPS2024-BMOJO]. Similarly, we want our state to reconstruct past values from their corresponding keys. To achieve this, we extend to a matrix-valued state Stn×dS_{t}\in\mathbb{R}^{n\times d}, where each column independently follows the dynamics in Eq.˜LGM.

Specifically, for the ithi^{\text{th}} channel:

st,i\displaystyle s_{t,i} =At,ist1,i+Bt,iut,i+wt,i,\displaystyle=A_{t,i}s_{t-1,i}+B_{t,i}u_{t,i}+w_{t,i},\quad wt𝒩(0,Qt,i)\displaystyle w_{t}\sim\mathcal{N}(0,Q_{t,i})
vt,i\displaystyle v_{t,i} =ktst,i+μt,i\displaystyle=k_{t}^{\top}s_{t,i}+\mu_{t,i}\quad μt𝒩(0,rt,i),\displaystyle\mu_{t}\sim\mathcal{N}(0,r_{t,i}),

where (kt,vt)(k_{t},v_{t}) is the key-value pair at time tt and vt,iv_{t,i} is the ithi^{\text{th}} element of vtv_{t}. In what follows, we focus on a single channel and drop the subscript ii from the state for notational clarity.

6.2 Kalman Filter for Optimal Inference

Given the model in Eq.˜LGM and a sequence of measurements {v1,v2,,vt}\{v_{1},v_{2},\ldots,v_{t}\}, the Kalman Filter computes the Maximum A-Posteriori (MAP) estimate of the latent state at time tt:

s^t=argmaxsp(sv1,v2,,vt),\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq: KF-MAP}}{e}q:KF-MAP}\hat{s}_{t}=\arg\max_{s}p(s\mid v_{1},v_{2},\ldots,v_{t}), (11)

where pp is the probability density function. The MAP estimate is optimal in the sense that it minimizes the expected squared error between the true state and its estimation given all measurements up to time tt.

The KF recursion. The Kalman Filter updates the state estimate recursively as new measurements arrive. At time tt, the update is:

st^=Ats^t1+BtutPredicted state+Gt(vt,ikt[Ats^t1+Btut]Predicted stateInnovation),\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:KF-recursive-for-LGM}}{e}q:KF-recursive-for-LGM}\hat{s_{t}}=\underset{\textrm{Predicted state}}{\underbrace{A_{t}\hat{s}_{t-1}+B_{t}u_{t}}}+G_{t}(\overbrace{v_{t,i}-k_{t}^{\top}\underset{\textrm{Predicted state}}{\underbrace{\Big[A_{t}\hat{s}_{t-1}+B_{t}u_{t}\Big]}}}^{\textrm{Innovation}}),\\ (12)

where the innovation measures the discrepancy between the actual measurement vtv_{t} and the predicted measurement based on the predicted state estimate.

The Kalman gain GtG_{t} determines how much to trust the new measurement versus the predicted state. It is computed as follows:

Gt=[AtΣt1AtT+Qt]ktkt[AtΣt1AtT+Qt]kt+rt.\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq: kalman gain update}}{e}q:kalmangainupdate}G_{t}=\frac{\Big[A_{t}\Sigma_{t-1}A_{t}^{T}+Q_{t}\Big]k_{t}}{k_{t}^{\top}\Big[A_{t}\Sigma_{t-1}A_{t}^{T}+Q_{t}\Big]k_{t}+r_{t}}. (13)

The error covariance Σt\Sigma_{t} quantifies the uncertainty in the state estimate. It represents the covariance of the estimation error (sts^t)(s_{t}-\hat{s}_{t}) conditioned on all measurements up to time tt. The covariance is updated as:

Σt=(IGtkt)(AtΣt1AtT+Qt)\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq: error covariance update}}{e}q:errorcovarianceupdate}\Sigma_{t}=\Big(I-G_{t}k_{t}^{\top}\Big)\Big(A_{t}\Sigma_{t-1}A_{t}^{T}+Q_{t}\Big) (14)

Equations (12), (13) and (14) constitute the KF recursion. We initialize with s^0=0\hat{s}_{0}=0 and Σ0=σIn\Sigma_{0}=\sigma I_{n}, where InI_{n} is the n×nn\times n identity matrix and σ\sigma represents our prior uncertainty about the state before observing any measurements.

6.3 Gated KalmaNet: A Steady-State Dynamical System for Large-Scale Training

Despite its optimality, the KF recursion in its most general form is inherently sequential; each update depends on the previous state estimate. This sequential dependency prevents the parallelization necessary for efficient large-scale training on modern hardware.

To enable parallelization, we make a key simplifying assumption: the underlying state remains static over time. This reduces the problem from tracking a dynamic state to estimating a fixed but unknown parameter from sequential noisy measurements. Formally, we assume a steady-state model:

st\displaystyle s_{t} =st1\displaystyle=s_{t-1} (15)
vt,i\displaystyle v_{t,i} =ktst+μt,\displaystyle=k_{t}^{\top}s_{t}+\mu_{t},\quad μt𝒩(0,rt),\displaystyle\mu_{t}\sim\mathcal{N}(0,r_{t}),

where At=InA_{t}=I_{n}, Bt=0B_{t}=0, and wt=0w_{t}=0 (i.e., no state evolution, no control input, and no process noise).

Adapting to evolving context. While the steady-state assumption may initially seem restrictive, contexts naturally evolve as topics change, GKA addresses this through adaptive weighting (Section˜4.1). By assigning higher weights to recent measurements, older observations are naturally faded out over time, allowing the model to track shifting context despite the static formulation.

Under this simplification, the KF recursion reduces to:

s^t=s^t1+Gt(vt,ikts^t1).\displaystyle\hat{s}_{t}=\hat{s}_{t-1}+G_{t}(v_{t,i}-k_{t}^{\top}\hat{s}_{t-1}). (16)
Gt=Σt1ktktΣt1kt+rt.\displaystyle G_{t}=\frac{\Sigma_{t-1}k_{t}}{k_{t}^{\top}\Sigma_{t-1}k_{t}+r_{t}}.
Σt=(IGtkt)Σt1.\displaystyle\Sigma_{t}=\Big(I-G_{t}k_{t}^{\top}\Big)\Sigma_{t-1}.

Collecting all channels, these equations can be written compactly in matrix form as shown in (KF).333with columns of StS_{t} transposed to being rows of StS_{t} to be consistent with the notation in (KF) and taking the noise variance rt=1ηtr_{t}=\frac{1}{\eta_{t}}.. A key insight of this work is that the KF recursion for the steady-state model admits an efficient parallel implementation via chunked processing (detailed in Section˜4) that results in Gated KalmaNet.

Critically, the KF recursion accounts for the entire history when computing state estimates. The Kalman gain GtG_{t} at each step depends on all previous measurements through Σt1\Sigma_{t-1}. This contrasts with most existing SSMs, which we show next can be viewed as approximations that ignore historical dependencies when computing their gain matrices. This principled treatment of the full history is a key advantage of our approach.

6.4 Connection with Existing SSM Layers

DeltaNet [Yang-DeltaNet] approximates the KF recursion in (16) by assuming fixed error covariance: Σt=In\Sigma_{t}=I_{n} for all tt. This simplifies the Kalman gain to:

Gt=ktktkt+rt=kt1+rt,\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{G_t for deltanet}}{G}{}_{t}fordeltanet}G_{t}=\frac{k_{t}}{k_{t}^{\top}k_{t}+r_{t}}=\frac{k_{t}}{1+r_{t}}, (17)

where the second equality assumes unit-normalized keys, a common assumption in practical instantiations of DeltaNet. Substituting (17) into the state update (16) and defining βt=(1+rt)1\beta_{t}~=~(1+r_{t})^{-1} yields:

s^t=(Iβtktkt)s^t1+βtktvt,i,\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq: DeltaNet}}{e}q:DeltaNet}\hat{s}_{t}=(I-\beta_{t}k_{t}k_{t}^{\top})\hat{s}_{t-1}+\beta_{t}k_{t}v_{t,i}, (DeltaNet)

which is the DeltaNet recurrence. By fixing Σt\Sigma_{t}, DeltaNet avoids tracking the evolving uncertainty in the state estimate, a key simplification that sacrifices optimality for computational efficiency. In contrast, GKA maintains the full error covariance Σt\Sigma_{t}, allowing it to optimally weight measurements based on the entire history.

Gated DeltaNet (GDN) [Yang-ICLR2025] extends DeltaNet by incorporating explicit forgetting through a time-dependent decay factor αt\alpha_{t}. Like DeltaNet, GDN can be viewed as fixing Σt=In\Sigma_{t}=I_{n}, but applying this approximation to the KF recursion for a fading dynamical system where the state decays over time.

Specifically, GDN assumes

st\displaystyle s_{t} =αtst1+wt\displaystyle=\alpha_{t}s_{t-1}+w_{t}\quad wt𝒩(0,In)\displaystyle w_{t}\sim\mathcal{N}(0,I_{n}) (18)
vt,i\displaystyle v_{t,i} =ktst+μt,\displaystyle=k_{t}^{\top}s_{t}+\mu_{t},\quad μt𝒩(0,rt),\displaystyle\mu_{t}\sim\mathcal{N}(0,r_{t}),

where αt[0,1]\alpha_{t}\in[0,1] is a learned decay factor controlling how much past information to retain. This corresponds to setting At=αtInA_{t}=\alpha_{t}I_{n} in (LGM). When αt0\alpha_{t}\to 0, the state "forgets" the past completely; when αt1\alpha_{t}\to 1, the state is fully retained.

Under the identity covariance assumption Σt=In\Sigma_{t}=I_{n}, the Kalman gain becomes:

Gt=(αt2+1)kt(αt2+1)ktkt+rt=kt1+rt/(αt2+1),\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq: kalman gain update GDN}}{e}q:kalmangainupdateGDN}G_{t}=\frac{(\alpha_{t}^{2}+1)k_{t}}{(\alpha_{t}^{2}+1)k_{t}^{\top}k_{t}+r_{t}}=\frac{k_{t}}{1+r_{t}/(\alpha_{t}^{2}+1)}, (19)

where the second equality again assumed unit-normalized keys (as in DeltaNet). Defining βt=(1+rtαt2+1)1\beta_{t}~=~(1+\frac{r_{t}}{\alpha_{t}^{2}+1})^{-1} and substituting into the state update (12) yields:

s^t\displaystyle\hat{s}_{t} =αts^t1+βtkt(vt,ikt[αts^t1]),\displaystyle=\alpha_{t}\hat{s}_{t-1}+\beta_{t}k_{t}(v_{t,i}-k_{t}^{\top}\Big[\alpha_{t}\hat{s}_{t-1}\Big]), (GDN)
=[Inβtktkt]αts^t1+βtktvt,i,\displaystyle=\Big[I_{n}-\beta_{t}k_{t}k_{t}^{\top}\Big]\alpha_{t}\hat{s}_{t-1}+\beta_{t}k_{t}v_{t,i},

which recovers the GDN recurrence. In practice, βt\beta_{t} is an input-dependent learnable parameter.

Like DeltaNet, GDN avoids tracking the evolving uncertainty Σt\Sigma_{t}, trading optimality for computational simplicity. The key difference is that GDN’s explicit forgetting factor αt\alpha_{t} provides additional control over the memory horizon. However, by fixing Σt=In\Sigma_{t}=I_{n}, GDN still ignores how measurement history should optimally influence the Kalman gain, leading to suboptimal performance compared to GKA (see Section˜5).

Kimi Delta Attention (KDA) [team2025kimi] further extends GDN by using channel-specific decay factors αt,i\alpha_{t,i} in place of the global αt\alpha_{t}. This allows different channels to have independent memory horizons. In the KF framework, this corresponds to:

st,i=αt,ist1,i+wt,iwt,i𝒩(0,In),s_{t,i}=\alpha_{t,i}s_{t-1,i}+w_{t,i}\quad w_{t,i}\sim\mathcal{N}(0,I_{n}), (20)

for each channel ii. While this added flexibility can improve expressiveness, KDA still assumes Σt=In\Sigma_{t}=I_{n} and therefore does not optimally consider the entire past when computing its state update. Like DeltaNet and GDN, KDA sacrifices optimality for computational simplicity.

7 Discussions and Limitations

Thanks to its expressive test-time ridge regression objective, Gated KalmaNet extends previous fading memory layers like Mamba2, LongHorn and Gated DeltaNet, all of which only depend on an instantaneous test-time objective. However, GKA is only optimal among linear memory layers, solving our test-time objective leveraging non-linear updates while still maintaining hardware efficiency and numerical stability is an interesting area for future research. Despite the efficient kernels we implemented, we believe even faster implementations of our idea are possible, e.g., via sketching (see Appendix˜H for preliminary results). Finally, while we have showed promising results in combining GKA with Attention layers into Hybrid models (Appendix˜I), further scaling beyond 3B parameters models is required to validate GKA on more challenging real world problems.

Appendix A Related Work

Since the introduction of Self-Attention [Vaswani-NeurIPS2017], significant research has been conducted to reduce its quadratic cost in processing long input sequences. As models and systems scale to million-token contexts, Attention’s bottlenecks have become critical blockers to frontier agentic applications in coding, information gathering, and scientific discovery [chen2024scienceagentbench, cui2025curie, jimenez2023swe]. Prior works have proposed various approximation schemes to overcome these limitations. For example, Reformer [kitaev2020reformer] uses locality-sensitive hashing to group tokens with similar embeddings. This enables the model to attend only to a subset of tokens rather than the entire sequence. Other works equip Transformer models with "compressed" memory tokens that are updated dynamically and causally over sliding windows on entire sequence chunks [dai2019transformerxl, munkhdalai2024leave, mohtashami2023landmark]. While a lot of prior work have focused on reducing the quadratic complexity of Attention with sparse approximations [nunez2024expansion, yuan2025NSA], this work focuses on linear approximations of Attention.

A.1 Linear Attention

Linear Attention methods approximate the Attention mechanism with constant-size recurrent dynamical systems [Dao-ICML2024-mamba2, Yang-ICML2024-gla, beck2024xlstm, Yang-ICLR2025]. Numerous State-Space Model (SSM) variations have been proposed, ranging from those closely resembling Linear Attention [sun2023retentive] or Linear Time-Invariant dynamical systems [gu2021combining, zancato2022stacked], to those introducing novel adaptive or gated state updates [Yang-ICML2024-gla, Dao-ICML2024-mamba2, orvieto2023resurrecting].

Despite their differences, all SSMs follow the same basic working principle inspired by classical state-space models [Kalman-1960]: they process the input sequence by maintaining a fixed-size state that acts as a compressed (lossy) representation of all processed tokens. Moreover, when implemented in hardware, the state must have finite precision and “fades away the past" as more samples are processed. Successful SSM layers typically employ hardware-aware implementations that efficiently utilize modern matrix multiplication accelerators through highly parallelizable and scalable primitives, including associative scans [gu2023mamba, de2024griffin], chunking mechanisms [Dao-ICML2024-mamba2, Yang-ICML2024-gla], and techniques that avoid materializing the entire state in slow high-bandwidth memory [gu2023mamba].

From a modeling perspective, most Linear Attention implementations introduce data-dependent gating factors to control the speed of their “fading” memory, balancing expressivity with scalability. For example, the transition from Mamba to Mamba2 replaced channel-wise data-dependent gating with head-wise gating for better scalability and Tensor Cores utilization. Input-dependent Gating has been shown to empirically improve training stability [arora2023zoology, Yang-ICLR2025] and has driven the development of Linear Attention models (e.g., from S4 [alber_gu_s4] to Mamba [gu2023mamba] and from DeltaNet [Yang-DeltaNet] to Gated DeltaNet [Yang-ICLR2025]). In our work, we demonstrate that gating emerges naturally as a consequence of solving a weighted least squares objective function, establishing a connection to the favorable numerical properties classically described in the adaptive filtering literature [LJUNG_RLS_stability, sayed2003fundamentals, sayed2011adaptive].

A.2 Hybrid State Space Attention Models

While extending the recurrent state in SSM layers has yielded performant models, they typically underperform on tasks requiring recall of information from the distant past [waleffe2024empirical, jelassi2024repeat]. Hybrid State-Space Models address this limitation by complementing SSMs’ “fading" state with Attention layers [dao2024transformers, de2024griffin, lieber2024jamba, glorioso2024zamba]. Early architectures simply stacked SSMs and Attention layers with different blending ratios [waleffe2024empirical, gu2023mamba, Dao-ICML2024-mamba2] or replaced full Attention layers with Sliding Window Attention [de2024griffin]. More sophisticated designs have recently emerged [glorioso2024zamba, Zancato-NeurIPS2024-BMOJO].

Notably, B’MOJO [Zancato-NeurIPS2024-BMOJO] complements SSMs’ fading state with "eidetic" memory by combining SSMs with Sliding Window Attention (SWA) in a single layer. Within the window, tokens can attend to a selected set of past tokens that were deemed difficult to predict using an asynchronous causal selection mechanism. B’MOJO was the first hybrid model to propose a parallel fusion of SSM and SWA at the layer level. Subsequent works [dong2024hymba, bae2025hybrid] have shown this parallel fusion approach to be more performant (at equivalent compute) than the stacked approach of earlier works.

Thanks to their lower memory footprint and test-time scalability over long sequences, Hybrid architectures are expanding into long-range agentic tasks and have recently been trained with Reinforcement Learning at scale [chen2025minimax]. When coupled with system-level optimizations like prefix caching [pan2024marconi] and specialized inference engines [kwon2023efficient], Hybrid models can increase the number of rollouts (exploration), thereby improving end-to-end performance in Reinforcement Learning loops.

Appendix B Forward and Backward Passes of Chebyshev Iteration (Details)

In Section˜4.2 we described our chunk-wise implementation of the CH method with adaptive regularization and gating. We now give full details omitted there.

B.1 Forward Pass

CH in Detail. We begin with describing the CH method (Algorithm˜1) in more detail. Assume we have a linear system of equations Hξ=qH\xi=q where HH is a D×DD\times D positive definite matrix. We assume HH has its all eigenvalues lie in the interval [μ,L][\mu,L] and the values of μ\mu and LL is known. Note that solving this system is equivalent to solving the following quadratic problem:

minξD12ξHξξq.\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:quadratic}}{e}q:quadratic}\min_{\xi\in\mathbb{R}^{D}}\ \frac{1}{2}\xi^{\top}H\xi-\xi^{\top}q. (21)

The classic Chebyshev Iteration in its standard form is presented in Algorithm˜1. In the initialization phase, we set ρ=LμL+μ\rho=\frac{L-\mu}{L+\mu}, which is the typical convergence rate of gradient descent applied to the above quadratic problem with stepsize 2L+μ\frac{2}{L+\mu}; vaguely speaking, in this setting, this stepsize choice is optimal (e.g., that allows gradient descent to converge the fastest possible). Algorithm˜1 initializes two points, ξ1\xi_{-1} and ξ0\xi_{0}. Here ξ1\xi_{-1} is zero, and ξ0\xi_{0} is a gradient step for Eq.˜21 starting at ξ1\xi_{-1} and with stepsize 2L+μ\frac{2}{L+\mu}. The final component in initialization is the weight ω0=2\omega_{0}=2. This is the starting point for the weight schedule recursion of ωi\omega_{i} in Eq.˜weight schedule. Similarly, the initialization of ξ1,ξ0\xi_{-1},\xi_{0} is where we start to compute ξi\xi_{i}, whose update consists of Eq.˜grad descent and Eq.˜momentum. Note that Eq.˜grad descent is with stepsize 2ωi/(L+μ)2\cdot\omega_{i}/(L+\mu). Since ωi>1\omega_{i}>1, this stepsize is strictly larger than 2/(L+μ)2/(L+\mu), the latter being the optimal stepsize for vanilla gradient descent. Such a large stepsize alone might not guarantee convergence, but it is balanced by the Eq.˜momentum term ξi1ξi2\xi_{i-1}-\xi_{i-2} with positive weight ωi1\omega_{i}-1 so that the convergence of the Chebyshev iterative method is ensured.

Numerical Stability Considerations. Now we analyze the numerical properties of the Chebyshev Iteration. The major computation consists of matrix-vector multiplication; in a batched parallel implementation, this turns out to be matrix-matrix multiplication. For this, the numerical accuracy is well controlled (e.g., in Triton we could specify the accuracy in tl.dot). The update of ωi\omega_{i} in Eq.˜weight schedule might raise numerical concerns as it involves division. That said, we show this division operates in a numerically well-behaved range as ωi\omega_{i} is decreasing with ii yet lower bounded by 11:

Lemma 3.

For any rr, we have 2=ω0ωrω1>12=\omega_{0}\geq\cdots\geq\omega_{r}\geq\omega^{*}_{1}>1, where ω1\omega^{*}_{1} is defined as

ω1:=2(11ρ2)ρ2.\displaystyle\omega^{*}_{1}:=\frac{2(1-\sqrt{1-\rho^{2}})}{\rho^{2}}.

As a consequence, we have 4ρ2ωi[2,4]4-\rho^{2}\omega_{i}\in[2,4] for all i=0,,ri=0,\dots,r.

Proof.

If L=μL=\mu, then HH is a scaled identity matrix, and the algorithm is simplified a lot. So we assume L>μL>\mu in what follows. With L>μ>0L>\mu>0 we have ρ(0,1)\rho\in(0,1). Since ω0=2\omega_{0}=2, we have 4ρ2ω024-\rho^{2}\omega_{0}\geq 2 and therefore 0<ω120<\omega_{1}\leq 2. Repeating this argument and we see ωi(0,2]\omega_{i}\in(0,2] for all ii. By the definition of ωi\omega_{i}, to show ωiωi1\omega_{i}\leq\omega_{i-1} is to show

44ρ2ωi1ωi1g(ωi)0\displaystyle\frac{4}{4-\rho^{2}\omega_{i-1}}\leq\omega_{i-1}\Leftrightarrow g(\omega_{i})\leq 0

where gg is defined as g(ω)=ρ2ω24ω+4g(\omega)=\rho^{2}\omega^{2}-4\omega+4. Note that g(ω)g(\omega) has two roots, ω1\omega_{1}^{*}, as defined earlier, and ω2=2(1+1ρ2)ρ2\omega_{2}^{*}=\frac{2(1+\sqrt{1-\rho^{2}})}{\rho^{2}}; ω1,ω2\omega_{1}^{*},\omega_{2}^{*} are the two fixed points of the update Eq.˜weight schedule. Observing that ω0=2\omega_{0}=2 lies in the interval (ω1,ω2)(\omega_{1}^{*},\omega_{2}^{*}), and moreover, for any i1i\geq 1, if ωi1>ω1\omega_{i-1}>\omega_{1}^{*} we must have

ωi=44ρ2ωi1>44ρ2ω1=ω1.\displaystyle\omega_{i}=\frac{4}{4-\rho^{2}\omega_{i-1}}>\frac{4}{4-\rho^{2}\omega_{1}^{*}}=\omega_{1}^{*}.

This proves ωi>ω1\omega_{i}>\omega_{1}^{*} for all i=1,,ri=1,\dots,r. Next, since ω0=2\omega_{0}=2 lies in the interval (ω1,ω2)(\omega_{1}^{*},\omega_{2}^{*}) where g(ω)g(\omega) decreases, therefore we have ω1ω0\omega_{1}\leq\omega_{0}. Thus ω1\omega_{1} lies in (ω1,ω2)(\omega_{1}^{*},\omega_{2}^{*}) again. We could then conclude inductively that ω1<ωiωi1\omega_{1}^{*}<\omega_{i}\leq\omega_{i-1} for all i=1,,ri=1,\dots,r. ∎

From Lemma˜3 we know that the update of ωi\omega_{i} in Eq.˜weight schedule would not create much numerical concern in a forward pass, as we have ωi[1,2]\omega_{i}\in[1,2] for all ii. Furthermore, we can bound the rate at which ωi\omega_{i} converges to ω1\omega_{1}^{*}:

Lemma 4.

Define κ:=Lμ\kappa:=\frac{L}{\mu}. For any i=1,,ri=1,\dots,r, we have

(ωiω1)Ri(ω0ω1),\displaystyle(\omega_{i}-\omega_{1}^{*})\leq R^{i}\cdot(\omega_{0}-\omega_{1}^{*}),

where RR is defined as

R:=κ1κ+1κ1κ+1.\displaystyle R:=\frac{\kappa-1}{\kappa+1}\cdot\frac{\sqrt{\kappa}-1}{\sqrt{\kappa}+1}.
Proof.

From the update rule of ωi\omega_{i} in Eq.˜weight schedule and the fixed point property of ω1\omega_{1}^{*}, we have

ωiω1\displaystyle\omega_{i}-\omega_{1}^{*} =44ρ2ωi144ρ2ω1\displaystyle=\frac{4}{4-\rho^{2}\omega_{i-1}}-\frac{4}{4-\rho^{2}\omega_{1}^{*}}
=4ρ2(4ρ2ωi1)(4ρ2ω1)(ωi1ω1)\displaystyle=\frac{4\rho^{2}}{(4-\rho^{2}\omega_{i-1})(4-\rho^{2}\omega_{1}^{*})}\cdot(\omega_{i-1}-\omega_{1}^{*})
=(i)ρ2ω14ρ2ωi1(ωi1ω1)\displaystyle\overset{\text{(i)}}{=}\frac{\rho^{2}\omega_{1}^{*}}{4-\rho^{2}\omega_{i-1}}\cdot(\omega_{i-1}-\omega_{1}^{*})
(ii)ρ2wi1ω14(ωi1ω1)\displaystyle\overset{\text{(ii)}}{\leq}\frac{\rho^{2}w_{i-1}\omega_{1}^{*}}{4}\cdot(\omega_{i-1}-\omega_{1}^{*})
(iii)(11ρ2)(ωi1ω1)\displaystyle\overset{\text{(iii)}}{\leq}\left(1-\sqrt{1-\rho^{2}}\right)\cdot(\omega_{i-1}-\omega_{1}^{*})
=(iv)(κ1κ+1κ1κ+1)(ωi1ω1)\displaystyle\overset{\text{(iv)}}{=}\left(\frac{\kappa-1}{\kappa+1}\cdot\frac{\sqrt{\kappa}-1}{\sqrt{\kappa}+1}\right)\cdot(\omega_{i-1}-\omega_{1}^{*})

Here, (i) follows from the fact that ω1\omega_{1}^{*} is a fixed point, (ii) follows from Lemma˜3 that ωiωi1\omega_{i}\leq\omega_{i-1}, (iii) follows from the definition of ω1\omega_{1}^{*} and the fact wi12w_{i-1}\leq 2, and (iv) follows from the definitions of κ\kappa and ρ\rho. The proof is concluded by unrolling the above recurrence. ∎

Remark 1.

Here, we call RR the linear convergence rate (or contraction factor) of ωi\omega_{i} to ω1\omega_{1}^{*}. First-order methods for solving Hξ=qH\xi=q converge at most at a rate Ra:=κ1κ+1R_{a}:=\frac{\sqrt{\kappa}-1}{\sqrt{\kappa}+1}, and we see ωi\omega_{i} converges at an even faster rate. Numerically, assuming κ=Lμ=1.020.02=51\kappa=\frac{L}{\mu}=\frac{1.02}{0.02}=51, we then have:

R\displaystyle R 0.7253,R50.2,R100.04,R200.0016,R306×105\displaystyle\approx 0.7253,\quad R^{5}\approx 0.2,\quad R^{10}\approx 0.04,\quad R^{20}\approx 0.0016,\quad R^{30}\approx 6\times 10^{-5}
Ra\displaystyle R_{a} 0.7543,Ra50.244,Ra100.0597,Ra200.0036,Ra300.0002.\displaystyle\approx 0.7543,\quad R_{a}^{5}\approx 0.244,\quad R_{a}^{10}\approx 0.0597,\quad R_{a}^{20}\approx 0.0036,\quad R_{a}^{30}\approx 0.0002.

Thus, with κ=51\kappa=51, the update of ωi\omega_{i} in Eq.˜weight schedule converges in at most 20 iterations up to the bfloat16 precision.

B.2 Backward Pass

We now give details for backpropagation through the Chebyshev Iteration (Algorithm˜1) via implicit differentiation.

Computing dLdqt\frac{dL}{dq_{t}} and dLdkt\frac{dL}{dk_{t}}. First, we follow Table˜1 and Lemma˜2, and compute dqtdq_{t}. Then, given the equation (Ht+λtI)dqt=dxt(H_{t}+\lambda_{t}I)dq_{t}=dx_{t}^{*}, we have that

d(Ht+λtI)=dqt(xt).\displaystyle d(H_{t}+\lambda_{t}I)=-dq_{t}(x_{t}^{*})^{\top}. (22)

Therefore dλt=tr(d(Ht+λtI))=(xt)dqtd\lambda_{t}=\text{tr}(d(H_{t}+\lambda_{t}I))=-(x_{t}^{*})^{\top}dq_{t}. Since we set λt=aHtF\lambda_{t}=a\cdot\|H_{t}\|_{\text{F}}, this indicates

dHt=dqt(xt)aHtHtF((xt)dqt).\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:dH_t-implicit}}{e}q:dH_{t}-implicit}dH_{t}=-dq_{t}(x_{t}^{*})^{\top}-a\cdot\frac{H_{t}}{\|H_{t}\|_{\text{F}}}\cdot\left((x_{t}^{*})^{\top}dq_{t}\right). (23)

Note that this expression of dHtdH_{t} is partial: It accounts for the upstream gradient from dqtdq_{t} only and one might think of the subsequent states all depend on HtH_{t}. We will accumulate the gradients later when needed.

Now, the recursion of HtH_{t} in Eq.˜CH implies

dki\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:dk_i}}{e}q:dk_{i}}dk_{i} =ti(dHt+(dHt))kiζtζi\displaystyle=\sum_{t\geq i}\left(dH_{t}+(dH_{t})^{\top}\right)k_{i}\cdot\frac{\zeta_{t}}{\zeta_{i}} (24)
=tiζtζi(dqt(xt)xt(dqt)+wtHt)ki,\displaystyle=\sum_{t\geq i}\frac{\zeta_{t}}{\zeta_{i}}\left(-dq_{t}\otimes(x_{t}^{*})^{\top}-x_{t}^{*}\otimes(dq_{t})^{\top}+w_{t}H_{t}\right)k_{i}, (25)

which proves Lemma˜1. We refer the reader to Section˜B.2.1 for more detailed derivations of dqtdq_{t} and dktdk_{t}.

Derivatives for Gating. In practice we often parameterize γt\gamma_{t} in the log space to ensure numerical stability. Thus, let us first revise our notations for this case. Let gt=logγtg_{t}=\log\gamma_{t} and Gt:=i=1tgi=log(i=1tγi)G_{t}:=\sum_{i=1}^{t}g_{i}=\log\left(\prod_{i=1}^{t}\gamma_{i}\right). Then the mask matrix MM is

Mi,j={exp(GjGi)ji;0otherwise.\displaystyle M_{i,j}=\begin{cases}\exp(G_{j}-G_{i})&j\geq i;\\ 0&\text{otherwise}.\end{cases} (26)

Now, since for any c=1,,Cc=1,\dots,C we have

Hc=exp(Gc)H0+j=1cexp(GcGj)kjkj,\displaystyle H_{c}=\exp(G_{c})\cdot H_{0}+\sum_{j=1}^{c}\exp(G_{c}-G_{j})\cdot k_{j}k_{j}^{\top}, (27)

for any GiG_{i} we have the following basic derivatives:

dHcdGi\displaystyle\frac{dH_{c}}{dG_{i}} ={0c<i;exp(Gi)H0+j=1i1exp(GiGj)kjkjc=i;exp(GcGi)kikic>i,i=1,,C,c=1,,C;\displaystyle=\begin{cases}0&c<i;\\ \exp(G_{i})\cdot H_{0}+\sum_{j=1}^{i-1}\exp(G_{i}-G_{j})\cdot k_{j}k_{j}^{\top}&c=i;\\ -\exp(G_{c}-G_{i})k_{i}k_{i}^{\top}&c>i,\end{cases}\quad\quad i=1,\dots,C,\quad c=1,\dots,C; (28)
dHC+1dGi\displaystyle\frac{dH_{C+1}}{dG_{i}} =exp(GC+1Gi)kikii=1,,C\displaystyle=-\exp(G_{C+1}-G_{i})k_{i}k_{i}^{\top}\quad\quad\quad i=1,\dots,C (29)

With dHC+1dH_{C+1} being the aggregated gradient from the future, we have for i=1,,Ci=1,\dots,C that

dGi\displaystyle dG_{i} =c=iC+1dHc,dHcdGi\displaystyle=\sum_{c=i}^{C+1}\langle dH_{c},\frac{dH_{c}}{dG_{i}}\rangle (30)
=eGidHi,H0+j=1i1eGiGjdHi,kjkjc=i+1CeGcGidHc,kikieGC+1GidHC+1,kiki\displaystyle=e^{G_{i}}\langle dH_{i},H_{0}\rangle+\sum_{j=1}^{i-1}e^{G_{i}-G_{j}}\langle dH_{i},k_{j}k_{j}^{\top}\rangle-\sum_{c=i+1}^{C}e^{G_{c}-G_{i}}\langle dH_{c},k_{i}k_{i}^{\top}\rangle-e^{G_{C+1}-G_{i}}\langle dH_{C+1},k_{i}k_{i}^{\top}\rangle (31)
=eGidHi,H0+j=1ieGiGjdHi,kjkjc=iCeGcGidHc,kikieGC+1GidHC+1,kiki\displaystyle=e^{G_{i}}\langle dH_{i},H_{0}\rangle+\sum_{j=1}^{i}e^{G_{i}-G_{j}}\langle dH_{i},k_{j}k_{j}^{\top}\rangle-\sum_{c=i}^{C}e^{G_{c}-G_{i}}\langle dH_{c},k_{i}k_{i}^{\top}\rangle-e^{G_{C+1}-G_{i}}\langle dH_{C+1},k_{i}k_{i}^{\top}\rangle (32)
dGC+1\displaystyle dG_{C+1} =eGC+1dHC+1,H0+j=1CeGC+1GjdHC+1,kjkj\displaystyle=e^{G_{C+1}}\langle dH_{C+1},H_{0}\rangle+\sum_{j=1}^{C}e^{G_{C+1}-G_{j}}\langle dH_{C+1},k_{j}k_{j}^{\top}\rangle (33)

Note that in one of the above equations we add and subtract the term dHi,kiki\langle dH_{i},k_{i}k_{i}^{\top}\rangle, which will simplify the implementation.

Recall that dHt=dqt(xt)12wtHtdH_{t}=-dq_{t}(x_{t}^{*})^{\top}-\frac{1}{2}\cdot w_{t}H_{t} with wt=2a(xt)dqtHtFw_{t}=\frac{2a\cdot(x_{t}^{*})^{\top}dq_{t}}{\|H_{t}\|_{\text{F}}}. In computing the derivatives of GiG_{i} the first term dqt(xt)dq_{t}(x_{t}^{*})^{\top} is the standard term that arises in that of Eq.˜Linear-SSM, which we omit here. We now focus on the second term 12wtHt\frac{1}{2}\cdot w_{t}H_{t}. This implies the gradients dGidG_{i} and dGC+1dG_{C+1} are partly given respectively by (using the notations in Section˜4.2.2 and omitting some algebraic operations)

12wiHi,Hi12kiBikiand12eGCB~C+1,H0+12j=1CeGCGjkjB~C+1kj.\displaystyle\frac{1}{2}\cdot\langle w_{i}H_{i},H_{i}\rangle-\frac{1}{2}\cdot k_{i}^{\top}B_{i}k_{i}\quad\quad\text{and}\quad\quad\frac{1}{2}\cdot e^{G_{C}}\langle\widetilde{B}_{C+1},H_{0}\rangle+\frac{1}{2}\cdot\sum_{j=1}^{C}e^{G_{C}-G_{j}}\cdot k_{j}^{\top}\widetilde{B}_{C+1}k_{j}. (34)

Computing the first term wiHi,Hi\langle w_{i}H_{i},H_{i}\rangle in parallel is easy by invoking the definition of wiw_{i} and the Frobenius norm of HiH_{i} we stored during the forward pass. Computing the quadratic terms kiBikik_{i}^{\top}B_{i}k_{i} and kjB~C+1kjk_{j}^{\top}\widetilde{B}_{C+1}k_{j} in parallel is easy and follows from our computation of BikiB_{i}k_{i} and B~C+1ki\widetilde{B}_{C+1}k_{i} for dkidk_{i} in Section˜4.2.2. Computing B~C+1,H0\langle\widetilde{B}_{C+1},H_{0}\rangle is easy since we recompute the initial states H0H_{0} of each chunk and have them available during the backward pass, while B~C+1\widetilde{B}_{C+1} is updated backwards in a for loop.

B.2.1 Computing dLdqt\frac{dL}{dq_{t}} and dLdkt\frac{dL}{dk_{t}}.

In forward pass we solve

(Ht+λtI)xt=qt(H_{t}+\lambda_{t}I)x_{t}=q_{t}
xt\displaystyle x_{t} =(Ht+λtI)1qt\displaystyle=(H_{t}+\lambda_{t}I)^{-1}q_{t} (35)
dxt\displaystyle\implies dx_{t} =(Ht+λtI)1Jqtxtdqt\displaystyle=\underset{J_{q_{t}\to x_{t}}}{\underbrace{(H_{t}+\lambda_{t}I)^{-1}}}dq_{t}

Recall that the gradient is transpose of the Jacobian, thus we obtain

dLdqt=(Ht+λtI)1dLdxt.\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq. implicity dl/dq}}{e}q.implicitydl/dq}\frac{dL}{dq_{t}}=(H_{t}+\lambda_{t}I)^{-1}\frac{dL}{dx_{t}}. (36)

Thus, we can obtain dLdqt\frac{dL}{dq_{t}} by running a Chebyshev iteration to solve (for zz) the linear system of equations

(Ht+λtI)z=dLdxt.(H_{t}+\lambda_{t}I)z=\frac{dL}{dx_{t}}.

Now we have

dxt\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq: dH}}{e}q:dH}dx_{t} =d(Ht+λtI)1qt\displaystyle=d(H_{t}+\lambda_{t}I)^{-1}q_{t} (37)
dxt\displaystyle dx_{t} =(Ht+λtI)1d(Ht+λtI)(Ht+λtI)1qt\displaystyle=-(H_{t}+\lambda_{t}I)^{-1}d(H_{t}+\lambda_{t}I)(H_{t}+\lambda_{t}I)^{-1}q_{t}
dxt\displaystyle dx_{t} =(Ht+λtI)1d(Ht+λtI)xt\displaystyle=-(H_{t}+\lambda_{t}I)^{-1}d(H_{t}+\lambda_{t}I)x_{t}
=(xt(Ht+λtI)1)vec(d(Ht+λtI))\displaystyle=(x_{t}^{\top}\otimes-(H_{t}+\lambda_{t}I)^{-1})\textrm{vec}(d(H_{t}+\lambda_{t}I))

In the last equality we have used the identity vec(ABC)=(CA)vec(B)\textrm{vec}(ABC)=(C^{\top}\otimes A)\textrm{vec}(B).

Now we will compute the Jacobian of λ\lambda with respect to HtH_{t}:

λt\displaystyle\lambda_{t} =aHtF\displaystyle=a||H_{t}||_{F} (38)
=aTr(HtHt)\displaystyle=a\sqrt{\textrm{Tr}(H_{t}^{\top}H_{t})}
dλ\displaystyle\implies d\lambda =ad(Tr(HtHt))\displaystyle=ad\Big(\sqrt{\textrm{Tr}(H_{t}^{\top}H_{t})}\Big)
=a12HtFTr((dHt)Ht+Ht(dHt))\displaystyle=a\frac{1}{2||H_{t}||_{F}}\textrm{Tr}((dH_{t})^{\top}H_{t}+H_{t}^{\top}(dH_{t}))
=a1HtFvec(Ht)dvec(Ht)\displaystyle=a\frac{1}{||H_{t}||_{F}}\textrm{vec}(H_{t})^{\top}d\textrm{vec}(H_{t})

Substituting (38) in (37).

dxt\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq: dH_full}}{e}q:dH_{f}ull}dx_{t} =(xt(Ht+λtI)1)(vec(dHt)+aHtFvec(I)vec(Ht)vec(dHt)))\displaystyle=(x_{t}^{\top}\otimes-(H_{t}+\lambda_{t}I)^{-1})\Big(\textrm{vec}(dH_{t})+\frac{a}{||H_{t}||_{F}}\textrm{vec}(I)\textrm{vec}(H_{t})^{\top}\textrm{vec}(dH_{t}))\Big) (39)

Thus, we can obtain vec(dLdHt)\textrm{vec}(\frac{dL}{dH_{t}}) as,

vec(dLdHt)=(xt(Ht+λtI)1)dLdxt+aHtFvec(Ht)vec(I)(xt(Ht+λtI)1)dLdxt\textrm{vec}(\frac{dL}{dH_{t}})=(x_{t}\otimes-(H_{t}+\lambda_{t}I)^{-1})\frac{dL}{dx_{t}}+\frac{a}{||H_{t}||_{F}}\textrm{vec}(H_{t})\textrm{vec}(I)^{\top}(x_{t}\otimes-(H_{t}+\lambda_{t}I)^{-1})\frac{dL}{dx_{t}} (40)

Substituting from (36),

vec(dLdHt)\displaystyle\textrm{vec}(\frac{dL}{dH_{t}}) =(xtdLdqt)aHtFvec(Ht)vec(I)(xtdLdqt)\displaystyle=-(x_{t}\otimes\frac{dL}{dq_{t}})-\frac{a}{||H_{t}||_{F}}\textrm{vec}(H_{t})\textrm{vec}(I)^{\top}(x_{t}\otimes\frac{dL}{dq_{t}}) (41)
=(xtdLdqt)aHtFvec(Ht)dLdqt,xt\displaystyle=-(x_{t}\otimes\frac{dL}{dq_{t}})-\frac{a}{||H_{t}||_{F}}\textrm{vec}(H_{t})\langle\frac{dL}{dq_{t}},x_{t}\rangle

Now, with gating, we have Ht=γtHt1+ktktH_{t}=\gamma_{t}H_{t-1}+k_{t}k_{t}^{\top}. Which can be unrolled as

Hl=i=0l(k=il1γk)kiki\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq: cumulative dynamics}}{e}q:cumulativedynamics}H_{l}=\sum_{i=0}^{l}\Big(\prod_{k=i}^{l-1}\gamma_{k}\Big)k_{i}k_{i}^{\top} (42)

We will compute dLdkl\frac{dL}{dk_{l}} for some ltl\leq t,

dLdkl=tldvec(Ht)dklvec(dLdHt)\displaystyle\frac{dL}{dk_{l}}=\sum_{t\geq l}\frac{d\textrm{vec}(H_{t})}{dk_{l}}\textrm{vec}(\frac{dL}{dH_{t}}) (43)

Computing dvec(Ht)dkl\frac{d\textrm{vec}(H_{t})}{dk_{l}} for some tlt\geq l

Ht=i=lt1γiklkl+terms indep. of kl.H_{t}=\prod_{i=l}^{t-1}\gamma_{i}k_{l}k_{l}^{\top}+\textrm{terms indep. of $k_{l}$.} (44)

Taking differentials on both sides,

dHt\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq: Jacobian H wrt k}}{e}q:JacobianHwrtk}dH_{t} =i=lt1γi[dklkl+kl(dkl)]\displaystyle=\prod_{i=l}^{t-1}\gamma_{i}\Big[dk_{l}k_{l}^{\top}+k_{l}(dk_{l})^{\top}\Big] (45)
d(vec(Ht))\displaystyle d(\textrm{vec}(H_{t})) =i=lt1γi[vec(dklkl)+vec(kl(dkl))]\displaystyle=\prod_{i=l}^{t-1}\gamma_{i}\Big[\textrm{vec}(dk_{l}k_{l}^{\top})+\textrm{vec}(k_{l}(dk_{l})^{\top})\Big]
=(i=lt1γi)[(klI)+(Ikl)]Jklvec(H)dkl\displaystyle=\underset{J_{k_{l}\to\textrm{vec}(H)}}{\underbrace{\Big(\prod_{i=l}^{t-1}\gamma_{i}\Big)\Big[(k_{l}\otimes I)+(I\otimes k_{l})\Big]}}dk_{l}

where in the last equality we used the identity vec(kldkl)=dklkl=(Ikl)dkl\textrm{vec}(k_{l}dk_{l}^{\top})=dk_{l}\otimes k_{l}=(I\otimes k_{l})dk_{l}.

Subsituting the Jacobian (transposed for gradients) from (45) to (43) we obtain.

dLdkl=tl(i=lt1γi)[(klI)+(Ikl)]vec(dLdHt)\displaystyle\frac{dL}{dk_{l}}=\sum_{t\geq l}\Big(\prod_{i=l}^{t-1}\gamma_{i}\Big)\Big[(k_{l}^{\top}\otimes I)+(I\otimes k_{l}^{\top})\Big]\textrm{vec}(\frac{dL}{dH_{t}}) (46)

Substituting the expression for vec(dLdHt)\text{vec}(\frac{dL}{dH_{t}}) into equation (46) we get:

dLdkl\displaystyle\frac{dL}{dk_{l}} =tl(i=lt1γi)\displaystyle=-\sum_{t\geq l}\left(\prod_{i=l}^{t-1}\gamma_{i}\right) [(klI)+(Ikl)][(xtdLdqt)+aHtFvec(Ht)dLdqt,xt]=\displaystyle\Big[(k_{l}^{\top}\otimes I)+(I\otimes k_{l}^{\top})\Big]\Big[(x_{t}\otimes\frac{dL}{dq_{t}})+\frac{a}{||H_{t}||_{F}}\textrm{vec}(H_{t})\langle\frac{dL}{dq_{t}},x_{t}\rangle\Big]= (47)
=tl(i=lt1γi)\displaystyle=-\sum_{t\geq l}\left(\prod_{i=l}^{t-1}\gamma_{i}\right) [((klI)(xtdLdqt)+aHtF(klI)vec(Ht)dLdqt,xt)+\displaystyle\Big[\Big((k_{l}^{\top}\otimes I)(x_{t}\otimes\frac{dL}{dq_{t}})+\frac{a}{||H_{t}||_{F}}(k_{l}^{\top}\otimes I)\textrm{vec}(H_{t})\langle\frac{dL}{dq_{t}},x_{t}\rangle\Big)+
(Ikl)(xtdLdqt)+aHtF(Ikl)vec(Ht)dLdqt,xt]\displaystyle(I\otimes k_{l}^{\top})(x_{t}\otimes\frac{dL}{dq_{t}})+\frac{a}{||H_{t}||_{F}}(I\otimes k_{l}^{\top})\textrm{vec}(H_{t})\langle\frac{dL}{dq_{t}},x_{t}\rangle\Big]

Note that the following equations hold:

(klI)(xtdLdqt)\displaystyle(k_{l}^{\top}\otimes I)(x_{t}\otimes\frac{dL}{dq_{t}}) =(klxtdLdqt)=kl,xtdLdqt\displaystyle=(k_{l}^{\top}x_{t}\otimes\frac{dL}{dq_{t}})=\langle k_{l},x_{t}\rangle\frac{dL}{dq_{t}} (48)
(Ikl)(xtdLdqt)\displaystyle(I\otimes k_{l}^{\top})(x_{t}\otimes\frac{dL}{dq_{t}}) =xtkldLdqt=kl,dLdqtxt\displaystyle=x_{t}\otimes k_{l}^{\top}\frac{dL}{dq_{t}}=\langle k_{l},\frac{dL}{dq_{t}}\rangle x_{t}

since (AB)(CD)=ACBD(A\otimes B)(C\otimes D)=AC\otimes BD and the fact that the Kronecker products after the simplification is a scalar times a vector.

For the other terms is holds:

(klI)vec(Ht)\displaystyle(k_{l}^{\top}\otimes I)\textrm{vec}(H_{t}) =vec(Htkl)\displaystyle=\textrm{vec}(H_{t}k_{l}) (49)
(Ikl)vec(Ht)\displaystyle(I\otimes k_{l}^{\top})\textrm{vec}(H_{t}) =vec(klHt)=vec(Htkl)\displaystyle=\textrm{vec}(k_{l}^{\top}H_{t})=\textrm{vec}(H_{t}^{\top}k_{l})

where we used the fact vec(AXB)=(BA)vec(X)\textrm{vec}(AXB)=(B^{\top}\otimes A)\textrm{vec}(X) and the fact that the vec operator applied to a row vector returns the same result as applying it on its transpose (so we go from klHtk_{l}^{\top}H_{t} to HtklH_{t}^{\top}k_{l}). Since HtH_{t} is symmetric we can sum both contributions and get twice that amount.

Eventually we get:

dLdkl=tl(i=lt1γi)[kl,xtdLdqt+kl,dLdqtxt+2aHtFxt,dLdqtHtkl]\boxed{\begin{aligned} \frac{dL}{dk_{l}}=-\sum_{t\geq l}\left(\prod_{i=l}^{t-1}\gamma_{i}\right)\Bigg[&\langle k_{l},x_{t}\rangle\frac{dL}{dq_{t}}+\langle k_{l},\frac{dL}{dq_{t}}\rangle x_{t}+\frac{2a}{||H_{t}||_{F}}\langle x_{t},\frac{dL}{dq_{t}}\rangle H_{t}k_{l}\Bigg]\end{aligned}} (50)

Or equivalently, collecting the terms that are linear in the gradient:

dLdkl=tl(i=lt1γi)[kl,xtdLdqt+kl,dLdqtxt+2axt,dLdqtHtFHtkl]\boxed{\frac{dL}{dk_{l}}=-\sum_{t\geq l}\left(\prod_{i=l}^{t-1}\gamma_{i}\right)\left[\langle k_{l},x_{t}\rangle\frac{dL}{dq_{t}}+\langle k_{l},\frac{dL}{dq_{t}}\rangle x_{t}+\frac{2a\langle x_{t},\frac{dL}{dq_{t}}\rangle}{||H_{t}||_{F}}H_{t}k_{l}\right]} (51)

Note: The last term creates a dependence on klk_{l} through HtklH_{t}k_{l}, which is expected since the regularization λt\lambda_{t} couples the gradient computation.

Appendix C Proof of Lemma˜2

Lemma˜2 describes an interesting phenomenon where, for the CH method (Algorithm˜1), the gradient dq^d\hat{q} obtained from implicit differentiation coincides with the exact gradient dqdq obtained via backpropagation (chain rule). To prove this result, one way is to derive an analytic expression for dqdq (Section˜C.1) and then inspect the recursions. However, this can be algebraically involved. Here, we present a clear proof based on some simple observations.

First, note that the output ξr\xi_{r} is linear in qq and moreover there is a matrix function pr(H)D×Dp_{r}(H)\in\mathbb{R}^{D\times D} such that

ξr=pr(H)q.\displaystyle\xi_{r}=p_{r}(H)\cdot q. (52)

Here pr(H)p_{r}(H) is a polynomial function of HH that encodes the Chebyshev iteration (Algorithm˜1). Conversely, we understand that pr(H)qp_{r}(H)\cdot q can be computed by applying the Chebyshev iteration with H,qH,q for rr iterations (together with other parameters such as μ,L\mu,L). Then, given the output gradient dξrd\xi_{r}, we have

dq=pr(H)dξr=pr(H)dξr,\displaystyle dq=p_{r}(H)^{\top}\cdot d\xi_{r}=p_{r}(H)\cdot d\xi_{r}, (53)

where the last equality follows, since HH is symmetric, which implies pr(H)p_{r}(H) is symmetric. The proof is finished by observing that pr(H)dξrp_{r}(H)\cdot d\xi_{r} can be computed via Algorithm˜1 with H=H,q=dξrH=H,q=d\xi_{r} and other parameters, which gives us dqdq.

C.1 The Exact Backward Pass for dqdq and dHdH

Here we show how to obtain the exact gradients of dHdH and dqdq in Algorithm˜1 given the output gradient dξrd\xi_{r}, which might be of independent interests. The key insight here is that the Chebyshev iteration can be reversed.

Backward Pass for dq. Let II be the identity matrix of suitable size. To derive a backward pass of Algorithm˜1, we first write down the update of ξi\xi_{i} concisely in the following recursion

ξi=Aiξi1+biξi2+ciq,\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:3recursion}}{e}q:3recursion}\xi_{i}=A_{i}\xi_{i-1}+b_{i}\xi_{i-2}+c_{i}q, (54)

where Ai,bi,ciA_{i},b_{i},c_{i} are defined as

Ai=ωiI2ωiL+μH,bi=(ωi1),ci=2ωiL+μ.\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:Abc}}{e}q:Abc}A_{i}=\omega_{i}I-\frac{2\cdot\omega_{i}}{L+\mu}H,\quad b_{i}=-(\omega_{i}-1),\quad c_{i}=\frac{2\cdot\omega_{i}}{L+\mu}. (55)

Note that AiA_{i} is symmetric. Define dξi:=ddξid\xi_{i}:=\frac{d\mathcal{L}}{d\xi_{i}} for every ii. With some loss function \mathcal{L}, assume we are now given dξrd\xi_{r}, and our goal is to compute dq:=ddqdq:=\frac{d\mathcal{L}}{dq}. Since qq appears in Eq.˜54 for every ii, we know ξ0,ξ1,,ξr\xi_{0},\xi_{1},\dots,\xi_{r} all depend on qq. Therefore, with c0:=2L+μc_{0}:=\frac{2}{L+\mu}, we have

dq=i=0rcidξi.\displaystyle dq=\sum_{i=0}^{r}c_{i}\cdot d\xi_{i}.

It remains to compute dξid\xi_{i} for every ii. Applying the chain rule to Eq.˜54, we obtain

dξr1=Ardξrdξi2=Ai1dξi1+bidξi,i=r,,2.\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:dxi}}{e}q:dxi}\begin{split}d\xi_{r-1}&=A_{r}\cdot d\xi_{r}\\ d\xi_{i-2}&=A_{i-1}\cdot d\xi_{i-1}+b_{i}\cdot d\xi_{i},\quad\forall i=r,\dots,2.\end{split} (56)

Note that Ai,bi,ciA_{i},b_{i},c_{i} depend on some constant terms and ωi\omega_{i}. Thus, to compute them backward we assume access to ωr\omega_{r} and these constants. By reversing Eq.˜weight schedule we derive the following recursion:

νrωrνi14ρ2(11νi),i=r,,1.\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:weight-update-backward}}{e}q:weight-update-backward}\begin{split}\nu_{r}&\leftarrow\omega_{r}\\ \nu_{i-1}&\leftarrow\frac{4}{\rho^{2}}\left(1-\frac{1}{\nu_{i}}\right),\quad\forall i=r,\dots,1.\end{split} (57)

Similarly to how ωi\omega_{i} decreases with ii and converges to ω1\omega_{1}^{*}, we may prove νi\nu_{i} is convergent to the other fixed point, ω2\omega_{2}^{*}, as ii decreases (and the iterate does not stop at i=1i=1).

Backward Pass for dA. From Eq.˜54 and Eq.˜55 we see that

dAi=dξiξi1,dH=2ωiL+μdAi\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:dA}}{e}q:dA}dA_{i}=d\xi_{i}\otimes\xi_{i-1}^{\top},\quad dH=-\frac{2\cdot\omega_{i}}{L+\mu}\cdot dA_{i} (58)

where \otimes denotes the Kronecker product; this is the out product of dξid\xi_{i} and ξi1\xi_{i-1}^{\top}, as dξid\xi_{i} and ξi1\xi_{i-1} are vectors.

Reverse Chebyshev Iteration. At first glance, computing dAi=dξiξi1dA_{i}=d\xi_{i}\otimes\xi_{i-1}^{\top} requires storing ξi1\xi_{i-1} in the forward pass, and the actual calculation of dAidA_{i} is done after we run the backward pass for dξid\xi_{i} in Eq.˜56. However, storing all ξi\xi_{i}’s would be memory-inefficient. To address this issue, a main insight here is that we can reverse Eq.˜54 and write

ξi2=1bi(ξiAiξi1+ciq).\displaystyle\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:xi_backward}}{e}q:xi_{b}ackward}\xi_{i-2}=\frac{1}{b_{i}}(\xi_{i}-A_{i}\xi_{i-1}+c_{i}q). (59)

This implies that we can recover all the iterates ξr,,ξ0\xi_{r},\dots,\xi_{0} as soon as we have access to the last two, ξr,ξr1\xi_{r},\xi_{r-1}. Therefore, to obtain dAidA_{i}, we can run two iteration schemes in Eq.˜56 and Eq.˜59 simultaneously.

Remark 2.

We find that being able to run the iterative update backward in a numerically stable fashion is a main feature of the Chebyshev iterative method (or more generally, gradient descent variants with momentum). Vanilla gradient descent can not efficiently reverse its iterate ξi=ξi1γi(Hξi1q)\xi_{i}=\xi_{i-1}-\gamma_{i}(H\xi_{i-1}-q) with stepsize γi\gamma_{i}, as it requires inverting (IγiH)(I-\gamma_{i}H). Moreover, reversing Eq.˜59 can be done stably, as bib_{i} is often in a good numerical range, which means division by bib_{i} in Eq.˜59 is not an issue. To see this, first note that by Lemma˜3 we have

1bi=ωi1ω11.\displaystyle 1\geq-b_{i}=\omega_{i}-1\geq\omega_{1}^{*}-1.

Note that ω1\omega_{1}^{*} defined in Lemma˜3 is an increasing function of ρ\rho and therefore of κ\kappa. We then have that bi[0.25,1]-b_{i}\in[0.25,1] for any κ10\kappa\geq 10 (we will not consider the case κ<10\kappa<10 as this means we need to add a very large regularization strength which might harm the minimization of the regression loss). In comparison, if we were to reverse the CG iteration, we would need to divide a quantity that is often numerically as small as 10310^{-3} or as large as 101010^{10} (see Fig.˜5). This is why it is numerically unstable to reverse CG.

Refer to caption
Refer to caption

     (a)     (b)

Figure 5: (a) The theoretical lower and upper bounds for the values of the divisor bib_{i} that arise in reversing Chebyshev Eq.˜59; (b) The empirical lower and upper bounds for the divisor that arises in reversing CG.
1
2Input: H,dξrH,d\xi_{r}, L,μL,\mu, , number of iterations rr, the final weight ωr\omega_{r};
3Initialize ρLμL+μ,dξr+10\rho\leftarrow\frac{L-\mu}{L+\mu},d\xi_{r+1}\leftarrow 0, νrωr,νr+10,ν01\nu_{r}\leftarrow\omega_{r},\nu_{r+1}\leftarrow 0,\nu_{0}\leftarrow 1 , dq2νrL+μdξrdq\leftarrow\frac{2\nu_{r}}{L+\mu}\cdot d\xi_{r}, dH2νrL+μdξrξr1dH\leftarrow-\frac{2\nu_{r}}{L+\mu}d\xi_{r}\otimes\xi_{r-1}^{\top};
4For i=r,,1i=r,\dots,1:
ξi2\displaystyle\xi_{i-2} 1νi1(ξi(νiI2νiL+μH)ξi1+2νiL+μq)\displaystyle\leftarrow-\frac{1}{\nu_{i}-1}\left(\xi_{i}-\left(\nu_{i}I-\frac{2\cdot\nu_{i}}{L+\mu}H\right)\xi_{i-1}+\frac{2\cdot\nu_{i}}{L+\mu}q\right) (60)
dξi1\displaystyle d\xi_{i-1} (νidξi2νiL+μHdξi)(νi+11)dξi+1\displaystyle\leftarrow\left(\nu_{i}\cdot d\xi_{i}-\frac{2\cdot\nu_{i}}{L+\mu}\cdot H\cdot d\xi_{i}\right)-(\nu_{i+1}-1)\cdot d\xi_{i+1} (61)
νi1\displaystyle\nu_{i-1} 4ρ2(11νi)\displaystyle\leftarrow\frac{4}{\rho^{2}}\left(1-\frac{1}{\nu_{i}}\right) (62)
dq\displaystyle dq dq+2νi1L+μdξi1\displaystyle\leftarrow dq+\frac{2\nu_{i-1}}{L+\mu}\cdot d\xi_{i-1} (63)
dH\displaystyle dH dH2νi1L+μdξi1ξi2\displaystyle\leftarrow dH-\frac{2\nu_{i-1}}{L+\mu}\cdot d\xi_{i-1}\otimes\xi_{i-2}^{\top}\addcontentsline{lla}{section}{\numberline{\string\crtrefnumber{eq:dH-1}}{e}q:dH-1} (64)
Output: dqdq, dHdH;
Algorithm 2 Backward Pass of Chebyshev Iteration

Appendix D Experimental Setup

Model Configurations. We consider models of 3 different sizes: 440M, 1B, and 2.8B. This is summarized in Table˜3. All models are with the GPT2 tokenizer similarly to [Von-arXiv2025-mesa].

Table 3: Model sizes and the corresponding architectural configurations.
Model Size Number of Layers Number of Heads Hidden Dimension
440M 28 8 1024
1B 28 12 1536
2.8B 32 20 2560

Training Configurations. All models are trained with the AdamW optimizer with initial learning rate 10310^{-3}, 5%5\% warm-up steps, cosine schedule, gradient clipping with maximum norm 11.

Table 4: Model sizes and the corresponding architectural configurations.
Model Size Global Batch Size Total Number of Training Tokens Sequence Length
440M 1M 8B 2048
1B 2M 20B 2048
2.8B 2M 100B 4096

Models of the same scale use the same training configurations. Specifically (see also Table˜4):

  • For 440M models, we use sequence length 2048 and 8B DCLM tokens.

  • For 1B models, we use sequence length 2048 and 20B DCLM tokens.

  • For 2.8B models, we use sequence length 4096 and 100B DCLM tokens.

Model Hyperparameters. We use default parameters for all other models as given in the Flash-Linear-Attention v0.4.0 library (except the ones mentioned in Table˜3). For our approach, we use λt=0.02HtF\lambda_{t}=0.02\cdot\|H_{t}\|_{\text{F}}, with gating and α\alpha-connection enabled by default, unless otherwise specified. We also run CH for 3030 iterations for all experiments.

Individual Experiments. We now describe the setups for each individual experiment.

In Fig.˜1a, we randomly generate tensors kB×T×H×Dk\in\mathbb{R}^{B\times T\times H\times D} and qq and normalize them along the last dimension (DD). Here B,T,H,DB,T,H,D simulate the batch size, sequence length, number of heads, and head dimension, respectively. Then we compute the covariance matrices HB×T×H×D×DH\in\mathbb{R}^{B\times T\times H\times D\times D} of kk, normalize its every D×DD\times D slice by its Frobenius norm. The code to generate data is shown below.

    k = torch.randn(B, T, H, D).to(dtype).to(’cuda’)
    q = torch.randn(B, T, H, D).to(dtype).to(’cuda’)

    q = q / torch.linalg.vector_norm(q, dim=-1, keepdim=True).to(q)
    k = k / torch.linalg.vector_norm(k, dim=-1, keepdim=True).to(k)

    kk = torch.einsum(’...i,...j->...ij’, k, k).cumsum(1)
    kk = kk / torch.linalg.matrix_norm(kk, ord=’fro’)[..., None, None]

    kk.diagonal(dim1=-2, dim2=-1).add_(ridge_strength)

For Fig.˜1c and Fig.˜1e, we generate random input ids with vocabulary size 5000, sequence length 2048 within a 5-layer LLAMA; we set 2 heads and head dimension 128 for this architecture.

In the MQAR experiments of Fig.˜3a, we follow the standard experimental setting but consider a strictly harder setting with smaller model dimension (or hidden dimension). Indeed, in the setting of [arora2023zoology], the model dimension is always larger than or equal to the number of KV pairs, while in the setting here, in some cases the model dimension is smaller than the number of KV pairs, in which case linear SSMs could not perfectly memorize all KV pairs.

In the main paper, Fig.˜3a is without any gating or α\alpha-connection.

In Fig.˜4 we considered the following tasks for long context evaluations. Reported results for each task is average over the score obtained for individual datasets in that task.

  • Retrieval-Augmented Generation (RAG): These tasks consist of open-domain question answer where the model is given a gold passage (passage containing the answer) interspersed between many other retrieved passages from a corpus [petroni2021kilt, Wikipedia dump split into 100-word passages]. The model is tasked with answering the question based on the obtained passages. We consider the following datasets from HELMET [yen2025helmet] for this task: Natural Questions, TriviaQA, PopQA, HotpotQA.

  • Many-shot In-Context Learning (ICL): ICL tests LLMs ability to learn new skills from a few examples. Here the task is to learn to classify between different concepts based on several in-context examples of the said concept. We consider the following datasets from HELMET [yen2025helmet] for this task: TREC Coarse, TREC Fine, NLU, BANKING77, CLINIC150.

  • Synthetic Recall: These tasks are variations of the “Needle-in-a-Haystack" task [needle2024haystack] where the goal is to retrieve an important piece of information, the “needle" from a long context of distractor tokens, the “haystack". These variations also test multi-hop tracing and aggregation capabilities of the model. We consider the following datasets from RULER [hsieh2024ruler] for this task: S-NIAH-1/2/3, MK-NIAH-1,2,3, MV-NIAH, MQ-NIAH, VT, CWE, FWE.

  • LongQA: These are long document based question-answering tasks. The documents are typically made long by randomly sampling different paragraphs from the same dataset along with the paragraph that contains the answer. We consider the following datasets from RULER [hsieh2024ruler] for this task: SQuAD, HotpotQA.

Appendix E How Does The Performance GKA Scale with Compute?

We consider models at three different scales: 440M, 1B and 2.8B. For training configurations and architecture refer to Appendix˜D. We use prototypical tasks from LM-Harness (see Section˜5.3.1 for list of tasks) to evaluate language modeling capabilities of GKA and compare with baseline SSM/fading memory layers. Table˜5 shows that at 440M scale, GKA is competitive with GDN and Deltanet. However, differences emerge at larger scales, with GKA showing increasing benefits. In particular, the retrieval capabilities of our model, as measured by FDA and SWDE consistently outperform all SSM baselines at 1B and 2.8B scale. We also report the results of equal-sized Transformer for completeness, which serves as a performance ceiling at each scale.

Table 5: GKA shows stronger scaling with compute that other SSM baseline models. LM-Harness results for models at different scales: 440M, 1B and 2.8B. All models were trained from scratch. 440M and 1B models were trained on 8B and 20B tokens respectively in accordance to the Chinchila scaling laws [hoffmann2022empirical]. For the 2.8B model we trained on 100B tokens.
Model
ARC-C
ARC-E
BoolQ
COPA
HellaSWAG
PIQA
SciQ
Winogrande
FDA
SWDE
Avg
acc_n \uparrow acc_n \uparrow acc \uparrow acc \uparrow acc_n \uparrow acc_n \uparrow acc_n \uparrow acc \uparrow contains \uparrow contains \uparrow
440M Models
Transformer 24.40 42.26 59.88 70.00 36.19 64.15 61.50 51.70 5.17 35.64 45.09
Gated Linear Attention 24.06 40.28 56.57 71.00 32.70 62.24 57.80 50.67 1.00 9.18 40.55
Gated DeltaNet 25.17 41.96 58.23 72.00 36.96 64.69 63.6 51.7 1.91 11.88 42.81
DeltaNet 25.09 41.92 61.13 65.00 37.20 64.47 64.00 49.49 2.81 14.31 42.54
Gated KalmaNet (Ours) 24.57 43.22 56.94 71.00 37.22 64.47 62.8 50.83 1.45 14.04 42.65
1B Models
Transformer 26.62 46.42 59.94 77.00 44.01 67.14 68.30 54.06 8.35 45.18 49.70
Mamba2 28.07 46.63 60.21 70.00 44.57 67.57 65.50 54.30 1.45 15.75 45.40
Gated Linear Attention 25.94 42.00 58.84 70.00 36.34 63.60 58.20 51.85 1.45 10.53 41.88
Gated DeltaNet 27.05 47.98 59.54 74.00 44.27 67.36 66.2 53.83 2.18 17.82 46.02
DeltaNet 27.56 46.25 59.97 71.00 43.18 67.74 65.90 55.41 3.09 20.61 46.07
Gated KalmaNet (Ours) 25.43 46.55 60.73 74.00 44.59 68.88 67.60 52.41 6.17 21.87 46.82
2.8B Models
Transformer 32.25 56.10 64.28 80.00 60.96 73.56 79.50 61.72 58.53 72.28 63.92
Mamba2 32.24 59.64 58.72 82.00 62.23 73.78 79.80 62.19 7.71 41.13 55.94
Gated Linear Attention 27.82 50.80 52.57 78.00 48.83 70.13 69.60 54.54 2.81 20.43 47.55
Gated DeltaNet 32.59 60.02 62.75 82.00 62.8 74.32 80.6 62.35 8.26 44.28 57.00
DeltaNet 32.85 58.16 42.51 81.00 61.13 73.78 43.90 61.72 11.80 46.08 51.29
Gated KalmaNet (Ours) 32.51 59.89 61.68 85.00 63.84 74.81 83.2 64.17 12.89 50.95 58.89

Appendix F Ablations

In this section we consider ablations for various modeling choices made in arriving at our final GKA model. For all ablations, we consider 2.8B models trained on 100B tokens on DCLM at 4K context length (unless mentioned otherwise). We use the same architecture and training configurations for these ablations as mentioned in Appendix˜D.

F.1 Does Adaptive Regularization Help?

As discussed in Section˜4.1, we introduced adaptive regularization to control the condition number of HT+λtIH_{T}+\lambda_{t}I for numerical stability. Here we ablate this choice, specifically we compare the following runs.

  1. 1.

    Adaptive regularization. We train a model with λt=aHtF\lambda_{t}=a||H_{t}||_{F}. We report results for a=0.02a=0.02 for this run.

  2. 2.

    Constant regularization We train same model architecture (as above) with λt=0.25\lambda_{t}=0.25 (a constant). This choice of 0.250.25 is motivated from concurrent work [Von-arXiv2025-mesa] which explored a similar ridge regression objective for LLM training.

As shown in Fig.˜6, without strict condition number control, gradient norms spike during training, leading to increased cross entropy loss (compared to the run with adaptive regularization).

Refer to caption
Refer to caption
Figure 6: Adaptive regularization results in smoother and better training curves. (a) Plots the training curve for 2.8B models on 100B tokens from DCLM. (b) Plots the corresponding gradient norm. The model with constant regularization (red curve) results in a higher loss that can be attributed to its non-smooth trajectory over the course of its training run (spiky gradient norms).

F.2 Does Adaptive Weighting Help?

In Section˜4.1, we discussed increasing the expressivity of our layer by introducing adaptive weights ηt,i\eta_{t,i} which re-weigh the past to be exponentially decaying in time. Given constant-sized memory, we hypothesize this adaptive weighting (gating) allows GKA to learn an effective representation by incorporating recency bias into its computation. In this subsection we test this hypothesis. We carry out the following runs.

  1. 1.

    Adaptive weighting (gating). We train a model with adaptive weights. Specifically, for all tit\geq i, we parameterize the weight for the ithi^{\textrm{th}} sample at time-step tt as ηt,i=j=i+1tγj\eta_{t,i}=\prod_{j=i+1}^{t}\gamma_{j}, with each γj[0,1]\gamma_{j}\in[0,1] learnable.

  2. 2.

    No weighting. We train the same model architecture as above, but with no weights. This essentially results in an unweighted ridge regression objective obtained by setting ηi=1\eta_{i}=1 in Eq.˜3.

Table˜6 shows clear benefits of adapting weighting with improvements across the board on all LM-Harness tasks considered, thereby validating our hypothesis.

Table 6: Adaptive weighting outperforms across the board on LM-Harness tasks. Results for 2.8B models trained on 100B tokens from DCLM with and without adaptive weights as introduced in Section˜4.1.
Adaptive Weights
ARC-C
ARC-E
BoolQ
COPA
HellaSWAG
PIQA
SciQ
Winogrande
FDA
SWDE
Avg
acc_n \uparrow acc_n \uparrow acc \uparrow acc \uparrow acc_n \uparrow acc_n \uparrow acc_n \uparrow acc \uparrow contains \uparrow contains \uparrow
28.24 51.73 57.68 76 53.87 71.87 71.6 54.38 6.08 33.03 50.45
32.51 59.89 61.68 85 63.84 74.81 83.2 64.17 12.89 50.95 58.89

F.3 Does α\alpha-connection Improve Training of GKA?

In Section˜4.3, we introduce the α\alpha-connection as a residual connection that establishes a direct path for gradient flow through the GLA solution, improving training stability. This allows the model to fall back on the GLA solution when CH produces poor-quality results due to non-convergence of the iterative solver within the fixed iteration budget. To validate this design choice, we perform two runs.

  1. R1.

    with α\alpha-connection. We train a model with the α\alpha-connection as shown in our GKA block in Fig.˜2.

  2. R2.

    without α\alpha-connection. We train the same model architecture as above, but with no α\alpha connection. This can be simply understood as setting αt=1\alpha_{t}=1 for all time-steps tt in Fig.˜2.

On LM-Harness, both models perform similarly, with R1 and R2 achieving aggregate scores of 58.89 and 58.39, respectively. However, clear differences emerge under long-context evaluation, where we trained both models on an additional 25B tokens from long documents at 128K context length. Fig.˜7 shows that GKA without the α\alpha-connection exhibits inferior long-context performance on average, with Synthetic Recall and LongQA showing major degradation.

Refer to caption
Figure 7: GKA without the α\alpha- connection severaly underperforms on Synthetic Recall and LongQA. On ICL all SSMs struggle to perform better than random chance (see Fig.˜4). Interestingly, although R2 exhibits poorer long-context abilities in aggregate, it outperforms R1 on RAG by a few points.

Appendix G Effects of Different Regularization Strengths

Recall that we proposed setting adaptive regularization λt=aHtF\lambda_{t}=a\cdot\|H_{t}\|_{\text{F}}. We now present experiments validating this choice.

Synthetic Experiments. First, we generate data as per Fig.˜1a, where the covariance matrix is normalized by its Frobenius norm. In this case we set λt=a\lambda_{t}=a for aa varying in {0.01,0.02,0.05,0.1}\{0.01,0.02,0.05,0.1\}. Fig.˜8 shows that the maximum regularized residual norm (computed as the maximum of (Ht+λtI)ξiq2\|(H_{t}+\lambda_{t}I)\xi_{i}-q\|_{2} over all dimensions where ξi\xi_{i} is the estimate of CH at iteration ii) decreases as we enlarge λt\lambda_{t}. This is because having a large λt\lambda_{t} reduces the condition number. The downside, though, with a large λt\lambda_{t} is that it reduces the memorization capacity, namely, it might enlarge Htξiq2\|H_{t}\xi_{i}-q\|_{2}, the true residual of interest.

Refer to caption
Refer to caption
Refer to caption
Refer to caption

    (a) λt=0.01\lambda_{t}=0.01     (b) λt=0.02\lambda_{t}=0.02     (c) λt=0.05\lambda_{t}=0.05     (d) λt=0.1\lambda_{t}=0.1

Figure 8: Convergence for varying regularization strengths (batch size 88, sequence length 2048, 8 heads, and head dimension 128).

GKA with different regularization strengths. We train several 2.8B models with varying regularization strength by choosing a[0.01,0.02,0.05,0.1]a\in[0.01,0.02,0.05,0.1]. While performance on LM-Harness (Table˜7) shows little discrepancy, we observe noticeable differences in long-context performance—where memorization capacity matters most—(Fig.˜9). Specifically, the long-context performance of GKA improves initially as we decrease aa from 0.10.050.1\to 0.05. This is expected since this increases the memorization capacity of the model. However, decreasing further from 0.050.020.010.05\to 0.02\to 0.01 causes performance to decrease. This can be attributed to the increasing condition number of the problem, which reduces the quality of the solution computed by CH (Fig.˜8).

Refer to caption
Figure 9: Long context performance GKA for different regularization strengths. The long-context performance of GKA improves initially as we decrease aa from 0.10.050.1\to 0.05. This is expected since this increases the memorization capacity of the model. However, decreasing further from 0.050.020.010.05\to 0.02\to 0.01 causes performance to decrease. This can be attributed to the increasing condition number of the problem, which reduces the quality of the solution computed by CH (Fig.˜8)
Table 7: Ablation over different choices of regularization strength λt=aHtF\lambda_{t}=a\cdot\|H_{t}\|_{\text{F}}. Short-context performance on LM-Harness shows little discrepancy with different regularization strengths.
aa
ARC-C
ARC-E
BoolQ
COPA
HellaSWAG
PIQA
SciQ
Winogrande
FDA
SWDE
Avg
acc_n \uparrow acc_n \uparrow acc \uparrow acc \uparrow acc_n \uparrow acc_n \uparrow acc_n \uparrow acc \uparrow contains \uparrow contains \uparrow
0.01 33.45 58.63 62.63 85.00 63.36 73.99 81.40 63.14 11.16 51.49 58.43
0.02 32.51 59.89 61.68 85.00 63.84 74.81 83.20 64.17 12.89 50.95 58.89
0.05 32.68 61.66 53.57 79.00 63.46 74.84 82.60 63.77 11.98 49.68 57.32
0.1 32.76 59.85 63.52 84.00 63.95 75.08 83.20 63.54 11.43 51.22 58.86
Table 8: Latent sketching increases training throughput (by up to 10%) while marginally reducing accuracy (< 1%). Training throughput is reported in # Billion tokens/day/node. It is measured on a single H200 GPU with a batch size of 1M tokens. Our results indicate minimal regression on LM-harness tasks but up to 10% improvement in training throughput (going from no-sketch to sketch dim 32). However, long context performance is adversely affected with sketching with up to 60% relative drop in performance. Future work will address this by exploring the use of sketching adaptively depending on the "complexity" of the task.
Sketch dimension LM-Harness avg. Training throughput
32 57.57 8.37
no-sketch 58.89 7.65

Appendix H Latent Sketching for Approximate Solutions

We introduce the idea of sketching from random matrix theory to further control the amount of FLOPs vs accuracy in GKA. Sketching involves down projecting the normal equations into a low-dimensional subspace, solving the equations in this subspace and finally up-projecting the solution back to the original space. This reduces the worst-case computational complexity of our approach from 𝒪(D2r)\mathcal{O}(D^{2}r) to 𝒪(d2r)\mathcal{O}(d^{2}r), where dDd\ll D and rr is the number of iterations in Algorithm˜1. To the best of our knowledge our work is the first one introducing sketching as a viable solution to increase efficiency of neural network layers that are defined implicitly by the solution to an optimization problem. Sketching can be thought of as an analogous to the Multi Latent Attention idea introduced by DeepSeek but applied to fading memory layers. Table˜8 shows preliminary results of this idea applied to GKA. Both models (no-sketch and sketch dim 32) are trained from scratch at 2.8B scale on 100B tokens.

Appendix I Hybrid Gated KalmaNet

As discussed in Section˜A.2, augmenting SSM models with Attention layers has proven to be an effective way of improving performance on tasks that require recalling information from the distant past. In this section, we show that our Gated KalmaNet layer can be interleaved with Attention layers to yield even stronger models. Our Hybrid GKA model is based on the Qwen3 architecture [yang2025qwen3]. Namely, our Hybrid model consists of a stack of “decoder” blocks, each of which contains a sequence mixer—either Attention or GKA—followed by an MLP. Similar to Qwen3, our Attention layers use QK normalization layers. Our Hybrid model consists of 30 decoder blocks, 26 of which use GKA as the sequence mixer, and 4 that use Attention. The Attention decoder blocks are at indices 6, 14, 22, and 29. Our Hybrid models follow the same training procedure as our non-Hybrid models. Specifically, we pretrain our Hybrid model on 100B tokens with a 4K context size, followed by fine-tuning on 25B tokens at a 128K context size.

When evaluating our pretrained Hybrid model standard NLP benchmarks, we observe that it improves substantially on recall-oriented tasks (FDA & SWDE) compared to the non-Hybrid model444Note, our non-hybrid model shares the same architecture as the hybrid with the distinction that all 4 Attention layers are replaced with GKA layers., as shown in Table˜9. Further, when evaluating our fine-tuned long-context model on tasks that require effective modeling of long-range dependencies, we observe a significant improvement across all context lengths, as shown in Fig.˜10.

Table 9: Our Hybrid GKA + Attention model improves language modeling performance. When interleaving Attention layers into our GKA models, we observe a significant improvement on recall-oriented tasks, such as FDA and SWDE, while preserving a similar performance on short-context tasks.
Model
ARC-C
ARC-E
BoolQ
COPA
HellaSWAG
PIQA
SciQ
Winogrande
FDA
SWDE
Avg
acc_n \uparrow acc_n \uparrow acc \uparrow acc \uparrow acc_n \uparrow acc_n \uparrow acc_n \uparrow acc \uparrow contains \uparrow contains \uparrow
Gated KalmaNet (Hybrid) 33.02 59.47 64.07 80.00 62.74 74.59 81.40 64.64 53.18 72.46 64.56
Gated KalmaNet 32.51 59.89 61.68 85.00 63.84 74.81 83.20 64.17 12.89 50.95 58.89
Refer to caption
Figure 10: Our Hybrid GKA + Attention model significantly improves performance across all long-context benchmarks compared to our non-Hybrid model. Adding a few Attention layers to our GKA model improves long-range dependency modeling, improving performance across all sequence lengths on RAG, ICL, Synthetic Recall, and Long-QA.