diff --git a/community/README.md b/community/README.md index 5a074cc2..777e3322 100644 --- a/community/README.md +++ b/community/README.md @@ -86,4 +86,8 @@ Community examples are sample code and deployments for RAG pipelines that are no * [Chat with LLM Llama 3.1 Nemotron Nano 4B](./chat-llama-nemotron/) - This is a React-based conversational UI designed for interacting with a powerful local LLM. It incorporates RAG to enhance contextual understanding and is backed by an NVIDIA Dynamo inference server running the NVIDIA Llama-3.1-Nemotron-Nano-4B-v1.1 model. The setup enables low-latency, cloud-free AI assistant capabilities, with live document search and reasoning, all deployable on local or edge infrastructure. \ No newline at end of file + This is a React-based conversational UI designed for interacting with a powerful local LLM. It incorporates RAG to enhance contextual understanding and is backed by an NVIDIA Dynamo inference server running the NVIDIA Llama-3.1-Nemotron-Nano-4B-v1.1 model. The setup enables low-latency, cloud-free AI assistant capabilities, with live document search and reasoning, all deployable on local or edge infrastructure. + +* [LLM Inference Series: Performance, Optimization & Deployment with LLMs](llm-inference-series) + +This repository supports a video + notebook series exploring how to run, optimize, and serve Large Language Models (LLMs) with a focus on latency, throughput, user experience (UX), and NVIDIA GPU acceleration. \ No newline at end of file diff --git a/community/llm-inference-series/.gitignore b/community/llm-inference-series/.gitignore new file mode 100644 index 00000000..afae3e0a --- /dev/null +++ b/community/llm-inference-series/.gitignore @@ -0,0 +1,35 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +env/ +venv/ +.venv/ + +# Jupyter +.ipynb_checkpoints/ + +# Data files +*.csv +*.json +*.pkl +*.h5 +*.hdf5 + +# OS +.DS_Store +Thumbs.db + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# Logs +*.log + +# Project +01_inference_101/batch_benchmark.csv \ No newline at end of file diff --git a/community/llm-inference-series/01_inference_101/inference_101.ipynb b/community/llm-inference-series/01_inference_101/inference_101.ipynb new file mode 100755 index 00000000..6cabfc21 --- /dev/null +++ b/community/llm-inference-series/01_inference_101/inference_101.ipynb @@ -0,0 +1,957 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1460d3ac-17e9-4bba-831b-957b0882656d", + "metadata": {}, + "source": [ + "# Inference 101: Latency, Throughput & UX" + ] + }, + { + "cell_type": "markdown", + "id": "087182c3-ce37-4540-8e5c-1d92da128a07", + "metadata": {}, + "source": [ + "This notebook is a hands-on introduction to the key performance aspects of running inference with Large Language Models (LLMs).\n", + "\n", + "You’ll learn:\n", + "\n", + "✅ What **latency** and **throughput** mean in the context of LLMs \n", + "✅ Why these metrics often trade off against each other \n", + "✅ How different parameters (like batch size, prompt length, sampling strategy) affect performance \n", + "✅ How to measure and visualize **p50 vs p99 latency**, **first-token vs total latency**, and find the \"sweet spot\" \n", + "✅ What this means for real-world **user experience**\n", + "\n", + "We'll use **TensorRT-LLM** with **PyTorch** backend and a **Hugging Face**-hosted model (**Mistral 7B Instruct**) for all experiments.\n", + "\n", + "By the end of this notebook, you'll not only be able to benchmark an LLM — you'll know what the numbers actually mean and how to tune them for real applications." + ] + }, + { + "cell_type": "markdown", + "id": "0a2f5bb3-4418-4e5a-8fdf-3a533a58778c", + "metadata": {}, + "source": [ + "## Preliminaries" + ] + }, + { + "cell_type": "markdown", + "id": "56aa9fbb-0528-4e22-aff0-0f2445b88fa0", + "metadata": {}, + "source": [ + "**Before you begin**, make sure you have:\n", + "\n", + "- An NVIDIA GPU environment\n", + "- Access to the gated [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) model on Hugging Face\n", + "- Your Hugging Face access [token](https://huggingface.co/settings/token)\n", + "\n", + "Let's test which GPUs are avaialable in our system:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91ab4753-aa5a-40b0-abe5-c28a365de8c7", + "metadata": {}, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "id": "939fc25f-3cd5-4b5d-aec6-9b295b5734a9", + "metadata": {}, + "source": [ + "### Authenticating with Hugging Face\n", + "\n", + "To download the model from Hugging Face, you’ll need to enter your personal access token.\n", + "\n", + "The cell below provides a simple interface to enter and save your token securely. It will be cached locally, so you only need to do this once per environment.\n", + "\n", + "➡️ You can find or generate your token at: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", + "\n", + "Once saved, the token will allow seamless access to gated models like `mistralai/Mistral-7B-Instruct-v0.3`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11099e7d-6ca8-47e7-bc24-934321e4f1ea", + "metadata": {}, + "outputs": [], + "source": [ + "# ⬇️ Run this cell once\n", + "from ipywidgets import Password, Button, HBox, Output\n", + "import os, pathlib\n", + "import sys\n", + "\n", + "from huggingface_hub import HfFolder, whoami\n", + "\n", + "# ---- UI widgets ----\n", + "token_box = Password(\n", + " description=\"HF Token:\",\n", + " placeholder=\"paste your Hugging Face token here\",\n", + " layout={\"width\": \"450px\"},\n", + ")\n", + "save_btn = Button(description=\"Save\", button_style=\"success\")\n", + "out = Output()\n", + "\n", + "# ---- Callback ----\n", + "def save_token(_):\n", + " out.clear_output()\n", + " token = token_box.value.strip()\n", + " with out:\n", + " if not token:\n", + " print(\"❌ No token entered.\")\n", + " return\n", + " # Persist token\n", + " HfFolder.save_token(token) # writes to ~/.cache/huggingface/token\n", + " os.environ[\"HF_TOKEN\"] = token # current kernel env (optional)\n", + " # Sanity-check who we are\n", + " try:\n", + " user = whoami(token)[\"name\"]\n", + " print(f\"✅ Token saved. Logged in as: {user}\")\n", + " except Exception as e:\n", + " print(\"⚠️ Token saved, but user lookup failed:\", e)\n", + "\n", + "save_btn.on_click(save_token)\n", + "\n", + "display(HBox([token_box, save_btn]), out)" + ] + }, + { + "cell_type": "markdown", + "id": "753f51cf-59a9-4094-b449-e2f5c9db502e", + "metadata": {}, + "source": [ + "### Downloading and optimizing the model" + ] + }, + { + "cell_type": "markdown", + "id": "eba6a81f-1fff-4f3a-b232-426d346ecf54", + "metadata": {}, + "source": [ + "We'll be using **TensorRT-LLM** with a **PyTorch** backend to run inference. \n", + "\n", + "[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) is an open-source library from NVIDIA designed specifically to **optimize and accelerate LLM inference** on NVIDIA GPUs.\n", + "\n", + "The PyTorch backend provides:\n", + "\n", + "- A familiar Pythonic API for rapid prototyping and ease of integration\n", + "- Compatibility with the Hugging Face model ecosystem\n", + "- Seamless fallback to PyTorch ops when certain layers or patterns can't be fully optimized\n", + "\n", + "This makes it an ideal choice for developers who want **high performance** without sacrificing **flexibility**.\n", + "\n", + "Compared to native unoptimized inference, it offers significantly better performance (especially in terms of throughput and latency) by leveraging features like:\n", + "\n", + "- quantization\n", + "- KV cache management\n", + "- kernel fusion\n", + "- efficient batching\n", + "\n", + "In this step, we’ll load the `mistralai/Mistral-7B-Instruct-v0.3` model from Hugging Face.\n", + "\n", + "📌 **Note:** The first time you run this, it may take a few minutes to:\n", + "\n", + "- Download the model weights into your local Hugging Face cache\n", + "- Optimize the model for your GPU (this step is automatic)\n", + "\n", + "Subsequent runs will be faster, as both the model and its compiled artifacts will be reused." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2bcd87be-29c6-4880-9863-090a3742defd", + "metadata": {}, + "outputs": [], + "source": [ + "from tensorrt_llm._torch import LLM\n", + "from tensorrt_llm import SamplingParams\n", + "\n", + "# Instantiate the model once, reuse for every experiment\n", + "model = LLM(model=\"mistralai/Mistral-7B-Instruct-v0.3\")" + ] + }, + { + "cell_type": "markdown", + "id": "991f0306-01f9-4829-8b56-2499feccb0ba", + "metadata": {}, + "source": [ + "## Executing inference" + ] + }, + { + "cell_type": "markdown", + "id": "66fe43d3-4382-4468-abd0-575fffa05226", + "metadata": {}, + "source": [ + "Let’s see what our model can do!\n", + "\n", + "We’ll run inference on a batch of prompts and observe the outputs. This is your first direct interaction with the model.\n", + "\n", + "### What is inference?\n", + "\n", + "**Inference** is the process of using a trained model to generate outputs for new inputs without any learning or weight updates.\n", + "\n", + "For large language models (LLMs), this means:\n", + "\n", + "- Receiving a **prompt** (a string of text)\n", + "- Computing the most likely **next token(s)**\n", + "- Repeating the process token-by-token to generate a full response\n", + "\n", + "Inference is the core of every production LLM system — whether you’re building a chatbot, writing assistant, summarizer, or anything else.\n", + "\n", + "### Sampling parameters\n", + "\n", + "When generating text, the model can either:\n", + "\n", + "- **Always pick the highest-probability token** (greedy decoding — fast and deterministic)\n", + "- **Sample from the probability distribution** over possible next tokens — which adds variety and creativity\n", + "\n", + "We use the following parameters to control that behavior:\n", + "\n", + "- `temperature = 0.7`: Controls randomness. Lower values → more confident, deterministic outputs. Higher values → more diverse, sometimes erratic responses.\n", + "- `top_p = 0.9`: Enables **nucleus sampling** — the model samples only from the top tokens that together make up 90% of the probability mass. Balances diversity and coherence.\n", + "- `top_k = 50`: Further restricts sampling to the top 50 tokens at each step (can be used alone or with `top_p`).\n", + "- `max_tokens = 512`: The maximum number of tokens to generate per prompt.\n", + "- `stop = [\"\"]`: Tells the model to stop when it sees the end-of-sequence token.\n", + "\n", + "Together, these settings ensure the output is diverse but not chaotic — a good balance for most use cases.\n", + "\n", + "Try changing the prompts and sampling values below to see how the model’s behavior changes!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86d0e074-9f3c-45d3-9582-a2d2ef21571a", + "metadata": {}, + "outputs": [], + "source": [ + "prompts = [\n", + " \"Summer is\",\n", + " \"The president of France is\",\n", + " \"The capital of Germany is\",\n", + " \"The future of AI is\",\n", + " ]\n", + "\n", + "sampling_params = SamplingParams(\n", + " temperature=0.7,\n", + " top_p=0.9,\n", + " top_k=50,\n", + " max_tokens=512,\n", + " stop=[\"\"]\n", + ")\n", + "\n", + "outputs = model.generate(prompts, sampling_params)\n", + "\n", + "for i, output in enumerate(outputs):\n", + " prompt = output.prompt\n", + " generated_text = output.outputs[0].text\n", + " print(f\"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "f45af53c-e233-4927-989a-81e2650dbf48", + "metadata": {}, + "source": [ + "## Performance experiments" + ] + }, + { + "cell_type": "markdown", + "id": "44746565-040b-4b0c-9d39-e83d346f82dd", + "metadata": {}, + "source": [ + "Before we dive into benchmarking, we’ll configure a few things to ensure our measurements are accurate and repeatable:\n", + "\n", + "What this cell does:\n", + "\n", + "- **Imports** all required Python libraries for timing, statistics, and plotting\n", + "- **Sets a fixed random seed** for reproducibility \n", + " > This doesn't significantly affect latency, but ensures the same outputs are generated each time — useful when debugging or comparing sampling runs.\n", + "- **Initializes the tokenizer** from Hugging Face, which we’ll use to measure how many tokens the model actually generates\n", + "- **Defines two helper functions**:\n", + " - `count_generated_tokens()` → counts only the tokens *generated* by the model, excluding the prompt\n", + " - `timed_generate()` → wraps the model's `.generate()` call and returns:\n", + " - total **latency**\n", + " - number of **generated tokens**\n", + " - raw **outputs** (including prompts and completions)\n", + "\n", + "We also apply `nest_asyncio` to avoid runtime warnings from Jupyter's existing event loop — it won’t affect performance but ensures clean execution in notebook environments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "019cc08b-4539-4cc2-9e80-536964167eca", + "metadata": {}, + "outputs": [], + "source": [ + "import os, time, random, math, json, itertools, statistics, gc\n", + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from transformers import AutoTokenizer\n", + "import nest_asyncio; nest_asyncio.apply()\n", + "import pandas as pd\n", + "\n", + "# Reproducibility – will NOT change latency much, but guarantees identical tokens\n", + "SEED = 42\n", + "random.seed(SEED)\n", + "np.random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "torch.cuda.manual_seed_all(SEED)\n", + "torch.backends.cudnn.benchmark = False # slightly hurts speed but stabilises variance\n", + "\n", + "# Use the same model card so we don’t download twice\n", + "tokenizer = AutoTokenizer.from_pretrained(\n", + " \"mistralai/Mistral-7B-Instruct-v0.3\",\n", + " use_fast=True\n", + ")\n", + "\n", + "def count_generated_tokens(output_obj, prompt_text):\n", + " \"\"\"\n", + " Return only the *new* tokens produced by the model\n", + " (prompt tokens aren't counted towards throughput).\n", + " \"\"\"\n", + " gen = output_obj.outputs[0]\n", + "\n", + " # Fast path: some TRT-LLM builds expose raw token IDs\n", + " if hasattr(gen, \"token_ids\") and gen.token_ids is not None:\n", + " return len(gen.token_ids)\n", + "\n", + " # Fallback: tokenize the generated string itself\n", + " # (prompt not included, so no subtraction needed)\n", + " return len(tokenizer.encode(gen.text, add_special_tokens=False))\n", + "\n", + "\n", + "def timed_generate(prompts, sampling_params):\n", + " \"\"\"Return latency, generated-token count, and the raw outputs.\"\"\"\n", + " torch.cuda.synchronize()\n", + " t0 = time.perf_counter()\n", + " outputs = model.generate(prompts, sampling_params)\n", + " torch.cuda.synchronize()\n", + " elapsed = time.perf_counter() - t0\n", + "\n", + " gen_tokens = sum(count_generated_tokens(o, p)\n", + " for o, p in zip(outputs, prompts))\n", + " return elapsed, gen_tokens, outputs" + ] + }, + { + "cell_type": "markdown", + "id": "9ea029ea-4250-4417-b34d-de954414c3c2", + "metadata": {}, + "source": [ + "### Baseline latency and throughput" + ] + }, + { + "cell_type": "markdown", + "id": "738e48d9-844e-4c82-85ec-b926d75f25a1", + "metadata": {}, + "source": [ + "Let’s establish a **baseline** for how our model performs with a single prompt under typical sampling settings.\n", + "\n", + "This test measures:\n", + "\n", + "- **Latency**: the total time it takes to generate a complete response\n", + "- **Throughput**: the number of tokens generated per second\n", + "\n", + "What we’re doing here:\n", + "\n", + "- Use a **single input prompt** (feel free to change it!)\n", + "- Generate up to `128` tokens using moderate sampling (`temperature=0.7`, `top_p=0.9`)\n", + "- Call `timed_generate()` to:\n", + " - time the entire inference run\n", + " - count how many tokens were actually generated\n", + "- Print the generated response\n", + "\n", + "This gives us a **reference point** to compare against later experiments where we vary different factors.\n", + "\n", + "> This is also the simplest “real-world” case: one user, one prompt, one answer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96dda97f-295a-4342-9a13-ca4f71437b46", + "metadata": {}, + "outputs": [], + "source": [ + "prompts = [\"The future of AI is\"] # single-prompt baseline\n", + "sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=128, stop=[\"\"])\n", + "\n", + "lat, toks, outs = timed_generate(prompts, sampling_params)\n", + "print(f\"Latency: {lat:.3f}s | Throughput: {toks/lat:.1f} tok/s\")\n", + "\n", + "for o in outs:\n", + " print(o.outputs[0].text)" + ] + }, + { + "cell_type": "markdown", + "id": "ee2f1674-a8e8-4902-9885-f76a1b5c7bf9", + "metadata": {}, + "source": [ + "### Batch size sweep" + ] + }, + { + "cell_type": "markdown", + "id": "4376d2ac-e894-4b23-b74f-7ed3c4a54e29", + "metadata": {}, + "source": [ + "Now we’re going to benchmark how inference performance changes with **different batch sizes** — that is, how many prompts we process in parallel.\n", + "\n", + "This experiment helps us understand the **tradeoff between throughput and latency**, and find the batch size that balances performance and responsiveness.\n", + "\n", + "For each batch size in:\n", + "\n", + "```python\n", + "[1, 8, 32, 128, 256, 512]\n", + "```\n", + "\n", + "we run 10 repetitions and record:\n", + "\n", + "- **Time to first token (TTFT)** — how quickly the model starts responding\n", + "- **Total latency** — how long it takes to generate the full output\n", + "- **p50 and p99 latency** — to capture both typical and tail performance\n", + "- **Throughput** — total tokens generated per second (real tokens, not just assumed)\n", + "\n", + "Why this matters?\n", + "\n", + "- Small batches (1–8) respond quickly but underutilize the GPU\n", + "- Large batches (128–512) maximize throughput but increase individual latency\n", + "- Real-time systems typically operate at a batch size that balances tail latency (p99) and system throughput\n", + "\n", + "> These runs can take a few minutes, especially at large batch sizes — start executing the cell and grab a coffee!\n", + "\n", + "We will be saving the results in the file `batch_benchmark.csv` so you can reuse or visualize the data later." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17391b5e-ad48-40ed-81e0-97c92d314539", + "metadata": {}, + "outputs": [], + "source": [ + "BATCH_SIZES = [1, 8, 32, 128, 256, 512]\n", + "RUNS_PER_SIZE = 10\n", + "MAX_TOKENS_FULL = 128 # length for total-latency test\n", + "SAMPLING_OPTS = dict(temperature=0.7, top_p=0.9, stop=[\"\"])\n", + "\n", + "records = []\n", + "\n", + "for bs in BATCH_SIZES:\n", + " first_latencies, total_latencies = [], []\n", + " gen_token_total = 0\n", + " for _ in range(RUNS_PER_SIZE):\n", + " prompts = [\"The future of AI is\"] * bs\n", + "\n", + " # TTFT – first token\n", + " t_first, _, _ = timed_generate(\n", + " prompts,\n", + " SamplingParams(**SAMPLING_OPTS, max_tokens=1)\n", + " )\n", + " first_latencies.append(t_first)\n", + "\n", + " # Total latency – full answer\n", + " t_total, gen_tokens, _ = timed_generate(\n", + " prompts,\n", + " SamplingParams(**SAMPLING_OPTS, max_tokens=MAX_TOKENS_FULL)\n", + " )\n", + " total_latencies.append(t_total)\n", + " gen_token_total += gen_tokens\n", + "\n", + " gc.collect(); torch.cuda.empty_cache()\n", + "\n", + " # derive stats once per batch size\n", + " p50_first = statistics.quantiles(first_latencies, n=100)[49]\n", + " p99_first = statistics.quantiles(first_latencies, n=100)[98]\n", + " p50_total = statistics.quantiles(total_latencies, n=100)[49]\n", + " p99_total = statistics.quantiles(total_latencies, n=100)[98]\n", + "\n", + " # throughput = tokens produced per second (generated tokens only)\n", + " avg_latency = p50_total\n", + " thru = gen_token_total / (RUNS_PER_SIZE * avg_latency)\n", + "\n", + " # We can also assume every run produced MAX_TOKENS_FULL tokens per request\n", + " # thru = (bs * MAX_TOKENS_FULL) / p50_total\n", + "\n", + " records.append(dict(\n", + " batch_size = bs,\n", + " p50_ttft = p50_first,\n", + " p99_ttft = p99_first,\n", + " p50_latency = p50_total,\n", + " p99_latency = p99_total,\n", + " throughput = thru\n", + " ))\n", + "\n", + " print(f\"BS={bs:>2} | p50={p50_total:.3f}s | p99={p99_total:.3f}s | thru={thru:.1f} tok/s\")\n", + "\n", + "df = pd.DataFrame.from_records(records).sort_values(\"batch_size\")\n", + "display(df) # Jupyter pretty-prints\n", + "\n", + "# Persist for later sessions\n", + "df.to_csv(\"batch_benchmark.csv\", index=False)\n", + "print(\"✅ Benchmark finished – data stored in DataFrame `df` and batch_benchmark.csv\")" + ] + }, + { + "cell_type": "markdown", + "id": "1fc1da67-8bad-42bb-8d32-c76e9a9b8bf1", + "metadata": {}, + "source": [ + "### Interpreting our benchmarks" + ] + }, + { + "cell_type": "markdown", + "id": "b3950c3c-5aa1-48d1-8c9b-2faa00f7fe7c", + "metadata": {}, + "source": [ + "Now that we've run our performance sweep, let’s dive into what the results actually mean — and what insights we can extract from them." + ] + }, + { + "cell_type": "markdown", + "id": "b3619f2d-ded8-45df-aea1-ef0bef992db0", + "metadata": {}, + "source": [ + "#### First-token vs total latency, p50 vs p99" + ] + }, + { + "cell_type": "markdown", + "id": "4aefbefb-c98d-42e5-9c0a-1f80f9ae270a", + "metadata": {}, + "source": [ + "A critical aspect of LLM user experience is how quickly the model starts responding. This is often referred to as:\n", + "\n", + "- **Time to first token (TTFT)** — the latency until the first word appears on screen. \n", + "- **Total Latency** — the time it takes to generate the *entire* response.\n", + "\n", + "In many real-world applications (e.g. chatbots, assistants), **TTFT is more important than total latency** — users start reading as soon as the model begins responding.\n", + "\n", + "Latency can vary between runs due to scheduling, GPU memory pressure, or sampling variability. That’s why we track:\n", + "\n", + "- **p50 latency**: the median latency — what a “typical” user experiences.\n", + "- **p99 latency**: the slowest 1% — what users see during worst-case delays.\n", + "\n", + "Why this matters:\n", + "\n", + "- A low p50 is nice.\n", + "- A high p99 is a problem. It means some users are waiting 2–3× longer than expected.\n", + "- **Good systems optimize for p99, not just average.**\n", + "\n", + "Let’s now visualize how **batch size** affects all of these:\n", + "\n", + "- Time to first token (TTFT)\n", + "- Total latency\n", + "- Latency variability (p50 vs. p99)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2a63786-a317-4ace-9327-04dfa4cf815e", + "metadata": {}, + "outputs": [], + "source": [ + "bs = df[\"batch_size\"]\n", + "\n", + "plt.figure(figsize=(7,4))\n", + "plt.plot(bs, df[\"p50_ttft\"], 'o-', label='TTFT p50')\n", + "plt.plot(bs, df[\"p99_ttft\"], 'x-', label='TTFT p99')\n", + "plt.plot(bs, df[\"p50_latency\"], 's--', label='Total p50')\n", + "plt.plot(bs, df[\"p99_latency\"], '^--', label='Total p99')\n", + "plt.xlabel(\"Batch size\"); plt.ylabel(\"Latency (s)\")\n", + "plt.title(\"First-token vs Total latency across batch sizes\")\n", + "plt.legend(); plt.grid(True); plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "cbd4d28d-82de-4ff7-918b-9d195af9d731", + "metadata": {}, + "source": [ + "This plot shows that first-token latency (TTFT) remains low across all batch sizes, making it ideal for streaming or interactive applications, while total latency increases sharply — especially at higher batch sizes. The widening gap between p50 and p99 total latency indicates that tail latency becomes a problem as batch size grows, which can degrade user experience. For real-time systems, batch sizes above 128–256 may introduce noticeable delays, even if throughput improves. TTFT’s stability suggests that batching is still viable for streaming use cases, where the first token appears quickly and the rest of the response streams in." + ] + }, + { + "cell_type": "markdown", + "id": "23abc1dd-2571-4cba-8279-997b9c85efa8", + "metadata": {}, + "source": [ + "#### Throughput vs p50 latency" + ] + }, + { + "cell_type": "markdown", + "id": "8c5fc1c1-c130-46bb-8c02-6028422055fb", + "metadata": {}, + "source": [ + "Now let’s visualize the tradeoff between throughput and p50 latency. Higher batch sizes typically improve throughput by maximizing GPU utilization, but also increase latency. Our goal is to identify the “sweet spot” — the batch size that delivers strong throughput without compromising responsiveness, striking the right balance for real-time inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37c7f286-95be-4ead-9277-10b938248348", + "metadata": {}, + "outputs": [], + "source": [ + "lat_n = df[\"p50_latency\"] / df[\"p50_latency\"].max() # 0-1\n", + "thr_inv= 1 - df[\"throughput\"] / df[\"throughput\"].max() # 0-1, inverted\n", + "gap = (lat_n - thr_inv).abs()\n", + "best_i = gap.argmin()\n", + "\n", + "plt.figure(figsize=(7,4))\n", + "plt.plot(bs, lat_n, 'o-', label='p50 latency (norm)')\n", + "plt.plot(bs, thr_inv,'^--', label='throughput (inverted norm)')\n", + "plt.scatter(bs.iloc[best_i], lat_n.iloc[best_i], c='red', s=120,\n", + " label=f'sweet spot ≈ BS {bs.iloc[best_i]}')\n", + "plt.xlabel(\"Batch size\"); plt.ylabel(\"Normalised metric (0-1)\")\n", + "plt.title(\"Latency vs Throughput – knee point\")\n", + "plt.legend(); plt.grid(True); plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "7e0d5c3a-3948-4885-a7ff-ca2f987d30ac", + "metadata": {}, + "source": [ + "This generated plot shows the tradeoff between p50 latency and throughput as batch size increases, helping us identify the optimal “sweet spot” for inference performance. As batch size grows, latency increases, while throughput improves (the line inverted for visual comparison). \n", + "\n", + "The two curves intersect around batch size 128, which represents the **knee point** — the smallest batch size that gets us close to maximum throughput without incurring excessive latency. Going beyond this point yields diminishing returns: throughput flattens while latency continues to rise sharply. For many real-time or near-real-time use cases, this intersection represents the best balance between speed and efficiency." + ] + }, + { + "cell_type": "markdown", + "id": "40f82fd4-cacf-4b5f-93ae-df595a76092d", + "metadata": {}, + "source": [ + "#### Finding the sweet spot based on UX budget" + ] + }, + { + "cell_type": "markdown", + "id": "a7112420-16d4-49d0-ae72-812a4d588bca", + "metadata": {}, + "source": [ + "Choosing the optimal batch size isn’t just about throughput — it’s about respecting your users’ **latency expectations**.\n", + "\n", + "This section helps you find the best-performing batch size that stays within a given **latency budget**, defined in seconds. You can adjust the value of `latency_budget` depending on your application’s needs (e.g. 2.5s for chat, 5s for summarization, etc.).\n", + "\n", + "What this code does:\n", + "\n", + "- Filters all benchmark results to include only configurations where latency is **below your target threshold**\n", + "- Then selects the **batch size with the highest throughput** from the remaining options\n", + "- Runs this selection twice:\n", + " - Once using **p50 latency** (typical case)\n", + " - Once using **p99 latency** (worst-case tail latency)\n", + "\n", + "Try changing the `latency_budget` value and observe how the recommended batch size shifts depending on the metric used." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8d6fc98-3062-472c-87f4-49f72dede876", + "metadata": {}, + "outputs": [], + "source": [ + "latency_budget = 2.5 # seconds\n", + "\n", + "filtered_p50 = df[df[\"p50_latency\"] <= latency_budget]\n", + "filtered_p99 = df[df[\"p99_latency\"] <= latency_budget]\n", + "\n", + "if not filtered_p50.empty:\n", + " best_row_p50 = filtered_p50.loc[filtered_p50[\"throughput\"].idxmax()]\n", + " print(f\"✅ Recommended: batch_size={int(best_row_p50.batch_size)} \"\n", + " f\"⇒ p50 latency = {best_row_p50.p50_latency:.2f}s, \"\n", + " f\"throughput = {best_row_p50.throughput:.0f} tok/s\")\n", + "else:\n", + " print(f\"❌ No configuration meets the latency budget of {latency_budget:.1f}s\")\n", + "\n", + "if not filtered_p99.empty:\n", + " best_row_p99 = filtered_p99.loc[filtered_p99[\"throughput\"].idxmax()]\n", + " print(f\"✅ Recommended: batch_size={int(best_row_p99.batch_size)} \"\n", + " f\"⇒ p99 latency = {best_row_p99.p99_latency:.2f}s, \"\n", + " f\"throughput = {best_row_p99.throughput:.0f} tok/s\")\n", + "else:\n", + " print(f\"❌ No configuration meets the latency budget of {latency_budget:.1f}s based on p99 latency\")" + ] + }, + { + "cell_type": "markdown", + "id": "468828c9-a9c3-475f-8818-1d1a0f4d1175", + "metadata": {}, + "source": [ + "### Asessing impact of other parameters" + ] + }, + { + "cell_type": "markdown", + "id": "671b9e3d-f1d6-470d-b510-cc98cf38444b", + "metadata": {}, + "source": [ + "So far, we’ve seen that **batch size** directly affects latency and throughput. But batch size isn’t the only factor that matters.\n", + "\n", + "Let’s now explore how **other inference parameters**, starting with **output sequence length**, influence performance." + ] + }, + { + "cell_type": "markdown", + "id": "c1a8afac-22cd-4168-bc02-b8876e10bbe2", + "metadata": {}, + "source": [ + "#### Sequence Length" + ] + }, + { + "cell_type": "markdown", + "id": "53396c0b-f26d-42c0-94df-8bc386383878", + "metadata": {}, + "source": [ + "The **sequence length** `max_tokens` parameter defines how many tokens the model is allowed to generate per prompt.\n", + "\n", + "In practice:\n", + "\n", + "- Short outputs (e.g. 32 tokens) return quickly\n", + "- Long outputs (e.g. 512 tokens) take significantly more time, especially at high batch sizes\n", + "\n", + "This is because while **prefill cost is fixed**, the **decode phase scales linearly** with the number of tokens generated — we will look into that in a bit!\n", + "\n", + "What this experiment does:\n", + "\n", + "- Uses a fixed batch size (`batch = 128`)\n", + "- Varies the `max_tokens` cap from 32 to 512\n", + "- Measures **mean** and **p99 latency** across 5 runs for each setting\n", + "\n", + "> Note: This experiment can take a few minutes — longer sequence lengths at high batch size are compute-intensive.\n", + "\n", + "Let’s visualize how increasing the output length impacts both average latency and tail latency." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f11c4afe-6cfb-40e2-85ce-39a93263c9d2", + "metadata": {}, + "outputs": [], + "source": [ + "seq_lengths = [32, 64, 128, 256, 512]\n", + "batch = 128\n", + "lat_means, lat_p99s = [], []\n", + "for mt in seq_lengths:\n", + " sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=mt, stop=[\"\"])\n", + " lats = [ timed_generate([\"Summer is\"]*batch, sampling_params)[0] for _ in range(5) ]\n", + " lat_means.append(sum(lats)/len(lats))\n", + " lat_p99s.append(statistics.quantiles(lats, n=100)[98])\n", + "\n", + "plt.figure()\n", + "plt.plot(seq_lengths, lat_means, marker='o', label='Mean latency')\n", + "plt.plot(seq_lengths, lat_p99s, marker='x', label='p99 latency')\n", + "plt.xlabel('max_tokens'); plt.ylabel('Latency (s)'); plt.title('Latency vs Generated Length')\n", + "plt.legend(); plt.grid(True); plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6ccc89ce-5c28-4c08-b8f6-3897b4be1035", + "metadata": {}, + "source": [ + "Once the execution finished, you can see, that latency grows roughly linearly with tokens generated. For chat UX, cap responses (e.g., 128-256 tokens) or stream early chunks." + ] + }, + { + "cell_type": "markdown", + "id": "4ffded4d-56a5-400d-99eb-05038c840adc", + "metadata": {}, + "source": [ + "#### Sampling knobs (temperature & top-p) and determinism" + ] + }, + { + "cell_type": "markdown", + "id": "5ce7e29e-5be9-4de8-8c28-778823c8e59d", + "metadata": {}, + "source": [ + "You might be wondering — do **sampling parameters** like `temperature` and `top_p` affect latency? Let's check!\n", + "\n", + "We run multiple generations using different combinations of:\n", + "\n", + "- `temperature = [0.0, 0.7, 1.3]`\n", + "- `top_p = [0.8, 0.9, 1.0]`\n", + "\n", + "For each pair, we measure:\n", + "\n", + "- **Mean latency**\n", + "- **p99 latency**\n", + "\n", + "This helps us assess whether sampling diversity impacts performance.\n", + "\n", + "Let’s visualize the results!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ffeda89-5fac-4a7e-9d52-cc20db791b67", + "metadata": {}, + "outputs": [], + "source": [ + "# Benchmark settings\n", + "temps = [0.0, 0.7, 1.3]\n", + "tops = [0.8, 0.9, 1.0]\n", + "grid = list(itertools.product(temps, tops))\n", + "\n", + "# Results containers\n", + "mean_lats = []\n", + "p99_lats = []\n", + "\n", + "# Benchmark loop\n", + "for t, p in grid:\n", + " sp = SamplingParams(temperature=t, top_p=p, max_tokens=64, stop=[\"\"])\n", + " lats = [timed_generate([\"The president of France is\"], sp)[0] for _ in range(30)]\n", + " mean_lat = statistics.mean(lats)\n", + " p99_lat = statistics.quantiles(lats, n=100)[98]\n", + "\n", + " mean_lats.append(mean_lat)\n", + " p99_lats.append(p99_lat)\n", + "\n", + " print(f\"T={t:<3} top_p={p:<3} | mean {mean_lat:.3f}s | p99 {p99_lat:.3f}s\")\n", + "\n", + "# Visualization\n", + "x = range(len(grid))\n", + "labels = [f\"T={t}, p={p}\" for (t, p) in grid]\n", + "\n", + "plt.figure(figsize=(8, 4))\n", + "plt.plot(x, mean_lats, marker='o', label='Mean latency')\n", + "plt.plot(x, p99_lats, marker='x', label='p99 latency')\n", + "plt.xticks(x, labels, rotation=45, ha='right')\n", + "plt.xlabel('(temperature, top_p)')\n", + "plt.ylabel('Latency (s)')\n", + "plt.title('Latency vs Sampling Parameters')\n", + "plt.legend(); plt.grid(True); plt.tight_layout()\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "id": "fa995662-0378-4e9c-a4c6-ad9a91a731cf", + "metadata": {}, + "source": [ + "You’ll notice sampling strategy barely moves latency (sub-1 % deltas). So worry about output quality, not perf, when tuning temperature and top-p.\n", + "\n", + "p99 latency is more volatile and fluctuates randomly." + ] + }, + { + "cell_type": "markdown", + "id": "50f7bdb7-6f73-468b-9295-7064c35e43bb", + "metadata": {}, + "source": [ + "## Note on reasoning models\n", + "\n", + "While this notebook focuses on explaining LLM inference and its performance, it's important to recognize that not all language models are built the same.\n", + "\n", + "Some are optimized for **throughput and deployment efficiency** (e.g. Mistral 7B, LLaMA 3 8B), while others are designed for **complex reasoning and instruction following** (e.g. GPT-4, Claude Opus, Gemini Pro, LLaMA 3 70B, NVIDIA Llama Nemotron).\n", + "\n", + "When working with reasoning models, it's important to evaluate them not just on latency and throughput, but also on:\n", + "\n", + "- Accuracy and depth of reasoning\n", + "- Consistency across temperature settings\n", + "- Ability to follow multi-step or long-form instructions\n", + "\n", + "That said, **UX rules still apply**: a powerful model that takes 10 seconds to respond without streaming may still feel unusable, no matter how smart it is.\n", + "\n", + "> In production, you'll often need to balance **speed** with **output quality** — especially for use cases involving multi-step reasoning or decision making.\n", + "\n", + "👉 **Choose the model based on the application**: chat assistants, real-time agents, RAG pipelines, and batch summarization all have different performance vs. quality requirements.\n", + "\n", + "The right tradeoff isn't about raw metrics — it's about delivering the right user experience for the task." + ] + }, + { + "cell_type": "markdown", + "id": "921beecb-942c-43ff-ade4-552171bf52e2", + "metadata": {}, + "source": [ + "## Summary" + ] + }, + { + "cell_type": "markdown", + "id": "2664ff55-02f5-420d-afb5-0f4c3e150901", + "metadata": {}, + "source": [ + "In this notebook, we explored the key performance aspects of running inference with large language models (LLMs) — and how to make informed tradeoffs between **latency**, **throughput**, and **user experience (UX)**.\n", + "\n", + "### Key takeaways\n", + "\n", + "- **Latency vs. Throughput is a balancing act**\n", + " - Larger batch sizes boost throughput but increase response time \n", + " - The “sweet spot” depends on your UX latency budget (e.g. < 2.5s for chat)\n", + "- **First-token latency (TTFT) matters for UX**\n", + " - TTFT stays low even at high batch sizes, making streaming UIs responsive \n", + " - Total latency, however, grows linearly with output length and batch size\n", + "- **p50 vs. p99 latency tells you everything about real-world performance**\n", + " - Optimizing for p50 gives nice averages \n", + " - Optimizing for p99 ensures consistent UX and avoids tail spikes\n", + "- **Sequence length directly affects decode time**\n", + " - Longer outputs = higher latency \n", + " - Always set `max_tokens` intentionally\n", + "- **Sampling parameters (temperature, top_p) have little impact on latency**\n", + " - You can safely tune them for quality without worrying about performance \n", + " - Only high-entropy sampling (e.g. temp > 1.3) may slightly increase p99\n", + "\n", + "### Final thought\n", + "\n", + "LLM inference isn't just about making the model run fast — it's about **making it feel fast** for the user. By understanding and measuring the right metrics, you can deliver models that are not only efficient, but actually usable in production." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a618e856-ec00-45c1-9eda-a64c81b97f70", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/community/llm-inference-series/README.md b/community/llm-inference-series/README.md new file mode 100644 index 00000000..d94afcdf --- /dev/null +++ b/community/llm-inference-series/README.md @@ -0,0 +1,67 @@ +# LLM Inference Series: Performance, Optimization & Deployment with LLMs + +This repository supports a video + notebook series exploring how to run, optimize, and serve Large Language Models (LLMs) with a focus on latency, throughput, user experience (UX), and NVIDIA GPU acceleration. + +All notebooks run in a shared container environment based on NVIDIA's TensorRT-LLM stack with a PyTorch backend. + +## Getting started + +Clone this project + +```bash +git clone +cd +``` + +Assuming, you are in your working directory, save the current path + +```bash +export ROOT_DIR=$(pwd) +``` + +Now execute the container: + +```bash +export PROJECT_DIR=$ROOT_DIR/community/llm-inference-series +export HF_CACHE_DIR=$ROOT_DIR/community/huggingface + +mkdir -p "$HF_CACHE_DIR" + +# Run container +docker run --gpus all -it --ipc=host \ + -v "$PROJECT_DIR":/workspace \ + -v "$HF_CACHE_DIR":/hf_cache \ + -p 8888:8888 \ + -e HF_HOME=/hf_cache \ + -e LOCAL_UID=$(id -u) \ + -e LOCAL_GID=$(id -g) \ + nvcr.io/nvidia/tensorrt-llm/release:0.21.0rc1 \ + bash -c ' + groupadd -g $LOCAL_GID hostgrp 2>/dev/null || true + useradd -u $LOCAL_UID -g $LOCAL_GID -M -d /workspace hostusr 2>/dev/null || true + + pip install --no-cache-dir -r /workspace/requirements.txt + + su hostusr -c "cd /workspace && HOME=/workspace HF_HOME=/hf_cache \ + jupyter lab --ip=0.0.0.0 --port=8888 --no-browser" + ' +``` + +Open Jupyter in your browser: http://localhost:8888 and use the token shown in the container logs. In the opened Jupyter Lab environment, navigate to a corresponding episode. + +## 🎥 Episode Guide + +✅ Episode 1: Inference 101 – Latency, Throughput & UX + +Folder: `01_inference_101/` + +What you'll learn: + +- What LLM inference means +- Latency vs tokens/sec vs p99: how and why they differ +- When latency matters and how to measure it +- How to visualize and interpret performance metrics + +👉 Watch Episode 1 (YouTube link coming soon) + +🔜 Stay tuned for updates as each episode is released! \ No newline at end of file diff --git a/community/llm-inference-series/constraints.txt b/community/llm-inference-series/constraints.txt new file mode 100755 index 00000000..8423feb4 --- /dev/null +++ b/community/llm-inference-series/constraints.txt @@ -0,0 +1,2 @@ +pillow==10.3.0 +torch==2.7.0a0 \ No newline at end of file diff --git a/community/llm-inference-series/requirements.txt b/community/llm-inference-series/requirements.txt new file mode 100755 index 00000000..2ab262dd --- /dev/null +++ b/community/llm-inference-series/requirements.txt @@ -0,0 +1,12 @@ +jupyterlab>=4.1 +ipywidgets>=8.0 +jupyterlab_widgets +tqdm>=4.66 +huggingface_hub>=0.33.0 +transformers>=4.40.0 +scipy +matplotlib +requests +jupytext +jupyterlab_code_formatter +jupyterlab_tensorboard_pro