Skip to content

Conversation

curtischong
Copy link
Collaborator

@curtischong curtischong commented Aug 30, 2025

Summary

the test_io, neighbors, and transform files now pass mypy.

The main issue that was revealed to me is that pbc is inconsistent in the codebase. we use bools, tuple of bools, tensors, etc. we should just move to a tensor of bools later. but ik we don't support axis-aware pbc yet so that's for a future PR.

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
  • Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run pip install pre-commit && pre-commit install to install the hooks which will check your code before each commit.

Summary by CodeRabbit

  • New Features

    • Neighbor-list generation now works for systems without a unit cell.
    • Position wrapping accepts either a scalar or a 3-value center.
    • Distance and cell-shift handling can operate without cell data; internal shift application is optional.
    • Summation and masking utilities accept broader dim/argument types and generic callables.
  • Tests

    • Tests updated to use concrete fixtures, explicit typing, order-insensitive distance comparisons, and consistent tensor scalar usage.
  • Refactor

    • Typing and annotations across neighbor/transform utilities clarified.

@cla-bot cla-bot bot added the cla-signed Contributor license agreement signed label Aug 30, 2025
Copy link

coderabbitai bot commented Aug 30, 2025

Walkthrough

Tests were refactored to use concrete fixtures and order-insensitive distance comparisons; transforms tests use 0‑D torch scalars and explicit int casts. torch_sim.neighbors and torch_sim.transforms were updated to accept optional cell/cell_shifts (None), added overloads/typing adjustments, and branched distance/shift logic accordingly.

Changes

Cohort / File(s) Summary
IO ImportError tests refactor
tests/test_io.py
Six ImportError tests now accept concrete fixtures (si_sim_state, si_atoms, si_phonopy_atoms, si_structure) and pass them into the corresponding ts.io conversion calls instead of None; test signatures updated.
Neighbor tests — typing & stable comparisons
tests/test_neighbors.py
Added return types for helper factories (periodic_atoms_set, molecule_atoms_set), switched tests to use atoms_list_fixture via request.getfixturevalue, annotated nl_implementation as Callable, and made distance comparisons order-insensitive by sorting arrays before assertions.
Transforms tests — scalar types & integers
tests/test_transforms.py
Ensured integer loop bounds (explicit int(...) cast) and passed r_onset/r_cutoff as 0‑D torch.tensor scalars (torch.tensor(1.0)) across multiple cutoff-related tests.
Neighbors core — optional cell & typing fixes
torch_sim/neighbors.py
strict_nl now accepts `cell: torch.Tensor
Transforms API — optional cell/cell_shifts & typing
torch_sim/transforms.py
compute_cell_shifts and compute_distances_with_cell_shifts accept None for cell/cell_shifts (overloads added); wrap_positions.center accepts `tuple

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Test as Test/Caller
  participant Neigh as torch_sim/neighbors.py
  participant Trans as torch_sim/transforms.py

  Test->>Neigh: strict_nl(cutoff, positions, cell|None, mapping, system_mapping, shifts_idx)
  alt cell is None
    Note over Neigh: skip cell-shift computation (None)
    Neigh->>Trans: compute_distances_with_cell_shifts(pos, mapping, cell_shifts=None)
    Trans-->>Neigh: distances (no shifts)
  else cell provided
    Note over Neigh: compute & apply cell shifts
    Neigh->>Trans: compute_cell_shifts(cell, shifts_idx, system_mapping)
    Trans-->>Neigh: cell_shifts
    Neigh->>Trans: compute_distances_with_cell_shifts(pos, mapping, cell_shifts)
    Trans-->>Neigh: distances (with shifts)
  end
  Neigh-->>Test: neighbor_pairs, distances, metadata
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • CompRhys

In whiskers, hops, and tiny paws,
I sort the distances without a pause.
Optional cells let shifts be free,
Fixtures snug — neat tests from me.
Carrot-coded cheer for CI glee 🥕

Pre-merge checks (3 passed)

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "Type test_io, neighbors, and transforms" is concise and accurately captures the primary change—adding/adjusting type annotations across those tests and modules to satisfy mypy—so it directly relates to the changeset and is clear to reviewers scanning PR history. It is a single short sentence that highlights the main intent without extraneous detail. Therefore it aligns with the PR title guidance.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.

✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch type-other-files

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@curtischong curtischong force-pushed the type-other-files branch 2 times, most recently from 4ccee76 to a5ee051 Compare August 30, 2025 23:59
@curtischong curtischong force-pushed the add-mypy branch 5 times, most recently from ee412da to 1b7e37c Compare September 4, 2025 01:25
Base automatically changed from add-mypy to main September 5, 2025 21:59
@curtischong curtischong force-pushed the type-other-files branch 3 times, most recently from b3eadc5 to e9261e2 Compare September 9, 2025 17:05
@curtischong curtischong changed the title Type other files Type io, neighbors, and transforms Sep 9, 2025
@curtischong curtischong force-pushed the type-other-files branch 3 times, most recently from 09edd54 to b2b7b1a Compare September 9, 2025 18:19
fix more type issues

