Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT][Faster Guard] add StringCompareGuard and support RangeVariable #70362

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/pybind/jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ void BindGuard(pybind11::module *m) {
std::shared_ptr<InstanceCheckGuard>>(
*m, "InstanceCheckGuard", R"DOC(InstanceCheckGuard Class.)DOC")
.def(py::init<const py::object &>(), py::arg("isinstance_obj"));
py::class_<StringCompareGuard,
GuardBase,
std::shared_ptr<StringCompareGuard>>(
*m, "StringCompareGuard", R"DOC(StringCompareGuard Class.)DOC")
.def(py::init<const char &>(), py::arg("str"));

m->def(
"merge_guard",
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/pybind/sot/guards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,15 @@ bool InstanceCheckGuard::check(PyObject* value) {
return PyObject_IsInstance(value, expected_);
}

bool StringCompareGuard::check(PyObject* value) {
// std::string expected_str = PyUnicode_AsUTF8(expected_);
std::cout << "[StringCheckGuard]" << expected_ << std::endl;
// std::string str = PyUnicode_AsUTF8(value);
std::cout << "[StringCheckGuard]" << value << std::endl;

// return PyUnicode_(value, expected_) == 0;
return PyUnicode_CompareWithASCIIString(value, &expected_) == 0;
// return value->equal(*expected_);
}

#endif
10 changes: 10 additions & 0 deletions paddle/fluid/pybind/sot/guards.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,14 @@ class InstanceCheckGuard : public GuardBase {
PyObject* expected_;
};

class StringCompareGuard : public GuardBase {
public:
explicit StringCompareGuard(const char& expected) : expected_(expected) {}

bool check(PyObject* value) override;

private:
char expected_;
};

#endif
3 changes: 2 additions & 1 deletion python/paddle/jit/sot/opcode_translator/executor/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,9 @@ def object_equal_stringified_guard(self) -> list[StringifiedExpression]:
)
]
return [
StringifiedExpression(
FasterStringifiedExpression(
f"{{}} == {obj_free_var_name}",
paddle.framework.core.StringCompareGuard(obj_free_var_name),
[frame_value_tracer],
union_free_vars(
frame_value_tracer.free_vars,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -744,11 +744,12 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
frame_value_tracer = self.tracker.trace_value_from_frame()

return [
StringifiedExpression(
"isinstance({0}, range) and "
FasterStringifiedExpression(
"id(type({0})) == id(range) and "
Copy link
Preview

Copilot AI Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attribute 'init_value' is used but not defined in the 'RangeVariable' class. This will lead to an AttributeError. Define 'init_value' in the class or use an existing attribute.

Suggested change
"id(type({0})) == id(range) and "
+ f"{{0}}.start == {self.some_existing_attribute.start} and "

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attribute 'init_value' is used but not defined in the 'RangeVariable' class. This will lead to an AttributeError. Define 'init_value' in the class or use an existing attribute.

差评

+ f"{{0}}.start == {self.init_value.start} and "
+ f"{{0}}.stop == {self.init_value.stop} and "
+ f"{{0}}.step == {self.init_value.step}",
paddle.framework.core.RangeMatchGuard(self.init_value),
[frame_value_tracer],
frame_value_tracer.free_vars,
)
Expand Down
44 changes: 27 additions & 17 deletions test/sot/test_06_call_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from test_case_base import (
TestCaseBase,
test_instruction_translator_cache_context,
test_with_faster_guard,
)

import paddle
Expand Down Expand Up @@ -127,30 +128,38 @@ def foo_8(x: paddle.Tensor):
return m


class TestCall(TestCaseBase):
def test_call1(self):
self.assert_results(foo_1, paddle.to_tensor(2))
# class TestCall(TestCaseBase):
# @test_with_faster_guard
# def test_call1(self):
# self.assert_results(foo_1, paddle.to_tensor(2))

def test_call2(self):
self.assert_results(foo_2, paddle.to_tensor(3))
# @test_with_faster_guard
# def test_call2(self):
# self.assert_results(foo_2, paddle.to_tensor(3))

def test_call3(self):
self.assert_results(foo_3, paddle.to_tensor(4))
# @test_with_faster_guard
# def test_call3(self):
# self.assert_results(foo_3, paddle.to_tensor(4))

def test_call4(self):
self.assert_results(foo_4, paddle.to_tensor(5))
# @test_with_faster_guard
# def test_call4(self):
# self.assert_results(foo_4, paddle.to_tensor(5))

def test_call5(self):
self.assert_results(foo_5, paddle.to_tensor(6))
# @test_with_faster_guard
# def test_call5(self):
# self.assert_results(foo_5, paddle.to_tensor(6))

def test_call6(self):
self.assert_results(foo_6, paddle.to_tensor(7))
# @test_with_faster_guard
# def test_call6(self):
# self.assert_results(foo_6, paddle.to_tensor(7))

def test_call7(self):
self.assert_results(foo_7, paddle.to_tensor(8))
# @test_with_faster_guard
# def test_call7(self):
# self.assert_results(foo_7, paddle.to_tensor(8))

def test_call8(self):
self.assert_results(foo_8, paddle.to_tensor(9))
# @test_with_faster_guard
# def test_call8(self):
# self.assert_results(foo_8, paddle.to_tensor(9))


def apply_fn(fn, x):
Expand All @@ -166,6 +175,7 @@ def fn2(x):


class TestApplyDifferentFunctions(TestCaseBase):
@test_with_faster_guard
def test_apply_fn(self):
x = 1
with test_instruction_translator_cache_context() as ctx:
Expand Down
3 changes: 2 additions & 1 deletion test/sot/test_enumerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import unittest

from test_case_base import TestCaseBase
from test_case_base import TestCaseBase, test_with_faster_guard

import paddle
from paddle.jit.sot.utils import strict_mode_guard
Expand Down Expand Up @@ -86,6 +86,7 @@ def test_enumerate_10(layer_list, x):


class TestEnumerate(TestCaseBase):
@test_with_faster_guard
def test_cases(self):
x = 8
y = 5
Expand Down
5 changes: 5 additions & 0 deletions test/sot/test_faster_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def test_id_match_guard(self):
self.assertTrue(guard_id.check(layer))
self.assertFalse(guard_id.check(paddle.nn.Linear(10, 10)))

def test_string_match_guard(self):
guard_string = paddle.framework.core.StringCompareGuard("1")
self.assertTrue(guard_string.check("1"))
self.assertFalse(guard_string.check("2"))


class TestFasterGuardGroup(unittest.TestCase):
def test_guard_group(self):
Expand Down