15
15
"""Commonly repeated jax expressions."""
16
16
17
17
import contextlib
18
- import dataclasses
19
18
import functools
20
19
import os
21
- from typing import Any , Callable , Optional , TypeVar , Union
20
+ from typing import Any , Callable , Optional , TypeVar
22
21
import chex
23
22
import equinox as eqx
24
23
import jax
@@ -155,39 +154,6 @@ def error_if_negative(
155
154
return error_if (to_wrap , min_var < 0 , msg )
156
155
157
156
158
- def jax_default (value : chex .Numeric ) -> ...:
159
- """Define a dataclass field with a jax-type default value.
160
-
161
- Args:
162
- value: The default value of the field.
163
-
164
- Returns:
165
- field: The dataclass field.
166
- """
167
- jax_value = lambda : jnp .array (value )
168
- return dataclasses .field (default_factory = jax_value )
169
-
170
-
171
- def compat_linspace (
172
- start : Union [chex .Numeric , jax .Array ], stop : jax .Array , num : jax .Array
173
- ) -> jax .Array :
174
- """See np.linspace.
175
-
176
- This implementation of a subset of the linspace API reproduces the
177
- output of numpy better (at least when run in float64 mode) than
178
- jnp.linspace does.
179
-
180
- Args:
181
- start: first value
182
- stop: last value
183
- num: Number of points in the series
184
-
185
- Returns:
186
- linspace: array of shape (num) increasing linearly from `start` to `stop`
187
- """
188
- return jnp .arange (num ) * ((stop - start ) / (num - 1 )) + start
189
-
190
-
191
157
def assert_rank (
192
158
inputs : chex .Numeric | jax .stages .ArgInfo ,
193
159
rank : int ,
@@ -199,34 +165,6 @@ def assert_rank(
199
165
chex .assert_rank (inputs , rank )
200
166
201
167
202
- def select (
203
- cond : jax .Array | bool ,
204
- true_val : jax .Array ,
205
- false_val : jax .Array ,
206
- ) -> jax .Array :
207
- """Wrapper around jnp.where for readability."""
208
- return jnp .where (cond , true_val , false_val )
209
-
210
-
211
- def is_tracer (var : jax .Array ) -> bool :
212
- """Checks whether `var` is a jax tracer.
213
-
214
- Args:
215
- var: The jax variable to inspect.
216
-
217
- Returns:
218
- output: True `var` is a tracer, False if concrete.
219
- """
220
-
221
- try :
222
- if var .sum () > 0 :
223
- return False
224
- return False
225
- except jax .errors .TracerBoolConversionError :
226
- return True
227
- assert False # Should be unreachable
228
-
229
-
230
168
def jit (* args , ** kwargs ) -> Callable [..., Any ]:
231
169
"""Calls jax.jit if TORAX_COMPILATION_ENABLED is True, otherwise no-op."""
232
170
if env_bool ('TORAX_COMPILATION_ENABLED' , True ):
0 commit comments