-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
1761 lines (1481 loc) · 70.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import math
import copy
import os
import random
import torch
import pickle
import torch.nn as nn
import numpy as np
from collections import defaultdict
# all possible batch sizes to choose from
supported_batch_sizes = [1, 2, 4, 8, 16, 32, 64]
# supported model architectures
all_cv_models = [
# distiller
"resnet20_cifar10", "resnet32_cifar10", "resnet44_cifar10",
"resnet56_cifar10", "resnet110_cifar10", "resnet1202_cifar10",
"resnet20_cifar100", "resnet32_cifar100", "resnet44_cifar100",
"resnet56_cifar100", "resnet110_cifar100", "resnet1202_cifar100",
# custom datasets
"resnet18_waymo", "resnet50_waymo",
"resnet18_urban", "resnet50_urban",
]
all_nlp_models = [
# dee{bert,roberta,distilbert}
"bert-base-uncased",
"bert-large-uncased",
# "roberta",
"distilbert-base-uncased",
]
all_supported_models = all_cv_models + all_nlp_models
# supported datasets
all_cv_datasets = [
# CV
"cifar10", "cifar100",
# custom CV datasets
"waymo", "urban",
]
all_nlp_datasets = [
# NLP - GLUE
"mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst-2", "wnli",
]
all_supported_datasets = all_cv_datasets + all_nlp_datasets
def round_up_batch_size(batch_size):
"""Rounds up the batch size to 2^n.
This is mostly used to accommodate the last batch in a dataset,
where the bs might not be 2^n.
Args:
batch_size (int): batch size
Returns:
int: rounded-up batch size
"""
return min([x for x in supported_batch_sizes if x >= batch_size])
def set_seeds(seed=42):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def nth_repl(s, sub, repl, n):
# https://stackoverflow.com/a/35092436/9601555
"""Replace the n-th occurrence of sub in s with repl
Args:
s (str): original string
sub (str): substring to be substituted
repl (str): replacement substring
n (int): 1-indexed occurrence ID
In [14]: s = "foobarfoofoobarbar"
In [15]: nth_repl(s, "bar", "replaced", 3)
Out[15]: 'foobarfoofoobarreplaced'
Returns:
_type_: _description_
"""
find = s.find(sub)
# If find is not -1 we have found at least one match for the substring
i = find != -1
# loop util we find the nth or we find no match
while find != -1 and i != n:
# find + 1 means we start searching from after the last match
find = s.find(sub, find + 1)
i += 1
# If i is equal to n we found nth match so replace
if i == n:
return s[:find] + repl + s[find+len(sub):]
return s
def compare_lists(list1, list2):
# Compare the lists element-wise
for element1, element2 in zip(list1, list2):
if element1 < element2:
return True
elif element1 > element2:
return False
return len(list1) < len(list2)
def is_monotonic(l, increasing):
"""Checks whether items in a list are monotonically
increasing/decreasing
Args:
l (list): list of items
increasing (bool): True if checking
for monotonic increasing, and False if
checking for monotonic decreasing.
Returns:
bool: whether items in a list are monotonically
increasing/decreasing
"""
l = list(l)
if increasing:
return all(l[i] <= l[i+1] for i in range(len(l)-1))
else: # check for monotonic decreasing
return all(l[i] >= l[i+1] for i in range(len(l)-1))
def normalize_exit_rate(exit_counter, total_num_samples):
"""Given an exit counter as a dict,
return the normalized exit rates such that
the exit rate at each ramp sum up to 100.
Args:
exit_counter (dict): key: 0-indexed ramp id, value: number of samples exited
total_num_samples (int): total number of samples in the dataset
Returns:
dict: key: 0-indexed ramp id, value: exit rate %
"""
exit_rate = {}
for ramp_id, num_samples_exited in exit_counter.items():
rate = round(num_samples_exited / total_num_samples * 100, 5)
exit_rate[ramp_id] = rate
return exit_rate
def get_remaining_rate(exit_rate):
"""Given the normalized exit rate at each ramp,
return the remaining sample rate (CDF) at
each ramp
Args:
exit_rate (dict): key: 0-indexed ramp IDs,
value: normalized exit rate (0-1)
OR: (np.ndarray): index x: normalized exit
rate at ramp of index x
Returns:
remaining_rate (dict): key: 0-indexed ramp IDs,
value: percentage of remaining samples (0-1).
OR: (np.ndarray): index x: normalized remaining
rate after ramp of index x
"""
if type(exit_rate) is dict:
remaining_rate = {}
for ramp_id, exit_rate_at_ramp in exit_rate.items():
last_ramp_rate = 1 if ramp_id == 0 else remaining_rate[ramp_id - 1]
remaining_rate[ramp_id] = last_ramp_rate - exit_rate_at_ramp
return remaining_rate
elif type(exit_rate) is np.ndarray:
remaining_rate = []
for ramp_index, exit_rate_at_ramp in enumerate(exit_rate):
last_ramp_rate = 1 if ramp_index == 0 else remaining_rate[ramp_index - 1]
remaining_rate.append(last_ramp_rate - exit_rate_at_ramp)
return np.array(remaining_rate)
else:
raise NotImplementedError
def merge_batches(batches, total_num_ramps):
merged = {'conf': [[] for _ in range(total_num_ramps)],
'acc': [[] for _ in range(total_num_ramps)]}
for i, batch in enumerate(batches):
for ramp_id in range(total_num_ramps):
merged["conf"][ramp_id] += batch["conf"][ramp_id]
merged["acc"][ramp_id] += batch["acc"][ramp_id]
return merged
def get_avg_exit_point(ramp_ids, exit_rate, latency_calc_list):
"""Returns the average exit point of all samples.
Args:
ramp_ids (list): list of ramp ids
exit_rate (np.ndarray): exit rate at all ramps and the final model,
sums up to 1
latency_calc_list (list): list (len = num_all_ramps + 1) of tuples.
Index x: vanilla model latency before ramp x and the latency of ramp x.
Last index: (vanilla model latency, None).
Returns:
avg_exit_point (float): normalized (0-1) average exit point of all samples
"""
# vanilla model latency
vanilla_latency = latency_calc_list[-1][0]
# normalized latency (w.r.t to vanilla model latency) of all current exit points
vanilla_distances = np.array(
[latency_calc_list[i][0] / vanilla_latency for i in ramp_ids] + [1.0])
print(f"vanilla_distances of ramps {ramp_ids}: {vanilla_distances}")
avg_exit_point = sum(exit_rate * vanilla_distances)
return avg_exit_point
def get_batches(pickle_dict: dict, batch_size: int, batch_size_schedule: list = None):
"""Given a fixed batch size or a schedule of the per-batch bs,
return the entropy dict for emulating serving
Args:
pickle_dict (dict): _description_
batch_size (int): fixed batch size
batch_size_schedule (list): list of ints of batch sizes in every batch
Yields:
_type_: _description_
"""
total_num_samples = len(pickle_dict["conf"][0])
num_ramps = len(pickle_dict["conf"]) - 1
if batch_size_schedule is None: # fixed batch size
start_indices = list(
range(0, total_num_samples, batch_size))
end_indices = [
x + batch_size for x in start_indices][:-1] + [total_num_samples - 1]
else: # dynamic/adaptive batch size, read from schedule
start_indices = [sum(batch_size_schedule[:i]) for i in range(len(batch_size_schedule))]
end_indices = [sum(x) for x in zip(start_indices, batch_size_schedule)]
for start, end in zip(start_indices, end_indices):
assert end - start <= max(supported_batch_sizes), \
f"Batch size in schedule exceeded max supported by our system!"
batch = {"conf": [], "acc": []}
if start == end:
continue
for ramp_id in range(num_ramps):
batch["conf"].append(
pickle_dict["conf"][ramp_id][start:end])
batch["acc"].append(
pickle_dict["acc"][ramp_id][start:end])
yield batch
def get_subdatasets(pickle_dict: dict, by_hardness: bool = False, num_subdatasets: int = 10):
print(f"pickle_dict.keys() {pickle_dict.keys()}") # seems to be a pickle format issue (key rn is sample id, needs to be conf and acc)
total_num_samples = len(pickle_dict["conf"][0])
num_ramps = len(pickle_dict["conf"]) - 1
print(f"total_num_samples {total_num_samples}, num_ramps {num_ramps}")
if by_hardness:
#############################################################################
# different methods of partitioning a subdataset
#############################################################################
# method 1: use entropy at first ramp as proxy for hardness
# works well for QQP, but less so for others
num_samples_in_subdataset = math.ceil(
total_num_samples / num_subdatasets)
# get the indices of sample ids sorted by entropy at first ramp ascending
sorted_index = np.argsort(pickle_dict["conf"][0])
start_indices = list(
range(0, total_num_samples, num_samples_in_subdataset))
end_indices = [
x + num_samples_in_subdataset for x in start_indices][:-1] + [total_num_samples - 1]
for start, end in zip(start_indices, end_indices):
subdataset = {"conf": [], "acc": []}
for ramp_id in range(num_ramps + 1):
conf = np.array(pickle_dict["conf"][ramp_id])[
sorted_index[start:end]]
conf = list(conf)
acc = np.array(pickle_dict["acc"][ramp_id])[
sorted_index[start:end]]
acc = list(acc)
subdataset["conf"].append(conf)
subdataset["acc"].append(acc)
yield subdataset
#############################################################################
# method 2.1: look at the series of predictions made by the ramps.
# the earlier "a correct prediction" appears, the easier an input is.
#############################################################################
# method 2.2: look at the series of predictions made by the ramps.
# the earlier "a series of correct predictions until the end of the model"
# appears, the easier an input is.
#############################################################################
#
#############################################################################
else:
num_samples_in_subdataset = math.ceil(
total_num_samples / num_subdatasets)
start_indices = list(
range(0, total_num_samples, num_samples_in_subdataset))
end_indices = [
x + num_samples_in_subdataset for x in start_indices][:-1] + [total_num_samples - 1]
for start, end in zip(start_indices, end_indices):
subdataset = {"conf": [], "acc": []}
for ramp_id in range(num_ramps + 1):
subdataset["conf"].append(
pickle_dict["conf"][ramp_id][start:end])
subdataset["acc"].append(
pickle_dict["acc"][ramp_id][start:end])
yield subdataset
# def get_subdatasets_v2(pickle_dict: dict, by_hardness: bool = False):
# # outdated, nlp pickle format
# total_num_samples = len(pickle_dict)
# num_ramps = len(pickle_dict[0]["all_entropies"])
# if by_hardness:
# num_subdatasets = 10
# num_samples_in_subdataset = math.ceil(total_num_samples / num_subdatasets)
# # get the indices of sample ids sorted by entropy at first ramp ascending
# sorted_index = np.argsort([sample[0]["all_entropies"] for sample in pickle_dict.values()])
# start_indices = list(range(0, total_num_samples, num_samples_in_subdataset))
# end_indices = [x + num_samples_in_subdataset for x in start_indices][:-1] + [total_num_samples - 1]
# for start, end in zip(start_indices, end_indices):
# subdataset = {}
# for sample_id, profile in pickle_dict.items():
# if sample_id in sorted_index[start:end]:
# subdataset[sample_id] = profile
# yield subdataset
# else:
# # at least 1000 samples in each subdataset for statistical significance
# num_samples_in_subdataset = 1000
# start_indices = list(range(0, total_num_samples, num_samples_in_subdataset))
# end_indices = [x + num_samples_in_subdataset for x in start_indices][:-1] + [total_num_samples - 1]
# for start, end in zip(start_indices, end_indices):
# subdataset = {k: v for k, v in pickle_dict.items() if start <= k < end}
# yield subdataset
def pickle_format_convert(pickle_dict):
"""
Converts an NLP pickle's format to match with that of CV pickles.
nlp/deebert pickle format: {
sample_id: {
"all_entropies": [], # list of entropies of this sample at each ramp
"all_logits": [], # list of logits of this sample at each ramp
"all_predictions": [], # list of predictions of this sample at each ramp
"orig_model_prediction": [], # ground-truth label from dataset
}
}
cv/distiller pickle format: {
"conf": [ # confidence = 1 - entropy
[sample_0_entropy_ramp_0, sample_1_entropy_ramp_0, ...], # float between 0 and 0.5
[sample_0_entropy_ramp_1, sample_1_entropy_ramp_1, ...],
...
# original model is considered a ramp
],
"acc": [
[sample_0_ee_prediction_ramp_0, sample_1_ee_prediction_ramp_0, ...], # bool
[sample_0_ee_prediction_ramp_1, sample_1_ee_prediction_ramp_1, ...],
...
],
}
Args:
pickle_dict (dict): nlp pickle
Returns:
dict: nlp pickle with cv pickle's format
"""
num_samples = len(pickle_dict)
num_ramps = len(pickle_dict[0]["all_entropies"])
# print(f"num_samples {num_samples}, num_ramps {num_ramps}")
"""
for now, conf and acc are: [
[sample_0_entropy_ramp_0, sample_0_entropy_ramp_1, ...],
[sample_1_entropy_ramp_0, sample_1_entropy_ramp_1, ...], ...
]
"""
conf = [[] for _ in range(num_samples)]
acc = [[] for _ in range(num_samples)]
if type(pickle_dict) == dict:
pickle_dict = list(pickle_dict.values())
for sample_id, sample in enumerate(pickle_dict):
# also change entropy definition to match cv pickle format
# removes deebert ramp at end of model
conf[sample_id] = [
1 - x for x in list(sample["all_entropies"])][:-1] + [None]
label = sample["orig_model_prediction"][0]
acc[sample_id] = [label == p[0]
for p in sample["all_predictions"]][:-1] + [True]
# transpose list of lists
conf = list(map(list, zip(*conf)))
acc = list(map(list, zip(*acc)))
return {
"conf": conf,
"acc": acc,
}
def query_performance(config, all_ramps_conf, all_ramps_acc, ramp_ids, latency_config, baseline):
"""
Given a configuration, return the performance of the model
under this configuration.
Args:
config (list): list of exit thresholds
all_ramps_conf (list): confidences of all ramps for all samples
all_ramps_acc (list): accuracies of all ramps for all samples
ramp_ids (list): list of ramp ids
latency_config (numpy array): latency configuration
baseline (float): baseline latency
Returns:
acc (float): accuracy of the model under this configuration
latency (float): latency of the model under this configuration
exit_rate (numpy array): exit rate of each ramp
"""
correct = 0
nums_exit = [0 for i in range(len(ramp_ids) + 1)]
for i in range(len(all_ramps_conf[ramp_ids[0]])):
earlyexit_taken = False
for j in range(len(ramp_ids)):
id = ramp_ids[j]
if 1 - all_ramps_conf[id][i] < config[j]:
nums_exit[j] += 1
earlyexit_taken = True
if all_ramps_acc[id][i]:
correct += 1
break
if not earlyexit_taken:
nums_exit[-1] += 1
correct += 1
exit_rate = np.array(
[(n+0.0)/len(all_ramps_conf[ramp_ids[0]]) for n in nums_exit])
acc = round((correct+0.0)/len(all_ramps_conf[ramp_ids[0]]), 7)
latency_improvement = (
baseline - sum(exit_rate * latency_config)) / baseline * 100
return acc, latency_improvement, exit_rate
def get_latency_config(path, ramp_ids):
if "resnet" in path and "cifar" in path:
res = [0.0 for _ in ramp_ids]
for i in range(len(ramp_ids)):
res[i] = 0.25 + 0.11*(i+1) + 0.31*(ramp_ids[i] + 1)
res.append(8.72 + 0.11*len(ramp_ids))
return np.array(res), 8.72
elif "bert" in path:
res = [0.0 for _ in ramp_ids]
for i in range(len(ramp_ids)):
res[i] = 0.106*(i+1) + 0.622*(ramp_ids[i] + 1)
res.append(7.464 + 0.106*len(ramp_ids))
# print(res)
return np.array(res), 7.464
elif "resnet50" in path and "phx" in path:
res = [0.0 for _ in ramp_ids]
for i in range(len(ramp_ids)):
res[i] = 0.25 + 0.11*(i+1) + 0.31*(ramp_ids[i] + 1)
res.append(5.31 + 0.11*len(ramp_ids))
return np.array(res), 5.31
elif "resnet18" in path and "urban" in path:
res = [0.0 for _ in ramp_ids]
for i in range(len(ramp_ids)):
res[i] = 0.25 + 0.11*(i+1) + 0.31*(ramp_ids[i] + 1)
res.append(2.83 + 0.11*len(ramp_ids))
return np.array(res), 2.83
def num_trials_for_global_optimal(num_ramps, num_threshold_options):
total_num = 0
for curr_num_ramps in reversed(range(1, num_ramps+1)):
# number of ramp combinations for curr_num_ramps
num_ramp_combinations = math.comb(num_ramps, curr_num_ramps)
# number of threshold combinations for curr_num_ramps
num_threshold_combinations = num_threshold_options ** curr_num_ramps
total_num += num_ramp_combinations * num_threshold_combinations
print(
f"total num trials for optimal grid search (num_ramps {num_ramps}, num_threshold_options {num_threshold_options}): {total_num}")
return total_num
def get_batch(data, batch_size):
"""Get a batch of data.
Args:
data (list): list of data
batch_size (int): batch size
Yields:
list: batch of data
"""
max_idx = len(data['conf'][0])
for i in range(0, max_idx, batch_size):
res = {'conf': [data['conf'][j][i: min(i + batch_size, max_idx)] for j in range(len(data['conf']))],
'acc': [data['acc'][j][i: min(i + batch_size, max_idx)] for j in range(len(data['acc']))]}
yield res
# for i in range(50):
# with open('./pickles/urban-{}_resnet18.pickle'.format(i), 'rb') as f:
# data = pickle.load(f)
# yield data
def parse_profile(profile):
"""Perform intermediate processing to the raw model profile, so that
we don't need to access the profile object every time we calculate
latency savings.
Args:
profile (Profiler.profile): model profile
Returns:
latency_calc_list (list): list (len = num_all_ramps + 1) of tuples.
Index x: vanilla model latency before ramp x and the latency of ramp x.
Last index: (vanilla model latency, None).
"""
# Distiller injection method (sequential execution)
latency_calc_list = []
all_branchpoints = profile.get_all_children_with_name(
"branch_net")
assert all_branchpoints != [], f"No branchpoints found!"
for ramp in all_branchpoints:
vanilla_latency_before_ramp = ramp.vanilla_latency_up_until_me # TODO(ruipan): change to me once all profile pickles are unified
# vanilla_latency_before_ramp = ramp.vanilla_latency_after_me # NOTE(ruipan): this is for old versions of profile pickles
ramp_latency = ramp.fwd_latency
latency_calc_list.append(
(vanilla_latency_before_ramp, ramp_latency,))
# add an entry for the vanilla model latency
vanilla_model_latency = profile.vanilla_latency_up_until_me # TODO(ruipan): change to me once all profile pickles are unified
# vanilla_model_latency = profile.vanilla_latency_after_me # NOTE(ruipan): this is for old versions of profile pickles
latency_calc_list.append((vanilla_model_latency, None,))
return latency_calc_list
def get_ramp_latencies(active_ramp_ids: list, latency_calc_list: list):
"""Compute the latency of exiting from different ramp locations
in a numpy array given a ramp configuration.
Args:
active_ramp_ids (list): 0-indexed ramp ids that are currently active
latency_calc_list (list): list of tuples, each tuple contains
the vanilla latency before the ramp and the ramp latency
Returns:
np.array: index x: latency in ms of exiting from the xth ramp.
float: latency in ms of traversing the vanilla model
"""
assert is_monotonic(active_ramp_ids, increasing=True), \
"Ramp IDs in configuration are not in order!"
assert latency_calc_list != [], f"latency_calc_list not yet calculated!"
# latency in ms of the vanilla model without any ramps
vanilla_latency = latency_calc_list[-1][0]
latencies = []
for ramp_id in active_ramp_ids:
latencies.append(
latency_calc_list[ramp_id][0] +
# vanilla latency + sum of all ramp latencies prior to current ramp
sum([latency_calc_list[id][1]
for id in active_ramp_ids if id <= ramp_id])
)
# add latency for samples that went through all exits but did not exit
latencies.append(vanilla_latency +
sum(latency_calc_list[id][1] for id in active_ramp_ids))
latencies = np.array(latencies)
return latencies, vanilla_latency
def serve_batch(thresholds, batch_data, ramp_ids, latency_calc_list):
"""Given a configuration, return the performance of the model
under this configuration.
Args:
thresholds (list): list of exit thresholds
batch_data (dict): batch of data
ramp_ids (list): list of ramp ids
latency_calc_list (list): list (len = num_all_ramps + 1) of tuples.
Index x: vanilla model latency before ramp x and the latency of ramp x.
Last index: (vanilla model latency, None).
Returns:
acc (float): accuracy
latency_improvement (float): latency improvement
exit_rate (numpy array): exit rate of each ramp
"""
all_ramps_conf = batch_data['conf']
all_ramps_acc = batch_data['acc']
# latency_config, baseline = \
# get_latency_config('resnet18_urban_config', ramp_ids)
latency_config, baseline = get_ramp_latencies(ramp_ids, latency_calc_list)
return query_performance(thresholds, all_ramps_conf, all_ramps_acc,
ramp_ids, latency_config, baseline)
def tune_threshold(ramp_ids, shadow_ramp_idx, data, acc_loss_budget, latency_calc_list, min_step_size=0.01, fixed_ramps_info = {}):
"""Tune the exit threshold for each ramp.
Args:
ramp_ids (list): list of ramp ids
shadow_ramp_idx (int): idx of shadow ramp in ramp_ids
data (dict): data
acc_loss_budget (float): max accuracy loss (compared to the original model output) we can afford
latency_calc_list (list): list (len = num_all_ramps + 1) of tuples.
Index x: vanilla model latency before ramp x and the latency of ramp x.
Last index: (vanilla model latency, None).
min_step_size (float): minimum step sizes
Returns:
thresholds (list): list of exit thresholds
latency_improvement (float): latency improvement
exit_rate (numpy array): exit rate of each ramp
acc (float): accuracy
"""
assert ramp_ids != [], f"Empty ramp_ids!"
if shadow_ramp_idx is not None:
activate_ramp_ids = ramp_ids[:shadow_ramp_idx] + \
ramp_ids[shadow_ramp_idx + 1:]
else:
activate_ramp_ids = ramp_ids
thresholds, latency_improvement, exit_rate, acc = None, float(
"-inf"), None, None
min_step_size = 0.001
for s in [0.01]:
# for s in [0.01, 0.02, 0.04]:
s = round(s, 4)
cur_config, curr_latency_improvement, curr_exit_rates, curr_acc = \
greedy_search_step(activate_ramp_ids, min_step_size, s, acc_loss_budget,
data, latency_calc_list, fixed_ramps_info)
if curr_latency_improvement > latency_improvement:
thresholds = cur_config
latency_improvement = curr_latency_improvement
exit_rate = curr_exit_rates
acc = curr_acc
# print("greedy search: ", ramp_ids, thresholds,
# latency_improvement, exit_rate, acc, flush=True)
if shadow_ramp_idx is not None:
thresholds.insert(shadow_ramp_idx, 0.0)
return thresholds, latency_improvement, exit_rate, acc
def greedy_search_step(ramp_ids, min_step_size, step_size, acc_loss_budget, data, latency_calc_list, fixed_ramps_info = {}):
"""Perform greedy search.
Args:
ramp_ids (list): list of ramp ids
min_step_size (float): minimum step sizes
step_size (float): step sizes
acc_loss_budget (float): max accuracy loss (compared to the original model output) we can afford
data (dict): data
latency_calc_list (list): list (len = num_all_ramps + 1) of tuples.
Index x: vanilla model latency before ramp x and the latency of ramp x.
Last index: (vanilla model latency, None).
Returns:
thresholds (list): list of exit thresholds
latency_improvement (float): latency improvement
exit_rate (list): exit rate
acc (float): accuracy
"""
latency_config, baseline = get_ramp_latencies(ramp_ids, latency_calc_list)
step_sizes = [step_size if id not in fixed_ramps_info else min_step_size for id in ramp_ids]
thresholds = [0.0 if id not in fixed_ramps_info else fixed_ramps_info[id] for id in ramp_ids]
acc, latency_improvement, exit_rate = None, None, None
acc, latency_improvement, exit_rate = \
query_performance(
thresholds, data['conf'], data['acc'], ramp_ids, latency_config, baseline)
while True:
next_exit_rate, positive_dirs = None, None
next_direction, next_acc, next_latency_improvement = None, None, None
next_direction, next_acc, next_latency_improvement, next_exit_rate, positive_dirs = \
explore_direction(data, ramp_ids, thresholds, step_sizes, latency_config,
baseline, acc, latency_improvement, exit_rate, acc_loss_budget, fixed_ramps_info)
if next_direction != None and thresholds[next_direction] <= 1:
acc = next_acc
latency_improvement = next_latency_improvement
exit_rate = next_exit_rate
thresholds[next_direction] =\
round(thresholds[next_direction] +
step_sizes[next_direction], 4)
step_sizes[next_direction] *= 2
for i in positive_dirs:
if i != next_direction:
step_sizes[i] *= 2
else:
flag = True
for i in range(len(step_sizes)):
if round(step_sizes[i], 4) <= min_step_size \
or thresholds[i] > 1:
continue
else:
flag = False
step_sizes[i] /= 2
if flag:
break
return thresholds, latency_improvement, exit_rate, acc
def explore_direction(data, ramp_ids, thresholds, step_sizes, latency_config,
baseline, curr_acc, curr_latency_improvement, curr_exit_rate,
acc_loss_budget, fixed_ramps_info=[None]):
"""Explore the direction of the next step.
Args:
data (dict): data
ramp_ids (list): list of ramp ids
thresholds (list): list of exit thresholds
step_sizes (list): list of step sizes
latency_config (list): list of latencies
baseline (float): baseline latency
curr_acc (float): current accuracy
curr_latency_improvement (float): current latency improvement
curr_exit_rate (list): current exit rate
acc_loss_budget (float): max accuracy loss (compared to the original model output) we can afford
Returns:
best_direction (int): best direction
res_acc (float): accuracy in best direction
res_latency_improvement (float): latency improvement in best direction
res_exit_rate (list): exit rate in best direction
positive_dirs (list): list of directions that have positive improvement
"""
best_direction = None
best_score = float("inf")
res_acc = None
res_latency_improvement = None
res_exit_rate = None
equal_num = 0
positive_dirs = []
positive_dirs_data = []
for direction in range(len(ramp_ids)):
if ramp_ids[direction] in fixed_ramps_info:
continue
temp_config = copy.deepcopy(thresholds)
temp_config[direction] = round(
temp_config[direction] + step_sizes[direction], 4)
temp_acc, temp_latency_improvement, temp_exit_rate = \
query_performance(
temp_config, data['conf'], data['acc'], ramp_ids, latency_config, baseline)
if abs(1 - temp_acc) < acc_loss_budget:
if temp_latency_improvement != curr_latency_improvement:
score = abs(temp_acc - curr_acc) / \
abs(temp_latency_improvement - curr_latency_improvement)
if score < best_score:
best_score = score
best_direction = direction
res_acc = temp_acc
res_exit_rate = temp_exit_rate
res_latency_improvement = temp_latency_improvement
else:
equal_num += 1
if temp_latency_improvement == curr_latency_improvement or \
temp_acc == curr_acc:
positive_dirs += [direction]
positive_dirs_data += [[temp_acc,
temp_latency_improvement, temp_exit_rate]]
if equal_num == len(ramp_ids):
return 0, curr_acc, curr_latency_improvement, curr_exit_rate, positive_dirs
if not best_direction and len(positive_dirs) > 0:
return positive_dirs[0], positive_dirs_data[0][0], \
positive_dirs_data[0][1], positive_dirs_data[0][2], positive_dirs
return best_direction, res_acc, \
res_latency_improvement, res_exit_rate, positive_dirs
def earlyexit_infer_per_sample(output, target, ramp_ids, thresholds, total_num_ramps, queuing_delay, ramp_latencies, optimal=False, simulated_pickle=None):
"""
Early exit inference per sample.
Args:
output (torch.Tensor): output tensor consisting of multiple exit outputs and final output
target (torch.Tensor): model prediction
ramp_ids (list): list of activated ramp ids
thresholds (list): list of exit thresholds
total_num_ramps (int): total possible exits including final exit
queuing_delay (list of float): queuing delay for each data sample in the batch
ramp_latencies (list of float): ramp latencies for each ramp including final exit
optimal (bool): whether to use optimal early exit, say only exit if the prediction is correct
Returns:
batch_meta_data (dict): for historical data update}
sample_latencies (list of float): latencies for each data sample (inference + queuing delay)
sample_acc (list of bool): prediction correctness for each data sample
sample_exit_points (list of int): exit ramp id for each data sample
"""
res_conf = [[] for _ in range(total_num_ramps)]
res_acc = [[] for _ in range(total_num_ramps)]
if simulated_pickle is None:
this_batch_size = target.size(0)
softmax = nn.Softmax(dim=1)
# calculate confidence and accuracy for each ramp
for exitnum in range(len(ramp_ids)):
out = softmax(output[exitnum])
out, inds = torch.max(out, dim=1)
res_conf[ramp_ids[exitnum]] += out.cpu().tolist()
res_acc[ramp_ids[exitnum]] += (inds == target).cpu().tolist()
res_conf[-1] += output[-1].cpu().tolist()
res_acc[-1] += [True for _ in range(this_batch_size)]
else:
this_batch_size = len(simulated_pickle["conf"][0])
for ramp_id in ramp_ids:
res_conf[ramp_id] = simulated_pickle["conf"][ramp_id]
res_acc[ramp_id] = simulated_pickle["acc"][ramp_id]
sample_latencies = []
sample_acc = []
sample_exit_points = []
for i in range(this_batch_size):
earlyexit = False
for j in range(len(ramp_ids)):
if not optimal:
ramp_id = ramp_ids[j]
if 1 - res_conf[ramp_id][i] < thresholds[j]:
sample_latencies += [(queuing_delay[i],
ramp_latencies[j],)]
sample_acc += [res_acc[ramp_id][i]]
sample_exit_points += [ramp_id]
earlyexit = True
break
else:
ramp_id = ramp_ids[j]
if res_acc[ramp_id][i]:
# if 1 - res_conf[ramp_id][i] < thresholds[j] and res_acc[ramp_id][i]:
sample_latencies += [(queuing_delay[i],
ramp_latencies[j],)]
sample_acc += [res_acc[ramp_id][i]]
sample_exit_points += [ramp_id]
earlyexit = True
break
if not earlyexit:
sample_latencies += [(queuing_delay[i], ramp_latencies[-1],)]
sample_exit_points += [total_num_ramps - 1]
sample_acc += [True]
batch_meta_data = {"conf": res_conf, "acc": res_acc}
return batch_meta_data, sample_latencies, sample_acc, sample_exit_points
def earlyexit_inference(output, target, ramp_ids, thresholds, total_num_ramps):
"""
Early exit inference.
Args:
output (torch.Tensor): output tensor consisting of multiple exit outputs and final output
target (torch.Tensor): model prediction
ramp_ids (list): list of activated ramp ids
thresholds (list): list of exit thresholds
total_num_ramps (int): total possible exits including final exit
Returns:
batch_meta_data (dict): for historical data update}
exit_rate (numpy array): exit rate for each ramp
"""
this_batch_size = target.size(0)
softmax = nn.Softmax(dim=1)
num_exits = len(ramp_ids) + 1
res_conf = [[] for _ in range(total_num_ramps)]
res_acc = [[] for _ in range(total_num_ramps)]
# calculate confidence and accuracy for each ramp
for exitnum in range(len(ramp_ids)):
out = softmax(output[exitnum])
out, inds = torch.max(out, dim=1)
res_conf[ramp_ids[exitnum]] += out.cpu().tolist()
res_acc[ramp_ids[exitnum]] += (inds == target).cpu().tolist()
# calculate confidence and accuracy for final exit
res_conf[-1] += output[-1].cpu().tolist()
res_acc[-1] += [True for _ in range(this_batch_size)]
exit_counter = [0 for i in range(len(ramp_ids) + 1)]
for batch_index in range(this_batch_size):
earlyexit_taken = False
# take the exit using CrossEntropyLoss as confidence measure (lower is more confident)
for exitnum in range(num_exits - 1):
if 1 - res_conf[ramp_ids[exitnum]][batch_index] < thresholds[exitnum]:
# take the results from early exit since lower than threshold
exit_counter[exitnum] += 1
earlyexit_taken = True
break
if not earlyexit_taken:
exit_counter[-1] += 1
exit_rate = np.array([(n+0.0)/this_batch_size for n in exit_counter])
batch_meta_data = {"conf": res_conf, "acc": res_acc}
return batch_meta_data, exit_rate
def get_queuing_delay(request_rate, batch_size):
"""
Get queuing delay.
Args:
request_rate (float): requests per second
Returns:
queuing_delay (list): list of queuing delays for each sample in the batch (ms)
"""
# time gap between two requests (ms)
gap = 1.0 / request_rate * 1000
queuing_delay = []
for i in range(batch_size):
queuing_delay.append(gap * i)
return queuing_delay[::-1]
def calculate_batch_size(request_rate, model_inference_time, slo):
"""
Calculate batch size based on request rate and latency.
Args:
request_rate (float): requests per second
model_inference_time (float): model inference time (ms)
slo (float): slo for request (ms)
Returns:
batch_size (int): batch size that can satisfy the slo
based on the given request rate
"""
batch_size = int((slo - model_inference_time) * request_rate / 1000.0)
return batch_size
def get_batch_perf(sample_latencies, sample_acc, sample_exit_points, vanilla_latency, ramp_ids, total_num_ramps):
"""
Query batch performance.
Args:
sample_latencies (list): list of tuples of (queuing_delay, ramp_latency)
sample_acc (list): list of bool of accuracy
sample_exit_points (list): list of int of exit points
vanilla_latency (float): latency of vanilla model
ramp_ids (list): list of activated ramp ids
total_num_ramps (int): total number of ramps
Returns:
acc (float): batch accuracy
latency_improvement (float): batch latency improvement
exit_rate (list): list of exit rate for each ramp
"""
acc = sum(sample_acc) / float(len(sample_acc))
curr_serving_latencies = [l[1] for l in sample_latencies]
latency_improvement = 100 * \
(vanilla_latency - sum(curr_serving_latencies) /
len(curr_serving_latencies)) / vanilla_latency
num_exit = [0 for _ in range(len(ramp_ids) + 1)]
for i, exit_point in enumerate(sample_exit_points):
for j, ramp_id in enumerate(ramp_ids):
if exit_point == ramp_id: