diff --git a/examples/habana/gaudi_spawn.py b/examples/habana/gaudi_spawn.py new file mode 100644 index 00000000000..b7833c41773 --- /dev/null +++ b/examples/habana/gaudi_spawn.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A simple launcher script for distributed training on HPUs. + +Single node: +:: + >>> python gaudi_spawn.py --world_size=NUM_CARDS_YOU_HAVE --use_mpi + YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other + arguments of your training script) + +Multi node: +:: + >>> python gaudi_spawn.py --hostfile=PATH_TO_HOSTFILE --use_deepspeed + YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other + arguments of your training script) +""" + +import sys +from argparse import REMAINDER, ArgumentParser + +from optimum.habana.distributed import DistributedRunner +from optimum.utils import logging + + +logger = logging.get_logger(__name__) + + +def parse_args(): + """ + Helper function parsing the command line options. + @retval ArgumentParser + """ + parser = ArgumentParser( + description=( + "Habana Gaudi distributed training launch helper utility that will spawn up multiple distributed" + " processes." + ) + ) + + # Optional arguments for the launch helper + parser.add_argument("--world_size", type=int, default=1, help="Number of HPUs to use (1 or 8)") + parser.add_argument("--hostfile", type=str, default=None, help="Path to the file where hosts are specified.") + parser.add_argument("--use_mpi", action="store_true", help="Use MPI for distributed training") + parser.add_argument("--use_deepspeed", action="store_true", help="Use DeepSpeed for distributed training") + parser.add_argument("--master_port", type=int, default=29500, help="Master port used by DeepSpeed and MPI") + + # positional + parser.add_argument( + "training_script", + type=str, + help=( + "The full path to the single HPU training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script." + ), + ) + + # rest from the training program + parser.add_argument("training_script_args", nargs=REMAINDER) + + return parser.parse_args() + + +def main(): + args = parse_args() + + if args.use_deepspeed: + from transformers.integrations.deepspeed import is_deepspeed_available + + if not is_deepspeed_available(): + raise ImportError( + "--use_deepspeed requires deepspeed: `pip install" + " git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0`." + ) + + # Patch sys.argv + sys.argv = [args.training_script] + args.training_script_args + # Handle the case where arguments contain whitespaces + argv = ['"{}"'.format(arg) if " " in arg and arg[0] != '"' and arg[-1] != '"' else arg for arg in sys.argv] + command_list = [" ".join(argv)] + + distributed_runner = DistributedRunner( + command_list=command_list, + world_size=args.world_size, + hostfile=args.hostfile, + use_mpi=args.use_mpi, + use_deepspeed=args.use_deepspeed, + master_port=args.master_port, + ) + + ret_code = distributed_runner.run() + sys.exit(ret_code) + + +if __name__ == "__main__": + main() diff --git a/examples/habana/run_measure.sh b/examples/habana/run_measure.sh index b585931112e..ecd14496a47 100644 --- a/examples/habana/run_measure.sh +++ b/examples/habana/run_measure.sh @@ -1,14 +1,15 @@ for i in {1..1..2} do - python run_generation.py \ + python gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ --use_hpu_graphs \ --use_kv_cache \ + --model_name_or_path /chenxi/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9/ \ --size $i \ --trim_logits \ --batch_size 1 \ --bf16 \ - --model_name_or_path /chenxi/models--01-ai--Yi-34B/snapshots/f9cec17e8fcc054d6c8d98fd5a41ed14895caa8b \ --prompt "It is done, and submitted. You can play 'Survival of the Tastiest' on the Android, and on the web. Playing on the web works, but you have to simulate multiple touch for table moving and that can be a bit confusing. There is a lot I'd like to talk about. I will go through every topic, instead of making the typical what went right/wrong list. Concept Working over the theme was probably one of the hardest tasks which I had to face. Originally, I had an idea of what kind of game I wanted to develop, gameplay wise - something with a lot of enemies/actors, simple graphics, maybe set in the space, controlled from a top-down view. I was confident that I could fit any theme around it. In the end, the problem with a theme like 'Evolution' in a game is that evolution is unassisted. It happens through several seemingly random mutations over time, with the most apt permutation surviving. This genetic car simulator is, in my opinion, a great example of actual evolution of a species facing a challenge. But is it a game? In a game, you need to control something to reach an objective. That control goes against what evolution is supposed to be like. If you allow the user to pick how to evolve something, it's not evolution anymore - it's the equivalent of intelligent design, the fable invented by creationists to combat the idea of evolution. Being agnostic and a Pastafarian, that's not something that rubbed me the right way. Hence, my biggest dilemma when deciding what to create was not with what I wanted to create, but with what I did not. I didn't want to create an 'intelligent design' simulator and wrongly call it evolution. This is a problem, of course, every other contestant also had to face it. And judging by the entries submitted, not many managed to work around it. I'd say the only real solution was through the use of artificial selection, somehow. So far, I haven't seen any entry using this at its core gameplay. Alas, this is just a fun competition and after a while I decided not to be as strict with the game idea, and allowed myself to pick whatever I thought would work out. My initial idea was to create something where humanity tried to evolve to a next level, but had some kind of foe trying to stop them from doing so. I kind of had this image of human souls flying in space towards a monolith or a space baby (all based in 2001: A Space Odyssey of course) but I couldn't think of compelling (read: serious) mechanics for that. Borgs were my next inspiration, as their whole hypothesis fit pretty well into the evolution theme. But how to make it work? Are you the borg, or fighting the Borg? The third and final idea came to me through my girlfriend, who somehow gave me the idea of making something about the evolution of Pasta. The more I thought about it the more it sounded like it would work, so I decided to go with it. Conversations with my inspiring co-worker Roushey (who also created the 'Mechanical Underdogs' signature logo for my intros) further matured the concept, as it involved into the idea of having individual pieces of pasta flying around and trying to evolve until they became all-powerful. A secondary idea here was that the game would work to explain how the Flying Spaghetti Monster came to exist - by evolving from a normal dinner table. So the idea evolved more or less into this: you are sitting a table. You have your own plate, with is your 'base'. There are 5 other guests at the table, each with their own plate. Your plate can spawn little pieces of pasta. You do so by 'ordering' them through a menu. Some pastas are better than others; some are faster, some are stronger. They have varying 'costs', which are debited from your credits (you start with a number of credits). Once spawned, your pastas start flying around. Their instinct is to fly to other plates, in order to conquer them (the objective of the game is having your pasta conquer all the plates on the table). But they are really autonomous, so after being spawned, you have no control over your pasta (think DotA or LoL creeps). Your pasta doesn't like other people's pasta, so if they meet, they shoot sauce at each other until one dies. You get credits for other pastas your own pasta kill." \ + # --model_name_or_path /chenxi/models--01-ai--Yi-34B/snapshots/51a24adb588163efeefde6cb452feef8a677cdae \ sleep 1 done echo "Test Done...." diff --git a/examples/habana/run_tp.sh b/examples/habana/run_tp.sh new file mode 100644 index 00000000000..a48827eed8f --- /dev/null +++ b/examples/habana/run_tp.sh @@ -0,0 +1,19 @@ +for i in {10..10..2} +do + python gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ + --use_hpu_graphs \ + --use_kv_cache \ + --limit_hpu_graphs \ + --size $i \ + --batch_size 1 \ + --model_name_or_path /chenxi/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9/ \ + --trim_logits \ + --fp8 \ + --max_input_tokens -1 \ + --bf16 \ + --prompt "It is done, and submitted. You can play 'Survival of the Tastiest' on the Android, and on the web. Playing on the web works, but you have to simulate multiple touch for table moving and that can be a bit confusing. There is a lot I'd like to talk about. I will go through every topic, instead of making the typical what went right/wrong list. Concept Working over the theme was probably one of the hardest tasks which I had to face. Originally, I had an idea of what kind of game I wanted to develop, gameplay wise - something with a lot of enemies/actors, simple graphics, maybe set in the space, controlled from a top-down view. I was confident that I could fit any theme around it. In the end, the problem with a theme like 'Evolution' in a game is that evolution is unassisted. It happens through several seemingly random mutations over time, with the most apt permutation surviving. This genetic car simulator is, in my opinion, a great example of actual evolution of a species facing a challenge. But is it a game? In a game, you need to control something to reach an objective. That control goes against what evolution is supposed to be like. If you allow the user to pick how to evolve something, it's not evolution anymore - it's the equivalent of intelligent design, the fable invented by creationists to combat the idea of evolution. Being agnostic and a Pastafarian, that's not something that rubbed me the right way. Hence, my biggest dilemma when deciding what to create was not with what I wanted to create, but with what I did not. I didn't want to create an 'intelligent design' simulator and wrongly call it evolution. This is a problem, of course, every other contestant also had to face it. And judging by the entries submitted, not many managed to work around it. I'd say the only real solution was through the use of artificial selection, somehow. So far, I haven't seen any entry using this at its core gameplay. Alas, this is just a fun competition and after a while I decided not to be as strict with the game idea, and allowed myself to pick whatever I thought would work out. My initial idea was to create something where humanity tried to evolve to a next level, but had some kind of foe trying to stop them from doing so. I kind of had this image of human souls flying in space towards a monolith or a space baby (all based in 2001: A Space Odyssey of course) but I couldn't think of compelling (read: serious) mechanics for that. Borgs were my next inspiration, as their whole hypothesis fit pretty well into the evolution theme. But how to make it work? Are you the borg, or fighting the Borg? The third and final idea came to me through my girlfriend, who somehow gave me the idea of making something about the evolution of Pasta. The more I thought about it the more it sounded like it would work, so I decided to go with it. Conversations with my inspiring co-worker Roushey (who also created the 'Mechanical Underdogs' signature logo for my intros) further matured the concept, as it involved into the idea of having individual pieces of pasta flying around and trying to evolve until they became all-powerful. A secondary idea here was that the game would work to explain how the Flying Spaghetti Monster came to exist - by evolving from a normal dinner table. So the idea evolved more or less into this: you are sitting a table. You have your own plate, with is your 'base'. There are 5 other guests at the table, each with their own plate. Your plate can spawn little pieces of pasta. You do so by 'ordering' them through a menu. Some pastas are better than others; some are faster, some are stronger. They have varying 'costs', which are debited from your credits (you start with a number of credits). Once spawned, your pastas start flying around. Their instinct is to fly to other plates, in order to conquer them (the objective of the game is having your pasta conquer all the plates on the table). But they are really autonomous, so after being spawned, you have no control over your pasta (think DotA or LoL creeps). Your pasta doesn't like other people's pasta, so if they meet, they shoot sauce at each other until one dies. You get credits for other pastas your own pasta kill." \ + # --prompt "how are you ?" \ + # --model_name_or_path /chenxi/models--01-ai--Yi-34B/snapshots/51a24adb588163efeefde6cb452feef8a677cdae \ + sleep 1 +done +echo "Test Done...." diff --git a/examples/habana/utils.py b/examples/habana/utils.py index 1e3bfffab6e..69edbf51aeb 100644 --- a/examples/habana/utils.py +++ b/examples/habana/utils.py @@ -179,14 +179,24 @@ def get_torch_compiled_model(model): def setup_model(args, model_dtype, model_kwargs): - model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs) + + config = AutoConfig.from_pretrained( + args.model_name_or_path, + torch_dtype=model_dtype, + **model_kwargs) + # config.max_position_embeddings = max(config.max_position_embeddings, 20000) + config.tensor_split = False + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + config=config, + torch_dtype=model_dtype, + **model_kwargs) if args.quant_config: import habana_quantization_toolkit habana_quantization_toolkit.prep_model(model) model = model.eval() - # import pdb; pdb.set_trace() model = model.to("hpu") if args.use_hpu_graphs: @@ -208,6 +218,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs): deepspeed.init_distributed(dist_backend="hccl") config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs) + config.tensor_split = False load_to_meta = model_on_meta(config) if load_to_meta: @@ -219,29 +230,27 @@ def setup_distributed_model(args, model_dtype, model_kwargs): checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") # For PEFT models, write the merged model on disk to be able to load it on the meta device - if args.peft_model is not None: - merged_model_dir = "/tmp/text_generation_merged_peft_model" - if args.local_rank == 0: - if Path(merged_model_dir).is_dir(): - shutil.rmtree(merged_model_dir) - peft_model(args, model_dtype, **model_kwargs).save_pretrained(merged_model_dir) - torch.distributed.barrier() + # if args.peft_model is not None: + # merged_model_dir = "/tmp/text_generation_merged_peft_model" + # if args.local_rank == 0: + # if Path(merged_model_dir).is_dir(): + # shutil.rmtree(merged_model_dir) + # peft_model(args, model_dtype, **model_kwargs).save_pretrained(merged_model_dir) + # torch.distributed.barrier() write_checkpoints_json( - merged_model_dir if args.peft_model is not None else args.model_name_or_path, + args.model_name_or_path, + # merged_model_dir if args.peft_model is not None else args.model_name_or_path, args.local_rank, checkpoints_json, - token=args.token, + token=None, ) else: # TODO: revisit placement on CPU when auto-injection is possible with deepspeed.OnDevice(dtype=model_dtype, device="cpu"): - if args.peft_model is not None: - model = peft_model(args, model_dtype, **model_kwargs) - else: - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs - ) + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs + ) model.eval() # Initialize the model diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_gaudi/models/llama/modeling_llama.py b/intel_extension_for_transformers/transformers/modeling/modeling_gaudi/models/llama/modeling_llama.py index 11dcf4e6d34..0f9fa02f72f 100755 --- a/intel_extension_for_transformers/transformers/modeling/modeling_gaudi/models/llama/modeling_llama.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_gaudi/models/llama/modeling_llama.py @@ -207,6 +207,41 @@ def get_shape(self): def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) +class KVCacheCPU(KVCache): + def __init__(self): + super(KVCacheCPU, self).__init__() + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + cpu_dtype = dtype + # TODO allocate int8 as cpu has no float8_e4m3fn + if dtype == torch.float8_e4m3fn: + cpu_dtype = torch.int8 + self.cache = torch.zeros(shape, dtype=cpu_dtype, device="cpu") + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + cpu_cur = cur.to('cpu') + cpu_idx = idx.to('cpu') + if prev.shape == cur.shape: + prev.copy_(cpu_cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cpu_cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, cpu_idx - 1, cpu_cur) + return prev + else: + return torch.cat((prev, cpu_cur), dim=dim) class GaudiLlamaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): @@ -447,24 +482,50 @@ def pre_attn_forward( query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) + if self.config.tensor_split: + q_slices = query_states.split(1, dim=1) + k_slices = key_states.transpose(-2, -1).split(1, dim=1) + v_slices = value_states.split(1, dim=1) + attn_outputs = [] + for idx in range(query_states.shape[1]): + attn_weights = self.matmul_qk(q_slices[idx], k_slices[idx]) * self.norm_factor + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_out = self.matmul_av(attn_weights, v_slices[idx]) + attn_outputs.append(attn_out) + attn_output = torch.cat(attn_outputs, dim=1) + else: + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor - attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask - if cache_position is not None: - causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = self.matmul_av(attn_weights, value_states) - if attn_softmax_bf16: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) - else: - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = self.matmul_av(attn_weights, value_states) attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):