Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs] torch.compile usage guide #11078

Open
wants to merge 28 commits into
base: main
Choose a base branch
from

Conversation

youkaichao
Copy link
Member

No description provided.

Signed-off-by: youkaichao <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the documentation Improvements or additions to documentation label Dec 10, 2024
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Comment on lines +114 to +132
$ # running a 8B model on H100 with batch size 1, 36.39 seconds of compilation time, 7.7% improvement in latency

$ python3 benchmarks/benchmark_latency.py --model meta-llama/Meta-Llama-3-8B --batch-size 1 --load-format dummy
init engine (profile, create kv cache, warmup model) took 11.79 seconds
Avg latency: 0.9704469823899369 seconds

$ python3 benchmarks/benchmark_latency.py --model meta-llama/Meta-Llama-3-8B --batch-size 1 --load-format dummy -O "{'level': 3, 'candidate_compile_sizes': [1]}"
init engine (profile, create kv cache, warmup model) took 48.18 seconds
Avg latency: 0.8950413154981409 seconds

$ # running a 8B model on L4 with batch size 1, 66.54 seconds of compilation time, 4.1 % improvement in latency

$ python3 benchmarks/benchmark_latency.py --model meta-llama/Meta-Llama-3-8B --batch-size 1 --load-format dummy
init engine (profile, create kv cache, warmup model) took 20.63 seconds
Avg latency: 7.81603614680001 seconds

$ python3 benchmarks/benchmark_latency.py --model meta-llama/Meta-Llama-3-8B --batch-size 1 --load-format dummy -O "{'level': 3, 'candidate_compile_sizes': [1]}"
init engine (profile, create kv cache, warmup model) took 87.17 seconds
Avg latency: 7.495755991366673 seconds
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also run it with llama 3.2 1B model:

    $ # running a 1B model on H100 with batch size 1, 21.29 seconds of compilation time, 13.7% improvement in latency

    $ python3 benchmarks/benchmark_latency.py --model meta-llama/Llama-3.2-1B --batch-size 1 --load-format dummy --num-scheduler-steps 16
    init engine (profile, create kv cache, warmup model) took 11.79 seconds
    Avg latency: 0.2771991847005362 seconds

    $ python3 benchmarks/benchmark_latency.py --model meta-llama/Llama-3.2-1B --batch-size 1 --load-format dummy --num-scheduler-steps 16 -O "{'level': 3, 'candidate_compile_sizes': [1]}"
    init engine (profile, create kv cache, warmup model) took 33.08 seconds
    Avg latency: 0.23920089063079406 seconds

    $ # running a 1B model on L4 with batch size 1, 42.0 seconds of compilation time, 4.0 % improvement in latency

    $ python3 benchmarks/benchmark_latency.py --model meta-llama/Llama-3.2-1B --batch-size 1 --load-format dummy --num-scheduler-steps 16
    init engine (profile, create kv cache, warmup model) took 20.32 seconds
    Avg latency: 1.526933370166671 seconds

    $ python3 benchmarks/benchmark_latency.py --model meta-llama/Llama-3.2-1B --batch-size 1 --load-format dummy --num-scheduler-steps 16 -O "{'level': 3, 'candidate_compile_sizes': [1]}"
    init engine (profile, create kv cache, warmup model) took 62.32 seconds
    Avg latency: 1.4660025673666648 seconds

Comment on lines +74 to +75
- **Inductor graph compilation**: Time taken for the inductor to compile the computation graph into Triton kernels. It includes compilation for a general shape and specific shapes. Check the logs for ``Compiling a graph for general shape takes 14.77 s`` and ``Compiling a graph for shape 1 takes 13.52 s``.
- **Triton kernel compilation**: Time taken for Triton to compile the Triton kernels into GPU kernels. No specific logs are available for this part.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inductor graph compilation time is inclusive of the triton kernel compilation time, right? Could you clarify that if so?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Some triton kernels might have JIT compilation that is not included in the inductor compilation time.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we document how to enable/disable custom inductor passes?

Comment on lines 138 to 152
$ # running an 8B model on H100 with various batch sizes, 72.76 seconds of compilation time, 3.9% improvement in throughput
$
$ # 1. Run the baseline setting
$ python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --model meta-llama/Meta-Llama-3-8B --load-format dummy --num-scheduler-steps 64
init engine (profile, create kv cache, warmup model) took 14.42 seconds
Throughput: 44.39 requests/s, 22728.17 total tokens/s, 11364.08 output tokens/s

$ # 2. Run the same setting with profiling
$ VLLM_LOG_BATCHSIZE_INTERVAL=1.0 python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --model meta-llama/Meta-Llama-3-8B --num-scheduler-steps 64
INFO 12-10 15:42:47 forward_context.py:58] Batchsize distribution (batchsize, count): [(256, 769), (232, 215), ...]

