Skip to content

Commit 4f33036

Browse files
committed
add Extra.DsuMonoid
1 parent acb6bf3 commit 4f33036

File tree

4 files changed

+313
-0
lines changed

4 files changed

+313
-0
lines changed

ac-library-hs.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ library
5656
AtCoder.Dsu
5757
AtCoder.Extra.AhoCorasick
5858
AtCoder.Extra.Bisect
59+
AtCoder.Extra.DsuMonoid
5960
AtCoder.Extra.DynLazySegTree
6061
AtCoder.Extra.DynLazySegTree.Persistent
6162
AtCoder.Extra.DynLazySegTree.Raw

src/AtCoder/Extra/DsuMonoid.hs

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
{-# LANGUAGE RecordWildCards #-}
2+
3+
-- | A disjoint set union with commutative monoid values associated with each group.
4+
--
5+
-- ==== __Example__
6+
--
7+
-- >>> import AtCoder.Extra.DsuMonoid qualified as Dm
8+
-- >>> import Data.Semigroup (Sum (..))
9+
-- >>> import Data.Vector.Unboxed qualified as VU
10+
-- >>> dsu <- Dm.build $ VU.generate 4 Sum
11+
-- >>> Dm.merge dsu 0 1
12+
-- 0
13+
--
14+
-- >>> Dm.read dsu 0
15+
-- Sum {getSum = 1}
16+
--
17+
-- >>> Dm.read dsu 1
18+
-- Sum {getSum = 1}
19+
--
20+
-- >>> Dm.mergeMaybe dsu 0 2
21+
-- Just 0
22+
--
23+
-- >>> Dm.read dsu 0
24+
-- Sum {getSum = 3}
25+
--
26+
-- @since 1.5.3.0
27+
module AtCoder.Extra.DsuMonoid
28+
( -- * Disjoint set union
29+
DsuMonoid (dsuDm, mDm),
30+
31+
-- * Constructors
32+
new,
33+
build,
34+
35+
-- * Merging
36+
merge,
37+
mergeMaybe,
38+
merge_,
39+
40+
-- * Leader
41+
leader,
42+
43+
-- * Component information
44+
same,
45+
size,
46+
groups,
47+
48+
-- * Monoid values
49+
read,
50+
unsafeRead,
51+
unsafeWrite,
52+
)
53+
where
54+
55+
import AtCoder.Dsu qualified as Dsu
56+
import Control.Monad.Primitive (PrimMonad, PrimState, stToPrim)
57+
import Data.Vector qualified as V
58+
import Data.Vector.Generic.Mutable qualified as VGM
59+
import Data.Vector.Unboxed qualified as VU
60+
import Data.Vector.Unboxed.Mutable qualified as VUM
61+
import GHC.Stack (HasCallStack)
62+
import Prelude hiding (read)
63+
64+
-- | A disjoint set union with commutative monoid values associated with each group.
65+
--
66+
-- @since 1.5.3.0
67+
data DsuMonoid s a = DsuMonoid
68+
{ -- | The original DSU.
69+
--
70+
-- @since 1.5.3.0
71+
dsuDm :: {-# UNPACK #-} !(Dsu.Dsu s),
72+
-- | Commutative monoid values for each group.
73+
--
74+
-- @since 1.5.3.0
75+
mDm :: !(VUM.MVector s a)
76+
}
77+
78+
-- | Creates an undirected graph with \(n\) vertices and \(0\) edges.
79+
--
80+
-- ==== Constraints
81+
-- - \(0 \le n\)
82+
--
83+
-- ==== Complexity
84+
-- - \(O(n)\)
85+
--
86+
-- @since 1.5.3.0
87+
{-# INLINE new #-}
88+
new :: (PrimMonad m, Monoid a, VU.Unbox a) => Int -> m (DsuMonoid (PrimState m) a)
89+
new n
90+
| n >= 0 = build $ VU.replicate n mempty
91+
| otherwise = error $ "AtCoder.Extra.DsuMonoid: given negative size (`" ++ show n ++ "`)"
92+
93+
-- | Creates an undirected graph with \(n\) vertices and \(0\) edges.
94+
--
95+
-- ==== Constraints
96+
-- - \(0 \le n\)
97+
--
98+
-- ==== Complexity
99+
-- - \(O(n)\)
100+
--
101+
-- @since 1.5.3.0
102+
{-# INLINE build #-}
103+
build :: (PrimMonad m, Semigroup a, VU.Unbox a) => VU.Vector a -> m (DsuMonoid (PrimState m) a)
104+
build ms = stToPrim $ do
105+
dsuDm <- Dsu.new $ VU.length ms
106+
mDm <- VU.thaw ms
107+
pure $ DsuMonoid {..}
108+
109+
-- | Adds an edge \((a, b)\). If the vertices \(a\) and \(b\) are in the same connected component, it
110+
-- returns the representative (`leader`) of this connected component. Otherwise, it returns the
111+
-- representative of the new connected component.
112+
--
113+
-- ==== Constraints
114+
-- - \(0 \leq a < n\)
115+
-- - \(0 \leq b < n\)
116+
--
117+
-- ==== Complexity
118+
-- - \(O(\alpha(n))\) amortized
119+
--
120+
-- @since 1.5.3.0
121+
{-# INLINEABLE merge #-}
122+
merge :: (HasCallStack, PrimMonad m, Semigroup a, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> Int -> m Int
123+
merge DsuMonoid {..} a b = stToPrim $ do
124+
r1 <- Dsu.leader dsuDm a
125+
r2 <- Dsu.leader dsuDm b
126+
if r1 == r2
127+
then pure r1
128+
else do
129+
!m1 <- VGM.read mDm r1
130+
!m2 <- VGM.read mDm r2
131+
r' <- Dsu.merge dsuDm a b
132+
VGM.write mDm r' $! m1 <> m2
133+
pure r'
134+
135+
-- | Adds an edge \((a, b)\). It returns the representative of the new connected component, or
136+
-- `Nothing` if the two vertices are in the same connected component.
137+
--
138+
-- ==== Constraints
139+
-- - \(0 \leq a < n\)
140+
-- - \(0 \leq b < n\)
141+
--
142+
-- ==== Complexity
143+
-- - \(O(\alpha(n))\) amortized
144+
--
145+
-- @since 1.2.4.0
146+
{-# INLINEABLE mergeMaybe #-}
147+
mergeMaybe :: (HasCallStack, PrimMonad m, Semigroup a, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> Int -> m (Maybe Int)
148+
mergeMaybe DsuMonoid {..} a b = stToPrim $ do
149+
r1 <- Dsu.leader dsuDm a
150+
r2 <- Dsu.leader dsuDm b
151+
if r1 == r2
152+
then pure Nothing
153+
else do
154+
!m1 <- VGM.read mDm r1
155+
!m2 <- VGM.read mDm r2
156+
r' <- Dsu.merge dsuDm a b
157+
VGM.write mDm r' $! m1 <> m2
158+
pure $ Just r'
159+
160+
-- | `merge` with the return value discarded.
161+
--
162+
-- ==== Constraints
163+
-- - \(0 \leq a < n\)
164+
-- - \(0 \leq b < n\)
165+
--
166+
-- ==== Complexity
167+
-- - \(O(\alpha(n))\) amortized
168+
--
169+
-- @since 1.5.3.0
170+
{-# INLINE merge_ #-}
171+
merge_ :: (PrimMonad m, Semigroup a, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> Int -> m ()
172+
merge_ dsu a b = do
173+
_ <- merge dsu a b
174+
pure ()
175+
176+
-- | Returns whether the vertices \(a\) and \(b\) are in the same connected component.
177+
--
178+
-- ==== Constraints
179+
-- - \(0 \leq a < n\)
180+
-- - \(0 \leq b < n\)
181+
--
182+
-- ==== Complexity
183+
-- - \(O(\alpha(n))\) amortized
184+
--
185+
-- @since 1.5.3.0
186+
{-# INLINE same #-}
187+
same :: (HasCallStack, PrimMonad m) => DsuMonoid (PrimState m) a -> Int -> Int -> m Bool
188+
same dsu = Dsu.same (dsuDm dsu)
189+
190+
-- | Returns the representative of the connected component that contains the vertex \(a\).
191+
--
192+
-- ==== Constraints
193+
-- - \(0 \leq a \lt n\)
194+
--
195+
-- ==== Complexity
196+
-- - \(O(\alpha(n))\) amortized
197+
--
198+
-- @since 1.5.3.0
199+
{-# INLINE leader #-}
200+
leader :: (HasCallStack, PrimMonad m) => DsuMonoid (PrimState m) a -> Int -> m Int
201+
leader dsu = Dsu.leader (dsuDm dsu)
202+
203+
-- | Returns the size of the connected component that contains the vertex \(a\).
204+
--
205+
-- ==== Constraints
206+
-- - \(0 \leq a < n\)
207+
--
208+
-- ==== Complexity
209+
-- - \(O(\alpha(n))\)
210+
--
211+
-- @since 1.5.3.0
212+
{-# INLINE size #-}
213+
size :: (HasCallStack, PrimMonad m) => DsuMonoid (PrimState m) a -> Int -> m Int
214+
size dsu = Dsu.size (dsuDm dsu)
215+
216+
-- | \O(n)\) Divides the graph into connected components and returns the vector of them.
217+
--
218+
-- More precisely, it returns a vector of the "vector of the vertices in a connected component".
219+
-- Both of the orders of the connected components and the vertices are undefined.
220+
--
221+
-- @since 1.5.3.0
222+
{-# INLINE groups #-}
223+
groups :: (PrimMonad m) => DsuMonoid (PrimState m) a -> m (V.Vector (VU.Vector Int))
224+
groups dsu = Dsu.groups (dsuDm dsu)
225+
226+
-- | \(O(1)\) Reads the group value of the \(k\)-th node.
227+
--
228+
-- @since 1.5.3.0
229+
{-# INLINE read #-}
230+
read :: (PrimMonad m, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> m a
231+
read DsuMonoid {..} i = do
232+
VGM.read mDm =<< Dsu.leader dsuDm i
233+
234+
-- | \(O(1)\) Reads the group value of the \(k\)-th node.
235+
--
236+
-- @since 1.5.3.0
237+
{-# INLINE unsafeRead #-}
238+
unsafeRead :: (PrimMonad m, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> m a
239+
unsafeRead DsuMonoid {..} i = do
240+
VGM.read mDm i
241+
242+
-- | \(O(1)\) Writes to the group value of the \(k\)-th node.
243+
--
244+
-- @since 1.5.3.0
245+
{-# INLINE unsafeWrite #-}
246+
unsafeWrite :: (PrimMonad m, VU.Unbox a) => DsuMonoid (PrimState m) a -> Int -> a -> m ()
247+
unsafeWrite DsuMonoid {..} i x = do
248+
VGM.write mDm i x

test/Main.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import Tests.Convolution qualified
66
import Tests.Dsu qualified
77
import Tests.Extra.AhoCorasick qualified
88
import Tests.Extra.Bisect qualified
9+
import Tests.Extra.DsuMonoid qualified
910
import Tests.Extra.DynLazySegTree qualified
1011
import Tests.Extra.DynLazySegTree.Persistent qualified
1112
import Tests.Extra.DynSegTree qualified
@@ -70,6 +71,7 @@ main =
7071
"Extra"
7172
[ testGroup "AhoCorasick" Tests.Extra.AhoCorasick.tests,
7273
testGroup "Bisect" Tests.Extra.Bisect.tests,
74+
testGroup "DsuMonoid" Tests.Extra.DsuMonoid.tests,
7375
testGroup "DynLazySegTree" Tests.Extra.DynLazySegTree.tests,
7476
testGroup "DynLazySegTree.Persistent" Tests.Extra.DynLazySegTree.Persistent.tests,
7577
testGroup "DynSegTree" Tests.Extra.DynSegTree.tests,

test/Tests/Extra/DsuMonoid.hs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
-- | Disjoint set union tests.
2+
module Tests.Extra.DsuMonoid (tests) where
3+
4+
import AtCoder.Extra.DsuMonoid qualified as Dsu
5+
import Data.Foldable
6+
import Data.Semigroup (Sum (..))
7+
import Data.Vector qualified as V
8+
import System.IO.Unsafe (unsafePerformIO)
9+
import Test.Hspec
10+
import Test.Tasty
11+
import Test.Tasty.HUnit
12+
import Test.Tasty.Hspec
13+
14+
unit_zero :: TestTree
15+
unit_zero = testCase "zero" $ do
16+
uf <- Dsu.new @_ @(Sum Int) 0
17+
(@?= V.empty) =<< Dsu.groups uf
18+
19+
-- empty
20+
-- assign
21+
22+
unit_simple :: TestTree
23+
unit_simple = testCase "simple" $ do
24+
uf <- Dsu.new @_ @(Sum Int) 2
25+
(@?= False) =<< Dsu.same uf 0 1
26+
x <- Dsu.merge uf 0 1
27+
(@?= x) =<< Dsu.leader uf 0
28+
(@?= x) =<< Dsu.leader uf 1
29+
(@?= True) =<< Dsu.same uf 0 1
30+
(@?= 2) =<< Dsu.size uf 0
31+
32+
unit_line :: TestTree
33+
unit_line = testCase "line" $ do
34+
let n = 500000
35+
uf <- Dsu.new @_ @(Sum Int) n
36+
for_ [0 .. n - 2] $ \i -> do
37+
Dsu.merge uf i (i + 1)
38+
(@?= n) =<< Dsu.size uf 0
39+
(@?= 1) . V.length =<< Dsu.groups uf
40+
41+
unit_lineReverse :: TestTree
42+
unit_lineReverse = testCase "lineReverse" $ do
43+
let n = 500000
44+
uf <- Dsu.new @_ @(Sum Int) n
45+
for_ [n - 2, n - 3 .. 0] $ \i -> do
46+
Dsu.merge uf i (i + 1)
47+
(@?= n) =<< Dsu.size uf 0
48+
(@?= 1) . V.length =<< Dsu.groups uf
49+
50+
spec_invalid :: IO TestTree
51+
spec_invalid = testSpec "invalid" $ do
52+
it "throws error" $ do
53+
Dsu.new @_ @(Sum Int) (-1) `shouldThrow` anyException
54+
55+
tests :: [TestTree]
56+
tests =
57+
[ unit_zero,
58+
unit_simple,
59+
unit_line,
60+
unit_lineReverse,
61+
unsafePerformIO spec_invalid
62+
]

0 commit comments

Comments
 (0)