Hi everyone,
I am posting my issue here but please let me know if this is not the appropriate place to do so.
I am trying to write a criteria query that order the results based on a vector distance.
My entity field holding the vector (an embedding coming from a LLM) is defined as such
@JdbcTypeCode(SqlTypes.VECTOR)
@Array(length = 1024)
private float[] embedding;
I am trying to write using criteria api the equivalent of the following jpql query (which works perfectly)
@Query("select data from Data data order by l2_distance(data.embedding, cast(:queryEmbedding as vector))")
Page<Data> findByEmbedding(float[] queryEmbedding, Pageable pageable);
but this proves to be challenging (and the documentation doesn’t seem to have an example for this)
if i try the following
var distanceExpr = cb.function(
"l2_distance",
Double.class,
root.get("embedding"),
cb.literal(embeddingQuery) // embeddingQuery is a float[]
);
criteriaQuery.orderBy(cb.asc(distanceExpr));
Hibernate validation complains with the following exception
Caused by: org.hibernate.query.sqm.produce.function.FunctionArgumentException: Parameter 1 of function 'euclidean_distance()' requires a vector type, but argument is of type 'float[]'
at org.hibernate.vector.VectorArgumentValidator.validate(VectorArgumentValidator.java:38)
at org.hibernate.query.sqm.produce.function.StandardArgumentsValidators$8.lambda$validate$0(StandardArgumentsValidators.java:254)
at java.base/java.util.Arrays$ArrayList.forEach(Arrays.java:4305)
at org.hibernate.query.sqm.produce.function.StandardArgumentsValidators$8.validate(StandardArgumentsValidators.java:254)
at org.hibernate.query.sqm.function.AbstractSqmFunctionDescriptor.generateSqmExpression(AbstractSqmFunctionDescriptor.java:102)
at org.hibernate.query.sqm.internal.SqmCriteriaNodeBuilder.function(SqmCriteriaNodeBuilder.java:1747)
at org.hibernate.query.sqm.internal.SqmCriteriaNodeBuilder.function(SqmCriteriaNodeBuilder.java:192)
I tried several things to have an appropriate cast for the second argument but none seem to work. How are argument of vector type be added in the function?
Note: I am using Hibernate 6.6.13.Final