Skip to content

Commit 70a1fe8

Browse files
ianyfanThomas Hoffmannnikhilkhatri
committed
Release version 0.2.8
Co-authored-by: Thomas Hoffmann <[email protected]> Co-authored-by: Nikhil Khatri <[email protected]>
1 parent c4e361a commit 70a1fe8

File tree

14 files changed

+70
-40
lines changed

14 files changed

+70
-40
lines changed

.github/workflows/build_test.yml

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ jobs:
2121
outputs:
2222
error-check: ${{ steps.error-check.conclusion }}
2323
steps:
24-
- uses: actions/checkout@v2
24+
- uses: actions/checkout@v3
2525
- name: Setup Python ${{ matrix.python-version }}
26-
uses: actions/setup-python@v2
26+
uses: actions/setup-python@v4
2727
with:
2828
python-version: ${{ matrix.python-version }}
2929
- name: Install linter
@@ -49,25 +49,14 @@ jobs:
4949
matrix:
5050
python-version: [ 3.8, 3.9, "3.10" ]
5151
steps:
52-
- uses: actions/checkout@v2
52+
- uses: actions/checkout@v3
5353
- name: Setup Python ${{ matrix.python-version }}
54-
uses: actions/setup-python@v2
54+
uses: actions/setup-python@v4
5555
with:
5656
python-version: ${{ matrix.python-version }}
57-
- name: Locate pip cache
58-
id: loc-pip-cache
59-
run: echo "::set-output name=dir::$(pip cache dir)"
60-
- name: Restore pip dependencies from cache
61-
uses: actions/cache@v2
62-
with:
63-
path: ${{ steps.loc-pip-cache.outputs.dir }}
64-
key: build_and_test-${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }}
65-
restore-keys: |
66-
build_and_test-${{ runner.os }}-pip-${{ matrix.python-version }}-
67-
build_and_test-${{ runner.os }}-pip-
68-
- name: Install DisCoPy from GitHub
57+
- name: Install DisCoPy 0.5 from GitHub
6958
if: github.ref_name != 'release' && github.ref_name != 'beta'
70-
run: pip install git+https://github.com/oxford-quantum-group/discopy
59+
run: pip install git+https://github.com/discopy/discopy@0.5
7160
- name: Install base package
7261
run: pip install .
7362
- name: Check package import works
@@ -130,9 +119,9 @@ jobs:
130119
matrix:
131120
python-version: [ 3.8, 3.9, "3.10" ]
132121
steps:
133-
- uses: actions/checkout@v2
122+
- uses: actions/checkout@v3
134123
- name: Setup Python ${{ matrix.python-version }}
135-
uses: actions/setup-python@v2
124+
uses: actions/setup-python@v4
136125
with:
137126
python-version: ${{ matrix.python-version }}
138127
- name: Install dependencies with type hints

.github/workflows/docs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ jobs:
1616
name: Build and deploy documentation
1717
runs-on: ubuntu-latest
1818
steps:
19-
- uses: actions/checkout@v2
19+
- uses: actions/checkout@v3
2020
with:
2121
fetch-depth: 0 # fetches tags, required for version info
2222
- name: Set up Python
23-
uses: actions/setup-python@v2
23+
uses: actions/setup-python@v4
2424
with:
2525
python-version: 3.8
2626
- name: Build lambeq

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
]
4747

4848
intersphinx_mapping = {
49-
'discopy': ("https://discopy.readthedocs.io/en/main/", None),
49+
'discopy': ("https://discopy.readthedocs.io/en/0.5/", None),
5050
'pennylane': ("https://pennylane.readthedocs.io/en/stable/", None),
5151
}
5252

docs/release_notes.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@
33
Release notes
44
=============
55

6+
.. _rel-0.2.8:
7+
8+
`0.2.8 <https://github.com/CQCL/lambeq/releases/tag/0.2.8>`_
9+
------------------------------------------------------------
10+
Changed:
11+
12+
- Improved the performance of :py:class:`.NumpyModel` when using Jax JIT-compilation.
13+
- Dependencies: pinned the required version of DisCoPy to 0.5.X.
14+
15+
Fixed:
16+
17+
- Fixed incorrectly scaled validation loss in progress bar during model training.
18+
- Fixed symbol type mismatch in the quantum models when a circuit was previously converted to tket.
19+
620
.. _rel-0.2.7:
721

822
`0.2.7 <https://github.com/CQCL/lambeq/releases/tag/0.2.7>`_

lambeq/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
'__version_info__',
1818

