Skip to content

Commit 5c07dc6

Browse files
IvyZXFlax Authors
authored andcommitted
Rolling back #4686 as it broke internal tests.
PiperOrigin-RevId: 745744499
1 parent fe0b35d commit 5c07dc6

File tree

5 files changed

+365
-493
lines changed

5 files changed

+365
-493
lines changed

docs/guides/quantization/fp8_basics.ipynb

Lines changed: 90 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
"as quantization (Q). Conversely, de-quantization (DQ) rescales the FP8 data back\n",
1414
"to its original type.\n",
1515
"\n",
16-
"While jnp.dot supports FP8 inputs directly, proper quantization and\n",
17-
"dequantization is needed for optimal performance. Flax provides\n",
18-
"nn.fp8_ops.Fp8DotGeneral and nn.fp8_ops.Fp8Einsum modules that handle\n",
19-
"this automatically and can be used with existing layers like nn.Dense.\n",
16+
"Although jnp.dot supports FP8 inputs, certain limitations make it impractical\n",
17+
"for real-world applications. Alternatively, XLA, our compiler, can recognize\n",
18+
"patterns like <FP8>->DQ->Dot and subsequently invoke FP8 backends (e.g.,\n",
19+
"cublasLt for GPUs). FLAX encapsulates such patterns into the\n",
20+
"nn.fp8_ops.Fp8DotGeneralOp module, allowing users to easily configure it for\n",
21+
"existing layers (e.g., nn.Dense).\n",
2022
"\n",
2123
"This tutorial will walk you through the basics of how to use it.\n",
2224
"\n",
@@ -48,6 +50,7 @@
4850
"from flax.linen import fp8_ops\n",
4951
"\n",
5052
"e4m3 = jnp.float8_e4m3fn\n",
53+
"e5m2 = jnp.float8_e5m2\n",
5154
"f32 = jnp.float32\n",
5255
"E4M3_MAX = jnp.finfo(e4m3).max.astype(f32)\n",
5356
"\n",
@@ -79,29 +82,34 @@
7982
"metadata": {},
8083
"outputs": [],
8184
"source": [
82-
"k0, k1 = random.split(random.key(0), 2)\n",
83-
"a = random.uniform(k0, (16, 32))\n",
84-
"b = random.uniform(k1, (32, 64))\n",
85+
"key = random.key(0)\n",
86+
"A = random.uniform(key, (16, 32))\n",
87+
"B = random.uniform(key, (32, 64))\n",
8588
"@jax.jit\n",
86-
"def dot_fp8(a, b):\n",
87-
" return jnp.dot(a.astype(e4m3), b.astype(e4m3), preferred_element_type=f32)\n",
88-
"check_fp8_call(dot_fp8.lower(a, b))"
89+
"def dot_fp8(A, B):\n",
90+
" return jnp.dot(A.astype(e4m3), B.astype(e4m3), preferred_element_type=f32)\n",
91+
"check_fp8_call(dot_fp8.lower(A, B))"
8992
]
9093
},
9194
{
9295
"cell_type": "markdown",
9396
"id": "adb22878",
9497
"metadata": {},
9598
"source": [
96-
"However, this approach has two key limitations:\n",
99+
"However, there are two main issues with this approach. Firstly, `jnp.dot` does\n",
100+
"not accept scaling factors for the operands, defaulting to a scaling factor of\n",
101+
"1.0. Secondly, it does not support operands of mixed FP8 data types. For\n",
102+
"example, when the operands are E5M2 and E4M3, the dot product is performed using\n",
103+
"the promoted FP16 data type.\n",
97104
"\n",
98-
"1. `jnp.dot` does not support custom scaling factors for operands, defaulting to\n",
99-
" a scale of 1.0\n",
100-
"2. The autodiff does not automatically use E5M2 for gradients and E4M3 for\n",
101-
" activations/weights during training, which is the recommended practice\n",
105+
"In real-world scenarios, it is essential to specify scaling factors, either from\n",
106+
"calibration for inference or a user-defined algorithm during training.\n",
107+
"Additionally, it is common practice to use E5M2 for gradients and E4M3 for\n",
108+
"activations and kernels. These limitations make this method less practical for\n",
109+
"real-world applications.\n",
102110
"\n",
103-
"To overcome these limitations and implement proper FP8 matrix multiplication, we\n",
104-
"recommend using the Flax FP8 APIs. Let's start with a basic scaling approach.\n",
111+
"To address these limitations and create a more versatile FP8 dot product, we\n",
112+
"recommend leveraging XLA-FP8. Let's begin with a simple scaling strategy.\n",
105113
"\n",
106114
"\n",
107115
"### Current Scaling\n",
@@ -121,38 +129,36 @@
121129
"outputs": [],
122130
"source": [
123131
"@jax.jit\n",
124-
"def dot_fp8(a, b):\n",
125-
" a_scale = jnp.max(jnp.abs(A)) / E4M3_MAX\n",
126-
" b_scale = jnp.max(jnp.abs(B)) / E4M3_MAX\n",
127-
" a = fp8_ops.quantize(a, e4m3, a_scale, f32)\n",
128-
" b = fp8_ops.quantize(b, e4m3, b_scale, f32)\n",
129-
"\n",
130-
" c = jnp.dot(a, b, preferred_element_type=f32)\n",
131-
" c = fp8_ops.dequantize(c, f32, a_scale * b_scale)\n",
132-
" return c\n",
132+
"def dot_fp8(A, B):\n",
133+
" A_scale = jnp.max(jnp.abs(A)) / E4M3_MAX\n",
134+
" B_scale = jnp.max(jnp.abs(B)) / E4M3_MAX\n",
135+
" A = fp8_ops.quantize_dequantize(A, e4m3, A_scale, f32)\n",
136+
" B = fp8_ops.quantize_dequantize(B, e4m3, B_scale, f32)\n",
137+
"\n",
138+
" C = jnp.dot(A, B)\n",
139+
" return C\n",
133140
"\n",
134-
"c = dot_fp8(a, b)\n",
135-
"check_fp8_call(dot_fp8.lower(a, b))"
141+
"C = dot_fp8(A, B)\n",
142+
"check_fp8_call(dot_fp8.lower(A, B))"
136143
]
137144
},
138145
{
139146
"cell_type": "markdown",
140147
"id": "59aca6fe",
141148
"metadata": {},
142149
"source": [
143-
"As shown in the code, we perform quantization (`fp8_ops.quantize`) on the\n",
144-
"tensors to get the lower precision operands. The `jnp.dot` processes them and\n",
145-
"accumulates the output in high precision (i.e., the `preferred_element_type`).\n",
146-
"After that, we multiply the result by the scaling factors to dequantize back to\n",
147-
"the original range (`fp8_ops.dequantize`). Note that while this example uses\n",
148-
"E4M3 for both inputs, it is possible to use different FP8 dtypes like E4M3 and\n",
149-
"E5M2 for the inputs. The quantization method and the scaling factors can also be\n",
150-
"customized based on application needs.\n",
151-
"\n",
152-
"One major issue with the current scaling method is the performance overhead\n",
153-
"introduced by computing `a_scale` and `b_scale`, which requires additional\n",
154-
"loading of the operand tensors. To overcome this issue, we recommend the delayed\n",
155-
"scaling.\n",
150+
"As shown in the code, we perform fake quantization\n",
151+
"(`fp8_ops.quantize_dequantize`) on the operands of the dot product. Although the\n",
152+
"`jnp.dot` still processes higher-precision inputs, XLA detects this pattern and\n",
153+
"rewrites the dot operation as an FP8 dot call (e.g., cublasLt call for GPUs).\n",
154+
"This approach effectively mimics the first example but offers greater\n",
155+
"flexibility. We can control the input dtypes (both are set to E4M3 here, but we\n",
156+
"could use mixed E4M3 and E5M2) and define scaling factors, which XLA can detect\n",
157+
"and use in the dot backend.\n",
158+
"\n",
159+
"One major issue with the current scaling method is the overhead introduced by\n",
160+
"computing `A_scale` and `B_scale`, which requires additional loading of the\n",
161+
"operand tensors. To overcome this issue, we recommend the delayed scaling.\n",
156162
"\n",
157163
"### Delayed Scaling\n",
158164
"\n",
@@ -161,10 +167,8 @@
161167
"values from recent steps (e.g., 1024 steps). Both tensors are computed from\n",
162168
"previous steps and maintained in the model parameters.\n",
163169
"\n",
164-
"The quantization and dequantization operations for delayed scaling are provided\n",
165-
"by `fp8_ops.in_q` and `fp8_ops.out_dq` respectively. `fp8_ops.in_q` handles\n",
166-
"input quantization and update the amax history and scaling factor, while\n",
167-
"`fp8_ops.out_dq` performs output dequantization."
170+
"Fake quantization for delayed scaling is provided by `fp8_ops.in_qdq` for the\n",
171+
"activations and weights, and `fp8_ops.out_qdq` for the gradients."
168172
]
169173
},
170174
{
@@ -176,20 +180,25 @@
176180
"source": [
177181
"a_scale = jnp.array(1.0)\n",
178182
"b_scale = jnp.array(1.0)\n",
183+
"g_scale = jnp.array(1.0)\n",
179184
"a_amax_hist = jnp.zeros((1024,))\n",
180185
"b_amax_hist = jnp.zeros((1024,))\n",
186+
"g_amax_hist = jnp.zeros((1024,))\n",
181187
"\n",
182188
"@jax.jit\n",
183-
"def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist):\n",
184-
" a, a_scale = fp8_ops.in_q(f32, e4m3, a, a_scale, a_amax_hist)\n",
185-
" b, b_scale = fp8_ops.in_q(f32, e4m3, b, b_scale, b_amax_hist)\n",
189+
"def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist,\n",
190+
" g_scale, g_amax_hist):\n",
191+
" a = fp8_ops.in_qdq(f32, e4m3, a, a_scale, a_amax_hist)\n",
192+
" b = fp8_ops.in_qdq(f32, e4m3, b, b_scale, b_amax_hist)\n",
186193
" \n",
187-
" c = jnp.dot(a, b, preferred_element_type=f32)\n",
188-
" c = fp8_ops.out_dq(f32, a_scale, b_scale, c)\n",
194+
" c = jnp.dot(a, b)\n",
195+
" c = fp8_ops.out_qdq(f32, e5m2, c, g_scale, g_amax_hist)\n",
189196
" return c\n",
190197
"\n",
191-
"c = dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist)\n",
192-
"check_fp8_call(dot_fp8.lower(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist))"
198+
"C = dot_fp8(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,\n",
199+
" g_scale, g_amax_hist)\n",
200+
"check_fp8_call(dot_fp8.lower(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,\n",
201+
" g_scale, g_amax_hist))"
193202
]
194203
},
195204
{
@@ -199,22 +208,22 @@
199208
"source": [
200209
"In this example, we first prepare three pairs of scaling factors and amax\n",
201210
"histories, treating them as results computed from previous steps. Then, we apply\n",
202-
"`fp8_ops.in_q` to the input operands of `jnp.dot`, followed by `fp8_ops.out_dq`\n",
203-
"to the output of `jnp.dot`.\n",
211+
"`fp8_ops.in_qdq` to the input operands of `jnp.dot`, followed by\n",
212+
"`fp8_ops.out_qdq` to the output of `jnp.dot`. Note the `fp8_ops.out_qdq` will\n",
213+
"apply fake quantization to the gradient of the output via custom_vjp functions.\n",
214+
"The new scaling factors and amax histories will be returned through their\n",
215+
"gradients, which will be covered in the next section.\n",
204216
"\n",
205217
"\n",
206218
"## FLAX High Level API\n",
207-
"Flax provides high-level operations to seamlessly integrate FP8 quantization\n",
208-
"into existing layers. Instead of manually handling quantization of the delayed\n",
209-
"scaling (e.g., the maintanence of the amax history and scaling factors), users\n",
210-
"can simply use these drop-in replacements:\n",
211-
"\n",
212-
"* `fp8_ops.Fp8DotGeneral` for `lax.dot_general` operations\n",
213-
"* `fp8_ops.Fp8Einsum` for `jnp.einsum` operations \n",
214-
"\n",
215-
"These operations automatically handle all FP8-related functionality, including\n",
216-
"quantization/dequantization, scale factor updates, and FP8 dtype selection for\n",
217-
"both forward and backward passes.\n",
219+
"With the FLAX library, incorporating FP8 operations into existing FLAX layers\n",
220+
"is a seamless process. Users don't need to manipulate the low-level APIs for\n",
221+
"quantization. Instead, they can integrate the provided custom FP8 dot\n",
222+
"(`fp8_ops.Fp8DotGeneralOp`) into FLAX layers using a straightforward\n",
223+
"\"code-injection\" approach. This custom operation encapsulates all FP8-related\n",
224+
"tasks, including the placement of quantization-dequantization ops, algorithms\n",
225+
"for updating scaling factors, and the selection of FP8 dtype combinations for\n",
226+
"forward and backward propagation.\n",
218227
"\n",
219228
"Consider the following example:"
220229
]
@@ -226,8 +235,8 @@
226235
"metadata": {},
227236
"outputs": [],
228237
"source": [
229-
"model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneral)\n",
230-
"params = model.init(k0, A)\n",
238+
"model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneralOp)\n",
239+
"params = model.init(key, A)\n",
231240
"\n",
232241
"@jax.jit\n",
233242
"def train_step(var, a): \n",
@@ -237,66 +246,18 @@
237246
"check_fp8_call(train_step.lower(params, A))"
238247
]
239248
},
240-
{
241-
"cell_type": "markdown",
242-
"id": "ba280e79",
243-
"metadata": {},
244-
"source": [
245-
"By setting `dot_general_cls=fp8_ops.Fp8DotGeneral`, we replace the\n",
246-
"default `lax.dot_general` operation in `nn.Dense` with an FP8-enabled version.\n",
247-
"The model usage remains similar, but now includes additional parameters for FP8\n",
248-
"quantization: scaling factors and amax history values. The next section explains\n",
249-
"how to update these FP8-specific parameters.\n",
250-
"\n",
251-
"For models that use `jnp.einsum` operations, such as Mixture of Experts (MoE)\n",
252-
"layers, users can replace them with `fp8_ops.Fp8Einsum` to enable FP8\n",
253-
"quantization. Here's an example:"
254-
]
255-
},
256-
{
257-
"cell_type": "code",
258-
"execution_count": null,
259-
"id": "961b4549",
260-
"metadata": {},
261-
"outputs": [],
262-
"source": [
263-
"from typing import Any\n",
264-
"class FooModule(nn.Module):\n",
265-
" einsum: Any = None\n",
266-
" @nn.compact\n",
267-
" def __call__(self, a, b):\n",
268-
" if self.einsum is not None:\n",
269-
" einsum_fn = self.einsum()\n",
270-
" elif self.einsum is None:\n",
271-
" einsum_fn = jnp.einsum\n",
272-
" c = einsum_fn(\"mk,kn->mn\", a, b)\n",
273-
" return c\n",
274-
"\n",
275-
"model = FooModule(einsum=fp8_ops.Fp8Einsum)\n",
276-
"params = model.init(k0, a, b)\n",
277-
"\n",
278-
"@jax.jit\n",
279-
"def train_step(var, a, b):\n",
280-
" c = model.apply(var, a, b)\n",
281-
" return jnp.sum(c)\n",
282-
"\n",
283-
"check_fp8_call(train_step.lower(params, a, b))"
284-
]
285-
},
286249
{
287250
"cell_type": "markdown",
288251
"id": "a83b0851",
289252
"metadata": {},
290253
"source": [
291-
"## Manipulate FP8 params\n",
292-
"\n",
293-
"The following sections explain the internal FP8 parameters managed by\n",
294-
"`fp8_ops.Fp8DotGeneral` and `fp8_ops.Fp8Einsum`. These parameters\n",
295-
"include scaling factors and amax history values that control the FP8\n",
296-
"quantization process. While most users don't need to interact with these\n",
297-
"directly, understanding them can be valuable for advanced optimization and\n",
298-
"debugging.\n",
254+
"In this example, we simply set `dot_general_cls=fp8_ops.Fp8DotGeneralOp` to\n",
255+
"enable the Dense layer to utilize the FP8 dot operation. The usage of the model\n",
256+
"remains almost the same as before. The main difference is the addition of a new\n",
257+
"category of parameters: the sets of scaling factors and amax history. In the\n",
258+
"next section, we will explore how to update these parameters.\n",
299259
"\n",
260+
"## Manipulate FP8 params\n",
300261
"Let's first examine the data structure of `params`. In the code below, we redact\n",
301262
"the parameter values and then display the PyTree structure."
302263
]
@@ -324,12 +285,13 @@
324285
"The output is as follows:\n",
325286
"\n",
326287
"```plaintext\n",
327-
"{'_overwrite_with_gradient': {'Fp8Einsum_0': {'input_amax_history': '*',\n",
328-
" 'input_scale': '*',\n",
329-
" 'kernel_amax_history': '*',\n",
330-
" 'kernel_scale': '*',\n",
331-
" 'output_grad_amax_history': '*',\n",
332-
" 'output_grad_scale': '*'}}}\n",
288+
"{'_overwrite_with_gradient': {'Fp8DotGeneralOp_0': {'input_amax_history': '*',\n",
289+
" 'input_scale': '*',\n",
290+
" 'kernel_amax_history': '*',\n",
291+
" 'kernel_scale': '*',\n",
292+
" 'output_grad_amax_history': '*',\n",
293+
" 'output_grad_scale': '*'}},\n",
294+
" 'params': {'bias': '*', 'kernel': '*'}}\n",
333295
"```\n",
334296
"\n",
335297
"In addition to the expected `params`, there is an additional category called\n",
@@ -438,26 +400,7 @@
438400
"2.0 [5. 0. 0. ... 0. 0. 0.]\n",
439401
"```\n",
440402
"\n",
441-
"This casting is already included if users choose to use the high-level APIs.\n",
442-
"\n",
443-
"## Deprecated APIs\n",
444-
"Previously, we provided APIs like `fp8_ops.quantize_dequantize` for current\n",
445-
"scaling and `fp8_ops.[in|out]_qdq` for delayed scaling. These were used with\n",
446-
"high precision dot operations, leveraging an XLA-FP8 feature that\n",
447-
"pattern-matched QDQ->dot sequences to Q->fp8_cublas_gemm. The corresponding\n",
448-
"high-level API was called `fp8_ops.Fp8DotGeneralOp`. However, this pattern\n",
449-
"matching-based solution proved brittle, as the patterns could be easily broken\n",
450-
"by other XLA optimizations. We recommend users migrate from these deprecated\n",
451-
"APIs to the newer ones described above.\n",
452-
"\n",
453-
"For migration, users should replace:\n",
454-
"* `fp8_ops.quantize_dequantize -> jnp.dot` with `fp8_ops.quantize -> jnp.dot ->\n",
455-
" fp8_ops.dequantize`\n",
456-
"* `fp8_ops.in_qdq -> jnp.dot -> fp8_ops.out_qdq` with `fp8_ops.in_q -> jnp.dot\n",
457-
" -> fp8_ops.out_dq`\n",
458-
"* `fp8_ops.Fp8DotGeneralOp` with `fp8_ops.Fp8DotGeneral`\n",
459-
"\n",
460-
"Additionally, we provide an einsum variant through `fp8_ops.Fp8Einsum`."
403+
"This casting is already included if users choose to use the high-level APIs."
461404
]
462405
}
463406
],

0 commit comments

Comments
 (0)