diff --git a/claude/script.py b/claude/script.py index 028ed1a..4c822ed 100644 --- a/claude/script.py +++ b/claude/script.py @@ -32,6 +32,22 @@ def starter_code(key): with open("diff.txt", '+rb') as f: # intelligent regex content = f.open() + fns_without_docstring = dict() + contains_docstring = False + for line in content: + if line.contains("+def "): + in_func = True + func_name = line.split('+def ')[1].split('(')[0] + # regex to check if there exists a docstring + if line == "+": + if in_func and not contains_docstring: + fns_without_docstring[func_name] = generate_docstring(filename, func_name, key) + in_func = False + contains_docstring = False + func_name = "" + if line.contains('"""'): + contains_docstring = True + # changed function names here # changed file names with open("~/files.json", 'r') as f: diff --git a/dummy-files/test.py b/dummy-files/test.py index 9e713b4..274c005 100644 --- a/dummy-files/test.py +++ b/dummy-files/test.py @@ -1,95 +1 @@ -# global -from numbers import Number -import numpy as np -from typing import Union, Optional, List, Sequence, Tuple - -import jax.dlpack -import jax.numpy as jnp -import jax._src as _src -import jaxlib.xla_extension - -# local -import ivy -from ivy import as_native_dtype -from ivy.functional.backends.jax import JaxArray -from ivy.functional.ivy.creation import ( - _asarray_to_native_arrays_and_back, - _asarray_infer_device, - _asarray_infer_dtype, - _asarray_handle_nestable, - NestedSequence, - SupportsBufferProtocol, - _asarray_inputs_to_native_shapes, -) - - -# Array API Standard # -# ------------------ # - -@_asarray_to_native_arrays_and_back -@_asarray_infer_device -@_asarray_handle_nestable -@_asarray_inputs_to_native_shapes -@_asarray_infer_dtype -def asarray( - obj: Union[ - JaxArray, - bool, - int, - float, - tuple, - NestedSequence, - SupportsBufferProtocol, - np.ndarray, - ], - /, - *, - copy: Optional[bool] = None, - dtype: Optional[jnp.dtype] = None, - device: jaxlib.xla_extension.Device = None, - out: Optional[JaxArray] = None, -) -> JaxArray: - ivy.utils.assertions._check_jax_x64_flag(dtype) - if copy is True: - return jnp.array(obj, dtype=dtype, copy=True) - else: - return jnp.asarray(obj, dtype=dtype) - -def arange( - start: float, - /, - stop: Optional[float] = None, - step: float = 1, - *, - dtype: Optional[jnp.dtype] = None, - device: jaxlib.xla_extension.Device = None, - out: Optional[JaxArray] = None, -) -> JaxArray: - if dtype: - dtype = as_native_dtype(dtype) - ivy.utils.assertions._check_jax_x64_flag(dtype.name) - res = jnp.arange(start, stop, step, dtype=dtype) - if not dtype: - if res.dtype == jnp.float64: - return res.astype(jnp.float32) - elif res.dtype == jnp.int64: - return res.astype(jnp.int32) - return res - -def meshgrid( - *arrays: JaxArray, - sparse: bool = False, - indexing: str = "xy", - out: Optional[JaxArray] = None, -) -> List[JaxArray]: - return jnp.meshgrid(*arrays, sparse=sparse, indexing=indexing) - -def ones_like( - x: JaxArray, - /, - *, - dtype: jnp.dtype, - device: jaxlib.xla_extension.Device = None, - out: Optional[JaxArray] = None, -) -> JaxArray: - return jnp.ones_like(x, dtype=dtype) \ No newline at end of file +1234 \ No newline at end of file