Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extensions/thunder/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from litgpt.args import EvalArgs, LogArgs, TrainArgs
from litgpt.data import DataModule, TinyLlama
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP, MultiheadLatentAttention
from litgpt.parser_config import save_hyperparameters
from litgpt.utils import (
CLI,
CycleIterator,
Expand All @@ -37,7 +38,6 @@
parse_devices,
reset_parameters,
save_config,
save_hyperparameters,
)

# support running without installing as a package
Expand Down
59 changes: 33 additions & 26 deletions litgpt/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from litgpt.finetune.adapter_v2 import setup as finetune_adapter_v2_fn
from litgpt.finetune.full import setup as finetune_full_fn
from litgpt.finetune.lora import setup as finetune_lora_fn
from litgpt.finetune.lora_legacy import setup as finetune_lora_legacy_fn
from litgpt.generate.adapter import main as generate_adapter_fn
from litgpt.generate.adapter_v2 import main as generate_adapter_v2_fn
from litgpt.generate.base import main as generate_base_fn
from litgpt.generate.full import main as generate_full_fn
from litgpt.generate.sequentially import main as generate_sequentially_fn
from litgpt.generate.speculative_decoding import main as generate_speculatively_fn
from litgpt.generate.tp import main as generate_tp_fn
from litgpt.parser_config import parser_commands
from litgpt.pretrain import setup as pretrain_fn
from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint as convert_hf_checkpoint_fn
from litgpt.scripts.convert_lit_checkpoint import convert_lit_checkpoint as convert_lit_checkpoint_fn
Expand All @@ -29,32 +29,39 @@
from litgpt.scripts.download import download_from_hub as download_fn
from litgpt.scripts.merge_lora import merge_lora as merge_lora_fn

PARSER_DATA = {
"download": download_fn,
"chat": chat_fn,
"finetune": finetune_lora_fn,
"finetune_lora": finetune_lora_fn,
"finetune_full": finetune_full_fn,
"finetune_adapter": finetune_adapter_fn,
"finetune_adapter_v2": finetune_adapter_v2_fn,
"pretrain": pretrain_fn,
"generate": generate_base_fn,
"generate_full": generate_full_fn,
"generate_adapter": generate_adapter_fn,
"generate_adapter_v2": generate_adapter_v2_fn,
"generate_sequentially": generate_sequentially_fn,
"generate_speculatively": generate_speculatively_fn,
"generate_tp": generate_tp_fn,
"convert_to_litgpt": convert_hf_checkpoint_fn,
"convert_from_litgpt": convert_lit_checkpoint_fn,
"convert_pretrained_checkpoint": convert_pretrained_checkpoint_fn,
"merge_lora": merge_lora_fn,
"evaluate": evaluate_fn,
"serve": serve_fn,
}


def _check_commands():
assert set(parser_commands()) == set(PARSER_DATA.keys()), (
"PARSER_DATA has to be kept in sync with litgpt.parser_config.parser_commands()"
)


def main() -> None:
parser_data = {
"download": download_fn,
"chat": chat_fn,
"finetune": finetune_lora_fn,
"finetune_lora": finetune_lora_fn,
"finetune_lora_legacy": finetune_lora_legacy_fn,
"finetune_full": finetune_full_fn,
"finetune_adapter": finetune_adapter_fn,
"finetune_adapter_v2": finetune_adapter_v2_fn,
"pretrain": pretrain_fn,
"generate": generate_base_fn,
"generate_full": generate_full_fn,
"generate_adapter": generate_adapter_fn,
"generate_adapter_v2": generate_adapter_v2_fn,
"generate_sequentially": generate_sequentially_fn,
"generate_speculatively": generate_speculatively_fn,
"generate_tp": generate_tp_fn,
"convert_to_litgpt": convert_hf_checkpoint_fn,
"convert_from_litgpt": convert_lit_checkpoint_fn,
"convert_pretrained_checkpoint": convert_pretrained_checkpoint_fn,
"merge_lora": merge_lora_fn,
"evaluate": evaluate_fn,
"serve": serve_fn,
}
_check_commands()

