Gated KalmaNet: A Fading Memory Layer Through Test-Time Ridge Regression
Abstract
As efficient alternatives to softmax Attention, linear state-space models (SSMs) achieve constant memory and linear compute, but maintain only a lossy, fading summary of the past, often leading to inferior performance in recall oriented tasks. We propose Gated KalmaNet (GKA), a layer that reduces this gap by accounting for the full past when predicting the next token, while maintaining SSM-style efficiency. GKA achieves this by solving an online ridge regression problem at test time, with constant memory and linear compute cost in the sequence length. Drawing inspiration from the Kalman Filter, we iteratively solve the online ridge regression problem. However, a critical insight is that standard Kalman filter equations are numerically unstable in low-precision environments (like bfloat16) and difficult to parallelize in modern hardware. We address both challenges through two key innovations: (1) an adaptive regularization strategy with input-dependent gating that controls the condition number of the ridge regression, ensuring numerical stability while balancing memory retention. And (2) the use of Chebyshev Iteration instead of other conventional iterative solvers, which we demonstrate to be more stable in low-precision settings. To further improve scalability, we develop a hardware-aware chunk-wise implementation of Chebyshev Iteration along with custom kernels for backpropagating through our adaptive regularization and gating mechanisms. Empirically, GKA shows strong language understanding capabilites on short-context tasks outperforming existing SSM layers (like Mamba2, GLA and Gated DeltaNet). On long-context, GKA excels at real-world RAG and LongQA tasks up to 128k tokens, achieving more than % relative improvement over other fading memory baselines.
1 Introduction
Large Language Models (LLMs) powered by (softmax) Attention mechanisms [Vaswani-NeurIPS2017] have revolutionized sequence modeling through their ability to form rich associations within their context window. However, a fundamental challenge that LLMs face is that their time complexity scales quadratically and storage grows linearly with their input length.
Recent years have seen intense efforts to develop Attention alternatives. Among them, memory layers based on linear State-Space models (SSMs) have grown popular for their linear-time computation and constant storage cost in the sequence length [Dao-ICML2024-mamba2, Yang-ICML2024-gla]. These SSMs find inspirations from classic techniques in adaptive signal processing, and integrating them into modern SSMs leads to principled layer design and enhanced performance [Liu-ICLR2025, Yang-NeurIPS2024, Zancato-NeurIPS2024-BMOJO]. However, pure SSM models still underperform Attention in many settings, especially on long-context tasks. This gap is a consequence of their different memory mechanisms: SSMs have a fading fixed dimensional lossy state of the past, while Attention has an eidetic ever increasing KV-cache state [Zancato-NeurIPS2024-BMOJO].
To bridge this gap, we aim at designing a memory layer that enjoys the efficiency of linear SSMs while performing computation conditioned on the exact past. Towards this goal, we first draw insights from the Kalman filter (KF) [Kalman-1960]. In signal processing terms, KF computes the most recent state conditioned on all data seen thus far, and, under mild assumptions, KF is optimal in the Maximum A-Posteriori (MAP) sense. In the LLM context, we use KF to update the state of an SSM layer and predict its output based on all past inputs. However, integrating KF into such a layer is non-trivial and faces two challenges:
-
•
Parallelizable Training. KF is an online algorithm and needs to be parallelized to fully utilize modern hardware that is highly optimized for large-scale LLM training.
-
•
Numerical Stability. KF involves matrix inversion, which can be numerically unstable in low precision arithmetic.
In this work, we propose Gated KalmaNet (GKA), a memory layer that incorporates KF into its design and is both numerically stable and trainable on highly parallelizable hardware. We start by observing that the KF recursion solves a test-time ridge regression problem. Then, to solve such a regularized problem stably, we make the following choices:
-
•
At the modeling level, we adaptively choose the regularization strength of our test-time objective function based on the Frobenius norm of the regularized data covariance. With this choice we can easily upper bound the condition number of the optimization problem.
-
•
At the algorithmic level, we note that exact solvers (e.g., torch.linalg.solve) are hard to parallelize (in a chunk-wise manner), so we resort to the classic Chebyshev Iteration (CH), which we show has high numerical accuracy and fast convergence compared with alternatives such as (accelerated) gradient descent and conjugate gradient.
To make GKA scalable and efficient, we implement CH with adaptive regularization in Triton in a hardware-aware, chunk-wise manner. Our technical novelty here includes deriving a chunk-wise implementation that back-propagates through the Frobenius norm, for which the difficulty is the presence of a nested recurrence. Furthermore, we combine CH with a gating mechanism that decides the regression residual weights in an input-aware and time-varying fashion, enhancing the contribution of recent inputs and smoothly fading out distant contexts. Overall, to the best of our knowledge, this is a first adoption of the CH method for training sequence modeling layers in LLMs stably at scale.
Finally, we demonstrate the efficacy of GKA on numerous LLM benchmarks. For example, on synthetic recall tasks (MQAR) [arora2023zoology], our method achieves the highest recall accuracy among other state-of-the-art linear SSMs including Mamba2 [Dao-ICML2024-mamba2] and (Gated) DeltaNet [Yang-NeurIPS2024, Yang-ICLR2025]. Also, GKA outperforms existing SSMs on several (short-context tasks from LM-Harness [eval-harness]) and long-context tasks (from RULER [hsieh2024ruler] and HELMET [yen2025helmet]). Specifically, GKA improves upon SSM baselines by at least 10% on real-world long-context tasks like Retrieval-Augmented Generation and Long Question-Answering tasks up to 128k tokens.
2 Prior Work and Preliminaries
In this section we briefly review prior work and preliminaries that will set the stage for motivating our choice of designing an SSM layer based on the Kalman Filter. For a more detailed exposition of related work refer Appendix˜A.
(Softmax) Attention. At each time , Attention [Vaswani-NeurIPS2017] linearly projects the -th input token to obtain three vectors, named query , key , value respectively. Then, it outputs a vector as a convex combination of all values seen so far, with coefficients given by inner products of the current query with all seen keys and a softmax mapping:
| (Attn) |
From an optimization perspective, Eq.˜Attn can be viewed as solving the following regression objective111Concretely, for keys and queries of unit norm, Attention is precisely the Nadaraya-Watson estimator [Nadaraya-1964, Watson-1964] with the Gaussian kernel to approximate the conditional expectation of the value given a query; cf. [Chaudhari-TIST2021, Vidal-2022].,
| (1) |
The success of Eq.˜Attn is often attributed to its ability to perform verbatim retrieval of relevant context from the entire past. Here, the past refers to the entire key-value pairs observed thus far, also known as the KV-cache, which grows linearly with time . Moreover, the computation is also linear at each time , and doing so for all results in a quadratic time complexity. This high computation and storage cost of Attention makes its use prohibitive in long context scenarios.
Linear State-Space Models (SSMs). The high computation cost of Eq.˜Attn has motivated a flurry of work developing new LLM layers, like SSMs, with linear rather than quadratic cost. Most SSMs maintain a state matrix and update it at each time step via a linear recursion of the form
| (Linear-SSM) |
where are typically in . Unlike the verbatim lookup of Eq.˜Attn, here Eq.˜Linear-SSM essentially compresses the entire KV-cache into a fixed-dimensional representation . Subsequent computation of the output relies on and no longer on the exact past. This results in a constant cost of storage and computation at every timestep.
In many linear SSMs (e.g., RetNet [Sun-arXiv2023RetNet], Mamba2 [Dao-ICML2024-mamba2]), the use of and is often heuristic and finds inspirations from nonlinear recurrent neural networks [Hochreiter-NC1997]; in that light, and are called forgetting and input gates, respectively. This basic form of Eq.˜Linear-SSM has been generalized by replacing with a diagonal matrix (GLA [Yang-ICML2024-gla], RWKV-6 [Peng-CoLM2024], Longhorn [Liu-ICLR2025]) or low-rank-plus-identity matrix (Gated DeltaNet [schlag2021linear, Yang-NeurIPS2024, Yang-ICLR2025], DeltaProduct [Siems-arXiv2025], RWKV-7 [Peng-arXiv2025rwkv]).
Similarly to that of Eq.˜Attn the case with low-rank-plus-identity matrices can often be justified from an optimization perspective. For example, Gated DeltaNet [schlag2021linear, Yang-NeurIPS2024] updates the state via ( is the identity matrix)
| (GDN) |
which can be viewed as applying one gradient descent step with stepsize and initialization to the objective
| (2) |
The objectives of Eq.˜GDN Eq.˜2 and Eq.˜Attn Eq.˜1 are prime examples that expose a general distinction between linear SSMs and Eq.˜Attn: The former updates its state based on a regression objective that considers only the previous lossy state and the current time step, whereas the latter uses the entire, exact KV-cache to solve its regression objective Eq.˜1.
We hypothesize this myopic view of SSM objectives results in their lower performance and limited long-context abilities. We then ask: What is an objective or, equivalently, a recursion that considers the entire past as Eq.˜Attn while still being solvable in linear time as in Eq.˜Linear-SSM?
3 A Linear SSM Inspired by the Kalman Filter
In Section˜3.1 we show how the Kalman Filter (KF) gives insights into a new linear SSM layer that takes all past time instants into account. In Section˜3.2 we explain the numerical and efficiency challenges of building such a layer.
3.1 Motivation from Kalman Filter
KF is an established online approach that takes the exact past into account to optimally solve a weighted ridge regression objective (e.g., see [Peng-MoCL2025, Proposition 2 & Lemma 3]). In our context, this means that the optimal state
| (3) | ||||
can be computed by the KF recursion
| (KF) |
where is the weight for the -th key-value pair, and is the Hessian inverse of Eq.˜3 at time ( itself can be continually updated via the Woodbury matrix identity). It is now clear that objective Eq.˜3 takes the entire KV-cache into account, similarly to Eq.˜Attn. It is also clear that Eq.˜KF is an efficient update scheme similarly to Eq.˜Linear-SSM; indeed, Eq.˜KF is also a low-rank-plus-identity form (cf. Eq.˜GDN).
A key difference from Eq.˜Linear-SSM is that Eq.˜KF leverages second-order information from to solve Eq.˜3 optimally, whereas Eq.˜Linear-SSM relies on instantaneous objectives akin to Eq.˜2 (cf. [Yang-NeurIPS2024, Table 2]). It is in this sense that we say Eq.˜KF is more expressive than other Eq.˜Linear-SSM or Eq.˜GDN. We now detail the differences in the objectives of Eq.˜KF and Eq.˜Attn:
- •
- •
-
•
In Eq.˜3, the regularizer prevents overfitting our state to key-value pairs, as only a finite amount of “information" can be stored in a constant-sized memory beyond which will result in “fuzzy" recall. In this light, can be thought of as controlling the memorization “capacity" of the state.
3.2 Hurdles Towards Scalable Kalman Filter SSMs
Despite its optimality and (sequential) computational efficiency the Eq.˜KF recursion lacks a hardware-aware implementation that leverages parallelism in modern Tensors Cores. Moreover, for long sequences it can lose numerical precision due to division (and more significantly due to how the Hessian inverse is updated). The final hurdle is conceptual: Fixing weights and regularization over time as in Eq.˜3 might make a layer less expressive.
We are aware of the use of Eq.˜KF or Eq.˜3 in neural network training three decades ago [Shah-1992] or in deep continual learning recently [Zeng-NMI2019, Mcdonnell-NeurIPS2023, Peng-ICLR2025]. We are also aware of the recent mentioning of Eq.˜3 or efforts towards solving it, which go by the name test-time optimization [Wang-arXiv2025v3, Von-arXiv2025-mesa]. However, to the best of our knowledge, none of the prior work has fully addressed the above hurdles that need to be solved to design an SSM layer that is trainable in parallel, numerical well-behaved, and sufficiently expressive. In particular, both [Von-arXiv2025-mesa] and [Wang-arXiv2025v3] have overlooked a basic numerical concern: The worst-case numerical error in solving Eq.˜3 can be [Golub-2013], where is the condition number of the Hessian in Eq.˜3 and the machine precision; since (bf16), Eq.˜3 has to be regularized strongly for and the worst-case error to be small, regardless of algorithmic choices to solve Eq.˜3. Indeed, the regularization enforced in [Von-arXiv2025-mesa] sets to be lower bounded by , but this is not sufficient: Their is as large as [Von-arXiv2025-mesa, Fig. 13], implying a worst-case error of (The implementation of [Von-arXiv2025-mesa] available on GitHub is numerically vulnerable; we failed to train it without NaNs in various settings.). Also, the regression objective in [Wang-arXiv2025v3] has no regularization, which makes it numerically ill-posed for low-precision training.
4 Gated KalmaNet (GKA)
We propose Gated KalmaNet (GKA) to address the above hurdles: We enhance numerical stability via adaptive regularization and the classic Chebyshev Iteration (CH), increase expressivity of KF via a standard gating mechanism, and improve parallelism via a hardware-friendly implementation.
4.1 CH with Adaptive Regularization & Weighting
Motivation. As alluded earlier, solving Eq.˜3 via Eq.˜KF is sequential in nature, and here we consider alternatives amenable to parallelizable training. Our first step towards this is to write down a closed form solution to Eq.˜3 and compute the output
With the weighted covariances and , we note that can be computed via first solving for and then left-multiplying . An exact solver (e.g., torch.linalg.solve) can do so with high accuracy, by parallelizing over the batch dimension. However, it is inefficient here for two reasons. First, it takes time for every . Second, it requires explicitly forming and materializing all ’s, which would entail a large I/O cost. In light of this, we resort to first-order iterative methods that admit chunk-wise implementation without materializing all ’s, enabling parallelism over chunks and batches. Furthermore, they often take time complexity per iteration and can converge quickly in a few iterations. The iterative method we choose is the Chebyshev Iteration (CH); we proceed to describe its basic idea, with a justification of using CH deferred to Section˜4.2.3.
Chebyshev Iteration (CH). CH can be seen as an accelerated gradient descent method (AGD) that applies Eq.˜grad descent and Eq.˜momentum to the strongly convex objective , that is to solve the optimality condition (Algorithm˜1). Different from AGD, CH incorporates a Eq.˜weight schedule and makes specific choices of different parameters; these choices makes CH optimal with the fastest convergence among first-order methods [Pedregosa-Chebyshev2021].
We now replace the above exact solver with CH:
Here, means iterations of CH to approximately solve . To improve stability and expressivity, next we allow regularization and weight to be time-varying and chosen adaptively. We write and to make their dependency in time explicit, with being the weight of the -th token at time .
| (weight schedule) | ||||
| (grad descent) | ||||
| (momentum) |
Adaptive Regularization. As mentioned, the condition number of has to be controlled for any method to be numerically stable. We choose to be proportional to the Frobenius norm , that is to set for some constant . An upper bound on now ensures:
| (4) |
Here are the maximum and minimum eigenvalues of , respectively. Given this choice of , we set and for Algorithm˜1.
Adaptive Weighting (Gating). We use weights that are exponentially decaying in time: For all , we parameterize , with each learnable. The fading weights encode the “prior" of recency bias that has been shown to exist in LLMs [fang2025large] without even explicitly computing the weights from the query-key dot products as in Eq.˜Attn. Similarly to Eq.˜Attn, the weights on the residuals are now time-varying, but differently, the exponentially decay parameterization allows for linear-time implementation.
Forward Recurrence. We now summarize our recurrence which arms CH with adaptive regularization and weighting:
| (CH) |
4.2 Chunk-wise Implementation
In this subsection, we describe our hardware-aware implementation for the forward + backward passes for Eq.˜CH. More details can be found in Appendix˜B.
4.2.1 Forward Pass
Similarly to prior work [Dao-ICML2024-mamba2, Yang-ICML2024-gla, Yang-NeurIPS2024], we now describe a chunk-wise implementation for Eq.˜CH. In Eq.˜CH, given and , computing in a chunk-wise fashion is similar to that of Eq.˜Linear-SSM; also similar is the calculation of as needed in Eq.˜grad descent. For these we refer the reader to [Yang-ICML2024-gla, Yang-NeurIPS2024] for details. Our algorithmic novelty here is a chunk-wise computational formula for , presented next.
Let be the sequence length and the chunk size such that is an integer. For , write . The core idea of a chunk-wise implementation is as follows. First, we compute and store the initial state of every chunk. This gives us implicit access to via unrolling the recurrence of for steps and therefore allows us to carry out computation with ; for example, we can compute the matrix-vector product via . This is without forming explicitly, thereby reducing the number of states to materialize on chip. To implement such a scheme, we need to precompute all ’s sequentially, and then do the computation with parallelism over chunks and within each chunk.
We now make this idea precise for computing all ’s within a chunk. Since the computation of each chunk is the same, we simplify by working with the first one where we have access to initial state , gates , keys ,and we aim to compute . With these notations, we first compute the -dimensional vector of cumulative products of ’s, with . Then, form the upper triangular matrix whose ()-th entry is (). Now, unroll the recurrence of :
Expanding gives the following sum of three terms:
With , the first term is easily computed in parallel for all . For the second term, we first compute the vector of quadratic forms for all in parallel, broadcast it and multiply it with element-wise, sum over each column, and multiply the result with element-wise. Finally, with Gram matrix , one verifies the third term can be computed in parallel for all via the following pseudocode:
| (5) |
Here denotes element-wise multiplication and the sum is over each column. Summing the three terms and taking the square root, we obtain , as desired.
4.2.2 Backward Pass
Motivation. Typically, the backward pass is done automatically via torch.autograd. However, for iterative methods such as CH (Algorithm˜1), torch.autograd would store some activations or intermediate iterates, entailing large storage cost. While in principle we can back-propagate through CH without storing any intermediate activations or iterates (by our trick of reverting the CH iterations, cf. Section˜C.1), under this trick it is difficult to compute all the gradients in a chunk-wise fashion. Therefore, we resort to the implicit differentiation trick, which is practically efficient and chunk-wise implementable, for backpropagation through the linear equations that CH approximately solves.
Implicit Differentiation. We derive the backward pass for our method with the standard implicit differentiation trick. It assumes we find an exact solution to the equations . In the backward pass, we are given the gradient of some loss function , and need to compute the corresponding gradients at . For example, via the chain rule we obtain via
| (6) |
that is to solve linear equations similarly to the forward pass. Since the forward pass computes an approximate solution via CH, we receive an approximate up stream gradient (not exactly ). Thus we employ CH to obtain an approximate gradient ; cf. Table˜1.
Backward Recurrence. Besides , we need to compute from which we obtain and via the chain rule. We describe in the Appendix. Here we analyze :
Lemma 1.
With , , we have
| (7) |
With , we can compute the first two terms and in Eq.˜7, similarly to Eq.˜Linear-SSM. Specifically, satisfies the recursion
| (8) |
thus calculating amounts to calculating in Eq.˜Linear-SSM; a difference is that the recursion here runs backwards.
Similarly, with , the third term in Eq.˜7 can be written recursively as
| (9) |
Chunk-wise Recurrence. As indicated, a chunk-wise implementation for computing is known. On the other hand, computing is more challenging than , as the additive term in the backward recursion Eq.˜9 is not necessarily rank- ; rather, itself is defined via the forward recursion in Eq.˜CH. Our contribution here is a derivation for computing efficiently in a chunk-wise manner.
We begin by unrolling to :
| (10) |
We next discuss the intra-chunk term and cross-chunk term in succession.
Intra-chunk Computation. We now unroll and obtain an expression more amenable to parallelism:
The coefficients of , written as , are easily computed in parallel for all via element-wise operations, broadcasting, and summing. The coefficient of is precisely the -th entry of the matrix . Thus is equal to
Here the mask is in general a full matrix with no zero entries, as opposed to the triangular matrix in the case of Eq.˜Linear-SSM. While the triangular mask in the backward pass allows the error feedback from future tokens to be leveraged for learning past tokens, here our full mask allows all tokens to interact with all other tokens in the backward pass, which facilitates the information flow and learning.





