Skip to content

Commit

Permalink
Merge pull request #26 from twosigma/remove-is-node-optimization
Browse files Browse the repository at this point in the history
Remove unnecessary _is_node optimization and clean up type hints.
  • Loading branch information
daniel-shields authored Aug 30, 2024
2 parents a8c0dcc + d4ce856 commit 952105d
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 24 deletions.
11 changes: 3 additions & 8 deletions src/uberjob/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@
}


def _is_node(value) -> bool:
"""Efficiently determines whether the given value is a :class:`~uberjob.graph.Node`."""
return type(value) in (Call, Literal)


class Plan:
"""Represents a symbolic call graph."""

Expand Down Expand Up @@ -90,7 +85,7 @@ def lit(self, value) -> Literal:
:param value: The literal value.
:return: The symbolic literal value.
"""
if _is_node(value):
if isinstance(value, Node):
raise TypeError(f"The value is already a {Node.__name__}.")
literal = Literal(value, scope=self._scope)
self.graph.add_node(literal)
Expand Down Expand Up @@ -118,12 +113,12 @@ def recurse(root):
if gather_fn is not None:
items = root.items() if root_type is dict else root
children = [recurse(item) for item in items]
if any(_is_node(child) for child in children):
if any(isinstance(child, Node) for child in children):
return self._call(stack_frame, gather_fn, *children)
return root

value = recurse(value)
return value if _is_node(value) else self.lit(value)
return value if isinstance(value, Node) else self.lit(value)

def gather(self, value) -> Node:
"""
Expand Down
6 changes: 3 additions & 3 deletions src/uberjob/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#
import copy
import typing
from collections.abc import Iterable, KeysView

from uberjob._builtins import source
from uberjob._plan import Node, Plan
Expand Down Expand Up @@ -99,7 +99,7 @@ def get(self, node: Node) -> ValueStore | None:
v = self.mapping.get(node)
return v.value_store if v else None

def keys(self) -> typing.KeysView[Node]:
def keys(self) -> KeysView[Node]:
"""
Get all registered :class:`~uberjob.graph.Node` instances.
Expand All @@ -123,7 +123,7 @@ def items(self) -> list[tuple[Node, ValueStore]]:
"""
return [(k, v.value_store) for k, v in self.mapping.items()]

def __iter__(self) -> typing.Iterable[Node]:
def __iter__(self) -> Iterable[Node]:
"""
Get all registered :class:`~uberjob.graph.Node` instances.
Expand Down
11 changes: 5 additions & 6 deletions src/uberjob/_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import typing
from collections import OrderedDict
from collections.abc import Callable

from uberjob._plan import Plan
from uberjob._registry import Registry
Expand All @@ -35,7 +34,7 @@
GRAY = (0.4, 0.4, 0.4)


def default_style(registry: Registry = None):
def default_style(registry: Registry | None = None):
import nxv

if registry is None:
Expand Down Expand Up @@ -132,8 +131,8 @@ class Scope:
def render(
plan: Plan | Graph | tuple[Plan, Node | None],
*,
registry: Registry = None,
predicate: typing.Callable[[Node, dict], bool] = None,
registry: Registry | None = None,
predicate: None | Callable[[Node, dict], bool] = None,
level: int | None = None,
format: str | None = None
) -> bytes | None:
Expand Down Expand Up @@ -165,7 +164,7 @@ def render(
)

if level is not None:
scope_groups = OrderedDict()
scope_groups = {}
for u in graph.nodes():
scope = u.scope
if scope:
Expand Down
4 changes: 2 additions & 2 deletions src/uberjob/_testing/test_mounted_file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#
import datetime as dt
import typing
from collections.abc import Callable

from uberjob._testing.test_store import TestStore
from uberjob._util import repr_helper
Expand All @@ -23,7 +23,7 @@


class TestMountedFileStore(MountedStore):
def __init__(self, create_file_store: typing.Callable[[str], FileStore]):
def __init__(self, create_file_store: Callable[[str], FileStore]):
super().__init__(create_file_store)
self.remote_store = TestStore()

Expand Down
6 changes: 3 additions & 3 deletions src/uberjob/_util/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#
import inspect
import typing
from collections.abc import Callable
from functools import lru_cache

from uberjob._util import fully_qualified_name
Expand Down Expand Up @@ -57,14 +57,14 @@ def assert_is_instance(


@lru_cache(4096)
def try_get_signature(fn: typing.Callable):
def try_get_signature(fn: Callable):
try:
return inspect.signature(fn)
except ValueError:
return None


def assert_can_bind(fn: typing.Callable, *args, **kwargs):
def assert_can_bind(fn: Callable, *args, **kwargs):
sig = try_get_signature(fn)
if sig is None:
return
Expand Down
4 changes: 2 additions & 2 deletions src/uberjob/stores/_mounted_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
#
import os
import tempfile
import typing
from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import contextmanager

from uberjob._util import repr_helper
Expand All @@ -38,7 +38,7 @@ class MountedStore(ValueStore, ABC):

__slots__ = ("create_store",)

def __init__(self, create_store: typing.Callable[[str], ValueStore]):
def __init__(self, create_store: Callable[[str], ValueStore]):
self.create_store = create_store
"""Creates an instance of the underlying :class:`~uberjob.ValueStore` for the given path."""

Expand Down

0 comments on commit 952105d

Please sign in to comment.