Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 100 additions & 4 deletions lua/opencode/event_manager.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ local util = require('opencode.util')
--- @field type "message.part.updated"
--- @field properties {part: OpencodeMessagePart}

--- @class EventMessagePartDelta
--- @field type "message.part.delta"
--- @field properties {
--- sessionID: string,
--- messageID: string,
--- partID: string,
--- field: string,
--- delta: string
--- }

--- @class EventMessagePartRemoved
--- @field type "message.part.removed"
--- @field properties {sessionID: string, messageID: string, partID: string}
Expand Down Expand Up @@ -128,6 +138,7 @@ local util = require('opencode.util')
--- | "message.updated"
--- | "message.removed"
--- | "message.part.updated"
--- | "message.part.delta"
--- | "message.part.removed"
--- | "session.compacted"
--- | "session.idle"
Expand Down Expand Up @@ -170,6 +181,7 @@ function EventManager.new()
state_server_listener = nil,
is_started = false,
captured_events = {},
_parts_by_id = {},
}, EventManager)

local throttle_ms = config.ui.output.rendering.event_throttle_ms
Expand All @@ -186,6 +198,7 @@ end
--- @overload fun(self: EventManager, event_name: "message.updated", callback: fun(data: EventMessageUpdated['properties']): nil)
--- @overload fun(self: EventManager, event_name: "message.removed", callback: fun(data: EventMessageRemoved['properties']): nil)
--- @overload fun(self: EventManager, event_name: "message.part.updated", callback: fun(data: EventMessagePartUpdated['properties']): nil)
--- @overload fun(self: EventManager, event_name: "message.part.delta", callback: fun(data: EventMessagePartDelta['properties']): nil)
--- @overload fun(self: EventManager, event_name: "message.part.removed", callback: fun(data: EventMessagePartRemoved['properties']): nil)
--- @overload fun(self: EventManager, event_name: "session.compacted", callback: fun(data: EventSessionCompacted['properties']): nil)
--- @overload fun(self: EventManager, event_name: "session.idle", callback: fun(data: EventSessionIdle['properties']): nil)
Expand Down Expand Up @@ -226,6 +239,7 @@ end
--- @overload fun(self: EventManager, event_name: "message.updated", callback: fun(data: EventMessageUpdated['properties']): nil)
--- @overload fun(self: EventManager, event_name: "message.removed", callback: fun(data: EventMessageRemoved['properties']): nil)
--- @overload fun(self: EventManager, event_name: "message.part.updated", callback: fun(data: EventMessagePartUpdated['properties']): nil)
--- @overload fun(self: EventManager, event_name: "message.part.delta", callback: fun(data: EventMessagePartDelta['properties']): nil)
--- @overload fun(self: EventManager, event_name: "message.part.removed", callback: fun(data: EventMessagePartRemoved['properties']): nil)
--- @overload fun(self: EventManager, event_name: "session.compacted", callback: fun(data: EventSessionCompacted['properties']): nil)
--- @overload fun(self: EventManager, event_name: "session.idle", callback: fun(data: EventSessionIdle['properties']): nil)
Expand Down Expand Up @@ -260,15 +274,93 @@ function EventManager:unsubscribe(event_name, callback)
end
end

---Normalize message.part.delta events into message.part.updated events so
---consumers can continue rendering full part payloads.
---@param event table
---@return table|nil
function EventManager:_normalize_stream_event(event)
if not event or not event.type then
return nil
end

local properties = event.properties or {}

if event.type == 'message.part.updated' and properties.part and properties.part.id then
self._parts_by_id[properties.part.id] = vim.deepcopy(properties.part)
return event
end

if event.type == 'message.part.removed' and properties.partID then
self._parts_by_id[properties.partID] = nil
return event
end

if event.type ~= 'message.part.delta' then
return event
end

local part_id = properties.partID
local message_id = properties.messageID
local session_id = properties.sessionID
local field = properties.field

if not part_id or not message_id or not session_id or not field then
return nil
end

local part = vim.deepcopy(self._parts_by_id[part_id])
if not part then
part = {
id = part_id,
messageID = message_id,
sessionID = session_id,
}

if field == 'text' then
part.type = 'text'
part.text = ''
end
end

local delta = properties.delta
local current = part[field]
if type(delta) == 'string' then
if type(current) == 'string' then
part[field] = current .. delta
else
part[field] = delta
end
else
part[field] = delta
end

self._parts_by_id[part_id] = part

return {
type = 'message.part.updated',
properties = {
part = part,
},
}
end

---Callback from ThrottlingEmitter when the events are now ready to be processed.
---Collapses parts that are duplicated, making sure to replace earlier parts with later
---ones (but keeping the earlier position)
---@param events any
function EventManager:_on_drained_events(events)
self:emit('custom.emit_events.started', {})

local normalized_events = {}
for _, event in ipairs(events) do
local normalized_event = self:_normalize_stream_event(event)
if normalized_event then
table.insert(normalized_events, normalized_event)
end
end

if not config.ui.output.rendering.event_collapsing then
for _, event in ipairs(events) do
for _, event in ipairs(normalized_events) do
self:emit(event.type, event.properties)
end
self:emit('custom.emit_events.finished', {})
Expand All @@ -278,7 +370,7 @@ function EventManager:_on_drained_events(events)
local collapsed_events = {}
local part_update_indices = {}

