Skip to content

Very slow AIG walking #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
masinag opened this issue Nov 1, 2024 · 22 comments · Fixed by #136
Closed

Very slow AIG walking #135

masinag opened this issue Nov 1, 2024 · 22 comments · Fixed by #136

Comments

@masinag
Copy link
Contributor

masinag commented Nov 1, 2024

Hi, I am using py-aiger to perform some AIG manipulation.
In particular, I am trying to convert AIG to PySMT formulas.
I have found an example of AIG waking in

def dump(circ):

I have tried to adapt this to my scenario, but I have noticed this gets very slow on medium-big instances.
And I mean very slow, like hours instead of seconds.
E.g. https://github.com/yogevshalmon/allsat-circuits/blob/b57c2d6cba244460008dc6400beef2604a720c24/benchmarks/random_aig/large_cir_or/bench1/bench1.aag

It seems to me that the bottleneck is somewhere in aiger.common.dfs function, likely operations on sets of nodes.
I suppose that this can be due to the computation of hash for nodes (generated by @attr.frozen), which traverses the whole subgraph each time, for each node.

I attach code to replicate the issue.

import aiger
import funcy as fn


def gates(circ):
    gg = []
    count = 0

    class NodeAlg:
        def __init__(self, lit: int):
            self.lit = lit

        @fn.memoize
        def __and__(self, other):
            nonlocal count
            nonlocal gg
            count += 1
            new = NodeAlg(count << 1)
            right, left = sorted([self.lit, other.lit])
            gg.append((new.lit, left, right))
            return new

        @fn.memoize
        def __invert__(self):
            return NodeAlg(self.lit ^ 1)

    def lift(obj) -> NodeAlg:
        if isinstance(obj, bool):
            return NodeAlg(int(obj))
        elif isinstance(obj, NodeAlg):
            return obj
        raise NotImplementedError

    start = 1
    inputs = {k: NodeAlg(i << 1) for i, k in enumerate(sorted(circ.inputs), start)}
    count += len(inputs)

    omap, _ = circ(inputs=inputs, lift=lift)

    return gg


def main():
    circ = aiger.to_aig("bench1.aag")
    gg = gates(circ)

    print(len(gg))


if __name__ == '__main__':
    main()
@mvcisback
Copy link
Owner

Hi @masinag ,

Thanks for reaching out. I can take a look sometime in the coming weeks.

I suspect you're right about the hash issue, it's been a bit of a wart for a while and one of the reasons the lazy API was initially developed -- although that won't help here.

I recommend looking at it with a tool like pyspy to get a flame graph is probably going to good to confirm.

https://github.com/benfred/py-spy

If you have a chance to take a look at the py-spy let me know (feel free to attach the output svg).

Supposing it is the hashing in common.dfs we can look at two solutions:

  1. accelerating hashing in general.
  2. re-writing common.dfs to avoid the hashing.

Option 1

For option 1, I would have thought this was solved by cache_hash.

@attr.frozen(cache_hash=True)

Perhaps we're having a lot of hash collisions and being killed by equality checks? Eitherway it's strange worst case we'll need to manually introduce smarter hashing and caching.

Option 2

I think this is the easiest to code, but not a very satisfying solution. Essentially would could switch to checking if that exact node has already been emitted. This would be done perhaps as follows.

def dfs(circ):
    """Generates nodes via depth first traversal in pre-order."""
    emitted: set()
    stack = list(circ.cones | circ.latch_cones)

    while stack:
        node = stack.pop()

        if id(node) in emitted:
            continue

        remaining = [c for c in node.children if id(c) not in emitted]

        if len(remaining) == 0:
            yield node
            emitted.add(id(node))   # node -> id(node)
            continue

        stack.append(node)  # Add to emit after remaining children.
        stack.extend(remaining)

@mvcisback
Copy link
Owner

@masinag looking at your code again, it may actually be that NodeAlg doesn't cache its hashes. Could you try again with that?

@masinag
Copy link
Contributor Author

masinag commented Nov 4, 2024

@masinag looking at your code again, it may actually be that NodeAlg doesn't cache its hashes. Could you try again with that?