(a) Empirical convergence (b) CG as a single layer (c) CG in 5-layer LLAMA (d) CH as a single layer (e) CH in 5-layer LLAMA
Cross-chunk Computation. In Eq.˜10, both and are from the future chunks, thus we revise Eq.˜10 into the cross chunk recursion of which allows us to maintain a single term from the future:
In our intra-chunk computation, we store the intra-chunk term of all chunks, implement the above with a simple for loop, and collect the terms .
4.2.3 Comparison to Other Iterative Solvers
Here we validate our choice of Chebyshev Iteration (CH) by benchmarking it against other iterative methods.
Convergence in the Forward Pass. We generate random regression problems, which we solve via CH and 3 other baselines: gradient descent (GD), accelerated GD with Nesterov’s momentum (AGD), conjugate gradient (CG). GD and AGD are run with stepsizes that are optimal for regression problems. Fig.˜1a shows CG converges the fastest within a few iterations, while CH reaches the same accuracy as CG at iteration 10 and eventually attains the smallest errors.
| forward pass | backward pass | |
| exact | ||
| CG | ||
| CH |
Stability of the Backward Pass. We then proceed and measure the gradient stability of CG and CH, whose backward passes are implemented either via implicit differentiation as per Table˜1 (impl), or via torch.autograd (auto).
In Fig.˜1b, CG (impl) as a standalone layer has its gradient close to that of the exact solver up to a relative difference. In Fig.˜1c, this difference is amplified to almost in a 5-layer LLAMA where Eq.˜Attn is replaced with Eq.˜3. This indicates CG (impl) completely deviates from the reference gradient (exact), defeating its purpose of training the network from the regression feedback. In contrast, the gradients of CH (impl) and CH (auto) are eventually close to that of the exact solver either as a single layer (Fig.˜1c) or within multiple layers (Fig.˜1d), up to a difference. Moreover, the curves for CH (impl) and CH (auto) nearly overlap, suggesting that their gradients may be close. The following lemma confirms and formalizes this intuition (see Appendix˜C for a proof), thereby justifying our choice of CH over the alternatives:
Lemma 2.
Let be the exact gradient of for CH, e.g., computed by CH (auto). Let be the gradient of CH (impl), computed as per Table˜1. We have .


