From 8a53b8ddebc73d9583f1a5e68640ae44392a7f16 Mon Sep 17 00:00:00 2001 From: Zhaoyi-Yan Date: Wed, 19 Jun 2024 12:33:43 +0800 Subject: [PATCH] Update archs.py --- Seg_UKAN/archs.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/Seg_UKAN/archs.py b/Seg_UKAN/archs.py index 8275eee..b7789ab 100644 --- a/Seg_UKAN/archs.py +++ b/Seg_UKAN/archs.py @@ -55,7 +55,7 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay ) self.fc2 = KANLinear( hidden_features, - out_features, + hidden_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, @@ -80,7 +80,7 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay self.dwconv_1 = DW_bn_relu(hidden_features) self.dwconv_2 = DW_bn_relu(hidden_features) - self.dwconv_3 = DW_bn_relu(hidden_features) + self.dwconv_3 = DW_bn_relu(out_features) self.drop = nn.Dropout(drop) @@ -104,17 +104,16 @@ def _init_weights(self, m): def forward(self, x, H, W): - # pdb.set_trace() - B, N, C = x.shape + B, N, _ = x.shape - x = self.fc1(x.reshape(B*N,C)) - x = x.reshape(B,N,C).contiguous() + x = self.fc1(x.reshape(B*N,-1)) + x = x.reshape(B,N,-1).contiguous() x = self.dwconv_1(x, H, W) - x = self.fc2(x.reshape(B*N,C)) - x = x.reshape(B,N,C).contiguous() + x = self.fc2(x.reshape(B*N,-1)) + x = x.reshape(B,N,-1).contiguous() x = self.dwconv_2(x, H, W) - x = self.fc3(x.reshape(B*N,C)) - x = x.reshape(B,N,C).contiguous() + x = self.fc3(x.reshape(B*N,-1)) + x = x.reshape(B,N,-1).contiguous() x = self.dwconv_3(x, H, W) return x @@ -169,7 +168,7 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay self.fc1 = nn.Linear(in_features, hidden_features) - self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = nn.Linear(hidden_features, hidden_features) self.fc3 = nn.Linear(hidden_features, out_features) @@ -191,7 +190,7 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay self.dwconv_1 = DW_bn_relu(hidden_features) self.dwconv_2 = DW_bn_relu(hidden_features) - self.dwconv_3 = DW_bn_relu(hidden_features) + self.dwconv_3 = DW_bn_relu(out_features) self.drop = nn.Dropout(drop) @@ -214,17 +213,16 @@ def _init_weights(self, m): def forward(self, x, H, W): - # pdb.set_trace() - B, N, C = x.shape + B, N, _ = x.shape - x = self.fc1(x.reshape(B*N,C)) - x = x.reshape(B,N,C).contiguous() + x = self.fc1(x.reshape(B*N,-1)) + x = x.reshape(B,N,-1).contiguous() x = self.dwconv_1(x, H, W) - x = self.fc2(x.reshape(B*N,C)) - x = x.reshape(B,N,C).contiguous() + x = self.fc2(x.reshape(B*N,-1)) + x = x.reshape(B,N,-1).contiguous() x = self.dwconv_2(x, H, W) - x = self.fc3(x.reshape(B*N,C)) - x = x.reshape(B,N,C).contiguous() + x = self.fc3(x.reshape(B*N,-1)) + x = x.reshape(B,N,-1).contiguous() x = self.dwconv_3(x, H, W) return x