I don't think I understand what you mean. I don't see where I am hashing NodeAlg objects

@masinag
Copy link
Contributor Author

masinag commented Nov 4, 2024

About the proposed options, I can try to profile the execution with py-spy.

Option 2 seems an easy fix, but if the issue is really the hashing, speeding it up could improve performance in many other contexts. So it could be worth looking deeper into that.

@mvcisback
Copy link
Owner

@masinag looking at your code again, it may actually be that NodeAlg doesn't cache its hashes. Could you try again with that?

I don't think I understand what you mean. I don't see where I am hashing NodeAlg objects

Err, actually ignore what I said.

@masinag
Copy link
Contributor Author

masinag commented Nov 22, 2024

Hi, I profiled the above code using py-spy. I stopped it after 3 hours of execution. This is the flame graph.

profile

It looks like the problem is the equality check between nodes, since most of the time is taken by the __eq__ function generated by attrs.

@mvcisback
Copy link
Owner

Thanks @masinag ! Could you share the SVG as well since it's interactive?

But that I suppose this makes sense. It's not the hash that's the problem, it's the equality check that happens after to check it wasn't a hash collision....

I will need to think about how to speed that up. Given that

if children <= emitted:

seems to be the bottleneck, it make be good to implement option 2 anyway.

Alternatively, I could do a breaking change and make equality the same as an id check. It's been a while since I really thought about what implications that would have, but it seems like a reasonable option.

@mvcisback
Copy link
Owner

@masinag if I could ask for a favor, if you have the bandwidth, could you regenerate the above flameplot with the change suggested in:

#135 (comment)

I suspect we'll see a huge speed improvement, but unfortunately I'm not able to run it myself for a while.

@masinag
Copy link
Contributor Author

masinag commented Nov 22, 2024

Thanks @masinag ! Could you share the SVG as well since it's interactive?

Sure, you can find it at https://github.com/user-attachments/assets/b11cc6ba-1966-4108-b447-e6d3997d51ba

@masinag
Copy link
Contributor Author

masinag commented Nov 22, 2024

@masinag if I could ask for a favor, if you have the bandwidth, could you regenerate the above flameplot with the change suggested in:

I've stopped it after 30 min.

profile2

