Skip to content

batched_tiled: derive real-space cutoff from num_k#32

Open
sirmarcel wants to merge 3 commits into
add/batched-tiledfrom
fix/batched-tiled-derive-cutoff
Open

batched_tiled: derive real-space cutoff from num_k#32
sirmarcel wants to merge 3 commits into
add/batched-tiledfrom
fix/batched-tiled-derive-cutoff

Conversation

@sirmarcel

Copy link
Copy Markdown
Collaborator

Problem

batched_tiled.prepare derived smearing from num_k but took cutoff as a
required, independent argument. This breaks the num_k pathway: num_k fixed
the reciprocal grid and the smearing, but the real-space truncation was left
decoupled. With the num_k-derived smearing, a mismatched cutoff silently
under-converges the real-space erfc tail.

Reproduced vs batched_mixed (frame 0, num_k=80smearing=3.12, balanced
cutoff lr·8≈12.5):

tiled cutoff rel. error vs batched_mixed
4.0 5.3 %
5.0 11.9 %
lr·8 (matched) 0

Fix

Make cutoff optional and derive it as lr_wavelength · 8 when omitted,
mirroring batched_mixed. An explicit cutoff still overrides (only the
real-space radius; smearing keeps tracking num_k). Num_k-only pathway now
matches batched_mixed bit-exactly (E/F/S), including mixed pbc/non-pbc batches.

Why tests missed it

Every tiled test picks num_k = _num_k_for_lr(atoms, cutoff/8) and pins
smearing = cutoff/4 — a hand-constructed consistent triple — and the
cross-checks feed the same cutoff to both backends, so a wrong-but-consistent
cutoff still agrees. Nothing exercised the num_k-only pathway.

Adds test_num_k_pathway_matches_mixed (num_k-only vs batched_mixed) and
test_derived_cutoff_converges_real_space. Both fail on the pre-fix source.

Docs (README.md, CLAUDE.md) updated. Full suite: 592 passed.

🤖 Generated with Claude Code

`batched_tiled.prepare` derived `smearing` from `num_k` but took `cutoff`
as a required, independent argument, breaking the num_k pathway: the
reciprocal grid and smearing tracked `num_k` while the real-space
truncation did not. With the num_k-derived smearing, a mismatched cutoff
silently under-converges the real-space sum (energies off by several
percent vs `batched_mixed`).

Make `cutoff` optional and derive it as `lr_wavelength * 8` when omitted,
mirroring `batched_mixed`; an explicit `cutoff` still overrides.

Tests missed this because each picks `num_k = _num_k_for_lr(atoms,
cutoff/8)` and pins `smearing = cutoff/4`, hand-constructing a consistent
triple, while cross-checks feed the same cutoff to both backends. Add
`test_num_k_pathway_matches_mixed` and `test_derived_cutoff_converges_real_space`.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

@sirmarcel sirmarcel left a comment

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good overall but needs a bit of polish; see comments :)

Comment thread CLAUDE.md Outdated
- **Serial**: `Ewald()`, `PME()`, or `P3M()` → `.prepare()` → `.energy()/.potentials()/.energy_forces()/.energy_forces_stress()`
- **Batched**: `jaxpme.batched_mixed.Ewald()` → `.prepare([atoms_list], cutoff)` → same methods but batched
- **Tiled batched**: `jaxpme.batched_tiled.Ewald()` → `.prepare([atoms_list], num_k, cutoff)` — sum-pads atoms per system (heterogeneous-batch friendly), routes reciprocal sum through tile-dispatched XLA kernel. `num_k` is REQUIRED (no cutoff-only path).
- **Tiled batched**: `jaxpme.batched_tiled.Ewald()` → `.prepare([atoms_list], num_k, cutoff=None)` — sum-pads atoms per system (heterogeneous-batch friendly), routes reciprocal sum through tile-dispatched XLA kernel. `num_k` is REQUIRED and fixes the k-grid; `cutoff` is optional and, when omitted, is derived from it as `lr_wavelength · 8` (mirrors `batched_mixed`).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be ALWAYS omitted?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right — in normal use cutoff is omitted; num_k is the single knob and the cutoff derives from it. Updated the Key APIs line to show the canonical prepare([atoms_list], num_k) call and frame cutoff as override-only. (The one real reason to pass it is to widen the real-space radius beyond the default lr·8 for extra accuracy — what the tests do — so I kept it as an override rather than removing it.) 7981dcd

Comment thread jaxpme/batched_tiled/batching.py Outdated
smearing = lr_wavelength * 2.0
else:
# non-pbc uses bare 1/r over all pairs: smearing is unused, and the
# derived cutoff only sizes the masked-out neighbor list.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is confusingly worded. what role does the cutoff actually play in the non-pbc case? is it even used? where? how?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question — traced it and verified empirically. For a non-pbc structure the physics is the bare 1/r sum over all pairs (triu_indices in to_lrbatch_nopbc), independent of cutoff. The cutoff only feeds the vesin list in to_structure; those pairs go into the PBC range-separated real-space term, which is masked off (pbc_mask_atom is False) for non-pbc atoms. So cutoff never changes a non-pbc result — it only sizes the padded pair buffer. Confirmed: energy identical for cutoff ∈ {2, 5, 12, None} while n_pair_slots goes 1→16→64. Rewrote the comment to say exactly this. 7981dcd

Comment thread README.md
- CLAUDE.md: show the canonical `prepare([atoms_list], num_k)` call (cutoff
  normally omitted); drop the convention restatement and point to README as
  the single source of truth.
- batching.py: spell out the role of `cutoff` in the non-pbc branch — the
  bare 1/r sum runs over all pairs (triu_indices), so cutoff doesn't affect
  the result, only the padded pair-buffer size.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@sirmarcel

Copy link
Copy Markdown
Collaborator Author

Polish round in 7981dcd:

  • CLAUDE.md: canonical call is now prepare([atoms_list], num_k); convention detail removed and pointed to README (single source of truth).
  • README: kept as the one detailed source.
  • batching.py: non-pbc comment now explains exactly what cutoff does there (nothing, result-wise — bare 1/r over all pairs; it only sizes the masked-out pair buffer).

ruff clean; tiled suite still green.

Comment thread jaxpme/batched_tiled/batching.py Outdated
# so cutoff does not affect the result here — it only feeds the vesin
# list in to_structure, whose pairs land in the PBC real-space term and
# are masked off for non-pbc atoms (it just sizes the padded pair
# buffer). smearing is likewise unused; None signals both.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the point of this comment?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the cutoff is passed to vesin and then the resulting pairs are masked out, we should either directly set a minimal NL there, or maybe a default cutoff, but not just put a random number there that may generate large, masked out, data

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — went with the override (no vesin coaxing). to_structure(cutoff=None) now returns an empty neighbor list, and tiled prepare passes None for non-pbc structures (their real space is the bare 1/r sum over all pairs in to_lr anyway). So no masked-out pairs are built or carried: the non-pbc pair buffer drops from the old 1→16→64 (it grew with the cell) back to the padding minimum, results unchanged. 10bb29b

— Claude

For non-pbc structures the real-space term is the bare 1/r sum over all
pairs (triu_indices in to_lr); the cutoff neighbor list is masked off and
never contributes. Deriving `cutoff = lr*8` for it (the previous behavior)
built a list that grew with the cell and was carried as dead, masked-out
pairs (n_pair_slots 1 -> 16 -> 64).

Override it instead: `to_structure(cutoff=None)` returns an empty list (no
vesin call), and tiled `prepare` passes `None` for non-pbc. Results are
unchanged; the non-pbc pair buffer drops to the padding minimum.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants