diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index 6071ea02f9..5ff03d302c 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -8,6 +8,7 @@ package bson import ( "bytes" + "encoding/binary" "encoding/json" "errors" "fmt" @@ -1116,3 +1117,50 @@ func compareDecimal128(d1, d2 Decimal128) bool { return true } + +func TestSliceCodec(t *testing.T) { + t.Run("[]byte is treated as binary data", func(t *testing.T) { + type testStruct struct { + B []byte `bson:"b"` + } + + testData := testStruct{B: []byte{0x01, 0x02, 0x03}} + data, err := Marshal(testData) + assert.Nil(t, err, "Marshal error: %v", err) + var doc D + err = Unmarshal(data, &doc) + assert.Nil(t, err, "Unmarshal error: %v", err) + + offset := 4 + 1 + 2 + length := int32(binary.LittleEndian.Uint32(data[offset:])) + offset += 4 // Skip length + subtype := data[offset] + offset++ // Skip subtype + dataBytes := data[offset : offset+int(length)] + + assert.Equal(t, byte(0x00), subtype, "Expected binary subtype 0x00") + assert.Equal(t, []byte{0x01, 0x02, 0x03}, dataBytes, "Binary data mismatch") + }) + + t.Run("[]int8 is not treated as binary data", func(t *testing.T) { + type testStruct struct { + I []int8 `bson:"i"` + } + testData := testStruct{I: []int8{1, 2, 3}} + data, err := Marshal(testData) + assert.Nil(t, err, "Marshal error: %v", err) + + offset := 4 // Skip document length + assert.Equal(t, byte(0x04), data[offset], "Expected array type (0x04), got: 0x%02x", data[offset]) + + var result struct { + I []int32 `bson:"i"` + } + err = Unmarshal(data, &result) + assert.Nil(t, err, "Unmarshal result error: %v", err) + assert.Equal(t, 3, len(result.I), "Expected array length 3") + assert.Equal(t, int32(1), result.I[0], "Array element 0 mismatch") + assert.Equal(t, int32(2), result.I[1], "Array element 1 mismatch") + assert.Equal(t, int32(3), result.I[2], "Array element 2 mismatch") + }) +} diff --git a/bson/slice_codec.go b/bson/slice_codec.go index c8719dcc18..b834762533 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -7,8 +7,10 @@ package bson import ( + "encoding/binary" "errors" "fmt" + "math" "reflect" ) @@ -19,6 +21,43 @@ type sliceCodec struct { encodeNilAsEmpty bool } +// decodeVectorBinary handles decoding of BSON Vector binary (subtype 9) into slices. +// It returns errNotAVectorBinary if the binary data is not a Vector binary. +// The method supports decoding into []int8 and []float32 slices. +func (sc *sliceCodec) decodeVectorBinary(vr ValueReader, val reflect.Value) error { + elemType := val.Type().Elem() + + if elemType != tInt8 && elemType != tFloat32 { + return errNotAVectorBinary + } + + data, subtype, err := vr.ReadBinary() + if err != nil { + return err + } + + if subtype != TypeBinaryVector { + return errNotAVectorBinary + } + + switch elemType { + case tInt8: + int8Slice, err := decodeVectorInt8(data) + if err != nil { + return err + } + val.Set(reflect.ValueOf(int8Slice)) + case tFloat32: + float32Slice, err := decodeVectorFloat32(data) + if err != nil { + return err + } + val.Set(reflect.ValueOf(float32Slice)) + } + + return nil +} + // EncodeValue is the ValueEncoder for slice types. func (sc *sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Slice { @@ -29,8 +68,10 @@ func (sc *sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect. return vw.WriteNull() } - // If we have a []byte we want to treat it as a binary instead of as an array. - if val.Type().Elem() == tByte { + // Treat []byte as binary data, but skip for []int8 since it's a different type. + // Even though byte is an alias for uint8 which has the same underlying type as int8, + // we want to maintain the semantic difference between []byte (binary data) and []int8 (array of integers). + if val.Type().Elem() == tByte && val.Type() != reflect.TypeOf([]int8{}) { byteSlice := make([]byte, val.Len()) reflect.Copy(reflect.ValueOf(byteSlice), val) return vw.WriteBinary(byteSlice) @@ -112,6 +153,14 @@ func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect. return fmt.Errorf("cannot decode document into %s", val.Type()) } case TypeBinary: + err := sc.decodeVectorBinary(vr, val) + if err == nil { + return nil + } + if err != errNotAVectorBinary { + return err + } + if val.Type().Elem() != tByte { return fmt.Errorf("SliceDecodeValue can only decode a binary into a byte array, got %v", vrType) } @@ -171,3 +220,62 @@ func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect. return nil } + +// decodeVectorInt8 decodes a BSON Vector binary value (subtype 9) into a []int8 slice. +// The binary data should be in the format: [ ] +// For int8 vectors, the vector type is Int8Vector (0x03). +func decodeVectorInt8(data []byte) ([]int8, error) { + if len(data) < 2 { + return nil, fmt.Errorf("insufficient bytes to decode vector: expected at least 2 bytes") + } + + vectorType := data[0] + if vectorType != Int8Vector { + return nil, fmt.Errorf("invalid vector type: expected int8 vector (0x%02x), got 0x%02x", Int8Vector, vectorType) + } + + if padding := data[1]; padding != 0 { + return nil, fmt.Errorf("invalid vector: padding byte must be 0") + } + + values := make([]int8, 0, len(data)-2) + for i := 2; i < len(data); i++ { + values = append(values, int8(data[i])) + } + + return values, nil +} + +// decodeVectorFloat32 decodes a BSON Vector binary value (subtype 9) into a []float32 slice. +// The binary data should be in the format: [ ] +// For float32 vectors, the vector type is Float32Vector (0x27) and data must be a multiple of 4 bytes. +func decodeVectorFloat32(data []byte) ([]float32, error) { + if len(data) < 2 { + return nil, fmt.Errorf("insufficient bytes to decode vector: expected at least 2 bytes") + } + + vectorType := data[0] + if vectorType != Float32Vector { + return nil, fmt.Errorf("invalid vector type: expected float32 vector (0x%02x), got 0x%02x", Float32Vector, vectorType) + } + + if padding := data[1]; padding != 0 { + return nil, fmt.Errorf("invalid vector: padding byte must be 0") + } + + floatData := data[2:] + if len(floatData)%4 != 0 { + return nil, fmt.Errorf("invalid float32 vector: data length must be a multiple of 4") + } + + values := make([]float32, 0, len(floatData)/4) + for i := 0; i < len(floatData); i += 4 { + if i+4 > len(floatData) { + return nil, fmt.Errorf("invalid float32 vector: truncated data") + } + bits := binary.LittleEndian.Uint32(floatData[i : i+4]) + values = append(values, math.Float32frombits(bits)) + } + + return values, nil +} diff --git a/bson/types.go b/bson/types.go index c2883aa4ef..91bd9e32fb 100644 --- a/bson/types.go +++ b/bson/types.go @@ -77,7 +77,9 @@ const ( ) var tBool = reflect.TypeOf(false) +var tFloat32 = reflect.TypeOf(float32(0)) var tFloat64 = reflect.TypeOf(float64(0)) +var tInt8 = reflect.TypeOf(int8(0)) var tInt32 = reflect.TypeOf(int32(0)) var tInt64 = reflect.TypeOf(int64(0)) var tString = reflect.TypeOf("") diff --git a/bson/vector.go b/bson/vector.go index 31a10bd5be..f0735f806f 100644 --- a/bson/vector.go +++ b/bson/vector.go @@ -25,6 +25,7 @@ var ( errInsufficientVectorData = errors.New("insufficient data") errNonZeroVectorPadding = errors.New("padding must be 0") errVectorPaddingTooLarge = errors.New("padding cannot be larger than 7") + errNotAVectorBinary = errors.New("not a vector binary") ) type vectorTypeError struct { diff --git a/bson/vector_unmarshal_test.go b/bson/vector_unmarshal_test.go new file mode 100644 index 0000000000..01ca9448ac --- /dev/null +++ b/bson/vector_unmarshal_test.go @@ -0,0 +1,154 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +// Helper function to create a BSON document with a vector binary field (subtype 0x09) +func createBSONWithBinary(data []byte) []byte { + // Document format: {"v": BinData(subtype, data)} + buf := make([]byte, 0, 32+len(data)) + + buf = append(buf, 0x00, 0x00, 0x00, 0x00) // Length placeholder + buf = append(buf, 0x05) // Binary type + buf = append(buf, 'v', 0x00) // Field name "v" + + buf = append(buf, + byte(len(data)), // Length of binary data + 0x00, 0x00, 0x00, // 4-byte length (little endian) + 0x09, // Binary subtype for Vector + ) + buf = append(buf, data...) + buf = append(buf, 0x00) + + docLen := len(buf) + buf[0] = byte(docLen) + buf[1] = byte(docLen >> 8) + buf[2] = byte(docLen >> 16) + buf[3] = byte(docLen >> 24) + + return buf +} + +func TestVectorBackwardCompatibility(t *testing.T) { + t.Parallel() + + t.Run("unmarshal to Vector field", func(t *testing.T) { + t.Parallel() + + vectorData := []byte{ + 0x03, // int8 vector type (0x03 is Int8Vector) + 0x00, // padding + 0x01, 0x02, 0x03, 0x04, // int8 values + } + + doc := createBSONWithBinary(vectorData) + + var result struct { + V Vector + } + err := Unmarshal(doc, &result) + require.NoError(t, err) + + require.Equal(t, Int8Vector, result.V.Type()) + int8Data, ok := result.V.Int8OK() + require.True(t, ok, "expected int8 vector") + require.Equal(t, []int8{1, 2, 3, 4}, int8Data) + }) +} + +func TestUnmarshalVectorToSlices(t *testing.T) { + t.Parallel() + + t.Run("int8 vector to []int8", func(t *testing.T) { + t.Parallel() + + doc := D{{"v", NewVector([]int8{-2, 1, 2, 3, 4})}} + bsonData, err := Marshal(doc) + require.NoError(t, err) + var result struct{ V []int8 } + err = Unmarshal(bsonData, &result) + require.NoError(t, err) + require.Equal(t, []int8{-2, 1, 2, 3, 4}, result.V) + }) + + t.Run("float32 vector to []float32", func(t *testing.T) { + t.Parallel() + + doc := D{{"v", NewVector([]float32{1.1, 2.2, 3.3, 4.4})}} + bsonData, err := Marshal(doc) + require.NoError(t, err) + var result struct{ V []float32 } + err = Unmarshal(bsonData, &result) + require.NoError(t, err) + require.InDeltaSlice(t, []float32{1.1, 2.2, 3.3, 4.4}, result.V, 0.001) + }) + + t.Run("invalid vector type to slice", func(t *testing.T) { + t.Parallel() + + vectorData := []byte{ + 0x10, // packed bit vector type (unsupported for direct unmarshaling) + 0x00, // padding + 0x01, 0x02, // some data + } + bsonData := createBSONWithBinary(vectorData) + + t.Run("to []int8", func(t *testing.T) { + t.Parallel() + + vectorData := []byte{0x10, 0x00} // Invalid vector type + bsonData := createBSONWithBinary(vectorData) + + var result struct{ V []int8 } + err := Unmarshal(bsonData, &result) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("invalid vector type: expected int8 vector (0x%02x)", Int8Vector)) + }) + + t.Run("to []float32", func(t *testing.T) { + t.Parallel() + var result struct{ V []float32 } + err := Unmarshal(bsonData, &result) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("invalid vector type: expected float32 vector (0x%02x)", Float32Vector)) + }) + }) + + t.Run("invalid binary data", func(t *testing.T) { + t.Parallel() + + vectorData := []byte{0x01, 0x00, 0x01, 0x02, 0x03, 0x04} + bsonData := createBSONWithBinary(vectorData) + + t.Run("to []int8", func(t *testing.T) { + t.Parallel() + + var result struct{ V []int8 } + err := Unmarshal(bsonData, &result) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("invalid vector type: expected int8 vector (0x%02x)", Int8Vector)) + }) + + t.Run("to []float32", func(t *testing.T) { + t.Parallel() + + vectorData := []byte{0x01, 0x00} + bsonData := createBSONWithBinary(vectorData) + + var result struct{ V []float32 } + err := Unmarshal(bsonData, &result) + require.Error(t, err) + require.Contains(t, err.Error(), fmt.Sprintf("invalid vector type: expected float32 vector (0x%02x)", Float32Vector)) + }) + }) +} diff --git a/mongo/mongo_test.go b/mongo/mongo_test.go index 96be905cb5..3270ee638e 100644 --- a/mongo/mongo_test.go +++ b/mongo/mongo_test.go @@ -7,6 +7,7 @@ package mongo import ( + "context" "errors" "fmt" "reflect" @@ -652,3 +653,188 @@ type bvMarsh struct { func (b bvMarsh) MarshalBSONValue() (byte, []byte, error) { return byte(b.t), b.data, b.err } + +func TestVectorIntegration(t *testing.T) { + t.Run("roundtrip int8 vector", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + type vectorDoc struct { + ID string `bson:"_id"` + Vec []int8 `bson:"v"` + } + + ctx := context.Background() + client := setupClient() + defer func() { + _ = client.Disconnect(ctx) + }() + + db := client.Database("test") + coll := db.Collection("vector_test") + + _, _ = coll.DeleteMany(ctx, bson.M{"$or": []bson.M{ + {"_id": "test_int8"}, + }}) + + expected := vectorDoc{ + ID: "test_int8", + Vec: []int8{-2, -1, 0, 1, 2}, + } + + _, err := coll.InsertOne(ctx, expected) + if err != nil { + t.Fatalf("InsertOne error: %v", err) + } + + var result vectorDoc + err = coll.FindOne(ctx, bson.M{"_id": "test_int8"}).Decode(&result) + if err != nil { + t.Fatalf("FindOne error: %v", err) + } + + if !reflect.DeepEqual(expected.Vec, result.Vec) { + t.Errorf("vector data does not match. Expected %v, got %v", expected.Vec, result.Vec) + } + }) + + t.Run("roundtrip float32 vector", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + type vectorDoc struct { + ID string `bson:"_id"` + Vec []float32 `bson:"v"` + } + + ctx := context.Background() + client := setupClient() + defer func() { + _ = client.Disconnect(ctx) + }() + + db := client.Database("test") + coll := db.Collection("vector_test") + + _, _ = coll.DeleteMany(ctx, bson.M{"$or": []bson.M{ + {"_id": "test_float32"}, + }}) + expected := vectorDoc{ + ID: "test_float32", + Vec: []float32{-1.1, 0.0, 0.5, 1.1, 2.2}, + } + + _, err := coll.InsertOne(ctx, expected) + if err != nil { + t.Fatalf("InsertOne error: %v", err) + } + + var result vectorDoc + err = coll.FindOne(ctx, bson.M{"_id": "test_float32"}).Decode(&result) + if err != nil { + t.Fatalf("FindOne error: %v", err) + } + + if len(expected.Vec) != len(result.Vec) { + t.Fatalf("vector length mismatch: expected %d, got %d", len(expected.Vec), len(result.Vec)) + } + for i := range expected.Vec { + if diff := expected.Vec[i] - result.Vec[i]; diff < -0.0001 || diff > 0.0001 { + t.Errorf("vector element %d mismatch: expected %v, got %v", i, expected.Vec[i], result.Vec[i]) + } + } + }) + + t.Run("bson.NewVector with int8", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + type vectorDoc struct { + ID string `bson:"_id"` + Vec []int8 `bson:"v"` + } + + ctx := context.Background() + client := setupClient() + defer func() { + _ = client.Disconnect(ctx) + }() + + db := client.Database("test") + coll := db.Collection("vector_test") + + testID := "test_new_vector_int8" + _, _ = coll.DeleteMany(ctx, bson.M{"$or": []bson.M{ + {"_id": testID}, + }}) + + expected := []int8{-2, -1, 0, 1, 2} + + _, err := coll.InsertOne(ctx, bson.D{ + {Key: "_id", Value: testID}, + {Key: "v", Value: bson.NewVector(expected)}, + }) + if err != nil { + t.Fatalf("InsertOne error: %v", err) + } + + var result vectorDoc + err = coll.FindOne(ctx, bson.M{"_id": testID}).Decode(&result) + if err != nil { + t.Fatalf("FindOne error: %v", err) + } + + if !reflect.DeepEqual(expected, result.Vec) { + t.Errorf("vector data does not match. Expected %v, got %v", expected, result.Vec) + } + }) + + t.Run("bson.NewVector with float32", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + type vectorDoc struct { + ID string `bson:"_id"` + Vec []float32 `bson:"v"` + } + + ctx := context.Background() + client := setupClient() + defer func() { + _ = client.Disconnect(ctx) + }() + + db := client.Database("test") + coll := db.Collection("vector_test") + + testID := "test_new_vector_float32" + _, _ = coll.DeleteMany(ctx, bson.M{"$or": []bson.M{ + {"_id": testID}, + }}) + + expected := []float32{-1.1, 0.0, 0.5, 1.1, 2.2} + + _, err := coll.InsertOne(ctx, bson.D{ + {Key: "_id", Value: testID}, + {Key: "v", Value: bson.NewVector(expected)}, + }) + if err != nil { + t.Fatalf("InsertOne error: %v", err) + } + + var result vectorDoc + err = coll.FindOne(ctx, bson.M{"_id": testID}).Decode(&result) + if err != nil { + t.Fatalf("FindOne error: %v", err) + } + + if len(expected) != len(result.Vec) { + t.Fatalf("vector length mismatch: expected %d, got %d", len(expected), len(result.Vec)) + } + for i := range expected { + if diff := expected[i] - result.Vec[i]; diff < -0.0001 || diff > 0.0001 { + t.Errorf("vector element %d mismatch: expected %v, got %v", i, expected[i], result.Vec[i]) + } + } + }) +}