(a) Accuracy vs. model dimension for different fading memory layers on MQAR. (b) Runtime of a single memory layer
4.3 Architectural Consideration
Our GKA layer in Fig.˜2 includes two components (in green) on top of established practices (in blue). The CH component is described in Section˜4.1, thus here we introduce the -connection. First, the sigmoid activation ensures , so the output of the -connection is a convex combination of the original query and the output of CH. Second, it plays a similar role to residual connection, which establishes a direct path that facilitates the gradient flow and improves training; we show this is indeed the case in Section˜F.3. Finally, the full architecture for GKA is the standard Transformer, with its attention layer replaced by the GKA layer.
5 Experiments
In this section, we empirically validate the efficacy of our approach. We first evaluate memorization ability on synthetic associative recall tasks (Section˜5.1). We then report training throughput of GKA (Section˜5.2). Finally, we examine performance on short-context language understanding benchmarks such as commonsense reasoning and long-context modeling abilities in Section˜5.3. The Appendix details our experimental settings (Appendix˜D) and ablations of various modeling choices (Appendix˜F, Appendix˜G).
Baselines. All experiments consider the following state-of-the-art linear SSM-based fading memory layers as baselines: Mamba2 [Dao-ICML2024-mamba2], DeltaNet [Yang-NeurIPS2024], Gated DeltaNet (GDN) [Yang-ICLR2025], and Gated Linear Attention (GLA) [Yang-ICML2024-gla]. Each of these layers rely on instantaneous objectives that depend on the previous lossy state and current tokens (e.g., Eq.˜2), as opposed to the entire history of tokens observed so far as in GKA. Finally, we contrast our results with (Softmax) Attention, which serves as our paragon. For our Attention-based model, we adopt the architecture proposed in Qwen3 models [yang2025qwen3].
5.1 GKA on Synthetic Associative Recall Tasks
We first assess the capability of our models to recall information on the multi-Query Associative Recall (MQAR) task, a synthetic task introduced by Arora et al. [arora2023zoology]. This task presents the model with a sequence of key-value pairs to memorize, followed by a sequence of queries. For each query, the model must retrieve the corresponding key from memory and accurately recall its associated value. Attention based layers perform the best in this task, while SSM-based memory layers are known to struggle as their memory fades away as the context length grows.
We compare GKA with Attention and other linear SSM baselines on this task. For each memory layer type, we train 2-layer models on MQAR training data and evaluate on a held-out test set. We repeat this experiment for different learning rates spanning from to . As shown in Fig.˜3a, GKA improves upon every other linear SSM baseline at all sequence lengths and model dimensions considered. Note, the complexity of the task increases with increasing sequence length and number of key-value pairs, while larger model dimensions improve memorization capacity through increased state size. The success of our layer can be attributed to our modeling choice: unlike other fading memory designs (like GDN or Mamba2), we construct states based on the optimal MAP estimate conditioned on the entire history, enabling better retention of remote information.
5.2 Training Throughput of GKA
In Fig.˜3b we measure the running time (forward + backward) of a single GKA layer and compare it with FlashAttention [golden2024flash], DeltaNet, and GDN. Our layer achieves comparable running time to GDN, a state-of-the-art SSM layer, despite having a more computationally expensive state update equation Eq.˜CH than Eq.˜GDN. This demonstrates that our chunk-wise parallelization strategy effectively compensates for the additional computational cost.
| Model |
|
|
BoolQ | COPA | HellaSWAG | PIQA | SciQ | Winogrande | FDA | SWDE | Avg | ||
| acc_n | acc_n | acc | acc | acc_n | acc_n | acc_n | acc | contains | contains | ||||
| Transformer | 32.25 | 56.10 | 64.28 | 80.00 | 60.96 | 73.56 | 79.50 | 61.72 | 58.53 | 72.28 | 63.92 | ||
| Gated Linear Attention | 27.82 | 50.80 | 52.57 | 78.00 | 48.83 | 70.13 | 69.60 | 54.54 | 2.81 | 20.43 | 47.55 | ||
| DeltaNet | 32.85 | 58.16 | 42.51 | 81.00 | 61.13 | 73.78 | 43.90 | 61.72 | 11.80 | 46.08 | 51.29 | ||
| Mamba2 | 32.24 | 59.64 | 58.72 | 82.00 | 62.23 | 73.78 | 79.80 | 62.19 | 7.71 | 41.13 | 55.94 | ||
| Gated DeltaNet | 32.59 | 60.02 | 62.75 | 82.00 | 62.80 | 74.32 | 80.60 | 62.35 | 8.26 | 44.28 | 57.00 | ||
| Gated KalmaNet (Ours) | 32.51 | 59.89 | 61.68 | 85.00 | 63.84 | 74.81 | 83.20 | 64.17 | 12.89 | 50.95 | 58.89 |
5.3 GKA on Language Modeling
5.3.1 Short-context Tasks
Setup. For this set of experiments, we construct 2.8B LLM models for each memory layers (GKA and baselines described in Section˜5) by cascading blocks of mem + Multi-Layer Perceptron (MLP) blocks.222For Mamba2 baseline, we consider cascading blocks of Mamba2 layer alone since a single Mamba2 layer has the Mamba2 SSM and MLP. Hereby, we refer to the 2.8B models with the same name as the layer used to construct them. We then train each model on DCLM [li2024datacomp], a generic pre-training dataset for B tokens at K context length using the AdamW optimizer with a peak Learning Rate (LR) of and gradient clipping of . We used the cosine LR scheduler with a warmup period of B tokens with a global batch size of M tokens. All models employ the GPT2 tokenizer with a vocabulary size of K tokens.
Tasks. Following prior works [Zancato-NeurIPS2024-BMOJO, Yang-ICML2024-gla, Yang-ICLR2025], to consider language modeling capabilities of our model we perform zero-shot evaluation on the following eight common-sense reasoning tasks from LM-Harness [eval-harness]: Arc-E, Arc-C, BoolQ, COPA, HellaSWAG, PIQA, SciQ, Winogrande. We also evaluate models on FDA and SWDE, real-world recall-intensive tasks which focus on extracting structured information like tagged content from raw text (for example, HTML files). All these tasks are relatively short (K tokens).
Results. We report our results in Table˜2. GKA outperforms all fading memory baselines on average across all tasks owing to its ability to better manage its state via solving Eq.˜3. In particular, GKA outperforms both GDN and Mamba2 on recall-intensive tasks (FDA and SWDE) by about (rel. improvement). We note that although GKA improves upon existing SSM layers there is still a gap with Attention-based Transformer especially on recall-tasks owing to the eidetic capabilities of Attention. Nevertheless, as discussed in Section˜1 this improvement comes at a quadratic cost at training time, whereas our layer’s computational complexity is still comparable to existing SSM layers (cf. Section˜5.2). In Appendix˜I we extend our results to Hybrid models (stack of SSM and Attention layers) and show that the gap with full Transformer models becomes negligible (while still benefiting the SSM’s computational advantages). Finally, in Appendix˜E we show that GKA exhibits stronger scaling with compute than other SSM baseline models.
5.3.2 Long-context Tasks
Setup. To enable long-context capabilities of our models, as is common practice, we perform continued pre-training of our 2.8B models obtained in Section˜5.3.1 on B tokens of long documents at K context length (cf. Appendix). To the best of our knowledge we are the first to train and evaluate SSM models up to K context (e.g., previous work [Yang-ICLR2025] only considered up to K/K context).
Tasks. For long-context, we refrain from using perplexity as it is known to have limitations at assessing long-context performance of LLMs [nunez2024expansion, fang2024wrong, gao2025train]. Instead, we turn to recently proposed benchmarks that mix synthetic and real datasets comprising several long-context tasks: synthetic Recall, Retrieval-Augmented Generation (RAG), Many shot In-Context Learning (ICL) and Long Question-Answering (LongQA). For synthetic Recall and LongQA we consider tasks from the RULER benchmark [hsieh2024ruler]. For RAG and ICL we consider tasks from HELMET [yen2025helmet].
Results. Fig.˜4 reports our results. GKA shows strong RAG and LongQA capabilities, outperforming all fading memory baselines by at least % (rel. improvement). Interestingly, on synthetic Recall tasks from RULER, GKA is competitive only at K context length and starts to fall behind afterwards. We attribute this divergence to the fundamental differences between these task types. While both RAG and LongQA can be thought of as finding relevant information in long streams of text, they involve more realistic linguistic patterns and semantic relationships that align with natural text distributions seen during pretraining. In contrast, synthetic Recall tasks require models to find specific words, numbers, or UUIDs verbatim from long contexts filled with random distractors. This artificial setting does not reflect natural text distributions. Since GKA computes MAP estimates of the latent state based on learned representations of observed tokens, it relies on its pretrained weights to determine which information is important to retain. The synthetic and unnatural structure of Recall tasks makes it difficult for the model to identify what should be retained, as pretrained knowledge provides little signal about importance in these artificial contexts. This suggests that our approach excels in realistic scenarios where pretrained knowledge about natural language structure can guide information selection, but struggles when the signal-to-noise distinction is purely artificial.
6 Kalman Filter for Optimally Modelling Fading Memory
In this section, we show how the Kalman Filter (KF) provides a principled solution for constructing an optimal fading memory that accounts for the entire history. We begin by describing the standard Kalman Filter recurrence in the context of memory modeling. However, the KF has a fundamental limitation: its inherently sequential nature makes it impractical for large-scale training on modern hardware accelerators Section˜3.2. To address this, we make simplifying assumptions that makes KF amenable to parallelization on modern hardware accelerators. We then demonstrate that several recent state-space models (DeltaNet, Gated DeltaNet, and Kimi Delta Attention) can be viewed as approximations to the KF recurrence. Specifically, these methods approximate the “optimal" Kalman gain matrix while ignoring dependencies on the past. In contrast, GKA computes the exact Kalman gain by considering the full history. This theoretical advantage translates to improved empirical performance, as we demonstrate in Section˜5.
6.1 A Dynamical System for Fading Memory
The Kalman filter is a classical algorithm for online optimal inference in linear Gaussian State-Space Models. It gives a principled way to maintain and update a compact state as new noisy observations arrive. The latent state serves as a compressed "memory" of the past. More formally, it is a minimal sufficient statistic that makes past observations conditionally independent of future ones given the state.
We begin by describing a linear Gaussian model for fading memory.
| (LGM) | |||||
where is a latent state that summarizes the past, is the control input that updates the state and is the scalar measurement observed at time . are the state transition and input selection matrices, and is the emission (readout) vector. Finally, and are Gaussian process and measurement noise, respectively.
Parameter interpretation. and control the forgetting (fading of the remote past) and input selectivity rates respectively, determining how the state evolves over time. The measurement noise naturally gives rise to gating mechanisms commonly used in modern SSM layers, as we will show in Section˜6.4.
Extension to multi-channel measurements. In attention mechanisms, the memory consists of verbatim key-value pairs that can be queried to retrieve past information [Zancato-NeurIPS2024-BMOJO]. Similarly, we want our state to reconstruct past values from their corresponding keys. To achieve this, we extend to a matrix-valued state , where each column independently follows the dynamics in Eq.˜LGM.
Specifically, for the channel:
where is the key-value pair at time and is the element of . In what follows, we focus on a single channel and drop the subscript from the state for notational clarity.
6.2 Kalman Filter for Optimal Inference
Given the model in Eq.˜LGM and a sequence of measurements , the Kalman Filter computes the Maximum A-Posteriori (MAP) estimate of the latent state at time :
| (11) |
where is the probability density function. The MAP estimate is optimal in the sense that it minimizes the expected squared error between the true state and its estimation given all measurements up to time .
The KF recursion. The Kalman Filter updates the state estimate recursively as new measurements arrive. At time , the update is:
| (12) |
where the innovation measures the discrepancy between the actual measurement and the predicted measurement based on the predicted state estimate.
The Kalman gain determines how much to trust the new measurement versus the predicted state. It is computed as follows:
| (13) |
The error covariance quantifies the uncertainty in the state estimate. It represents the covariance of the estimation error conditioned on all measurements up to time . The covariance is updated as:
| (14) |
6.3 Gated KalmaNet: A Steady-State Dynamical System for Large-Scale Training
Despite its optimality, the KF recursion in its most general form is inherently sequential; each update depends on the previous state estimate. This sequential dependency prevents the parallelization necessary for efficient large-scale training on modern hardware.
To enable parallelization, we make a key simplifying assumption: the underlying state remains static over time. This reduces the problem from tracking a dynamic state to estimating a fixed but unknown parameter from sequential noisy measurements. Formally, we assume a steady-state model:
| (15) | |||||
where , , and (i.e., no state evolution, no control input, and no process noise).
Adapting to evolving context. While the steady-state assumption may initially seem restrictive, contexts naturally evolve as topics change, GKA addresses this through adaptive weighting (Section˜4.1). By assigning higher weights to recent measurements, older observations are naturally faded out over time, allowing the model to track shifting context despite the static formulation.
Under this simplification, the KF recursion reduces to:
| (16) | |||
Collecting all channels, these equations can be written compactly in matrix form as shown in (KF).333with columns of transposed to being rows of to be consistent with the notation in (KF) and taking the noise variance .. A key insight of this work is that the KF recursion for the steady-state model admits an efficient parallel implementation via chunked processing (detailed in Section˜4) that results in Gated KalmaNet.
Critically, the KF recursion accounts for the entire history when computing state estimates. The Kalman gain at each step depends on all previous measurements through . This contrasts with most existing SSMs, which we show next can be viewed as approximations that ignore historical dependencies when computing their gain matrices. This principled treatment of the full history is a key advantage of our approach.
6.4 Connection with Existing SSM Layers
DeltaNet [Yang-DeltaNet] approximates the KF recursion in (16) by assuming fixed error covariance: for all . This simplifies the Kalman gain to:
| (17) |
where the second equality assumes unit-normalized keys, a common assumption in practical instantiations of DeltaNet. Substituting (17) into the state update (16) and defining yields:
| (DeltaNet) |
which is the DeltaNet recurrence. By fixing , DeltaNet avoids tracking the evolving uncertainty in the state estimate, a key simplification that sacrifices optimality for computational efficiency. In contrast, GKA maintains the full error covariance , allowing it to optimally weight measurements based on the entire history.
Gated DeltaNet (GDN) [Yang-ICLR2025] extends DeltaNet by incorporating explicit forgetting through a time-dependent decay factor . Like DeltaNet, GDN can be viewed as fixing , but applying this approximation to the KF recursion for a fading dynamical system where the state decays over time.
Specifically, GDN assumes
| (18) | |||||
where is a learned decay factor controlling how much past information to retain. This corresponds to setting in (LGM). When , the state "forgets" the past completely; when , the state is fully retained.
Under the identity covariance assumption , the Kalman gain becomes:
| (19) |
where the second equality again assumed unit-normalized keys (as in DeltaNet). Defining and substituting into the state update (12) yields:
| (GDN) | ||||
which recovers the GDN recurrence. In practice, is an input-dependent learnable parameter.
Like DeltaNet, GDN avoids tracking the evolving uncertainty , trading optimality for computational simplicity. The key difference is that GDN’s explicit forgetting factor provides additional control over the memory horizon. However, by fixing , GDN still ignores how measurement history should optimally influence the Kalman gain, leading to suboptimal performance compared to GKA (see Section˜5).
Kimi Delta Attention (KDA) [team2025kimi] further extends GDN by using channel-specific decay factors in place of the global . This allows different channels to have independent memory horizons. In the KF framework, this corresponds to:
| (20) |
for each channel . While this added flexibility can improve expressiveness, KDA still assumes and therefore does not optimally consider the entire past when computing its state update. Like DeltaNet and GDN, KDA sacrifices optimality for computational simplicity.
7 Discussions and Limitations
Thanks to its expressive test-time ridge regression objective, Gated KalmaNet extends previous fading memory layers like Mamba2, LongHorn and Gated DeltaNet, all of which only depend on an instantaneous test-time objective. However, GKA is only optimal among linear memory layers, solving our test-time objective leveraging non-linear updates while still maintaining hardware efficiency and numerical stability is an interesting area for future research. Despite the efficient kernels we implemented, we believe even faster implementations of our idea are possible, e.g., via sketching (see Appendix˜H for preliminary results). Finally, while we have showed promising results in combining GKA with Attention layers into Hybrid models (Appendix˜I), further scaling beyond 3B parameters models is required to validate GKA on more challenging real world problems.
Appendix A Related Work
Since the introduction of Self-Attention [Vaswani-NeurIPS2017], significant research has been conducted to reduce its quadratic cost in processing long input sequences. As models and systems scale to million-token contexts, Attention’s bottlenecks have become critical blockers to frontier agentic applications in coding, information gathering, and scientific discovery [chen2024scienceagentbench, cui2025curie, jimenez2023swe]. Prior works have proposed various approximation schemes to overcome these limitations. For example, Reformer [kitaev2020reformer] uses locality-sensitive hashing to group tokens with similar embeddings. This enables the model to attend only to a subset of tokens rather than the entire sequence. Other works equip Transformer models with "compressed" memory tokens that are updated dynamically and causally over sliding windows on entire sequence chunks [dai2019transformerxl, munkhdalai2024leave, mohtashami2023landmark]. While a lot of prior work have focused on reducing the quadratic complexity of Attention with sparse approximations [nunez2024expansion, yuan2025NSA], this work focuses on linear approximations of Attention.
A.1 Linear Attention
Linear Attention methods approximate the Attention mechanism with constant-size recurrent dynamical systems [Dao-ICML2024-mamba2, Yang-ICML2024-gla, beck2024xlstm, Yang-ICLR2025]. Numerous State-Space Model (SSM) variations have been proposed, ranging from those closely resembling Linear Attention [sun2023retentive] or Linear Time-Invariant dynamical systems [gu2021combining, zancato2022stacked], to those introducing novel adaptive or gated state updates [Yang-ICML2024-gla, Dao-ICML2024-mamba2, orvieto2023resurrecting].
Despite their differences, all SSMs follow the same basic working principle inspired by classical state-space models [Kalman-1960]: they process the input sequence by maintaining a fixed-size state that acts as a compressed (lossy) representation of all processed tokens. Moreover, when implemented in hardware, the state must have finite precision and “fades away the past" as more samples are processed. Successful SSM layers typically employ hardware-aware implementations that efficiently utilize modern matrix multiplication accelerators through highly parallelizable and scalable primitives, including associative scans [gu2023mamba, de2024griffin], chunking mechanisms [Dao-ICML2024-mamba2, Yang-ICML2024-gla], and techniques that avoid materializing the entire state in slow high-bandwidth memory [gu2023mamba].
From a modeling perspective, most Linear Attention implementations introduce data-dependent gating factors to control the speed of their “fading” memory, balancing expressivity with scalability. For example, the transition from Mamba to Mamba2 replaced channel-wise data-dependent gating with head-wise gating for better scalability and Tensor Cores utilization. Input-dependent Gating has been shown to empirically improve training stability [arora2023zoology, Yang-ICLR2025] and has driven the development of Linear Attention models (e.g., from S4 [alber_gu_s4] to Mamba [gu2023mamba] and from DeltaNet [Yang-DeltaNet] to Gated DeltaNet [Yang-ICLR2025]). In our work, we demonstrate that gating emerges naturally as a consequence of solving a weighted least squares objective function, establishing a connection to the favorable numerical properties classically described in the adaptive filtering literature [LJUNG_RLS_stability, sayed2003fundamentals, sayed2011adaptive].
A.2 Hybrid State Space Attention Models
While extending the recurrent state in SSM layers has yielded performant models, they typically underperform on tasks requiring recall of information from the distant past [waleffe2024empirical, jelassi2024repeat]. Hybrid State-Space Models address this limitation by complementing SSMs’ “fading" state with Attention layers [dao2024transformers, de2024griffin, lieber2024jamba, glorioso2024zamba]. Early architectures simply stacked SSMs and Attention layers with different blending ratios [waleffe2024empirical, gu2023mamba, Dao-ICML2024-mamba2] or replaced full Attention layers with Sliding Window Attention [de2024griffin]. More sophisticated designs have recently emerged [glorioso2024zamba, Zancato-NeurIPS2024-BMOJO].
Notably, B’MOJO [Zancato-NeurIPS2024-BMOJO] complements SSMs’ fading state with "eidetic" memory by combining SSMs with Sliding Window Attention (SWA) in a single layer. Within the window, tokens can attend to a selected set of past tokens that were deemed difficult to predict using an asynchronous causal selection mechanism. B’MOJO was the first hybrid model to propose a parallel fusion of SSM and SWA at the layer level. Subsequent works [dong2024hymba, bae2025hybrid] have shown this parallel fusion approach to be more performant (at equivalent compute) than the stacked approach of earlier works.
Thanks to their lower memory footprint and test-time scalability over long sequences, Hybrid architectures are expanding into long-range agentic tasks and have recently been trained with Reinforcement Learning at scale [chen2025minimax]. When coupled with system-level optimizations like prefix caching [pan2024marconi] and specialized inference engines [kwon2023efficient], Hybrid models can increase the number of rollouts (exploration), thereby improving end-to-end performance in Reinforcement Learning loops.
Appendix B Forward and Backward Passes of Chebyshev Iteration (Details)
In Section˜4.2 we described our chunk-wise implementation of the CH method with adaptive regularization and gating. We now give full details omitted there.
B.1 Forward Pass
CH in Detail. We begin with describing the CH method (Algorithm˜1) in more detail. Assume we have a linear system of equations where is a positive definite matrix. We assume has its all eigenvalues lie in the interval and the values of and is known. Note that solving this system is equivalent to solving the following quadratic problem:
| (21) |
The classic Chebyshev Iteration in its standard form is presented in Algorithm˜1. In the initialization phase, we set , which is the typical convergence rate of gradient descent applied to the above quadratic problem with stepsize ; vaguely speaking, in this setting, this stepsize choice is optimal (e.g., that allows gradient descent to converge the fastest possible). Algorithm˜1 initializes two points, and . Here is zero, and is a gradient step for Eq.˜21 starting at and with stepsize . The final component in initialization is the weight . This is the starting point for the weight schedule recursion of in Eq.˜weight schedule. Similarly, the initialization of is where we start to compute , whose update consists of Eq.˜grad descent and Eq.˜momentum. Note that Eq.˜grad descent is with stepsize . Since , this stepsize is strictly larger than , the latter being the optimal stepsize for vanilla gradient descent. Such a large stepsize alone might not guarantee convergence, but it is balanced by the Eq.˜momentum term with positive weight so that the convergence of the Chebyshev iterative method is ensured.
Numerical Stability Considerations. Now we analyze the numerical properties of the Chebyshev Iteration. The major computation consists of matrix-vector multiplication; in a batched parallel implementation, this turns out to be matrix-matrix multiplication. For this, the numerical accuracy is well controlled (e.g., in Triton we could specify the accuracy in tl.dot). The update of in Eq.˜weight schedule might raise numerical concerns as it involves division. That said, we show this division operates in a numerically well-behaved range as is decreasing with yet lower bounded by :
Lemma 3.
For any , we have , where is defined as
As a consequence, we have for all .
Proof.
If , then is a scaled identity matrix, and the algorithm is simplified a lot. So we assume in what follows. With we have . Since , we have and therefore . Repeating this argument and we see for all . By the definition of , to show is to show
where is defined as . Note that has two roots, , as defined earlier, and ; are the two fixed points of the update Eq.˜weight schedule. Observing that lies in the interval , and moreover, for any , if we must have
This proves for all . Next, since lies in the interval where decreases, therefore we have . Thus lies in again. We could then conclude inductively that for all . ∎
From Lemma˜3 we know that the update of in Eq.˜weight schedule would not create much numerical concern in a forward pass, as we have for all . Furthermore, we can bound the rate at which converges to :
Lemma 4.
Define . For any , we have
where is defined as
Proof.
From the update rule of in Eq.˜weight schedule and the fixed point property of , we have
Here, (i) follows from the fact that is a fixed point, (ii) follows from Lemma˜3 that , (iii) follows from the definition of and the fact , and (iv) follows from the definitions of and . The proof is concluded by unrolling the above recurrence. ∎
Remark 1.
Here, we call the linear convergence rate (or contraction factor) of to . First-order methods for solving converge at most at a rate , and we see converges at an even faster rate. Numerically, assuming , we then have:
Thus, with , the update of in Eq.˜weight schedule converges in at most 20 iterations up to the bfloat16 precision.
B.2 Backward Pass
We now give details for backpropagation through the Chebyshev Iteration (Algorithm˜1) via implicit differentiation.
Computing and . First, we follow Table˜1 and Lemma˜2, and compute . Then, given the equation , we have that
| (22) |
Therefore . Since we set , this indicates
| (23) |
Note that this expression of is partial: It accounts for the upstream gradient from only and one might think of the subsequent states all depend on . We will accumulate the gradients later when needed.
Now, the recursion of in Eq.˜CH implies
| (24) | ||||
| (25) |
which proves Lemma˜1. We refer the reader to Section˜B.2.1 for more detailed derivations of and .
Derivatives for Gating. In practice we often parameterize in the log space to ensure numerical stability. Thus, let us first revise our notations for this case. Let and . Then the mask matrix is
| (26) |
Now, since for any we have
| (27) |
for any we have the following basic derivatives:
| (28) | ||||
| (29) |
With being the aggregated gradient from the future, we have for that
| (30) | ||||
| (31) | ||||
| (32) | ||||
| (33) |
Note that in one of the above equations we add and subtract the term , which will simplify the implementation.
Recall that with . In computing the derivatives of the first term is the standard term that arises in that of Eq.˜Linear-SSM, which we omit here. We now focus on the second term . This implies the gradients and are partly given respectively by (using the notations in Section˜4.2.2 and omitting some algebraic operations)
| (34) |
Computing the first term in parallel is easy by invoking the definition of and the Frobenius norm of we stored during the forward pass. Computing the quadratic terms and in parallel is easy and follows from our computation of and for in Section˜4.2.2. Computing is easy since we recompute the initial states of each chunk and have them available during the backward pass, while is updated backwards in a for loop.
B.2.1 Computing and .
In forward pass we solve
| (35) | ||||
Recall that the gradient is transpose of the Jacobian, thus we obtain
| (36) |
Thus, we can obtain by running a Chebyshev iteration to solve (for ) the linear system of equations
Now we have
| (37) | ||||
In the last equality we have used the identity .
Now we will compute the Jacobian of with respect to :
| (38) | ||||
Now, with gating, we have . Which can be unrolled as
| (42) |
We will compute for some ,
| (43) |
Computing for some
| (44) |
Taking differentials on both sides,
| (45) | ||||
where in the last equality we used the identity .
| (46) |
Substituting the expression for into equation (46) we get:
| (47) | |||||
Note that the following equations hold:
| (48) | ||||
since and the fact that the Kronecker products after the simplification is a scalar times a vector.
For the other terms is holds:
| (49) | ||||
where we used the fact and the fact that the vec operator applied to a row vector returns the same result as applying it on its transpose (so we go from to ). Since is symmetric we can sum both contributions and get twice that amount.
Eventually we get:
| (50) |
Or equivalently, collecting the terms that are linear in the gradient:
| (51) |
Note: The last term creates a dependence on through , which is expected since the regularization couples the gradient computation.
Appendix C Proof of Lemma˜2
Lemma˜2 describes an interesting phenomenon where, for the CH method (Algorithm˜1), the gradient obtained from implicit differentiation coincides with the exact gradient obtained via backpropagation (chain rule). To prove this result, one way is to derive an analytic expression for (Section˜C.1) and then inspect the recursions. However, this can be algebraically involved. Here, we present a clear proof based on some simple observations.
First, note that the output is linear in and moreover there is a matrix function such that
| (52) |
Here is a polynomial function of that encodes the Chebyshev iteration (Algorithm˜1). Conversely, we understand that can be computed by applying the Chebyshev iteration with for iterations (together with other parameters such as ). Then, given the output gradient , we have
| (53) |
where the last equality follows, since is symmetric, which implies is symmetric. The proof is finished by observing that can be computed via Algorithm˜1 with and other parameters, which gives us .
C.1 The Exact Backward Pass for and
Here we show how to obtain the exact gradients of and in Algorithm˜1 given the output gradient , which might be of independent interests. The key insight here is that the Chebyshev iteration can be reversed.
Backward Pass for dq. Let be the identity matrix of suitable size. To derive a backward pass of Algorithm˜1, we first write down the update of concisely in the following recursion
| (54) |
where are defined as
| (55) |
Note that is symmetric. Define for every . With some loss function , assume we are now given , and our goal is to compute . Since appears in Eq.˜54 for every , we know all depend on . Therefore, with , we have
It remains to compute for every . Applying the chain rule to Eq.˜54, we obtain
| (56) |
Note that depend on some constant terms and . Thus, to compute them backward we assume access to and these constants. By reversing Eq.˜weight schedule we derive the following recursion:
| (57) |
Similarly to how decreases with and converges to , we may prove is convergent to the other fixed point, , as decreases (and the iterate does not stop at ).
Backward Pass for dA. From Eq.˜54 and Eq.˜55 we see that
| (58) |
where denotes the Kronecker product; this is the out product of and , as and are vectors.
Reverse Chebyshev Iteration. At first glance, computing requires storing in the forward pass, and the actual calculation of is done after we run the backward pass for in Eq.˜56. However, storing all ’s would be memory-inefficient. To address this issue, a main insight here is that we can reverse Eq.˜54 and write
| (59) |
This implies that we can recover all the iterates as soon as we have access to the last two, . Therefore, to obtain , we can run two iteration schemes in Eq.˜56 and Eq.˜59 simultaneously.
Remark 2.
We find that being able to run the iterative update backward in a numerically stable fashion is a main feature of the Chebyshev iterative method (or more generally, gradient descent variants with momentum). Vanilla gradient descent can not efficiently reverse its iterate with stepsize , as it requires inverting . Moreover, reversing Eq.˜59 can be done stably, as is often in a good numerical range, which means division by in Eq.˜59 is not an issue. To see this, first note that by Lemma˜3 we have
Note that defined in Lemma˜3 is an increasing function of and therefore of . We then have that for any (we will not consider the case as this means we need to add a very large regularization strength which might harm the minimization of the regression loss). In comparison, if we were to reverse the CG iteration, we would need to divide a quantity that is often numerically as small as or as large as (see Fig.˜5). This is why it is numerically unstable to reverse CG.


