How to make Hibernate Proxies implement a certain marker interface

Hi again,

I’ve had one problem with the Autofetch migration project for a while now. It came up after we started trying to enforce lazy loading for all entities. We then realized that there is something wrong with the proxies and their ability to be tracked. I want to put a marker interface on proxies in order for my program to do certain operations if the current proxy is an instance of this marker interface. The problem is that I can’t get it to work for single-valued associations, only for collection associations. This has to do with that we inject our own custom collection types that implement this interface. It also works as it should when it is the real object and not the proxy for some reason.

I’ve tried injecting the interface into the LazyInitializer (which acts as the handler for the proxy) but that did not help. Somehow setting the interface field in the LazyInitializer doesn’t make the proxy to actually implement the interface, which I initially thought.

So basically what I need to do is to make sure that our proxies implement this “TrackableEntity” marker interface. Note that in the code we check if the entity is instance of Trackable, but this is a superinterface to TrackableEntity so if we implement TrackableEntity we also implement Trackable. Does anyone know how I can get my proxies to also implement this interface?

I added a picture showing where I need the propVal to be an instance of Trackable, but it’s not. If I do the same thing but eagerly I get the real object instead and then it goes inside this if-clause.

public class AutofetchLazyInitializer extends BasicLazyInitializer implements MethodHandler {

    private static final CoreMessageLogger LOG = CoreLogging.messageLogger(AutofetchLazyInitializer.class);

    private final EntityTracker entityTracker;

    private final Class[] interfaces;

    private boolean entityTrackersSet;

    private boolean constructed;

    private AutofetchLazyInitializer(String entityName,
                                     Class persistentClass,
                                     Class[] interfaces,
                                     Serializable id,
                                     Method getIdentifierMethod,
                                     Method setIdentifierMethod,
                                     CompositeType componentIdType,
                                     SessionImplementor session,
                                     Set<Property> persistentProperties,
                                     boolean classOverridesEquals) {

        super(entityName, persistentClass, id, getIdentifierMethod, setIdentifierMethod,
                componentIdType, session, classOverridesEquals);

        this.interfaces = interfaces;

        AutofetchService autofetchService = session.getFactory().getServiceRegistry().getService(AutofetchService.class);
        this.entityTracker = new EntityTracker(persistentProperties, autofetchService.getExtentManager());
        this.entityTrackersSet = false;
    }

    @Override
    public Object invoke(final Object proxy, final Method thisMethod, final Method proceed, final Object[] args) throws Throwable {
        if (this.constructed) {
            Object result;
            try {
                result = this.invoke(thisMethod, args, proxy);
            } catch (Throwable t) {
                throw new Exception(t.getCause());
            }

            if (result == INVOKE_IMPLEMENTATION) {
                if (args.length == 0) {
                    switch (thisMethod.getName()) {
                        case "enableTracking":
                            return handleEnableTracking();
                        case "disableTracking":
                            return handleDisableTracking();
                        case "isAccessed":
                            return entityTracker.isAccessed();
                    }
                } else if (args.length == 1) {
                    if (thisMethod.getName().equals("addTracker") && thisMethod.getParameterTypes()[0].equals(Statistics.class)) {
                        return handleAddTracked(args[0]);
                    } else if (thisMethod.getName().equals("addTrackers") && thisMethod.getParameterTypes()[0].equals(Set.class)) {
                        return handleAddTrackers(args[0]);
                    } else if (thisMethod.getName().equals("removeTracker") && thisMethod.getParameterTypes()[0].equals(Statistics.class)) {
                        entityTracker.removeTracker((Statistics) args[0]);
                        return handleRemoveTracker(args);
                    } else if (thisMethod.getName().equals("extendProfile") && thisMethod.getParameterTypes()[0].equals(Statistics.class)) {
                        return extendProfile(args);
                    }
                }

                final Object target = getImplementation();
                final Object returnValue;

                try {
                    if (ReflectHelper.isPublic(persistentClass, thisMethod)) {
                        if (!thisMethod.getDeclaringClass().isInstance(target)) {
                            throw new ClassCastException(
                                    target.getClass().getName() + " incompatible with " + thisMethod.getDeclaringClass().getName()
                            );
                        }
                    } else {
                        thisMethod.setAccessible(true);
                    }

                    returnValue = thisMethod.invoke(target, args);
                    if (returnValue == target) {
                        if (returnValue.getClass().isInstance(proxy)) {
                            return proxy;
                        } else {
                            LOG.narrowingProxy(returnValue.getClass());
                        }
                    }

                    return returnValue;
                } catch (InvocationTargetException ite) {
                    throw ite.getTargetException();
                } finally {
                    if (!entityTrackersSet && target instanceof Trackable) {
                        entityTrackersSet = true;
                        Trackable entity = (Trackable) target;
                        entity.addTrackers(entityTracker.getTrackers());
                        if (entityTracker.isTracking()) {
                            entity.enableTracking();
                        } else {
                            entity.disableTracking();
                        }
                    }
                }
            } else {
                return result;
            }
        } else {
            // while constructor is running
            if (thisMethod.getName().equals("getHibernateLazyInitializer")) {
                return this;
            } else {
                return proceed.invoke(proxy, args);
            }
        }
    }

