Skip to content

Commit

Permalink
wrote function to write multichannel pretraining data, some modificat…
Browse files Browse the repository at this point in the history
…ions to VelocityDataset and write_data utility function
  • Loading branch information
npeard committed Sep 24, 2024
1 parent b509798 commit 4fa561f
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 32 deletions.
90 changes: 71 additions & 19 deletions interferometers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import numpy as np
import matplotlib.pyplot as plt
import util
from tqdm import tqdm


class MichelsonInterferometer:
def __init__(self, wavelength, displacement_amplitude, phase):
self.wavelength = wavelength # in microns
self.displacement_amplitude = displacement_amplitude # in microns
self.phase = phase
self.displacement = None
self.velocity = None
self.phase = phase # in radians, stands for random position offset
self.displacement = None # in microns
self.velocity = None # in microns/s
self.time = None # in seconds

def get_displacement(self, start_frequency, end_frequency, length, sample_rate):
# Get a random displacement in time, resets each time it is called
Expand All @@ -22,44 +25,93 @@ def get_displacement(self, start_frequency, end_frequency, length, sample_rate):

return time, displacement

def interferometer_output(self, start_frequency, end_frequency,
measurement_noise_level, length, sample_rate):
def set_displacement(self, displacement, time):
# TODO: add checks and tests here, that the displacement is in right
# range, sampling, etc.
self.displacement = displacement
self.time = time

def get_interferometer_output(self, start_frequency, end_frequency,
measurement_noise_level, length, sample_rate):
E0 = 1 + measurement_noise_level * util.bounded_frequency_waveform(1e3,
1e6,
length, sample_rate)[1]
ER = 0.1 + measurement_noise_level * util.bounded_frequency_waveform(1e3,
1e6,
length, sample_rate)[1]

self.time, self.displacement = self.get_displacement(start_frequency,
if self.displacement is None:
self.time, self.displacement = self.get_displacement(start_frequency,
end_frequency, length, sample_rate)

interference = np.cos(2 * np.pi / self.wavelength * self.displacement)
interference = np.cos(2 * np.pi / self.wavelength * self.displacement
+ self.phase)

signal = E0**2 + ER**2 + 2 * E0 * ER * interference

return signal

def get_pretraining_data(self, start_frequency, end_frequency):
self.signal = self.interferometer_output(start_frequency,
end_frequency, 0.1, 8192, 1e6)
def get_buffer(self, start_frequency=0, end_frequency=1e3):
self.signal = self.get_interferometer_output(start_frequency,
end_frequency, 0.3, 8192, 1e6)

self.velocity = np.diff(self.displacement)
self.velocity = np.insert(self.velocity, 0, self.velocity[0])
self.velocity /= (self.time[1] - self.time[0])

# Remove DC offset
self.signal = self.signal - np.mean(self.signal)

return self.time, self.signal, self.displacement, self.velocity

def plot_pretraining_data(self):
time, signal, displacement, velocity = self.get_pretraining_data(0, 1e3)

plt.plot(time, signal)
plt.plot(time, displacement)
plt.plot(time, velocity)
def plot_buffer(self):
time, signal, displacement, velocity = self.get_buffer(0, 1e3)

time = time[0:256]
signal = signal[0:256]
displacement = displacement[0:256]
velocity = velocity[0:256]

fig, ax1 = plt.subplots(figsize=(18, 6))
ax1.plot(time, signal, color='b')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Signal (V)', color='b')
ax1.tick_params('y', colors='b')

ax2 = ax1.twinx()
ax2.plot(time, displacement, color='r')
ax2.plot(time, velocity, color='g')
ax2.set_ylabel('Displacement (microns)', color='r')
ax2.tick_params('y', colors='r')

plt.tight_layout()
plt.show()



def write_pretraining_data(num_shots, num_channels, file_path):
if num_channels == 1:
interferometer = MichelsonInterferometer(0.5, 5, np.pi / 4)
for _ in tqdm(range(num_shots)):
interferometer.plot_buffer()
_, signal, _, velocity = interferometer.get_buffer()
signal = np.expand_dims(signal, axis=-1)
velocity = np.expand_dims(velocity, axis=-1)
entries = {"signal": signal, "velocity": velocity}
util.write_data(file_path, entries)
elif num_channels == 2:
interferometer1 = MichelsonInterferometer(0.5, 5, np.pi / 4)
interferometer2 = MichelsonInterferometer(0.3, 5, 2 * np.pi / 4)
for _ in tqdm(range(num_shots)):
time, signal1, displacement, velocity = interferometer1.get_buffer()
interferometer2.set_displacement(displacement, time)
_, signal2, _, _ = interferometer2.get_buffer()
signal = np.stack((signal1, signal2), axis=-1)
velocity = np.expand_dims(velocity, axis=-1)
print(signal.shape)
entries = {"signal": signal, "velocity": velocity}
util.write_data(file_path, entries)


if __name__ == '__main__':
interferometer = MichelsonInterferometer(0.5, 5, 0)
interferometer.plot_pretraining_data()
write_pretraining_data(2, 2, "/Users/nolanpeard/Desktop/test.h5")

33 changes: 21 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, test_mode, h5_file, step, group_size=256):
self.opened_flag = False
self.test_mode = test_mode

def open_hdf5(self, rolling=True, step=256, group_size=256, ch_in=1):
def open_hdf5(self, rolling=True, step=256, group_size=256):
"""Set up inputs and targets. For each shot, buffer is split into groups of sequences.
Inputs include grouped photodiode trace of 'group_size', spaced interval 'step' apart for each buffer.
Targets include average velocity of each group.
Expand All @@ -55,34 +55,43 @@ def open_hdf5(self, rolling=True, step=256, group_size=256, ch_in=1):
# solves issue where hdf5 file opened in __init__ prevents multiple
# workers: https://github.com/pytorch/pytorch/issues/11929
self.file = h5py.File(self.h5_file, 'r')
pds = torch.Tensor(np.array(self.file['PD (V)'])) # [num_shots, buffer_size]
vels = torch.Tensor(np.array(self.file['Speaker (Microns/s)'])) # [num_shots, buffer_size]

signal = torch.Tensor(np.array(self.file['signal']))
# [num_shots, buffer_size, num_channels]
velocity = torch.Tensor(np.array(self.file['velocity']))
# [num_shots, buffer_size]

num_channels = signal.shape[-1]
if num_channels == 1:
signal = torch.squeeze(signal, dim=-1)
else:
raise ValueError('num_channels must be 1')
pass

if rolling:
# ROLLING INPUT INDICES
num_groups = (pds.shape[1] - group_size) // step + 1
num_groups = (signal.shape[1] - group_size) // step + 1
start_idxs = torch.arange(num_groups) * step # starting indices for each group
idxs = torch.arange(group_size)[:, None] + start_idxs
idxs = torch.transpose(idxs, dim0=0, dim1=1) # indices in shape [num_groups, group_size]
if self.test_mode:
self.inputs = pds # [num_shots, buffer_size]
grouped_vels = vels[:, idxs] # [num_shots, num_groups, group_size]
self.inputs = signal # [num_shots, buffer_size]
grouped_vels = velocity[:, idxs] # [num_shots, num_groups, group_size]
self.targets = torch.mean(grouped_vels, dim=2) # [num_shots, num_groups]
else:
self.inputs = pds[:, idxs].reshape(-1, group_size) # [num_shots * num_groups, group_size]
grouped_vels = vels[:, idxs].reshape(-1, group_size) # [num_shots * num_groups, group_size]
self.inputs = signal[:, idxs].reshape(-1, group_size) # [num_shots * num_groups, group_size]
grouped_vels = velocity[:, idxs].reshape(-1, group_size) # [num_shots * num_groups, group_size]
self.targets = torch.unsqueeze(torch.mean(grouped_vels, dim=1), dim=1) # [num_shots * num_groups, 1]
else:
# STEP INPUT
if self.test_mode:
assert False, 'test_mode not implemented for step input. use rolling step=256'
else:
# [num_shots * num_groups, group_size]
self.inputs = torch.cat(torch.split(pds, group_size, dim=1), dim=0)
grouped_vels = torch.cat(torch.split(vels, group_size, dim=1), dim=0)
self.inputs = torch.cat(torch.split(signal, group_size, dim=1), dim=0)
grouped_vels = torch.cat(torch.split(velocity, group_size, dim=1), dim=0)
self.targets = torch.unsqueeze(torch.mean(grouped_vels, dim=1), dim=1) # [num_shots * num_groups, 1]

if ch_in == 1:
if num_channels == 1:
self.inputs = torch.unsqueeze(self.inputs, dim=1)
self.targets = torch.unsqueeze(self.targets, dim=1)
else:
Expand Down
3 changes: 2 additions & 1 deletion util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def write_data(file_path, entries):
else:
f.create_dataset(col_name,
data=np.expand_dims(col_data, axis=0),
maxshape=(None, col_data.shape[0]),
maxshape=(None, col_data.shape[0],
col_data.shape[1]),
chunks=True)


Expand Down

0 comments on commit 4fa561f

Please sign in to comment.