From 7774693e8c4b82676932ec23886f8c15152b53a8 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:23:23 +0200 Subject: [PATCH] Make `vector` package optional (#110) * Make `vector` package optional * FIX: import `vector` inline --- phasespace/phasespace.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/phasespace/phasespace.py b/phasespace/phasespace.py index da9cbce..e73dec3 100644 --- a/phasespace/phasespace.py +++ b/phasespace/phasespace.py @@ -17,16 +17,20 @@ import inspect from collections.abc import Callable from math import pi +from typing import TYPE_CHECKING, NoReturn import numpy as np import tensorflow as tf import tensorflow.experimental.numpy as tnp -import vector from . import kinematics as kin from .backend import function, function_jit_fixedshape from .random import SeedLike, get_rng +if TYPE_CHECKING: + import vector + + RELAX_SHAPES = False @@ -669,6 +673,10 @@ def generate( """ rng = get_rng(seed) if boost_to is not None: + try: + import vector + except ImportError as error: + _raise_missing_vector_package(error) if isinstance(boost_to, vector.Vector): if not ( isinstance(boost_to, vector.Momentum) @@ -815,8 +823,18 @@ def to_vectors(particles: dict[str, tf.Tensor]) -> dict[str, vector.Momentum]: Return: dict: Dictionary of `vector.Momentum` instances with numpy arrays """ + try: + import vector + except ImportError as error: + _raise_missing_vector_package(error) newparticles = {} for name, particle in particles.items(): px, py, pz, e = np.moveaxis(particle, -1, 0) # numpy "unstack" newparticles[name] = vector.array(dict(px=px, py=py, pz=pz, energy=e)) return newparticles + + +def _raise_missing_vector_package(exception: ImportError) -> NoReturn: + raise ImportError( + "To use `boost_to`, the `vector` package has to be installed." + ) from exception