[Enh] Improve type closure for primitive func#552
Conversation
|
may need update the same aiter kernels simultaneously. |
There was a problem hiding this comment.
Pull request overview
This PR updates FlyDSL’s Python expression-layer wrappers to better preserve (“close over”) DSL types (e.g., Numeric, Vector) when calling primitive ops, and adjusts affected tests/kernels to use the new return-value behavior.
Changes:
- Wrap primitive-op scalar results back into
Numerictypes (and propagate this through helpers likeget_scalar,get_leaves,ptr_load,memref_load). - Centralize
memref_load_vecto return aVectorwith shape/dtype metadata, and simplifyTensor.load()accordingly. - Update unit tests and a few kernel utilities to align with the revised scalar/vector return types and pipeline string formatting.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit/test_static_vs_dynamic.py | Adjusts dynamic-layout tests to return i32 scalars directly and simplifies pipeline string formatting. |
| tests/unit/test_layout_algebra.py | Updates dynamic test functions to return i32 scalars from fx.get_scalar(...).ir_value() and simplifies pipeline string formatting. |
| python/flydsl/expr/typing.py | Updates IntTuple reconstruction and makes Tensor.load() rely on the updated memref_load_vec wrapper. |
| python/flydsl/expr/primitive.py | Introduces numeric re-wrapping helper and applies it across several primitive ops; moves vector wrapping into memref_load_vec. |
| python/flydsl/expr/math.py | Extends traced math-op wrapping to preserve DSL closure for both Numeric and Vector inputs. |
| kernels/silu_and_mul_fq.py | Simplifies scale-offset computation by relying on the updated fx.get_scalar behavior. |
| kernels/mfma_preshuffle_pipeline.py | Updates crd2idx helper to unwrap int-tuples and cast to index type using the new scalar typing behavior. |
| kernels/layout_utils.py | Updates dynamic-layout crd2idx fallback to unwrap/cast through fx.get_scalar(...).ir_value(). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
CI failed. @sjfeng1999 |
205ec38 to
abb308a
Compare
jhinpan
left a comment
There was a problem hiding this comment.
I found two issues that still apply on the current head. The earlier mixed_moe index/i32 arithmetic concern appears fixed in this version: sorted_m is back to an index constant path.
Ensure type closure across the primitive/numeric/math expr layer: covariant int_tuple handling, drop redundant coercion in derived, remove IR value plumbing, fix idx2crd DSL type compatibility, and re-export select. Rebased onto latest main; the legacy python/flydsl/expr/rocdl.py is dropped in favor of the expr/rocdl/ package (removed upstream in #677). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
573a7ee to
4fff65b
Compare
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist