Skip to content

Commit

Permalink
support run analyzer
Browse files Browse the repository at this point in the history
Signed-off-by: aoiasd <[email protected]>
  • Loading branch information
aoiasd committed Feb 18, 2025
1 parent 5220f8d commit 4796d39
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 19 deletions.
13 changes: 13 additions & 0 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2258,3 +2258,16 @@ def remove_privileges_from_group(
)
resp = self._stub.OperatePrivilegeGroup(req, wait_for_ready=True, timeout=timeout)
check_status(resp)

@retry_on_rpc_failure()
def run_analyzer(
self,
texts: Union[str, List[str]],
analyzer_params: Union[str, Dict],
timeout: Optional[float],
**kwargs,
):
req = Prepare.run_analyzer(texts, analyzer_params)
resp = self._stub.RunAnalyzer(req, timeout=timeout)
check_status(resp.status)
return resp.results
15 changes: 15 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,3 +1728,18 @@ def operate_privilege_group_req(
privileges=[milvus_types.PrivilegeEntity(name=p) for p in privileges],
type=operate_privilege_group_type,
)

@classmethod
def run_analyzer(cls, texts: Union[str, List[str]], analyzer_params: Union[str, Dict]):
req = milvus_types.RunAnalyzerRequset()
if isinstance(texts, str):
req.placeholder.append(texts.encode("utf-8"))
else:
req.placeholder.extend([text.encode("utf-8") for text in texts])

if isinstance(analyzer_params, dict):
req.analyzer_params = ujson.dumps(analyzer_params)
else:
req.analyzer_params = analyzer_params

return req
2 changes: 1 addition & 1 deletion pymilvus/grpc_gen/milvus-proto
40 changes: 23 additions & 17 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions pymilvus/grpc_gen/milvus_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2211,3 +2211,27 @@ class ListImportsAuthPlaceholder(_message.Message):
db_name: str
collection_name: str
def __init__(self, db_name: _Optional[str] = ..., collection_name: _Optional[str] = ...) -> None: ...

class RunAnalyzerRequset(_message.Message):
__slots__ = ("base", "analyzer_params", "placeholder")
BASE_FIELD_NUMBER: _ClassVar[int]
ANALYZER_PARAMS_FIELD_NUMBER: _ClassVar[int]
PLACEHOLDER_FIELD_NUMBER: _ClassVar[int]
base: _common_pb2.MsgBase
analyzer_params: str
placeholder: _containers.RepeatedScalarFieldContainer[bytes]
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., analyzer_params: _Optional[str] = ..., placeholder: _Optional[_Iterable[bytes]] = ...) -> None: ...

class AnalyzerResult(_message.Message):
__slots__ = ("tokens",)
TOKENS_FIELD_NUMBER: _ClassVar[int]
tokens: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, tokens: _Optional[_Iterable[str]] = ...) -> None: ...

class RunAnalyzerResponse(_message.Message):
__slots__ = ("status", "results")
STATUS_FIELD_NUMBER: _ClassVar[int]
RESULTS_FIELD_NUMBER: _ClassVar[int]
status: _common_pb2.Status
results: _containers.RepeatedCompositeFieldContainer[AnalyzerResult]
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., results: _Optional[_Iterable[_Union[AnalyzerResult, _Mapping]]] = ...) -> None: ...
33 changes: 33 additions & 0 deletions pymilvus/grpc_gen/milvus_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,11 @@ def __init__(self, channel):
request_serializer=milvus__pb2.OperatePrivilegeGroupRequest.SerializeToString,
response_deserializer=common__pb2.Status.FromString,
)
self.RunAnalyzer = channel.unary_unary(
'/milvus.proto.milvus.MilvusService/RunAnalyzer',
request_serializer=milvus__pb2.RunAnalyzerRequset.SerializeToString,
response_deserializer=milvus__pb2.RunAnalyzerResponse.FromString,
)


class MilvusServiceServicer(object):
Expand Down Expand Up @@ -1062,6 +1067,12 @@ def OperatePrivilegeGroup(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def RunAnalyzer(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_MilvusServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -1535,6 +1546,11 @@ def add_MilvusServiceServicer_to_server(servicer, server):
request_deserializer=milvus__pb2.OperatePrivilegeGroupRequest.FromString,
response_serializer=common__pb2.Status.SerializeToString,
),
'RunAnalyzer': grpc.unary_unary_rpc_method_handler(
servicer.RunAnalyzer,
request_deserializer=milvus__pb2.RunAnalyzerRequset.FromString,
response_serializer=milvus__pb2.RunAnalyzerResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'milvus.proto.milvus.MilvusService', rpc_method_handlers)
Expand Down Expand Up @@ -3143,6 +3159,23 @@ def OperatePrivilegeGroup(request,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def RunAnalyzer(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/milvus.proto.milvus.MilvusService/RunAnalyzer',
milvus__pb2.RunAnalyzerRequset.SerializeToString,
milvus__pb2.RunAnalyzerResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)


class ProxyServiceStub(object):
"""Missing associated documentation comment in .proto file."""
Expand Down
38 changes: 37 additions & 1 deletion pymilvus/orm/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# the License.

from datetime import datetime, timedelta, timezone
from typing import List, Mapping, Optional
from typing import Dict, List, Mapping, Optional, Union

from pymilvus.client.types import (
BulkInsertState,
Expand Down Expand Up @@ -1310,3 +1310,39 @@ def list_indexes(
# list all indexes of this field.
index_name_list.append(index.index_name)
return index_name_list


def run_analyzer(
text: Union[str, List[str]],
analyzer_params: Union[str, Dict, None] = None,
using: str = "default",
timeout: Optional[float] = None,
):
"""Run analyzer. Return result tokens of analysis.
:param text: The input text (string or string list).
:type text: str or List[str]
:param analyzer_params: The parameters of analyzer.
:type analyzer_params: str or Dict or None
:param using: Alias to the connection. Default connection is used if this is not specified.
:type using: str
:param kwargs:
* *field_name* (``str``)
The name of field. If no field name is specified, all indexes
of this collection will be returned.
:type kwargs: dict
:return: The result tokens of analysis.
:rtype: List[str] or List[List[str]]
"""
if analyzer_params is None:
analyzer_params = {}

results = _get_connection(using).run_analyzer(text, analyzer_params, timeout)

if isinstance(text, str):
return results[0].tokens
return [result.tokens for result in results]

0 comments on commit 4796d39

Please sign in to comment.