fixing types for creating cell shifts

fix type defining the fixture to use

neighbor improvements

wip type runners

fix types in trajectory

backup before thinking of messing with autobatching

transforms is typed

made test_io conform to types

lint

lint fixes

fix transforms code

fix safemask type

revert trajectory file

rm runners changes

fix desc for fn

ignore call-arg for pbc
@curtischong curtischong changed the title Type io, neighbors, and transforms Type test_io, neighbors, and transforms Sep 9, 2025
@curtischong curtischong marked this pull request as ready for review September 9, 2025 18:43
@curtischong curtischong requested a review from CompRhys September 9, 2025 18:43
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (6)
tests/test_io.py (1)

264-274: ImportError tests: fixture-based calls are correct; minor hardening suggestion

Using real fixtures instead of None exercises imports earlier and is more realistic. To make the import failure simulation a bit more robust (and concise), consider monkeypatch.dict on sys.modules for the targets instead of multiple setitem calls.

Example:

- monkeypatch.setitem(sys.modules, "ase", None)
- monkeypatch.setitem(sys.modules, "ase.data", None)
+ monkeypatch.dict(sys.modules, {"ase": None, "ase.data": None})

Also applies to: 276-287, 289-300, 302-312, 314-325, 327-337

torch_sim/transforms.py (4)

352-355: wrap_positions: center scalar support is good—update docstring and keep type clarity

Allowing a scalar center is ergonomic. Please update the “Args” doc to reflect tuple[float, float, float] | float, and avoid ambiguity.

-        center (Tuple[float, float, float]): Center of the cell as
-            (x,y,z) tuple, defaults to (0.5, 0.5, 0.5).
+        center (tuple[float, float, float] | float): Center of the cell
+            as (x,y,z) tuple or scalar, defaults to (0.5, 0.5, 0.5).

Also applies to: 377-381, 389-399


469-496: Prefer public torch.dtype instead of private torch.types._dtype

Using the private _dtype type and a type: ignore can be avoided by switching to torch.dtype.

-from torch.types import _dtype
...
-def get_cell_shift_idx(num_repeats: torch.Tensor, dtype: _dtype) -> torch.Tensor:
+def get_cell_shift_idx(num_repeats: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
@@
-        )  # type: ignore[call-overload]
+        )

If _dtype isn’t used elsewhere, remove the import to keep things clean.


538-565: compute_cell_shifts: doc return type/shape mismatch

Function can return None and returns per-pair shifts (n_pairs, 3), not (n_systems, 3). Update the doc to match behavior.

-    Returns:
-        torch.Tensor: A tensor of shape (n_systems, 3) containing
-            the computed cell shifts.
+    Returns:
+        torch.Tensor | None: A tensor of shape (n_pairs, 3) containing
+            the computed cell shifts, or None if cell is None.

1165-1169: safe_mask: avoid evaluating fn on masked-out values

Current implementation may compute fn on zeros at masked-out positions (e.g., log(0) -> -inf), which is later hidden but can trigger warnings/NaNs in grads. Apply fn only where mask is True.

-def safe_mask(
-    mask: torch.Tensor,
-    fn: Callable[..., torch.Tensor],
-    operand: torch.Tensor,
-    placeholder: float = 0.0,
-) -> torch.Tensor:
+def safe_mask(
+    mask: torch.Tensor,
+    fn: Callable[..., torch.Tensor],
+    operand: torch.Tensor,
+    placeholder: float = 0.0,
+) -> torch.Tensor:
@@
-    masked = torch.where(mask, operand, torch.zeros_like(operand))
-    return torch.where(mask, fn(masked), torch.full_like(operand, placeholder))
+    out = torch.full_like(operand, placeholder)
+    if mask.any():
+        out[mask] = fn(operand[mask])
+    return out
tests/test_transforms.py (1)

701-702: Passing 0-dim tensors for r_onset/r_cutoff matches the API; optional ergonomics

Good mypy-friendly fix. Optionally, we could make multiplicative_isotropic_cutoff accept float | Tensor and as_tensor internally to keep tests (and users) free to pass floats.

If desired, I can draft a minimal patch to transforms.multiplicative_isotropic_cutoff to accept floats without changing types at call sites.

Also applies to: 719-721, 744-746, 767-768, 785-786, 803-805

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 390c071 and 9ac309a.

