diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 780a9511a9709..5271658e856f1 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -594,7 +594,7 @@ void BindProgram(py::module *m) { }) .def("num_ops", [](Program &self) { return self.num_ops(); }) .def( - "state_dict", + "_state_dict", [](std::shared_ptr self, const std::string &mode = "all", const framework::Scope &scope = framework::Scope()) { diff --git a/python/paddle/pir/program_patch.py b/python/paddle/pir/program_patch.py index df7f55083d126..519a5041f0937 100644 --- a/python/paddle/pir/program_patch.py +++ b/python/paddle/pir/program_patch.py @@ -27,8 +27,22 @@ def _lr_schedule_guard(self, is_with_opt=False): # be fixed in the future. yield + def state_dict(self, mode="all", scope=None): + from paddle.base import core + from paddle.base.executor import global_scope + + if scope is not None and not isinstance(scope, core._Scope): + raise TypeError( + f"`scope` should be None or `paddle.static.Scope'` type, but received {type(scope)}." + ) + + if scope is None: + scope = global_scope() + return self._state_dict(mode, scope) + program_attrs = { "_lr_schedule_guard": _lr_schedule_guard, + "state_dict": state_dict, } global _already_patch_program