batched_tiled: derive real-space cutoff from num_k#32
Conversation
`batched_tiled.prepare` derived `smearing` from `num_k` but took `cutoff` as a required, independent argument, breaking the num_k pathway: the reciprocal grid and smearing tracked `num_k` while the real-space truncation did not. With the num_k-derived smearing, a mismatched cutoff silently under-converges the real-space sum (energies off by several percent vs `batched_mixed`). Make `cutoff` optional and derive it as `lr_wavelength * 8` when omitted, mirroring `batched_mixed`; an explicit `cutoff` still overrides. Tests missed this because each picks `num_k = _num_k_for_lr(atoms, cutoff/8)` and pins `smearing = cutoff/4`, hand-constructing a consistent triple, while cross-checks feed the same cutoff to both backends. Add `test_num_k_pathway_matches_mixed` and `test_derived_cutoff_converges_real_space`. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
sirmarcel
left a comment
There was a problem hiding this comment.
Good overall but needs a bit of polish; see comments :)
| - **Serial**: `Ewald()`, `PME()`, or `P3M()` → `.prepare()` → `.energy()/.potentials()/.energy_forces()/.energy_forces_stress()` | ||
| - **Batched**: `jaxpme.batched_mixed.Ewald()` → `.prepare([atoms_list], cutoff)` → same methods but batched | ||
| - **Tiled batched**: `jaxpme.batched_tiled.Ewald()` → `.prepare([atoms_list], num_k, cutoff)` — sum-pads atoms per system (heterogeneous-batch friendly), routes reciprocal sum through tile-dispatched XLA kernel. `num_k` is REQUIRED (no cutoff-only path). | ||
| - **Tiled batched**: `jaxpme.batched_tiled.Ewald()` → `.prepare([atoms_list], num_k, cutoff=None)` — sum-pads atoms per system (heterogeneous-batch friendly), routes reciprocal sum through tile-dispatched XLA kernel. `num_k` is REQUIRED and fixes the k-grid; `cutoff` is optional and, when omitted, is derived from it as `lr_wavelength · 8` (mirrors `batched_mixed`). |
There was a problem hiding this comment.
Shouldn't it be ALWAYS omitted?
There was a problem hiding this comment.
Right — in normal use cutoff is omitted; num_k is the single knob and the cutoff derives from it. Updated the Key APIs line to show the canonical prepare([atoms_list], num_k) call and frame cutoff as override-only. (The one real reason to pass it is to widen the real-space radius beyond the default lr·8 for extra accuracy — what the tests do — so I kept it as an override rather than removing it.) 7981dcd
| smearing = lr_wavelength * 2.0 | ||
| else: | ||
| # non-pbc uses bare 1/r over all pairs: smearing is unused, and the | ||
| # derived cutoff only sizes the masked-out neighbor list. |
There was a problem hiding this comment.
this is confusingly worded. what role does the cutoff actually play in the non-pbc case? is it even used? where? how?
There was a problem hiding this comment.
Good question — traced it and verified empirically. For a non-pbc structure the physics is the bare 1/r sum over all pairs (triu_indices in to_lr → batch_nopbc), independent of cutoff. The cutoff only feeds the vesin list in to_structure; those pairs go into the PBC range-separated real-space term, which is masked off (pbc_mask_atom is False) for non-pbc atoms. So cutoff never changes a non-pbc result — it only sizes the padded pair buffer. Confirmed: energy identical for cutoff ∈ {2, 5, 12, None} while n_pair_slots goes 1→16→64. Rewrote the comment to say exactly this. 7981dcd
- CLAUDE.md: show the canonical `prepare([atoms_list], num_k)` call (cutoff normally omitted); drop the convention restatement and point to README as the single source of truth. - batching.py: spell out the role of `cutoff` in the non-pbc branch — the bare 1/r sum runs over all pairs (triu_indices), so cutoff doesn't affect the result, only the padded pair-buffer size. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
Polish round in
ruff clean; tiled suite still green. |
| # so cutoff does not affect the result here — it only feeds the vesin | ||
| # list in to_structure, whose pairs land in the PBC real-space term and | ||
| # are masked off for non-pbc atoms (it just sizes the padded pair | ||
| # buffer). smearing is likewise unused; None signals both. |
There was a problem hiding this comment.
what's the point of this comment?
There was a problem hiding this comment.
if the cutoff is passed to vesin and then the resulting pairs are masked out, we should either directly set a minimal NL there, or maybe a default cutoff, but not just put a random number there that may generate large, masked out, data
There was a problem hiding this comment.
Done — went with the override (no vesin coaxing). to_structure(cutoff=None) now returns an empty neighbor list, and tiled prepare passes None for non-pbc structures (their real space is the bare 1/r sum over all pairs in to_lr anyway). So no masked-out pairs are built or carried: the non-pbc pair buffer drops from the old 1→16→64 (it grew with the cell) back to the padding minimum, results unchanged. 10bb29b
— Claude
For non-pbc structures the real-space term is the bare 1/r sum over all pairs (triu_indices in to_lr); the cutoff neighbor list is masked off and never contributes. Deriving `cutoff = lr*8` for it (the previous behavior) built a list that grew with the cell and was carried as dead, masked-out pairs (n_pair_slots 1 -> 16 -> 64). Override it instead: `to_structure(cutoff=None)` returns an empty list (no vesin call), and tiled `prepare` passes `None` for non-pbc. Results are unchanged; the non-pbc pair buffer drops to the padding minimum. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Problem
batched_tiled.preparederivedsmearingfromnum_kbut tookcutoffas arequired, independent argument. This breaks the num_k pathway:
num_kfixedthe reciprocal grid and the smearing, but the real-space truncation was left
decoupled. With the num_k-derived smearing, a mismatched cutoff silently
under-converges the real-space erfc tail.
Reproduced vs
batched_mixed(frame 0,num_k=80→smearing=3.12, balancedcutoff
lr·8≈12.5):lr·8(matched)Fix
Make
cutoffoptional and derive it aslr_wavelength · 8when omitted,mirroring
batched_mixed. An explicitcutoffstill overrides (only thereal-space radius;
smearingkeeps trackingnum_k). Num_k-only pathway nowmatches
batched_mixedbit-exactly (E/F/S), including mixed pbc/non-pbc batches.Why tests missed it
Every tiled test picks
num_k = _num_k_for_lr(atoms, cutoff/8)and pinssmearing = cutoff/4— a hand-constructed consistent triple — and thecross-checks feed the same cutoff to both backends, so a wrong-but-consistent
cutoff still agrees. Nothing exercised the num_k-only pathway.
Adds
test_num_k_pathway_matches_mixed(num_k-only vsbatched_mixed) andtest_derived_cutoff_converges_real_space. Both fail on the pre-fix source.Docs (
README.md,CLAUDE.md) updated. Full suite: 592 passed.🤖 Generated with Claude Code