    private Object handleDisableTracking() {
        boolean oldValue = entityTracker.isTracking();
        this.entityTracker.setTracking(false);
        if (!isUninitialized()) {
            Object o = getImplementation();
            if (o instanceof Trackable) {
                Trackable entity = (Trackable) o;
                entity.disableTracking();
            }
        }

        return oldValue;
    }

    private Object handleEnableTracking() {
        boolean oldValue = this.entityTracker.isTracking();
        this.entityTracker.setTracking(true);

        if (!isUninitialized()) {
            Object o = getImplementation();
            if (o instanceof Trackable) {
                Trackable entity = (Trackable) o;
                entity.enableTracking();
            }
        }

        return oldValue;
    }

    private Object extendProfile(Object[] params) {
        if (!isUninitialized()) {
            Object o = getImplementation();
            if (o instanceof TrackableEntity) {
                TrackableEntity entity = (TrackableEntity) o;
                entity.extendProfile((Statistics) params[0]);
            }
        } else {
            throw new IllegalStateException("Can't call extendProfile on unloaded self.");
        }

        return null;
    }

    private Object handleRemoveTracker(Object[] params) {
        if (!isUninitialized()) {
            Object o = getImplementation();
            if (o instanceof Trackable) {
                Trackable entity = (Trackable) o;
                entity.removeTracker((Statistics) params[0]);
            }
        }
        return null;
    }

    @SuppressWarnings("unchecked")
    private Object handleAddTrackers(Object param) {
        Set<Statistics> newTrackers = (Set<Statistics>) param;
        this.entityTracker.addTrackers(newTrackers);
        if (!isUninitialized()) {
            Object o = getImplementation();
            if (o instanceof Trackable) {
                Trackable entity = (Trackable) o;
                entity.addTrackers(newTrackers);
            }
        }

        return null;
    }

    private Object handleAddTracked(Object param) {
        this.entityTracker.addTracker((Statistics) param);
        if (!isUninitialized()) {
            Object o = getImplementation();
            if (o instanceof Trackable) {
                Trackable entity = (Trackable) o;
                entity.addTracker((Statistics) param);
            }
        }

        return null;
    }

    @Override
    protected Object serializableProxy() {
        return new AutofetchSerializableProxy(
                getEntityName(),
                this.persistentClass,
                this.interfaces,
                getIdentifier(),
                (isReadOnlySettingAvailable() ? Boolean.valueOf(isReadOnly()) : isReadOnlyBeforeAttachedToSession()),
                this.getIdentifierMethod,
                this.setIdentifierMethod,
                this.componentIdType,
                this.entityTracker.getPersistentProperties()
        );
    }

    public static HibernateProxy getProxy(
            final String entityName,
            final Class persistentClass,
            final Class[] interfaces,
            final Method getIdentifierMethod,
            final Method setIdentifierMethod,
            final CompositeType componentIdType,
            final Serializable id,
            final SessionImplementor session,
            final Set<Property> persistentProperties) throws HibernateException {

        // note: interface is assumed to already contain HibernateProxy.class
        try {
            final AutofetchLazyInitializer instance = new AutofetchLazyInitializer(
                    entityName,
                    persistentClass,
                    interfaces,
                    id,
                    getIdentifierMethod,
                    setIdentifierMethod,
                    componentIdType,
                    session,
                    persistentProperties,
                    ReflectHelper.overridesEquals(persistentClass)
            );

            final ProxyFactory factory = new ProxyFactory();
            factory.setSuperclass(interfaces.length == 1 ? persistentClass : null);
            factory.setInterfaces(interfaces);
            factory.setFilter(FINALIZE_FILTER);
            Class cl = factory.createClass();
            final HibernateProxy proxy = (HibernateProxy) cl.newInstance();
            ((Proxy) proxy).setHandler(instance);
            instance.constructed = true;
            return proxy;
        } catch (Throwable t) {
            LOG.error(LOG.javassistEnhancementFailed(entityName), t);
            throw new HibernateException(LOG.javassistEnhancementFailed(entityName), t);
        }
    }

    public static HibernateProxy getProxy(
            final Class factory,
            final String entityName,
            final Class persistentClass,
            final Class[] interfaces,
            final Method getIdentifierMethod,
            final Method setIdentifierMethod,
            final CompositeType componentIdType,
            final Serializable id,
            final SessionImplementor session,
            final Set<Property> persistentProperties) throws HibernateException {

        // note: interfaces is assumed to already contain HibernateProxy.class
        final AutofetchLazyInitializer instance = new AutofetchLazyInitializer(
                entityName,
                persistentClass,
                interfaces,
                id,
                getIdentifierMethod,
                setIdentifierMethod,
                componentIdType,
                session,
                persistentProperties,
                ReflectHelper.overridesEquals(persistentClass)
        );

        final HibernateProxy proxy;
        try {
            proxy = (HibernateProxy) factory.newInstance();
        } catch (Exception e) {
            throw new HibernateException("Javassist Enhancement failed: " + persistentClass.getName(), e);
        }

        ((Proxy) proxy).setHandler(instance);
        instance.constructed = true;

        return proxy;
    }

    private static final MethodFilter FINALIZE_FILTER = new MethodFilter() {

        @Override
        public boolean isHandled(Method m) {
            // skip finalize methods
            return !(m.getParameterTypes().length == 0 && m.getName().equals("finalize"));
        }
    };
}

public class EntityProxyFactory {

    private static final CoreMessageLogger LOG = CoreLogging.messageLogger(AutofetchLazyInitializer.class);

    private static final MethodFilter FINALIZE_FILTER = new MethodFilter() {
        @Override
        public boolean isHandled(Method m) {
            // skip finalize methods
            return !(m.getParameterTypes().length == 0 && m.getName().equals("finalize"));
        }
    };

    private static final ConcurrentMap<Class<?>, Class<?>> entityFactoryMap = new ConcurrentHashMap<>();

    private static final ConcurrentMap<Class<?>, Constructor<?>> entityConstructorMap = new ConcurrentHashMap<>();

    private static Class<?> getProxyFactory(Class<?> persistentClass, String idMethodName) {
        if (!entityFactoryMap.containsKey(persistentClass)) {
            ProxyFactory factory = new ProxyFactory();
            factory.setSuperclass(persistentClass);
            factory.setInterfaces(new Class[]{TrackableEntity.class});
            factory.setFilter(FINALIZE_FILTER);
            entityFactoryMap.putIfAbsent(persistentClass, factory.createClass());
        }

        return entityFactoryMap.get(persistentClass);
    }

    private static <T> Constructor<T> getDefaultConstructor(Class<T> clazz) throws NoSuchMethodException {
        Constructor<T> constructor = clazz.getDeclaredConstructor();
        if (!constructor.isAccessible()) {
            constructor.setAccessible(true);
        }

        return constructor;
    }

    public static Object getProxyInstance(Class persistentClass, String idMethodName, Set<Property> persistentProperties,
                                          ExtentManager extentManager)
            throws InstantiationException, IllegalAccessException, NoSuchMethodException, InvocationTargetException {

        if (Modifier.isFinal(persistentClass.getModifiers())) {
            // Use the default constructor, because final classes cannot be inherited.
            return useDefaultConstructor(persistentClass);
        }

        Class<?> factory = getProxyFactory(persistentClass, idMethodName);
        try {
            final Object proxy = factory.newInstance();
            ((Proxy) proxy).setHandler(new EntityProxyMethodHandler(persistentProperties, extentManager));
            return proxy;
        } catch (IllegalAccessException | InstantiationException e) {
            return useDefaultConstructor(persistentClass);
        }
    }

    private static Object useDefaultConstructor(Class<?> clazz) throws NoSuchMethodException, InstantiationException,
            InvocationTargetException, IllegalAccessException {

        if (!entityConstructorMap.containsKey(clazz)) {
            entityConstructorMap.put(clazz, getDefaultConstructor(clazz));
        }

        final Constructor<?> c = entityConstructorMap.get(clazz);

        return c.newInstance((Object[]) null);
    }
}


public interface TrackableEntity extends Trackable {
    
    void extendProfile(Statistics tracker);

}

proxyproblem

The only way to do it is if you use bytecode enhancement. Otherwise, to-one associations can be either Pojos or Proxies depending if you fetch them eagerly or lazily.

Therefore, you could mandate that the clients use bytecode enhancement with the AutoFetch library. You could also provide a custom bytecode enhancement mechanism base don the Hibernate one and provide more tracking options if you want.

1 Like

Yes, I am using bytecode enhancement (as I have the plugin in my pom.xml-file), but how does this help with my problem of proxies not being implementations of this interface? I mean if I enforce lazy loading on all types of associations, we will always have proxies instead of POJOs, but right now the problem has to do with the proxies.

Bytecode enhancement allows you to change the class signature any way you want. For your use case might be a great advantage.

1 Like

That sounds great, and it is absolutely what I want. Is there somewhere where I can read more about it? I have been looking around, and I’ve mostly just found how to install it and setting the four different enhancement capabilites available in the maven plugin. It would certainly help, since I am not sure how to use it properly. Maybe there’s a tutorial or something?

I’m not sure if there’s any such doc. But you can check the bytecode enhancement package in hibernate-core and the enhance tool which you can find in the hibernate-orm GitHub repository.

1 Like

Thanks again for all the help, I will try to share the solution when I figure it out.

1 Like