Skip to content

Commit a67a7da

Browse files
authored
fix(event_manager): normalize message.part.delta streams (#290)
1 parent ba94536 commit a67a7da

File tree

3 files changed

+233
-5
lines changed

3 files changed

+233
-5
lines changed

lua/opencode/event_manager.lua

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ local util = require('opencode.util')
2323
--- @field type "message.part.updated"
2424
--- @field properties {part: OpencodeMessagePart}
2525

26+
--- @class EventMessagePartDelta
27+
--- @field type "message.part.delta"
28+
--- @field properties {
29+
--- sessionID: string,
30+
--- messageID: string,
31+
--- partID: string,
32+
--- field: string,
33+
--- delta: string
34+
--- }
35+
2636
--- @class EventMessagePartRemoved
2737
--- @field type "message.part.removed"
2838
--- @field properties {sessionID: string, messageID: string, partID: string}
@@ -128,6 +138,7 @@ local util = require('opencode.util')
128138
--- | "message.updated"
129139
--- | "message.removed"
130140
--- | "message.part.updated"
141+
--- | "message.part.delta"
131142
--- | "message.part.removed"
132143
--- | "session.compacted"
133144
--- | "session.idle"
@@ -170,6 +181,7 @@ function EventManager.new()
170181
state_server_listener = nil,
171182
is_started = false,
172183
captured_events = {},
184+
_parts_by_id = {},
173185
}, EventManager)
174186

175187
local throttle_ms = config.ui.output.rendering.event_throttle_ms
@@ -186,6 +198,7 @@ end
186198
--- @overload fun(self: EventManager, event_name: "message.updated", callback: fun(data: EventMessageUpdated['properties']): nil)
187199
--- @overload fun(self: EventManager, event_name: "message.removed", callback: fun(data: EventMessageRemoved['properties']): nil)
188200
--- @overload fun(self: EventManager, event_name: "message.part.updated", callback: fun(data: EventMessagePartUpdated['properties']): nil)
201+
--- @overload fun(self: EventManager, event_name: "message.part.delta", callback: fun(data: EventMessagePartDelta['properties']): nil)
189202
--- @overload fun(self: EventManager, event_name: "message.part.removed", callback: fun(data: EventMessagePartRemoved['properties']): nil)
190203
--- @overload fun(self: EventManager, event_name: "session.compacted", callback: fun(data: EventSessionCompacted['properties']): nil)
191204
--- @overload fun(self: EventManager, event_name: "session.idle", callback: fun(data: EventSessionIdle['properties']): nil)
@@ -226,6 +239,7 @@ end
226239
--- @overload fun(self: EventManager, event_name: "message.updated", callback: fun(data: EventMessageUpdated['properties']): nil)
227240
--- @overload fun(self: EventManager, event_name: "message.removed", callback: fun(data: EventMessageRemoved['properties']): nil)
228241
--- @overload fun(self: EventManager, event_name: "message.part.updated", callback: fun(data: EventMessagePartUpdated['properties']): nil)
242+
--- @overload fun(self: EventManager, event_name: "message.part.delta", callback: fun(data: EventMessagePartDelta['properties']): nil)
229243
--- @overload fun(self: EventManager, event_name: "message.part.removed", callback: fun(data: EventMessagePartRemoved['properties']): nil)
230244
--- @overload fun(self: EventManager, event_name: "session.compacted", callback: fun(data: EventSessionCompacted['properties']): nil)
231245
--- @overload fun(self: EventManager, event_name: "session.idle", callback: fun(data: EventSessionIdle['properties']): nil)
@@ -260,15 +274,93 @@ function EventManager:unsubscribe(event_name, callback)
260274
end
261275
end
262276

