Skip to content

Commit

Permalink
cache tests
Browse files Browse the repository at this point in the history
  • Loading branch information
farazkh80 committed Dec 14, 2024
1 parent 5cc3fed commit 3a3b94a
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 102 deletions.
1 change: 1 addition & 0 deletions tripy/docs/post0_developer_guides/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This guide outlines some methods of doing so.

We include some environment variables to enable extra debugging information from MLIR-TRT:

- `export TRIPY_EAGER_CACHE=1` will enable eager caching for Tripy tensors to cache all intermediate IRs for future reuse.
- `export TRIPY_MLIR_DEBUG_ENABLED=1` will enable debug prints in MLIR-TRT and dump all intermediate IRs to a directory.
- `export TRIPY_MLIR_DEBUG_PATH=<mlir-debug-path>` sets the directory for IR dumps. The default path is `mlir-dumps`.
- `export TRIPY_TRT_DEBUG_ENABLED=1` will dump TensorRT engines and their layer information.
Expand Down
140 changes: 82 additions & 58 deletions tripy/tests/frontend/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,103 @@
# # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# # SPDX-License-Identifier: Apache-2.0
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
# import pytest
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import os

# import tripy as tp
import tripy as tp

# from tripy.frontend.trace import Trace
# from tripy.frontend.cache import ExecutableCache
from tripy.frontend.trace import Trace
from tripy.frontend.cache import ExecutableCache


# @pytest.fixture
# def cache():
# return ExecutableCache()
@pytest.fixture
def set_env(monkeypatch):
monkeypatch.setenv("TRIPY_EAGER_CACHE", "1")


# class TestCache:
# def test_get_nonexistent_key(self, cache):
# """Test getting a value for a nonexistent key."""
# cached_value = cache.get("nonexistent_key")
@pytest.fixture
def cache():
return ExecutableCache()

# assert cached_value is None, "Expected None for a nonexistent key"

# def test_normalize_key(self, cache):
# raw_trace =
# expected_key =
# assert cache._normalize_key(raw_key) == expected_key
@pytest.fixture
def mock_global_cache(monkeypatch, cache):
monkeypatch.setattr(tp.frontend.cache, "global_cache", cache)
return cache

# def test_same_operation_different_tensor(self, cache, monkeypatch):
# # Mock the global cache
# monkeypatch.setattr(tp.frontend, "global_cache", cache)

# tensor = tp.Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=tp.float32)
class TestCache:
def test_identical_graph_different_input_shapes(self, mock_global_cache, set_env):
"""Test caching with identical computation graph but different input shapes."""
input1 = tp.Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=tp.float32)
input2 = tp.Tensor([[[1.0, 2.0], [3.0, 4.0]]], dtype=tp.float32)

# cache.clear() # Ensure the cache is empty
# assert cache.size() == 0
layer = tp.Linear(2, 3)

# tensor.eval()
output1 = layer(input1)
assert mock_global_cache.get(Trace([output1.trace_tensor]), devices=[output1.device]) is None
output1.eval()
assert mock_global_cache.get(Trace([layer(input1).trace_tensor]), devices=[output1.device]) is not None

# tensor2 = tp.Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=tp.float32)
# assert cache.size() == 1
# assert cache.get(str(Trace([tensor2]))) is not None # Ensure tensor trace is in cache
output2 = layer(input2)
assert mock_global_cache.get(Trace([output2.trace_tensor]), devices=[output2.device]) is None
output2.eval()
assert mock_global_cache.get(Trace([layer(input2).trace_tensor]), devices=[output2.device]) is not None

# def test_equivalance(self, cache, monkeypatch):
# # Mock the global cache
# monkeypatch.setattr(tp.frontend, "global_cache", cache)
def test_identical_graph_different_input_names(self, mock_global_cache, set_env):
"""Test caching with identical computation graph but different input names."""
input1 = tp.Tensor([[1.0, 2.0]], dtype=tp.float32, name="input_a")
input2 = tp.Tensor([[1.0, 2.0]], dtype=tp.float32, name="input_b")

# input_tensor = tp.Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=tp.float32)
# layer = tp.Linear(2, 3)
layer = tp.Linear(2, 3)
output1 = layer(input1)
output1.eval()

# # Operation without caching
# output_without_cache = layer(input_tensor)
output2 = layer(input2)
assert mock_global_cache.get(Trace([output2.trace_tensor]), devices=[output2.device]) is not None

# cache.clear() # Ensure the cache is empty
# assert cache.size() == 0
def test_identical_graph_different_output_names(self, mock_global_cache, set_env):
"""Test caching with identical computation graph but different output tensor names."""
input_tensor = tp.Tensor([[1.0, 2.0]], dtype=tp.float32)

