diff --git a/src/main/java/io/stargate/sgv2/jsonapi/exception/catchable/ToCQLCodecException.java b/src/main/java/io/stargate/sgv2/jsonapi/exception/catchable/ToCQLCodecException.java index 36059ee5d..5817a8c94 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/exception/catchable/ToCQLCodecException.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/exception/catchable/ToCQLCodecException.java @@ -9,6 +9,11 @@ * into the appropriate API error. */ public class ToCQLCodecException extends CheckedApiException { + /** + * Since String representation of some input values can be huge (like for Vectors), we limit the + * length of the description to avoid flooding logs. + */ + private static final int MAX_VALUE_DESC_LENGTH = 1000; public final Object value; public final DataType targetCQLType; @@ -39,10 +44,11 @@ private static String valueDesc(Object value) { if (value == null) { return "null"; } + String desc = maybeTruncate(String.valueOf(value)); if (value instanceof String) { - return "\"" + value + "\""; + desc = "\"" + desc + "\""; } - return String.valueOf(value); + return desc; } private static String className(Object value) { @@ -51,4 +57,11 @@ private static String className(Object value) { } return value.getClass().getName(); } + + private static String maybeTruncate(String value) { + if (value.length() <= MAX_VALUE_DESC_LENGTH) { + return value; + } + return value.substring(0, MAX_VALUE_DESC_LENGTH) + "[...](TRUNCATED)"; + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistry.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistry.java index f26cd507d..f53eb22d2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistry.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistry.java @@ -7,6 +7,7 @@ import com.datastax.oss.driver.api.core.type.DataTypes; import com.datastax.oss.driver.api.core.type.ListType; import com.datastax.oss.driver.api.core.type.SetType; +import com.datastax.oss.driver.api.core.type.VectorType; import com.datastax.oss.driver.api.core.type.reflect.GenericType; import io.stargate.sgv2.jsonapi.exception.catchable.MissingJSONCodecException; import io.stargate.sgv2.jsonapi.exception.catchable.ToCQLCodecException; @@ -116,6 +117,17 @@ public JSONCodec codecToCQL( CollectionCodecs.buildToCQLSetCodec(valueCodecCandidates, st.getElementType()); } // fall through + } else if (columnType instanceof VectorType vt) { + // Only Float supported for now + if (!vt.getElementType().equals(DataTypes.FLOAT)) { + throw new ToCQLCodecException(value, columnType, "only Vector supported"); + } + if (value instanceof Collection) { + return VectorCodecs.arrayToCQLFloatVectorCodec(vt); + } + // !!! TODO: different Codec for Base64 encoded (String) Float vectors + + throw new ToCQLCodecException(value, columnType, "no codec matching value type"); } throw new MissingJSONCodecException(table, columnMetadata, value.getClass(), value); @@ -189,6 +201,13 @@ public JSONCodec codecToJSON(DataType fromCQLType) { return (JSONCodec) CollectionCodecs.buildToJsonSetCodec(valueCodecCandidates.get(0)); } + if (fromCQLType instanceof VectorType vt) { + // Only Float supported for now + if (vt.getElementType().equals(DataTypes.FLOAT)) { + return VectorCodecs.toJSONFloatVectorCodec(vt); + } + // fall through + } return null; } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/VectorCodecs.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/VectorCodecs.java new file mode 100644 index 000000000..410197788 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/VectorCodecs.java @@ -0,0 +1,88 @@ +package io.stargate.sgv2.jsonapi.service.operation.filters.table.codecs; + +import com.datastax.oss.driver.api.core.data.CqlVector; +import com.datastax.oss.driver.api.core.type.VectorType; +import com.datastax.oss.driver.api.core.type.reflect.GenericType; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.JsonLiteral; +import io.stargate.sgv2.jsonapi.exception.catchable.ToCQLCodecException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * Container for factories of codecs that handle CQL Vector type. Separated from main {@link + * JSONCodecs} to keep the code somewhat modular. + */ +public abstract class VectorCodecs { + private static final GenericType> FLOAT_LIST = GenericType.listOf(Float.class); + + public static JSONCodec arrayToCQLFloatVectorCodec( + VectorType vectorType) { + // Unfortunately we cannot simply construct and return a single Codec instance here + // because VectorType's dimensions vary, and we need to know the expected dimensions + // (unless we want to rely on DB validating dimension as part of write and catch failure) + return (JSONCodec) + new JSONCodec<>( + FLOAT_LIST, + vectorType, + (cqlType, value) -> toCQLFloatVector(vectorType, value), + // This codec only for to-cql case, not to-json, so we don't need this + null); + } + + public static JSONCodec toJSONFloatVectorCodec(VectorType vectorType) { + return (JSONCodec) + new JSONCodec<>( + FLOAT_LIST, + vectorType, + // This codec only for to-json case, not to-cql, so we don't need this + null, + (objectMapper, cqlType, value) -> toJsonNode(objectMapper, (CqlVector) value)); + } + + static CqlVector toCQLFloatVector(VectorType vectorType, Collection listValue) + throws ToCQLCodecException { + Collection> vectorIn = (Collection>) listValue; + final int expLen = vectorType.getDimensions(); + if (expLen != vectorIn.size()) { + throw new ToCQLCodecException( + vectorIn, + vectorType, + "expected vector of length " + + expLen + + ", got one with " + + vectorIn.size() + + " elements"); + } + List floats = new ArrayList<>(expLen); + for (JsonLiteral literalElement : vectorIn) { + Object element = literalElement.value(); + if (element instanceof Number num) { + floats.add(num.floatValue()); + continue; + } + throw new ToCQLCodecException( + vectorIn, + vectorType, + String.format( + "expected JSON Number value as Vector element at position #%d (of %d), instead have: %s", + floats.size(), expLen, literalElement)); + } + return CqlVector.newInstance(floats); + } + + static JsonNode toJsonNode(ObjectMapper objectMapper, CqlVector vectorValue) { + final ArrayNode result = objectMapper.createArrayNode(); + for (Number element : vectorValue) { + if (element == null) { // is this even legal? + result.addNull(); + } else { + result.add(element.floatValue()); + } + } + return result; + } +} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/InsertOneTableIntegrationTest.java b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/InsertOneTableIntegrationTest.java index 693b487b2..c53cdae27 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/InsertOneTableIntegrationTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/api/v1/tables/InsertOneTableIntegrationTest.java @@ -27,6 +27,7 @@ public class InsertOneTableIntegrationTest extends AbstractTableIntegrationTestB static final String TABLE_WITH_INET_COLUMN = "insertOneInetColumnTable"; static final String TABLE_WITH_LIST_COLUMNS = "insertOneListColumnsTable"; static final String TABLE_WITH_SET_COLUMNS = "insertOneSetColumnsTable"; + static final String TABLE_WITH_VECTOR_COLUMN = "insertOneVectorColumnTable"; final JSONCodecRegistryTestData codecTestData = new JSONCodecRegistryTestData(); @@ -109,6 +110,12 @@ public final void createDefaultTables() { "stringSet", Map.of("type", "set", "valueType", "text")), "id"); + + createTableWithColumns( + TABLE_WITH_VECTOR_COLUMN, + Map.of( + "id", "text", "vector", Map.of("type", "vector", "valueType", "float", "dimension", 3)), + "id"); } @Nested @@ -790,4 +797,59 @@ void failOnWrongSetElementValue() { "Unsupported String value: only"); } } + + @Nested + @Order(10) + class InsertVectorColumns { + @Test + void insertValidVectorValue() { + String docJSON = + """ + { "id": "vectorValid", + "vector": [0.0, -0.5, 3.125] + } + """; + insertOneInTable(TABLE_WITH_VECTOR_COLUMN, docJSON); + DataApiCommandSenders.assertTableCommand(keyspaceName, TABLE_WITH_VECTOR_COLUMN) + .postFindOne("{ \"filter\": { \"id\": \"vectorValid\" } }") + .hasNoErrors() + .hasJSONField("data.document", docJSON); + } + + @Test + void failOnNonArrayVectorValue() { + DataApiCommandSenders.assertTableCommand(keyspaceName, TABLE_WITH_VECTOR_COLUMN) + .postInsertOne( + """ + { + "id": "vectorInvalid", + "vector": "abc" + } + """) + .hasSingleApiError( + DocumentException.Code.INVALID_COLUMN_VALUES, + DocumentException.class, + "Only values that are supported by", + "Error trying to convert to targetCQLType `Vector(FLOAT", + "no codec matching value type"); + } + + @Test + void failOnWrongVectorElementValue() { + DataApiCommandSenders.assertTableCommand(keyspaceName, TABLE_WITH_VECTOR_COLUMN) + .postInsertOne( + """ + { + "id":" vectorInvalid", + "vector": ["abc", 123, false] + } + """) + .hasSingleApiError( + DocumentException.Code.INVALID_COLUMN_VALUES, + DocumentException.class, + "Only values that are supported by", + "Error trying to convert to targetCQLType `Vector(FLOAT", + "expected JSON Number value as Vector element at position #0"); + } + } } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistryTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistryTest.java index 547a5d8bb..843bcd9ae 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistryTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistryTest.java @@ -6,6 +6,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.data.CqlDuration; +import com.datastax.oss.driver.api.core.data.CqlVector; import com.datastax.oss.driver.api.core.type.DataType; import com.datastax.oss.driver.api.core.type.DataTypes; import com.fasterxml.jackson.databind.JsonNode; @@ -64,7 +65,6 @@ private JSONCodec assertGetCodecToCQL( String.format( "Get codec for cqlType=%s and fromValue.class=%s", cqlType, fromValue.getClass().getName())); - assertThat(codec) .isNotNull() .satisfies( @@ -133,6 +133,12 @@ public void codecToCQLCollections(DataType cqlType, Object fromValue, Object exp _codecToCQL(cqlType, fromValue, expectedCqlValue); } + @ParameterizedTest + @MethodSource("validCodecToCQLTestCasesVectors") + public void codecToCQLVectors(DataType cqlType, Object fromValue, Object expectedCqlValue) { + _codecToCQL(cqlType, fromValue, expectedCqlValue); + } + private void _codecToCQL(DataType cqlType, Object fromValue, Object expectedCqlValue) { var codec = assertGetCodecToCQL(cqlType, fromValue); @@ -329,6 +335,21 @@ private static Stream validCodecToCQLTestCasesCollections() { Set.of(-0.75, 42.5))); } + private static Stream validCodecToCQLTestCasesVectors() { + // Arguments: (CQL-type, from-caller-json, bound-by-driver-for-cql) + return Stream.of( + // // Lists: + Arguments.of( + DataTypes.vectorOf(DataTypes.FLOAT, 3), + // Important: all incoming JSON numbers are represented as Long, BigInteger, + // or BigDecimal. All legal as source for Float. + Arrays.asList( + numberLiteral(0L), + numberLiteral(new BigDecimal(-0.5)), + numberLiteral(new BigDecimal(0.25))), + CqlVector.newInstance(0.0f, -0.5f, 0.25f))); + } + private static JsonLiteral numberLiteral(Number value) { return new JsonLiteral<>(value, JsonType.NUMBER); } @@ -389,6 +410,12 @@ public void codecToJSONCollections( _codecToJSON(cqlType, fromValue, expectedJsonValue); } + @ParameterizedTest + @MethodSource("validCodecToJSONTestCasesVectors") + public void codecToJSONVectors(DataType cqlType, Object fromValue, JsonNode expectedJsonValue) { + _codecToJSON(cqlType, fromValue, expectedJsonValue); + } + private void _codecToJSON(DataType cqlType, Object fromValue, JsonNode expectedJsonValue) { var codec = assertGetCodecToJSON(cqlType); @@ -576,6 +603,19 @@ private static Stream validCodecToJSONTestCasesCollections() throws I OBJECT_MAPPER.readTree("[0.25,-4.5]"))); } + private static Stream validCodecToJSONTestCasesVectors() throws IOException { + // Arguments: (CQL-type, from-CQL-result-set, JsonNode-to-serialize) + return Stream.of( + Arguments.of( + DataTypes.vectorOf(DataTypes.FLOAT, 2), + CqlVector.newInstance(0.25f, -0.5f), + OBJECT_MAPPER.readTree("[0.25,-0.5]")), + Arguments.of( + DataTypes.vectorOf(DataTypes.FLOAT, 3), + CqlVector.newInstance(0.25f, -0.5f, 1.0f), + OBJECT_MAPPER.readTree("[0.25,-0.5,1.0]"))); + } + @Test public void missingJSONCodecException() { @@ -639,28 +679,13 @@ public void unknownColumnException() { @ParameterizedTest @MethodSource("outOfRangeOfCqlNumberTestCases") public void outOfRangeOfCqlNumber(DataType typeToTest, Number valueToTest, String rootCause) { - var codec = assertGetCodecToCQL(typeToTest, valueToTest); - - var error = - assertThrowsExactly( - ToCQLCodecException.class, - () -> codec.toCQL(valueToTest), - String.format( - "Throw ToCQLCodecException for out of range `%s` value: %s", - typeToTest, valueToTest)); - - assertThat(error) - .satisfies( - e -> { - assertThat(e.targetCQLType).isEqualTo(typeToTest); - assertThat(e.value).isEqualTo(valueToTest); - - assertThat(e.getMessage()) - .contains(typeToTest.toString()) - .contains(valueToTest.getClass().getName()) - .contains(valueToTest.toString()) - .contains("Root cause: " + rootCause); - }); + assertToCQLFail( + typeToTest, + valueToTest, + typeToTest.toString(), + valueToTest.getClass().getName(), + valueToTest.toString(), + "Root cause: " + rootCause); } private static Stream outOfRangeOfCqlNumberTestCases() { @@ -701,28 +726,13 @@ private static Stream outOfRangeOfCqlNumberTestCases() { @ParameterizedTest @MethodSource("nonExactToCqlIntegerTestCases") public void nonExactToCqlInteger(DataType typeToTest, Number valueToTest) { - var codec = assertGetCodecToCQL(typeToTest, valueToTest); - - var error = - assertThrowsExactly( - ToCQLCodecException.class, - () -> codec.toCQL(valueToTest), - String.format( - "Throw ToCQLCodecException when attempting to convert `%s` from non-integer value %s", - typeToTest, valueToTest)); - - assertThat(error) - .satisfies( - e -> { - assertThat(e.targetCQLType).isEqualTo(typeToTest); - assertThat(e.value).isEqualTo(valueToTest); - - assertThat(e.getMessage()) - .contains(typeToTest.toString()) - .contains(valueToTest.getClass().getName()) - .contains(valueToTest.toString()) - .contains("Root cause: Rounding necessary"); - }); + assertToCQLFail( + typeToTest, + valueToTest, + typeToTest.toString(), + valueToTest.getClass().getName(), + valueToTest.toString(), + "Root cause: Rounding necessary"); } private static Stream nonExactToCqlIntegerTestCases() { @@ -738,28 +748,12 @@ private static Stream nonExactToCqlIntegerTestCases() { @ParameterizedTest @MethodSource("nonAsciiValueFailTestCases") public void nonAsciiValueFail(String valueToTest) { - var codec = assertGetCodecToCQL(DataTypes.ASCII, valueToTest); - - var error = - assertThrowsExactly( - ToCQLCodecException.class, - () -> codec.toCQL(valueToTest), - String.format( - "Throw ToCQLCodecException when attempting to convert `%s` from non-ASCII value %s", - DataTypes.ASCII, valueToTest)); - - assertThat(error) - .satisfies( - e -> { - assertThat(e.targetCQLType).isEqualTo(DataTypes.ASCII); - assertThat(e.value).isEqualTo(valueToTest); - - assertThat(e.getMessage()) - .contains(DataTypes.ASCII.toString()) - .contains(valueToTest.getClass().getName()) - .contains(valueToTest.toString()) - .contains("Root cause: String contains non-ASCII character at index"); - }); + assertToCQLFail( + DataTypes.ASCII, + valueToTest, + valueToTest.getClass().getName(), + valueToTest.toString(), + "Root cause: String contains non-ASCII character at index"); } private static Stream nonAsciiValueFailTestCases() { @@ -809,78 +803,32 @@ private static Stream invalidCodecToCQLTestCasesDatetime() { // difficult to parameterize this test, so just test a few cases @Test public void invalidBinaryInputs() { - EJSONWrapper valueToTest1 = + assertToCQLFail( + DataTypes.BLOB, new EJSONWrapper( - EJSONWrapper.EJSONType.BINARY, JsonNodeFactory.instance.textNode("bad-base64!")); - final var codec = assertGetCodecToCQL(DataTypes.BLOB, valueToTest1); - var error = - assertThrowsExactly( - ToCQLCodecException.class, - () -> codec.toCQL(valueToTest1), - "Throw ToCQLCodecException when attempting to convert DataTypes.BLOB from invalid Base64 value"); - assertThat(error) - .satisfies( - e -> { - assertThat(e.targetCQLType).isEqualTo(DataTypes.BLOB); - assertThat(e.value).isEqualTo(valueToTest1); - assertThat(e.getMessage()) - .contains("Root cause: Invalid content in EJSON $binary wrapper"); - }); + EJSONWrapper.EJSONType.BINARY, JsonNodeFactory.instance.textNode("bad-base64!")), + "Root cause: Invalid content in EJSON $binary wrapper"); - EJSONWrapper valueToTest2 = - new EJSONWrapper(EJSONWrapper.EJSONType.BINARY, JsonNodeFactory.instance.numberNode(42)); + assertToCQLFail( + DataTypes.BLOB, + new EJSONWrapper(EJSONWrapper.EJSONType.BINARY, JsonNodeFactory.instance.numberNode(42)), + "Root cause: Unsupported JSON value type in EJSON $binary wrapper (NUMBER): only STRING allowed"); - error = - assertThrowsExactly( - ToCQLCodecException.class, - () -> codec.toCQL(valueToTest2), - "Throw ToCQLCodecException when attempting to convert DataTypes.BLOB from non-String EJSONWrapper value"); - assertThat(error) - .satisfies( - e -> { - assertThat(e.targetCQLType).isEqualTo(DataTypes.BLOB); - assertThat(e.value).isEqualTo(valueToTest2); - assertThat(e.getMessage()) - .contains( - "Root cause: Unsupported JSON value type in EJSON $binary wrapper (NUMBER): only STRING allowed"); - }); - - // Test with unpadded base64 - EJSONWrapper valueToTest3 = binaryWrapper(TEST_DATA.BASE64_UNPADDED_ENCODED_STR); - error = - assertThrowsExactly( - ToCQLCodecException.class, - () -> codec.toCQL(valueToTest3), - "Throw ToCQLCodecException when attempting to convert DataTypes.BLOB from non-String EJSONWrapper value"); - assertThat(error) - .satisfies( - e -> { - assertThat(e.targetCQLType).isEqualTo(DataTypes.BLOB); - assertThat(e.value).isEqualTo(valueToTest3); - assertThat(e.getMessage()) - .contains("Unexpected end of base64-encoded String") - .contains("expects padding"); - }); + // We require Base64 padding + assertToCQLFail( + DataTypes.BLOB, + binaryWrapper(TEST_DATA.BASE64_UNPADDED_ENCODED_STR), + "Unexpected end of base64-encoded String", + "expects padding"); } @Test public void invalidInetAddress() { - final String valueToTest = TEST_DATA.INET_ADDRESS_INVALID_STRING; - final var codec = assertGetCodecToCQL(DataTypes.INET, valueToTest); - var error = - assertThrowsExactly( - ToCQLCodecException.class, - () -> codec.toCQL(valueToTest), - "Throw ToCQLCodecException when attempting to convert DataTypes.INET from invalid Base64 value"); - assertThat(error) - .satisfies( - e -> { - assertThat(e.targetCQLType).isEqualTo(DataTypes.INET); - assertThat(e.value).isEqualTo(valueToTest); - assertThat(e.getMessage()) - .contains("Root cause: Invalid String value for type `INET`") - .contains("Invalid IP address value"); - }); + assertToCQLFail( + DataTypes.INET, + TEST_DATA.INET_ADDRESS_INVALID_STRING, + "Root cause: Invalid String value for type `INET`", + "Invalid IP address value"); } @Test @@ -906,7 +854,7 @@ public void invalidListValueFail() { public void invalidSetValueFail() { DataType cqlTypeToTest = DataTypes.setOf(DataTypes.INT); List> valueToTest = List.of(stringLiteral("xyz")); - var codec = assertGetCodecToCQL(cqlTypeToTest, new ArrayList<>()); + var codec = assertGetCodecToCQL(cqlTypeToTest, valueToTest); var error = assertThrowsExactly( @@ -920,4 +868,43 @@ public void invalidSetValueFail() { assertThat(e.getMessage()).contains("no codec matching (list/set)"); }); } + + @Test + public void invalidVectorValueNonNumberFail() { + DataType cqlTypeToTest = DataTypes.vectorOf(DataTypes.FLOAT, 1); + List> valueToTest = List.of(stringLiteral("abc")); + assertToCQLFail( + cqlTypeToTest, valueToTest, "expected JSON Number value as Vector element at position #0"); + } + + @Test + public void invalidVectorValueWrongDimensionFail() { + DataType cqlTypeToTest = DataTypes.vectorOf(DataTypes.FLOAT, 1); + List> valueToTest = List.of(numberLiteral(1.0), numberLiteral(-0.5)); + assertToCQLFail( + cqlTypeToTest, valueToTest, "expected vector of length 1, got one with 2 elements"); + } + + private void assertToCQLFail(DataType cqlType, Object valueToTest, String... expectedMessages) { + var codec = assertGetCodecToCQL(cqlType, valueToTest); + + ToCQLCodecException error = + assertThrowsExactly( + ToCQLCodecException.class, + () -> codec.toCQL(valueToTest), + String.format( + "Throw ToCQLCodecException when attempting to convert `%s` from value of %s", + cqlType, valueToTest)); + + assertThat(error) + .satisfies( + e -> { + assertThat(e.targetCQLType).isEqualTo(cqlType); + assertThat(e.value).isEqualTo(valueToTest); + + for (String expectedMessage : expectedMessages) { + assertThat(e.getMessage()).contains(expectedMessage); + } + }); + } }