1919
'ansatz',
20-
'text2diagram',
2120
'core',
2221
'pregroups',
23-
'reader',
2422
'rewrite',
23+
'text2diagram',
2524
'tokeniser',
2625
'training',
2726

lambeq/ansatz/circuit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
from discopy.quantum.gates import Bra, H, Ket, Rx, Ry, Rz
3535
from discopy.rigid import Box, Diagram, Ty
3636
import numpy as np
37-
from sympy import symbols
37+
from sympy import Symbol, symbols
3838

39-
from lambeq.ansatz import BaseAnsatz, Symbol
39+
from lambeq.ansatz import BaseAnsatz
4040

4141
computational_basis = Id(qubit)
4242

lambeq/text2diagram/ccg_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import Optional
2323

2424
from discopy import Diagram
25-
from tqdm.autonotebook import tqdm
25+
from tqdm.auto import tqdm
2626

2727
from lambeq.core.globals import VerbosityLevel
2828
from lambeq.core.utils import (SentenceBatchType, SentenceType,

lambeq/training/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from typing import Any, Union
2727

2828
from discopy.tensor import Diagram
29-
from sympy import default_sort_key
29+
from sympy import default_sort_key, Symbol as SymPySymbol
3030

3131
from lambeq.ansatz.base import Symbol
3232
from lambeq.training.checkpoint import Checkpoint
@@ -50,7 +50,7 @@ class Model(ABC):
5050

5151
def __init__(self) -> None:
5252
"""Initialise an instance of :py:class:`Model` base class."""
53-
self.symbols: list[Symbol] = []
53+
self.symbols: list[Union[Symbol, SymPySymbol]] = []
5454
self.weights: Collection = []
5555

5656
def __call__(self, *args: Any, **kwds: Any) -> Any:

lambeq/training/numpy_model.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,19 @@
2828

2929
from collections.abc import Callable, Iterable
3030
import pickle
31-
from typing import Any
31+
from typing import Any, TYPE_CHECKING, Union
3232

3333
from discopy import Tensor
3434
from discopy.tensor import Diagram
3535
import numpy
3636
from numpy.typing import ArrayLike
3737
from sympy import lambdify
3838

39+
40+
if TYPE_CHECKING:
41+
from jax import numpy as jnp
42+
43+
3944
from lambeq.training.quantum_model import QuantumModel
4045

4146

@@ -74,7 +79,7 @@ def _get_lambda(self, diagram: Diagram) -> Callable[[Any], Any]:
7479
if diagram in self.lambdas:
7580
return self.lambdas[diagram]
7681

77-
def diagram_output(*x: ArrayLike) -> ArrayLike:
82+
def diagram_output(x: Iterable[ArrayLike]) -> ArrayLike:
7883
with Tensor.backend('jax'), tn.DefaultBackend('jax'):
7984
sub_circuit = self._fast_subs([diagram], x)[0]
8085
result = tn.contractors.auto(*sub_circuit.to_tn()).tensor
@@ -112,7 +117,9 @@ def _fast_subs(self,
112117
b._phase = b._data
113118
return diagrams
114119

115-
def get_diagram_output(self, diagrams: list[Diagram]) -> numpy.ndarray:
120+
def get_diagram_output(self,
121+
diagrams: list[Diagram]) -> Union[jnp.ndarray,
122+
numpy.ndarray]:
116123
"""Return the exact prediction for each diagram.
117124
118125
Parameters
@@ -142,9 +149,13 @@ def get_diagram_output(self, diagrams: list[Diagram]) -> numpy.ndarray:
142149
'from pre-trained checkpoint.')
143150

144151
if self.use_jit:
152+
from jax import numpy as jnp
153+
145154
lambdified_diagrams = [self._get_lambda(d) for d in diagrams]
146-
return numpy.array([diag_f(*self.weights)
147-
for diag_f in lambdified_diagrams])
155+
res: jnp.ndarray = jnp.array([diag_f(self.weights)
156+
for diag_f in lambdified_diagrams])
157+
158+
return res
148159

149160
diagrams = self._fast_subs(diagrams, self.weights)
150161
with Tensor.backend('numpy'):

lambeq/training/pytorch_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class PytorchModel(Model, torch.nn.Module):
3535
"""A lambeq model for the classical pipeline using PyTorch."""
3636

3737
weights: torch.nn.ParameterList # type: ignore[assignment]
38+
symbols: list[Symbol] # type: ignore[assignment]
3839

3940
def __init__(self) -> None:
4041
"""Initialise a PytorchModel."""

0 commit comments

Comments
 (0)