Why does getOldState return the new state for @OneToMany and @ManyToMany?

Hello,

I have a Spring Boot project where I use Hibernate.
I am writing exploratory tests to understand the behaviour of lifecycle listeners and I encountered a behaviour I did not expect.

Specifically, I made an entity with one field per type (normal, @ManyToOne, @OneToMany, @ManyToMany), and verified that the state in the event is what I expected… but I found that getOldState() returns the new state for @OneToMany and @ManyToMany fields.
Am I doing something wrong or is this expected behaviour?

This is a complete reproducer:

package exploratory.hibernate.beforeafter

import jdlf.compass.common.ProfileResolver
import org.hibernate.event.service.spi.EventListenerRegistry
import org.hibernate.event.spi.EventType
import org.hibernate.event.spi.PostUpdateEvent
import org.hibernate.event.spi.PostUpdateEventListener
import org.hibernate.internal.SessionFactoryImpl
import org.hibernate.persister.entity.EntityPersister
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Tag
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.api.assertAll
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.autoconfigure.EnableAutoConfiguration
import org.springframework.boot.autoconfigure.domain.EntityScan
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.data.jpa.repository.JpaRepository
import org.springframework.data.jpa.repository.config.EnableJpaRepositories
import org.springframework.stereotype.Component
import org.springframework.stereotype.Repository
import org.springframework.test.context.ActiveProfiles
import org.springframework.test.context.TestPropertySource
import java.util.UUID
import java.util.UUID.randomUUID
import javax.annotation.PostConstruct
import javax.persistence.Entity
import javax.persistence.EntityManagerFactory
import javax.persistence.Id
import javax.persistence.ManyToMany
import javax.persistence.ManyToOne
import javax.persistence.OneToMany
import javax.persistence.PersistenceUnit


@Tag("exploratory-test")
@SpringBootTest(
    webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
    classes = [
        AccessingBeforeAndAfterStatesListener::class,
        MyEntityRepository::class,
        HibernateListenerConfigurer::class,
    ]
)
@EntityScan(basePackages = ["exploratory.hibernate.beforeafter"])
@EnableJpaRepositories(basePackages = ["exploratory.hibernate.beforeafter"])
@EnableAutoConfiguration
@TestPropertySource(properties = ["spring.jpa.hibernate.ddl-auto=create-drop", "spring.flyway.enabled=false"])
@ActiveProfiles(resolver = ProfileResolver::class)
@TestInstance(TestInstance.Lifecycle.PER_METHOD)
class AccessingBeforeAndAfterStatesListenerTest {
    companion object {
        var oldState: Array<Any>? = null
        var newState: Array<Any>? = null
        var propertyNames: Array<String>? = null
    }

    @Autowired
    lateinit var myEntityRepository: MyEntityRepository

    @Test
    fun testCanAccessStateBeforeAndAfter() {
        val myEntity1 = MyEntity()
        myEntity1.id = randomUUID()
        myEntity1.name = "Goofy"
        myEntityRepository.save(myEntity1)

        val myEntity2 = MyEntity()
        myEntity2.id = randomUUID()
        myEntity2.name = "Pluto"
        myEntityRepository.save(myEntity2)

        // Double check we are not cheating
        assertAll(
            { assertNull(oldState) },
            { assertNull(newState) },
            { assertNull(propertyNames) },
        )

        // Double check saved state
        assertAll(
            { assertEquals("Goofy", myEntityRepository.findById(myEntity1.id).orElseThrow().name) },
            { assertNull(myEntityRepository.findById(myEntity1.id).orElseThrow().manyToOne) },
            { assertEquals(emptySet<MyEntity>(), myEntityRepository.findById(myEntity1.id).orElseThrow().oneToMany) },
            { assertEquals(emptySet<MyEntity>(), myEntityRepository.findById(myEntity1.id).orElseThrow().manyToMany) },
        )

        // Edit every field
        myEntity1.name = "Mickey"
        myEntity1.manyToOne = myEntity2
        myEntity1.oneToMany.add(myEntity2)
        myEntity1.manyToMany.add(myEntity2)
        myEntityRepository.save(myEntity1)

        // Double check saved state
        assertAll(
            { assertEquals("Mickey", myEntityRepository.findById(myEntity1.id).orElseThrow().name) },
            { assertEquals(myEntity2, myEntityRepository.findById(myEntity1.id).orElseThrow().manyToOne) },
            { assertEquals(setOf(myEntity2), myEntityRepository.findById(myEntity1.id).orElseThrow().oneToMany) },
            { assertEquals(setOf(myEntity2), myEntityRepository.findById(myEntity1.id).orElseThrow().manyToMany) },
        )

        // Get property indices
        val iName = propertyNames?.indexOf("name")!!
        val iManyToOne = propertyNames?.indexOf("manyToOne")!!
        val iOneToMany = propertyNames?.indexOf("oneToMany")!!
        val iManyToMany = propertyNames?.indexOf("manyToMany")!!

        // Assert we got the correct before and after states
        assertAll(
            { assertEquals(4, oldState?.size) },
            { assertEquals(4, newState?.size) },
            { assertEquals(4, propertyNames?.size) },

            { assertEquals("Goofy", oldState?.get(iName)) },
            { assertEquals("Mickey", newState?.get(iName)) },

            { assertNull(oldState?.get(iManyToOne)) },
            { assertEquals(myEntity2, newState?.get(iManyToOne)) },

            // HERE: Expected to pass, actually fails
            // { assertEquals(emptySet<MyEntity>(), oldState?.get(iOneToMany)) },
            // HERE: Expected to fail, actually passes
            { assertEquals(setOf(myEntity2), oldState?.get(iOneToMany)) },
            { assertEquals(setOf(myEntity2), newState?.get(iOneToMany)) },

            // HERE: Expected to pass, actually fails
            // { assertEquals(emptySet<MyEntity>(), oldState?.get(iManyToMany)) },
            // HERE: Expected to fail, actually passes
            { assertEquals(setOf(myEntity2), oldState?.get(iManyToMany)) },
            { assertEquals(setOf(myEntity2), newState?.get(iManyToMany)) },
        )
    }
}

