|
13 | 13 | "as quantization (Q). Conversely, de-quantization (DQ) rescales the FP8 data back\n",
|
14 | 14 | "to its original type.\n",
|
15 | 15 | "\n",
|
16 |
| - "While jnp.dot supports FP8 inputs directly, proper quantization and\n", |
17 |
| - "dequantization is needed for optimal performance. Flax provides\n", |
18 |
| - "nn.fp8_ops.Fp8DotGeneral and nn.fp8_ops.Fp8Einsum modules that handle\n", |
19 |
| - "this automatically and can be used with existing layers like nn.Dense.\n", |
| 16 | + "Although jnp.dot supports FP8 inputs, certain limitations make it impractical\n", |
| 17 | + "for real-world applications. Alternatively, XLA, our compiler, can recognize\n", |
| 18 | + "patterns like <FP8>->DQ->Dot and subsequently invoke FP8 backends (e.g.,\n", |
| 19 | + "cublasLt for GPUs). FLAX encapsulates such patterns into the\n", |
| 20 | + "nn.fp8_ops.Fp8DotGeneralOp module, allowing users to easily configure it for\n", |
| 21 | + "existing layers (e.g., nn.Dense).\n", |
20 | 22 | "\n",
|
21 | 23 | "This tutorial will walk you through the basics of how to use it.\n",
|
22 | 24 | "\n",
|
|
48 | 50 | "from flax.linen import fp8_ops\n",
|
49 | 51 | "\n",
|
50 | 52 | "e4m3 = jnp.float8_e4m3fn\n",
|
| 53 | + "e5m2 = jnp.float8_e5m2\n", |
51 | 54 | "f32 = jnp.float32\n",
|
52 | 55 | "E4M3_MAX = jnp.finfo(e4m3).max.astype(f32)\n",
|
53 | 56 | "\n",
|
|
79 | 82 | "metadata": {},
|
80 | 83 | "outputs": [],
|
81 | 84 | "source": [
|
82 |
| - "k0, k1 = random.split(random.key(0), 2)\n", |
83 |
| - "a = random.uniform(k0, (16, 32))\n", |
84 |
| - "b = random.uniform(k1, (32, 64))\n", |
| 85 | + "key = random.key(0)\n", |
| 86 | + "A = random.uniform(key, (16, 32))\n", |
| 87 | + "B = random.uniform(key, (32, 64))\n", |
85 | 88 | "@jax.jit\n",
|
86 |
| - "def dot_fp8(a, b):\n", |
87 |
| - " return jnp.dot(a.astype(e4m3), b.astype(e4m3), preferred_element_type=f32)\n", |
88 |
| - "check_fp8_call(dot_fp8.lower(a, b))" |
| 89 | + "def dot_fp8(A, B):\n", |
| 90 | + " return jnp.dot(A.astype(e4m3), B.astype(e4m3), preferred_element_type=f32)\n", |
| 91 | + "check_fp8_call(dot_fp8.lower(A, B))" |
89 | 92 | ]
|
90 | 93 | },
|
91 | 94 | {
|
92 | 95 | "cell_type": "markdown",
|
93 | 96 | "id": "adb22878",
|
94 | 97 | "metadata": {},
|
95 | 98 | "source": [
|
96 |
| - "However, this approach has two key limitations:\n", |
| 99 | + "However, there are two main issues with this approach. Firstly, `jnp.dot` does\n", |
| 100 | + "not accept scaling factors for the operands, defaulting to a scaling factor of\n", |
| 101 | + "1.0. Secondly, it does not support operands of mixed FP8 data types. For\n", |
| 102 | + "example, when the operands are E5M2 and E4M3, the dot product is performed using\n", |
| 103 | + "the promoted FP16 data type.\n", |
97 | 104 | "\n",
|
98 |
| - "1. `jnp.dot` does not support custom scaling factors for operands, defaulting to\n", |
99 |
| - " a scale of 1.0\n", |
100 |
| - "2. The autodiff does not automatically use E5M2 for gradients and E4M3 for\n", |
101 |
| - " activations/weights during training, which is the recommended practice\n", |
| 105 | + "In real-world scenarios, it is essential to specify scaling factors, either from\n", |
| 106 | + "calibration for inference or a user-defined algorithm during training.\n", |
| 107 | + "Additionally, it is common practice to use E5M2 for gradients and E4M3 for\n", |
| 108 | + "activations and kernels. These limitations make this method less practical for\n", |
| 109 | + "real-world applications.\n", |
102 | 110 | "\n",
|
103 |
| - "To overcome these limitations and implement proper FP8 matrix multiplication, we\n", |
104 |
| - "recommend using the Flax FP8 APIs. Let's start with a basic scaling approach.\n", |
| 111 | + "To address these limitations and create a more versatile FP8 dot product, we\n", |
| 112 | + "recommend leveraging XLA-FP8. Let's begin with a simple scaling strategy.\n", |
105 | 113 | "\n",
|
106 | 114 | "\n",
|
107 | 115 | "### Current Scaling\n",
|
|
121 | 129 | "outputs": [],
|
122 | 130 | "source": [
|
123 | 131 | "@jax.jit\n",
|
124 |
| - "def dot_fp8(a, b):\n", |
125 |
| - " a_scale = jnp.max(jnp.abs(A)) / E4M3_MAX\n", |
126 |
| - " b_scale = jnp.max(jnp.abs(B)) / E4M3_MAX\n", |
127 |
| - " a = fp8_ops.quantize(a, e4m3, a_scale, f32)\n", |
128 |
| - " b = fp8_ops.quantize(b, e4m3, b_scale, f32)\n", |
129 |
| - "\n", |
130 |
| - " c = jnp.dot(a, b, preferred_element_type=f32)\n", |
131 |
| - " c = fp8_ops.dequantize(c, f32, a_scale * b_scale)\n", |
132 |
| - " return c\n", |
| 132 | + "def dot_fp8(A, B):\n", |
| 133 | + " A_scale = jnp.max(jnp.abs(A)) / E4M3_MAX\n", |
| 134 | + " B_scale = jnp.max(jnp.abs(B)) / E4M3_MAX\n", |
| 135 | + " A = fp8_ops.quantize_dequantize(A, e4m3, A_scale, f32)\n", |
| 136 | + " B = fp8_ops.quantize_dequantize(B, e4m3, B_scale, f32)\n", |
| 137 | + "\n", |
| 138 | + " C = jnp.dot(A, B)\n", |
| 139 | + " return C\n", |
133 | 140 | "\n",
|
134 |
| - "c = dot_fp8(a, b)\n", |
135 |
| - "check_fp8_call(dot_fp8.lower(a, b))" |
| 141 | + "C = dot_fp8(A, B)\n", |
| 142 | + "check_fp8_call(dot_fp8.lower(A, B))" |
136 | 143 | ]
|
137 | 144 | },
|
138 | 145 | {
|
139 | 146 | "cell_type": "markdown",
|
140 | 147 | "id": "59aca6fe",
|
141 | 148 | "metadata": {},
|
142 | 149 | "source": [
|
143 |
| - "As shown in the code, we perform quantization (`fp8_ops.quantize`) on the\n", |
144 |
| - "tensors to get the lower precision operands. The `jnp.dot` processes them and\n", |
145 |
| - "accumulates the output in high precision (i.e., the `preferred_element_type`).\n", |
146 |
| - "After that, we multiply the result by the scaling factors to dequantize back to\n", |
147 |
| - "the original range (`fp8_ops.dequantize`). Note that while this example uses\n", |
148 |
| - "E4M3 for both inputs, it is possible to use different FP8 dtypes like E4M3 and\n", |
149 |
| - "E5M2 for the inputs. The quantization method and the scaling factors can also be\n", |
150 |
| - "customized based on application needs.\n", |
151 |
| - "\n", |
152 |
| - "One major issue with the current scaling method is the performance overhead\n", |
153 |
| - "introduced by computing `a_scale` and `b_scale`, which requires additional\n", |
154 |
| - "loading of the operand tensors. To overcome this issue, we recommend the delayed\n", |
155 |
| - "scaling.\n", |
| 150 | + "As shown in the code, we perform fake quantization\n", |
| 151 | + "(`fp8_ops.quantize_dequantize`) on the operands of the dot product. Although the\n", |
| 152 | + "`jnp.dot` still processes higher-precision inputs, XLA detects this pattern and\n", |
| 153 | + "rewrites the dot operation as an FP8 dot call (e.g., cublasLt call for GPUs).\n", |
| 154 | + "This approach effectively mimics the first example but offers greater\n", |
| 155 | + "flexibility. We can control the input dtypes (both are set to E4M3 here, but we\n", |
| 156 | + "could use mixed E4M3 and E5M2) and define scaling factors, which XLA can detect\n", |
| 157 | + "and use in the dot backend.\n", |
| 158 | + "\n", |
| 159 | + "One major issue with the current scaling method is the overhead introduced by\n", |
| 160 | + "computing `A_scale` and `B_scale`, which requires additional loading of the\n", |
| 161 | + "operand tensors. To overcome this issue, we recommend the delayed scaling.\n", |
156 | 162 | "\n",
|
157 | 163 | "### Delayed Scaling\n",
|
158 | 164 | "\n",
|
|
161 | 167 | "values from recent steps (e.g., 1024 steps). Both tensors are computed from\n",
|
162 | 168 | "previous steps and maintained in the model parameters.\n",
|
163 | 169 | "\n",
|
164 |
| - "The quantization and dequantization operations for delayed scaling are provided\n", |
165 |
| - "by `fp8_ops.in_q` and `fp8_ops.out_dq` respectively. `fp8_ops.in_q` handles\n", |
166 |
| - "input quantization and update the amax history and scaling factor, while\n", |
167 |
| - "`fp8_ops.out_dq` performs output dequantization." |
| 170 | + "Fake quantization for delayed scaling is provided by `fp8_ops.in_qdq` for the\n", |
| 171 | + "activations and weights, and `fp8_ops.out_qdq` for the gradients." |
168 | 172 | ]
|
169 | 173 | },
|
170 | 174 | {
|
|
176 | 180 | "source": [
|
177 | 181 | "a_scale = jnp.array(1.0)\n",
|
178 | 182 | "b_scale = jnp.array(1.0)\n",
|
| 183 | + "g_scale = jnp.array(1.0)\n", |
179 | 184 | "a_amax_hist = jnp.zeros((1024,))\n",
|
180 | 185 | "b_amax_hist = jnp.zeros((1024,))\n",
|
| 186 | + "g_amax_hist = jnp.zeros((1024,))\n", |
181 | 187 | "\n",
|
182 | 188 | "@jax.jit\n",
|
183 |
| - "def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist):\n", |
184 |
| - " a, a_scale = fp8_ops.in_q(f32, e4m3, a, a_scale, a_amax_hist)\n", |
185 |
| - " b, b_scale = fp8_ops.in_q(f32, e4m3, b, b_scale, b_amax_hist)\n", |
| 189 | + "def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist,\n", |
| 190 | + " g_scale, g_amax_hist):\n", |
| 191 | + " a = fp8_ops.in_qdq(f32, e4m3, a, a_scale, a_amax_hist)\n", |
| 192 | + " b = fp8_ops.in_qdq(f32, e4m3, b, b_scale, b_amax_hist)\n", |
186 | 193 | " \n",
|
187 |
| - " c = jnp.dot(a, b, preferred_element_type=f32)\n", |
188 |
| - " c = fp8_ops.out_dq(f32, a_scale, b_scale, c)\n", |
| 194 | + " c = jnp.dot(a, b)\n", |
| 195 | + " c = fp8_ops.out_qdq(f32, e5m2, c, g_scale, g_amax_hist)\n", |
189 | 196 | " return c\n",
|
190 | 197 | "\n",
|
191 |
| - "c = dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist)\n", |
192 |
| - "check_fp8_call(dot_fp8.lower(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist))" |
| 198 | + "C = dot_fp8(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,\n", |
| 199 | + " g_scale, g_amax_hist)\n", |
| 200 | + "check_fp8_call(dot_fp8.lower(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,\n", |
| 201 | + " g_scale, g_amax_hist))" |
193 | 202 | ]
|
194 | 203 | },
|
195 | 204 | {
|
|
199 | 208 | "source": [
|
200 | 209 | "In this example, we first prepare three pairs of scaling factors and amax\n",
|
201 | 210 | "histories, treating them as results computed from previous steps. Then, we apply\n",
|
202 |
| - "`fp8_ops.in_q` to the input operands of `jnp.dot`, followed by `fp8_ops.out_dq`\n", |
203 |
| - "to the output of `jnp.dot`.\n", |
| 211 | + "`fp8_ops.in_qdq` to the input operands of `jnp.dot`, followed by\n", |
| 212 | + "`fp8_ops.out_qdq` to the output of `jnp.dot`. Note the `fp8_ops.out_qdq` will\n", |
| 213 | + "apply fake quantization to the gradient of the output via custom_vjp functions.\n", |
| 214 | + "The new scaling factors and amax histories will be returned through their\n", |
| 215 | + "gradients, which will be covered in the next section.\n", |
204 | 216 | "\n",
|
205 | 217 | "\n",
|
206 | 218 | "## FLAX High Level API\n",
|
207 |
| - "Flax provides high-level operations to seamlessly integrate FP8 quantization\n", |
208 |
| - "into existing layers. Instead of manually handling quantization of the delayed\n", |
209 |
| - "scaling (e.g., the maintanence of the amax history and scaling factors), users\n", |
210 |
| - "can simply use these drop-in replacements:\n", |
211 |
| - "\n", |
212 |
| - "* `fp8_ops.Fp8DotGeneral` for `lax.dot_general` operations\n", |
213 |
| - "* `fp8_ops.Fp8Einsum` for `jnp.einsum` operations \n", |
214 |
| - "\n", |
215 |
| - "These operations automatically handle all FP8-related functionality, including\n", |
216 |
| - "quantization/dequantization, scale factor updates, and FP8 dtype selection for\n", |
217 |
| - "both forward and backward passes.\n", |
| 219 | + "With the FLAX library, incorporating FP8 operations into existing FLAX layers\n", |
| 220 | + "is a seamless process. Users don't need to manipulate the low-level APIs for\n", |
| 221 | + "quantization. Instead, they can integrate the provided custom FP8 dot\n", |
| 222 | + "(`fp8_ops.Fp8DotGeneralOp`) into FLAX layers using a straightforward\n", |
| 223 | + "\"code-injection\" approach. This custom operation encapsulates all FP8-related\n", |
| 224 | + "tasks, including the placement of quantization-dequantization ops, algorithms\n", |
| 225 | + "for updating scaling factors, and the selection of FP8 dtype combinations for\n", |
| 226 | + "forward and backward propagation.\n", |
218 | 227 | "\n",
|
219 | 228 | "Consider the following example:"
|
220 | 229 | ]
|
|
226 | 235 | "metadata": {},
|
227 | 236 | "outputs": [],
|
228 | 237 | "source": [
|
229 |
| - "model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneral)\n", |
230 |
| - "params = model.init(k0, A)\n", |
| 238 | + "model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneralOp)\n", |
| 239 | + "params = model.init(key, A)\n", |
231 | 240 | "\n",
|
232 | 241 | "@jax.jit\n",
|
233 | 242 | "def train_step(var, a): \n",
|
|
237 | 246 | "check_fp8_call(train_step.lower(params, A))"
|
238 | 247 | ]
|
239 | 248 | },
|
240 |
| - { |
241 |
| - "cell_type": "markdown", |
242 |
| - "id": "ba280e79", |
243 |
| - "metadata": {}, |
244 |
| - "source": [ |
245 |
| - "By setting `dot_general_cls=fp8_ops.Fp8DotGeneral`, we replace the\n", |
246 |
| - "default `lax.dot_general` operation in `nn.Dense` with an FP8-enabled version.\n", |
247 |
| - "The model usage remains similar, but now includes additional parameters for FP8\n", |
248 |
| - "quantization: scaling factors and amax history values. The next section explains\n", |
249 |
| - "how to update these FP8-specific parameters.\n", |
250 |
| - "\n", |
251 |
| - "For models that use `jnp.einsum` operations, such as Mixture of Experts (MoE)\n", |
252 |
| - "layers, users can replace them with `fp8_ops.Fp8Einsum` to enable FP8\n", |
253 |
| - "quantization. Here's an example:" |
254 |
| - ] |
255 |
| - }, |
256 |
| - { |
257 |
| - "cell_type": "code", |
258 |
| - "execution_count": null, |
259 |
| - "id": "961b4549", |
260 |
| - "metadata": {}, |
261 |
| - "outputs": [], |
262 |
| - "source": [ |
263 |
| - "from typing import Any\n", |
264 |
| - "class FooModule(nn.Module):\n", |
265 |
| - " einsum: Any = None\n", |
266 |
| - " @nn.compact\n", |
267 |
| - " def __call__(self, a, b):\n", |
268 |
| - " if self.einsum is not None:\n", |
269 |
| - " einsum_fn = self.einsum()\n", |
270 |
| - " elif self.einsum is None:\n", |
271 |
| - " einsum_fn = jnp.einsum\n", |
272 |
| - " c = einsum_fn(\"mk,kn->mn\", a, b)\n", |
273 |
| - " return c\n", |
274 |
| - "\n", |
275 |
| - "model = FooModule(einsum=fp8_ops.Fp8Einsum)\n", |
276 |
| - "params = model.init(k0, a, b)\n", |
277 |
| - "\n", |
278 |
| - "@jax.jit\n", |
279 |
| - "def train_step(var, a, b):\n", |
280 |
| - " c = model.apply(var, a, b)\n", |
281 |
| - " return jnp.sum(c)\n", |
282 |
| - "\n", |
283 |
| - "check_fp8_call(train_step.lower(params, a, b))" |
284 |
| - ] |
285 |
| - }, |
286 | 249 | {
|
287 | 250 | "cell_type": "markdown",
|
288 | 251 | "id": "a83b0851",
|
289 | 252 | "metadata": {},
|
290 | 253 | "source": [
|
291 |
| - "## Manipulate FP8 params\n", |
292 |
| - "\n", |
293 |
| - "The following sections explain the internal FP8 parameters managed by\n", |
294 |
| - "`fp8_ops.Fp8DotGeneral` and `fp8_ops.Fp8Einsum`. These parameters\n", |
295 |
| - "include scaling factors and amax history values that control the FP8\n", |
296 |
| - "quantization process. While most users don't need to interact with these\n", |
297 |
| - "directly, understanding them can be valuable for advanced optimization and\n", |
298 |
| - "debugging.\n", |
| 254 | + "In this example, we simply set `dot_general_cls=fp8_ops.Fp8DotGeneralOp` to\n", |
| 255 | + "enable the Dense layer to utilize the FP8 dot operation. The usage of the model\n", |
| 256 | + "remains almost the same as before. The main difference is the addition of a new\n", |
| 257 | + "category of parameters: the sets of scaling factors and amax history. In the\n", |
| 258 | + "next section, we will explore how to update these parameters.\n", |
299 | 259 | "\n",
|
| 260 | + "## Manipulate FP8 params\n", |
300 | 261 | "Let's first examine the data structure of `params`. In the code below, we redact\n",
|
301 | 262 | "the parameter values and then display the PyTree structure."
|
302 | 263 | ]
|
|
324 | 285 | "The output is as follows:\n",
|
325 | 286 | "\n",
|
326 | 287 | "```plaintext\n",
|
327 |
| - "{'_overwrite_with_gradient': {'Fp8Einsum_0': {'input_amax_history': '*',\n", |
328 |
| - " 'input_scale': '*',\n", |
329 |
| - " 'kernel_amax_history': '*',\n", |
330 |
| - " 'kernel_scale': '*',\n", |
331 |
| - " 'output_grad_amax_history': '*',\n", |
332 |
| - " 'output_grad_scale': '*'}}}\n", |
| 288 | + "{'_overwrite_with_gradient': {'Fp8DotGeneralOp_0': {'input_amax_history': '*',\n", |
| 289 | + " 'input_scale': '*',\n", |
| 290 | + " 'kernel_amax_history': '*',\n", |
| 291 | + " 'kernel_scale': '*',\n", |
| 292 | + " 'output_grad_amax_history': '*',\n", |
| 293 | + " 'output_grad_scale': '*'}},\n", |
| 294 | + " 'params': {'bias': '*', 'kernel': '*'}}\n", |
333 | 295 | "```\n",
|
334 | 296 | "\n",
|
335 | 297 | "In addition to the expected `params`, there is an additional category called\n",
|
|
438 | 400 | "2.0 [5. 0. 0. ... 0. 0. 0.]\n",
|
439 | 401 | "```\n",
|
440 | 402 | "\n",
|
441 |
| - "This casting is already included if users choose to use the high-level APIs.\n", |
442 |
| - "\n", |
443 |
| - "## Deprecated APIs\n", |
444 |
| - "Previously, we provided APIs like `fp8_ops.quantize_dequantize` for current\n", |
445 |
| - "scaling and `fp8_ops.[in|out]_qdq` for delayed scaling. These were used with\n", |
446 |
| - "high precision dot operations, leveraging an XLA-FP8 feature that\n", |
447 |
| - "pattern-matched QDQ->dot sequences to Q->fp8_cublas_gemm. The corresponding\n", |
448 |
| - "high-level API was called `fp8_ops.Fp8DotGeneralOp`. However, this pattern\n", |
449 |
| - "matching-based solution proved brittle, as the patterns could be easily broken\n", |
450 |
| - "by other XLA optimizations. We recommend users migrate from these deprecated\n", |
451 |
| - "APIs to the newer ones described above.\n", |
452 |
| - "\n", |
453 |
| - "For migration, users should replace:\n", |
454 |
| - "* `fp8_ops.quantize_dequantize -> jnp.dot` with `fp8_ops.quantize -> jnp.dot ->\n", |
455 |
| - " fp8_ops.dequantize`\n", |
456 |
| - "* `fp8_ops.in_qdq -> jnp.dot -> fp8_ops.out_qdq` with `fp8_ops.in_q -> jnp.dot\n", |
457 |
| - " -> fp8_ops.out_dq`\n", |
458 |
| - "* `fp8_ops.Fp8DotGeneralOp` with `fp8_ops.Fp8DotGeneral`\n", |
459 |
| - "\n", |
460 |
| - "Additionally, we provide an einsum variant through `fp8_ops.Fp8Einsum`." |
| 403 | + "This casting is already included if users choose to use the high-level APIs." |
461 | 404 | ]
|
462 | 405 | }
|
463 | 406 | ],
|
|
0 commit comments