Skip to content

Commit

Permalink
fix(server): fix llamav2 config (#635)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Jul 18, 2023
1 parent cf83f9b commit 5e6ddfd
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple

# Flash attention imports
Expand All @@ -43,6 +44,56 @@
)


class LlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)


class LlamaRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
Expand Down
4 changes: 2 additions & 2 deletions server/text_generation_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import torch.distributed

from opentelemetry import trace
from transformers import AutoConfig
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
from typing import Optional

from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
LlamaConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(
trust_remote_code=trust_remote_code,
)

config = AutoConfig.from_pretrained(
config = LlamaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)

Expand Down

0 comments on commit 5e6ddfd

Please sign in to comment.