diff --git a/clrs/__init__.py b/clrs/__init__.py index bf44d5cb..e30d5d02 100644 --- a/clrs/__init__.py +++ b/clrs/__init__.py @@ -16,21 +16,29 @@ """The CLRS Algorithmic Reasoning Benchmark.""" from clrs import models + from clrs._src import algorithms +from clrs._src import clrs_text from clrs._src import decoders from clrs._src import processors + from clrs._src.dataset import chunkify from clrs._src.dataset import CLRSDataset from clrs._src.dataset import create_chunked_dataset from clrs._src.dataset import create_dataset from clrs._src.dataset import get_clrs_folder from clrs._src.dataset import get_dataset_gcp_url + from clrs._src.evaluation import evaluate from clrs._src.evaluation import evaluate_hints + from clrs._src.model import Model + from clrs._src.probing import DataPoint from clrs._src.probing import predecessor_to_cyclic_predecessor_and_first + from clrs._src.processors import get_processor_factory + from clrs._src.samplers import build_sampler from clrs._src.samplers import CLRS30 from clrs._src.samplers import Features @@ -40,6 +48,7 @@ from clrs._src.samplers import process_random_pos from clrs._src.samplers import Sampler from clrs._src.samplers import Trajectory + from clrs._src.specs import ALGO_IDX_INPUT_NAME from clrs._src.specs import CLRS_30_ALGS_SETTINGS from clrs._src.specs import Location diff --git a/clrs/_src/clrs_text/__init__.py b/clrs/_src/clrs_text/__init__.py new file mode 100644 index 00000000..6ab5403f --- /dev/null +++ b/clrs/_src/clrs_text/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The CLRS Text Algorithmic Reasoning Benchmark.""" + +from clrs._src.clrs_text import clrs_utils diff --git a/clrs/_src/clrs_text/clrs_utils.py b/clrs/_src/clrs_text/clrs_utils.py index ea790032..3289e529 100644 --- a/clrs/_src/clrs_text/clrs_utils.py +++ b/clrs/_src/clrs_text/clrs_utils.py @@ -14,8 +14,8 @@ # ============================================================================== """Functions to create text versions of CLRS data.""" from typing import Any, Optional - -import clrs +from clrs._src import samplers +from clrs._src import specs import numpy as np @@ -78,7 +78,7 @@ def format_clrs_example( algo: str, - sample: clrs.Feedback, + sample: samplers.Feedback, use_hints: bool = False, ) -> tuple[str, str]: """Formats CLRS example into prompt for the LLM. @@ -112,7 +112,7 @@ def format_clrs_example( def _get_output_names( algo_name: str, - spec: clrs.Spec, + spec: specs.Spec, use_hints: bool, ) -> list[str]: """Gets the output names for a CLRS algorithm.""" @@ -124,12 +124,12 @@ def _get_output_names( return [ spec_name for spec_name in spec - if spec[spec_name][0] == clrs.Stage.OUTPUT + if spec[spec_name][0] == specs.Stage.OUTPUT ] def _get_output_str( - sample: clrs.Feedback, spec, algo_name: str, use_hints: bool + sample: samplers.Feedback, spec, algo_name: str, use_hints: bool ) -> list[str]: """Gets the output string for a CLRS algorithm.""" if algo_name in CLRS_SEARCH_TAKS_OUTPUT_REPLACER and use_hints: @@ -157,7 +157,7 @@ def _get_output_str( def sample_to_str( algo: str, - sample: clrs.Feedback, + sample: samplers.Feedback, use_hints: bool = False, ) -> tuple[str, str, str, bool]: """Converts a CLRS sample into input and output strings. @@ -206,7 +206,7 @@ def sample_to_str( Returns: A 3-tuple of (input, output_names, output) strings. """ - spec = clrs.SPECS[algo] + spec = specs.SPECS[algo] # Create input prompt. input_strs = _create_input_feature_strs(spec, sample.features.inputs) @@ -252,8 +252,8 @@ def sample_to_str( def _create_input_feature_strs( - spec: clrs.Spec, - inputs: clrs.Features, + spec: specs.Spec, + inputs: samplers.Features, ) -> list[str]: """Extracts input features and convert them into strings.""" input_strs = [] @@ -261,7 +261,7 @@ def _create_input_feature_strs( stage, _, _ = spec[spec_name] # (stage, location, type) - if stage != clrs.Stage.INPUT: + if stage != specs.Stage.INPUT: continue if _do_not_include_input_in_text(spec_name, spec): @@ -279,16 +279,16 @@ def _create_input_feature_strs( def _create_output_feature_strs( - spec: clrs.Spec, - inputs: clrs.Features, - outputs: clrs.Features, + spec: specs.Spec, + inputs: samplers.Features, + outputs: samplers.Features, ) -> list[str]: """Extracts output features and convert them into strings.""" output_strs = [] for spec_name in spec: stage, _, _ = spec[spec_name] - if stage != clrs.Stage.OUTPUT: + if stage != specs.Stage.OUTPUT: continue x = _get_feature_by_name(outputs, spec_name).data @@ -339,9 +339,9 @@ def _format_hint(hints: list[str], algo_name: str) -> str: def _create_hint_feature_strs( algo_name: str, - spec: clrs.Spec, - inputs: clrs.Features, - hints: clrs.Features, + spec: specs.Spec, + inputs: samplers.Features, + hints: samplers.Features, output_names: list[str], ) -> tuple[str, str, bool]: """Extracts hint features and convert them into strings.""" @@ -405,10 +405,10 @@ def _create_hint_feature_strs( def _feature_to_str( name: str, - spec: clrs.Spec, + spec: specs.Spec, x: np.ndarray, with_name: bool, - inputs: Optional[clrs.Features] = None, + inputs: Optional[samplers.Features] = None, edge_masks_as_edge_list: bool = False, ) -> str: """Converts a numerical CLRS feature into a string.""" @@ -421,7 +421,7 @@ def _feature_to_str( x = x[0] unused_stage, location, typ_ = spec[name] match location: - case clrs.Location.NODE: + case specs.Location.NODE: output = _convert_node_features_to_str( x=x, spec_name=name, @@ -429,14 +429,14 @@ def _feature_to_str( spec_type=typ_, inputs=inputs, ) - case clrs.Location.GRAPH: + case specs.Location.GRAPH: output = _convert_graph_features_to_str( x=x, spec_name=name, spec=spec, spec_type=typ_, ) - case clrs.Location.EDGE: + case specs.Location.EDGE: output = _convert_edge_features_to_str( x=x, spec_name=name, @@ -469,13 +469,13 @@ def predecessors_to_order(x: np.ndarray) -> np.ndarray: def _convert_node_features_to_str( x: np.ndarray, spec_name: str, - spec: clrs.Spec, + spec: specs.Spec, spec_type: str, - inputs: Optional[clrs.Features] = None, + inputs: Optional[samplers.Features] = None, ) -> str: """Converts node features into string.""" match spec_type: - case clrs.Type.SHOULD_BE_PERMUTATION: + case specs.Type.SHOULD_BE_PERMUTATION: # For the text version of CLRS, if the output is a permutation, we present # the "key" input values in the order given by the permutation. nonsorted_values = _get_feature_by_name(inputs, 'key').data[0] @@ -488,15 +488,15 @@ def _convert_node_features_to_str( SEQUENCE_SEPARATOR.join([f'{scalar:.3g}' for scalar in sorted_values]) ) - case clrs.Type.MASK_ONE: + case specs.Type.MASK_ONE: [index] = x.nonzero()[0] return f'{index}' - case clrs.Type.SCALAR: + case specs.Type.SCALAR: return _bracket(SEQUENCE_SEPARATOR.join([f'{a:.3g}' for a in x])) - case clrs.Type.MASK | clrs.Type.POINTER | clrs.Type.CATEGORICAL: - if spec_type == clrs.Type.CATEGORICAL: + case specs.Type.MASK | specs.Type.POINTER | specs.Type.CATEGORICAL: + if spec_type == specs.Type.CATEGORICAL: categories = np.argmax(x, axis=-1) int_output = categories else: @@ -510,20 +510,24 @@ def _convert_node_features_to_str( def _convert_graph_features_to_str( x: np.ndarray, spec_name: str, - spec: clrs.Spec, + spec: specs.Spec, spec_type: str, ) -> str: """Converts graph features into string.""" match spec_type: - case clrs.Type.SCALAR: + case specs.Type.SCALAR: return f'{x:.3f}' - case clrs.Type.CATEGORICAL: + case specs.Type.CATEGORICAL: categories = np.argmax(x, axis=-1) return f'{categories}' case _: - if spec_type in [clrs.Type.MASK, clrs.Type.MASK_ONE, clrs.Type.POINTER]: + if spec_type in [ + specs.Type.MASK, + specs.Type.MASK_ONE, + specs.Type.POINTER, + ]: return f'{x.astype(int)}' else: raise KeyError(f'Feature type not supported in spec {spec[spec_name]}') @@ -532,24 +536,24 @@ def _convert_graph_features_to_str( def _convert_edge_features_to_str( x: np.ndarray, spec_name: str, - spec: clrs.Spec, + spec: specs.Spec, spec_type: str, edge_masks_as_edge_list: bool, ): """Converts edge features into string.""" if edge_masks_as_edge_list: - if spec_type == clrs.Type.MASK or ( - spec_type == clrs.Type.SCALAR and _is_binary(x) + if spec_type == specs.Type.MASK or ( + spec_type == specs.Type.SCALAR and _is_binary(x) ): edges = list(zip(*np.nonzero(x > 0))) return DEFAULT_SEPARATOR.join([f'({x},{y})' for x, y in edges]) else: match spec_type: - case clrs.Type.POINTER | clrs.Type.MASK | clrs.Type.CATEGORICAL: - if spec_type == clrs.Type.CATEGORICAL: + case specs.Type.POINTER | specs.Type.MASK | specs.Type.CATEGORICAL: + if spec_type == specs.Type.CATEGORICAL: # lcs_length includes masked elements where the category is -1 - mask = np.any(x == clrs.OutputClass.MASKED, axis=-1) + mask = np.any(x == specs.OutputClass.MASKED, axis=-1) categories = np.argmax(x, axis=-1) categories[mask] = -1 int_output = categories @@ -562,14 +566,14 @@ def _convert_edge_features_to_str( ), ) - case clrs.Type.SCALAR: + case specs.Type.SCALAR: row_to_str = lambda r: _bracket(' '.join([f'{a:.3g}' for a in r])) return _bracket(DEFAULT_SEPARATOR.join([row_to_str(r) for r in x])) raise KeyError(f'Feature type not supported in spec {spec[spec_name]}') -def _get_feature_by_name(examples: clrs.Features, spec_name: str) -> Any: +def _get_feature_by_name(examples: samplers.Features, spec_name: str) -> Any: filtered_inputs = [ example for example in examples if example.name == spec_name ] @@ -590,7 +594,7 @@ def _bracket(s: str) -> str: return f'[{s}]' -def _do_not_include_input_in_text(spec_name: str, spec: clrs.Spec) -> bool: +def _do_not_include_input_in_text(spec_name: str, spec: specs.Spec) -> bool: if spec_name == 'pos': return True if spec_name == 'adj' and 'A' in spec: