Skip to content

Commit

Permalink
dauhst update
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 8, 2024
1 parent 074a221 commit 73add40
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 15 deletions.
17 changes: 11 additions & 6 deletions MST/simulation/train_code/architecture/DAUHST.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class HS_MSA(nn.Module):
def __init__(
self,
dim,
window_size=(8, 8),
#window_size=(8, 8),
window_size=(7, 7),
dim_head=28,
heads=8,
only_local_branch=False
Expand All @@ -71,7 +72,8 @@ def __init__(
else:
seq_l1 = window_size[0] * window_size[1]
self.pos_emb1 = nn.Parameter(torch.Tensor(1, 1, heads//2, seq_l1, seq_l1))
h,w = 256//self.heads,320//self.heads
#h,w = 256//self.heads,320//self.heads
h,w = 112//self.heads,112//self.heads
seq_l2 = h*w//seq_l1
self.pos_emb2 = nn.Parameter(torch.Tensor(1, 1, heads//2, seq_l2, seq_l2))
trunc_normal_(self.pos_emb1)
Expand All @@ -89,6 +91,7 @@ def forward(self, x):
"""
b, h, w, c = x.shape
w_size = self.window_size

assert h % w_size[0] == 0 and w % w_size[1] == 0, 'fmap dimensions must be divisible by the window size'
if self.only_local_branch:
x_inp = rearrange(x, 'b (h b0) (w b1) c -> (b h w) (b0 b1) c', b0=w_size[0], b1=w_size[1])
Expand Down Expand Up @@ -145,7 +148,8 @@ class HSAB(nn.Module):
def __init__(
self,
dim,
window_size=(8, 8),
#window_size=(8, 8),
window_size=(7, 7),
dim_head=64,
heads=8,
num_blocks=2,
Expand Down Expand Up @@ -338,9 +342,10 @@ def initial(self, y, Phi):
nC, step = 28, 2
y = y / nC * 2
bs,row,col = y.shape
y_shift = torch.zeros(bs, nC, row, col).cuda().float()
""" y_shift = torch.zeros(bs, nC, row, col).cuda().float()
for i in range(nC):
y_shift[:, i, :, step * i:step * i + col - (nC - 1) * step] = y[:, :, step * i:step * i + col - (nC - 1) * step]
y_shift[:, i, :, step * i:step * i + col - (nC - 1) * step] = y[:, :, step * i:step * i + col - (nC - 1) * step] """
y_shift = y.unsqueeze(1).repeat((1, 28, 1, 1)).float()
z = self.fution(torch.cat([y_shift, Phi], dim=1))
alpha, beta = self.para_estimator(self.fution(torch.cat([y_shift, Phi], dim=1)))
return z, alpha, beta
Expand All @@ -363,5 +368,5 @@ def forward(self, y, input_mask=None):
z = self.denoisers[i](torch.cat([x, beta_repeat],dim=1))
if i<self.num_iterations-1:
z = shift_3d(z)
return z[:, :, :, 0:256]
return z[:, :, :, 0:112]

1 change: 0 additions & 1 deletion MST/simulation/train_code/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def generate_shift_masks(mask_path, batch_size):
Phi_batch = mask_3d_shift.expand([batch_size, nC, H, W]).cuda().float()
Phi_s_batch = torch.sum(Phi_batch**2,1)
Phi_s_batch[Phi_s_batch==0] = 1
# print(Phi_batch.shape, Phi_s_batch.shape)
return Phi_batch, Phi_s_batch

def LoadTraining(path):
Expand Down
21 changes: 17 additions & 4 deletions optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,19 @@ def forward(self, x):
# acquisition = self.acquired_image1.unsqueeze(1)
acquisition = self.acquired_image1.float()
filtering_cubes = filtering_cubes.float()
elif "dauhst" in self.model_name:
acquisition = self.acquired_image1.float()

filtering_cubes_s = torch.sum(filtering_cubes**2,1)
filtering_cubes_s[filtering_cubes_s==0] = 1
filtering_cubes = (filtering_cubes.float(), filtering_cubes_s.float())

elif self.model_name == "mst_plus_plus":
acquisition = self.acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device)
print(f"acquisition shape: {acquisition.shape}")
print(f"filtering_cubes shape: {filtering_cubes.shape}")
reconstructed_cube = self.reconstruction_model(acquisition, filtering_cubes.to(self.device))
#print(f"acquisition shape: {acquisition.shape}")
#print(f"filtering_cubes shape: {filtering_cubes.shape}")

reconstructed_cube = self.reconstruction_model(acquisition, filtering_cubes)


return reconstructed_cube
Expand Down Expand Up @@ -200,7 +208,12 @@ def _common_step(self, batch, batch_idx):

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=4e-4)
return optimizer
return { "optimizer":optimizer,
"lr_scheduler":{
"scheduler":torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6),
"interval": "epoch"
}
}

def _log_images(self, tag, images, global_step):
# Convert model output to image grid and log to TensorBoard
Expand Down
2 changes: 1 addition & 1 deletion simca/functions_acquisition_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def interpolate_data_on_grid_positions_torch(data, X_init, Y_init, X_target, Y_t
torch.Tensor: Interpolated 4D data on the target grid.
"""

print(data.shape)
#print(data.shape)

# Ensure tensors are on the correct device and data type
device = data.device
Expand Down
6 changes: 3 additions & 3 deletions training_simca_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@



data_dir = "./datasets_reconstruction/cave_1024_28"
datamodule = CubesDataModule(data_dir, batch_size=2, num_workers=1)
data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28"
datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=11)


name = "testing_simca_reconstruction"
model_name = "birnat"
model_name = "dauhst_5"

log_dir = 'tb_logs'

Expand Down

0 comments on commit 73add40

Please sign in to comment.