diff --git a/extensions/thunder/pretrain.py b/extensions/thunder/pretrain.py index 186e2e5f77..f5a47bb4ff 100644 --- a/extensions/thunder/pretrain.py +++ b/extensions/thunder/pretrain.py @@ -24,6 +24,7 @@ from litgpt.args import EvalArgs, LogArgs, TrainArgs from litgpt.data import DataModule, TinyLlama from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP, MultiheadLatentAttention +from litgpt.parser_config import save_hyperparameters from litgpt.utils import ( CLI, CycleIterator, @@ -37,7 +38,6 @@ parse_devices, reset_parameters, save_config, - save_hyperparameters, ) # support running without installing as a package diff --git a/litgpt/__main__.py b/litgpt/__main__.py index 6f63f589a2..649f9ee960 100644 --- a/litgpt/__main__.py +++ b/litgpt/__main__.py @@ -12,7 +12,6 @@ from litgpt.finetune.adapter_v2 import setup as finetune_adapter_v2_fn from litgpt.finetune.full import setup as finetune_full_fn from litgpt.finetune.lora import setup as finetune_lora_fn -from litgpt.finetune.lora_legacy import setup as finetune_lora_legacy_fn from litgpt.generate.adapter import main as generate_adapter_fn from litgpt.generate.adapter_v2 import main as generate_adapter_v2_fn from litgpt.generate.base import main as generate_base_fn @@ -20,6 +19,7 @@ from litgpt.generate.sequentially import main as generate_sequentially_fn from litgpt.generate.speculative_decoding import main as generate_speculatively_fn from litgpt.generate.tp import main as generate_tp_fn +from litgpt.parser_config import parser_commands from litgpt.pretrain import setup as pretrain_fn from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint as convert_hf_checkpoint_fn from litgpt.scripts.convert_lit_checkpoint import convert_lit_checkpoint as convert_lit_checkpoint_fn @@ -29,32 +29,39 @@ from litgpt.scripts.download import download_from_hub as download_fn from litgpt.scripts.merge_lora import merge_lora as merge_lora_fn +PARSER_DATA = { + "download": download_fn, + "chat": chat_fn, + "finetune": finetune_lora_fn, + "finetune_lora": finetune_lora_fn, + "finetune_full": finetune_full_fn, + "finetune_adapter": finetune_adapter_fn, + "finetune_adapter_v2": finetune_adapter_v2_fn, + "pretrain": pretrain_fn, + "generate": generate_base_fn, + "generate_full": generate_full_fn, + "generate_adapter": generate_adapter_fn, + "generate_adapter_v2": generate_adapter_v2_fn, + "generate_sequentially": generate_sequentially_fn, + "generate_speculatively": generate_speculatively_fn, + "generate_tp": generate_tp_fn, + "convert_to_litgpt": convert_hf_checkpoint_fn, + "convert_from_litgpt": convert_lit_checkpoint_fn, + "convert_pretrained_checkpoint": convert_pretrained_checkpoint_fn, + "merge_lora": merge_lora_fn, + "evaluate": evaluate_fn, + "serve": serve_fn, +} + + +def _check_commands(): + assert set(parser_commands()) == set(PARSER_DATA.keys()), ( + "PARSER_DATA has to be kept in sync with litgpt.parser_config.parser_commands()" + ) + def main() -> None: - parser_data = { - "download": download_fn, - "chat": chat_fn, - "finetune": finetune_lora_fn, - "finetune_lora": finetune_lora_fn, - "finetune_lora_legacy": finetune_lora_legacy_fn, - "finetune_full": finetune_full_fn, - "finetune_adapter": finetune_adapter_fn, - "finetune_adapter_v2": finetune_adapter_v2_fn, - "pretrain": pretrain_fn, - "generate": generate_base_fn, - "generate_full": generate_full_fn, - "generate_adapter": generate_adapter_fn, - "generate_adapter_v2": generate_adapter_v2_fn, - "generate_sequentially": generate_sequentially_fn, - "generate_speculatively": generate_speculatively_fn, - "generate_tp": generate_tp_fn, - "convert_to_litgpt": convert_hf_checkpoint_fn, - "convert_from_litgpt": convert_lit_checkpoint_fn, - "convert_pretrained_checkpoint": convert_pretrained_checkpoint_fn, - "merge_lora": merge_lora_fn, - "evaluate": evaluate_fn, - "serve": serve_fn, - } + _check_commands() set_docstring_parse_options(attribute_docstrings=True) set_config_read_mode(urls_enabled=True) @@ -68,7 +75,7 @@ def main() -> None: ) torch.set_float32_matmul_precision("high") - CLI(parser_data) + CLI(PARSER_DATA) if __name__ == "__main__": diff --git a/litgpt/config.py b/litgpt/config.py index da7d3ee5bb..98674154f8 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -68,8 +68,13 @@ class Config: n_query_groups: Optional[int] = None attn_bias: bool = False attention_scores_scalar: Optional[int] = None + # If `sliding_window_size` is given, sliding window attention with this + # size is used in layers where `sliding_window_indices` has a 1. The + # default is all 1, so that sliding window attention is used in all + # layers. If `len(sliding_window_indices) > n_layer`, we only use the + # initial part. sliding_window_size: Optional[int] = None - sliding_window_indices: Optional[List] = None + sliding_window_indices: Optional[List[int]] = None # if `attention_logit_softcapping` is used, cannot use optimized # `torch.nn.functional.scaled_dot_product_attention` (which implements # Flash attention), may result in higher memory and runtime footprint. @@ -102,9 +107,12 @@ class Config: norm_2: bool = True latent_attention: Optional[dict] = None # The base period of the RoPE embeddings for local attention. - # If not provided, rope_theta will be used for both local and global attention. + # If not provided, `rope_base` will be used for both local and global attention. rope_local_base_freq: Optional[float] = None - rope_indices: Optional[List] = None + # If provided, must have `>= n_layer` entries, either 0 or 1. For 0, + # `rope_base` is used, for 1 `rope_local_base_freq` is used. If + # `len(rope_indices) > n_layer`, we only use the initial part. + rope_indices: Optional[List[int]] = None def __post_init__(self): if not self.name: @@ -135,11 +143,19 @@ def __post_init__(self): self.rope_n_elem = int(self.rotary_percentage * self.head_size) - if self.sliding_window_size is not None and self.sliding_window_indices is None: - self.sliding_window_indices = [1] * self.n_layer + if self.sliding_window_size is not None: + self.sliding_window_indices = check_indicator_and_length( + self.sliding_window_indices, + name="sliding_window_indices", + required_length=self.n_layer, + ) - if self.rope_local_base_freq is not None and self.rope_indices is None: - self.rope_indices = [1] * self.n_layer + if self.rope_local_base_freq is not None: + self.rope_indices = check_indicator_and_length( + self.rope_indices, + name="rope_indices", + required_length=self.n_layer, + ) if self.latent_attention is not None: self.q_lora_rank = self.latent_attention.get("q_lora_rank") @@ -232,6 +248,25 @@ def norm_class(self) -> Type: return getattr(torch.nn, self.norm_class_name) +def check_indicator_and_length( + params: Optional[List[int]], + name: str, + required_length: int, + use_initial_part: bool = True, + def_val: int = 1, +) -> List[int]: + if params is None: + return [def_val] * required_length + if len(params) != required_length: + if use_initial_part and len(params) > required_length: + params = params[:required_length] + else: + raise ValueError(f"{name} = {params}, must have length {required_length}") + if not set(params).issubset({0, 1}): + raise ValueError(f"{name} = {params}, must only contain 0 and 1") + return params + + ######################## # Stability AI StableLM ######################## diff --git a/litgpt/data/base.py b/litgpt/data/base.py index b571cb7c40..4cd360b8c4 100644 --- a/litgpt/data/base.py +++ b/litgpt/data/base.py @@ -17,7 +17,11 @@ class DataModule(LightningDataModule): @abstractmethod def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + **kwargs, ) -> None: """All settings that can't be determined at the time of instantiation need to be passed through here before any dataloaders can be accessed. diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index 813f2c8226..b874231bf0 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -20,6 +20,7 @@ from litgpt.args import EvalArgs, LogArgs, TrainArgs from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate +from litgpt.parser_config import save_hyperparameters from litgpt.prompts import save_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( @@ -39,7 +40,6 @@ load_checkpoint, num_parameters, parse_devices, - save_hyperparameters, select_sft_generate_example, ) diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index b80f5688a5..8ec25c905c 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -20,6 +20,7 @@ from litgpt.args import EvalArgs, LogArgs, TrainArgs from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate +from litgpt.parser_config import save_hyperparameters from litgpt.prompts import save_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( @@ -40,7 +41,6 @@ load_checkpoint_update, num_parameters, parse_devices, - save_hyperparameters, select_sft_generate_example, ) diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 22699b8c5c..eec16bcde2 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -17,6 +17,7 @@ from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.model import GPT, Block, Config +from litgpt.parser_config import save_hyperparameters from litgpt.prompts import save_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( @@ -35,7 +36,6 @@ load_checkpoint, num_parameters, parse_devices, - save_hyperparameters, select_sft_generate_example, ) diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 1ef450f620..eee57cff78 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -20,6 +20,7 @@ from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.lora import GPT, Block, Config, mark_only_lora_as_trainable +from litgpt.parser_config import save_hyperparameters from litgpt.prompts import save_prompt_style from litgpt.scripts.merge_lora import merge_lora from litgpt.tokenizer import Tokenizer @@ -40,7 +41,6 @@ load_checkpoint, num_parameters, parse_devices, - save_hyperparameters, select_sft_generate_example, ) diff --git a/litgpt/finetune/lora_legacy.py b/litgpt/finetune/lora_legacy.py index fe05896df6..6575bc10db 100644 --- a/litgpt/finetune/lora_legacy.py +++ b/litgpt/finetune/lora_legacy.py @@ -20,6 +20,7 @@ from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable +from litgpt.parser_config import save_hyperparameters from litgpt.prompts import save_prompt_style from litgpt.scripts.merge_lora import merge_lora from litgpt.tokenizer import Tokenizer @@ -40,7 +41,6 @@ load_checkpoint, num_parameters, parse_devices, - save_hyperparameters, select_sft_generate_example, ) diff --git a/litgpt/generate/sequentially.py b/litgpt/generate/sequentially.py index 8cdc1ed75d..0711f5f9ff 100644 --- a/litgpt/generate/sequentially.py +++ b/litgpt/generate/sequentially.py @@ -2,7 +2,6 @@ import itertools import logging -import math import re import sys import time @@ -11,7 +10,7 @@ from functools import partial from pathlib import Path from pprint import pprint -from typing import Literal, Optional, Type +from typing import List, Literal, Optional, Type import lightning as L import torch @@ -41,18 +40,12 @@ def sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int f" n_layer={model.config.n_layer} and devices={devices}." ) - # The last device might get fewer layers if number of layers not evenly divisible by device count - max_layers_per_device = math.ceil(model.config.n_layer / devices) - # dictates where each block should be instantiated - mapping = layer_to_device(model, chunk_on=Block, chunk_size=max_layers_per_device) - - if set(mapping.values()) != set(range(devices)): - # TODO: support smarter partitioning schemes - raise RuntimeError( - f"Not able to distribute the {model.config.n_layer} layers across {devices} devices." - " Try running with a lower number of devices." - ) - + # Dictates where each block should be instantiated + mapping = layer_to_device( + model, + chunk_on=Block, + chunk_sizes=chunk_sizes(model.config.n_layer, devices), + ) num_layers_per_device = {i: sum(1 for v in mapping.values() if v == i) for i in range(devices)} # materialize each block on the appropriate device @@ -100,13 +93,25 @@ def sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int return model +def chunk_sizes(num_units: int, devices: int) -> List[int]: + cs = num_units // devices + k = devices * (cs + 1) - num_units + return [cs] * k + [cs + 1] * (devices - k) + + def layer_to_device( - module: torch.nn.Module, chunk_on: Type[torch.nn.Module], chunk_size: int + module: torch.nn.Module, + chunk_on: Type[torch.nn.Module], + chunk_sizes: List[int], ) -> "OrderedDict[str, int]": """Create a mapping from layer (block) to device.""" # this assumes that the definition order is the same as the execution order hits = [name for name, submodule in module.named_modules() if isinstance(submodule, chunk_on)] - return OrderedDict((name, i // chunk_size) for i, name in enumerate(hits)) + if sum(chunk_sizes) != len(hits): + raise ValueError(f"Found {len(hits)} for chunk_on={chunk_on}, not covered by chunk_sizes={chunk_sizes}") + _devices = [[d] * cs for d, cs in enumerate(chunk_sizes)] + devices = [d for lst in _devices for d in lst] + return OrderedDict(zip(hits, devices)) def move_block_input(device: torch.device, module: torch.nn.Module, ins): @@ -263,9 +268,9 @@ def main( for i in range(num_samples): t0 = time.perf_counter() y = generate_base.generate( - model, - encoded, - max_returned_tokens, + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, diff --git a/litgpt/parser_config.py b/litgpt/parser_config.py new file mode 100644 index 0000000000..d90cc459ab --- /dev/null +++ b/litgpt/parser_config.py @@ -0,0 +1,55 @@ +import sys +from pathlib import Path +from typing import List, Optional + +from litgpt.utils import CLI + + +def parser_commands() -> List[str]: + return [ + "download", + "chat", + "finetune", + "finetune_lora", + "finetune_full", + "finetune_adapter", + "finetune_adapter_v2", + "pretrain", + "generate", + "generate_full", + "generate_adapter", + "generate_adapter_v2", + "generate_sequentially", + "generate_speculatively", + "generate_tp", + "convert_to_litgpt", + "convert_from_litgpt", + "convert_pretrained_checkpoint", + "merge_lora", + "evaluate", + "serve", + ] + + +def save_hyperparameters( + function: callable, + checkpoint_dir: Path, + known_commands: Optional[List[str]] = None, +) -> None: + """Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint.""" + from jsonargparse import capture_parser + + # TODO: Make this more robust + # This hack strips away the subcommands from the top-level CLI + # to parse the file as if it was called as a script + if known_commands is None: + known_commands = parser_commands() + known_commands = [(c,) for c in known_commands] + for known_command in known_commands: + unwanted = slice(1, 1 + len(known_command)) + if tuple(sys.argv[unwanted]) == known_command: + sys.argv[unwanted] = [] + + parser = capture_parser(lambda: CLI(function)) + config = parser.parse_args() + parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True) diff --git a/litgpt/pretrain.py b/litgpt/pretrain.py index cdf4583195..e61b1494e7 100644 --- a/litgpt/pretrain.py +++ b/litgpt/pretrain.py @@ -23,6 +23,7 @@ from litgpt.config import name_to_config from litgpt.data import DataModule, TinyLlama from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP +from litgpt.parser_config import save_hyperparameters from litgpt.utils import ( _TORCH_EQUAL_2_7, _TORCH_EQUAL_2_8, @@ -41,7 +42,6 @@ parse_devices, reset_parameters, save_config, - save_hyperparameters, ) diff --git a/litgpt/utils.py b/litgpt/utils.py index 303c4d1bf1..6a175d2c98 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -552,31 +552,6 @@ def capture_hparams() -> Dict[str, Any]: return hparams -def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None: - """Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint.""" - from jsonargparse import capture_parser - - # TODO: Make this more robust - # This hack strips away the subcommands from the top-level CLI - # to parse the file as if it was called as a script - known_commands = [ - ("finetune_full",), # For subcommands, use `("finetune", "full")` etc - ("finetune_lora",), - ("finetune_adapter",), - ("finetune_adapter_v2",), - ("finetune",), - ("pretrain",), - ] - for known_command in known_commands: - unwanted = slice(1, 1 + len(known_command)) - if tuple(sys.argv[unwanted]) == known_command: - sys.argv[unwanted] = [] - - parser = capture_parser(lambda: CLI(function)) - config = parser.parse_args() - parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True) - - def save_config(config: "Config", checkpoint_dir: Path) -> None: config_dict = asdict(config) with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp: diff --git a/tests/generate/test_sequentially.py b/tests/generate/test_sequentially.py index 37175fa489..95dca762d2 100644 --- a/tests/generate/test_sequentially.py +++ b/tests/generate/test_sequentially.py @@ -1,7 +1,6 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import itertools -import math import subprocess import sys from dataclasses import asdict @@ -14,7 +13,12 @@ from lightning import Fabric from litgpt import Config -from litgpt.generate.sequentially import layer_to_device, replace_device, sequential +from litgpt.generate.sequentially import ( + chunk_sizes, + layer_to_device, + replace_device, + sequential, +) from litgpt.model import GPT, Block from litgpt.scripts.download import download_from_hub from litgpt.utils import _RunIf @@ -28,8 +32,8 @@ (6, 1, {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0}), (6, 2, {0: 0, 1: 0, 2: 0, 3: 1, 4: 1, 5: 1}), (6, 3, {0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2}), - (6, 4, {0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2}), - (6, 5, {0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2}), + (6, 4, {0: 0, 1: 1, 2: 2, 3: 2, 4: 3, 5: 3}), + (6, 5, {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4}), (6, 6, {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5}), ], ) @@ -37,28 +41,12 @@ def test_layer_to_device(n_layer, devices, expected): with torch.device("meta"): model = GPT.from_name("pythia-14m", n_layer=n_layer) - max_layers_per_device = math.ceil(n_layer / devices) - actual = layer_to_device(model, Block, chunk_size=max_layers_per_device) + c_sizes = chunk_sizes(n_layer, devices) + actual = layer_to_device(model, Block, chunk_sizes=c_sizes) expected = {f"transformer.h.{i}": v for i, v in expected.items()} assert actual == expected -def test_sequential_layer_to_device_mapping_not_possible(): - # Fewer layers than devices - config = Config(n_layer=1) - with torch.device("meta"): - model = GPT(config) - with pytest.raises(ValueError, match="number of layers in the model must be larger than the number of devices"): - sequential(model, root=torch.device("cpu"), max_seq_length=128, devices=2) - - # Last device would get 0 layers - config = Config(n_layer=6) - with torch.device("meta"): - model = GPT(config) - with pytest.raises(RuntimeError, match="Not able to distribute the 6 layers across 4 devices"): - sequential(model, root=torch.device("cpu"), max_seq_length=128, devices=4) - - def path_to_device(model): return {k: str(v.device) for k, v in itertools.chain(model.named_parameters(), model.named_buffers())} diff --git a/tests/test_cli.py b/tests/test_cli.py index 1ddbf5588e..6d1c1091be 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -16,7 +16,7 @@ def test_cli(): out = out.getvalue() assert "usage: litgpt" in out assert ( - "{download,chat,finetune,finetune_lora,finetune_lora_legacy,finetune_full,finetune_adapter,finetune_adapter_v2," + "{download,chat,finetune,finetune_lora,finetune_full,finetune_adapter,finetune_adapter_v2," "pretrain,generate,generate_full,generate_adapter,generate_adapter_v2,generate_sequentially," "generate_speculatively,generate_tp,convert_to_litgpt,convert_from_litgpt,convert_pretrained_checkpoint," "merge_lora,evaluate,serve}" in out diff --git a/tests/test_utils.py b/tests/test_utils.py index 4f41366ca8..3a04c6ed57 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -19,6 +19,7 @@ from litgpt import GPT from litgpt.args import TrainArgs +from litgpt.parser_config import save_hyperparameters from litgpt.utils import ( CLI, CycleIterator, @@ -39,7 +40,6 @@ instantiate_torch_optimizer, num_parameters, parse_devices, - save_hyperparameters, select_sft_generate_example, )