@Component
class AccessingBeforeAndAfterStatesListener : PostUpdateEventListener {
    override fun onPostUpdate(event: PostUpdateEvent) {
        // Just to make show that this runs and gets the right data
        AccessingBeforeAndAfterStatesListenerTest.oldState = event.oldState
        AccessingBeforeAndAfterStatesListenerTest.newState = event.state
        AccessingBeforeAndAfterStatesListenerTest.propertyNames = event.persister.propertyNames
    }

    @Deprecated(
        "Use requiresPostCommitHandling instead",
        ReplaceWith("requiresPostCommitHandling")
    )
    override fun requiresPostCommitHanding(persister: EntityPersister): Boolean {
        return false
    }
}

@Entity
class MyEntity {
    @Id
    lateinit var id: UUID

    lateinit var name: String

    @ManyToOne
    var manyToOne: MyEntity? = null

    @OneToMany
    var oneToMany: MutableSet<MyEntity> = mutableSetOf()

    @ManyToMany
    var manyToMany: MutableSet<MyEntity> = mutableSetOf()

    override fun equals(other: Any?): Boolean {
        if (this === other) return true
        if (javaClass != other?.javaClass) return false

        other as MyEntity

        return id == other.id
    }

    override fun hashCode(): Int {
        return id.hashCode()
    }
}

@Repository
interface MyEntityRepository : JpaRepository<MyEntity, UUID>

@Component
class HibernateListenerConfigurer {
    @PersistenceUnit
    private val emf: EntityManagerFactory? = null

    @Autowired
    private val listener: AccessingBeforeAndAfterStatesListener? = null

    @PostConstruct
    protected fun init() {
        val sessionFactory = emf!!.unwrap(SessionFactoryImpl::class.java)
        val registry = sessionFactory.serviceRegistry.getService(
            EventListenerRegistry::class.java
        )
        registry.getEventListenerGroup(EventType.POST_UPDATE).appendListener(listener)
    }
}

I don’t know Kotlin enough to help you with that. Please rewrite this in Java if you want assistance.

Hi @beikov, no worries, sorry about that.

Please find the Java code below.

package exploratory.hibernate.beforeafter;

import jdlf.compass.common.ProfileResolver;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.domain.EntityScan;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.data.jpa.repository.config.EnableJpaRepositories;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.TestPropertySource;

import java.util.HashSet;
import java.util.Set;
import java.util.UUID;


@Tag("exploratory-test")
@SpringBootTest(
        webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
        classes = {
                AccessingBeforeAndAfterStatesListener.class,
                MyEntityRepository.class,
                HibernateListenerConfigurer.class
        }
)
@EntityScan(basePackages = {"exploratory.hibernate.beforeafter"})
@EnableJpaRepositories(basePackages = {"exploratory.hibernate.beforeafter"})
@EnableAutoConfiguration
@TestPropertySource(properties = {"spring.jpa.hibernate.ddl-auto=create-drop", "spring.flyway.enabled=false"})
@ActiveProfiles(resolver = ProfileResolver.class)
@TestInstance(TestInstance.Lifecycle.PER_METHOD)
public class AccessingBeforeAndAfterStatesListenerTest {
    public static Object[] oldState;
    public static Object[] newState;
    public static String[] propertyNames;

    @Autowired
    private MyEntityRepository myEntityRepository;

