Skip to content

Commit e2c1f2d

Browse files
Nush395Torax team
authored andcommitted
Remove unused jax_utils functions.
PiperOrigin-RevId: 742877934
1 parent 0653d5a commit e2c1f2d

File tree

1 file changed

+1
-63
lines changed

1 file changed

+1
-63
lines changed

torax/jax_utils.py

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
"""Commonly repeated jax expressions."""
1616

1717
import contextlib
18-
import dataclasses
1918
import functools
2019
import os
21-
from typing import Any, Callable, Optional, TypeVar, Union
20+
from typing import Any, Callable, Optional, TypeVar
2221
import chex
2322
import equinox as eqx
2423
import jax
@@ -155,39 +154,6 @@ def error_if_negative(
155154
return error_if(to_wrap, min_var < 0, msg)
156155

157156

158-
def jax_default(value: chex.Numeric) -> ...:
159-
"""Define a dataclass field with a jax-type default value.
160-
161-
Args:
162-
value: The default value of the field.
163-
164-
Returns:
165-
field: The dataclass field.
166-
"""
167-
jax_value = lambda: jnp.array(value)
168-
return dataclasses.field(default_factory=jax_value)
169-
170-
171-
def compat_linspace(
172-
start: Union[chex.Numeric, jax.Array], stop: jax.Array, num: jax.Array
173-
) -> jax.Array:
174-
"""See np.linspace.
175-
176-
This implementation of a subset of the linspace API reproduces the
177-
output of numpy better (at least when run in float64 mode) than
178-
jnp.linspace does.
179-
180-
Args:
181-
start: first value
182-
stop: last value
183-
num: Number of points in the series
184-
185-
Returns:
186-
linspace: array of shape (num) increasing linearly from `start` to `stop`
187-
"""
188-
return jnp.arange(num) * ((stop - start) / (num - 1)) + start
189-
190-
191157
def assert_rank(
192158
inputs: chex.Numeric | jax.stages.ArgInfo,
193159
rank: int,
@@ -199,34 +165,6 @@ def assert_rank(
199165
chex.assert_rank(inputs, rank)
200166

201167

202-
def select(
203-
cond: jax.Array | bool,
204-
true_val: jax.Array,
205-
false_val: jax.Array,
206-
) -> jax.Array:
207-
"""Wrapper around jnp.where for readability."""
208-
return jnp.where(cond, true_val, false_val)
209-
210-
211-
def is_tracer(var: jax.Array) -> bool:
212-
"""Checks whether `var` is a jax tracer.
213-
214-
Args:
215-
var: The jax variable to inspect.
216-
217-
Returns:
218-
output: True `var` is a tracer, False if concrete.
219-
"""
220-
221-
try:
222-
if var.sum() > 0:
223-
return False
224-
return False
225-
except jax.errors.TracerBoolConversionError:
226-
return True
227-
assert False # Should be unreachable
228-
229-
230168
def jit(*args, **kwargs) -> Callable[..., Any]:
231169
"""Calls jax.jit if TORAX_COMPILATION_ENABLED is True, otherwise no-op."""
232170
if env_bool('TORAX_COMPILATION_ENABLED', True):

0 commit comments

Comments
 (0)