Skip to content

Commit

Permalink
added a function
Browse files Browse the repository at this point in the history
  • Loading branch information
Aarsh2001 committed Nov 4, 2023
1 parent 589f773 commit 687fb7c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 95 deletions.
16 changes: 16 additions & 0 deletions claude/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ def starter_code(key):
with open("diff.txt", '+rb') as f:
# intelligent regex
content = f.open()
fns_without_docstring = dict()
contains_docstring = False
for line in content:
if line.contains("+def "):
in_func = True
func_name = line.split('+def ')[1].split('(')[0]
# regex to check if there exists a docstring
if line == "+":
if in_func and not contains_docstring:
fns_without_docstring[func_name] = generate_docstring(filename, func_name, key)
in_func = False
contains_docstring = False
func_name = ""
if line.contains('"""'):
contains_docstring = True

# changed function names here
# changed file names
with open("~/files.json", 'r') as f:
Expand Down
96 changes: 1 addition & 95 deletions dummy-files/test.py
Original file line number Diff line number Diff line change
@@ -1,95 +1 @@
# global
from numbers import Number
import numpy as np
from typing import Union, Optional, List, Sequence, Tuple

import jax.dlpack
import jax.numpy as jnp
import jax._src as _src
import jaxlib.xla_extension

# local
import ivy
from ivy import as_native_dtype
from ivy.functional.backends.jax import JaxArray
from ivy.functional.ivy.creation import (
_asarray_to_native_arrays_and_back,
_asarray_infer_device,
_asarray_infer_dtype,
_asarray_handle_nestable,
NestedSequence,
SupportsBufferProtocol,
_asarray_inputs_to_native_shapes,
)


# Array API Standard #
# ------------------ #

@_asarray_to_native_arrays_and_back
@_asarray_infer_device
@_asarray_handle_nestable
@_asarray_inputs_to_native_shapes
@_asarray_infer_dtype
def asarray(
obj: Union[
JaxArray,
bool,
int,
float,
tuple,
NestedSequence,
SupportsBufferProtocol,
np.ndarray,
],
/,
*,
copy: Optional[bool] = None,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
ivy.utils.assertions._check_jax_x64_flag(dtype)
if copy is True:
return jnp.array(obj, dtype=dtype, copy=True)
else:
return jnp.asarray(obj, dtype=dtype)

def arange(
start: float,
/,
stop: Optional[float] = None,
step: float = 1,
*,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if dtype:
dtype = as_native_dtype(dtype)
ivy.utils.assertions._check_jax_x64_flag(dtype.name)
res = jnp.arange(start, stop, step, dtype=dtype)
if not dtype:
if res.dtype == jnp.float64:
return res.astype(jnp.float32)
elif res.dtype == jnp.int64:
return res.astype(jnp.int32)
return res

def meshgrid(
*arrays: JaxArray,
sparse: bool = False,
indexing: str = "xy",
out: Optional[JaxArray] = None,
) -> List[JaxArray]:
return jnp.meshgrid(*arrays, sparse=sparse, indexing=indexing)

def ones_like(
x: JaxArray,
/,
*,
dtype: jnp.dtype,
device: jaxlib.xla_extension.Device = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.ones_like(x, dtype=dtype)
1234

0 comments on commit 687fb7c

Please sign in to comment.