Skip to content

Commit

Permalink
Merge pull request #30 from dbatten5/col-index
Browse files Browse the repository at this point in the history
Col index
  • Loading branch information
dbatten5 committed Apr 11, 2023
2 parents 7f339a5 + 67d9ca1 commit 42d8bf1
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 21 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "arraytex"
version = "0.0.6"
version = "0.0.7"
description = "ArrayTeX"
authors = ["Dom Batten <[email protected]>"]
license = "MIT"
Expand Down Expand Up @@ -106,7 +106,6 @@ line-length = 80
select = [
'B',
'B9',
'C',
'D',
'E',
'F',
Expand Down
56 changes: 43 additions & 13 deletions src/arraytex/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def to_tabular(
num_format: Optional[str] = None,
scientific_notation: bool = False,
col_align: Union[List[str], str] = "c",
cols: Optional[List[str]] = None,
col_names: Optional[List[str]] = None,
col_index: Optional[List[str]] = None,
to_clp: bool = False,
) -> str:
"""Convert a numpy.NDArray to LaTeX tabular environment.
Expand All @@ -72,7 +73,8 @@ def to_tabular(
single character is provided then it will be broadcast to all columns. If a list
is provided then each item will be assigned to each column, list size and
number of columns must match
cols: an optional list of column names, otherwise generic names will be assigned
col_names: an optional list of column names, otherwise generic names will be assigned
col_index: an optional list of column indices, i.e. row identifiers
to_clp: copy the output to the system clipboard
Returns:
Expand All @@ -94,29 +96,57 @@ def to_tabular(
else:
raise TooManyDimensionsError

if isinstance(col_align, list) and len(col_align) != n_cols:
if not col_index:
if isinstance(col_align, list) and len(col_align) != n_cols:
raise DimensionMismatchError(
f"Number of `col_align` items ({len(col_align)}) "
+ f"doesn't match number of columns ({n_cols})"
)

if col_names and len(col_names) != n_cols:
raise DimensionMismatchError(
f"Number of `col_names` items ({len(col_names)}) "
+ f"doesn't match number of columns ({n_cols})"
)

if (
col_index
and col_names
and isinstance(col_align, list)
and len(col_names) != len(col_align)
):
raise DimensionMismatchError(
f"Number of `col_align` items ({len(col_align)}) "
+ f"doesn't match number of columns ({n_cols})"
+ f"doesn't match number of columns ({len(col_names)})"
)

if isinstance(col_align, str):
col_align = [col_align for _ in range(n_cols)]

if cols and len(cols) != n_cols:
raise DimensionMismatchError(
f"Number of `cols` items ({len(cols)}) "
+ f"doesn't match number of columns ({n_cols})"
)

if not cols:
cols = [f"Col {i + 1}" for i in range(n_cols)]
if not col_names:
col_names = [f"Col {i + 1}" for i in range(n_cols)]

lines = _parse_lines(arr, num_format, scientific_notation)

if col_index:
if len(col_index) != len(lines):
raise DimensionMismatchError(
f"Number of `col_index` items ({len(col_index)}) "
+ f"doesn't match number of rows ({len(lines)})"
)

if len(col_align) == n_cols:
col_align.insert(0, "l")

if len(col_names) == n_cols:
col_names.insert(0, "Index")

for idx, line in enumerate(lines):
lines[idx] = f"{col_index[idx]} & " + line.strip()

rv = [f"\\begin{{tabular}}{{{' '.join(col_align)}}}"]
rv += [r"\toprule"]
rv += [" & ".join(cols) + r" \\"]
rv += [" & ".join(col_names) + r" \\"]
rv += [r"\midrule"]
rv += [line.strip() + r" \\" for line in lines]
rv += [r"\bottomrule"]
Expand Down
130 changes: 124 additions & 6 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,15 @@ def test_mismatch_col_align(self) -> None:
"Number of `col_align` items (2) doesn't match number of columns (3)"
)

def test_mismatch_cols(self) -> None:
"""Error is thrown if wrong number of cols items."""
def test_mismatch_col_names(self) -> None:
"""Error is thrown if wrong number of col_names items."""
mat = np.arange(6).reshape(2, 3)

with pytest.raises(DimensionMismatchError) as exc:
to_tabular(mat, cols=["1", "2"])
to_tabular(mat, col_names=["1", "2"])

assert str(exc.value) == (
"Number of `cols` items (2) doesn't match number of columns (3)"
"Number of `col_names` items (2) doesn't match number of columns (3)"
)

def test_default(self) -> None:
Expand All @@ -208,11 +208,11 @@ def test_default(self) -> None:
\end{tabular}"""
)

