Skip to content

Commit

Permalink
Initial work on supporting units in the scatter viewer
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofrog committed Aug 12, 2024
1 parent 6cb7094 commit 89e6aab
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 6 deletions.
88 changes: 87 additions & 1 deletion glue/viewers/scatter/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from glue.core import BaseData, Subset
from glue.core import BaseData, Subset, Data

from glue.config import colormaps
from glue.viewers.matplotlib.state import (MatplotlibDataViewerState,
Expand All @@ -14,6 +14,7 @@
from glue.core.data_combo_helper import ComponentIDComboHelper, ComboHelper
from glue.core.exceptions import IncompatibleAttribute
from glue.viewers.common.stretch_state_mixin import StretchStateMixin
from glue.core.units import find_unit_choices, UnitConverter

from matplotlib.projections import get_projection_names

Expand All @@ -34,6 +35,9 @@ class ScatterViewerState(MatplotlibDataViewerState):
x_limits_percentile = DDCProperty(100, docstring="Percentile to use when automatically determining x limits")
y_limits_percentile = DDCProperty(100, docstring="Percentile to use when automatically determining y limits")

x_display_unit = DDSCProperty(docstring='The units to use to display the x-axis.')
y_display_unit = DDSCProperty(docstring='The units to use to display the y-axis')

def __init__(self, **kwargs):

super(ScatterViewerState, self).__init__()
Expand All @@ -43,15 +47,20 @@ def __init__(self, **kwargs):
self.x_lim_helper = StateAttributeLimitsHelper(self, attribute='x_att',
lower='x_min', upper='x_max',
log='x_log', margin=0.04,
display_units='x_display_unit',
limits_cache=self.limits_cache)

self.y_lim_helper = StateAttributeLimitsHelper(self, attribute='y_att',
lower='y_min', upper='y_max',
log='y_log', margin=0.04,
display_units='y_display_unit',
limits_cache=self.limits_cache)

self.add_callback('layers', self._layers_changed)

# self.add_callback('x_display_unit', self._convert_units_x_limits, echo_old=True)
# self.add_callback('y_display_unit', self._convert_units_y_limits, echo_old=True)

self.x_att_helper = ComponentIDComboHelper(self, 'x_att', pixel_coord=True, world_coord=True)
self.y_att_helper = ComponentIDComboHelper(self, 'y_att', pixel_coord=True, world_coord=True)

Expand All @@ -68,6 +77,9 @@ def __init__(self, **kwargs):
self.add_callback('x_log', self._reset_x_limits)
self.add_callback('y_log', self._reset_y_limits)

self.add_callback('x_att', self._update_x_display_unit_choices)
self.add_callback('y_att', self._update_y_display_unit_choices)

if self.using_polar:
self.full_circle()

Expand Down Expand Up @@ -197,6 +209,80 @@ def _layers_changed(self, *args):

self._layers_data_cache = layers_data

# def _convert_units_x_limits(self, old_unit, new_unit):

# print(repr(old_unit), repr(new_unit))

# if old_unit != new_unit:

# if old_unit is None or new_unit is None:
# self._reset_x_limits()
# return

# limits = np.array([self.x_min, self.x_max])

# converter = UnitConverter()

# data = self.x_att.parent

# limits_native = converter.to_native(data, self.x_att, limits, old_unit)

# limits_new = converter.to_unit(data, self.x_att, limits_native, new_unit)

# with delay_callback(self, 'x_min', 'x_max'):
# print('Setting limits to new', limits_new)
# self.x_min, self.x_max = sorted(limits_new)

# def _convert_units_y_limits(self, old_unit, new_unit):

# if old_unit != new_unit:

# if old_unit is None or new_unit is None:
# self._reset_y_limits()
# return

# limits = np.array([self.y_min, self.y_max])

# converter = UnitConverter()

# data = self.y_att.parent

# limits_native = converter.to_native(data, self.y_att, limits, old_unit)

# limits_new = converter.to_unit(data, self.y_att, limits_native, new_unit)

# with delay_callback(self, 'y_min', 'y_max'):
# self.y_min, self.y_max = sorted(limits_new)

def _update_x_display_unit_choices(self, *args):

# NOTE: only Data and its subclasses support specifying units
if self.x_att is None or not isinstance(self.x_att.parent, Data):
ScatterViewerState.x_display_unit.set_choices(self, [])
return

component = self.x_att.parent.get_component(self.x_att)
if component.units:
x_choices = find_unit_choices([(self.x_att.parent, self.x_att, component.units)])
else:
x_choices = ['']
ScatterViewerState.x_display_unit.set_choices(self, x_choices)
self.x_display_unit = component.units

def _update_y_display_unit_choices(self, *args):

# NOTE: only Data and its subclasses support specifying units
if self.y_att is None or not isinstance(self.y_att.parent, Data):
ScatterViewerState.y_display_unit.set_choices(self, [])
return

component = self.y_att.parent.get_component(self.y_att)
if component.units:
y_choices = find_unit_choices([(self.y_att.parent, self.y_att, component.units)])
else:
y_choices = ['']
ScatterViewerState.y_display_unit.set_choices(self, y_choices)
self.y_display_unit = component.units

def display_func_slow(x):
if x == 'Linear':
Expand Down
175 changes: 173 additions & 2 deletions glue/viewers/scatter/tests/test_viewer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_equal

import matplotlib.pyplot as plt

Expand All @@ -8,8 +8,9 @@
from glue.viewers.scatter.viewer import SimpleScatterViewer
from glue.core.application_base import Application
from glue.core.data import Data
from glue.core.link_helpers import LinkSame
from glue.core.link_helpers import LinkSame, LinkSameWithUnits
from glue.core.data_derived import IndexedData
from glue.core.roi import RectangularROI


@visual_test
Expand Down Expand Up @@ -131,3 +132,173 @@ def test_indexed_data():

assert viewer.state.x_att is data_2d.main_components[0]
assert viewer.state.y_att is data_2d.main_components[1]


def test_unit_conversion():

d1 = Data(a=[1, 2, 3], b=[2, 3, 4])
d1.get_component('a').units = 'm'
d1.get_component('b').units = 's'

d2 = Data(c=[2000, 1000, 3000], d=[0.001, 0.002, 0.004])
d2.get_component('c').units = 'mm'
d2.get_component('d').units = 'ks'

# d3 is the same as d2 but we will link it differently
d3 = Data(e=[2000, 1000, 3000], f=[0.001, 0.002, 0.004])
d3.get_component('e').units = 'mm'
d3.get_component('f').units = 'ks'

d4 = Data(g=[2, 2, 3], h=[1, 2, 1])
d4.get_component('g').units = 'kg'
d4.get_component('h').units = 'm/s'

app = Application()
session = app.session

data_collection = session.data_collection
data_collection.append(d1)
data_collection.append(d2)
data_collection.append(d3)
data_collection.append(d4)

data_collection.add_link(LinkSameWithUnits(d1.id['a'], d2.id['c']))
data_collection.add_link(LinkSameWithUnits(d1.id['b'], d2.id['d']))
data_collection.add_link(LinkSame(d1.id['a'], d3.id['e']))
data_collection.add_link(LinkSame(d1.id['b'], d3.id['f']))
data_collection.add_link(LinkSame(d1.id['a'], d4.id['g']))
data_collection.add_link(LinkSame(d1.id['b'], d4.id['h']))

viewer = app.new_data_viewer(SimpleScatterViewer)
viewer.add_data(d1)
viewer.add_data(d2)
viewer.add_data(d3)
viewer.add_data(d4)

assert viewer.layers[0].enabled
assert viewer.layers[1].enabled
assert viewer.layers[2].enabled
assert viewer.layers[3].enabled

assert viewer.state.x_min == 0.92
assert viewer.state.x_max == 3.08
assert viewer.state.y_min == 1.92
assert viewer.state.y_max == 4.08

roi = RectangularROI(0.5, 2.5, 1.5, 4.5)
viewer.apply_roi(roi)

assert len(d1.subsets) == 1
assert_equal(d1.subsets[0].to_mask(), [1, 1, 0])

# Because of the LinkSameWithUnits, the points actually appear in the right
# place even before we set the display units.
assert len(d2.subsets) == 1
assert_equal(d2.subsets[0].to_mask(), [0, 1, 0])

# d3 is only linked with LinkSame not LinkSameWithUnits so currently the
# points are outside the visible axes
assert len(d3.subsets) == 1
assert_equal(d3.subsets[0].to_mask(), [0, 0, 0])

# As we haven't set display units yet, the values for this dataset are shown
# on the same scale as for d1 as if the units had never been set.
assert len(d4.subsets) == 1
assert_equal(d4.subsets[0].to_mask(), [0, 1, 0])

# Now try setting the units explicitly

viewer.state.x_display_unit = 'km'
viewer.state.y_display_unit = 'ms'

assert_allclose(viewer.state.x_min, 0.92e-3)
assert_allclose(viewer.state.x_max, 3.08e-3)
assert_allclose(viewer.state.y_min, 1.92e3)
assert_allclose(viewer.state.y_max, 4.08e3)

roi = RectangularROI(0.5e-3, 2.5e-3, 1.5e3, 4.5e3)
viewer.apply_roi(roi)

# d1 and d2 will be as above, but d3 will now work correctly while d4 should
# not be shown.

assert_equal(d1.subsets[1].to_mask(), [1, 1, 0])
assert_equal(d2.subsets[1].to_mask(), [0, 1, 0])
assert_equal(d3.subsets[1].to_mask(), [0, 0, 0])
assert_equal(d4.subsets[1].to_mask(), [0, 1, 0])


# # Change the limits to make sure they are always converted
# viewer.state.x_min = 5e8
# viewer.state.x_max = 4e9
# viewer.state.y_min = 0.5
# viewer.state.y_max = 3.5

# roi = XRangeROI(1.4e9, 2.1e9)
# viewer.apply_roi(roi)

# assert len(d1.subsets) == 1
# assert_equal(d1.subsets[0].to_mask(), [0, 1, 0])

# assert len(d2.subsets) == 1
# assert_equal(d2.subsets[0].to_mask(), [0, 1, 0])

# viewer.state.x_display_unit = 'GHz'
# viewer.state.y_display_unit = 'mJy'

# x, y = viewer.state.layers[0].profile
# assert_allclose(x, [1, 2, 3])
# assert_allclose(y, [1000, 2000, 3000])

# x, y = viewer.state.layers[1].profile
# assert_allclose(x, 2.99792458 / np.array([1, 2, 3]))
# assert_allclose(y, [2000, 1000, 3000])

# assert viewer.state.x_min == 0.5
# assert viewer.state.x_max == 4.

# # Units get reset because they were originally 'native' and 'native' to a
# # specific unit always trigger resetting the limits since different datasets
# # might be converted in different ways.
# assert viewer.state.y_min == 1000.
# assert viewer.state.y_max == 3000.

# # Now set the limits explicitly again and make sure in future they are converted
# viewer.state.y_min = 500.
# viewer.state.y_max = 3500.

# roi = XRangeROI(0.5, 1.2)
# viewer.apply_roi(roi)

# assert len(d1.subsets) == 1
# assert_equal(d1.subsets[0].to_mask(), [1, 0, 0])

# assert len(d2.subsets) == 1
# assert_equal(d2.subsets[0].to_mask(), [0, 0, 1])

# viewer.state.x_display_unit = 'cm'
# viewer.state.y_display_unit = 'Jy'

# roi = XRangeROI(15, 35)
# viewer.apply_roi(roi)

# assert len(d1.subsets) == 1
# assert_equal(d1.subsets[0].to_mask(), [1, 0, 0])

# assert len(d2.subsets) == 1
# assert_equal(d2.subsets[0].to_mask(), [0, 1, 1])

# assert_allclose(viewer.state.x_min, (4 * u.GHz).to_value(u.cm, equivalencies=u.spectral()))
# assert_allclose(viewer.state.x_max, (0.5 * u.GHz).to_value(u.cm, equivalencies=u.spectral()))
# assert_allclose(viewer.state.y_min, 0.5)
# assert_allclose(viewer.state.y_max, 3.5)

# # Regression test for a bug that caused unit changes to not work on y axis
# # if reference data was not first layer

# viewer.state.reference_data = d2
# viewer.state.y_display_unit = 'mJy'


# data_collection.add_link(LinkSame(d1.id['a'], d2.id['e']))
# data_collection.add_link(LinkSame(d1.id['b'], d2.id['f']))
30 changes: 27 additions & 3 deletions glue/viewers/scatter/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from glue.viewers.matplotlib.viewer import SimpleMatplotlibViewer
from glue.viewers.scatter.state import ScatterViewerState
from glue.viewers.scatter.layer_artist import ScatterLayerArtist
from glue.core.units import UnitConverter


__all__ = ['MatplotlibScatterMixin', 'SimpleScatterViewer']

Expand Down Expand Up @@ -152,9 +154,31 @@ def apply_roi(self, roi, override_mode=None):
x_date = 'datetime' in self.state.x_kinds
y_date = 'datetime' in self.state.y_kinds

if x_date or y_date:
roi = roi.transformed(xfunc=mpl_to_datetime64 if x_date else None,
yfunc=mpl_to_datetime64 if y_date else None)
converter = UnitConverter()

xfunc = None
if x_date:
xfunc = mpl_to_datetime64
else:
if self.state.x_display_unit:
xfunc = lambda x: converter.to_native(self.state.x_att.parent,
self.state.x_att, x,
self.state.x_display_unit)

yfunc = None
if y_date:
yfunc = mpl_to_datetime64
else:
if self.state.y_display_unit:
yfunc = lambda y: converter.to_native(self.state.y_att.parent,
self.state.y_att, y,
self.state.y_display_unit)

print(xfunc)
print(yfunc)

if xfunc or yfunc:
roi = roi.transformed(xfunc=xfunc, yfunc=yfunc)

use_transform = not self.using_rectilinear()
subset_state = roi_to_subset_state(roi,
Expand Down

0 comments on commit 89e6aab

Please sign in to comment.