    @Test
    public void testCanAccessStateBeforeAndAfter() {
        MyEntity myEntity1 = new MyEntity();
        myEntity1.setId(UUID.randomUUID());
        myEntity1.setName("Goofy");
        myEntityRepository.save(myEntity1);

        MyEntity myEntity2 = new MyEntity();
        myEntity2.setId(UUID.randomUUID());
        myEntity2.setName("Pluto");
        myEntityRepository.save(myEntity2);

        // Double check we are not cheating
        Assertions.assertAll(
                () -> Assertions.assertNull(oldState),
                () -> Assertions.assertNull(newState),
                () -> Assertions.assertNull(propertyNames)
        );

        // Double check saved state
        Assertions.assertAll(
                () -> Assertions.assertEquals("Goofy", myEntityRepository.findById(myEntity1.getId()).orElseThrow().getName()),
                () -> Assertions.assertNull(myEntityRepository.findById(myEntity1.getId()).orElseThrow().getManyToOne()),
                () -> Assertions.assertEquals(new HashSet<MyEntity>(), myEntityRepository.findById(myEntity1.getId()).orElseThrow().getOneToMany()),
                () -> Assertions.assertEquals(new HashSet<MyEntity>(), myEntityRepository.findById(myEntity1.getId()).orElseThrow().getManyToMany())
        );

        // Edit every field
        myEntity1.setName("Mickey");
        myEntity1.setManyToOne(myEntity2);
        myEntity1.getOneToMany().add(myEntity2);
        myEntity1.getManyToMany().add(myEntity2);
        myEntityRepository.save(myEntity1);

        // Double check saved state
        Assertions.assertAll(
                () -> Assertions.assertEquals("Mickey", myEntityRepository.findById(myEntity1.getId()).orElseThrow().getName()),
                () -> Assertions.assertEquals(myEntity2, myEntityRepository.findById(myEntity1.getId()).orElseThrow().getManyToOne()),
                () -> Assertions.assertEquals(Set.of(myEntity2), myEntityRepository.findById(myEntity1.getId()).orElseThrow().getOneToMany()),
                () -> Assertions.assertEquals(Set.of(myEntity2), myEntityRepository.findById(myEntity1.getId()).orElseThrow().getManyToMany())
        );

        // Get property indices
        int iName = indexOf(propertyNames, "name");
        int iManyToOne = indexOf(propertyNames, "manyToOne");
        int iOneToMany = indexOf(propertyNames, "oneToMany");
        int iManyToMany = indexOf(propertyNames, "manyToMany");

        // Assert we got the correct before and after states
        Assertions.assertAll(
                () -> Assertions.assertEquals(4, oldState.length),
                () -> Assertions.assertEquals(4, newState.length),
                () -> Assertions.assertEquals(4, propertyNames.length),

                () -> Assertions.assertEquals("Goofy", oldState[iName]),
                () -> Assertions.assertEquals("Mickey", newState[iName]),

                () -> Assertions.assertNull(oldState[iManyToOne]),
                () -> Assertions.assertEquals(myEntity2, newState[iManyToOne]),

                // HERE: Expected to pass, actually fails
                // () -> Assertions.assertEquals(new HashSet<MyEntity>(), oldState[iOneToMany]),
                // HERE: Expected to fail, actually passes
                () -> Assertions.assertEquals(Set.of(myEntity2), oldState[iOneToMany]),
                () -> Assertions.assertEquals(Set.of(myEntity2), newState[iOneToMany]),

                // HERE: Expected to pass, actually fails
                // () -> Assertions.assertEquals(new HashSet<MyEntity>(), oldState[iManyToMany]),
                // HERE: Expected to fail, actually passes
                () -> Assertions.assertEquals(Set.of(myEntity2), oldState[iManyToMany]),
                () -> Assertions.assertEquals(Set.of(myEntity2), newState[iManyToMany])
        );
    }

    private int indexOf(String[] array, String value) {
        for (int i = 0; i < array.length; i++) {
            if (array[i].equals(value)) {
                return i;
            }
        }
        return -1;
    }
}
package exploratory.hibernate.beforeafter;

import org.hibernate.event.spi.PostUpdateEvent;
import org.hibernate.event.spi.PostUpdateEventListener;
import org.hibernate.persister.entity.EntityPersister;
import org.springframework.stereotype.Component;

@Component
class AccessingBeforeAndAfterStatesListener implements PostUpdateEventListener {
    @Override
    public void onPostUpdate(PostUpdateEvent event) {
        AccessingBeforeAndAfterStatesListenerTest.oldState = event.getOldState();
        AccessingBeforeAndAfterStatesListenerTest.newState = event.getState();
        AccessingBeforeAndAfterStatesListenerTest.propertyNames = event.getPersister().getPropertyNames();
    }

    @Override
    @Deprecated
    public boolean requiresPostCommitHanding(EntityPersister persister) {
        return false;
    }
}
package exploratory.hibernate.beforeafter;

import org.hibernate.event.service.spi.EventListenerRegistry;
import org.hibernate.event.spi.EventType;
import org.hibernate.internal.SessionFactoryImpl;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import javax.persistence.EntityManagerFactory;
import javax.persistence.PersistenceUnit;

@Component
class HibernateListenerConfigurer {
    @PersistenceUnit
    private EntityManagerFactory emf;

    @Autowired
    private AccessingBeforeAndAfterStatesListener listener;

    @PostConstruct
    protected void init() {
        SessionFactoryImpl sessionFactory = emf.unwrap(SessionFactoryImpl.class);
        EventListenerRegistry registry = sessionFactory.getServiceRegistry().getService(EventListenerRegistry.class);
        registry.getEventListenerGroup(EventType.POST_UPDATE).appendListener(listener);
    }
}
package exploratory.hibernate.beforeafter;

import javax.persistence.Entity;
import javax.persistence.Id;
import javax.persistence.ManyToMany;
import javax.persistence.ManyToOne;
import javax.persistence.OneToMany;
import java.util.HashSet;
import java.util.Set;
import java.util.UUID;

@Entity
class MyEntity {
    @Id
    private UUID id;

    private String name;

