@@ -17,8 +17,8 @@ There are 2 ways to accomplish this:
17
17
from torch.export import export
18
18
import torchvision
19
19
import torch
20
- import torch_xla2 as tx
21
- import torch_xla2 .export
20
+ import torchax as tx
21
+ import torchax .export
22
22
23
23
resnet18 = torchvision.models.resnet18()
24
24
# Sample input is a tuple
@@ -64,6 +64,31 @@ print(stablehlo.mlir_module())
64
64
The second to last line we used ` jax.ShapedDtypeStruct ` to specify the input shape.
65
65
You can also pass a numpy array here.
66
66
67
+ ### Inline some weights in generated stablehlo
68
+
69
+ Suppose that you want to inline some (or all) of the model's weight
70
+ into the generated StableHLO graph as constant. You can accomplish it by
71
+ exporting a different function that calls your model.
72
+
73
+ The convention used in ` jax.jit ` is that, all the input of the ` jit ` 'd python
74
+ functions are exported as parameters, and everything else are inlined as constants.
75
+
76
+ So as above, the function we exported ` jfunc ` takes ` weights ` and ` args ` as input, so
77
+ they appear as paramters.
78
+
79
+ If you do this instead:
80
+
81
+ ```
82
+ def jfunc_inlined(args):
83
+ return jfunc(weights, args)
84
+ ```
85
+ and export / print out stablehlo for that:
86
+
87
+ ```
88
+ print(jax.jit(jfunc_inlined).lower((jax.ShapedDtypeStruct((4, 3, 224, 224), jnp.float32.dtype, ))))
89
+ ```
90
+ Then, you will see inlined constants.
91
+
67
92
68
93
## Preserving High-Level PyTorch Operations in StableHLO by generating ` stablehlo.composite `
69
94
0 commit comments