for i, event in ipairs(events) do
for i, event in ipairs(normalized_events) do
if event.type == 'message.part.updated' and event.properties.part then
local part_id = event.properties.part.id
if part_update_indices[part_id] then
Expand All @@ -289,7 +381,10 @@ function EventManager:_on_drained_events(events)
-- permission.updated/permission.asked sits between the two updates.
local has_intervening_permission_event = false
for j = previous_index + 1, i - 1 do
if events[j] and (events[j].type == 'permission.updated' or events[j].type == 'permission.asked') then
if normalized_events[j] and (
normalized_events[j].type == 'permission.updated'
or normalized_events[j].type == 'permission.asked'
) then
has_intervening_permission_event = true
break
end
Expand All @@ -312,7 +407,7 @@ function EventManager:_on_drained_events(events)
end
end

for i = 1, #events do
for i = 1, #normalized_events do
local event = collapsed_events[i]
if event then
self:emit(event.type, event.properties)
Expand Down Expand Up @@ -404,6 +499,7 @@ function EventManager:stop()
self:_cleanup_server_subscription()

self.throttling_emitter:clear()
self._parts_by_id = {}
self.events = {}
end

Expand Down
46 changes: 45 additions & 1 deletion tests/helpers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ end

function M.load_session_from_events(events)
local session_data = {}
local parts_by_id = {}

for _, event in ipairs(events) do
local properties = event.properties
Expand Down Expand Up @@ -210,12 +211,53 @@ function M.load_session_from_events(events)

if existing_part then
msg.parts[existing_part] = vim.deepcopy(part)
parts_by_id[part.id] = msg.parts[existing_part]
else
table.insert(msg.parts, vim.deepcopy(part))
parts_by_id[part.id] = msg.parts[#msg.parts]
end
break
end
end
elseif event.type == 'message.part.delta'
and properties.partID
and properties.messageID
and properties.field then
local part = parts_by_id[properties.partID]

if not part then
for _, msg in ipairs(session_data) do
if msg.info.id == properties.messageID then
part = {
id = properties.partID,
messageID = properties.messageID,
sessionID = properties.sessionID,
type = properties.field == 'text' and 'text' or nil,
}
if properties.field == 'text' then
part.text = ''
end
table.insert(msg.parts, part)
parts_by_id[properties.partID] = part
break
end
end
end

if part then
local field = properties.field
local delta = properties.delta
if type(delta) == 'string' then
local current = part[field]
if type(current) == 'string' then
part[field] = current .. delta
else
part[field] = delta
end
else
part[field] = delta
end
end
end
end

Expand Down Expand Up @@ -247,7 +289,9 @@ function M.get_session_from_events(events, with_session_updates)
for _, event in ipairs(events) do
-- find the session id in a message or part event
local properties = event.properties
local session_id = properties.info and properties.info.sessionID or properties.part and properties.part.sessionID
local session_id = properties.info and properties.info.sessionID
or properties.part and properties.part.sessionID
or properties.sessionID

if session_id then
---@diagnostic disable-next-line: missing-fields
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/event_manager_spec.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
local EventManager = require('opencode.event_manager')
local Promise = require('opencode.promise')
local state = require('opencode.state')
local config = require('opencode.config')

describe('EventManager', function()
local event_manager
Expand Down Expand Up @@ -189,6 +190,93 @@ describe('EventManager', function()
vim.defer_fn = original_defer_fn
end)

it('normalizes message.part.delta into message.part.updated', function()
local original_event_collapsing = config.ui.output.rendering.event_collapsing
config.ui.output.rendering.event_collapsing = true

local received = {}
event_manager:subscribe('message.part.updated', function(data)
table.insert(received, vim.deepcopy(data.part))
end)

event_manager:_on_drained_events({
{
type = 'message.part.updated',
properties = {
part = {
id = 'part_1',
messageID = 'msg_1',
sessionID = 'ses_1',
type = 'text',
text = '',
},
},
},
{
type = 'message.part.delta',
properties = {
partID = 'part_1',
messageID = 'msg_1',
sessionID = 'ses_1',
field = 'text',
delta = 'hello',
},
},
{
type = 'message.part.delta',
properties = {
partID = 'part_1',
messageID = 'msg_1',
sessionID = 'ses_1',
field = 'text',
delta = ' world',
},
},
})

config.ui.output.rendering.event_collapsing = original_event_collapsing

assert.are.equal(1, #received)
assert.are.equal('hello world', received[1].text)
end)

it('keeps accumulated delta text across event batches', function()
local received = {}
event_manager:subscribe('message.part.updated', function(data)
table.insert(received, vim.deepcopy(data.part))
end)

event_manager:_on_drained_events({
{
type = 'message.part.updated',
properties = {
part = {
id = 'part_2',
messageID = 'msg_2',
sessionID = 'ses_2',
type = 'text',
text = '',
},
},
},
})

event_manager:_on_drained_events({
{
type = 'message.part.delta',
properties = {
partID = 'part_2',
messageID = 'msg_2',
sessionID = 'ses_2',
field = 'text',
delta = 'abc',
},
},
})

assert.are.equal('abc', received[#received].text)
end)

describe('User autocmd events', function()
it('should fire User autocmd when emitting events', function()
local autocmd_called = false
Expand Down