diff --git a/plot_2dhists.py b/plot_2dhists.py new file mode 100644 index 0000000..680c93a --- /dev/null +++ b/plot_2dhists.py @@ -0,0 +1,121 @@ +import os, sys, glob +import h5py +import numpy as np +import matplotlib.pyplot as plt +import matplotlib as mpl +import torch +from torch import optim, nn +import lightning as L +from torch.utils.data import Dataset, DataLoader +from matplotlib.colors import LogNorm + +import train +import decoder + +if __name__ == '__main__': + if len(sys.argv) == 1: + N = 16384 + SMPL_RATE_DEC1 = 125e6 + decimation = 256 + smpl_rate = SMPL_RATE_DEC1 / decimation + time_data = np.linspace(0, N - 1, num=N) / smpl_rate + + # valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\valid_1to1kHz_misaligned_invertspectra_trigdelay8192_sleep100ms_2kx1shots_randampl.h5py' + # test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_1to1kHz_misaligned_invertspectra_trigdelay8192_sleep100ms_2kx1shots_randampl.h5py' + # train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\train_1to1kHz_misaligned_invertspectra_trigdelay8192_sleep100ms_10kx1shots_randampl.h5py' + valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\val_1to1kHz_invertspectra_trigdelay8192_sleep100ms_2kx1shots_dec=256_8192_randampl.h5py' + test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_1to1kHz_invertspectra_trigdelay8192_sleep100ms_2kx1shots_dec=256_8192_randampl.h5py' + train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\train_1to1kHz_invertspectra_trigdelay8192_sleep100ms_10kx1shots_dec=256_8192_randampl.h5py' + + model_tag = "z4dscrhq" #"qn4eo8zu"#"z4dscrhq" + step = 64 + group_size = 256 + + runner = train.TrainingRunner(train_file, valid_file, test_file, step=step) + model, result = runner.load_model(model_tag) + + modes = ['train', 'valid', 'test'] + + # only works for datasets using test_mode=True + for mode in ['test']: + if mode == 'valid': + iter_loader = iter(runner.valid_loader_testmode) + elif mode == 'test': # test mode + iter_loader = iter(runner.test_loader) + else: + raise ValueError('Invalid mode') + + preds = torch.empty(0) + truths = torch.empty(0) + input_segments = torch.empty(0) + + num_batches = 10 # ceil[2k/128] + for i in range(num_batches): + print(i) + batch = next(iter_loader) + inputs = torch.squeeze(batch[0]) + targets = torch.squeeze(batch[1]) + outputs = torch.squeeze(model.predict_step(batch, 1, test_mode=True)[0]) # store only y_hat + # print("Input shape", inputs.shape) # [batch_size, 16384] + # print("Target shape", targets.shape) # [batch_size, num_groups] + # print("Pred shape", outputs.shape) # [batch_size, num_groups] + + preds = torch.cat((preds, torch.flatten(outputs)), dim=0) + truths = torch.cat((truths, torch.flatten(targets)), dim=0) + + preds = preds.cpu().detach().numpy() + truths = truths.cpu().detach().numpy() + + num_samples = preds.shape[0] + print("samples", num_samples) + losses = (preds - truths)**2 + print("losses", losses.shape) + + fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True, tight_layout=True) #, gridspec_kw={'height_ratios': [1, 1, 1]}) + num_bins = 51 # odd to see if 0-velocity over/under-predicted + + range_min = min(np.min(truths), np.min(preds)) + range_max = max(np.max(truths), np.max(preds)) + range_max = max(np.abs(range_min), range_max) # find max magnitude in all data + hist_range = [[-range_max, range_max]] * 2 + # 2D Histogram for Counts + hist_counts, xedges_counts, yedges_counts = np.histogram2d(truths, preds, bins=num_bins, range=hist_range) + masked_hist_counts = np.ma.masked_where(hist_counts == 0, hist_counts) + im_counts_log = axs[0].imshow(masked_hist_counts.T, extent=[xedges_counts[0], xedges_counts[-1], yedges_counts[0], yedges_counts[-1]], + origin='lower', aspect='equal', cmap='coolwarm', norm=LogNorm()) + axs[0].set_xlabel('Expected Velocity (um/s)') + axs[0].set_ylabel('Predicted Velocity (um/s)') + axs[0].set_title('Counts for Predicted vs Expected Velocity') + cbar_0 = fig.colorbar(im_counts_log, ax=axs[0], orientation='horizontal', pad=0.1) + cbar_0.set_label('Counts Per Bin') + # avg_mse = hist_mse_sum / hist_counts # avg mse per bin + + # 2D Histogram for Average MSELoss + hist_mse_sum, xedges, yedges = np.histogram2d(truths, preds, bins=num_bins, weights=losses, range=hist_range) + # masked_hist = np.ma.masked_where(hist_mse_sum == 0, hist_mse_sum) + cax = axs[1].imshow((hist_mse_sum/masked_hist_counts).T, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], + origin='lower', aspect='equal', cmap='coolwarm', norm=LogNorm()) + # with np.errstate(divide='ignore', invalid='ignore'): + # cax = axs[1].imshow((hist_mse_sum / hist_counts).T, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], origin='lower', aspect='auto') + cbar_1 = fig.colorbar(cax, ax=axs[1], orientation='horizontal', pad=0.1) + cbar_1.set_label('Average MSE Loss Per Bin') + axs[1].set_xlabel('Expected Velocity (um/s)') + # axs[1].set_ylabel('Predicted Velocity') + axs[1].set_title('Avg MSELoss for Predicted vs Expected Velocity') + + # 2D Histogram for Summed MSELoss + sum_ax = axs[2].imshow(hist_mse_sum.T, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], + origin='lower', aspect='equal', cmap='coolwarm', norm=LogNorm()) + cbar_sum = fig.colorbar(sum_ax, ax=axs[2], orientation='horizontal', pad=0.1) + cbar_sum.set_label('Summed MSE Loss Per Bin') + axs[2].set_xlabel('Expected Velocity (um/s)') + # axs[2].set_ylabel('Predicted Velocity') + axs[2].set_title('Summed MSELoss for Predicted vs Expected Velocity') + + for ax in axs: + ax.axline((0, 0), slope=1, color='black') + ax.axline((0, 0), slope=-1, color='black') + ax.set_aspect(np.diff(ax.get_xlim()) / np.diff(ax.get_ylim())) + plt.show() + else: + print("Error: Unsupported number of command-line arguments") diff --git a/model_analysis.py b/plot_predictions.py similarity index 100% rename from model_analysis.py rename to plot_predictions.py diff --git a/plot_quartiles.py b/plot_quartiles.py new file mode 100644 index 0000000..507ef94 --- /dev/null +++ b/plot_quartiles.py @@ -0,0 +1,147 @@ +import os, sys, glob +import h5py +import numpy as np +import matplotlib.pyplot as plt +import matplotlib as mpl +import torch +from torch import optim, nn +import lightning as L +from torch.utils.data import Dataset, DataLoader +from matplotlib.colors import LogNorm + +import train +import decoder + +if __name__ == '__main__': + if len(sys.argv) == 1: + N = 16384 + SMPL_RATE_DEC1 = 125e6 + decimation = 256 + smpl_rate = SMPL_RATE_DEC1 / decimation + time_data = np.linspace(0, N - 1, num=N) / smpl_rate + + # valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\valid_1to1kHz_misaligned_invertspectra_trigdelay8192_sleep100ms_2kx1shots_randampl.h5py' + # test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_1to1kHz_misaligned_invertspectra_trigdelay8192_sleep100ms_2kx1shots_randampl.h5py' + # train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\train_1to1kHz_misaligned_invertspectra_trigdelay8192_sleep100ms_10kx1shots_randampl.h5py' + valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\val_1to1kHz_invertspectra_trigdelay8192_sleep100ms_2kx1shots_dec=256_8192_randampl.h5py' + test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_1to1kHz_invertspectra_trigdelay8192_sleep100ms_2kx1shots_dec=256_8192_randampl.h5py' + train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\train_1to1kHz_invertspectra_trigdelay8192_sleep100ms_10kx1shots_dec=256_8192_randampl.h5py' + + model_tag = "z4dscrhq" #"qn4eo8zu"#"z4dscrhq" + step = 64 + group_size = 256 + + runner = train.TrainingRunner(train_file, valid_file, test_file, step=step) + model, result = runner.load_model(model_tag) + + modes = ['train', 'valid', 'test'] + + # only works for datasets using test_mode=True + iter_loader = iter(runner.test_loader) + + preds = torch.empty(0) + truths = torch.empty(0) + inputs = torch.empty(0) + + num_batches = 10 # ceil[2k/128] + for i in range(num_batches): + print(i) + batch = next(iter_loader) + input_buffers = torch.squeeze(batch[0]) + targets = torch.squeeze(batch[1]) + outputs = torch.squeeze(model.predict_step(batch, 1, test_mode=True)[0]) # store only y_hat + # print("Input shape", inputs.shape) # [batch_size, 16384] + # print("Target shape", targets.shape) # [batch_size, num_groups] + # print("Pred shape", outputs.shape) # [batch_size, num_groups] + + preds = torch.cat((preds, outputs), dim=0) + truths = torch.cat((truths, targets), dim=0) + inputs = torch.cat((inputs, input_buffers), dim=0) + + preds = preds.cpu().detach().numpy() # [num_batches * batch_size, num_groups] + truths = truths.cpu().detach().numpy() # [num_batches * batch_size, num_groups] + inputs = inputs.cpu().detach().numpy() # [num_batches * batch_size, 16384] + + print("preds shape", preds.shape) + print("truths shape", truths.shape) + print("inputs shape", inputs.shape) + + losses = np.mean((preds - truths)**2, axis=1) # avg_mse_loss per buffer, [num_batches * batch_size] + print("losses shape", losses.shape) + + sorted_idxs = np.argsort(losses) + # all sorted arrays should have same shape as their unsorted originals + sorted_losses = losses[sorted_idxs] + print("sorted losses shape", sorted_losses.shape) + sorted_truths = truths[sorted_idxs] + print("sorted truths shape", sorted_truths.shape) + sorted_preds = preds[sorted_idxs] + print("sorted preds shape", sorted_preds.shape) + sorted_inputs = inputs[sorted_idxs] + print("sorted inputs shape", sorted_inputs.shape) + + num_buffers = sorted_losses.shape[0] + q2_start_idx = num_buffers // 4 # End of first quartile + q3_start_idx = num_buffers // 2 # End of second quartile + q4_start_idx = 3 * num_buffers // 4 # End of third quartile + + first_quartile = { + "losses": sorted_losses[:q2_start_idx], + "truths": sorted_truths[:q2_start_idx], + "preds": sorted_preds[:q2_start_idx], + "inputs": sorted_inputs[:q2_start_idx] + } + + fourth_quartile = { + "losses": sorted_losses[q4_start_idx:], + "truths": sorted_truths[q4_start_idx:], + "preds": sorted_preds[q4_start_idx:], + "inputs": sorted_inputs[q4_start_idx:] + } + + print("First quartile buffer count:", first_quartile["inputs"].shape) + print("Fourth quartile buffer count:", fourth_quartile["inputs"].shape) + + def plot_quartile_examples(quartiles, num_exs=4): + fig, axes = plt.subplots(num_exs * 2, 2, figsize=(15, num_exs * 5), + gridspec_kw={'height_ratios': [1, 2] * num_exs}) + + time_data = np.linspace(0, N - 1, num=N) / smpl_rate + num_groups = truths.shape[1] + start_idxs = torch.arange(num_groups) * step + for i in range(num_exs): + for j, quartile in enumerate(quartiles): + # Plot PD trace + axes[i*2, j].plot(time_data, quartile['inputs'][i], color='blue') + axes[i*2, j].set_ylabel('V') + axes[i*2, j].set_xticks([]) + + # Plot velocities + axes[i*2 + 1, j].plot(time_data[start_idxs], quartile['truths'][i], + label='Target', color='blue', + marker='o', markersize=2) + axes[i*2 + 1, j].plot(time_data[start_idxs], quartile['preds'][i], + label=f'Pred (Avg MSE: {quartile["losses"][i]:.2f})', color='orange', marker='o', + markersize=2) + axes[i*2 + 1, j].set_ylabel('µm/s') + axes[i*2 + 1, j].legend() + + if i == num_exs - 1: # a bottom plot + axes[i*2 + 1, j].set_xlabel('Time (s)') + else: + axes[i*2 + 1, j].set_xticks([]) + + fig.text(0.25, 0.98, '1st Quartile (Best Performance)', ha='center', fontsize=15) + fig.text(0.75, 0.98, '4th Quartile (Worst Performance)', ha='center', fontsize=15) + plt.tight_layout(rect=[0, 0, 1, 0.96]) # Make space for suptitles + plt.show() + + # Call the function with the first and fourth quartile data + plot_quartile_examples( + quartiles=[ + first_quartile, # Best quartile data + fourth_quartile # Worst quartile data + ] + ) + else: + print("Error: Unsupported number of command-line arguments") \ No newline at end of file