Skip to content

Commit

Permalink
ci(python): type checking with pyright (#3286)
Browse files Browse the repository at this point in the history
Initial step of #3285
  • Loading branch information
wjones127 authored Dec 24, 2024
1 parent d06488e commit 877b018
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 36 deletions.
1 change: 1 addition & 0 deletions python/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ lint: lint-python lint-rust
lint-python:
ruff format --check python
ruff check python
pyright
.PHONY: lint-python

lint-rust:
Expand Down
18 changes: 12 additions & 6 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ tests = [
"tensorflow",
"tqdm",
]
dev = ["ruff==0.4.1"]
dev = ["ruff==0.4.1", "pyright"]
benchmarks = ["pytest-benchmark"]
torch = ["torch"]
ray = ["ray[data]<2.38; python_version<'3.12'"]
Expand All @@ -68,11 +68,17 @@ lint.select = ["F", "E", "W", "I", "G", "TCH", "PERF", "B019"]
[tool.ruff.lint.per-file-ignores]
"*.pyi" = ["E301", "E302"]

[tool.mypy]
python_version = "3.12"
check_untyped_defs = true
warn_redundant_casts = true
warn_unused_ignores = true
[tool.pyright]
pythonVersion = "3.12"
# TODO: expand this list as we fix more files.
include = ["python/lance/util.py"]
# Dependencies like pyarrow make this difficult to enforce strictly.
reportMissingTypeStubs = "warning"
reportImportCycles = "error"
reportUnusedImport = "error"
reportPropertyTypeMismatch = "error"
reportUnnecessaryCast = "error"


[tool.pytest.ini_options]
markers = [
Expand Down
8 changes: 4 additions & 4 deletions python/python/lance/torch/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from lance.dependencies import numpy as np
from lance.log import LOGGER
from lance.util import MetricType, _normalize_metric_type

from . import preferred_device
from .data import TensorDataset
Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(
self,
k: int,
*,
metric: Literal["l2", "euclidean", "cosine", "dot"] = "l2",
metric: MetricType = "l2",
init: Literal["random"] = "random",
max_iters: int = 50,
tolerance: float = 1e-4,
Expand All @@ -64,9 +65,8 @@ def __init__(
self.k = k
self.max_iters = max_iters

metric = metric.lower()
self.metric = metric
if metric in ["l2", "euclidean", "cosine"]:
self.metric = _normalize_metric_type(metric)
if metric in ["l2", "cosine"]:
# Cosine uses normalized unit vector and calculate l2 distance
self.dist_func = l2_distance
elif metric == "dot":
Expand Down
39 changes: 18 additions & 21 deletions python/python/lance/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Iterator, Literal, Optional, Union
from typing import TYPE_CHECKING, Iterator, Literal, Optional, Union, cast

import pyarrow as pa

Expand All @@ -16,14 +16,16 @@
if TYPE_CHECKING:
ts_types = Union[datetime, pd.Timestamp, str]

try:
from pyarrow import FixedShapeTensorType
MetricType = Literal["l2", "euclidean", "dot", "cosine"]

CENTROIDS_TYPE = FixedShapeTensorType
has_fixed_shape_tensor = True
except ImportError:
has_fixed_shape_tensor = False
CENTROIDS_TYPE = pa.FixedSizeListType

def _normalize_metric_type(metric_type: str) -> MetricType:
normalized = metric_type.lower()
if normalized == "euclidean":
normalized = "l2"
if normalized not in {"l2", "dot", "cosine"}:
raise ValueError(f"Invalid metric_type: {metric_type}")
return cast(MetricType, normalized)


def sanitize_ts(ts: ts_types) -> datetime:
Expand Down Expand Up @@ -76,7 +78,7 @@ class KMeans:
def __init__(
self,
k: int,
metric_type: Literal["l2", "dot", "cosine"] = "l2",
metric_type: MetricType = "l2",
max_iters: int = 50,
centroids: Optional[pa.FixedSizeListArray] = None,
):
Expand All @@ -93,11 +95,7 @@ def __init__(
The maximum number of iterations to run the KMeans algorithm. Default: 50.
centroids (pyarrow.FixedSizeListArray, optional.) – Provide existing centroids.
"""
metric_type = metric_type.lower()
if metric_type not in ["l2", "dot", "cosine"]:
raise ValueError(
f"metric_type must be one of 'l2', 'dot', 'cosine', got: {metric_type}"
)
metric_type = _normalize_metric_type(metric_type)
self.k = k
self._metric_type = metric_type
self._kmeans = _KMeans(
Expand All @@ -108,19 +106,18 @@ def __repr__(self) -> str:
return f"lance.KMeans(k={self.k}, metric_type={self._metric_type})"

@property
def centroids(self) -> Optional[CENTROIDS_TYPE]:
def centroids(self) -> Optional[pa.FixedShapeTensorArray]:
"""Returns the centroids of the model,
Returns None if the model is not trained.
"""
ret = self._kmeans.centroids()
if ret is None:
return None
if has_fixed_shape_tensor:
# Pyarrow compatibility
shape = (ret.type.list_size,)
tensor_type = pa.fixed_shape_tensor(ret.type.value_type, shape)
ret = pa.FixedShapeTensorArray.from_storage(tensor_type, ret)

shape = (ret.type.list_size,)
tensor_type = pa.fixed_shape_tensor(ret.type.value_type, shape)
ret = pa.FixedShapeTensorArray.from_storage(tensor_type, ret)
return ret

def _to_fixed_size_list(self, data: pa.Array) -> pa.FixedSizeListArray:
Expand All @@ -130,7 +127,7 @@ def _to_fixed_size_list(self, data: pa.Array) -> pa.FixedSizeListArray:
f"Array must be float32 type, got: {data.type.value_type}"
)
return data
elif has_fixed_shape_tensor and isinstance(data, pa.FixedShapeTensorArray):
elif isinstance(data, pa.FixedShapeTensorArray):
if len(data.type.shape) != 1:
raise ValueError(
f"Fixed shape tensor array must be a 1-D array, "
Expand Down
16 changes: 11 additions & 5 deletions python/python/lance/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import re
import tempfile
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union

import pyarrow as pa
from tqdm.auto import tqdm
Expand All @@ -19,6 +19,7 @@
)
from .dependencies import numpy as np
from .log import LOGGER
from .util import MetricType, _normalize_metric_type

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -132,7 +133,7 @@ def vec_to_table(

def train_pq_codebook_on_accelerator(
dataset: LanceDataset | Path | str,
metric_type: Literal["l2", "cosine", "dot"],
metric_type: MetricType,
accelerator: Union[str, "torch.Device"],
num_sub_vectors: int,
batch_size: int = 1024 * 10 * 4,
Expand All @@ -142,6 +143,8 @@ def train_pq_codebook_on_accelerator(
from .torch.data import LanceDataset as TorchDataset
from .torch.kmeans import KMeans

metric_type = _normalize_metric_type(metric_type)

centroids_list = []
kmeans_list = []

Expand Down Expand Up @@ -197,7 +200,7 @@ def train_ivf_centroids_on_accelerator(
dataset: LanceDataset,
column: str,
k: int,
metric_type: Literal["l2", "cosine", "dot"],
metric_type: MetricType,
accelerator: Union[str, "torch.Device"],
batch_size: int = 1024 * 10 * 4,
*,
Expand All @@ -210,6 +213,8 @@ def train_ivf_centroids_on_accelerator(
from .torch.data import LanceDataset as TorchDataset
from .torch.kmeans import KMeans

metric_type = _normalize_metric_type(metric_type)

if isinstance(accelerator, str) and (
not (CUDA_REGEX.match(accelerator) or accelerator == "mps")
):
Expand Down Expand Up @@ -558,7 +563,7 @@ def one_pass_train_ivf_pq_on_accelerator(
dataset: LanceDataset,
column: str,
k: int,
metric_type: Literal["l2", "cosine", "dot"],
metric_type: MetricType,
accelerator: Union[str, "torch.Device"],
num_sub_vectors: int,
batch_size: int = 1024 * 10 * 4,
Expand All @@ -567,6 +572,7 @@ def one_pass_train_ivf_pq_on_accelerator(
max_iters: int = 50,
filter_nan: bool = True,
):
metric_type = _normalize_metric_type(metric_type)
centroids, kmeans = train_ivf_centroids_on_accelerator(
dataset,
column,
Expand Down Expand Up @@ -597,7 +603,7 @@ def one_pass_train_ivf_pq_on_accelerator(
def one_pass_assign_ivf_pq_on_accelerator(
dataset: LanceDataset,
column: str,
metric_type: Literal["l2", "cosine", "dot"],
metric_type: MetricType,
accelerator: Union[str, "torch.Device"],
ivf_kmeans: Any, # KMeans
pq_kmeans_list: List[Any], # List[KMeans]
Expand Down

0 comments on commit 877b018

Please sign in to comment.