(a) (b)
| (60) | ||||
| (61) | ||||
| (62) | ||||
| (63) | ||||
| (64) |
Appendix D Experimental Setup
Model Configurations. We consider models of 3 different sizes: 440M, 1B, and 2.8B. This is summarized in Table˜3. All models are with the GPT2 tokenizer similarly to [Von-arXiv2025-mesa].
| Model Size | Number of Layers | Number of Heads | Hidden Dimension |
| 440M | 28 | 8 | 1024 |
| 1B | 28 | 12 | 1536 |
| 2.8B | 32 | 20 | 2560 |
Training Configurations. All models are trained with the AdamW optimizer with initial learning rate , warm-up steps, cosine schedule, gradient clipping with maximum norm .
| Model Size | Global Batch Size | Total Number of Training Tokens | Sequence Length |
| 440M | 1M | 8B | 2048 |
| 1B | 2M | 20B | 2048 |
| 2.8B | 2M | 100B | 4096 |
Models of the same scale use the same training configurations. Specifically (see also Table˜4):
-
•
For 440M models, we use sequence length 2048 and 8B DCLM tokens.
-
•
For 1B models, we use sequence length 2048 and 20B DCLM tokens.
-
•
For 2.8B models, we use sequence length 4096 and 100B DCLM tokens.
Model Hyperparameters. We use default parameters for all other models as given in the Flash-Linear-Attention v0.4.0 library (except the ones mentioned in Table˜3). For our approach, we use , with gating and -connection enabled by default, unless otherwise specified. We also run CH for iterations for all experiments.
Individual Experiments. We now describe the setups for each individual experiment.
In Fig.˜1a, we randomly generate tensors and and normalize them along the last dimension (). Here simulate the batch size, sequence length, number of heads, and head dimension, respectively. Then we compute the covariance matrices of , normalize its every slice by its Frobenius norm. The code to generate data is shown below.
k = torch.randn(B, T, H, D).to(dtype).to(’cuda’)
q = torch.randn(B, T, H, D).to(dtype).to(’cuda’)
q = q / torch.linalg.vector_norm(q, dim=-1, keepdim=True).to(q)
k = k / torch.linalg.vector_norm(k, dim=-1, keepdim=True).to(k)
kk = torch.einsum(’...i,...j->...ij’, k, k).cumsum(1)
kk = kk / torch.linalg.matrix_norm(kk, ord=’fro’)[..., None, None]
kk.diagonal(dim1=-2, dim2=-1).add_(ridge_strength)
For Fig.˜1c and Fig.˜1e, we generate random input ids with vocabulary size 5000, sequence length 2048 within a 5-layer LLAMA; we set 2 heads and head dimension 128 for this architecture.
In the MQAR experiments of Fig.˜3a, we follow the standard experimental setting but consider a strictly harder setting with smaller model dimension (or hidden dimension). Indeed, in the setting of [arora2023zoology], the model dimension is always larger than or equal to the number of KV pairs, while in the setting here, in some cases the model dimension is smaller than the number of KV pairs, in which case linear SSMs could not perfectly memorize all KV pairs.
In the main paper, Fig.˜3a is without any gating or -connection.
In Fig.˜4 we considered the following tasks for long context evaluations. Reported results for each task is average over the score obtained for individual datasets in that task.
-
•
Retrieval-Augmented Generation (RAG): These tasks consist of open-domain question answer where the model is given a gold passage (passage containing the answer) interspersed between many other retrieved passages from a corpus [petroni2021kilt, Wikipedia dump split into 100-word passages]. The model is tasked with answering the question based on the obtained passages. We consider the following datasets from HELMET [yen2025helmet] for this task: Natural Questions, TriviaQA, PopQA, HotpotQA.
-
•
Many-shot In-Context Learning (ICL): ICL tests LLMs ability to learn new skills from a few examples. Here the task is to learn to classify between different concepts based on several in-context examples of the said concept. We consider the following datasets from HELMET [yen2025helmet] for this task: TREC Coarse, TREC Fine, NLU, BANKING77, CLINIC150.
-
•
Synthetic Recall: These tasks are variations of the “Needle-in-a-Haystack" task [needle2024haystack] where the goal is to retrieve an important piece of information, the “needle" from a long context of distractor tokens, the “haystack". These variations also test multi-hop tracing and aggregation capabilities of the model. We consider the following datasets from RULER [hsieh2024ruler] for this task: S-NIAH-1/2/3, MK-NIAH-1,2,3, MV-NIAH, MQ-NIAH, VT, CWE, FWE.
-
•
LongQA: These are long document based question-answering tasks. The documents are typically made long by randomly sampling different paragraphs from the same dataset along with the paragraph that contains the answer. We consider the following datasets from RULER [hsieh2024ruler] for this task: SQuAD, HotpotQA.
Appendix E How Does The Performance GKA Scale with Compute?
We consider models at three different scales: 440M, 1B and 2.8B. For training configurations and architecture refer to Appendix˜D. We use prototypical tasks from LM-Harness (see Section˜5.3.1 for list of tasks) to evaluate language modeling capabilities of GKA and compare with baseline SSM/fading memory layers. Table˜5 shows that at 440M scale, GKA is competitive with GDN and Deltanet. However, differences emerge at larger scales, with GKA showing increasing benefits. In particular, the retrieval capabilities of our model, as measured by FDA and SWDE consistently outperform all SSM baselines at 1B and 2.8B scale. We also report the results of equal-sized Transformer for completeness, which serves as a performance ceiling at each scale.
| Model |
|
|
|
|
|
|
|
|
|
|
|
|||||||||||
| acc_n | acc_n | acc | acc | acc_n | acc_n | acc_n | acc | contains | contains | |||||||||||||
| 440M Models | ||||||||||||||||||||||
| Transformer | 24.40 | 42.26 | 59.88 | 70.00 | 36.19 | 64.15 | 61.50 | 51.70 | 5.17 | 35.64 | 45.09 | |||||||||||
| Gated Linear Attention | 24.06 | 40.28 | 56.57 | 71.00 | 32.70 | 62.24 | 57.80 | 50.67 | 1.00 | 9.18 | 40.55 | |||||||||||
| Gated DeltaNet | 25.17 | 41.96 | 58.23 | 72.00 | 36.96 | 64.69 | 63.6 | 51.7 | 1.91 | 11.88 | 42.81 | |||||||||||
| DeltaNet | 25.09 | 41.92 | 61.13 | 65.00 | 37.20 | 64.47 | 64.00 | 49.49 | 2.81 | 14.31 | 42.54 | |||||||||||
| Gated KalmaNet (Ours) | 24.57 | 43.22 | 56.94 | 71.00 | 37.22 | 64.47 | 62.8 | 50.83 | 1.45 | 14.04 | 42.65 | |||||||||||
| 1B Models | ||||||||||||||||||||||
| Transformer | 26.62 | 46.42 | 59.94 | 77.00 | 44.01 | 67.14 | 68.30 | 54.06 | 8.35 | 45.18 | 49.70 | |||||||||||
| Mamba2 | 28.07 | 46.63 | 60.21 | 70.00 | 44.57 | 67.57 | 65.50 | 54.30 | 1.45 | 15.75 | 45.40 | |||||||||||
| Gated Linear Attention | 25.94 | 42.00 | 58.84 | 70.00 | 36.34 | 63.60 | 58.20 | 51.85 | 1.45 | 10.53 | 41.88 | |||||||||||
| Gated DeltaNet | 27.05 | 47.98 | 59.54 | 74.00 | 44.27 | 67.36 | 66.2 | 53.83 | 2.18 | 17.82 | 46.02 | |||||||||||
| DeltaNet | 27.56 | 46.25 | 59.97 | 71.00 | 43.18 | 67.74 | 65.90 | 55.41 | 3.09 | 20.61 | 46.07 | |||||||||||
| Gated KalmaNet (Ours) | 25.43 | 46.55 | 60.73 | 74.00 | 44.59 | 68.88 | 67.60 | 52.41 | 6.17 | 21.87 | 46.82 | |||||||||||
| 2.8B Models | ||||||||||||||||||||||
| Transformer | 32.25 | 56.10 | 64.28 | 80.00 | 60.96 | 73.56 | 79.50 | 61.72 | 58.53 | 72.28 | 63.92 | |||||||||||
| Mamba2 | 32.24 | 59.64 | 58.72 | 82.00 | 62.23 | 73.78 | 79.80 | 62.19 | 7.71 | 41.13 | 55.94 | |||||||||||
| Gated Linear Attention | 27.82 | 50.80 | 52.57 | 78.00 | 48.83 | 70.13 | 69.60 | 54.54 | 2.81 | 20.43 | 47.55 | |||||||||||
| Gated DeltaNet | 32.59 | 60.02 | 62.75 | 82.00 | 62.8 | 74.32 | 80.6 | 62.35 | 8.26 | 44.28 | 57.00 | |||||||||||
| DeltaNet | 32.85 | 58.16 | 42.51 | 81.00 | 61.13 | 73.78 | 43.90 | 61.72 | 11.80 | 46.08 | 51.29 | |||||||||||
| Gated KalmaNet (Ours) | 32.51 | 59.89 | 61.68 | 85.00 | 63.84 | 74.81 | 83.2 | 64.17 | 12.89 | 50.95 | 58.89 | |||||||||||
Appendix F Ablations
In this section we consider ablations for various modeling choices made in arriving at our final GKA model. For all ablations, we consider 2.8B models trained on 100B tokens on DCLM at 4K context length (unless mentioned otherwise). We use the same architecture and training configurations for these ablations as mentioned in Appendix˜D.
F.1 Does Adaptive Regularization Help?
As discussed in Section˜4.1, we introduced adaptive regularization to control the condition number of for numerical stability. Here we ablate this choice, specifically we compare the following runs.
-
1.
Adaptive regularization. We train a model with . We report results for for this run.
-
2.
Constant regularization We train same model architecture (as above) with (a constant). This choice of is motivated from concurrent work [Von-arXiv2025-mesa] which explored a similar ridge regression objective for LLM training.
As shown in Fig.˜6, without strict condition number control, gradient norms spike during training, leading to increased cross entropy loss (compared to the run with adaptive regularization).