set_docstring_parse_options(attribute_docstrings=True)
set_config_read_mode(urls_enabled=True)
Expand All @@ -68,7 +75,7 @@ def main() -> None:
)

torch.set_float32_matmul_precision("high")
CLI(parser_data)
CLI(PARSER_DATA)


if __name__ == "__main__":
Expand Down
49 changes: 42 additions & 7 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,13 @@ class Config:
n_query_groups: Optional[int] = None
attn_bias: bool = False
attention_scores_scalar: Optional[int] = None
# If `sliding_window_size` is given, sliding window attention with this
# size is used in layers where `sliding_window_indices` has a 1. The
# default is all 1, so that sliding window attention is used in all
# layers. If `len(sliding_window_indices) > n_layer`, we only use the
# initial part.
sliding_window_size: Optional[int] = None
sliding_window_indices: Optional[List] = None
sliding_window_indices: Optional[List[int]] = None
# if `attention_logit_softcapping` is used, cannot use optimized
# `torch.nn.functional.scaled_dot_product_attention` (which implements
# Flash attention), may result in higher memory and runtime footprint.
Expand Down Expand Up @@ -102,9 +107,12 @@ class Config:
norm_2: bool = True
latent_attention: Optional[dict] = None
# The base period of the RoPE embeddings for local attention.
# If not provided, rope_theta will be used for both local and global attention.
# If not provided, `rope_base` will be used for both local and global attention.
rope_local_base_freq: Optional[float] = None
rope_indices: Optional[List] = None
# If provided, must have `>= n_layer` entries, either 0 or 1. For 0,
# `rope_base` is used, for 1 `rope_local_base_freq` is used. If
# `len(rope_indices) > n_layer`, we only use the initial part.
rope_indices: Optional[List[int]] = None

def __post_init__(self):
if not self.name:
Expand Down Expand Up @@ -135,11 +143,19 @@ def __post_init__(self):

self.rope_n_elem = int(self.rotary_percentage * self.head_size)

if self.sliding_window_size is not None and self.sliding_window_indices is None:
self.sliding_window_indices = [1] * self.n_layer
if self.sliding_window_size is not None:
self.sliding_window_indices = check_indicator_and_length(
self.sliding_window_indices,
name="sliding_window_indices",
required_length=self.n_layer,
)

if self.rope_local_base_freq is not None and self.rope_indices is None:
self.rope_indices = [1] * self.n_layer
if self.rope_local_base_freq is not None:
self.rope_indices = check_indicator_and_length(
self.rope_indices,
name="rope_indices",
required_length=self.n_layer,
)

if self.latent_attention is not None:
self.q_lora_rank = self.latent_attention.get("q_lora_rank")
Expand Down Expand Up @@ -232,6 +248,25 @@ def norm_class(self) -> Type:
return getattr(torch.nn, self.norm_class_name)


def check_indicator_and_length(
params: Optional[List[int]],
name: str,
required_length: int,
use_initial_part: bool = True,
def_val: int = 1,
) -> List[int]:
if params is None:
return [def_val] * required_length
if len(params) != required_length:
if use_initial_part and len(params) > required_length:
params = params[:required_length]
else:
raise ValueError(f"{name} = {params}, must have length {required_length}")
if not set(params).issubset({0, 1}):
raise ValueError(f"{name} = {params}, must only contain 0 and 1")
return params


########################
# Stability AI StableLM
########################
Expand Down
6 changes: 5 additions & 1 deletion litgpt/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ class DataModule(LightningDataModule):

