-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a07c5da
commit afad4b0
Showing
3 changed files
with
268 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import os, sys, glob | ||
import h5py | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import matplotlib as mpl | ||
import torch | ||
from torch import optim, nn | ||
import lightning as L | ||
from torch.utils.data import Dataset, DataLoader | ||
from matplotlib.colors import LogNorm | ||
|
||
import train | ||
import decoder | ||
|
||
if __name__ == '__main__': | ||
if len(sys.argv) == 1: | ||
N = 16384 | ||
SMPL_RATE_DEC1 = 125e6 | ||
decimation = 256 | ||
smpl_rate = SMPL_RATE_DEC1 / decimation | ||
time_data = np.linspace(0, N - 1, num=N) / smpl_rate | ||
|
||
# valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\valid_1to1kHz_misaligned_invertspectra_trigdelay8192_sleep100ms_2kx1shots_randampl.h5py' | ||
# test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_1to1kHz_misaligned_invertspectra_trigdelay8192_sleep100ms_2kx1shots_randampl.h5py' | ||
# train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\train_1to1kHz_misaligned_invertspectra_trigdelay8192_sleep100ms_10kx1shots_randampl.h5py' | ||
valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\val_1to1kHz_invertspectra_trigdelay8192_sleep100ms_2kx1shots_dec=256_8192_randampl.h5py' | ||
test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_1to1kHz_invertspectra_trigdelay8192_sleep100ms_2kx1shots_dec=256_8192_randampl.h5py' | ||
train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\train_1to1kHz_invertspectra_trigdelay8192_sleep100ms_10kx1shots_dec=256_8192_randampl.h5py' | ||
|
||
model_tag = "z4dscrhq" #"qn4eo8zu"#"z4dscrhq" | ||
step = 64 | ||
group_size = 256 | ||
|
||
runner = train.TrainingRunner(train_file, valid_file, test_file, step=step) | ||
model, result = runner.load_model(model_tag) | ||
|
||
modes = ['train', 'valid', 'test'] | ||
|
||
# only works for datasets using test_mode=True | ||
for mode in ['test']: | ||
if mode == 'valid': | ||
iter_loader = iter(runner.valid_loader_testmode) | ||
elif mode == 'test': # test mode | ||
iter_loader = iter(runner.test_loader) | ||
else: | ||
raise ValueError('Invalid mode') | ||
|
||
preds = torch.empty(0) | ||
truths = torch.empty(0) | ||
input_segments = torch.empty(0) | ||
|
||
num_batches = 10 # ceil[2k/128] | ||
for i in range(num_batches): | ||
print(i) | ||
batch = next(iter_loader) | ||
inputs = torch.squeeze(batch[0]) | ||
targets = torch.squeeze(batch[1]) | ||
outputs = torch.squeeze(model.predict_step(batch, 1, test_mode=True)[0]) # store only y_hat | ||
# print("Input shape", inputs.shape) # [batch_size, 16384] | ||
# print("Target shape", targets.shape) # [batch_size, num_groups] | ||
# print("Pred shape", outputs.shape) # [batch_size, num_groups] | ||
|
||
preds = torch.cat((preds, torch.flatten(outputs)), dim=0) | ||
truths = torch.cat((truths, torch.flatten(targets)), dim=0) | ||
|
||
preds = preds.cpu().detach().numpy() | ||
truths = truths.cpu().detach().numpy() | ||
|
||
num_samples = preds.shape[0] | ||
print("samples", num_samples) | ||
losses = (preds - truths)**2 | ||
print("losses", losses.shape) | ||
|
||
fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True, tight_layout=True) #, gridspec_kw={'height_ratios': [1, 1, 1]}) | ||
num_bins = 51 # odd to see if 0-velocity over/under-predicted | ||
|
||
range_min = min(np.min(truths), np.min(preds)) | ||
range_max = max(np.max(truths), np.max(preds)) | ||
range_max = max(np.abs(range_min), range_max) # find max magnitude in all data | ||
hist_range = [[-range_max, range_max]] * 2 | ||
# 2D Histogram for Counts | ||
hist_counts, xedges_counts, yedges_counts = np.histogram2d(truths, preds, bins=num_bins, range=hist_range) | ||
masked_hist_counts = np.ma.masked_where(hist_counts == 0, hist_counts) | ||
im_counts_log = axs[0].imshow(masked_hist_counts.T, extent=[xedges_counts[0], xedges_counts[-1], yedges_counts[0], yedges_counts[-1]], | ||
origin='lower', aspect='equal', cmap='coolwarm', norm=LogNorm()) | ||
axs[0].set_xlabel('Expected Velocity (um/s)') | ||
axs[0].set_ylabel('Predicted Velocity (um/s)') | ||
axs[0].set_title('Counts for Predicted vs Expected Velocity') | ||
cbar_0 = fig.colorbar(im_counts_log, ax=axs[0], orientation='horizontal', pad=0.1) | ||
cbar_0.set_label('Counts Per Bin') | ||
# avg_mse = hist_mse_sum / hist_counts # avg mse per bin | ||
|
||
# 2D Histogram for Average MSELoss | ||
hist_mse_sum, xedges, yedges = np.histogram2d(truths, preds, bins=num_bins, weights=losses, range=hist_range) | ||
# masked_hist = np.ma.masked_where(hist_mse_sum == 0, hist_mse_sum) | ||
cax = axs[1].imshow((hist_mse_sum/masked_hist_counts).T, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], | ||
origin='lower', aspect='equal', cmap='coolwarm', norm=LogNorm()) | ||
# with np.errstate(divide='ignore', invalid='ignore'): | ||
# cax = axs[1].imshow((hist_mse_sum / hist_counts).T, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], origin='lower', aspect='auto') | ||
cbar_1 = fig.colorbar(cax, ax=axs[1], orientation='horizontal', pad=0.1) | ||
cbar_1.set_label('Average MSE Loss Per Bin') | ||
axs[1].set_xlabel('Expected Velocity (um/s)') | ||
# axs[1].set_ylabel('Predicted Velocity') | ||
axs[1].set_title('Avg MSELoss for Predicted vs Expected Velocity') | ||
|
||
# 2D Histogram for Summed MSELoss | ||
sum_ax = axs[2].imshow(hist_mse_sum.T, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], | ||
origin='lower', aspect='equal', cmap='coolwarm', norm=LogNorm()) | ||
cbar_sum = fig.colorbar(sum_ax, ax=axs[2], orientation='horizontal', pad=0.1) | ||
cbar_sum.set_label('Summed MSE Loss Per Bin') | ||
axs[2].set_xlabel('Expected Velocity (um/s)') | ||
# axs[2].set_ylabel('Predicted Velocity') | ||
axs[2].set_title('Summed MSELoss for Predicted vs Expected Velocity') | ||
|
||
for ax in axs: | ||
ax.axline((0, 0), slope=1, color='black') | ||
ax.axline((0, 0), slope=-1, color='black') | ||
ax.set_aspect(np.diff(ax.get_xlim()) / np.diff(ax.get_ylim())) | ||
plt.show() | ||
else: | ||
print("Error: Unsupported number of command-line arguments") |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import os, sys, glob | ||
import h5py | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import matplotlib as mpl | ||
import torch | ||
from torch import optim, nn | ||
import lightning as L | ||
from torch.utils.data import Dataset, DataLoader | ||
from matplotlib.colors import LogNorm | ||
|
||
import train | ||
import decoder | ||
|
||
if __name__ == '__main__': | ||
if len(sys.argv) == 1: | ||
N = 16384 | ||
SMPL_RATE_DEC1 = 125e6 | ||
decimation = 256 | ||
smpl_rate = SMPL_RATE_DEC1 / decimation | ||
time_data = np.linspace(0, N - 1, num=N) / smpl_rate | ||
|
||
# valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\valid_1to1kHz_misaligned_invertspectra_trigdelay8192_sleep100ms_2kx1shots_randampl.h5py' | ||
# test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_1to1kHz_misaligned_invertspectra_trigdelay8192_sleep100ms_2kx1shots_randampl.h5py' | ||
# train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\train_1to1kHz_misaligned_invertspectra_trigdelay8192_sleep100ms_10kx1shots_randampl.h5py' | ||
valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\val_1to1kHz_invertspectra_trigdelay8192_sleep100ms_2kx1shots_dec=256_8192_randampl.h5py' | ||
test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_1to1kHz_invertspectra_trigdelay8192_sleep100ms_2kx1shots_dec=256_8192_randampl.h5py' | ||
train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\train_1to1kHz_invertspectra_trigdelay8192_sleep100ms_10kx1shots_dec=256_8192_randampl.h5py' | ||
|
||
model_tag = "z4dscrhq" #"qn4eo8zu"#"z4dscrhq" | ||
step = 64 | ||
group_size = 256 | ||
|
||
runner = train.TrainingRunner(train_file, valid_file, test_file, step=step) | ||
model, result = runner.load_model(model_tag) | ||
|
||
modes = ['train', 'valid', 'test'] | ||
|
||
# only works for datasets using test_mode=True | ||
iter_loader = iter(runner.test_loader) | ||
|
||
preds = torch.empty(0) | ||
truths = torch.empty(0) | ||
inputs = torch.empty(0) | ||
|
||
num_batches = 10 # ceil[2k/128] | ||
for i in range(num_batches): | ||
print(i) | ||
batch = next(iter_loader) | ||
input_buffers = torch.squeeze(batch[0]) | ||
targets = torch.squeeze(batch[1]) | ||
outputs = torch.squeeze(model.predict_step(batch, 1, test_mode=True)[0]) # store only y_hat | ||
# print("Input shape", inputs.shape) # [batch_size, 16384] | ||
# print("Target shape", targets.shape) # [batch_size, num_groups] | ||
# print("Pred shape", outputs.shape) # [batch_size, num_groups] | ||
|
||
preds = torch.cat((preds, outputs), dim=0) | ||
truths = torch.cat((truths, targets), dim=0) | ||
inputs = torch.cat((inputs, input_buffers), dim=0) | ||
|
||
preds = preds.cpu().detach().numpy() # [num_batches * batch_size, num_groups] | ||
truths = truths.cpu().detach().numpy() # [num_batches * batch_size, num_groups] | ||
inputs = inputs.cpu().detach().numpy() # [num_batches * batch_size, 16384] | ||
|
||
print("preds shape", preds.shape) | ||
print("truths shape", truths.shape) | ||
print("inputs shape", inputs.shape) | ||
|
||
losses = np.mean((preds - truths)**2, axis=1) # avg_mse_loss per buffer, [num_batches * batch_size] | ||
print("losses shape", losses.shape) | ||
|
||
sorted_idxs = np.argsort(losses) | ||
# all sorted arrays should have same shape as their unsorted originals | ||
sorted_losses = losses[sorted_idxs] | ||
print("sorted losses shape", sorted_losses.shape) | ||
sorted_truths = truths[sorted_idxs] | ||
print("sorted truths shape", sorted_truths.shape) | ||
sorted_preds = preds[sorted_idxs] | ||
print("sorted preds shape", sorted_preds.shape) | ||
sorted_inputs = inputs[sorted_idxs] | ||
print("sorted inputs shape", sorted_inputs.shape) | ||
|
||
num_buffers = sorted_losses.shape[0] | ||
q2_start_idx = num_buffers // 4 # End of first quartile | ||
q3_start_idx = num_buffers // 2 # End of second quartile | ||
q4_start_idx = 3 * num_buffers // 4 # End of third quartile | ||
|
||
first_quartile = { | ||
"losses": sorted_losses[:q2_start_idx], | ||
"truths": sorted_truths[:q2_start_idx], | ||
"preds": sorted_preds[:q2_start_idx], | ||
"inputs": sorted_inputs[:q2_start_idx] | ||
} | ||
|
||
fourth_quartile = { | ||
"losses": sorted_losses[q4_start_idx:], | ||
"truths": sorted_truths[q4_start_idx:], | ||
"preds": sorted_preds[q4_start_idx:], | ||
"inputs": sorted_inputs[q4_start_idx:] | ||
} | ||
|
||
print("First quartile buffer count:", first_quartile["inputs"].shape) | ||
print("Fourth quartile buffer count:", fourth_quartile["inputs"].shape) | ||
|
||
def plot_quartile_examples(quartiles, num_exs=4): | ||
fig, axes = plt.subplots(num_exs * 2, 2, figsize=(15, num_exs * 5), | ||
gridspec_kw={'height_ratios': [1, 2] * num_exs}) | ||
|
||
time_data = np.linspace(0, N - 1, num=N) / smpl_rate | ||
num_groups = truths.shape[1] | ||
start_idxs = torch.arange(num_groups) * step | ||
for i in range(num_exs): | ||
for j, quartile in enumerate(quartiles): | ||
# Plot PD trace | ||
axes[i*2, j].plot(time_data, quartile['inputs'][i], color='blue') | ||
axes[i*2, j].set_ylabel('V') | ||
axes[i*2, j].set_xticks([]) | ||
|
||
# Plot velocities | ||
axes[i*2 + 1, j].plot(time_data[start_idxs], quartile['truths'][i], | ||
label='Target', color='blue', | ||
marker='o', markersize=2) | ||
axes[i*2 + 1, j].plot(time_data[start_idxs], quartile['preds'][i], | ||
label=f'Pred (Avg MSE: {quartile["losses"][i]:.2f})', color='orange', marker='o', | ||
markersize=2) | ||
axes[i*2 + 1, j].set_ylabel('µm/s') | ||
axes[i*2 + 1, j].legend() | ||
|
||
if i == num_exs - 1: # a bottom plot | ||
axes[i*2 + 1, j].set_xlabel('Time (s)') | ||
else: | ||
axes[i*2 + 1, j].set_xticks([]) | ||
|
||
fig.text(0.25, 0.98, '1st Quartile (Best Performance)', ha='center', fontsize=15) | ||
fig.text(0.75, 0.98, '4th Quartile (Worst Performance)', ha='center', fontsize=15) | ||
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Make space for suptitles | ||
plt.show() | ||
|
||
# Call the function with the first and fourth quartile data | ||
plot_quartile_examples( | ||
quartiles=[ | ||
first_quartile, # Best quartile data | ||
fourth_quartile # Worst quartile data | ||
] | ||
) | ||
else: | ||
print("Error: Unsupported number of command-line arguments") |