277+
---Normalize message.part.delta events into message.part.updated events so
278+
---consumers can continue rendering full part payloads.
279+
---@param event table
280+
---@return table|nil
281+
function EventManager:_normalize_stream_event(event)
282+
if not event or not event.type then
283+
return nil
284+
end
285+
286+
local properties = event.properties or {}
287+
288+
if event.type == 'message.part.updated' and properties.part and properties.part.id then
289+
self._parts_by_id[properties.part.id] = vim.deepcopy(properties.part)
290+
return event
291+
end
292+
293+
if event.type == 'message.part.removed' and properties.partID then
294+
self._parts_by_id[properties.partID] = nil
295+
return event
296+
end
297+
298+
if event.type ~= 'message.part.delta' then
299+
return event
300+
end
301+
302+
local part_id = properties.partID
303+
local message_id = properties.messageID
304+
local session_id = properties.sessionID
305+
local field = properties.field
306+
307+
if not part_id or not message_id or not session_id or not field then
308+
return nil
309+
end
310+
311+
local part = vim.deepcopy(self._parts_by_id[part_id])
312+
if not part then
313+
part = {
314+
id = part_id,
315+
messageID = message_id,
316+
sessionID = session_id,
317+
}
318+
319+
if field == 'text' then
320+
part.type = 'text'
321+
part.text = ''
322+
end
323+
end
324+
325+
local delta = properties.delta
326+
local current = part[field]
327+
if type(delta) == 'string' then
328+
if type(current) == 'string' then
329+
part[field] = current .. delta
330+
else
331+
part[field] = delta
332+
end
333+
else
334+
part[field] = delta
335+
end
336+
337+
self._parts_by_id[part_id] = part
338+
339+
return {
340+
type = 'message.part.updated',
341+
properties = {
342+
part = part,
343+
},
344+
}
345+
end
346+
263347
---Callback from ThrottlingEmitter when the events are now ready to be processed.
264348
---Collapses parts that are duplicated, making sure to replace earlier parts with later
265349
---ones (but keeping the earlier position)
266350
---@param events any
267351
function EventManager:_on_drained_events(events)
268352
self:emit('custom.emit_events.started', {})
269353

354+
local normalized_events = {}
355+
for _, event in ipairs(events) do
356+
local normalized_event = self:_normalize_stream_event(event)
357+
if normalized_event then
358+
table.insert(normalized_events, normalized_event)
359+
end
360+
end
361+
270362
if not config.ui.output.rendering.event_collapsing then
271-
for _, event in ipairs(events) do
363+
for _, event in ipairs(normalized_events) do
272364
self:emit(event.type, event.properties)
273365
end
274366
self:emit('custom.emit_events.finished', {})
@@ -278,7 +370,7 @@ function EventManager:_on_drained_events(events)
278370
local collapsed_events = {}
279371
local part_update_indices = {}
280372

281-
for i, event in ipairs(events) do
373+
for i, event in ipairs(normalized_events) do
282374
if event.type == 'message.part.updated' and event.properties.part then
283375
local part_id = event.properties.part.id
284376
if part_update_indices[part_id] then
@@ -289,7 +381,10 @@ function EventManager:_on_drained_events(events)
289381
-- permission.updated/permission.asked sits between the two updates.
290382
local has_intervening_permission_event = false
291383
for j = previous_index + 1, i - 1 do
292-
if events[j] and (events[j].type == 'permission.updated' or events[j].type == 'permission.asked') then
384+
if normalized_events[j] and (
385+
normalized_events[j].type == 'permission.updated'
386+
or normalized_events[j].type == 'permission.asked'
387+
) then
293388
has_intervening_permission_event = true
294389
break
295390
end
@@ -312,7 +407,7 @@ function EventManager:_on_drained_events(events)
312407
end
313408
end
314409

315-
for i = 1, #events do
410+
for i = 1, #normalized_events do
316411
local event = collapsed_events[i]
317412
if event then
318413
self:emit(event.type, event.properties)
@@ -404,6 +499,7 @@ function EventManager:stop()
404499
self:_cleanup_server_subscription()
405500

406501
self.throttling_emitter:clear()
502+
self._parts_by_id = {}
407503
self.events = {}
408504
end
409505

tests/helpers.lua

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ end
174174

175175
function M.load_session_from_events(events)
176176
local session_data = {}
177+
local parts_by_id = {}
177178

178179
for _, event in ipairs(events) do
179180
local properties = event.properties
@@ -210,12 +211,53 @@ function M.load_session_from_events(events)
210211

