From 77fdf8f61105fe0b6451afe147470465a5f6ca95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 16 Oct 2024 22:36:31 +0200 Subject: [PATCH] Time tree building for LFE --- src/benchmark_lfe.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/benchmark_lfe.py b/src/benchmark_lfe.py index 22a62cf..cdff195 100644 --- a/src/benchmark_lfe.py +++ b/src/benchmark_lfe.py @@ -23,23 +23,20 @@ def setup(self, model, _): self.tokenizer = AutoTokenizer.from_pretrained( model, clean_up_tokenization_spaces=True ) - self.tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer) def time_lfe(self, _, regex_name): regex_string = regex_cases[regex_name]["regex"] regex_samples = regex_cases[regex_name]["samples"] parser = RegexParser(regex_string) - token_enforcer = TokenEnforcer(self.tokenizer_data, parser) + tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer) + token_enforcer = TokenEnforcer(tokenizer_data, parser) for regex_sample in regex_samples: regex_sample_tokens = self.tokenizer.encode(regex_sample) for i in range(len(regex_sample_tokens)): _ = token_enforcer.get_allowed_tokens(regex_sample_tokens[: i + 1]) - def teardown(self, *args): - del self.tokenizer_data - class LMFormatEnforcerJsonSchema: params = [models, json_cases.keys()] @@ -56,19 +53,16 @@ def setup(self, model, _): self.tokenizer = AutoTokenizer.from_pretrained( model, clean_up_tokenization_spaces=True ) - self.tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer) def time_lfe(self, _, json_schema_name): json_string = json_cases[json_schema_name]["schema"] json_samples = json_cases[json_schema_name]["samples"] parser = JsonSchemaParser(json_string) - token_enforcer = TokenEnforcer(self.tokenizer_data, parser) + tokenizer_data = build_token_enforcer_tokenizer_data(self.tokenizer) + token_enforcer = TokenEnforcer(tokenizer_data, parser) for json_sample in json_samples: json_sample_tokens = self.tokenizer.encode(json_sample) for i in range(len(json_sample_tokens)): _ = token_enforcer.get_allowed_tokens(json_sample_tokens[: i + 1]) - - def teardown(self, *args): - del self.tokenizer_data