Skip to content

Commit 9aefcf8

Browse files
committed
Add instruction for exporting inlined constant
1 parent 6f423d0 commit 9aefcf8

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

docs/source/features/stablehlo.md

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ There are 2 ways to accomplish this:
1717
from torch.export import export
1818
import torchvision
1919
import torch
20-
import torch_xla2 as tx
21-
import torch_xla2.export
20+
import torchax as tx
21+
import torchax.export
2222

2323
resnet18 = torchvision.models.resnet18()
2424
# Sample input is a tuple
@@ -64,6 +64,31 @@ print(stablehlo.mlir_module())
6464
The second to last line we used `jax.ShapedDtypeStruct` to specify the input shape.
6565
You can also pass a numpy array here.
6666

67+
### Inline some weights in generated stablehlo
68+
69+
Suppose that you want to inline some (or all) of the model's weight
70+
into the generated StableHLO graph as constant. You can accomplish it by
71+
exporting a different function that calls your model.
72+
73+
The convention used in `jax.jit` is that, all the input of the `jit`'d python
74+
functions are exported as parameters, and everything else are inlined as constants.
75+
76+
So as above, the function we exported `jfunc` takes `weights` and `args` as input, so
77+
they appear as paramters.
78+
79+
If you do this instead:
80+
81+
```
82+
def jfunc_inlined(args):
83+
return jfunc(weights, args)
84+
```
85+
and export / print out stablehlo for that:
86+
87+
```
88+
print(jax.jit(jfunc_inlined).lower((jax.ShapedDtypeStruct((4, 3, 224, 224), jnp.float32.dtype, ))))
89+
```
90+
Then, you will see inlined constants.
91+
6792

6893
## Preserving High-Level PyTorch Operations in StableHLO by generating `stablehlo.composite`
6994

0 commit comments

Comments
 (0)