diff --git a/Documentation/SpeakerManager.md b/Documentation/SpeakerManager.md index 8e8a9850b..026ff7ea1 100644 --- a/Documentation/SpeakerManager.md +++ b/Documentation/SpeakerManager.md @@ -1,3 +1,4 @@ + # SpeakerManager API Tracks and manages speaker identities across audio chunks for streaming diarization. @@ -73,6 +74,21 @@ let bob = Speaker(id: "bob", name: "Bob", currentEmbedding: bobEmbedding) speakerManager.initializeKnownSpeakers([alice, bob]) ``` +Sometimes, there are already speakers in the database that may have the same ID. +```swift +let alice = Speaker(id: "alice", name: "Alice", currentEmbedding: aliceEmbedding) +let bob = Speaker(id: "bob", name: "Bob", currentEmbedding: bobEmbedding) +speakerManager.initializeKnownSpeakers([alice, bob], mode: .overwrite, preserveIfPermanent: false) // replace any speakers with ID "alice" or "bob" with the new speakers, even if the old ones were marked as permanent. +``` + +> The `mode` argument dictates how to handle redundant speakers. It is of type `SpeakerInitializationMode`, and can take on one of four values: +> - `.reset`: reset the speaker database and add the new speakers +> - `.merge`: merge new speakers whose IDs match with existing ones +> - `.overwrite`: overwrite existing speakers with the same IDs as the new ones +> - `.skip`: skip adding speakers whose IDs match existing ones +> +> The `preserveIfPermanent` argument determines whether existing speakers marked as permanent should be preserved (i.e., not overwritten or merged). It is `true` by default. + **Use case:** When you have pre-recorded voice samples of known speakers and want to recognize them by name instead of numeric IDs. #### upsertSpeaker @@ -91,6 +107,7 @@ speakerManager.upsertSpeaker( updateCount: 5, // optional createdAt: Date(), // optional updatedAt: Date() // optional + isPermanent: false // optional ) ``` @@ -98,9 +115,137 @@ speakerManager.upsertSpeaker( - If speaker ID exists: updates the existing speaker's data - If speaker ID is new: inserts as a new speaker - Maintains ID uniqueness and tracks numeric IDs for auto-increment +- If `isPermanent` is true, then the new speaker or the existing speaker will become permanent. This means that the speaker will not be merged or removed without an override. + +#### mergeSpeaker +```swift +// merge speaker 1 into "alice" +speakerManager.mergeSpeaker("1", into: "alice") + +// merge speaker 2 into speaker 3 under the name "bob", regardless of whether speaker 2 is permanent. +speakerManager.mergeSpeaker("2", into: "3", mergedName: "Bob", stopIfPermanent: false) +``` + +**Behavior:** +- Unless `stopIfPermanent` is `false`, the merge will be stopped if the first speaker is permanent. +- Otherwise: Merges the first speaker into the destination speaker and removes the first speaker from the known speaker database. +- If `mergedName` is provided, the destination speaker will be renamed. Otherwise, its name will be preserved. + +> Note: the `mergedName` argument is optional. +> Note: `stopIfPermanent` is `true` by default. + +#### removeSpeaker +Remove a speaker from the database. + +```swift +// remove speaker 1 +speakerManager.removeSpeaker("1") + +// remove "alice" from the known speaker database, even if they are marked as permanent +speakerManager.removeSpeaker("alice", keepIfPermanent: false) +``` +> Note: `keepIfPermanent` is `true` by default. + +#### removeSpeakersInactive +Remove speakers that have been inactive since a certain date or for a certain duration. + +```swift +// remove speakers that have been inactive since `date` +speakerManager.removeSpeakersInactive(since: date) + +// remove speakers that have been inactive for 10 seconds, even if they were marked as permanent +speakerManager.removeSpeakersInactive(for: 10.0, keepIfPermanent: false) +``` + +> Note: Both versions of the method have an optional `keepIfPermanent` argument that defaults to `true`. + +#### removeAllSpeakers +Remove all speakers that match a given predicate. + +```swift +// remove all speakers with less than 5 seconds of speaking time +speakerManager.removeSpeakers( + where: { $0.duration < 5.0 }, + keepIfPermanent: false // also remove permanent speakers (optional) +) + +// Alternate syntax (does NOT remove permanent speakers) +speakerManager.removeSpeakers { + $0.duration < 5.0 +} +``` + +> Note: the predicate should take in a `Speaker` object and return a `Bool`. + +#### makeSpeakerPermanent +Make the speaker permanent. + +```swift +speakerManager.makeSpeakerPermanent("alice") // mark "alice" as permanent +``` + +#### revokePermanence +Make the speaker not permanent. + +```swift +speakerManager.revokePermanence(from: "alice") // mark "alice" as not permanent +``` + +#### resetPermanentFlags +Mark all speakers as not permanent. + +```swift +speakerManager.resetPermanentFlags() +``` ### Speaker Retrieval +#### findSpeaker +Find the best matching speaker to an embedding vector and the cosine distance to them, unless no match is found. + +```swift +let (id, distance) = speakerManager.findSpeaker(with: embedding) +``` +> Note: there is an optional `speakerThreshold` argument to use a threshold other than the default. + +#### findMatchingSpeakers +Find all speakers within the maximum `speakerThreshold` to an embedding vector. + +```swift +for speaker in speakerManager.findMatchingSpeakers(with: embedding) { + print("ID: \(speaker.id), Distance: \(speaker.distance)") +} +``` + +> Note: there is an optional `speakerThreshold` argument to use a threshold other than the default. + +#### findSpeakers +Find all speakers that meet a certain predicate. +```swift +// two ways to find all speakers with > 5.0s of speaking time. +speakerManager.findSpeakers(where: { $0.duration > 5.0 }) +speakerManager.findSpeakers{ $0.duration > 5.0 } +// Returns an array of IDs corresponding to speakers that meet the predicate. +``` + +> Note: the predicate should take in a `Speaker` object and return a `Bool`. + +#### findMergeablePairs +Find all pairs of speakers that might be the same person. Specifically, find the pairs of speakers such that the cosine distance between them is less than the `speakerThreshold`. + +Returns a list of pairs of speaker IDs. + +```swift +let pairs = speakerManager.findMergeablePairs( + speakerThreshold: 0.6, // optional + excludeIfBothPermanent: true // optional +) + +for pair in pairs { + print("Merge Speaker \(pair.speakerToMerge) into Speaker \(pair.destination)") +} +``` + #### getSpeaker Get a specific speaker by ID. @@ -118,6 +263,22 @@ let allSpeakers = speakerManager.getAllSpeakers() // Returns: [String: Speaker] - dictionary keyed by speaker ID ``` +#### getSpeakerList +Get all speakers in the database as an array of speakers (for testing/debugging) +```swift +let allSpeakers = speakerManager.getSpeakerList() +// Returns: [Speaker] - Array of speakers +``` + +#### hasSpeaker +Check if the speaker database has a speaker with a given ID. + +```swift +if speakerManager.hasSpeaker("alice") { + print("Alice was found in the database") +} +``` + #### speakerCount Get the total number of tracked speakers. @@ -140,6 +301,7 @@ Clear all speakers from the database. ```swift speakerManager.reset() +speakerManager.reset(keepIfPermanent: true) // remove all non-permanent speakers from the database ``` Useful for: @@ -147,6 +309,8 @@ Useful for: - Freeing memory between recordings - Resetting speaker tracking + + ## Speaker Enrollment The `Speaker` class includes a `name` field for speaker enrollment workflows: @@ -237,6 +401,7 @@ public final class Speaker: Identifiable, Codable { public var updatedAt: Date // Last update timestamp public var updateCount: Int // Number of updates public var rawEmbeddings: [RawEmbedding] // Historical embeddings (max 50) + public var isPermanent: Bool // Permanence flag } ``` @@ -547,13 +712,25 @@ class RealtimeDiarizer { | Method | Returns | Description | |--------|---------|-------------| | `assignSpeaker(_:speechDuration:confidence:)` | `Speaker?` | Assign/create speaker from embedding | -| `initializeKnownSpeakers(_:)` | `Void` | Pre-load known speaker profiles | +| `initializeKnownSpeakers(_:mode:preserveIfPermanent:)` | `Void` | Pre-load known speaker profiles | +| `findSpeaker(with:speakerThreshold:)` | `(id: String?, distance: Float)` | Find speaker that matches an embedding | +| `findMatchingSpeakers(with:speakerThreshold:)` | `[(id: String, distance: Float)]` | Find all speakers that match an embedding | +| `findSpeakers(where:)` | `[String]` | Find all speakers that meet a certain predicate +| `findMergeablePairs(speakerThreshold:excludeIfBothPermanent:)` | `[(speakerToMerge: String, destination: String)]` | Find all pairs of very similar speakers | +| `removeSpeaker(_:keepIfPermanent:)` | `Void` | Remove a speaker from the database | +| `removeSpeakersInactive(since:keepIfPermanent:)` | `Void` | Remove speakers inactive since a given date | +| `removeSpeakersInactive(for:keepIfPermanent:)` | `Void` | Remove speakers inactive for a given duration | +| `removeSpeakers(where:)` | `Void` | Remove speakers that satisfy a given predicate | +| `removeSpeakers(where:keepIfPermanent:)` | `Void` | Remove speakers that satisfy a given predicate | +| `mergeSpeaker(_:into:mergedName:stopIfPermanent:)` | `Void` | Merge a speaker into another one | | `upsertSpeaker(_:)` | `Void` | Update or insert speaker (from object) | | `upsertSpeaker(id:currentEmbedding:duration:...)` | `Void` | Update or insert speaker (from params) | | `getSpeaker(for:)` | `Speaker?` | Get speaker by ID | | `getAllSpeakers()` | `[String: Speaker]` | Get all speakers (debugging) | -| `reset()` | `Void` | Clear speaker database | -| `reassignSegment(segmentId:from:to:)` | `Bool` | Move segment between speakers | +| `getSpeakerList()` | `[Speaker]` | Get array of all speakers (debugging) | +| `hasSpeaker(_:)` | `Bool` | Check if database has a speaker with a given ID | +| `reset(keepIfPermanent:)` | `Void` | Clear speaker database | +| `resetPermanentFlags()` | `Void` | Mark all speakers as not permanent | | `getCurrentSpeakerNames()` | `[String]` | Get sorted speaker IDs | | `getGlobalSpeakerStats()` | `(Int, Float, Float, Int)` | Aggregate statistics | @@ -567,6 +744,7 @@ class RealtimeDiarizer { | `minEmbeddingUpdateDuration` | `Float` | Min duration to update embeddings (seconds) | | `speakerCount` | `Int` | Number of tracked speakers | | `speakerIds` | `[String]` | Sorted array of speaker IDs | +| `permanentSpeakerIds` | `[String]` | Sorted array of speaker IDs of permanent speakers | ### Speaker Properties @@ -580,6 +758,7 @@ class RealtimeDiarizer { | `updatedAt` | `Date` | Last update timestamp | | `updateCount` | `Int` | Number of embedding updates | | `rawEmbeddings` | `[RawEmbedding]` | Historical embeddings (max 50) | +| `isPermanent` | `Bool` | Permanence flag | ### Speaker Methods @@ -602,6 +781,7 @@ class RealtimeDiarizer { | `averageEmbeddings(_:)` | `[Float]?` | Average multiple embeddings | | `createSpeaker(id:name:duration:embedding:config:)` | `Speaker?` | Create validated speaker | | `updateEmbedding(current:new:alpha:)` | `[Float]?` | EMA update (pure function) | +| `reassignSegment(segmentId:from:to:)` | `Bool` | Move segment between speakers | ## See Also diff --git a/Sources/FluidAudio/Diarizer/Clustering/SpeakerManager.swift b/Sources/FluidAudio/Diarizer/Clustering/SpeakerManager.swift index fccad9bae..ba340265a 100644 --- a/Sources/FluidAudio/Diarizer/Clustering/SpeakerManager.swift +++ b/Sources/FluidAudio/Diarizer/Clustering/SpeakerManager.swift @@ -35,7 +35,18 @@ public class SpeakerManager { self.minEmbeddingUpdateDuration = minEmbeddingUpdateDuration } - public func initializeKnownSpeakers(_ speakers: [Speaker]) { + /// Add known speakers to the database + /// - Parameters: + /// - speakers: Array of `Speaker`s to add + /// - mode: Mode for handling overlapping ID conflicts. + /// - preservePermanent: Whether to avoid overwriting/merging pre-existing permanent speakers + public func initializeKnownSpeakers( + _ speakers: [Speaker], mode: SpeakerInitializationMode = .skip, preserveIfPermanent: Bool = true + ) { + if mode == .reset { + self.reset(keepIfPermanent: preserveIfPermanent) + } + queue.sync(flags: .barrier) { var maxNumericId = 0 @@ -46,7 +57,36 @@ public class SpeakerManager { continue } - speakerDatabase[speaker.id] = speaker + // Check if the speaker ID is already initialized + if let oldSpeaker = self.speakerDatabase[speaker.id] { + // Handle duplicate speaker + switch mode { + case .reset, .overwrite: + if !(oldSpeaker.isPermanent && preserveIfPermanent) { + logger.warning("Speaker \(speaker.id) is already initialized. Overwriting old speaker.") + speakerDatabase[speaker.id] = speaker + } else { + logger.warning( + "Failed to overwrite Speaker \(speaker.id) because it is permanent. Skipping") + continue + } + case .merge: + if !(oldSpeaker.isPermanent && preserveIfPermanent) { + logger.warning("Speaker \(speaker.id) is already initialized. Merging with old speaker.") + oldSpeaker.mergeWith(speaker, keepName: speaker.name) + } else { + logger.warning( + "Failed to merge Speaker \(speaker.id) into Speaker \(oldSpeaker.id) because the existing speaker is permanent. Skipping" + ) + continue + } + case .skip: + logger.warning("Speaker \(speaker.id) is already initialized. Skipping new speaker.") + continue + } + } else { + speakerDatabase[speaker.id] = speaker + } // Try to extract numeric ID if it's a pure number if let numericId = Int(speaker.id) { @@ -67,10 +107,20 @@ public class SpeakerManager { } } + /// Match the embedding to the closest existing speaker if sufficiently similar or create a new one if not. + /// - Parameters: + /// - embedding: 256D speaker embedding vector + /// - speechDuration: Duration of the speech segment during which this speaker was active + /// - confidence: Confidence in the embedding vector being correct + /// - speakerThreshold: The maximum cosine distance to an existing speaker to create a new one (uses the default threshold for this `SpeakerManager` object if none is provided) + /// - newName: Name to assign the speaker if a new one is created (default: `Speaker $id`) + /// - Returns: A `Speaker` object if a match was found or a new one was created. Returns `nil` if an error occurred. public func assignSpeaker( _ embedding: [Float], speechDuration: Float, - confidence: Float = 1.0 + confidence: Float = 1.0, + speakerThreshold: Float? = nil, + newName: String? = nil ) -> Speaker? { guard !embedding.isEmpty && embedding.count == Self.embeddingSize else { logger.error("Invalid embedding size: \(embedding.count)") @@ -78,6 +128,7 @@ public class SpeakerManager { } let normalizedEmbedding = VDSPOperations.l2Normalize(embedding) + let speakerThreshold = speakerThreshold ?? self.speakerThreshold return queue.sync(flags: .barrier) { let (closestSpeaker, distance) = findClosestSpeaker(to: normalizedEmbedding) @@ -117,6 +168,265 @@ public class SpeakerManager { } } + /// Find the closest existing speaker to an embedding, up to a maximum cosine distance of `speakerThreshold`. + /// - Parameters: + /// - embedding: 256D speaker embedding vector + /// - speakerThreshold: Maximum cosine distance to an existing speaker to create a new one (uses the default threshold for this `SpeakerManager` object if none is provided) + /// - Returns: ID of the match (if found) and the distance to that match. + public func findSpeaker(with embedding: [Float], speakerThreshold: Float? = nil) -> (id: String?, distance: Float) { + queue.sync { + let (closestSpeakerId, minDistance) = findClosestSpeaker(to: embedding) + let speakerThreshold = speakerThreshold ?? self.speakerThreshold + if let closestSpeakerId, minDistance <= speakerThreshold { + return (closestSpeakerId, minDistance) + } + return (nil, .infinity) + } + } + + /// Find the closest existing speaker to an embedding, up to a maximum cosine distance of `speakerThreshold`. + /// - Parameters: + /// - embedding: 256D speaker embedding vector + /// - speakerThreshold: Maximum cosine distance between `embedding` and another speaker for them to be a match (default: `self.speakerThreshold`) + /// - Returns: Array of the `maxCount` nearest speakers and the distances to them from `embedding`, sorted by ascending cosine distances (from closest to farthest). + public func findMatchingSpeakers( + with embedding: [Float], speakerThreshold: Float? = nil + ) -> [(id: String, distance: Float)] { + queue.sync { + var matches: [(id: String, distance: Float)] = [] + let speakerThreshold = speakerThreshold ?? self.speakerThreshold + + for (speakerId, speaker) in speakerDatabase { + let distance = cosineDistance(embedding, speaker.currentEmbedding) + if distance <= speakerThreshold { + matches.append((speakerId, distance)) + } + } + matches.sort { $0.distance < $1.distance } + return matches + } + } + + /// Find all speakers that meet a certain predicate + /// - Parameter predicate: Condition the speakers must meet to be returned + /// - Returns: A list of all Speaker IDs corresponding to Speakers that meet the predicate + public func findSpeakers(where predicate: (Speaker) -> Bool) -> [String] { + queue.sync { + return speakerDatabase.filter { predicate($0.value) }.map(\.key) + } + } + + /// Mark a speaker as permanent + /// - Parameter speakerId: ID of the speaker to mark as permanent + public func makeSpeakerPermanent(_ speakerId: String) { + queue.sync(flags: .barrier) { + guard let speaker = speakerDatabase[speakerId] else { + logger.warning("Failed to mark speaker \(speakerId) as permanent (speaker not found).") + return + } + logger.info("Marking speaker \(speakerId) as permanent.") + speaker.isPermanent = true + } + } + + /// Remove a speaker's permanent marker + /// - Parameter speakerId: ID of the speaker from which to remove the permanent marker + public func revokePermanence(from speakerId: String) { + queue.sync(flags: .barrier) { + guard let speaker = speakerDatabase[speakerId] else { + logger.warning("Failed to revoke permanence from speaker \(speakerId) (speaker not found).") + return + } + + logger.info("Revoking permanence from speaker \(speakerId).") + speaker.isPermanent = false + } + } + + /// Merge two speakers in the database. + /// - Parameters: + /// - sourceId: ID of the `Speaker` being merged + /// - destinationId: ID of the `Speaker` that absorbs the other one + /// - mergedName: New name for the merged speaker (uses `destination`'s name if not provided) + /// - stopIfPermanent: Whether to stop merging if the source speaker is permanent + public func mergeSpeaker( + _ sourceId: String, into destinationId: String, mergedName: String? = nil, stopIfPermanent: Bool = true + ) { + // don't merge a speaker into itself + guard sourceId != destinationId else { + return + } + + queue.sync(flags: .barrier) { + // ensure both speakers exist + guard let speakerToMerge = speakerDatabase[sourceId], + let destinationSpeaker = speakerDatabase[destinationId] + else { + return + } + + // don't merge permanent speakers into another one + guard !(stopIfPermanent && speakerToMerge.isPermanent) else { + return + } + + // merge source into destination + destinationSpeaker.mergeWith(speakerToMerge, keepName: mergedName) + + // remove source speaker + speakerDatabase.removeValue(forKey: sourceId) + } + } + + /// Find all pairs of speakers that can be merged + /// - Parameters: + /// - speakerThreshold: Max cosine distance between speakers to let them be considered mergeable + /// - excludeIfBothPermanent: Whether to exclude speaker pairs where both speakers are permanent + /// - Returns: Array of speaker ID pairs that belong to speakers that are similar enough to be merged + public func findMergeablePairs( + speakerThreshold: Float? = nil, excludeIfBothPermanent: Bool = true + ) -> [(speakerToMerge: String, destination: String)] { + queue.sync { + let speakerThreshold = speakerThreshold ?? self.speakerThreshold + var pairs: [(speakerToMerge: String, destination: String)] = [] + let ids = Array(speakerDatabase.keys) + + for i in (0.. Bool, keepIfPermanent: Bool = true) { + queue.sync(flags: .barrier) { + if keepIfPermanent { + // don't remove permanent speakers + for (speakerId, speaker) in speakerDatabase where predicate(speaker) && !speaker.isPermanent { + speakerDatabase.removeValue(forKey: speakerId) + logger.info("Removing speaker \(speakerId) based on predicate") + } + } else { + for (speakerId, speaker) in speakerDatabase where predicate(speaker) { + speakerDatabase.removeValue(forKey: speakerId) + logger.info("Removing speaker \(speakerId) based on predicate") + } + } + } + } + + /// Remove non-permanent speakers that meet a certain predicate + /// - Parameters: + /// - predicate: Predicate to determine whether the speaker should be removed + public func removeSpeakers(where predicate: (Speaker) -> Bool) { + removeSpeakers(where: predicate, keepIfPermanent: true) + } + + /// Check if the speaker database has a speaker with a given ID. + /// - Parameter speakerId: ID to check + /// - Returns: `true` if a speaker is found, `false` if not + public func hasSpeaker(_ speakerId: String) -> Bool { + queue.sync { + return speakerDatabase.keys.contains(speakerId) + } + } + + private func findDistanceToClosestSpeaker(to embedding: [Float]) -> Float { + return speakerDatabase.values.reduce(Float.infinity) { + min($0, cosineDistance(embedding, $1.currentEmbedding)) + } + } + private func findClosestSpeaker(to embedding: [Float]) -> (speakerId: String?, distance: Float) { var minDistance: Float = Float.infinity var closestSpeakerId: String? @@ -167,19 +477,23 @@ public class SpeakerManager { private func createNewSpeaker( embedding: [Float], duration: Float, - distanceToClosest: Float + distanceToClosest: Float, + name: String? = nil, + isPermanent: Bool = false ) -> String { let normalizedEmbedding = VDSPOperations.l2Normalize(embedding) let newSpeakerId = String(nextSpeakerId) + let newSpeakerName = name ?? "Speaker \(newSpeakerId)" // Default name with number if not provided nextSpeakerId += 1 highestSpeakerId = max(highestSpeakerId, nextSpeakerId - 1) // Create new Speaker object let newSpeaker = Speaker( id: newSpeakerId, - name: "Speaker \(newSpeakerId)", // Default name with number + name: newSpeakerName, currentEmbedding: normalizedEmbedding, - duration: duration + duration: duration, + isPermanent: isPermanent ) // Add initial raw embedding @@ -206,6 +520,10 @@ public class SpeakerManager { queue.sync { Array(speakerDatabase.keys).sorted() } } + public var permanentSpeakerIds: [String] { + queue.sync { Array(speakerDatabase.filter(\.value.isPermanent).keys).sorted() } + } + /// Get all speakers (for testing/debugging). public func getAllSpeakers() -> [String: Speaker] { queue.sync { @@ -213,6 +531,13 @@ public class SpeakerManager { } } + /// Get list of all speakers. + public func getSpeakerList() -> [Speaker] { + queue.sync { + return [Speaker](speakerDatabase.values) + } + } + public func getSpeaker(for speakerId: String) -> Speaker? { queue.sync { speakerDatabase[speakerId] } } @@ -226,11 +551,12 @@ public class SpeakerManager { rawEmbeddings: speaker.rawEmbeddings, updateCount: speaker.updateCount, createdAt: speaker.createdAt, - updatedAt: speaker.updatedAt + updatedAt: speaker.updatedAt, + isPermanent: speaker.isPermanent ) } - /// Upsert a speaker - update if exists, insert if new + /// Upsert a speaker - update if ID exists, insert if new /// /// - Parameters: /// - id: The speaker ID @@ -240,6 +566,7 @@ public class SpeakerManager { /// - updateCount: Number of updates to this speaker /// - createdAt: Creation timestamp /// - updatedAt: Last update timestamp + /// - isPermanent: Whether the speaker should be protected from merges and removals by default public func upsertSpeaker( id: String, currentEmbedding: [Float], @@ -247,7 +574,8 @@ public class SpeakerManager { rawEmbeddings: [RawEmbedding] = [], updateCount: Int = 1, createdAt: Date? = nil, - updatedAt: Date? = nil + updatedAt: Date? = nil, + isPermanent: Bool = false ) { queue.sync(flags: .barrier) { let now = Date() @@ -259,6 +587,7 @@ public class SpeakerManager { existingSpeaker.rawEmbeddings = rawEmbeddings existingSpeaker.updateCount = updateCount existingSpeaker.updatedAt = updatedAt ?? now + existingSpeaker.isPermanent = existingSpeaker.isPermanent || isPermanent // Keep original createdAt and name speakerDatabase[id] = existingSpeaker @@ -271,7 +600,8 @@ public class SpeakerManager { currentEmbedding: currentEmbedding, duration: duration, createdAt: createdAt ?? now, - updatedAt: updatedAt ?? now + updatedAt: updatedAt ?? now, + isPermanent: isPermanent ) newSpeaker.rawEmbeddings = rawEmbeddings @@ -290,12 +620,36 @@ public class SpeakerManager { } } - public func reset() { + /// Reset the speaker database + /// - Parameter keepIfPermanent: Whether to keep permanent speakers + public func reset(keepIfPermanent: Bool = false) { queue.sync(flags: .barrier) { - speakerDatabase.removeAll() - nextSpeakerId = 1 - highestSpeakerId = 0 + if !keepIfPermanent { + speakerDatabase.removeAll() + nextSpeakerId = 1 + highestSpeakerId = 0 + } else { + speakerDatabase = speakerDatabase.filter(\.value.isPermanent) + // Recalculate nextSpeakerId and highestSpeakerId based on remaining permanent speakers + var maxNumericId = 0 + for id in speakerDatabase.keys { + if let numericId = Int(id) { + maxNumericId = max(maxNumericId, numericId) + } + } + highestSpeakerId = maxNumericId + nextSpeakerId = maxNumericId + 1 + } logger.info("Speaker database reset") } } + + /// Mark all speakers as not permanent + public func resetPermanentFlags() { + queue.sync(flags: .barrier) { + speakerDatabase.forEach { + $0.value.isPermanent = false + } + } + } } diff --git a/Sources/FluidAudio/Diarizer/Clustering/SpeakerTypes.swift b/Sources/FluidAudio/Diarizer/Clustering/SpeakerTypes.swift index 77562d117..741285ebb 100644 --- a/Sources/FluidAudio/Diarizer/Clustering/SpeakerTypes.swift +++ b/Sources/FluidAudio/Diarizer/Clustering/SpeakerTypes.swift @@ -4,22 +4,43 @@ import Foundation /// Speaker profile representation for tracking speakers across audio /// This represents a speaker's identity, not a specific segment public final class Speaker: Identifiable, Codable, Equatable, Hashable { + /// Speaker ID public let id: String + /// Speaker name public var name: String + /// Main embedding vector for this speaker's voice public var currentEmbedding: [Float] + /// Total speech duration for this speaker public var duration: Float = 0 + /// Date that this speaker object was created public var createdAt: Date + /// Date that this speaker object was last updated public var updatedAt: Date + /// Number of times the embedding vector was updated public var updateCount: Int = 1 + /// Array of raw embedding vectors public var rawEmbeddings: [RawEmbedding] = [] - + /// Whether this speaker can be deleted due to inactivity or merging + public var isPermanent: Bool = false + + /// - Parameters: + /// - id: Speaker ID + /// - name: Speaker name + /// - currentEmbedding: Main embedding vector for this speaker's voice + /// - duration: Total speech duration for this speaker + /// - createdAt: Date that this speaker object was last updated + /// - updatedAt: Number of times the embedding vector was updated + /// - updateCount: Array of raw embedding vectors + /// - rawEmbeddings: Array of raw embedding vectors + /// - isPermanent: Whether this speaker can be deleted due to inactivity or merging public init( id: String? = nil, name: String? = nil, currentEmbedding: [Float], duration: Float = 0, createdAt: Date? = nil, - updatedAt: Date? = nil + updatedAt: Date? = nil, + isPermanent: Bool = false ) { let now = Date() self.id = id ?? UUID().uuidString @@ -30,6 +51,7 @@ public final class Speaker: Identifiable, Codable, Equatable, Hashable { self.updatedAt = updatedAt ?? now self.updateCount = 1 self.rawEmbeddings = [] + self.isPermanent = isPermanent } /// Convert to SendableSpeaker format for cross-boundary usage. @@ -37,14 +59,18 @@ public final class Speaker: Identifiable, Codable, Equatable, Hashable { return SendableSpeaker(from: self) } - /// Update main embedding with new segment data using exponential moving average + /// Update main embedding with new segment data using exponential moving average (EMA) + /// - Parameters: + /// - duration: Segment duration + /// - embedding: 256D speaker embedding vector + /// - segmentId: The ID of the segment + /// - alpha: EMA blending parameter public func updateMainEmbedding( duration: Float, embedding: [Float], segmentId: UUID, alpha: Float = 0.9 ) { - // Validate embedding quality var sumSquares: Float = 0 vDSP_svesq(embedding, 1, &sumSquares, vDSP_Length(embedding.count)) @@ -136,6 +162,9 @@ public final class Speaker: Identifiable, Codable, Equatable, Hashable { } /// Merge another speaker into this one + /// - Parameters: + /// - other: Other Speaker to merge + /// - keepName: The resulting name after the merge public func mergeWith(_ other: Speaker, keepName: String? = nil) { // Merge raw embeddings var allEmbeddings = rawEmbeddings + other.rawEmbeddings @@ -239,3 +268,15 @@ public struct SendableSpeaker: Sendable, Identifiable, Hashable { return lhs.id == rhs.id && lhs.name == rhs.name } } + +/// Configuration for handling initializing known speakers +public enum SpeakerInitializationMode { + /// Reset the speaker database and add the new speakers + case reset + /// Merge new speakers whose IDs match with existing ones + case merge + /// Overwrite existing speakers with the same IDs as the new ones + case overwrite + /// Skip speakers whose IDs match existing ones + case skip +} diff --git a/Sources/FluidAudio/Diarizer/Core/DiarizerManager.swift b/Sources/FluidAudio/Diarizer/Core/DiarizerManager.swift index 00d27c3ae..a671358d5 100644 --- a/Sources/FluidAudio/Diarizer/Core/DiarizerManager.swift +++ b/Sources/FluidAudio/Diarizer/Core/DiarizerManager.swift @@ -474,7 +474,7 @@ public final class DiarizerManager { } private func calculateEmbeddingQuality(_ embedding: [Float]) -> Float { - let magnitude = sqrt(embedding.map { $0 * $0 }.reduce(0, +)) + let magnitude = sqrt(vDSP.sumOfSquares(embedding)) return min(1.0, magnitude / 10.0) } diff --git a/Tests/FluidAudioTests/SpeakerManagerTests.swift b/Tests/FluidAudioTests/SpeakerManagerTests.swift index b8c0ab350..84986f620 100644 --- a/Tests/FluidAudioTests/SpeakerManagerTests.swift +++ b/Tests/FluidAudioTests/SpeakerManagerTests.swift @@ -15,6 +15,10 @@ final class SpeakerManagerTests: XCTestCase { return embedding } + private func normalizedEmbedding(pattern: Int) -> [Float] { + VDSPOperations.l2Normalize(createDistinctEmbedding(pattern: pattern)) + } + // MARK: - Basic Operations func testInitialization() { @@ -119,6 +123,80 @@ final class SpeakerManagerTests: XCTestCase { XCTAssertEqual(assignedSpeaker?.id, "Alice") } + func testInitializeKnownSpeakersPreservesPermanentByDefault() { + let manager = SpeakerManager() + + let original = Speaker( + id: "Alice", + name: "Original", + currentEmbedding: createDistinctEmbedding(pattern: 10), + duration: 4.0 + ) + manager.initializeKnownSpeakers([original]) + manager.makeSpeakerPermanent("Alice") + + let replacement = Speaker( + id: "Alice", + name: "Replacement", + currentEmbedding: createDistinctEmbedding(pattern: 20), + duration: 8.0 + ) + + manager.initializeKnownSpeakers([replacement], mode: .overwrite, preserveIfPermanent: true) + + let stored = manager.getSpeaker(for: "Alice") + XCTAssertEqual(stored?.name, "Original") + XCTAssertEqual(stored?.duration, 4.0) + } + + func testInitializeKnownSpeakersOverwriteCanReplacePermanentWhenAllowed() { + let manager = SpeakerManager() + + let original = Speaker( + id: "Alice", + name: "Original", + currentEmbedding: createDistinctEmbedding(pattern: 10), + duration: 4.0, + isPermanent: true + ) + manager.initializeKnownSpeakers([original]) + + let replacement = Speaker( + id: "Alice", + name: "Replacement", + currentEmbedding: createDistinctEmbedding(pattern: 20), + duration: 10.0 + ) + + manager.initializeKnownSpeakers([replacement], mode: .overwrite, preserveIfPermanent: false) + + let stored = manager.getSpeaker(for: "Alice") + XCTAssertEqual(stored?.name, "Replacement") + XCTAssertEqual(stored?.duration, 10.0) + } + + func testInitializeKnownSpeakersMergeCombinesDurations() { + let manager = SpeakerManager() + + let base = Speaker( + id: "Alice", + name: "Alice", + currentEmbedding: createDistinctEmbedding(pattern: 10), + duration: 2.0 + ) + let incoming = Speaker( + id: "Alice", + name: "Alice", + currentEmbedding: createDistinctEmbedding(pattern: 11), + duration: 3.0 + ) + + manager.initializeKnownSpeakers([base]) + manager.initializeKnownSpeakers([incoming], mode: .merge) + + XCTAssertEqual(manager.getSpeaker(for: "Alice")?.duration, 5.0) + } + func testInvalidEmbeddingSize() { let manager = SpeakerManager() @@ -206,6 +284,51 @@ final class SpeakerManagerTests: XCTestCase { } } + // MARK: - Lookup Helpers + + func testFindSpeakerAndMatchingSpeakers() { + let manager = SpeakerManager(speakerThreshold: 0.8) + + manager.upsertSpeaker(id: "A", currentEmbedding: normalizedEmbedding(pattern: 1), duration: 5.0) + manager.upsertSpeaker(id: "B", currentEmbedding: normalizedEmbedding(pattern: 2), duration: 5.0) + + let (matchId, distance) = manager.findSpeaker(with: normalizedEmbedding(pattern: 1)) + XCTAssertEqual(matchId, "A") + XCTAssertEqual(distance, 0.0, accuracy: 0.0001) + + var orthogonalEmbedding0 = [Float](repeating: 0, count: 256) + var orthogonalEmbedding1 = [Float](repeating: 0, count: 256) + orthogonalEmbedding0[0] = 1 + orthogonalEmbedding1[1] = 1 + manager.upsertSpeaker(id: "C", currentEmbedding: orthogonalEmbedding0, duration: 5.0) + let (missingId, missingDistance) = manager.findSpeaker( + with: orthogonalEmbedding1, + speakerThreshold: 0.5 + ) + XCTAssertNil(missingId) + XCTAssertEqual(missingDistance, .infinity) + + let combined = zip(normalizedEmbedding(pattern: 1), normalizedEmbedding(pattern: 2)).map { ($0 + $1) / 2 } + let matches = manager.findMatchingSpeakers( + with: VDSPOperations.l2Normalize(combined), + speakerThreshold: 2.0 + ) + + XCTAssertEqual(matches.count, 3) + XCTAssertLessThanOrEqual(matches[0].distance, matches[1].distance) + XCTAssertEqual(Set(matches.map(\.id)), Set(["A", "B", "C"])) + } + + func testFindSpeakersWhereFiltersByPredicate() { + let manager = SpeakerManager() + manager.upsertSpeaker(id: "short", currentEmbedding: normalizedEmbedding(pattern: 10), duration: 1.0) + manager.upsertSpeaker(id: "long", currentEmbedding: normalizedEmbedding(pattern: 20), duration: 8.0) + + let filtered = manager.findSpeakers { $0.duration > 5.0 } + XCTAssertEqual(filtered.count, 1) + XCTAssertEqual(filtered.first, "long") + } + // MARK: - Clear Operations func testResetSpeakers() { @@ -426,6 +549,118 @@ final class SpeakerManagerTests: XCTestCase { XCTAssertEqual(info?.rawEmbeddings.count, 1) } + // MARK: - Permanence & Merge Operations + + func testMakeAndRevokePermanentSpeakers() throws { + let manager = SpeakerManager() + let speaker = manager.assignSpeaker(createDistinctEmbedding(pattern: 1), speechDuration: 2.5) + let id = try XCTUnwrap(speaker?.id) + + manager.makeSpeakerPermanent(id) + XCTAssertTrue(manager.permanentSpeakerIds.contains(id)) + + manager.removeSpeaker(id) + XCTAssertTrue(manager.hasSpeaker(id)) + + manager.revokePermanence(from: id) + manager.removeSpeaker(id) + XCTAssertFalse(manager.hasSpeaker(id)) + } + + func testMergeSpeakerRespectsPermanentFlag() throws { + let manager = SpeakerManager() + let speaker1 = manager.assignSpeaker(createDistinctEmbedding(pattern: 1), speechDuration: 3.0) + let speaker2 = manager.assignSpeaker(createDistinctEmbedding(pattern: 2), speechDuration: 4.0) + + let id1 = try XCTUnwrap(speaker1?.id) + let id2 = try XCTUnwrap(speaker2?.id) + + manager.makeSpeakerPermanent(id1) + manager.mergeSpeaker(id1, into: id2) + XCTAssertTrue(manager.hasSpeaker(id1)) + XCTAssertTrue(manager.hasSpeaker(id2)) + + manager.mergeSpeaker(id1, into: id2, mergedName: "Merged Speaker", stopIfPermanent: false) + XCTAssertFalse(manager.hasSpeaker(id1)) + let merged = try XCTUnwrap(manager.getSpeaker(for: id2)) + XCTAssertEqual(merged.name, "Merged Speaker") + XCTAssertEqual(manager.speakerCount, 1) + XCTAssertGreaterThan(merged.duration, 4.0) + } + + func testFindMergeablePairsRespectsPermanentExclusion() { + let manager = SpeakerManager(speakerThreshold: 0.3) + let base = normalizedEmbedding(pattern: 1) + var close = base + close[0] += 0.001 + close = VDSPOperations.l2Normalize(close) + let far = normalizedEmbedding(pattern: 80) + + manager.upsertSpeaker(id: "A", currentEmbedding: base, duration: 5.0) + manager.upsertSpeaker(id: "B", currentEmbedding: close, duration: 5.0) + manager.upsertSpeaker(id: "C", currentEmbedding: far, duration: 5.0) + + let pairs = manager.findMergeablePairs(speakerThreshold: 0.2) + XCTAssertEqual(pairs.count, 1) + XCTAssertEqual(Set([pairs[0].speakerToMerge, pairs[0].destination]), Set(["A", "B"])) + + manager.makeSpeakerPermanent("A") + manager.makeSpeakerPermanent("B") + + let filtered = manager.findMergeablePairs(speakerThreshold: 0.2, excludeIfBothPermanent: true) + XCTAssertTrue(filtered.isEmpty) + + let unfiltered = manager.findMergeablePairs(speakerThreshold: 0.2, excludeIfBothPermanent: false) + XCTAssertEqual(unfiltered.count, 1) + XCTAssertEqual(Set([unfiltered[0].speakerToMerge, unfiltered[0].destination]), Set(["A", "B"])) + } + + // MARK: - Removal & Reset + + func testRemoveSpeakersInactiveAndPredicateVariants() { + let manager = SpeakerManager() + let now = Date() + manager.upsertSpeaker( + id: "old", + currentEmbedding: normalizedEmbedding(pattern: 3), + duration: 2.0, + updatedAt: now.addingTimeInterval(-120) + ) + manager.upsertSpeaker( + id: "recent", + currentEmbedding: normalizedEmbedding(pattern: 4), + duration: 2.0, + updatedAt: now + ) + + manager.removeSpeakersInactive(since: now.addingTimeInterval(-60)) + XCTAssertFalse(manager.hasSpeaker("old")) + XCTAssertTrue(manager.hasSpeaker("recent")) + + manager.makeSpeakerPermanent("recent") + manager.removeSpeakers { $0.duration <= 2.0 } + XCTAssertTrue(manager.hasSpeaker("recent")) + + manager.removeSpeakers(where: { $0.duration <= 2.0 }, keepIfPermanent: false) + XCTAssertFalse(manager.hasSpeaker("recent")) + } + + func testResetKeepsPermanentSpeakers() throws { + let manager = SpeakerManager() + let speaker1 = manager.assignSpeaker(createDistinctEmbedding(pattern: 1), speechDuration: 2.0) + let speaker2 = manager.assignSpeaker(createDistinctEmbedding(pattern: 2), speechDuration: 2.0) + + let id1 = try XCTUnwrap(speaker1?.id) + let id2 = try XCTUnwrap(speaker2?.id) + + manager.makeSpeakerPermanent(id1) + manager.reset(keepIfPermanent: true) + + XCTAssertTrue(manager.hasSpeaker(id1)) + XCTAssertFalse(manager.hasSpeaker(id2)) + XCTAssertEqual(manager.speakerIds, [id1]) + } + // MARK: - Embedding Update Tests func testEmbeddingUpdateWithinAssignSpeaker() {