F.2 Does Adaptive Weighting Help?
In Section˜4.1, we discussed increasing the expressivity of our layer by introducing adaptive weights which re-weigh the past to be exponentially decaying in time. Given constant-sized memory, we hypothesize this adaptive weighting (gating) allows GKA to learn an effective representation by incorporating recency bias into its computation. In this subsection we test this hypothesis. We carry out the following runs.
-
1.
Adaptive weighting (gating). We train a model with adaptive weights. Specifically, for all , we parameterize the weight for the sample at time-step as , with each learnable.
-
2.
No weighting. We train the same model architecture as above, but with no weights. This essentially results in an unweighted ridge regression objective obtained by setting in Eq.˜3.
Table˜6 shows clear benefits of adapting weighting with improvements across the board on all LM-Harness tasks considered, thereby validating our hypothesis.
| Adaptive Weights |
|
|
|
|
|
|
|
|
|
|
|
|||||||||||
| acc_n | acc_n | acc | acc | acc_n | acc_n | acc_n | acc | contains | contains | |||||||||||||
| ✗ | 28.24 | 51.73 | 57.68 | 76 | 53.87 | 71.87 | 71.6 | 54.38 | 6.08 | 33.03 | 50.45 | |||||||||||
| ✓ | 32.51 | 59.89 | 61.68 | 85 | 63.84 | 74.81 | 83.2 | 64.17 | 12.89 | 50.95 | 58.89 |
F.3 Does -connection Improve Training of GKA?
In Section˜4.3, we introduce the -connection as a residual connection that establishes a direct path for gradient flow through the GLA solution, improving training stability. This allows the model to fall back on the GLA solution when CH produces poor-quality results due to non-convergence of the iterative solver within the fixed iteration budget. To validate this design choice, we perform two runs.
On LM-Harness, both models perform similarly, with R1 and R2 achieving aggregate scores of 58.89 and 58.39, respectively. However, clear differences emerge under long-context evaluation, where we trained both models on an additional 25B tokens from long documents at 128K context length. Fig.˜7 shows that GKA without the -connection exhibits inferior long-context performance on average, with Synthetic Recall and LongQA showing major degradation.
Appendix G Effects of Different Regularization Strengths
Recall that we proposed setting adaptive regularization . We now present experiments validating this choice.
Synthetic Experiments. First, we generate data as per Fig.˜1a, where the covariance matrix is normalized by its Frobenius norm. In this case we set for varying in . Fig.˜8 shows that the maximum regularized residual norm (computed as the maximum of over all dimensions where is the estimate of CH at iteration ) decreases as we enlarge . This is because having a large reduces the condition number. The downside, though, with a large is that it reduces the memorization capacity, namely, it might enlarge , the true residual of interest.




