Skip to content

Commit

Permalink
Minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasstolker committed Jun 6, 2024
1 parent cb0b6eb commit 2a868d1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
10 changes: 8 additions & 2 deletions species/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,7 +1284,10 @@ def add_object(
and hdulist[0].header["INSTRU"] == "GRAVITY"
):
# Read data from a FITS file with the GRAVITY format
gravity_object = hdulist[0].header["OBJECT"]
if "OBJECT" in hdulist[0].header:
gravity_object = hdulist[0].header["OBJECT"]
else:
gravity_object = None

if verbose:
print(" - GRAVITY spectrum:")
Expand Down Expand Up @@ -1448,7 +1451,10 @@ def add_object(
and hdulist[0].header["INSTRU"] == "GRAVITY"
):
# Read data from a FITS file with the GRAVITY format
gravity_object = hdulist[0].header["OBJECT"]
if "OBJECT" in hdulist[0].header:
gravity_object = hdulist[0].header["OBJECT"]
else:
gravity_object = None

if verbose:
print(" - GRAVITY covariance matrix:")
Expand Down
26 changes: 16 additions & 10 deletions species/plot/plot_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,13 +535,13 @@ def plot_spectrum(
wavelength = box_item.wavelength
flux = box_item.flux

data_in = np.column_stack([wavelength, flux])
data_out = convert_units(data_in, units, convert_from=False)
if isinstance(wavelength[0], (np.float32, np.float64)):
data_in = np.column_stack([wavelength, flux])
data_out = convert_units(data_in, units, convert_from=False)

wavelength = data_out[:, 0]
flux = data_out[:, 1]
wavelength = data_out[:, 0]
flux = data_out[:, 1]

if isinstance(wavelength[0], (np.float32, np.float64)):
data = np.array(flux, dtype=np.float64)
flux_masked = np.ma.array(data, mask=np.isnan(data))

Expand Down Expand Up @@ -597,8 +597,14 @@ def plot_spectrum(
)

elif isinstance(wavelength[0], (np.ndarray)):
for i, item in enumerate(wavelength):
data = np.array(flux[i], dtype=np.float64)
for i in range(len(wavelength)):
data_in = np.column_stack([wavelength[i], flux[i]])
data_out = convert_units(data_in, units, convert_from=False)

wavelength = data_out[:, 0]
flux = data_out[:, 1]

data = np.array(flux, dtype=np.float64)
flux_masked = np.ma.array(data, mask=np.isnan(data))

if isinstance(box_item.name[i], bytes):
Expand All @@ -607,10 +613,10 @@ def plot_spectrum(
label = box_item.name[i]

if quantity == "flux":
flux_scaling = item
flux_scaling = wavelength

ax1.plot(
item, flux_scaling * flux_masked / scaling, lw=0.5, label=label
wavelength, flux_scaling * flux_masked / scaling, lw=0.5, label=label
)

elif isinstance(box_item, list):
Expand Down Expand Up @@ -943,7 +949,7 @@ def plot_spectrum(
if isinstance(box_item.flux[filter_item][0], np.ndarray):
for phot_idx in range(box_item.flux[filter_item].shape[1]):
plot_obj = ax1.errorbar(
wavelength,
wavelength[phot_idx],
scale_tmp * box_item.flux[filter_item][0, phot_idx],
xerr=fwhm / 2.0,
yerr=scale_tmp
Expand Down

0 comments on commit 2a868d1

Please sign in to comment.