    @ManyToOne
    private MyEntity manyToOne;

    @OneToMany
    private Set<MyEntity> oneToMany = new HashSet<>();

    @ManyToMany
    private Set<MyEntity> manyToMany = new HashSet<>();

    public UUID getId() {
        return id;
    }

    public void setId(UUID id) {
        this.id = id;
    }

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public MyEntity getManyToOne() {
        return manyToOne;
    }

    public void setManyToOne(MyEntity manyToOne) {
        this.manyToOne = manyToOne;
    }

    public Set<MyEntity> getOneToMany() {
        return oneToMany;
    }

    public void setOneToMany(Set<MyEntity> oneToMany) {
        this.oneToMany = oneToMany;
    }

    public Set<MyEntity> getManyToMany() {
        return manyToMany;
    }

    public void setManyToMany(Set<MyEntity> manyToMany) {
        this.manyToMany = manyToMany;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || getClass() != o.getClass()) {
            return false;
        }
        MyEntity myEntity = (MyEntity) o;
        return id.equals(myEntity.id);
    }

    @Override
    public int hashCode() {
        return id.hashCode();
    }
}
package exploratory.hibernate.beforeafter;

import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;

import java.util.UUID;

@Repository
interface MyEntityRepository extends JpaRepository<MyEntity, UUID> {
}

So am I understanding this right, that the problem you have is that oldState for collections refers to the same PersistentCollection instance which is also referred to in the newState array, which is non-empty?

Yes, that’s correct. Is this expected behaviour? Is it possible to get both the old and new states for one-to-many and many-to-one collections?

You will have to diff that yourself by comparing the map with it’s snapshot via PersistentCollection#getStoredSnapshot(). Not 100% sure why the snapshot is not exposed directly in the oldState, but it has been this way for a long time. I guess one could argue that this is correct though since the collection is managed and does its own dirty tracking, as well as the collection field in the entity didn’t change.

Thanks!

I’ll give you my unrequested take on this:

As someone that comes into Hibernate from tutorials (as opposed to reading TFM, which I use as a reference, but not as study material), this is unexpected behaviour. From my uninformed point of view, there are entities and there are the fields of entities. The idea that collections are managed separately from entities, while making sense in hindsight and likely being quite the neat design choice, is not the first “mental model” one comes to - at least it didn’t come to me.

With all that said, Hibernate does a lot of heavy lifting and I am not even sure if it’s possible to hide this complexity in a “less surprising” API. I am not sure what I could suggest, but perhaps, it could be worth spelling out in the Javadoc of PostUpdateEventListener what it does and what it does not. I suggest this as, in my workflow, the first level of documentation I read is Javadocs, as it’s really easy to do from within IntelliJ. Currently, the Javadoc for PostUpdateEventListener is a oneliner:

Called after updating the datastore

AI proposes this Javadoc based on my tests:

/**
 * Listener interface for handling Hibernate post-update events.
 * <p>
 * Implementations of this interface are notified after an entity has been updated in the database,
 * allowing for custom logic or side effects based on the entity's updated state. These events are
 * fired after a flush occurs but before the transaction is committed.
 *
 * <h2>Key Features</h2>
 * <ul>
 *   <li>Access to both old and new entity states</li>
 *   <li>Ability to perform custom logic post-update</li>
 *   <li>Executes within the same transaction as the update</li>
 * </ul>
 *
 * <h2>Usage</h2>
 * <ol>
 *   <li>Implement this interface in your custom listener class</li>
 *   <li>Register your listener with Hibernate's {@code EventListenerRegistry}</li>
 * </ol>
 *
 * <h2>Important Considerations</h2>
 * <ul>
 *   <li><strong>Collection Changes:</strong> Not captured by this listener. Use 
 *       {@link PostCollectionUpdateEventListener} for collection updates.</li>
 *   <li><strong>Transient Fields:</strong> Changes to transient or unmapped fields do not trigger this event.</li>
 *   <li><strong>Performance:</strong> Complex logic may impact performance, especially with large datasets.</li>
 *   <li><strong>Thread Safety:</strong> Implementations must be thread-safe.</li>
 * </ul>
 *
 * <h2>Best Practices</h2>
 * <ul>
 *   <li>Keep {@code onPostUpdate} logic lightweight to minimize performance impact</li>
 *   <li>Avoid modifying the entity or persisting new entities within the listener</li>
 *   <li>Use try-catch blocks for graceful exception handling</li>
 *   <li>Consider asynchronous processing for heavy computations or I/O operations</li>
 * </ul>
 *
 * @see PostCollectionUpdateEventListener
 * @see org.hibernate.event.spi.PostUpdateEvent
 */

In case other people land here, this is how I edited the tests to show what is and isn’t possible with PostUpdateEventListener. Feel free to include these tests (edited or not) in Hibernate’s examples, tutorials, docs, or whatever you deem sensible.

Kotlin, original code first, followed by AI-converted Java code (I didn’t check it, but last time it did a good job)

package exploratory.hibernate.beforeafter

import jdlf.compass.common.ProfileResolver
import org.hibernate.event.service.spi.EventListenerRegistry
import org.hibernate.event.spi.EventType
import org.hibernate.event.spi.PostCollectionUpdateEvent
import org.hibernate.event.spi.PostCollectionUpdateEventListener
import org.hibernate.event.spi.PostUpdateEvent
import org.hibernate.event.spi.PostUpdateEventListener
import org.hibernate.internal.SessionFactoryImpl
import org.hibernate.persister.entity.EntityPersister
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Tag
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.api.assertAll
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.autoconfigure.EnableAutoConfiguration
import org.springframework.boot.autoconfigure.domain.EntityScan
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.data.jpa.repository.JpaRepository
import org.springframework.data.jpa.repository.config.EnableJpaRepositories
import org.springframework.stereotype.Component
import org.springframework.stereotype.Repository
import org.springframework.test.context.ActiveProfiles
import org.springframework.test.context.TestPropertySource
import java.io.Serializable
import java.util.UUID
import java.util.UUID.randomUUID
import javax.annotation.PostConstruct
import javax.persistence.Entity
import javax.persistence.EntityManagerFactory
import javax.persistence.Id
import javax.persistence.ManyToMany
import javax.persistence.ManyToOne
import javax.persistence.OneToMany
import javax.persistence.PersistenceUnit


@Tag("exploratory-test")
@SpringBootTest(
    webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
    classes = [
        AccessingBeforeAndAfterStatesListener::class,
        AccessingBeforeAndAfterStatesCollectionListener::class,
        MyEntityRepository::class,
        HibernateListenerConfigurer::class,
    ]
)
@EntityScan(basePackages = ["exploratory.hibernate.beforeafter"])
@EnableJpaRepositories(basePackages = ["exploratory.hibernate.beforeafter"])
@EnableAutoConfiguration
@TestPropertySource(properties = ["spring.jpa.hibernate.ddl-auto=create-drop", "spring.flyway.enabled=false"])
@ActiveProfiles(resolver = ProfileResolver::class)
@TestInstance(TestInstance.Lifecycle.PER_METHOD)
class AccessingBeforeAndAfterStatesListenerTest {
    companion object {
        var oldState: MutableList<Any> = mutableListOf()
        var newState: MutableList<Any> = mutableListOf()
        var propertyNames: MutableList<String> = mutableListOf()
    }

    @Autowired
    lateinit var myEntityRepository: MyEntityRepository

    @Test
    fun testCanAccessStateBeforeAndAfter() {
        val myEntity1 = MyEntity()
        myEntity1.id = randomUUID()
        myEntity1.name = "Goofy"
        myEntityRepository.save(myEntity1)

        val myEntity2 = MyEntity()
        myEntity2.id = randomUUID()
        myEntity2.name = "Pluto"
        myEntityRepository.save(myEntity2)

        oldState.clear()
        newState.clear()
        propertyNames.clear()

        // Double check saved state
        assertAll(
            { assertEquals("Goofy", myEntityRepository.findById(myEntity1.id).orElseThrow().name) },
            { assertNull(myEntityRepository.findById(myEntity1.id).orElseThrow().manyToOne) },
            { assertEquals(emptySet<MyEntity>(), myEntityRepository.findById(myEntity1.id).orElseThrow().oneToMany) },
            { assertEquals(emptySet<MyEntity>(), myEntityRepository.findById(myEntity1.id).orElseThrow().manyToMany) },
        )

        // Edit every field
        myEntity1.name = "Mickey"
        myEntity1.manyToOne = myEntity2
        myEntityRepository.save(myEntity1)

        // Double check saved state
        assertAll(
            { assertEquals("Mickey", myEntityRepository.findById(myEntity1.id).orElseThrow().name) },
            { assertEquals(myEntity2, myEntityRepository.findById(myEntity1.id).orElseThrow().manyToOne) },
        )

        // Get property indices
        val iName = propertyNames.indexOf("name")
        val iManyToOne = propertyNames.indexOf("manyToOne")

        // Assert we got the correct before and after states
        assertAll(
            { assertEquals(4, oldState.size) },
            { assertEquals(4, newState.size) },
            { assertEquals(4, propertyNames.size) },

            { assertEquals("Goofy", oldState[iName]) },
            { assertEquals("Mickey", newState[iName]) },

            { assertNull(oldState[iManyToOne]) },
            { assertEquals(myEntity2, newState[iManyToOne]) }
        )
    }

    @Test
    fun testCanAccessCollectionStateBeforeAndAfter() {
        val myEntity1 = MyEntity()
        myEntity1.id = randomUUID()
        myEntity1.name = "Goofy"
        myEntityRepository.save(myEntity1)

        val myEntity2 = MyEntity()
        myEntity2.id = randomUUID()
        myEntity2.name = "Pluto"
        myEntityRepository.save(myEntity2)

        oldState.clear()
        newState.clear()
        propertyNames.clear()

        // Double check saved state
        assertAll(
            { assertEquals(emptySet<MyEntity>(), myEntityRepository.findById(myEntity1.id).orElseThrow().oneToMany) },
            { assertEquals(emptySet<MyEntity>(), myEntityRepository.findById(myEntity1.id).orElseThrow().manyToMany) },
        )

        // Edit every field
        myEntity1.oneToMany.add(myEntity2)
        myEntity1.manyToMany.add(myEntity2)
        myEntityRepository.save(myEntity1)

        // Double check saved state
        assertAll(
            { assertEquals(setOf(myEntity2), myEntityRepository.findById(myEntity1.id).orElseThrow().oneToMany) },
            { assertEquals(setOf(myEntity2), myEntityRepository.findById(myEntity1.id).orElseThrow().manyToMany) },
        )

        // Get property indices
        val iOneToMany = propertyNames.indexOf("oneToMany")
        val iManyToMany = propertyNames.indexOf("manyToMany")

        // Assert we got the correct before and after states
        assertAll(
            { assertEquals(emptySet<MyEntity>(), oldState[iOneToMany]) },
            { assertEquals(setOf(myEntity2), newState[iOneToMany]) },

            { assertEquals(emptySet<MyEntity>(), oldState[iManyToMany]) },
            { assertEquals(setOf(myEntity2), newState[iManyToMany]) },
        )
    }

}

@Component
class AccessingBeforeAndAfterStatesListener : PostUpdateEventListener {
    override fun onPostUpdate(event: PostUpdateEvent) {
        // Just to make show that this runs and gets the right data
        AccessingBeforeAndAfterStatesListenerTest.oldState = event.oldState.toMutableList()
        AccessingBeforeAndAfterStatesListenerTest.newState = event.state.toMutableList()
        AccessingBeforeAndAfterStatesListenerTest.propertyNames = event.persister.propertyNames.toMutableList()
    }

    @Deprecated(
        "Use requiresPostCommitHandling instead",
        ReplaceWith("requiresPostCommitHandling")
    )
    override fun requiresPostCommitHanding(persister: EntityPersister): Boolean {
        return false
    }
}

@Component
class AccessingBeforeAndAfterStatesCollectionListener : PostCollectionUpdateEventListener {
    override fun onPostUpdateCollection(event: PostCollectionUpdateEvent) {
        // Just to make show that this runs and gets the right data
        AccessingBeforeAndAfterStatesListenerTest.newState.add(event.collection)
        AccessingBeforeAndAfterStatesListenerTest.oldState.add(getOldState(event))
        AccessingBeforeAndAfterStatesListenerTest.propertyNames.add(getFieldName(event))
    }

    private fun getFieldName(event: PostCollectionUpdateEvent): String {
        val role: String = event.collection.role
        val fieldName = role.substring(role.lastIndexOf('.') + 1)
        return fieldName
    }

    private fun getOldState(event: PostCollectionUpdateEvent): Any {
        val oldSnapshot = event.collection.storedSnapshot
        val collectionClass = getCollectionClass(event)

        return instantiateSnapshotAsCollectionClass(collectionClass, oldSnapshot)
    }

    private fun instantiateSnapshotAsCollectionClass(
        collectionClass: Class<out Collection<*>>,
        oldSnapshot: Serializable?
    ): Any {
        if (Map::class.java.isAssignableFrom(collectionClass)) {
            return oldSnapshot as Map<*, *>
        } else {
            // For some reason, the stored snapshot can be a map when the field is a set
            val values = if (oldSnapshot is Map<*, *>)
                oldSnapshot.keys
            else
                oldSnapshot as Collection<*>
            val constructor = collectionClass.getConstructor(Collection::class.java)
            return constructor.newInstance(values) as Collection<*>
        }
    }

    private fun getCollectionClass(event: PostCollectionUpdateEvent): Class<out Collection<*>> {
        val role: String = event.collection.role
        val factory = event.session.factory
        val metamodel = factory.metamodel
        val persister = metamodel.collectionPersister(role)

        @Suppress("UNCHECKED_CAST")
        val collectionClass = persister.collectionType.returnedClass as Class<out Collection<*>>
        return collectionClass
    }
}

@Entity
class MyEntity {
    @Id
    lateinit var id: UUID

    lateinit var name: String

    @ManyToOne
    var manyToOne: MyEntity? = null

    @OneToMany
    var oneToMany: MutableSet<MyEntity> = mutableSetOf()

    @ManyToMany
    var manyToMany: MutableSet<MyEntity> = mutableSetOf()

    override fun equals(other: Any?): Boolean {
        if (this === other) return true
        if (javaClass != other?.javaClass) return false

        other as MyEntity

        return id == other.id
    }

    override fun hashCode(): Int {
        return id.hashCode()
    }
}

@Repository
interface MyEntityRepository : JpaRepository<MyEntity, UUID>

@Component
class HibernateListenerConfigurer {
    @PersistenceUnit
    private val emf: EntityManagerFactory? = null

    @Autowired
    private val listener: AccessingBeforeAndAfterStatesListener? = null

    @Autowired
    private val collectionListener: AccessingBeforeAndAfterStatesCollectionListener? = null

    @PostConstruct
    protected fun init() {
        val sessionFactory = emf!!.unwrap(SessionFactoryImpl::class.java)
        val registry = sessionFactory.serviceRegistry.getService(
            EventListenerRegistry::class.java
        )
        registry.getEventListenerGroup(EventType.POST_UPDATE).appendListener(listener)
        registry.getEventListenerGroup(EventType.POST_COLLECTION_UPDATE).appendListener(collectionListener)
    }
}
package exploratory.hibernate.beforeafter;