# # Eval operation without caching
# output_without_cache.eval()
layer = tp.Linear(2, 3)

# # Operation with caching
# output_with_cache = layer(input_tensor)
# assert cache.get(str(Trace([output_with_cache]))) is not None
output1 = layer(input_tensor)
output1.name = "output_a"
output1.eval()

# # Run the operation with caching
# output_with_cache.eval()
output2 = layer(input_tensor)
output2.name = "output_b"
assert mock_global_cache.get(Trace([output2.trace_tensor]), devices=[output2.device]) is not None

# # Assert outputs are equivalent
# assert tp.allclose(output_without_cache, output_with_cache)
def test_different_graphs_different_cache_entries(self, mock_global_cache, set_env):
"""Test caching with different computation graphs having different cache entries."""
input_tensor = tp.Tensor([[1.0, 2.0]], dtype=tp.float32)

layer1 = tp.Linear(2, 3)
layer2 = tp.Linear(2, 4)

output1 = layer1(input_tensor)
assert mock_global_cache.get(Trace([output1.trace_tensor]), devices=[output1.device]) is None
output1.eval()
assert mock_global_cache.get(Trace([layer1(input_tensor).trace_tensor]), devices=[output1.device]) is not None

output2 = layer2(input_tensor)
assert mock_global_cache.get(Trace([output2.trace_tensor]), devices=[output2.device]) is None
output2.eval()
assert mock_global_cache.get(Trace([layer2(input_tensor).trace_tensor]), devices=[output2.device]) is not None

# test_trace_normalize
# test_trace_normalize with storage op shape < thershold
# test_trace_normalize with storage op shape > thershold
2 changes: 1 addition & 1 deletion tripy/tests/frontend/trace/ops/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_from_list_float(self):
def test_empty_list(self):
data = [[]]
storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data)
assert storage.dtype == tp.int32
assert storage.dtype == tp.float32
assert storage.shape == (1, 0)
assert storage.device.kind == "gpu"
assert storage.data_str == "[[]]"
Expand Down
3 changes: 0 additions & 3 deletions tripy/tripy/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,3 @@

from tripy.frontend.tensor import Tensor
from tripy.frontend.trace import Trace
from tripy.frontend.cache import ExecutableCache

global_cache = ExecutableCache()
61 changes: 44 additions & 17 deletions tripy/tripy/frontend/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ExecutableCache:
"""Global cache for storing compiled executables."""

def __init__(self):
self._cache = {}
self._cache: Dict[str, runtime.Executable] = {}

