diff --git a/examples/run_text_generation_starcoder.sh b/examples/run_text_generation_starcoder.sh new file mode 100644 index 0000000000..b2feb7bf68 --- /dev/null +++ b/examples/run_text_generation_starcoder.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# This example will start serving the 1B model. +# You may need to adapt Flask port if it's occupied in MegatronServer class, we chnaged it from 5000 (default) to 8080 +DISTRIBUTED_ARGS="--nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr ip-26-0-156-228 \ + --master_port 6000" + + +CHECKPOINT=/fsx/loubna/data/extra/generations_starcoder2_1b_200k/megatron + +#/mp_rank_00/model_optim_rng.pt +VOCAB_FILE=/fsx/bigcode/experiments/pretraining/starcoder2-1B/checkpoints/conversions/vocab.json +MERGE_FILE=/fsx/bigcode/experiments/pretraining/starcoder2-1B/checkpoints/conversions/merges.txt +TOKENIZER_FILE=/fsx/loubna/data/tokenizer/starcoder2-smol-internal-1/tokenizer.json + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +#pip install flask-restful + +torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --num-layers 24 \ + --hidden-size 2048 \ + --num-attention-heads 16 \ + --attention-head-type multiquery \ + --init-method-std 0.02209 \ + --seq-length 4096 \ + --use-rotary-position-embeddings \ + --max-position-embeddings 4096 \ + --rotary-theta 100000 \ + --attention-dropout 0.1 \ + --hidden-dropout 0.1 \ + --load ${CHECKPOINT} \ + --tokenizer-type TokenizerFromFile \ + --tokenizer-file $TOKENIZER_FILE \ + --bf16 \ + --micro-batch-size 1 \ + --out-seq-length 512 \ + --seed 42 + --output_file diff --git a/megatron/text_generation_server.py b/megatron/text_generation_server.py index 58550f2e63..80dd0b288a 100644 --- a/megatron/text_generation_server.py +++ b/megatron/text_generation_server.py @@ -238,4 +238,4 @@ def __init__(self, model): api.add_resource(MegatronGenerate, '/api', resource_class_args=[model]) def run(self, url): - self.app.run(url, threaded=True, debug=False) + self.app.run(url, threaded=True, debug=False, port=8080) diff --git a/tools/run_requests_humaneval.py b/tools/run_requests_humaneval.py new file mode 100644 index 0000000000..ff52c54cd9 --- /dev/null +++ b/tools/run_requests_humaneval.py @@ -0,0 +1,83 @@ +import requests +import json +from human_eval.data import write_jsonl, read_problems + + +NUM_SAMPLES_PER_TASK = 1 +stop_tokens = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif", "\n```", "", "", "<|endoftext|>"] + + +def query_server(prompt, temperature=0.1): + url = 'http://localhost:8080/api' + headers = {'Content-Type': 'application/json; charset=UTF-8'} + data = {"prompts": [prompt], "tokens_to_generate": 256, "temperature": 0.00001, "stop_token": 0, "random_seed": 1234} + response = requests.put(url, json=data, headers=headers) + result = json.loads(response.text)["text"] + return result[0] + + +def stop_at_stop_token(decoded_string, stop_tokens): + """ + Produces the prefix of decoded_string that ends at the first occurrence of + a stop_token. + WARNING: the decoded_string *must not* include the prompt, which may have stop tokens + itself. + """ + min_stop_index = len(decoded_string) + for stop_token in stop_tokens: + stop_index = decoded_string.find(stop_token) + if stop_index != -1 and stop_index < min_stop_index: + min_stop_index = stop_index + return decoded_string[:min_stop_index] + + +def postprocess_generation(generation, prompt): + """Defines the postprocessing for a LM generation. + :param generation: str + code generation from LM + :param idx: int + (not used for Humaneval-Task) + """ + if not generation.startswith(prompt[:20]): + print(f"issue with generation: {generation}") + print(f"origin prompt: {prompt}") + generation = generation[len(prompt) :] + return prompt + stop_at_stop_token(generation, stop_tokens) + + +def main(): + problems = read_problems() + prompts = [ + problems[task_id]["prompt"] + for task_id in problems + ] + errors = [] + success = 0 + generations = [] + postprocessed_generations = [] + for i, prompt in enumerate(prompts): + prompt = prompt.strip() + try: + result = query_server(prompt) + generations.append([result]) + postprocessed_generations.append([postprocess_generation(result, prompt)]) + success += 1 + except Exception as e: + print(f"Error processing problem '{i}': {e}") + errors.append(i) + if i % 10 == 0: + print(f"Processed {i} problems") + print(f"Failed problem generations are: {errors}") + #print(f"Example:\n{result}END\n") + + print(f"Done! {success} successful problems out of {len(prompts)}, failed are: {errors}") + + with open('megatron_generations_fixtemp_50.json', 'w') as f: + json.dump(generations, f) + + with open('megatron_postprocessed_generations_fixtemp_50.json', 'w') as f: + json.dump(postprocessed_generations, f) + + +if __name__ == '__main__': + main()