Problem summary
- The
extra/torch_backend shim contains many workarounds that materialize or force contiguity instead of correctly supporting arbitrary strides/as_strided semantics. Those hacks show up in places like empty_strided, _copy_from, _as_strided and elsewhere.
- The movement-op layer (
extra/to_movement_ops.py) currently treats stride changes very restrictively (it only treats -1/1 flips as STRIDE), which prevents correct reconstruction of arbitrary views.
- Result: certain PyTorch programs (especially ones relying on non-contiguous layouts,
as_strided, or specific stride-based views) either get incorrect results or require expensive/incorrect copies.
Top-level goal (north star)
- Make the
tiny PyTorch backend fully correct for PyTorch semantics regarding views and strides so that:
aten::as_strided, aten::empty_strided, and related view APIs are supported without forcing contiguity or doing ad-hoc copies.
- Copies and assignments preserve requested shape + stride semantics where possible.
- The backend’s behavior matches PyTorch for shape/stride/offset semantics (including negative strides) with well-defined performance characteristics (prefer lazy, minimal-realize paths).
Primary acceptance criteria
- Correctness:
aten::as_strided and aten::empty_strided return tensors with the exact requested shape/stride/storage_offset semantics (not just a contiguous copy unless that is the only valid representation).
- Copies between tensors with different stride patterns produce identical results to PyTorch (including negative strides).
- No new regressions for existing PyTorch backend tests (
extra/torch_backend/torch_tests.py should pass or show reduced failures).
- Tests:
- Add unit tests exercising positive/negative/non-unit strides and
empty_strided.
- Add end-to-end smoke test: run a small upstream script (e.g., hlb-CIFAR10 steps=5) with the
tiny torch shim and demonstrate a numeric loss (no crash).
- Maintainability:
- Remove ad-hoc comments like "this only solves some cases" where a proper implementation exists.
- Document any remaining unavoidable materializations and why they’re necessary.
Key files to inspect and why
extra/torch_backend/backend.py
- Central PyTorch compatibility layer. Look for:
aten::as_strided / _as_strided
aten::empty_strided / empty.memory_format
_copy_from, _reshape_alias, uses of .contiguous() / .realize() as workarounds
- These are where the current hacks and TODOs are concentrated.
extra/to_movement_ops.py
- Converts
ShapeTracker views into a sequence of MovementOps (RESHAPE, PERMUTE, EXPAND, PAD, SHRINK, STRIDE, AS_STRIDED).
- Current
STRIDE handling asserts only [-1, 1] values and maps to axis flip. This is a key limitation to fix.
extra/torch_backend/wrapped_tensor.cpp
- C++ glue used to map tinygrad
Tensor to PyTorch torch.Tensor. Must ensure stride + offset metadata is passed correctly.
extra/torch_backend/torch_tests.py
- Test harness for PyTorch ops against the
tiny backend. Use this to validate behavior.
tinygrad/tinygrad/tensor.py
- Core semantics for views,
contiguous(), realize(), assign(), and other methods the backend relies on.
- Any ShapeTracker / View API that the movement layer expects
- Many files import
tinygrad.shape.shapetracker and tinygrad.shape.view (see extra/to_movement_ops.py), so ensure you find the canonical implementation in the tree (the codebase may place it under tinygrad/shape or tinygrad/shape/*).
Key technical changes to enable proper strides
- Represent arbitrary stride patterns in the view machinery
- Ensure
View/ShapeTracker can represent:
- per-dimension stride (positive/negative integer)
- storage offset (index into flat buffer)
- masks (for shrinks/pads) and contiguous flag
- If representation already exists, confirm api and tests; if gaps exist, extend it.
- Movement op generation / application
- Update
extra/to_movement_ops.py:
- Remove the assert that restricts
STRIDE to [-1,1].
- Either:
- Implement
STRIDE semantics that can express arbitrary stride steps (e.g., step > 1) and negative strides, or
- Prefer
AS_STRIDED semantics (explicitly generate an AS_STRIDED op carrying the stride tuple + offset) when arbitrary stride is needed.
- Ensure
apply_mop can apply these ops to a ShapeTracker and that the resulting scratch shape/buffer size calculation is correct (get_buffer_size and get_real_view helpers are relevant).
- Keep
to_movement_ops's correctness checks (e.g., test_rebuild) and extend them to validate arbitrary stride cases.
- Allocation for strided buffers
- Implement
aten::empty_strided to allocate a base buffer of the correct minimal size (via get_buffer_size) and return a view with requested shape, stride, and storage_offset (not a contiguous .contiguous() tensor).
- Confirm APIs to construct a
View with explicit stride and offset exist (e.g., View.create(...) used already in backend.py).
- Copy / assign semantics
- Improve
_copy_from and assign paths:
- If dest layout matches requested stride exactly, perform assign without extra copies.
- If not, produce an intermediate view that maps source into target layout without forcing full realization or, if necessary, create a temporary buffer that has the exact layout (and limit such cases).
- Avoid blind
src = src.contiguous() fixes in _copy_from. Instead, add code to transform layout correctly using movement ops if possible.
- C++ wrapper metadata
- Update
extra/torch_backend/wrapped_tensor.cpp to pass strides and storage_offset metadata into PyTorch wrapper tensors so PyTorch clients see correct shapes & strides.
- Ensure
unwrap/wrap preserve stride + offset semantics and device mapping.
Testing strategy (unit + integration)
- Unit tests to add:
test_as_strided_basic: create a base tensor and use as_strided to create view with arbitrary stride (positive, negative, step>1) and assert read/write semantics match PyTorch.
test_empty_strided_alloc: empty_strided returns a tensor with correct shape and stride and numel corresponds to correct buffer size; writing then reading yields expected layout.
test_copy_between_strides: copy between tensors with different strides (source contiguous, dest non-contig; source non-contig, dest contiguous; both non-contig with different patterns). Validate values match PyTorch.
test_negative_stride: slicing with negative steps (e.g., tensor.flip(dim)) and as_strided negative strides should behave like PyTorch.
- Add tests to
extra/to_movement_ops.py to validate to_movement_ops and apply_mop with arbitrary stride/view combinations (cover STRIDE and/or AS_STRIDED cases).
- Integration smoke:
- Run
extra/torch_backend/torch_tests.py to exercise many aten ops with tiny backend.
- Run
examples/hlb_cifar10.py (or the upstream hlb-CIFAR10 main script) with `STE
PS=5using thetiny` torch shim to verify training proceeds end-to-end.
- Continuous validation:
- Add CI job (optional first, then required) that runs the smoke test on PRs touching backend/shape code.
Concrete step-by-step plan / milestones
- Investigation (0.5–1 day)
- Reproduce failing cases: run
extra/torch_backend/torch_tests.py and any existing failing tests that hint at stride issues.
- Find exact places in
backend.py that currently force .contiguous() or otherwise bypass stride semantics.
- Confirm where
View and ShapeTracker implementations live and their API (if not located, search the repo for class View / ShapeTracker).
- Movement ops fix (1–2 days)
- Remove the
STRIDE restriction; implement a richer STRIDE or AS_STRIDED handling in apply_mop.
- Add unit tests for
to_movement_ops with non-trivial stride patterns.
empty_strided and as_strided (1 day)
- Implement
aten::empty_strided to allocate correct buffer size and return a proper view (no forced contiguous).
- Rework
_as_strided / _reshape_alias to use the improved movement op path instead of heavy weight or incorrect fallbacks.
- Copy/assign semantics (1–2 days)
- Rework
_copy_from to avoid blind contiguous() when possible; perform minimal necessary transformation or create layout-matching temporary buffer.
- C++ wrapper updates (0.5–1 day)
- Ensure
wrapped_tensor.cpp exposes the stride + offset metadata properly.
- Testing & cleanup (1–2 days)
- Add tests listed above.
- Run
extra/torch_backend/torch_tests.py and fix any regressions.
- Add the hlb-CIFAR10 smoke run and iterate until it succeeds on small steps.
- PR: include tests, documentation, and a short migration note.
- Title: "torch backend: proper support for arbitrary strides and as_strided"
- Include before/after test logs for
torch_tests.py and smoke test.
Helpful immediate edits you can make now (quick wins)
- Replace comments that call out hacks with
TODO + link to this document and a short note on expected semantics. This frames future work and prevents accidental re-hacks.
- Add targeted unit test(s) for
as_strided that currently fail — these will guide the implementation and act as regression tests.
Risk analysis & edge cases
- Performance: correct stride support might require extra kernel logic or buffer rearrangement depending on the computation backend. Prioritize correctness first; later optimize to avoid copies where possible.
- Multi-device / sharded tensors: if
to()/shard() paths move data across devices, stride semantics must be preserved or explicitly documented as normalized (e.g. contiguous) on transfer. Start with single-device correctness.
- ShapeTracker / symbolic dims: if shapes are symbolic (Variables), movement-op generation must still produce correct ops or defer realization. Ensure
to_movement_ops handles symbolic cases robustly (it already contains symbolic checks — extend if needed).
- Some code paths previously relied on "realize()" side-effects. Replacing hacks may expose latent bugs requiring careful testing.
PR checklist (what a completed PR should include)
- Implementation changes in
extra/to_movement_ops.py and extra/torch_backend/backend.py.
- Any necessary updates to
extra/torch_backend/wrapped_tensor.cpp.
- New unit tests for
as_strided, empty_strided, negative strides, and copying between varied layouts.
- Updated or added comments documenting the design and any trade-offs.
- No unrelated whitespace or style changes (follow tinygrad style rules).
- Run
extra/torch_backend/torch_tests.py and include test results in PR description (or CI logs).
Suggested branch and PR title
- Branch name:
fix/torch-backend-strides
- PR title:
torch backend: support arbitrary strides / as_strided without hacks
References (local files to start with)
extra/torch_backend/backend.py
extra/to_movement_ops.py
extra/torch_backend/wrapped_tensor.cpp
extra/torch_backend/torch_tests.py
tinygrad/tinygrad/tensor.py
Problem summary
extra/torch_backendshim contains many workarounds that materialize or force contiguity instead of correctly supporting arbitrarystrides/as_stridedsemantics. Those hacks show up in places likeempty_strided,_copy_from,_as_stridedand elsewhere.extra/to_movement_ops.py) currently treats stride changes very restrictively (it only treats -1/1 flips asSTRIDE), which prevents correct reconstruction of arbitrary views.as_strided, or specific stride-based views) either get incorrect results or require expensive/incorrect copies.Top-level goal (north star)
tinyPyTorch backend fully correct for PyTorch semantics regarding views and strides so that:aten::as_strided,aten::empty_strided, and related view APIs are supported without forcing contiguity or doing ad-hoc copies.Primary acceptance criteria
aten::as_stridedandaten::empty_stridedreturn tensors with the exact requested shape/stride/storage_offset semantics (not just a contiguous copy unless that is the only valid representation).extra/torch_backend/torch_tests.pyshould pass or show reduced failures).empty_strided.tinytorch shim and demonstrate a numeric loss (no crash).Key files to inspect and why
extra/torch_backend/backend.pyaten::as_strided/_as_stridedaten::empty_strided/empty.memory_format_copy_from,_reshape_alias, uses of.contiguous()/.realize()as workaroundsextra/to_movement_ops.pyShapeTrackerviews into a sequence of MovementOps (RESHAPE, PERMUTE, EXPAND, PAD, SHRINK, STRIDE, AS_STRIDED).STRIDEhandling asserts only [-1, 1] values and maps to axis flip. This is a key limitation to fix.extra/torch_backend/wrapped_tensor.cppTensorto PyTorchtorch.Tensor. Must ensure stride + offset metadata is passed correctly.extra/torch_backend/torch_tests.pytinybackend. Use this to validate behavior.tinygrad/tinygrad/tensor.pycontiguous(),realize(),assign(), and other methods the backend relies on.tinygrad.shape.shapetrackerandtinygrad.shape.view(seeextra/to_movement_ops.py), so ensure you find the canonical implementation in the tree (the codebase may place it undertinygrad/shapeortinygrad/shape/*).Key technical changes to enable proper strides
View/ShapeTrackercan represent:extra/to_movement_ops.py:STRIDEto [-1,1].STRIDEsemantics that can express arbitrary stride steps (e.g., step > 1) and negative strides, orAS_STRIDEDsemantics (explicitly generate anAS_STRIDEDop carrying the stride tuple + offset) when arbitrary stride is needed.apply_mopcan apply these ops to aShapeTrackerand that the resulting scratch shape/buffer size calculation is correct (get_buffer_sizeandget_real_viewhelpers are relevant).to_movement_ops's correctness checks (e.g.,test_rebuild) and extend them to validate arbitrary stride cases.aten::empty_stridedto allocate a base buffer of the correct minimal size (viaget_buffer_size) and return a view with requestedshape,stride, andstorage_offset(not a contiguous.contiguous()tensor).Viewwith explicit stride and offset exist (e.g.,View.create(...)used already inbackend.py)._copy_fromandassignpaths:src = src.contiguous()fixes in_copy_from. Instead, add code to transform layout correctly using movement ops if possible.extra/torch_backend/wrapped_tensor.cppto passstridesandstorage_offsetmetadata into PyTorch wrapper tensors so PyTorch clients see correct shapes & strides.unwrap/wrappreserve stride + offset semantics and device mapping.Testing strategy (unit + integration)
test_as_strided_basic: create a base tensor and useas_stridedto create view with arbitrary stride (positive, negative, step>1) and assert read/write semantics match PyTorch.test_empty_strided_alloc:empty_stridedreturns a tensor with correctshapeandstrideandnumelcorresponds to correct buffer size; writing then reading yields expected layout.test_copy_between_strides: copy between tensors with different strides (source contiguous, dest non-contig; source non-contig, dest contiguous; both non-contig with different patterns). Validate values match PyTorch.test_negative_stride: slicing with negative steps (e.g.,tensor.flip(dim)) andas_stridednegative strides should behave like PyTorch.extra/to_movement_ops.pyto validateto_movement_opsandapply_mopwith arbitrary stride/view combinations (coverSTRIDEand/orAS_STRIDEDcases).extra/torch_backend/torch_tests.pyto exercise many aten ops withtinybackend.examples/hlb_cifar10.py(or the upstream hlb-CIFAR10 main script) with `STEPS=5
using thetiny` torch shim to verify training proceeds end-to-end.Concrete step-by-step plan / milestones
extra/torch_backend/torch_tests.pyand any existing failing tests that hint at stride issues.backend.pythat currently force.contiguous()or otherwise bypass stride semantics.ViewandShapeTrackerimplementations live and their API (if not located, search the repo forclass View/ShapeTracker).STRIDErestriction; implement a richerSTRIDEorAS_STRIDEDhandling inapply_mop.to_movement_opswith non-trivial stride patterns.empty_stridedandas_strided(1 day)aten::empty_stridedto allocate correct buffer size and return a proper view (no forced contiguous)._as_strided/_reshape_aliasto use the improved movement op path instead of heavy weight or incorrect fallbacks._copy_fromto avoid blindcontiguous()when possible; perform minimal necessary transformation or create layout-matching temporary buffer.wrapped_tensor.cppexposes the stride + offset metadata properly.extra/torch_backend/torch_tests.pyand fix any regressions.torch_tests.pyand smoke test.Helpful immediate edits you can make now (quick wins)
TODO+ link to this document and a short note on expected semantics. This frames future work and prevents accidental re-hacks.as_stridedthat currently fail — these will guide the implementation and act as regression tests.Risk analysis & edge cases
to()/shard()paths move data across devices, stride semantics must be preserved or explicitly documented as normalized (e.g. contiguous) on transfer. Start with single-device correctness.to_movement_opshandles symbolic cases robustly (it already contains symbolic checks — extend if needed).PR checklist (what a completed PR should include)
extra/to_movement_ops.pyandextra/torch_backend/backend.py.extra/torch_backend/wrapped_tensor.cpp.as_strided,empty_strided, negative strides, and copying between varied layouts.extra/torch_backend/torch_tests.pyand include test results in PR description (or CI logs).Suggested branch and PR title
fix/torch-backend-stridestorch backend: support arbitrary strides / as_strided without hacksReferences (local files to start with)
extra/torch_backend/backend.pyextra/to_movement_ops.pyextra/torch_backend/wrapped_tensor.cppextra/torch_backend/torch_tests.pytinygrad/tinygrad/tensor.py