$ # 3. The most common batch sizes are 256 and 232, so we can compile the model for these two batch sizes
$ python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --model meta-llama/Meta-Llama-3-8B --num-scheduler-steps 64 -O "{'level': 3, 'candidate_compile_sizes': [232, 256]}"
init engine (profile, create kv cache, warmup model) took 87.18 seconds
Throughput: 46.11 requests/s, 23606.51 total tokens/s, 11803.26 output tokens/s
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeat for llama 3.2 1B on H100, almost no improvement.

    $ # running an 1B model on H100 with various batch sizes, 39.79 seconds of compilation time, 0.5% improvement in throughput
    $
    $ # 1. Run the baseline setting
    $ python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --model meta-llama/Llama-3.2-1B --load-format dummy --num-scheduler-steps 64
    init engine (profile, create kv cache, warmup model) took 13.14 seconds
    Throughput: 116.83 requests/s, 59814.48 total tokens/s, 29907.24 output tokens/s

    $ # 2. Run the same setting with profiling
    $ VLLM_LOG_BATCHSIZE_INTERVAL=1.0 python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --model meta-llama/Llama-3.2-1B --num-scheduler-steps 64
    INFO 12-10 15:42:47 forward_context.py:58] Batchsize distribution (batchsize, count): [(256, 769), (232, 215), ...]

    $ # 3. The most common batch sizes are 256 and 232, so we can compile the model for these two batch sizes
    $ python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --model meta-llama/Llama-3.2-1B --num-scheduler-steps 64 -O "{'level': 3, 'candidate_compile_sizes': [232, 256]}"
    init engine (profile, create kv cache, warmup model) took 52.93 seconds
    Throughput: 117.38 requests/s, 60100.50 total tokens/s, 30050.25 output tokens/s

repeat for llama 3.2 1B on L4 (it's even slower):

    $ # running an 1B model on L4 with various batch sizes, 58.5 seconds of compilation time, -6.0% improvement in throughput
    $
    $ # 1. Run the baseline setting
    $ python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --model meta-llama/Llama-3.2-1B --load-format dummy --num-scheduler-steps 64
    init engine (profile, create kv cache, warmup model) took 21.77 seconds
    Throughput: 16.36 requests/s, 8376.21 total tokens/s, 4188.10 output tokens/s

    $ # 2. Run the same setting with profiling
    $ VLLM_LOG_BATCHSIZE_INTERVAL=1.0 python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --model meta-llama/Llama-3.2-1B --load-format dummy --num-scheduler-steps 64
    INFO 12-10 15:42:47 forward_context.py:58] Batchsize distribution (batchsize, count): [(256, 769), (232, 215), ...]

    $ # 3. The most common batch sizes are 256 and 232, so we can compile the model for these two batch sizes
    $ python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --model meta-llama/Llama-3.2-1B --load-format dummy --num-scheduler-steps 64 -O "{'level': 3, 'candidate_compile_sizes': [232, 256]}"
    init engine (profile, create kv cache, warmup model) took 80.27 seconds
    Throughput: 15.38 requests/s, 7873.07 total tokens/s, 3936.54 output tokens/s

@youkaichao
Copy link
Member Author

Should we document how to enable/disable custom inductor passes?

not right now until we have some passes with significant perf gain.

Signed-off-by: youkaichao <[email protected]>
@DarkLight1337
Copy link
Member

Can you update the Compatibility Matrix with this feature as well?

@youkaichao
Copy link
Member Author

youkaichao commented Dec 11, 2024

Can you update the Compatibility Matrix with this feature as well?

it's quite complicated, especially for vision language models. Maybe I can just list all the models not supporting torch.compile in the future. The criterion is simple:

  • For a text-only model, if the modeling file has support_torch_compile, it means torch.compile is supported. If the super class supports it, then subclass also supports it, e.g. class AriaMoELMModel(LlamaModel): .
  • For a multi-modality model, if the text part supports torch.compile, then it is supported as well.

I think mainly cross-attention models do not support torch.compile right now.

Signed-off-by: youkaichao <[email protected]>
Comment on lines +170 to +178
The following models are currently not supported by ``torch.compile``, because their computation graphs are too dynamic to compile:

- ``InternLM2VEForCausalLM``, ``InternVLChatModel``
- cross-attention models like ``MllamaForConditionalGeneration`` and ``BartForConditionalGeneration``

The following models should be supported by ``torch.compile`` in the future, but not supported yet due to bandwidth limitations:

- ``Mamba`` related models
- ``ChameleonModel``, ``ChatGLMModel``, ``DbrxModel``, ``DeepseekModel``, ``MixtralModel``, ``Olmo2Model``, ``Phi3SmallModel``, ``StableLMEpochModel``
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 I checked all the models, and these should be the list of unsupported models. I will also update the add models page to show how to add support for new models shortly.


To effectively use ``torch.compile``, the TL;DR; is:

- Ensure GPUs are busy executing the model before enabling ``torch.compile``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify this statement? I think I know what it is trying to say but I think it is a bit ambiguous.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have better ideas?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like "torch.compile works best for models that are GPU bound"?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants