1
- import torch
2
- import torch .nn as nn
3
- import torchvision .models as models
4
-
5
- import os
6
- import dotenv
7
-
8
- dotenv .load_dotenv ()
9
-
10
- ML_MODEL_CHECKPOINT_PATH = os .environ .get ("ML_MODEL_CHECKPOINT_PATH" )
11
-
12
- # taken from https://github.com/RS2002/CrossFi/blob/main/model.py, introduced by "CrossFi: A Cross Domain WiFi Sensing Framework Based on Siamese Network"
13
- class ResnetBasedSiamese2dNetwork (nn .Module ):
14
- def __init__ (self , output_dims : int = 64 , channel : int = 2 , pretrained : bool = True , norm : bool = False ):
15
- super ().__init__ ()
16
- self .model = models .resnet18 (pretrained )
17
- self .model .conv1 = nn .Conv2d (channel , 64 , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )
18
- self .model .fc = nn .Linear (self .model .fc .in_features , output_dims )
19
- self .norm = norm
20
-
21
- def forward (self ,x ):
22
- if self .norm :
23
- mean = torch .mean (x , dim = - 1 , keepdim = True )
24
- std = torch .std (x , dim = - 1 , keepdim = True )
25
- y = (x - mean ) / std
26
- else :
27
- y = x
28
- return self .model (y )
29
-
30
-
31
- class SiameseDiscriminator (nn .Module ):
32
- def __init__ (self , embedding_dim = 64 ):
33
- super ().__init__ ()
34
- # after concat you have 2×embedding_dim input
35
- self .classifier = nn .Sequential (
36
- nn .Linear (2 * embedding_dim , 256 ),
37
- nn .ReLU (inplace = True ),
38
- nn .Dropout (0.3 ),
39
- nn .Linear (256 , 64 ),
40
- nn .ReLU (inplace = True ),
41
- nn .Linear (64 , 1 ),
42
- nn .Sigmoid (), # outputs P(same)
43
- )
44
-
45
- def forward (self , e1 , e2 ):
46
- # you could also do torch.abs(e1 - e2) or elementwise mult
47
- x = torch .cat ([e1 , e2 ], dim = 1 )
48
- return self .classifier (x )
49
-
50
-
51
- class SiameseModel (nn .Module ):
52
- def __init__ (self , embedding_dim = 64 ):
53
- super ().__init__ ()
54
- self .encoder = ResnetBasedSiamese2dNetwork ()
55
- self .discriminator = SiameseDiscriminator (embedding_dim )
56
-
57
- def forward (self , x1 , x2 ):
58
- e1 = self .encoder (x1 )
59
- e2 = self .encoder (x2 )
60
- p_same = self .discriminator (e1 , e2 )
61
- return p_same
62
-
63
-
64
- def model_predict (embedding_model , csi_1 , csi_2 , device ):
65
- with torch .no_grad ():
66
- csi_1 , csi_2 = csi_1 .to (device ), csi_2 .to (device )
67
-
68
- p_same = embedding_model (csi_1 , csi_2 ).flatten ()
69
- pred = (p_same >= 0.9 ).float ().item ()
70
-
71
- return pred > 0
72
-
73
-
74
-
75
- ## External Func
76
- def authenticate (csi_1 , csi_2 ):
77
- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
78
-
79
- embedding_model = SiameseModel ().to (device ).eval ()
80
-
81
- embedding_model .load_state_dict (torch .load (ML_MODEL_CHECKPOINT_PATH , map_location = device ))
82
-
83
- return model_predict (embedding_model , csi_1 , csi_2 , device )
1
+ import torch
2
+ import torch .nn as nn
3
+ import torchvision .models as models
4
+
5
+ import os
6
+ import dotenv
7
+
8
+ dotenv .load_dotenv ()
9
+
10
+ ML_MODEL_CHECKPOINT_PATH = os .environ .get ("ML_MODEL_CHECKPOINT_PATH" )
11
+
12
+ class CNNEmbedder (nn .Module ):
13
+ def __init__ (self , channel : int = 1 ):
14
+ super ().__init__ ()
15
+ self .model = torch .nn .Sequential (
16
+ nn .Conv2d (channel , 4 , kernel_size = 3 , stride = 1 ),
17
+ nn .Tanh (),
18
+ nn .Dropout2d (p = 0.5 ),
19
+
20
+ nn .Conv2d (4 , 8 , kernel_size = 3 , stride = 1 ),
21
+ nn .Tanh (),
22
+ nn .Dropout2d (p = 0.5 ),
23
+
24
+ nn .MaxPool2d (kernel_size = 3 , stride = 1 ),
25
+
26
+ nn .Conv2d (8 , 16 , kernel_size = 5 , stride = 1 ),
27
+ nn .Tanh (),
28
+ nn .Dropout2d (p = 0.5 ),
29
+
30
+ nn .Conv2d (16 , 32 , kernel_size = 3 , stride = 1 ),
31
+ nn .Tanh (),
32
+ nn .AdaptiveAvgPool2d (1 ),
33
+
34
+ nn .Flatten (),
35
+ )
36
+
37
+ def forward (self ,x ):
38
+ return self .model (x )
39
+
40
+
41
+ class SiameseDiscriminator (nn .Module ):
42
+ def __init__ (self ):
43
+ super ().__init__ ()
44
+ self .classifier = nn .Sequential (
45
+ nn .Linear (32 , 1 ),
46
+ nn .Sigmoid (),
47
+ )
48
+
49
+ def forward (self , e1 , e2 ):
50
+ x = torch .abs (e1 - e2 )
51
+ return self .classifier (x )
52
+
53
+
54
+ class SiameseModel (nn .Module ):
55
+ def __init__ (self ):
56
+ super ().__init__ ()
57
+ self .encoder = CNNEmbedder ()
58
+ self .discriminator = SiameseDiscriminator ()
59
+
60
+ def forward (self , x1 , x2 ):
61
+ e1 = self .encoder (x1 )
62
+ e2 = self .encoder (x2 )
63
+ p_same = self .discriminator (e1 , e2 )
64
+ return p_same
65
+
66
+
67
+ def model_predict (embedding_model , csi_1 , csi_2 , device ):
68
+ with torch .no_grad ():
69
+ csi_1 , csi_2 = csi_1 .to (device ), csi_2 .to (device )
70
+
71
+ p_same = embedding_model (csi_1 , csi_2 ).flatten ()
72
+ pred = (p_same > 0.7 ).float ().item ()
73
+
74
+ return pred > 0
75
+
76
+
77
+
78
+ ## External Func
79
+ def authenticate (csi_1 , csi_2 ):
80
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
81
+
82
+ embedding_model = SiameseModel ().to (device ).eval ()
83
+
84
+ embedding_model .load_state_dict (torch .load (ML_MODEL_CHECKPOINT_PATH , map_location = device ))
85
+
86
+ return model_predict (embedding_model , csi_1 , csi_2 , device )
0 commit comments