diff --git a/MaxText/common_types.py b/MaxText/common_types.py index c96bcaeef..fba7716f5 100644 --- a/MaxText/common_types.py +++ b/MaxText/common_types.py @@ -36,6 +36,7 @@ BATCH = "activation_batch" LENGTH = "activation_length" +KV_LENGTH = "activation_length_kv" EMBED = "activation_embed" HEAD = "activation_heads" PREFILL_KV_BATCH = "activation_prefill_kv_batch" diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index e784535e6..0b2250b5d 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -228,15 +228,18 @@ jax_cache_dir: "~/jax_cache" hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' # Parallelism -mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_sequence', 'expert', 'autoregressive'] +mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_heads', ['tensor','sequence','tensor_sequence']], ['activation_kv_heads', ['tensor','sequence','tensor_sequence']], - ['activation_length', ['sequence']], ['activation_norm_length', ['tensor_sequence', 'sequence']], + ['activation_length', ['sequence', 'context']], + ['activation_length', ['context']], + ['activation_length_q', ['context']], + ['activation_length_kv', []], ['activation_embed', 'tensor'], ['activation_mlp', ['tensor', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_sequence']], @@ -251,13 +254,13 @@ logical_axis_rules: [ ['activation_exp', 'expert'], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_sequence', 'autoregressive']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], - ['embed', ['fsdp', 'sequence', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence']], - ['embed_no_exp', ['fsdp', 'sequence']], - ['norm', ['tensor', 'tensor_sequence']], - ['q_heads', ['tensor', 'tensor_sequence', 'autoregressive']], - ['heads', ['tensor', 'tensor_sequence', 'autoregressive']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], + ['embed', ['fsdp', 'sequence', 'context', 'expert']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], + ['embed_no_exp', ['fsdp', 'sequence', 'context']], + ['norm', 'tensor'], + ['q_heads', ['tensor', 'autoregressive']], + ['heads', ['tensor', 'autoregressive']], ['layers', 'stage'], ['kv', []], ['kv_heads', ['tensor', 'tensor_sequence', 'autoregressive']], @@ -270,7 +273,7 @@ logical_axis_rules: [ ['exp', 'expert'], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']] +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor', 'tensor_sequence', 'expert', 'autoregressive']] # sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters. sharding_tolerance: 0.02 @@ -288,6 +291,7 @@ dcn_tensor_sequence_parallelism: 1 # never recommended dcn_pipeline_parallelism: 1 dcn_expert_parallelism: 1 dcn_autoregressive_parallelism: 1 # never recommended +dcn_context_parallelism: 1 ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_fsdp_transpose_parallelism: 1 @@ -297,6 +301,7 @@ ici_tensor_sequence_parallelism: 1 ici_autoregressive_parallelism: 1 ici_pipeline_parallelism: 1 ici_expert_parallelism: 1 +ici_context_parallelism: 1 # The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation, # you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1. diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index f5990e7d9..4a7f58fd3 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -34,6 +34,9 @@ from layers import initializers from layers import linears from layers import quantizations +import max_logging +import pdb +import numpy as np # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes @@ -64,6 +67,7 @@ class AttentionType(enum.Enum): PREFILL_KV_BATCH = common_types.PREFILL_KV_BATCH KV_BATCH = common_types.KV_BATCH LENGTH = common_types.LENGTH +KV_LENGTH = common_types.KV_LENGTH HEAD = common_types.HEAD EMBED = common_types.EMBED KV_HEAD = common_types.KV_HEAD @@ -141,7 +145,9 @@ class AttentionOp(nn.Module): float32_qk_product: bool = False max_prefill_predict_length: int = -1 float32_logits: bool = False - flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) + flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV) + flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV) + flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH) prefill_cache_logical_axis_names: AxisNames = (CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV) cache_logical_axis_names: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV) cache_scale_logical_axis_names: AxisNames = (CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV) @@ -296,17 +302,35 @@ def tpu_flash_attention( value: Array, decoder_segment_ids: Array | None, attn_logits_soft_cap: float | None = None, + load_balanced_context_parallel: bool = True, ) -> Array: """TPU Flash Attention.""" + + decoder_segment_ids_permuted = None + + # Reorder tensors which is currently [B,S,H,KV] + cp_size = self.mesh.shape["context"] + if cp_size > 1 and load_balanced_context_parallel: + query = self.reorder_causal_load_balancing(tensor=query, cp_size=cp_size, seq_dim=1, to_contiguous=False) + decoder_segment_ids_permuted = self.reorder_causal_load_balancing( + tensor=decoder_segment_ids, cp_size=cp_size, seq_dim=1, to_contiguous=False + ) + # Transpose to ('batch', 'heads', 'length', 'kv') query = jnp.transpose(query, axes=(0, 2, 1, 3)) key = jnp.transpose(key, axes=(0, 2, 1, 3)) value = jnp.transpose(value, axes=(0, 2, 1, 3)) - + segment_axis_names_q = None + segment_axis_names_kv = None if decoder_segment_ids is not None: - decoder_segment_ids = splash_attention_kernel.SegmentIds(decoder_segment_ids, decoder_segment_ids) - axis_names = nn.logical_to_mesh_axes(self.flash_axis_names) - segment_axis_names = nn.logical_to_mesh_axes((BATCH, "activation_length_no_heads")) + segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, "activation_length_q")) + segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, "activation_length_kv")) + axis_names_splash_kernel = nn.logical_to_mesh_axes(self.flash_axis_names_splash_kernel) + axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q) + axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv) + max_logging.log(f"axis_names_q: {axis_names_q}") + max_logging.log(f"axis_names_kv: {axis_names_kv}") + max_logging.log(f"axis_names_splash_kernel: {axis_names_splash_kernel}") global_block_q = self.config.sa_block_q global_block_kv = self.config.sa_block_kv @@ -321,70 +345,237 @@ def tpu_flash_attention( global_k_layout = self.config.sa_k_layout global_v_layout = self.config.sa_v_layout - @functools.partial( - shard_map, - mesh=self.mesh, - in_specs=( - axis_names, - axis_names, - axis_names, - segment_axis_names, - ), - out_specs=axis_names, - check_rep=False, + devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] + assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( + "Batch dimension should be shardable among the devices in data and fsdp" " axis" ) - def wrap_flash_attention(query, key, value, decoder_segment_ids): - if decoder_segment_ids is not None: - assert ( - query.shape[2] == decoder_segment_ids.q.shape[1] - ), "Sharding along sequence dimension not allowed in tpu kernel attention" - block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(global_block_q, query.shape[2]), - block_kv=min(global_block_kv, key.shape[2]), - block_kv_compute=min(global_block_kv_compute, key.shape[2]), - block_q_dkv=min(global_block_q_dkv, query.shape[2]), - block_kv_dkv=min(global_block_kv_dkv, key.shape[2]), - block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]), - block_q_dq=None if global_use_fused_bwd_kernel else min(global_block_q_dq, query.shape[2]), - block_kv_dq=None if global_use_fused_bwd_kernel else min(global_block_kv_dq, query.shape[2]), - use_fused_bwd_kernel=global_use_fused_bwd_kernel, - q_layout=splash_attention_kernel.QKVLayout[global_q_layout], - k_layout=splash_attention_kernel.QKVLayout[global_k_layout], - v_layout=splash_attention_kernel.QKVLayout[global_v_layout], - ) - mask = splash_attention_mask.CausalMask(shape=(query.shape[2], query.shape[2])) + # create_splash_attention kernel + block_sizes = splash_attention_kernel.BlockSizes( + block_q=min(global_block_q, query.shape[2]), + block_kv=min(global_block_kv, key.shape[2]), + block_kv_compute=min(global_block_kv_compute, key.shape[2]), + block_q_dkv=min(global_block_q_dkv, query.shape[2]), + block_kv_dkv=min(global_block_kv_dkv, key.shape[2]), + block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]), + block_q_dq=None if global_use_fused_bwd_kernel else min(global_block_q_dq, query.shape[2]), + block_kv_dq=None if global_use_fused_bwd_kernel else min(global_block_kv_dq, query.shape[2]), + use_fused_bwd_kernel=global_use_fused_bwd_kernel, + q_layout=splash_attention_kernel.QKVLayout[global_q_layout], + k_layout=splash_attention_kernel.QKVLayout[global_k_layout], + v_layout=splash_attention_kernel.QKVLayout[global_v_layout], + ) - # Apply local masking if local sliding attention is enabled. - if self.attention_type == AttentionType.LOCAL_SLIDING: - if self.sliding_window_size is None: - raise ValueError("Sliding_window_size must be set if Local Sliding attention type") - mask &= splash_attention_mask.LocalMask( - shape=(query.shape[2], query.shape[2]), - window_size=(self.sliding_window_size, self.sliding_window_size), - offset=0, - ) + # mask_shape = (query.shape[2], key.shape[2]) + mask_shape = (self.config.max_target_length, self.config.max_target_length) + mask = splash_attention_mask.CausalMask(shape=mask_shape) + + # permute the mask if cp and load_balancing + if cp_size > 1 and load_balanced_context_parallel: + # mask = create_load_balance_causal_mask(shape=mask_shape,cp_size=cp_size) + mask = LoadBalancedCausalMask(shape=mask_shape, cp_size=cp_size) + + # jax.debug.print("permuted: mask items = {items}", items = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) - # Create multi-head mask - multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + # jax.debug.print("new_mask == old_mask = {equal}", equal = new_mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))==mask.__getitem__((slice(mask.shape[0]),slice(mask.shape[1])))) + + # TODO: figure out local_sliding attention + load_balancing, default is global + # Apply local masking if local sliding attention is enabled. + if self.attention_type == AttentionType.LOCAL_SLIDING: + if self.sliding_window_size is None: + raise ValueError("Sliding_window_size must be set if Local Sliding attention type") + mask &= splash_attention_mask.LocalMask( + shape=(query.shape[2], key.shape[2]), + window_size=(self.sliding_window_size, self.sliding_window_size), + offset=0, + ) + + # Create multi-head mask + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + + @partial( + jax.jit, + static_argnames=[ + "multi_head_mask", + ], + ) + def wrap_splash_kernel(multi_head_mask): splash_kernel = splash_attention_kernel.make_splash_mha( mask=multi_head_mask, - head_shards=1, - q_seq_shards=1, + head_shards=1, # we would need to change this to the size of the axis if sharding over heads + q_seq_shards=cp_size, # axis for sequence sharding block_sizes=block_sizes, attn_logits_soft_cap=attn_logits_soft_cap, ) + return splash_kernel - return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids) + splash_kernel = wrap_splash_kernel(multi_head_mask) - devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] - assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( - "Batch dimension should be shardable among the devices in data and fsdp" " axis" + named_sharding = jax.sharding.NamedSharding(self.mesh, axis_names_splash_kernel) + segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) + + @functools.partial( + shard_map, + mesh=self.mesh, + in_specs=( + axis_names_q, + axis_names_kv, + axis_names_kv, + segment_axis_names_q, + segment_axis_names_kv, + segment_axis_names_splash_kernel, + ), + out_specs=axis_names_q, + check_rep=False, ) - x = wrap_flash_attention(query, key, value, decoder_segment_ids) + def wrap_flash_attention(query, key, value, decoder_segment_ids_q, decoder_segment_ids_kv, splash_kernel): + + if decoder_segment_ids_q is not None: + decoder_segment_ids_tuple = splash_attention_kernel.SegmentIds(decoder_segment_ids_q, decoder_segment_ids_kv) + else: + decoder_segment_ids_tuple = None + attention_output = jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids_tuple) + # pdb.set_trace() + # jax.debug.print("attention_output.shape = {ash}", ash = attention_output.shape) + # full_mask = [per_head_mask for per_head_mask in multi_head_mask.masks] + # valid_tokens = multi_head_mask.masks.any(dim=-1) # [q_sl] -> [q_sl, 1] -> [q_sl, head_dim] + # valid_tokens = decoder_segment_ids_q & multi_head_mask.masks.any(dim=-1) + # attention_output = attention_output * valid_tokens # broadcasting along head_dim + + return attention_output + + if cp_size > 1 and load_balanced_context_parallel: + x = wrap_flash_attention(query, key, value, decoder_segment_ids_permuted, decoder_segment_ids, splash_kernel) + else: + x = wrap_flash_attention(query, key, value, decoder_segment_ids, decoder_segment_ids, splash_kernel) + x = jnp.transpose(x, axes=(0, 2, 1, 3)) + + if cp_size > 1 and load_balanced_context_parallel: + # inverse reorder for load_balancing + x = self.reorder_causal_load_balancing(tensor=x, cp_size=cp_size, seq_dim=1, to_contiguous=True) + return x + # @functools.partial( + # jax.jit, + # static_argnames=[ + # "tensor", + # "cp_size", + # "seq_dim", + # ], + # ) + @staticmethod + def reorder_mask_load_balancing(tensor, cp_size: int, seq_dim: int): + """Reorders a tensor for load balancing the compute of causal attention.""" + if type(tensor) is not np.ndarray: + breakpoint() + + if tensor is None: + return tensor + + if cp_size == 1: + return tensor + + if cp_size % 2 != 0: + raise ValueError(f"{cp_size=} must be a multiple of 2.") + + # Need to ensure we have 2 pairs to swap for balancing between cp ranks + if tensor.shape[seq_dim] % (cp_size * 2) != 0: + raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + + # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] #Anisha: this is ours + # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] + + ori_tensor_shape = tensor.shape + tensor = np.reshape( + tensor, + ( + *ori_tensor_shape[:seq_dim], + 2 * cp_size, + ori_tensor_shape[seq_dim] // (2 * cp_size), + *ori_tensor_shape[seq_dim + 1 :], + ), + ) + + parts = [] + for cp_rank in range(cp_size): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + index = np.array([cp_rank, (2 * cp_size - cp_rank - 1)]) + try: + parts.append(np.take(tensor, index, axis=seq_dim)) + except Exception as e: + print(f"Got exception={e}") + breakpoint() + + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] + try: + combined = np.stack(parts, axis=seq_dim) + except Exception as e: + print(f"Got exception={e}") + breakpoint() + + return np.reshape(combined, ori_tensor_shape) + + def reorder_causal_load_balancing(self, tensor, cp_size: int, seq_dim: int, to_contiguous: bool): + """Reorders a tensor for load balancing the compute of causal attention.""" + + if tensor is None: + return tensor + + if cp_size == 1: + return tensor + + if cp_size % 2 != 0: + raise ValueError(f"{cp_size=} must be a multiple of 2.") + + # Need to ensure we have 2 pairs to swap for balancing between cp ranks + if tensor.shape[seq_dim] % (cp_size * 2) != 0: + raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + + # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] #Anisha: this is ours + # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] + + ori_tensor_shape = tensor.shape + tensor = jnp.reshape( + tensor, + ( + *ori_tensor_shape[:seq_dim], + 2 * cp_size, + ori_tensor_shape[seq_dim] // (2 * cp_size), + *ori_tensor_shape[seq_dim + 1 :], + ), + ) + + parts = [] + if not to_contiguous: + for cp_rank in range(cp_size): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + else: + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 4 * cp_rank + index = jnp.array([base, base + 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + for cp_rank in range(cp_size // 2): + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D] + base = 2 * cp_size - 1 - 4 * cp_rank + index = jnp.array([base, base - 2]) + parts.append(jnp.take(tensor, index, axis=seq_dim)) + + # [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] + # [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] + combined = jnp.stack(parts, axis=seq_dim) + + return jnp.reshape(combined, ori_tensor_shape) + def cudnn_flash_attention( self, query: Array, @@ -1328,3 +1519,144 @@ def __call__( out = self.out_projection(inputs_q.shape[-1], out) out = checkpoint_name(out, "out_proj") return out + + +partial = functools.partial + + +class WrapperNpNDArray: + np_ndarray: np.ndarray + + def __init__(self, np_ndarray): + self.np_ndarray = np_ndarray + + def __hash__(self): + return hash( + ( + type(self), + self.np_ndarray.tobytes() if self.np_ndarray is not None else None, + ) + ) + + +class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): + """Lazy causal mask, prevents the model from attending to future tokens. + Attributes: + offset: Offset of q start wrt kv. A positive offset shifts the bottom + triangle upward, a negative one shifts it downward. A negative offset + makes the first 'offset' rows of the attention matrix all 0s which leads + to undefined softmax. + """ + + offset: int + shape: tuple[int, int] + cp_size: int + + def __init__(self, shape: tuple[int, int], offset: int = 0, shard_count: int = 1, cp_size: int = 4): + self.offset = offset + + def causal_mask_function(q_ids, kv_ids): + if self.offset == 0: + return q_ids >= kv_ids + else: + return q_ids + self.offset >= kv_ids + + arr = np.arange(shape[0]) + out = AttentionOp.reorder_mask_load_balancing(arr[None, :, None, None], cp_size, seq_dim=1) + q_sequence = out[0, :, 0, 0] + + mask_function = causal_mask_function + + super().__init__( + shape=shape, + mask_function=mask_function, + shard_count=shard_count, + ) + self.q_sequence = q_sequence + + # def __init__( + # self, + # shape: tuple[int, int], + # cp_size: int, + # offset: int = 0, + # shard_count: int = 1, + # ): + # self.offset = offset + # self.cp_size = cp_size + + # def causal_mask_function(q_ids, kv_ids): + # # When evaluating the mask in _process_mask we typically work with numpy + # # array views. + # # Avoid the addition when possible to avoid instantiating an actual array. + + # def create_causal_mask_for_index( + # shape: tuple[int, int], + # idx: int, + # cp: int, #context parallelism val + # ): + + # q_slice, kv_slice = idx, slice(shape[1]) + + # def create_load_balance_causal_mask( + # shape: tuple[int, int], + # q_ids: int, + # kv_ids: int, + # offset: int = 0, #This is important for auto regressive decoding, + # #we are not supporting flash/splash attention for auto regressive decoding + # ): + # (slice * i + jnp.arange(slice))[:, None] <= jnp.arange(seq_len)[None, :] - something like this + # for each index, for i = i and i = n - i -1 (perhaps 2 * n - i - 1 here maybe?) + + # #then concatenate + + # """ + # 1. create all + # """ + + # # self.offset = offset + # # idx = (slice(shape[0]),slice(shape[1])) + # # q_slice, kv_slice = idx + # # q_slice = splash_attention_mask._fill_slice(q_slice, shape[0]) + # # kv_slice = splash_attention_mask._fill_slice(kv_slice, shape[1]) + # # q_sequence = np.arange(shape[0], dtype=np.int32) + # # rows = q_sequence[q_slice] + # # cols = np.arange(kv_slice.start, kv_slice.stop) + # # q_ids = rows[:, None] + # # kv_ids = cols[None, :] + # if offset == 0: + # return q_ids >= kv_ids + # else: + # return q_ids + offset >= kv_ids + + # original_mask_ndarray = create_load_balance_causal_mask(self.shape, q_ids, kv_ids, self.offset) + # if type(original_mask_ndarray) is not np.ndarray: + # raise ValueError("Something went wrong in function_create_load_balance_causal_mask") + # else: + # print("np ndarray found!!") + # mask_ndarray = AttentionOp.reorder_mask_load_balancing(tensor = original_mask_ndarray, cp_size= self.cp_size, seq_dim= 0) + # return mask_ndarray + # # return splash_attention_mask.NumpyMask(mask_ndarray) + + # mask_function = causal_mask_function + + # super().__init__( + # shape=shape, + # mask_function=mask_function, + # shard_count=shard_count, + # ) + + def __eq__(self, other: object): + if not isinstance(other, type(self)): + return NotImplemented + + return self.shape == other.shape and self.offset == other.offset and np.array_equal(self.q_sequence, other.q_sequence) + + def __hash__(self): + return hash( + ( + type(self), + self.shape, + self.offset, + self.q_sequence.tobytes() if self.q_sequence is not None else None, + ) + ) diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 64caf0613..46a850364 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -239,7 +239,7 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance): """ total_num_params = max_utils.calculate_num_params_from_pytree(params) product_num_devices_for_weight_sharding = 1 - for axis in ["fsdp", "fsdp_transpose", "sequence", "tensor", "tensor_sequence", "stage", "expert"]: + for axis in ["fsdp", "fsdp_transpose", "sequence", "context", "tensor", "tensor_sequence", "stage", "expert"]: product_num_devices_for_weight_sharding *= mesh.shape[axis] total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params) perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index fd0d17756..a26da6e7b 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -511,6 +511,7 @@ def create_parallelisms_list(raw_keys): raw_keys["ici_fsdp_parallelism"], raw_keys["ici_fsdp_transpose_parallelism"], raw_keys["ici_sequence_parallelism"], + raw_keys["ici_context_parallelism"], raw_keys["ici_tensor_parallelism"], raw_keys["ici_tensor_sequence_parallelism"], raw_keys["ici_expert_parallelism"], @@ -522,6 +523,7 @@ def create_parallelisms_list(raw_keys): raw_keys["dcn_fsdp_parallelism"], raw_keys["dcn_fsdp_transpose_parallelism"], raw_keys["dcn_sequence_parallelism"], + raw_keys["dcn_context_parallelism"], raw_keys["dcn_tensor_parallelism"], raw_keys["dcn_tensor_sequence_parallelism"], raw_keys["dcn_expert_parallelism"], diff --git a/MaxText/train.py b/MaxText/train.py index 63cfa137e..b45f1c30e 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -19,7 +19,7 @@ # Calling jax.device_count here prevents a "TPU platform already registered" error. # See github.com/google/maxtext/issues/20 for more - +import pdb import datetime import os import sys @@ -68,6 +68,7 @@ from ml_goodput_measurement import monitoring # pylint: disable=too-many-positional-arguments +jax.config.update("jax_debug_nans", True) Transformer = models.Transformer EPS = 1e-8 @@ -517,6 +518,36 @@ def reshape_to_microbatch_accumulations(batch_arr): extra_dpo_args = [reference_params] grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, state.params, *extra_dpo_args, is_train=True) + + # isinf = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["decoder_norm"]["scale"])) + # isinf2 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["mlp"]["wi_0"]["kernel"])) # + # isinf3 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["mlp"]["wi_1"]["kernel"])) # + # isinf4 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["mlp"]["wo"]["kernel"])) # + # isinf5 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["post_self_attention_layer_norm"]["scale"])) # + # isinf6 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["pre_self_attention_layer_norm"]["scale"])) # + # isinf7 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"])) # + # isinf8 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"])) # + # isinf9 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"])) # + # isinf10 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"])) # + # isinf11 = jnp.any(jnp.isinf(raw_grads["params"]["decoder"]["logits_dense"]["kernel"])) # + # isinf12 = jnp.any(jnp.isinf(raw_grads["params"]["token_embedder"]["embedding"])) + + # jax.debug.print("debug isinf raw_grads 1: {x}", x=isinf) + # jax.debug.print("debug isinf raw_grads 2: {x}", x=isinf2) + # jax.debug.print("debug isinf raw_grads 3: {x}", x=isinf3) + # jax.debug.print("debug isinf raw_grads 4: {x}", x=isinf4) + # jax.debug.print("debug isinf raw_grads 5: {x}", x=isinf5) + # jax.debug.print("debug isinf raw_grads 6: {x}", x=isinf6) + # jax.debug.print("debug isinf raw_grads 7: {x}", x=isinf7) + # jax.debug.print("debug isinf raw_grads 8: {x}", x=isinf8) + # jax.debug.print("debug isinf raw_grads 9: {x}", x=isinf9) + # jax.debug.print("debug isinf raw_grads 10: {x}", x=isinf10) + # jax.debug.print("debug isinf raw_grads 11: {x}", x=isinf11) + # jax.debug.print("debug isinf raw_grads 12: {x}", x=isinf12) + + # isinf_l2norm = jnp.any(jnp.isinf(max_utils.l2norm_pytree(raw_grads))) + # jax.debug.print("debug isinf l2norm raw_grads: {x}", x=isinf_l2norm) + intermediate_outputs = aux["intermediate_outputs"] total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index b87fae7d5..ef501cded 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -29,6 +29,16 @@ from maxtext_xpk_runner import xpk_benchmark_runner from maxtext_xpk_runner import XpkClusterConfig from maxtext_xpk_runner import LibTpuType +import datetime + +def get_current_day_hour_minute(): + """Gets the current day, hour, and minute as a formatted string. + + Returns: + A string in the format 'YYYYMMDDHHMM'. + """ + now = datetime.datetime.now() + return now.strftime("%Y%m%d%H%M") def add_shared_arguments(custom_parser: argparse.ArgumentParser): """Add shared arguments to the parser. @@ -156,6 +166,8 @@ def main() -> None: add_shared_arguments(parser) options = parser.parse_args() + idx = get_current_day_hour_minute() + cluster_config = XpkClusterConfig( cluster_name=options.cluster_name, project=options.project, diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 1e1acb654..b0c864e82 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -1062,3 +1062,172 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ ), ) ) + +llama3_1_70b_131072 = _add_to_model_dictionary( + trillium_model_dict, + MaxTextModel( + model_name="llama3_1_70b_131072", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 0.0625, + "ici_fsdp_parallelism": -1, + "ici_context_parallelism": 16, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 131072, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + # "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + # "allow_split_physical_axes": True, + # "custom_mesh": "hybrid_ring_32x8", + "tokenizer_path": "assets/tokenizer_llama3.tiktoken", + "steps": 30, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), + ) +) + +llama3_1_70b_131072_1 = _add_to_model_dictionary( + trillium_model_dict, + MaxTextModel( + model_name="llama3_1_70b_131072_1", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 0.0625, + "ici_fsdp_parallelism": -1, + "ici_context_parallelism": 16, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 131072, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + # "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 6144, + "sa_block_kv": 6144, + "sa_block_kv_compute": 6144, + "sa_block_q_dkv": 6144, + "sa_block_kv_dkv": 6144, + "sa_block_kv_dkv_compute": 6144, + "sa_block_q_dq": 6144, + "sa_block_kv_dq": 6144, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + # "allow_split_physical_axes": True, + # "custom_mesh": "hybrid_ring_32x8", + "tokenizer_path": "assets/tokenizer_llama3.tiktoken", + "steps": 20, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), + ) +) +llama3_1_70b_131072_2 = _add_to_model_dictionary( + trillium_model_dict, + MaxTextModel( + model_name="llama3_1_70b_131072_2", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 0.0625, + "ici_fsdp_parallelism": -1, + "ici_context_parallelism": 16, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 131072, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + # "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 8192, + "sa_block_kv": 8192, + "sa_block_kv_compute": 8192, + "sa_block_q_dkv": 8192, + "sa_block_kv_dkv": 8192, + "sa_block_kv_dkv_compute": 8192, + "sa_block_q_dq": 8192, + "sa_block_kv_dq": 8192, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + # "allow_split_physical_axes": True, + # "custom_mesh": "hybrid_ring_32x8", + "tokenizer_path": "assets/tokenizer_llama3.tiktoken", + "steps": 20, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), + ) +) + +# TODO(b/368441022) LLAMA3.1 8B, 70B, 405B +# TODO(b/368441022) MaxDiffusion BEST +# TODO(b/368441022) Determine largest batch per slice for non-optimized models +# List of all models +maxstar_models = [ + default_basic_1, + default_32, + default_64, # Not Optimizied yet + default_128, # Not Optimizied yet + # default_256, # OOM, Not Optimizied yet + # default_512, # OOM, Not Optimizied yet + gpt_3_175b, + llama2_7b_4096, + llama2_70b_4096, + llama2_70b_4096_sc_real_data_tfds, + llama3_8b_8192, # Not Optimizied yet + llama3_70b_8192, # Not Optimizied yet + llama3_1_405b_8192_fsdp_dcn, + llama3_1_70b_129024, + mixtral_8x7b_dropped, + mixtral_8x7b_dropped_int8, + mixtral_8x7b_dropless, + gemma2_9b_8192, + gemma2_27b_8192, +] diff --git a/end_to_end/tpu/llama2/test_llama2_context_parallel.sh b/end_to_end/tpu/llama2/test_llama2_context_parallel.sh new file mode 100644 index 000000000..72ec8ed96 --- /dev/null +++ b/end_to_end/tpu/llama2/test_llama2_context_parallel.sh @@ -0,0 +1,5 @@ +#!/bin/bash + + +idx=$(date +%Y-%m-%d-%H-%M) +python3 MaxText/train.py MaxText/configs/base.yml ici_context_parallelism=-1 ici_fsdp_parallelism=1 enable_checkpointing=false base_output_directory=gs://mazumdera-test-bucket-us-east5/maxtext/seqpara/${idx} dataset_path=gs://max-datasets-rogue run_name=context_test enable_goodput_recording=false monitor_goodput=false per_device_batch_size=10 steps=30 profiler=xplane profiler_steps=20 max_target_length=65536 \ No newline at end of file