Skip to content

Commit

Permalink
Implement #1492: basic Vector<Float> support with Number arrays (#1542)
Browse files Browse the repository at this point in the history
  • Loading branch information
tatu-at-datastax authored Oct 15, 2024
1 parent ebc1f85 commit e5f718c
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
* into the appropriate API error.
*/
public class ToCQLCodecException extends CheckedApiException {
/**
* Since String representation of some input values can be huge (like for Vectors), we limit the
* length of the description to avoid flooding logs.
*/
private static final int MAX_VALUE_DESC_LENGTH = 1000;

public final Object value;
public final DataType targetCQLType;
Expand Down Expand Up @@ -39,10 +44,11 @@ private static String valueDesc(Object value) {
if (value == null) {
return "null";
}
String desc = maybeTruncate(String.valueOf(value));
if (value instanceof String) {
return "\"" + value + "\"";
desc = "\"" + desc + "\"";
}
return String.valueOf(value);
return desc;
}

private static String className(Object value) {
Expand All @@ -51,4 +57,11 @@ private static String className(Object value) {
}
return value.getClass().getName();
}

private static String maybeTruncate(String value) {
if (value.length() <= MAX_VALUE_DESC_LENGTH) {
return value;
}
return value.substring(0, MAX_VALUE_DESC_LENGTH) + "[...](TRUNCATED)";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.datastax.oss.driver.api.core.type.DataTypes;
import com.datastax.oss.driver.api.core.type.ListType;
import com.datastax.oss.driver.api.core.type.SetType;
import com.datastax.oss.driver.api.core.type.VectorType;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import io.stargate.sgv2.jsonapi.exception.catchable.MissingJSONCodecException;
import io.stargate.sgv2.jsonapi.exception.catchable.ToCQLCodecException;
Expand Down Expand Up @@ -116,6 +117,17 @@ public <JavaT, CqlT> JSONCodec<JavaT, CqlT> codecToCQL(
CollectionCodecs.buildToCQLSetCodec(valueCodecCandidates, st.getElementType());
}
// fall through
} else if (columnType instanceof VectorType vt) {
// Only Float<Vector> supported for now
if (!vt.getElementType().equals(DataTypes.FLOAT)) {
throw new ToCQLCodecException(value, columnType, "only Vector<Float> supported");
}
if (value instanceof Collection<?>) {
return VectorCodecs.arrayToCQLFloatVectorCodec(vt);
}
// !!! TODO: different Codec for Base64 encoded (String) Float vectors

throw new ToCQLCodecException(value, columnType, "no codec matching value type");
}

throw new MissingJSONCodecException(table, columnMetadata, value.getClass(), value);
Expand Down Expand Up @@ -189,6 +201,13 @@ public <JavaT, CqlT> JSONCodec<JavaT, CqlT> codecToJSON(DataType fromCQLType) {
return (JSONCodec<JavaT, CqlT>)
CollectionCodecs.buildToJsonSetCodec(valueCodecCandidates.get(0));
}
if (fromCQLType instanceof VectorType vt) {
// Only Float<Vector> supported for now
if (vt.getElementType().equals(DataTypes.FLOAT)) {
return VectorCodecs.toJSONFloatVectorCodec(vt);
}
// fall through
}

return null;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package io.stargate.sgv2.jsonapi.service.operation.filters.table.codecs;

import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.type.VectorType;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.JsonLiteral;
import io.stargate.sgv2.jsonapi.exception.catchable.ToCQLCodecException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**
* Container for factories of codecs that handle CQL Vector type. Separated from main {@link
* JSONCodecs} to keep the code somewhat modular.
*/
public abstract class VectorCodecs {
private static final GenericType<List<Float>> FLOAT_LIST = GenericType.listOf(Float.class);

public static <JavaT, CqlT> JSONCodec<JavaT, CqlT> arrayToCQLFloatVectorCodec(
VectorType vectorType) {
// Unfortunately we cannot simply construct and return a single Codec instance here
// because VectorType's dimensions vary, and we need to know the expected dimensions
// (unless we want to rely on DB validating dimension as part of write and catch failure)
return (JSONCodec<JavaT, CqlT>)
new JSONCodec<>(
FLOAT_LIST,
vectorType,
(cqlType, value) -> toCQLFloatVector(vectorType, value),
// This codec only for to-cql case, not to-json, so we don't need this
null);
}

public static <JavaT, CqlT> JSONCodec<JavaT, CqlT> toJSONFloatVectorCodec(VectorType vectorType) {
return (JSONCodec<JavaT, CqlT>)
new JSONCodec<>(
FLOAT_LIST,
vectorType,
// This codec only for to-json case, not to-cql, so we don't need this
null,
(objectMapper, cqlType, value) -> toJsonNode(objectMapper, (CqlVector<Number>) value));
}

static CqlVector<Float> toCQLFloatVector(VectorType vectorType, Collection<?> listValue)
throws ToCQLCodecException {
Collection<JsonLiteral<?>> vectorIn = (Collection<JsonLiteral<?>>) listValue;
final int expLen = vectorType.getDimensions();
if (expLen != vectorIn.size()) {
throw new ToCQLCodecException(
vectorIn,
vectorType,
"expected vector of length "
+ expLen
+ ", got one with "
+ vectorIn.size()
+ " elements");
}
List<Float> floats = new ArrayList<>(expLen);
for (JsonLiteral<?> literalElement : vectorIn) {
Object element = literalElement.value();
if (element instanceof Number num) {
floats.add(num.floatValue());
continue;
}
throw new ToCQLCodecException(
vectorIn,
vectorType,
String.format(
"expected JSON Number value as Vector element at position #%d (of %d), instead have: %s",
floats.size(), expLen, literalElement));
}
return CqlVector.newInstance(floats);
}

static JsonNode toJsonNode(ObjectMapper objectMapper, CqlVector<Number> vectorValue) {
final ArrayNode result = objectMapper.createArrayNode();
for (Number element : vectorValue) {
if (element == null) { // is this even legal?
result.addNull();
} else {
result.add(element.floatValue());
}
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class InsertOneTableIntegrationTest extends AbstractTableIntegrationTestB
static final String TABLE_WITH_INET_COLUMN = "insertOneInetColumnTable";
static final String TABLE_WITH_LIST_COLUMNS = "insertOneListColumnsTable";
static final String TABLE_WITH_SET_COLUMNS = "insertOneSetColumnsTable";
static final String TABLE_WITH_VECTOR_COLUMN = "insertOneVectorColumnTable";

final JSONCodecRegistryTestData codecTestData = new JSONCodecRegistryTestData();

Expand Down Expand Up @@ -109,6 +110,12 @@ public final void createDefaultTables() {
"stringSet",
Map.of("type", "set", "valueType", "text")),
"id");

createTableWithColumns(
TABLE_WITH_VECTOR_COLUMN,
Map.of(
"id", "text", "vector", Map.of("type", "vector", "valueType", "float", "dimension", 3)),
"id");
}

@Nested
Expand Down Expand Up @@ -790,4 +797,59 @@ void failOnWrongSetElementValue() {
"Unsupported String value: only");
}
}

@Nested
@Order(10)
class InsertVectorColumns {
@Test
void insertValidVectorValue() {
String docJSON =
"""
{ "id": "vectorValid",
"vector": [0.0, -0.5, 3.125]
}
""";
insertOneInTable(TABLE_WITH_VECTOR_COLUMN, docJSON);
DataApiCommandSenders.assertTableCommand(keyspaceName, TABLE_WITH_VECTOR_COLUMN)
.postFindOne("{ \"filter\": { \"id\": \"vectorValid\" } }")
.hasNoErrors()
.hasJSONField("data.document", docJSON);
}

@Test
void failOnNonArrayVectorValue() {
DataApiCommandSenders.assertTableCommand(keyspaceName, TABLE_WITH_VECTOR_COLUMN)
.postInsertOne(
"""
{
"id": "vectorInvalid",
"vector": "abc"
}
""")
.hasSingleApiError(
DocumentException.Code.INVALID_COLUMN_VALUES,
DocumentException.class,
"Only values that are supported by",
"Error trying to convert to targetCQLType `Vector(FLOAT",
"no codec matching value type");
}

@Test
void failOnWrongVectorElementValue() {
DataApiCommandSenders.assertTableCommand(keyspaceName, TABLE_WITH_VECTOR_COLUMN)
.postInsertOne(
"""
{
"id":" vectorInvalid",
"vector": ["abc", 123, false]
}
""")
.hasSingleApiError(
DocumentException.Code.INVALID_COLUMN_VALUES,
DocumentException.class,
"Only values that are supported by",
"Error trying to convert to targetCQLType `Vector(FLOAT",
"expected JSON Number value as Vector element at position #0");
}
}
}
Loading

0 comments on commit e5f718c

Please sign in to comment.