Skip to content

Commit

Permalink
refactor: reuse Schema.to_<backend>() in from_numpy (#2024)
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned authored Feb 16, 2025
1 parent 6662df5 commit 43d072b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 60 deletions.
68 changes: 16 additions & 52 deletions narwhals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
if TYPE_CHECKING:
from types import ModuleType

import polars as pl
import pyarrow as pa
from typing_extensions import Self

Expand Down Expand Up @@ -476,20 +477,14 @@ def from_numpy(
| e: [[1,3]] |
└──────────────────┘
"""
return _from_numpy_impl(
data,
schema,
native_namespace=native_namespace,
version=Version.MAIN,
)
return _from_numpy_impl(data, schema, native_namespace=native_namespace)


def _from_numpy_impl(
data: _2DArray,
schema: Mapping[str, DType] | Schema | Sequence[str] | None = None,
*,
native_namespace: ModuleType,
version: Version,
) -> DataFrame[Any]:
from narwhals.schema import Schema

Expand All @@ -500,56 +495,34 @@ def _from_numpy_impl(

if implementation is Implementation.POLARS:
if isinstance(schema, (Mapping, Schema)):
from narwhals._polars.utils import (
narwhals_to_native_dtype as polars_narwhals_to_native_dtype,
)

backend_version = parse_version(native_namespace.__version__)
schema = {
name: polars_narwhals_to_native_dtype( # type: ignore[misc]
dtype,
version=version,
backend_version=backend_version,
)
for name, dtype in schema.items()
}
elif schema is None:
native_frame = native_namespace.from_numpy(data)
elif not is_sequence_but_not_str(schema):
schema_pl: pl.Schema | Sequence[str] | None = Schema(schema).to_polars()
elif is_sequence_but_not_str(schema) or schema is None:
schema_pl = schema
else:
msg = (
"`schema` is expected to be one of the following types: "
"Mapping[str, DType] | Schema | Sequence[str]. "
f"Got {type(schema)}."
)
raise TypeError(msg)
native_frame = native_namespace.from_numpy(data, schema=schema)
native_frame = native_namespace.from_numpy(data, schema=schema_pl)

elif implementation.is_pandas_like():
if isinstance(schema, (Mapping, Schema)):
from narwhals._pandas_like.utils import get_dtype_backend
from narwhals._pandas_like.utils import (
narwhals_to_native_dtype as pandas_like_narwhals_to_native_dtype,
)

backend_version = parse_version(native_namespace)
pd_schema = {
name: pandas_like_narwhals_to_native_dtype(
dtype=schema[name],
dtype_backend=get_dtype_backend(native_type, implementation),
implementation=implementation,
backend_version=backend_version,
version=version,
)
for name, native_type in schema.items()
}
it: Iterable[DTypeBackend] = (
get_dtype_backend(native_type, implementation)
for native_type in schema.values()
)
native_frame = native_namespace.DataFrame(data, columns=schema.keys()).astype(
pd_schema
Schema(schema).to_pandas(it)
)
elif is_sequence_but_not_str(schema):
native_frame = native_namespace.DataFrame(data, columns=list(schema))
elif schema is None:
native_frame = native_namespace.DataFrame(
data, columns=["column_" + str(x) for x in range(data.shape[1])]
data, columns=[f"column_{x}" for x in range(data.shape[1])]
)
else:
msg = (
Expand All @@ -562,24 +535,15 @@ def _from_numpy_impl(
elif implementation is Implementation.PYARROW:
pa_arrays = [native_namespace.array(val) for val in data.T]
if isinstance(schema, (Mapping, Schema)):
from narwhals._arrow.utils import (
narwhals_to_native_dtype as arrow_narwhals_to_native_dtype,
)

schema = native_namespace.schema(
[
(name, arrow_narwhals_to_native_dtype(dtype, version))
for name, dtype in schema.items()
]
)
native_frame = native_namespace.Table.from_arrays(pa_arrays, schema=schema)
schema_pa = Schema(schema).to_arrow()
native_frame = native_namespace.Table.from_arrays(pa_arrays, schema=schema_pa)
elif is_sequence_but_not_str(schema):
native_frame = native_namespace.Table.from_arrays(
pa_arrays, names=list(schema)
)
elif schema is None:
native_frame = native_namespace.Table.from_arrays(
pa_arrays, names=["column_" + str(x) for x in range(data.shape[1])]
pa_arrays, names=[f"column_{x}" for x in range(data.shape[1])]
)
else:
msg = (
Expand Down
9 changes: 1 addition & 8 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2237,14 +2237,7 @@ def from_numpy(
Returns:
A new DataFrame.
"""
return _stableify( # type: ignore[no-any-return]
_from_numpy_impl(
data,
schema,
native_namespace=native_namespace,
version=Version.V1,
)
)
return _stableify(_from_numpy_impl(data, schema, native_namespace=native_namespace))


def read_csv(
Expand Down

0 comments on commit 43d072b

Please sign in to comment.