diff --git a/tensorflow_quantum/python/util_test.py b/tensorflow_quantum/python/util_test.py index 00715899b..365bb18ef 100644 --- a/tensorflow_quantum/python/util_test.py +++ b/tensorflow_quantum/python/util_test.py @@ -428,6 +428,42 @@ def test_get_circuit_symbols_error(self): 'cirq.Circuit'): util.get_circuit_symbols(param) + def test_random_circuit_resolver_batch(self): + """Confirm that random_circuit_resolver_batch works.""" + qubits = cirq.GridQubit.rect(1, 2) + batch_size = 5 + circuits, resolvers = util.random_circuit_resolver_batch( + qubits, batch_size) + self.assertEqual(len(circuits), batch_size) + self.assertEqual(len(resolvers), batch_size) + for circuit in circuits: + self.assertIsInstance(circuit, cirq.Circuit) + self.assertGreater(len(circuit), 0, "Generated circuit should not be empty.") + self.assertEqual(len(util.get_circuit_symbols(circuit)), 0, "Circuit should not have symbols.") + for resolver in resolvers: + self.assertIsInstance(resolver, cirq.ParamResolver) + self.assertEqual(len(resolver.param_dict), 0) + + def test_random_symbol_circuit_resolver_batch(self): + """Confirm that random_symbol_circuit_resolver_batch works.""" + qubits = cirq.GridQubit.rect(1, 2) + symbols = [sympy.Symbol('a'), sympy.Symbol('b')] + batch_size = 5 + circuits, resolvers = util.random_symbol_circuit_resolver_batch( + qubits, symbols, batch_size) + self.assertEqual(len(circuits), batch_size) + self.assertEqual(len(resolvers), batch_size) + for circuit in circuits: + self.assertIsInstance(circuit, cirq.Circuit) + extracted_symbols = util.get_circuit_symbols(circuit) + expected_symbols = sorted([str(s) for s in symbols]) + self.assertListEqual(expected_symbols, sorted(extracted_symbols)) + for resolver in resolvers: + self.assertIsInstance(resolver, cirq.ParamResolver) + self.assertEqual(len(resolver.param_dict), len(symbols)) + for symbol in symbols: + self.assertIn(symbol, resolver.param_dict) + class ExponentialUtilFunctionsTest(tf.test.TestCase): """Test that Exponential utility functions work."""