(interactive version https://github.com/user-attachments/assets/e000cc4a-aaab-4274-9193-3db3fdc4dc17)

It looks like there is another call to __eq__ when nodes are used as keys for the mem dictionary.

mem[gate] = neg(mem[gate.input])

@masinag
Copy link
Contributor Author

masinag commented Nov 22, 2024

The code is much faster (milliseconds instead of hours!) if the AIG class

def __call__(self, inputs, latches=None, *, lift=None):
is modified as follows:

@attr.frozen(repr=False)
class AIG:
    ...
    def __call__(self, inputs, latches=None, *, lift=None):
        """Evaluate AIG on inputs (and latches).
        If `latches` is `None` initial latch value is used.

        `lift` is an optional argument used to interpret constants
        (False, True) in some other Boolean algebra over (&, ~).

        - See py-aiger-bdd and py-aiger-cnf for examples.
        """
        if latches is None:
            latches = dict()

        if lift is None:
            lift = fn.identity
            and_, neg = op.and_, op.not_
        else:
            and_, neg = op.__and__, op.__invert__

        latchins = fn.merge(dict(self.latch2init), latches)
        # Remove latch inputs not used by self.
        latchins = fn.project(latchins, self.latches)

        latch_map = dict(self.latch_map)
        boundary = set(self.node_map.values()) | set(latch_map.values())

        store, prev, mem = {}, set(), {}

        for node_batch in self.__iter_nodes__():
            prev = set(mem.keys()) - prev
            mem = fn.project(mem, prev)  # Forget about unnecessary gates.

            for gate in node_batch:
                if isinstance(gate, Inverter):
                    mem[id(gate)] = neg(mem[id(gate.input)])
                elif isinstance(gate, AndGate):
                    mem[id(gate)] = and_(mem[id(gate.left)], mem[id(gate.right)])
                elif isinstance(gate, Input):
                    mem[id(gate)] = lift(inputs[gate.name])
                elif isinstance(gate, LatchIn):
                    mem[id(gate)] = lift(latchins[gate.name])
                elif isinstance(gate, ConstFalse):
                    mem[id(gate)] = lift(False)

                if gate in boundary:
                    store[id(gate)] = mem[id(gate)]  # Store for eventual output.

        outs = {out: store[id(gate)] for out, gate in self.node_map.items()}
        louts = {out: store[id(gate)] for out, gate in latch_map.items()}
        return outs, louts

Notice that I used id(gate) instead of gate as dictionary key.

profile4
(interactive version https://github.com/user-attachments/assets/2ec73a97-4541-4a7a-b82d-7e6374c9e838)

@mvcisback
Copy link
Owner

Amazing! This does suggest to me that the right solution might be to do a breaking change and make __eq__ work via id. This would have have the same effect (since after hashing, eq is called), but would be done throughout the codebase.

I'll think about it over the weekend to make sure there aren't any gotcha's to applying more widely.

@masinag
Copy link
Contributor Author

masinag commented Nov 22, 2024

Great! Btw from attrs' doc:

If you want hashing and equality by object identity: use @define(eq=False)

@mvcisback
Copy link
Owner

Hi,

sorry for the slow follow up. I tried implementing this and it seemed to break the unit tests (run via pytest).

@masinag when you did you change, did you happen to run the pytest suite? I'm worried there's a subtle issue I couldn't figure out in the 30min I could dedicate to exploring this.

@masinag
Copy link
Contributor Author

masinag commented Dec 4, 2024

I confirm that the tests are failing also for me. I don't have much time to debug it now, but the failing tests are very small, so it should be easy to spot the bug.

@masinag
Copy link
Contributor Author

masinag commented Dec 14, 2024

Hi @mvcisback, I have done some debugging and I found out this.
In tests.test_common.test_tee, at some point __call__ is called on a circuit with the following node_map:

pmap({
    'b': Input(name='562f849c-ba08-11ef-b60b-6c24083d2fd2'), 
    'a': Input(name='562f849c-ba08-11ef-b60b-6c24083d2fd2'), 
    'c': Input(name='562f849c-ba08-11ef-b60b-6c24083d2fd2')
})

So we have three inputs with the same name but different keys in the nodes map, and also different id(node):
[127306698814800, 127306698805456, 127306698802512].

However, since they have the same name, these nodes are considered as if they were the same node by the dfs, because AIG.cones returns a frozenset of the node_map.
Due to this, only b is visited and then stored in store by __call__.

I am not really sure if this input makes sense, and what the expected behavior is.

@masinag
Copy link
Contributor Author

masinag commented Dec 14, 2024

For instance, can we use the node.name to check for equality?

@mvcisback
Copy link
Owner

Thanks @masinag .

I will take a look today.

mvcisback added a commit that referenced this issue Dec 16, 2024
closes: #135
BREAKING_CHANGE: Structually equal circuits will not necessarily be equal now.
co-authored-by: Gabriele Masina <[email protected]>
@mvcisback
Copy link
Owner

@masinag I created a quick PR to try out this idea.

It seems to pass all the tests (except for hash stability which is expected given the breaking change.)

Do you mind benchmarking when you have some time? I need to context switch back to my dayjob :)

mvcisback added a commit that referenced this issue Dec 17, 2024
closes: #135
BREAKING_CHANGE: Structually equal circuits will not necessarily be equal now.
co-authored-by: Gabriele Masina <[email protected]>
@mvcisback
Copy link
Owner

@masinag had some free time to sitdown and try.

Looks like the issue is resolved. I added a performance test that checks that your example is fast.

mvcisback added a commit that referenced this issue Dec 17, 2024
closes: #135
BREAKING_CHANGE: Structually equal circuits will not necessarily be equal now.

Co-authored-by: Gabriele Masina <[email protected]>
@mvcisback
Copy link
Owner

Should be published now as 7.0.0.

@masinag I also made sure the git trailer listed you as a co-author of the commit. Hopefully I got the email right.

@masinag
Copy link
Contributor Author

masinag commented Dec 17, 2024

Wonderful 🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants