-
Notifications
You must be signed in to change notification settings - Fork 53
Type test_io, neighbors, and transforms #243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughTests were refactored to use concrete fixtures and order-insensitive distance comparisons; transforms tests use 0‑D torch scalars and explicit int casts. torch_sim.neighbors and torch_sim.transforms were updated to accept optional cell/cell_shifts (None), added overloads/typing adjustments, and branched distance/shift logic accordingly. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Test as Test/Caller
participant Neigh as torch_sim/neighbors.py
participant Trans as torch_sim/transforms.py
Test->>Neigh: strict_nl(cutoff, positions, cell|None, mapping, system_mapping, shifts_idx)
alt cell is None
Note over Neigh: skip cell-shift computation (None)
Neigh->>Trans: compute_distances_with_cell_shifts(pos, mapping, cell_shifts=None)
Trans-->>Neigh: distances (no shifts)
else cell provided
Note over Neigh: compute & apply cell shifts
Neigh->>Trans: compute_cell_shifts(cell, shifts_idx, system_mapping)
Trans-->>Neigh: cell_shifts
Neigh->>Trans: compute_distances_with_cell_shifts(pos, mapping, cell_shifts)
Trans-->>Neigh: distances (with shifts)
end
Neigh-->>Test: neighbor_pairs, distances, metadata
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Pre-merge checks (3 passed)✅ Passed checks (3 passed)
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. ✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
4ccee76
to
a5ee051
Compare
ee412da
to
1b7e37c
Compare
b3eadc5
to
e9261e2
Compare
09edd54
to
b2b7b1a
Compare
fix more type issues fixing types for creating cell shifts fix type defining the fixture to use neighbor improvements wip type runners fix types in trajectory backup before thinking of messing with autobatching transforms is typed made test_io conform to types lint lint fixes fix transforms code fix safemask type revert trajectory file rm runners changes fix desc for fn ignore call-arg for pbc
b2b7b1a
to
9ac309a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (6)
tests/test_io.py (1)
264-274
: ImportError tests: fixture-based calls are correct; minor hardening suggestionUsing real fixtures instead of None exercises imports earlier and is more realistic. To make the import failure simulation a bit more robust (and concise), consider monkeypatch.dict on sys.modules for the targets instead of multiple setitem calls.
Example:
- monkeypatch.setitem(sys.modules, "ase", None) - monkeypatch.setitem(sys.modules, "ase.data", None) + monkeypatch.dict(sys.modules, {"ase": None, "ase.data": None})Also applies to: 276-287, 289-300, 302-312, 314-325, 327-337
torch_sim/transforms.py (4)
352-355
: wrap_positions: center scalar support is good—update docstring and keep type clarityAllowing a scalar center is ergonomic. Please update the “Args” doc to reflect tuple[float, float, float] | float, and avoid ambiguity.
- center (Tuple[float, float, float]): Center of the cell as - (x,y,z) tuple, defaults to (0.5, 0.5, 0.5). + center (tuple[float, float, float] | float): Center of the cell + as (x,y,z) tuple or scalar, defaults to (0.5, 0.5, 0.5).Also applies to: 377-381, 389-399
469-496
: Prefer public torch.dtype instead of private torch.types._dtypeUsing the private _dtype type and a type: ignore can be avoided by switching to torch.dtype.
-from torch.types import _dtype ... -def get_cell_shift_idx(num_repeats: torch.Tensor, dtype: _dtype) -> torch.Tensor: +def get_cell_shift_idx(num_repeats: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: @@ - ) # type: ignore[call-overload] + )If _dtype isn’t used elsewhere, remove the import to keep things clean.
538-565
: compute_cell_shifts: doc return type/shape mismatchFunction can return None and returns per-pair shifts (n_pairs, 3), not (n_systems, 3). Update the doc to match behavior.
- Returns: - torch.Tensor: A tensor of shape (n_systems, 3) containing - the computed cell shifts. + Returns: + torch.Tensor | None: A tensor of shape (n_pairs, 3) containing + the computed cell shifts, or None if cell is None.
1165-1169
: safe_mask: avoid evaluating fn on masked-out valuesCurrent implementation may compute fn on zeros at masked-out positions (e.g., log(0) -> -inf), which is later hidden but can trigger warnings/NaNs in grads. Apply fn only where mask is True.
-def safe_mask( - mask: torch.Tensor, - fn: Callable[..., torch.Tensor], - operand: torch.Tensor, - placeholder: float = 0.0, -) -> torch.Tensor: +def safe_mask( + mask: torch.Tensor, + fn: Callable[..., torch.Tensor], + operand: torch.Tensor, + placeholder: float = 0.0, +) -> torch.Tensor: @@ - masked = torch.where(mask, operand, torch.zeros_like(operand)) - return torch.where(mask, fn(masked), torch.full_like(operand, placeholder)) + out = torch.full_like(operand, placeholder) + if mask.any(): + out[mask] = fn(operand[mask]) + return outtests/test_transforms.py (1)
701-702
: Passing 0-dim tensors for r_onset/r_cutoff matches the API; optional ergonomicsGood mypy-friendly fix. Optionally, we could make multiplicative_isotropic_cutoff accept float | Tensor and as_tensor internally to keep tests (and users) free to pass floats.
If desired, I can draft a minimal patch to transforms.multiplicative_isotropic_cutoff to accept floats without changing types at call sites.
Also applies to: 719-721, 744-746, 767-768, 785-786, 803-805
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
tests/test_io.py
(1 hunks)tests/test_neighbors.py
(12 hunks)tests/test_transforms.py
(7 hunks)torch_sim/neighbors.py
(8 hunks)torch_sim/transforms.py
(13 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-31T11:15:22.654Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:1022-1024
Timestamp: 2025-08-31T11:15:22.654Z
Learning: In PyTorch code, avoid using .item() on tensors when performance is critical as it causes thread synchronization between GPU and CPU, breaking parallelism. Use more specific type ignore comments like `# type: ignore[arg-type]` instead of generic `# type: ignore` to satisfy linting rules while maintaining performance.
Applied to files:
torch_sim/neighbors.py
🧬 Code graph analysis (4)
tests/test_io.py (2)
tests/conftest.py (5)
si_sim_state
(142-144)si_atoms
(89-91)device
(24-25)si_phonopy_atoms
(119-138)si_structure
(101-115)torch_sim/io.py (6)
state_to_atoms
(29-72)state_to_phonopy
(131-177)state_to_structures
(75-128)atoms_to_state
(180-245)phonopy_to_state
(316-391)structures_to_state
(248-313)
torch_sim/neighbors.py (1)
torch_sim/transforms.py (1)
compute_cell_shifts_strict
(567-588)
tests/test_neighbors.py (1)
torch_sim/neighbors.py (3)
standard_nl
(411-497)vesin_nl
(571-637)vesin_nl_ts
(501-568)
tests/test_transforms.py (1)
torch_sim/transforms.py (2)
cutoff_fn
(1120-1122)multiplicative_isotropic_cutoff
(1067-1124)
🔇 Additional comments (16)
torch_sim/transforms.py (3)
498-536
: compute_distances_with_cell_shifts: API widening looks goodcell_shifts: Optional[...] is a sensible, typed way to express “no shifts.” Logic and validation are clear.
567-589
: compute_cell_shifts_strict: clear separation—LGTMNon-None contract is explicit and mirrors compute_cell_shifts behavior. Good addition.
1127-1162
: high_precision_sum: widened dim types—LGTMAccepting int | list[int] | tuple[int, ...] | None aligns with torch.sum. Behavior preserved.
tests/test_transforms.py (1)
458-466
: Loop bound cast to int avoids range TypeError—LGTMExplicit int() on tensor.count ensures compatibility with range and reads clearly.
tests/test_neighbors.py (9)
2-2
: LGTM! Good addition of explicit typing import.The import of
Callable
fromcollections.abc
supports the new type annotations in the test functions.
108-123
: LGTM! Type annotations for fixture functions are clear and correct.The return type annotations
-> list[Atoms]
for bothperiodic_atoms_set()
andmolecule_atoms_set()
fixtures properly document their return types and help with static type checking.Also applies to: 127-130
135-141
: Good parameterization change for fixture handling.The switch from direct fixture usage to parameterized fixture names with
request.getfixturevalue()
provides better test control and clearer test naming.Also applies to: 156-156
215-215
: LGTM! Proper handling of distance comparison ordering.The change to sort distances before comparison (
np.sort()
) ensures order-insensitive comparison, which is the correct approach for validating neighbor list implementations that may return neighbors in different orders.Also applies to: 241-242, 245-245, 249-249
254-265
: LGTM! Consistent fixture parameterization pattern.The parameterization pattern is consistently applied across test functions, and the
Callable
type annotation fornl_implementation
provides proper typing.Also applies to: 273-273
292-292
: LGTM! Consistent distance sorting for reliable comparisons.The sorted distance comparisons ensure that neighbor list implementations are validated correctly regardless of the order in which they return neighbor pairs.
Also applies to: 314-321
333-333
: LGTM! Proper type annotation for callable parameter.The
Callable
type annotation fornl_implementation
maintains consistency with other test functions.
359-359
: LGTM! Consistent sorting pattern maintained.The distance sorting approach is consistently applied across all test functions for reliable neighbor list validation.
Also applies to: 372-372, 375-375
568-575
: All test call signatures align with the current neighbor-list definitions; no annotation mismatches detected.torch_sim/neighbors.py (3)
177-177
: LGTM! Appropriate type ignore annotations for PyTorch overload resolution.The
# type: ignore[call-overload]
annotations are correctly applied to PyTorch API calls where the type checker cannot resolve the correct overload. These are necessary for maintaining type safety while avoiding performance-degrading.item()
calls, as mentioned in the retrieved learnings.Also applies to: 196-197, 230-231, 247-249
366-369
: LGTM! Variable renaming improves code clarity.The renaming from
bin_cnt
tobin_cnt_sort_idx
better reflects the variable's purpose as sorting indices, making the code more self-documenting.
643-643
: LGTM! Clean handling of optional cell parameter.The addition of
cell: torch.Tensor | None
parameter and the conditional logic properly handle both periodic and non-periodic systems:
- When
cell
is None, distances are computed directly without cell shifts- When
cell
is provided, it uses the newcompute_cell_shifts_strict
functionThe implementation maintains backward compatibility while extending functionality.
Also applies to: 661-662, 692-697
165027d
to
9ac309a
Compare
# Convert center to tensor | ||
if not hasattr(center, "__len__"): | ||
center = (center,) * 3 | ||
center = torch.tensor(center, dtype=positions.dtype, device=device) | ||
center_pos = torch.tensor((center,) * 3, dtype=positions.dtype, device=device) | ||
else: | ||
center_pos = torch.tensor(center, dtype=positions.dtype, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels like a vague way to check this? maybe an improvement to do isinstance(x, float)
and isinstance(x, tuple) and len(x)==3 and all(isinstance(y, float) for y in x)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think isinstance(x, float)
is sufficient. because center is defined as tuple[float, float, float] | float
return cell_shifts | ||
|
||
|
||
def compute_cell_shifts_strict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not call this function from inside compute_cell_shifts
? would be more DRY
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two nits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
torch_sim/transforms.py (1)
558-570
: Excellent separation of optional and strict variantsThis is a well-designed solution to handle mypy type checking while maintaining flexibility:
compute_cell_shifts
handles the optional case by returning None when cell is Nonecompute_cell_shifts_strict
provides a non-optional variant for internal use where cell is guaranteed to be non-NoneHowever, I notice there's a past review comment suggesting to call
compute_cell_shifts_strict
from insidecompute_cell_shifts
to be more DRY. The current implementation already does this correctly on line 560.
🧹 Nitpick comments (1)
torch_sim/transforms.py (1)
492-492
: Address mypy type check warningThe type ignore comment suggests mypy is complaining about the
torch.arange
call. This is likely due to thedtype
parameter type annotation.Consider checking if this can be resolved without the type ignore:
r1 = torch.arange( -num_repeats[ii], num_repeats[ii] + 1, device=num_repeats.device, dtype=dtype, )If the issue persists, the type ignore is acceptable but consider adding a more descriptive comment explaining why it's needed.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
torch_sim/transforms.py
(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
torch_sim/transforms.py (1)
tests/test_transforms.py (3)
test_wrap_positions_matches_ase
(200-217)test_wrap_positions_basic
(220-227)test_compute_distances_with_cell_shifts
(1168-1179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (33)
- GitHub Check: test-examples (examples/scripts/7_Others/7.2_Stress_autograd.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: build-docs
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
🔇 Additional comments (9)
torch_sim/transforms.py (9)
8-8
: LGTM - Good addition for type annotation supportThe import of
Callable
fromcollections.abc
is appropriate for the type annotations used in function parameters later in the file.
377-380
: Implementation handles both scalar and tuple center values correctlyThe logic correctly handles both cases:
- Single float: creates a 3-tuple with the same value repeated
- Tuple: uses the provided tuple directly
The implementation preserves device and dtype consistency.
389-389
: Consistent use of center_pos variableGood refactoring to use the
center_pos
variable consistently instead of the originalcenter
parameter. This ensures type safety and consistent behavior regardless of the input format.Also applies to: 397-397
501-501
: Good type annotation for optional cell_shifts parameterThe addition of
| None
to the type annotation correctly reflects thatcell_shifts
can be None, improving type safety.Also applies to: 517-517
539-540
: Good type annotation updates for optional cell parameterThe type annotations correctly indicate that both the
cell
parameter and return value can be None.
857-857
: Consistent use of strict variant where cell is guaranteed non-NoneGood choice to use
compute_cell_shifts_strict
here since the cell parameter is guaranteed to be non-None in this context (it's accessed viacell.view(-1, 3, 3)
).
1110-1110
: Improved type flexibility for dimension parameterThe expanded type annotation
int | list[int] | tuple[int, ...] | None
better reflects the actual accepted types for PyTorch's sum function, improving type safety.
1147-1147
: More flexible callable type annotationChanging from
torch.jit.ScriptFunction
toCallable[..., torch.Tensor]
is a good improvement that:
- Increases flexibility by accepting any callable that returns a tensor
- Improves type checking compatibility
- Maintains the same functional requirements
Also applies to: 1159-1159
352-352
: Approve — backward-compatibility forcenter
verifiedTuple inputs remain supported (converted with torch.tensor) and the default is unchanged; no call sites in the repo pass
center=
explicitly. No action required.
Summary
the test_io, neighbors, and transform files now pass mypy.
The main issue that was revealed to me is that pbc is inconsistent in the codebase. we use bools, tuple of bools, tensors, etc. we should just move to a tensor of bools later. but ik we don't support axis-aware pbc yet so that's for a future PR.
Checklist
Before a pull request can be merged, the following items must be checked:
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run
pip install pre-commit && pre-commit install
to install the hooks which will check your code before each commit.Summary by CodeRabbit
New Features
Tests
Refactor