diff --git a/examples/example_bulkwriter.py b/examples/example_bulkwriter.py index 36b098870..9731cf9f6 100644 --- a/examples/example_bulkwriter.py +++ b/examples/example_bulkwriter.py @@ -230,9 +230,11 @@ def test_all_types_writer(bin_vec: bool, schema: CollectionSchema)->list: secret_key=MINIO_SECRET_KEY, bucket_name="a-bucket", ), + file_type=BulkFileType.JSON_RB, ) as remote_writer: print("Append rows") - for i in range(10000): + batch_count = 10000 + for i in range(batch_count): row = { "id": i, "bool": True if i%5 == 0 else False, @@ -249,6 +251,23 @@ def test_all_types_writer(bin_vec: bool, schema: CollectionSchema)->list: } remote_writer.append_row(row) + # append rows by numpy type + for i in range(batch_count): + remote_writer.append_row({ + "id": np.int64(i+batch_count), + "bool": True if i % 3 == 0 else False, + "int8": np.int8(i%128), + "int16": np.int16(i%1000), + "int32": np.int32(i%100000), + "int64": np.int64(i), + "float": np.float32(i/3), + "double": np.float64(i/7), + "varchar": f"varchar_{i}", + "json": json.dumps({"dummy": i, "ok": f"name_{i}"}), + "vector": gen_binary_vector() if bin_vec else gen_float_vector(), + f"dynamic_{i}": i, + }) + print("Generate data files...") remote_writer.commit() print(f"Data files have been uploaded: {remote_writer.batch_files}") @@ -371,6 +390,6 @@ def test_cloud_bulkinsert(): test_call_bulkinsert(schema, batch_files) test_retrieve_imported_data(bin_vec=True) - # # to test cloud bulkinsert api, you need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud) - # test_cloud_bulkinsert() + # to test cloud bulkinsert api, you need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud) + test_cloud_bulkinsert() diff --git a/pymilvus/bulk_writer/bulk_writer.py b/pymilvus/bulk_writer/bulk_writer.py index 2f53256b0..88bc44c9b 100644 --- a/pymilvus/bulk_writer/bulk_writer.py +++ b/pymilvus/bulk_writer/bulk_writer.py @@ -10,9 +10,12 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. +import json import logging from threading import Lock +import numpy as np + from pymilvus.client.types import DataType from pymilvus.exceptions import MilvusException from pymilvus.orm.schema import CollectionSchema @@ -85,6 +88,16 @@ def commit(self, **kwargs): def data_path(self): return "" + def _try_convert_json(self, field_name: str, obj: object): + if isinstance(obj, str): + try: + return json.loads(obj) + except Exception as e: + self._throw( + f"Illegal JSON value for field '{field_name}', type mismatch or illegal format, error: {e}" + ) + return obj + def _throw(self, msg: str): logger.error(msg) raise MilvusException(message=msg) @@ -109,10 +122,12 @@ def _verify_row(self, row: dict): dtype = DataType(field.dtype) validator = TYPE_VALIDATOR[dtype.name] if dtype in {DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR}: + if isinstance(row[field.name], np.ndarray): + row[field.name] = row[field.name].tolist() dim = field.params["dim"] if not validator(row[field.name], dim): self._throw( - f"Illegal vector data for vector field: '{dtype.name}'," + f"Illegal vector data for vector field: '{field.name}'," f" dim is not {dim} or type mismatch" ) @@ -126,20 +141,23 @@ def _verify_row(self, row: dict): max_len = field.params["max_length"] if not validator(row[field.name], max_len): self._throw( - f"Illegal varchar value for field '{dtype.name}'," + f"Illegal varchar value for field '{field.name}'," f" length exceeds {max_len} or type mismatch" ) row_size = row_size + len(row[field.name]) elif dtype == DataType.JSON: + row[field.name] = self._try_convert_json(field.name, row[field.name]) if not validator(row[field.name]): - self._throw(f"Illegal varchar value for field '{dtype.name}', type mismatch") + self._throw(f"Illegal JSON value for field '{field.name}', type mismatch") row_size = row_size + len(row[field.name]) else: + if isinstance(row[field.name], np.generic): + row[field.name] = row[field.name].item() if not validator(row[field.name]): self._throw( - f"Illegal scalar value for field '{dtype.name}', value overflow or type mismatch" + f"Illegal scalar value for field '{field.name}', value overflow or type mismatch" ) row_size = row_size + TYPE_SIZE[dtype.name]