Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lazy propagation to treaps #262

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 50 additions & 42 deletions content/data-structures/Treap.h
Original file line number Diff line number Diff line change
@@ -1,67 +1,75 @@
/**
* Author: someone on Codeforces
* Date: 2017-03-14
* Source: folklore
* Author: Unknown
* Date: Unknown
* Source: https://cp-algorithms.com/data_structures/treap.html
* Description: A short self-balancing tree. It acts as a
* sequential container with log-time splits/joins, and
* is easy to augment with additional data.
* Time: $O(\log N)$
* Status: stress-tested
* is easy to augment with additional data. Also supports
* range updates.
* Time: $O(\log N)$ per split/merge
* Status: stress-tested, tested on cses Substring Reversals and Cut and Paste
*/
#pragma once

struct Node {
Node *l = 0, *r = 0;
int val, y, c = 1;
int val, y, c = 1, rev = 0;
Node(int val) : val(val), y(rand()) {}
void recalc();
};

int cnt(Node* n) { return n ? n->c : 0; }
void Node::recalc() { c = cnt(l) + cnt(r) + 1; }
void pull(Node* n) { if (n) n->c = cnt(n->l) + cnt(n->r)+1; }
void push(Node* n) {
if (!n) return;
if (n->rev) {
swap(n->l, n->r);
if (n->l) n->l->rev ^= 1;
if (n->r) n->r->rev ^= 1;
n->rev = 0;
}
}

template<class F> void each(Node* n, F f) {
if (n) { each(n->l, f); f(n->val); each(n->r, f); }
if (n) { push(n);each(n->l, f);f(n->val);each(n->r, f); }
}

pair<Node*, Node*> split(Node* n, int k) {
if (!n) return {};
if (cnt(n->l) >= k) { // "n->val >= k" for lower_bound(k)
auto pa = split(n->l, k);
n->l = pa.second;
n->recalc();
return {pa.first, n};
} else {
auto pa = split(n->r, k - cnt(n->l) - 1); // and just "k"
n->r = pa.first;
n->recalc();
return {n, pa.second};
}
// Put i first nodes into l, the rest into r
void split(Node* x, Node*& l, Node*& r, int i) {
if (!x) return void(l = r = 0);
push(x);
// replace cnt(x->l) with x->val for lower_bound(i)
if (i <= cnt(x->l)) split(x->l, l, x->l, i), r = x;
// and just i instead
else split(x->r, x->r, r, i - cnt(x->l) - 1), l = x;
pull(x);
}

Node* merge(Node* l, Node* r) {
if (!l) return r;
if (!r) return l;
if (l->y > r->y) {
l->r = merge(l->r, r);
l->recalc();
return l;
} else {
r->l = merge(l, r->l);
r->recalc();
return r;
}
// Append r to l, store it in x
void merge(Node*& x, Node* l, Node* r) {
push(l), push(r);
if (!l || !r) x = l ? l : r;
else if (l->y < r->y) merge(r->l, l, r->l), x = r;
else merge(l->r, l->r, r), x = l;
pull(x);
}

Node* ins(Node* t, Node* n, int pos) {
auto pa = split(t, pos);
return merge(merge(pa.first, n), pa.second);
void insert(Node*& t, Node* n, int pos) {
Node* l, * r;
split(t, l, r, pos), merge(l, l, n), merge(t, l, r);
}

// Example application: move the range [l, r) to index k
void move(Node*& t, int l, int r, int k) {
Node *a, *b, *c;
tie(a,b) = split(t, l); tie(b,c) = split(b, r - l);
if (k <= l) t = merge(ins(a, b, k), c);
else t = merge(a, ins(c, b, k - r));
Node* a, * b, * c;
split(t, a, c, r), split(a, a, b, l), merge(t, a, c);
if (k<=l) insert(t, b, k);
else insert(t, b, k - r + l);
}

// Reverse the range [l, r)
void rev(Node*& t, int l, int r) {
Node* a, * b, * c;
split(t, a, c, r), split(a, a, b, l);
b->rev ^= 1;
merge(a, a, b), merge(t, a, c);
}
75 changes: 47 additions & 28 deletions stress-tests/data-structures/Treap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,57 @@

#include "../../content/data-structures/Treap.h"

pair<Node*, Node*> split2(Node* n, int v) {
if (!n) return {};
if (n->val >= v) {
auto pa = split2(n->l, v);
n->l = pa.second;
n->recalc();
return {pa.first, n};
} else {
auto pa = split2(n->r, v);
n->r = pa.first;
n->recalc();
return {n, pa.second};
}
// l will have nodes <= i, r rest
void split2(Node* x, Node*& l, Node*& r, int i) {
if (!x) return void(l = r = 0);
push(x);
if (i <= x->val) split2(x->l, l, x->l, i), r = x;
else split2(x->r, x->r, r, i), l = x;
pull(x);
}

int ra() {
static unsigned x;
x *= 4176481;
x += 193861934;
return x >> 1;
mt19937 rng(10);
int ra(int hi) {
return uniform_int_distribution<int>(0, hi)(rng);
}

int main() {
srand(3);
// Treaps as sets
rep(it,0,1000) {
vector<Node> nodes;
vi exp;
rep(i,0,10) {
nodes.emplace_back(i*2+2);
exp.emplace_back(i*2+2);
}
Node* n = 0;
rep(i,0,10)
n = merge(n, &nodes[i]);
Node* root = 0;
rep(i,0,10) merge(root, root, &nodes[i]);

int v = rand() % 25;
int left = cnt(split2(n, v).first);
Node *d1, *d2;
split2(root, d1, d2, v);
int left = cnt(d1);
int rleft = (int)(lower_bound(all(exp), v) - exp.begin());
assert(left == rleft);
}

// move range
rep(it,0,10000) {
vector<Node> nodes;
vi exp;
rep(i,0,10) nodes.emplace_back(i);
rep(i,0,10) exp.emplace_back(i);
Node* n = 0;
rep(i,0,10)
n = merge(n, &nodes[i]);
merge(n, n, &nodes[i]);

int i = ra() % 11, j = ra() % 11;
int i = ra(10), j = ra(10);
if (i > j) swap(i, j);
int k = ra() % 11;
int k = ra(10);
if (i < k && k < j) continue;

move(n, i, j, k);
// cerr << i << ' ' << j << ' ' << k << endl;

int nk = (k >= j ? k - (j - i) : k);
vi iv(exp.begin() + i, exp.begin() + j);
Expand All @@ -67,10 +61,35 @@ int main() {

int ind = 0;
each(n, [&](int x) {
// cerr << x << ' ';
assert(x == exp[ind++]);
});
// cerr << endl;
}

// reverse range
rep(it,0,10000) {
vector<Node> nodes;
vi exp;
rep(i,0,10) nodes.emplace_back(i);
rep(i,0,10) exp.emplace_back(i);
Node* n = 0;
rep(i,0,10)
merge(n, n, &nodes[i]);

int rounds = ra(10);
// do multiple rounds to try and break lazy propagation
rep(i,0,rounds) {
int l = ra(9);
int r = ra(9);
if (r<l) swap(l,r);
if (l==r) r++;
reverse(exp.begin()+l, exp.begin()+r);
rev(n, l, r);
}

int ind = 0;
each(n, [&](int x) {
assert(x == exp[ind++]);
});
}
cout<<"Tests passed!"<<endl;
}
Loading