211212
if existing_part then
212213
msg.parts[existing_part] = vim.deepcopy(part)
214+
parts_by_id[part.id] = msg.parts[existing_part]
213215
else
214216
table.insert(msg.parts, vim.deepcopy(part))
217+
parts_by_id[part.id] = msg.parts[#msg.parts]
215218
end
216219
break
217220
end
218221
end
222+
elseif event.type == 'message.part.delta'
223+
and properties.partID
224+
and properties.messageID
225+
and properties.field then
226+
local part = parts_by_id[properties.partID]
227+
228+
if not part then
229+
for _, msg in ipairs(session_data) do
230+
if msg.info.id == properties.messageID then
231+
part = {
232+
id = properties.partID,
233+
messageID = properties.messageID,
234+
sessionID = properties.sessionID,
235+
type = properties.field == 'text' and 'text' or nil,
236+
}
237+
if properties.field == 'text' then
238+
part.text = ''
239+
end
240+
table.insert(msg.parts, part)
241+
parts_by_id[properties.partID] = part
242+
break
243+
end
244+
end
245+
end
246+
247+
if part then
248+
local field = properties.field
249+
local delta = properties.delta
250+
if type(delta) == 'string' then
251+
local current = part[field]
252+
if type(current) == 'string' then
253+
part[field] = current .. delta
254+
else
255+
part[field] = delta
256+
end
257+
else
258+
part[field] = delta
259+
end
260+
end
219261
end
220262
end
221263

@@ -247,7 +289,9 @@ function M.get_session_from_events(events, with_session_updates)
247289
for _, event in ipairs(events) do
248290
-- find the session id in a message or part event
249291
local properties = event.properties
250-
local session_id = properties.info and properties.info.sessionID or properties.part and properties.part.sessionID
292+
local session_id = properties.info and properties.info.sessionID
293+
or properties.part and properties.part.sessionID
294+
or properties.sessionID
251295

252296
if session_id then
253297
---@diagnostic disable-next-line: missing-fields

tests/unit/event_manager_spec.lua

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
local EventManager = require('opencode.event_manager')
22
local Promise = require('opencode.promise')
33
local state = require('opencode.state')
4+
local config = require('opencode.config')
45

56
describe('EventManager', function()
67
local event_manager
@@ -189,6 +190,93 @@ describe('EventManager', function()
189190
vim.defer_fn = original_defer_fn
190191
end)
191192

193+
it('normalizes message.part.delta into message.part.updated', function()
194+
local original_event_collapsing = config.ui.output.rendering.event_collapsing
195+
config.ui.output.rendering.event_collapsing = true
196+
197+
local received = {}
198+
event_manager:subscribe('message.part.updated', function(data)
199+
table.insert(received, vim.deepcopy(data.part))
200+
end)
201+
202+
event_manager:_on_drained_events({
203+
{
204+
type = 'message.part.updated',
205+
properties = {
206+
part = {
207+
id = 'part_1',
208+
messageID = 'msg_1',
209+
sessionID = 'ses_1',
210+
type = 'text',
211+
text = '',
212+
},
213+
},
214+
},
215+
{
216+
type = 'message.part.delta',
217+
properties = {
218+
partID = 'part_1',
219+
messageID = 'msg_1',
220+
sessionID = 'ses_1',
221+
field = 'text',
222+
delta = 'hello',
223+
},
224+
},
225+
{
226+
type = 'message.part.delta',
227+
properties = {
228+
partID = 'part_1',
229+
messageID = 'msg_1',
230+
sessionID = 'ses_1',
231+
field = 'text',
232+
delta = ' world',
233+
},
234+
},
235+
})
236+
237+
config.ui.output.rendering.event_collapsing = original_event_collapsing
238+
239+
assert.are.equal(1, #received)
240+
assert.are.equal('hello world', received[1].text)
241+
end)
242+
243+
it('keeps accumulated delta text across event batches', function()
244+
local received = {}
245+
event_manager:subscribe('message.part.updated', function(data)
246+
table.insert(received, vim.deepcopy(data.part))
247+
end)
248+
249+
event_manager:_on_drained_events({
250+
{
251+
type = 'message.part.updated',
252+
properties = {
253+
part = {
254+
id = 'part_2',
255+
messageID = 'msg_2',
256+
sessionID = 'ses_2',
257+
type = 'text',
258+
text = '',
259+
},
260+
},
261+
},
262+
})
263+
264+
event_manager:_on_drained_events({
265+
{
266+
type = 'message.part.delta',
267+
properties = {
268+
partID = 'part_2',
269+
messageID = 'msg_2',
270+
sessionID = 'ses_2',
271+
field = 'text',
272+
delta = 'abc',
273+
},
274+
},
275+
})
276+
277+
assert.are.equal('abc', received[#received].text)
278+
end)
279+
192280
describe('User autocmd events', function()
193281
it('should fire User autocmd when emitting events', function()
194282
local autocmd_called = false

0 commit comments

Comments
 (0)