|
1 | 1 | #include "bucket.h" |
2 | 2 | #include "sketch.h" |
| 3 | +#include <span> |
| 4 | + |
| 5 | +enum RecoveryResultTypes { |
| 6 | + SUCCESS, |
| 7 | + FAILURE |
| 8 | +}; |
| 9 | +struct RecoveryResult { |
| 10 | + RecoveryResultTypes result; |
| 11 | + std::vector<vec_t> recovered_indices; |
| 12 | +}; |
| 13 | + |
3 | 14 |
|
4 | 15 | class SparseRecovery { |
5 | 16 | private: |
6 | 17 | size_t universe_size; |
7 | 18 | size_t max_recovery_size; |
8 | 19 | size_t cleanup_sketch_support; |
9 | | - static constexpr double reduction_factor = 0.82; |
| 20 | + // 1 - 1/2e. TODO - can do better. closer to 1-1/e. for the power-of-two-rounding, |
| 21 | + // I'm gonna propose 0.69 (comfortably below sqrt(2) so we decrease the size every two levels) |
| 22 | + static constexpr double reduction_factor = 0.82; |
| 23 | + static constexpr double reduction_factor = 0.69; |
| 24 | + uint64_t checksum_seed; |
| 25 | + uint64_t seed; |
10 | 26 | // approx 1-1/2e. TODO - can do better. closer to 1-1/e with right |
11 | 27 | // bounding parameters |
12 | 28 | // TODO - rewrite this for better locality |
13 | 29 | // should just be a single array, maybe with a lookup set of pointers for the start of each |
14 | | - std::vector<std::vector<Bucket>> recovery_buckets; |
| 30 | + std::vector<Bucket> recovery_buckets; |
| 31 | + std::vector<size_t> starter_indices; |
| 32 | + Sketch cleanup_sketch; |
15 | 33 | // TODO - see if we want to continue maintaining the deterministic bucket |
16 | 34 | Bucket deterministic_bucket; |
17 | | - Sketch cleanup_sketch; |
18 | 35 | public: |
19 | 36 | SparseRecovery(size_t universe_size, size_t max_recovery_size, double cleanup_sketch_support_factor, uint64_t seed): |
20 | 37 | // TODO - ugly constructor |
21 | 38 | cleanup_sketch(universe_size, seed, ceil(cleanup_sketch_support_factor * log2(universe_size)) * 2, 1) |
22 | 39 | { |
| 40 | + // TODO - define the seed better |
| 41 | + checksum_seed = seed; |
| 42 | + seed = seed * seed + 13; |
23 | 43 | universe_size = universe_size; |
24 | 44 | max_recovery_size = max_recovery_size; |
| 45 | + starter_indices.reserve(2 + ceil(log2(universe_size) - log2(log2( cleanup_sketch_support_factor * universe_size)))); |
| 46 | + starter_indices.push_back(0); |
25 | 47 | cleanup_sketch_support = ceil(cleanup_sketch_support_factor * log2(universe_size)); |
26 | 48 | size_t current_cfr_size = max_recovery_size; |
| 49 | + size_t current_cfr_idx = 0; |
27 | 50 | while (current_cfr_size > cleanup_sketch_support) { |
28 | | - // doing it this way also deals with zero-initialization |
29 | | - recovery_buckets.push_back(std::vector<Bucket>(current_cfr_size)); |
| 51 | + size_t power_of_two_rounded_size = 1 << (size_t) ceil(log2(current_cfr_size)); |
| 52 | + // TODO - examine whether it's better to do something else. |
| 53 | + // ROUND THE SIZE TO A POWER OF TWO -- important for maintaining uniformity. |
| 54 | + auto current_start_idx = starter_indices[current_cfr_idx++] + power_of_two_rounded_size; |
| 55 | + starter_indices.push_back(current_start_idx); |
30 | 56 | current_cfr_size = ceil(current_cfr_size * reduction_factor); |
31 | 57 | } |
| 58 | + auto full_storage_size = starter_indices.back(); |
| 59 | + // starter_indices.pop_back(); |
| 60 | + recovery_buckets.resize(full_storage_size); |
32 | 61 | }; |
| 62 | + private: |
| 63 | + size_t num_levels() const { |
| 64 | + return starter_indices.size() - 1; |
| 65 | + } |
| 66 | + size_t get_cfr_size(size_t level) const { |
| 67 | + assert(level < starter_indices.size() - 1); |
| 68 | + return starter_indices[level+1] - starter_indices[level]; |
| 69 | + } |
| 70 | + Bucket& get_cfr_bucket(size_t row, size_t col) { |
| 71 | + size_t cfr_start_idx = starter_indices[row]; |
| 72 | + return recovery_buckets[cfr_start_idx + col]; |
| 73 | + } |
| 74 | + |
| 75 | + public: |
| 76 | + inline uint64_t get_seed() const { return seed; } |
| 77 | + inline uint64_t level_seed(size_t level) const { |
| 78 | + return seed * (2 + seed) + level * 30; |
| 79 | + } |
| 80 | + inline size_t checksum_seed() const { return seed; } |
33 | 81 | void update(const vec_t update) { |
34 | | - // TODO - checksum seed agreement. |
35 | | - vec_hash_t checksum = Bucket_Boruvka::get_index_hash(update,0); |
36 | 82 | for (size_t cfr_idx=0; cfr_idx < recovery_buckets.size(); cfr_idx++) { |
37 | | - // TODO - get this with an actual function |
38 | | - size_t hash_index = Bucket_Boruvka::get_index_hash(update, cfr_idx * 1231) % recovery_buckets[cfr_idx].size(); |
39 | | - // recovery_buckets[cfr_idx][hash_index] ^= update; |
40 | | - Bucket_Boruvka::update(recovery_buckets[cfr_idx][hash_index], update, checksum); |
| 83 | + size_t hash_index = Bucket_Boruvka::get_index_hash(update, cfr_idx * 1231) % get_cfr_size(cfr_idx); |
| 84 | + Bucket_Boruvka::update(get_cfr_bucket(cfr_idx, hash_index), update, checksum_seed()); |
41 | 85 | } |
42 | 86 | cleanup_sketch.update(update); |
| 87 | + Bucket_Boruvka::update(deterministic_bucket, update, checksum_seed()); |
43 | 88 | } |
44 | 89 | void reset() { |
45 | 90 | // zero contents of the CFRs |
| 91 | + for (size_t i=0; i < recovery_buckets.size(); i++) { |
| 92 | + recovery_buckets[i] = {0, 0}; |
| 93 | + } |
46 | 94 | cleanup_sketch.zero_contents(); |
47 | 95 | }; |
48 | 96 | // NOTE THAT THIS IS A DESTRUCTIVE OPERATION AT THE MOMENT. |
49 | | - std::vector<Bucket> recover() { |
50 | | - std::vector<Bucket> recovered_indices; |
51 | | - for (size_t cfr_idx=0; cfr_idx < recovery_buckets.size(); cfr_idx++) { |
52 | | - // first, remove all the already recovered indices |
53 | | - for (auto recov: recovered_indices) { |
54 | | - size_t hash_index = Bucket_Boruvka::get_index_hash(recov.alpha, cfr_idx * 1231) % recovery_buckets[cfr_idx].size(); |
55 | | - recovery_buckets[cfr_idx][hash_index] ^= recov; |
56 | | - } |
57 | | - // now go hunting for good buckets |
58 | | - for (size_t bucket_idx=0; bucket_idx < recovery_buckets[cfr_idx].size(); bucket_idx++) { |
59 | | - Bucket &bucket = recovery_buckets[cfr_idx][bucket_idx]; |
60 | | - if (Bucket_Boruvka::is_good(bucket, 0)) { |
61 | | - recovered_indices.push_back(bucket); |
| 97 | + RecoveryResult recover() { |
| 98 | + std::vector<vec_t> recovered_indices; |
| 99 | + for (size_t cfr_idx=0; cfr_idx < num_levels(); cfr_idx++) { |
| 100 | + // go hunting for good buckets |
| 101 | + auto cfr_size = get_cfr_size(cfr_idx); |
| 102 | + for (size_t bucket_idx=0; bucket_idx < cfr_size; bucket_idx++) { |
| 103 | + // Bucket &bucket = recovery_buckets[cfr_idx][bucket_idx]; |
| 104 | + Bucket &bucket = get_cfr_bucket(cfr_idx, bucket_idx); |
| 105 | + if (Bucket_Boruvka::is_good(bucket, checksum_seed())) { |
| 106 | + recovered_indices.push_back(bucket.alpha); |
| 107 | + // update it out of the sketch everywhere. |
| 108 | + this->update(bucket.alpha); |
| 109 | + |
| 110 | + // EARLY EXIT CONDITION: deterministic bucket empty |
| 111 | + if (Bucket_Boruvka::is_empty(deterministic_bucket)) { |
| 112 | + return {SUCCESS, recovered_indices}; |
| 113 | + } |
62 | 114 | } |
63 | 115 | } |
64 | | - // ... repeat until we cleared all the cfrs |
65 | | - } |
66 | | - // now, recover from the sketches |
67 | | - for (auto recov: recovered_indices) { |
68 | | - cleanup_sketch.update(recov.alpha); |
| 116 | + // repeat until we cleared out all the sketches. |
69 | 117 | } |
70 | 118 | size_t i=0; |
71 | 119 | for (; i < cleanup_sketch.get_num_samples(); i++) { |
72 | 120 | ExhaustiveSketchSample sample = cleanup_sketch.exhaustive_sample(); |
73 | 121 | if (sample.result == ZERO) { |
74 | | - break; |
| 122 | + return {SUCCESS, recovered_indices}; |
75 | 123 | } |
76 | 124 | for (auto idx: sample.idxs) { |
77 | 125 | // todo - checksum stuff. tihs is bad code writing but whatever, anything |
78 | 126 | // to get out of writing psuedocode... |
79 | | - recovered_indices.push_back({idx, Bucket_Boruvka::get_index_hash(idx, 0)}); |
| 127 | + recovered_indices.push_back(idx); |
80 | 128 | // todo - this is inefficient. we are recalculating the bucket hash |
81 | 129 | // for literally no reason |
82 | | - cleanup_sketch.update(idx); |
| 130 | + // but doing things this way is important for undoing our recovery! |
| 131 | + // otherwise, we're stuck with a bunch of extra bookkeeping |
| 132 | + this->update(idx); |
83 | 133 | } |
84 | 134 | } |
85 | 135 | if (i == cleanup_sketch.get_num_samples()) { |
86 | 136 | // we ran out of samples |
87 | 137 | // TODO - UNDO YOUR RECOVERY!!! |
| 138 | + for (auto idx: recovered_indices) { |
| 139 | + this->update(idx); |
| 140 | + } |
| 141 | + recovered_indices.clear(); |
88 | 142 | } |
89 | | - return recovered_indices; |
| 143 | + return {FAILURE, recovered_indices}; |
90 | 144 | }; |
91 | 145 | void merge(const SparseRecovery &other) { |
92 | | - // TODO - xor together all the CFRs |
| 146 | + assert(other.recovery_buckets.size() == recovery_buckets.size()); |
| 147 | + for (size_t i=0; i < recovery_buckets.size(); i++) { |
| 148 | + recovery_buckets[i] ^= other.recovery_buckets[i]; |
| 149 | + } |
93 | 150 | cleanup_sketch.merge(other.cleanup_sketch); |
94 | 151 | }; |
95 | 152 | ~SparseRecovery(); |
|
0 commit comments