Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #1775: use optimized CqlVector<Float> codec to improve performance #1801

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io.stargate.sgv2.jsonapi.config.OperationsConfig;
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaCache;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector.SubtypeOnlyFloatVectorToArrayCodec;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -148,6 +149,9 @@ private CqlSession getNewSession(SessionCacheKey cacheKey) {
builder.addContactPoints(seeds);
}

// Add optimized CqlVector codec (see [data-api#1775])
builder = builder.addTypeCodecs(SubtypeOnlyFloatVectorToArrayCodec.instance());

// aaron - this used to have an if / else that threw an exception if the database type was not
// known but we test that when creating the credentials for the cache key so no need to do it
// here.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector;

import com.datastax.oss.driver.api.core.ProtocolVersion;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.core.type.DataTypes;
import com.datastax.oss.driver.api.core.type.VectorType;
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.datastax.oss.driver.internal.core.type.codec.FloatCodec;
import com.datastax.oss.driver.shaded.guava.common.base.Splitter;
import com.datastax.oss.driver.shaded.guava.common.collect.Iterators;
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Objects;

/**
* Implementation of {@link TypeCodec} which translates CQL vectors into float arrays. Difference
* between this and {@link
* com.datastax.oss.driver.internal.core.type.codec.extras.vector.FloatVectorToArrayCodec} is that
* we don't concern ourselves with the dimensionality specified in the input CQL type. This codec
* just reads all the bytes, tries to deserislize them consecutively into subtypes and then returns
* the result. Serialiation is similar: we take the input array, serialize each element and return
* the result.
*/
public class SubtypeOnlyFloatVectorToArrayCodec implements TypeCodec<float[]> {

private static final int ELEMENT_SIZE = 4;

@NonNull protected final VectorType cqlType;
@NonNull protected final GenericType<float[]> javaType;

private final FloatCodec floatCodec = new FloatCodec();

private static final SubtypeOnlyFloatVectorToArrayCodec INSTANCE =
new SubtypeOnlyFloatVectorToArrayCodec(DataTypes.FLOAT);

private SubtypeOnlyFloatVectorToArrayCodec(@NonNull DataType subType) {
cqlType = new SubtypeOnlyVectorType(Objects.requireNonNull(subType, "subType cannot be null"));
javaType = GenericType.of(float[].class);
}

public static TypeCodec<float[]> instance() {
return INSTANCE;
}

@NonNull
@Override
public GenericType<float[]> getJavaType() {
return javaType;
}

@NonNull
@Override
public DataType getCqlType() {
return cqlType;
}

@Override
public boolean accepts(@NonNull Class<?> javaClass) {
return float[].class.equals(javaClass);
}

@Override
public boolean accepts(@NonNull Object value) {
return value instanceof float[];
}

@Override
public boolean accepts(@NonNull DataType value) {
if (!(value instanceof VectorType)) {
return false;
}
VectorType valueVectorType = (VectorType) value;
return this.cqlType.getElementType().equals(valueVectorType.getElementType());
}

@Nullable
@Override
public ByteBuffer encode(@Nullable float[] array, @NonNull ProtocolVersion protocolVersion) {
if (array == null) {
return null;
}
int length = array.length;
int totalSize = length * ELEMENT_SIZE;
ByteBuffer output = ByteBuffer.allocate(totalSize);
for (int i = 0; i < length; i++) {
serializeElement(output, array, i, protocolVersion);
}
output.flip();
return output;
}

@Nullable
@Override
public float[] decode(@Nullable ByteBuffer bytes, @NonNull ProtocolVersion protocolVersion) {
if (bytes == null || bytes.remaining() == 0) {
throw new IllegalArgumentException(
"Input ByteBuffer must not be null and must have non-zero remaining bytes");
}
// TODO: Do we want to treat this as an error? We could also just ignore any extraneous bytes
// if they appear.
if (bytes.remaining() % ELEMENT_SIZE != 0) {
throw new IllegalArgumentException(
String.format("Input ByteBuffer should have a multiple of %d bytes", ELEMENT_SIZE));
Comment on lines +106 to +107
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how common this exception is. If this is a common error, should we use Error V2 instead of IllegalArgumentException? Have the same question for other IllegalArgumentExceptions. Or all these IllegalArgumentExceptions will be caught eventually and converted to our own errors? I see there is caught in VectorCodec and convert it to ToCQLCodecException, but not sure if it covers all the cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This codec is used by Java CQL Driver, it's not a Data API codec. But I think code we use for java driver will have to catch exceptions anyway. Problem itself should not be possible to occur by any external calls; it is more an assertion (and as such not sure there is a way to easily test).

}
ByteBuffer input = bytes.duplicate();
int elementCount = input.remaining() / 4;
float[] array = new float[elementCount];
for (int i = 0; i < elementCount; i++) {
deserializeElement(input, array, i, protocolVersion);
}
return array;
}

/**
* Write the {@code index}th element of {@code array} to {@code output}.
*
* @param output The ByteBuffer to write to.
* @param array The array to read from.
* @param index The element index.
* @param protocolVersion The protocol version to use.
*/
protected void serializeElement(
@NonNull ByteBuffer output,
@NonNull float[] array,
int index,
@NonNull ProtocolVersion protocolVersion) {
output.putFloat(array[index]);
}

/**
* Read the {@code index}th element of {@code array} from {@code input}.
*
* @param input The ByteBuffer to read from.
* @param array The array to write to.
* @param index The element index.
* @param protocolVersion The protocol version to use.
*/
protected void deserializeElement(
@NonNull ByteBuffer input,
@NonNull float[] array,
int index,
@NonNull ProtocolVersion protocolVersion) {
array[index] = input.getFloat();
}

@NonNull
@Override
public String format(@Nullable float[] value) {
return value == null ? "NULL" : Arrays.toString(value);
}

@Nullable
@Override
public float[] parse(@Nullable String str) {
/* TODO: Logic below requires a double traversal through the input String but there's no other obvious way to
* get the size. It's still probably worth the initial pass through in order to avoid having to deal with
* resizing ops. Fortunately we're only dealing with the format/parse pair here so this shouldn't impact
* general performance much. */
if ((str == null) || str.isEmpty()) {
throw new IllegalArgumentException("Cannot create float array from null or empty string");
}
Iterable<String> strIterable =
Splitter.on(", ").trimResults().split(str.substring(1, str.length() - 1));
float[] rv = new float[Iterators.size(strIterable.iterator())];
Iterator<String> strIterator = strIterable.iterator();
for (int i = 0; i < rv.length; ++i) {
String strVal = strIterator.next();
// TODO: String.isBlank() should be included here but it's only available with Java11+
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are using Java 21, why don't we use String.isBlank() here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code was first written for Java CQL driver (by Brett). I can change that, thanks.

if (strVal == null || strVal.isEmpty()) {
throw new IllegalArgumentException("Null element observed in float array string");
}
rv[i] = floatCodec.parse(strVal).floatValue();
}
return rv;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector;

import com.datastax.oss.driver.api.core.detach.AttachmentPoint;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.core.type.VectorType;
import com.datastax.oss.driver.internal.core.type.DefaultVectorType;
import java.util.Objects;

/**
* An implementation of {@link VectorType} which is only concerned with the subtype of the vector.
* Useful if you want to describe a call of vector types that do not differ by subtype but do differ
* by dimension.
*/
public class SubtypeOnlyVectorType extends DefaultVectorType {
private static final int NO_DIMENSION = -1;

public SubtypeOnlyVectorType(DataType subtype) {
super(subtype, NO_DIMENSION);
}

@Override
public int getDimensions() {
throw new UnsupportedOperationException("Subtype-only vectors do not support dimensions");
}

/* ============== General class implementation ============== */
@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
return (o instanceof VectorType that) && that.getElementType().equals(getElementType());
}

@Override
public int hashCode() {
return super.hashCode() ^ Objects.hashCode(getElementType());
}

@Override
public String toString() {
return String.format("(Subtype-only) Vector(%s)", getElementType());
}

@Override
public boolean isDetached() {
return false;
}

@Override
public void attach(AttachmentPoint attachmentPoint) {
// nothing to do
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ public <JavaT, CqlT> JSONCodec<JavaT, CqlT> codecToCQL(
throw new ToCQLCodecException(value, columnType, "only Vector<Float> supported");
}
if (value instanceof Collection<?>) {
return VectorCodecs.arrayToCQLFloatVectorCodec(vt);
return VectorCodecs.arrayToCQLFloatArrayCodec(vt);
}
if (value instanceof EJSONWrapper) {
return VectorCodecs.binaryToCQLFloatVectorCodec(vt);
return VectorCodecs.binaryToCQLFloatArrayCodec(vt);
}

throw new ToCQLCodecException(value, columnType, "no codec matching value type");
Expand Down
Loading
Loading