import jdlf.compass.common.ProfileResolver;
import org.hibernate.event.spi.*;
import org.hibernate.persister.entity.EntityPersister;
import org.junit.jupiter.api.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.domain.EntityScan;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.config.EnableJpaRepositories;
import org.springframework.stereotype.Component;
import org.springframework.stereotype.Repository;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.TestPropertySource;

import javax.persistence.*;
import java.io.Serializable;
import java.util.*;

@Tag("exploratory-test")
@SpringBootTest(
    webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
    classes = {
        AccessingBeforeAndAfterStatesListener.class,
        AccessingBeforeAndAfterStatesCollectionListener.class,
        MyEntityRepository.class,
        HibernateListenerConfigurer.class
    }
)
@EntityScan(basePackages = {"exploratory.hibernate.beforeafter"})
@EnableJpaRepositories(basePackages = {"exploratory.hibernate.beforeafter"})
@EnableAutoConfiguration
@TestPropertySource(properties = {"spring.jpa.hibernate.ddl-auto=create-drop", "spring.flyway.enabled=false"})
@ActiveProfiles(resolver = ProfileResolver.class)
@TestInstance(TestInstance.Lifecycle.PER_METHOD)
class AccessingBeforeAndAfterStatesListenerTest {
    private static List<Object> oldState = new ArrayList<>();
    private static List<Object> newState = new ArrayList<>();
    private static List<String> propertyNames = new ArrayList<>();

    @Autowired
    private MyEntityRepository myEntityRepository;

    @Test
    void testCanAccessStateBeforeAndAfter() {
        MyEntity myEntity1 = new MyEntity();
        myEntity1.setId(UUID.randomUUID());
        myEntity1.setName("Goofy");
        myEntityRepository.save(myEntity1);

        MyEntity myEntity2 = new MyEntity();
        myEntity2.setId(UUID.randomUUID());
        myEntity2.setName("Pluto");
        myEntityRepository.save(myEntity2);

        oldState.clear();
        newState.clear();
        propertyNames.clear();

        // Double check saved state
        Assertions.assertAll(
            () -> Assertions.assertEquals("Goofy", myEntityRepository.findById(myEntity1.getId()).orElseThrow().getName()),
            () -> Assertions.assertNull(myEntityRepository.findById(myEntity1.getId()).orElseThrow().getManyToOne()),
            () -> Assertions.assertEquals(Collections.emptySet(), myEntityRepository.findById(myEntity1.getId()).orElseThrow().getOneToMany()),
            () -> Assertions.assertEquals(Collections.emptySet(), myEntityRepository.findById(myEntity1.getId()).orElseThrow().getManyToMany())
        );

        // Edit every field
        myEntity1.setName("Mickey");
        myEntity1.setManyToOne(myEntity2);
        myEntityRepository.save(myEntity1);

        // Double check saved state
        Assertions.assertAll(
            () -> Assertions.assertEquals("Mickey", myEntityRepository.findById(myEntity1.getId()).orElseThrow().getName()),
            () -> Assertions.assertEquals(myEntity2, myEntityRepository.findById(myEntity1.getId()).orElseThrow().getManyToOne())
        );

        // Get property indices
        int iName = propertyNames.indexOf("name");
        int iManyToOne = propertyNames.indexOf("manyToOne");

        // Assert we got the correct before and after states
        Assertions.assertAll(
            () -> Assertions.assertEquals(4, oldState.size()),
            () -> Assertions.assertEquals(4, newState.size()),
            () -> Assertions.assertEquals(4, propertyNames.size()),

            () -> Assertions.assertEquals("Goofy", oldState.get(iName)),
            () -> Assertions.assertEquals("Mickey", newState.get(iName)),

            () -> Assertions.assertNull(oldState.get(iManyToOne)),
            () -> Assertions.assertEquals(myEntity2, newState.get(iManyToOne))
        );
    }

    @Test
    void testCanAccessCollectionStateBeforeAndAfter() {
        MyEntity myEntity1 = new MyEntity();
        myEntity1.setId(UUID.randomUUID());
        myEntity1.setName("Goofy");
        myEntityRepository.save(myEntity1);

        MyEntity myEntity2 = new MyEntity();
        myEntity2.setId(UUID.randomUUID());
        myEntity2.setName("Pluto");
        myEntityRepository.save(myEntity2);

        oldState.clear();
        newState.clear();
        propertyNames.clear();

        // Double check saved state
        Assertions.assertAll(
            () -> Assertions.assertEquals(Collections.emptySet(), myEntityRepository.findById(myEntity1.getId()).orElseThrow().getOneToMany()),
            () -> Assertions.assertEquals(Collections.emptySet(), myEntityRepository.findById(myEntity1.getId()).orElseThrow().getManyToMany())
        );

        // Edit every field
        myEntity1.getOneToMany().add(myEntity2);
        myEntity1.getManyToMany().add(myEntity2);
        myEntityRepository.save(myEntity1);

        // Double check saved state
        Assertions.assertAll(
            () -> Assertions.assertEquals(Collections.singleton(myEntity2), myEntityRepository.findById(myEntity1.getId()).orElseThrow().getOneToMany()),
            () -> Assertions.assertEquals(Collections.singleton(myEntity2), myEntityRepository.findById(myEntity1.getId()).orElseThrow().getManyToMany())
        );

        // Get property indices
        int iOneToMany = propertyNames.indexOf("oneToMany");
        int iManyToMany = propertyNames.indexOf("manyToMany");

        // Assert we got the correct before and after states
        Assertions.assertAll(
            () -> Assertions.assertEquals(Collections.emptySet(), oldState.get(iOneToMany)),
            () -> Assertions.assertEquals(Collections.singleton(myEntity2), newState.get(iOneToMany)),

            () -> Assertions.assertEquals(Collections.emptySet(), oldState.get(iManyToMany)),
            () -> Assertions.assertEquals(Collections.singleton(myEntity2), newState.get(iManyToMany))
        );
    }
}