def _assign_tensor_name(
self,
Expand All @@ -35,13 +35,13 @@ def _assign_tensor_name(
Assign or retrieve a tensor name.
Args:
tensor: The tensor to name
tensor_map: Mapping of tensor ids to names (clean or original)
next_id: Mutable list for tracking next tensor ID
backup_map: Mapping to store original names
tensor (TraceTensor): The tensor to name.
tensor_map (Dict[int, str]): Mapping of tensor ids to names (clean or original).
next_id (List[int]): Mutable list for tracking next tensor ID.
backup_map (Dict[int, str], optional): Mapping to store original names. Defaults to None.
Returns:
str: The assigned or retrieved tensor name
str: The assigned or retrieved tensor name.
"""
t_id = id(tensor)

Expand All @@ -62,10 +62,10 @@ def _update_trace_names(
Update names for inputs, outputs, and operations in the trace.
Args:
trace: The trace to update
tensor_map: Mapping of tensor ids to names
next_id: Mutable list for tracking next tensor ID
backup_map: Mapping of original tensor names
trace (Trace): The trace to update.
tensor_map (Dict[int, str]): Mapping of tensor ids to names.
next_id (List[int]): Mutable list for tracking next tensor ID.
backup_map (Dict[int, str], optional): Mapping of original tensor names. Defaults to None.
"""
# Update input names
for inp in trace.inputs:
Expand All @@ -87,10 +87,10 @@ def _normalize_trace(self, trace: "Trace") -> str:
Normalize the trace by renaming all tensor names while preserving the structure.
Args:
trace: The trace to normalize
trace (Trace): The trace to normalize.
Returns:
str: Normalized trace as a string
str: Normalized trace as a string.
"""
# Initialize maps and next tensor ID
tensor_map: Dict[int, str] = {}
Expand All @@ -109,18 +109,45 @@ def _normalize_trace(self, trace: "Trace") -> str:
return trace_str

def _generate_key(self, trace: "Trace", devices: List["tripy.common.device"]) -> str:
"""
Generate a unique key for a given trace and device configuration.
Args:
trace (Trace): The trace for which to generate the key.
devices (List[Device]): List of devices associated with the trace.
Returns:
str: A unique hash key representing the trace and devices.
"""
normalized_trace = self._normalize_trace(trace)
key = normalized_trace + "\ndevices:\n" + "\n".join([str(device) for device in devices])
return hashlib.sha256(key.encode("utf-8")).hexdigest()

def get(self, trace: "Trace", devices: List["tripy.common.device"]):
def get(self, trace: "Trace", devices: List["tripy.common.device"]) -> runtime.Executable:
"""
Retrieve a cached executable for the given trace and devices.
Args:
trace (Trace): The trace used as a key.
devices (List[Device]): List of devices associated with the trace.
Returns:
Executable: The cached executable, or None if not found.
"""
key = self._generate_key(trace, devices)
return self._cache.get(key, None)

def set(self, trace: "Trace", executable: runtime.Executable, devices: List["tripy.common.device"]):
def set(self, trace: "Trace", executable: runtime.Executable, devices: List["tripy.common.device"]) -> None:
"""
Cache an executable for the given trace and devices.
Args:
trace (Trace): The trace used as a key.
executable (Executable): The executable to cache.
devices (List[Device]): List of devices associated with the trace.
"""
key = self._generate_key(trace, devices)
self._cache[key] = executable

def size(self) -> int:
"""Return the number of items in the cache."""
return len(self._cache)

global_cache = ExecutableCache()
27 changes: 5 additions & 22 deletions tripy/tripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def raw_init(
if data is None:
return

Storage.build_internal([], [instance.trace_tensor], data)
Storage.build_internal(
[], [instance.trace_tensor], data, device=device if not hasattr(data, "__dlpack__") else None
)

# TODO(#155): Remove this hack:
instance.trace_tensor.device = utils.default(device, instance.trace_tensor.device)
Expand Down Expand Up @@ -209,10 +211,10 @@ def eval(self) -> runtime.MemRefValue:
from tripy.backend.mlir.compiler import Compiler
from tripy.backend.mlir.executor import Executor
from tripy.frontend.trace import Trace
from tripy.frontend import global_cache
from tripy.frontend.cache import global_cache

# Collect inputs
inputs = self._collect_storage_tensors() # TODO: how to test real inputs? not shape inputs
inputs = Trace._collect_storage_tensors(self.trace_tensor) # TODO: how to test real inputs? not shape inputs
input_shapes = [ShapeBounds(min=tuple(inp.shape), opt=tuple(inp.shape), max=tuple(inp.shape)) for inp in inputs]

trace = Trace([self.trace_tensor], inputs=inputs, shapes=input_shapes)
Expand Down Expand Up @@ -257,25 +259,6 @@ def eval(self) -> runtime.MemRefValue:

return data

def _collect_storage_tensors(self):
visited = set()
inputs = []

def dfs(trace_tensor):
if id(trace_tensor) in visited:
return
visited.add(id(trace_tensor))

producer = trace_tensor.producer
if isinstance(producer, Storage) and utils.should_lift_storage_op_as_input(producer.shape):
inputs.append(trace_tensor)
else:
for inp in producer.inputs:
dfs(inp)

dfs(self.trace_tensor)
return inputs

def tolist(self):
data_memref = self.eval()
if self.dtype not in (
Expand Down
2 changes: 1 addition & 1 deletion tripy/tripy/frontend/trace/ops/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
)
else:
if common_utils.is_empty(data):
self.dtype = datatype.int32
self.dtype = datatype.float32
data_array = None
else:
self.dtype = common_utils.get_element_type(data)
Expand Down
22 changes: 22 additions & 0 deletions tripy/tripy/frontend/trace/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import copy
from typing import List, Sequence, Set

from tripy import utils
from tripy.common.exception import raise_error
from tripy.common.shape_bounds import ShapeBounds
from tripy.frontend.trace.ops import Storage
from tripy.frontend.trace.tensor import TraceTensor
from tripy.frontend.utils import topological_sort
from tripy.logging import logger
Expand Down Expand Up @@ -105,6 +107,26 @@ def __str__(self) -> str:
layer_strs.append(f" {str(out)}")
return "\n".join(layer_strs)

@staticmethod
def _collect_storage_tensors(trace_tensor):
visited = set()
inputs = []

def dfs(trace_tensor):
if id(trace_tensor) in visited:
return
visited.add(id(trace_tensor))

producer = trace_tensor.producer
if isinstance(producer, Storage) and utils.should_lift_storage_op_as_input(producer.shape):
inputs.append(trace_tensor)
else:
for inp in producer.inputs:
dfs(inp)

dfs(trace_tensor)
return inputs

def to_flat_ir(self):
from tripy.flat_ir.flat_ir import FlatIR

Expand Down

0 comments on commit 3a3b94a

Please sign in to comment.