diff --git a/lua/opencode/event_manager.lua b/lua/opencode/event_manager.lua index 7eeeb4c5..f42777d8 100644 --- a/lua/opencode/event_manager.lua +++ b/lua/opencode/event_manager.lua @@ -23,6 +23,16 @@ local util = require('opencode.util') --- @field type "message.part.updated" --- @field properties {part: OpencodeMessagePart} +--- @class EventMessagePartDelta +--- @field type "message.part.delta" +--- @field properties { +--- sessionID: string, +--- messageID: string, +--- partID: string, +--- field: string, +--- delta: string +--- } + --- @class EventMessagePartRemoved --- @field type "message.part.removed" --- @field properties {sessionID: string, messageID: string, partID: string} @@ -128,6 +138,7 @@ local util = require('opencode.util') --- | "message.updated" --- | "message.removed" --- | "message.part.updated" +--- | "message.part.delta" --- | "message.part.removed" --- | "session.compacted" --- | "session.idle" @@ -170,6 +181,7 @@ function EventManager.new() state_server_listener = nil, is_started = false, captured_events = {}, + _parts_by_id = {}, }, EventManager) local throttle_ms = config.ui.output.rendering.event_throttle_ms @@ -186,6 +198,7 @@ end --- @overload fun(self: EventManager, event_name: "message.updated", callback: fun(data: EventMessageUpdated['properties']): nil) --- @overload fun(self: EventManager, event_name: "message.removed", callback: fun(data: EventMessageRemoved['properties']): nil) --- @overload fun(self: EventManager, event_name: "message.part.updated", callback: fun(data: EventMessagePartUpdated['properties']): nil) +--- @overload fun(self: EventManager, event_name: "message.part.delta", callback: fun(data: EventMessagePartDelta['properties']): nil) --- @overload fun(self: EventManager, event_name: "message.part.removed", callback: fun(data: EventMessagePartRemoved['properties']): nil) --- @overload fun(self: EventManager, event_name: "session.compacted", callback: fun(data: EventSessionCompacted['properties']): nil) --- @overload fun(self: EventManager, event_name: "session.idle", callback: fun(data: EventSessionIdle['properties']): nil) @@ -226,6 +239,7 @@ end --- @overload fun(self: EventManager, event_name: "message.updated", callback: fun(data: EventMessageUpdated['properties']): nil) --- @overload fun(self: EventManager, event_name: "message.removed", callback: fun(data: EventMessageRemoved['properties']): nil) --- @overload fun(self: EventManager, event_name: "message.part.updated", callback: fun(data: EventMessagePartUpdated['properties']): nil) +--- @overload fun(self: EventManager, event_name: "message.part.delta", callback: fun(data: EventMessagePartDelta['properties']): nil) --- @overload fun(self: EventManager, event_name: "message.part.removed", callback: fun(data: EventMessagePartRemoved['properties']): nil) --- @overload fun(self: EventManager, event_name: "session.compacted", callback: fun(data: EventSessionCompacted['properties']): nil) --- @overload fun(self: EventManager, event_name: "session.idle", callback: fun(data: EventSessionIdle['properties']): nil) @@ -260,6 +274,76 @@ function EventManager:unsubscribe(event_name, callback) end end +---Normalize message.part.delta events into message.part.updated events so +---consumers can continue rendering full part payloads. +---@param event table +---@return table|nil +function EventManager:_normalize_stream_event(event) + if not event or not event.type then + return nil + end + + local properties = event.properties or {} + + if event.type == 'message.part.updated' and properties.part and properties.part.id then + self._parts_by_id[properties.part.id] = vim.deepcopy(properties.part) + return event + end + + if event.type == 'message.part.removed' and properties.partID then + self._parts_by_id[properties.partID] = nil + return event + end + + if event.type ~= 'message.part.delta' then + return event + end + + local part_id = properties.partID + local message_id = properties.messageID + local session_id = properties.sessionID + local field = properties.field + + if not part_id or not message_id or not session_id or not field then + return nil + end + + local part = vim.deepcopy(self._parts_by_id[part_id]) + if not part then + part = { + id = part_id, + messageID = message_id, + sessionID = session_id, + } + + if field == 'text' then + part.type = 'text' + part.text = '' + end + end + + local delta = properties.delta + local current = part[field] + if type(delta) == 'string' then + if type(current) == 'string' then + part[field] = current .. delta + else + part[field] = delta + end + else + part[field] = delta + end + + self._parts_by_id[part_id] = part + + return { + type = 'message.part.updated', + properties = { + part = part, + }, + } +end + ---Callback from ThrottlingEmitter when the events are now ready to be processed. ---Collapses parts that are duplicated, making sure to replace earlier parts with later ---ones (but keeping the earlier position) @@ -267,8 +351,16 @@ end function EventManager:_on_drained_events(events) self:emit('custom.emit_events.started', {}) + local normalized_events = {} + for _, event in ipairs(events) do + local normalized_event = self:_normalize_stream_event(event) + if normalized_event then + table.insert(normalized_events, normalized_event) + end + end + if not config.ui.output.rendering.event_collapsing then - for _, event in ipairs(events) do + for _, event in ipairs(normalized_events) do self:emit(event.type, event.properties) end self:emit('custom.emit_events.finished', {}) @@ -278,7 +370,7 @@ function EventManager:_on_drained_events(events) local collapsed_events = {} local part_update_indices = {} - for i, event in ipairs(events) do + for i, event in ipairs(normalized_events) do if event.type == 'message.part.updated' and event.properties.part then local part_id = event.properties.part.id if part_update_indices[part_id] then @@ -289,7 +381,10 @@ function EventManager:_on_drained_events(events) -- permission.updated/permission.asked sits between the two updates. local has_intervening_permission_event = false for j = previous_index + 1, i - 1 do - if events[j] and (events[j].type == 'permission.updated' or events[j].type == 'permission.asked') then + if normalized_events[j] and ( + normalized_events[j].type == 'permission.updated' + or normalized_events[j].type == 'permission.asked' + ) then has_intervening_permission_event = true break end @@ -312,7 +407,7 @@ function EventManager:_on_drained_events(events) end end - for i = 1, #events do + for i = 1, #normalized_events do local event = collapsed_events[i] if event then self:emit(event.type, event.properties) @@ -404,6 +499,7 @@ function EventManager:stop() self:_cleanup_server_subscription() self.throttling_emitter:clear() + self._parts_by_id = {} self.events = {} end diff --git a/tests/helpers.lua b/tests/helpers.lua index 189741b9..53d15189 100644 --- a/tests/helpers.lua +++ b/tests/helpers.lua @@ -174,6 +174,7 @@ end function M.load_session_from_events(events) local session_data = {} + local parts_by_id = {} for _, event in ipairs(events) do local properties = event.properties @@ -210,12 +211,53 @@ function M.load_session_from_events(events) if existing_part then msg.parts[existing_part] = vim.deepcopy(part) + parts_by_id[part.id] = msg.parts[existing_part] else table.insert(msg.parts, vim.deepcopy(part)) + parts_by_id[part.id] = msg.parts[#msg.parts] end break end end + elseif event.type == 'message.part.delta' + and properties.partID + and properties.messageID + and properties.field then + local part = parts_by_id[properties.partID] + + if not part then + for _, msg in ipairs(session_data) do + if msg.info.id == properties.messageID then + part = { + id = properties.partID, + messageID = properties.messageID, + sessionID = properties.sessionID, + type = properties.field == 'text' and 'text' or nil, + } + if properties.field == 'text' then + part.text = '' + end + table.insert(msg.parts, part) + parts_by_id[properties.partID] = part + break + end + end + end + + if part then + local field = properties.field + local delta = properties.delta + if type(delta) == 'string' then + local current = part[field] + if type(current) == 'string' then + part[field] = current .. delta + else + part[field] = delta + end + else + part[field] = delta + end + end end end @@ -247,7 +289,9 @@ function M.get_session_from_events(events, with_session_updates) for _, event in ipairs(events) do -- find the session id in a message or part event local properties = event.properties - local session_id = properties.info and properties.info.sessionID or properties.part and properties.part.sessionID + local session_id = properties.info and properties.info.sessionID + or properties.part and properties.part.sessionID + or properties.sessionID if session_id then ---@diagnostic disable-next-line: missing-fields diff --git a/tests/unit/event_manager_spec.lua b/tests/unit/event_manager_spec.lua index 01915f87..3eb79074 100644 --- a/tests/unit/event_manager_spec.lua +++ b/tests/unit/event_manager_spec.lua @@ -1,6 +1,7 @@ local EventManager = require('opencode.event_manager') local Promise = require('opencode.promise') local state = require('opencode.state') +local config = require('opencode.config') describe('EventManager', function() local event_manager @@ -189,6 +190,93 @@ describe('EventManager', function() vim.defer_fn = original_defer_fn end) + it('normalizes message.part.delta into message.part.updated', function() + local original_event_collapsing = config.ui.output.rendering.event_collapsing + config.ui.output.rendering.event_collapsing = true + + local received = {} + event_manager:subscribe('message.part.updated', function(data) + table.insert(received, vim.deepcopy(data.part)) + end) + + event_manager:_on_drained_events({ + { + type = 'message.part.updated', + properties = { + part = { + id = 'part_1', + messageID = 'msg_1', + sessionID = 'ses_1', + type = 'text', + text = '', + }, + }, + }, + { + type = 'message.part.delta', + properties = { + partID = 'part_1', + messageID = 'msg_1', + sessionID = 'ses_1', + field = 'text', + delta = 'hello', + }, + }, + { + type = 'message.part.delta', + properties = { + partID = 'part_1', + messageID = 'msg_1', + sessionID = 'ses_1', + field = 'text', + delta = ' world', + }, + }, + }) + + config.ui.output.rendering.event_collapsing = original_event_collapsing + + assert.are.equal(1, #received) + assert.are.equal('hello world', received[1].text) + end) + + it('keeps accumulated delta text across event batches', function() + local received = {} + event_manager:subscribe('message.part.updated', function(data) + table.insert(received, vim.deepcopy(data.part)) + end) + + event_manager:_on_drained_events({ + { + type = 'message.part.updated', + properties = { + part = { + id = 'part_2', + messageID = 'msg_2', + sessionID = 'ses_2', + type = 'text', + text = '', + }, + }, + }, + }) + + event_manager:_on_drained_events({ + { + type = 'message.part.delta', + properties = { + partID = 'part_2', + messageID = 'msg_2', + sessionID = 'ses_2', + field = 'text', + delta = 'abc', + }, + }, + }) + + assert.are.equal('abc', received[#received].text) + end) + describe('User autocmd events', function() it('should fire User autocmd when emitting events', function() local autocmd_called = false