Skip to content

Commit

Permalink
Ensure consistent string type for geos in InputData and fix tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 709908398
  • Loading branch information
andyl7an authored and The Meridian Authors committed Dec 27, 2024
1 parent b7027a1 commit fa61ba6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
-->

## [Unreleased]
* Convert `InputData` geo coordinates to strings upon initialization to avoid
type mismatches with `GeoInfo` proto which expects strings.

* Add `get_historical_spend` method to `Analyzer` class.

## [0.14.0] - 2024-12-17
Expand Down
10 changes: 9 additions & 1 deletion meridian/data/input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class InputData:
non_media_treatments: xr.DataArray | None = None

def __post_init__(self):
self._convert_geos_to_strings()
self._validate_kpi()
self._validate_scenarios()
self._validate_names()
Expand All @@ -241,10 +242,17 @@ def __post_init__(self):
self._validate_time_formats()
self._validate_times()

def _convert_geos_to_strings(self):
"""Converts geo coordinates to strings in all relevant DataArrays."""
for field in dataclasses.fields(self):
array = getattr(self, field.name)
if isinstance(array, xr.DataArray) and constants.GEO in array.dims:
array.coords[constants.GEO] = array.coords[constants.GEO].astype(str)

@property
def geo(self) -> xr.DataArray:
"""Returns the geo dimension."""
return self.kpi[constants.GEO].astype(str)
return self.kpi[constants.GEO]

@property
def time(self) -> xr.DataArray:
Expand Down
1 change: 1 addition & 0 deletions meridian/data/input_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ def test_geo_property_returns_strings(self):
data.geo.values.tolist(),
["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
)
self.assertIn("1", data.population.coords[constants.GEO].values)

def test_properties_media_and_rf(self):
data = input_data.InputData(
Expand Down

0 comments on commit fa61ba6

Please sign in to comment.