From 05ef9c76fda4fd18310fddbc536c6bac40dbdffc Mon Sep 17 00:00:00 2001 From: Raymond Zou Date: Tue, 5 Nov 2024 19:30:18 +0000 Subject: [PATCH 01/28] Add sequence parallelism sharding q --- MaxText/common_types.py | 1 + MaxText/configs/base.yml | 25 ++++++++++++++---------- MaxText/layers/attentions.py | 38 +++++++++++++++++++++++------------- MaxText/max_utils.py | 23 ++++++++++++++++++++++ MaxText/maxtext_utils.py | 2 +- 5 files changed, 64 insertions(+), 25 deletions(-) diff --git a/MaxText/common_types.py b/MaxText/common_types.py index c96bcaeef..bdca4f380 100644 --- a/MaxText/common_types.py +++ b/MaxText/common_types.py @@ -36,6 +36,7 @@ BATCH = "activation_batch" LENGTH = "activation_length" +KV_LENGTH = "activation_kv_length" EMBED = "activation_embed" HEAD = "activation_heads" PREFILL_KV_BATCH = "activation_prefill_kv_batch" diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index e784535e6..a2ba4bc7e 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -228,15 +228,18 @@ jax_cache_dir: "~/jax_cache" hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' # Parallelism -mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_sequence', 'expert', 'autoregressive'] +mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor','sequence','tensor_sequence']], ['activation_kv_heads', ['tensor','sequence','tensor_sequence']], - ['activation_length', ['sequence']], ['activation_norm_length', ['tensor_sequence', 'sequence']], + ['activation_length', ['sequence', 'context']], + ['activation_length', ['context']], + ['activation_length_q', ['context']], + ['activation_kv_length', []], ['activation_embed', 'tensor'], ['activation_mlp', ['tensor', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_sequence']], @@ -251,13 +254,13 @@ logical_axis_rules: [ ['activation_exp', 'expert'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_sequence', 'autoregressive']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], - ['embed', ['fsdp', 'sequence', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence']], - ['embed_no_exp', ['fsdp', 'sequence']], - ['norm', ['tensor', 'tensor_sequence']], - ['q_heads', ['tensor', 'tensor_sequence', 'autoregressive']], - ['heads', ['tensor', 'tensor_sequence', 'autoregressive']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], + ['embed', ['fsdp', 'sequence', 'context', 'expert']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], + ['embed_no_exp', ['fsdp', 'sequence', 'context']], + ['norm', 'tensor'], + ['q_heads', ['tensor', 'autoregressive']], + ['heads', ['tensor', 'autoregressive']], ['layers', 'stage'], ['kv', []], ['kv_heads', ['tensor', 'tensor_sequence', 'autoregressive']], @@ -270,7 +273,7 @@ logical_axis_rules: [ ['exp', 'expert'], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']] +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']] # sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters. sharding_tolerance: 0.02 @@ -288,6 +291,7 @@ dcn_tensor_sequence_parallelism: 1 # never recommended dcn_pipeline_parallelism: 1 dcn_expert_parallelism: 1 dcn_autoregressive_parallelism: 1 # never recommended +dcn_context_parallelism: 1 ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_fsdp_transpose_parallelism: 1 @@ -297,6 +301,7 @@ ici_tensor_sequence_parallelism: 1 ici_autoregressive_parallelism: 1 ici_pipeline_parallelism: 1 ici_expert_parallelism: 1 +ici_context_parallelism: 1 # The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation, # you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1. diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index f5990e7d9..7720c73db 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -34,6 +34,7 @@ from layers import initializers from layers import linears from layers import quantizations +import max_logging # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes @@ -64,6 +65,7 @@ class AttentionType(enum.Enum): PREFILL_KV_BATCH = common_types.PREFILL_KV_BATCH KV_BATCH = common_types.KV_BATCH LENGTH = common_types.LENGTH +KV_LENGTH = common_types.KV_LENGTH HEAD = common_types.HEAD EMBED = common_types.EMBED KV_HEAD = common_types.KV_HEAD @@ -141,7 +143,9 @@ class AttentionOp(nn.Module): float32_qk_product: bool = False max_prefill_predict_length: int = -1 float32_logits: bool = False - flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) + # flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) + flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV) + flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV) prefill_cache_logical_axis_names: AxisNames = (CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV) cache_logical_axis_names: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV) cache_scale_logical_axis_names: AxisNames = (CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV) @@ -304,9 +308,15 @@ def tpu_flash_attention( value = jnp.transpose(value, axes=(0, 2, 1, 3)) if decoder_segment_ids is not None: - decoder_segment_ids = splash_attention_kernel.SegmentIds(decoder_segment_ids, decoder_segment_ids) - axis_names = nn.logical_to_mesh_axes(self.flash_axis_names) - segment_axis_names = nn.logical_to_mesh_axes((BATCH, "activation_length_no_heads")) + decoder_segment_ids = splash_attention_kernel.SegmentIds(decoder_segment_ids) + axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q) + axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv) + max_logging.log(f'axis_names_q: {axis_names_q}') + max_logging.log(f'axis_names_kv: {axis_names_kv}') + segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, "activation_length_q")) + segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, "activation_length_kv")) + + segment_axis_names = [{"q": segment_axis_names_q, "kv": segment_axis_names_kv}] global_block_q = self.config.sa_block_q global_block_kv = self.config.sa_block_kv @@ -325,19 +335,19 @@ def tpu_flash_attention( shard_map, mesh=self.mesh, in_specs=( - axis_names, - axis_names, - axis_names, + axis_names_q, + axis_names_kv, + axis_names_kv, segment_axis_names, ), - out_specs=axis_names, + out_specs=axis_names_q, check_rep=False, ) def wrap_flash_attention(query, key, value, decoder_segment_ids): - if decoder_segment_ids is not None: - assert ( - query.shape[2] == decoder_segment_ids.q.shape[1] - ), "Sharding along sequence dimension not allowed in tpu kernel attention" + # if decoder_segment_ids is not None: + # assert ( + # query.shape[2] == decoder_segment_ids.q.shape[1] + # ), "Sharding along sequence dimension not allowed in tpu kernel attention" block_sizes = splash_attention_kernel.BlockSizes( block_q=min(global_block_q, query.shape[2]), block_kv=min(global_block_kv, key.shape[2]), @@ -375,13 +385,13 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids): attn_logits_soft_cap=attn_logits_soft_cap, ) - return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids) + return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids[0]) devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( "Batch dimension should be shardable among the devices in data and fsdp" " axis" ) - x = wrap_flash_attention(query, key, value, decoder_segment_ids) + x = wrap_flash_attention(query, key, value, [decoder_segment_ids]) x = jnp.transpose(x, axes=(0, 2, 1, 3)) return x diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index fb41b8347..92100375d 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -573,6 +573,29 @@ def create_device_mesh(config, devices=None): multi_slice_env = num_slices > 1 + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_pipeline_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_fsdp_transpose_parallelism, + config.dcn_sequence_parallelism, + config.dcn_context_parallelism, + config.dcn_tensor_parallelism, + config.dcn_expert_parallelism, + config.dcn_autoregressive_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_pipeline_parallelism, + config.ici_fsdp_parallelism, + config.ici_fsdp_transpose_parallelism, + config.ici_sequence_parallelism, + config.ici_context_parallelism, + config.ici_tensor_parallelism, + config.ici_expert_parallelism, + config.ici_autoregressive_parallelism, + ] + # Find possible unspecified parallelisms ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 64caf0613..46a850364 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -239,7 +239,7 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance): """ total_num_params = max_utils.calculate_num_params_from_pytree(params) product_num_devices_for_weight_sharding = 1 - for axis in ["fsdp", "fsdp_transpose", "sequence", "tensor", "tensor_sequence", "stage", "expert"]: + for axis in ["fsdp", "fsdp_transpose", "sequence", "context", "tensor", "tensor_sequence", "stage", "expert"]: product_num_devices_for_weight_sharding *= mesh.shape[axis] total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params) perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding From c87ecf944fb3920b3efb419b698f303b55be630f Mon Sep 17 00:00:00 2001 From: Raymond Zou Date: Tue, 5 Nov 2024 21:27:20 +0000 Subject: [PATCH 02/28] Add segmentid implementation --- MaxText/layers/attentions.py | 34 ++++++++++++++++------------------ MaxText/max_utils.py | 23 ----------------------- MaxText/pyconfig.py | 2 ++ 3 files changed, 18 insertions(+), 41 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 7720c73db..d49448a13 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -143,7 +143,6 @@ class AttentionOp(nn.Module): float32_qk_product: bool = False max_prefill_predict_length: int = -1 float32_logits: bool = False - # flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV) flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV) prefill_cache_logical_axis_names: AxisNames = (CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV) @@ -306,18 +305,16 @@ def tpu_flash_attention( query = jnp.transpose(query, axes=(0, 2, 1, 3)) key = jnp.transpose(key, axes=(0, 2, 1, 3)) value = jnp.transpose(value, axes=(0, 2, 1, 3)) - + segment_axis_names_q = None + segment_axis_names_kv = None if decoder_segment_ids is not None: - decoder_segment_ids = splash_attention_kernel.SegmentIds(decoder_segment_ids) + segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, "activation_length_q")) + segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, "activation_length_kv")) axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q) axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv) max_logging.log(f'axis_names_q: {axis_names_q}') max_logging.log(f'axis_names_kv: {axis_names_kv}') - segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, "activation_length_q")) - segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, "activation_length_kv")) - - segment_axis_names = [{"q": segment_axis_names_q, "kv": segment_axis_names_kv}] - + global_block_q = self.config.sa_block_q global_block_kv = self.config.sa_block_kv global_block_kv_compute = self.config.sa_block_kv_compute @@ -338,16 +335,13 @@ def tpu_flash_attention( axis_names_q, axis_names_kv, axis_names_kv, - segment_axis_names, + segment_axis_names_q, + segment_axis_names_kv, ), out_specs=axis_names_q, check_rep=False, ) - def wrap_flash_attention(query, key, value, decoder_segment_ids): - # if decoder_segment_ids is not None: - # assert ( - # query.shape[2] == decoder_segment_ids.q.shape[1] - # ), "Sharding along sequence dimension not allowed in tpu kernel attention" + def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segment_ids_kv): block_sizes = splash_attention_kernel.BlockSizes( block_q=min(global_block_q, query.shape[2]), block_kv=min(global_block_kv, key.shape[2]), @@ -363,14 +357,14 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids): v_layout=splash_attention_kernel.QKVLayout[global_v_layout], ) - mask = splash_attention_mask.CausalMask(shape=(query.shape[2], query.shape[2])) + mask = splash_attention_mask.CausalMask(shape=(query.shape[2], key.shape[2])) # Apply local masking if local sliding attention is enabled. if self.attention_type == AttentionType.LOCAL_SLIDING: if self.sliding_window_size is None: raise ValueError("Sliding_window_size must be set if Local Sliding attention type") mask &= splash_attention_mask.LocalMask( - shape=(query.shape[2], query.shape[2]), + shape=(query.shape[2], key.shape[2]), window_size=(self.sliding_window_size, self.sliding_window_size), offset=0, ) @@ -385,13 +379,17 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids): attn_logits_soft_cap=attn_logits_soft_cap, ) - return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids[0]) + if decoder_segment_ids_q is not None: + decoder_segment_ids_tuple = splash_attention_kernel.SegmentIds(decoder_segment_ids_q, decoder_segment_ids_kv) + else: + decoder_segment_ids_tuple = None + return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids_tuple) devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( "Batch dimension should be shardable among the devices in data and fsdp" " axis" ) - x = wrap_flash_attention(query, key, value, [decoder_segment_ids]) + x = wrap_flash_attention(query, key, value, decoder_segment_ids, decoder_segment_ids) x = jnp.transpose(x, axes=(0, 2, 1, 3)) return x diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 92100375d..fb41b8347 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -573,29 +573,6 @@ def create_device_mesh(config, devices=None): multi_slice_env = num_slices > 1 - dcn_parallelism = [ - config.dcn_data_parallelism, - config.dcn_pipeline_parallelism, - config.dcn_fsdp_parallelism, - config.dcn_fsdp_transpose_parallelism, - config.dcn_sequence_parallelism, - config.dcn_context_parallelism, - config.dcn_tensor_parallelism, - config.dcn_expert_parallelism, - config.dcn_autoregressive_parallelism, - ] - ici_parallelism = [ - config.ici_data_parallelism, - config.ici_pipeline_parallelism, - config.ici_fsdp_parallelism, - config.ici_fsdp_transpose_parallelism, - config.ici_sequence_parallelism, - config.ici_context_parallelism, - config.ici_tensor_parallelism, - config.ici_expert_parallelism, - config.ici_autoregressive_parallelism, - ] - # Find possible unspecified parallelisms ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI") diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index fd0d17756..a26da6e7b 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -511,6 +511,7 @@ def create_parallelisms_list(raw_keys): raw_keys["ici_fsdp_parallelism"], raw_keys["ici_fsdp_transpose_parallelism"], raw_keys["ici_sequence_parallelism"], + raw_keys["ici_context_parallelism"], raw_keys["ici_tensor_parallelism"], raw_keys["ici_tensor_sequence_parallelism"], raw_keys["ici_expert_parallelism"], @@ -522,6 +523,7 @@ def create_parallelisms_list(raw_keys): raw_keys["dcn_fsdp_parallelism"], raw_keys["dcn_fsdp_transpose_parallelism"], raw_keys["dcn_sequence_parallelism"], + raw_keys["dcn_context_parallelism"], raw_keys["dcn_tensor_parallelism"], raw_keys["dcn_tensor_sequence_parallelism"], raw_keys["dcn_expert_parallelism"], From a965a1356c69f0ed6e55d4b39a28b2d88a41e467 Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 3 Jan 2025 17:28:19 +0000 Subject: [PATCH 03/28] add debug statements --- MaxText/train.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/MaxText/train.py b/MaxText/train.py index 63cfa137e..13ef6daf6 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -19,7 +19,7 @@ # Calling jax.device_count here prevents a "TPU platform already registered" error. # See github.com/google/maxtext/issues/20 for more - +import pdb import datetime import os import sys @@ -68,6 +68,7 @@ from ml_goodput_measurement import monitoring # pylint: disable=too-many-positional-arguments +jax.config.update("jax_debug_nans", True) Transformer = models.Transformer EPS = 1e-8 @@ -517,6 +518,63 @@ def reshape_to_microbatch_accumulations(batch_arr): extra_dpo_args = [reference_params] grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, state.params, *extra_dpo_args, is_train=True) + # pdb.set_trace() + def print_dict(d, indent_level=0): + for k, v in d.items(): + indent = " " * indent_level + if isinstance(v, dict): + # jax.debug.print("{indent}Key: {key}, Value: (nested dict)", indent=indent, key=k) + print_dict(v, indent_level + 1) + else: + jax.debug.print("{indent}Key: {key}, Value shape: {value_shape}", indent=indent, key=k, value_shape=v.shape) + # pdb.set_trace() + print(f"key = {k}") + jax.lax.cond(jnp.any(jnp.isnan(v)), lambda v: jax.debug.print(" NaN found: value: {val}",val = v), lambda v: jax.debug.print(" is clean"), jnp.inf) + jax.lax.cond(jnp.any(jnp.isinf(v)), lambda v: jax.debug.print(" Inf found: value: {val}",val = v), lambda v: jax.debug.print(" is clean"), jnp.inf) + # jax.lax.cond(jnp.any(jnp.isnan(v)), lambda x: jax.debug.print("NaN found in key: {key}, value: {val}",key = x[0],val = x[1]), lambda x: jax.debug.print("{key} is clean", key = x[0]), (k,v)) + # if jnp.any(jnp.isnan(v)): + # jax.debug.print(f"NaN found in key: {key}, value: {x[1]}") + # else: + # jax.debug.print(f"{x[0]} is clean"), (key,value) + + # print_dict(raw_grads) + # x = 0.0/0.0 + # jax.tree_util.tree_map(lambda x: jax.debug.print("Nan found = {bool_result}", bool_result=jnp.any(jnp.isnan(x))), {"x":jnp.inf, "y": 1, "z":jnp.nan}) + # jax.tree_util.tree_map(lambda x: jax.debug.print("Inf found = {bool_result}", bool_result=jnp.any(jnp.isinf(x))), {"x":jnp.inf, "y": 1, "z":jnp.nan}) + # jax.tree_util.tree_map(lambda x: jax.debug.print("Nan found = {bool_result}", bool_result=jnp.any(jnp.isnan(x))), jax.tree_leaves({"x":jnp.inf, "y": 1, "z":jnp.nan})) + # def check_nested_jax_dict_for_nans(nested_dict): + # """ + # Recursively checks a nested dictionary containing JAX ndarrays for NaN values. + + # Args: + # nested_dict: A dictionary that can potentially contain other dictionaries + # or JAX ndarrays (jnp.ndarray). + + # Returns: + # True if any NaN is found, False otherwise. + # """ + + # for key, value in nested_dict.items(): + # if isinstance(value, jnp.ndarray): + # jax.lax.cond(jnp.any(jnp.isnan(value)), lambda x: jax.debug.print(f"NaN found in key: {x[0]}, value: {x[1]}"), lambda x: jax.debug.print(f"{x[0]} is clean"), (key,value)) + # # if jnp.any(jnp.isnan(value)): + # # jax.debug.print(f"NaN found in key: {key}, value: {value}") + # # return True + # elif isinstance(value, dict): + # jax.lax.cond(check_nested_jax_dict_for_nans(value), lambda key: jax.debug.print(f"NaN found in nested key: {key}"), lambda key: jax.debug.print(f"{key} is clean"), key) + # # if check_nested_jax_dict_for_nans(value): # Recursive call + # # jax.debug.print(f"NaN found in nested key: {key}") + # # return True + + # return False # No NaNs found + # key = 'raw_grads["params"]["token_embedder"]["embedding"]' + # jax.debug.print("Key: {key}, Value: {value}", key=key, value = raw_grads["params"]["token_embedder"]['embedding']) + # pdb.set_trace() + # check_nested_jax_dict_for_nans(raw_grads) + # isnan = jnp.any(jnp.isnan(raw_grads)) + # jax.debug.print(f"{isnan=}") + # jax.tree_util.tree_map(lambda x: arr / grad_and_loss["total_weights"], raw_grads) + intermediate_outputs = aux["intermediate_outputs"] total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] From 61d5ef8e0c392a9956e61f353b0eda575f5f2b42 Mon Sep 17 00:00:00 2001 From: Raymond Zou Date: Tue, 5 Nov 2024 21:27:20 +0000 Subject: [PATCH 04/28] Add segmentid implementation --- MaxText/train.py | 87 +++++++++++++++++------------------------------- 1 file changed, 30 insertions(+), 57 deletions(-) diff --git a/MaxText/train.py b/MaxText/train.py index 13ef6daf6..b45f1c30e 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -518,63 +518,36 @@ def reshape_to_microbatch_accumulations(batch_arr): extra_dpo_args = [reference_params] grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, state.params, *extra_dpo_args, is_train=True) - # pdb.set_trace() - def print_dict(d, indent_level=0): - for k, v in d.items(): - indent = " " * indent_level - if isinstance(v, dict): - # jax.debug.print("{indent}Key: {key}, Value: (nested dict)", indent=indent, key=k) - print_dict(v, indent_level + 1) - else: - jax.debug.print("{indent}Key: {key}, Value shape: {value_shape}", indent=indent, key=k, value_shape=v.shape) - # pdb.set_trace() - print(f"key = {k}") - jax.lax.cond(jnp.any(jnp.isnan(v)), lambda v: jax.debug.print(" NaN found: value: {val}",val = v), lambda v: jax.debug.print(" is clean"), jnp.inf) - jax.lax.cond(jnp.any(jnp.isinf(v)), lambda v: jax.debug.print(" Inf found: value: {val}",val = v), lambda v: jax.debug.print(" is clean"), jnp.inf) - # jax.lax.cond(jnp.any(jnp.isnan(v)), lambda x: jax.debug.print("NaN found in key: {key}, value: {val}",key = x[0],val = x[1]), lambda x: jax.debug.print("{key} is clean", key = x[0]), (k,v)) - # if jnp.any(jnp.isnan(v)): - # jax.debug.print(f"NaN found in key: {key}, value: {x[1]}") - # else: - # jax.debug.print(f"{x[0]} is clean"), (key,value) - - # print_dict(raw_grads) - # x = 0.0/0.0 - # jax.tree_util.tree_map(lambda x: jax.debug.print("Nan found = {bool_result}", bool_result=jnp.any(jnp.isnan(x))), {"x":jnp.inf, "y": 1, "z":jnp.nan}) - # jax.tree_util.tree_map(lambda x: jax.debug.print("Inf found = {bool_result}", bool_result=jnp.any(jnp.isinf(x))), {"x":jnp.inf, "y": 1, "z":jnp.nan}) - # jax.tree_util.tree_map(lambda x: jax.debug.print("Nan found = {bool_result}", bool_result=jnp.any(jnp.isnan(x))), jax.tree_leaves({"x":jnp.inf, "y": 1, "z":jnp.nan})) - # def check_nested_jax_dict_for_nans(nested_dict): - # """ - # Recursively checks a nested dictionary containing JAX ndarrays for NaN values. - - # Args: - # nested_dict: A dictionary that can potentially contain other dictionaries - # or JAX ndarrays (jnp.ndarray). - - # Returns: - # True if any NaN is found, False otherwise. - # """ - - # for key, value in nested_dict.items(): - # if isinstance(value, jnp.ndarray): - # jax.lax.cond(jnp.any(jnp.isnan(value)), lambda x: jax.debug.print(f"NaN found in key: {x[0]}, value: {x[1]}"), lambda x: jax.debug.print(f"{x[0]} is clean"), (key,value)) - # # if jnp.any(jnp.isnan(value)): - # # jax.debug.print(f"NaN found in key: {key}, value: {value}") - # # return True - # elif isinstance(value, dict): - # jax.lax.cond(check_nested_jax_dict_for_nans(value), lambda key: jax.debug.print(f"NaN found in nested key: {key}"), lambda key: jax.debug.print(f"{key} is clean"), key) - # # if check_nested_jax_dict_for_nans(value): # Recursive call - # # jax.debug.print(f"NaN found in nested key: {key}") - # # return True - - # return False # No NaNs found - # key = 'raw_grads["params"]["token_embedder"]["embedding"]' - # jax.debug.print("Key: {key}, Value: {value}", key=key, value = raw_grads["params"]["token_embedder"]['embedding']) - # pdb.set_trace() - # check_nested_jax_dict_for_nans(raw_grads) - # isnan = jnp.any(jnp.isnan(raw_grads)) - # jax.debug.print(f"{isnan=}") - # jax.tree_util.tree_map(lambda x: arr / grad_and_loss["total_weights"], raw_grads) - + + # isinf = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["decoder_norm"]["scale"])) + # isinf2 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["mlp"]["wi_0"]["kernel"])) # + # isinf3 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["mlp"]["wi_1"]["kernel"])) # + # isinf4 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["mlp"]["wo"]["kernel"])) # + # isinf5 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["post_self_attention_layer_norm"]["scale"])) # + # isinf6 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["pre_self_attention_layer_norm"]["scale"])) # + # isinf7 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"])) # + # isinf8 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"])) # + # isinf9 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"])) # + # isinf10 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"])) # + # isinf11 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["logits_dense"]["kernel"])) # + # isinf12 = jnp.any(jnp.isinf(raw_grads["params"]["token_embedder"]["embedding"])) + + # jax.debug.print("debug isinf raw_grads 1: {x}", x=isinf) + # jax.debug.print("debug isinf raw_grads 2: {x}", x=isinf2) + # jax.debug.print("debug isinf raw_grads 3: {x}", x=isinf3) + # jax.debug.print("debug isinf raw_grads 4: {x}", x=isinf4) + # jax.debug.print("debug isinf raw_grads 5: {x}", x=isinf5) + # jax.debug.print("debug isinf raw_grads 6: {x}", x=isinf6) + # jax.debug.print("debug isinf raw_grads 7: {x}", x=isinf7) + # jax.debug.print("debug isinf raw_grads 8: {x}", x=isinf8) + # jax.debug.print("debug isinf raw_grads 9: {x}", x=isinf9) + # jax.debug.print("debug isinf raw_grads 10: {x}", x=isinf10) + # jax.debug.print("debug isinf raw_grads 11: {x}", x=isinf11) + # jax.debug.print("debug isinf raw_grads 12: {x}", x=isinf12) + + # isinf_l2norm = jnp.any(jnp.isinf(max_utils.l2norm_pytree(raw_grads))) + # jax.debug.print("debug isinf l2norm raw_grads: {x}", x=isinf_l2norm) + intermediate_outputs = aux["intermediate_outputs"] total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] From d01af2e8e8babac4c240d9783d6ad07a93356133 Mon Sep 17 00:00:00 2001 From: A9isha Date: Sat, 4 Jan 2025 01:56:15 +0000 Subject: [PATCH 05/28] temp commit to fix splash attention kernel call --- MaxText/common_types.py | 2 +- MaxText/configs/base.yml | 2 +- MaxText/layers/attentions.py | 116 ++++++++++++++++++++++------------- 3 files changed, 74 insertions(+), 46 deletions(-) diff --git a/MaxText/common_types.py b/MaxText/common_types.py index bdca4f380..fba7716f5 100644 --- a/MaxText/common_types.py +++ b/MaxText/common_types.py @@ -36,7 +36,7 @@ BATCH = "activation_batch" LENGTH = "activation_length" -KV_LENGTH = "activation_kv_length" +KV_LENGTH = "activation_length_kv" EMBED = "activation_embed" HEAD = "activation_heads" PREFILL_KV_BATCH = "activation_prefill_kv_batch" diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index a2ba4bc7e..0b2250b5d 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -239,7 +239,7 @@ logical_axis_rules: [ ['activation_length', ['sequence', 'context']], ['activation_length', ['context']], ['activation_length_q', ['context']], - ['activation_kv_length', []], + ['activation_length_kv', []], ['activation_embed', 'tensor'], ['activation_mlp', ['tensor', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_sequence']], diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index d49448a13..807443593 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -35,6 +35,7 @@ from layers import linears from layers import quantizations import max_logging +import pdb # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes @@ -145,6 +146,7 @@ class AttentionOp(nn.Module): float32_logits: bool = False flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV) flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV) + flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH) prefill_cache_logical_axis_names: AxisNames = (CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV) cache_logical_axis_names: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV) cache_scale_logical_axis_names: AxisNames = (CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV) @@ -310,10 +312,14 @@ def tpu_flash_attention( if decoder_segment_ids is not None: segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, "activation_length_q")) segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, "activation_length_kv")) + #Anisha + # does segment_axis_names_splash_kernel also need to be inside `if decoder_segment_ids is not None:`? + axis_names_splash_kernel = nn.logical_to_mesh_axes(self.flash_axis_names_splash_kernel) axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q) axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv) max_logging.log(f'axis_names_q: {axis_names_q}') max_logging.log(f'axis_names_kv: {axis_names_kv}') + max_logging.log(f'axis_names_splash_kernel: {axis_names_splash_kernel}') global_block_q = self.config.sa_block_q global_block_kv = self.config.sa_block_kv @@ -328,6 +334,56 @@ def tpu_flash_attention( global_k_layout = self.config.sa_k_layout global_v_layout = self.config.sa_v_layout + + devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] + assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( + "Batch dimension should be shardable among the devices in data and fsdp" " axis" + ) + + #Anisha + #create_splash_attention kernel + + block_sizes = splash_attention_kernel.BlockSizes( + block_q=min(global_block_q, query.shape[2]), + block_kv=min(global_block_kv, key.shape[2]), + block_kv_compute=min(global_block_kv_compute, key.shape[2]), + block_q_dkv=min(global_block_q_dkv, query.shape[2]), + block_kv_dkv=min(global_block_kv_dkv, key.shape[2]), + block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]), + block_q_dq=None if global_use_fused_bwd_kernel else min(global_block_q_dq, query.shape[2]), + block_kv_dq=None if global_use_fused_bwd_kernel else min(global_block_kv_dq, query.shape[2]), + use_fused_bwd_kernel=global_use_fused_bwd_kernel, + q_layout=splash_attention_kernel.QKVLayout[global_q_layout], + k_layout=splash_attention_kernel.QKVLayout[global_k_layout], + v_layout=splash_attention_kernel.QKVLayout[global_v_layout], + ) + # jax.debug.print("query.shape = {qs}, key.shape = {ks}", qs = query.shape, ks = key.shape) + mask = splash_attention_mask.CausalMask(shape=(query.shape[2], key.shape[2])) + + # Apply local masking if local sliding attention is enabled. + if self.attention_type == AttentionType.LOCAL_SLIDING: + if self.sliding_window_size is None: + raise ValueError("Sliding_window_size must be set if Local Sliding attention type") + mask &= splash_attention_mask.LocalMask( + shape=(query.shape[2], key.shape[2]), + window_size=(self.sliding_window_size, self.sliding_window_size), + offset=0, + ) + + # Create multi-head mask + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=1, + q_seq_shards=1, #seq shard + block_sizes=block_sizes, + attn_logits_soft_cap=attn_logits_soft_cap, + ) + + named_sharding = jax.sharding.NamedSharding(self.mesh, axis_names_splash_kernel) + segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) + + @functools.partial( shard_map, mesh=self.mesh, @@ -337,59 +393,31 @@ def tpu_flash_attention( axis_names_kv, segment_axis_names_q, segment_axis_names_kv, + segment_axis_names_splash_kernel, #TODO: add the manual sharding, ), out_specs=axis_names_q, check_rep=False, ) - def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segment_ids_kv): - block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(global_block_q, query.shape[2]), - block_kv=min(global_block_kv, key.shape[2]), - block_kv_compute=min(global_block_kv_compute, key.shape[2]), - block_q_dkv=min(global_block_q_dkv, query.shape[2]), - block_kv_dkv=min(global_block_kv_dkv, key.shape[2]), - block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]), - block_q_dq=None if global_use_fused_bwd_kernel else min(global_block_q_dq, query.shape[2]), - block_kv_dq=None if global_use_fused_bwd_kernel else min(global_block_kv_dq, query.shape[2]), - use_fused_bwd_kernel=global_use_fused_bwd_kernel, - q_layout=splash_attention_kernel.QKVLayout[global_q_layout], - k_layout=splash_attention_kernel.QKVLayout[global_k_layout], - v_layout=splash_attention_kernel.QKVLayout[global_v_layout], - ) - - mask = splash_attention_mask.CausalMask(shape=(query.shape[2], key.shape[2])) - - # Apply local masking if local sliding attention is enabled. - if self.attention_type == AttentionType.LOCAL_SLIDING: - if self.sliding_window_size is None: - raise ValueError("Sliding_window_size must be set if Local Sliding attention type") - mask &= splash_attention_mask.LocalMask( - shape=(query.shape[2], key.shape[2]), - window_size=(self.sliding_window_size, self.sliding_window_size), - offset=0, - ) - - # Create multi-head mask - multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) - splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, - head_shards=1, - q_seq_shards=1, - block_sizes=block_sizes, - attn_logits_soft_cap=attn_logits_soft_cap, - ) + def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segment_ids_kv, splash_kernel): if decoder_segment_ids_q is not None: decoder_segment_ids_tuple = splash_attention_kernel.SegmentIds(decoder_segment_ids_q, decoder_segment_ids_kv) else: decoder_segment_ids_tuple = None - return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids_tuple) - - devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] - assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( - "Batch dimension should be shardable among the devices in data and fsdp" " axis" - ) - x = wrap_flash_attention(query, key, value, decoder_segment_ids, decoder_segment_ids) + attention_output = jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids_tuple) + # pdb.set_trace() + # jax.debug.print("attention_output.shape = {ash}", ash = attention_output.shape) + # full_mask = [per_head_mask for per_head_mask in multi_head_mask.masks] + # valid_tokens = multi_head_mask.masks.any(dim=-1) # [q_sl] -> [q_sl, 1] -> [q_sl, head_dim] + # valid_tokens = decoder_segment_ids_q & multi_head_mask.masks.any(dim=-1) + # attention_output = attention_output * valid_tokens # broadcasting along head_dim + + return attention_output + + x = wrap_flash_attention(query, key, value, decoder_segment_ids, decoder_segment_ids, splash_kernel) + # pdb.set_trace() + # jax.debug.print("{x}", x=x) + x = jnp.transpose(x, axes=(0, 2, 1, 3)) return x From a0e093efcde7dca5615ba5b23f30edc413cce74c Mon Sep 17 00:00:00 2001 From: A9isha Date: Mon, 6 Jan 2025 23:23:13 +0000 Subject: [PATCH 06/28] fix seq sharding parameter in make_splash_mha --- MaxText/layers/attentions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 807443593..79481b9d1 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -369,13 +369,13 @@ def tpu_flash_attention( window_size=(self.sliding_window_size, self.sliding_window_size), offset=0, ) - + # pdb.set_trace() # Create multi-head mask multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, head_shards=1, - q_seq_shards=1, #seq shard + q_seq_shards=int(query.shape[2]/global_block_q), #seq shard block_sizes=block_sizes, attn_logits_soft_cap=attn_logits_soft_cap, ) From ce2c42fb90a93931dbfcdded96c805f5e98e8613 Mon Sep 17 00:00:00 2001 From: A9isha Date: Tue, 7 Jan 2025 03:39:40 +0000 Subject: [PATCH 07/28] update --- MaxText/layers/attentions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 79481b9d1..393bd029a 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -374,7 +374,7 @@ def tpu_flash_attention( multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, - head_shards=1, + head_shards=1, # we would need to change this to the size of the axis if sharding over heads q_seq_shards=int(query.shape[2]/global_block_q), #seq shard block_sizes=block_sizes, attn_logits_soft_cap=attn_logits_soft_cap, From ece53a55064564cf9c47a9a764b242b1848f4200 Mon Sep 17 00:00:00 2001 From: A9isha Date: Tue, 7 Jan 2025 22:51:04 +0000 Subject: [PATCH 08/28] fix the -1 value for context parallelism --- MaxText/layers/attentions.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 393bd029a..88159ac69 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -340,9 +340,7 @@ def tpu_flash_attention( "Batch dimension should be shardable among the devices in data and fsdp" " axis" ) - #Anisha #create_splash_attention kernel - block_sizes = splash_attention_kernel.BlockSizes( block_q=min(global_block_q, query.shape[2]), block_kv=min(global_block_kv, key.shape[2]), @@ -369,13 +367,13 @@ def tpu_flash_attention( window_size=(self.sliding_window_size, self.sliding_window_size), offset=0, ) - # pdb.set_trace() + # Create multi-head mask multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, head_shards=1, # we would need to change this to the size of the axis if sharding over heads - q_seq_shards=int(query.shape[2]/global_block_q), #seq shard + q_seq_shards=self.mesh.shape["context"], #axis for sequence sharding block_sizes=block_sizes, attn_logits_soft_cap=attn_logits_soft_cap, ) From 54e9a55c1dfe65cbc3124c95bed6c073f0045304 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 8 Jan 2025 21:54:13 +0000 Subject: [PATCH 09/28] add load_balancing in causal mask --- MaxText/layers/attentions.py | 90 +++++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 88159ac69..84b82eb67 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -301,8 +301,21 @@ def tpu_flash_attention( value: Array, decoder_segment_ids: Array | None, attn_logits_soft_cap: float | None = None, + load_balanced_context_parallel: bool = True ) -> Array: """TPU Flash Attention.""" + + + #Anisha: reorder tensors which is currently [B,S,H,KV] + cp_size = self.mesh.shape["context"] + if cp_size>1 and load_balanced_context_parallel: + query = self.reorder_causal_load_balancing(tensor = query, cp_size= cp_size, seq_dim= 1, to_contiguous=False) + decoder_segment_ids = self.reorder_causal_load_balancing(tensor = decoder_segment_ids, cp_size= cp_size, seq_dim= 1, to_contiguous=False) + + + + + # Transpose to ('batch', 'heads', 'length', 'kv') query = jnp.transpose(query, axes=(0, 2, 1, 3)) key = jnp.transpose(key, axes=(0, 2, 1, 3)) @@ -340,6 +353,8 @@ def tpu_flash_attention( "Batch dimension should be shardable among the devices in data and fsdp" " axis" ) + + #create_splash_attention kernel block_sizes = splash_attention_kernel.BlockSizes( block_q=min(global_block_q, query.shape[2]), @@ -358,6 +373,11 @@ def tpu_flash_attention( # jax.debug.print("query.shape = {qs}, key.shape = {ks}", qs = query.shape, ks = key.shape) mask = splash_attention_mask.CausalMask(shape=(query.shape[2], key.shape[2])) + # Anisha: permute the mask + if cp_size>1 and load_balanced_context_parallel: + mask = self.reorder_causal_load_balancing(mask, cp_size, 1, to_contiguous=False) + + #Anisha: figure out local_sliding attention + load_balancing, default is global # Apply local masking if local sliding attention is enabled. if self.attention_type == AttentionType.LOCAL_SLIDING: if self.sliding_window_size is None: @@ -373,7 +393,7 @@ def tpu_flash_attention( splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, head_shards=1, # we would need to change this to the size of the axis if sharding over heads - q_seq_shards=self.mesh.shape["context"], #axis for sequence sharding + q_seq_shards=cp_size, #axis for sequence sharding block_sizes=block_sizes, attn_logits_soft_cap=attn_logits_soft_cap, ) @@ -417,8 +437,76 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segme # jax.debug.print("{x}", x=x) x = jnp.transpose(x, axes=(0, 2, 1, 3)) + + if cp_size>1 and load_balanced_context_parallel: + #Anisha: inverse reorder for load_balancing + x = self.reorder_causal_load_balancing(tensor = x, cp_size= cp_size, seq_dim= 1, to_contiguous=True) + + return x + + def reorder_causal_load_balancing(self, tensor, cp_size: int, seq_dim: int, to_contiguous: bool): + """Reorders a tensor for load balancing the compute of causal attention.""" + + if tensor is None: + return tensor + + if cp_size == 1: + return tensor + + if cp_size % 2 != 0: + raise ValueError(f"{cp_size=} must be a multiple of 2.") + + # Need to ensure we have 2 pairs to swap for balancing between cp ranks + if tensor.shape[seq_dim] % (cp_size * 2) != 0: + raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + + # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] #Anisha: this is ours + # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] + + ori_tensor_shape = tensor.shape + tensor = jnp.reshape( + tensor, + ( + *ori_tensor_shape[:seq_dim], + 2 * cp_size, + ori_tensor_shape[seq_dim] // (2 * cp_size), + *ori_tensor_shape[seq_dim + 1 :], + ) + ) + + parts = [] + if not to_contiguous: + for cp_rank in range(cp_size): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + else: + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 4 * cp_rank + index = jnp.array([base, base + 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 2 * cp_size - 1 - 4 * cp_rank + index = jnp.array([base, base - 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] + combined = jnp.stack(parts, axis=seq_dim) + + return jnp.reshape(combined,ori_tensor_shape) + + + + + def cudnn_flash_attention( self, query: Array, From 42415709d7e8e164353225ee1afe77ee99dc5d59 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 8 Jan 2025 23:23:05 +0000 Subject: [PATCH 10/28] reorder mask first commit --- MaxText/layers/attentions.py | 76 ++++++++++++++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 3 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 84b82eb67..5c096fbeb 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -36,6 +36,7 @@ from layers import quantizations import max_logging import pdb +import numpy as np # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes @@ -370,12 +371,22 @@ def tpu_flash_attention( k_layout=splash_attention_kernel.QKVLayout[global_k_layout], v_layout=splash_attention_kernel.QKVLayout[global_v_layout], ) - # jax.debug.print("query.shape = {qs}, key.shape = {ks}", qs = query.shape, ks = key.shape) + + mask = splash_attention_mask.CausalMask(shape=(query.shape[2], key.shape[2])) - # Anisha: permute the mask + # jax.debug.print("original: mask items = {items}", items = mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) + + # Anisha: permute the mask if cp and load_balancing if cp_size>1 and load_balanced_context_parallel: - mask = self.reorder_causal_load_balancing(mask, cp_size, 1, to_contiguous=False) + original_mask_ndarray = mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1]))) + permuted_mask_ndarray = self.reorder_causal_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0, to_contiguous=False) + mask = LoadBalancedCausalMask(shape=(query.shape[2], key.shape[2]),mask_ndarray=permuted_mask_ndarray) + + # pdb.set_trace() + # jax.debug.print("permuted: mask items = {items}", items = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) + + # jax.debug.print("new_mask == old_mask = {equal}", equal = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))==mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) #Anisha: figure out local_sliding attention + load_balancing, default is global # Apply local masking if local sliding attention is enabled. @@ -1450,3 +1461,62 @@ def __call__( out = self.out_projection(inputs_q.shape[-1], out) out = checkpoint_name(out, "out_proj") return out + +class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): + """Lazy causal mask, prevents the model from attending to future tokens. + + Attributes: + offset: Offset of q start wrt kv. A positive offset shifts the bottom + triangle upward, a negative one shifts it downward. A negative offset + makes the first 'offset' rows of the attention matrix all 0s which leads + to undefined softmax. + """ + + offset: int + mask_ndarray: np.ndarray + + def __init__( + self, + shape: tuple[int, int], + offset: int = 0, + mask_ndarray: np.ndarray = None, + shard_count: int = 1, + ): + self.offset = offset + self.mask_ndarray = mask_ndarray + + def causal_mask_function_load_balanced(q_ids, kv_ids): + return self.mask_ndarray[q_ids, kv_ids] + # # When evaluating the mask in _process_mask we typically work with numpy + # # array views. + # # Avoid the addition when possible to avoid instantiating an actual array. + # if self.offset == 0: + # return q_ids >= kv_ids + # else: + # return q_ids + self.offset >= kv_ids + + mask_function = causal_mask_function_load_balanced + + super().__init__( + shape=shape, + mask_function=mask_function, + shard_count=shard_count, + ) + + def __eq__(self, other: object): + if not isinstance(other, type(self)): + return NotImplemented + + return ( + self.shape == other.shape + and self.offset == other.offset + and np.array_equal(self.q_sequence, other.q_sequence) + ) + + def __hash__(self): + return hash(( + type(self), + self.shape, + self.offset, + self.q_sequence.tobytes() if self.q_sequence is not None else None, + )) \ No newline at end of file From 567233d5e4f6ef05a8a4671bc7266286026d68a1 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 8 Jan 2025 23:57:56 +0000 Subject: [PATCH 11/28] try to make static --- MaxText/layers/attentions.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 5c096fbeb..eb3a0150c 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -372,15 +372,24 @@ def tpu_flash_attention( v_layout=splash_attention_kernel.QKVLayout[global_v_layout], ) - - mask = splash_attention_mask.CausalMask(shape=(query.shape[2], key.shape[2])) + mask_shape = (query.shape[2], key.shape[2]) + mask = splash_attention_mask.CausalMask(shape=mask_shape) # jax.debug.print("original: mask items = {items}", items = mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) # Anisha: permute the mask if cp and load_balancing if cp_size>1 and load_balanced_context_parallel: - original_mask_ndarray = mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1]))) + idx = (slice(mask_shape[0]),slice(mask_shape[1])) + q_slice, kv_slice = idx + q_slice = splash_attention_mask._fill_slice(q_slice, mask_shape[0]) + kv_slice = splash_attention_mask._fill_slice(kv_slice, mask_shape[1]) + q_ids = mask.q_sequence[q_slice] + kv_ids = np.arange(kv_slice.start, kv_slice.stop) + #assuming offset == 0: + original_mask_ndarray = q_ids >= kv_ids + permuted_mask_ndarray = self.reorder_causal_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0, to_contiguous=False) + pdb.set_trace() mask = LoadBalancedCausalMask(shape=(query.shape[2], key.shape[2]),mask_ndarray=permuted_mask_ndarray) # pdb.set_trace() @@ -1462,6 +1471,7 @@ def __call__( out = checkpoint_name(out, "out_proj") return out +partial = functools.partial class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): """Lazy causal mask, prevents the model from attending to future tokens. @@ -1474,7 +1484,12 @@ class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): offset: int mask_ndarray: np.ndarray - + @partial( + jax.jit, + static_argnames=[ + "mask_ndarray", + ], +) def __init__( self, shape: tuple[int, int], From b70e50b75d0b2ed4860d25eac9db3506115d22db Mon Sep 17 00:00:00 2001 From: A9isha Date: Thu, 9 Jan 2025 00:39:51 +0000 Subject: [PATCH 12/28] static argnames in progress --- MaxText/layers/attentions.py | 64 ++++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index eb3a0150c..38dc9c581 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -372,7 +372,8 @@ def tpu_flash_attention( v_layout=splash_attention_kernel.QKVLayout[global_v_layout], ) - mask_shape = (query.shape[2], key.shape[2]) + # mask_shape = (query.shape[2], key.shape[2]) + mask_shape = (self.config.max_target_length, self.config.max_target_length) mask = splash_attention_mask.CausalMask(shape=mask_shape) # jax.debug.print("original: mask items = {items}", items = mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) @@ -383,14 +384,15 @@ def tpu_flash_attention( q_slice, kv_slice = idx q_slice = splash_attention_mask._fill_slice(q_slice, mask_shape[0]) kv_slice = splash_attention_mask._fill_slice(kv_slice, mask_shape[1]) - q_ids = mask.q_sequence[q_slice] - kv_ids = np.arange(kv_slice.start, kv_slice.stop) + q_sequence = jnp.arange(mask_shape[0], dtype=jnp.int32) + q_ids = q_sequence[q_slice] + kv_ids = jnp.arange(kv_slice.start, kv_slice.stop) #assuming offset == 0: original_mask_ndarray = q_ids >= kv_ids - permuted_mask_ndarray = self.reorder_causal_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0, to_contiguous=False) - pdb.set_trace() - mask = LoadBalancedCausalMask(shape=(query.shape[2], key.shape[2]),mask_ndarray=permuted_mask_ndarray) + permuted_mask_ndarray = self.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) + # pdb.set_trace() + mask = LoadBalancedCausalMask(shape=mask_shape,mask_ndarray=permuted_mask_ndarray) # pdb.set_trace() # jax.debug.print("permuted: mask items = {items}", items = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) @@ -465,6 +467,56 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segme return x + @functools.partial( + jax.jit, + static_argnames=[ + "tensor", + "cp_size", + "seq_dim", + ], + ) + def reorder_mask_load_balancing(self, tensor, cp_size: int, seq_dim: int): + """Reorders a tensor for load balancing the compute of causal attention.""" + + if tensor is None: + return tensor + + if cp_size == 1: + return tensor + + if cp_size % 2 != 0: + raise ValueError(f"{cp_size=} must be a multiple of 2.") + + # Need to ensure we have 2 pairs to swap for balancing between cp ranks + if tensor.shape[seq_dim] % (cp_size * 2) != 0: + raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + + # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] #Anisha: this is ours + # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] + + ori_tensor_shape = tensor.shape + tensor = jnp.reshape( + tensor, + ( + *ori_tensor_shape[:seq_dim], + 2 * cp_size, + ori_tensor_shape[seq_dim] // (2 * cp_size), + *ori_tensor_shape[seq_dim + 1 :], + ) + ) + + parts = [] + for cp_rank in range(cp_size): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] + combined = jnp.stack(parts, axis=seq_dim) + + return jnp.reshape(combined,ori_tensor_shape) def reorder_causal_load_balancing(self, tensor, cp_size: int, seq_dim: int, to_contiguous: bool): """Reorders a tensor for load balancing the compute of causal attention.""" From 106ae91dc68e8be5c4a6487936432221d0a12622 Mon Sep 17 00:00:00 2001 From: A9isha Date: Thu, 9 Jan 2025 01:32:29 +0000 Subject: [PATCH 13/28] made multi_head_mask np.ndarray --- MaxText/layers/attentions.py | 97 +++++++++++++++++++++++------------- 1 file changed, 61 insertions(+), 36 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 38dc9c581..846466b4b 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -380,19 +380,19 @@ def tpu_flash_attention( # Anisha: permute the mask if cp and load_balancing if cp_size>1 and load_balanced_context_parallel: - idx = (slice(mask_shape[0]),slice(mask_shape[1])) - q_slice, kv_slice = idx - q_slice = splash_attention_mask._fill_slice(q_slice, mask_shape[0]) - kv_slice = splash_attention_mask._fill_slice(kv_slice, mask_shape[1]) - q_sequence = jnp.arange(mask_shape[0], dtype=jnp.int32) - q_ids = q_sequence[q_slice] - kv_ids = jnp.arange(kv_slice.start, kv_slice.stop) - #assuming offset == 0: - original_mask_ndarray = q_ids >= kv_ids - - permuted_mask_ndarray = self.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) - # pdb.set_trace() - mask = LoadBalancedCausalMask(shape=mask_shape,mask_ndarray=permuted_mask_ndarray) + # idx = (slice(mask_shape[0]),slice(mask_shape[1])) + # q_slice, kv_slice = idx + # q_slice = splash_attention_mask._fill_slice(q_slice, mask_shape[0]) + # kv_slice = splash_attention_mask._fill_slice(kv_slice, mask_shape[1]) + # q_sequence = jnp.arange(mask_shape[0], dtype=jnp.int32) + # q_ids = q_sequence[q_slice] + # kv_ids = jnp.arange(kv_slice.start, kv_slice.stop) + # #assuming offset == 0: + # original_mask_ndarray = q_ids >= kv_ids + + # permuted_mask_ndarray = reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) + + mask = LoadBalancedCausalMask(shape=mask_shape,cp_size=cp_size) # pdb.set_trace() # jax.debug.print("permuted: mask items = {items}", items = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) @@ -412,13 +412,24 @@ def tpu_flash_attention( # Create multi-head mask multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) - splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, - head_shards=1, # we would need to change this to the size of the axis if sharding over heads - q_seq_shards=cp_size, #axis for sequence sharding - block_sizes=block_sizes, - attn_logits_soft_cap=attn_logits_soft_cap, + + @partial( + jax.jit, + static_argnames=[ + "multi_head_mask", + ], ) + def wrap_splash_kernel(multi_head_mask): + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=1, # we would need to change this to the size of the axis if sharding over heads + q_seq_shards=cp_size, #axis for sequence sharding + block_sizes=block_sizes, + attn_logits_soft_cap=attn_logits_soft_cap, + ) + return splash_kernel + + splash_kernel = wrap_splash_kernel(multi_head_mask) named_sharding = jax.sharding.NamedSharding(self.mesh, axis_names_splash_kernel) segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) @@ -467,15 +478,15 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segme return x - @functools.partial( - jax.jit, - static_argnames=[ - "tensor", - "cp_size", - "seq_dim", - ], - ) - def reorder_mask_load_balancing(self, tensor, cp_size: int, seq_dim: int): + # @functools.partial( + # jax.jit, + # static_argnames=[ + # "tensor", + # "cp_size", + # "seq_dim", + # ], + # ) + def reorder_mask_load_balancing(tensor, cp_size: int, seq_dim: int): """Reorders a tensor for load balancing the compute of causal attention.""" if tensor is None: @@ -1536,21 +1547,35 @@ class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): offset: int mask_ndarray: np.ndarray - @partial( - jax.jit, - static_argnames=[ - "mask_ndarray", - ], -) +# @partial( +# jax.jit, +# static_argnames=[ +# "mask_ndarray", +# ], +# ) def __init__( self, shape: tuple[int, int], offset: int = 0, - mask_ndarray: np.ndarray = None, + # mask_ndarray: jnp.ndarray = None, + cp_size: int = 1, shard_count: int = 1, ): self.offset = offset - self.mask_ndarray = mask_ndarray + idx = (slice(shape[0]),slice(shape[1])) + q_slice, kv_slice = idx + q_slice = splash_attention_mask._fill_slice(q_slice, shape[0]) + kv_slice = splash_attention_mask._fill_slice(kv_slice, shape[1]) + q_sequence = np.arange(shape[0], dtype=np.int32) + rows = q_sequence[q_slice] + cols = np.arange(kv_slice.start, kv_slice.stop) + q_ids = rows[:, None] + kv_ids = cols[None, :] + #assuming offset == 0: + original_mask_ndarray = q_ids >= kv_ids + permuted_mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) + self.mask_ndarray = permuted_mask_ndarray + # pdb.set_trace() def causal_mask_function_load_balanced(q_ids, kv_ids): return self.mask_ndarray[q_ids, kv_ids] From 442c27a3ff05039b5317f09fe485d781e7f349b6 Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 10 Jan 2025 00:31:30 +0000 Subject: [PATCH 14/28] wrap ndarray --- MaxText/layers/attentions.py | 71 +++++++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 846466b4b..7970c9222 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -392,9 +392,11 @@ def tpu_flash_attention( # permuted_mask_ndarray = reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) - mask = LoadBalancedCausalMask(shape=mask_shape,cp_size=cp_size) + permuted_mask_ndarray = LoadBalancedCausalMask.create_mask(shape=mask_shape,cp_size=cp_size) + pdb.set_trace() + + mask = LoadBalancedCausalMask(mask_ndarray=permuted_mask_ndarray,shape=mask_shape,cp_size=cp_size) - # pdb.set_trace() # jax.debug.print("permuted: mask items = {items}", items = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) # jax.debug.print("new_mask == old_mask = {equal}", equal = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))==mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) @@ -486,6 +488,7 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segme # "seq_dim", # ], # ) + @staticmethod def reorder_mask_load_balancing(tensor, cp_size: int, seq_dim: int): """Reorders a tensor for load balancing the compute of causal attention.""" @@ -506,7 +509,7 @@ def reorder_mask_load_balancing(tensor, cp_size: int, seq_dim: int): # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] ori_tensor_shape = tensor.shape - tensor = jnp.reshape( + tensor = np.reshape( tensor, ( *ori_tensor_shape[:seq_dim], @@ -520,14 +523,14 @@ def reorder_mask_load_balancing(tensor, cp_size: int, seq_dim: int): for cp_rank in range(cp_size): # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] - index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) - parts.append(jnp.take(tensor, index, axis=seq_dim)) + index = np.array([cp_rank, (2 * cp_size - cp_rank - 1)]) + parts.append(np.take(tensor, index, axis=seq_dim)) # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] - combined = jnp.stack(parts, axis=seq_dim) + combined = np.stack(parts, axis=seq_dim) - return jnp.reshape(combined,ori_tensor_shape) + return np.reshape(combined,ori_tensor_shape) def reorder_causal_load_balancing(self, tensor, cp_size: int, seq_dim: int, to_contiguous: bool): """Reorders a tensor for load balancing the compute of causal attention.""" @@ -1535,6 +1538,19 @@ def __call__( return out partial = functools.partial + +class WrapperNpNDArray(): + np_ndarray: np.ndarray + + def __init__(self,np_ndarray): + self.np_ndarray = np_ndarray + + def __hash__(self): + return hash(( + type(self), + self.np_ndarray.tobytes() if self.np_ndarray is not None else None, + )) + class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): """Lazy causal mask, prevents the model from attending to future tokens. @@ -1546,22 +1562,17 @@ class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): """ offset: int - mask_ndarray: np.ndarray -# @partial( -# jax.jit, -# static_argnames=[ -# "mask_ndarray", -# ], -# ) - def __init__( - self, + mask_ndarray: WrapperNpNDArray + + @staticmethod + def create_mask( shape: tuple[int, int], - offset: int = 0, + # offset: int = 0, #Anisha: do we need offset? # mask_ndarray: jnp.ndarray = None, cp_size: int = 1, shard_count: int = 1, ): - self.offset = offset + # self.offset = offset idx = (slice(shape[0]),slice(shape[1])) q_slice, kv_slice = idx q_slice = splash_attention_mask._fill_slice(q_slice, shape[0]) @@ -1573,12 +1584,30 @@ def __init__( kv_ids = cols[None, :] #assuming offset == 0: original_mask_ndarray = q_ids >= kv_ids - permuted_mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) - self.mask_ndarray = permuted_mask_ndarray + mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) + return WrapperNpNDArray(mask_ndarray) + # self.mask_ndarray = permuted_mask_ndarray # pdb.set_trace() + @partial( + jax.jit, + static_argnames=[ + "mask_ndarray", + ], + ) + def __init__( + self, + shape: tuple[int, int], + offset: int = 0, + mask_ndarray: WrapperNpNDArray = None, + # cp_size: int = 1, + shard_count: int = 1, + ): + self.mask_ndarray = mask_ndarray + self.offset = offset + def causal_mask_function_load_balanced(q_ids, kv_ids): - return self.mask_ndarray[q_ids, kv_ids] + return self.mask_ndarray.np_ndarray[q_ids, kv_ids] # # When evaluating the mask in _process_mask we typically work with numpy # # array views. # # Avoid the addition when possible to avoid instantiating an actual array. From 0ca559a51fff6cf137fefa250d93614cec596934 Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 10 Jan 2025 01:14:56 +0000 Subject: [PATCH 15/28] fix permuted mask --- MaxText/layers/attentions.py | 90 +++++++++++++++++------------------- 1 file changed, 43 insertions(+), 47 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 7970c9222..0deeec7e2 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -392,10 +392,11 @@ def tpu_flash_attention( # permuted_mask_ndarray = reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) - permuted_mask_ndarray = LoadBalancedCausalMask.create_mask(shape=mask_shape,cp_size=cp_size) - pdb.set_trace() + # permuted_mask_ndarray = LoadBalancedCausalMask.create_mask(shape=mask_shape,cp_size=cp_size) + # pdb.set_trace() - mask = LoadBalancedCausalMask(mask_ndarray=permuted_mask_ndarray,shape=mask_shape,cp_size=cp_size) + mask = create_load_balance_causal_mask(shape=mask_shape,cp_size=cp_size) + breakpoint() # jax.debug.print("permuted: mask items = {items}", items = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) @@ -1550,8 +1551,31 @@ def __hash__(self): type(self), self.np_ndarray.tobytes() if self.np_ndarray is not None else None, )) - -class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): + +def create_load_balance_causal_mask( + shape: tuple[int, int], + # offset: int = 0, #Anisha: do we need offset? + # mask_ndarray: jnp.ndarray = None, + cp_size: int = 1, + # shard_count: int = 1, + ): + # self.offset = offset + idx = (slice(shape[0]),slice(shape[1])) + q_slice, kv_slice = idx + q_slice = splash_attention_mask._fill_slice(q_slice, shape[0]) + kv_slice = splash_attention_mask._fill_slice(kv_slice, shape[1]) + q_sequence = np.arange(shape[0], dtype=np.int32) + rows = q_sequence[q_slice] + cols = np.arange(kv_slice.start, kv_slice.stop) + q_ids = rows[:, None] + kv_ids = cols[None, :] + #assuming offset == 0: + original_mask_ndarray = q_ids >= kv_ids + mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) + return splash_attention_mask.NumpyMask(mask_ndarray) + + +class LoadBalancedCausalMask(splash_attention_mask.NumpyMask): """Lazy causal mask, prevents the model from attending to future tokens. Attributes: @@ -1562,59 +1586,31 @@ class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): """ offset: int - mask_ndarray: WrapperNpNDArray + # mask_ndarray: WrapperNpNDArray - @staticmethod - def create_mask( - shape: tuple[int, int], - # offset: int = 0, #Anisha: do we need offset? - # mask_ndarray: jnp.ndarray = None, - cp_size: int = 1, - shard_count: int = 1, - ): - # self.offset = offset - idx = (slice(shape[0]),slice(shape[1])) - q_slice, kv_slice = idx - q_slice = splash_attention_mask._fill_slice(q_slice, shape[0]) - kv_slice = splash_attention_mask._fill_slice(kv_slice, shape[1]) - q_sequence = np.arange(shape[0], dtype=np.int32) - rows = q_sequence[q_slice] - cols = np.arange(kv_slice.start, kv_slice.stop) - q_ids = rows[:, None] - kv_ids = cols[None, :] - #assuming offset == 0: - original_mask_ndarray = q_ids >= kv_ids - mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) - return WrapperNpNDArray(mask_ndarray) - # self.mask_ndarray = permuted_mask_ndarray - # pdb.set_trace() - - @partial( - jax.jit, - static_argnames=[ - "mask_ndarray", - ], - ) + def __init__( self, shape: tuple[int, int], offset: int = 0, - mask_ndarray: WrapperNpNDArray = None, + # mask_ndarray: WrapperNpNDArray = None, # cp_size: int = 1, shard_count: int = 1, ): - self.mask_ndarray = mask_ndarray + # self.mask_ndarray = mask_ndarray self.offset = offset def causal_mask_function_load_balanced(q_ids, kv_ids): - return self.mask_ndarray.np_ndarray[q_ids, kv_ids] - # # When evaluating the mask in _process_mask we typically work with numpy - # # array views. - # # Avoid the addition when possible to avoid instantiating an actual array. - # if self.offset == 0: - # return q_ids >= kv_ids - # else: - # return q_ids + self.offset >= kv_ids + + + return self.array[q_ids, kv_ids] + # # When evaluating the mask in _process_mask we typically work with numpy + # # array views. + # # Avoid the addition when possible to avoid instantiating an actual array. + # if self.offset == 0: + # return q_ids >= kv_ids + # else: + # return q_ids + self.offset >= kv_ids mask_function = causal_mask_function_load_balanced From c283c4f8e32a61d00bf32bd75aa86b234e4b9f88 Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 10 Jan 2025 01:39:17 +0000 Subject: [PATCH 16/28] fix permutation --- MaxText/layers/attentions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 0deeec7e2..7b8f5c636 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -306,12 +306,13 @@ def tpu_flash_attention( ) -> Array: """TPU Flash Attention.""" + decoder_segment_ids_permuted = None #Anisha: reorder tensors which is currently [B,S,H,KV] cp_size = self.mesh.shape["context"] if cp_size>1 and load_balanced_context_parallel: query = self.reorder_causal_load_balancing(tensor = query, cp_size= cp_size, seq_dim= 1, to_contiguous=False) - decoder_segment_ids = self.reorder_causal_load_balancing(tensor = decoder_segment_ids, cp_size= cp_size, seq_dim= 1, to_contiguous=False) + decoder_segment_ids_permuted = self.reorder_causal_load_balancing(tensor = decoder_segment_ids, cp_size= cp_size, seq_dim= 1, to_contiguous=False) @@ -468,7 +469,7 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segme return attention_output - x = wrap_flash_attention(query, key, value, decoder_segment_ids, decoder_segment_ids, splash_kernel) + x = wrap_flash_attention(query, key, value, decoder_segment_ids_permuted, decoder_segment_ids, splash_kernel) # pdb.set_trace() # jax.debug.print("{x}", x=x) From 4d52390b6e16d0df27d8bc1b42ff20776cc33c30 Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 10 Jan 2025 04:22:35 +0000 Subject: [PATCH 17/28] clean up --- MaxText/layers/attentions.py | 89 ++---------------------------------- 1 file changed, 3 insertions(+), 86 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 7b8f5c636..11f5fdf58 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -377,33 +377,15 @@ def tpu_flash_attention( mask_shape = (self.config.max_target_length, self.config.max_target_length) mask = splash_attention_mask.CausalMask(shape=mask_shape) - # jax.debug.print("original: mask items = {items}", items = mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) - - # Anisha: permute the mask if cp and load_balancing + # permute the mask if cp and load_balancing if cp_size>1 and load_balanced_context_parallel: - # idx = (slice(mask_shape[0]),slice(mask_shape[1])) - # q_slice, kv_slice = idx - # q_slice = splash_attention_mask._fill_slice(q_slice, mask_shape[0]) - # kv_slice = splash_attention_mask._fill_slice(kv_slice, mask_shape[1]) - # q_sequence = jnp.arange(mask_shape[0], dtype=jnp.int32) - # q_ids = q_sequence[q_slice] - # kv_ids = jnp.arange(kv_slice.start, kv_slice.stop) - # #assuming offset == 0: - # original_mask_ndarray = q_ids >= kv_ids - - # permuted_mask_ndarray = reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) - - # permuted_mask_ndarray = LoadBalancedCausalMask.create_mask(shape=mask_shape,cp_size=cp_size) - # pdb.set_trace() - mask = create_load_balance_causal_mask(shape=mask_shape,cp_size=cp_size) - breakpoint() # jax.debug.print("permuted: mask items = {items}", items = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) # jax.debug.print("new_mask == old_mask = {equal}", equal = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))==mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) - #Anisha: figure out local_sliding attention + load_balancing, default is global + #TODO: figure out local_sliding attention + load_balancing, default is global # Apply local masking if local sliding attention is enabled. if self.attention_type == AttentionType.LOCAL_SLIDING: if self.sliding_window_size is None: @@ -448,7 +430,7 @@ def wrap_splash_kernel(multi_head_mask): axis_names_kv, segment_axis_names_q, segment_axis_names_kv, - segment_axis_names_splash_kernel, #TODO: add the manual sharding, + segment_axis_names_splash_kernel, ), out_specs=axis_names_q, check_rep=False, @@ -1556,9 +1538,7 @@ def __hash__(self): def create_load_balance_causal_mask( shape: tuple[int, int], # offset: int = 0, #Anisha: do we need offset? - # mask_ndarray: jnp.ndarray = None, cp_size: int = 1, - # shard_count: int = 1, ): # self.offset = offset idx = (slice(shape[0]),slice(shape[1])) @@ -1575,66 +1555,3 @@ def create_load_balance_causal_mask( mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) return splash_attention_mask.NumpyMask(mask_ndarray) - -class LoadBalancedCausalMask(splash_attention_mask.NumpyMask): - """Lazy causal mask, prevents the model from attending to future tokens. - - Attributes: - offset: Offset of q start wrt kv. A positive offset shifts the bottom - triangle upward, a negative one shifts it downward. A negative offset - makes the first 'offset' rows of the attention matrix all 0s which leads - to undefined softmax. - """ - - offset: int - # mask_ndarray: WrapperNpNDArray - - - def __init__( - self, - shape: tuple[int, int], - offset: int = 0, - # mask_ndarray: WrapperNpNDArray = None, - # cp_size: int = 1, - shard_count: int = 1, - ): - # self.mask_ndarray = mask_ndarray - self.offset = offset - - def causal_mask_function_load_balanced(q_ids, kv_ids): - - - return self.array[q_ids, kv_ids] - # # When evaluating the mask in _process_mask we typically work with numpy - # # array views. - # # Avoid the addition when possible to avoid instantiating an actual array. - # if self.offset == 0: - # return q_ids >= kv_ids - # else: - # return q_ids + self.offset >= kv_ids - - mask_function = causal_mask_function_load_balanced - - super().__init__( - shape=shape, - mask_function=mask_function, - shard_count=shard_count, - ) - - def __eq__(self, other: object): - if not isinstance(other, type(self)): - return NotImplemented - - return ( - self.shape == other.shape - and self.offset == other.offset - and np.array_equal(self.q_sequence, other.q_sequence) - ) - - def __hash__(self): - return hash(( - type(self), - self.shape, - self.offset, - self.q_sequence.tobytes() if self.q_sequence is not None else None, - )) \ No newline at end of file From 7feeb07704861f5fbcee05950fd71e0c1aa9d9db Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 10 Jan 2025 04:24:50 +0000 Subject: [PATCH 18/28] fix non load balanced case --- MaxText/layers/attentions.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 11f5fdf58..817ed61b3 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -451,14 +451,15 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segme return attention_output - x = wrap_flash_attention(query, key, value, decoder_segment_ids_permuted, decoder_segment_ids, splash_kernel) - # pdb.set_trace() - # jax.debug.print("{x}", x=x) + if cp_size>1 and load_balanced_context_parallel: + x = wrap_flash_attention(query, key, value, decoder_segment_ids_permuted, decoder_segment_ids, splash_kernel) + else: + x = wrap_flash_attention(query, key, value, decoder_segment_ids, decoder_segment_ids, splash_kernel) x = jnp.transpose(x, axes=(0, 2, 1, 3)) if cp_size>1 and load_balanced_context_parallel: - #Anisha: inverse reorder for load_balancing + #inverse reorder for load_balancing x = self.reorder_causal_load_balancing(tensor = x, cp_size= cp_size, seq_dim= 1, to_contiguous=True) From bf2d21a28f337a8ad897f5c6533ceaff9059d512 Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 10 Jan 2025 04:25:55 +0000 Subject: [PATCH 19/28] clean up --- MaxText/layers/attentions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 817ed61b3..3a01cdc2c 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -308,7 +308,7 @@ def tpu_flash_attention( decoder_segment_ids_permuted = None - #Anisha: reorder tensors which is currently [B,S,H,KV] + #Reorder tensors which is currently [B,S,H,KV] cp_size = self.mesh.shape["context"] if cp_size>1 and load_balanced_context_parallel: query = self.reorder_causal_load_balancing(tensor = query, cp_size= cp_size, seq_dim= 1, to_contiguous=False) @@ -327,8 +327,6 @@ def tpu_flash_attention( if decoder_segment_ids is not None: segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, "activation_length_q")) segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, "activation_length_kv")) - #Anisha - # does segment_axis_names_splash_kernel also need to be inside `if decoder_segment_ids is not None:`? axis_names_splash_kernel = nn.logical_to_mesh_axes(self.flash_axis_names_splash_kernel) axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q) axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv) From 2edbfcf6ee95cda0c9c7c9f810516fd296608539 Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 10 Jan 2025 23:45:24 +0000 Subject: [PATCH 20/28] use ComputableMask --- MaxText/layers/attentions.py | 102 ++++++++++++++++++++++++++++------- 1 file changed, 83 insertions(+), 19 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 3a01cdc2c..4565d818c 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -377,7 +377,8 @@ def tpu_flash_attention( # permute the mask if cp and load_balancing if cp_size>1 and load_balanced_context_parallel: - mask = create_load_balance_causal_mask(shape=mask_shape,cp_size=cp_size) + # mask = create_load_balance_causal_mask(shape=mask_shape,cp_size=cp_size) + mask = LoadBalancedCausalMask(shape=mask_shape,cp_size=cp_size) # jax.debug.print("permuted: mask items = {items}", items = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) @@ -508,6 +509,8 @@ def reorder_mask_load_balancing(tensor, cp_size: int, seq_dim: int): # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] index = np.array([cp_rank, (2 * cp_size - cp_rank - 1)]) parts.append(np.take(tensor, index, axis=seq_dim)) + # breakpoint() + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] @@ -1534,23 +1537,84 @@ def __hash__(self): self.np_ndarray.tobytes() if self.np_ndarray is not None else None, )) -def create_load_balance_causal_mask( - shape: tuple[int, int], - # offset: int = 0, #Anisha: do we need offset? - cp_size: int = 1, + +class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): + """Lazy causal mask, prevents the model from attending to future tokens. + Attributes: + offset: Offset of q start wrt kv. A positive offset shifts the bottom + triangle upward, a negative one shifts it downward. A negative offset + makes the first 'offset' rows of the attention matrix all 0s which leads + to undefined softmax. + """ + offset: int + shape: tuple[int, int] + cp_size: int + + def __init__( + self, + shape: tuple[int, int], + cp_size: int, + offset: int = 0, + shard_count: int = 1, ): - # self.offset = offset - idx = (slice(shape[0]),slice(shape[1])) - q_slice, kv_slice = idx - q_slice = splash_attention_mask._fill_slice(q_slice, shape[0]) - kv_slice = splash_attention_mask._fill_slice(kv_slice, shape[1]) - q_sequence = np.arange(shape[0], dtype=np.int32) - rows = q_sequence[q_slice] - cols = np.arange(kv_slice.start, kv_slice.stop) - q_ids = rows[:, None] - kv_ids = cols[None, :] - #assuming offset == 0: - original_mask_ndarray = q_ids >= kv_ids - mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) - return splash_attention_mask.NumpyMask(mask_ndarray) + self.offset = offset + self.cp_size = cp_size + + + def causal_mask_function(q_ids, kv_ids): + # When evaluating the mask in _process_mask we typically work with numpy + # array views. + # Avoid the addition when possible to avoid instantiating an actual array. + + def create_load_balance_causal_mask( + shape: tuple[int, int], + q_ids: int, + kv_ids: int, + offset: int = 0, #Anisha: do we need offset? + ): + # self.offset = offset + # idx = (slice(shape[0]),slice(shape[1])) + # q_slice, kv_slice = idx + # q_slice = splash_attention_mask._fill_slice(q_slice, shape[0]) + # kv_slice = splash_attention_mask._fill_slice(kv_slice, shape[1]) + # q_sequence = np.arange(shape[0], dtype=np.int32) + # rows = q_sequence[q_slice] + # cols = np.arange(kv_slice.start, kv_slice.stop) + # q_ids = rows[:, None] + # kv_ids = cols[None, :] + if offset == 0: + return q_ids >= kv_ids + else: + return q_ids + offset >= kv_ids + + original_mask_ndarray = create_load_balance_causal_mask(self.shape, q_ids, kv_ids, self.offset) + mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) + return mask_ndarray + # return splash_attention_mask.NumpyMask(mask_ndarray) + + + mask_function = causal_mask_function + + super().__init__( + shape=shape, + mask_function=mask_function, + shard_count=shard_count, + ) + + def __eq__(self, other: object): + if not isinstance(other, type(self)): + return NotImplemented + return ( + self.shape == other.shape + and self.offset == other.offset + and np.array_equal(self.q_sequence, other.q_sequence) + ) + + def __hash__(self): + return hash(( + type(self), + self.shape, + self.offset, + self.q_sequence.tobytes() if self.q_sequence is not None else None, + )) From 32539c9fdc3ff77f875fe74fe5bac56c4ed88418 Mon Sep 17 00:00:00 2001 From: A9isha Date: Sat, 11 Jan 2025 00:07:59 +0000 Subject: [PATCH 21/28] fix traced array --- MaxText/layers/attentions.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 4565d818c..5b30494b3 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -508,8 +508,11 @@ def reorder_mask_load_balancing(tensor, cp_size: int, seq_dim: int): # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] index = np.array([cp_rank, (2 * cp_size - cp_rank - 1)]) - parts.append(np.take(tensor, index, axis=seq_dim)) - # breakpoint() + try: + parts.append(np.take(tensor, index, axis=seq_dim)) + except Exception as e: + print(f"Got exception={e}") + # breakpoint() # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] @@ -1588,7 +1591,7 @@ def create_load_balance_causal_mask( return q_ids + offset >= kv_ids original_mask_ndarray = create_load_balance_causal_mask(self.shape, q_ids, kv_ids, self.offset) - mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= cp_size, seq_dim= 0) + mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= self.cp_size, seq_dim= 0) return mask_ndarray # return splash_attention_mask.NumpyMask(mask_ndarray) From b47eb81f3f21bbecc35a4727eb185562be823e9a Mon Sep 17 00:00:00 2001 From: A9isha Date: Mon, 13 Jan 2025 22:09:42 +0000 Subject: [PATCH 22/28] debugging _ComputableMask but moving to dynamic slice creation --- MaxText/layers/attentions.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 5b30494b3..037442c4c 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -475,7 +475,9 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segme @staticmethod def reorder_mask_load_balancing(tensor, cp_size: int, seq_dim: int): """Reorders a tensor for load balancing the compute of causal attention.""" - + if type(tensor) is not np.ndarray: + breakpoint() + if tensor is None: return tensor @@ -512,12 +514,16 @@ def reorder_mask_load_balancing(tensor, cp_size: int, seq_dim: int): parts.append(np.take(tensor, index, axis=seq_dim)) except Exception as e: print(f"Got exception={e}") - # breakpoint() + breakpoint() # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] - combined = np.stack(parts, axis=seq_dim) + try: + combined = np.stack(parts, axis=seq_dim) + except Exception as e: + print(f"Got exception={e}") + breakpoint() return np.reshape(combined,ori_tensor_shape) @@ -1569,12 +1575,28 @@ def causal_mask_function(q_ids, kv_ids): # array views. # Avoid the addition when possible to avoid instantiating an actual array. + def create_causal_mask_for_index( + shape: tuple[int, int], + idx: int, + cp: int, #context parallelism val + ): + + q_slice, kv_slice = idx, slice(shape[1]) + + + def create_load_balance_causal_mask( shape: tuple[int, int], q_ids: int, kv_ids: int, - offset: int = 0, #Anisha: do we need offset? + offset: int = 0, #This is important for auto regressive decoding, + #we are not supporting flash/splash attention for auto regressive decoding ): + (slice * i + jnp.arange(slice))[:, None] <= jnp.arange(seq_len)[None, :] - something like this + for each index, for i = i and i = n - i -1 (perhaps 2 * n - i - 1 here maybe?) + + #then concatenate + # self.offset = offset # idx = (slice(shape[0]),slice(shape[1])) # q_slice, kv_slice = idx @@ -1591,6 +1613,10 @@ def create_load_balance_causal_mask( return q_ids + offset >= kv_ids original_mask_ndarray = create_load_balance_causal_mask(self.shape, q_ids, kv_ids, self.offset) + if type(original_mask_ndarray) is not np.ndarray: + raise ValueError("Something went wrong in function_create_load_balance_causal_mask") + else: + print("np ndarray found!!") mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= self.cp_size, seq_dim= 0) return mask_ndarray # return splash_attention_mask.NumpyMask(mask_ndarray) From 968f2352ca6e01c7fd7e96de682c126786fa02ce Mon Sep 17 00:00:00 2001 From: A9isha Date: Tue, 21 Jan 2025 23:04:51 +0000 Subject: [PATCH 23/28] try out new computable mask --- MaxText/layers/attentions.py | 134 +++++++++++------- .../llama2/test_llama2_context_parallel.sh | 5 + 2 files changed, 88 insertions(+), 51 deletions(-) create mode 100644 end_to_end/tpu/llama2/test_llama2_context_parallel.sh diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 037442c4c..bd4280049 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -1562,65 +1562,21 @@ class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): def __init__( self, shape: tuple[int, int], - cp_size: int, offset: int = 0, shard_count: int = 1, + cp_size: int = 4 ): self.offset = offset - self.cp_size = cp_size - def causal_mask_function(q_ids, kv_ids): - # When evaluating the mask in _process_mask we typically work with numpy - # array views. - # Avoid the addition when possible to avoid instantiating an actual array. - - def create_causal_mask_for_index( - shape: tuple[int, int], - idx: int, - cp: int, #context parallelism val - ): - - q_slice, kv_slice = idx, slice(shape[1]) - - - - def create_load_balance_causal_mask( - shape: tuple[int, int], - q_ids: int, - kv_ids: int, - offset: int = 0, #This is important for auto regressive decoding, - #we are not supporting flash/splash attention for auto regressive decoding - ): - (slice * i + jnp.arange(slice))[:, None] <= jnp.arange(seq_len)[None, :] - something like this - for each index, for i = i and i = n - i -1 (perhaps 2 * n - i - 1 here maybe?) - - #then concatenate - - # self.offset = offset - # idx = (slice(shape[0]),slice(shape[1])) - # q_slice, kv_slice = idx - # q_slice = splash_attention_mask._fill_slice(q_slice, shape[0]) - # kv_slice = splash_attention_mask._fill_slice(kv_slice, shape[1]) - # q_sequence = np.arange(shape[0], dtype=np.int32) - # rows = q_sequence[q_slice] - # cols = np.arange(kv_slice.start, kv_slice.stop) - # q_ids = rows[:, None] - # kv_ids = cols[None, :] - if offset == 0: - return q_ids >= kv_ids - else: - return q_ids + offset >= kv_ids - - original_mask_ndarray = create_load_balance_causal_mask(self.shape, q_ids, kv_ids, self.offset) - if type(original_mask_ndarray) is not np.ndarray: - raise ValueError("Something went wrong in function_create_load_balance_causal_mask") + if self.offset == 0: + return q_ids >= kv_ids else: - print("np ndarray found!!") - mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= self.cp_size, seq_dim= 0) - return mask_ndarray - # return splash_attention_mask.NumpyMask(mask_ndarray) + return q_ids + self.offset >= kv_ids + arr = np.arange(shape[0]) + out = AttentionOp.reorder_mask_load_balancing(arr[None, :, None, None], cp_size, seq_dim=1) + q_sequence = out[0, :, 0, 0] mask_function = causal_mask_function @@ -1629,6 +1585,82 @@ def create_load_balance_causal_mask( mask_function=mask_function, shard_count=shard_count, ) + self.q_sequence = q_sequence + + # def __init__( + # self, + # shape: tuple[int, int], + # cp_size: int, + # offset: int = 0, + # shard_count: int = 1, + # ): + # self.offset = offset + # self.cp_size = cp_size + + + # def causal_mask_function(q_ids, kv_ids): + # # When evaluating the mask in _process_mask we typically work with numpy + # # array views. + # # Avoid the addition when possible to avoid instantiating an actual array. + + # def create_causal_mask_for_index( + # shape: tuple[int, int], + # idx: int, + # cp: int, #context parallelism val + # ): + + # q_slice, kv_slice = idx, slice(shape[1]) + + + + # def create_load_balance_causal_mask( + # shape: tuple[int, int], + # q_ids: int, + # kv_ids: int, + # offset: int = 0, #This is important for auto regressive decoding, + # #we are not supporting flash/splash attention for auto regressive decoding + # ): + # (slice * i + jnp.arange(slice))[:, None] <= jnp.arange(seq_len)[None, :] - something like this + # for each index, for i = i and i = n - i -1 (perhaps 2 * n - i - 1 here maybe?) + + # #then concatenate + + # """ + # 1. create all + # """ + + # # self.offset = offset + # # idx = (slice(shape[0]),slice(shape[1])) + # # q_slice, kv_slice = idx + # # q_slice = splash_attention_mask._fill_slice(q_slice, shape[0]) + # # kv_slice = splash_attention_mask._fill_slice(kv_slice, shape[1]) + # # q_sequence = np.arange(shape[0], dtype=np.int32) + # # rows = q_sequence[q_slice] + # # cols = np.arange(kv_slice.start, kv_slice.stop) + # # q_ids = rows[:, None] + # # kv_ids = cols[None, :] + # if offset == 0: + # return q_ids >= kv_ids + # else: + # return q_ids + offset >= kv_ids + + # original_mask_ndarray = create_load_balance_causal_mask(self.shape, q_ids, kv_ids, self.offset) + # if type(original_mask_ndarray) is not np.ndarray: + # raise ValueError("Something went wrong in function_create_load_balance_causal_mask") + # else: + # print("np ndarray found!!") + # mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= self.cp_size, seq_dim= 0) + # return mask_ndarray + # # return splash_attention_mask.NumpyMask(mask_ndarray) + + + # mask_function = causal_mask_function + + # super().__init__( + # shape=shape, + # mask_function=mask_function, + # shard_count=shard_count, + # ) def __eq__(self, other: object): if not isinstance(other, type(self)): diff --git a/end_to_end/tpu/llama2/test_llama2_context_parallel.sh b/end_to_end/tpu/llama2/test_llama2_context_parallel.sh new file mode 100644 index 000000000..72ec8ed96 --- /dev/null +++ b/end_to_end/tpu/llama2/test_llama2_context_parallel.sh @@ -0,0 +1,5 @@ +#!/bin/bash + + +idx=$(date +%Y-%m-%d-%H-%M) +python3 MaxText/train.py MaxText/configs/base.yml ici_context_parallelism=-1 ici_fsdp_parallelism=1 enable_checkpointing=false base_output_directory=gs://mazumdera-test-bucket-us-east5/maxtext/seqpara/${idx} dataset_path=gs://max-datasets-rogue run_name=context_test enable_goodput_recording=false monitor_goodput=false per_device_batch_size=10 steps=30 profiler=xplane profiler_steps=20 max_target_length=65536 \ No newline at end of file From a026cb1a0ec93e61f7db94de00cade37cc4c8ee2 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 22 Jan 2025 01:57:12 +0000 Subject: [PATCH 24/28] try running v6e-256 experiment --- benchmarks/benchmark_runner.py | 12 ++++ benchmarks/maxtext_trillium_model_configs.py | 73 ++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index b87fae7d5..ef501cded 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -29,6 +29,16 @@ from maxtext_xpk_runner import xpk_benchmark_runner from maxtext_xpk_runner import XpkClusterConfig from maxtext_xpk_runner import LibTpuType +import datetime + +def get_current_day_hour_minute(): + """Gets the current day, hour, and minute as a formatted string. + + Returns: + A string in the format 'YYYYMMDDHHMM'. + """ + now = datetime.datetime.now() + return now.strftime("%Y%m%d%H%M") def add_shared_arguments(custom_parser: argparse.ArgumentParser): """Add shared arguments to the parser. @@ -156,6 +166,8 @@ def main() -> None: add_shared_arguments(parser) options = parser.parse_args() + idx = get_current_day_hour_minute() + cluster_config = XpkClusterConfig( cluster_name=options.cluster_name, project=options.project, diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 1e1acb654..c9302466f 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -1062,3 +1062,76 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ ), ) ) +<<<<<<< HEAD +======= + +llama3_1_70b_131072 = MaxTextModel( + model_name="llama3_1_70b_131072", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 0.125, + "ici_fsdp_parallelism": -1, + "ici_context_parallelism": 16, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 131072, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + "allow_split_physical_axes": True, + # "custom_mesh": "hybrid_ring_32x8", + "tokenizer_path": "assets/tokenizer_llama3.tiktoken", + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_1D_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), +) + +# TODO(b/368441022) LLAMA3.1 8B, 70B, 405B +# TODO(b/368441022) MaxDiffusion BEST +# TODO(b/368441022) Determine largest batch per slice for non-optimized models +# List of all models +maxstar_models = [ + default_basic, + default_32, + default_64, # Not Optimizied yet + default_128, # Not Optimizied yet + # default_256, # OOM, Not Optimizied yet + # default_512, # OOM, Not Optimizied yet + gpt_3_175b, + llama2_7b_4096, + llama2_70b_4096, + llama2_70b_4096_real_data, + llama3_8b_8192, # Not Optimizied yet + llama3_70b_8192, # Not Optimizied yet + llama3_1_405b_8192_fsdp_dcn, + llama3_1_70b_129024, + mixtral_8x7b_dropped, + mixtral_8x7b_dropped_int8, + mixtral_8x7b_dropless, + gemma2_9b_8192, + gemma2_27b_8192, +] +>>>>>>> 5efc11e1 (try running v6e-256 experiment) From 5adc54f0a01d0a137ed3b0b42fa66109aa031c2d Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 22 Jan 2025 07:15:35 +0000 Subject: [PATCH 25/28] clean up config --- benchmarks/maxtext_trillium_model_configs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index c9302466f..f2a4cf5dd 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -1082,7 +1082,7 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ "attention": "flash", "use_iota_embed": True, "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", + # "dataset_type": "synthetic", "enable_checkpointing": False, "sa_block_q": 2048, "sa_block_kv": 2048, @@ -1096,7 +1096,7 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ "profiler": "xplane", "skip_first_n_steps_for_profiler": 10, "profiler_steps": 5, - "allow_split_physical_axes": True, + # "allow_split_physical_axes": True, # "custom_mesh": "hybrid_ring_32x8", "tokenizer_path": "assets/tokenizer_llama3.tiktoken", }, From fca5d0ee7bd653ee71b6c1d32e666be5084124b3 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 22 Jan 2025 21:08:56 +0000 Subject: [PATCH 26/28] try different sa_block values --- benchmarks/maxtext_trillium_model_configs.py | 137 ++++++++++++++++++- 1 file changed, 136 insertions(+), 1 deletion(-) diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index f2a4cf5dd..65c4d17b6 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -1069,7 +1069,7 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ model_name="llama3_1_70b_131072", model_type="llama3.1-70b", tuning_params={ - "per_device_batch_size": 0.125, + "per_device_batch_size": 0.0625, "ici_fsdp_parallelism": -1, "ici_context_parallelism": 16, "remat_policy": "custom", @@ -1099,6 +1099,141 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ # "allow_split_physical_axes": True, # "custom_mesh": "hybrid_ring_32x8", "tokenizer_path": "assets/tokenizer_llama3.tiktoken", + "steps": 30, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_1D_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), +) + +llama3_1_70b_131072 = MaxTextModel( + model_name="llama3_1_70b_131072", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 0.0625, + "ici_fsdp_parallelism": -1, + "ici_context_parallelism": 16, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 131072, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + # "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + # "allow_split_physical_axes": True, + # "custom_mesh": "hybrid_ring_32x8", + "tokenizer_path": "assets/tokenizer_llama3.tiktoken", + "steps": 30, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_1D_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), +) + +llama3_1_70b_131072_1 = MaxTextModel( + model_name="llama3_1_70b_131072_1", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 0.0625, + "ici_fsdp_parallelism": -1, + "ici_context_parallelism": 16, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 131072, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + # "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 6144, + "sa_block_kv": 6144, + "sa_block_kv_compute": 6144, + "sa_block_q_dkv": 6144, + "sa_block_kv_dkv": 6144, + "sa_block_kv_dkv_compute": 6144, + "sa_block_q_dq": 6144, + "sa_block_kv_dq": 6144, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + # "allow_split_physical_axes": True, + # "custom_mesh": "hybrid_ring_32x8", + "tokenizer_path": "assets/tokenizer_llama3.tiktoken", + "steps": 20, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_1D_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), +) +llama3_1_70b_131072_2 = MaxTextModel( + model_name="llama3_1_70b_131072_2", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 0.0625, + "ici_fsdp_parallelism": -1, + "ici_context_parallelism": 16, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 131072, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + # "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 8192, + "sa_block_kv": 8192, + "sa_block_kv_compute": 8192, + "sa_block_q_dkv": 8192, + "sa_block_kv_dkv": 8192, + "sa_block_kv_dkv_compute": 8192, + "sa_block_q_dq": 8192, + "sa_block_kv_dq": 8192, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + # "allow_split_physical_axes": True, + # "custom_mesh": "hybrid_ring_32x8", + "tokenizer_path": "assets/tokenizer_llama3.tiktoken", + "steps": 20, }, xla_flags=( xla_flags_library.DENSE_VMEM_LIMIT_FLAG From b70ea1e1e72de4aaac392892d05ae0a6b5b9f1b2 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 29 Jan 2025 23:25:15 +0000 Subject: [PATCH 27/28] codestyle changes --- MaxText/layers/attentions.py | 241 +++++++++---------- benchmarks/maxtext_trillium_model_configs.py | 73 ++---- 2 files changed, 128 insertions(+), 186 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index bd4280049..1eba91dc4 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -302,21 +302,19 @@ def tpu_flash_attention( value: Array, decoder_segment_ids: Array | None, attn_logits_soft_cap: float | None = None, - load_balanced_context_parallel: bool = True + load_balanced_context_parallel: bool = True, ) -> Array: """TPU Flash Attention.""" decoder_segment_ids_permuted = None - #Reorder tensors which is currently [B,S,H,KV] + # Reorder tensors which is currently [B,S,H,KV] cp_size = self.mesh.shape["context"] - if cp_size>1 and load_balanced_context_parallel: - query = self.reorder_causal_load_balancing(tensor = query, cp_size= cp_size, seq_dim= 1, to_contiguous=False) - decoder_segment_ids_permuted = self.reorder_causal_load_balancing(tensor = decoder_segment_ids, cp_size= cp_size, seq_dim= 1, to_contiguous=False) - - - - + if cp_size > 1 and load_balanced_context_parallel: + query = self.reorder_causal_load_balancing(tensor=query, cp_size=cp_size, seq_dim=1, to_contiguous=False) + decoder_segment_ids_permuted = self.reorder_causal_load_balancing( + tensor=decoder_segment_ids, cp_size=cp_size, seq_dim=1, to_contiguous=False + ) # Transpose to ('batch', 'heads', 'length', 'kv') query = jnp.transpose(query, axes=(0, 2, 1, 3)) @@ -330,10 +328,10 @@ def tpu_flash_attention( axis_names_splash_kernel = nn.logical_to_mesh_axes(self.flash_axis_names_splash_kernel) axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q) axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv) - max_logging.log(f'axis_names_q: {axis_names_q}') - max_logging.log(f'axis_names_kv: {axis_names_kv}') - max_logging.log(f'axis_names_splash_kernel: {axis_names_splash_kernel}') - + max_logging.log(f"axis_names_q: {axis_names_q}") + max_logging.log(f"axis_names_kv: {axis_names_kv}") + max_logging.log(f"axis_names_splash_kernel: {axis_names_splash_kernel}") + global_block_q = self.config.sa_block_q global_block_kv = self.config.sa_block_kv global_block_kv_compute = self.config.sa_block_kv_compute @@ -347,15 +345,12 @@ def tpu_flash_attention( global_k_layout = self.config.sa_k_layout global_v_layout = self.config.sa_v_layout - devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( "Batch dimension should be shardable among the devices in data and fsdp" " axis" ) - - - #create_splash_attention kernel + # create_splash_attention kernel block_sizes = splash_attention_kernel.BlockSizes( block_q=min(global_block_q, query.shape[2]), block_kv=min(global_block_kv, key.shape[2]), @@ -370,21 +365,21 @@ def tpu_flash_attention( k_layout=splash_attention_kernel.QKVLayout[global_k_layout], v_layout=splash_attention_kernel.QKVLayout[global_v_layout], ) - + # mask_shape = (query.shape[2], key.shape[2]) mask_shape = (self.config.max_target_length, self.config.max_target_length) mask = splash_attention_mask.CausalMask(shape=mask_shape) # permute the mask if cp and load_balancing - if cp_size>1 and load_balanced_context_parallel: + if cp_size > 1 and load_balanced_context_parallel: # mask = create_load_balance_causal_mask(shape=mask_shape,cp_size=cp_size) - mask = LoadBalancedCausalMask(shape=mask_shape,cp_size=cp_size) - + mask = LoadBalancedCausalMask(shape=mask_shape, cp_size=cp_size) + # jax.debug.print("permuted: mask items = {items}", items = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) - + # jax.debug.print("new_mask == old_mask = {equal}", equal = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))==mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) - #TODO: figure out local_sliding attention + load_balancing, default is global + # TODO: figure out local_sliding attention + load_balancing, default is global # Apply local masking if local sliding attention is enabled. if self.attention_type == AttentionType.LOCAL_SLIDING: if self.sliding_window_size is None: @@ -407,19 +402,18 @@ def tpu_flash_attention( def wrap_splash_kernel(multi_head_mask): splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, - head_shards=1, # we would need to change this to the size of the axis if sharding over heads - q_seq_shards=cp_size, #axis for sequence sharding + head_shards=1, # we would need to change this to the size of the axis if sharding over heads + q_seq_shards=cp_size, # axis for sequence sharding block_sizes=block_sizes, attn_logits_soft_cap=attn_logits_soft_cap, ) return splash_kernel - + splash_kernel = wrap_splash_kernel(multi_head_mask) named_sharding = jax.sharding.NamedSharding(self.mesh, axis_names_splash_kernel) segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) - @functools.partial( shard_map, mesh=self.mesh, @@ -444,24 +438,23 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segme # pdb.set_trace() # jax.debug.print("attention_output.shape = {ash}", ash = attention_output.shape) # full_mask = [per_head_mask for per_head_mask in multi_head_mask.masks] - # valid_tokens = multi_head_mask.masks.any(dim=-1) # [q_sl] -> [q_sl, 1] -> [q_sl, head_dim] + # valid_tokens = multi_head_mask.masks.any(dim=-1) # [q_sl] -> [q_sl, 1] -> [q_sl, head_dim] # valid_tokens = decoder_segment_ids_q & multi_head_mask.masks.any(dim=-1) # attention_output = attention_output * valid_tokens # broadcasting along head_dim return attention_output - if cp_size>1 and load_balanced_context_parallel: + if cp_size > 1 and load_balanced_context_parallel: x = wrap_flash_attention(query, key, value, decoder_segment_ids_permuted, decoder_segment_ids, splash_kernel) else: x = wrap_flash_attention(query, key, value, decoder_segment_ids, decoder_segment_ids, splash_kernel) - + x = jnp.transpose(x, axes=(0, 2, 1, 3)) - if cp_size>1 and load_balanced_context_parallel: - #inverse reorder for load_balancing - x = self.reorder_causal_load_balancing(tensor = x, cp_size= cp_size, seq_dim= 1, to_contiguous=True) + if cp_size > 1 and load_balanced_context_parallel: + # inverse reorder for load_balancing + x = self.reorder_causal_load_balancing(tensor=x, cp_size=cp_size, seq_dim=1, to_contiguous=True) - return x # @functools.partial( @@ -477,116 +470,111 @@ def reorder_mask_load_balancing(tensor, cp_size: int, seq_dim: int): """Reorders a tensor for load balancing the compute of causal attention.""" if type(tensor) is not np.ndarray: breakpoint() - + if tensor is None: return tensor - + if cp_size == 1: - return tensor + return tensor if cp_size % 2 != 0: - raise ValueError(f"{cp_size=} must be a multiple of 2.") + raise ValueError(f"{cp_size=} must be a multiple of 2.") # Need to ensure we have 2 pairs to swap for balancing between cp ranks if tensor.shape[seq_dim] % (cp_size * 2) != 0: - raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] #Anisha: this is ours # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] - + ori_tensor_shape = tensor.shape tensor = np.reshape( - tensor, - ( - *ori_tensor_shape[:seq_dim], - 2 * cp_size, - ori_tensor_shape[seq_dim] // (2 * cp_size), - *ori_tensor_shape[seq_dim + 1 :], - ) - ) + tensor, + ( + *ori_tensor_shape[:seq_dim], + 2 * cp_size, + ori_tensor_shape[seq_dim] // (2 * cp_size), + *ori_tensor_shape[seq_dim + 1 :], + ), + ) parts = [] for cp_rank in range(cp_size): - # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] - # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] - index = np.array([cp_rank, (2 * cp_size - cp_rank - 1)]) - try: - parts.append(np.take(tensor, index, axis=seq_dim)) - except Exception as e: - print(f"Got exception={e}") - breakpoint() - + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + index = np.array([cp_rank, (2 * cp_size - cp_rank - 1)]) + try: + parts.append(np.take(tensor, index, axis=seq_dim)) + except Exception as e: + print(f"Got exception={e}") + breakpoint() # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] try: combined = np.stack(parts, axis=seq_dim) except Exception as e: - print(f"Got exception={e}") - breakpoint() + print(f"Got exception={e}") + breakpoint() - return np.reshape(combined,ori_tensor_shape) + return np.reshape(combined, ori_tensor_shape) def reorder_causal_load_balancing(self, tensor, cp_size: int, seq_dim: int, to_contiguous: bool): """Reorders a tensor for load balancing the compute of causal attention.""" if tensor is None: return tensor - + if cp_size == 1: - return tensor + return tensor if cp_size % 2 != 0: - raise ValueError(f"{cp_size=} must be a multiple of 2.") + raise ValueError(f"{cp_size=} must be a multiple of 2.") # Need to ensure we have 2 pairs to swap for balancing between cp ranks if tensor.shape[seq_dim] % (cp_size * 2) != 0: - raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] #Anisha: this is ours # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] - + ori_tensor_shape = tensor.shape tensor = jnp.reshape( - tensor, - ( - *ori_tensor_shape[:seq_dim], - 2 * cp_size, - ori_tensor_shape[seq_dim] // (2 * cp_size), - *ori_tensor_shape[seq_dim + 1 :], - ) - ) + tensor, + ( + *ori_tensor_shape[:seq_dim], + 2 * cp_size, + ori_tensor_shape[seq_dim] // (2 * cp_size), + *ori_tensor_shape[seq_dim + 1 :], + ), + ) parts = [] if not to_contiguous: - for cp_rank in range(cp_size): - # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] - # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] - index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) - parts.append(jnp.take(tensor, index, axis=seq_dim)) + for cp_rank in range(cp_size): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) else: - for cp_rank in range(cp_size // 2): - # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] - # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] - base = 4 * cp_rank - index = jnp.array([base, base + 2]) - parts.append(jnp.take(tensor, index, axis=seq_dim)) - for cp_rank in range(cp_size // 2): - # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] - # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] - base = 2 * cp_size - 1 - 4 * cp_rank - index = jnp.array([base, base - 2]) - parts.append(jnp.take(tensor, index, axis=seq_dim)) + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 4 * cp_rank + index = jnp.array([base, base + 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 2 * cp_size - 1 - 4 * cp_rank + index = jnp.array([base, base - 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] combined = jnp.stack(parts, axis=seq_dim) - return jnp.reshape(combined,ori_tensor_shape) - - - - + return jnp.reshape(combined, ori_tensor_shape) def cudnn_flash_attention( self, @@ -1532,19 +1520,23 @@ def __call__( out = checkpoint_name(out, "out_proj") return out + partial = functools.partial -class WrapperNpNDArray(): + +class WrapperNpNDArray: np_ndarray: np.ndarray - def __init__(self,np_ndarray): + def __init__(self, np_ndarray): self.np_ndarray = np_ndarray def __hash__(self): - return hash(( - type(self), - self.np_ndarray.tobytes() if self.np_ndarray is not None else None, - )) + return hash( + ( + type(self), + self.np_ndarray.tobytes() if self.np_ndarray is not None else None, + ) + ) class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): @@ -1554,18 +1546,13 @@ class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): triangle upward, a negative one shifts it downward. A negative offset makes the first 'offset' rows of the attention matrix all 0s which leads to undefined softmax. - """ + """ + offset: int shape: tuple[int, int] cp_size: int - def __init__( - self, - shape: tuple[int, int], - offset: int = 0, - shard_count: int = 1, - cp_size: int = 4 - ): + def __init__(self, shape: tuple[int, int], offset: int = 0, shard_count: int = 1, cp_size: int = 4): self.offset = offset def causal_mask_function(q_ids, kv_ids): @@ -1574,7 +1561,7 @@ def causal_mask_function(q_ids, kv_ids): else: return q_ids + self.offset >= kv_ids - arr = np.arange(shape[0]) + arr = jnp.arange(shape[0]) out = AttentionOp.reorder_mask_load_balancing(arr[None, :, None, None], cp_size, seq_dim=1) q_sequence = out[0, :, 0, 0] @@ -1596,7 +1583,6 @@ def causal_mask_function(q_ids, kv_ids): # ): # self.offset = offset # self.cp_size = cp_size - # def causal_mask_function(q_ids, kv_ids): # # When evaluating the mask in _process_mask we typically work with numpy @@ -1608,16 +1594,14 @@ def causal_mask_function(q_ids, kv_ids): # idx: int, # cp: int, #context parallelism val # ): - - # q_slice, kv_slice = idx, slice(shape[1]) - + # q_slice, kv_slice = idx, slice(shape[1]) # def create_load_balance_causal_mask( # shape: tuple[int, int], # q_ids: int, # kv_ids: int, - # offset: int = 0, #This is important for auto regressive decoding, + # offset: int = 0, #This is important for auto regressive decoding, # #we are not supporting flash/splash attention for auto regressive decoding # ): # (slice * i + jnp.arange(slice))[:, None] <= jnp.arange(seq_len)[None, :] - something like this @@ -1626,7 +1610,7 @@ def causal_mask_function(q_ids, kv_ids): # #then concatenate # """ - # 1. create all + # 1. create all # """ # # self.offset = offset @@ -1643,17 +1627,16 @@ def causal_mask_function(q_ids, kv_ids): # return q_ids >= kv_ids # else: # return q_ids + offset >= kv_ids - + # original_mask_ndarray = create_load_balance_causal_mask(self.shape, q_ids, kv_ids, self.offset) # if type(original_mask_ndarray) is not np.ndarray: # raise ValueError("Something went wrong in function_create_load_balance_causal_mask") # else: # print("np ndarray found!!") - # mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= self.cp_size, seq_dim= 0) + # mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= self.cp_size, seq_dim= 0) # return mask_ndarray # # return splash_attention_mask.NumpyMask(mask_ndarray) - # mask_function = causal_mask_function # super().__init__( @@ -1666,16 +1649,14 @@ def __eq__(self, other: object): if not isinstance(other, type(self)): return NotImplemented - return ( - self.shape == other.shape - and self.offset == other.offset - and np.array_equal(self.q_sequence, other.q_sequence) - ) + return self.shape == other.shape and self.offset == other.offset and jnp.array_equal(self.q_sequence, other.q_sequence) def __hash__(self): - return hash(( - type(self), - self.shape, - self.offset, - self.q_sequence.tobytes() if self.q_sequence is not None else None, - )) + return hash( + ( + type(self), + self.shape, + self.offset, + # self.q_sequence.tobytes() if self.q_sequence is not None else None, + ) + ) diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 65c4d17b6..b0c864e82 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -1062,55 +1062,10 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ ), ) ) -<<<<<<< HEAD -======= -llama3_1_70b_131072 = MaxTextModel( - model_name="llama3_1_70b_131072", - model_type="llama3.1-70b", - tuning_params={ - "per_device_batch_size": 0.0625, - "ici_fsdp_parallelism": -1, - "ici_context_parallelism": 16, - "remat_policy": "custom", - "decoder_layer_input": "offload", - "out_proj": "offload", - "query_proj": "offload", - "key_proj": "offload", - "value_proj": "offload", - "max_target_length": 131072, - "attention": "flash", - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - # "dataset_type": "synthetic", - "enable_checkpointing": False, - "sa_block_q": 2048, - "sa_block_kv": 2048, - "sa_block_kv_compute": 2048, - "sa_block_q_dkv": 2048, - "sa_block_kv_dkv": 2048, - "sa_block_kv_dkv_compute": 2048, - "sa_block_q_dq": 2048, - "sa_block_kv_dq": 2048, - "sa_use_fused_bwd_kernel": True, - "profiler": "xplane", - "skip_first_n_steps_for_profiler": 10, - "profiler_steps": 5, - # "allow_split_physical_axes": True, - # "custom_mesh": "hybrid_ring_32x8", - "tokenizer_path": "assets/tokenizer_llama3.tiktoken", - "steps": 30, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER - + xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_1D_ALL_GATHER - + xla_flags_library.HOST_OFFLOAD_FLAGS - ), -) - -llama3_1_70b_131072 = MaxTextModel( +llama3_1_70b_131072 = _add_to_model_dictionary( + trillium_model_dict, + MaxTextModel( model_name="llama3_1_70b_131072", model_type="llama3.1-70b", tuning_params={ @@ -1150,12 +1105,15 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_1D_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER + xla_flags_library.HOST_OFFLOAD_FLAGS ), + ) ) -llama3_1_70b_131072_1 = MaxTextModel( +llama3_1_70b_131072_1 = _add_to_model_dictionary( + trillium_model_dict, + MaxTextModel( model_name="llama3_1_70b_131072_1", model_type="llama3.1-70b", tuning_params={ @@ -1195,11 +1153,14 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_1D_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER + xla_flags_library.HOST_OFFLOAD_FLAGS ), + ) ) -llama3_1_70b_131072_2 = MaxTextModel( +llama3_1_70b_131072_2 = _add_to_model_dictionary( + trillium_model_dict, + MaxTextModel( model_name="llama3_1_70b_131072_2", model_type="llama3.1-70b", tuning_params={ @@ -1239,9 +1200,10 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_1D_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER + xla_flags_library.HOST_OFFLOAD_FLAGS ), + ) ) # TODO(b/368441022) LLAMA3.1 8B, 70B, 405B @@ -1249,7 +1211,7 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ # TODO(b/368441022) Determine largest batch per slice for non-optimized models # List of all models maxstar_models = [ - default_basic, + default_basic_1, default_32, default_64, # Not Optimizied yet default_128, # Not Optimizied yet @@ -1258,7 +1220,7 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ gpt_3_175b, llama2_7b_4096, llama2_70b_4096, - llama2_70b_4096_real_data, + llama2_70b_4096_sc_real_data_tfds, llama3_8b_8192, # Not Optimizied yet llama3_70b_8192, # Not Optimizied yet llama3_1_405b_8192_fsdp_dcn, @@ -1269,4 +1231,3 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ gemma2_9b_8192, gemma2_27b_8192, ] ->>>>>>> 5efc11e1 (try running v6e-256 experiment) From a829d234766d0ea51755f99e750075da2153a53d Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 29 Jan 2025 23:41:53 +0000 Subject: [PATCH 28/28] revert remaining jnp to np in mask computation --- MaxText/layers/attentions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 1eba91dc4..4a7f58fd3 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -1561,7 +1561,7 @@ def causal_mask_function(q_ids, kv_ids): else: return q_ids + self.offset >= kv_ids - arr = jnp.arange(shape[0]) + arr = np.arange(shape[0]) out = AttentionOp.reorder_mask_load_balancing(arr[None, :, None, None], cp_size, seq_dim=1) q_sequence = out[0, :, 0, 0] @@ -1649,7 +1649,7 @@ def __eq__(self, other: object): if not isinstance(other, type(self)): return NotImplemented - return self.shape == other.shape and self.offset == other.offset and jnp.array_equal(self.q_sequence, other.q_sequence) + return self.shape == other.shape and self.offset == other.offset and np.array_equal(self.q_sequence, other.q_sequence) def __hash__(self): return hash( @@ -1657,6 +1657,6 @@ def __hash__(self): type(self), self.shape, self.offset, - # self.q_sequence.tobytes() if self.q_sequence is not None else None, + self.q_sequence.tobytes() if self.q_sequence is not None else None, ) )