diff --git a/tests/test_traitlets.py b/tests/test_traitlets.py index f9f623b4..665422cd 100644 --- a/tests/test_traitlets.py +++ b/tests/test_traitlets.py @@ -7,6 +7,7 @@ # also under the terms of the Modified BSD License. from __future__ import annotations +import decimal import pickle import re import typing as t @@ -1391,7 +1392,7 @@ class TestLong(TraitTestBase): obj = LongTrait() _default_value = 99 - _good_values = [10, -10] + _good_values = [10, -10, 10.0, decimal.Decimal("10.0")] _bad_values = [ "ten", [10], @@ -1401,6 +1402,7 @@ class TestLong(TraitTestBase): 1j, 10.1, -10.1, + decimal.Decimal("10.1"), "10", "-10", "10L", diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index a9d02c72..c8c9d399 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -44,6 +44,7 @@ import contextlib import enum import inspect +import numbers import os import re import sys @@ -2646,6 +2647,15 @@ def __init__( ) def validate(self, obj: t.Any, value: t.Any) -> G: + if not isinstance(value, int) and isinstance(value, numbers.Number): + # allow casting integer-valued numbers to int + # allows for more concise assignment like `4e9` which is a float + try: + int_value = int(value) + if int_value == value: + value = int_value + except Exception: + pass if not isinstance(value, int): self.error(obj, value) return _validate_bounds(self, obj, value) # type:ignore[no-any-return]