Skip to content

Commit 006e76c

Browse files
committed
update: more advanced ML model + support code
1 parent 7a4e9f1 commit 006e76c

File tree

3 files changed

+144
-156
lines changed

3 files changed

+144
-156
lines changed

e2e.pt

-43 MB
Binary file not shown.

lib/csi_preprocessor.py

Lines changed: 58 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,58 @@
1-
import math
2-
3-
import torch
4-
import numpy as np
5-
from scipy.signal import butter, filtfilt
6-
7-
import os
8-
import dotenv
9-
10-
from .csi_database import CSIRecord
11-
12-
dotenv.load_dotenv()
13-
14-
ML_MODEL_SAMPLES_PER_RECORDING = int(os.environ.get("ML_MODEL_SAMPLES_PER_RECORDING"))
15-
16-
17-
18-
def csi_int16_from_bytes(b: bytes) -> list[int]:
19-
return [
20-
int.from_bytes(b[2*i:2*i+2], byteorder='little', signed=True)
21-
for i in range(len(b) // 2)
22-
]
23-
24-
25-
def lowpass_filter(data, cutoff=0.1, order=4):
26-
b, a = butter(order, cutoff, btype='low')
27-
return filtfilt(b, a, data)
28-
29-
30-
def bytes_to_amplitude_phase(raw_csi_bytes: bytes) -> tuple[list[float], list[float]]:
31-
csi_fingerprint = csi_int16_from_bytes(raw_csi_bytes)
32-
amplitude = []
33-
phase = []
34-
for j in range(0, 128, 2):
35-
real = csi_fingerprint[j]
36-
imag = csi_fingerprint[j + 1]
37-
amp = math.sqrt(real ** 2 + imag ** 2)
38-
phs = math.atan2(imag, real)
39-
amplitude.append(amp) # [64]
40-
phase.append(phs) # [64]
41-
42-
amplitude = lowpass_filter(amplitude)
43-
44-
delta_ampl = np.max(amplitude) - np.min(amplitude)
45-
amplitude = (amplitude - np.min(amplitude)) / (delta_ampl if delta_ampl != 0 else 1) # min-max
46-
47-
phase = list(np.unwrap(phase))
48-
delta_phase = np.max(phase) - np.min(phase)
49-
phase = (phase - np.min(phase)) / (delta_phase if delta_phase != 0 else 1)
50-
51-
return amplitude, phase
52-
53-
#csi_fingerprint = [amplitude, phase]
54-
55-
#csi_fingerprint = torch.tensor(csi_fingerprint, dtype=torch.float64) # [time, 2, subcarrier_idx]
56-
#print(csi_fingerprint.shape)
57-
#csi_fingerprint = csi_fingerprint.transpose(0, 1) # [2, time, subcarrier_idx]
58-
#print(csi_fingerprint.shape)
59-
60-
#return csi_fingerprint.tolist()
61-
62-
63-
def records_to_tensor(csi_records: list[CSIRecord]) -> torch.tensor:
64-
amplitudes = []
65-
phases = []
66-
for record in csi_records:
67-
csi_bytes = record.get_csi_bytes()
68-
amplitude, phase = bytes_to_amplitude_phase(csi_bytes)
69-
amplitudes.append(amplitude)
70-
phases.append(phase)
71-
72-
arr = np.array([[amplitudes, phases]])
73-
return torch.tensor(arr, dtype=torch.float)
1+
import math
2+
3+
import torch
4+
import numpy as np
5+
from scipy.signal import butter, filtfilt
6+
7+
import os
8+
import dotenv
9+
10+
from .csi_database import CSIRecord
11+
12+
dotenv.load_dotenv()
13+
14+
ML_MODEL_SAMPLES_PER_RECORDING = int(os.environ.get("ML_MODEL_SAMPLES_PER_RECORDING"))
15+
16+
17+
18+
def csi_int16_from_bytes(b: bytes) -> list[int]:
19+
return [
20+
int.from_bytes(b[2*i:2*i+2], byteorder='little', signed=True)
21+
for i in range(len(b) // 2)
22+
]
23+
24+
25+
def bytes_to_amplitude_phase(raw_csi_bytes: bytes) -> list[float]:
26+
csi_fingerprint = csi_int16_from_bytes(raw_csi_bytes)
27+
amplitude = []
28+
for j in range(0, 128, 2):
29+
if j in [0, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74]:
30+
continue
31+
real = csi_fingerprint[j]
32+
imag = csi_fingerprint[j + 1]
33+
amplitude.append(math.sqrt(real ** 2 + imag ** 2)) # [52]
34+
35+
36+
delta_ampl = np.max(amplitude) - np.min(amplitude)
37+
amplitude = (amplitude - np.min(amplitude)) / (delta_ampl if delta_ampl != 0 else 1) # min-max normalization
38+
39+
return amplitude
40+
#csi_fingerprint = [amplitude, phase]
41+
42+
#csi_fingerprint = torch.tensor(csi_fingerprint, dtype=torch.float64) # [time, 2, subcarrier_idx]
43+
#print(csi_fingerprint.shape)
44+
#csi_fingerprint = csi_fingerprint.transpose(0, 1) # [2, time, subcarrier_idx]
45+
#print(csi_fingerprint.shape)
46+
47+
#return csi_fingerprint.tolist()
48+
49+
50+
def records_to_tensor(csi_records: list[CSIRecord]) -> torch.tensor:
51+
amplitudes = []
52+
for record in csi_records:
53+
csi_bytes = record.get_csi_bytes()
54+
amplitude = bytes_to_amplitude_phase(csi_bytes)
55+
amplitudes.append(amplitude)
56+
57+
arr = np.array([[amplitudes]])
58+
return torch.tensor(arr, dtype=torch.float)

lib/ml_model.py

Lines changed: 86 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,86 @@
1-
import torch
2-
import torch.nn as nn
3-
import torchvision.models as models
4-
5-
import os
6-
import dotenv
7-
8-
dotenv.load_dotenv()
9-
10-
ML_MODEL_CHECKPOINT_PATH = os.environ.get("ML_MODEL_CHECKPOINT_PATH")
11-
12-
# taken from https://github.com/RS2002/CrossFi/blob/main/model.py, introduced by "CrossFi: A Cross Domain WiFi Sensing Framework Based on Siamese Network"
13-
class ResnetBasedSiamese2dNetwork(nn.Module):
14-
def __init__(self, output_dims: int = 64, channel: int = 2, pretrained: bool = True, norm: bool = False):
15-
super().__init__()
16-
self.model = models.resnet18(pretrained)
17-
self.model.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
18-
self.model.fc = nn.Linear(self.model.fc.in_features, output_dims)
19-
self.norm = norm
20-
21-
def forward(self,x):
22-
if self.norm:
23-
mean = torch.mean(x, dim=-1, keepdim=True)
24-
std = torch.std(x, dim=-1, keepdim=True)
25-
y = (x - mean) / std
26-
else:
27-
y = x
28-
return self.model(y)
29-
30-
31-
class SiameseDiscriminator(nn.Module):
32-
def __init__(self, embedding_dim=64):
33-
super().__init__()
34-
# after concat you have 2×embedding_dim input
35-
self.classifier = nn.Sequential(
36-
nn.Linear(2 * embedding_dim, 256),
37-
nn.ReLU(inplace=True),
38-
nn.Dropout(0.3),
39-
nn.Linear(256, 64),
40-
nn.ReLU(inplace=True),
41-
nn.Linear(64, 1),
42-
nn.Sigmoid(), # outputs P(same)
43-
)
44-
45-
def forward(self, e1, e2):
46-
# you could also do torch.abs(e1 - e2) or elementwise mult
47-
x = torch.cat([e1, e2], dim=1)
48-
return self.classifier(x)
49-
50-
51-
class SiameseModel(nn.Module):
52-
def __init__(self, embedding_dim=64):
53-
super().__init__()
54-
self.encoder = ResnetBasedSiamese2dNetwork()
55-
self.discriminator = SiameseDiscriminator(embedding_dim)
56-
57-
def forward(self, x1, x2):
58-
e1 = self.encoder(x1)
59-
e2 = self.encoder(x2)
60-
p_same = self.discriminator(e1, e2)
61-
return p_same
62-
63-
64-
def model_predict(embedding_model, csi_1, csi_2, device):
65-
with torch.no_grad():
66-
csi_1, csi_2 = csi_1.to(device), csi_2.to(device)
67-
68-
p_same = embedding_model(csi_1, csi_2).flatten()
69-
pred = (p_same >= 0.9).float().item()
70-
71-
return pred > 0
72-
73-
74-
75-
## External Func
76-
def authenticate(csi_1, csi_2):
77-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78-
79-
embedding_model = SiameseModel().to(device).eval()
80-
81-
embedding_model.load_state_dict(torch.load(ML_MODEL_CHECKPOINT_PATH, map_location=device))
82-
83-
return model_predict(embedding_model, csi_1, csi_2, device)
1+
import torch
2+
import torch.nn as nn
3+
import torchvision.models as models
4+
5+
import os
6+
import dotenv
7+
8+
dotenv.load_dotenv()
9+
10+
ML_MODEL_CHECKPOINT_PATH = os.environ.get("ML_MODEL_CHECKPOINT_PATH")
11+
12+
class CNNEmbedder(nn.Module):
13+
def __init__(self, channel: int = 1):
14+
super().__init__()
15+
self.model = torch.nn.Sequential(
16+
nn.Conv2d(channel, 4, kernel_size=3, stride=1),
17+
nn.Tanh(),
18+
nn.Dropout2d(p=0.5),
19+
20+
nn.Conv2d(4, 8, kernel_size=3, stride=1),
21+
nn.Tanh(),
22+
nn.Dropout2d(p=0.5),
23+
24+
nn.MaxPool2d(kernel_size=3, stride=1),
25+
26+
nn.Conv2d(8, 16, kernel_size=5, stride=1),
27+
nn.Tanh(),
28+
nn.Dropout2d(p=0.5),
29+
30+
nn.Conv2d(16, 32, kernel_size=3, stride=1),
31+
nn.Tanh(),
32+
nn.AdaptiveAvgPool2d(1),
33+
34+
nn.Flatten(),
35+
)
36+
37+
def forward(self,x):
38+
return self.model(x)
39+
40+
41+
class SiameseDiscriminator(nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
self.classifier = nn.Sequential(
45+
nn.Linear(32, 1),
46+
nn.Sigmoid(),
47+
)
48+
49+
def forward(self, e1, e2):
50+
x = torch.abs(e1 - e2)
51+
return self.classifier(x)
52+
53+
54+
class SiameseModel(nn.Module):
55+
def __init__(self):
56+
super().__init__()
57+
self.encoder = CNNEmbedder()
58+
self.discriminator = SiameseDiscriminator()
59+
60+
def forward(self, x1, x2):
61+
e1 = self.encoder(x1)
62+
e2 = self.encoder(x2)
63+
p_same = self.discriminator(e1, e2)
64+
return p_same
65+
66+
67+
def model_predict(embedding_model, csi_1, csi_2, device):
68+
with torch.no_grad():
69+
csi_1, csi_2 = csi_1.to(device), csi_2.to(device)
70+
71+
p_same = embedding_model(csi_1, csi_2).flatten()
72+
pred = (p_same > 0.7).float().item()
73+
74+
return pred > 0
75+
76+
77+
78+
## External Func
79+
def authenticate(csi_1, csi_2):
80+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81+
82+
embedding_model = SiameseModel().to(device).eval()
83+
84+
embedding_model.load_state_dict(torch.load(ML_MODEL_CHECKPOINT_PATH, map_location=device))
85+
86+
return model_predict(embedding_model, csi_1, csi_2, device)

0 commit comments

Comments
 (0)