diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 866d69af216..e84e63f19a2 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -10,7 +10,7 @@ cirq-google/**/*.* @wcourtney @quantumlib/cirq-maintainers @vtomole @cduck @verult -cirq-ionq/**/*.* @dabacon @ColemanCollins @nakardo @gmauricio @Cynocracy @quantumlib/cirq-maintainers @vtomole @cduck +cirq-ionq/**/*.* @dabacon @ColemanCollins @nakardo @gmauricio @Cynocracy @quantumlib/cirq-maintainers @vtomole @cduck @splch cirq-aqt/**/*.* @ma5x @pschindler @alfrisch @quantumlib/cirq-maintainers @vtomole @cduck @@ -37,7 +37,7 @@ docs/**/*.* @aasfaw @rmlarose @quantumlib/cirq-maintainers @vtomole @cduck docs/google/**/*.* @wcourtney @aasfaw @rmlarose @quantumlib/cirq-maintainers @vtomole @cduck @verult docs/tutorials/google/**/*.* @wcourtney @aasfaw @rmlarose @quantumlib/cirq-maintainers @vtomole @cduck @verult -docs/hardware/ionq/**/*.* @dabacon @ColemanCollins @nakardo @gmauricio @aasfaw @rmlarose @Cynocracy @quantumlib/cirq-maintainers @vtomole @cduck +docs/hardware/ionq/**/*.* @dabacon @ColemanCollins @nakardo @gmauricio @aasfaw @rmlarose @Cynocracy @quantumlib/cirq-maintainers @vtomole @cduck @splch docs/hardware/aqt/**/*.* @ma5x @pschindler @alfrisch @aasfaw @rmlarose @quantumlib/cirq-maintainers @vtomole @cduck diff --git a/cirq-core/cirq/contrib/svg/svg.py b/cirq-core/cirq/contrib/svg/svg.py index be7a1d60c56..3a9be84bdeb 100644 --- a/cirq-core/cirq/contrib/svg/svg.py +++ b/cirq-core/cirq/contrib/svg/svg.py @@ -2,12 +2,14 @@ from typing import TYPE_CHECKING, List, Tuple, cast, Dict import matplotlib.textpath +import matplotlib.font_manager + if TYPE_CHECKING: import cirq QBLUE = '#1967d2' -FONT = "Arial" +FONT = matplotlib.font_manager.FontProperties(family="Arial") EMPTY_MOMENT_COLWIDTH = float(21) # assumed default column width diff --git a/cirq-core/cirq/devices/insertion_noise_model.py b/cirq-core/cirq/devices/insertion_noise_model.py index ab6604868fc..cbe158ae8a5 100644 --- a/cirq-core/cirq/devices/insertion_noise_model.py +++ b/cirq-core/cirq/devices/insertion_noise_model.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from cirq import devices from cirq.devices import noise_utils @@ -74,3 +74,23 @@ def noisy_moment( if self.prepend: return [*noise_steps.moments, moment] return [moment, *noise_steps.moments] + + def __repr__(self) -> str: + return ( + f'cirq.devices.InsertionNoiseModel(ops_added={self.ops_added},' + + f' prepend={self.prepend},' + + f' require_physical_tag={self.require_physical_tag})' + ) + + def _json_dict_(self) -> Dict[str, Any]: + return { + 'ops_added': list(self.ops_added.items()), + 'prepend': self.prepend, + 'require_physical_tag': self.require_physical_tag, + } + + @classmethod + def _from_json_dict_(cls, ops_added, prepend, require_physical_tag, **kwargs): + return cls( + ops_added=dict(ops_added), prepend=prepend, require_physical_tag=require_physical_tag + ) diff --git a/cirq-core/cirq/devices/insertion_noise_model_test.py b/cirq-core/cirq/devices/insertion_noise_model_test.py index dc15eec2b44..3a316b805e6 100644 --- a/cirq-core/cirq/devices/insertion_noise_model_test.py +++ b/cirq-core/cirq/devices/insertion_noise_model_test.py @@ -47,6 +47,8 @@ def test_insertion_noise(): moment_3 = cirq.Moment(cirq.Z(q0), cirq.X(q1)) assert model.noisy_moment(moment_3, system_qubits=[q0, q1]) == [moment_3] + cirq.testing.assert_equivalent_repr(model) + def test_colliding_noise_qubits(): # Check that noise affecting other qubits doesn't cause issues. @@ -61,6 +63,8 @@ def test_colliding_noise_qubits(): cirq.Moment(cirq.CNOT(q1, q2)), ] + cirq.testing.assert_equivalent_repr(model) + def test_prepend(): q0, q1 = cirq.LineQubit.range(2) @@ -106,3 +110,5 @@ def test_supertype_matching(): moment_1 = cirq.Moment(cirq.Y(q0)) assert model.noisy_moment(moment_1, system_qubits=[q0]) == [moment_1, cirq.Moment(cirq.T(q0))] + + cirq.testing.assert_equivalent_repr(model) diff --git a/cirq-core/cirq/devices/named_topologies.py b/cirq-core/cirq/devices/named_topologies.py index 5f32d8b1d5d..6aa46e19e94 100644 --- a/cirq-core/cirq/devices/named_topologies.py +++ b/cirq-core/cirq/devices/named_topologies.py @@ -74,7 +74,7 @@ def _node_and_coordinates( def draw_gridlike( - graph: nx.Graph, ax: plt.Axes = None, tilted: bool = True, **kwargs + graph: nx.Graph, ax: Optional[plt.Axes] = None, tilted: bool = True, **kwargs ) -> Dict[Any, Tuple[int, int]]: """Draw a grid-like graph using Matplotlib. diff --git a/cirq-core/cirq/experiments/qubit_characterizations.py b/cirq-core/cirq/experiments/qubit_characterizations.py index 114e2e28659..ed12b311e22 100644 --- a/cirq-core/cirq/experiments/qubit_characterizations.py +++ b/cirq-core/cirq/experiments/qubit_characterizations.py @@ -15,13 +15,13 @@ import dataclasses import itertools -from typing import Any, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING +from typing import Any, cast, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING import numpy as np from matplotlib import pyplot as plt # this is for older systems with matplotlib <3.2 otherwise 3d projections fail -from mpl_toolkits import mplot3d # pylint: disable=unused-import +from mpl_toolkits import mplot3d from cirq import circuits, ops, protocols if TYPE_CHECKING: @@ -89,8 +89,9 @@ def plot(self, ax: Optional[plt.Axes] = None, **plot_kwargs: Any) -> plt.Axes: """ show_plot = not ax if not ax: - fig, ax = plt.subplots(1, 1, figsize=(8, 8)) - ax.set_ylim([0, 1]) + fig, ax = plt.subplots(1, 1, figsize=(8, 8)) # pragma: no cover + ax = cast(plt.Axes, ax) # pragma: no cover + ax.set_ylim((0.0, 1.0)) # pragma: no cover ax.plot(self._num_cfds_seq, self._gnd_state_probs, 'ro-', **plot_kwargs) ax.set_xlabel(r"Number of Cliffords") ax.set_ylabel('Ground State Probability') @@ -541,7 +542,7 @@ def _find_inv_matrix(mat: np.ndarray, mat_sequence: np.ndarray) -> int: def _matrix_bar_plot( mat: np.ndarray, z_label: str, - ax: plt.Axes, + ax: mplot3d.axes3d.Axes3D, kets: Optional[Sequence[str]] = None, title: Optional[str] = None, ylim: Tuple[int, int] = (-1, 1), diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 20a7377b294..f1d178bb530 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -47,6 +47,7 @@ def _class_resolver_dictionary() -> Dict[str, ObjectFactory]: import pandas as pd import numpy as np from cirq.devices.noise_model import _NoNoiseModel + from cirq.devices import InsertionNoiseModel from cirq.experiments import GridInteractionLayer from cirq.experiments.grid_parallel_two_qubit_xeb import GridParallelXEBMetadata @@ -147,6 +148,7 @@ def _symmetricalqidpair(qids): 'ISwapPowGate': cirq.ISwapPowGate, 'IdentityGate': cirq.IdentityGate, 'InitObsSetting': cirq.work.InitObsSetting, + 'InsertionNoiseModel': InsertionNoiseModel, 'KeyCondition': cirq.KeyCondition, 'KrausChannel': cirq.KrausChannel, 'LinearDict': cirq.LinearDict, diff --git a/cirq-core/cirq/linalg/decompositions.py b/cirq-core/cirq/linalg/decompositions.py index 60dc0123640..43434ff4d1b 100644 --- a/cirq-core/cirq/linalg/decompositions.py +++ b/cirq-core/cirq/linalg/decompositions.py @@ -20,6 +20,7 @@ from typing import ( Any, Callable, + cast, Iterable, List, Optional, @@ -33,7 +34,7 @@ import matplotlib.pyplot as plt # this is for older systems with matplotlib <3.2 otherwise 3d projections fail -from mpl_toolkits import mplot3d # pylint: disable=unused-import +from mpl_toolkits import mplot3d import numpy as np from cirq import value, protocols @@ -554,7 +555,7 @@ def scatter_plot_normalized_kak_interaction_coefficients( interactions: Iterable[Union[np.ndarray, 'cirq.SupportsUnitary', 'KakDecomposition']], *, include_frame: bool = True, - ax: Optional[plt.Axes] = None, + ax: Optional[mplot3d.axes3d.Axes3D] = None, **kwargs, ): r"""Plots the interaction coefficients of many two-qubit operations. @@ -633,13 +634,13 @@ def scatter_plot_normalized_kak_interaction_coefficients( show_plot = not ax if not ax: fig = plt.figure() - ax = fig.add_subplot(1, 1, 1, projection='3d') + ax = cast(mplot3d.axes3d.Axes3D, fig.add_subplot(1, 1, 1, projection='3d')) def coord_transform( pts: Union[List[Tuple[int, int, int]], np.ndarray] - ) -> Tuple[Iterable[float], Iterable[float], Iterable[float]]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if len(pts) == 0: - return [], [], [] + return np.array([]), np.array([]), np.array([]) xs, ys, zs = np.transpose(pts) return xs, zs, ys diff --git a/cirq-core/cirq/ops/fsim_gate.py b/cirq-core/cirq/ops/fsim_gate.py index b3369411324..01cf6291b8b 100644 --- a/cirq-core/cirq/ops/fsim_gate.py +++ b/cirq-core/cirq/ops/fsim_gate.py @@ -81,7 +81,7 @@ class FSimGate(gate_features.InterchangeableQubitsGate, raw_types.Gate): $$ $$ - c = e^{i \phi} + c = e^{-i \phi} $$ Note the difference in sign conventions between FSimGate and the diff --git a/cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.json b/cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.json new file mode 100644 index 00000000000..1a825bdbe48 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.json @@ -0,0 +1,91 @@ +[ + { + "cirq_type": "InsertionNoiseModel", + "ops_added": [ + [ + { + "cirq_type": "OpIdentifier", + "gate_type": "XPowGate", + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 0 + } + ] + }, + { + "cirq_type": "GateOperation", + "gate": { + "cirq_type": "BitFlipChannel", + "p": 0.2 + }, + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 0 + } + ] + } + ] + ], + "prepend": false, + "require_physical_tag": false + }, + { + "cirq_type": "InsertionNoiseModel", + "ops_added": [ + [ + { + "cirq_type": "OpIdentifier", + "gate_type": "XPowGate", + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 0 + } + ] + }, + { + "cirq_type": "GateOperation", + "gate": { + "cirq_type": "BitFlipChannel", + "p": 0.2 + }, + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 0 + } + ] + } + ], + [ + { + "cirq_type": "OpIdentifier", + "gate_type": "HPowGate", + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 1 + } + ] + }, + { + "cirq_type": "GateOperation", + "gate": { + "cirq_type": "BitFlipChannel", + "p": 0.1 + }, + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 1 + } + ] + } + ] + ], + "prepend": false, + "require_physical_tag": false + } +] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.repr b/cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.repr new file mode 100644 index 00000000000..d0650e0d6fc --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.repr @@ -0,0 +1,4 @@ +[ +cirq.devices.InsertionNoiseModel(ops_added={cirq.devices.noise_utils.OpIdentifier(cirq.ops.common_gates.XPowGate, cirq.LineQubit(0)): cirq.bit_flip(p=0.2).on(cirq.LineQubit(0))}, prepend=False, require_physical_tag=False), +cirq.devices.InsertionNoiseModel(ops_added={cirq.devices.noise_utils.OpIdentifier(cirq.ops.common_gates.XPowGate, cirq.LineQubit(0)): cirq.bit_flip(p=0.2).on(cirq.LineQubit(0)), cirq.devices.noise_utils.OpIdentifier(cirq.ops.common_gates.HPowGate, cirq.LineQubit(1)): cirq.bit_flip(p=0.1).on(cirq.LineQubit(1))}, prepend=False, require_physical_tag=False) +] \ No newline at end of file diff --git a/cirq-core/cirq/vis/heatmap.py b/cirq-core/cirq/vis/heatmap.py index e5598f59450..e672a2b8c27 100644 --- a/cirq-core/cirq/vis/heatmap.py +++ b/cirq-core/cirq/vis/heatmap.py @@ -15,6 +15,7 @@ from dataclasses import astuple, dataclass from typing import ( Any, + cast, Dict, List, Mapping, @@ -217,7 +218,7 @@ def _plot_colorbar( ) position = self._config['colorbar_position'] orien = 'vertical' if position in ('left', 'right') else 'horizontal' - colorbar = ax.figure.colorbar( + colorbar = cast(plt.Figure, ax.figure).colorbar( mappable, colorbar_ax, ax, orientation=orien, **self._config.get("colorbar_options", {}) ) colorbar_ax.tick_params(axis='y', direction='out') @@ -230,15 +231,15 @@ def _write_annotations( ax: plt.Axes, ) -> None: """Writes annotations to the center of cells. Internal.""" - for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolors()): + for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolor()): # Calculate the center of the cell, assuming that it is a square # centered at (x=col, y=row). if not annotation: continue x, y = center - face_luminance = vis_utils.relative_luminance(facecolor) + face_luminance = vis_utils.relative_luminance(facecolor) # type: ignore text_color = 'black' if face_luminance > 0.4 else 'white' - text_kwargs = dict(color=text_color, ha="center", va="center") + text_kwargs: Dict[str, Any] = dict(color=text_color, ha="center", va="center") text_kwargs.update(self._config.get('annotation_text_kwargs', {})) ax.text(x, y, annotation, **text_kwargs) @@ -295,6 +296,7 @@ def plot( show_plot = not ax if not ax: fig, ax = plt.subplots(figsize=(8, 8)) + ax = cast(plt.Axes, ax) original_config = copy.deepcopy(self._config) self.update_config(**kwargs) collection = self._plot_on_axis(ax) @@ -381,6 +383,7 @@ def plot( show_plot = not ax if not ax: fig, ax = plt.subplots(figsize=(8, 8)) + ax = cast(plt.Axes, ax) original_config = copy.deepcopy(self._config) self.update_config(**kwargs) qubits = set([q for qubits in self._value_map.keys() for q in qubits]) diff --git a/cirq-core/cirq/vis/heatmap_test.py b/cirq-core/cirq/vis/heatmap_test.py index 1ca493386f5..dceb00cff1c 100644 --- a/cirq-core/cirq/vis/heatmap_test.py +++ b/cirq-core/cirq/vis/heatmap_test.py @@ -34,6 +34,14 @@ def ax(): return figure.add_subplot(111) +def test_default_ax(): + row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) + test_value_map = { + grid_qubit.GridQubit(row, col): np.random.random() for (row, col) in row_col_list + } + _, _ = heatmap.Heatmap(test_value_map).plot() + + @pytest.mark.parametrize('tuple_keys', [True, False]) def test_cells_positions(ax, tuple_keys): row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) @@ -61,6 +69,8 @@ def test_two_qubit_heatmap(ax): title = "Two Qubit Interaction Heatmap" heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot(ax) assert ax.get_title() == title + # Test default axis + heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot() def test_invalid_args(): diff --git a/cirq-core/cirq/vis/histogram.py b/cirq-core/cirq/vis/histogram.py index f3b0a8047bc..88349097a97 100644 --- a/cirq-core/cirq/vis/histogram.py +++ b/cirq-core/cirq/vis/histogram.py @@ -100,9 +100,9 @@ def integrated_histogram( plot_options.update(kwargs) if cdf_on_x: - ax.step(bin_values, parameter_values, **plot_options) + ax.step(bin_values, parameter_values, **plot_options) # type: ignore else: - ax.step(parameter_values, bin_values, **plot_options) + ax.step(parameter_values, bin_values, **plot_options) # type: ignore set_semilog = ax.semilogy if cdf_on_x else ax.semilogx set_lim = ax.set_xlim if cdf_on_x else ax.set_ylim @@ -128,7 +128,7 @@ def integrated_histogram( if median_line: set_line( - np.median(float_data), + float(np.median(float_data)), linestyle='--', color=plot_options['color'], alpha=0.5, @@ -136,7 +136,7 @@ def integrated_histogram( ) if mean_line: set_line( - np.mean(float_data), + float(np.mean(float_data)), linestyle='-.', color=plot_options['color'], alpha=0.5, diff --git a/cirq-core/cirq/vis/state_histogram.py b/cirq-core/cirq/vis/state_histogram.py index 51ccfc5f073..3a3706cf04f 100644 --- a/cirq-core/cirq/vis/state_histogram.py +++ b/cirq-core/cirq/vis/state_histogram.py @@ -14,7 +14,7 @@ """Tool to visualize the results of a study.""" -from typing import Union, Optional, Sequence, SupportsFloat +from typing import cast, Optional, Sequence, SupportsFloat, Union import collections import numpy as np import matplotlib.pyplot as plt @@ -51,13 +51,13 @@ def get_state_histogram(result: 'result.Result') -> np.ndarray: def plot_state_histogram( data: Union['result.Result', collections.Counter, Sequence[SupportsFloat]], - ax: Optional['plt.Axis'] = None, + ax: Optional[plt.Axes] = None, *, tick_label: Optional[Sequence[str]] = None, xlabel: Optional[str] = 'qubit state', ylabel: Optional[str] = 'result count', title: Optional[str] = 'Result State Histogram', -) -> 'plt.Axis': +) -> plt.Axes: """Plot the state histogram from either a single result with repetitions or a histogram computed using `result.histogram()` or a flattened histogram of measurement results computed using `get_state_histogram`. @@ -87,6 +87,7 @@ def plot_state_histogram( show_fig = not ax if not ax: fig, ax = plt.subplots(1, 1) + ax = cast(plt.Axes, ax) if isinstance(data, result.Result): values = get_state_histogram(data) elif isinstance(data, collections.Counter): @@ -96,9 +97,12 @@ def plot_state_histogram( if tick_label is None: tick_label = [str(i) for i in range(len(values))] ax.bar(np.arange(len(values)), values, tick_label=tick_label) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.set_title(title) + if xlabel: + ax.set_xlabel(xlabel) + if ylabel: + ax.set_ylabel(ylabel) + if title: + ax.set_title(title) if show_fig: fig.show() return ax diff --git a/cirq-core/cirq/vis/state_histogram_test.py b/cirq-core/cirq/vis/state_histogram_test.py index 220030d0e81..a922b12b1ff 100644 --- a/cirq-core/cirq/vis/state_histogram_test.py +++ b/cirq-core/cirq/vis/state_histogram_test.py @@ -78,6 +78,8 @@ def test_plot_state_histogram_result(): for r1, r2 in zip(ax1.get_children(), ax2.get_children()): if isinstance(r1, mpl.patches.Rectangle) and isinstance(r2, mpl.patches.Rectangle): assert str(r1) == str(r2) + # Test default axis + state_histogram.plot_state_histogram(expected_values) @pytest.mark.usefixtures('closefigures') diff --git a/cirq-google/cirq_google/engine/calibration.py b/cirq-google/cirq_google/engine/calibration.py index d28434da6c0..8e0ac4c1560 100644 --- a/cirq-google/cirq_google/engine/calibration.py +++ b/cirq-google/cirq_google/engine/calibration.py @@ -17,7 +17,7 @@ from collections import abc, defaultdict import datetime from itertools import cycle -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Sequence +from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union, Sequence import matplotlib as mpl import matplotlib.pyplot as plt @@ -277,6 +277,7 @@ def plot_histograms( show_plot = not ax if not ax: fig, ax = plt.subplots(1, 1) + ax = cast(plt.Axes, ax) if isinstance(keys, str): keys = [keys] @@ -322,7 +323,7 @@ def plot( show_plot = not fig if not fig: fig = plt.figure() - axs = fig.subplots(1, 2) + axs = cast(List[plt.Axes], fig.subplots(1, 2)) self.heatmap(key).plot(axs[0]) self.plot_histograms(key, axs[1]) if show_plot: diff --git a/cirq-google/cirq_google/engine/stream_manager.py b/cirq-google/cirq_google/engine/stream_manager.py index b5bb5696eda..c45e43d81fc 100644 --- a/cirq-google/cirq_google/engine/stream_manager.py +++ b/cirq-google/cirq_google/engine/stream_manager.py @@ -109,8 +109,6 @@ class StreamManager: def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): self._grpc_client = grpc_client - # TODO(#5996) Make this local to the asyncio thread. - self._request_queue: Optional[asyncio.Queue] = None # Used to determine whether the stream coroutine is actively running, and provides a way to # cancel it. self._manage_stream_loop_future: Optional[duet.AwaitableFuture[None]] = None @@ -121,6 +119,16 @@ def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): # interface. self._response_demux = ResponseDemux() self._next_available_message_id = 0 + # Construct queue in AsyncioExecutor to ensure it binds to the correct event loop, since it + # is used by asyncio coroutines. + self._request_queue = self._executor.submit(self._make_request_queue).result() + + async def _make_request_queue(self) -> asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]]: + """Returns a queue used to back the request iterator passed to the stream. + + If `None` is put into the queue, the request iterator will stop. + """ + return asyncio.Queue() def submit( self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob @@ -153,8 +161,12 @@ def submit( raise ValueError('Program name must be set.') if self._manage_stream_loop_future is None or self._manage_stream_loop_future.done(): - self._manage_stream_loop_future = self._executor.submit(self._manage_stream) - return self._executor.submit(self._manage_execution, project_name, program, job) + self._manage_stream_loop_future = self._executor.submit( + self._manage_stream, self._request_queue + ) + return self._executor.submit( + self._manage_execution, self._request_queue, project_name, program, job + ) def stop(self) -> None: """Closes the open stream and resets all management resources.""" @@ -168,9 +180,9 @@ def stop(self) -> None: def _reset(self): """Resets the manager state.""" - self._request_queue = None self._manage_stream_loop_future = None self._response_demux = ResponseDemux() + self._request_queue = self._executor.submit(self._make_request_queue).result() @property def _executor(self) -> AsyncioExecutor: @@ -178,7 +190,9 @@ def _executor(self) -> AsyncioExecutor: # clients: https://github.com/grpc/grpc/issues/25364. return AsyncioExecutor.instance() - async def _manage_stream(self) -> None: + async def _manage_stream( + self, request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]] + ) -> None: """The stream coroutine, an asyncio coroutine to manage QuantumRunStream. This coroutine reads responses from the stream and forwards them to the ResponseDemux, where @@ -187,25 +201,32 @@ async def _manage_stream(self) -> None: When the stream breaks, the stream is reopened, and all execution coroutines are notified. There is at most a single instance of this coroutine running. + + Args: + request_queue: The queue holding requests from the execution coroutine. """ - self._request_queue = asyncio.Queue() while True: try: # The default gRPC client timeout is used. response_iterable = await self._grpc_client.quantum_run_stream( - _request_iterator(self._request_queue) + _request_iterator(request_queue) ) async for response in response_iterable: self._response_demux.publish(response) except asyncio.CancelledError: + await request_queue.put(None) break except BaseException as e: - # TODO(#5996) Close the request iterator to close the existing stream. # Note: the message ID counter is not reset upon a new stream. + await request_queue.put(None) self._response_demux.publish_exception(e) # Raise to all request tasks async def _manage_execution( - self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob + self, + request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]], + project_name: str, + program: quantum.QuantumProgram, + job: quantum.QuantumJob, ) -> Union[quantum.QuantumResult, quantum.QuantumJob]: """The execution coroutine, an asyncio coroutine to manage the lifecycle of a job execution. @@ -216,8 +237,20 @@ async def _manage_execution( error by sending another request. The exact request type depends on the error. There is one execution coroutine per running job submission. + + Args: + request_queue: The queue used to send requests to the stream coroutine. + project_name: The full project ID resource path associated with the job. + program: The Quantum Engine program representing the circuit to be executed. + job: The Quantum Engine job to be executed. + + Raises: + concurrent.futures.CancelledError: if either the request is cancelled or the stream + coroutine is cancelled. + google.api_core.exceptions.GoogleAPICallError: if the stream breaks with a non-retryable + error. + ValueError: if the response is of a type which is not recognized by this client. """ - # Construct requests ahead of time to be reused for retries. create_program_and_job_request = quantum.QuantumRunStreamRequest( parent=project_name, create_quantum_program_and_job=quantum.CreateQuantumProgramAndJobRequest( @@ -225,19 +258,12 @@ async def _manage_execution( ), ) - while self._request_queue is None: - # Wait for the stream coroutine to start. - # Ignoring coverage since this is rarely triggered. - # TODO(#5996) Consider awaiting for the queue to become available, once it is changed - # to be local to the asyncio thread. - await asyncio.sleep(1) # pragma: no cover - current_request = create_program_and_job_request while True: try: current_request.message_id = self._generate_message_id() response_future = self._response_demux.subscribe(current_request.message_id) - await self._request_queue.put(current_request) + await request_queue.put(current_request) response = await response_future # Broken stream @@ -325,16 +351,15 @@ def _is_retryable_error(e: google_exceptions.GoogleAPICallError) -> bool: return any(isinstance(e, exception_type) for exception_type in RETRYABLE_GOOGLE_API_EXCEPTIONS) -# TODO(#5996) Add stop signal to the request iterator. async def _request_iterator( - request_queue: asyncio.Queue, + request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]], ) -> AsyncIterator[quantum.QuantumRunStreamRequest]: """The request iterator for Quantum Engine client RPC quantum_run_stream(). Every call to this method generates a new iterator. """ - while True: - yield await request_queue.get() + while request := await request_queue.get(): + yield request def _to_create_job_request( diff --git a/cirq-google/cirq_google/engine/stream_manager_test.py b/cirq-google/cirq_google/engine/stream_manager_test.py index 7b56dcb8bb3..42e6defbcc8 100644 --- a/cirq-google/cirq_google/engine/stream_manager_test.py +++ b/cirq-google/cirq_google/engine/stream_manager_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import AsyncIterable, AsyncIterator, Awaitable, List, Union +from typing import AsyncIterable, AsyncIterator, Awaitable, List, Sequence, Union import asyncio import concurrent from unittest import mock @@ -21,6 +21,7 @@ import pytest import google.api_core.exceptions as google_exceptions +from cirq_google.engine.asyncio_executor import AsyncioExecutor from cirq_google.engine.stream_manager import ( _get_retry_request_or_raise, ProgramAlreadyExistsError, @@ -49,62 +50,130 @@ # StreamManager test suite constants REQUEST_PROJECT_NAME = 'projects/proj' REQUEST_PROGRAM = quantum.QuantumProgram(name='projects/proj/programs/prog') -REQUEST_JOB = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') +REQUEST_JOB0 = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') +REQUEST_JOB1 = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job1') -def setup_fake_quantum_run_stream_client(client_constructor, responses_and_exceptions): - grpc_client = FakeQuantumRunStream(responses_and_exceptions) - client_constructor.return_value = grpc_client - return grpc_client +def setup_client(client_constructor): + fake_client = FakeQuantumRunStream() + client_constructor.return_value = fake_client + return fake_client + + +def setup(client_constructor): + fake_client = setup_client(client_constructor) + return fake_client, StreamManager(fake_client) class FakeQuantumRunStream: """A fake Quantum Engine client which supports QuantumRunStream and CancelQuantumJob.""" - def __init__( - self, responses_and_exceptions: List[Union[quantum.QuantumRunStreamResponse, BaseException]] - ): - self.stream_requests: List[quantum.QuantumRunStreamRequest] = [] - self.cancel_requests: List[quantum.CancelQuantumJobRequest] = [] - self.responses_and_exceptions = responses_and_exceptions + _REQUEST_STOPPED = 'REQUEST_STOPPED' + + def __init__(self): + self.all_stream_requests: List[quantum.QuantumRunStreamRequest] = [] + self.all_cancel_requests: List[quantum.CancelQuantumJobRequest] = [] + self._executor = AsyncioExecutor.instance() + self._request_buffer = duet.AsyncCollector[quantum.QuantumRunStreamRequest]() + self._request_iterator_stopped = duet.AwaitableFuture() + # asyncio.Queue needs to be initialized inside the asyncio thread because all callers need + # to use the same event loop. + self._responses_and_exceptions_future: duet.AwaitableFuture[ + asyncio.Queue[Union[quantum.QuantumRunStreamResponse, BaseException]] + ] = duet.AwaitableFuture() async def quantum_run_stream( self, requests: AsyncIterator[quantum.QuantumRunStreamRequest], **kwargs ) -> Awaitable[AsyncIterable[quantum.QuantumRunStreamResponse]]: """Fakes the QuantumRunStream RPC. - Expects the number of requests to be the same as len(self.responses_and_exceptions). - - For every request, a response or exception is popped from `self.responses_and_exceptions`. - Before the next request: - * If it is a response, it is sent back through the stream. - * If it is an exception, the exception is raised. + Once a request is received, it is appended to `all_stream_requests`, and the test calling + `wait_for_requests()` is notified. - This fake does not support out-of-order responses. + The response is sent when a test calls `reply()` with a `QuantumRunStreamResponse`. If a + test calls `reply()` with an exception, it is raised here to the `quantum_run_stream()` + caller. - No responses are ever made if `self.responses_and_exceptions` is empty. + This is called from the asyncio thread. """ + responses_and_exceptions: asyncio.Queue[ + Union[quantum.QuantumRunStreamResponse, BaseException] + ] = asyncio.Queue() + self._responses_and_exceptions_future.try_set_result(responses_and_exceptions) - async def run_async_iterator(): + async def read_requests(): async for request in requests: - self.stream_requests.append(request) + self.all_stream_requests.append(request) + self._request_buffer.add(request) + await responses_and_exceptions.put(FakeQuantumRunStream._REQUEST_STOPPED) + self._request_iterator_stopped.try_set_result(None) + + async def response_iterator(): + asyncio.create_task(read_requests()) + while ( + message := await responses_and_exceptions.get() + ) != FakeQuantumRunStream._REQUEST_STOPPED: + if isinstance(message, quantum.QuantumRunStreamResponse): + yield message + else: # isinstance(message, BaseException) + self._responses_and_exceptions_future = duet.AwaitableFuture() + raise message + + return response_iterator() - if not self.responses_and_exceptions: - while True: - await asyncio.sleep(1) - - response_or_exception = self.responses_and_exceptions.pop(0) - if isinstance(response_or_exception, BaseException): - raise response_or_exception - response_or_exception.message_id = request.message_id - yield response_or_exception + async def cancel_quantum_job(self, request: quantum.CancelQuantumJobRequest) -> None: + """Records the cancellation in `cancel_requests`. + This is called from the asyncio thread. + """ + self.all_cancel_requests.append(request) await asyncio.sleep(0) - return run_async_iterator() - async def cancel_quantum_job(self, request: quantum.CancelQuantumJobRequest) -> None: - self.cancel_requests.append(request) - await asyncio.sleep(0) + async def wait_for_requests(self, num_requests=1) -> Sequence[quantum.QuantumRunStreamRequest]: + """Wait til `num_requests` number of requests are received via `quantum_run_stream()`. + + This must be called from the duet thread. + + Returns: + The received requests. + """ + requests = [] + for _ in range(num_requests): + requests.append(await self._request_buffer.__anext__()) + return requests + + async def reply( + self, response_or_exception: Union[quantum.QuantumRunStreamResponse, BaseException] + ): + """Sends a response or raises an exception to the `quantum_run_stream()` caller. + + If input response is missing `message_id`, it is defaulted to the `message_id` of the most + recent request. This is to support the most common use case of responding immediately after + a request. + + Assumes that at least one request must have been submitted to the StreamManager. + + This must be called from the duet thread. + """ + responses_and_exceptions = await self._responses_and_exceptions_future + if ( + isinstance(response_or_exception, quantum.QuantumRunStreamResponse) + and not response_or_exception.message_id + ): + response_or_exception.message_id = self.all_stream_requests[-1].message_id + + async def send(): + await responses_and_exceptions.put(response_or_exception) + + await self._executor.submit(send) + + async def wait_for_request_iterator_stop(self): + """Wait for the request iterator to stop. + + This must be called from a duet thread. + """ + await self._request_iterator_stopped + self._request_iterator_stopped = duet.AwaitableFuture() class TestResponseDemux: @@ -207,49 +276,38 @@ async def test_publish_exception_after_publishing_response_does_not_change_futur class TestStreamManager: @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_submit_expects_result_response(self, client_constructor): + # Arrange + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - # Arrange - expected_result = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' - ) - mock_responses = [quantum.QuantumRunStreamResponse(result=expected_result)] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses - ) - manager = StreamManager(fake_client) - # Act - actual_result = await manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(result=expected_result)) + actual_result = await actual_result_future manager.stop() # Assert assert actual_result == expected_result - assert len(fake_client.stream_requests) == 1 + assert len(fake_client.all_stream_requests) == 1 # assert that the first request is a CreateQuantumProgramAndJobRequest. - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] duet.run(test) @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_submit_program_without_name_raises(self, client_constructor): + _, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - # Arrange - expected_result = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' - ) - mock_responses = [quantum.QuantumRunStreamResponse(result=expected_result)] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses - ) - manager = StreamManager(fake_client) - with pytest.raises(ValueError, match='Program name must be set'): await manager.submit( - REQUEST_PROJECT_NAME, quantum.QuantumProgram(), REQUEST_JOB + REQUEST_PROJECT_NAME, quantum.QuantumProgram(), REQUEST_JOB0 ) manager.stop() @@ -257,20 +315,17 @@ async def test(): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_submit_cancel_future_expects_engine_cancellation_rpc_call(self, client_constructor): + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=[] - ) - manager = StreamManager(fake_client) - - result_future = manager.submit(REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB) + result_future = manager.submit(REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0) result_future.cancel() await duet.sleep(1) # Let cancellation complete asynchronously manager.stop() - assert len(fake_client.cancel_requests) == 1 - assert fake_client.cancel_requests[0] == quantum.CancelQuantumJobRequest( + assert len(fake_client.all_cancel_requests) == 1 + assert fake_client.all_cancel_requests[0] == quantum.CancelQuantumJobRequest( name='projects/proj/programs/prog/jobs/job0' ) @@ -280,31 +335,28 @@ async def test(): def test_submit_stream_broken_twice_expects_retry_with_get_quantum_result_twice( self, client_constructor ): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - expected_result = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' - ) - mock_responses_and_exceptions = [ - google_exceptions.ServiceUnavailable('unavailable'), - google_exceptions.ServiceUnavailable('unavailable'), - quantum.QuantumRunStreamResponse(result=expected_result), - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses_and_exceptions - ) - manager = StreamManager(fake_client) - - actual_result = await manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) + await fake_client.wait_for_requests() + await fake_client.reply(google_exceptions.ServiceUnavailable('unavailable')) + await fake_client.wait_for_requests() + await fake_client.reply(google_exceptions.ServiceUnavailable('unavailable')) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(result=expected_result)) + actual_result = await actual_result_future manager.stop() assert actual_result == expected_result - assert len(fake_client.stream_requests) == 3 - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] - assert 'get_quantum_result' in fake_client.stream_requests[1] - assert 'get_quantum_result' in fake_client.stream_requests[2] + assert len(fake_client.all_stream_requests) == 3 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'get_quantum_result' in fake_client.all_stream_requests[1] + assert 'get_quantum_result' in fake_client.all_stream_requests[2] duet.run(test) @@ -319,25 +371,24 @@ async def test(): def test_submit_with_retryable_stream_breakage_expects_get_result_request( self, client_constructor, error ): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - mock_responses = [ - error, - quantum.QuantumRunStreamResponse( - result=quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') - ), - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) - manager = StreamManager(fake_client) - - await manager.submit(REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB) + await fake_client.wait_for_requests() + await fake_client.reply(error) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(result=expected_result)) + await actual_result_future manager.stop() - assert len(fake_client.stream_requests) == 2 - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] - assert 'get_quantum_result' in fake_client.stream_requests[1] + assert len(fake_client.all_stream_requests) == 2 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'get_quantum_result' in fake_client.all_stream_requests[1] duet.run(test) @@ -360,80 +411,73 @@ async def test(): def test_submit_with_non_retryable_stream_breakage_raises_error( self, client_constructor, error ): + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - mock_responses = [ - error, - quantum.QuantumRunStreamResponse( - result=quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') - ), - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) - manager = StreamManager(fake_client) - + await fake_client.wait_for_requests() + await fake_client.reply(error) with pytest.raises(type(error)): - await manager.submit(REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB) + await actual_result_future manager.stop() - assert len(fake_client.stream_requests) == 1 - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] + assert len(fake_client.all_stream_requests) == 1 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] duet.run(test) @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_submit_expects_job_response(self, client_constructor): + expected_job = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - expected_job = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') - mock_responses = [quantum.QuantumRunStreamResponse(job=expected_job)] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses - ) - manager = StreamManager(fake_client) - - actual_job = await manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + actual_job_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(job=expected_job)) + actual_job = await actual_job_future manager.stop() assert actual_job == expected_job - assert len(fake_client.stream_requests) == 1 - # assert that the first request is a CreateQuantumProgramAndJobRequest. - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] + assert len(fake_client.all_stream_requests) == 1 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] duet.run(test) @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_submit_job_does_not_exist_expects_create_quantum_job_request(self, client_constructor): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - expected_result = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) - mock_responses_and_exceptions = [ - google_exceptions.ServiceUnavailable('unavailable'), + await fake_client.wait_for_requests() + await fake_client.reply(google_exceptions.ServiceUnavailable('unavailable')) + await fake_client.wait_for_requests() + await fake_client.reply( quantum.QuantumRunStreamResponse( error=quantum.StreamError(code=quantum.StreamError.Code.JOB_DOES_NOT_EXIST) - ), - quantum.QuantumRunStreamResponse(result=expected_result), - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses_and_exceptions - ) - manager = StreamManager(fake_client) - - actual_result = await manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + ) ) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(result=expected_result)) + actual_result = await actual_result_future manager.stop() assert actual_result == expected_result - assert len(fake_client.stream_requests) == 3 - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] - assert 'get_quantum_result' in fake_client.stream_requests[1] - assert 'create_quantum_job' in fake_client.stream_requests[2] + assert len(fake_client.all_stream_requests) == 3 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'get_quantum_result' in fake_client.all_stream_requests[1] + assert 'create_quantum_job' in fake_client.all_stream_requests[2] duet.run(test) @@ -441,39 +485,41 @@ async def test(): def test_submit_program_does_not_exist_expects_create_quantum_program_and_job_request( self, client_constructor ): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - expected_result = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) - mock_responses_and_exceptions = [ - google_exceptions.ServiceUnavailable('unavailable'), + await fake_client.wait_for_requests() + await fake_client.reply(google_exceptions.ServiceUnavailable('unavailable')) + await fake_client.wait_for_requests() + await fake_client.reply( quantum.QuantumRunStreamResponse( error=quantum.StreamError(code=quantum.StreamError.Code.JOB_DOES_NOT_EXIST) - ), + ) + ) + await fake_client.wait_for_requests() + await fake_client.reply( quantum.QuantumRunStreamResponse( error=quantum.StreamError( code=quantum.StreamError.Code.PROGRAM_DOES_NOT_EXIST ) - ), - quantum.QuantumRunStreamResponse(result=expected_result), - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses_and_exceptions - ) - manager = StreamManager(fake_client) - - actual_result = await manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + ) ) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(result=expected_result)) + actual_result = await actual_result_future manager.stop() assert actual_result == expected_result - assert len(fake_client.stream_requests) == 4 - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] - assert 'get_quantum_result' in fake_client.stream_requests[1] - assert 'create_quantum_job' in fake_client.stream_requests[2] - assert 'create_quantum_program_and_job' in fake_client.stream_requests[3] + assert len(fake_client.all_stream_requests) == 4 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'get_quantum_result' in fake_client.all_stream_requests[1] + assert 'create_quantum_job' in fake_client.all_stream_requests[2] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[3] duet.run(test) @@ -481,124 +527,129 @@ async def test(): def test_submit_program_already_exists_expects_program_already_exists_error( self, client_constructor ): + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - mock_responses_and_exceptions = [ + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( quantum.QuantumRunStreamResponse( error=quantum.StreamError( code=quantum.StreamError.Code.PROGRAM_ALREADY_EXISTS ) ) - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses_and_exceptions ) - manager = StreamManager(fake_client) - with pytest.raises(ProgramAlreadyExistsError): - await manager.submit(REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB) + await actual_result_future manager.stop() duet.run(test) @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_submit_twice_in_parallel_expect_result_responses(self, client_constructor): + expected_result0 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + expected_result1 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job1') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - request_job1 = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job1') - expected_result0 = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' + actual_result0_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) - expected_result1 = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job1' + actual_result1_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB1 ) - mock_responses = [ - quantum.QuantumRunStreamResponse(result=expected_result0), - quantum.QuantumRunStreamResponse(result=expected_result1), - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses + await fake_client.wait_for_requests(num_requests=2) + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[0].message_id, + result=expected_result0, + ) + ) + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[1].message_id, + result=expected_result1, + ) ) - manager = StreamManager(fake_client) + actual_result1 = await actual_result1_future + actual_result0 = await actual_result0_future + manager.stop() + + assert actual_result0 == expected_result0 + assert actual_result1 == expected_result1 + assert len(fake_client.all_stream_requests) == 2 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[1] + + duet.run(test) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) + def test_submit_twice_and_break_stream_expect_result_responses(self, client_constructor): + expected_result0 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + expected_result1 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job1') + fake_client, manager = setup(client_constructor) + async def test(): + async with duet.timeout_scope(5): actual_result0_future = manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) actual_result1_future = manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, request_job1 + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB1 + ) + await fake_client.wait_for_requests(num_requests=2) + await fake_client.reply(google_exceptions.ServiceUnavailable('unavailable')) + await fake_client.wait_for_requests(num_requests=2) + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=next( + req.message_id + for req in fake_client.all_stream_requests[2:] + if req.get_quantum_result.parent == expected_result0.parent + ), + result=expected_result0, + ) + ) + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=next( + req.message_id + for req in fake_client.all_stream_requests[2:] + if req.get_quantum_result.parent == expected_result1.parent + ), + result=expected_result1, + ) ) - actual_result1 = await actual_result1_future actual_result0 = await actual_result0_future + actual_result1 = await actual_result1_future manager.stop() assert actual_result0 == expected_result0 assert actual_result1 == expected_result1 - assert len(fake_client.stream_requests) == 2 - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] - assert 'create_quantum_program_and_job' in fake_client.stream_requests[1] + assert len(fake_client.all_stream_requests) == 4 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[1] + assert 'get_quantum_result' in fake_client.all_stream_requests[2] + assert 'get_quantum_result' in fake_client.all_stream_requests[3] duet.run(test) - # TODO(#5996) Update fake client implementation to support this test case. - # @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) - # def test_submit_twice_and_break_stream_expect_result_responses(self, client_constructor): - # async def test(): - # async with duet.timeout_scope(5): - # request_job1 = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job1') - # expected_result0 = quantum.QuantumResult( - # parent='projects/proj/programs/prog/jobs/job0' - # ) - # expected_result1 = quantum.QuantumResult( - # parent='projects/proj/programs/prog/jobs/job1' - # ) - # # TODO the current fake client doesn't have the response timing flexibility - # # required by this test. - # # Ideally, the client raises ServiceUnavailable after both initial requests are - # # sent. - # mock_responses = [ - # google_exceptions.ServiceUnavailable('unavailable'), - # google_exceptions.ServiceUnavailable('unavailable'), - # quantum.QuantumRunStreamResponse(result=expected_result0), - # quantum.QuantumRunStreamResponse(result=expected_result1), - # ] - # fake_client = setup_fake_quantum_run_stream_client( - # client_constructor, responses_and_exceptions=mock_responses - # ) - # manager = StreamManager(fake_client) - - # actual_result0_future = manager.submit( - # REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB - # ) - # actual_result1_future = manager.submit( - # REQUEST_PROJECT_NAME, REQUEST_PROGRAM, request_job1 - # ) - # actual_result1 = await actual_result1_future - # actual_result0 = await actual_result0_future - # manager.stop() - - # assert actual_result0 == expected_result0 - # assert actual_result1 == expected_result1 - # assert len(fake_client.stream_requests) == 2 - # assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] - # assert 'create_quantum_program_and_job' in fake_client.stream_requests[1] - - # duet.run(test) - @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_stop_cancels_existing_sends(self, client_constructor): + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=[] - ) - manager = StreamManager(fake_client) - actual_result_future = manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) # Wait for the manager to submit a request. If request submission runs after stop(), # it will start the manager again and the test will block waiting for a response. - await duet.sleep(1) + await fake_client.wait_for_requests() manager.stop() with pytest.raises(concurrent.futures.CancelledError): @@ -609,28 +660,24 @@ async def test(): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_stop_then_send_expects_result_response(self, client_constructor): """New requests should work after stopping the manager.""" + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) async def test(): async with duet.timeout_scope(5): - expected_result = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' - ) - mock_responses = [quantum.QuantumRunStreamResponse(result=expected_result)] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses - ) - manager = StreamManager(fake_client) - manager.stop() - actual_result = await manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(result=expected_result)) + actual_result = await actual_result_future manager.stop() assert actual_result == expected_result - assert len(fake_client.stream_requests) == 1 + assert len(fake_client.all_stream_requests) == 1 # assert that the first request is a CreateQuantumProgramAndJobRequest. - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] duet.run(test) @@ -674,3 +721,91 @@ def test_get_retry_request_or_raise_expects_stream_error( create_quantum_program_and_job_request, create_quantum_job_request, ) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) + def test_broken_stream_stops_request_iterator(self, client_constructor): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + + async def test(): + async with duet.timeout_scope(5): + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[0].message_id, + result=expected_result, + ) + ) + await actual_result_future + await fake_client.reply(google_exceptions.ServiceUnavailable('service unavailable')) + await fake_client.wait_for_request_iterator_stop() + manager.stop() + + duet.run(test) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) + def test_stop_stops_request_iterator(self, client_constructor): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + + async def test(): + async with duet.timeout_scope(5): + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[0].message_id, + result=expected_result, + ) + ) + await actual_result_future + manager.stop() + await fake_client.wait_for_request_iterator_stop() + + duet.run(test) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) + def test_submit_after_stream_breakage(self, client_constructor): + expected_result0 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + expected_result1 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job1') + fake_client, manager = setup(client_constructor) + + async def test(): + async with duet.timeout_scope(5): + actual_result0_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[0].message_id, + result=expected_result0, + ) + ) + actual_result0 = await actual_result0_future + await fake_client.reply(google_exceptions.ServiceUnavailable('service unavailable')) + actual_result1_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[1].message_id, + result=expected_result1, + ) + ) + actual_result1 = await actual_result1_future + manager.stop() + + assert len(fake_client.all_stream_requests) == 2 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[1] + assert actual_result0 == expected_result0 + assert actual_result1 == expected_result1 + + duet.run(test) diff --git a/dev_tools/codeowners_test.py b/dev_tools/codeowners_test.py index b0e788371f5..64c6786acb9 100644 --- a/dev_tools/codeowners_test.py +++ b/dev_tools/codeowners_test.py @@ -26,7 +26,8 @@ GOOGLE_MAINTAINERS = BASE_MAINTAINERS.union(GOOGLE_TEAM) IONQ_TEAM = { - ('USERNAME', u) for u in ["@dabacon", "@ColemanCollins", "@nakardo", "@gmauricio", "@Cynocracy"] + ('USERNAME', u) + for u in ["@dabacon", "@ColemanCollins", "@nakardo", "@gmauricio", "@Cynocracy", "@splch"] } IONQ_MAINTAINERS = BASE_MAINTAINERS.union(IONQ_TEAM) diff --git a/docs/experiments/textbook_algorithms.ipynb b/docs/experiments/textbook_algorithms.ipynb index 182a91e5ff2..9bec52408b1 100644 --- a/docs/experiments/textbook_algorithms.ipynb +++ b/docs/experiments/textbook_algorithms.ipynb @@ -1010,7 +1010,7 @@ "outputs": [], "source": [ "\"\"\"Plot the results.\"\"\"\n", - "plt.style.use(\"seaborn-whitegrid\")\n", + "plt.style.use(\"seaborn-v0_8-whitegrid\")\n", "\n", "plt.plot(nvals, estimates, \"--o\", label=\"Phase estimation\")\n", "plt.axhline(theta, label=\"True value\", color=\"black\")\n", diff --git a/docs/start/intro.ipynb b/docs/start/intro.ipynb index 6929b08fce3..42599d0cfe2 100644 --- a/docs/start/intro.ipynb +++ b/docs/start/intro.ipynb @@ -1453,7 +1453,7 @@ " probs.append(prob[0])\n", "\n", "# Plot the probability of the ground state at each simulation step.\n", - "plt.style.use('seaborn-whitegrid')\n", + "plt.style.use('seaborn-v0_8-whitegrid')\n", "plt.plot(probs, 'o')\n", "plt.xlabel(\"Step\")\n", "plt.ylabel(\"Probability of ground state\");" @@ -1490,7 +1490,7 @@ "\n", "\n", "# Plot the probability of the ground state at each simulation step.\n", - "plt.style.use('seaborn-whitegrid')\n", + "plt.style.use('seaborn-v0_8-whitegrid')\n", "plt.plot(sampled_probs, 'o')\n", "plt.xlabel(\"Step\")\n", "plt.ylabel(\"Probability of ground state\");" diff --git a/examples/two_qubit_gate_compilation.py b/examples/two_qubit_gate_compilation.py index 2dd1a9e3260..9362ce9c12c 100644 --- a/examples/two_qubit_gate_compilation.py +++ b/examples/two_qubit_gate_compilation.py @@ -88,7 +88,7 @@ def main(samples: int = 1000, max_infidelity: float = 0.01): print(f'Maximum infidelity of "failed" compilation: {np.max(failed_infidelities_arr)}') plt.figure() - plt.hist(infidelities_arr, bins=25, range=[0, max_infidelity * 1.1]) + plt.hist(infidelities_arr, bins=25, range=(0.0, max_infidelity * 1.1)) # pragma: no cover ylim = plt.ylim() plt.plot([max_infidelity] * 2, ylim, '--', label='Maximum tabulation infidelity') plt.xlabel('Compiled gate infidelity vs target')