5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import math
8
- import random
9
8
from typing import Any , List , Optional , Set , Union
10
9
10
+ from inputgen .utils .random_manager import random_manager as rm
11
11
from inputgen .variable .constants import BOUND_ON_INF , INT64_MAX , INT64_MIN
12
12
from inputgen .variable .space import Interval , Intervals , VariableSpace
13
+ from inputgen .variable .type import sort_values_of_type
13
14
from inputgen .variable .utils import nextdown , nextup
14
15
15
16
@@ -51,7 +52,7 @@ def gen_float_from_interval(r: Interval) -> Optional[float]:
51
52
elif lower > upper :
52
53
return None
53
54
else :
54
- return random .uniform (lower , upper )
55
+ return rm . get_random () .uniform (lower , upper )
55
56
56
57
57
58
def gen_min_float_from_intervals (rs : Intervals ) -> Optional [float ]:
@@ -69,7 +70,7 @@ def gen_max_float_from_intervals(rs: Intervals) -> Optional[float]:
69
70
def gen_float_from_intervals (rs : Intervals ) -> Optional [float ]:
70
71
if rs .empty ():
71
72
return None
72
- r = random .choice (rs .intervals )
73
+ r = rm . get_random () .choice (rs .intervals )
73
74
return gen_float_from_interval (r )
74
75
75
76
@@ -112,7 +113,7 @@ def gen_int_from_interval(r: Interval) -> Optional[int]:
112
113
elif upper is None :
113
114
upper = max (lower , 0 ) + BOUND_ON_INF
114
115
assert lower is not None and upper is not None
115
- return random .randint (lower , upper )
116
+ return rm . get_random () .randint (lower , upper )
116
117
117
118
118
119
def gen_min_int_from_intervals (rs : Intervals ) -> Optional [int ]:
@@ -133,7 +134,7 @@ def gen_int_from_intervals(rs: Intervals) -> Optional[int]:
133
134
intervals_with_ints = [r for r in rs .intervals if r .contains_int ()]
134
135
if len (intervals_with_ints ) == 0 :
135
136
return None
136
- r = random .choice (intervals_with_ints )
137
+ r = rm . get_random () .choice (intervals_with_ints )
137
138
return gen_int_from_interval (r )
138
139
139
140
@@ -147,6 +148,12 @@ def __init__(self, space: VariableSpace):
147
148
self .vtype = space .vtype
148
149
self .space = space
149
150
151
+ def _sorted (self , values : Set [Any ]) -> List [Any ]:
152
+ return sort_values_of_type (self .vtype , values )
153
+
154
+ def _sample (self , values : Set [Any ], num : int ) -> List [Any ]:
155
+ return rm .get_random ().sample (self ._sorted (values ), num )
156
+
150
157
def gen_min (self ) -> Any :
151
158
"""Returns the minimum value of the space."""
152
159
if self .space .empty () or self .vtype not in [bool , int , float ]:
@@ -221,7 +228,7 @@ def gen_edges_non_extreme(self, num: int = 2) -> Set[Any]:
221
228
edges_not_extreme = self .gen_edges () - self .gen_extremes ()
222
229
if num >= len (edges_not_extreme ):
223
230
return edges_not_extreme
224
- return set (random . sample ( list ( edges_not_extreme ) , num ))
231
+ return set (self . _sample ( edges_not_extreme , num ))
225
232
226
233
def gen_non_edges (self , num : int = 2 ) -> Set [Any ]:
227
234
"""Generates non-edge (or interior) values of the space."""
@@ -232,7 +239,7 @@ def gen_non_edges(self, num: int = 2) -> Set[Any]:
232
239
if self .space .discrete .initialized :
233
240
vals = self .space .discrete .values - edge_or_extreme_vals
234
241
if num < len (vals ):
235
- vals = set (random . sample ( list ( vals ) , num ))
242
+ vals = set (self . _sample ( vals , num ))
236
243
else :
237
244
for _ in range (100 ):
238
245
v : Optional [Union [int , float ]] = None
@@ -269,11 +276,8 @@ def gen_balanced(self, num: int = 6) -> Set[Any]:
269
276
270
277
if num >= len (balanced ):
271
278
return balanced
272
- return set (random . sample ( list ( balanced ) , num ))
279
+ return set (self . _sample ( balanced , num ))
273
280
274
281
def gen (self , num : int = 6 ) -> List [Any ]:
275
282
"""Generates a sorted (if applicable), balanced sample of the space."""
276
- vals = list (self .gen_balanced (num ))
277
- if self .vtype in [bool , int , float , str ]:
278
- return sorted (vals )
279
- return vals
283
+ return sort_values_of_type (self .vtype , self .gen_balanced (num ))
0 commit comments