From d4ea54f87652a064659df29273b12e872a904b3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Capretto?= Date: Thu, 30 Jan 2025 19:12:51 -0300 Subject: [PATCH] Add as_dataframe() method to GroupEffectsMatrix (#115) * Update python-version in test.yml * Add as_dataframe() method to GroupEffectsMatrix --- .pylintrc | 11 +---------- formulae/matrices.py | 7 +++++++ tests/test_design_matrices.py | 16 ++++++++++++++++ 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/.pylintrc b/.pylintrc index 02a77bd..c477072 100644 --- a/.pylintrc +++ b/.pylintrc @@ -61,9 +61,7 @@ disable=missing-docstring, too-many-locals, too-many-branches, too-many-statements, - no-self-use, too-few-public-methods, - bad-continuation, invalid-name @@ -132,13 +130,6 @@ max-line-length=100 # Maximum number of lines in a module max-module-lines=1000 -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check=trailing-comma, - dict-separator - # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt=no @@ -500,4 +491,4 @@ min-public-methods=2 # Exceptions that will emit a warning when being caught. Defaults to # "Exception" -overgeneral-exceptions=Exception +overgeneral-exceptions=builtins.Exception diff --git a/formulae/matrices.py b/formulae/matrices.py index 68fde88..fddd777 100644 --- a/formulae/matrices.py +++ b/formulae/matrices.py @@ -428,6 +428,13 @@ def evaluate_new_data(self, data): new_instance.evaluated = True return new_instance + def as_dataframe(self): + """Returns `self.design_matrix` as a pandas.DataFrame.""" + columns = [] + for term in self.terms.values(): + columns.extend(term.labels) + return pd.DataFrame(self.design_matrix, columns=columns) + def __getitem__(self, term): """Get the sub-matrix that corresponds to a given term. diff --git a/tests/test_design_matrices.py b/tests/test_design_matrices.py index b1178ae..f040559 100644 --- a/tests/test_design_matrices.py +++ b/tests/test_design_matrices.py @@ -1052,6 +1052,22 @@ def test_group_specific_as_array(data): assert np.asarray(group).shape == (20, 4) +def test_group_specific_as_data_frame(data): + _, _, group = design_matrices("y ~ 1 + (x1|g) + (x1:x2|g) + (h|g)", data) + group_specific_dataframe = group.as_dataframe() + assert group_specific_dataframe.shape == (20, 8) + assert group_specific_dataframe.columns.tolist() == [ + "1|g[A]", + "1|g[B]", + "x1|g[A]", + "x1|g[B]", + "x1:x2|g[A]", + "x1:x2|g[B]", + "h[B]|g[A]", + "h[B]|g[B]", + ] + + def test_group_specific_repr_and_str(data): _, _, group = design_matrices("y ~ 1 + (x1|g) + (h|g)", data) text = (