Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tensorflow_quantum/python/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,42 @@
'cirq.Circuit'):
util.get_circuit_symbols(param)

def test_random_circuit_resolver_batch(self):
"""Confirm that random_circuit_resolver_batch works."""
qubits = cirq.GridQubit.rect(1, 2)
batch_size = 5
circuits, resolvers = util.random_circuit_resolver_batch(
qubits, batch_size)
self.assertEqual(len(circuits), batch_size)
self.assertEqual(len(resolvers), batch_size)
for circuit in circuits:
self.assertIsInstance(circuit, cirq.Circuit)
self.assertGreater(len(circuit), 0, "Generated circuit should not be empty.")

Check warning on line 441 in tensorflow_quantum/python/util_test.py

View workflow job for this annotation

GitHub Actions / Check Python lint

C0301: Line too long (89/80) (line-too-long)
self.assertEqual(len(util.get_circuit_symbols(circuit)), 0, "Circuit should not have symbols.")

Check warning on line 442 in tensorflow_quantum/python/util_test.py

View workflow job for this annotation

GitHub Actions / Check Python lint

C0301: Line too long (107/80) (line-too-long)
for resolver in resolvers:
self.assertIsInstance(resolver, cirq.ParamResolver)
self.assertEqual(len(resolver.param_dict), 0)

def test_random_symbol_circuit_resolver_batch(self):
"""Confirm that random_symbol_circuit_resolver_batch works."""
qubits = cirq.GridQubit.rect(1, 2)
symbols = [sympy.Symbol('a'), sympy.Symbol('b')]
batch_size = 5
circuits, resolvers = util.random_symbol_circuit_resolver_batch(
qubits, symbols, batch_size)
self.assertEqual(len(circuits), batch_size)
self.assertEqual(len(resolvers), batch_size)
for circuit in circuits:
self.assertIsInstance(circuit, cirq.Circuit)
extracted_symbols = util.get_circuit_symbols(circuit)
expected_symbols = sorted([str(s) for s in symbols])
self.assertListEqual(expected_symbols, sorted(extracted_symbols))
for resolver in resolvers:
self.assertIsInstance(resolver, cirq.ParamResolver)
self.assertEqual(len(resolver.param_dict), len(symbols))
for symbol in symbols:
self.assertIn(symbol, resolver.param_dict)


class ExponentialUtilFunctionsTest(tf.test.TestCase):
"""Test that Exponential utility functions work."""
Expand Down
Loading