@Component
class AccessingBeforeAndAfterStatesListener implements PostUpdateEventListener {
    @Override
    public void onPostUpdate(PostUpdateEvent event) {
        AccessingBeforeAndAfterStatesListenerTest.oldState = new ArrayList<>(Arrays.asList(event.getOldState()));
        AccessingBeforeAndAfterStatesListenerTest.newState = new ArrayList<>(Arrays.asList(event.getState()));
        AccessingBeforeAndAfterStatesListenerTest.propertyNames = new ArrayList<>(Arrays.asList(event.getPersister().getPropertyNames()));
    }

    @Override
    @Deprecated
    public boolean requiresPostCommitHanding(EntityPersister persister) {
        return false;
    }
}

@Component
class AccessingBeforeAndAfterStatesCollectionListener implements PostCollectionUpdateEventListener {
    @Override
    public void onPostUpdateCollection(PostCollectionUpdateEvent event) {
        AccessingBeforeAndAfterStatesListenerTest.newState.add(event.getCollection());
        AccessingBeforeAndAfterStatesListenerTest.oldState.add(getOldState(event));
        AccessingBeforeAndAfterStatesListenerTest.propertyNames.add(getFieldName(event));
    }

    private String getFieldName(PostCollectionUpdateEvent event) {
        String role = event.getCollection().getRole();
        return role.substring(role.lastIndexOf('.') + 1);
    }

    private Object getOldState(PostCollectionUpdateEvent event) {
        Serializable oldSnapshot = event.getCollection().getStoredSnapshot();
        Class<? extends Collection<?>> collectionClass = getCollectionClass(event);

        return instantiateSnapshotAsCollectionClass(collectionClass, oldSnapshot);
    }

    private Object instantiateSnapshotAsCollectionClass(Class<? extends Collection<?>> collectionClass, Serializable oldSnapshot) {
        if (Map.class.isAssignableFrom(collectionClass)) {
            return (Map<?, ?>) oldSnapshot;
        } else {
            Collection<?> values = (oldSnapshot instanceof Map) ? ((Map<?, ?>) oldSnapshot).keySet() : (Collection<?>) oldSnapshot;
            try {
                return collectionClass.getConstructor(Collection.class).newInstance(values);
            } catch (ReflectiveOperationException e) {
                throw new RuntimeException(e);
            }
        }
    }

    private Class<? extends Collection<?>> getCollectionClass(PostCollectionUpdateEvent event) {
        String role = event.getCollection().getRole();
        org.hibernate.engine.spi.SessionFactoryImplementor factory = event.getSession().getFactory();
        org.hibernate.metamodel.spi.MetamodelImplementor metamodel = factory.getMetamodel();
        org.hibernate.persister.collection.CollectionPersister persister = metamodel.collectionPersister(role);

        @SuppressWarnings("unchecked")
        Class<? extends Collection<?>> collectionClass = (Class<? extends Collection<?>>) persister.getCollectionType().getReturnedClass();
        return collectionClass;
    }
}

@Entity
class MyEntity {
    @Id
    private UUID id;

    private String name;

    @ManyToOne
    private MyEntity manyToOne;

    @OneToMany
    private Set<MyEntity> oneToMany = new HashSet<>();

    @ManyToMany
    private Set<MyEntity> manyToMany = new HashSet<>();

    // Getters and setters

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        MyEntity myEntity = (MyEntity) o;
        return Objects.equals(id, myEntity.id);
    }

    @Override
    public int hashCode() {
        return Objects.hash(id);
    }
}

@Repository
interface MyEntityRepository extends JpaRepository<MyEntity, UUID> {}

@Component
class HibernateListenerConfigurer {
    @PersistenceUnit
    private EntityManagerFactory emf;

    @Autowired
    private AccessingBeforeAndAfterStatesListener listener;

    @Autowired
    private AccessingBeforeAndAfterStatesCollectionListener collectionListener;

    @PostConstruct
    protected void init() {
        SessionFactoryImpl sessionFactory = emf.unwrap(SessionFactoryImpl.class);
        EventListenerRegistry registry = sessionFactory.getServiceRegistry().getService(EventListenerRegistry.class);
        registry.getEventListenerGroup(EventType.POST_UPDATE).appendListener(listener);
        registry.getEventListenerGroup(EventType.POST_COLLECTION_UPDATE).appendListener(collectionListener);
    }
}

I agree that the Javadoc should mention this potential pitfall. Please create a PR with a note about this.