diff --git a/core/src/main/java/dev/morphia/mapping/codec/DecodeSession.java b/core/src/main/java/dev/morphia/mapping/codec/DecodeSession.java new file mode 100644 index 00000000000..9a23f680f2c --- /dev/null +++ b/core/src/main/java/dev/morphia/mapping/codec/DecodeSession.java @@ -0,0 +1,89 @@ +package dev.morphia.mapping.codec; + +import java.util.HashMap; +import java.util.Map; + +import com.mongodb.lang.Nullable; + +import dev.morphia.annotations.internal.MorphiaInternal; + +/** + * Per-document-decode cache that maps (collection, id) → entity instance. + * Activated via {@link DecodeSession#activate()} before each document decode + * and cleared via {@link DecodeSession#deactivate()} afterwards. + * + * @hidden + * @morphia.internal + */ +@MorphiaInternal +public class DecodeSession { + private static final ThreadLocal CURRENT = new ThreadLocal<>(); + + private final Map> cache = new HashMap<>(); + + private DecodeSession() { + } + + /** + * Activates a session on the current thread. If a session is already active it is + * reused, so nested activations (e.g. fetching a @Reference while decoding an outer + * document) share one cache. Returns {@code true} if this call created the root session + * and therefore owns the responsibility of calling {@link #deactivate()}. + * + * @return {@code true} if a new root session was created; {@code false} if an existing session was reused + */ + public static boolean activate() { + if (CURRENT.get() != null) { + return false; + } + CURRENT.set(new DecodeSession()); + return true; + } + + /** + * Returns the session active on the current thread, or {@code null} if none. + */ + @Nullable + public static DecodeSession current() { + return CURRENT.get(); + } + + /** + * Removes the session from the current thread. + */ + public static void deactivate() { + CURRENT.remove(); + } + + /** + * Stores a decoded entity in the cache. + * + * @param collection the MongoDB collection name + * @param id the entity's {@code _id} value + * @param entity the decoded entity instance + */ + public void register(String collection, Object id, Object entity) { + cache.computeIfAbsent(collection, k -> new HashMap<>()).put(id, entity); + } + + /** + * Returns a previously cached entity, or {@code null} if not present. + * + * @param collection the MongoDB collection name + * @param id the entity's {@code _id} value + */ + @Nullable + public Object lookup(String collection, Object id) { + Map col = cache.get(collection); + return col != null ? col.get(id) : null; + } + + /** + * Returns {@code true} if an entity with this collection+id is already in the cache + * (even if still being populated — used for cycle detection). + */ + public boolean contains(String collection, Object id) { + Map col = cache.get(collection); + return col != null && col.containsKey(id); + } +} diff --git a/core/src/main/java/dev/morphia/mapping/codec/pojo/EntityDecoder.java b/core/src/main/java/dev/morphia/mapping/codec/pojo/EntityDecoder.java index 67741760b74..65ead0db7fc 100644 --- a/core/src/main/java/dev/morphia/mapping/codec/pojo/EntityDecoder.java +++ b/core/src/main/java/dev/morphia/mapping/codec/pojo/EntityDecoder.java @@ -4,6 +4,7 @@ import dev.morphia.annotations.internal.MorphiaInternal; import dev.morphia.mapping.DiscriminatorLookup; +import dev.morphia.mapping.codec.DecodeSession; import dev.morphia.mapping.codec.MorphiaInstanceCreator; import org.bson.BsonInvalidOperationException; @@ -45,7 +46,29 @@ public T decode(BsonReader reader, DecoderContext decoderContext) { if (decoderContext.hasCheckedDiscriminator()) { LOG.debug(format("Decoding document using codec for %s'", morphiaCodec.getEntityModel().getType().getName())); MorphiaInstanceCreator instanceCreator = getInstanceCreator(); + T instance = (T) instanceCreator.getInstance(); + + DecodeSession session = DecodeSession.current(); + Object prereadId = null; + if (session != null) { + prereadId = peekId(reader); + if (prereadId != null) { + session.register(classModel.collectionName(), prereadId, instance); + } + } + decodeProperties(reader, decoderContext, instanceCreator, classModel); + + if (session != null && prereadId == null) { + PropertyModel idProp = classModel.getIdProperty(); + if (idProp != null) { + Object id = morphiaCodec.getDatastore().getMapper().getId(instance); + if (id != null) { + session.register(classModel.collectionName(), id, instance); + } + } + } + return (T) instanceCreator.getInstance(); } else { entity = getCodecFromDocument(reader, classModel.useDiscriminator(), classModel.discriminatorKey(), @@ -117,6 +140,32 @@ protected Codec getCodecFromDocument(BsonReader reader, boolean useDiscrimina return codec != null ? codec : defaultCodec; } + @Nullable + private Object peekId(BsonReader reader) { + BsonReaderMark mark = reader.getMark(); + try { + reader.readStartDocument(); + String idName = classModel.getIdProperty() != null + ? classModel.getIdProperty().getMappedName() + : "_id"; + while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + String name = reader.readName(); + if ("_id".equals(name) || name.equals(idName)) { + return morphiaCodec.getRegistry() + .get(Object.class) + .decode(reader, DecoderContext.builder().build()); + } else { + reader.skipValue(); + } + } + return null; + } catch (Exception e) { + return null; + } finally { + mark.reset(); + } + } + protected MorphiaInstanceCreator getInstanceCreator() { return classModel.getInstanceCreator(morphiaCodec.getConversions()); } diff --git a/core/src/main/java/dev/morphia/mapping/codec/pojo/MorphiaCodec.java b/core/src/main/java/dev/morphia/mapping/codec/pojo/MorphiaCodec.java index 752104e74a1..bf9eaba57de 100644 --- a/core/src/main/java/dev/morphia/mapping/codec/pojo/MorphiaCodec.java +++ b/core/src/main/java/dev/morphia/mapping/codec/pojo/MorphiaCodec.java @@ -7,6 +7,7 @@ import dev.morphia.mapping.DiscriminatorLookup; import dev.morphia.mapping.MappingException; import dev.morphia.mapping.codec.Conversions; +import dev.morphia.mapping.codec.DecodeSession; import dev.morphia.mapping.codec.PropertyCodecRegistryImpl; import dev.morphia.sofia.Sofia; @@ -77,7 +78,14 @@ public MorphiaCodec(MorphiaDatastore datastore, EntityModel model, @Override public T decode(BsonReader reader, DecoderContext decoderContext) { - return getDecoder().decode(reader, decoderContext); + boolean root = DecodeSession.activate(); + try { + return getDecoder().decode(reader, decoderContext); + } finally { + if (root) { + DecodeSession.deactivate(); + } + } } @Override diff --git a/core/src/main/java/dev/morphia/mapping/codec/references/ReferenceCodec.java b/core/src/main/java/dev/morphia/mapping/codec/references/ReferenceCodec.java index 508d93999fe..767992896fc 100644 --- a/core/src/main/java/dev/morphia/mapping/codec/references/ReferenceCodec.java +++ b/core/src/main/java/dev/morphia/mapping/codec/references/ReferenceCodec.java @@ -27,6 +27,7 @@ import dev.morphia.annotations.internal.MorphiaInternal; import dev.morphia.mapping.Mapper; import dev.morphia.mapping.MappingException; +import dev.morphia.mapping.codec.DecodeSession; import dev.morphia.mapping.codec.pojo.EntityModel; import dev.morphia.mapping.codec.pojo.PropertyHandler; import dev.morphia.mapping.codec.pojo.PropertyModel; @@ -322,6 +323,40 @@ private Class makeProxy() { .getLoaded(); } + @Nullable + private Object lookupInSession(Object id, EntityModel entityModel) { + DecodeSession session = DecodeSession.current(); + if (session == null) { + return null; + } + String collection = id instanceof DBRef + ? ((DBRef) id).getCollectionName() + : entityModel.collectionName(); + Object lookupId = id instanceof DBRef ? ((DBRef) id).getId() : id; + return session.lookup(collection, lookupId); + } + + @Nullable + private List lookupCollectionInSession(List rawIds, EntityModel entityModel) { + DecodeSession session = DecodeSession.current(); + if (session == null) { + return null; + } + List results = new ArrayList<>(); + for (Object id : rawIds) { + String collection = id instanceof DBRef + ? ((DBRef) id).getCollectionName() + : entityModel.collectionName(); + Object lookupId = id instanceof DBRef ? ((DBRef) id).getId() : id; + Object cached = session.lookup(collection, lookupId); + if (cached == null) { + return null; // at least one miss — fall through to DB fetch + } + results.add(cached); + } + return results; + } + @Nullable private Object fetch(Object value) { boolean lazy = annotation.lazy(); @@ -335,6 +370,10 @@ private Object fetch(Object value) { if (!preDecoded.isEmpty()) { return preDecoded; } + List cachedList = lookupCollectionInSession(rawIds, entityModel); + if (cachedList != null) { + return cachedList; + } List ids = stripDbRefs(rawIds); Supplier loader = () -> fetchCollection(rawIds, entityModel, ignoreMissing); return lazy ? createProxy(loader, ids, entityModel.getType()) : loader.get(); @@ -345,6 +384,10 @@ private Object fetch(Object value) { if (!preDecoded.isEmpty()) { return new LinkedHashSet<>(preDecoded); } + List cachedSet = lookupCollectionInSession(rawIds, entityModel); + if (cachedSet != null) { + return new LinkedHashSet<>(cachedSet); + } List ids = stripDbRefs(rawIds); Supplier loader = () -> new LinkedHashSet<>(fetchCollection(rawIds, entityModel, ignoreMissing)); return lazy ? createProxy(loader, ids, entityModel.getType()) : loader.get(); @@ -366,6 +409,10 @@ private Object fetch(Object value) { if (entityModel.getType().isInstance(id)) { return id; } + Object cached = lookupInSession(id, entityModel); + if (cached != null) { + return cached; + } List ids = List.of(stripDbRef(id)); Supplier loader = () -> fetchSingle(id, entityModel, ignoreMissing); return lazy ? createProxy(loader, ids, entityModel.getType()) : loader.get(); diff --git a/core/src/main/java/dev/morphia/query/MorphiaCursor.java b/core/src/main/java/dev/morphia/query/MorphiaCursor.java index a0a5efc99ed..9321f34f5d9 100644 --- a/core/src/main/java/dev/morphia/query/MorphiaCursor.java +++ b/core/src/main/java/dev/morphia/query/MorphiaCursor.java @@ -9,6 +9,7 @@ import com.mongodb.lang.NonNull; import dev.morphia.annotations.internal.MorphiaInternal; +import dev.morphia.mapping.codec.DecodeSession; /** * @param the original type being iterated @@ -44,7 +45,14 @@ public boolean hasNext() { @Override @NonNull public T next() { - return wrapped.next(); + boolean root = DecodeSession.activate(); + try { + return wrapped.next(); + } finally { + if (root) { + DecodeSession.deactivate(); + } + } } @Override @@ -54,7 +62,14 @@ public int available() { @Override public T tryNext() { - return wrapped.tryNext(); + boolean root = DecodeSession.activate(); + try { + return wrapped.tryNext(); + } finally { + if (root) { + DecodeSession.deactivate(); + } + } } @Override diff --git a/core/src/test/java/dev/morphia/test/mapping/TestReferences.java b/core/src/test/java/dev/morphia/test/mapping/TestReferences.java index c07928ec175..bd43153098d 100644 --- a/core/src/test/java/dev/morphia/test/mapping/TestReferences.java +++ b/core/src/test/java/dev/morphia/test/mapping/TestReferences.java @@ -1223,4 +1223,72 @@ public void setId(ObjectId id) { this.id = id; } } + + @Entity + private static class TwoRefContainer { + @Id + private ObjectId id; + @Reference(idOnly = true) + private Ref ref1; + @Reference(idOnly = true) + private Ref ref2; + } + + @Entity + private static class NodeA { + @Id + private ObjectId id = new ObjectId(); + private String name; + @Reference + private NodeB partner; + } + + @Entity + private static class NodeB { + @Id + private ObjectId id = new ObjectId(); + private String name; + @Reference + private NodeA partner; + } + + @Test + public void testReferenceDeduplication() { + // A single document with two @Reference fields pointing to the same entity. + // Both fields should decode to the same Java instance within one decode session. + Ref shared = new Ref("shared-ref"); + getDs().save(shared); + + TwoRefContainer container = new TwoRefContainer(); + container.ref1 = shared; + container.ref2 = shared; + getDs().save(container); + + TwoRefContainer loaded = getDs().find(TwoRefContainer.class).first(); + assertNotNull(loaded); + assertSame(loaded.ref1, loaded.ref2, "Both ref fields should point to the same Ref instance"); + } + + @Test + public void testCyclicReferenceDoesNotStackOverflow() { + NodeA a = new NodeA(); + a.name = "alpha"; + NodeB b = new NodeB(); + b.name = "beta"; + + getDs().save(a); + getDs().save(b); + + a.partner = b; + b.partner = a; + getDs().save(a); + getDs().save(b); + + NodeA loaded = getDs().find(NodeA.class).filter(eq("_id", a.id)).first(); + assertNotNull(loaded); + assertNotNull(loaded.partner); + assertEquals(loaded.partner.name, "beta"); + assertNotNull(loaded.partner.partner); + assertEquals(loaded.partner.partner.name, "alpha"); + } }