Skip to content

Commit 76f46c5

Browse files
committed
Support numpy type value for bulkwriter
Signed-off-by: yhmo <[email protected]>
1 parent 50f3c2b commit 76f46c5

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

examples/example_bulkwriter.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def test_all_types_writer(bin_vec: bool, schema: CollectionSchema)->list:
232232
),
233233
) as remote_writer:
234234
print("Append rows")
235-
for i in range(10000):
235+
batch_count = 10000
236+
for i in range(batch_count):
236237
row = {
237238
"id": i,
238239
"bool": True if i%5 == 0 else False,
@@ -249,6 +250,23 @@ def test_all_types_writer(bin_vec: bool, schema: CollectionSchema)->list:
249250
}
250251
remote_writer.append_row(row)
251252

253+
# append rows by numpy type
254+
for i in range(batch_count):
255+
remote_writer.append_row({
256+
"id": np.int64(i+batch_count),
257+
"bool": True if i % 3 == 0 else False,
258+
"int8": np.int8(i%128),
259+
"int16": np.int16(i%1000),
260+
"int32": np.int32(i%100000),
261+
"int64": np.int64(i),
262+
"float": np.float32(i/3),
263+
"double": np.float64(i/7),
264+
"varchar": f"varchar_{i}",
265+
"json": json.dumps({"dummy": i, "ok": f"name_{i}"}),
266+
"vector": gen_binary_vector() if bin_vec else gen_float_vector(),
267+
f"dynamic_{i}": i,
268+
})
269+
252270
print("Generate data files...")
253271
remote_writer.commit()
254272
print(f"Data files have been uploaded: {remote_writer.batch_files}")

pymilvus/bulk_writer/bulk_writer.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
# or implied. See the License for the specific language governing permissions and limitations under
1111
# the License.
1212

13+
import json
1314
import logging
1415
from threading import Lock
1516

17+
import numpy as np
18+
1619
from pymilvus.client.types import DataType
1720
from pymilvus.exceptions import MilvusException
1821
from pymilvus.orm.schema import CollectionSchema
@@ -85,6 +88,16 @@ def commit(self, **kwargs):
8588
def data_path(self):
8689
return ""
8790

91+
def _try_convert_json(self, field_name: str, obj: object):
92+
if isinstance(obj, str):
93+
try:
94+
return json.loads(obj)
95+
except Exception as e:
96+
self._throw(
97+
f"Illegal JSON value for field '{field_name}', type mismatch or illegal format, error: {e}"
98+
)
99+
return obj
100+
88101
def _throw(self, msg: str):
89102
logger.error(msg)
90103
raise MilvusException(message=msg)
@@ -109,10 +122,12 @@ def _verify_row(self, row: dict):
109122
dtype = DataType(field.dtype)
110123
validator = TYPE_VALIDATOR[dtype.name]
111124
if dtype in {DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR}:
125+
if isinstance(row[field.name], np.ndarray):
126+
row[field.name] = row[field.name].tolist()
112127
dim = field.params["dim"]
113128
if not validator(row[field.name], dim):
114129
self._throw(
115-
f"Illegal vector data for vector field: '{dtype.name}',"
130+
f"Illegal vector data for vector field: '{field.name}',"
116131
f" dim is not {dim} or type mismatch"
117132
)
118133

@@ -126,20 +141,23 @@ def _verify_row(self, row: dict):
126141
max_len = field.params["max_length"]
127142
if not validator(row[field.name], max_len):
128143
self._throw(
129-
f"Illegal varchar value for field '{dtype.name}',"
144+
f"Illegal varchar value for field '{field.name}',"
130145
f" length exceeds {max_len} or type mismatch"
131146
)
132147

133148
row_size = row_size + len(row[field.name])
134149
elif dtype == DataType.JSON:
150+
row[field.name] = self._try_convert_json(field.name, row[field.name])
135151
if not validator(row[field.name]):
136-
self._throw(f"Illegal varchar value for field '{dtype.name}', type mismatch")
152+
self._throw(f"Illegal JSON value for field '{field.name}', type mismatch")
137153

138154
row_size = row_size + len(row[field.name])
139155
else:
156+
if isinstance(row[field.name], np.generic):
157+
row[field.name] = row[field.name].item()
140158
if not validator(row[field.name]):
141159
self._throw(
142-
f"Illegal scalar value for field '{dtype.name}', value overflow or type mismatch"
160+
f"Illegal scalar value for field '{field.name}', value overflow or type mismatch"
143161
)
144162

145163
row_size = row_size + TYPE_SIZE[dtype.name]

0 commit comments

Comments
 (0)