-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
156 additions
and
102 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,79 +1,103 @@ | ||
# # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# # SPDX-License-Identifier: Apache-2.0 | ||
# # | ||
# # Licensed under the Apache License, Version 2.0 (the "License"); | ||
# # you may not use this file except in compliance with the License. | ||
# # You may obtain a copy of the License at | ||
# # | ||
# # http://www.apache.org/licenses/LICENSE-2.0 | ||
# # | ||
# # Unless required by applicable law or agreed to in writing, software | ||
# # distributed under the License is distributed on an "AS IS" BASIS, | ||
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# # See the License for the specific language governing permissions and | ||
# # limitations under the License. | ||
# import pytest | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import pytest | ||
import os | ||
|
||
# import tripy as tp | ||
import tripy as tp | ||
|
||
# from tripy.frontend.trace import Trace | ||
# from tripy.frontend.cache import ExecutableCache | ||
from tripy.frontend.trace import Trace | ||
from tripy.frontend.cache import ExecutableCache | ||
|
||
|
||
# @pytest.fixture | ||
# def cache(): | ||
# return ExecutableCache() | ||
@pytest.fixture | ||
def set_env(monkeypatch): | ||
monkeypatch.setenv("TRIPY_EAGER_CACHE", "1") | ||
|
||
|
||
# class TestCache: | ||
# def test_get_nonexistent_key(self, cache): | ||
# """Test getting a value for a nonexistent key.""" | ||
# cached_value = cache.get("nonexistent_key") | ||
@pytest.fixture | ||
def cache(): | ||
return ExecutableCache() | ||
|
||
# assert cached_value is None, "Expected None for a nonexistent key" | ||
|
||
# def test_normalize_key(self, cache): | ||
# raw_trace = | ||
# expected_key = | ||
# assert cache._normalize_key(raw_key) == expected_key | ||
@pytest.fixture | ||
def mock_global_cache(monkeypatch, cache): | ||
monkeypatch.setattr(tp.frontend.cache, "global_cache", cache) | ||
return cache | ||
|
||
# def test_same_operation_different_tensor(self, cache, monkeypatch): | ||
# # Mock the global cache | ||
# monkeypatch.setattr(tp.frontend, "global_cache", cache) | ||
|
||
# tensor = tp.Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=tp.float32) | ||
class TestCache: | ||
def test_identical_graph_different_input_shapes(self, mock_global_cache, set_env): | ||
"""Test caching with identical computation graph but different input shapes.""" | ||
input1 = tp.Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=tp.float32) | ||
input2 = tp.Tensor([[[1.0, 2.0], [3.0, 4.0]]], dtype=tp.float32) | ||
|
||
# cache.clear() # Ensure the cache is empty | ||
# assert cache.size() == 0 | ||
layer = tp.Linear(2, 3) | ||
|
||
# tensor.eval() | ||
output1 = layer(input1) | ||
assert mock_global_cache.get(Trace([output1.trace_tensor]), devices=[output1.device]) is None | ||
output1.eval() | ||
assert mock_global_cache.get(Trace([layer(input1).trace_tensor]), devices=[output1.device]) is not None | ||
|
||
# tensor2 = tp.Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=tp.float32) | ||
# assert cache.size() == 1 | ||
# assert cache.get(str(Trace([tensor2]))) is not None # Ensure tensor trace is in cache | ||
output2 = layer(input2) | ||
assert mock_global_cache.get(Trace([output2.trace_tensor]), devices=[output2.device]) is None | ||
output2.eval() | ||
assert mock_global_cache.get(Trace([layer(input2).trace_tensor]), devices=[output2.device]) is not None | ||
|
||
# def test_equivalance(self, cache, monkeypatch): | ||
# # Mock the global cache | ||
# monkeypatch.setattr(tp.frontend, "global_cache", cache) | ||
def test_identical_graph_different_input_names(self, mock_global_cache, set_env): | ||
"""Test caching with identical computation graph but different input names.""" | ||
input1 = tp.Tensor([[1.0, 2.0]], dtype=tp.float32, name="input_a") | ||
input2 = tp.Tensor([[1.0, 2.0]], dtype=tp.float32, name="input_b") | ||
|
||
# input_tensor = tp.Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=tp.float32) | ||
# layer = tp.Linear(2, 3) | ||
layer = tp.Linear(2, 3) | ||
output1 = layer(input1) | ||
output1.eval() | ||
|
||
# # Operation without caching | ||
# output_without_cache = layer(input_tensor) | ||
output2 = layer(input2) | ||
assert mock_global_cache.get(Trace([output2.trace_tensor]), devices=[output2.device]) is not None | ||
|
||
# cache.clear() # Ensure the cache is empty | ||
# assert cache.size() == 0 | ||
def test_identical_graph_different_output_names(self, mock_global_cache, set_env): | ||
"""Test caching with identical computation graph but different output tensor names.""" | ||
input_tensor = tp.Tensor([[1.0, 2.0]], dtype=tp.float32) | ||
|
||
# # Eval operation without caching | ||
# output_without_cache.eval() | ||
layer = tp.Linear(2, 3) | ||
|
||
# # Operation with caching | ||
# output_with_cache = layer(input_tensor) | ||
# assert cache.get(str(Trace([output_with_cache]))) is not None | ||
output1 = layer(input_tensor) | ||
output1.name = "output_a" | ||
output1.eval() | ||
|
||
# # Run the operation with caching | ||
# output_with_cache.eval() | ||
output2 = layer(input_tensor) | ||
output2.name = "output_b" | ||
assert mock_global_cache.get(Trace([output2.trace_tensor]), devices=[output2.device]) is not None | ||
|
||
# # Assert outputs are equivalent | ||
# assert tp.allclose(output_without_cache, output_with_cache) | ||
def test_different_graphs_different_cache_entries(self, mock_global_cache, set_env): | ||
"""Test caching with different computation graphs having different cache entries.""" | ||
input_tensor = tp.Tensor([[1.0, 2.0]], dtype=tp.float32) | ||
|
||
layer1 = tp.Linear(2, 3) | ||
layer2 = tp.Linear(2, 4) | ||
|
||
output1 = layer1(input_tensor) | ||
assert mock_global_cache.get(Trace([output1.trace_tensor]), devices=[output1.device]) is None | ||
output1.eval() | ||
assert mock_global_cache.get(Trace([layer1(input_tensor).trace_tensor]), devices=[output1.device]) is not None | ||
|
||
output2 = layer2(input_tensor) | ||
assert mock_global_cache.get(Trace([output2.trace_tensor]), devices=[output2.device]) is None | ||
output2.eval() | ||
assert mock_global_cache.get(Trace([layer2(input_tensor).trace_tensor]), devices=[output2.device]) is not None | ||
|
||
# test_trace_normalize | ||
# test_trace_normalize with storage op shape < thershold | ||
# test_trace_normalize with storage op shape > thershold |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters