Skip to content

Commit

Permalink
Chore: Implemented tests for get_key and list_clients
Browse files Browse the repository at this point in the history
  • Loading branch information
lordsarcastic committed Nov 30, 2024
1 parent 79cfcd4 commit f5f7a70
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 68 deletions.
5 changes: 5 additions & 0 deletions sdk/ahnlich-client-py/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
install:
@poetry install

test:
@poetry run pytest . -s -vv
8 changes: 8 additions & 0 deletions sdk/ahnlich-client-py/ahnlich_client_py/builders/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def get_sim_n(
closest_n: st.uint64 = 1,
algorithm: ai_query.Algorithm = ai_query.Algorithm__CosineSimilarity,
condition: typing.Optional[ai_query.PredicateCondition] = None,
preprocess_action: ai_query.PreprocessAction = ai_query.PreprocessAction__ModelPreprocessing,
):
nonzero_n = NonZeroSizeInteger(closest_n)
self.queries.append(
Expand All @@ -61,6 +62,7 @@ def get_sim_n(
closest_n=nonzero_n.value,
algorithm=algorithm,
condition=condition,
preprocess_action=preprocess_action,
)
)

Expand Down Expand Up @@ -127,6 +129,9 @@ def set(
def del_key(self, store_name: str, key: ai_query.StoreInput):
self.queries.append(ai_query.AIQuery__DelKey(store=store_name, key=key))

def get_key(self, store_name: str, keys: typing.Sequence[ai_query.StoreInput]):
self.queries.append(ai_query.AIQuery__GetKey(store=store_name, keys=keys))

def drop_store(self, store_name: str, error_if_not_exists: bool = True):
self.queries.append(
ai_query.AIQuery__DropStore(
Expand All @@ -143,6 +148,9 @@ def info_server(self):
def list_stores(self):
self.queries.append(ai_query.AIQuery__ListStores())

def list_clients(self):
self.queries.append(ai_query.AIQuery__ListClients())

def ping(self):
self.queries.append(ai_query.AIQuery__Ping())

Expand Down
18 changes: 18 additions & 0 deletions sdk/ahnlich-client-py/ahnlich_client_py/clients/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ def del_key(
builder.del_key(store_name=store_name, key=key)
return self.process_request(builder.to_server_query())

def get_key(
self,
store_name: str,
keys: typing.Sequence[ai_query.StoreInput],
tracing_id: typing.Optional[str] = None,
):
builder = builders.AhnlichAIRequestBuilder(tracing_id)
builder.get_key(store_name=store_name, keys=keys)
return self.process_request(builder.to_server_query())

def drop_store(
self,
store_name: str,
Expand Down Expand Up @@ -189,6 +199,14 @@ def list_stores(
builder = builders.AhnlichAIRequestBuilder(tracing_id)
builder.list_stores()
return self.process_request(builder.to_server_query())

def list_clients(
self,
tracing_id: typing.Optional[str] = None,
):
builder = builders.AhnlichAIRequestBuilder(tracing_id)
builder.list_clients()
return self.process_request(builder.to_server_query())

def ping(
self,
Expand Down
20 changes: 20 additions & 0 deletions sdk/ahnlich-client-py/ahnlich_client_py/clients/non_blocking/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ async def get_sim_n(
closest_n: st.uint64 = 1,
algorithm: ai_query.Algorithm = ai_query.Algorithm__CosineSimilarity,
condition: typing.Optional[ai_query.PredicateCondition] = None,
preprocess_action: ai_query.PreprocessAction = ai_query.PreprocessAction__ModelPreprocessing,
tracing_id: typing.Optional[str] = None,
):
builder = AsyncAhnlichAIRequestBuilder(tracing_id)
Expand All @@ -74,6 +75,7 @@ async def get_sim_n(
closest_n=closest_n,
algorithm=algorithm,
condition=condition,
preprocess_action=preprocess_action,
)
return await self.process_request(builder.to_server_query())

Expand Down Expand Up @@ -154,6 +156,16 @@ async def del_key(
builder.del_key(store_name=store_name, key=key)
return await self.process_request(builder.to_server_query())

async def get_key(
self,
store_name: str,
keys: typing.Sequence[ai_query.StoreInput],
tracing_id: typing.Optional[str] = None,
):
builder = AsyncAhnlichAIRequestBuilder(tracing_id)
builder.get_key(store_name=store_name, keys=keys)
return await self.process_request(builder.to_server_query())

async def drop_store(
self,
store_name: str,
Expand Down Expand Up @@ -190,6 +202,14 @@ async def list_stores(
builder.list_stores()
return await self.process_request(builder.to_server_query())

async def list_clients(
self,
tracing_id: typing.Optional[str] = None,
):
builder = AsyncAhnlichAIRequestBuilder(tracing_id)
builder.list_clients()
return await self.process_request(builder.to_server_query())

async def ping(
self,
tracing_id: typing.Optional[str] = None,
Expand Down
43 changes: 28 additions & 15 deletions sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# pyre-strict
from dataclasses import dataclass
import typing
from ahnlich_client_py.internals import serde_types as st
from dataclasses import dataclass

from ahnlich_client_py.internals import bincode
from ahnlich_client_py.internals import serde_types as st


class AIModel:
VARIANTS = [] # type: typing.Sequence[typing.Type[AIModel]]
Expand All @@ -11,7 +13,7 @@ def bincode_serialize(self) -> bytes:
return bincode.serialize(self, AIModel)

@staticmethod
def bincode_deserialize(input: bytes) -> 'AIModel':
def bincode_deserialize(input: bytes) -> "AIModel":
v, buffer = bincode.deserialize(input, AIModel)
if buffer:
raise st.DeserializationError("Some input bytes were not read")
Expand Down Expand Up @@ -59,6 +61,7 @@ class AIModel__ClipVitB32Text(AIModel):
INDEX = 6 # type: int
pass


AIModel.VARIANTS = [
AIModel__AllMiniLML6V2,
AIModel__AllMiniLML12V2,
Expand All @@ -77,7 +80,7 @@ def bincode_serialize(self) -> bytes:
return bincode.serialize(self, AIQuery)

@staticmethod
def bincode_deserialize(input: bytes) -> 'AIQuery':
def bincode_deserialize(input: bytes) -> "AIQuery":
v, buffer = bincode.deserialize(input, AIQuery)
if buffer:
raise st.DeserializationError("Some input bytes were not read")
Expand Down Expand Up @@ -148,7 +151,9 @@ class AIQuery__DropNonLinearAlgorithmIndex(AIQuery):
class AIQuery__Set(AIQuery):
INDEX = 7 # type: int
store: str
inputs: typing.Sequence[typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]]]
inputs: typing.Sequence[
typing.Tuple["StoreInput", typing.Dict[str, "MetadataValue"]]
]
preprocess_action: "PreprocessAction"


Expand Down Expand Up @@ -202,6 +207,7 @@ class AIQuery__Ping(AIQuery):
INDEX = 15 # type: int
pass


AIQuery.VARIANTS = [
AIQuery__CreateStore,
AIQuery__GetPred,
Expand Down Expand Up @@ -231,7 +237,7 @@ def bincode_serialize(self) -> bytes:
return bincode.serialize(self, AIServerQuery)

@staticmethod
def bincode_deserialize(input: bytes) -> 'AIServerQuery':
def bincode_deserialize(input: bytes) -> "AIServerQuery":
v, buffer = bincode.deserialize(input, AIServerQuery)
if buffer:
raise st.DeserializationError("Some input bytes were not read")
Expand All @@ -245,7 +251,7 @@ def bincode_serialize(self) -> bytes:
return bincode.serialize(self, AIStoreInputType)

@staticmethod
def bincode_deserialize(input: bytes) -> 'AIStoreInputType':
def bincode_deserialize(input: bytes) -> "AIStoreInputType":
v, buffer = bincode.deserialize(input, AIStoreInputType)
if buffer:
raise st.DeserializationError("Some input bytes were not read")
Expand All @@ -263,6 +269,7 @@ class AIStoreInputType__Image(AIStoreInputType):
INDEX = 1 # type: int
pass


AIStoreInputType.VARIANTS = [
AIStoreInputType__RawString,
AIStoreInputType__Image,
Expand All @@ -276,7 +283,7 @@ def bincode_serialize(self) -> bytes:
return bincode.serialize(self, Algorithm)

@staticmethod
def bincode_deserialize(input: bytes) -> 'Algorithm':
def bincode_deserialize(input: bytes) -> "Algorithm":
v, buffer = bincode.deserialize(input, Algorithm)
if buffer:
raise st.DeserializationError("Some input bytes were not read")
Expand Down Expand Up @@ -306,6 +313,7 @@ class Algorithm__KDTree(Algorithm):
INDEX = 3 # type: int
pass


Algorithm.VARIANTS = [
Algorithm__EuclideanDistance,
Algorithm__DotProductSimilarity,
Expand All @@ -321,7 +329,7 @@ def bincode_serialize(self) -> bytes:
return bincode.serialize(self, MetadataValue)

@staticmethod
def bincode_deserialize(input: bytes) -> 'MetadataValue':
def bincode_deserialize(input: bytes) -> "MetadataValue":
v, buffer = bincode.deserialize(input, MetadataValue)
if buffer:
raise st.DeserializationError("Some input bytes were not read")
Expand All @@ -339,6 +347,7 @@ class MetadataValue__Image(MetadataValue):
INDEX = 1 # type: int
value: typing.Sequence[st.uint8]


MetadataValue.VARIANTS = [
MetadataValue__RawString,
MetadataValue__Image,
Expand All @@ -352,7 +361,7 @@ def bincode_serialize(self) -> bytes:
return bincode.serialize(self, NonLinearAlgorithm)

@staticmethod
def bincode_deserialize(input: bytes) -> 'NonLinearAlgorithm':
def bincode_deserialize(input: bytes) -> "NonLinearAlgorithm":
v, buffer = bincode.deserialize(input, NonLinearAlgorithm)
if buffer:
raise st.DeserializationError("Some input bytes were not read")
Expand All @@ -364,6 +373,7 @@ class NonLinearAlgorithm__KDTree(NonLinearAlgorithm):
INDEX = 0 # type: int
pass


NonLinearAlgorithm.VARIANTS = [
NonLinearAlgorithm__KDTree,
]
Expand All @@ -376,7 +386,7 @@ def bincode_serialize(self) -> bytes:
return bincode.serialize(self, Predicate)

@staticmethod
def bincode_deserialize(input: bytes) -> 'Predicate':
def bincode_deserialize(input: bytes) -> "Predicate":
v, buffer = bincode.deserialize(input, Predicate)
if buffer:
raise st.DeserializationError("Some input bytes were not read")
Expand Down Expand Up @@ -410,6 +420,7 @@ class Predicate__NotIn(Predicate):
key: str
value: typing.Sequence["MetadataValue"]


Predicate.VARIANTS = [
Predicate__Equals,
Predicate__NotEquals,
Expand All @@ -425,7 +436,7 @@ def bincode_serialize(self) -> bytes:
return bincode.serialize(self, PredicateCondition)

@staticmethod
def bincode_deserialize(input: bytes) -> 'PredicateCondition':
def bincode_deserialize(input: bytes) -> "PredicateCondition":
v, buffer = bincode.deserialize(input, PredicateCondition)
if buffer:
raise st.DeserializationError("Some input bytes were not read")
Expand All @@ -449,6 +460,7 @@ class PredicateCondition__Or(PredicateCondition):
INDEX = 2 # type: int
value: typing.Tuple["PredicateCondition", "PredicateCondition"]


PredicateCondition.VARIANTS = [
PredicateCondition__Value,
PredicateCondition__And,
Expand All @@ -463,7 +475,7 @@ def bincode_serialize(self) -> bytes:
return bincode.serialize(self, PreprocessAction)

@staticmethod
def bincode_deserialize(input: bytes) -> 'PreprocessAction':
def bincode_deserialize(input: bytes) -> "PreprocessAction":
v, buffer = bincode.deserialize(input, PreprocessAction)
if buffer:
raise st.DeserializationError("Some input bytes were not read")
Expand All @@ -481,6 +493,7 @@ class PreprocessAction__ModelPreprocessing(PreprocessAction):
INDEX = 1 # type: int
pass


PreprocessAction.VARIANTS = [
PreprocessAction__NoPreprocessing,
PreprocessAction__ModelPreprocessing,
Expand All @@ -494,7 +507,7 @@ def bincode_serialize(self) -> bytes:
return bincode.serialize(self, StoreInput)

@staticmethod
def bincode_deserialize(input: bytes) -> 'StoreInput':
def bincode_deserialize(input: bytes) -> "StoreInput":
v, buffer = bincode.deserialize(input, StoreInput)
if buffer:
raise st.DeserializationError("Some input bytes were not read")
Expand All @@ -512,8 +525,8 @@ class StoreInput__Image(StoreInput):
INDEX = 1 # type: int
value: typing.Sequence[st.uint8]


StoreInput.VARIANTS = [
StoreInput__RawString,
StoreInput__Image,
]

Loading

0 comments on commit f5f7a70

Please sign in to comment.