How to use criteria Api to build query with vector operation such as euclidean distance

Hi everyone,

I am posting my issue here but please let me know if this is not the appropriate place to do so.

I am trying to write a criteria query that order the results based on a vector distance.

My entity field holding the vector (an embedding coming from a LLM) is defined as such

@JdbcTypeCode(SqlTypes.VECTOR)
@Array(length = 1024)
private float[] embedding;

I am trying to write using criteria api the equivalent of the following jpql query (which works perfectly)

    @Query("select data from Data data order by l2_distance(data.embedding, cast(:queryEmbedding as vector))")
    Page<Data> findByEmbedding(float[] queryEmbedding, Pageable pageable);

but this proves to be challenging (and the documentation doesn’t seem to have an example for this)

if i try the following

var distanceExpr = cb.function(
    "l2_distance",
    Double.class,
    root.get("embedding"),
    cb.literal(embeddingQuery) // embeddingQuery is a float[]
);

criteriaQuery.orderBy(cb.asc(distanceExpr));

Hibernate validation complains with the following exception

Caused by: org.hibernate.query.sqm.produce.function.FunctionArgumentException: Parameter 1 of function 'euclidean_distance()' requires a vector type, but argument is of type 'float[]'
	at org.hibernate.vector.VectorArgumentValidator.validate(VectorArgumentValidator.java:38)
	at org.hibernate.query.sqm.produce.function.StandardArgumentsValidators$8.lambda$validate$0(StandardArgumentsValidators.java:254)
	at java.base/java.util.Arrays$ArrayList.forEach(Arrays.java:4305)
	at org.hibernate.query.sqm.produce.function.StandardArgumentsValidators$8.validate(StandardArgumentsValidators.java:254)
	at org.hibernate.query.sqm.function.AbstractSqmFunctionDescriptor.generateSqmExpression(AbstractSqmFunctionDescriptor.java:102)
	at org.hibernate.query.sqm.internal.SqmCriteriaNodeBuilder.function(SqmCriteriaNodeBuilder.java:1747)
	at org.hibernate.query.sqm.internal.SqmCriteriaNodeBuilder.function(SqmCriteriaNodeBuilder.java:192)

I tried several things to have an appropriate cast for the second argument but none seem to work. How are argument of vector type be added in the function?

Note: I am using Hibernate 6.6.13.Final

This could be a bug or simply a limitation, though I’d have to dig deeper to say for sure. Try using a named parameter instead of a literal. Parameters will be typed according to type inference rules, whereas literals might not, though they probably should in this case.

Thanks for your suggestion, I will give it a try.