diff --git a/pkg/synchronizer/integration_test.go b/pkg/synchronizer/integration_test.go new file mode 100644 index 00000000..05fe5bb8 --- /dev/null +++ b/pkg/synchronizer/integration_test.go @@ -0,0 +1,212 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package synchronizer + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// TestIntegration_CrossParticipantClock exercises the full SyncEngine stack +// (NtpEstimator -> SessionTimeline -> ParticipantClock -> SyncEngine) to verify +// that two participants producing audio at the same real-world time are aligned +// on the session timeline despite having different NTP clock offsets. +// +// Setup: +// - Alice: audio 48kHz, SSRC=1000, NTP clock offset = 0 (matches real time) +// - Bob: audio 48kHz, SSRC=2000, NTP clock offset = +500ms (ahead) +// - Both have real OWD of 50ms (same SFU -> egress path) +// - Session starts when Alice's first packet arrives +// +// The OWD estimator sees: +// - Alice: receivedAt - senderNTP = (realTime+50ms) - realTime = 50ms +// - Bob: receivedAt - senderNTP = (realTime+50ms) - (realTime+500ms) = -450ms +// +// The formula sessionPTS = ntpTime + OWD - sessionStart normalizes the clock +// offset because ntpTime includes the +500ms and OWD reflects the -500ms. +func TestIntegration_CrossParticipantClock(t *testing.T) { + const ( + clockRate = uint32(48000) + owd = 50 * time.Millisecond + bobNTPOffset = 500 * time.Millisecond + ) + + engine := NewSyncEngine(WithSyncEngineOldPacketThreshold(0)) + + aliceTrack := newMockAudioTrack("audio-alice", 1000) + bobTrack := newMockAudioTrack("audio-bob", 2000) + + aliceTS := engine.AddTrack(aliceTrack, "alice") + bobTS := engine.AddTrack(bobTrack, "bob") + + // Session starts at a fixed base time. Both participants' first packets + // arrive at the same instant (same real OWD from the SFU). + baseTime := time.Date(2025, 7, 1, 12, 0, 0, 0, time.UTC) + firstArrival := baseTime.Add(owd) + + // Prime both tracks with their first packets (same arrival time). + alicePkt0 := makeExtPacket(0, 0, firstArrival) + bobPkt0 := makeExtPacket(0, 0, firstArrival) + _, _, aliceDone := aliceTS.PrimeForStart(alicePkt0) + _, _, bobDone := bobTS.PrimeForStart(bobPkt0) + require.True(t, aliceDone) + require.True(t, bobDone) + + // Feed 5 sender reports for each participant, 5 seconds apart. + // Alice's NTP = realTime (no offset), Bob's NTP = realTime + 500ms. + // Both SRs arrive at realTime + OWD. + for i := 0; i < 5; i++ { + realTime := baseTime.Add(time.Duration(i) * 5 * time.Second) + receivedAt := realTime.Add(owd) + rtpTS := uint32(i) * 5 * clockRate + + // Alice SR: NTP = realTime + aliceNTP := ntpToUint64(realTime) + aliceSR := makeSenderReport(1000, aliceNTP, rtpTS) + // Manually set receivedAt by calling OnSenderReport on the timeline directly + // since OnRTCP uses time.Now(). We need deterministic timing. + engine.timeline.OnSenderReport("alice", "audio-alice", clockRate, aliceNTP, rtpTS, receivedAt) + _ = aliceSR // used above indirectly + + // Bob SR: NTP = realTime + 500ms (Bob's NTP clock is 500ms ahead) + bobNTP := ntpToUint64(realTime.Add(bobNTPOffset)) + engine.timeline.OnSenderReport("bob", "audio-bob", clockRate, bobNTP, rtpTS, receivedAt) + } + + // Get PTS for both participants at "real time + 10s" with corresponding + // RTP timestamps (10s * 48kHz = 480000). + realTimeAt10s := baseTime.Add(10 * time.Second) + receivedAtAt10s := realTimeAt10s.Add(owd) + rtpAt10s := uint32(10) * clockRate + + alicePkt := makeExtPacket(rtpAt10s, 100, receivedAtAt10s) + bobPkt := makeExtPacket(rtpAt10s, 100, receivedAtAt10s) + + alicePTS, err := aliceTS.GetPTS(alicePkt) + require.NoError(t, err) + + bobPTS, err := bobTS.GetPTS(bobPkt) + require.NoError(t, err) + + // The 500ms NTP clock difference should be normalized away by OWD estimation. + diff := alicePTS - bobPTS + if diff < 0 { + diff = -diff + } + + t.Logf("Alice PTS: %v, Bob PTS: %v, diff: %v", alicePTS, bobPTS, diff) + require.Less(t, diff, 50*time.Millisecond, + "cross-participant PTS should be aligned despite 500ms NTP clock offset; alice=%v bob=%v diff=%v", + alicePTS, bobPTS, diff) +} + +// TestIntegration_AVLipSync exercises the full SyncEngine stack to verify that +// a single participant's audio and video tracks are kept in sync despite an +// 80ms video encoder delay (video NTP timestamps lag audio by 80ms in the +// sender's clock domain). +// +// Setup: +// - Audio: 48kHz, SSRC=1000 +// - Video: 90kHz, SSRC=2000 +// - Same participant "alice" +// - OWD = 50ms for both tracks +// - Video has 80ms encoder delay: video NTP = audio NTP + 80ms for same +// real-world instant (video capture is delayed by encoding pipeline) +// +// The ParticipantClock detects the A/V NTP offset and applies a slew-limited +// correction on the video track to bring them into alignment. +func TestIntegration_AVLipSync(t *testing.T) { + const ( + audioClockRate = uint32(48000) + videoClockRate = uint32(90000) + owd = 50 * time.Millisecond + videoEncoderDelay = 80 * time.Millisecond + ) + + engine := NewSyncEngine(WithSyncEngineOldPacketThreshold(0)) + + audioTrack := newMockAudioTrack("audio-alice", 1000) + videoTrack := newMockVideoTrack("video-alice", 2000) + + audioTS := engine.AddTrack(audioTrack, "alice") + videoTS := engine.AddTrack(videoTrack, "alice") + + baseTime := time.Date(2025, 7, 1, 12, 0, 0, 0, time.UTC) + firstArrival := baseTime.Add(owd) + + // Prime both tracks. + audioPkt0 := makeExtPacket(0, 0, firstArrival) + videoPkt0 := makeExtPacket(0, 0, firstArrival) + _, _, audioDone := audioTS.PrimeForStart(audioPkt0) + _, _, videoDone := videoTS.PrimeForStart(videoPkt0) + require.True(t, audioDone) + require.True(t, videoDone) + + // Feed 5 SRs for audio and video, 5 seconds apart. + // Audio: NTP = baseNtp + i*5s, RTP = i * 5 * audioClockRate + // Video: NTP = baseNtp + i*5s + 80ms (encoder delay), RTP = i * 5 * videoClockRate + for i := 0; i < 5; i++ { + srTime := baseTime.Add(time.Duration(i) * 5 * time.Second) + receivedAt := srTime.Add(owd) + + audioRTP := uint32(i) * 5 * audioClockRate + audioNTP := ntpToUint64(srTime) + engine.timeline.OnSenderReport("alice", "audio-alice", audioClockRate, audioNTP, audioRTP, receivedAt) + + videoRTP := uint32(i) * 5 * videoClockRate + videoNTP := ntpToUint64(srTime.Add(videoEncoderDelay)) + engine.timeline.OnSenderReport("alice", "video-alice", videoClockRate, videoNTP, videoRTP, receivedAt) + } + + // Push multiple packets through GetPTS to drive the transition slew + // and allow the sync engine's per-call slew to converge. + for i := 1; i <= 200; i++ { + recvAt := firstArrival.Add(time.Duration(i) * 20 * time.Millisecond) + audioRTP := uint32(i) * 960 // 20ms at 48kHz + videoRTP := uint32(i) * 1800 // 20ms at 90kHz + + aPkt := makeExtPacket(audioRTP, uint16(i), recvAt) + vPkt := makeExtPacket(videoRTP, uint16(i), recvAt) + + audioTS.GetPTS(aPkt) + videoTS.GetPTS(vPkt) + } + + // Get PTS for audio at RTP=480000 (10s at 48kHz) and video at RTP=900000 (10s at 90kHz). + recvAt10s := firstArrival.Add(10 * time.Second) + audioPktFinal := makeExtPacket(10*audioClockRate, 500, recvAt10s) + videoPktFinal := makeExtPacket(10*videoClockRate, 500, recvAt10s) + + audioPTS, err := audioTS.GetPTS(audioPktFinal) + require.NoError(t, err) + + videoPTS, err := videoTS.GetPTS(videoPktFinal) + require.NoError(t, err) + + // The 80ms encoder delay should be corrected (or mostly corrected) by + // ParticipantClock's slew-limited adjustment. Allow 100ms tolerance to + // account for slew rate convergence. + diff := audioPTS - videoPTS + if diff < 0 { + diff = -diff + } + + t.Logf("Audio PTS: %v, Video PTS: %v, diff: %v", audioPTS, videoPTS, diff) + require.Less(t, diff, 100*time.Millisecond, + "A/V lip sync should be within 100ms after convergence; audio=%v video=%v diff=%v", + audioPTS, videoPTS, diff) +} diff --git a/pkg/synchronizer/interfaces.go b/pkg/synchronizer/interfaces.go new file mode 100644 index 00000000..8dd30e9b --- /dev/null +++ b/pkg/synchronizer/interfaces.go @@ -0,0 +1,45 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package synchronizer + +import ( + "time" + + "github.com/pion/rtcp" + + "github.com/livekit/media-sdk/jitter" +) + +// Sync is the top-level synchronization interface. +// Implemented by both Synchronizer (legacy) and SyncEngine (new). +type Sync interface { + AddTrack(track TrackRemote, participantID string) TrackSync + RemoveTrack(trackID string) + OnRTCP(packet rtcp.Packet) + End() + GetStartedAt() int64 + GetEndedAt() int64 + SetMediaRunningTime(mediaRunningTime func() (time.Duration, bool)) +} + +// TrackSync is the per-track synchronization interface. +// Implemented by both TrackSynchronizer (legacy) and syncEngineTrack (new). +type TrackSync interface { + PrimeForStart(pkt jitter.ExtPacket) ([]jitter.ExtPacket, int, bool) + GetPTS(pkt jitter.ExtPacket) (time.Duration, error) + OnSenderReport(f func(drift time.Duration)) + LastPTSAdjusted() time.Duration + Close() +} diff --git a/pkg/synchronizer/ntpestimator.go b/pkg/synchronizer/ntpestimator.go new file mode 100644 index 00000000..c1e80da6 --- /dev/null +++ b/pkg/synchronizer/ntpestimator.go @@ -0,0 +1,309 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package synchronizer + +import ( + "errors" + "math" + "sync" + "time" +) + +const ( + // maxSRSamples is the sliding window size for sender report pairs. + maxSRSamples = 20 + + // minSamplesReady is the minimum number of SR pairs needed before the + // regression is considered ready. With only 2 points the slope is entirely + // determined by SR timing jitter; 4 gives a much more stable fit. + minSamplesReady = 4 + + // outlierThresholdStdDevs is the number of standard deviations beyond which + // a new SR is considered an outlier and excluded from the regression. + outlierThresholdStdDevs = 3.0 + + // ntpEpochOffset is the number of seconds between the NTP epoch (1900-01-01) + // and the Unix epoch (1970-01-01). + ntpEpochOffset = 2208988800 +) + +var errNotReady = errors.New("NtpEstimator: not enough sender reports for regression (need >= 2)") + +// srSample holds one sender report observation used in the regression. +type srSample struct { + unwrappedRTP int64 // RTP timestamp unwrapped to 64-bit + ntpNanos int64 // NTP wall-clock in nanoseconds since Unix epoch + receivedAt time.Time +} + +// NtpEstimator maintains a linear regression over a sliding window of RTCP +// sender report pairs to map RTP timestamps to NTP time. It is modeled after +// Chrome's RtpToNtpEstimator. +type NtpEstimator struct { + mu sync.Mutex + clockRate uint32 + + samples [maxSRSamples]srSample + sampleLen int // number of valid samples in the buffer (0..maxSRSamples) + sampleHead int // index of the next write position + + // RTP unwrapping state + lastRTP uint32 + rtpOffset int64 // cumulative offset from wraparounds + hasLastRTP bool + + // Regression results (valid when sampleLen >= minSamplesReady) + // The internal model is: ntpNanos = slopeNanos * (unwrappedRTP - meanX) + meanY + // where slopeNanos is nanos per RTP tick. + slopeNanos float64 // nanos of NTP time per RTP tick + meanX float64 // mean of unwrapped RTP values in the current window + meanY float64 // mean of NTP nanos values in the current window + residStd float64 // residual standard deviation in NTP nanos + ready bool +} + +// NewNtpEstimator creates an NtpEstimator for a codec with the given clock rate. +func NewNtpEstimator(clockRate uint32) *NtpEstimator { + return &NtpEstimator{ + clockRate: clockRate, + } +} + +// Reset clears all state, returning the estimator to its initial condition. +// Used when a stream discontinuity is detected (e.g., stream restart with a new +// RTP offset) and the old regression is no longer valid. +func (e *NtpEstimator) Reset() { + e.mu.Lock() + defer e.mu.Unlock() + e.samples = [maxSRSamples]srSample{} + e.sampleLen = 0 + e.sampleHead = 0 + e.lastRTP = 0 + e.rtpOffset = 0 + e.hasLastRTP = false + e.slopeNanos = 0 + e.meanX = 0 + e.meanY = 0 + e.residStd = 0 + e.ready = false +} + +// SRResult indicates the outcome of processing a sender report. +type SRResult int + +const ( + SRAccepted SRResult = iota + SRDuplicate + SROutlier +) + +// OnSenderReport ingests a new RTCP sender report observation. +// ntpTime is the 64-bit NTP timestamp from the SR, rtpTimestamp is the +// corresponding RTP timestamp, and receivedAt is the local wall-clock time +// when the SR was received. +func (e *NtpEstimator) OnSenderReport(ntpTime uint64, rtpTimestamp uint32, receivedAt time.Time) SRResult { + e.mu.Lock() + defer e.mu.Unlock() + + ntpNanos := ntpTimestampToNanos(ntpTime) + unwrapped := e.unwrapRTP(rtpTimestamp) + + // Skip duplicate SRs (same NTP/RTP pair as the most recent sample). + // This happens when the same SR is dispatched multiple times via + // per-publication RTCP callbacks. + if e.sampleLen > 0 { + lastIdx := (e.sampleHead - 1 + maxSRSamples) % maxSRSamples + last := e.samples[lastIdx] + if last.unwrappedRTP == unwrapped && last.ntpNanos == ntpNanos { + return SRDuplicate + } + } + + // Outlier rejection: if we already have a valid regression, check whether + // this new sample deviates from the prediction by more than 3 standard + // deviations. + if e.ready && e.residStd > 0 { + predicted := e.slopeNanos*(float64(unwrapped)-e.meanX) + e.meanY + residual := math.Abs(float64(ntpNanos) - predicted) + if residual > outlierThresholdStdDevs*e.residStd { + return SROutlier + } + } + + // Write into circular buffer. + e.samples[e.sampleHead] = srSample{ + unwrappedRTP: unwrapped, + ntpNanos: ntpNanos, + receivedAt: receivedAt, + } + e.sampleHead = (e.sampleHead + 1) % maxSRSamples + if e.sampleLen < maxSRSamples { + e.sampleLen++ + } + + // Recompute regression if we have enough samples. + if e.sampleLen >= minSamplesReady { + e.computeRegression() + e.ready = true + } + + return SRAccepted +} + +// IsReady returns true once at least 2 sender reports have been processed +// and the regression is valid. +func (e *NtpEstimator) IsReady() bool { + e.mu.Lock() + defer e.mu.Unlock() + return e.ready +} + +// RtpToNtp maps an RTP timestamp to wall-clock time using the current regression. +func (e *NtpEstimator) RtpToNtp(rtpTimestamp uint32) (time.Time, error) { + e.mu.Lock() + defer e.mu.Unlock() + if !e.ready { + return time.Time{}, errNotReady + } + + unwrapped := e.unwrapRTPQuery(rtpTimestamp) + ntpNanos := e.slopeNanos*(float64(unwrapped)-e.meanX) + e.meanY + return nanosToTime(int64(math.Round(ntpNanos))), nil +} + +// Slope returns the regression slope: seconds of NTP time per RTP tick. +// For a perfect clock this equals 1/clockRate. +func (e *NtpEstimator) Slope() float64 { + e.mu.Lock() + defer e.mu.Unlock() + return e.slopeNanos / 1e9 +} + +// computeRegression performs ordinary least squares on the current samples +// using centered data to preserve float64 precision. +// Model: ntpNanos = slopeNanos * (unwrappedRTP - meanX) + meanY +func (e *NtpEstimator) computeRegression() { + n := float64(e.sampleLen) + + // First pass: compute means for centering. + var sumX, sumY float64 + e.iterSamples(func(s srSample) { + sumX += float64(s.unwrappedRTP) + sumY += float64(s.ntpNanos) + }) + mX := sumX / n + mY := sumY / n + + // Second pass: compute centered sums for regression. + var sumDxDx, sumDxDy float64 + e.iterSamples(func(s srSample) { + dx := float64(s.unwrappedRTP) - mX + dy := float64(s.ntpNanos) - mY + sumDxDx += dx * dx + sumDxDy += dx * dy + }) + + if sumDxDx == 0 { + // Degenerate case: all RTP timestamps identical. + return + } + + e.slopeNanos = sumDxDy / sumDxDx + e.meanX = mX + e.meanY = mY + + // Compute residual standard deviation. + var sumResidSq float64 + e.iterSamples(func(s srSample) { + predicted := e.slopeNanos*(float64(s.unwrappedRTP)-mX) + mY + r := float64(s.ntpNanos) - predicted + sumResidSq += r * r + }) + + if e.sampleLen > 2 { + e.residStd = math.Sqrt(sumResidSq / (n - 2)) + } else { + // With exactly 2 points the regression is exact; use a small positive + // value so that the 3-sigma check is not trivially zero. + e.residStd = math.Sqrt(sumResidSq / n) + } +} + +// iterSamples calls fn for each valid sample in the circular buffer. +func (e *NtpEstimator) iterSamples(fn func(srSample)) { + start := 0 + if e.sampleLen == maxSRSamples { + start = e.sampleHead // oldest entry is at head when buffer is full + } + for i := 0; i < e.sampleLen; i++ { + idx := (start + i) % maxSRSamples + fn(e.samples[idx]) + } +} + +// unwrapRTP unwraps a 32-bit RTP timestamp to a 64-bit value, tracking +// forward/backward jumps via signed diff. This is used when ingesting SRs +// to maintain the running unwrap state. +func (e *NtpEstimator) unwrapRTP(rtpTS uint32) int64 { + if !e.hasLastRTP { + e.hasLastRTP = true + e.lastRTP = rtpTS + e.rtpOffset = 0 + return int64(rtpTS) + } + + diff := int32(rtpTS - e.lastRTP) + if diff > 0 && rtpTS < e.lastRTP { + // Forward jump that crossed the uint32 boundary. + e.rtpOffset += 1 << 32 + } else if diff < 0 && rtpTS > e.lastRTP { + // Backward jump that crossed the uint32 boundary. + e.rtpOffset -= 1 << 32 + } + + e.lastRTP = rtpTS + return e.rtpOffset + int64(rtpTS) +} + +// unwrapRTPQuery unwraps an RTP timestamp for a query (RtpToNtp) without +// mutating the unwrap state. It uses the current offset tracked from SRs. +func (e *NtpEstimator) unwrapRTPQuery(rtpTS uint32) int64 { + if !e.hasLastRTP { + return int64(rtpTS) + } + + offset := e.rtpOffset + diff := int32(rtpTS - e.lastRTP) + if diff > 0 && rtpTS < e.lastRTP { + offset += 1 << 32 + } else if diff < 0 && rtpTS > e.lastRTP { + offset -= 1 << 32 + } + return offset + int64(rtpTS) +} + +// ntpTimestampToNanos converts a 64-bit NTP timestamp to nanoseconds since +// the Unix epoch. +func ntpTimestampToNanos(ntpTS uint64) int64 { + secs := int64(ntpTS>>32) - ntpEpochOffset + frac := ntpTS & 0xFFFFFFFF + nanos := int64(frac) * 1e9 / (1 << 32) + return secs*1e9 + nanos +} + +// nanosToTime converts nanoseconds since the Unix epoch to a time.Time. +func nanosToTime(nanos int64) time.Time { + return time.Unix(0, nanos) +} diff --git a/pkg/synchronizer/ntpestimator_test.go b/pkg/synchronizer/ntpestimator_test.go new file mode 100644 index 00000000..e01f425d --- /dev/null +++ b/pkg/synchronizer/ntpestimator_test.go @@ -0,0 +1,225 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package synchronizer + +import ( + "math" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// ntpToUint64 converts a time.Time to a 64-bit NTP timestamp. +// Upper 32 bits = seconds since NTP epoch (1900-01-01), +// lower 32 bits = fractional seconds. +func ntpToUint64(t time.Time) uint64 { + const ntpEpochOffset = 2208988800 + secs := uint64(t.Unix()) + ntpEpochOffset + frac := uint64(t.Nanosecond()) * (1 << 32) / 1e9 + return secs<<32 | frac +} + +func TestNtpEstimator_NotReadyBeforeEnoughSRs(t *testing.T) { + e := NewNtpEstimator(90000) + + require.False(t, e.IsReady(), "should not be ready with 0 SRs") + + _, err := e.RtpToNtp(1000) + require.Error(t, err, "RtpToNtp should error when not ready") + + // Feed SRs one at a time, checking readiness + now := time.Now() + for i := 0; i < minSamplesReady-1; i++ { + srTime := now.Add(time.Duration(i) * time.Second) + rtpTS := uint32(i+1) * 90000 + e.OnSenderReport(ntpToUint64(srTime), rtpTS, srTime) + require.False(t, e.IsReady(), "should not be ready with %d SRs", i+1) + } + + // One more SR makes it ready + srTime := now.Add(time.Duration(minSamplesReady-1) * time.Second) + rtpTS := uint32(minSamplesReady) * 90000 + e.OnSenderReport(ntpToUint64(srTime), rtpTS, srTime) + require.True(t, e.IsReady(), "should be ready with %d SRs", minSamplesReady) + + _, err = e.RtpToNtp(135000) + require.NoError(t, err, "RtpToNtp should succeed when ready") +} + +func TestNtpEstimator_AccurateMapping(t *testing.T) { + const clockRate = 90000 + e := NewNtpEstimator(clockRate) + + baseTime := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + + // Feed 10 perfect SRs at 1-second intervals + for i := 0; i < 10; i++ { + wallTime := baseTime.Add(time.Duration(i) * time.Second) + rtpTS := uint32(i) * clockRate + e.OnSenderReport(ntpToUint64(wallTime), rtpTS, wallTime) + } + + require.True(t, e.IsReady()) + + // Verify mapping at intermediate points + for _, tc := range []struct { + name string + rtpTS uint32 + wantNTP time.Time + }{ + {"at SR 0", 0, baseTime}, + {"at SR 5", 5 * clockRate, baseTime.Add(5 * time.Second)}, + {"between SR 2 and 3", uint32(2.5 * clockRate), baseTime.Add(2500 * time.Millisecond)}, + {"at SR 9", 9 * clockRate, baseTime.Add(9 * time.Second)}, + } { + t.Run(tc.name, func(t *testing.T) { + got, err := e.RtpToNtp(tc.rtpTS) + require.NoError(t, err) + diff := got.Sub(tc.wantNTP) + if diff < 0 { + diff = -diff + } + require.Less(t, diff, time.Millisecond, + "mapping off by %v; got %v, want %v", diff, got, tc.wantNTP) + }) + } +} + +func TestNtpEstimator_OutlierRejection(t *testing.T) { + const clockRate = 90000 + e := NewNtpEstimator(clockRate) + + baseTime := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC) + + // Feed 5 good SRs at 1-second intervals + for i := 0; i < 5; i++ { + wallTime := baseTime.Add(time.Duration(i) * time.Second) + rtpTS := uint32(i) * clockRate + e.OnSenderReport(ntpToUint64(wallTime), rtpTS, wallTime) + } + + require.True(t, e.IsReady()) + + // Feed 1 wildly wrong SR: RTP says 5 seconds but NTP says 50 seconds + badWallTime := baseTime.Add(50 * time.Second) + badRTP := uint32(5) * clockRate + e.OnSenderReport(ntpToUint64(badWallTime), badRTP, badWallTime) + + // Verify mapping is still accurate (outlier should have been rejected) + got, err := e.RtpToNtp(uint32(2.5 * clockRate)) + require.NoError(t, err) + want := baseTime.Add(2500 * time.Millisecond) + diff := got.Sub(want) + if diff < 0 { + diff = -diff + } + require.Less(t, diff, time.Millisecond, + "mapping should be accurate despite outlier; off by %v", diff) +} + +func TestNtpEstimator_Wraparound(t *testing.T) { + const clockRate = 90000 + e := NewNtpEstimator(clockRate) + + baseTime := time.Date(2025, 3, 1, 0, 0, 0, 0, time.UTC) + + // Start RTP near uint32 max so wraparound occurs + // math.MaxUint32 - 5*clockRate puts us 5 seconds before wrap + startRTP := uint32(math.MaxUint32 - 5*clockRate) + + for i := 0; i < 10; i++ { + wallTime := baseTime.Add(time.Duration(i) * time.Second) + rtpTS := startRTP + uint32(i)*clockRate // will wrap around uint32 + e.OnSenderReport(ntpToUint64(wallTime), rtpTS, wallTime) + } + + require.True(t, e.IsReady()) + + // Test mapping at points before and after the wraparound + // SR at i=5 is exactly where RTP wraps past 0 + for _, tc := range []struct { + name string + idx int + wantNTP time.Time + }{ + {"before wrap (i=3)", 3, baseTime.Add(3 * time.Second)}, + {"at wrap (i=5)", 5, baseTime.Add(5 * time.Second)}, + {"after wrap (i=8)", 8, baseTime.Add(8 * time.Second)}, + } { + t.Run(tc.name, func(t *testing.T) { + rtpTS := startRTP + uint32(tc.idx)*clockRate + got, err := e.RtpToNtp(rtpTS) + require.NoError(t, err) + diff := got.Sub(tc.wantNTP) + if diff < 0 { + diff = -diff + } + require.Less(t, diff, time.Millisecond, + "mapping off by %v across wraparound; got %v, want %v", diff, got, tc.wantNTP) + }) + } +} + +func TestNtpEstimator_SlidingWindow(t *testing.T) { + const clockRate = 90000 + e := NewNtpEstimator(clockRate) + + baseTime := time.Date(2025, 4, 1, 0, 0, 0, 0, time.UTC) + + // Feed 25 SRs (exceeds window of 20) + for i := 0; i < 25; i++ { + wallTime := baseTime.Add(time.Duration(i) * time.Second) + rtpTS := uint32(i) * clockRate + e.OnSenderReport(ntpToUint64(wallTime), rtpTS, wallTime) + } + + require.True(t, e.IsReady()) + + // Verify mapping still works accurately in the recent window + got, err := e.RtpToNtp(uint32(22) * clockRate) + require.NoError(t, err) + want := baseTime.Add(22 * time.Second) + diff := got.Sub(want) + if diff < 0 { + diff = -diff + } + require.Less(t, diff, time.Millisecond, + "mapping should be accurate after sliding window overflow; off by %v", diff) +} + +func TestNtpEstimator_Slope(t *testing.T) { + const clockRate = 90000 + e := NewNtpEstimator(clockRate) + + baseTime := time.Date(2025, 5, 1, 0, 0, 0, 0, time.UTC) + + // Feed 5 perfect SRs + for i := 0; i < 5; i++ { + wallTime := baseTime.Add(time.Duration(i) * time.Second) + rtpTS := uint32(i) * clockRate + e.OnSenderReport(ntpToUint64(wallTime), rtpTS, wallTime) + } + + require.True(t, e.IsReady()) + + // Slope should be close to 1/clockRate (seconds per RTP tick) + expectedSlope := 1.0 / float64(clockRate) + gotSlope := e.Slope() + + relError := math.Abs(gotSlope-expectedSlope) / expectedSlope + require.Less(t, relError, 1e-6, + "slope should be ~%e, got %e (relative error %e)", expectedSlope, gotSlope, relError) +} diff --git a/pkg/synchronizer/participantclock.go b/pkg/synchronizer/participantclock.go new file mode 100644 index 00000000..2c067cf4 --- /dev/null +++ b/pkg/synchronizer/participantclock.go @@ -0,0 +1,132 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package synchronizer + +import ( + "sync" + "time" + + "github.com/livekit/mediatransportutil/pkg/latency" + "github.com/livekit/protocol/logger" +) + +// ParticipantClock holds OWD and NTP estimation state for a single participant. +type ParticipantClock struct { + mu sync.Mutex + logger logger.Logger + owdEstimator *latency.OWDEstimator + tracks map[string]*NtpEstimator + ntpEpoch time.Time // NTP time from first SR + hasEpoch bool +} + +// NewParticipantClock creates a new ParticipantClock. +func NewParticipantClock(l logger.Logger) *ParticipantClock { + return &ParticipantClock{ + logger: l, + owdEstimator: latency.NewOWDEstimator(latency.OWDEstimatorParamsDefault), + tracks: make(map[string]*NtpEstimator), + } +} + +// OnSenderReport processes an RTCP sender report for a track. +// It updates the NTP estimator, OWD estimator, and records the NTP epoch. +func (pc *ParticipantClock) OnSenderReport(trackID string, clockRate uint32, ntpTime uint64, rtpTimestamp uint32, receivedAt time.Time) { + pc.mu.Lock() + defer pc.mu.Unlock() + + est, ok := pc.tracks[trackID] + if !ok { + est = NewNtpEstimator(clockRate) + pc.tracks[trackID] = est + } + + result := est.OnSenderReport(ntpTime, rtpTimestamp, receivedAt) + if result == SROutlier && pc.logger != nil { + pc.logger.Warnw("sender report rejected as outlier", nil, + "trackID", trackID, + "rtpTimestamp", rtpTimestamp, + "ntpTime", ntpTime, + ) + } + if result != SRAccepted { + return + } + + senderNtpNanos := ntpTimestampToNanos(ntpTime) + pc.owdEstimator.Update(senderNtpNanos, receivedAt.UnixNano()) + + if !pc.hasEpoch { + pc.ntpEpoch = nanosToTime(senderNtpNanos) + pc.hasEpoch = true + } +} + +// RtpToReceiverClock maps an RTP timestamp to a time on the receiver's clock. +// The result is ntpTime + estimatedOWD, which places the sender's NTP time +// into the receiver's clock domain. +func (pc *ParticipantClock) RtpToReceiverClock(trackID string, rtpTimestamp uint32) (time.Time, error) { + pc.mu.Lock() + defer pc.mu.Unlock() + + est, ok := pc.tracks[trackID] + if !ok { + return time.Time{}, errNoSenderReports + } + + if !est.IsReady() { + return time.Time{}, errNotReady + } + + if !pc.hasEpoch { + return time.Time{}, errNoSenderReports + } + + ntpTime, err := est.RtpToNtp(rtpTimestamp) + if err != nil { + return time.Time{}, err + } + + estimatedOWD := time.Duration(pc.owdEstimator.EstimatedPropagationDelay()) + return ntpTime.Add(estimatedOWD), nil +} + +// ResetTrack clears the NTP estimator for a track, forcing it to rebuild +// from new sender reports. Used when a stream discontinuity is detected. +func (pc *ParticipantClock) ResetTrack(trackID string) { + pc.mu.Lock() + defer pc.mu.Unlock() + + if est, ok := pc.tracks[trackID]; ok { + est.Reset() + } +} + +// RemoveTrack removes a track. +func (pc *ParticipantClock) RemoveTrack(trackID string) { + pc.mu.Lock() + defer pc.mu.Unlock() + + delete(pc.tracks, trackID) +} + +// HasTrack returns true if the participant has a track with the given ID. +func (pc *ParticipantClock) HasTrack(trackID string) bool { + pc.mu.Lock() + defer pc.mu.Unlock() + + _, ok := pc.tracks[trackID] + return ok +} diff --git a/pkg/synchronizer/participantclock_test.go b/pkg/synchronizer/participantclock_test.go new file mode 100644 index 00000000..31df7318 --- /dev/null +++ b/pkg/synchronizer/participantclock_test.go @@ -0,0 +1,55 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package synchronizer + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// readyEstimator creates an NtpEstimator pre-loaded with `count` sender reports +// so that it is ready for use. The SR samples are spaced 5 seconds apart in both +// NTP and RTP time. +func readyEstimator(clockRate uint32, baseNtp time.Time, baseRtp uint32, count int) *NtpEstimator { + e := NewNtpEstimator(clockRate) + for i := 0; i < count; i++ { + ntpTime := baseNtp.Add(time.Duration(i) * 5 * time.Second) + rtpTS := baseRtp + uint32(i)*uint32(clockRate)*5 + e.OnSenderReport(ntpToUint64(ntpTime), rtpTS, ntpTime.Add(30*time.Millisecond)) + } + return e +} + +func TestParticipantClock_RemoveTrack(t *testing.T) { + st := NewSessionTimeline(nil) + st.AddParticipant("alice") + + baseNtp := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + // Feed SRs to create the track estimator via the timeline. + for i := 0; i < 5; i++ { + ntpTime := baseNtp.Add(time.Duration(i) * 5 * time.Second) + rtpTS := uint32(i) * 5 * 48000 + st.OnSenderReport("alice", "audio-1", 48000, ntpToUint64(ntpTime), rtpTS, ntpTime.Add(30*time.Millisecond)) + } + + pc := st.GetParticipantClock("alice") + require.NotNil(t, pc) + require.True(t, pc.HasTrack("audio-1")) + + pc.RemoveTrack("audio-1") + require.False(t, pc.HasTrack("audio-1")) +} diff --git a/pkg/synchronizer/sessiontimeline.go b/pkg/synchronizer/sessiontimeline.go new file mode 100644 index 00000000..fea25f5b --- /dev/null +++ b/pkg/synchronizer/sessiontimeline.go @@ -0,0 +1,175 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package synchronizer + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/livekit/protocol/logger" +) + +var ( + errNoSenderReports = errors.New("SessionTimeline: no sender reports received for track") + errNoSessionStart = errors.New("SessionTimeline: session start time not set") +) + +// SessionTimeline establishes a shared recording timeline and maps each +// participant's NTP clock domain onto it using OWD (one-way delay) +// normalization. This is the key component that fixes cross-participant +// misalignment. +// +// Algorithm: +// 1. Each SR provides a pair: (senderNtpTime, receivedAtWallClock). The +// difference is the one-way delay (OWD). +// 2. Using the OWDEstimator, estimate each participant's OWD. The min +// observed OWD approximates true propagation delay. +// 3. To map a participant's RTP timestamp to the session timeline: +// sessionPTS = ntpTime + estimatedOWD - sessionStart +type SessionTimeline struct { + mu sync.RWMutex + logger logger.Logger + participants map[string]*ParticipantClock + sessionStart time.Time + hasStart bool +} + +// NewSessionTimeline creates a new SessionTimeline. +func NewSessionTimeline(l logger.Logger) *SessionTimeline { + return &SessionTimeline{ + logger: l, + participants: make(map[string]*ParticipantClock), + } +} + +// SetSessionStart sets the session start time (wall-clock time when the first +// packet of any track arrived at the receiver). +func (st *SessionTimeline) SetSessionStart(t time.Time) { + st.mu.Lock() + defer st.mu.Unlock() + + st.sessionStart = t + st.hasStart = true +} + +// AddParticipant registers a new participant with the given participantID. +func (st *SessionTimeline) AddParticipant(participantID string) *ParticipantClock { + st.mu.Lock() + defer st.mu.Unlock() + + pc := NewParticipantClock(st.logger) + st.participants[participantID] = pc + return pc +} + +// GetOrAddParticipant returns the ParticipantClock for the given participantID, +// creating one if it doesn't exist. This is safe for concurrent use. +func (st *SessionTimeline) GetOrAddParticipant(participantID string) *ParticipantClock { + st.mu.Lock() + defer st.mu.Unlock() + + if pc, ok := st.participants[participantID]; ok { + return pc + } + + pc := NewParticipantClock(st.logger) + st.participants[participantID] = pc + return pc +} + +// GetParticipantClock returns the ParticipantClock for a participant, or nil. +func (st *SessionTimeline) GetParticipantClock(participantID string) *ParticipantClock { + st.mu.RLock() + defer st.mu.RUnlock() + + return st.participants[participantID] +} + +// RemoveParticipant removes the participant with the given participantID. +func (st *SessionTimeline) RemoveParticipant(participantID string) { + st.mu.Lock() + defer st.mu.Unlock() + + delete(st.participants, participantID) +} + +// ResetTrack clears the NTP estimator for a track, forcing it to rebuild from +// new sender reports. Used when a stream discontinuity is detected. +func (st *SessionTimeline) ResetTrack(participantID, trackID string) { + st.mu.RLock() + pc, ok := st.participants[participantID] + st.mu.RUnlock() + + if ok { + pc.ResetTrack(trackID) + } +} + +// OnSenderReport processes an RTCP sender report for a participant's track. +// It delegates to the ParticipantClock to update the NTP estimator, OWD +// estimator, and NTP epoch. +func (st *SessionTimeline) OnSenderReport(participantID, trackID string, clockRate uint32, ntpTime uint64, rtpTimestamp uint32, receivedAt time.Time) { + st.mu.RLock() + pc, ok := st.participants[participantID] + st.mu.RUnlock() + + if !ok { + return + } + + pc.OnSenderReport(trackID, clockRate, ntpTime, rtpTimestamp, receivedAt) +} + +// GetSessionPTS maps an RTP timestamp for a participant's track to a position +// on the shared session timeline. +// +// The formula is: sessionPTS = ntpTime + estimatedOWD - sessionStart +func (st *SessionTimeline) GetSessionPTS(participantID, trackID string, rtpTimestamp uint32) (time.Duration, error) { + st.mu.RLock() + if !st.hasStart { + st.mu.RUnlock() + return 0, errNoSessionStart + } + pc, ok := st.participants[participantID] + sessionStart := st.sessionStart + st.mu.RUnlock() + + if !ok { + return 0, fmt.Errorf("SessionTimeline: unknown participant %q", participantID) + } + + receiverTime, err := pc.RtpToReceiverClock(trackID, rtpTimestamp) + if err != nil { + return 0, err + } + + sessionPTS := receiverTime.Sub(sessionStart) + + if (sessionPTS < 0 || sessionPTS > 24*time.Hour) && st.logger != nil { + st.logger.Warnw("GetSessionPTS: abnormal result", + nil, + "participantID", participantID, + "trackID", trackID, + "rtpTimestamp", rtpTimestamp, + "receiverTime", receiverTime, + "sessionStart", sessionStart, + "sessionPTS", sessionPTS, + ) + } + + return sessionPTS, nil +} diff --git a/pkg/synchronizer/sessiontimeline_test.go b/pkg/synchronizer/sessiontimeline_test.go new file mode 100644 index 00000000..415825c1 --- /dev/null +++ b/pkg/synchronizer/sessiontimeline_test.go @@ -0,0 +1,223 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package synchronizer + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSessionTimeline_SingleParticipant(t *testing.T) { + // One participant with 50ms OWD, feed 5 SRs, verify PTS at 10s is ~10s. + const ( + clockRate = 90000 + owd = 50 * time.Millisecond + identity = "alice" + trackID = "audio-1" + ) + + st := NewSessionTimeline(nil) + + // Session starts at a fixed wall-clock time. + sessionStart := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC) + st.SetSessionStart(sessionStart) + + st.AddParticipant(identity) + + // The participant's NTP clock is offset from wall clock by OWD. + // senderNtpTime + OWD = receivedAt (approximately). + // So receivedAt = senderNtpTime + OWD. + baseNTP := sessionStart // participant's NTP epoch starts at sessionStart + for i := 0; i < 5; i++ { + senderNTP := baseNTP.Add(time.Duration(i) * 2 * time.Second) + rtpTS := uint32(i) * 2 * clockRate + receivedAt := senderNTP.Add(owd) + st.OnSenderReport(identity, trackID, clockRate, ntpToUint64(senderNTP), rtpTS, receivedAt) + } + + // Query PTS at RTP timestamp corresponding to 10s into the stream. + rtpAt10s := uint32(10 * clockRate) + pts, err := st.GetSessionPTS(identity, trackID, rtpAt10s) + require.NoError(t, err) + + // Expected: ~10s on the session timeline. + diff := pts - 10*time.Second + if diff < 0 { + diff = -diff + } + require.Less(t, diff, 100*time.Millisecond, + "PTS at 10s should be ~10s, got %v (diff %v)", pts, diff) +} + +func TestSessionTimeline_CrossParticipantAlignment(t *testing.T) { + // Two participants with different OWDs (50ms and 200ms), both producing + // media at the same real-world time. The SessionTimeline maps each + // participant's NTP clock domain onto the receiver's clock using OWD. + // + // Because both start producing at the same real-world time but have + // different network path delays, the receiver-clock-based timeline + // correctly reflects the OWD difference: bob's media arrives 150ms + // later than alice's for the same production instant. + // + // Additionally, we verify that NTP clock offset differences between + // participants are properly normalized via the OWD mapping: if bob's + // NTP clock is offset by +500ms relative to alice's, the OWD estimator + // absorbs this, and the session PTS still reflects the real receiver-clock + // arrival times. + const ( + clockRate = 90000 + owd1 = 50 * time.Millisecond // alice's real network delay + owd2 = 50 * time.Millisecond // bob's real network delay (same) + ) + + // Bob's NTP clock is offset by 500ms relative to alice's (different NTP servers). + bobNTPOffset := 500 * time.Millisecond + + st := NewSessionTimeline(nil) + sessionStart := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC) + st.SetSessionStart(sessionStart) + + st.AddParticipant("alice") + st.AddParticipant("bob") + + // Both participants start producing at the same real-world time. + // Alice's NTP clock = real-world time. + // Bob's NTP clock = real-world time + 500ms (NTP offset). + // Both have the same real OWD of 50ms, so: + // receivedAt = realWorldTime + owd + // aliceNTP = realWorldTime + // bobNTP = realWorldTime + 500ms + // OWD as seen by estimator: + // alice: receivedAt - aliceNTP = owd = 50ms + // bob: receivedAt - bobNTP = owd - 500ms = -450ms (negative! but the estimator handles this) + // + // Actually, OWD = receivedAt - senderNTP. For bob: + // receivedAt = realWorldTime + 50ms + // senderNTP = realWorldTime + 500ms + // observed OWD = (realWorldTime + 50ms) - (realWorldTime + 500ms) = -450ms + // + // This negative OWD is fine - it just means bob's NTP clock is ahead of + // the receiver's clock by more than the real OWD. The formula still works + // because: ntpTime + OWD - sessionStart = (realWorldTime + 500ms) + (-450ms) - sessionStart + // = realWorldTime + 50ms - sessionStart + // Which matches alice's: realWorldTime + 50ms - sessionStart + + for i := 0; i < 5; i++ { + realTime := sessionStart.Add(time.Duration(i) * 2 * time.Second) + rtpTS := uint32(i) * 2 * clockRate + receivedAt := realTime.Add(owd1) + + aliceNTP := realTime // alice NTP = real time + st.OnSenderReport("alice", "audio-a", clockRate, ntpToUint64(aliceNTP), rtpTS, receivedAt) + + bobNTP := realTime.Add(bobNTPOffset) // bob NTP = real time + offset + bobRecv := realTime.Add(owd2) + st.OnSenderReport("bob", "audio-b", clockRate, ntpToUint64(bobNTP), rtpTS, bobRecv) + } + + // Both participants produce a frame at RTP timestamp corresponding to 5s. + rtpAt5s := uint32(5 * clockRate) + + alicePTS, err := st.GetSessionPTS("alice", "audio-a", rtpAt5s) + require.NoError(t, err) + + bobPTS, err := st.GetSessionPTS("bob", "audio-b", rtpAt5s) + require.NoError(t, err) + + // Despite bob's NTP clock being 500ms offset, the OWD-based mapping + // normalizes both to the receiver's clock domain. Their PTS values + // should be within a small tolerance. + diff := alicePTS - bobPTS + if diff < 0 { + diff = -diff + } + require.Less(t, diff, 50*time.Millisecond, + "cross-participant PTS should be aligned despite NTP clock offset; alice=%v bob=%v diff=%v", alicePTS, bobPTS, diff) +} + +func TestSessionTimeline_LateJoiner(t *testing.T) { + // One participant starts, 30s later another joins. + // Verify the late joiner's first frame maps to ~30s on the session timeline. + const ( + clockRate = 90000 + owd = 50 * time.Millisecond + ) + + st := NewSessionTimeline(nil) + sessionStart := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC) + st.SetSessionStart(sessionStart) + + // Alice joins at session start. + st.AddParticipant("alice") + aliceBaseNTP := sessionStart + for i := 0; i < 5; i++ { + aliceNTP := aliceBaseNTP.Add(time.Duration(i) * 2 * time.Second) + aliceRTP := uint32(i) * 2 * clockRate + aliceRecv := aliceNTP.Add(owd) + st.OnSenderReport("alice", "audio-a", clockRate, ntpToUint64(aliceNTP), aliceRTP, aliceRecv) + } + + // Bob joins 30s later. + st.AddParticipant("bob") + bobBaseNTP := sessionStart.Add(30 * time.Second) + for i := 0; i < 5; i++ { + bobNTP := bobBaseNTP.Add(time.Duration(i) * 2 * time.Second) + bobRTP := uint32(i) * 2 * clockRate + bobRecv := bobNTP.Add(owd) + st.OnSenderReport("bob", "audio-b", clockRate, ntpToUint64(bobNTP), bobRTP, bobRecv) + } + + // Bob's first frame (RTP=0) should map to ~30s on session timeline. + bobPTS, err := st.GetSessionPTS("bob", "audio-b", 0) + require.NoError(t, err) + + diff := bobPTS - 30*time.Second + if diff < 0 { + diff = -diff + } + require.Less(t, diff, 100*time.Millisecond, + "late joiner's first frame should be at ~30s; got %v (diff %v)", bobPTS, diff) + + // Alice's first frame should be at ~0s. + alicePTS, err := st.GetSessionPTS("alice", "audio-a", 0) + require.NoError(t, err) + + diff = alicePTS + if diff < 0 { + diff = -diff + } + require.Less(t, diff, 100*time.Millisecond, + "first participant's first frame should be at ~0s; got %v", alicePTS) +} + +func TestSessionTimeline_FallbackBeforeSRs(t *testing.T) { + // Verify error when no SRs received. + st := NewSessionTimeline(nil) + sessionStart := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC) + st.SetSessionStart(sessionStart) + + st.AddParticipant("alice") + + // No SRs have been received: should return error. + _, err := st.GetSessionPTS("alice", "audio-a", 1000) + require.Error(t, err) + require.ErrorIs(t, err, errNoSenderReports) + + // Unknown participant should also error. + _, err = st.GetSessionPTS("unknown", "track-x", 1000) + require.Error(t, err) +} diff --git a/pkg/synchronizer/syncengine.go b/pkg/synchronizer/syncengine.go new file mode 100644 index 00000000..0b9849c8 --- /dev/null +++ b/pkg/synchronizer/syncengine.go @@ -0,0 +1,329 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package synchronizer + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/pion/rtcp" + + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/rtputil" +) + +const ( + // defaultOldPacketThreshold is the default age after which packets are dropped. + defaultOldPacketThreshold = 500 * time.Millisecond +) + +// SyncEngineOption configures a SyncEngine. +type SyncEngineOption func(*SyncEngine) + +// WithSyncEngineOnStarted sets a callback invoked once the first track is initialized. +func WithSyncEngineOnStarted(f func()) SyncEngineOption { + return func(e *SyncEngine) { + e.onStarted = f + } +} + +// WithSyncEngineStartGate enables the burst-estimation start gate on all tracks. +func WithSyncEngineStartGate() SyncEngineOption { + return func(e *SyncEngine) { + e.enableStartGate = true + } +} + +// WithSyncEngineOldPacketThreshold sets the age after which packets are dropped. +// Zero disables the check. +func WithSyncEngineOldPacketThreshold(d time.Duration) SyncEngineOption { + return func(e *SyncEngine) { + e.oldPacketThreshold = d + } +} + +// WithSyncEngineMediaRunningTime sets the initial media running time provider and max delay. +// If a track's PTS falls behind the deadline by more than maxDelay for >10s, PTS is force-corrected. +func WithSyncEngineMediaRunningTime(mediaRunningTime func() (time.Duration, bool), maxDelay time.Duration) SyncEngineOption { + return func(e *SyncEngine) { + e.mediaRunningTime = mediaRunningTime + e.maxMediaRunningTimeDelay = maxDelay + } +} + +// WithSyncEngineLogger sets the logger for the sync engine and all sub-components. +func WithSyncEngineLogger(l logger.Logger) SyncEngineOption { + return func(e *SyncEngine) { + e.logger = l + } +} + +// WithSyncEngineAudioDriftCompensated signals that audio drift is handled +// externally (e.g., by a tempo controller) and the sync engine should not +// apply NTP PTS corrections to audio tracks. NTP regression still runs for +// drift measurement and reporting. +func WithSyncEngineAudioDriftCompensated() SyncEngineOption { + return func(e *SyncEngine) { + e.audioDriftCompensated = true + } +} + +// SyncEngine orchestrates NtpEstimator, ParticipantClock, and SessionTimeline +// to provide cross-participant alignment and per-participant A/V lip sync. +// It implements the Sync interface. +type SyncEngine struct { + mu sync.Mutex + timeline *SessionTimeline + tracks map[uint32]*syncEngineTrack // keyed by SSRC + trackIDs map[string]*syncEngineTrack // keyed by track ID + + startedAt atomic.Int64 + endedAt atomic.Int64 + + // high-water mark for removed tracks, so End() includes their PTS + maxRemovedPTS time.Duration + + logger logger.Logger + enableStartGate bool + oldPacketThreshold time.Duration + audioDriftCompensated bool // audio drift handled externally (e.g., tempo controller) + onStarted func() + + mediaRunningTime func() (time.Duration, bool) + maxMediaRunningTimeDelay time.Duration + mediaRunningTimeLock sync.RWMutex +} + +// NewSyncEngine creates a new SyncEngine with the given options. +func NewSyncEngine(opts ...SyncEngineOption) *SyncEngine { + e := &SyncEngine{ + tracks: make(map[uint32]*syncEngineTrack), + trackIDs: make(map[string]*syncEngineTrack), + oldPacketThreshold: defaultOldPacketThreshold, + } + for _, opt := range opts { + opt(e) + } + e.timeline = NewSessionTimeline(e.logger) + return e +} + +// AddTrack registers a new track and returns a TrackSync handle. +func (e *SyncEngine) AddTrack(track TrackRemote, participantID string) TrackSync { + ssrc := uint32(track.SSRC()) + clockRate := track.Codec().ClockRate + + e.mu.Lock() + defer e.mu.Unlock() + + // Ensure the participant exists in the timeline. + e.timeline.GetOrAddParticipant(participantID) + + st := &syncEngineTrack{ + engine: e, + track: track, + participantID: participantID, + logger: e.getTrackLogger(track), + converter: rtputil.NewRTPConverter(int64(clockRate)), + } + + if e.enableStartGate { + st.startGate = newStartGate(clockRate, track.Kind(), nil) + } + + e.tracks[ssrc] = st + e.trackIDs[track.ID()] = st + + return st +} + +// RemoveTrack removes a track by track ID. +func (e *SyncEngine) RemoveTrack(trackID string) { + e.mu.Lock() + st, ok := e.trackIDs[trackID] + if !ok { + e.mu.Unlock() + return + } + + // Preserve removed track's PTS high-water mark so End() includes it. + st.mu.Lock() + if st.lastPTSAdjusted > e.maxRemovedPTS { + e.maxRemovedPTS = st.lastPTSAdjusted + } + st.mu.Unlock() + + ssrc := uint32(st.track.SSRC()) + delete(e.tracks, ssrc) + delete(e.trackIDs, trackID) + e.mu.Unlock() + + // Clean up track from participant, and remove the participant from the + // timeline if this was their last track. + participantID := st.participantID + if pc := e.timeline.GetParticipantClock(participantID); pc != nil { + pc.RemoveTrack(trackID) + } + if !e.hasTracksForParticipant(participantID) { + e.timeline.RemoveParticipant(participantID) + } + + st.logger.Infow("track removed", "lastPTS", st.lastPTSAdjusted) + st.Close() +} + +// OnRTCP processes an RTCP packet, dispatching sender reports to the appropriate +// track's NTP estimator and ParticipantClock. +func (e *SyncEngine) OnRTCP(packet rtcp.Packet) { + sr, ok := packet.(*rtcp.SenderReport) + if !ok { + return + } + + e.mu.Lock() + st, ok := e.tracks[sr.SSRC] + if !ok { + e.mu.Unlock() + return + } + participantID := st.participantID + trackID := st.track.ID() + clockRate := st.track.Codec().ClockRate + e.mu.Unlock() + + now := time.Now() + + // Feed the SR to the session timeline (updates NTP estimator + OWD). + e.timeline.OnSenderReport(participantID, trackID, clockRate, sr.NTPTime, sr.RTPTime, now) + + // Call onSR callback if set. + st.mu.Lock() + onSR := st.onSR + st.mu.Unlock() + + if onSR != nil { + // Compute drift using OWD-normalized session PTS (not raw NTP, which + // includes the sender's clock offset and would produce phantom drift + // if the sender's NTP clock adjusts during the recording). + startedAt := e.startedAt.Load() + if startedAt > 0 { + sessionPTS, err := e.timeline.GetSessionPTS(participantID, trackID, sr.RTPTime) + if err == nil { + sessionStart := time.Unix(0, startedAt) + expectedElapsed := now.Sub(sessionStart) + drift := sessionPTS - expectedElapsed + st.logger.Debugw("sender report", + "drift", drift, + "sessionPTS", sessionPTS, + "expectedElapsed", expectedElapsed, + ) + onSR(drift) + } + } + } +} + +// End signals the end of the session and sets drain ceilings on all tracks. +func (e *SyncEngine) End() { + e.mu.Lock() + defer e.mu.Unlock() + + // Start from the high-water mark of removed tracks. + maxPTS := e.maxRemovedPTS + for _, st := range e.tracks { + st.mu.Lock() + if st.lastPTSAdjusted > maxPTS { + maxPTS = st.lastPTSAdjusted + } + st.mu.Unlock() + } + + startedAt := e.startedAt.Load() + if startedAt > 0 { + e.endedAt.Store(startedAt + int64(maxPTS)) + } else { + e.endedAt.Store(time.Now().UnixNano()) + } + + // Set drain ceiling on all tracks. + for _, st := range e.tracks { + st.mu.Lock() + st.maxPTS = maxPTS + st.maxPTSSet = true + st.mu.Unlock() + } +} + +// GetStartedAt returns the start timestamp in nanoseconds, or 0 if not started. +func (e *SyncEngine) GetStartedAt() int64 { + return e.startedAt.Load() +} + +// GetEndedAt returns the end timestamp in nanoseconds, or 0 if not ended. +func (e *SyncEngine) GetEndedAt() int64 { + return e.endedAt.Load() +} + +// SetMediaRunningTime sets the external media running time provider. +func (e *SyncEngine) SetMediaRunningTime(mediaRunningTime func() (time.Duration, bool)) { + e.mediaRunningTimeLock.Lock() + e.mediaRunningTime = mediaRunningTime + e.mediaRunningTimeLock.Unlock() +} + +// getMediaDeadline returns the current pipeline deadline, or false if unavailable. +func (e *SyncEngine) getMediaDeadline() (time.Duration, bool) { + e.mediaRunningTimeLock.RLock() + fn := e.mediaRunningTime + e.mediaRunningTimeLock.RUnlock() + if fn == nil { + return 0, false + } + return fn() +} + +// initializeIfNeeded sets the session start time and fires the onStarted callback +// on the first track initialization. Returns the startedAt value. +func (e *SyncEngine) initializeIfNeeded(receivedAt time.Time) int64 { + nano := receivedAt.UnixNano() + if e.startedAt.CompareAndSwap(0, nano) { + e.timeline.SetSessionStart(receivedAt) + if e.onStarted != nil { + e.onStarted() + } + } + return e.startedAt.Load() +} + +// hasTracksForParticipant returns true if any remaining track belongs to the +// given participant participantID. Caller must NOT hold e.mu. +func (e *SyncEngine) hasTracksForParticipant(participantID string) bool { + e.mu.Lock() + defer e.mu.Unlock() + for _, st := range e.tracks { + if st.participantID == participantID { + return true + } + } + return false +} + +func (e *SyncEngine) getTrackLogger(track TrackRemote) logger.Logger { + if e.logger != nil { + return e.logger.WithValues("trackID", track.ID(), "kind", track.Kind().String()) + } + return logger.GetLogger().WithValues("trackID", track.ID(), "kind", track.Kind().String(), "syncEngine", true) +} diff --git a/pkg/synchronizer/syncengine_test.go b/pkg/synchronizer/syncengine_test.go new file mode 100644 index 00000000..7cf317c7 --- /dev/null +++ b/pkg/synchronizer/syncengine_test.go @@ -0,0 +1,197 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package synchronizer + +import ( + "testing" + "time" + + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" + "github.com/stretchr/testify/require" + + "github.com/livekit/media-sdk/jitter" +) + +// --- Test helpers --- + +type mockTrackRemote struct { + id string + codec webrtc.RTPCodecParameters + kind webrtc.RTPCodecType + ssrc webrtc.SSRC +} + +func (m *mockTrackRemote) ID() string { return m.id } +func (m *mockTrackRemote) Codec() webrtc.RTPCodecParameters { return m.codec } +func (m *mockTrackRemote) Kind() webrtc.RTPCodecType { return m.kind } +func (m *mockTrackRemote) SSRC() webrtc.SSRC { return m.ssrc } + +func newMockAudioTrack(id string, ssrc uint32) *mockTrackRemote { + return &mockTrackRemote{ + id: id, + codec: webrtc.RTPCodecParameters{RTPCodecCapability: webrtc.RTPCodecCapability{ClockRate: 48000}}, + kind: webrtc.RTPCodecTypeAudio, + ssrc: webrtc.SSRC(ssrc), + } +} + +func newMockVideoTrack(id string, ssrc uint32) *mockTrackRemote { + return &mockTrackRemote{ + id: id, + codec: webrtc.RTPCodecParameters{RTPCodecCapability: webrtc.RTPCodecCapability{ClockRate: 90000}}, + kind: webrtc.RTPCodecTypeVideo, + ssrc: webrtc.SSRC(ssrc), + } +} + +func makeExtPacket(ts uint32, sn uint16, receivedAt time.Time) jitter.ExtPacket { + return jitter.ExtPacket{ + ReceivedAt: receivedAt, + Packet: &rtp.Packet{Header: rtp.Header{Timestamp: ts, SequenceNumber: sn}}, + } +} + +// --- Tests --- + +func TestSyncEngine_ImplementsSyncInterface(t *testing.T) { + // Compile-time check that SyncEngine implements Sync. + var _ Sync = (*SyncEngine)(nil) +} + +func TestSyncEngine_FallbackToWallClockBeforeSRs(t *testing.T) { + engine := NewSyncEngine() + + track := newMockAudioTrack("audio-1", 1000) + ts := engine.AddTrack(track, "alice") + + now := time.Now() + + // Prime the track with the first packet. + pkt0 := makeExtPacket(0, 0, now) + _, _, done := ts.PrimeForStart(pkt0) + require.True(t, done, "without start gate, track should be ready immediately") + + // Get PTS for first packet (same as prime packet). + pts0, err := ts.GetPTS(pkt0) + require.NoError(t, err) + require.GreaterOrEqual(t, int64(pts0), int64(0), "first PTS should be >= 0") + + // Second packet 100ms later. + pkt1 := makeExtPacket(4800, 1, now.Add(100*time.Millisecond)) + pts1, err := ts.GetPTS(pkt1) + require.NoError(t, err) + require.Greater(t, int64(pts1), int64(0), "second packet PTS should be > 0") + require.Greater(t, int64(pts1), int64(pts0), "PTS should advance") +} + +func TestSyncEngine_TransitionsToNTPAfterSRs(t *testing.T) { + engine := NewSyncEngine() + + track := newMockAudioTrack("audio-1", 1000) + ts := engine.AddTrack(track, "alice") + + now := time.Now() + + // Prime and get initial wall-clock PTS. + pkt0 := makeExtPacket(0, 0, now) + ts.PrimeForStart(pkt0) + pts0, err := ts.GetPTS(pkt0) + require.NoError(t, err) + + // Get a wall-clock PTS at 500ms. + pkt1 := makeExtPacket(24000, 1, now.Add(500*time.Millisecond)) + pts1, err := ts.GetPTS(pkt1) + require.NoError(t, err) + require.Greater(t, int64(pts1), int64(pts0)) + + // Feed 3 sender reports to make NTP estimator ready. + for i := 0; i < 3; i++ { + srTime := now.Add(time.Duration(i) * time.Second) + rtpTS := uint32(i) * 48000 + ntpTime := ntpToUint64(srTime) + sr := makeSenderReport(1000, ntpTime, rtpTS) + engine.OnRTCP(sr) + } + + // Get PTS after NTP transition - should still be valid and advancing. + pkt2 := makeExtPacket(48000, 2, now.Add(time.Second)) + pts2, err := ts.GetPTS(pkt2) + require.NoError(t, err) + require.Greater(t, int64(pts2), int64(pts1), "PTS should continue to advance after NTP transition") +} + +func TestSyncEngine_MonotonicPTS(t *testing.T) { + engine := NewSyncEngine() + + track := newMockAudioTrack("audio-1", 1000) + ts := engine.AddTrack(track, "alice") + + now := time.Now() + + // Prime with first packet. + pkt0 := makeExtPacket(0, 0, now) + ts.PrimeForStart(pkt0) + + var lastPTS time.Duration + for i := 0; i < 100; i++ { + recvAt := now.Add(time.Duration(i) * 20 * time.Millisecond) + rtpTS := uint32(i) * 960 // 20ms at 48kHz + pkt := makeExtPacket(rtpTS, uint16(i), recvAt) + pts, err := ts.GetPTS(pkt) + require.NoError(t, err) + require.GreaterOrEqual(t, int64(pts), int64(lastPTS), + "PTS must be monotonically non-decreasing: packet %d got %v, last was %v", i, pts, lastPTS) + lastPTS = pts + } +} + +func TestSyncEngine_EndDrain(t *testing.T) { + engine := NewSyncEngine() + + track := newMockAudioTrack("audio-1", 1000) + ts := engine.AddTrack(track, "alice") + + now := time.Now() + + // Prime and push some packets. + pkt0 := makeExtPacket(0, 0, now) + ts.PrimeForStart(pkt0) + ts.GetPTS(pkt0) + + for i := 1; i <= 10; i++ { + recvAt := now.Add(time.Duration(i) * 20 * time.Millisecond) + rtpTS := uint32(i) * 960 + pkt := makeExtPacket(rtpTS, uint16(i), recvAt) + ts.GetPTS(pkt) + } + + require.Equal(t, int64(0), engine.GetEndedAt(), "endedAt should be 0 before End()") + + engine.End() + + require.Greater(t, engine.GetEndedAt(), int64(0), "endedAt should be > 0 after End()") + require.Greater(t, engine.GetStartedAt(), int64(0), "startedAt should be > 0") +} + +// makeSenderReport creates an rtcp.SenderReport with the given fields. +func makeSenderReport(ssrc uint32, ntpTime uint64, rtpTime uint32) *rtcp.SenderReport { + return &rtcp.SenderReport{ + SSRC: ssrc, + NTPTime: ntpTime, + RTPTime: rtpTime, + } +} diff --git a/pkg/synchronizer/syncenginetrack.go b/pkg/synchronizer/syncenginetrack.go new file mode 100644 index 00000000..3a339a73 --- /dev/null +++ b/pkg/synchronizer/syncenginetrack.go @@ -0,0 +1,385 @@ +// Copyright 2026 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package synchronizer + +import ( + "io" + "sync" + "time" + + "github.com/pion/webrtc/v4" + + "github.com/livekit/media-sdk/jitter" + + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/rtputil" +) + +const ( + // transitionSlewRatePerSecond is the rate at which the wall-clock→NTP + // transition correction is absorbed: 5ms per second of real time. + transitionSlewRatePerSecond = 5 * time.Millisecond + + // wallClockSanityThreshold is the maximum divergence between RTP-derived PTS + // and wall-clock PTS before falling back to wall clock in wallClockPTS(). + wallClockSanityThreshold = 5 * time.Second + + // ntpTrustThreshold is the maximum allowed divergence between NTP-derived PTS + // and wall-clock PTS. If NTP disagrees with wall clock by more than this, + // the NTP data is suspect (bad SRs, clock jumps, nonsensical timing) and + // we clamp to wall clock. This prevents bad publishers from dragging PTS far + // from reality. + ntpTrustThreshold = 500 * time.Millisecond + + // maxTimelyPacketAge is how long a track can be behind the pipeline deadline + // before its PTS is force-corrected forward. + maxTimelyPacketAge = 10 * time.Second + + // slewRatePerSecond is the maximum rate at which PTS corrections are absorbed. + slewRatePerSecond = 5 * time.Millisecond + + // deadbandThreshold is the minimum |correction| before slew smoothing kicks in. + deadbandThreshold = 5 * time.Millisecond +) + +// syncEngineTrack implements TrackSync for a single track within a SyncEngine. +type syncEngineTrack struct { + engine *SyncEngine + track TrackRemote + participantID string + logger logger.Logger + converter *rtputil.RTPConverter + startGate startGate // from start_gate.go, nil if not enabled + + mu sync.Mutex + startTime time.Time + sessionOffset time.Duration // offset from session start to this track's start + lastTS uint32 + lastPTS time.Duration + lastPTSAdjusted time.Duration + initialized bool + closed bool + + // NTP transition and smoothing + ntpTransitioned bool + transitionSlew time.Duration + lastSlewPTS time.Duration // PTS at which slew was last updated + lastNtpPTS time.Duration // last raw NTP PTS (before corrections), for jump detection + ntpCorrection time.Duration // smoothing correction for SR-induced NTP jumps + + // pipeline time feedback + lastTimelyPacket time.Time + + // drain + maxPTS time.Duration + maxPTSSet bool + + onSR func(drift time.Duration) +} + +// PrimeForStart implements TrackSync. It buffers packets through the optional +// start gate and initializes the track on the first valid packet. +func (st *syncEngineTrack) PrimeForStart(pkt jitter.ExtPacket) ([]jitter.ExtPacket, int, bool) { + st.mu.Lock() + defer st.mu.Unlock() + + if st.initialized || st.startGate == nil { + if !st.initialized { + st.initializeLocked(pkt) + } + return []jitter.ExtPacket{pkt}, 0, true + } + + ready, dropped, done := st.startGate.Push(pkt) + if !done { + return nil, dropped, false + } + + if len(ready) == 0 { + ready = []jitter.ExtPacket{pkt} + } + + if !st.initialized { + st.initializeLocked(ready[0]) + } + + return ready, dropped, true +} + +// initializeLocked sets the track's start time and registers with the engine. +// Caller must hold st.mu. +func (st *syncEngineTrack) initializeLocked(pkt jitter.ExtPacket) { + receivedAt := pkt.ReceivedAt + if receivedAt.IsZero() { + receivedAt = time.Now() + } + + st.startTime = receivedAt + st.lastTS = pkt.Timestamp + st.lastTimelyPacket = receivedAt + st.initialized = true + + // Initialize the engine's session start time. + sessionStart := st.engine.initializeIfNeeded(receivedAt) + st.sessionOffset = time.Duration(receivedAt.UnixNano() - sessionStart) + + st.logger.Infow("initialized track", + "startTime", st.startTime, + "sessionOffset", st.sessionOffset, + "rtpTS", pkt.Timestamp, + ) +} + +// GetPTS implements TrackSync. It computes the presentation timestamp for a packet +// using the NTP-grounded timeline when available, falling back to wall clock otherwise. +func (st *syncEngineTrack) GetPTS(pkt jitter.ExtPacket) (time.Duration, error) { + st.mu.Lock() + defer st.mu.Unlock() + + if st.closed { + return 0, io.EOF + } + + if !st.initialized { + st.initializeLocked(pkt) + } + + ts := pkt.Timestamp + + // Same RTP timestamp as last packet: return same PTS (same frame). + if ts == st.lastTS && st.lastPTSAdjusted > 0 { + return st.lastPTSAdjusted, nil + } + + // Drop packets older than threshold. + if st.engine.oldPacketThreshold > 0 && !pkt.ReceivedAt.IsZero() { + if time.Since(pkt.ReceivedAt) > st.engine.oldPacketThreshold { + return 0, ErrPacketTooOld + } + } + + // Step 1: Try NTP-grounded PTS from SessionTimeline. + rawNtpPTS, ntpErr := st.engine.timeline.GetSessionPTS(st.participantID, st.track.ID(), ts) + + wallPTS := st.wallClockPTS(pkt) + + // Audio tracks with external drift compensation (e.g., tempo controller) skip + // NTP PTS corrections — drift is handled by resampling, not PTS adjustment. + // NTP regression still runs (via OnSenderReport) for drift measurement. + useWallClockOnly := st.engine.audioDriftCompensated && st.track.Kind() == webrtc.RTPCodecTypeAudio + + // Step 2: Detect discontinuities and NTP regression jumps on RAW NTP PTS. + // This operates before any corrections to avoid feedback loops. + rtpDelta := ts - st.lastTS + rtpDeltaDuration := st.converter.ToDuration(rtpDelta) + + if st.lastTS != 0 && rtpDeltaDuration >= 30*time.Second { + // Discontinuity: stream restart, SSRC reuse with new RTP offset, or massive gap. + st.engine.timeline.ResetTrack(st.participantID, st.track.ID()) + st.lastNtpPTS = 0 + st.ntpCorrection = 0 + st.ntpTransitioned = false + st.transitionSlew = 0 + st.lastSlewPTS = 0 + st.logger.Warnw("stream discontinuity detected, resetting NTP state", nil, + "rtpDelta", rtpDelta, + "rtpDeltaDuration", rtpDeltaDuration, + ) + } else if !useWallClockOnly && ntpErr == nil && st.lastNtpPTS > 0 && rtpDelta > 0 { + // Detect regression jumps: compare raw NTP PTS against expected. + expectedRawNtpPTS := st.lastNtpPTS + rtpDeltaDuration + jump := rawNtpPTS - expectedRawNtpPTS + if jump > deadbandThreshold || jump < -deadbandThreshold { + st.ntpCorrection -= jump + st.logger.Debugw("NTP regression jump detected", + "jump", jump, + "ntpCorrection", st.ntpCorrection, + ) + } + } + if ntpErr == nil { + st.lastNtpPTS = rawNtpPTS // Always track raw NTP PTS, never corrected + } + + // Step 3: Compute final PTS with corrections. + var pts time.Duration + if ntpErr != nil || useWallClockOnly { + pts = wallPTS + } else { + // Apply NTP jump correction. + pts = rawNtpPTS + st.ntpCorrection + + // Clamp corrected PTS to within trust threshold of wall clock. + diff := pts - wallPTS + if diff > ntpTrustThreshold || diff < -ntpTrustThreshold { + st.logger.Warnw("NTP PTS exceeds trust threshold, clamping to wall clock", nil, + "rawNtpPTS", rawNtpPTS, + "ntpCorrection", st.ntpCorrection, + "wallPTS", wallPTS, + "diff", diff, + ) + pts = wallPTS + } + + // On first successful NTP PTS, compute transition correction. + if !st.ntpTransitioned { + st.transitionSlew = wallPTS - pts + st.ntpTransitioned = true + st.logger.Infow("NTP transition", + "wallPTS", wallPTS, + "ntpPTS", rawNtpPTS, + "transitionSlew", st.transitionSlew, + ) + } + } + + // Compute PTS delta for slew rate calculations. + var slewPTSDelta time.Duration + if st.lastSlewPTS > 0 { + slewPTSDelta = pts - st.lastSlewPTS + } + + // Step 4: Apply transition slew (absorb gradually toward zero). + if st.transitionSlew != 0 { + pts += st.transitionSlew + + if slewPTSDelta > 0 { + maxStep := time.Duration(float64(transitionSlewRatePerSecond) * slewPTSDelta.Seconds()) + if st.transitionSlew > 0 { + st.transitionSlew -= maxStep + if st.transitionSlew < 0 { + st.transitionSlew = 0 + } + } else { + st.transitionSlew += maxStep + if st.transitionSlew > 0 { + st.transitionSlew = 0 + } + } + } + } + + // Decay ntpCorrection toward zero via slew. + if st.ntpCorrection != 0 { + if slewPTSDelta > 0 { + maxStep := time.Duration(float64(slewRatePerSecond) * slewPTSDelta.Seconds()) + if st.ntpCorrection > 0 { + st.ntpCorrection -= maxStep + if st.ntpCorrection < 0 { + st.ntpCorrection = 0 + } + } else { + st.ntpCorrection += maxStep + if st.ntpCorrection > 0 { + st.ntpCorrection = 0 + } + } + } + } + + st.lastSlewPTS = pts + + // Step 6: Pipeline time feedback — if the track has fallen behind the + // pipeline's deadline for too long, force-correct PTS forward. + if deadline, ok := st.engine.getMediaDeadline(); ok && st.engine.maxMediaRunningTimeDelay > 0 { + limit := deadline - st.engine.maxMediaRunningTimeDelay + if pts < limit { + if time.Since(st.lastTimelyPacket) > maxTimelyPacketAge { + oldPTS := pts + pts = deadline - st.engine.maxMediaRunningTimeDelay/2 + st.logger.Warnw("force-correcting PTS forward, track behind pipeline deadline", nil, + "oldPTS", oldPTS, + "newPTS", pts, + "deadline", deadline, + "behindBy", limit-oldPTS, + ) + } + } else { + st.lastTimelyPacket = time.Now() + } + } + + // Step 7: Enforce monotonicity. + if pts < st.lastPTSAdjusted+time.Millisecond && st.lastPTSAdjusted > 0 { + pts = st.lastPTSAdjusted + time.Millisecond + } + + // Step 7: Enforce drain ceiling. + if st.maxPTSSet && pts > st.maxPTS { + return 0, io.EOF + } + + // Update state. + st.lastTS = ts + st.lastPTS = pts // the raw PTS before adjustment (for wall clock computation) + st.lastPTSAdjusted = pts + + return pts, nil +} + +// wallClockPTS computes a PTS based on wall-clock timing and RTP deltas. +func (st *syncEngineTrack) wallClockPTS(pkt jitter.ExtPacket) time.Duration { + ts := pkt.Timestamp + + // Same RTP timestamp as last packet: same frame. + if st.lastTS == ts && st.lastPTS > 0 { + return st.lastPTS + } + + // Wall-clock elapsed since this track started, plus session offset + wallElapsed := pkt.ReceivedAt.Sub(st.startTime) + st.sessionOffset + + // If we have a previous timestamp, use RTP delta for more precision. + if st.lastPTS > 0 { + rtpDelta := ts - st.lastTS + rtpDerived := st.lastPTS + st.converter.ToDuration(rtpDelta) + + // Sanity check: if RTP-derived PTS diverges from wall-clock by > 5s, use wall clock. + diff := rtpDerived - wallElapsed + if diff < 0 { + diff = -diff + } + if diff <= wallClockSanityThreshold { + return rtpDerived + } + } + + // Use wall-clock elapsed, ensuring non-negative. + if wallElapsed < 0 { + wallElapsed = 0 + } + return wallElapsed +} + +// OnSenderReport implements TrackSync. It stores a callback invoked on sender reports. +func (st *syncEngineTrack) OnSenderReport(f func(drift time.Duration)) { + st.mu.Lock() + defer st.mu.Unlock() + st.onSR = f +} + +// LastPTSAdjusted implements TrackSync. +func (st *syncEngineTrack) LastPTSAdjusted() time.Duration { + st.mu.Lock() + defer st.mu.Unlock() + return st.lastPTSAdjusted +} + +// Close implements TrackSync. +func (st *syncEngineTrack) Close() { + st.mu.Lock() + defer st.mu.Unlock() + st.closed = true +} diff --git a/pkg/synchronizer/synchronizer.go b/pkg/synchronizer/synchronizer.go index 9227aa7d..673f72c1 100644 --- a/pkg/synchronizer/synchronizer.go +++ b/pkg/synchronizer/synchronizer.go @@ -275,14 +275,14 @@ func NewSynchronizerWithOptions(opts ...SynchronizerOption) *Synchronizer { } } -func (s *Synchronizer) AddTrack(track TrackRemote, identity string) *TrackSynchronizer { +func (s *Synchronizer) AddTrack(track TrackRemote, participantID string) *TrackSynchronizer { t := newTrackSynchronizer(s, track) s.Lock() - p := s.psByIdentity[identity] + p := s.psByIdentity[participantID] if p == nil { p = newParticipantSynchronizer() - s.psByIdentity[identity] = p + s.psByIdentity[participantID] = p } ssrc := uint32(track.SSRC()) s.ssrcByID[track.ID()] = ssrc @@ -386,6 +386,22 @@ func (s *Synchronizer) GetEndedAt() int64 { return s.endedAt } +// SynchronizerAdapter wraps the legacy Synchronizer to implement the Sync interface. +// The Synchronizer's own AddTrack returns *TrackSynchronizer (concrete type); this +// adapter's AddTrack returns TrackSync so that *SynchronizerAdapter satisfies Sync. +type SynchronizerAdapter struct { + *Synchronizer +} + +func (a *SynchronizerAdapter) AddTrack(track TrackRemote, participantID string) TrackSync { + return a.Synchronizer.AddTrack(track, participantID) +} + +// AsSyncInterface returns a Sync-compatible wrapper around this Synchronizer. +func (s *Synchronizer) AsSyncInterface() Sync { + return &SynchronizerAdapter{Synchronizer: s} +} + func (s *Synchronizer) getExternalMediaDeadline() (time.Duration, bool) { s.RLock() startTime := s.externalMediaStartTime diff --git a/pkg/synchronizer/synchronizer_test.go b/pkg/synchronizer/synchronizer_test.go index a485989e..c5f40b6f 100644 --- a/pkg/synchronizer/synchronizer_test.go +++ b/pkg/synchronizer/synchronizer_test.go @@ -15,6 +15,10 @@ import ( "github.com/livekit/server-sdk-go/v2/pkg/synchronizer/synchronizerfakes" ) +// Compile-time interface checks +var _ synchronizer.Sync = (*synchronizer.SynchronizerAdapter)(nil) +var _ synchronizer.TrackSync = (*synchronizer.TrackSynchronizer)(nil) + const timeTolerance = time.Millisecond * 10 const fakeAudioTrackID = "audio-1"