def test_given_cols(self) -> None:
def test_given_col_names(self) -> None:
"""User can supply col names."""
mat = np.arange(1, 5).reshape(2, 2)

out = to_tabular(mat, cols=["1", "b"])
out = to_tabular(mat, col_names=["1", "b"])

assert (
out
Expand Down Expand Up @@ -277,3 +277,121 @@ def test_0_d(self) -> None:
\bottomrule
\end{tabular}"""
)

class TestColIndex:
"""Tests for the `col_index` support."""

def test_default(self) -> None:
"""`col_index` forms the row names."""
col_index = ["Row 1", "Row 2"]
mat = np.arange(1, 5).reshape(2, 2)

out = to_tabular(mat, col_index=col_index)

assert (
out
== r"""\begin{tabular}{l c c}
\toprule
Index & Col 1 & Col 2 \\
\midrule
Row 1 & 1 & 2 \\
Row 2 & 3 & 4 \\
\bottomrule
\end{tabular}"""
)

def test_bad_dimensions(self) -> None:
"""An error is raised if wrong dimension of `col_index`."""
col_index = ["Row 1", "Row 2"]
mat = np.arange(1, 4).reshape(3, 1)

with pytest.raises(DimensionMismatchError) as exc:
to_tabular(mat, col_index=col_index)

assert str(exc.value) == (
"Number of `col_index` items (2) doesn't match number of rows (3)"
)

def test_given_col_names(self) -> None:
"""A given index name as part of `col_names` is used."""
col_index = ["Row 1", "Row 2"]
col_names = ["My Index", "Col 1", "Col 2"]
mat = np.arange(1, 5).reshape(2, 2)

out = to_tabular(mat, col_index=col_index, col_names=col_names)

assert (
out
== r"""\begin{tabular}{l c c}
\toprule
My Index & Col 1 & Col 2 \\
\midrule
Row 1 & 1 & 2 \\
Row 2 & 3 & 4 \\
\bottomrule
\end{tabular}"""
)

def test_given_col_align(self) -> None:
"""A given col align char can be used for the col index."""
col_index = ["Row 1", "Row 2"]
mat = np.arange(1, 5).reshape(2, 2)

out = to_tabular(mat, col_index=col_index, col_align=["r", "c", "c"])

assert (
out
== r"""\begin{tabular}{r c c}
\toprule
Index & Col 1 & Col 2 \\
\midrule
Row 1 & 1 & 2 \\
Row 2 & 3 & 4 \\
\bottomrule
\end{tabular}"""
)

def test_given_col_name_and_align(self) -> None:
"""A given col index name and align can be used for the index."""
col_index = ["Row 1", "Row 2"]
col_names = ["My Index", "Col 1", "Col 2"]
col_align = ["r", "c", "c"]
mat = np.arange(1, 5).reshape(2, 2)

out = to_tabular(
mat,
col_align=col_align,
col_names=col_names,
col_index=col_index,
)

assert (
out
== r"""\begin{tabular}{r c c}
\toprule
My Index & Col 1 & Col 2 \\
\midrule
Row 1 & 1 & 2 \\
Row 2 & 3 & 4 \\
\bottomrule
\end{tabular}"""
)

def test_col_align_bad_dimensions(self) -> None:
"""Bad dimensions of `col_align` is caught."""
col_index = ["Row 1", "Row 2"]
col_names = ["My Index", "Col 1", "Col 2"]
col_align = ["r", "c"]
mat = np.arange(1, 5).reshape(2, 2)

with pytest.raises(DimensionMismatchError) as exc:
to_tabular(
mat,
col_align=col_align,
col_names=col_names,
col_index=col_index,
)

assert str(exc.value) == (
"Number of `col_align` items (2) doesn't match number of columns (3)"
)

0 comments on commit 42d8bf1

Please sign in to comment.