Skip to content

Commit 0105829

Browse files
authored
Merge pull request #560 from yaoguany/main
allow exceeding model maximum length when train&inference
2 parents 1449cd4 + 78159d0 commit 0105829

File tree

4 files changed

+40
-14
lines changed

4 files changed

+40
-14
lines changed

src/lmflow/args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,14 @@ class ModelArguments:
198198
)
199199
}
200200
)
201+
truncate_to_model_max_length: bool = field(
202+
default=True,
203+
metadata={
204+
"help": (
205+
"whether truncate the dataset to model max length."
206+
)
207+
}
208+
)
201209
use_int8: bool = field(
202210
default=False,
203211
metadata={"help": "whether to load int8 quantization for inference"}

src/lmflow/models/hf_decoder_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ def __init__(
248248
# We resize the embeddings only when necessary to avoid index errors.
249249
# If you are creating a model from scratch on a small vocab and want a
250250
# smaller embedding size, remove this test.
251-
embedding_size = model.get_input_embeddings().weight.shape[0]
251+
with deepspeed.zero.GatheredParameters(model.get_input_embeddings().weight, modifier_rank=None):
252+
weights = model.get_input_embeddings().weight
253+
embedding_size = weights.shape[0]
252254
if len(tokenizer) > embedding_size:
253255
model.resize_token_embeddings(len(tokenizer))
254256

src/lmflow/pipeline/evaluator.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,14 @@ def _evaluate_ppl(self, model, dataset: Dataset, verbose=True):
323323
texts = [ instance["text"] for instance in data_dict["instances"] ]
324324
encodings = model.get_tokenizer()("\n\n".join(texts), return_tensors="pt")
325325
# Define some constant
326-
try:
327-
max_length = min(model.get_backend_model().config.n_positions, model.get_max_length())
328-
except:
329-
max_length = min(1024, model.get_max_length())
330-
326+
if self.model_args.truncate_to_model_max_length:
327+
try:
328+
max_length = min(model.get_backend_model().config.n_positions, model.get_max_length())
329+
except:
330+
max_length = min(1024, model.get_max_length())
331+
else:
332+
max_length = self.block_size
333+
331334
if verbose:
332335
print(f"The maximum sequence length : {max_length}")
333336
seq_len = encodings.input_ids.size(1)

src/lmflow/pipeline/finetuner.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,27 @@ def group_text(self, tokenized_datasets, model_max_length):
130130
block_size = 1024
131131
else:
132132
if data_args.block_size > model_max_length:
133-
logger.warning(
134-
f"The block_size passed ({data_args.block_size}) is larger"
135-
f" than the maximum length for the model"
136-
f"({model_max_length})."
137-
f" Using block_size={model_max_length}."
138-
)
139-
block_size = min(data_args.block_size, model_max_length)
140-
133+
if self.model_args.truncate_to_model_max_length:
134+
logger.warning(
135+
f"The block_size passed ({data_args.block_size}) is larger"
136+
f" than the maximum length for the model"
137+
f"({model_max_length})."
138+
f" Using block_size={model_max_length}."
139+
f"If you would like to use a longer 'block_size' that is"
140+
f" longer than the maximum length supported by the model,"
141+
f" you can override this behavior with"
142+
f"default with `--truncate_to_model_max_length False`."
143+
)
144+
block_size = model_max_length
145+
else:
146+
logger.warning(
147+
f"The block_size passed ({data_args.block_size}) is larger"
148+
f"than the maximum length for the model"
149+
f"({model_max_length})."
150+
f"Using block_size={data_args.block_size}.")
151+
block_size = data_args.block_size
152+
else:
153+
block_size = data_args.block_size
141154
# Main data processing function that will concatenate all texts from
142155
# our dataset and generate chunks of block_size.
143156
def group_texts(examples):

0 commit comments

Comments
 (0)