Skip to content

Commit

Permalink
ENH: working code, needs some documentation improvements and possible…
Browse files Browse the repository at this point in the history
… exposing of other functions, but almost ready.
  • Loading branch information
marklescroart committed Oct 25, 2023
1 parent 188c1ea commit 93081e6
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 92 deletions.
1 change: 1 addition & 0 deletions cortex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from cortex.volume import mosaic, unmask
import cortex.export
from cortex.version import __version__, __full_version__
from cortex import dartboards

try:
from cortex import formats
Expand Down
193 changes: 101 additions & 92 deletions cortex/dartboards.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import polyutils
from . import quickflat

from .dataset.views import Vertex
from .dataset.views import Vertex, Vertex2D
from scipy import interpolate
from scipy.spatial import ConvexHull
from .utils import get_cmap
Expand Down Expand Up @@ -329,9 +329,9 @@ def _angles_from_vertex(overlay, start_idx, target_idxs=None, period=2*np.pi, th
overlay : SVGOverlay object
svg overlay object for a subject
start_idx : int
idk
index for center vertex from which to compute angles
target_idxs : list, optional
idk, by default None
index for anchor vertex to which to compute angles, by default None
period : scalar, optional
max angle, by default 2*np.pi
theta_direction : str, optional
Expand Down Expand Up @@ -465,7 +465,7 @@ def _get_closest_vertex_to_roi(overlay, roi, comparison_roi, roi_full=False, com
if return_indices:
nearest_vertices.append(roi_verts[hemi][nearest_vertex])
else:
nearest_vertices.append(hemi_roi_coords[:, nearest_vertex])
nearest_vertices.append(hemi_roi_coords[nearest_vertex, :])
return nearest_vertices


Expand Down Expand Up @@ -603,6 +603,7 @@ def show_dartboard(data,
vmax2=None,
theta_direction=-1,
show_grid=True,
max_radius=None,
grid_linewidth=0.5,
grid_linecolor='lightgray'):
"""Given values masked by angle and eccentricity, shows them as a radial grid ('dartboard'-style visualization).
Expand Down Expand Up @@ -641,7 +642,8 @@ def show_dartboard(data,
plt.Axes
Matplotlib axis in which data is plotted.
"""
max_radius = 1
if max_radius is None:
max_radius = 1
data = np.array(data).astype(np.float)
if isinstance(data2, np.ndarray):
data = np.stack([data, data2], axis=0)
Expand Down Expand Up @@ -1105,7 +1107,11 @@ def show_dartboard_pair(dartboard_data,
# Loop over hemispheres
# Hemisphere index goes (0, 1) = (left, right)
directions = [-1, 1]
max_radii = dartboard_spec['max_radii']
if not isinstance(max_radii, (list, tuple)):
max_radii = [max_radii, max_radii]
for hemi_index, data in enumerate(zip(data0, data1)):
max_radius = max_radii[hemi_index]
axis = axes[hemi_index]
d0, d1 = data
_ = show_dartboard(d0, data2=d1,
Expand All @@ -1116,6 +1122,7 @@ def show_dartboard_pair(dartboard_data,
show_grid=show_grid,
grid_linewidth=grid_linewidth,
grid_linecolor=grid_linecolor,
max_radius=max_radius,
vmin=vmin,
vmax=vmax,
vmin2=vmin2,
Expand Down Expand Up @@ -1435,14 +1442,8 @@ def _get_dartboard_str(**dartboard_spec):
fname = '_center{center_roi}-anchors_{anchor_str}-{n_angles}ang-{n_eccentricities}ecc-{rad_str}'
return fname.format(**fmt)

def _compute_centroids_angles_from_spec(svg, center_roi, anchors, verbose=False, **kwargs):
"""kwargs catches extra inputs from dartboard_spec.
Perhaps not the most principled.
functional.
"""
# Compute centroids of each ROI
t0 = time.time()
# Compute centroids

def _get_anchor_points(svg, center_roi, anchors, return_indices=True):
centroids = {}
for j, anchor in enumerate([center_roi] + anchors):
if isinstance(anchor, tuple):
Expand All @@ -1452,15 +1453,30 @@ def _compute_centroids_angles_from_spec(svg, center_roi, anchors, verbose=False,
if j == 0:
center_name, center_type = anchor_name, anchor_type
if anchor_type == 'nearest':
centroids[anchor_name] = _get_closest_vertex_to_roi(svg, anchor_name, center_name)
centroids[anchor_name] = _get_closest_vertex_to_roi(svg, anchor_name, center_name, return_indices=return_indices)
elif anchor_type == 'centroid':
centroids[anchor_name] = get_roi_centroids(svg, anchor_name)
centroids[anchor_name] = get_roi_centroids(svg, anchor_name, return_indices=return_indices)
else:
raise ValueError("unknown anchor type specified: %s\n(Must be 'nearest' or 'centroid')"%(anchor_type))
raise ValueError("unknown anchor type specified: %s\n(Must be 'nearest' or 'centroid')"%(anchor_type))
return centroids


def _compute_centroids_angles_from_spec(svg, center_roi, anchors, verbose=False, **kwargs):
"""kwargs catches extra inputs from dartboard_spec.
Perhaps not the most principled.
functional.
"""
# Compute centroids of each ROI
t0 = time.time()
# Compute centroids
centroids = _get_anchor_points(svg, center_roi, anchors)
t1 = time.time()
if verbose:
print('Time to get centroids:', t1 - t0)

if isinstance(center_roi, tuple) and (len(center_roi) == 2):
center_name, _ = center_roi
else:
center_name = center_roi
# Compute angles from center ROI to each of the anchors
anchor_angles_dict = {}
for anchor in anchors:
Expand Down Expand Up @@ -1562,6 +1578,10 @@ def get_dartboard_data(vertex_obj,

# Compute the masks, based on specified eccentricity bins, angle bins, and the previously-computed variables
t2 = time.time()
# NOTE: magic number here, not great. 4 sub-bins only works specifically
# for 4 anchors and 16 total anglular bins. This could be re-computed with
# an assumption of even spacing, or fundamentally changed by specifying
# the desired angles of the anchor points
masks = compute_eccentricity_angle_masks(
svg, centroids[center_name],
eccentricities=eccentricities,
Expand Down Expand Up @@ -1589,12 +1609,13 @@ def dartboard_on_flatmap(vertex_data,
cmap=None,
# Dartboard args
center_roi=None,
anchor_rois=None,
display_rois=None,
anchors=None,
rois=None,
n_angles=16,
n_eccentricities=8,
max_radii=(50, 50),
surf_type='inflated',
eccentricities=None,
#surf_type='inflated', # Was for choosing which
# Plotting args
figsize=(12, 6),
dartboard_axes_dist_from_midline=0.15,
Expand All @@ -1605,7 +1626,7 @@ def dartboard_on_flatmap(vertex_data,
quickflat_kw=None,
flatmap_line_linewidth=1.5,
flatmap_line_color='y',
flatmap_line_style=None,
flatmap_line_style=('--','-','--','-'),
show_anchor_lines=None,
show_dartboard_grid=True,
show_dartboard_edge=True,
Expand All @@ -1614,10 +1635,9 @@ def dartboard_on_flatmap(vertex_data,
n_roi_border_points=64,
roi_outline_smooth_factor=5, # every 5th point kept, smoothed with cubic spline
roi_border_kw=None,
cache_dir=None,
verbose=False,
outline_kw=None,
**kwargs):
):
"""Make a flatmap with overlaid dartboard plots
Parameters
Expand Down Expand Up @@ -1652,21 +1672,31 @@ def dartboard_on_flatmap(vertex_data,
"""
if verbose:
print("Getting masks...")
masks = generate_dartboard_masks(
vertex_data.subject,
center_roi,
anchor_rois,
n_angles=n_angles,
n_eccentricities=n_eccentricities,
max_radii=max_radii,
cache_dir=cache_dir,)
# Allow manually specified eccentricities to override linear spacing
if eccentricities is None:
eccentricities = [np.linspace(
0, mr, n_eccentricities + 1) for mr in max_radii]
dartboard_spec = dict(center_roi=center_roi,
anchors=anchors,
n_angles=n_angles,
n_eccentricities=n_eccentricities,
eccentricities=eccentricities,
max_radii=max_radii,)
# masks = generate_dartboard_masks(
# vertex_data.subject,
# center_roi,
# anchor_rois,
# n_angles=n_angles,
# n_eccentricities=n_eccentricities,
# max_radii=max_radii,
# cache_dir=cache_dir,)
# Manage inputs
if isinstance(vertex_data, cx.Vertex):
if isinstance(vertex_data, Vertex):
if vmin is None:
vmin = vertex_data.vmin
if vmax is None:
vmax = vertex_data.vmax
elif isinstance(vertex_data, cx.Vertex2D):
elif isinstance(vertex_data, Vertex2D):
if vmin is None:
vmin = (vertex_data.vmin, vertex_data.vmin2)
if vmax is None:
Expand All @@ -1675,39 +1705,37 @@ def dartboard_on_flatmap(vertex_data,
raise NotImplementedError("No VertexRGB yet!")
if cmap is None:
cmap = vertex_data.cmap
#print(vmin, vmax)
if flatmap_line_style is None:
flatmap_line_style = '-'
if quickflat_kw is None:
quickflat_kw = {}
quickflat_kw = {'with_curvature' : True, }
if roi_border_kw is None:
roi_border_kw = {}
if anchor_rois is None:
anchor_rois = []
if outline_kw is None:
outline_kw = {}
# Get masked data
data_lh, data_rh = get_dartboard_data(vertex_data, masks,
n_angles=n_angles, n_eccentricities=n_eccentricities,
fn=fn)
# # Get masked data
# data_lh, data_rh = get_dartboard_data(vertex_data, masks,
# n_angles=n_angles, n_eccentricities=n_eccentricities,
# fn=fn)
masks, to_plot = get_dartboard_data(vertex_data, **dartboard_spec, mean_func=fn)
# Vertex flatmap plot
fig, ax = plt.subplots(figsize=figsize)
_ = cx.quickflat.make_figure(
vertex_data, with_curvature=True, fig=ax, **quickflat_kw)
_ = quickflat.make_figure(
vertex_data, fig=ax, **quickflat_kw)

# Augment plot
if show_dartboard_grid:
# Grid fill for dartboard area
vx_grid = generate_dartboard_vertex_object(
masks, vertex_data.subject, type='grid', bg_value=np.nan)
img_grid, extent = cx.quickflat.composite.make_flatmap_image(vx_grid)
img_grid, extent = quickflat.composite.make_flatmap_image(vx_grid)
ax.imshow(img_grid, extent=extent,
alpha=dartboard_display_alpha, cmap='gray')
if show_dartboard_edge:
# Solid fill for dartboard area
vx_fill = generate_dartboard_vertex_object(
masks, vertex_data.subject, type='solid', bg_value=0)
img_fill, extent = cx.quickflat.composite.make_flatmap_image(vx_fill)
img_fill, extent = quickflat.composite.make_flatmap_image(vx_fill)
xt = np.linspace(extent[0], extent[1], img_fill.shape[1])
yt = np.linspace(extent[3], extent[2], img_fill.shape[0])
xg, yg = np.meshgrid(xt, yt)
Expand All @@ -1717,81 +1745,62 @@ def dartboard_on_flatmap(vertex_data,
zorder=10, # (in front)
)
if show_anchor_lines:
overlay = db.get_overlay(vertex_data.subject)
# Draw lines from center of dartboard to each anchor ROI center
if not isinstance(flatmap_line_style, (list, tuple)):
flatmap_line_style = [flatmap_line_style] * len(anchor_rois)
pts, _ = cx.db.get_surf(vertex_data.subject,
flatmap_line_style = [flatmap_line_style] * len(anchors)
pts, _ = db.get_surf(vertex_data.subject,
'flat', merge=True, nudge=True)
roi_centers = {}
for roi in [center_roi] + anchor_rois:
vertex_indices = get_roi_centroids(vertex_data.subject, roi,
surf_type=surf_type,
cache_dir=cache_dir,
verbose=verbose)
roi_centers[roi] = np.array([pts[vertex_indices[0]][:2],
pts[vertex_indices[1]][:2]])
if roi == center_roi:
center_roi_vertex = vertex_indices
roi_center_indices = _get_anchor_points(overlay, center_roi, anchors)
for roi, c in roi_center_indices.items():
roi_centers[roi] = np.array([pts[c[0]][:2],
pts[c[1]][:2]])
center_pt = roi_centers.pop(center_roi)
for lr in [0, 1]: # left, right
for pt, ls in zip(anchor_rois, flatmap_line_style):
x = [roi_centers[center_roi][lr, 0],
for pt, ls in zip(roi_centers.keys(), flatmap_line_style):
x = [center_pt[lr, 0],
roi_centers[pt][lr, 0]]
y = [roi_centers[center_roi][lr, 1],
y = [center_pt[lr, 1],
roi_centers[pt][lr, 1]]
ax.plot(x, y, lw=flatmap_line_linewidth,
ls=ls, color=flatmap_line_color)
# Replace below with...?
# overlaid_axis = fig.add_axes(
# mtrans.Bbox([
# [position_x-scale/2, position_y-scale/2],
# [position_x+scale/2, position_y+scale/2]]),
# frameon=False
# )

# These are in unitless percentages of the figure size. (0,0 is bottom left)
left, bottom, width, height = [0.5 - dartboard_axes_dist_from_midline - dartboard_axes_width,
dartboard_axes_bottom,
dartboard_axes_width,
dartboard_axes_height]
ax_lhem = fig.add_axes([left, bottom, width, height])
# lh_h = dartboard(data_lh,
# cmap=vertex_data.cmap,
# ax=ax_lhem,
# theta_direction='counter_clockwise',#1,
# vmin=vmin,
# vmax=vmax,
# max_radius=max_radii[0],
# **kwargs)
left, bottom, width, height = [0.5 + dartboard_axes_dist_from_midline,
dartboard_axes_bottom,
dartboard_axes_width,
dartboard_axes_height]
ax_rhem = fig.add_axes([left, bottom, width, height])
# rh_h = dartboard(data_rh,
# cmap=vertex_data.cmap,
# ax=ax_rhem,
# theta_direction='clockwise', #-1,
# vmin=vmin,
# vmax=vmax,
# max_radius=max_radii[1],
# **kwargs,
# )
print(cmap)
dartboard_pair((data_lh, data_rh),
vertex_data.subject,
center_roi,
anchor_rois,
display_rois=display_rois,
n_angles=n_angles,
n_eccentricities=n_eccentricities,
max_radii=max_radii,
fn=fn,
surf_type=surf_type,
axs=np.array([ax_lhem, ax_rhem]),

show_dartboard_pair(vertex_data,
**dartboard_spec,
rois=rois,
mean_func=fn,
axes=np.array([ax_lhem, ax_rhem]),
vmin=vmin,
vmax=vmax,
cmap=cmap,
outline_kw=outline_kw,
cache_dir=cache_dir,
#outline_kw=outline_kw,
verbose=verbose,
**kwargs,
#**kwargs,
###
)
if show_anchor_lines:
# TEMP: Will need changing if anchors are not meant to specify
# Show vertical and horizontal lines on dartboard plots
# NOTE: Will need changing if anchors are not meant to specify
# 90 degree ticks around dartboard. For that, should be something
# like defining regularly spaced angles:
# angles = np.linspace(0, 2 * np.pi, n_anchor_points)
Expand Down

0 comments on commit 93081e6

Please sign in to comment.