diff --git a/tavern/internal/portals/mux/mux_create.go b/tavern/internal/portals/mux/mux_create.go index 49ee8936e..02d1f9d91 100644 --- a/tavern/internal/portals/mux/mux_create.go +++ b/tavern/internal/portals/mux/mux_create.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "gocloud.dev/pubsub" "realm.pub/tavern/internal/ent" "realm.pub/tavern/internal/ent/task" ) @@ -48,69 +49,33 @@ func (m *Mux) CreatePortal(ctx context.Context, client *ent.Client, taskID int) topicOut := m.TopicOut(portalID) subName := m.SubName(topicIn) - // 2. Provisioning - // Ensure topics exist - if err := m.ensureTopic(ctx, topicIn); err != nil { - return portalID, nil, fmt.Errorf("failed to ensure topic in: %w", err) - } - if err := m.ensureTopic(ctx, topicOut); err != nil { - return portalID, nil, fmt.Errorf("failed to ensure topic out: %w", err) - } + provision := func() error { + // Ensure topics exist + if err := m.ensureTopic(ctx, topicIn); err != nil { + return fmt.Errorf("failed to ensure topic in: %w", err) + } + if err := m.ensureTopic(ctx, topicOut); err != nil { + return fmt.Errorf("failed to ensure topic out: %w", err) + } - // Ensure subscription exists - if err := m.ensureSub(ctx, topicIn, subName); err != nil { - return portalID, nil, fmt.Errorf("failed to ensure subscription: %w", err) + // Ensure subscription exists + if err := m.ensureSub(ctx, topicIn, subName); err != nil { + return fmt.Errorf("failed to ensure subscription: %w", err) + } + return nil } - // 3. Connect - // Updated SubURL usage - subURL := m.SubURL(topicIn, subName) - sub, err := m.openSubscription(ctx, subURL) - if err != nil { - return portalID, nil, fmt.Errorf("failed to open subscription %s: %w", subURL, err) + startLoop := func(ctxLoop context.Context, sub *pubsub.Subscription) { + m.receiveLoop(ctxLoop, topicIn, sub) } - // Store in subMgr - m.subMgr.Lock() - - // Check if existing sub - if existingSub, ok := m.subMgr.active[subName]; ok { - // Existing found. Shutdown new one and cleanup old one if needed. - if cancel, ok := m.subMgr.cancelFuncs[subName]; ok { - cancel() - } - // We are overwriting, so we must assume the old one is invalid or we are restarting. - existingSub.Shutdown(context.Background()) + teardownSub, err := m.acquireSubscription(ctx, subName, topicIn, provision, startLoop) + if err != nil { + return portalID, nil, err } - m.subMgr.active[subName] = sub - - ctxLoop, cancelLoop := context.WithCancel(context.Background()) - m.subMgr.cancelFuncs[subName] = cancelLoop - m.subMgr.Unlock() - - // 4. Spawn - go func() { - m.receiveLoop(ctxLoop, topicIn, sub) - }() - teardown := func() { - m.subMgr.Lock() - s, ok := m.subMgr.active[subName] - cancel := m.subMgr.cancelFuncs[subName] - if ok { - delete(m.subMgr.active, subName) - delete(m.subMgr.cancelFuncs, subName) - } - m.subMgr.Unlock() - - if ok { - if cancel != nil { - cancel() - } - // Shutdown using Background context - s.Shutdown(context.Background()) - } + teardownSub() // Update DB to Closed using ID client.Portal.UpdateOneID(p.ID). diff --git a/tavern/internal/portals/mux/mux_open.go b/tavern/internal/portals/mux/mux_open.go index 9365b11bb..bbd827057 100644 --- a/tavern/internal/portals/mux/mux_open.go +++ b/tavern/internal/portals/mux/mux_open.go @@ -2,7 +2,6 @@ package mux import ( "context" - "fmt" "gocloud.dev/pubsub" ) @@ -12,138 +11,21 @@ func (m *Mux) OpenPortal(ctx context.Context, portalID int) (func(), error) { topicOut := m.TopicOut(portalID) subName := m.SubName(topicOut) - m.subMgr.Lock() - // Check Cache - if _, ok := m.subMgr.active[subName]; ok { - m.subMgr.refs[subName]++ - m.subMgr.Unlock() - return func() { - m.subMgr.Lock() - m.subMgr.refs[subName]-- - shouldShutdown := false - var s *pubsub.Subscription - var cancel context.CancelFunc - if m.subMgr.refs[subName] <= 0 { - if sub, ok := m.subMgr.active[subName]; ok { - s = sub - cancel = m.subMgr.cancelFuncs[subName] - delete(m.subMgr.active, subName) - delete(m.subMgr.refs, subName) - delete(m.subMgr.cancelFuncs, subName) - shouldShutdown = true - } - } - m.subMgr.Unlock() - - if shouldShutdown { - if cancel != nil { - cancel() - } - if s != nil { - s.Shutdown(context.Background()) - } - } - }, nil + provision := func() error { + // Ensure subscription exists for the OUT topic + if err := m.ensureSub(ctx, topicOut, subName); err != nil { + return err // m.ensureSub already formats error if needed, or we can wrap here. Original wrapped it. + } + return nil } - m.subMgr.Unlock() - // Provisioning - // Ensure subscription exists for the OUT topic - if err := m.ensureSub(ctx, topicOut, subName); err != nil { - return nil, fmt.Errorf("failed to ensure subscription: %w", err) + startLoop := func(ctxLoop context.Context, sub *pubsub.Subscription) { + m.receiveLoop(ctxLoop, topicOut, sub) } - // Connect - // Updated SubURL usage - subURL := m.SubURL(topicOut, subName) - sub, err := m.openSubscription(ctx, subURL) + teardown, err := m.acquireSubscription(ctx, subName, topicOut, provision, startLoop) if err != nil { - return nil, fmt.Errorf("failed to open subscription %s: %w", subURL, err) - } - - m.subMgr.Lock() - // RACE CONDITION CHECK: - // Re-check cache in case another goroutine created it while we were provisioning/connecting - if existingSub, ok := m.subMgr.active[subName]; ok { - // Another routine won the race. Use theirs. - m.subMgr.refs[subName]++ - m.subMgr.Unlock() - - // Close our unused subscription immediately - sub.Shutdown(context.Background()) - - // Return teardown for the EXISTING subscription - return func() { - m.subMgr.Lock() - m.subMgr.refs[subName]-- - shouldShutdown := false - var s *pubsub.Subscription - var cancel context.CancelFunc - if m.subMgr.refs[subName] <= 0 { - if sub, ok := m.subMgr.active[subName]; ok { - s = sub - cancel = m.subMgr.cancelFuncs[subName] - delete(m.subMgr.active, subName) - delete(m.subMgr.refs, subName) - delete(m.subMgr.cancelFuncs, subName) - shouldShutdown = true - } - } - m.subMgr.Unlock() - - if shouldShutdown { - if cancel != nil { - cancel() - } - if existingSub != nil { - s.Shutdown(context.Background()) - } - } - }, nil - } - - // We won the race (or are the first). - m.subMgr.active[subName] = sub - m.subMgr.refs[subName] = 1 - - // Prepare Loop Context - ctxLoop, cancelLoop := context.WithCancel(context.Background()) - m.subMgr.cancelFuncs[subName] = cancelLoop - - m.subMgr.Unlock() - - // Spawn - go func() { - // No defer cancelLoop() here, controlled by teardown/map - m.receiveLoop(ctxLoop, topicOut, sub) - }() - - teardown := func() { - m.subMgr.Lock() - m.subMgr.refs[subName]-- - shouldShutdown := false - var s *pubsub.Subscription - var cancel context.CancelFunc - if m.subMgr.refs[subName] <= 0 { - if sub, ok := m.subMgr.active[subName]; ok { - s = sub - cancel = m.subMgr.cancelFuncs[subName] - delete(m.subMgr.active, subName) - delete(m.subMgr.refs, subName) - delete(m.subMgr.cancelFuncs, subName) - shouldShutdown = true - } - } - m.subMgr.Unlock() - - if shouldShutdown { - if cancel != nil { - cancel() - } - if s != nil { - s.Shutdown(context.Background()) - } - } + return nil, err // acquireSubscription already wraps error } return teardown, nil diff --git a/tavern/internal/portals/mux/mux_pubsub.go b/tavern/internal/portals/mux/mux_pubsub.go index 90dafb5e0..89246605d 100644 --- a/tavern/internal/portals/mux/mux_pubsub.go +++ b/tavern/internal/portals/mux/mux_pubsub.go @@ -219,3 +219,95 @@ func (m *Mux) openSubscription(ctx context.Context, url string) (*pubsub.Subscri return sub, nil } + +// acquireSubscription acquires a subscription, reusing an existing one if available. +// It handles reference counting and race conditions. +func (m *Mux) acquireSubscription( + ctx context.Context, + subName string, + topicID string, + provision func() error, + startLoop func(context.Context, *pubsub.Subscription), +) (func(), error) { + m.subMgr.Lock() + // Check Cache + if _, ok := m.subMgr.active[subName]; ok { + m.subMgr.refs[subName]++ + m.subMgr.Unlock() + return m.makeTeardown(subName), nil + } + m.subMgr.Unlock() + + // Provisioning + if err := provision(); err != nil { + return nil, fmt.Errorf("failed to provision subscription: %w", err) + } + + // Connect + subURL := m.SubURL(topicID, subName) + sub, err := m.openSubscription(ctx, subURL) + if err != nil { + return nil, fmt.Errorf("failed to open subscription %s: %w", subURL, err) + } + + m.subMgr.Lock() + // RACE CONDITION CHECK: + // Re-check cache in case another goroutine created it while we were provisioning/connecting + if _, ok := m.subMgr.active[subName]; ok { + // Another routine won the race. Use theirs. + m.subMgr.refs[subName]++ + m.subMgr.Unlock() + + // Close our unused subscription immediately + sub.Shutdown(context.Background()) + + // Return teardown for the EXISTING subscription + return m.makeTeardown(subName), nil + } + + // We won the race (or are the first). + m.subMgr.active[subName] = sub + m.subMgr.refs[subName] = 1 + + // Prepare Loop Context + ctxLoop, cancelLoop := context.WithCancel(context.Background()) + m.subMgr.cancelFuncs[subName] = cancelLoop + + m.subMgr.Unlock() + + // Spawn + go startLoop(ctxLoop, sub) + + return m.makeTeardown(subName), nil +} + +// makeTeardown creates a teardown function for a specific subscription. +func (m *Mux) makeTeardown(subName string) func() { + return func() { + m.subMgr.Lock() + m.subMgr.refs[subName]-- + shouldShutdown := false + var s *pubsub.Subscription + var cancel context.CancelFunc + if m.subMgr.refs[subName] <= 0 { + if sub, ok := m.subMgr.active[subName]; ok { + s = sub + cancel = m.subMgr.cancelFuncs[subName] + delete(m.subMgr.active, subName) + delete(m.subMgr.refs, subName) + delete(m.subMgr.cancelFuncs, subName) + shouldShutdown = true + } + } + m.subMgr.Unlock() + + if shouldShutdown { + if cancel != nil { + cancel() + } + if s != nil { + s.Shutdown(context.Background()) + } + } + } +} diff --git a/tavern/internal/portals/mux/mux_test.go b/tavern/internal/portals/mux/mux_test.go index 5787a0a58..231f19dcf 100644 --- a/tavern/internal/portals/mux/mux_test.go +++ b/tavern/internal/portals/mux/mux_test.go @@ -197,3 +197,72 @@ func TestWithSubscriberBufferSize(t *testing.T) { m := New(WithSubscriberBufferSize(expected)) assert.Equal(t, expected, m.subs.bufferSize) } + +func TestMux_CreatePortal_RefCounting(t *testing.T) { + // Setup DB + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + // Setup Mux + m := New(WithInMemoryDriver()) + ctx := context.Background() + + // Create User, Tome, Quest, Task + u := client.User.Create().SetName("testuser_rc").SetOauthID("oauth_rc").SetPhotoURL("photo_rc").SaveX(ctx) + tomeEnt := client.Tome.Create().SetName("testtome_rc").SetDescription("desc").SetEldritch("code").SetAuthor(u.Name).SetUploader(u).SetTactic(tome.TacticRECON).SaveX(ctx) + quest := client.Quest.Create().SetName("testquest_rc").SetParameters("{}").SetCreator(u).SetTome(tomeEnt).SaveX(ctx) + h := client.Host.Create().SetName("host_rc").SetIdentifier("ident_rc").SetPlatform(c2pb.Host_PLATFORM_LINUX).SaveX(ctx) + b := client.Beacon.Create().SetName("beacon_rc").SetTransport(c2pb.Transport_TRANSPORT_HTTP1).SetHost(h).SaveX(ctx) + task := client.Task.Create().SetQuest(quest).SetBeacon(b).SaveX(ctx) + + // 1. Create Portal + portalID, teardown, err := m.CreatePortal(ctx, client, task.ID) + require.NoError(t, err) + defer func() { + // Safety cleanup if test fails + if teardown != nil { + // mimic teardown call if needed, but we call it explicitly + } + }() + + topicIn := m.TopicIn(portalID) + subName := m.SubName(topicIn) + + m.subMgr.Lock() + // Check if refs initialized (should be 1) + refs, ok := m.subMgr.refs[subName] + m.subMgr.Unlock() + + assert.True(t, ok, "refs should be initialized") + assert.Equal(t, 1, refs) + + // Simulating a second user/process acquiring the same subscription + // (Since we can't easily trigger CreatePortal collision, we simulate the ref count increase manually + // to verify TEARDOWN respects it) + m.subMgr.Lock() + m.subMgr.refs[subName] = 2 // Manually bump to 2 (simulate another user) + m.subMgr.Unlock() + + // 2. Teardown + teardown() + + // 3. Verify + m.subMgr.Lock() + _, active := m.subMgr.active[subName] + currRefs := m.subMgr.refs[subName] + m.subMgr.Unlock() + + // If teardown respects refs, sub should still be active, refs should be 1. + // If teardown is broken (current state), sub is gone. + assert.True(t, active, "Subscription should still be active due to ref count") + assert.Equal(t, 1, currRefs, "Ref count should be decremented to 1") + + // Cleanup manual ref + m.subMgr.Lock() + if active { + delete(m.subMgr.active, subName) + delete(m.subMgr.refs, subName) + delete(m.subMgr.cancelFuncs, subName) + } + m.subMgr.Unlock() +}