📒 Files selected for processing (5)
  • tests/test_io.py (1 hunks)
  • tests/test_neighbors.py (12 hunks)
  • tests/test_transforms.py (7 hunks)
  • torch_sim/neighbors.py (8 hunks)
  • torch_sim/transforms.py (13 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-31T11:15:22.654Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:1022-1024
Timestamp: 2025-08-31T11:15:22.654Z
Learning: In PyTorch code, avoid using .item() on tensors when performance is critical as it causes thread synchronization between GPU and CPU, breaking parallelism. Use more specific type ignore comments like `# type: ignore[arg-type]` instead of generic `# type: ignore` to satisfy linting rules while maintaining performance.

Applied to files:

  • torch_sim/neighbors.py
🧬 Code graph analysis (4)
tests/test_io.py (2)
tests/conftest.py (5)
  • si_sim_state (142-144)
  • si_atoms (89-91)
  • device (24-25)
  • si_phonopy_atoms (119-138)
  • si_structure (101-115)
torch_sim/io.py (6)
  • state_to_atoms (29-72)
  • state_to_phonopy (131-177)
  • state_to_structures (75-128)
  • atoms_to_state (180-245)
  • phonopy_to_state (316-391)
  • structures_to_state (248-313)
torch_sim/neighbors.py (1)
torch_sim/transforms.py (1)
  • compute_cell_shifts_strict (567-588)
tests/test_neighbors.py (1)
torch_sim/neighbors.py (3)
  • standard_nl (411-497)
  • vesin_nl (571-637)
  • vesin_nl_ts (501-568)
tests/test_transforms.py (1)
torch_sim/transforms.py (2)
  • cutoff_fn (1120-1122)
  • multiplicative_isotropic_cutoff (1067-1124)
🔇 Additional comments (16)
torch_sim/transforms.py (3)

498-536: compute_distances_with_cell_shifts: API widening looks good

cell_shifts: Optional[...] is a sensible, typed way to express “no shifts.” Logic and validation are clear.


567-589: compute_cell_shifts_strict: clear separation—LGTM

Non-None contract is explicit and mirrors compute_cell_shifts behavior. Good addition.


1127-1162: high_precision_sum: widened dim types—LGTM

Accepting int | list[int] | tuple[int, ...] | None aligns with torch.sum. Behavior preserved.

tests/test_transforms.py (1)

458-466: Loop bound cast to int avoids range TypeError—LGTM

Explicit int() on tensor.count ensures compatibility with range and reads clearly.

tests/test_neighbors.py (9)

2-2: LGTM! Good addition of explicit typing import.

The import of Callable from collections.abc supports the new type annotations in the test functions.


108-123: LGTM! Type annotations for fixture functions are clear and correct.

The return type annotations -> list[Atoms] for both periodic_atoms_set() and molecule_atoms_set() fixtures properly document their return types and help with static type checking.

Also applies to: 127-130


135-141: Good parameterization change for fixture handling.

The switch from direct fixture usage to parameterized fixture names with request.getfixturevalue() provides better test control and clearer test naming.

Also applies to: 156-156


215-215: LGTM! Proper handling of distance comparison ordering.

The change to sort distances before comparison (np.sort()) ensures order-insensitive comparison, which is the correct approach for validating neighbor list implementations that may return neighbors in different orders.

Also applies to: 241-242, 245-245, 249-249


254-265: LGTM! Consistent fixture parameterization pattern.

The parameterization pattern is consistently applied across test functions, and the Callable type annotation for nl_implementation provides proper typing.

Also applies to: 273-273


292-292: LGTM! Consistent distance sorting for reliable comparisons.

The sorted distance comparisons ensure that neighbor list implementations are validated correctly regardless of the order in which they return neighbor pairs.

Also applies to: 314-321


333-333: LGTM! Proper type annotation for callable parameter.

The Callable type annotation for nl_implementation maintains consistency with other test functions.


359-359: LGTM! Consistent sorting pattern maintained.

The distance sorting approach is consistently applied across all test functions for reliable neighbor list validation.

Also applies to: 372-372, 375-375


568-575: All test call signatures align with the current neighbor-list definitions; no annotation mismatches detected.

torch_sim/neighbors.py (3)

177-177: LGTM! Appropriate type ignore annotations for PyTorch overload resolution.

The # type: ignore[call-overload] annotations are correctly applied to PyTorch API calls where the type checker cannot resolve the correct overload. These are necessary for maintaining type safety while avoiding performance-degrading .item() calls, as mentioned in the retrieved learnings.

Also applies to: 196-197, 230-231, 247-249


366-369: LGTM! Variable renaming improves code clarity.

The renaming from bin_cnt to bin_cnt_sort_idx better reflects the variable's purpose as sorting indices, making the code more self-documenting.


643-643: LGTM! Clean handling of optional cell parameter.

The addition of cell: torch.Tensor | None parameter and the conditional logic properly handle both periodic and non-periodic systems:

  • When cell is None, distances are computed directly without cell shifts
  • When cell is provided, it uses the new compute_cell_shifts_strict function

The implementation maintains backward compatibility while extending functionality.

Also applies to: 661-662, 692-697

@curtischong curtischong force-pushed the type-other-files branch 2 times, most recently from 165027d to 9ac309a Compare September 9, 2025 21:23
Comment on lines 376 to 380
# Convert center to tensor
if not hasattr(center, "__len__"):
center = (center,) * 3
center = torch.tensor(center, dtype=positions.dtype, device=device)
center_pos = torch.tensor((center,) * 3, dtype=positions.dtype, device=device)
else:
center_pos = torch.tensor(center, dtype=positions.dtype, device=device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like a vague way to check this? maybe an improvement to do isinstance(x, float) and isinstance(x, tuple) and len(x)==3 and all(isinstance(y, float) for y in x)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think isinstance(x, float) is sufficient. because center is defined as tuple[float, float, float] | float

return cell_shifts


def compute_cell_shifts_strict(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not call this function from inside compute_cell_shifts? would be more DRY

Copy link
Member

@CompRhys CompRhys left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two nits

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
torch_sim/transforms.py (1)

558-570: Excellent separation of optional and strict variants

This is a well-designed solution to handle mypy type checking while maintaining flexibility:

  • compute_cell_shifts handles the optional case by returning None when cell is None
  • compute_cell_shifts_strict provides a non-optional variant for internal use where cell is guaranteed to be non-None

However, I notice there's a past review comment suggesting to call compute_cell_shifts_strict from inside compute_cell_shifts to be more DRY. The current implementation already does this correctly on line 560.

🧹 Nitpick comments (1)
torch_sim/transforms.py (1)

492-492: Address mypy type check warning

The type ignore comment suggests mypy is complaining about the torch.arange call. This is likely due to the dtype parameter type annotation.

Consider checking if this can be resolved without the type ignore:

r1 = torch.arange(
    -num_repeats[ii],
    num_repeats[ii] + 1,
    device=num_repeats.device,
    dtype=dtype,
)

If the issue persists, the type ignore is acceptable but consider adding a more descriptive comment explaining why it's needed.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1c2288e and 26d7d31.

📒 Files selected for processing (1)
  • torch_sim/transforms.py (13 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
torch_sim/transforms.py (1)
tests/test_transforms.py (3)
  • test_wrap_positions_matches_ase (200-217)
  • test_wrap_positions_basic (220-227)
  • test_compute_distances_with_cell_shifts (1168-1179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (33)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.2_Stress_autograd.py)
  • GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
  • GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/models/test_mace.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mattersim, tests/models/test_mattersim.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
  • GitHub Check: build-docs
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
🔇 Additional comments (9)
torch_sim/transforms.py (9)

8-8: LGTM - Good addition for type annotation support

The import of Callable from collections.abc is appropriate for the type annotations used in function parameters later in the file.


377-380: Implementation handles both scalar and tuple center values correctly

The logic correctly handles both cases:

  • Single float: creates a 3-tuple with the same value repeated
  • Tuple: uses the provided tuple directly

The implementation preserves device and dtype consistency.


389-389: Consistent use of center_pos variable

Good refactoring to use the center_pos variable consistently instead of the original center parameter. This ensures type safety and consistent behavior regardless of the input format.

Also applies to: 397-397


501-501: Good type annotation for optional cell_shifts parameter

The addition of | None to the type annotation correctly reflects that cell_shifts can be None, improving type safety.

Also applies to: 517-517


539-540: Good type annotation updates for optional cell parameter

The type annotations correctly indicate that both the cell parameter and return value can be None.


857-857: Consistent use of strict variant where cell is guaranteed non-None

Good choice to use compute_cell_shifts_strict here since the cell parameter is guaranteed to be non-None in this context (it's accessed via cell.view(-1, 3, 3)).


1110-1110: Improved type flexibility for dimension parameter

The expanded type annotation int | list[int] | tuple[int, ...] | None better reflects the actual accepted types for PyTorch's sum function, improving type safety.


1147-1147: More flexible callable type annotation

Changing from torch.jit.ScriptFunction to Callable[..., torch.Tensor] is a good improvement that:

  • Increases flexibility by accepting any callable that returns a tensor
  • Improves type checking compatibility
  • Maintains the same functional requirements

Also applies to: 1159-1159


352-352: Approve — backward-compatibility for center verified

Tuple inputs remain supported (converted with torch.tensor) and the default is unchanged; no call sites in the repo pass center= explicitly. No action required.

@curtischong curtischong merged commit 01782f8 into main Sep 12, 2025
92 of 93 checks passed
@curtischong curtischong deleted the type-other-files branch September 12, 2025 13:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla-signed Contributor license agreement signed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants