Skip to content

Commit

Permalink
added a function
Browse files Browse the repository at this point in the history
  • Loading branch information
Aarsh2001 committed Nov 4, 2023
1 parent 687fb7c commit a7444ed
Showing 1 changed file with 110 additions and 1 deletion.
111 changes: 110 additions & 1 deletion dummy-files/test.py
Original file line number Diff line number Diff line change
@@ -1 +1,110 @@
1234
# global
from numbers import Number
import numpy as np
from typing import Union, Optional, List, Sequence, Tuple

import jax.dlpack
import jax.numpy as jnp
import jax._src as _src
import jaxlib.xla_extension
import tensorflow as tf

# local
import ivy
from ivy import as_native_dtype
from ivy.functional.backends.jax import JaxArray
from ivy.functional.ivy.creation import (
_asarray_to_native_arrays_and_back,
_asarray_infer_device,
_asarray_infer_dtype,
_asarray_handle_nestable,
NestedSequence,
SupportsBufferProtocol,
_asarray_inputs_to_native_shapes,
)


# Array API Standard #
# ------------------ #

@_asarray_to_native_arrays_and_back
@_asarray_infer_device
@_asarray_handle_nestable
@_asarray_inputs_to_native_shapes
@_asarray_infer_dtype
def asarray(
obj: Union[
JaxArray,
bool,
int,
float,
tuple,
NestedSequence,
SupportsBufferProtocol,
np.ndarray,
],
/,
*,
copy: Optional[bool] = None,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
ivy.utils.assertions._check_jax_x64_flag(dtype)
if copy is True:
return jnp.array(obj, dtype=dtype, copy=True)
else:
return jnp.asarray(obj, dtype=dtype)

def arange(
start: float,
/,
stop: Optional[float] = None,
step: float = 1,
*,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if dtype:
dtype = as_native_dtype(dtype)
ivy.utils.assertions._check_jax_x64_flag(dtype.name)
res = jnp.arange(start, stop, step, dtype=dtype)
if not dtype:
if res.dtype == jnp.float64:
return res.astype(jnp.float32)
elif res.dtype == jnp.int64:
return res.astype(jnp.int32)
return res

def meshgrid(
*arrays: JaxArray,
sparse: bool = False,
indexing: str = "xy",
out: Optional[JaxArray] = None,
) -> List[JaxArray]:
return jnp.meshgrid(*arrays, sparse=sparse, indexing=indexing)

def broadcast_arrays(
*arrays: Union[tf.Tensor, tf.Variable],
) -> List[Union[tf.Tensor, tf.Variable]]:
if len(arrays) > 1:
try:
desired_shape = tf.broadcast_dynamic_shape(arrays[0].shape, arrays[1].shape)
except tf.errors.InvalidArgumentError as e:
raise ivy.utils.exceptions.IvyBroadcastShapeError(e)
if len(arrays) > 2:
for i in range(2, len(arrays)):
try:
desired_shape = tf.broadcast_dynamic_shape(
desired_shape, arrays[i].shape
)
except tf.errors.InvalidArgumentError as e:
raise ivy.utils.exceptions.IvyBroadcastShapeError(e)
else:
return [arrays[0]]
result = []
for tensor in arrays:
result.append(tf.broadcast_to(tensor, desired_shape))

return result

0 comments on commit a7444ed

Please sign in to comment.