(a) (b) (c) (d)
GKA with different regularization strengths. We train several 2.8B models with varying regularization strength by choosing . While performance on LM-Harness (Table˜7) shows little discrepancy, we observe noticeable differences in long-context performance—where memorization capacity matters most—(Fig.˜9). Specifically, the long-context performance of GKA improves initially as we decrease from . This is expected since this increases the memorization capacity of the model. However, decreasing further from causes performance to decrease. This can be attributed to the increasing condition number of the problem, which reduces the quality of the solution computed by CH (Fig.˜8).
|
|
|
|
|
|
|
|
|
|
|
||||||||||||
| acc_n | acc_n | acc | acc | acc_n | acc_n | acc_n | acc | contains | contains | |||||||||||||
| 0.01 | 33.45 | 58.63 | 62.63 | 85.00 | 63.36 | 73.99 | 81.40 | 63.14 | 11.16 | 51.49 | 58.43 | |||||||||||
| 0.02 | 32.51 | 59.89 | 61.68 | 85.00 | 63.84 | 74.81 | 83.20 | 64.17 | 12.89 | 50.95 | 58.89 | |||||||||||
| 0.05 | 32.68 | 61.66 | 53.57 | 79.00 | 63.46 | 74.84 | 82.60 | 63.77 | 11.98 | 49.68 | 57.32 | |||||||||||
| 0.1 | 32.76 | 59.85 | 63.52 | 84.00 | 63.95 | 75.08 | 83.20 | 63.54 | 11.43 | 51.22 | 58.86 |
| Sketch dimension | LM-Harness avg. | Training throughput |
| 32 | 57.57 | 8.37 |
| no-sketch | 58.89 | 7.65 |
Appendix H Latent Sketching for Approximate Solutions
We introduce the idea of sketching from random matrix theory to further control the amount of FLOPs vs accuracy in GKA. Sketching involves down projecting the normal equations into a low-dimensional subspace, solving the equations in this subspace and finally up-projecting the solution back to the original space. This reduces the worst-case computational complexity of our approach from to , where and is the number of iterations in Algorithm˜1. To the best of our knowledge our work is the first one introducing sketching as a viable solution to increase efficiency of neural network layers that are defined implicitly by the solution to an optimization problem. Sketching can be thought of as an analogous to the Multi Latent Attention idea introduced by DeepSeek but applied to fading memory layers. Table˜8 shows preliminary results of this idea applied to GKA. Both models (no-sketch and sketch dim 32) are trained from scratch at 2.8B scale on 100B tokens.
Appendix I Hybrid Gated KalmaNet
As discussed in Section˜A.2, augmenting SSM models with Attention layers has proven to be an effective way of improving performance on tasks that require recalling information from the distant past. In this section, we show that our Gated KalmaNet layer can be interleaved with Attention layers to yield even stronger models. Our Hybrid GKA model is based on the Qwen3 architecture [yang2025qwen3]. Namely, our Hybrid model consists of a stack of “decoder” blocks, each of which contains a sequence mixer—either Attention or GKA—followed by an MLP. Similar to Qwen3, our Attention layers use QK normalization layers. Our Hybrid model consists of 30 decoder blocks, 26 of which use GKA as the sequence mixer, and 4 that use Attention. The Attention decoder blocks are at indices 6, 14, 22, and 29. Our Hybrid models follow the same training procedure as our non-Hybrid models. Specifically, we pretrain our Hybrid model on 100B tokens with a 4K context size, followed by fine-tuning on 25B tokens at a 128K context size.
When evaluating our pretrained Hybrid model standard NLP benchmarks, we observe that it improves substantially on recall-oriented tasks (FDA & SWDE) compared to the non-Hybrid model444Note, our non-hybrid model shares the same architecture as the hybrid with the distinction that all 4 Attention layers are replaced with GKA layers., as shown in Table˜9. Further, when evaluating our fine-tuned long-context model on tasks that require effective modeling of long-range dependencies, we observe a significant improvement across all context lengths, as shown in Fig.˜10.
| Model |
|
|
|
|
|
|
|
|
|
|
|
|||||||||||
| acc_n | acc_n | acc | acc | acc_n | acc_n | acc_n | acc | contains | contains | |||||||||||||
| Gated KalmaNet (Hybrid) | 33.02 | 59.47 | 64.07 | 80.00 | 62.74 | 74.59 | 81.40 | 64.64 | 53.18 | 72.46 | 64.56 | |||||||||||
| Gated KalmaNet | 32.51 | 59.89 | 61.68 | 85.00 | 63.84 | 74.81 | 83.20 | 64.17 | 12.89 | 50.95 | 58.89 |