From 525f7eb0328b33b4854c4682b6ffa82497251ba2 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Tue, 1 Aug 2023 16:36:41 -0700 Subject: [PATCH 1/5] Add broken tests --- tests/test_magma_protocol.py | 109 +++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/tests/test_magma_protocol.py b/tests/test_magma_protocol.py index 9dde24d5..a9a87174 100644 --- a/tests/test_magma_protocol.py +++ b/tests/test_magma_protocol.py @@ -1,4 +1,7 @@ from typing import Optional + +import pytest + import magma as m import fault from hwtypes import BitVector @@ -46,3 +49,109 @@ class Bar(m.Circuit): (BitVector[8](0xDE) << 1)[0] | BitVector[8](0xDE)[0]) tester.compile_and_run("verilator") + + +class SimpleMagmaProtocolMeta(m.MagmaProtocolMeta): + _CACHE = {} + + def _to_magma_(cls): + return cls.T + + def _qualify_magma_(cls, direction): + return cls[cls.T.qualify(direction)] + + def _flip_magma_(cls): + return cls[cls.T.flip()] + + def _from_magma_value_(cls, val): + return cls(val) + + def __getitem__(cls, T): + try: + base = cls.base + except AttributeError: + base = cls + dct = {"T": T, "base": base} + derived = type(cls)(f"{base.__name__}[{T}]", (cls,), dct) + return SimpleMagmaProtocolMeta._CACHE.setdefault(T, derived) + + def __repr__(cls): + return str(cls) + + def __str__(cls): + return cls.__name__ + + +class BrokenProtocol(m.MagmaProtocol, metaclass=SimpleMagmaProtocolMeta): + def __init__(self, val = None): + if val is None: + self._val = self.T() + elif isinstance(val, self.T): + self._val = val + else: + self._val = self.T(val) + + def _get_magma_value_(self): + return self._val + + +class FixedProtocol(BrokenProtocol): + @property + def debug_name(self): + # Beyond not liking the number of names already reserved by + # MagmaProtocol I have not strong feeling about adding `debug_name` + # to the list. + return self._get_magma_value_().debug_name + + def __len__(self): + # !!! Please do not reserve len !!! + # I have container types which use it + cls = type(self) + magma_t = cls._to_magma_() + if issubclass(magma_t, m.Digital): + # for some reason this breaks on m.Bits + # dont know if bug + return len(magma_t) + else: + return len(self._get_magma_value_()) + +def gen_DUT(T, BoxT, i, o): + if i: + i_t = BoxT[T] + else: + i_t = T + + if o: + o_t = BoxT[T] + else: + o_t = T + + class DUT(m.Circuit): + io = m.IO( + I=m.In(i_t), + O=m.Out(o_t) + ) + if i and not o: + io.O @= io.I._get_magma_value_() + elif o and not i: + io.O @= o_t._from_magma_value_(io.I) + else: + io.O @= io.I + return DUT + + + +@pytest.mark.parametrize('T', [m.Bit, m.Bits[16]]) +@pytest.mark.parametrize('BoxT', [BrokenProtocol]) +@pytest.mark.parametrize('proto_in, proto_out', [ + (True, False), + (False, True), + (True, True), + ]) +def test_protocol_as_input_and_output(T, BoxT, proto_in, proto_out): + DUT = gen_DUT(T, BoxT, proto_in, proto_out) + tester = fault.Tester(DUT) + tester.circuit.I = BoxT[T](T(0)) + tester.eval() + tester.circuit.O.expect(T(0)) + tester.compile_and_run("verilator") From b85216036f89f1f41d61d8e12c2eb9dedff66829 Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 17 Aug 2023 15:15:43 -0700 Subject: [PATCH 2/5] Get magma value from protocol at fault entrypoint --- fault/actions.py | 6 ++++-- fault/wrapper.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fault/actions.py b/fault/actions.py index c5e764f7..91a16104 100644 --- a/fault/actions.py +++ b/fault/actions.py @@ -9,6 +9,8 @@ from fault.select_path import SelectPath import fault.expression as expression +import magma as m + class Action(ABC): @abstractmethod @@ -25,8 +27,8 @@ def __repr__(self): class PortAction(Action): def __init__(self, port, value): super().__init__() - self.port = port - self.value = value + self.port = m.protocol_type.magma_value(port) + self.value = m.protocol_type.magma_value(value) def __str__(self): type_name = type(self).__name__ diff --git a/fault/wrapper.py b/fault/wrapper.py index bc0dec61..0350fbcf 100644 --- a/fault/wrapper.py +++ b/fault/wrapper.py @@ -61,7 +61,7 @@ class CircuitWrapper(Wrapper): class PortWrapper(expression.Expression): def __init__(self, port, parent): - self.port = port + self.port = m.protocol_type.magma_value(port) self.parent = parent self.init_done = True From ce65667b4a3dd0184007d2c1687b4acda735ad2e Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 17 Aug 2023 15:47:20 -0700 Subject: [PATCH 3/5] Style --- tests/test_magma_protocol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_magma_protocol.py b/tests/test_magma_protocol.py index a9a87174..7c3d4e44 100644 --- a/tests/test_magma_protocol.py +++ b/tests/test_magma_protocol.py @@ -83,7 +83,7 @@ def __str__(cls): class BrokenProtocol(m.MagmaProtocol, metaclass=SimpleMagmaProtocolMeta): - def __init__(self, val = None): + def __init__(self, val=None): if val is None: self._val = self.T() elif isinstance(val, self.T): @@ -115,6 +115,7 @@ def __len__(self): else: return len(self._get_magma_value_()) + def gen_DUT(T, BoxT, i, o): if i: i_t = BoxT[T] @@ -140,14 +141,13 @@ class DUT(m.Circuit): return DUT - @pytest.mark.parametrize('T', [m.Bit, m.Bits[16]]) @pytest.mark.parametrize('BoxT', [BrokenProtocol]) @pytest.mark.parametrize('proto_in, proto_out', [ (True, False), (False, True), (True, True), - ]) +]) def test_protocol_as_input_and_output(T, BoxT, proto_in, proto_out): DUT = gen_DUT(T, BoxT, proto_in, proto_out) tester = fault.Tester(DUT) From 3a9bbf86f9d944d67725afe65decea0d7e5d06d1 Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 17 Aug 2023 15:48:56 -0700 Subject: [PATCH 4/5] Add tempdir logic --- tests/test_magma_protocol.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_magma_protocol.py b/tests/test_magma_protocol.py index 7c3d4e44..fe711618 100644 --- a/tests/test_magma_protocol.py +++ b/tests/test_magma_protocol.py @@ -1,3 +1,4 @@ +import tempfile from typing import Optional import pytest @@ -154,4 +155,5 @@ def test_protocol_as_input_and_output(T, BoxT, proto_in, proto_out): tester.circuit.I = BoxT[T](T(0)) tester.eval() tester.circuit.O.expect(T(0)) - tester.compile_and_run("verilator") + with tempfile.TemporaryDirectory(dir=".") as tempdir: + tester.compile_and_run("verilator", directory="tempdir") From 5db3c124f0e57efb6e472b85e1396fb7d9db5c60 Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 17 Aug 2023 16:06:39 -0700 Subject: [PATCH 5/5] Fix style --- tests/test_tester/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tester/test_core.py b/tests/test_tester/test_core.py index 4315b022..857aaf98 100644 --- a/tests/test_tester/test_core.py +++ b/tests/test_tester/test_core.py @@ -23,7 +23,7 @@ def pytest_generate_tests(metafunc): def check(got, expected): - assert type(got) == type(expected) + assert isinstance(got, type(expected)) if isinstance(got, actions.PortAction): assert got.port is expected.port assert got.value == expected.value