From df88e02539538e25c1c4fab2dbe2782e360d7c9c Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 3 Mar 2025 14:39:31 +0800 Subject: [PATCH 01/20] Figure.plot/plot3d/text: Pass a dict of vectors rather than x/y/extra_arrays --- pygmt/src/plot.py | 25 +++++++++++++------------ pygmt/src/plot3d.py | 31 ++++++++++++++----------------- pygmt/src/text.py | 17 ++++++++--------- 3 files changed, 35 insertions(+), 38 deletions(-) diff --git a/pygmt/src/plot.py b/pygmt/src/plot.py index 23c5bde12fd..8114f5dfec6 100644 --- a/pygmt/src/plot.py +++ b/pygmt/src/plot.py @@ -50,7 +50,7 @@ w="wrap", ) @kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence") -def plot( +def plot( # noqa: PLR0912 self, data=None, x=None, @@ -232,8 +232,8 @@ def plot( kwargs = self._preprocess(**kwargs) kind = data_kind(data) - extra_arrays = [] - if kind == "empty": # Add more columns for vectors input + if kind == "empty": # Data is given via a series of vectors. + data = {"x": x, "y": y} # Parameters for vector styles if ( isinstance(kwargs.get("S"), str) @@ -241,25 +241,28 @@ def plot( and kwargs["S"][0] in "vV" and is_nonstr_iter(direction) ): - extra_arrays.extend(direction) + data.update({"x2": direction[0], "y2": direction[1]}) # Fill if is_nonstr_iter(kwargs.get("G")): - extra_arrays.append(kwargs.get("G")) + data["fill"] = kwargs["G"] del kwargs["G"] # Size if is_nonstr_iter(size): - extra_arrays.append(size) + data["size"] = size # Intensity and transparency - for flag in ["I", "t"]: + for flag, name in ["I", "intensity"], ["t", "transparency"]: if is_nonstr_iter(kwargs.get(flag)): - extra_arrays.append(kwargs.get(flag)) + data[name] = kwargs[flag] kwargs[flag] = "" # Symbol must be at the last column if is_nonstr_iter(symbol): if "S" not in kwargs: kwargs["S"] = True - extra_arrays.append(symbol) + data["symbol"] = symbol else: + if any(v is not None for v in (x, y)): + msg = "Too much data. Use either data or x/y/z." + raise GMTInvalidInput(msg) for name, value in [ ("direction", direction), ("fill", kwargs.get("G")), @@ -277,7 +280,5 @@ def plot( kwargs["S"] = "s0.2c" with Session() as lib: - with lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, extra_arrays=extra_arrays - ) as vintbl: + with lib.virtualfile_in(check_kind="vector", data=data) as vintbl: lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/plot3d.py b/pygmt/src/plot3d.py index e8e75382d74..6f6a3919d76 100644 --- a/pygmt/src/plot3d.py +++ b/pygmt/src/plot3d.py @@ -51,7 +51,7 @@ w="wrap", ) @kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence") -def plot3d( +def plot3d( # noqa: PLR0912 self, data=None, x=None, @@ -210,9 +210,8 @@ def plot3d( kwargs = self._preprocess(**kwargs) kind = data_kind(data) - extra_arrays = [] - - if kind == "empty": # Add more columns for vectors input + if kind == "empty": # Data is given via a series of vectors. + data = {"x": x, "y": y, "z": z} # Parameters for vector styles if ( isinstance(kwargs.get("S"), str) @@ -220,25 +219,29 @@ def plot3d( and kwargs["S"][0] in "vV" and is_nonstr_iter(direction) ): - extra_arrays.extend(direction) + data.update({"x2": direction[0], "y2": direction[1]}) # Fill if is_nonstr_iter(kwargs.get("G")): - extra_arrays.append(kwargs.get("G")) + data["fill"] = kwargs["G"] del kwargs["G"] # Size if is_nonstr_iter(size): - extra_arrays.append(size) + data["size"] = size # Intensity and transparency - for flag in ["I", "t"]: + for flag, name in [("I", "intensity"), ("t", "transparency")]: if is_nonstr_iter(kwargs.get(flag)): - extra_arrays.append(kwargs.get(flag)) + data[name] = kwargs[flag] kwargs[flag] = "" # Symbol must be at the last column if is_nonstr_iter(symbol): if "S" not in kwargs: kwargs["S"] = True - extra_arrays.append(symbol) + data["symbol"] = symbol else: + if any(v is not None for v in (x, y, z)): + msg = "Too much data. Use either data or x/y/z." + raise GMTInvalidInput(msg) + for name, value in [ ("direction", direction), ("fill", kwargs.get("G")), @@ -257,12 +260,6 @@ def plot3d( with Session() as lib: with lib.virtualfile_in( - check_kind="vector", - data=data, - x=x, - y=y, - z=z, - extra_arrays=extra_arrays, - required_z=True, + check_kind="vector", data=data, required_z=True ) as vintbl: lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/text.py b/pygmt/src/text.py index b507510f620..6eda035f025 100644 --- a/pygmt/src/text.py +++ b/pygmt/src/text.py @@ -222,22 +222,24 @@ def text_( # noqa: PLR0912 elif isinstance(arg, int | float | str): kwargs["F"] += f"{flag}{arg}" - extra_arrays = [] confdict = {} + data = None if kind == "empty": + data = {"x": x, "y": y} + for arg, flag, name in array_args: if is_nonstr_iter(arg): kwargs["F"] += flag # angle is numeric type and font/justify are str type. if name == "angle": - extra_arrays.append(arg) + data["angle"] = arg else: - extra_arrays.append(np.asarray(arg, dtype=np.str_)) + data[name] = np.asarray(arg, dtype=np.str_) # If an array of transparency is given, GMT will read it from the last numerical # column per data record. if is_nonstr_iter(kwargs.get("t")): - extra_arrays.append(kwargs["t"]) + data["transparency"] = kwargs["t"] kwargs["t"] = True # Append text to the last column. Text must be passed in as str type. @@ -247,7 +249,7 @@ def text_( # noqa: PLR0912 text, encoding=encoding ) confdict["PS_CHAR_ENCODING"] = encoding - extra_arrays.append(text) + data["text"] = text else: if isinstance(position, str): kwargs["F"] += f"+c{position}+t{text}" @@ -260,10 +262,7 @@ def text_( # noqa: PLR0912 with Session() as lib: with lib.virtualfile_in( check_kind="vector", - data=textfiles, - x=x, - y=y, - extra_arrays=extra_arrays, + data=textfiles or data, required_data=required_data, ) as vintbl: lib.call_module( From ce57c59680365e900c3244db8ef80db61008fe64 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 3 Mar 2025 15:00:35 +0800 Subject: [PATCH 02/20] Check if a dict of vectors contain None --- pygmt/helpers/utils.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index 94c1d8a8901..09201408e20 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -12,6 +12,7 @@ import time import webbrowser from collections.abc import Iterable, Mapping, Sequence +from itertools import islice from pathlib import Path from typing import Any, Literal @@ -41,7 +42,7 @@ ] -def _validate_data_input( +def _validate_data_input( # noqa: PLR0912 data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None ) -> None: """ @@ -143,6 +144,15 @@ def _validate_data_input( raise GMTInvalidInput(msg) if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset raise GMTInvalidInput(msg) + if kind == "vectors" and isinstance(data, dict): + # Iterator over the up-to-3 first elements. + arrays = list(islice(data.values(), 3)) + if len(arrays) < 2 or any(v is None for v in arrays[:2]): # Check x/y + msg = "Must provide x and y." + raise GMTInvalidInput(msg) + if required_z and (len(arrays) < 3 or arrays[2] is None): # Check z + msg = "Must provide x, y, and z." + raise GMTInvalidInput(msg) def _is_printable_ascii(argstr: str) -> bool: From 2cb9295538f7b6e369ab617fe2bd67e190932e32 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 3 Mar 2025 15:01:08 +0800 Subject: [PATCH 03/20] clib.virtualfile_in: Remove the 'extra_arrays' parameter --- pygmt/clib/session.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 9ab3e37574e..6b29d6183dd 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1772,7 +1772,6 @@ def virtualfile_in( x=None, y=None, z=None, - extra_arrays=None, required_z=False, required_data=True, ): @@ -1794,9 +1793,6 @@ def virtualfile_in( data input. x/y/z : 1-D arrays or None x, y, and z columns as numpy arrays. - extra_arrays : list of 1-D arrays - Optional. A list of numpy arrays in addition to x, y, and z. - All of these arrays must be of the same size as the x/y/z arrays. required_z : bool State whether the 'z' column is required. required_data : bool @@ -1879,8 +1875,6 @@ def virtualfile_in( _data = [x, y] if z is not None: _data.append(z) - if extra_arrays: - _data.extend(extra_arrays) case "vectors": if hasattr(data, "items") and not hasattr(data, "to_frame"): # pandas.DataFrame or xarray.Dataset types. From 362e5bd9efa03cb30d4cf150783d4c982eade67f Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 3 Mar 2025 16:41:21 +0800 Subject: [PATCH 04/20] Figure.plot3d: Add one more test to increase code coverage --- pygmt/tests/test_plot3d.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pygmt/tests/test_plot3d.py b/pygmt/tests/test_plot3d.py index f3a616d75e6..879b3783b9a 100644 --- a/pygmt/tests/test_plot3d.py +++ b/pygmt/tests/test_plot3d.py @@ -88,6 +88,21 @@ def test_plot3d_fail_1d_array_with_data(data, region): fig.plot3d(style="cc", fill="red", transparency=data[:, 2] * 100, **kwargs) +def test_plot3d_fail_no_data(data, region): + """ + Should raise an exception if data is not enough. + """ + fig = Figure() + with pytest.raises(GMTInvalidInput): + fig.plot3d( + style="c0.2c", x=data[0], y=data[1], region=region, projection="X10c" + ) + with pytest.raises(GMTInvalidInput): + fig.plot3d( + style="c0.2c", data=data, x=data[0], region=region, projection="X10c" + ) + + @pytest.mark.mpl_image_compare def test_plot3d_projection(data, region): """ From 8f72e4cad8a4f61fa71b4ffeaa10bbf406036e24 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 3 Mar 2025 16:55:14 +0800 Subject: [PATCH 05/20] Clarify that dictionary will be recognized as vectors --- pygmt/clib/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 6b29d6183dd..8153d4dbbce 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1877,7 +1877,7 @@ def virtualfile_in( _data.append(z) case "vectors": if hasattr(data, "items") and not hasattr(data, "to_frame"): - # pandas.DataFrame or xarray.Dataset types. + # Dictionary, pandas.DataFrame or xarray.Dataset types. # pandas.Series will be handled below like a 1-D numpy.ndarray. _data = [array for _, array in data.items()] else: From 5d3d308643cc59fa2cf71acf6315255e27edf2c0 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 5 Mar 2025 16:01:41 +0800 Subject: [PATCH 06/20] Rename 'required_z' to 'required_ncols' in _validate_data_input --- pygmt/helpers/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index 94c1d8a8901..e14c1c4a15e 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -42,7 +42,7 @@ def _validate_data_input( - data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None + data=None, x=None, y=None, z=None, required_data=True, required_ncols=2, kind=None ) -> None: """ Check if the combination of data/x/y/z is valid. @@ -65,7 +65,7 @@ def _validate_data_input( Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide both x and y. - >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True) + >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_ncols=3) Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z. @@ -73,13 +73,13 @@ def _validate_data_input( >>> import pandas as pd >>> import xarray as xr >>> data = np.arange(8).reshape((4, 2)) - >>> _validate_data_input(data=data, required_z=True, kind="matrix") + >>> _validate_data_input(data=data, required_ncols=3, kind="matrix") Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. >>> _validate_data_input( ... data=pd.DataFrame(data, columns=["x", "y"]), - ... required_z=True, + ... required_ncols=3, ... kind="vectors", ... ) Traceback (most recent call last): @@ -87,7 +87,7 @@ def _validate_data_input( pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. >>> _validate_data_input( ... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])), - ... required_z=True, + ... required_ncols=3, ... kind="vectors", ... ) Traceback (most recent call last): @@ -115,6 +115,7 @@ def _validate_data_input( GMTInvalidInput If the data input is not valid. """ + required_z = required_ncols >= 3 if data is None: # data is None if x is None and y is None: # both x and y are None if required_data: # data is not optional From 406cb39bfd16b5e3b4902a75852f67f0565788d2 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 5 Mar 2025 16:03:19 +0800 Subject: [PATCH 07/20] Rename 'required_z' to 'required_ncols' in Session.virtualfile_in --- pygmt/clib/session.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 9ab3e37574e..96378d92767 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1773,8 +1773,8 @@ def virtualfile_in( y=None, z=None, extra_arrays=None, - required_z=False, required_data=True, + required_ncols=2, ): """ Store any data inside a virtual file. @@ -1797,11 +1797,11 @@ def virtualfile_in( extra_arrays : list of 1-D arrays Optional. A list of numpy arrays in addition to x, y, and z. All of these arrays must be of the same size as the x/y/z arrays. - required_z : bool - State whether the 'z' column is required. required_data : bool Set to True when 'data' is required, or False when dealing with optional virtual files. [Default is True]. + required_ncols + Number of minimum required columns. Returns ------- @@ -1835,8 +1835,8 @@ def virtualfile_in( x=x, y=y, z=z, - required_z=required_z, required_data=required_data, + required_ncols=required_ncols, kind=kind, ) From af189e675e4ef81044ad6474312ac5c18b1f8ba1 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 5 Mar 2025 16:03:34 +0800 Subject: [PATCH 08/20] Rename 'required_z' to 'required_ncols' in Session.virtualfile_in tests --- pygmt/tests/test_clib_virtualfile_in.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pygmt/tests/test_clib_virtualfile_in.py b/pygmt/tests/test_clib_virtualfile_in.py index 8a43c1dc273..35477bcf845 100644 --- a/pygmt/tests/test_clib_virtualfile_in.py +++ b/pygmt/tests/test_clib_virtualfile_in.py @@ -42,7 +42,7 @@ def test_virtualfile_in_required_z_matrix(array_func, kind): data = array_func(dataframe) with clib.Session() as lib: with lib.virtualfile_in( - data=data, required_z=True, check_kind="vector" + data=data, required_ncols=3, check_kind="vector" ) as vfile: with GMTTempFile() as outfile: lib.call_module("info", [vfile, f"->{outfile.name}"]) @@ -64,7 +64,7 @@ def test_virtualfile_in_required_z_matrix_missing(): data = np.ones((5, 2)) with clib.Session() as lib: with pytest.raises(GMTInvalidInput): - with lib.virtualfile_in(data=data, required_z=True, check_kind="vector"): + with lib.virtualfile_in(data=data, required_ncols=3, check_kind="vector"): pass @@ -91,7 +91,7 @@ def test_virtualfile_in_fail_non_valid_data(data): with clib.Session() as lib: with pytest.raises(GMTInvalidInput): lib.virtualfile_in( - x=variable[0], y=variable[1], z=variable[2], required_z=True + x=variable[0], y=variable[1], z=variable[2], required_ncols=3 ) # Should also fail if given too much data From 2d55e5a37241fa7dbd890d4061c8be8d3fadffb3 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 5 Mar 2025 16:04:28 +0800 Subject: [PATCH 09/20] Rename 'required_z' to 'required_ncols' in module wrappers --- pygmt/src/blockm.py | 2 +- pygmt/src/contour.py | 2 +- pygmt/src/nearneighbor.py | 2 +- pygmt/src/plot3d.py | 2 +- pygmt/src/project.py | 2 +- pygmt/src/surface.py | 2 +- pygmt/src/triangulate.py | 4 ++-- pygmt/src/wiggle.py | 2 +- pygmt/src/xyz2grd.py | 2 +- 9 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pygmt/src/blockm.py b/pygmt/src/blockm.py index 581167bcd5d..1afbd51d978 100644 --- a/pygmt/src/blockm.py +++ b/pygmt/src/blockm.py @@ -55,7 +55,7 @@ def _blockm( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=3 ) as vintbl, lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, ): diff --git a/pygmt/src/contour.py b/pygmt/src/contour.py index c5aa26a3b10..a951811a9ed 100644 --- a/pygmt/src/contour.py +++ b/pygmt/src/contour.py @@ -145,7 +145,7 @@ def contour(self, data=None, x=None, y=None, z=None, **kwargs): with Session() as lib: with lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=3 ) as vintbl: lib.call_module( module="contour", args=build_arg_list(kwargs, infile=vintbl) diff --git a/pygmt/src/nearneighbor.py b/pygmt/src/nearneighbor.py index 94cb02bdf68..7a3e0412c01 100644 --- a/pygmt/src/nearneighbor.py +++ b/pygmt/src/nearneighbor.py @@ -141,7 +141,7 @@ def nearneighbor( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=3 ) as vintbl, lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd, ): diff --git a/pygmt/src/plot3d.py b/pygmt/src/plot3d.py index e8e75382d74..e8b4f9b2372 100644 --- a/pygmt/src/plot3d.py +++ b/pygmt/src/plot3d.py @@ -263,6 +263,6 @@ def plot3d( y=y, z=z, extra_arrays=extra_arrays, - required_z=True, + required_ncols=3, ) as vintbl: lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/project.py b/pygmt/src/project.py index a49d5a1ad1f..33aa5dcaac5 100644 --- a/pygmt/src/project.py +++ b/pygmt/src/project.py @@ -245,7 +245,7 @@ def project( x=x, y=y, z=z, - required_z=False, + required_ncols=2, required_data=False, ) as vintbl, lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, diff --git a/pygmt/src/surface.py b/pygmt/src/surface.py index 4d6a828986e..cab9d313e02 100644 --- a/pygmt/src/surface.py +++ b/pygmt/src/surface.py @@ -155,7 +155,7 @@ def surface( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=3 ) as vintbl, lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd, ): diff --git a/pygmt/src/triangulate.py b/pygmt/src/triangulate.py index 1fa60d6a301..19718841f8f 100644 --- a/pygmt/src/triangulate.py +++ b/pygmt/src/triangulate.py @@ -138,7 +138,7 @@ def regular_grid( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=False + check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=2 ) as vintbl, lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd, ): @@ -238,7 +238,7 @@ def delaunay_triples( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=False + check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=2 ) as vintbl, lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, ): diff --git a/pygmt/src/wiggle.py b/pygmt/src/wiggle.py index 921c5317349..917cd753167 100644 --- a/pygmt/src/wiggle.py +++ b/pygmt/src/wiggle.py @@ -108,6 +108,6 @@ def wiggle( with Session() as lib: with lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=3 ) as vintbl: lib.call_module(module="wiggle", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/xyz2grd.py b/pygmt/src/xyz2grd.py index ca7c9e94c6b..e87899c6129 100644 --- a/pygmt/src/xyz2grd.py +++ b/pygmt/src/xyz2grd.py @@ -149,7 +149,7 @@ def xyz2grd( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=3 ) as vintbl, lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd, ): From c199dbf88712277eaf8b5937de9bd5c29156bb29 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Thu, 6 Mar 2025 10:09:08 +0800 Subject: [PATCH 10/20] Clarify that the data parameter can accept dicts --- pygmt/clib/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 8153d4dbbce..4556c0d9ce3 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1787,7 +1787,7 @@ def virtualfile_in( check_kind : str or None Used to validate the type of data that can be passed in. Choose from 'raster', 'vector', or None. Default is None (no validation). - data : str or pathlib.Path or xarray.DataArray or {table-like} or None + data : str or pathlib.Path or xarray.DataArray or {table-like} or dict | None Any raster or vector data format. This could be a file name or path, a raster grid, a vector matrix/arrays, or other supported data input. From 1d0d0dcff12b90e260ac6d3f86d5d29d2791dded Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 10 Mar 2025 15:52:16 +0800 Subject: [PATCH 11/20] Fix a typo --- pygmt/clib/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 4556c0d9ce3..e8f6299885b 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1787,7 +1787,7 @@ def virtualfile_in( check_kind : str or None Used to validate the type of data that can be passed in. Choose from 'raster', 'vector', or None. Default is None (no validation). - data : str or pathlib.Path or xarray.DataArray or {table-like} or dict | None + data : str or pathlib.Path or xarray.DataArray or {table-like} or dict or None Any raster or vector data format. This could be a file name or path, a raster grid, a vector matrix/arrays, or other supported data input. From 5eeb37bcb3117560efef27ee59418065cb1a76e0 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 10 Mar 2025 16:08:06 +0800 Subject: [PATCH 12/20] Rename required_ncols to the shorter ncols --- pygmt/clib/session.py | 6 +++--- pygmt/helpers/utils.py | 12 ++++++------ pygmt/src/blockm.py | 2 +- pygmt/src/contour.py | 2 +- pygmt/src/nearneighbor.py | 2 +- pygmt/src/plot3d.py | 2 +- pygmt/src/project.py | 2 +- pygmt/src/surface.py | 2 +- pygmt/src/triangulate.py | 4 ++-- pygmt/src/wiggle.py | 2 +- pygmt/src/xyz2grd.py | 2 +- pygmt/tests/test_clib_virtualfile_in.py | 10 +++------- 12 files changed, 22 insertions(+), 26 deletions(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 96378d92767..d40546cfdd1 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1774,7 +1774,7 @@ def virtualfile_in( z=None, extra_arrays=None, required_data=True, - required_ncols=2, + ncols=2, ): """ Store any data inside a virtual file. @@ -1800,7 +1800,7 @@ def virtualfile_in( required_data : bool Set to True when 'data' is required, or False when dealing with optional virtual files. [Default is True]. - required_ncols + ncols Number of minimum required columns. Returns @@ -1836,7 +1836,7 @@ def virtualfile_in( y=y, z=z, required_data=required_data, - required_ncols=required_ncols, + ncols=ncols, kind=kind, ) diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index e14c1c4a15e..c8a37a24a5e 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -42,7 +42,7 @@ def _validate_data_input( - data=None, x=None, y=None, z=None, required_data=True, required_ncols=2, kind=None + data=None, x=None, y=None, z=None, required_data=True, ncols=2, kind=None ) -> None: """ Check if the combination of data/x/y/z is valid. @@ -65,7 +65,7 @@ def _validate_data_input( Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide both x and y. - >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_ncols=3) + >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], ncols=3) Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z. @@ -73,13 +73,13 @@ def _validate_data_input( >>> import pandas as pd >>> import xarray as xr >>> data = np.arange(8).reshape((4, 2)) - >>> _validate_data_input(data=data, required_ncols=3, kind="matrix") + >>> _validate_data_input(data=data, ncols=3, kind="matrix") Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. >>> _validate_data_input( ... data=pd.DataFrame(data, columns=["x", "y"]), - ... required_ncols=3, + ... ncols=3, ... kind="vectors", ... ) Traceback (most recent call last): @@ -87,7 +87,7 @@ def _validate_data_input( pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. >>> _validate_data_input( ... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])), - ... required_ncols=3, + ... ncols=3, ... kind="vectors", ... ) Traceback (most recent call last): @@ -115,7 +115,7 @@ def _validate_data_input( GMTInvalidInput If the data input is not valid. """ - required_z = required_ncols >= 3 + required_z = ncols >= 3 if data is None: # data is None if x is None and y is None: # both x and y are None if required_data: # data is not optional diff --git a/pygmt/src/blockm.py b/pygmt/src/blockm.py index 1afbd51d978..1d3d1dd099c 100644 --- a/pygmt/src/blockm.py +++ b/pygmt/src/blockm.py @@ -55,7 +55,7 @@ def _blockm( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=3 + check_kind="vector", data=data, x=x, y=y, z=z, ncols=3 ) as vintbl, lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, ): diff --git a/pygmt/src/contour.py b/pygmt/src/contour.py index a951811a9ed..d7ac6ab75b0 100644 --- a/pygmt/src/contour.py +++ b/pygmt/src/contour.py @@ -145,7 +145,7 @@ def contour(self, data=None, x=None, y=None, z=None, **kwargs): with Session() as lib: with lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=3 + check_kind="vector", data=data, x=x, y=y, z=z, ncols=3 ) as vintbl: lib.call_module( module="contour", args=build_arg_list(kwargs, infile=vintbl) diff --git a/pygmt/src/nearneighbor.py b/pygmt/src/nearneighbor.py index 7a3e0412c01..6d25ecd65eb 100644 --- a/pygmt/src/nearneighbor.py +++ b/pygmt/src/nearneighbor.py @@ -141,7 +141,7 @@ def nearneighbor( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=3 + check_kind="vector", data=data, x=x, y=y, z=z, ncols=3 ) as vintbl, lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd, ): diff --git a/pygmt/src/plot3d.py b/pygmt/src/plot3d.py index e8b4f9b2372..1d7f2b4357a 100644 --- a/pygmt/src/plot3d.py +++ b/pygmt/src/plot3d.py @@ -263,6 +263,6 @@ def plot3d( y=y, z=z, extra_arrays=extra_arrays, - required_ncols=3, + ncols=3, ) as vintbl: lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/project.py b/pygmt/src/project.py index 33aa5dcaac5..7b3afc25127 100644 --- a/pygmt/src/project.py +++ b/pygmt/src/project.py @@ -245,7 +245,7 @@ def project( x=x, y=y, z=z, - required_ncols=2, + ncols=2, required_data=False, ) as vintbl, lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, diff --git a/pygmt/src/surface.py b/pygmt/src/surface.py index cab9d313e02..11d13506efb 100644 --- a/pygmt/src/surface.py +++ b/pygmt/src/surface.py @@ -155,7 +155,7 @@ def surface( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=3 + check_kind="vector", data=data, x=x, y=y, z=z, ncols=3 ) as vintbl, lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd, ): diff --git a/pygmt/src/triangulate.py b/pygmt/src/triangulate.py index 19718841f8f..1daf477a7d4 100644 --- a/pygmt/src/triangulate.py +++ b/pygmt/src/triangulate.py @@ -138,7 +138,7 @@ def regular_grid( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=2 + check_kind="vector", data=data, x=x, y=y, z=z, ncols=2 ) as vintbl, lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd, ): @@ -238,7 +238,7 @@ def delaunay_triples( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=2 + check_kind="vector", data=data, x=x, y=y, z=z, ncols=2 ) as vintbl, lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, ): diff --git a/pygmt/src/wiggle.py b/pygmt/src/wiggle.py index 917cd753167..4b28d5fd640 100644 --- a/pygmt/src/wiggle.py +++ b/pygmt/src/wiggle.py @@ -108,6 +108,6 @@ def wiggle( with Session() as lib: with lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=3 + check_kind="vector", data=data, x=x, y=y, z=z, ncols=3 ) as vintbl: lib.call_module(module="wiggle", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/xyz2grd.py b/pygmt/src/xyz2grd.py index e87899c6129..7b5b355b3f6 100644 --- a/pygmt/src/xyz2grd.py +++ b/pygmt/src/xyz2grd.py @@ -149,7 +149,7 @@ def xyz2grd( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_ncols=3 + check_kind="vector", data=data, x=x, y=y, z=z, ncols=3 ) as vintbl, lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd, ): diff --git a/pygmt/tests/test_clib_virtualfile_in.py b/pygmt/tests/test_clib_virtualfile_in.py index 35477bcf845..1f97abe3bb2 100644 --- a/pygmt/tests/test_clib_virtualfile_in.py +++ b/pygmt/tests/test_clib_virtualfile_in.py @@ -41,9 +41,7 @@ def test_virtualfile_in_required_z_matrix(array_func, kind): ) data = array_func(dataframe) with clib.Session() as lib: - with lib.virtualfile_in( - data=data, required_ncols=3, check_kind="vector" - ) as vfile: + with lib.virtualfile_in(data=data, ncols=3, check_kind="vector") as vfile: with GMTTempFile() as outfile: lib.call_module("info", [vfile, f"->{outfile.name}"]) output = outfile.read(keep_tabs=True) @@ -64,7 +62,7 @@ def test_virtualfile_in_required_z_matrix_missing(): data = np.ones((5, 2)) with clib.Session() as lib: with pytest.raises(GMTInvalidInput): - with lib.virtualfile_in(data=data, required_ncols=3, check_kind="vector"): + with lib.virtualfile_in(data=data, ncols=3, check_kind="vector"): pass @@ -90,9 +88,7 @@ def test_virtualfile_in_fail_non_valid_data(data): continue with clib.Session() as lib: with pytest.raises(GMTInvalidInput): - lib.virtualfile_in( - x=variable[0], y=variable[1], z=variable[2], required_ncols=3 - ) + lib.virtualfile_in(x=variable[0], y=variable[1], z=variable[2], ncols=3) # Should also fail if given too much data with clib.Session() as lib: From e44e712115359a0e2211ff810c4d975679f97dfb Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 26 Mar 2025 16:02:18 +0800 Subject: [PATCH 13/20] Improve docstrings for the new test --- pygmt/tests/test_plot3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygmt/tests/test_plot3d.py b/pygmt/tests/test_plot3d.py index 879b3783b9a..a1f4e306d41 100644 --- a/pygmt/tests/test_plot3d.py +++ b/pygmt/tests/test_plot3d.py @@ -90,7 +90,7 @@ def test_plot3d_fail_1d_array_with_data(data, region): def test_plot3d_fail_no_data(data, region): """ - Should raise an exception if data is not enough. + Should raise an exception if data is not enough or too much. """ fig = Figure() with pytest.raises(GMTInvalidInput): From 5330d0ac4c388bd107877582fefafc353219b4c1 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 26 Mar 2025 16:32:24 +0800 Subject: [PATCH 14/20] Raise warnings when extra_arrays is used --- pygmt/clib/session.py | 16 ++++++++++++++ pygmt/tests/test_clib_virtualfile_in.py | 28 +++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index e8f6299885b..03f3843ee75 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1774,6 +1774,7 @@ def virtualfile_in( z=None, required_z=False, required_data=True, + extra_arrays=None, ): """ Store any data inside a virtual file. @@ -1798,6 +1799,13 @@ def virtualfile_in( required_data : bool Set to True when 'data' is required, or False when dealing with optional virtual files. [Default is True]. + extra_arrays : list of 1-D arrays + Optional. A list of numpy arrays in addition to x, y, and z. All of these + arrays must be of the same size as the x/y/z arrays. + + .. deprecated:: v0.16.0 + The parameter 'extra_arrays' will be removed in v0.20.0. Prepare and pass + a dictionary of arrays instead. E.g., `{"x": x, "y": y, "size": size}`. Returns ------- @@ -1875,6 +1883,14 @@ def virtualfile_in( _data = [x, y] if z is not None: _data.append(z) + if extra_arrays: + msg = ( + "The parameter 'extra_arrays' will be removed in v0.20.0. " + "Prepare and pass a dictionary of arrays instead. E.g., " + "`{'x': x, 'y': y, 'size': size}`." + ) + warnings.warn(message=msg, category=FutureWarning, stacklevel=1) + _data.extend(extra_arrays) case "vectors": if hasattr(data, "items") and not hasattr(data, "to_frame"): # Dictionary, pandas.DataFrame or xarray.Dataset types. diff --git a/pygmt/tests/test_clib_virtualfile_in.py b/pygmt/tests/test_clib_virtualfile_in.py index 8a43c1dc273..dd885f56f0f 100644 --- a/pygmt/tests/test_clib_virtualfile_in.py +++ b/pygmt/tests/test_clib_virtualfile_in.py @@ -128,3 +128,31 @@ def test_virtualfile_in_matrix_string_dtype(): assert output == "347.5 348.5 -30.5 -30\n" # Should check that lib.virtualfile_from_vectors is called once, # not lib.virtualfile_from_matrix, but it's technically complicated. + + +def test_virtualfile_in_extra_arrays(data): + """ + Test that the extra_arrays parameter is deprecated. + """ + with clib.Session() as lib: + # Call the method twice to ensure only one statement in the with block. + # Test that a FutureWarning is raised when extra_arrays is used. + with pytest.warns(FutureWarning): + with lib.virtualfile_in( + check_kind="vector", + x=data[:, 0], + y=data[:, 1], + extra_arrays=[data[:, 2]], + ) as vfile: + pass + # Test that the output is correct. + with GMTTempFile() as outfile: + with lib.virtualfile_in( + check_kind="vector", + x=data[:, 0], + y=data[:, 1], + extra_arrays=[data[:, 2]], + ) as vfile: + lib.call_module("info", [vfile, "-C", f"->{outfile.name}"]) + output = outfile.read(keep_tabs=False) + assert output == "11.5309 61.7074 -2.9289 7.8648 0.1412 0.9338\n" From 0a63ef0a34ffccd920e25aa0ce747e9f602e5200 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 26 Mar 2025 16:43:23 +0800 Subject: [PATCH 15/20] Add TODO comments --- pygmt/clib/session.py | 1 + pygmt/tests/test_clib_virtualfile_in.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 03f3843ee75..572e318f895 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1883,6 +1883,7 @@ def virtualfile_in( _data = [x, y] if z is not None: _data.append(z) + # TODO(PyGMT>=0.20.0): Remove the deprecated parameter 'extra_arrays'. if extra_arrays: msg = ( "The parameter 'extra_arrays' will be removed in v0.20.0. " diff --git a/pygmt/tests/test_clib_virtualfile_in.py b/pygmt/tests/test_clib_virtualfile_in.py index dd885f56f0f..bf7c54b1bba 100644 --- a/pygmt/tests/test_clib_virtualfile_in.py +++ b/pygmt/tests/test_clib_virtualfile_in.py @@ -130,6 +130,7 @@ def test_virtualfile_in_matrix_string_dtype(): # not lib.virtualfile_from_matrix, but it's technically complicated. +# TODO(PyGMT>=0.20.0): Remove the test related to deprecated parameter 'extra_arrays'. def test_virtualfile_in_extra_arrays(data): """ Test that the extra_arrays parameter is deprecated. From 0e0c1d8e212aef35ba567faa7b2c52603e118847 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 26 Mar 2025 16:59:19 +0800 Subject: [PATCH 16/20] Deprecate required_z in a backward-compatible way --- pygmt/clib/session.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index d40546cfdd1..bc0a7de836f 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1765,7 +1765,7 @@ def virtualfile_from_stringio( seg.header = None seg.text = None - def virtualfile_in( + def virtualfile_in( # noqa: PLR0912 self, check_kind=None, data=None, @@ -1775,6 +1775,7 @@ def virtualfile_in( extra_arrays=None, required_data=True, ncols=2, + required_z=False, ): """ Store any data inside a virtual file. @@ -1802,6 +1803,12 @@ def virtualfile_in( optional virtual files. [Default is True]. ncols Number of minimum required columns. + required_z : bool + State whether the 'z' column is required. + + .. deprecated:: v0.16.0 + The parameter 'required_z' will be removed in v0.20.0. Use parameter + 'ncols' instead. E.g., ``required_z=True`` is equivalent to ``ncols=3``. Returns ------- @@ -1829,6 +1836,17 @@ def virtualfile_in( ... print(fout.read().strip()) : N = 3 <7/9> <4/6> <1/3> """ + # TODO(PyGMT>=0.20.0): Remove the deprecated 'required_z' parameter. + if required_z is True: + warnings.warn( + "The parameter 'required_z' is deprecated and will be removed in " + "v0.20.0. Use parameter 'ncols' instead. E.g., ``required_z=True`` is " + "equivalent to ``ncols=3``.", + category=FutureWarning, + stacklevel=1, + ) + ncols = 3 + kind = data_kind(data, required=required_data) _validate_data_input( data=data, From f8293cdf2c63b2fe1a254a8ad52d9c9c3db6b7b9 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 26 Mar 2025 17:11:24 +0800 Subject: [PATCH 17/20] Add one test with the deprecated 'required_z' parameter to increase code coverage --- pygmt/tests/test_clib_virtualfile_in.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pygmt/tests/test_clib_virtualfile_in.py b/pygmt/tests/test_clib_virtualfile_in.py index 1f97abe3bb2..c1616565709 100644 --- a/pygmt/tests/test_clib_virtualfile_in.py +++ b/pygmt/tests/test_clib_virtualfile_in.py @@ -66,6 +66,22 @@ def test_virtualfile_in_required_z_matrix_missing(): pass +# TODO(PyGMT>=0.20.0): Remove this test for the deprecated 'required_z' parameter. +def test_virtualfile_in_required_z_deprecated(): + """ + Same as test_virtualfile_in_required_z_matrix_missing but using the deprecated + 'required_z' parameter. + """ + data = np.ones((5, 2)) + with clib.Session() as lib: + with pytest.raises(GMTInvalidInput): # noqa: PT012 + with pytest.warns(FutureWarning): + with lib.virtualfile_in( + data=data, required_z=True, check_kind="vector" + ): + pass + + def test_virtualfile_in_fail_non_valid_data(data): """ Should raise an exception if too few or too much data is given. From 60f17582e9aeb2d8c11764459fe059cc3d1c9757 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Fri, 21 Feb 2025 23:46:59 +0800 Subject: [PATCH 18/20] Refactor _validate_data_input --- pygmt/clib/session.py | 33 ++++++----- pygmt/helpers/utils.py | 122 ++++++++++++++++------------------------- 2 files changed, 67 insertions(+), 88 deletions(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 572e318f895..f01d9d606e6 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1765,7 +1765,7 @@ def virtualfile_from_stringio( seg.header = None seg.text = None - def virtualfile_in( + def virtualfile_in( # noqa: PLR0912 self, check_kind=None, data=None, @@ -1833,23 +1833,25 @@ def virtualfile_in( ... print(fout.read().strip()) : N = 3 <7/9> <4/6> <1/3> """ + # Specify either data or x/y/z. + if data is not None and any(v is not None for v in (x, y, z)): + msg = "Too much data. Use either data or x/y/z." + raise GMTInvalidInput(msg) + + # Determine the kind of data. kind = data_kind(data, required=required_data) - _validate_data_input( - data=data, - x=x, - y=y, - z=z, - required_z=required_z, - required_data=required_data, - kind=kind, - ) + # Check if the kind of data is valid. if check_kind: valid_kinds = ("file", "arg") if required_data is False else ("file",) - if check_kind == "raster": - valid_kinds += ("grid", "image") - elif check_kind == "vector": - valid_kinds += ("empty", "matrix", "vectors", "geojson") + match check_kind: + case "raster": + valid_kinds += ("grid", "image") + case "vector": + valid_kinds += ("empty", "matrix", "vectors", "geojson") + case _: + msg = f"Invalid value for check_kind: '{check_kind}'." + raise GMTInvalidInput(msg) if kind not in valid_kinds: msg = f"Unrecognized data type for {check_kind}: {type(data)}." raise GMTInvalidInput(msg) @@ -1909,6 +1911,9 @@ def virtualfile_in( _virtualfile_from = self.virtualfile_from_vectors _data = data.T + # Check if _data to be passed to the virtualfile_from_ function is valid. + _validate_data_input(data=_data, kind=kind, required_z=required_z) + # Finally create the virtualfile from the data, to be passed into GMT file_context = _virtualfile_from(_data) return file_context diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index 09201408e20..75d8b705d2b 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -12,7 +12,6 @@ import time import webbrowser from collections.abc import Iterable, Mapping, Sequence -from itertools import islice from pathlib import Path from typing import Any, Literal @@ -40,118 +39,97 @@ "ISO-8859-15", "ISO-8859-16", ] +# Type hints for the list of possible data kinds. +Kind = Literal[ + "arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors" +] -def _validate_data_input( # noqa: PLR0912 - data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None -) -> None: +def _validate_data_input(data: Any, kind: Kind, required_z: bool = False) -> None: """ - Check if the combination of data/x/y/z is valid. + Check if the data to be passed to the virtualfile_from_ functions is valid. Examples -------- - >>> _validate_data_input(data="infile") - >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6]) - >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], z=[7, 8, 9]) - >>> _validate_data_input(data=None, required_data=False) - >>> _validate_data_input() + The "empty" kind means the data is given via a series of vectors like x/y/z. + + >>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6]], kind="empty") + >>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], kind="empty") + >>> _validate_data_input(data=[None, [4, 5, 6]], kind="empty") Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: No input data provided. - >>> _validate_data_input(x=[1, 2, 3]) + pygmt.exceptions.GMTInvalidInput: Must provide both x and y. + >>> _validate_data_input(data=[[1, 2, 3], None], kind="empty") Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide both x and y. - >>> _validate_data_input(y=[4, 5, 6]) + >>> _validate_data_input(data=[None, None], kind="empty") Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide both x and y. - >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True) + >>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6]], kind="empty", required_z=True) Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z. + + The "matrix" kind means the data is given via a 2-D numpy.ndarray. + >>> import numpy as np >>> import pandas as pd >>> import xarray as xr >>> data = np.arange(8).reshape((4, 2)) - >>> _validate_data_input(data=data, required_z=True, kind="matrix") + >>> _validate_data_input(data=data, kind="matrix", required_z=True) Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. + pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given. + + The "vectors" kind means the original data is either dictionary, list, tuple, + pandas.DataFrame, pandas.Series, xarray.Dataset, or xarray.DataArray. + >>> _validate_data_input( ... data=pd.DataFrame(data, columns=["x", "y"]), - ... required_z=True, ... kind="vectors", + ... required_z=True, ... ) Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. + pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given. >>> _validate_data_input( ... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])), - ... required_z=True, ... kind="vectors", + ... required_z=True, ... ) Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. - >>> _validate_data_input(data="infile", x=[1, 2, 3]) - Traceback (most recent call last): - ... - pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z. - >>> _validate_data_input(data="infile", y=[4, 5, 6]) - Traceback (most recent call last): - ... - pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z. - >>> _validate_data_input(data="infile", x=[1, 2, 3], y=[4, 5, 6]) - Traceback (most recent call last): - ... - pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z. - >>> _validate_data_input(data="infile", z=[7, 8, 9]) - Traceback (most recent call last): - ... - pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z. + pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given. Raises ------ GMTInvalidInput If the data input is not valid. """ - if data is None: # data is None - if x is None and y is None: # both x and y are None - if required_data: # data is not optional - msg = "No input data provided." + # Determine the required number of columns based on the required_z flag. + required_cols = 3 if required_z else 1 + + match kind: + case "empty": # data = [x, y], [x, y, z], [x, y, z, ...] + if len(data) < 2 or any(v is None for v in data[:2]): + msg = "Must provide both x and y." raise GMTInvalidInput(msg) - elif x is None or y is None: # either x or y is None - msg = "Must provide both x and y." - raise GMTInvalidInput(msg) - if required_z and z is None: # both x and y are not None, now check z - msg = "Must provide x, y, and z." - raise GMTInvalidInput(msg) - else: # data is not None - if x is not None or y is not None or z is not None: - msg = "Too much data. Use either data or x/y/z." - raise GMTInvalidInput(msg) - # check if data has the required z column - if required_z: - msg = "data must provide x, y, and z columns." - if kind == "matrix" and data.shape[1] < 3: + if required_z and (len(data) < 3 or data[:3] is None): + msg = "Must provide x, y, and z." raise GMTInvalidInput(msg) - if kind == "vectors": - if hasattr(data, "shape") and ( - (len(data.shape) == 1 and data.shape[0] < 3) - or (len(data.shape) > 1 and data.shape[1] < 3) - ): # np.ndarray or pd.DataFrame - raise GMTInvalidInput(msg) - if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset - raise GMTInvalidInput(msg) - if kind == "vectors" and isinstance(data, dict): - # Iterator over the up-to-3 first elements. - arrays = list(islice(data.values(), 3)) - if len(arrays) < 2 or any(v is None for v in arrays[:2]): # Check x/y - msg = "Must provide x and y." + case "matrix": # 2-D numpy.ndarray + if (actual_cols := data.shape[1]) < required_cols: + msg = f"Need at least {required_cols} columns but {actual_cols} column(s) are given." raise GMTInvalidInput(msg) - if required_z and (len(arrays) < 3 or arrays[2] is None): # Check z - msg = "Must provide x, y, and z." + case "vectors": + # "vectors" means the original data is either dictionary, list, tuple, + # pandas.DataFrame, pandas.Series, xarray.Dataset, or xarray.DataArray. + # The original data is converted to a list of vectors or a 2-D numpy.ndarray + # in the virtualfile_in function. + if (actual_cols := len(data)) < required_cols: + msg = f"Need at least {required_cols} columns but {actual_cols} column(s) are given." raise GMTInvalidInput(msg) @@ -271,11 +249,7 @@ def _check_encoding(argstr: str) -> Encoding: return "ISOLatin1+" -def data_kind( - data: Any, required: bool = True -) -> Literal[ - "arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors" -]: +def data_kind(data: Any, required: bool = True) -> Kind: r""" Check the kind of data that is provided to a module. From b7ca239f84f373490d600cae6bad55f35b3808b5 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Fri, 28 Mar 2025 21:43:29 +0800 Subject: [PATCH 19/20] Move TODO comment to the top --- pygmt/clib/session.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index bc0a7de836f..c8352421c8a 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1765,6 +1765,7 @@ def virtualfile_from_stringio( seg.header = None seg.text = None + # TODO(PyGMT>=0.20.0): Remove the deprecated parameter 'required_z'. def virtualfile_in( # noqa: PLR0912 self, check_kind=None, @@ -1836,12 +1837,11 @@ def virtualfile_in( # noqa: PLR0912 ... print(fout.read().strip()) : N = 3 <7/9> <4/6> <1/3> """ - # TODO(PyGMT>=0.20.0): Remove the deprecated 'required_z' parameter. if required_z is True: warnings.warn( - "The parameter 'required_z' is deprecated and will be removed in " - "v0.20.0. Use parameter 'ncols' instead. E.g., ``required_z=True`` is " - "equivalent to ``ncols=3``.", + "The parameter 'required_z' is deprecated in v0.16.0 and will be " + "removed in v0.20.0. Use parameter 'ncols' instead. E.g., " + "``required_z=True`` is equivalent to ``ncols=3``.", category=FutureWarning, stacklevel=1, ) From da874d107b628a3f1c38a2b5e910720e0c354651 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Tue, 15 Apr 2025 10:53:42 +0800 Subject: [PATCH 20/20] Put ncols before required_data --- pygmt/clib/session.py | 2 +- pygmt/helpers/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index a53c152c5bc..f5b2c26d684 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1842,8 +1842,8 @@ def virtualfile_in( # noqa: PLR0912 x=x, y=y, z=z, - required_data=required_data, ncols=ncols, + required_data=required_data, kind=kind, ) diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index 4be83b10ced..5c4fa66c614 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -43,7 +43,7 @@ def _validate_data_input( # noqa: PLR0912 - data=None, x=None, y=None, z=None, required_data=True, ncols=2, kind=None + data=None, x=None, y=None, z=None, ncols=2, required_data=True, kind=None ) -> None: """ Check if the combination of data/x/y/z is valid.