@abstractmethod
def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
self,
tokenizer: Optional[Tokenizer] = None,
batch_size: int = 1,
max_seq_length: Optional[int] = None,
**kwargs,
) -> None:
"""All settings that can't be determined at the time of instantiation need to be passed through here
before any dataloaders can be accessed.
Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from litgpt.args import EvalArgs, LogArgs, TrainArgs
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.parser_config import save_hyperparameters
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
Expand All @@ -39,7 +40,6 @@
load_checkpoint,
num_parameters,
parse_devices,
save_hyperparameters,
select_sft_generate_example,
)

Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from litgpt.args import EvalArgs, LogArgs, TrainArgs
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.parser_config import save_hyperparameters
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
Expand All @@ -40,7 +41,6 @@
load_checkpoint_update,
num_parameters,
parse_devices,
save_hyperparameters,
select_sft_generate_example,
)

Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.model import GPT, Block, Config
from litgpt.parser_config import save_hyperparameters
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
Expand All @@ -35,7 +36,6 @@
load_checkpoint,
num_parameters,
parse_devices,
save_hyperparameters,
select_sft_generate_example,
)

Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.lora import GPT, Block, Config, mark_only_lora_as_trainable
from litgpt.parser_config import save_hyperparameters
from litgpt.prompts import save_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.tokenizer import Tokenizer
Expand All @@ -40,7 +41,6 @@
load_checkpoint,
num_parameters,
parse_devices,
save_hyperparameters,
select_sft_generate_example,
)

Expand Down
2 changes: 1 addition & 1 deletion litgpt/finetune/lora_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
from litgpt.parser_config import save_hyperparameters
from litgpt.prompts import save_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.tokenizer import Tokenizer
Expand All @@ -40,7 +41,6 @@
load_checkpoint,
num_parameters,
parse_devices,
save_hyperparameters,
select_sft_generate_example,
)

Expand Down
43 changes: 24 additions & 19 deletions litgpt/generate/sequentially.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import itertools
import logging
import math
import re
import sys
import time
Expand All @@ -11,7 +10,7 @@
from functools import partial
from pathlib import Path
from pprint import pprint
from typing import Literal, Optional, Type
from typing import List, Literal, Optional, Type

import lightning as L
import torch
Expand Down Expand Up @@ -41,18 +40,12 @@ def sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int
f" n_layer={model.config.n_layer} and devices={devices}."
)

# The last device might get fewer layers if number of layers not evenly divisible by device count
max_layers_per_device = math.ceil(model.config.n_layer / devices)
# dictates where each block should be instantiated
mapping = layer_to_device(model, chunk_on=Block, chunk_size=max_layers_per_device)

if set(mapping.values()) != set(range(devices)):
# TODO: support smarter partitioning schemes
raise RuntimeError(
f"Not able to distribute the {model.config.n_layer} layers across {devices} devices."
" Try running with a lower number of devices."
)

# Dictates where each block should be instantiated
mapping = layer_to_device(
model,
chunk_on=Block,
chunk_sizes=chunk_sizes(model.config.n_layer, devices),
)
num_layers_per_device = {i: sum(1 for v in mapping.values() if v == i) for i in range(devices)}

# materialize each block on the appropriate device
Expand Down Expand Up @@ -100,13 +93,25 @@ def sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int
return model


def chunk_sizes(num_units: int, devices: int) -> List[int]:
cs = num_units // devices
k = devices * (cs + 1) - num_units
return [cs] * k + [cs + 1] * (devices - k)


def layer_to_device(
module: torch.nn.Module, chunk_on: Type[torch.nn.Module], chunk_size: int
module: torch.nn.Module,
chunk_on: Type[torch.nn.Module],
chunk_sizes: List[int],
) -> "OrderedDict[str, int]":
"""Create a mapping from layer (block) to device."""
# this assumes that the definition order is the same as the execution order
hits = [name for name, submodule in module.named_modules() if isinstance(submodule, chunk_on)]
return OrderedDict((name, i // chunk_size) for i, name in enumerate(hits))
if sum(chunk_sizes) != len(hits):
raise ValueError(f"Found {len(hits)} for chunk_on={chunk_on}, not covered by chunk_sizes={chunk_sizes}")
_devices = [[d] * cs for d, cs in enumerate(chunk_sizes)]
devices = [d for lst in _devices for d in lst]
return OrderedDict(zip(hits, devices))


def move_block_input(device: torch.device, module: torch.nn.Module, ins):
Expand Down Expand Up @@ -263,9 +268,9 @@ def main(
for i in range(num_samples):
t0 = time.perf_counter()
y = generate_base.generate(
model,
encoded,
max_returned_tokens,
model=model,
prompt=encoded,
max_returned_tokens=max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
Expand Down
Loading