From 8eebae5746074c55c5443d831385a70861a9f304 Mon Sep 17 00:00:00 2001 From: Morten Hersson Date: Fri, 15 May 2026 06:49:04 +0200 Subject: [PATCH 1/6] feat(chat): SQLite-backed session manager with SSE - Schema migrations and monotonic seq with UNIQUE (session_id, seq) - Per-session append mutex so unrelated sessions skip the global lock - Singleflight cold-open lets distinct sessions call StartChat in parallel - setRehydrationActive store + cache writes share one critical section - SSE hub with per-session subscriber cap and reattach - Idle reaper transitions warm-idle to cold under quota --- .gitignore | 2 + config.yaml.example | 45 + go.mod | 9 + go.sum | 45 + internal/chat/doc.go | 15 + internal/chat/export_test.go | 29 + internal/chat/failing_store_test.go | 374 ++++ internal/chat/manager.go | 1122 ++++++++++++ internal/chat/manager_test.go | 1713 +++++++++++++++++++ internal/chat/reaper.go | 101 ++ internal/chat/reaper_export_test.go | 8 + internal/chat/reaper_test.go | 228 +++ internal/chat/runner.go | 247 +++ internal/chat/runner_test.go | 102 ++ internal/chat/sqlite/migrations.go | 211 +++ internal/chat/sqlite/migrations_test.go | 220 +++ internal/chat/sqlite/store.go | 399 +++++ internal/chat/sqlite/store_test.go | 154 ++ internal/chat/sse.go | 238 +++ internal/chat/sse_internal_test.go | 243 +++ internal/chat/sse_test.go | 146 ++ internal/chat/store.go | 64 + internal/chat/transcript/transcript.go | 325 ++++ internal/chat/transcript/transcript_test.go | 253 +++ internal/chat/types.go | 129 ++ internal/chat/types_test.go | 41 + internal/config/config.go | 152 ++ internal/config/config_test.go | 153 ++ 28 files changed, 6768 insertions(+) create mode 100644 internal/chat/doc.go create mode 100644 internal/chat/export_test.go create mode 100644 internal/chat/failing_store_test.go create mode 100644 internal/chat/manager.go create mode 100644 internal/chat/manager_test.go create mode 100644 internal/chat/reaper.go create mode 100644 internal/chat/reaper_export_test.go create mode 100644 internal/chat/reaper_test.go create mode 100644 internal/chat/runner.go create mode 100644 internal/chat/runner_test.go create mode 100644 internal/chat/sqlite/migrations.go create mode 100644 internal/chat/sqlite/migrations_test.go create mode 100644 internal/chat/sqlite/store.go create mode 100644 internal/chat/sqlite/store_test.go create mode 100644 internal/chat/sse.go create mode 100644 internal/chat/sse_internal_test.go create mode 100644 internal/chat/sse_test.go create mode 100644 internal/chat/store.go create mode 100644 internal/chat/transcript/transcript.go create mode 100644 internal/chat/transcript/transcript_test.go create mode 100644 internal/chat/types.go create mode 100644 internal/chat/types_test.go diff --git a/.gitignore b/.gitignore index c758107a..bb956cbf 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ boards/ config.yaml test/integration/stub-worker/stub-worker +.superpowers/ +.worktrees diff --git a/config.yaml.example b/config.yaml.example index 1651a33a..7bbe7cf9 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -218,6 +218,51 @@ runner: # Env: CONTEXTMATRIX_RUNNER_RECONCILE_INTERVAL reconcile_interval: "60s" +# Chat (global chat panel) +chat: + # SQLite database for chat sessions and transcripts. + # CONTEXTMATRIX_CHAT_DB_PATH overrides this. + # Default: $XDG_STATE_HOME/contextmatrix/chats.db + # db_path: /var/lib/contextmatrix/chats.db + + # How long a chat container survives after browser disconnects. + # CONTEXTMATRIX_CHAT_IDLE_TTL overrides this. + idle_ttl: 1h + + # Maximum concurrent chat containers. Default is 8 — enough headroom for + # the multi-pane chat UI's 4 user-facing panes plus 2-4 agent-owned + # background sessions. + # CONTEXTMATRIX_CHAT_MAX_CONCURRENT overrides this. + max_concurrent: 8 + + # Claude model used when a chat is created without an explicit + # selection in the New Chat dialog. Must be a key in chat.models. + default_model: claude-sonnet-4-6 + + # Rough token budget for the rehydration payload sent to the runner on + # cold-reopen. Older transcript turns are dropped (first user turn and + # last 20 turns are always preserved) until the estimate fits. + resume_budget_tokens: 40000 + + # Force the rehydration phase off after this duration even if the + # agent never called chat_rehydration_complete. The first user message + # also ends the phase, so this is a belt-and-suspenders cap. + rehydration_timeout: 10m + + # Allowlist of selectable models for new chats. The label is shown in + # the picker; max_tokens drives the context-window denominator in the + # ChatThread header indicator. Adding a new model is a single edit. + models: + claude-sonnet-4-6: + label: "Sonnet 4.6" + max_tokens: 1000000 + claude-opus-4-7: + label: "Opus 4.7" + max_tokens: 1000000 + claude-haiku-4-5-20251001: + label: "Haiku 4.5" + max_tokens: 200000 + # GitHub authentication and integration. # Used for boards git, task-skills git, issue importing, and branch listing. github: diff --git a/go.mod b/go.mod index f83eba7a..4f440300 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,9 @@ require ( github.com/modelcontextprotocol/go-sdk v1.6.0 github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_model v0.6.2 + golang.org/x/sync v0.20.0 gopkg.in/yaml.v3 v3.0.1 + modernc.org/sqlite v1.50.1 ) require ( @@ -20,6 +22,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudflare/circl v1.6.3 // indirect github.com/cyphar/filepath-securejoin v0.6.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect github.com/go-git/go-billy/v5 v5.9.0 // indirect @@ -30,10 +33,13 @@ require ( github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect github.com/pjbgf/sha1cd v0.6.0 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/segmentio/asm v1.1.3 // indirect github.com/segmentio/encoding v0.5.4 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect @@ -47,6 +53,9 @@ require ( golang.org/x/sys v0.43.0 // indirect google.golang.org/protobuf v1.36.8 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect + modernc.org/libc v1.72.3 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect ) require ( diff --git a/go.sum b/go.sum index 33a26591..ee7fbacf 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/cyphar/filepath-securejoin v0.6.1/go.mod h1:A8hd4EnAeyujCJRrICiOWqjS1 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= @@ -42,8 +44,12 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0= github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= @@ -61,12 +67,16 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mhersson/contextmatrix-githubauth v0.1.0 h1:wzbTo6DTPAL6BQUcFcOhafPU6m92DLt4Er4xIp7LfD8= github.com/mhersson/contextmatrix-githubauth v0.1.0/go.mod h1:jjMMyN9ieQgdy56N23korXlV+msG0+Wd+en+zQkuN7U= github.com/modelcontextprotocol/go-sdk v1.6.0 h1:PPLS3kn7WtOEnR+Af4X5H96SG0qSab8R/ZQT/HkhPkY= github.com/modelcontextprotocol/go-sdk v1.6.0/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= github.com/pjbgf/sha1cd v0.6.0 h1:3WJ8Wz8gvDz29quX1OcEmkAlUg9diU4GxJHqs0/XiwU= @@ -83,6 +93,8 @@ github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9Z github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= @@ -112,17 +124,22 @@ golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM= golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -146,3 +163,31 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY= +modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI= +modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ= +modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A= +modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= +modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= +modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU= +modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg= +modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.50.1 h1:l+cQvn0sd0zJJtfygGHuQJ5AjlrwXmWPw4KP3ZMwr9w= +modernc.org/sqlite v1.50.1/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/chat/doc.go b/internal/chat/doc.go new file mode 100644 index 00000000..cdaa5fb8 --- /dev/null +++ b/internal/chat/doc.go @@ -0,0 +1,15 @@ +// Package chat provides the global chat panel orchestration layer. +// It coordinates the SQLite-backed session/transcript store, the runner +// client for spawning chat containers, the idle TTL reaper, and the SSE +// fan-out for browser subscribers. +// +// File layout: +// - doc.go — package documentation. +// - types.go — Session, Message, Status enums. +// - store.go — Store interface used by Manager. +// - sqlite/ — SQLite implementation of Store. +// - manager.go — Manager: session lifecycle orchestration. +// - runner.go — RunnerClient: wraps chat-mode runner webhooks. +// - reaper.go — IdleReaper: TTL goroutine. +// - sse.go — Per-session SSE buffer and fan-out hub. +package chat diff --git a/internal/chat/export_test.go b/internal/chat/export_test.go new file mode 100644 index 00000000..6bb38d9f --- /dev/null +++ b/internal/chat/export_test.go @@ -0,0 +1,29 @@ +package chat + +import "context" + +// BuildResumeForTest is a test-only export of buildResume, allowing +// package chat_test to exercise the tail-loading behaviour without +// accessing unexported methods directly. +func (m *Manager) BuildResumeForTest(ctx context.Context, sessionID string) *ResumeContext { + return m.buildResume(ctx, sessionID) +} + +// SetRehydrationActiveForTest is a test-only export of setRehydrationActive. +func (m *Manager) SetRehydrationActiveForTest(ctx context.Context, sessionID string, active bool) error { + return m.setRehydrationActive(ctx, sessionID, active) +} + +// RehydrationActiveCacheForTest reads only the in-memory cache value under +// m.mu, returning (value, present). Unlike isRehydrationActive, it does not +// fall back to the store on miss — tests use it to assert that the cache +// reflects exactly what setRehydrationActive committed, without masking a +// store/cache divergence by silently re-populating from disk. +func (m *Manager) RehydrationActiveCacheForTest(sessionID string) (bool, bool) { + m.mu.Lock() + defer m.mu.Unlock() + + v, ok := m.rehydrationActive[sessionID] + + return v, ok +} diff --git a/internal/chat/failing_store_test.go b/internal/chat/failing_store_test.go new file mode 100644 index 00000000..78c5016f --- /dev/null +++ b/internal/chat/failing_store_test.go @@ -0,0 +1,374 @@ +package chat_test + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "time" + + "github.com/mhersson/contextmatrix/internal/chat" +) + +// failingStore wraps a real chat.Store and can inject a one-shot error on +// SetRehydrationActive. All other methods delegate directly to the inner store. +// Used only in tests. +type failingStore struct { + chat.Store + failNextSetRehydration atomic.Bool +} + +// FailNextSetRehydration arms the one-shot fault: the next call to +// SetRehydrationActive returns an error and disarms the fault. +func (f *failingStore) FailNextSetRehydration() { + f.failNextSetRehydration.Store(true) +} + +func (f *failingStore) SetRehydrationActive(ctx context.Context, sessionID string, active bool) error { + if f.failNextSetRehydration.CompareAndSwap(true, false) { + return errors.New("injected: SetRehydrationActive failure") + } + + return f.Store.SetRehydrationActive(ctx, sessionID, active) +} + +// Remaining Store methods delegate to the inner store via embedding, so we +// only need explicit overrides for methods where we inject faults. +// These thin wrappers exist to satisfy the interface with pointer receiver: + +func (f *failingStore) CreateSession(ctx context.Context, s chat.Session) error { + return f.Store.CreateSession(ctx, s) +} + +func (f *failingStore) GetSession(ctx context.Context, id string) (chat.Session, error) { + return f.Store.GetSession(ctx, id) +} + +func (f *failingStore) ListSessions(ctx context.Context, filter chat.SessionFilter) ([]chat.Session, error) { + return f.Store.ListSessions(ctx, filter) +} + +func (f *failingStore) UpdateSession(ctx context.Context, s chat.Session) error { + return f.Store.UpdateSession(ctx, s) +} + +func (f *failingStore) DeleteSession(ctx context.Context, id string) error { + return f.Store.DeleteSession(ctx, id) +} + +func (f *failingStore) UpdateContextTokens(ctx context.Context, sessionID string, tokens int64, updatedAt time.Time) error { + return f.Store.UpdateContextTokens(ctx, sessionID, tokens, updatedAt) +} + +func (f *failingStore) AppendMessage(ctx context.Context, m chat.Message) (int64, error) { + return f.Store.AppendMessage(ctx, m) +} + +func (f *failingStore) ListMessages(ctx context.Context, sessionID string, sinceSeq int64, limit int) ([]chat.Message, error) { + return f.Store.ListMessages(ctx, sessionID, sinceSeq, limit) +} + +func (f *failingStore) ListMessagesTail(ctx context.Context, sessionID string, limit int) ([]chat.Message, error) { + return f.Store.ListMessagesTail(ctx, sessionID, limit) +} + +func (f *failingStore) MaxSeq(ctx context.Context, sessionID string) (int64, error) { + return f.Store.MaxSeq(ctx, sessionID) +} + +func (f *failingStore) Close() error { + return f.Store.Close() +} + +// trackingStore wraps a real chat.Store and records the Status of every +// UpdateSession call so tests can assert on intermediate writes. +type trackingStore struct { + chat.Store + mu sync.Mutex + statuses []chat.Status +} + +func (ts *trackingStore) UpdateSession(ctx context.Context, s chat.Session) error { + ts.mu.Lock() + ts.statuses = append(ts.statuses, s.Status) + ts.mu.Unlock() + + return ts.Store.UpdateSession(ctx, s) +} + +func (ts *trackingStore) writtenStatuses() []chat.Status { + ts.mu.Lock() + defer ts.mu.Unlock() + + out := make([]chat.Status, len(ts.statuses)) + copy(out, ts.statuses) + + return out +} + +// Explicit interface delegation (needed because trackingStore has a pointer +// receiver on UpdateSession, which shadows the embedded interface). + +func (ts *trackingStore) CreateSession(ctx context.Context, s chat.Session) error { + return ts.Store.CreateSession(ctx, s) +} + +func (ts *trackingStore) GetSession(ctx context.Context, id string) (chat.Session, error) { + return ts.Store.GetSession(ctx, id) +} + +func (ts *trackingStore) ListSessions(ctx context.Context, filter chat.SessionFilter) ([]chat.Session, error) { + return ts.Store.ListSessions(ctx, filter) +} + +func (ts *trackingStore) DeleteSession(ctx context.Context, id string) error { + return ts.Store.DeleteSession(ctx, id) +} + +func (ts *trackingStore) SetRehydrationActive(ctx context.Context, sessionID string, active bool) error { + return ts.Store.SetRehydrationActive(ctx, sessionID, active) +} + +func (ts *trackingStore) UpdateContextTokens(ctx context.Context, sessionID string, tokens int64, updatedAt time.Time) error { + return ts.Store.UpdateContextTokens(ctx, sessionID, tokens, updatedAt) +} + +func (ts *trackingStore) AppendMessage(ctx context.Context, m chat.Message) (int64, error) { + return ts.Store.AppendMessage(ctx, m) +} + +func (ts *trackingStore) ListMessages(ctx context.Context, sessionID string, sinceSeq int64, limit int) ([]chat.Message, error) { + return ts.Store.ListMessages(ctx, sessionID, sinceSeq, limit) +} + +func (ts *trackingStore) ListMessagesTail(ctx context.Context, sessionID string, limit int) ([]chat.Message, error) { + return ts.Store.ListMessagesTail(ctx, sessionID, limit) +} + +func (ts *trackingStore) MaxSeq(ctx context.Context, sessionID string) (int64, error) { + return ts.Store.MaxSeq(ctx, sessionID) +} + +func (ts *trackingStore) Close() error { + return ts.Store.Close() +} + +// yieldingStore wraps a real chat.Store and inserts a randomised sleep +// after every SetRehydrationActive. SQLite's UPDATE is heavyweight +// relative to the trivial cache write that follows in setRehydrationActive, +// so the natural race window between the two writes is tight enough to +// hide a store/cache ordering bug under -race -count. The variable +// sleeps scatter cache writes out of the order their preceding store +// writes committed in — the schedule that exposes the regression. +type yieldingStore struct { + chat.Store + rng atomic.Int64 +} + +func (y *yieldingStore) SetRehydrationActive(ctx context.Context, sessionID string, active bool) error { + err := y.Store.SetRehydrationActive(ctx, sessionID, active) + // xorshift-style scramble of an atomic counter gives each call a + // different sleep duration without pulling in math/rand and without + // introducing a shared mutex. Range is [0, ~2ms). Single-millisecond + // jitter is what reliably surfaces the bug — narrower windows are + // hidden by SQLite serialisation and the race detector's coarser + // scheduling. + n := y.rng.Add(1) + n ^= n << 13 + n ^= n >> 7 + n ^= n << 17 + + if n < 0 { + n = -n + } + + time.Sleep(time.Duration(n%2000) * time.Microsecond) + + return err +} + +// Explicit interface delegation — pointer-receiver SetRehydrationActive +// would otherwise shadow only that method while the rest fall through to +// the embedded chat.Store. Listing every method keeps go vet happy and +// mirrors the failingStore/trackingStore pattern. + +func (y *yieldingStore) CreateSession(ctx context.Context, s chat.Session) error { + return y.Store.CreateSession(ctx, s) +} + +func (y *yieldingStore) GetSession(ctx context.Context, id string) (chat.Session, error) { + return y.Store.GetSession(ctx, id) +} + +func (y *yieldingStore) ListSessions(ctx context.Context, filter chat.SessionFilter) ([]chat.Session, error) { + return y.Store.ListSessions(ctx, filter) +} + +func (y *yieldingStore) UpdateSession(ctx context.Context, s chat.Session) error { + return y.Store.UpdateSession(ctx, s) +} + +func (y *yieldingStore) DeleteSession(ctx context.Context, id string) error { + return y.Store.DeleteSession(ctx, id) +} + +func (y *yieldingStore) UpdateContextTokens(ctx context.Context, sessionID string, tokens int64, updatedAt time.Time) error { + return y.Store.UpdateContextTokens(ctx, sessionID, tokens, updatedAt) +} + +func (y *yieldingStore) AppendMessage(ctx context.Context, m chat.Message) (int64, error) { + return y.Store.AppendMessage(ctx, m) +} + +func (y *yieldingStore) ListMessages(ctx context.Context, sessionID string, sinceSeq int64, limit int) ([]chat.Message, error) { + return y.Store.ListMessages(ctx, sessionID, sinceSeq, limit) +} + +func (y *yieldingStore) ListMessagesTail(ctx context.Context, sessionID string, limit int) ([]chat.Message, error) { + return y.Store.ListMessagesTail(ctx, sessionID, limit) +} + +func (y *yieldingStore) MaxSeq(ctx context.Context, sessionID string) (int64, error) { + return y.Store.MaxSeq(ctx, sessionID) +} + +func (y *yieldingStore) Close() error { + return y.Store.Close() +} + +// sessionGate is a per-session-id blocking gate used by gatingStore. block(id) +// arms the gate so the next AppendMessage call with that id parks on a +// channel; waiting(id) reports whether a goroutine is currently parked there; +// release(id) unblocks the parked call. Designed for tests that need to assert +// two AppendMessage calls reach the store concurrently. +type sessionGate struct { + mu sync.Mutex + armed map[string]chan struct{} + parked map[string]bool +} + +func newSessionGate() *sessionGate { + return &sessionGate{ + armed: map[string]chan struct{}{}, + parked: map[string]bool{}, + } +} + +// block arms the gate for sessionID so the next AppendMessage call for that +// id will park until release(sessionID) runs. +func (g *sessionGate) block(sessionID string) { + g.mu.Lock() + defer g.mu.Unlock() + + g.armed[sessionID] = make(chan struct{}) +} + +// waiting reports whether a goroutine is currently parked on the gate for +// sessionID. +func (g *sessionGate) waiting(sessionID string) bool { + g.mu.Lock() + defer g.mu.Unlock() + + return g.parked[sessionID] +} + +// release closes the gate channel for sessionID, unblocking the parked call. +// Safe to call before any goroutine has reached the gate — the close happens +// before the channel receive. +func (g *sessionGate) release(sessionID string) { + g.mu.Lock() + + ch, ok := g.armed[sessionID] + + g.mu.Unlock() + + if !ok { + return + } + + close(ch) +} + +// wait is invoked from gatingStore.AppendMessage. If the gate is armed for +// sessionID, the call parks until release runs. Marks parked[sessionID]=true +// before parking and clears it after release. +func (g *sessionGate) wait(sessionID string) { + g.mu.Lock() + + ch, ok := g.armed[sessionID] + if !ok { + g.mu.Unlock() + + return + } + + g.parked[sessionID] = true + g.mu.Unlock() + + <-ch + + g.mu.Lock() + g.parked[sessionID] = false + g.mu.Unlock() +} + +// gatingStore wraps a real chat.Store and routes AppendMessage through a +// sessionGate so tests can park individual append calls per-session-id. All +// other methods delegate to the embedded store. Modelled on yieldingStore and +// failingStore — explicit method overrides for every Store method because the +// pointer-receiver AppendMessage would otherwise shadow only that one and +// leave the rest fall-through, which go vet flags as ambiguous. +type gatingStore struct { + chat.Store + gate *sessionGate +} + +func (g *gatingStore) AppendMessage(ctx context.Context, m chat.Message) (int64, error) { + g.gate.wait(m.SessionID) + + return g.Store.AppendMessage(ctx, m) +} + +func (g *gatingStore) CreateSession(ctx context.Context, s chat.Session) error { + return g.Store.CreateSession(ctx, s) +} + +func (g *gatingStore) GetSession(ctx context.Context, id string) (chat.Session, error) { + return g.Store.GetSession(ctx, id) +} + +func (g *gatingStore) ListSessions(ctx context.Context, filter chat.SessionFilter) ([]chat.Session, error) { + return g.Store.ListSessions(ctx, filter) +} + +func (g *gatingStore) UpdateSession(ctx context.Context, s chat.Session) error { + return g.Store.UpdateSession(ctx, s) +} + +func (g *gatingStore) DeleteSession(ctx context.Context, id string) error { + return g.Store.DeleteSession(ctx, id) +} + +func (g *gatingStore) SetRehydrationActive(ctx context.Context, sessionID string, active bool) error { + return g.Store.SetRehydrationActive(ctx, sessionID, active) +} + +func (g *gatingStore) UpdateContextTokens(ctx context.Context, sessionID string, tokens int64, updatedAt time.Time) error { + return g.Store.UpdateContextTokens(ctx, sessionID, tokens, updatedAt) +} + +func (g *gatingStore) ListMessages(ctx context.Context, sessionID string, sinceSeq int64, limit int) ([]chat.Message, error) { + return g.Store.ListMessages(ctx, sessionID, sinceSeq, limit) +} + +func (g *gatingStore) ListMessagesTail(ctx context.Context, sessionID string, limit int) ([]chat.Message, error) { + return g.Store.ListMessagesTail(ctx, sessionID, limit) +} + +func (g *gatingStore) MaxSeq(ctx context.Context, sessionID string) (int64, error) { + return g.Store.MaxSeq(ctx, sessionID) +} + +func (g *gatingStore) Close() error { + return g.Store.Close() +} diff --git a/internal/chat/manager.go b/internal/chat/manager.go new file mode 100644 index 00000000..62eacf78 --- /dev/null +++ b/internal/chat/manager.go @@ -0,0 +1,1122 @@ +package chat + +import ( + "context" + "errors" + "fmt" + "log/slog" + "slices" + "sync" + "time" + "unicode/utf8" + + "golang.org/x/sync/singleflight" + + "github.com/mhersson/contextmatrix/internal/chat/transcript" + "github.com/mhersson/contextmatrix/internal/clock" +) + +// ErrTooManyConcurrent is returned by OpenSession when the number of active +// or warm-idle sessions has reached the configured MaxConcurrent ceiling. +var ErrTooManyConcurrent = errors.New("chat: too many concurrent sessions") + +// RunnerClient is the subset of the runner webhook surface that +// chat.Manager uses. The real implementation lives in internal/chat/runner.go; +// tests inject stubs. +type RunnerClient interface { + StartChat(ctx context.Context, opts StartChatOpts) (containerID string, err error) + EndChat(ctx context.Context, sessionID string) error + SendChatMessage(ctx context.Context, sessionID, content, messageID string) error + // StreamLogs opens a long-lived SSE subscription to the runner's + // /logs?session_id= endpoint and invokes onEntry for each parsed + // LogEntry. Returns when ctx is cancelled or the stream closes. + StreamLogs(ctx context.Context, sessionID string, onEntry func(LogEntry)) error +} + +// StartChatOpts carries every input to RunnerClient.StartChat. Bundling the +// arguments lets us add fields (Model, Resume) without breaking the wire +// for tests with bespoke fakes. +type StartChatOpts struct { + SessionID string + Project string + RepoURL string + Model string + Resume *ResumeContext +} + +// Config carries Manager dependencies. +type Config struct { + Store Store + Runner RunnerClient + Clock clock.Clock + IdleTTL time.Duration + Logger *slog.Logger + // MaxConcurrent is the maximum number of sessions that may be active or + // warm-idle at the same time. Zero means unlimited. + MaxConcurrent int + // Hub is the per-session SSE fan-out. When non-nil, SendUserMessage + // publishes a user echo so the originator sees their own message in + // the transcript without depending on a runner-side log round-trip. + Hub *SSEHub + // ResolveRepoURL returns the repo URL for a project, or "" if the + // project has no repo. Caller wires this from service.CardService.GetProject. + ResolveRepoURL func(ctx context.Context, project string) (string, error) + + // ResumeBudgetTokens caps the rehydration payload size passed to + // transcript.Build on cold-reopen. Zero falls back to + // transcript.DefaultBudgetTokens. + ResumeBudgetTokens int + + // RehydrationTimeout is the upper bound on how long a session may + // remain in the rehydration phase before the reaper forces it off. + // Zero means "do not force off by timer" (user-message and + // chat_rehydration_complete remain the only end signals). Production + // wires this to chat.rehydration_timeout from config. + RehydrationTimeout time.Duration + + // DefaultModel is used when a session row's model column is empty + // (legacy rows, or callers that didn't pass a model on creation). + // Production wires this to chat.default_model from config. + DefaultModel string +} + +// Manager orchestrates chat session lifecycle, transcript persistence, +// and runner-client coordination. +type Manager struct { + store Store + runner RunnerClient + clk clock.Clock + idleTTL time.Duration + maxConcurrent int + logger *slog.Logger + hub *SSEHub + resolveRepoURL func(ctx context.Context, project string) (string, error) + resumeBudgetTokens int + rehydrationTimeout time.Duration + defaultModel string + + mu sync.Mutex + seqMap map[string]int64 // sessionID → last assigned seq + titled map[string]bool // sessionID → auto-title work already completed + consumers map[string]*consumerHandle // sessionID → runner-log consumer lifecycle + // rehydrationActive mirrors chat_sessions.rehydration_active. Reads + // from AppendMessage's hot path go through the cache to avoid an + // extra DB round-trip per log entry; cache misses fall back to the + // store and populate. setRehydrationActive updates both store and + // cache atomically (under mu). + rehydrationActive map[string]bool + wg sync.WaitGroup + + // openGroup deduplicates concurrent cold-open work per sessionID. Two + // callers racing to open the same id share one runner.StartChat + // round-trip; callers on *different* ids no longer serialise behind a + // global mutex while a slow docker pull is in flight. + openGroup singleflight.Group + // openLimitMu serialises just the MaxConcurrent count check + the + // StartChat reservation window so concurrent cold opens cannot pass a + // stale count and overshoot the limit. Held across StartChat for + // limit-bounded callers only; when MaxConcurrent is 0 the lock is not + // acquired and singleflight alone gates per-id work. + openLimitMu sync.Mutex + + // appendLocks holds a per-session mutex used by AppendMessage to keep + // the (seq-assign → store-write) window atomic for a given session + // without coupling unrelated sessions through m.mu. The + // UNIQUE(session_id, seq) index on disk is the final guarantor of seq + // uniqueness. Lazily populated on first AppendMessage per session and + // cleaned up by DeleteSession. + appendLocksMu sync.Mutex + appendLocks map[string]*sync.Mutex + + closeOnce sync.Once +} + +// NewManager constructs a Manager. Required: Store, Runner. +func NewManager(cfg Config) *Manager { + if cfg.Clock == nil { + cfg.Clock = clock.Real() + } + + if cfg.Logger == nil { + cfg.Logger = slog.Default() + } + + return &Manager{ + store: cfg.Store, + runner: cfg.Runner, + clk: cfg.Clock, + idleTTL: cfg.IdleTTL, + maxConcurrent: cfg.MaxConcurrent, + logger: cfg.Logger, + hub: cfg.Hub, + resolveRepoURL: cfg.ResolveRepoURL, + resumeBudgetTokens: cfg.ResumeBudgetTokens, + rehydrationTimeout: cfg.RehydrationTimeout, + defaultModel: cfg.DefaultModel, + seqMap: make(map[string]int64), + titled: make(map[string]bool), + consumers: make(map[string]*consumerHandle), + rehydrationActive: make(map[string]bool), + appendLocks: map[string]*sync.Mutex{}, + } +} + +// appendLock returns the per-session append mutex, creating it on first use. +// Lazily allocated so cold sessions don't pay the cost upfront. The map +// itself is guarded by appendLocksMu, which is independent of m.mu so the +// hot path of AppendMessage does not serialise on the same lock that +// guards shared session state. +func (m *Manager) appendLock(sessionID string) *sync.Mutex { + m.appendLocksMu.Lock() + defer m.appendLocksMu.Unlock() + + mu, ok := m.appendLocks[sessionID] + if !ok { + mu = &sync.Mutex{} + m.appendLocks[sessionID] = mu + } + + return mu +} + +// consumerHandle is the per-session lifecycle handle for a runner-log consumer +// goroutine. done is closed when the goroutine returns, so stopConsumer can +// block until the cleanup defers have executed and the consumers map is +// guaranteed clean before a subsequent startConsumer runs. +type consumerHandle struct { + cancel context.CancelFunc + done chan struct{} +} + +// roleFromLogType maps a runner LogEntry.Type to a chat.Role. +// Unknown types fall back to RoleSystem so transcripts remain complete. +func roleFromLogType(typ string) Role { + switch typ { + case "text": + return RoleAssistantText + case "thinking": + return RoleAssistantThinking + case "tool_call": + return RoleToolCall + case "tool_result": + return RoleToolResult + case "stderr": + return RoleStderr + case "user": + // User echoes are produced CM-side via SendUserMessage; the runner + // re-emits them on the broadcaster as a courtesy. Ignore to avoid + // duplicate transcript entries. + return "" + case "usage": + // Usage entries are metadata; handled separately by handleUsageEntry + // and do not become transcript rows. + return "" + default: + return RoleSystem + } +} + +// handleUsageEntry processes a Claude stream-json usage block reported by the +// runner. The session row's context_tokens are updated and a session_updated +// SSE event is published so the UI header indicator refreshes in real time. +// Errors are non-fatal — usage is a UI niceness, not a correctness property. +func (m *Manager) handleUsageEntry(ctx context.Context, sessionID string, e LogEntry) { + if e.Usage == nil { + return + } + + tokens := e.Usage.InputTokens + e.Usage.CacheReadTokens + e.Usage.CacheCreateTokens + + updatedAt := e.Timestamp + if updatedAt.IsZero() { + updatedAt = m.clk.Now().UTC() + } + + if err := m.store.UpdateContextTokens(ctx, sessionID, tokens, updatedAt); err != nil { + // Session may have been deleted between the runner emitting the + // event and CM consuming it — log at debug rather than warn. + m.logger.Debug("chat: handleUsageEntry: update context_tokens failed", + "session_id", sessionID, "error", err) + + return + } + + if m.hub != nil { + m.hub.PublishSessionUpdate(sessionID, SessionUpdate{ + ContextTokens: tokens, + ContextTokensUpdatedAt: updatedAt, + Model: e.Model, + }) + } +} + +// startConsumer ensures a goroutine is bridging runner /logs for sessionID +// into AppendMessage + hub.Publish. Idempotent: subsequent calls for the +// same session are no-ops while a consumer is already running. +func (m *Manager) startConsumer(sessionID string) { + m.mu.Lock() + + if _, ok := m.consumers[sessionID]; ok { + m.mu.Unlock() + + return + } + + ctx, cancel := context.WithCancel(context.Background()) + handle := &consumerHandle{cancel: cancel, done: make(chan struct{})} + m.consumers[sessionID] = handle + m.mu.Unlock() + + m.wg.Add(1) + + go func() { + defer m.wg.Done() + // close(done) runs LAST so stopConsumer's <-done blocks until the + // map-entry cleanup has executed; that guarantees a follow-on + // startConsumer sees a clean slate and is not defeated by a stale + // entry. + defer close(handle.done) + defer func() { + m.mu.Lock() + // Defensive identity check: stopConsumer may have already removed + // the entry. Only delete if we still own it. + if cur, ok := m.consumers[sessionID]; ok && cur == handle { + delete(m.consumers, sessionID) + } + m.mu.Unlock() + }() + + onEntry := func(e LogEntry) { + if e.Type == "usage" { + m.handleUsageEntry(ctx, sessionID, e) + + return + } + + role := roleFromLogType(e.Type) + if role == "" { + return + } + + msg, err := m.AppendMessage(ctx, sessionID, role, e.Content) + if err != nil { + m.logger.Warn("chat: consumer AppendMessage failed", + "session_id", sessionID, "type", e.Type, "error", err) + + return + } + + if m.hub != nil { + m.hub.Publish(sessionID, SSEEvent{ + Seq: msg.Seq, + Role: role, + Content: e.Content, + RehydrationPhase: msg.RehydrationPhase, + }) + } + } + + m.logger.Info("chat: runner-log consumer started", "session_id", sessionID) + + if err := m.runner.StreamLogs(ctx, sessionID, onEntry); err != nil && !errors.Is(err, context.Canceled) { + m.logger.Warn("chat: runner-log consumer exited with error", + "session_id", sessionID, "error", err) + + return + } + + m.logger.Info("chat: runner-log consumer stopped", "session_id", sessionID) + }() +} + +// stopConsumer cancels the runner-log consumer for sessionID and blocks until +// the goroutine has exited. Synchronous teardown is required so a fast Reopen +// after EndSession is guaranteed to start a fresh consumer — an asynchronous +// cleanup defer would leave the map entry visible to startConsumer's +// idempotency check, dropping the new session's log bridge. +func (m *Manager) stopConsumer(sessionID string) { + m.mu.Lock() + + handle, ok := m.consumers[sessionID] + if ok { + delete(m.consumers, sessionID) + } + m.mu.Unlock() + + if !ok { + return + } + + handle.cancel() + <-handle.done +} + +// Close cancels every active runner-log consumer goroutine and waits for them +// to exit. The supplied ctx acts as a deadline: if consumers have not all +// stopped by ctx.Done(), Close returns an error wrapping ctx.Err(). Idempotent +// — subsequent calls are no-ops and return nil. +func (m *Manager) Close(ctx context.Context) error { + m.closeOnce.Do(func() { + m.mu.Lock() + for id, handle := range m.consumers { + handle.cancel() + delete(m.consumers, id) + } + m.mu.Unlock() + }) + + done := make(chan struct{}) + + go func() { + m.wg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return fmt.Errorf("chat: Close: timeout waiting for consumers: %w", ctx.Err()) + } +} + +// CreateInput is the user-facing payload for creating a new session. +type CreateInput struct { + Title string + Project string + CreatedBy string + Model string +} + +// CreateSession inserts a cold-state session row. The container is +// not started until OpenSession is called. +func (m *Manager) CreateSession(ctx context.Context, in CreateInput) (Session, error) { + now := m.clk.Now().UTC().Truncate(time.Second) + + sess := Session{ + ID: NewID(), + Title: in.Title, + Project: in.Project, + Status: StatusCold, + CreatedAt: now, + LastActive: now, + CreatedBy: in.CreatedBy, + Model: in.Model, + } + if err := m.store.CreateSession(ctx, sess); err != nil { + return Session{}, fmt.Errorf("chat: CreateSession: %w", err) + } + + return sess, nil +} + +// buildResume loads the prior transcript and returns a ResumeContext for the +// runner, or nil when there's nothing worth resuming (fresh session, all +// messages filtered, or a DB error — the last case is logged and degrades to +// a fresh agent rather than refusing to open). +func (m *Manager) buildResume(ctx context.Context, sessionID string) *ResumeContext { + const maxMessagesForBuild = 600 // matches transcript.MaxTurns + headroom + + msgs, err := m.store.ListMessagesTail(ctx, sessionID, maxMessagesForBuild) + if err != nil { + m.logger.Warn("chat: buildResume list messages failed; skipping rehydration", + "session_id", sessionID, "error", err) + + return nil + } + + tmsgs := make([]transcript.Message, len(msgs)) + for i, msg := range msgs { + tmsgs[i] = transcript.Message{ + Seq: msg.Seq, + Role: string(msg.Role), + Content: msg.Content, + RehydrationPhase: msg.RehydrationPhase, + } + } + + return transcript.Build(tmsgs, transcript.BuildOpts{BudgetTokens: m.resumeBudgetTokens}) +} + +// resumeTurnCount is a small helper for structured logging — returns the +// number of turns in a Resume, or 0 when nil. +func resumeTurnCount(r *ResumeContext) int { + if r == nil { + return 0 + } + + return len(r.Turns) +} + +// isRehydrationActive reports whether the session is currently in its +// rehydration phase. Reads go through the in-memory cache first; misses +// fall back to the store and populate. Errors fall back to false rather +// than blocking AppendMessage's hot path. +func (m *Manager) isRehydrationActive(ctx context.Context, sessionID string) bool { + m.mu.Lock() + + if v, ok := m.rehydrationActive[sessionID]; ok { + m.mu.Unlock() + + return v + } + + m.mu.Unlock() + + sess, err := m.store.GetSession(ctx, sessionID) + if err != nil { + return false + } + + m.mu.Lock() + m.rehydrationActive[sessionID] = sess.RehydrationActive + m.mu.Unlock() + + return sess.RehydrationActive +} + +// setRehydrationActive writes the flag to the store and mirrors it to the +// in-memory cache. Called from OpenSession (cold-resume → true), +// SendUserMessage (first user msg → false), CompleteRehydration (MCP tool +// → false), EndSession (cold transition → false), and the reaper sweep +// (timeout → false). +// +// Hold m.mu across the store write so the persisted value and the cached +// value cannot diverge under concurrent flips. Two callers that race to +// write opposite booleans now serialise here, and whichever commits to +// disk last is the value the cache holds on return. +func (m *Manager) setRehydrationActive(ctx context.Context, sessionID string, active bool) error { + m.mu.Lock() + defer m.mu.Unlock() + + if err := m.store.SetRehydrationActive(ctx, sessionID, active); err != nil { + return err + } + + m.rehydrationActive[sessionID] = active + + return nil +} + +// CompleteRehydration ends the per-session rehydration phase: persists +// `summary` as a normal (non-phase) assistant_text message, flips the +// session flag off, and publishes the summary to the SSE hub. Idempotent: +// a second call with the flag already off returns success and no-ops. +func (m *Manager) CompleteRehydration(ctx context.Context, sessionID, summary string) error { + if _, err := m.store.GetSession(ctx, sessionID); err != nil { + return fmt.Errorf("chat: CompleteRehydration: %w", err) + } + + if !m.isRehydrationActive(ctx, sessionID) { + m.logger.Debug("chat: CompleteRehydration: already inactive, no-op", + "session_id", sessionID) + + return nil + } + + if err := m.setRehydrationActive(ctx, sessionID, false); err != nil { + return fmt.Errorf("chat: CompleteRehydration: flip flag: %w", err) + } + + msg, err := m.AppendMessage(ctx, sessionID, RoleAssistantText, summary) + if err != nil { + return fmt.Errorf("chat: CompleteRehydration: append summary: %w", err) + } + + if m.hub != nil { + m.hub.Publish(sessionID, SSEEvent{ + Seq: msg.Seq, + Role: RoleAssistantText, + Content: summary, + }) + } + + m.logger.Info("chat: rehydration complete", + "session_id", sessionID, "summary_len", len(summary)) + + return nil +} + +// OpenSession transitions a session into the active state, starting a +// new container if cold or reattaching if warm-idle. Idempotent on +// already-active sessions. +func (m *Manager) OpenSession(ctx context.Context, id string) (Session, error) { + sess, err := m.store.GetSession(ctx, id) + if err != nil { + return Session{}, fmt.Errorf("chat: OpenSession: %w", err) + } + + switch sess.Status { + case StatusActive: + // Idempotent for already-active sessions. Also ensure the runner-log + // consumer is bridging /logs back into the SSE hub: a CM restart + // strands that goroutine while the row stays active, and the only + // recovery path used to be End → Reopen (which kills the container + // and rehydrates a fresh one). startConsumer is a no-op when a + // consumer for this session is already running. + m.startConsumer(sess.ID) + + return sess, nil + + case StatusWarmIdle: + sess.Status = StatusActive + + sess.LastActive = m.clk.Now().UTC().Truncate(time.Second) + if err := m.store.UpdateSession(ctx, sess); err != nil { + return Session{}, fmt.Errorf("chat: warm reattach: %w", err) + } + + m.logger.Info("chat: warm-idle reattached", "session_id", sess.ID) + m.startConsumer(sess.ID) + + return sess, nil + + case StatusCold: + // Route the cold-start path through singleflight keyed on + // sessionID so concurrent callers for the same id share one + // runner.StartChat round-trip, and callers for *different* ids + // no longer serialise on a global mutex behind a slow docker + // run / image pull. + v, err, _ := m.openGroup.Do(id, func() (any, error) { + return m.openCold(ctx, id) + }) + if err != nil { + return Session{}, err + } + + return v.(Session), nil + + case StatusEnding: + return Session{}, fmt.Errorf("chat: session is ending") + } + + return Session{}, fmt.Errorf("chat: unknown status %q", sess.Status) +} + +// openCold runs the cold→active transition for a single sessionID. It is +// invoked under singleflight by OpenSession so concurrent callers for the +// same id share one runner.StartChat round-trip; callers for *different* +// ids no longer serialise on a global lock when MaxConcurrent is 0. +// +// The MaxConcurrent count check + StartChat reservation are still held +// under m.openLimitMu so racing limit-bounded opens cannot pass a stale +// count and overshoot. Holding the lock across StartChat keeps the +// limit-bounded path serial at runner-latency timescale — still strictly +// better than the old global serialisation, which gated even MaxConcurrent=0 +// callers. A truly parallel cold-open under a hard limit would need a +// reservation counter; out of scope here. +func (m *Manager) openCold(ctx context.Context, id string) (Session, error) { + sess, err := m.store.GetSession(ctx, id) + if err != nil { + return Session{}, fmt.Errorf("chat: OpenSession (re-read): %w", err) + } + + if sess.Status != StatusCold { + // Another caller raced ahead and opened this session. Treat as + // already-active. + return sess, nil + } + + if m.maxConcurrent > 0 { + m.openLimitMu.Lock() + defer m.openLimitMu.Unlock() + + active, err := m.store.ListSessions(ctx, SessionFilter{Status: StatusActive}) + if err != nil { + return Session{}, fmt.Errorf("chat: count active: %w", err) + } + + warm, err := m.store.ListSessions(ctx, SessionFilter{Status: StatusWarmIdle}) + if err != nil { + return Session{}, fmt.Errorf("chat: count warm: %w", err) + } + + if len(active)+len(warm) >= m.maxConcurrent { + return Session{}, ErrTooManyConcurrent + } + } + + var repoURL string + if sess.Project != "" && m.resolveRepoURL != nil { + repoURL, err = m.resolveRepoURL(ctx, sess.Project) + if err != nil { + return Session{}, fmt.Errorf("chat: resolve repo for %q: %w", sess.Project, err) + } + } + + // Build the rehydration payload from the persisted transcript. + // Errors here are non-fatal — fall back to "no resume" so we + // never block the user from opening the chat. + resume := m.buildResume(ctx, sess.ID) + + model := sess.Model + if model == "" { + model = m.defaultModel + } + + m.logger.Info("chat: opening cold session", + "session_id", sess.ID, "project", sess.Project, "repo_url", repoURL, + "model", model, "has_resume", resume != nil, + "resume_turn_count", resumeTurnCount(resume)) + + containerID, err := m.runner.StartChat(ctx, StartChatOpts{ + SessionID: sess.ID, + Project: sess.Project, + RepoURL: repoURL, + Model: model, + Resume: resume, + }) + if err != nil { + return Session{}, fmt.Errorf("chat: start container: %w", err) + } + + sess.Status = StatusActive + sess.ContainerID = containerID + sess.Model = model + + sess.LastActive = m.clk.Now().UTC().Truncate(time.Second) + if sess.Project != "" && !slices.Contains(sess.Workspace, sess.Project) { + sess.Workspace = append(sess.Workspace, sess.Project) + } + + if err := m.store.UpdateSession(ctx, sess); err != nil { + // Roll back the container start so we don't leak. + if rbErr := m.runner.EndChat(context.Background(), sess.ID); rbErr != nil { + m.logger.Warn("chat: rollback EndChat failed after persist failure", + "session_id", sess.ID, "container_id", containerID, "error", rbErr) + } + + return Session{}, fmt.Errorf("chat: persist active: %w", err) + } + + if resume != nil { + // Pre-arm the in-memory cache so concurrent log writes during + // the persist window stamp rehydration_phase=TRUE even before + // the DB write completes. + // NOTE: if the setRehydrationActive persist below fails, the cache + // is rolled back and the session is reset to cold. However, any + // messages appended during this narrow window (between the pre-arm + // and the rollback) will keep their rehydration_phase=TRUE stamp + // permanently — they will be excluded from future resume payloads + // by transcript.Build. In practice this window spans a single store + // write and no user-driven AppendMessage can race here: the + // runner-consumer goroutine is only spawned after a successful + // OpenSession returns, so the risk is negligible. + m.mu.Lock() + m.rehydrationActive[sess.ID] = true + m.mu.Unlock() + + if err := m.setRehydrationActive(ctx, sess.ID, true); err != nil { + // Roll back: clear the cache, stop the container, and + // reset the session row to cold so the next open retries + // cleanly. + m.mu.Lock() + delete(m.rehydrationActive, sess.ID) + m.mu.Unlock() + + if rbErr := m.runner.EndChat(context.Background(), sess.ID); rbErr != nil { + m.logger.Warn("chat: OpenSession: rollback EndChat failed after rehydration persist failure", + "session_id", sess.ID, "error", rbErr) + } + + sess.Status = StatusCold + sess.ContainerID = "" + + if err := m.store.UpdateSession(ctx, sess); err != nil { + m.logger.Warn("chat: OpenSession: rollback reset to cold failed", + "session_id", sess.ID, "error", err) + } + + return Session{}, fmt.Errorf("chat: OpenSession: persist rehydration_active: %w", err) + } + + sess.RehydrationActive = true + } + + m.logger.Info("chat: cold session active", + "session_id", sess.ID, "container_id", containerID) + m.startConsumer(sess.ID) + + return sess, nil +} + +// maxMessageBytes caps a single persisted transcript entry. Verbose tool +// output (e.g. a tool_result containing a large file dump) would otherwise +// grow chats.db linearly without bound. The user-message path is already +// capped at the HTTP boundary (8192 bytes), so this cap mainly fires on +// runner-emitted entries. +const maxMessageBytes = 32 * 1024 + +// truncationMarker is appended to messages that exceeded maxMessageBytes. +const truncationMarker = "\n... [truncated]" + +// truncateMessageContent caps content at maxMessageBytes and appends the +// truncation marker. Truncation respects UTF-8 rune boundaries so the marker +// is not appended in the middle of a multibyte sequence. +func truncateMessageContent(content string) string { + if len(content) <= maxMessageBytes { + return content + } + + cut := maxMessageBytes - len(truncationMarker) + // Back up to a rune start so we don't slice mid-rune. + for cut > 0 && (content[cut]&0xC0) == 0x80 { + cut-- + } + + return content[:cut] + truncationMarker +} + +// AppendMessage persists a transcript entry with a monotonic seq. +// Seq is assigned server-side; the caller does not provide it. The +// rehydration_phase column on the persisted row is sourced from the +// in-memory cache (mirrors session.rehydration_active) so messages emitted +// during the rehydration phase are excluded from future resume payloads +// by transcript.Build. +func (m *Manager) AppendMessage(ctx context.Context, sessionID string, role Role, content string) (Message, error) { + content = truncateMessageContent(content) + + phase := m.isRehydrationActive(ctx, sessionID) + + // Auto-title: if this is the first user message on a still-untitled session, + // derive a title from the content (50-byte truncation with ellipsis). The + // `titled` cache skips the SELECT+UPDATE round-trip once we've confirmed a + // title exists for the session. + if role == RoleUser { + m.mu.Lock() + alreadyTitled := m.titled[sessionID] + m.mu.Unlock() + + if !alreadyTitled { + sess, err := m.store.GetSession(ctx, sessionID) + if err == nil { + if sess.Title == "" { + title := content + // Truncate at rune boundary, not byte boundary — slicing + // bytes mid-UTF-8-rune produces invalid sequences that + // JSON-marshal as U+FFFD. + if utf8.RuneCountInString(title) > 50 { + runes := []rune(title) + title = string(runes[:50]) + "…" + } + + sess.Title = title + if err := m.store.UpdateSession(ctx, sess); err != nil { + m.logger.Warn("chat: auto-title persist failed", + "session_id", sessionID, "error", err) + } + } + + m.mu.Lock() + m.titled[sessionID] = true + m.mu.Unlock() + } + } + } + + // Per-session lock keeps the (seq-assign → store-write) window atomic + // for this session without coupling unrelated sessions through m.mu. + // One slow fsync on session A no longer stalls appends to session B. + // SQLite serialises writes at the engine level and the + // UNIQUE(session_id, seq) index is the final correctness guarantor. + sl := m.appendLock(sessionID) + sl.Lock() + defer sl.Unlock() + + // seqMap is shared across sessions so the seq-assign window must take + // m.mu briefly. The store write below runs outside m.mu — only the + // per-session lock is held across the I/O. + m.mu.Lock() + + // Lazy seed from the store if first call this process. Uses an indexed + // MAX(seq) query so the seed cost is constant time even on long sessions. + if _, ok := m.seqMap[sessionID]; !ok { + maxSeq, err := m.store.MaxSeq(ctx, sessionID) + if err != nil { + m.mu.Unlock() + + return Message{}, fmt.Errorf("chat: seed seq: %w", err) + } + + m.seqMap[sessionID] = maxSeq + } + + m.seqMap[sessionID]++ + seq := m.seqMap[sessionID] + m.mu.Unlock() + + msg := Message{ + SessionID: sessionID, + Seq: seq, + Role: role, + Content: content, + CreatedAt: m.clk.Now().UTC().Truncate(time.Second), + RehydrationPhase: phase, + } + + if _, err := m.store.AppendMessage(ctx, msg); err != nil { + // The seq was claimed but the store write failed. Roll back the + // in-memory counter so the next append re-uses it; the UNIQUE + // index on (session_id, seq) would otherwise reject a future + // AppendMessage if the failed write somehow made it to disk. The + // per-session lock is still held so the rollback is sequenced + // before any other append on this session. + m.mu.Lock() + m.seqMap[sessionID]-- + m.mu.Unlock() + + return Message{}, fmt.Errorf("chat: append: %w", err) + } + + return msg, nil +} + +// Reattach ensures the runner-log consumer is running for an active or +// warm-idle session and refreshes its LastActive timestamp so the idle +// reaper doesn't end it while the user is interacting with it. No-op on +// cold and ending sessions, which have no live runner container to +// bridge to. +// +// Status is intentionally left untouched: warm-idle stays warm-idle. +// Flipping to active would require a session_updated SSE push to keep +// the sidebar in sync, but SessionUpdate carries no Status field today. +// Leaving status as-is means the chat still works (SendUserMessage +// accepts warm-idle), the sidebar stays consistent, and the next +// natural transition (user types a message, or the grace timer fires +// after disconnect) keeps the state machine clean. +// +// Idempotent and safe for concurrent callers — startConsumer guards +// against duplicate consumer goroutines internally. +func (m *Manager) Reattach(ctx context.Context, sessionID string) error { + sess, err := m.store.GetSession(ctx, sessionID) + if err != nil { + return fmt.Errorf("chat: Reattach: %w", err) + } + + if sess.Status != StatusActive && sess.Status != StatusWarmIdle { + return nil + } + + sess.LastActive = m.clk.Now().UTC().Truncate(time.Second) + if err := m.store.UpdateSession(ctx, sess); err != nil { + return fmt.Errorf("chat: Reattach: persist last-active: %w", err) + } + + m.startConsumer(sess.ID) + + return nil +} + +// MarkWarmIdle transitions an active session to warm-idle. No-op if the +// session is not active. Tolerant of ErrSessionNotFound — a grace timer +// fired against a session that was already deleted (DeleteSession, +// reconcile sweep) is a benign race, not an error. +func (m *Manager) MarkWarmIdle(ctx context.Context, id string) error { + sess, err := m.store.GetSession(ctx, id) + if err != nil { + if errors.Is(err, ErrSessionNotFound) { + m.logger.Debug("chat: MarkWarmIdle: session not found, ignoring", "session_id", id) + + return nil + } + + return fmt.Errorf("chat: MarkWarmIdle: %w", err) + } + + if sess.Status != StatusActive { + return nil + } + + sess.Status = StatusWarmIdle + + sess.LastActive = m.clk.Now().UTC().Truncate(time.Second) + if err := m.store.UpdateSession(ctx, sess); err != nil { + return fmt.Errorf("chat: MarkWarmIdle persist: %w", err) + } + + return nil +} + +// GetSession returns the persisted session by ID. +func (m *Manager) GetSession(ctx context.Context, id string) (Session, error) { + return m.store.GetSession(ctx, id) +} + +// EndSession transitions a session to cold, stopping the runner container. +// Idempotent on already-cold sessions and re-entrant against status=ending +// rows (which can result from a prior partial failure). Runner teardown and +// consumer stop are both idempotent, so calling EndSession on a wedged +// ending row safely completes the transition in a single store write. +func (m *Manager) EndSession(ctx context.Context, id string) error { + sess, err := m.store.GetSession(ctx, id) + if err != nil { + return fmt.Errorf("chat: EndSession: %w", err) + } + + if sess.Status == StatusCold { + return nil + } + + m.logger.Info("chat: ending session", "session_id", sess.ID, + "from_status", string(sess.Status)) + + // Tear down runner-side resources first; both calls are idempotent so + // re-entry from a status=ending row is safe. + m.stopConsumer(sess.ID) + + if err := m.runner.EndChat(ctx, sess.ID); err != nil { + m.logger.Warn("chat: runner end failed, marking cold anyway", + "session_id", sess.ID, "error", err) + } + + // Single store write: transition directly to cold without an intermediate + // status=ending persist. Collapsing to one write means a failure here + // leaves the row in its original state (active/warm-idle/ending) rather + // than wedged in ending, making the next EndSession call a clean retry. + sess.Status = StatusCold + sess.ContainerID = "" + sess.LastActive = m.clk.Now().UTC().Truncate(time.Second) + + if err := m.store.UpdateSession(ctx, sess); err != nil { + return fmt.Errorf("chat: mark cold: %w", err) + } + + // Reset any leftover rehydration flag so a subsequent reopen starts + // from a clean state. setRehydrationActive is idempotent and tolerant + // of an already-false value. + if err := m.setRehydrationActive(ctx, sess.ID, false); err != nil { + m.logger.Warn("chat: EndSession: clear rehydration flag failed", + "session_id", sess.ID, "error", err) + } + + // TODO: publish a session_updated SSE event here so the UI refreshes its + // status indicator on the cold transition. SessionUpdate currently has no + // Status field; add one and call m.hub.PublishSessionUpdate when that lands. + + m.logger.Info("chat: session cold", "session_id", sess.ID) + + return nil +} + +// ListSessions returns sessions matching the filter, newest-active first. +func (m *Manager) ListSessions(ctx context.Context, f SessionFilter) ([]Session, error) { + return m.store.ListSessions(ctx, f) +} + +// DeleteSession ends the container if running, then deletes the row. +func (m *Manager) DeleteSession(ctx context.Context, id string) error { + sess, err := m.store.GetSession(ctx, id) + if err != nil { + return err + } + + if sess.Status == StatusActive || sess.Status == StatusWarmIdle { + if err := m.EndSession(ctx, id); err != nil { + m.logger.Warn("chat: DeleteSession: EndSession failed, deleting anyway", + "session_id", id, "error", err) + } + } + + m.stopConsumer(id) + + if err := m.store.DeleteSession(ctx, id); err != nil { + return err + } + + // Release the SSE hub's per-session ring buffer + subscriber set so the + // hub doesn't grow without bound across session churn. + if m.hub != nil { + m.hub.Drop(id) + } + + m.logger.Info("chat: session deleted", "session_id", id) + + // Drop the seq cache entry so a future session that happens to reuse + // the ID (or an accidental Append after delete) does not leak memory. + m.mu.Lock() + delete(m.seqMap, id) + delete(m.titled, id) + delete(m.rehydrationActive, id) + m.mu.Unlock() + + // Drop the per-session append lock entry. Held under appendLocksMu + // rather than m.mu so the AppendMessage hot path's appendLock() call + // does not serialise on the same lock that guards shared session state. + m.appendLocksMu.Lock() + delete(m.appendLocks, id) + m.appendLocksMu.Unlock() + + return nil +} + +// SendUserMessage forwards a user message to the runner first; only on a +// successful runner call is the message persisted and fanned out via the +// SSE hub. If the runner is unreachable the caller gets an error and the +// UI can retry — the alternative (snappy echo, then runner failure) used +// to leave the user staring at their own message with no reply path. +// Cold-state sessions are opened first. Returns the generated message_id +// used for runner-side echo dedup. +// +// If the session is currently in its rehydration phase, the user typing +// ends the phase as a belt-and-suspenders safety net for agents that +// forget to call chat_rehydration_complete. The flag is flipped BEFORE +// AppendMessage so the user's message itself is persisted as non-phase. +func (m *Manager) SendUserMessage(ctx context.Context, sessionID, content string) (string, error) { + sess, err := m.store.GetSession(ctx, sessionID) + if err != nil { + return "", err + } + + if sess.Status == StatusCold { + if _, err := m.OpenSession(ctx, sessionID); err != nil { + return "", err + } + } + + if m.isRehydrationActive(ctx, sessionID) { + if err := m.setRehydrationActive(ctx, sessionID, false); err != nil { + m.logger.Warn("chat: SendUserMessage: clear rehydration flag failed", + "session_id", sessionID, "error", err) + } else { + m.logger.Info("chat: rehydration ended by user message", + "session_id", sessionID) + } + } + + msgID := NewID() + + m.logger.Info("chat: forwarding user message to runner", + "session_id", sessionID, "message_id", msgID, "content_len", len(content)) + + if err := m.runner.SendChatMessage(ctx, sessionID, content, msgID); err != nil { + return "", err + } + + // Runner accepted the message — now safe to persist + publish. + msg, err := m.AppendMessage(ctx, sessionID, RoleUser, content) + if err != nil { + return "", err + } + + if m.hub != nil { + m.hub.Publish(sessionID, SSEEvent{ + Seq: msg.Seq, + Role: RoleUser, + Content: content, + }) + } + + return msgID, nil +} + +// UpdateSessionMetadata writes session metadata changes (title, last_active). +func (m *Manager) UpdateSessionMetadata(ctx context.Context, s Session) error { + return m.store.UpdateSession(ctx, s) +} + +// ListMessages returns the transcript slice (seq > sinceSeq, oldest-first, +// bounded by limit). Used by the REST bootstrap endpoint that backfills the +// browser ring buffer beyond what the SSE in-memory ring can replay. +func (m *Manager) ListMessages(ctx context.Context, sessionID string, sinceSeq int64, limit int) ([]Message, error) { + return m.store.ListMessages(ctx, sessionID, sinceSeq, limit) +} diff --git a/internal/chat/manager_test.go b/internal/chat/manager_test.go new file mode 100644 index 00000000..b2bece43 --- /dev/null +++ b/internal/chat/manager_test.go @@ -0,0 +1,1713 @@ +package chat_test + +import ( + "context" + "database/sql" + "errors" + "fmt" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + "unicode/utf8" + + _ "modernc.org/sqlite" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mhersson/contextmatrix/internal/chat" + "github.com/mhersson/contextmatrix/internal/chat/sqlite" + "github.com/mhersson/contextmatrix/internal/clock" +) + +// stubRunner is a fake chat.Runner used by manager tests. Counters are atomic +// because Manager.startConsumer spawns a goroutine that calls StreamLogs +// independently of the test goroutine — plain ints would race under -race. +type stubRunner struct { + startCalls atomic.Int64 + endCalls atomic.Int64 + sendCalls atomic.Int64 + streamCalls atomic.Int64 + activeStreams atomic.Int32 + startErr error + sendErr error + streamLogsFn func(ctx context.Context, sessionID string, onEntry func(chat.LogEntry)) error + mu sync.Mutex + lastOpts chat.StartChatOpts +} + +func (s *stubRunner) StartChat(ctx context.Context, opts chat.StartChatOpts) (string, error) { + s.startCalls.Add(1) + s.mu.Lock() + s.lastOpts = opts + s.mu.Unlock() + + if s.startErr != nil { + return "", s.startErr + } + + return "container-abc", nil +} + +func (s *stubRunner) EndChat(ctx context.Context, sessionID string) error { + s.endCalls.Add(1) + + return nil +} + +func (s *stubRunner) SendChatMessage(ctx context.Context, sessionID, content, messageID string) error { + s.sendCalls.Add(1) + + return s.sendErr +} + +func (s *stubRunner) StreamLogs(ctx context.Context, sessionID string, onEntry func(chat.LogEntry)) error { + s.streamCalls.Add(1) + + s.activeStreams.Add(1) + defer s.activeStreams.Add(-1) + + if s.streamLogsFn != nil { + return s.streamLogsFn(ctx, sessionID, onEntry) + } + + <-ctx.Done() + + return ctx.Err() +} + +func newManagerWithStubs(t *testing.T) (*chat.Manager, *stubRunner, chat.Store) { + t.Helper() + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + runner := &stubRunner{} + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: runner, + Clock: clock.Real(), + IdleTTL: time.Hour, + }) + + return mgr, runner, store +} + +func TestManager_CreateSession_RowExists(t *testing.T) { + mgr, _, _ := newManagerWithStubs(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{ + Title: "runner-auth", + Project: "contextmatrix-runner", + CreatedBy: "human:web-abcd1234", + }) + require.NoError(t, err) + assert.NotEmpty(t, sess.ID) + assert.Equal(t, chat.StatusCold, sess.Status, "newly-created sessions are cold") + assert.Equal(t, "runner-auth", sess.Title) +} + +func TestManager_OpenSession_ColdStartsContainer(t *testing.T) { + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + runner := &stubRunner{} + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: runner, + Clock: clock.Real(), + IdleTTL: time.Hour, + ResolveRepoURL: func(_ context.Context, _ string) (string, error) { + return "https://example.com/alpha.git", nil + }, + }) + + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", Project: "alpha", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + got, err := mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + assert.Equal(t, chat.StatusActive, got.Status) + assert.Equal(t, "container-abc", got.ContainerID) + assert.Equal(t, int64(1), runner.startCalls.Load(), "container started exactly once") + assert.Equal(t, []string{"alpha"}, got.Workspace, "project recorded in workspace list") +} + +func TestManager_OpenSession_WarmIdleReattaches(t *testing.T) { + mgr, runner, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + sess.Status = chat.StatusWarmIdle + sess.ContainerID = "container-existing" + require.NoError(t, store.UpdateSession(ctx, sess)) + + got, err := mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + assert.Equal(t, chat.StatusActive, got.Status) + assert.Equal(t, "container-existing", got.ContainerID) + assert.Equal(t, int64(0), runner.startCalls.Load(), "warm-idle reattach must not start a new container") +} + +func TestManager_OpenSession_AlreadyActive_NoOp(t *testing.T) { + mgr, runner, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + sess.Status = chat.StatusActive + sess.ContainerID = "container-x" + require.NoError(t, store.UpdateSession(ctx, sess)) + + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + assert.Equal(t, int64(0), runner.startCalls.Load()) +} + +// TestManager_OpenSession_AlreadyActive_StartsConsumer ensures the active +// branch reattaches the runner-log consumer. CM-restart strands the in- +// memory consumer goroutine while the session row stays active; without +// this, an /open call on an active session leaves the bridge missing. +func TestManager_OpenSession_AlreadyActive_StartsConsumer(t *testing.T) { + mgr, runner, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + sess.Status = chat.StatusActive + sess.ContainerID = "container-x" + require.NoError(t, store.UpdateSession(ctx, sess)) + + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + + require.Eventually(t, func() bool { + return runner.streamCalls.Load() == 1 + }, 2*time.Second, 10*time.Millisecond, "consumer must stream logs from runner") + assert.Equal(t, int64(0), runner.startCalls.Load(), "no new container") +} + +// TestManager_Reattach_Active starts a runner-log consumer for an already- +// active session whose in-memory consumer was lost (CM restart). The DB +// row is left as-is. +func TestManager_Reattach_Active(t *testing.T) { + mgr, runner, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + sess.Status = chat.StatusActive + sess.ContainerID = "container-x" + require.NoError(t, store.UpdateSession(ctx, sess)) + + require.NoError(t, mgr.Reattach(ctx, sess.ID)) + + require.Eventually(t, func() bool { + return runner.streamCalls.Load() == 1 + }, 2*time.Second, 10*time.Millisecond) + + got, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.Equal(t, chat.StatusActive, got.Status, "Reattach must not change status") + assert.Equal(t, int64(0), runner.startCalls.Load()) +} + +// TestManager_Reattach_WarmIdle starts a consumer for a warm-idle session +// and refreshes LastActive so the idle reaper doesn't end it. Status is +// intentionally left at warm-idle — the SessionUpdate SSE type has no +// Status field yet, so a flip-to-active here would desync the sidebar. +func TestManager_Reattach_WarmIdle(t *testing.T) { + mgr, runner, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + sess.Status = chat.StatusWarmIdle + sess.ContainerID = "container-warm" + old := time.Now().Add(-time.Hour).UTC().Truncate(time.Second) + sess.LastActive = old + require.NoError(t, store.UpdateSession(ctx, sess)) + + require.NoError(t, mgr.Reattach(ctx, sess.ID)) + + require.Eventually(t, func() bool { + return runner.streamCalls.Load() == 1 + }, 2*time.Second, 10*time.Millisecond) + + got, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.Equal(t, chat.StatusWarmIdle, got.Status) + assert.True(t, got.LastActive.After(old), "LastActive must be refreshed") + assert.Equal(t, int64(0), runner.startCalls.Load()) +} + +// TestManager_Reattach_Cold is a no-op — cold sessions have no container +// to reattach to. +func TestManager_Reattach_Cold(t *testing.T) { + mgr, runner, _ := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + require.NoError(t, mgr.Reattach(ctx, sess.ID)) + + // Give any (incorrect) goroutine spawn time to call StreamLogs. + time.Sleep(50 * time.Millisecond) + assert.Equal(t, int64(0), runner.streamCalls.Load()) + assert.Equal(t, int64(0), runner.startCalls.Load()) +} + +// TestManager_Reattach_Idempotent guarantees concurrent or repeated calls +// don't spawn extra consumer goroutines. +func TestManager_Reattach_Idempotent(t *testing.T) { + mgr, runner, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + sess.Status = chat.StatusActive + sess.ContainerID = "container-x" + require.NoError(t, store.UpdateSession(ctx, sess)) + + require.NoError(t, mgr.Reattach(ctx, sess.ID)) + require.NoError(t, mgr.Reattach(ctx, sess.ID)) + require.NoError(t, mgr.Reattach(ctx, sess.ID)) + + require.Eventually(t, func() bool { + return runner.streamCalls.Load() == 1 + }, 2*time.Second, 10*time.Millisecond) + // Give any duplicate goroutine spawn a chance to (wrongly) increment. + time.Sleep(50 * time.Millisecond) + assert.Equal(t, int64(1), runner.streamCalls.Load(), "exactly one consumer") +} + +func TestManager_EndSession_ActiveTransitionsToCold(t *testing.T) { + mgr, runner, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + sess.Status = chat.StatusActive + sess.ContainerID = "container-x" + require.NoError(t, store.UpdateSession(ctx, sess)) + + require.NoError(t, mgr.EndSession(ctx, sess.ID)) + got, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.Equal(t, chat.StatusCold, got.Status) + assert.Empty(t, got.ContainerID) + assert.Equal(t, int64(1), runner.endCalls.Load()) +} + +func TestManager_EndSession_AlreadyCold_NoOp(t *testing.T) { + mgr, runner, _ := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + require.NoError(t, mgr.EndSession(ctx, sess.ID)) + assert.Equal(t, int64(0), runner.endCalls.Load(), "ending an already-cold session must not call runner") +} + +// TestManager_EndSession_RecoversFromStuckEnding verifies that EndSession +// succeeds when the row is already in status=ending (a prior call failed +// between the two-write pattern and left the row wedged), and that the session +// can subsequently be reopened via OpenSession. +func TestManager_EndSession_RecoversFromStuckEnding(t *testing.T) { + t.Parallel() + mgr, _, store := newManagerWithStubs(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + + // Simulate a prior partial failure: set status=ending directly in the store. + sess.Status = chat.StatusEnding + require.NoError(t, store.UpdateSession(ctx, sess)) + + // EndSession must succeed even though the row is already in ending. + require.NoError(t, mgr.EndSession(ctx, sess.ID)) + + got, err := mgr.GetSession(ctx, sess.ID) + require.NoError(t, err) + require.Equal(t, chat.StatusCold, got.Status, "session must be cold after EndSession recovers from stuck-ending") + + // The recovered session must be openable again (OpenSession previously + // rejected status=ending rows, so a stuck row would prevent reopening). + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err, "session must be openable after EndSession clears the stuck-ending state") +} + +// TestManager_EndSession_NeverPersistsEndingStatus verifies that a successful +// EndSession call never writes status=ending to the store (single-write +// contract). If the first write in the old two-step pattern had written +// status=ending, the injected fault on that write would cause EndSession to +// fail — but with the single-write pattern the fault is never triggered. +func TestManager_EndSession_NeverPersistsEndingStatus(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + inner, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = inner.Close() }) + + ts := &trackingStore{Store: inner} + runner := &stubRunner{} + mgr := chat.NewManager(chat.Config{ + Store: ts, + Runner: runner, + Clock: clock.Real(), + IdleTTL: time.Hour, + }) + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + // Move the session to active so EndSession has real work to do. + sess.Status = chat.StatusActive + sess.ContainerID = "container-x" + require.NoError(t, inner.UpdateSession(ctx, sess)) + + require.NoError(t, mgr.EndSession(ctx, sess.ID)) + + // Verify no intermediate write ever persisted status=ending. + for _, s := range ts.writtenStatuses() { + require.NotEqual(t, chat.StatusEnding, s, + "EndSession must never persist status=ending; got intermediate statuses: %v", ts.writtenStatuses()) + } + + // Session must end up cold. + got, err := inner.GetSession(ctx, sess.ID) + require.NoError(t, err) + require.Equal(t, chat.StatusCold, got.Status) +} + +func TestManager_AppendMessage_AssignsMonotonicSeq(t *testing.T) { + mgr, _, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + m1, err := mgr.AppendMessage(ctx, sess.ID, chat.RoleUser, `{"text":"hi"}`) + require.NoError(t, err) + assert.Equal(t, int64(1), m1.Seq) + + m2, err := mgr.AppendMessage(ctx, sess.ID, chat.RoleAssistantText, `{"text":"hello"}`) + require.NoError(t, err) + assert.Equal(t, int64(2), m2.Seq) + + msgs, err := store.ListMessages(ctx, sess.ID, 0, 100) + require.NoError(t, err) + require.Len(t, msgs, 2) +} + +func TestManager_AutoTitle_FromFirstUserMessage(t *testing.T) { + mgr, _, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "", CreatedBy: "x"}) + require.NoError(t, err) + assert.Empty(t, sess.Title) + + _, err = mgr.AppendMessage(ctx, sess.ID, chat.RoleUser, "let's investigate the auth flow") + require.NoError(t, err) + + got, _ := store.GetSession(ctx, sess.ID) + assert.Equal(t, "let's investigate the auth flow", got.Title) +} + +func TestManager_AutoTitle_TruncatesAt50Chars(t *testing.T) { + mgr, _, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "", CreatedBy: "x"}) + require.NoError(t, err) + + long := "this is a fairly long first message that exceeds fifty characters total" + _, err = mgr.AppendMessage(ctx, sess.ID, chat.RoleUser, long) + require.NoError(t, err) + + got, _ := store.GetSession(ctx, sess.ID) + assert.LessOrEqual(t, utf8.RuneCountInString(got.Title), 51) // 50 runes + ellipsis + assert.True(t, strings.HasSuffix(got.Title, "…")) +} + +// TestManager_AutoTitle_RuneSafe verifies that auto-title slices at a rune +// boundary, not a byte boundary. Multi-byte characters (UTF-8) like "é" +// (2 bytes) would otherwise be cut mid-rune and round-trip as U+FFFD garbage +// through JSON marshaling. +func TestManager_AutoTitle_RuneSafe(t *testing.T) { + mgr, _, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "", CreatedBy: "x"}) + require.NoError(t, err) + + // 49 ASCII chars + "é" places the first byte of "é" at byte index 49 and + // the second byte at index 50. A naive byte-slice [:50] cuts mid-rune. + long := strings.Repeat("a", 49) + "é trailing words" + _, err = mgr.AppendMessage(ctx, sess.ID, chat.RoleUser, long) + require.NoError(t, err) + + got, _ := store.GetSession(ctx, sess.ID) + assert.True(t, utf8.ValidString(got.Title), + "auto-title must remain valid UTF-8; got %q", got.Title) + assert.True(t, strings.HasSuffix(got.Title, "…")) +} + +func TestManager_MarkWarmIdle_ActiveToWarmIdle(t *testing.T) { + mgr, _, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + sess.Status = chat.StatusActive + sess.ContainerID = "c-1" + require.NoError(t, store.UpdateSession(ctx, sess)) + + require.NoError(t, mgr.MarkWarmIdle(ctx, sess.ID)) + got, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.Equal(t, chat.StatusWarmIdle, got.Status) + assert.Equal(t, "c-1", got.ContainerID, "container ID must survive warm-idle") +} + +func TestManager_MarkWarmIdle_ColdNoOp(t *testing.T) { + mgr, _, _ := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + // session is cold; MarkWarmIdle should not change anything + require.NoError(t, mgr.MarkWarmIdle(ctx, sess.ID)) + got, _ := mgr.GetSession(ctx, sess.ID) + assert.Equal(t, chat.StatusCold, got.Status, "cold sessions stay cold") +} + +// TestManager_OpenSession_MaxConcurrent_ParallelTOCTOU exercises the +// concurrency cap under a tight race: ten goroutines call OpenSession at +// once with MaxConcurrent=2. Without the lock fix, the two ListSessions +// reads happen before any StartChat call mutates the store, so several +// goroutines pass the limit check simultaneously and the runner sees +// more than two StartChat calls. With the fix exactly two start. +func TestManager_OpenSession_MaxConcurrent_ParallelTOCTOU(t *testing.T) { + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + const total = 10 + + // slowRunner stalls StartChat briefly to widen the race window. + runner := &slowStartRunner{delay: 10 * time.Millisecond} + + mgr := chat.NewManager(chat.Config{ + Store: store, Runner: runner, Clock: clock.Real(), + IdleTTL: time.Hour, MaxConcurrent: 2, + }) + + ctx := context.Background() + + ids := make([]string, total) + + for i := range total { + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "", CreatedBy: "x"}) + require.NoError(t, err) + + ids[i] = sess.ID + } + + var ( + wg sync.WaitGroup + successes atomic.Int64 + rejects atomic.Int64 + ) + + for _, id := range ids { + wg.Add(1) + + go func(sessID string) { + defer wg.Done() + + _, err := mgr.OpenSession(ctx, sessID) + + switch { + case err == nil: + successes.Add(1) + case errors.Is(err, chat.ErrTooManyConcurrent): + rejects.Add(1) + default: + t.Errorf("OpenSession(%s) unexpected error: %v", sessID, err) + } + }(id) + } + + wg.Wait() + + assert.Equal(t, int64(2), successes.Load(), "exactly MaxConcurrent (=2) opens must succeed") + assert.Equal(t, int64(total-2), rejects.Load(), "all other opens must be rejected") + assert.LessOrEqual(t, runner.startCalls.Load(), int64(2), + "runner.StartChat must be called at most MaxConcurrent times (no leaked containers)") +} + +// TestManager_AppendMessage_SeqMonotonicUnderConcurrency exercises the +// serialisation fix: concurrent AppendMessage calls on the same session must +// land in the store both (a) with strictly monotonic seq values and (b) in +// insertion order — so the rowid order matches the seq order. Without +// holding m.mu across the store insert, two appends can race past one +// another and land out of seq order on disk. +func TestManager_AppendMessage_SeqMonotonicUnderConcurrency(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "chats.db") + + store, err := sqlite.Open(dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + mgr := chat.NewManager(chat.Config{ + Store: store, Runner: &stubRunner{}, Clock: clock.Real(), IdleTTL: time.Hour, + }) + + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "seq", CreatedBy: "x"}) + require.NoError(t, err) + + const N = 50 + + var wg sync.WaitGroup + + wg.Add(N) + + for i := range N { + go func(i int) { + defer wg.Done() + + _, err := mgr.AppendMessage(ctx, sess.ID, chat.RoleAssistantText, strconv.Itoa(i)) + assert.NoError(t, err) + }(i) + } + + wg.Wait() + + // (a) ListMessages orders by seq; verify seqs are 1..N with no holes. + msgs, err := store.ListMessages(ctx, sess.ID, 0, 1000) + require.NoError(t, err) + require.Len(t, msgs, N) + + for i, m := range msgs { + assert.Equal(t, int64(i+1), m.Seq, "seq %d should be %d", i, i+1) + } + + // (b) Open the DB directly and query in rowid order. The seq column + // must increase monotonically with rowid — i.e. the insertion order + // matches the seq order. This is the assertion that fails when the + // store write happens outside the seq-assignment lock. + db, err := sql.Open("sqlite", dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + rows, err := db.QueryContext(ctx, + `SELECT seq FROM chat_messages WHERE session_id = ? ORDER BY id ASC`, sess.ID) + require.NoError(t, err) + t.Cleanup(func() { _ = rows.Close() }) + + var prev int64 + + for rows.Next() { + var seq int64 + require.NoError(t, rows.Scan(&seq)) + assert.Greater(t, seq, prev, + "insertion order: seq must increase with rowid, got prev=%d cur=%d", prev, seq) + prev = seq + } + + require.NoError(t, rows.Err()) +} + +type slowStartRunner struct { + delay time.Duration + startCalls atomic.Int64 +} + +func (s *slowStartRunner) StartChat(_ context.Context, _ chat.StartChatOpts) (string, error) { + s.startCalls.Add(1) + time.Sleep(s.delay) + + return "container-" + strconv.FormatInt(s.startCalls.Load(), 10), nil +} + +func (s *slowStartRunner) EndChat(_ context.Context, _ string) error { return nil } + +func (s *slowStartRunner) SendChatMessage(_ context.Context, _, _, _ string) error { return nil } + +func (s *slowStartRunner) StreamLogs(ctx context.Context, _ string, _ func(chat.LogEntry)) error { + <-ctx.Done() + + return ctx.Err() +} + +func TestManager_OpenSession_RespectsMaxConcurrent(t *testing.T) { + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + runner := &stubRunner{} + mgr := chat.NewManager(chat.Config{ + Store: store, Runner: runner, Clock: clock.Real(), + IdleTTL: time.Hour, MaxConcurrent: 2, + ResolveRepoURL: func(ctx context.Context, project string) (string, error) { + return "", nil + }, + }) + + ctx := context.Background() + for i := 0; i < 2; i++ { + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "", CreatedBy: "x"}) + require.NoError(t, err) + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + } + + // Third open should fail with ErrTooManyConcurrent. + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "", CreatedBy: "x"}) + require.NoError(t, err) + _, err = mgr.OpenSession(ctx, sess.ID) + require.ErrorIs(t, err, chat.ErrTooManyConcurrent) +} + +func TestManager_ListSessions_FilterByStatus(t *testing.T) { + mgr, _, store := newManagerWithStubs(t) + ctx := context.Background() + + sess1, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "a", CreatedBy: "x"}) + require.NoError(t, err) + sess2, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "b", CreatedBy: "x"}) + require.NoError(t, err) + + // Flip sess2 to warm-idle in the store. + sess2.Status = chat.StatusWarmIdle + require.NoError(t, store.UpdateSession(ctx, sess2)) + + all, err := mgr.ListSessions(ctx, chat.SessionFilter{}) + require.NoError(t, err) + assert.Len(t, all, 2) + + cold, err := mgr.ListSessions(ctx, chat.SessionFilter{Status: chat.StatusCold}) + require.NoError(t, err) + assert.Len(t, cold, 1) + assert.Equal(t, sess1.ID, cold[0].ID) +} + +func TestManager_DeleteSession_ColdDeletesRow(t *testing.T) { + mgr, _, store := newManagerWithStubs(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + require.NoError(t, mgr.DeleteSession(ctx, sess.ID)) + + _, err = store.GetSession(ctx, sess.ID) + require.ErrorIs(t, err, chat.ErrSessionNotFound) +} + +func TestManager_DeleteSession_ActiveEndsFirst(t *testing.T) { + mgr, runner, store := newManagerWithStubs(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + sess.Status = chat.StatusActive + sess.ContainerID = "c-1" + require.NoError(t, store.UpdateSession(ctx, sess)) + + require.NoError(t, mgr.DeleteSession(ctx, sess.ID)) + + assert.Equal(t, int64(1), runner.endCalls.Load(), "EndSession must have stopped the container") + + _, err = store.GetSession(ctx, sess.ID) + require.ErrorIs(t, err, chat.ErrSessionNotFound) +} + +func TestManager_SendUserMessage_HappyPath(t *testing.T) { + mgr, runner, store := newManagerWithStubs(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + // Pre-open to active so OpenSession is not needed inside SendUserMessage. + sess.Status = chat.StatusActive + sess.ContainerID = "c-1" + require.NoError(t, store.UpdateSession(ctx, sess)) + + msgID, err := mgr.SendUserMessage(ctx, sess.ID, "hello world") + require.NoError(t, err) + assert.NotEmpty(t, msgID) + assert.Equal(t, int64(1), runner.sendCalls.Load(), "SendChatMessage must be called once") + + msgs, err := store.ListMessages(ctx, sess.ID, 0, 100) + require.NoError(t, err) + require.Len(t, msgs, 1) + assert.Equal(t, chat.RoleUser, msgs[0].Role) + assert.Equal(t, "hello world", msgs[0].Content) +} + +// TestManager_SendUserMessage_RunnerErrorDoesNotPersist exercises the +// runner-first ordering: if the runner.SendChatMessage call fails, the +// user message is NOT persisted and not published to the hub. The UI sees +// the error and can retry without ending up with an orphaned echo. +func TestManager_SendUserMessage_RunnerErrorDoesNotPersist(t *testing.T) { + mgr, runner, store := newManagerWithStubs(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + sess.Status = chat.StatusActive + sess.ContainerID = "c-1" + require.NoError(t, store.UpdateSession(ctx, sess)) + + runner.sendErr = errors.New("runner unreachable") + + _, err = mgr.SendUserMessage(ctx, sess.ID, "hello") + require.Error(t, err, "runner failure must propagate to the caller") + assert.Contains(t, err.Error(), "runner unreachable") + + // No persisted user message — the runner-first ordering means we never + // got past the runner call. + msgs, err := store.ListMessages(ctx, sess.ID, 0, 100) + require.NoError(t, err) + assert.Empty(t, msgs, "no user message must be persisted when runner.SendChatMessage fails") +} + +func TestManager_SendUserMessage_OpensColdSession(t *testing.T) { + mgr, runner, _ := newManagerWithStubs(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + // Session remains cold — SendUserMessage must open it first. + + _, err = mgr.SendUserMessage(ctx, sess.ID, "hi") + require.NoError(t, err) + assert.Equal(t, int64(1), runner.startCalls.Load(), "cold session must trigger StartChat") + assert.Equal(t, int64(1), runner.sendCalls.Load()) +} + +func TestManager_UpdateSessionMetadata_ChangesTitle(t *testing.T) { + mgr, _, store := newManagerWithStubs(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "old", CreatedBy: "x"}) + require.NoError(t, err) + + sess.Title = "new title" + require.NoError(t, mgr.UpdateSessionMetadata(ctx, sess)) + + got, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.Equal(t, "new title", got.Title) +} + +// TestManager_OpenSession_BridgesRunnerLogs verifies that an assistant text +// event emitted by the runner's /logs stream is persisted as an +// assistant_text message and published to the SSE hub. Without this +// bridge, the browser would see only the user echo and no reply. +func TestManager_OpenSession_BridgesRunnerLogs(t *testing.T) { + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + delivered := make(chan struct{}) + runner := &stubRunner{ + streamLogsFn: func(ctx context.Context, _ string, onEntry func(chat.LogEntry)) error { + onEntry(chat.LogEntry{Type: "text", Content: "Hello back."}) + close(delivered) + + <-ctx.Done() + + return ctx.Err() + }, + } + + hub := chat.NewSSEHub(128) + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: runner, + Clock: clock.Real(), + IdleTTL: time.Hour, + Hub: hub, + }) + + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "bridge", CreatedBy: "human:test"}) + require.NoError(t, err) + + ch, _, _ := hub.Subscribe(sess.ID, 0) + + t.Cleanup(func() { hub.Unsubscribe(sess.ID, ch) }) + + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + + select { + case <-delivered: + case <-time.After(2 * time.Second): + t.Fatal("StreamLogs onEntry never invoked") + } + + select { + case e := <-ch: + assert.Equal(t, chat.RoleAssistantText, e.Role) + assert.Equal(t, "Hello back.", e.Content) + assert.Equal(t, int64(1), e.Seq) + case <-time.After(2 * time.Second): + t.Fatal("hub did not receive assistant_text event") + } + + // EndSession should stop the consumer; verify via streamCalls staying at 1. + require.NoError(t, mgr.EndSession(ctx, sess.ID)) + + // Re-opening should kick off a new consumer (idempotency check would + // require another OpenSession; verifying stop is enough for this test). + assert.Equal(t, int64(1), runner.streamCalls.Load()) +} + +// TestManager_EndThenReopen_SpawnsFreshConsumer is the regression for the +// startConsumer ↔ stopConsumer cleanup race. With the unfixed code, +// stopConsumer cancels the consumer context and returns immediately; the +// goroutine's deferred map-delete runs asynchronously. A fast Reopen that +// runs while the deferred delete is still pending finds a stale entry in +// m.consumers and returns early — the new session has no log bridge. +// +// We simulate slow goroutine exit with a streamLogsFn that sleeps after +// ctx.Done. With the fix, stopConsumer waits on a per-consumer done channel +// and the entry is gone before Reopen runs. +func TestManager_EndThenReopen_SpawnsFreshConsumer(t *testing.T) { + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + runner := &stubRunner{ + streamLogsFn: func(ctx context.Context, _ string, _ func(chat.LogEntry)) error { + <-ctx.Done() + // Simulate slow goroutine exit — the goroutine has received cancel + // but has not yet run its cleanup defers. + time.Sleep(50 * time.Millisecond) + + return ctx.Err() + }, + } + + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: runner, + Clock: clock.Real(), + IdleTTL: time.Hour, + }) + + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "human:x"}) + require.NoError(t, err) + + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + require.Eventually(t, func() bool { return runner.streamCalls.Load() == 1 }, + time.Second, 5*time.Millisecond, "first open must start consumer") + + require.NoError(t, mgr.EndSession(ctx, sess.ID)) + + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + + // Without the fix, streamCalls stays at 1 because the second startConsumer + // returned early on a stale map entry. + require.Eventually(t, func() bool { return runner.streamCalls.Load() == 2 }, + time.Second, 5*time.Millisecond, + "Reopen after End must spawn a fresh runner-log consumer") + + require.NoError(t, mgr.EndSession(ctx, sess.ID)) +} + +// TestManager_AppendMessage_TruncatesOversizedContent verifies that +// runner-emitted entries beyond the per-message size cap are truncated with +// a marker before persistence. Without this cap, a chatty tool (cat of a +// large file, verbose tool_result) fills chats.db linearly and never +// reclaims the space. +func TestManager_AppendMessage_TruncatesOversizedContent(t *testing.T) { + mgr, _, store := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + huge := strings.Repeat("a", 100*1024) + + msg, err := mgr.AppendMessage(ctx, sess.ID, chat.RoleAssistantText, huge) + require.NoError(t, err) + assert.Less(t, len(msg.Content), len(huge), "oversized content must be truncated") + assert.LessOrEqual(t, len(msg.Content), 32*1024+64, "truncated content must fit the cap (with marker)") + assert.Contains(t, msg.Content, "[truncated]", "truncation must leave a marker") + + msgs, err := store.ListMessages(ctx, sess.ID, 0, 10) + require.NoError(t, err) + require.Len(t, msgs, 1) + assert.Equal(t, msg.Content, msgs[0].Content, "persisted content must match returned content") +} + +// TestManager_AppendMessage_DoesNotTruncateSmallContent ensures the cap only +// fires on oversized content. +func TestManager_AppendMessage_DoesNotTruncateSmallContent(t *testing.T) { + mgr, _, _ := newManagerWithStubs(t) + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + msg, err := mgr.AppendMessage(ctx, sess.ID, chat.RoleAssistantText, "hello world") + require.NoError(t, err) + assert.Equal(t, "hello world", msg.Content, "small content must not be touched") +} + +func TestManager_OpenSession_ColdWithPriorTranscript_SendsResume(t *testing.T) { + mgr, runner, _ := newManagerWithStubsAndConfig(t, chat.Config{IdleTTL: time.Hour}) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + // Seed a transcript so transcript.Build returns a non-nil resume. + _, err = mgr.AppendMessage(ctx, sess.ID, chat.RoleUser, "first goal") + require.NoError(t, err) + _, err = mgr.AppendMessage(ctx, sess.ID, chat.RoleAssistantText, "okay") + require.NoError(t, err) + + // End so the next OpenSession follows the cold-branch. + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + require.NoError(t, mgr.EndSession(ctx, sess.ID)) + + // Reopen. + reopened, err := mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + assert.Equal(t, chat.StatusActive, reopened.Status) + + runner.mu.Lock() + opts := runner.lastOpts + runner.mu.Unlock() + + assert.Equal(t, sess.ID, opts.SessionID) + require.NotNil(t, opts.Resume, "Resume must be sent on cold-reopen with prior transcript") + require.GreaterOrEqual(t, len(opts.Resume.Turns), 2, + "resume payload should carry the prior user + assistant turns") +} + +func TestManager_OpenSession_ColdEmptyTranscript_OmitsResume(t *testing.T) { + mgr, runner, _ := newManagerWithStubsAndConfig(t, chat.Config{IdleTTL: time.Hour}) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + + runner.mu.Lock() + opts := runner.lastOpts + runner.mu.Unlock() + + assert.Nil(t, opts.Resume, "fresh session must not carry a Resume") +} + +func TestManager_OpenSession_PassesModel(t *testing.T) { + mgr, runner, _ := newManagerWithStubsAndConfig(t, chat.Config{ + IdleTTL: time.Hour, + DefaultModel: "claude-sonnet-4-6", + }) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{ + Title: "t", + CreatedBy: "x", + Model: "claude-opus-4-7", + }) + require.NoError(t, err) + + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + + runner.mu.Lock() + opts := runner.lastOpts + runner.mu.Unlock() + + assert.Equal(t, "claude-opus-4-7", opts.Model, + "session-stored model must be passed to runner") +} + +func TestManager_OpenSession_FallsBackToDefaultModel(t *testing.T) { + mgr, runner, _ := newManagerWithStubsAndConfig(t, chat.Config{ + IdleTTL: time.Hour, + DefaultModel: "claude-sonnet-4-6", + }) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + + runner.mu.Lock() + opts := runner.lastOpts + runner.mu.Unlock() + + assert.Equal(t, "claude-sonnet-4-6", opts.Model, + "empty session.Model falls back to config DefaultModel") +} + +func TestManager_CompleteRehydration_PersistsSummaryAndFlipsFlag(t *testing.T) { + mgr, _, store := newManagerWithStubsAndConfig(t, chat.Config{IdleTTL: time.Hour}) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + // Seed a transcript and reopen so rehydration_active flips on. + _, err = mgr.AppendMessage(ctx, sess.ID, chat.RoleUser, "task") + require.NoError(t, err) + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + require.NoError(t, mgr.EndSession(ctx, sess.ID)) + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + + reopened, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + require.True(t, reopened.RehydrationActive, "reopen with prior transcript should set rehydration_active=true") + + err = mgr.CompleteRehydration(ctx, sess.ID, "Picking up where we left off — re-cloned foo.") + require.NoError(t, err) + + flipped, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.False(t, flipped.RehydrationActive, "CompleteRehydration must flip flag off") + + msgs, err := store.ListMessages(ctx, sess.ID, 0, 100) + require.NoError(t, err) + + var summary *chat.Message + + for i, msg := range msgs { + if msg.Role == chat.RoleAssistantText && msg.Content[:7] == "Picking" { + summary = &msgs[i] + + break + } + } + + require.NotNil(t, summary, "summary message must be persisted") + assert.False(t, summary.RehydrationPhase, "summary message must NOT carry the phase flag") +} + +func TestManager_CompleteRehydration_Idempotent(t *testing.T) { + t.Parallel() + mgr, _, store := newManagerWithStubsAndConfig(t, chat.Config{IdleTTL: time.Hour}) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + // Set rehydration active, then complete it. + require.NoError(t, mgr.SetRehydrationActiveForTest(ctx, sess.ID, true)) + require.NoError(t, mgr.CompleteRehydration(ctx, sess.ID, "first call")) + + // Second call — must succeed and NOT append a second summary. + before, err := store.ListMessages(ctx, sess.ID, 0, 100) + require.NoError(t, err) + + require.NoError(t, mgr.CompleteRehydration(ctx, sess.ID, "second call")) + + after, err := store.ListMessages(ctx, sess.ID, 0, 100) + require.NoError(t, err) + + assert.Len(t, after, len(before), + "second call must not append another summary") +} + +func TestManager_SendUserMessage_EndsRehydrationPhase(t *testing.T) { + mgr, _, store := newManagerWithStubsAndConfig(t, chat.Config{IdleTTL: time.Hour}) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + _, err = mgr.AppendMessage(ctx, sess.ID, chat.RoleUser, "task") + require.NoError(t, err) + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + require.NoError(t, mgr.EndSession(ctx, sess.ID)) + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + + active, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + require.True(t, active.RehydrationActive) + + _, err = mgr.SendUserMessage(ctx, sess.ID, "follow up") + require.NoError(t, err) + + after, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.False(t, after.RehydrationActive, + "first user message during rehydration must flip the flag off") +} + +func TestManager_EndSession_ResetsRehydrationActive(t *testing.T) { + mgr, _, store := newManagerWithStubsAndConfig(t, chat.Config{IdleTTL: time.Hour}) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + _, err = mgr.AppendMessage(ctx, sess.ID, chat.RoleUser, "task") + require.NoError(t, err) + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + require.NoError(t, mgr.EndSession(ctx, sess.ID)) + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + + require.NoError(t, mgr.EndSession(ctx, sess.ID)) + + got, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.False(t, got.RehydrationActive, + "EndSession must clear the rehydration flag in the cold transition") +} + +// TestManager_OpenSession_RollbackOnRehydrationPersistFailure verifies that if +// the store.SetRehydrationActive write fails after the container is already up, +// OpenSession rolls back the container (EndChat), clears the in-memory cache, +// resets the session row to cold, and returns an error — leaving no orphaned +// active container with an unset rehydration flag. +func TestManager_OpenSession_RollbackOnRehydrationPersistFailure(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + inner, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = inner.Close() }) + + fstore := &failingStore{Store: inner} + runner := &stubRunner{} + + mgr := chat.NewManager(chat.Config{ + Store: fstore, + Runner: runner, + Clock: clock.Real(), + IdleTTL: time.Hour, + }) + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + + // Seed a message so cold-reopen triggers the rehydration path. + _, err = mgr.AppendMessage(ctx, sess.ID, chat.RoleUser, `{"text":"hi"}`) + require.NoError(t, err) + + // End the session so next OpenSession is cold with a non-empty transcript. + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + require.NoError(t, mgr.EndSession(ctx, sess.ID)) + + // Arm the one-shot fault: next SetRehydrationActive call will fail. + fstore.FailNextSetRehydration() + + _, err = mgr.OpenSession(ctx, sess.ID) + require.Error(t, err, "OpenSession must fail when the rehydration flag cannot be persisted") + + // The container that was started must have been rolled back. + require.Equal(t, int64(2), runner.endCalls.Load(), + "EndChat must be called once for the explicit EndSession, once for the rollback") + + // Session must be back to cold so the next open is a clean retry. + got, err := inner.GetSession(ctx, sess.ID) + require.NoError(t, err) + require.Equal(t, chat.StatusCold, got.Status, "failed open must leave session cold") + assert.Empty(t, got.ContainerID, "container ID must be cleared on rollback") + assert.False(t, got.RehydrationActive, "rehydration_active must not be set after failed open") +} + +// TestSetRehydrationActive_StoreAndCacheStayInSync drives many concurrent +// flips through setRehydrationActive and asserts the on-disk value equals +// the in-memory cache value once the dust settles. When the store write +// happens outside m.mu, two callers writing opposite booleans can land in +// opposite orders on disk vs cache, leaving the cache permanently desynced. +// Holding m.mu across both writes forces a single serialization point so +// whichever value commits to disk last is also the cache value on return. +// +// The store is wrapped in yieldingStore which sleeps a jittered amount +// after every SetRehydrationActive call. SQLite's UPDATE is heavyweight +// relative to the trivial cache write that follows, so without an +// explicit, variable post-store delay the cache writes drain in lockstep +// with the store commits and the race window collapses. The jitter +// scatters cache writes out of store-commit order — the schedule that +// exposes the regression. Multiple flips per goroutine compound the +// variance; iterating the outer batch a few times makes a single CI run +// likely to catch the bug. +func TestSetRehydrationActive_StoreAndCacheStayInSync(t *testing.T) { + inner, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = inner.Close() }) + + store := &yieldingStore{Store: inner} + runner := &stubRunner{} + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: runner, + Clock: clock.Real(), + IdleTTL: time.Hour, + }) + + sess, err := mgr.CreateSession(context.Background(), chat.CreateInput{ + Title: "ordering", + Project: "alpha", + CreatedBy: "human:test", + }) + require.NoError(t, err) + + // Several batches of 100 concurrent flips. After each batch the + // cache value must equal the persisted store value — they are + // written under the same lock, so no schedule should split them. + // Each batch is an independent observation; running enough of them + // makes a single -count=10 CI run very likely to surface a + // regression. + // + // flipErr captures any error from inside the goroutines. testifylint + // (go-require) bans require.* in goroutines because it Goexits the + // caller, not the test — flip errors are funneled out here instead. + var flipErr atomic.Pointer[error] + + for batch := range 5 { + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(1) + + active := i%2 == 0 + + go func() { + defer wg.Done() + + if err := mgr.SetRehydrationActiveForTest(context.Background(), sess.ID, active); err != nil { + flipErr.Store(&err) + } + }() + } + + wg.Wait() + + if err := flipErr.Load(); err != nil { + require.NoError(t, *err, "setRehydrationActive flip failed inside goroutine") + } + + stored, err := store.GetSession(context.Background(), sess.ID) + require.NoError(t, err) + + cached, ok := mgr.RehydrationActiveCacheForTest(sess.ID) + require.True(t, ok, "cache must be populated after setRehydrationActive calls") + require.Equalf(t, stored.RehydrationActive, cached, + "batch %d: cache value %v diverged from stored value %v", + batch, cached, stored.RehydrationActive) + } +} + +func TestManager_HandleUsageEntry_UpdatesContextTokens(t *testing.T) { + hub := chat.NewSSEHub(64) + + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + runner := &usageStreamingRunner{ + entries: []chat.LogEntry{ + { + Type: "usage", + Usage: &chat.TokenUsage{ + InputTokens: 1000, + OutputTokens: 500, + CacheReadTokens: 4000, + CacheCreateTokens: 200, + }, + Model: "claude-sonnet-4-6", + }, + }, + } + + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: runner, + Clock: clock.Real(), + IdleTTL: time.Hour, + Hub: hub, + DefaultModel: "claude-sonnet-4-6", + }) + + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + + // Subscribe BEFORE opening so we observe the session_updated event. + events, _, _ := hub.Subscribe(sess.ID, 0) + + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + + // Wait for the usage event to propagate through the consumer. + var got chat.SSEEvent + select { + case got = <-events: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for session_updated event") + } + + require.Equal(t, chat.SSEKindSessionUpdate, got.Kind, "first event must be the session_updated push") + require.NotNil(t, got.SessionUpdate) + // 1000 + 4000 + 200 = 5200 (output tokens NOT included in context). + assert.Equal(t, int64(5200), got.SessionUpdate.ContextTokens, + "context_tokens = input + cache_read + cache_create") + + // Wait briefly for the DB write (handleUsageEntry persists then publishes). + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + s, err := store.GetSession(ctx, sess.ID) + require.NoError(t, err) + + if s.ContextTokens == 5200 { + return + } + + time.Sleep(10 * time.Millisecond) + } + + t.Fatalf("session.context_tokens never reached 5200") +} + +// usageStreamingRunner is a stub RunnerClient that delivers a canned list +// of LogEntry values through StreamLogs (in order, with a small delay so +// the consumer reliably observes them). +type usageStreamingRunner struct { + entries []chat.LogEntry +} + +func (r *usageStreamingRunner) StartChat(_ context.Context, _ chat.StartChatOpts) (string, error) { + return "container-usage", nil +} + +func (r *usageStreamingRunner) EndChat(_ context.Context, _ string) error { return nil } + +func (r *usageStreamingRunner) SendChatMessage(_ context.Context, _, _, _ string) error { + return nil +} + +func (r *usageStreamingRunner) StreamLogs(ctx context.Context, _ string, onEntry func(chat.LogEntry)) error { + for _, e := range r.entries { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + onEntry(e) + } + + <-ctx.Done() + + return ctx.Err() +} + +// newManagerWithStubsAndConfig is like newManagerWithStubs but lets the +// caller override the manager Config fields (DefaultModel, IdleTTL, etc.) +// without duplicating the store + stubRunner wiring boilerplate. +func newManagerWithStubsAndConfig(t *testing.T, base chat.Config) (*chat.Manager, *stubRunner, chat.Store) { + t.Helper() + + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + runner := &stubRunner{} + + base.Store = store + base.Runner = runner + + if base.Clock == nil { + base.Clock = clock.Real() + } + + mgr := chat.NewManager(base) + + return mgr, runner, store +} + +// TestManager_BuildResume_UsesTailOnLongSession is a regression for +// buildResume loading the oldest 600 messages instead of the newest. +// Sessions past ~600 messages would lose recent context — the "pin last 20 +// turns" guarantee in transcript.Build operated on a stale prefix. +func TestManager_BuildResume_UsesTailOnLongSession(t *testing.T) { + t.Parallel() + + mgr, _, store := newManagerWithStubs(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "long", Project: "p", CreatedBy: "human:test"}) + require.NoError(t, err) + + // Seed 650 messages directly via the store (bypasses Manager seq tracking, + // which is intentional — we are testing the read path, not the write path). + // maxMessagesForBuild is 600, so messages 1..50 must be excluded when + // using the old ListMessages(0, 600) call but present when using the tail. + const total = 650 + + for i := 1; i <= total; i++ { + _, err := store.AppendMessage(ctx, chat.Message{ + SessionID: sess.ID, + Seq: int64(i), + Role: chat.RoleAssistantText, + Content: fmt.Sprintf(`{"text":"msg-%d"}`, i), + CreatedAt: time.Now().UTC().Truncate(time.Second), + }) + require.NoError(t, err) + } + + rc := mgr.BuildResumeForTest(ctx, sess.ID) + require.NotNil(t, rc) + require.NotEmpty(t, rc.Turns) + + // The most recent message must be in the resume payload. + last := rc.Turns[len(rc.Turns)-1] + require.Contains(t, last.Content, `msg-650`, "tail must include the newest message") +} + +func TestManager_CompleteRehydration_UnknownSession(t *testing.T) { + t.Parallel() + mgr, _, _ := newManagerWithStubs(t) + err := mgr.CompleteRehydration(context.Background(), "01DOES_NOT_EXIST", "summary text") + require.Error(t, err) + require.ErrorIs(t, err, chat.ErrSessionNotFound) +} + +func TestManager_OpenSession_WorkspaceDedupesOnReopen(t *testing.T) { + t.Parallel() + mgr, _, _ := newManagerWithStubs(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "proj", CreatedBy: "human:t"}) + require.NoError(t, err) + + for i := 0; i < 5; i++ { + _, err := mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + err = mgr.EndSession(ctx, sess.ID) + require.NoError(t, err) + } + + got, err := mgr.GetSession(ctx, sess.ID) + require.NoError(t, err) + require.Equal(t, []string{"proj"}, got.Workspace, "project must appear once regardless of reopen count") +} + +func TestManager_Close_StopsAllConsumers(t *testing.T) { + t.Parallel() + mgr, runner, _ := newManagerWithStubs(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + + _, err = mgr.OpenSession(ctx, sess.ID) + require.NoError(t, err) + + // Wait for StreamLogs goroutine to start and increment activeStreams. + require.Eventually(t, func() bool { + return runner.activeStreams.Load() == 1 + }, time.Second, time.Millisecond, "StreamLogs goroutine must start before Close") + + require.NoError(t, mgr.Close(context.Background())) + require.Equal(t, int32(0), runner.activeStreams.Load(), "Close must stop all log streams") +} + +// countingRunner is a fake RunnerClient whose StartChat behaviour is fully +// controlled by the test via the startChat func field. Used to gate cold-open +// progress on a per-test signal so we can assert that two distinct sessions +// reach the runner concurrently. +type countingRunner struct { + startChat func(ctx context.Context, opts chat.StartChatOpts) (string, error) +} + +func (r *countingRunner) StartChat(ctx context.Context, opts chat.StartChatOpts) (string, error) { + if r.startChat != nil { + return r.startChat(ctx, opts) + } + + return "container-" + opts.SessionID, nil +} + +func (r *countingRunner) EndChat(_ context.Context, _ string) error { return nil } + +func (r *countingRunner) SendChatMessage(_ context.Context, _, _, _ string) error { return nil } + +func (r *countingRunner) StreamLogs(ctx context.Context, _ string, _ func(chat.LogEntry)) error { + <-ctx.Done() + + return ctx.Err() +} + +// newTestManagerWithRunner constructs a chat.Manager wired to the supplied +// RunnerClient and a fresh sqlite store, with MaxConcurrent explicitly set +// to 0 (unlimited) so the limit-bounded serialisation path does not gate +// the cold-open singleflight test. +func newTestManagerWithRunner(t *testing.T, runner chat.RunnerClient) (*chat.Manager, chat.Store, func()) { + t.Helper() + + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: runner, + Clock: clock.Real(), + IdleTTL: time.Hour, + MaxConcurrent: 0, + }) + + cleanup := func() { + _ = mgr.Close(context.Background()) + _ = store.Close() + } + + return mgr, store, cleanup +} + +// newTestManagerWithStore constructs a chat.Manager wired to the supplied +// Store (typically a wrapper around the real sqlite store that injects faults +// or instruments calls) and a stubRunner. The wrapped store is responsible +// for embedding the real chat.Store; this helper just wires it in. Used by +// tests that need to gate AppendMessage independently from the runner path. +func newTestManagerWithStore(t *testing.T, store chat.Store) (*chat.Manager, *stubRunner, func()) { + t.Helper() + + runner := &stubRunner{} + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: runner, + Clock: clock.Real(), + IdleTTL: time.Hour, + MaxConcurrent: 0, + }) + + cleanup := func() { + _ = mgr.Close(context.Background()) + _ = store.Close() + } + + return mgr, runner, cleanup +} + +// TestAppendMessage_UnrelatedSessionsDoNotSerialize asserts that two appends +// to two different sessions execute in parallel. The gatingStore parks the +// underlying store write on a per-session channel; the test verifies that +// both calls reach the parked point before either returns. Regression for +// the global m.mu lock in AppendMessage, which used to couple unrelated +// sessions through the seq-assign window. +func TestAppendMessage_UnrelatedSessionsDoNotSerialize(t *testing.T) { + t.Parallel() + + innerStore, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + + gate := newSessionGate() + gating := &gatingStore{Store: innerStore, gate: gate} + + mgr, _, cleanup := newTestManagerWithStore(t, gating) + defer cleanup() + + sess1, err := mgr.CreateSession(context.Background(), chat.CreateInput{Title: "a", CreatedBy: "human:t"}) + require.NoError(t, err) + sess2, err := mgr.CreateSession(context.Background(), chat.CreateInput{Title: "b", CreatedBy: "human:t"}) + require.NoError(t, err) + + gate.block(sess1.ID) + gate.block(sess2.ID) + + var wg sync.WaitGroup + + wg.Add(2) + + go func() { + defer wg.Done() + + _, _ = mgr.AppendMessage(context.Background(), sess1.ID, chat.RoleUser, "x") + }() + + go func() { + defer wg.Done() + + _, _ = mgr.AppendMessage(context.Background(), sess2.ID, chat.RoleUser, "y") + }() + + require.Eventually(t, func() bool { return gate.waiting(sess1.ID) && gate.waiting(sess2.ID) }, + time.Second, 5*time.Millisecond, + "both AppendMessage calls must reach the store write concurrently") + + gate.release(sess1.ID) + gate.release(sess2.ID) + wg.Wait() +} + +// TestOpenSession_ConcurrentColdOpensRunInParallel asserts that two cold +// opens for distinct session IDs route through their own singleflight slot +// and reach the runner concurrently. Before the singleflight refactor, a +// global openMu serialised the cold-start path so the second call observed +// the first's full StartChat latency; one slow docker pull stalled every +// other cold open. With singleflight keyed on sessionID, two distinct +// sessions complete within ~one StartChat duration. +func TestOpenSession_ConcurrentColdOpensRunInParallel(t *testing.T) { + release := make(chan struct{}) + + var calls atomic.Int64 + + runner := &countingRunner{ + startChat: func(_ context.Context, opts chat.StartChatOpts) (string, error) { + calls.Add(1) + <-release + + return "container-" + opts.SessionID, nil + }, + } + + mgr, _, cleanup := newTestManagerWithRunner(t, runner) + defer cleanup() + + sess1, err := mgr.CreateSession(context.Background(), chat.CreateInput{Title: "a", Project: "alpha", CreatedBy: "human:t"}) + require.NoError(t, err) + sess2, err := mgr.CreateSession(context.Background(), chat.CreateInput{Title: "b", Project: "alpha", CreatedBy: "human:t"}) + require.NoError(t, err) + + var wg sync.WaitGroup + + wg.Add(2) + + go func() { defer wg.Done(); _, _ = mgr.OpenSession(context.Background(), sess1.ID) }() + go func() { defer wg.Done(); _, _ = mgr.OpenSession(context.Background(), sess2.ID) }() + + require.Eventually(t, func() bool { return calls.Load() == 2 }, time.Second, 5*time.Millisecond, + "both runner.StartChat calls must be in flight concurrently") + + close(release) + wg.Wait() +} diff --git a/internal/chat/reaper.go b/internal/chat/reaper.go new file mode 100644 index 00000000..fc1d48f9 --- /dev/null +++ b/internal/chat/reaper.go @@ -0,0 +1,101 @@ +package chat + +import ( + "context" + "sync" + "time" +) + +// IdleReaper periodically scans warm-idle sessions and ends those whose +// last_active is older than the Manager's configured IdleTTL. +type IdleReaper struct { + mgr *Manager + interval time.Duration + stopCh chan struct{} + stopOnce sync.Once +} + +// NewIdleReaper wires a reaper. interval should be << IdleTTL in production; +// tests pass a tiny interval for quick triggering. +func NewIdleReaper(mgr *Manager, interval time.Duration) *IdleReaper { + return &IdleReaper{mgr: mgr, interval: interval, stopCh: make(chan struct{})} +} + +// Run blocks until ctx is cancelled or Stop is called. +func (r *IdleReaper) Run(ctx context.Context) { + ticker := time.NewTicker(r.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-r.stopCh: + return + case <-ticker.C: + r.tick(ctx) + } + } +} + +// Stop ends Run. Safe to call any number of times — repeated calls are no-ops. +func (r *IdleReaper) Stop() { r.stopOnce.Do(func() { close(r.stopCh) }) } + +func (r *IdleReaper) tick(ctx context.Context) { + r.sweepWarmIdle(ctx) + r.sweepStaleRehydration(ctx) +} + +// sweepWarmIdle ends sessions whose warm-idle TTL has elapsed. +func (r *IdleReaper) sweepWarmIdle(ctx context.Context) { + sessions, err := r.mgr.store.ListSessions(ctx, SessionFilter{Status: StatusWarmIdle}) + if err != nil { + r.mgr.logger.Warn("chat: reaper list failed", "error", err) + + return + } + + cutoff := r.mgr.clk.Now().Add(-r.mgr.idleTTL) + + for _, sess := range sessions { + if sess.LastActive.Before(cutoff) { + if err := r.mgr.EndSession(ctx, sess.ID); err != nil { + r.mgr.logger.Warn("chat: reaper end failed", "session_id", sess.ID, "error", err) + } + } + } +} + +// sweepStaleRehydration forces rehydration_active off for sessions whose +// phase has been open longer than the configured timeout. Safety net for +// agents that crashed or otherwise never reached chat_rehydration_complete +// and where the operator hasn't yet typed. +func (r *IdleReaper) sweepStaleRehydration(ctx context.Context) { + if r.mgr.rehydrationTimeout <= 0 { + return + } + + active := true + + sessions, err := r.mgr.store.ListSessions(ctx, SessionFilter{ + RehydrationActive: &active, + LastActiveBefore: r.mgr.clk.Now().Add(-r.mgr.rehydrationTimeout), + }) + if err != nil { + r.mgr.logger.Warn("chat: reaper rehydration list failed", "error", err) + + return + } + + for _, sess := range sessions { + if err := r.mgr.setRehydrationActive(ctx, sess.ID, false); err != nil { + r.mgr.logger.Warn("chat: reaper rehydration clear failed", + "session_id", sess.ID, "error", err) + + continue + } + + r.mgr.logger.Info("chat: rehydration phase forced off by timeout", + "session_id", sess.ID) + } +} diff --git a/internal/chat/reaper_export_test.go b/internal/chat/reaper_export_test.go new file mode 100644 index 00000000..50d580dd --- /dev/null +++ b/internal/chat/reaper_export_test.go @@ -0,0 +1,8 @@ +package chat + +import "context" + +// SweepStaleRehydrationForTest exports the private sweepStaleRehydration for testing. +func (r *IdleReaper) SweepStaleRehydrationForTest(ctx context.Context) { + r.sweepStaleRehydration(ctx) +} diff --git a/internal/chat/reaper_test.go b/internal/chat/reaper_test.go new file mode 100644 index 00000000..1cb9eab8 --- /dev/null +++ b/internal/chat/reaper_test.go @@ -0,0 +1,228 @@ +package chat_test + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mhersson/contextmatrix/internal/chat" + "github.com/mhersson/contextmatrix/internal/chat/sqlite" + "github.com/mhersson/contextmatrix/internal/clock" +) + +func TestIdleReaper_EndsWarmIdlePastTTL(t *testing.T) { + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + fakeClock := clock.Fake(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC)) + runner := &stubRunner{} + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: runner, + Clock: fakeClock, + IdleTTL: 30 * time.Minute, + }) + + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Title: "t", CreatedBy: "x"}) + require.NoError(t, err) + // Set to warm-idle with last_active far in the past + sess.Status = chat.StatusWarmIdle + sess.LastActive = fakeClock.Now().Add(-2 * time.Hour) + require.NoError(t, store.UpdateSession(ctx, sess)) + + reaper := chat.NewIdleReaper(mgr, 1*time.Millisecond) + go reaper.Run(ctx) + + t.Cleanup(reaper.Stop) + + require.Eventually(t, func() bool { + got, err := store.GetSession(ctx, sess.ID) + if err != nil { + return false + } + + return got.Status == chat.StatusCold + }, 2*time.Second, 5*time.Millisecond, "reaper did not transition session to cold") + + assert.Equal(t, int64(1), runner.endCalls.Load()) +} + +// TestIdleReaper_Stop_DoubleCallSafe verifies that calling Stop twice does +// not panic. The reaper is plumbed through main.go's lifecycle and shutdown +// hooks can fire it more than once during graceful shutdown / signal-driven +// teardown. +func TestIdleReaper_Stop_DoubleCallSafe(t *testing.T) { + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: &stubRunner{}, + Clock: clock.Real(), + IdleTTL: time.Hour, + }) + + reaper := chat.NewIdleReaper(mgr, time.Hour) + + // First Stop closes the channel; second Stop must be a no-op, not a panic. + assert.NotPanics(t, reaper.Stop) + assert.NotPanics(t, reaper.Stop) + assert.NotPanics(t, reaper.Stop) +} + +func TestIdleReaper_SweepStaleRehydration_FlipsTimeoutSessions(t *testing.T) { + t.Parallel() + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + fakeClock := clock.Fake(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC)) + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: &stubRunner{}, + Clock: fakeClock, + IdleTTL: 1 * time.Hour, + RehydrationTimeout: 10 * time.Minute, + }) + + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + + // Activate rehydration and push last_active 15 min into the past. + require.NoError(t, mgr.SetRehydrationActiveForTest(ctx, sess.ID, true)) + sess.LastActive = fakeClock.Now().Add(-15 * time.Minute) + require.NoError(t, store.UpdateSession(ctx, sess)) + + reaper := chat.NewIdleReaper(mgr, 1*time.Millisecond) + reaper.SweepStaleRehydrationForTest(ctx) + + got, err := mgr.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.False(t, got.RehydrationActive, "stale rehydration flag should be flipped off") +} + +func TestIdleReaper_SweepStaleRehydration_LeavesRecentAlone(t *testing.T) { + t.Parallel() + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + fakeClock := clock.Fake(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC)) + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: &stubRunner{}, + Clock: fakeClock, + IdleTTL: 1 * time.Hour, + RehydrationTimeout: 10 * time.Minute, + }) + + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + + // Activate rehydration and push last_active 2 min into the past (within timeout). + require.NoError(t, mgr.SetRehydrationActiveForTest(ctx, sess.ID, true)) + sess.LastActive = fakeClock.Now().Add(-2 * time.Minute) + require.NoError(t, store.UpdateSession(ctx, sess)) + + reaper := chat.NewIdleReaper(mgr, 1*time.Millisecond) + reaper.SweepStaleRehydrationForTest(ctx) + + got, err := mgr.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.True(t, got.RehydrationActive, "recent rehydration must survive the sweep") +} + +func TestIdleReaper_SweepStaleRehydration_SkipsIfDisabled(t *testing.T) { + t.Parallel() + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + fakeClock := clock.Fake(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC)) + // RehydrationTimeout: 0 disables the sweep. + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: &stubRunner{}, + Clock: fakeClock, + IdleTTL: 1 * time.Hour, + RehydrationTimeout: 0, + }) + + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + + // Activate rehydration and push far into the past. + require.NoError(t, mgr.SetRehydrationActiveForTest(ctx, sess.ID, true)) + sess.LastActive = fakeClock.Now().Add(-1 * time.Hour) + require.NoError(t, store.UpdateSession(ctx, sess)) + + reaper := chat.NewIdleReaper(mgr, 1*time.Millisecond) + reaper.SweepStaleRehydrationForTest(ctx) + + got, err := mgr.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.True(t, got.RehydrationActive, "sweep must be skipped when RehydrationTimeout is 0") +} + +func TestIdleReaper_SweepStaleRehydration_MultipleStale(t *testing.T) { + t.Parallel() + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + fakeClock := clock.Fake(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC)) + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: &stubRunner{}, + Clock: fakeClock, + IdleTTL: 1 * time.Hour, + RehydrationTimeout: 5 * time.Minute, + }) + + ctx := context.Background() + + // Create three sessions: two stale, one fresh. + sess1, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + require.NoError(t, mgr.SetRehydrationActiveForTest(ctx, sess1.ID, true)) + sess1.LastActive = fakeClock.Now().Add(-10 * time.Minute) + require.NoError(t, store.UpdateSession(ctx, sess1)) + + sess2, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + require.NoError(t, mgr.SetRehydrationActiveForTest(ctx, sess2.ID, true)) + sess2.LastActive = fakeClock.Now().Add(-8 * time.Minute) + require.NoError(t, store.UpdateSession(ctx, sess2)) + + sess3, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + require.NoError(t, mgr.SetRehydrationActiveForTest(ctx, sess3.ID, true)) + sess3.LastActive = fakeClock.Now().Add(-1 * time.Minute) + require.NoError(t, store.UpdateSession(ctx, sess3)) + + reaper := chat.NewIdleReaper(mgr, 1*time.Millisecond) + reaper.SweepStaleRehydrationForTest(ctx) + + // Check results. + got1, err := mgr.GetSession(ctx, sess1.ID) + require.NoError(t, err) + assert.False(t, got1.RehydrationActive, "sess1 (-10m) should have rehydration flipped off") + + got2, err := mgr.GetSession(ctx, sess2.ID) + require.NoError(t, err) + assert.False(t, got2.RehydrationActive, "sess2 (-8m) should have rehydration flipped off") + + got3, err := mgr.GetSession(ctx, sess3.ID) + require.NoError(t, err) + assert.True(t, got3.RehydrationActive, "sess3 (-1m) should survive") +} diff --git a/internal/chat/runner.go b/internal/chat/runner.go new file mode 100644 index 00000000..cac7049a --- /dev/null +++ b/internal/chat/runner.go @@ -0,0 +1,247 @@ +package chat + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/mhersson/contextmatrix/internal/runner" +) + +// RunnerClientConfig wires the HMAC-signed webhook client. +type RunnerClientConfig struct { + BaseURL string // e.g. http://contextmatrix-runner:8080 + HMACKey string // pre-shared HMAC secret + MCPAPIKey string // forwarded to chat containers as CM_MCP_API_KEY + HTTPClient *http.Client // optional; defaults to a 30s-timeout client +} + +// runnerClient implements RunnerClient by talking HMAC-signed HTTP to the +// runner's /chat/* and /message endpoints. +type runnerClient struct { + baseURL string + key string + mcpAPIKey string + httpc *http.Client +} + +// NewRunnerClient constructs a RunnerClient. If cfg.HTTPClient is nil, a +// 30-second-timeout default client is used. +func NewRunnerClient(cfg RunnerClientConfig) RunnerClient { + c := cfg.HTTPClient + if c == nil { + c = &http.Client{Timeout: 30 * time.Second} + } + + return &runnerClient{baseURL: cfg.BaseURL, key: cfg.HMACKey, mcpAPIKey: cfg.MCPAPIKey, httpc: c} +} + +type chatStartPayload struct { + SessionID string `json:"session_id"` + Project string `json:"project,omitempty"` + RepoURL string `json:"repo_url,omitempty"` + MCPAPIKey string `json:"mcp_api_key,omitempty"` + Model string `json:"model,omitempty"` + Resume *ResumeContext `json:"resume,omitempty"` +} + +type chatEndPayload struct { + SessionID string `json:"session_id"` +} + +type messagePayload struct { + SessionID string `json:"session_id"` + Content string `json:"content"` + MessageID string `json:"message_id,omitempty"` +} + +func (c *runnerClient) StartChat(ctx context.Context, opts StartChatOpts) (string, error) { + body, err := json.Marshal(chatStartPayload{ + SessionID: opts.SessionID, + Project: opts.Project, + RepoURL: opts.RepoURL, + MCPAPIKey: c.mcpAPIKey, + Model: opts.Model, + Resume: opts.Resume, + }) + if err != nil { + return "", fmt.Errorf("chat: runner: marshal StartChat payload: %w", err) + } + + resp, err := c.post(ctx, "/chat/start", body) + if err != nil { + return "", err + } + + var out struct { + ContainerID string `json:"container_id"` + } + if err := json.Unmarshal(resp, &out); err != nil { + return "", fmt.Errorf("chat: decode StartChat resp: %w", err) + } + + return out.ContainerID, nil +} + +func (c *runnerClient) EndChat(ctx context.Context, sessionID string) error { + body, err := json.Marshal(chatEndPayload{SessionID: sessionID}) + if err != nil { + return fmt.Errorf("chat: runner: marshal EndChat payload: %w", err) + } + + _, err = c.post(ctx, "/chat/end", body) + + return err +} + +func (c *runnerClient) SendChatMessage(ctx context.Context, sessionID, content, messageID string) error { + body, err := json.Marshal(messagePayload{SessionID: sessionID, Content: content, MessageID: messageID}) + if err != nil { + return fmt.Errorf("chat: runner: marshal SendChatMessage payload: %w", err) + } + + _, err = c.post(ctx, "/message", body) + + return err +} + +// runnerLogEntry mirrors the runner's logbroadcast.LogEntry JSON shape. +type runnerLogEntry struct { + Timestamp time.Time `json:"ts"` + SessionID string `json:"session_id,omitempty"` + Type string `json:"type"` + Content string `json:"content,omitempty"` + Usage *runnerLogUsage `json:"usage,omitempty"` + Model string `json:"model,omitempty"` +} + +// runnerLogUsage mirrors the runner's logbroadcast.TokenUsage JSON shape. +type runnerLogUsage struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + CacheCreateTokens int64 `json:"cache_creation_tokens"` +} + +// StreamLogs subscribes to the runner's HMAC-signed /logs?session_id= +// SSE endpoint and dispatches each parsed entry to onEntry. The HTTP client +// is constructed without a timeout for this call because the SSE connection +// is long-lived; cancellation is via ctx. +func (c *runnerClient) StreamLogs(ctx context.Context, sessionID string, onEntry func(LogEntry)) error { + fullURL := c.baseURL + "/logs?session_id=" + url.QueryEscape(sessionID) + + parsed, err := url.Parse(fullURL) + if err != nil { + return fmt.Errorf("chat: parse logs URL: %w", err) + } + + uri := parsed.RequestURI() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + return fmt.Errorf("chat: build logs request: %w", err) + } + + sig, ts := runner.SignRequestHeaders(c.key, http.MethodGet, uri, nil) + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("X-Signature-256", sig) + req.Header.Set("X-Webhook-Timestamp", ts) + + // Use a no-timeout client for the SSE stream; ctx drives cancellation. + streamClient := &http.Client{} + + resp, err := streamClient.Do(req) + if err != nil { + return fmt.Errorf("chat: /logs request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 300 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + + return fmt.Errorf("chat: /logs: status %d: %s", resp.StatusCode, string(respBody)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 1<<20) + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + + var entry runnerLogEntry + if err := json.Unmarshal([]byte(strings.TrimPrefix(line, "data: ")), &entry); err != nil { + continue + } + + out := LogEntry{ + Timestamp: entry.Timestamp, + Type: entry.Type, + Content: entry.Content, + Model: entry.Model, + } + + if entry.Usage != nil { + out.Usage = &TokenUsage{ + InputTokens: entry.Usage.InputTokens, + OutputTokens: entry.Usage.OutputTokens, + CacheReadTokens: entry.Usage.CacheReadTokens, + CacheCreateTokens: entry.Usage.CacheCreateTokens, + } + } + + onEntry(out) + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("chat: /logs scan: %w", err) + } + + return nil +} + +// post sends an HMAC-signed POST and returns the body on 2xx. +func (c *runnerClient) post(ctx context.Context, path string, body []byte) ([]byte, error) { + fullURL := c.baseURL + path + + parsed, err := url.Parse(fullURL) + if err != nil { + return nil, fmt.Errorf("chat: parse URL: %w", err) + } + + uri := parsed.RequestURI() // path + "?" + raw query (or just path) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("chat: build request: %w", err) + } + + sig, ts := runner.SignRequestHeaders(c.key, http.MethodPost, uri, body) + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Signature-256", sig) + req.Header.Set("X-Webhook-Timestamp", ts) + + resp, err := c.httpc.Do(req) + if err != nil { + return nil, fmt.Errorf("chat: %s: %w", path, err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode >= 300 { + return nil, fmt.Errorf("chat: %s: status %d: %s", path, resp.StatusCode, string(respBody)) + } + + return respBody, nil +} diff --git a/internal/chat/runner_test.go b/internal/chat/runner_test.go new file mode 100644 index 00000000..eb164f6a --- /dev/null +++ b/internal/chat/runner_test.go @@ -0,0 +1,102 @@ +package chat_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mhersson/contextmatrix/internal/chat" +) + +func TestRunnerClient_StartChat_HappyPath(t *testing.T) { + var received struct { + path string + body map[string]any + sig string + ts string + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + received.path = r.URL.Path + received.sig = r.Header.Get("X-Signature-256") + received.ts = r.Header.Get("X-Webhook-Timestamp") + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &received.body) + + w.WriteHeader(http.StatusAccepted) + _ = json.NewEncoder(w).Encode(map[string]any{"ok": true, "container_id": "c-1"}) + })) + t.Cleanup(srv.Close) + + rc := chat.NewRunnerClient(chat.RunnerClientConfig{ + BaseURL: srv.URL, + HMACKey: "k", + }) + containerID, err := rc.StartChat(context.Background(), chat.StartChatOpts{ + SessionID: "S1", + Project: "alpha", + RepoURL: "https://x/y", + Model: "claude-sonnet-4-6", + Resume: &chat.ResumeContext{ + Turns: []chat.ResumeTurn{{Seq: 1, Role: "user", Content: "hi"}}, + }, + }) + require.NoError(t, err) + assert.Equal(t, "c-1", containerID) + assert.Equal(t, "/chat/start", received.path) + assert.Equal(t, "S1", received.body["session_id"]) + assert.Equal(t, "alpha", received.body["project"]) + assert.Equal(t, "claude-sonnet-4-6", received.body["model"]) + assert.NotEmpty(t, received.sig) + assert.NotEmpty(t, received.ts) + + resume, ok := received.body["resume"].(map[string]any) + require.True(t, ok, "resume should be present in payload") + + turns, ok := resume["turns"].([]any) + require.True(t, ok) + require.Len(t, turns, 1) +} + +func TestRunnerClient_EndChat_ReturnsErrorOnNon2xx(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":"no_container"}`)) + })) + t.Cleanup(srv.Close) + + rc := chat.NewRunnerClient(chat.RunnerClientConfig{BaseURL: srv.URL, HMACKey: "k"}) + err := rc.EndChat(context.Background(), "S1") + require.Error(t, err) +} + +func TestRunnerClient_SendChatMessage_PostsToMessage(t *testing.T) { + var ( + receivedBody map[string]any + receivedPath string + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedPath = r.URL.Path + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &receivedBody) + + w.WriteHeader(http.StatusAccepted) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + t.Cleanup(srv.Close) + + rc := chat.NewRunnerClient(chat.RunnerClientConfig{BaseURL: srv.URL, HMACKey: "k"}) + err := rc.SendChatMessage(context.Background(), "S1", "hello", "msg-1") + require.NoError(t, err) + assert.Equal(t, "/message", receivedPath) + assert.Equal(t, "S1", receivedBody["session_id"]) + assert.Equal(t, "hello", receivedBody["content"]) + assert.Equal(t, "msg-1", receivedBody["message_id"]) +} diff --git a/internal/chat/sqlite/migrations.go b/internal/chat/sqlite/migrations.go new file mode 100644 index 00000000..b710453d --- /dev/null +++ b/internal/chat/sqlite/migrations.go @@ -0,0 +1,211 @@ +package sqlite + +import ( + "context" + "database/sql" + "fmt" + "time" +) + +// migration represents a single versioned schema change. Each Up function is +// idempotent (every statement uses IF EXISTS / IF NOT EXISTS) so the runner +// is safe to re-execute on pre-versioning databases without back-filling the +// schema_migrations rows separately. +type migration struct { + version int + up func(ctx context.Context, db *sql.DB) error +} + +var migrations = []migration{ + { + version: 1, + up: func(ctx context.Context, db *sql.DB) error { + return execAll(ctx, db, []string{ + `CREATE TABLE IF NOT EXISTS chat_sessions ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + project TEXT, + status TEXT NOT NULL, + created_at INTEGER NOT NULL, + last_active INTEGER NOT NULL, + created_by TEXT NOT NULL, + container_id TEXT, + workspace TEXT + )`, + `CREATE INDEX IF NOT EXISTS idx_chat_sessions_last_active ON chat_sessions(last_active)`, + `CREATE INDEX IF NOT EXISTS idx_chat_sessions_status ON chat_sessions(status)`, + `CREATE TABLE IF NOT EXISTS chat_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + seq INTEGER NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + created_at INTEGER NOT NULL, + FOREIGN KEY (session_id) REFERENCES chat_sessions(id) ON DELETE CASCADE + )`, + `CREATE INDEX IF NOT EXISTS idx_chat_messages_session_seq ON chat_messages(session_id, seq)`, + }) + }, + }, + { + version: 2, + up: func(ctx context.Context, db *sql.DB) error { + return execAll(ctx, db, []string{ + `DROP INDEX IF EXISTS idx_chat_messages_session_seq`, + `CREATE UNIQUE INDEX IF NOT EXISTS idx_chat_messages_session_seq_unique ON chat_messages(session_id, seq)`, + }) + }, + }, + { + version: 3, + up: func(ctx context.Context, db *sql.DB) error { + // SQLite cannot ALTER TABLE ADD COLUMN IF NOT EXISTS, so we + // introspect pragma table_info first and skip already-present + // columns. This keeps the migration safe to re-run against any + // database that drifted from the versioned migration history + // (e.g. one that had v3 partially applied before crashing). + if err := addColumnIfMissing(ctx, db, "chat_sessions", "model", + `ALTER TABLE chat_sessions ADD COLUMN model TEXT NOT NULL DEFAULT ''`); err != nil { + return err + } + + if err := addColumnIfMissing(ctx, db, "chat_sessions", "context_tokens", + `ALTER TABLE chat_sessions ADD COLUMN context_tokens INTEGER NOT NULL DEFAULT 0`); err != nil { + return err + } + + if err := addColumnIfMissing(ctx, db, "chat_sessions", "context_tokens_updated_at", + `ALTER TABLE chat_sessions ADD COLUMN context_tokens_updated_at INTEGER`); err != nil { + return err + } + + if err := addColumnIfMissing(ctx, db, "chat_sessions", "rehydration_active", + `ALTER TABLE chat_sessions ADD COLUMN rehydration_active INTEGER NOT NULL DEFAULT 0`); err != nil { + return err + } + + if err := addColumnIfMissing(ctx, db, "chat_messages", "rehydration_phase", + `ALTER TABLE chat_messages ADD COLUMN rehydration_phase INTEGER NOT NULL DEFAULT 0`); err != nil { + return err + } + + return execAll(ctx, db, []string{ + `CREATE INDEX IF NOT EXISTS idx_chat_messages_phase ON chat_messages(session_id, rehydration_phase)`, + }) + }, + }, +} + +// addColumnIfMissing applies an ALTER TABLE ADD COLUMN statement only if the +// column is not already present. SQLite lacks IF NOT EXISTS on ADD COLUMN, +// and pre-versioning databases may have had columns added by an earlier code +// path that drifted from the migrations list. +func addColumnIfMissing(ctx context.Context, db *sql.DB, table, column, stmt string) error { + rows, err := db.QueryContext(ctx, fmt.Sprintf(`PRAGMA table_info(%q)`, table)) + if err != nil { + return fmt.Errorf("introspect %s: %w", table, err) + } + defer rows.Close() + + for rows.Next() { + var ( + cid int + name string + ctype string + notnull int + dfltValue sql.NullString + pk int + ) + + if err := rows.Scan(&cid, &name, &ctype, ¬null, &dfltValue, &pk); err != nil { + return fmt.Errorf("scan %s columns: %w", table, err) + } + + if name == column { + if err := rows.Err(); err != nil { + return fmt.Errorf("iterate %s columns: %w", table, err) + } + + return nil + } + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("iterate %s columns: %w", table, err) + } + + if _, err := db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("add column %s.%s: %w", table, column, err) + } + + return nil +} + +func migrate(ctx context.Context, db *sql.DB) error { + if _, err := db.ExecContext(ctx, `CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at INTEGER NOT NULL + )`); err != nil { + return fmt.Errorf("chat schema_migrations table: %w", err) + } + + applied, err := loadAppliedVersions(ctx, db) + if err != nil { + return err + } + + for _, m := range migrations { + if applied[m.version] { + continue + } + + if err := m.up(ctx, db); err != nil { + return fmt.Errorf("chat migration v%d: %w", m.version, err) + } + + if _, err := db.ExecContext(ctx, + `INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)`, + m.version, time.Now().Unix(), + ); err != nil { + return fmt.Errorf("chat record migration v%d: %w", m.version, err) + } + } + + return nil +} + +func loadAppliedVersions(ctx context.Context, db *sql.DB) (map[int]bool, error) { + rows, err := db.QueryContext(ctx, `SELECT version FROM schema_migrations`) + if err != nil { + return nil, fmt.Errorf("chat schema_migrations query: %w", err) + } + + defer rows.Close() + + applied := map[int]bool{} + + for rows.Next() { + var v int + if err := rows.Scan(&v); err != nil { + return nil, fmt.Errorf("chat schema_migrations scan: %w", err) + } + + applied[v] = true + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("chat schema_migrations rows: %w", err) + } + + return applied, nil +} + +func execAll(ctx context.Context, db *sql.DB, statements []string) error { + for i, stmt := range statements { + if _, err := db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("step %d: %w", i, err) + } + } + + return nil +} diff --git a/internal/chat/sqlite/migrations_test.go b/internal/chat/sqlite/migrations_test.go new file mode 100644 index 00000000..94ccd145 --- /dev/null +++ b/internal/chat/sqlite/migrations_test.go @@ -0,0 +1,220 @@ +package sqlite + +import ( + "database/sql" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMigrate_FreshDB_AppliesAllVersionsAndDropsRedundantIndex(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "chats.db") + s, err := Open(dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = s.Close() }) + + assert.Equal(t, []int{1, 2, 3}, appliedVersions(t, s.db)) + assert.True(t, indexExists(t, s.db, "idx_chat_messages_session_seq_unique")) + assert.False(t, indexExists(t, s.db, "idx_chat_messages_session_seq")) +} + +func TestMigrate_PreWave38DB_AppliesV2(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "chats.db") + seedV1OnlySchema(t, dbPath) + + s, err := Open(dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = s.Close() }) + + assert.Equal(t, []int{1, 2, 3}, appliedVersions(t, s.db)) + assert.True(t, indexExists(t, s.db, "idx_chat_messages_session_seq_unique")) + assert.False(t, indexExists(t, s.db, "idx_chat_messages_session_seq")) +} + +func TestMigrate_Wave38DB_DropsRedundantNonUniqueIndex(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "chats.db") + seedV1OnlySchema(t, dbPath) + addUniqueIndex(t, dbPath) + + s, err := Open(dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = s.Close() }) + + assert.Equal(t, []int{1, 2, 3}, appliedVersions(t, s.db)) + assert.True(t, indexExists(t, s.db, "idx_chat_messages_session_seq_unique")) + assert.False(t, indexExists(t, s.db, "idx_chat_messages_session_seq")) +} + +func TestMigrate_ReopenDoesNotDuplicateVersionRows(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "chats.db") + s1, err := Open(dbPath) + require.NoError(t, err) + require.NoError(t, s1.Close()) + + s2, err := Open(dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = s2.Close() }) + + assert.Equal(t, []int{1, 2, 3}, appliedVersions(t, s2.db)) +} + +func TestMigrate_V3_AddsRehydrationAndModelColumns(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "chats.db") + s, err := Open(dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = s.Close() }) + + assert.True(t, columnExists(t, s.db, "chat_sessions", "model")) + assert.True(t, columnExists(t, s.db, "chat_sessions", "context_tokens")) + assert.True(t, columnExists(t, s.db, "chat_sessions", "context_tokens_updated_at")) + assert.True(t, columnExists(t, s.db, "chat_sessions", "rehydration_active")) + assert.True(t, columnExists(t, s.db, "chat_messages", "rehydration_phase")) + assert.True(t, indexExists(t, s.db, "idx_chat_messages_phase")) +} + +func TestMigrate_V3_IdempotentOnPreV3DBWithPartialColumns(t *testing.T) { + // Simulate a database that drifted from the version history: v1 + v2 + // schema in place but one v3 column already exists (e.g. added by a + // buggy intermediate build). addColumnIfMissing must not error. + dbPath := filepath.Join(t.TempDir(), "chats.db") + seedV1OnlySchema(t, dbPath) + addUniqueIndex(t, dbPath) + addPartialV3(t, dbPath) + + s, err := Open(dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = s.Close() }) + + assert.Equal(t, []int{1, 2, 3}, appliedVersions(t, s.db)) + assert.True(t, columnExists(t, s.db, "chat_sessions", "model")) + assert.True(t, columnExists(t, s.db, "chat_sessions", "rehydration_active")) + assert.True(t, columnExists(t, s.db, "chat_messages", "rehydration_phase")) +} + +func columnExists(t *testing.T, db *sql.DB, table, column string) bool { + t.Helper() + + rows, err := db.Query(`PRAGMA table_info(` + table + `)`) + require.NoError(t, err) + + defer rows.Close() + + for rows.Next() { + var ( + cid int + name string + ctype string + notnull int + dfltValue sql.NullString + pk int + ) + + require.NoError(t, rows.Scan(&cid, &name, &ctype, ¬null, &dfltValue, &pk)) + + if name == column { + return true + } + } + + require.NoError(t, rows.Err()) + + return false +} + +func addPartialV3(t *testing.T, dbPath string) { + t.Helper() + + db, err := sql.Open("sqlite", dbPath+"?_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)&_pragma=busy_timeout(5000)") + require.NoError(t, err) + + _, err = db.Exec(`ALTER TABLE chat_sessions ADD COLUMN model TEXT NOT NULL DEFAULT ''`) + require.NoError(t, err) + + require.NoError(t, db.Close()) +} + +func appliedVersions(t *testing.T, db *sql.DB) []int { + t.Helper() + + rows, err := db.Query(`SELECT version FROM schema_migrations ORDER BY version ASC`) + require.NoError(t, err) + + defer rows.Close() + + out := []int{} + + for rows.Next() { + var v int + + require.NoError(t, rows.Scan(&v)) + + out = append(out, v) + } + + require.NoError(t, rows.Err()) + + return out +} + +func indexExists(t *testing.T, db *sql.DB, name string) bool { + t.Helper() + + var n int + + err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='index' AND name=?`, name).Scan(&n) + require.NoError(t, err) + + return n > 0 +} + +func seedV1OnlySchema(t *testing.T, dbPath string) { + t.Helper() + + db, err := sql.Open("sqlite", dbPath+"?_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)&_pragma=busy_timeout(5000)") + require.NoError(t, err) + + for _, stmt := range []string{ + `CREATE TABLE chat_sessions ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + project TEXT, + status TEXT NOT NULL, + created_at INTEGER NOT NULL, + last_active INTEGER NOT NULL, + created_by TEXT NOT NULL, + container_id TEXT, + workspace TEXT + )`, + `CREATE INDEX idx_chat_sessions_last_active ON chat_sessions(last_active)`, + `CREATE INDEX idx_chat_sessions_status ON chat_sessions(status)`, + `CREATE TABLE chat_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + seq INTEGER NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + created_at INTEGER NOT NULL, + FOREIGN KEY (session_id) REFERENCES chat_sessions(id) ON DELETE CASCADE + )`, + `CREATE INDEX idx_chat_messages_session_seq ON chat_messages(session_id, seq)`, + } { + _, err := db.Exec(stmt) + require.NoError(t, err) + } + + require.NoError(t, db.Close()) +} + +func addUniqueIndex(t *testing.T, dbPath string) { + t.Helper() + + db, err := sql.Open("sqlite", dbPath+"?_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)&_pragma=busy_timeout(5000)") + require.NoError(t, err) + + _, err = db.Exec(`CREATE UNIQUE INDEX idx_chat_messages_session_seq_unique ON chat_messages(session_id, seq)`) + require.NoError(t, err) + + require.NoError(t, db.Close()) +} diff --git a/internal/chat/sqlite/store.go b/internal/chat/sqlite/store.go new file mode 100644 index 00000000..c5e276ae --- /dev/null +++ b/internal/chat/sqlite/store.go @@ -0,0 +1,399 @@ +package sqlite + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "time" + + _ "modernc.org/sqlite" // register sqlite driver + + "github.com/mhersson/contextmatrix/internal/chat" +) + +// compile-time assertion that *Store satisfies chat.Store. +var _ chat.Store = (*Store)(nil) + +// Store is the SQLite-backed implementation of chat.Store. +type Store struct { + db *sql.DB +} + +// Open opens (or creates) the SQLite database at path and applies the +// schema migrations. Parent directories are created as needed. +func Open(path string) (*Store, error) { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return nil, fmt.Errorf("chat: ensure db dir: %w", err) + } + + db, err := sql.Open("sqlite", path+"?_pragma=foreign_keys(1)&_pragma=journal_mode(WAL)&_pragma=busy_timeout(5000)") + if err != nil { + return nil, fmt.Errorf("chat: open sqlite: %w", err) + } + + // SQLite is single-writer regardless of pool size; serialisation across + // writers happens at the manager level (chat.Manager.mu held across + // AppendMessage). MaxOpenConns > 1 lets concurrent readers (ListMessages, + // MaxSeq, GetSession) avoid queueing behind a writer when WAL is on. + db.SetMaxOpenConns(5) + + if err := migrate(context.Background(), db); err != nil { + _ = db.Close() + + return nil, err + } + + return &Store{db: db}, nil +} + +// Close releases the underlying database. +func (s *Store) Close() error { return s.db.Close() } + +// sessionColumns lists every column read by scanSession in the exact order +// the SELECT statement projects them. Kept as a single source of truth so +// new fields don't drift between GetSession, ListSessions, and scanSession. +const sessionColumns = `id, title, project, status, created_at, last_active, created_by, + container_id, workspace, model, context_tokens, context_tokens_updated_at, rehydration_active` + +func (s *Store) CreateSession(ctx context.Context, sess chat.Session) error { + workspaceJSON, err := json.Marshal(sess.Workspace) + if err != nil { + return fmt.Errorf("chat: marshal workspace: %w", err) + } + + _, err = s.db.ExecContext(ctx, `INSERT INTO chat_sessions + (id, title, project, status, created_at, last_active, created_by, container_id, workspace, model) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + sess.ID, sess.Title, nullIf(sess.Project), string(sess.Status), + sess.CreatedAt.Unix(), sess.LastActive.Unix(), sess.CreatedBy, + nullIf(sess.ContainerID), string(workspaceJSON), sess.Model, + ) + if err != nil { + return fmt.Errorf("chat: insert session: %w", err) + } + + return nil +} + +func (s *Store) GetSession(ctx context.Context, id string) (chat.Session, error) { + row := s.db.QueryRowContext(ctx, `SELECT `+sessionColumns+` FROM chat_sessions WHERE id = ?`, id) + + return scanSession(row) +} + +func (s *Store) ListSessions(ctx context.Context, f chat.SessionFilter) ([]chat.Session, error) { + q := `SELECT ` + sessionColumns + ` FROM chat_sessions WHERE 1=1` + + var args []any + + if f.Project != "" { + q += " AND project = ?" + + args = append(args, f.Project) + } + + if f.Status != "" { + q += " AND status = ?" + + args = append(args, string(f.Status)) + } + + if f.CreatedBy != "" { + q += " AND created_by = ?" + + args = append(args, f.CreatedBy) + } + + if f.RehydrationActive != nil { + q += " AND rehydration_active = ?" + + if *f.RehydrationActive { + args = append(args, 1) + } else { + args = append(args, 0) + } + } + + if !f.LastActiveBefore.IsZero() { + q += " AND last_active < ?" + + args = append(args, f.LastActiveBefore.Unix()) + } + + q += " ORDER BY last_active DESC" + + if f.Limit > 0 { + q += " LIMIT ?" + + args = append(args, f.Limit) + } + + rows, err := s.db.QueryContext(ctx, q, args...) + if err != nil { + return nil, fmt.Errorf("chat: list sessions: %w", err) + } + + defer rows.Close() + + var out []chat.Session + + for rows.Next() { + sess, err := scanSession(rows) + if err != nil { + return nil, err + } + + out = append(out, sess) + } + + return out, rows.Err() +} + +func (s *Store) UpdateSession(ctx context.Context, sess chat.Session) error { + workspaceJSON, err := json.Marshal(sess.Workspace) + if err != nil { + return fmt.Errorf("chat: marshal workspace: %w", err) + } + + res, err := s.db.ExecContext(ctx, `UPDATE chat_sessions SET + title=?, project=?, status=?, last_active=?, container_id=?, workspace=?, model=? + WHERE id=?`, + sess.Title, nullIf(sess.Project), string(sess.Status), + sess.LastActive.Unix(), nullIf(sess.ContainerID), string(workspaceJSON), + sess.Model, sess.ID) + if err != nil { + return fmt.Errorf("chat: update session: %w", err) + } + + n, _ := res.RowsAffected() + if n == 0 { + return chat.ErrSessionNotFound + } + + return nil +} + +// SetRehydrationActive flips the rehydration_active flag on a session row. +// Targeted update avoids scribbling the entire session, which the consumer +// path would otherwise have to do via UpdateSession. +func (s *Store) SetRehydrationActive(ctx context.Context, sessionID string, active bool) error { + flag := 0 + if active { + flag = 1 + } + + res, err := s.db.ExecContext(ctx, + `UPDATE chat_sessions SET rehydration_active = ? WHERE id = ?`, + flag, sessionID) + if err != nil { + return fmt.Errorf("chat: set rehydration_active: %w", err) + } + + n, _ := res.RowsAffected() + if n == 0 { + return chat.ErrSessionNotFound + } + + return nil +} + +// UpdateContextTokens stamps the latest Claude usage block onto the session +// row. Called from the runner-log consumer when a "usage" event arrives. +func (s *Store) UpdateContextTokens(ctx context.Context, sessionID string, tokens int64, updatedAt time.Time) error { + res, err := s.db.ExecContext(ctx, + `UPDATE chat_sessions SET context_tokens = ?, context_tokens_updated_at = ? WHERE id = ?`, + tokens, updatedAt.Unix(), sessionID) + if err != nil { + return fmt.Errorf("chat: update context_tokens: %w", err) + } + + n, _ := res.RowsAffected() + if n == 0 { + return chat.ErrSessionNotFound + } + + return nil +} + +func (s *Store) DeleteSession(ctx context.Context, id string) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM chat_sessions WHERE id = ?`, id) + if err != nil { + return fmt.Errorf("chat: delete session: %w", err) + } + + return nil +} + +func (s *Store) AppendMessage(ctx context.Context, m chat.Message) (int64, error) { + phase := 0 + if m.RehydrationPhase { + phase = 1 + } + + _, err := s.db.ExecContext(ctx, `INSERT INTO chat_messages + (session_id, seq, role, content, created_at, rehydration_phase) + VALUES (?, ?, ?, ?, ?, ?)`, + m.SessionID, m.Seq, string(m.Role), m.Content, m.CreatedAt.Unix(), phase) + if err != nil { + return 0, fmt.Errorf("chat: append message: %w", err) + } + + return m.Seq, nil +} + +func (s *Store) MaxSeq(ctx context.Context, sessionID string) (int64, error) { + var maxSeq sql.NullInt64 + + err := s.db.QueryRowContext(ctx, + `SELECT MAX(seq) FROM chat_messages WHERE session_id = ?`, + sessionID, + ).Scan(&maxSeq) + if err != nil { + return 0, fmt.Errorf("chat: max seq: %w", err) + } + + return maxSeq.Int64, nil +} + +func (s *Store) ListMessages(ctx context.Context, sessionID string, sinceSeq int64, limit int) ([]chat.Message, error) { + rows, err := s.db.QueryContext(ctx, `SELECT id, session_id, seq, role, content, created_at, rehydration_phase + FROM chat_messages + WHERE session_id = ? AND seq > ? + ORDER BY seq ASC LIMIT ?`, sessionID, sinceSeq, limit) + if err != nil { + return nil, fmt.Errorf("chat: list messages: %w", err) + } + + defer rows.Close() + + var out []chat.Message + + for rows.Next() { + var ( + m chat.Message + createdAt int64 + role string + phase int + ) + + if err := rows.Scan(&m.ID, &m.SessionID, &m.Seq, &role, &m.Content, &createdAt, &phase); err != nil { + return nil, err + } + + m.Role = chat.Role(role) + m.CreatedAt = time.Unix(createdAt, 0).UTC() + m.RehydrationPhase = phase != 0 + out = append(out, m) + } + + return out, rows.Err() +} + +func (s *Store) ListMessagesTail(ctx context.Context, sessionID string, limit int) ([]chat.Message, error) { + if limit <= 0 { + return nil, nil + } + + rows, err := s.db.QueryContext(ctx, ` + SELECT id, session_id, seq, role, content, created_at, rehydration_phase + FROM ( + SELECT id, session_id, seq, role, content, created_at, rehydration_phase + FROM chat_messages + WHERE session_id = ? + ORDER BY seq DESC + LIMIT ? + ) + ORDER BY seq ASC + `, sessionID, limit) + if err != nil { + return nil, fmt.Errorf("list tail: %w", err) + } + + defer rows.Close() + + var out []chat.Message + + for rows.Next() { + var ( + m chat.Message + createdAt int64 + role string + phase int + ) + + if err := rows.Scan(&m.ID, &m.SessionID, &m.Seq, &role, &m.Content, &createdAt, &phase); err != nil { + return nil, err + } + + m.Role = chat.Role(role) + m.CreatedAt = time.Unix(createdAt, 0).UTC() + m.RehydrationPhase = phase != 0 + out = append(out, m) + } + + return out, rows.Err() +} + +// nullIf returns a sql.NullString that is NULL when s is empty. +func nullIf(s string) sql.NullString { + if s == "" { + return sql.NullString{} + } + + return sql.NullString{String: s, Valid: true} +} + +type scanner interface { + Scan(dest ...any) error +} + +func scanSession(sc scanner) (chat.Session, error) { + var ( + s chat.Session + project, containerID, model sql.NullString + workspaceJSON sql.NullString + createdAt, lastActive int64 + status string + contextTokens int64 + contextTokensUpdatedAt sql.NullInt64 + rehydrationActive int + ) + + if err := sc.Scan( + &s.ID, &s.Title, &project, &status, &createdAt, &lastActive, &s.CreatedBy, + &containerID, &workspaceJSON, + &model, &contextTokens, &contextTokensUpdatedAt, &rehydrationActive, + ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return chat.Session{}, chat.ErrSessionNotFound + } + + return chat.Session{}, fmt.Errorf("chat: scan session: %w", err) + } + + s.Project = project.String + s.ContainerID = containerID.String + s.Status = chat.Status(status) + s.CreatedAt = time.Unix(createdAt, 0).UTC() + s.LastActive = time.Unix(lastActive, 0).UTC() + s.Model = model.String + s.ContextTokens = contextTokens + + if contextTokensUpdatedAt.Valid { + s.ContextTokensUpdatedAt = time.Unix(contextTokensUpdatedAt.Int64, 0).UTC() + } + + s.RehydrationActive = rehydrationActive != 0 + + if workspaceJSON.Valid && workspaceJSON.String != "" && workspaceJSON.String != "null" { + if err := json.Unmarshal([]byte(workspaceJSON.String), &s.Workspace); err != nil { + return chat.Session{}, fmt.Errorf("chat: unmarshal workspace: %w", err) + } + } + + return s, nil +} diff --git a/internal/chat/sqlite/store_test.go b/internal/chat/sqlite/store_test.go new file mode 100644 index 00000000..2aae4a75 --- /dev/null +++ b/internal/chat/sqlite/store_test.go @@ -0,0 +1,154 @@ +package sqlite + +import ( + "context" + "fmt" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mhersson/contextmatrix/internal/chat" +) + +func TestOpen_CreatesSchema(t *testing.T) { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "chats.db") + s, err := Open(dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = s.Close() }) + + ctx := context.Background() + sess := chat.Session{ + ID: chat.NewID(), + Title: "test", + Project: "alpha", + Status: chat.StatusCold, + CreatedAt: time.Now().UTC().Truncate(time.Second), + LastActive: time.Now().UTC().Truncate(time.Second), + CreatedBy: "human:web-abc", + } + require.NoError(t, s.CreateSession(ctx, sess)) + + got, err := s.GetSession(ctx, sess.ID) + require.NoError(t, err) + assert.Equal(t, sess.Title, got.Title) + assert.Equal(t, sess.Project, got.Project) + assert.Equal(t, sess.Status, got.Status) + assert.Equal(t, sess.CreatedBy, got.CreatedBy) +} + +func TestOpen_IsIdempotent(t *testing.T) { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "chats.db") + s1, err := Open(dbPath) + require.NoError(t, err) + require.NoError(t, s1.Close()) + + s2, err := Open(dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = s2.Close() }) +} + +func TestAppendAndList_Messages(t *testing.T) { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "chats.db") + s, err := Open(dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = s.Close() }) + + ctx := context.Background() + sess := chat.Session{ + ID: chat.NewID(), Title: "t", Status: chat.StatusActive, + CreatedAt: time.Now().UTC(), LastActive: time.Now().UTC(), + CreatedBy: "human:web-x", + } + require.NoError(t, s.CreateSession(ctx, sess)) + + for i, body := range []string{"hello", "world", "claude"} { + seq, err := s.AppendMessage(ctx, chat.Message{ + SessionID: sess.ID, + Seq: int64(i + 1), + Role: chat.RoleUser, + Content: `{"text":"` + body + `"}`, + CreatedAt: time.Now().UTC(), + }) + require.NoError(t, err) + assert.Equal(t, int64(i+1), seq) + } + + msgs, err := s.ListMessages(ctx, sess.ID, 0, 100) + require.NoError(t, err) + require.Len(t, msgs, 3) + assert.Equal(t, int64(1), msgs[0].Seq) + assert.Equal(t, int64(3), msgs[2].Seq) + + msgs2, err := s.ListMessages(ctx, sess.ID, 1, 100) + require.NoError(t, err) + require.Len(t, msgs2, 2) + assert.Equal(t, int64(2), msgs2[0].Seq) +} + +func TestDeleteSession_CascadesMessages(t *testing.T) { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "chats.db") + s, err := Open(dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = s.Close() }) + + ctx := context.Background() + sess := chat.Session{ + ID: chat.NewID(), Title: "t", Status: chat.StatusCold, + CreatedAt: time.Now().UTC(), LastActive: time.Now().UTC(), CreatedBy: "x", + } + require.NoError(t, s.CreateSession(ctx, sess)) + _, err = s.AppendMessage(ctx, chat.Message{SessionID: sess.ID, Seq: 1, Role: chat.RoleUser, Content: "{}", CreatedAt: time.Now().UTC()}) + require.NoError(t, err) + + require.NoError(t, s.DeleteSession(ctx, sess.ID)) + msgs, err := s.ListMessages(ctx, sess.ID, 0, 100) + require.NoError(t, err) + assert.Empty(t, msgs) +} + +func TestStore_ListMessagesTail_ReturnsNewestNInChronologicalOrder(t *testing.T) { + t.Parallel() + dbPath := filepath.Join(t.TempDir(), "chats.db") + store, err := Open(dbPath) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + ctx := context.Background() + sessionID := chat.NewID() + require.NoError(t, store.CreateSession(ctx, chat.Session{ + ID: sessionID, + Title: "tail-test", + Status: chat.StatusCold, + CreatedAt: time.Now().UTC().Truncate(time.Second), + LastActive: time.Now().UTC().Truncate(time.Second), + CreatedBy: "human:test", + })) + + // Insert 50 messages with seq 1..50. + for i := 1; i <= 50; i++ { + _, err := store.AppendMessage(ctx, chat.Message{ + SessionID: sessionID, + Seq: int64(i), + Role: chat.RoleUser, + Content: fmt.Sprintf(`{"text":"m%d"}`, i), + CreatedAt: time.Now().UTC().Truncate(time.Second), + }) + require.NoError(t, err) + } + + msgs, err := store.ListMessagesTail(ctx, sessionID, 10) + require.NoError(t, err) + require.Len(t, msgs, 10) + + // Newest 10 are seq 41..50, returned in ASC order. + for i, m := range msgs { + require.Equal(t, int64(41+i), m.Seq, "row %d", i) + } +} diff --git a/internal/chat/sse.go b/internal/chat/sse.go new file mode 100644 index 00000000..7b3c548d --- /dev/null +++ b/internal/chat/sse.go @@ -0,0 +1,238 @@ +package chat + +import ( + "fmt" + "sync" + "time" +) + +// SSEEventKind is the discriminator on SSEEvent. The wire format uses +// SSE's "event:" header to route different kinds to different client-side +// listeners. +type SSEEventKind string + +// maxSubscribersPerSession is the maximum number of concurrent SSE subscribers +// allowed per chat session. A single browser tab opens one subscriber; the +// cap is generous enough for legitimate concurrent tabs/devtools but prevents +// unbounded growth from leaky clients. +const maxSubscribersPerSession = 32 + +const ( + // SSEKindMessage is a transcript message append; Seq/Role/Content carry it. + SSEKindMessage SSEEventKind = "message" + // SSEKindSessionUpdate is a session metadata change (context_tokens, + // rehydration_active, model). SessionUpdate carries the payload. + SSEKindSessionUpdate SSEEventKind = "session_updated" +) + +// SSEEvent is the shape pushed to browser subscribers. The default kind +// (empty) is treated as a message event for backwards compatibility with +// the original SSE wire (which had no kind discriminator). +type SSEEvent struct { + Kind SSEEventKind `json:"-"` + Seq int64 `json:"seq,omitempty"` + Role Role `json:"role,omitempty"` + Content string `json:"content,omitempty"` + RehydrationPhase bool `json:"rehydration_phase,omitempty"` + SessionUpdate *SessionUpdate `json:"session_update,omitempty"` +} + +// SessionUpdate is the payload of an SSEKindSessionUpdate event. Zero-valued +// fields mean "unchanged" — the client merges these into its session view. +type SessionUpdate struct { + ContextTokens int64 `json:"context_tokens,omitempty"` + ContextTokensUpdatedAt time.Time `json:"context_tokens_updated_at,omitempty"` + Model string `json:"model,omitempty"` + RehydrationActive *bool `json:"rehydration_active,omitempty"` +} + +type subscriber struct { + ch chan SSEEvent +} + +type sessionHub struct { + mu sync.Mutex + ring []SSEEvent + cap int + subs map[*subscriber]struct{} +} + +// SSEHub manages per-session ring buffers and subscriber fan-out. +type SSEHub struct { + mu sync.Mutex + bufCap int + perSess map[string]*sessionHub + + // Callbacks (optional). Fired outside the hub's internal mutex but the + // sessionHub mutex is held during the count check, so the callbacks see + // a consistent snapshot of "last subscriber departed" / "first or Nth + // subscriber arrived". + OnLastUnsubscribe func(sessionID string) + OnSubscribe func(sessionID string) +} + +// NewSSEHub creates an SSEHub with the given ring buffer capacity per session. +func NewSSEHub(bufCap int) *SSEHub { + return &SSEHub{bufCap: bufCap, perSess: make(map[string]*sessionHub)} +} + +func (h *SSEHub) hub(sessionID string) *sessionHub { + h.mu.Lock() + defer h.mu.Unlock() + + sh, ok := h.perSess[sessionID] + if !ok { + sh = &sessionHub{cap: h.bufCap, subs: make(map[*subscriber]struct{})} + h.perSess[sessionID] = sh + } + + return sh +} + +// Publish appends an event to the session's ring buffer and pushes to all +// subscribers. Slow subscribers drop events rather than blocking the producer. +// +// The fan-out runs under the per-session lock so a concurrent Unsubscribe / +// Drop cannot close a subscriber channel between our copy-out and our send. +// Sends are non-blocking (buffered channels + default branch), so holding the +// lock here does not couple producer throughput to subscriber draining. +func (h *SSEHub) Publish(sessionID string, e SSEEvent) { + if e.Kind == "" { + e.Kind = SSEKindMessage + } + + sh := h.hub(sessionID) + sh.mu.Lock() + defer sh.mu.Unlock() + + // Only persistent transcript events go into the replay ring. Session + // updates are pure state push — late subscribers should fetch fresh + // state via GET /api/chats/{id} rather than seeing a stale update. + if e.Kind == SSEKindMessage { + sh.ring = append(sh.ring, e) + if len(sh.ring) > sh.cap { + sh.ring = sh.ring[len(sh.ring)-sh.cap:] + } + } + + for s := range sh.subs { + select { + case s.ch <- e: + default: + // slow subscriber — drop rather than block the producer + } + } +} + +// PublishSessionUpdate fans out a session-metadata change event. Convenience +// wrapper around Publish that sets the kind discriminator. +func (h *SSEHub) PublishSessionUpdate(sessionID string, u SessionUpdate) { + h.Publish(sessionID, SSEEvent{ + Kind: SSEKindSessionUpdate, + SessionUpdate: &u, + }) +} + +// Subscribe returns a channel for live events and the replay slice of buffered +// events with Seq > sinceSeq. The returned channel must be passed back to +// Unsubscribe to release resources. +// +// Returns an error if the per-session subscriber count would exceed +// maxSubscribersPerSession. This caps memory and goroutine growth from leaky +// clients without affecting normal browser usage (one tab = one subscriber). +// +// The OnSubscribe callback fires inside the per-session lock so that the +// (Unsubscribe → OnLastUnsubscribe → Subscribe → OnSubscribe) sequence is +// strict — a fast resubscribe cannot race the prior unsubscribe's callback +// and leave a stale grace timer. +func (h *SSEHub) Subscribe(sessionID string, sinceSeq int64) (<-chan SSEEvent, []SSEEvent, error) { + sh := h.hub(sessionID) + sh.mu.Lock() + defer sh.mu.Unlock() + + if len(sh.subs) >= maxSubscribersPerSession { + return nil, nil, fmt.Errorf("chat: sse: session %q subscriber cap (%d) reached", sessionID, maxSubscribersPerSession) + } + + s := &subscriber{ch: make(chan SSEEvent, 64)} + sh.subs[s] = struct{}{} + + var replay []SSEEvent + + for _, e := range sh.ring { + if e.Seq > sinceSeq { + replay = append(replay, e) + } + } + + if h.OnSubscribe != nil { + h.OnSubscribe(sessionID) + } + + return s.ch, replay, nil +} + +// Drop unconditionally releases the per-session ring buffer and closes every +// live subscriber. Used by Manager.DeleteSession so the hub's memory does not +// grow without bound across session churn. Idempotent: a second Drop, or a +// Drop on a never-seen session, is a no-op. +func (h *SSEHub) Drop(sessionID string) { + h.mu.Lock() + sh, ok := h.perSess[sessionID] + + if !ok { + h.mu.Unlock() + + return + } + + delete(h.perSess, sessionID) + h.mu.Unlock() + + sh.mu.Lock() + defer sh.mu.Unlock() + + for s := range sh.subs { + close(s.ch) + delete(sh.subs, s) + } + + sh.ring = nil +} + +// Unsubscribe removes a subscriber and closes its channel. +// +// Lookup-only: if the session has already been Drop'd, the perSess entry is +// gone and this is a no-op. (The streamChat HTTP handler defers Unsubscribe +// after Subscribe; when a concurrent DeleteSession runs Drop, the handler's +// receive loop sees the channel close and returns, then the deferred +// Unsubscribe arrives — without this lookup-only guard it would resurrect +// the per-session entry and defeat Drop's cleanup.) +// +// OnLastUnsubscribe fires inside the per-session lock — see Subscribe for +// the ordering rationale. +func (h *SSEHub) Unsubscribe(sessionID string, ch <-chan SSEEvent) { + h.mu.Lock() + sh, ok := h.perSess[sessionID] + h.mu.Unlock() + + if !ok { + return + } + + sh.mu.Lock() + defer sh.mu.Unlock() + + for s := range sh.subs { + if s.ch == ch { + delete(sh.subs, s) + close(s.ch) + + if len(sh.subs) == 0 && h.OnLastUnsubscribe != nil { + h.OnLastUnsubscribe(sessionID) + } + + return + } + } +} diff --git a/internal/chat/sse_internal_test.go b/internal/chat/sse_internal_test.go new file mode 100644 index 00000000..c008abcd --- /dev/null +++ b/internal/chat/sse_internal_test.go @@ -0,0 +1,243 @@ +package chat + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestSSEHub_Drop_RemovesSessionEntry exercises the explicit +// lifecycle method: once Drop fires for a sessionID, the per-session +// ring buffer + subscriber map is freed (perSess no longer holds it). +func TestSSEHub_Drop_RemovesSessionEntry(t *testing.T) { + hub := NewSSEHub(8) + hub.Publish("S1", SSEEvent{Seq: 1, Content: "x"}) + + hub.mu.Lock() + _, present := hub.perSess["S1"] + hub.mu.Unlock() + require.True(t, present, "publish must create the per-session entry") + + hub.Drop("S1") + + hub.mu.Lock() + _, present = hub.perSess["S1"] + hub.mu.Unlock() + assert.False(t, present, "Drop must remove the per-session entry") +} + +// TestSSEHub_Drop_ClosesLiveSubscribers ensures that a Drop while a subscriber +// is still attached closes the subscriber's channel so the SSE handler can +// exit promptly (it shouldn't be left blocked on a channel that nothing will +// ever publish to again). +func TestSSEHub_Drop_ClosesLiveSubscribers(t *testing.T) { + hub := NewSSEHub(8) + ch, _, _ := hub.Subscribe("S1", 0) + + hub.Drop("S1") + + // Channel must be closed (a closed channel yields the zero value with ok=false). + _, ok := <-ch + assert.False(t, ok, "Drop must close subscriber channels") +} + +// TestSSEHub_Drop_Idempotent ensures Drop on a session that was never seen +// or has already been dropped is a no-op. +func TestSSEHub_Drop_Idempotent(t *testing.T) { + hub := NewSSEHub(8) + + hub.Drop("never-existed") // must not panic + hub.Publish("S1", SSEEvent{Seq: 1}) + hub.Drop("S1") + hub.Drop("S1") // second drop is a no-op +} + +// TestSSEHub_PerSessNoLeak_ManyDrops verifies that after a flurry of +// Subscribe + Publish + Drop + Unsubscribe cycles the perSess map size returns +// to zero. The earlier version of this test never exercised Subscribe before +// Drop, which made it vacuous: Unsubscribe used to lazy-create a fresh perSess +// entry, defeating Drop. The Subscribe + deferred-Unsubscribe pattern is what +// the streamChat HTTP handler actually does. +func TestSSEHub_PerSessNoLeak_ManyDrops(t *testing.T) { + hub := NewSSEHub(8) + + for i := range 100 { + id := "session-" + itoa(i) + ch, _, _ := hub.Subscribe(id, 0) + hub.Publish(id, SSEEvent{Seq: 1, Content: "x"}) + hub.Drop(id) + // Simulate the handler's deferred Unsubscribe firing AFTER Drop closed + // the channel. With a lazy-create Unsubscribe this resurrects perSess[id]. + hub.Unsubscribe(id, ch) + } + + hub.mu.Lock() + defer hub.mu.Unlock() + + assert.Empty(t, hub.perSess, "perSess must be empty after every session dropped") +} + +// TestSSEHub_UnsubscribeAfterDropIsNoOp verifies the specific defer-after-Drop +// pattern from the streamChat handler: Drop closes the subscriber channel and +// removes the perSess entry; the handler's deferred Unsubscribe must NOT +// resurrect the entry. +func TestSSEHub_UnsubscribeAfterDropIsNoOp(t *testing.T) { + hub := NewSSEHub(8) + ch, _, _ := hub.Subscribe("S-drop", 0) + + hub.Drop("S-drop") + + // Drop should have closed ch already. + _, ok := <-ch + require.False(t, ok, "Drop must close subscriber channel") + + hub.Unsubscribe("S-drop", ch) + + hub.mu.Lock() + _, present := hub.perSess["S-drop"] + hub.mu.Unlock() + assert.False(t, present, "Unsubscribe after Drop must not resurrect the perSess entry") +} + +// TestSSEHub_PublishUnsubscribeRace_NoPanic stresses Publish against concurrent +// Subscribe/Unsubscribe to catch the send-on-closed-channel race. Without the +// fix, Publish copies subscribers under the lock then sends after releasing it; +// an interleaved Unsubscribe closes the channel between those two steps and +// the subsequent send panics. With the fan-out held under the per-session lock, +// the panic is impossible. +func TestSSEHub_PublishUnsubscribeRace_NoPanic(t *testing.T) { + hub := NewSSEHub(64) + + stop := make(chan struct{}) + + var wg sync.WaitGroup + + for range 4 { + wg.Add(1) + + go func() { + defer wg.Done() + + for { + select { + case <-stop: + return + default: + hub.Publish("S-race", SSEEvent{Seq: 1, Content: "x"}) + } + } + }() + } + + for range 4 { + wg.Add(1) + + go func() { + defer wg.Done() + + for { + select { + case <-stop: + return + default: + ch, _, _ := hub.Subscribe("S-race", 0) + hub.Unsubscribe("S-race", ch) + } + } + }() + } + + time.Sleep(200 * time.Millisecond) + close(stop) + wg.Wait() +} + +// TestSSEHub_OnLastUnsubscribeHoldsLockAgainstSubscribe is the regression +// test for the grace-timer race: a concurrent Subscribe must be blocked while +// OnLastUnsubscribe is running. Without the fix the callback fired outside +// the per-session mutex, so a fast resubscribe could fire OnSubscribe BEFORE +// the OnLastUnsubscribe that was supposed to seed the grace timer to be +// cancelled by that OnSubscribe — leaving a stale 30s timer. +func TestSSEHub_OnLastUnsubscribeHoldsLockAgainstSubscribe(t *testing.T) { + hub := NewSSEHub(8) + + gate := make(chan struct{}) + resumed := make(chan struct{}) + subscribeFired := make(chan struct{}, 1) + + hub.OnLastUnsubscribe = func(string) { + // Hold the per-session lock long enough for a racing Subscribe to + // arrive. If callbacks fire outside the lock, the racing Subscribe + // proceeds immediately and the assertion below fails. + <-gate + close(resumed) + } + + var watching bool + + hub.OnSubscribe = func(string) { + if !watching { + return + } + + select { + case subscribeFired <- struct{}{}: + default: + } + } + + ch, _, _ := hub.Subscribe("S1", 0) + watching = true + + done := make(chan struct{}) + + go func() { + hub.Unsubscribe("S1", ch) + close(done) + }() + + // Wait until Unsubscribe is well inside OnLastUnsubscribe (a brief sleep + // is enough; the goroutine will be parked at <-gate). + deadline := time.NewTimer(50 * time.Millisecond) + defer deadline.Stop() + + go func() { + <-deadline.C + // Racing Subscribe — it MUST block until OnLastUnsubscribe returns. + _, _, _ = hub.Subscribe("S1", 0) + }() + + // Give the racing Subscribe a chance to register if it weren't blocked. + select { + case <-subscribeFired: + t.Fatal("Subscribe callback fired while OnLastUnsubscribe was still running — callbacks are not under the per-session lock") + case <-time.After(100 * time.Millisecond): + // Good — racing Subscribe is properly blocked. + } + + close(gate) + <-resumed + <-done +} + +func itoa(i int) string { + if i == 0 { + return "0" + } + + var ( + buf [20]byte + pos = len(buf) + ) + + for i > 0 { + pos-- + buf[pos] = byte('0' + i%10) + i /= 10 + } + + return string(buf[pos:]) +} diff --git a/internal/chat/sse_test.go b/internal/chat/sse_test.go new file mode 100644 index 00000000..abe7d096 --- /dev/null +++ b/internal/chat/sse_test.go @@ -0,0 +1,146 @@ +package chat_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mhersson/contextmatrix/internal/chat" +) + +func TestSSEHub_PublishFansOut(t *testing.T) { + hub := chat.NewSSEHub(128) + sub1, _, _ := hub.Subscribe("S1", 0) + sub2, _, _ := hub.Subscribe("S1", 0) + + t.Cleanup(func() { hub.Unsubscribe("S1", sub1); hub.Unsubscribe("S1", sub2) }) + + hub.Publish("S1", chat.SSEEvent{Seq: 1, Role: chat.RoleUser, Content: "{}"}) + + got1, ok := readSSEWithTimeout(t, sub1, 500*time.Millisecond) + require.True(t, ok, "sub1 must receive event") + assert.Equal(t, int64(1), got1.Seq) + got2, ok := readSSEWithTimeout(t, sub2, 500*time.Millisecond) + require.True(t, ok, "sub2 must receive event") + assert.Equal(t, int64(1), got2.Seq) +} + +func TestSSEHub_ReplaysFromSinceSeq(t *testing.T) { + hub := chat.NewSSEHub(128) + hub.Publish("S1", chat.SSEEvent{Seq: 1, Content: "a"}) + hub.Publish("S1", chat.SSEEvent{Seq: 2, Content: "b"}) + hub.Publish("S1", chat.SSEEvent{Seq: 3, Content: "c"}) + + sub, replay, _ := hub.Subscribe("S1", 1) + + t.Cleanup(func() { hub.Unsubscribe("S1", sub) }) + require.Len(t, replay, 2, "replay should include seq 2, 3 (since_seq=1)") + assert.Equal(t, int64(2), replay[0].Seq) + assert.Equal(t, int64(3), replay[1].Seq) +} + +func TestSSEHub_RingBufferEvicts(t *testing.T) { + hub := chat.NewSSEHub(2) + hub.Publish("S1", chat.SSEEvent{Seq: 1}) + hub.Publish("S1", chat.SSEEvent{Seq: 2}) + hub.Publish("S1", chat.SSEEvent{Seq: 3}) + + _, replay, _ := hub.Subscribe("S1", 0) + require.Len(t, replay, 2) + assert.Equal(t, int64(2), replay[0].Seq) + assert.Equal(t, int64(3), replay[1].Seq) +} + +func TestSSEHub_PerSessionIsolation(t *testing.T) { + hub := chat.NewSSEHub(128) + subA, _, _ := hub.Subscribe("A", 0) + subB, _, _ := hub.Subscribe("B", 0) + + t.Cleanup(func() { hub.Unsubscribe("A", subA); hub.Unsubscribe("B", subB) }) + + hub.Publish("A", chat.SSEEvent{Seq: 1, Content: "from-A"}) + + _, ok := readSSEWithTimeout(t, subA, 200*time.Millisecond) + assert.True(t, ok, "A subscriber must receive A's events") + _, ok2 := readSSEWithTimeout(t, subB, 100*time.Millisecond) + assert.False(t, ok2, "B subscriber must not receive A's events") +} + +func TestSSEHub_SubscriberCap(t *testing.T) { + t.Parallel() + + hub := chat.NewSSEHub(128) + + var chs []<-chan chat.SSEEvent + + for i := range 32 { + ch, _, err := hub.Subscribe("cap-session", 0) + require.NoError(t, err, "subscriber %d should succeed", i) + + chs = append(chs, ch) + } + + t.Cleanup(func() { + for _, ch := range chs { + hub.Unsubscribe("cap-session", ch) + } + }) + + _, _, err := hub.Subscribe("cap-session", 0) + require.Error(t, err, "33rd subscriber must be rejected") + require.Contains(t, err.Error(), "subscriber cap") +} + +func readSSEWithTimeout(t *testing.T, ch <-chan chat.SSEEvent, d time.Duration) (chat.SSEEvent, bool) { + t.Helper() + + select { + case e := <-ch: + return e, true + case <-time.After(d): + return chat.SSEEvent{}, false + } +} + +func TestSSEHub_LastUnsubscribe_FiresCallback(t *testing.T) { + hub := chat.NewSSEHub(64) + + var fired bool + + hub.OnLastUnsubscribe = func(sessionID string) { + if sessionID == "S1" { + fired = true + } + } + sub, _, _ := hub.Subscribe("S1", 0) + hub.Unsubscribe("S1", sub) + assert.True(t, fired, "callback should fire on last unsubscribe") +} + +func TestSSEHub_AdditionalSubsDontFire(t *testing.T) { + hub := chat.NewSSEHub(64) + + var fires int + + hub.OnLastUnsubscribe = func(string) { fires++ } + sub1, _, _ := hub.Subscribe("S1", 0) + sub2, _, _ := hub.Subscribe("S1", 0) + hub.Unsubscribe("S1", sub1) + assert.Equal(t, 0, fires, "callback fires only on last subscriber leaving") + hub.Unsubscribe("S1", sub2) + assert.Equal(t, 1, fires) +} + +func TestSSEHub_OnSubscribe_Fires(t *testing.T) { + hub := chat.NewSSEHub(64) + + var got string + + hub.OnSubscribe = func(sessionID string) { got = sessionID } + sub, _, _ := hub.Subscribe("S7", 0) + + t.Cleanup(func() { hub.Unsubscribe("S7", sub) }) + assert.Equal(t, "S7", got) +} diff --git a/internal/chat/store.go b/internal/chat/store.go new file mode 100644 index 00000000..e8907f7b --- /dev/null +++ b/internal/chat/store.go @@ -0,0 +1,64 @@ +package chat + +import ( + "context" + "errors" + "time" +) + +// ErrSessionNotFound is returned when a session ID has no row. +var ErrSessionNotFound = errors.New("chat: session not found") + +// Store persists chat sessions and messages. Implementations must be +// safe for concurrent use. +type Store interface { + CreateSession(ctx context.Context, s Session) error + GetSession(ctx context.Context, id string) (Session, error) + ListSessions(ctx context.Context, filter SessionFilter) ([]Session, error) + UpdateSession(ctx context.Context, s Session) error + DeleteSession(ctx context.Context, id string) error + + // SetRehydrationActive flips the rehydration_active flag on a session + // row without rewriting the rest of the columns. Returns + // ErrSessionNotFound if no row matches. + SetRehydrationActive(ctx context.Context, sessionID string, active bool) error + + // UpdateContextTokens stamps the context-window usage from the most + // recent Claude turn onto the session row. updatedAt is the runner-side + // timestamp of the usage event. Returns ErrSessionNotFound if no row + // matches. + UpdateContextTokens(ctx context.Context, sessionID string, tokens int64, updatedAt time.Time) error + + AppendMessage(ctx context.Context, m Message) (int64, error) + ListMessages(ctx context.Context, sessionID string, sinceSeq int64, limit int) ([]Message, error) + + // ListMessagesTail returns the newest limit messages for sessionID in + // chronological (ASC) order. Used by buildResume so rehydration payloads + // reflect recent context rather than oldest. limit <= 0 returns nil. + ListMessagesTail(ctx context.Context, sessionID string, limit int) ([]Message, error) + + // MaxSeq returns the largest seq for a session, or 0 if no messages exist. + // Used by the Manager to seed monotonic seq assignment after restart + // without scanning the full transcript. + MaxSeq(ctx context.Context, sessionID string) (int64, error) + + Close() error +} + +// SessionFilter narrows ListSessions. +type SessionFilter struct { + Project string + Status Status + CreatedBy string + // RehydrationActive, when non-nil, restricts results to rows where + // rehydration_active matches the pointed value. Used by the reaper + // to find sessions whose rehydration phase needs forcing off. + RehydrationActive *bool + // LastActiveBefore, when non-zero, restricts results to rows where + // last_active is strictly older than this time. Used by the reaper + // alongside RehydrationActive to find stale phases. + LastActiveBefore time.Time + // Limit, when > 0, caps the number of rows returned (ORDER BY + // last_active DESC). Zero means no limit. + Limit int +} diff --git a/internal/chat/transcript/transcript.go b/internal/chat/transcript/transcript.go new file mode 100644 index 00000000..d6173374 --- /dev/null +++ b/internal/chat/transcript/transcript.go @@ -0,0 +1,325 @@ +// Package transcript builds the rehydration payload sent to the runner on a +// cold-reopen. It filters role-typed messages from SQLite into the bounded +// shape Claude can ingest from /run/cm-chat/resume.jsonl: drop noise (thinking, +// stderr, system, prior rehydration turns), summarise tool_result bodies, and +// truncate to fit a configurable token budget while always preserving the +// first user turn (the original goal) and the last K turns (recent context). +// +// The package is intentionally free of dependencies on the surrounding chat +// package — Build operates on its own Message type so callers can convert +// once and unit-test the filtering rules in isolation. +package transcript + +import "strings" + +const ( + // DefaultBudgetTokens is the fallback when BuildOpts.BudgetTokens is zero. + DefaultBudgetTokens = 40000 + + // MaxTurns is the absolute upper bound on turns in a single resume. + // Acts as a hard cap so a runaway transcript never produces an + // unbounded payload even when the budget would allow more. + MaxTurns = 500 + + // MaxContentBytes caps each ResumeTurn.Content. The transcript pkg + // is the last line of defence; manager.go already caps persisted + // content at 32KB, but a 32KB tool_call line is still wasteful for + // rehydration purposes. + MaxContentBytes = 4 * 1024 + + // AlwaysKeepLastK ensures the most recent turns survive truncation + // so the agent always sees how the conversation actually ended. + AlwaysKeepLastK = 20 + + // truncationMarker is appended to ResumeTurn.Content when it exceeds + // MaxContentBytes. Suffix is part of the cap. + truncationMarker = " … [truncated]" +) + +// Role string constants used by Message inputs and ResumeTurn outputs. +const ( + RoleUser = "user" + RoleAssistantText = "assistant_text" + RoleAssistantThinking = "assistant_thinking" + RoleToolCall = "tool_call" + RoleToolResult = "tool_result" + RoleToolResultSummary = "tool_result_summary" + RoleStderr = "stderr" + RoleSystem = "system" +) + +// Message is one persisted transcript entry in the input to Build. It +// mirrors the load-bearing fields of chat.Message; callers convert their +// type into this one before invoking Build. +type Message struct { + Seq int64 + Role string + Content string + RehydrationPhase bool +} + +// ResumeContext is the rehydration payload CM passes to the runner on a +// cold-open. The runner writes it to /run/cm-chat/resume.jsonl inside the +// container; the entrypoint instructs Claude to read it before greeting +// the operator. +type ResumeContext struct { + Turns []ResumeTurn `json:"turns"` + Clipped bool `json:"clipped"` + OrigSeq int64 `json:"original_seq"` +} + +// ResumeTurn is one filtered, possibly summarized transcript entry in the +// rehydration payload. Roles: "user", "assistant_text", "tool_call", +// "tool_result_summary" (tool_result bodies are collapsed to a one-liner +// outcome by the transcript builder). +type ResumeTurn struct { + Seq int64 `json:"seq"` + Role string `json:"role"` + Content string `json:"content"` +} + +// BuildOpts carries the knobs the manager passes from config. +type BuildOpts struct { + // BudgetTokens caps the rough token-count estimate of the produced + // payload. Zero means use DefaultBudgetTokens. + BudgetTokens int +} + +// Build assembles a ResumeContext from a chronological transcript slice. +// Returns nil when there is nothing worth resuming (empty input or every +// message filtered out). Caller is expected to treat nil as "skip the +// rehydration path; start a fresh agent.". +func Build(msgs []Message, opts BuildOpts) *ResumeContext { + if len(msgs) == 0 { + return nil + } + + budget := opts.BudgetTokens + if budget <= 0 { + budget = DefaultBudgetTokens + } + + origSeq := msgs[len(msgs)-1].Seq + + turns := make([]ResumeTurn, 0, len(msgs)) + + for _, m := range msgs { + turn, ok := filterMessage(m) + if !ok { + continue + } + + turns = append(turns, turn) + } + + if len(turns) == 0 { + return nil + } + + clipped := false + + turns, hardClipped := applyHardTurnCap(turns) + if hardClipped { + clipped = true + } + + turns, budgetClipped := applyBudget(turns, budget) + if budgetClipped { + clipped = true + } + + return &ResumeContext{ + Turns: turns, + Clipped: clipped, + OrigSeq: origSeq, + } +} + +// filterMessage maps one persisted Message to a ResumeTurn, or skips it. +// Roles dropped: assistant_thinking, stderr, system. Rehydration-phase +// messages are always dropped (anti-pollution on the 2nd+ reopen). +func filterMessage(m Message) (ResumeTurn, bool) { + if m.RehydrationPhase { + return ResumeTurn{}, false + } + + switch m.Role { + case RoleUser, RoleAssistantText, RoleToolCall: + return ResumeTurn{ + Seq: m.Seq, + Role: m.Role, + Content: capContent(m.Content), + }, true + + case RoleToolResult: + return ResumeTurn{ + Seq: m.Seq, + Role: RoleToolResultSummary, + Content: summarizeToolResult(m.Content), + }, true + } + + return ResumeTurn{}, false +} + +// applyHardTurnCap enforces MaxTurns by preserving the first user turn (if +// any) and the most recent MaxTurns-1 turns. +func applyHardTurnCap(turns []ResumeTurn) ([]ResumeTurn, bool) { + if len(turns) <= MaxTurns { + return turns, false + } + + firstUserIdx := indexOfFirstUser(turns) + tailStart := len(turns) - MaxTurns + // Reserve one slot at the front for the first user turn so we can + // always include it. + if firstUserIdx >= 0 && firstUserIdx < tailStart { + tailStart = len(turns) - (MaxTurns - 1) + + return append([]ResumeTurn{turns[firstUserIdx]}, turns[tailStart:]...), true + } + + // No user turn in the early section, or the first user turn is + // already inside the kept tail — just take the tail. + return turns[tailStart:], true +} + +// applyBudget drops oldest "middle" turns until the rough token estimate +// fits within budget. The first user turn and the last AlwaysKeepLastK +// turns are pinned. If the pinned set alone exceeds budget, we accept +// that — never refuse to build. +func applyBudget(turns []ResumeTurn, budget int) ([]ResumeTurn, bool) { + if budget <= 0 { + return turns, false + } + + total := 0 + for _, t := range turns { + total += estimateTokens(t.Content) + } + + if total <= budget { + return turns, false + } + + firstUserIdx := indexOfFirstUser(turns) + + keepLastFrom := max(len(turns)-AlwaysKeepLastK, 0) + + dropped := make(map[int]bool) + + for i := range turns { + if total <= budget { + break + } + + if i == firstUserIdx { + continue + } + + if i >= keepLastFrom { + break // we've hit the always-kept tail; stop dropping. + } + + dropped[i] = true + total -= estimateTokens(turns[i].Content) + } + + if len(dropped) == 0 { + return turns, false + } + + out := make([]ResumeTurn, 0, len(turns)-len(dropped)) + + for i, t := range turns { + if dropped[i] { + continue + } + + out = append(out, t) + } + + return out, true +} + +// indexOfFirstUser returns the position of the first "user" role turn, or -1 +// if there is none. +func indexOfFirstUser(turns []ResumeTurn) int { + for i, t := range turns { + if t.Role == RoleUser { + return i + } + } + + return -1 +} + +// estimateTokens returns a rough token count for a content string. Anthropic +// tokens average ~4 bytes for English prose; we use ceil-divide so empty +// strings stay at zero but a single character counts as one token. +func estimateTokens(s string) int { + if s == "" { + return 0 + } + + return (len(s) + 3) / 4 +} + +// capContent enforces the per-turn MaxContentBytes cap, appending the +// truncation marker. Truncation respects UTF-8 rune boundaries so the +// marker is not glued onto a partial multi-byte sequence. +func capContent(s string) string { + if len(s) <= MaxContentBytes { + return s + } + + cut := max(MaxContentBytes-len(truncationMarker), 0) + + // Back up to a rune start. + for cut > 0 && (s[cut]&0xC0) == 0x80 { + cut-- + } + + return s[:cut] + truncationMarker +} + +// summarizeToolResult collapses a tool_result body into a one-liner. The +// agent does not need to re-see the original payload — it can re-Read / +// re-run the producing tool_call if the content matters. The summary +// preserves the load-bearing signal (success vs. failure) and the tail +// of any error text. +func summarizeToolResult(content string) string { + s := strings.TrimSpace(content) + if s == "" { + return "→ ok" + } + + if looksLikeError(s) { + tail := s + if len(tail) > 200 { + tail = tail[len(tail)-200:] + } + + return "→ failed: " + tail + } + + return "→ ok" +} + +// looksLikeError heuristically classifies a tool_result body as a failure. +// We err on the side of "ok" — a noisy success that happens to contain the +// word "error" is preferable to mis-labelling a clean success as failure. +func looksLikeError(s string) bool { + lower := strings.ToLower(s) + for _, needle := range []string{ + "error:", "fatal:", "exit code 1", "exit code 2", "exit code 3", + "exit status 1", "exit status 2", "exit status 3", + "permission denied", "not found", "no such file", + } { + if strings.Contains(lower, needle) { + return true + } + } + + return false +} diff --git a/internal/chat/transcript/transcript_test.go b/internal/chat/transcript/transcript_test.go new file mode 100644 index 00000000..93456b91 --- /dev/null +++ b/internal/chat/transcript/transcript_test.go @@ -0,0 +1,253 @@ +package transcript + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuild_EmptyTranscriptReturnsNil(t *testing.T) { + got := Build(nil, BuildOpts{}) + assert.Nil(t, got, "no messages should produce no ResumeContext") + + got = Build([]Message{}, BuildOpts{}) + assert.Nil(t, got, "empty slice should produce no ResumeContext") +} + +func TestBuild_FiltersByRole(t *testing.T) { + in := []Message{ + {Seq: 1, Role: RoleUser, Content: "hello"}, + {Seq: 2, Role: RoleAssistantThinking, Content: "thinking aloud"}, + {Seq: 3, Role: RoleAssistantText, Content: "hi back"}, + {Seq: 4, Role: RoleToolCall, Content: "Bash: ls"}, + {Seq: 5, Role: RoleToolResult, Content: "file1\nfile2\n"}, + {Seq: 6, Role: RoleStderr, Content: "container plumbing"}, + {Seq: 7, Role: RoleSystem, Content: "system boilerplate"}, + } + + got := Build(in, BuildOpts{}) + require.NotNil(t, got) + + roles := rolesOf(got.Turns) + assert.Equal(t, + []string{"user", "assistant_text", "tool_call", "tool_result_summary"}, + roles, + "only user/assistant_text/tool_call/tool_result_summary should survive") +} + +func TestBuild_ToolResultSummarized_OK(t *testing.T) { + in := []Message{ + {Seq: 1, Role: RoleUser, Content: "list files"}, + {Seq: 2, Role: RoleToolCall, Content: "Bash: ls"}, + {Seq: 3, Role: RoleToolResult, Content: strings.Repeat("a", 5000)}, + } + + got := Build(in, BuildOpts{}) + require.NotNil(t, got) + require.Len(t, got.Turns, 3) + assert.Equal(t, RoleToolResultSummary, got.Turns[2].Role) + assert.Equal(t, "→ ok", got.Turns[2].Content) +} + +func TestBuild_ToolResultSummarized_Failed(t *testing.T) { + in := []Message{ + {Seq: 1, Role: RoleUser, Content: "do thing"}, + {Seq: 2, Role: RoleToolCall, Content: "Bash: gh repo clone foo/bar"}, + {Seq: 3, Role: RoleToolResult, Content: "fatal: error: repository not found"}, + } + + got := Build(in, BuildOpts{}) + require.NotNil(t, got) + require.Len(t, got.Turns, 3) + assert.True(t, strings.HasPrefix(got.Turns[2].Content, "→ failed:"), + "tool_result with error indicator should start with '→ failed:'; got %q", got.Turns[2].Content) + assert.Contains(t, got.Turns[2].Content, "repository not found") +} + +func TestBuild_RehydrationPhaseExcluded(t *testing.T) { + in := []Message{ + {Seq: 1, Role: RoleUser, Content: "original turn"}, + {Seq: 2, Role: RoleAssistantText, Content: "rehydration narration", RehydrationPhase: true}, + {Seq: 3, Role: RoleToolCall, Content: "Bash: git clone", RehydrationPhase: true}, + {Seq: 4, Role: RoleToolResult, Content: "ok", RehydrationPhase: true}, + {Seq: 5, Role: RoleAssistantText, Content: "real reply"}, + {Seq: 6, Role: RoleUser, Content: "follow up"}, + } + + got := Build(in, BuildOpts{}) + require.NotNil(t, got) + + seqs := seqsOf(got.Turns) + assert.Equal(t, []int64{1, 5, 6}, seqs, + "rehydration_phase=TRUE messages must be excluded from the resume payload") +} + +func TestBuild_AlwaysKeepsFirstUserAndLastK(t *testing.T) { + in := make([]Message, 0, 50) + in = append(in, Message{Seq: 1, Role: RoleUser, Content: "the original goal"}) + + for seq := int64(2); seq <= 50; seq++ { + role := RoleAssistantText + if seq%2 == 0 { + role = RoleUser + } + + in = append(in, Message{Seq: seq, Role: role, Content: strings.Repeat("x", 1000)}) + } + + got := Build(in, BuildOpts{BudgetTokens: 6000}) + require.NotNil(t, got) + require.True(t, got.Clipped, "should mark Clipped when truncated") + + seqs := seqsOf(got.Turns) + assert.Equal(t, int64(1), seqs[0], "first user turn must be preserved at position 0") + + lastTwenty := seqs[len(seqs)-20:] + + expectedLast := make([]int64, 0, 20) + for s := int64(31); s <= 50; s++ { + expectedLast = append(expectedLast, s) + } + + assert.Equal(t, expectedLast, lastTwenty, "last 20 turns must be preserved") +} + +func TestBuild_TruncatesOverBudget(t *testing.T) { + in := make([]Message, 0, 600) + + for seq := int64(1); seq <= 600; seq++ { + role := RoleAssistantText + if seq == 1 { + role = RoleUser + } + + in = append(in, Message{Seq: seq, Role: role, Content: strings.Repeat("y", 400)}) + } + + got := Build(in, BuildOpts{BudgetTokens: 40000}) + require.NotNil(t, got) + require.True(t, got.Clipped) + + totalTokens := 0 + for _, turn := range got.Turns { + totalTokens += estimateTokens(turn.Content) + } + + assert.LessOrEqual(t, totalTokens, 40000, "kept turns must fit within the budget") + + require.NotEmpty(t, got.Turns) + assert.Equal(t, int64(1), got.Turns[0].Seq) + assert.Equal(t, int64(600), got.Turns[len(got.Turns)-1].Seq) +} + +func TestBuild_HardTurnCap(t *testing.T) { + in := make([]Message, 0, 700) + in = append(in, Message{Seq: 1, Role: RoleUser, Content: "first goal"}) + + for seq := int64(2); seq <= 700; seq++ { + in = append(in, Message{Seq: seq, Role: RoleAssistantText, Content: "tiny"}) + } + + got := Build(in, BuildOpts{BudgetTokens: 10_000_000}) + require.NotNil(t, got) + require.True(t, got.Clipped, "should mark Clipped at hard turn cap") + assert.LessOrEqual(t, len(got.Turns), MaxTurns, "must respect MaxTurns hard cap") + assert.Equal(t, int64(1), got.Turns[0].Seq, "first user turn must be preserved at the hard cap") +} + +func TestBuild_HardContentSizeCap(t *testing.T) { + huge := strings.Repeat("z", MaxContentBytes*2) + + in := []Message{ + {Seq: 1, Role: RoleUser, Content: "ok"}, + {Seq: 2, Role: RoleAssistantText, Content: huge}, + } + + got := Build(in, BuildOpts{}) + require.NotNil(t, got) + require.Len(t, got.Turns, 2) + assert.LessOrEqual(t, len(got.Turns[1].Content), MaxContentBytes, + "per-content hard cap must be enforced") + assert.Contains(t, got.Turns[1].Content, truncationMarker) +} + +func TestBuild_OrigSeqIsLastInputSeq(t *testing.T) { + in := []Message{ + {Seq: 5, Role: RoleUser, Content: "a"}, + {Seq: 9, Role: RoleAssistantText, Content: "b"}, + {Seq: 42, Role: RoleUser, Content: "c"}, + } + + got := Build(in, BuildOpts{}) + require.NotNil(t, got) + assert.Equal(t, int64(42), got.OrigSeq, "OrigSeq must equal the max seq of the input") +} + +func TestBuild_AllRehydrationPhase_ReturnsNil(t *testing.T) { + msgs := []Message{ + {Seq: 1, Role: RoleAssistantThinking, Content: "x", RehydrationPhase: true}, + {Seq: 2, Role: RoleToolCall, Content: "y", RehydrationPhase: true}, + } + out := Build(msgs, BuildOpts{BudgetTokens: 1000}) + assert.Nil(t, out, "all-rehydration-phase messages should produce nil") +} + +func TestBuild_ExactlyAtBudget(t *testing.T) { + // Two small messages well within budget. Clipped should be false. + msgs := []Message{ + {Seq: 1, Role: RoleUser, Content: "hello"}, + {Seq: 2, Role: RoleAssistantText, Content: "hi"}, + } + out := Build(msgs, BuildOpts{BudgetTokens: 1_000_000}) + require.NotNil(t, out) + require.Len(t, out.Turns, 2) + assert.False(t, out.Clipped, "messages well within budget should not be clipped") +} + +func TestBuild_FirstUserAndLastKCollision(t *testing.T) { + // 5 messages: msg 1 is user; the K=20 last-tail trivially includes all 5 + // including msg 1. Verify msg 1 appears exactly once (no duplication from the + // first-user pin). + msgs := []Message{ + {Seq: 1, Role: RoleUser, Content: "first user"}, + {Seq: 2, Role: RoleAssistantText, Content: "a"}, + {Seq: 3, Role: RoleUser, Content: "b"}, + {Seq: 4, Role: RoleAssistantText, Content: "c"}, + {Seq: 5, Role: RoleUser, Content: "d"}, + } + out := Build(msgs, BuildOpts{BudgetTokens: 1_000_000}) + require.NotNil(t, out) + require.Len(t, out.Turns, 5) + + // Verify Seq 1 appears exactly once. + firstUserCount := 0 + + for _, m := range out.Turns { + if m.Seq == 1 { + firstUserCount++ + } + } + + assert.Equal(t, 1, firstUserCount, + "first user message must appear exactly once (no duplication from pin)") +} + +func rolesOf(turns []ResumeTurn) []string { + out := make([]string, len(turns)) + for i, t := range turns { + out[i] = t.Role + } + + return out +} + +func seqsOf(turns []ResumeTurn) []int64 { + out := make([]int64, len(turns)) + for i, t := range turns { + out[i] = t.Seq + } + + return out +} diff --git a/internal/chat/types.go b/internal/chat/types.go new file mode 100644 index 00000000..07020a3c --- /dev/null +++ b/internal/chat/types.go @@ -0,0 +1,129 @@ +package chat + +import ( + "crypto/rand" + "encoding/base32" + "encoding/binary" + "time" + + "github.com/mhersson/contextmatrix/internal/chat/transcript" +) + +// Status is the lifecycle state of a chat session. +type Status string + +const ( + StatusCold Status = "cold" + StatusActive Status = "active" + StatusWarmIdle Status = "warm-idle" + StatusEnding Status = "ending" +) + +func (s Status) String() string { return string(s) } + +// ParseStatus reports whether s is a valid Status. +func ParseStatus(s string) (Status, bool) { + switch Status(s) { + case StatusCold, StatusActive, StatusWarmIdle, StatusEnding: + return Status(s), true + } + + return "", false +} + +// Role is the kind of message in a transcript. +type Role string + +const ( + RoleUser Role = "user" + RoleAssistantText Role = "assistant_text" + RoleAssistantThinking Role = "assistant_thinking" + RoleToolCall Role = "tool_call" + RoleToolResult Role = "tool_result" + RoleStderr Role = "stderr" + RoleSystem Role = "system" +) + +// Session is the persisted shape of a chat session row. +type Session struct { + ID string `json:"id"` + Title string `json:"title"` + Project string `json:"project,omitempty"` + Status Status `json:"status"` + CreatedAt time.Time `json:"created_at"` + LastActive time.Time `json:"last_active"` + CreatedBy string `json:"created_by"` + ContainerID string `json:"container_id,omitempty"` + Workspace []string `json:"workspace,omitempty"` + Model string `json:"model,omitempty"` + ContextTokens int64 `json:"context_tokens,omitempty"` + ContextTokensUpdatedAt time.Time `json:"context_tokens_updated_at"` + RehydrationActive bool `json:"rehydration_active,omitempty"` +} + +// Message is a single persisted transcript entry. +type Message struct { + ID int64 `json:"id"` + SessionID string `json:"session_id"` + Seq int64 `json:"seq"` + Role Role `json:"role"` + Content string `json:"content"` // JSON envelope, opaque to the store + CreatedAt time.Time `json:"created_at"` + RehydrationPhase bool `json:"rehydration_phase,omitempty"` +} + +// LogEntry is a parsed event from the runner's /logs SSE stream. The Type +// values mirror the runner's logbroadcast.LogEntry.Type vocabulary: "text", +// "thinking", "tool_call", "stderr", "system", "user", "usage". The chat +// package translates Type → Role when bridging into the transcript. "usage" +// entries are metadata (Claude stream-json usage block) and carry token +// counts in Usage; they do NOT become transcript entries. +type LogEntry struct { + Timestamp time.Time + Type string + Content string + Usage *TokenUsage + Model string +} + +// TokenUsage carries the per-turn context window accounting reported by +// Claude in its stream-json output. The sum of all four fields approximates +// the prompt size Claude actually processed; the UI typically displays +// InputTokens + CacheReadTokens + CacheCreateTokens as "context used.". +type TokenUsage struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` + CacheCreateTokens int64 `json:"cache_creation_tokens"` +} + +// ResumeContext is the rehydration payload CM passes to the runner on a +// cold-open. The runner writes it to /run/cm-chat/resume.jsonl inside the +// container; the entrypoint instructs Claude to read it before greeting +// the operator. Defined in the transcript subpackage so its filtering +// logic stays free of import cycles; aliased here for the rest of the +// chat package and external callers. +type ResumeContext = transcript.ResumeContext + +// ResumeTurn is one filtered, possibly summarized transcript entry in the +// rehydration payload. +type ResumeTurn = transcript.ResumeTurn + +// NewID returns a 26-char ULID-shaped identifier (48-bit Unix-millis prefix +// + 80 random bits, encoded with the standard base32 alphabet). It is +// time-sortable but uses RFC4648 base32 (A-Z2-7), not Crockford. +func NewID() string { + var b [16]byte + + // Encode the 48-bit Unix millisecond timestamp into the leading 6 bytes + // in big-endian order. We build an 8-byte buffer and skip the top two + // bytes (which are unused for timestamps well past Y2K). + var tsBuf [8]byte + binary.BigEndian.PutUint64(tsBuf[:], uint64(time.Now().UnixMilli())) + copy(b[0:6], tsBuf[2:]) + + _, _ = rand.Read(b[6:]) + enc := base32.StdEncoding.WithPadding(base32.NoPadding) + + return enc.EncodeToString(b[:])[:26] +} diff --git a/internal/chat/types_test.go b/internal/chat/types_test.go new file mode 100644 index 00000000..b61e2259 --- /dev/null +++ b/internal/chat/types_test.go @@ -0,0 +1,41 @@ +package chat + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStatus_String(t *testing.T) { + assert.Equal(t, "cold", StatusCold.String()) + assert.Equal(t, "active", StatusActive.String()) + assert.Equal(t, "warm-idle", StatusWarmIdle.String()) + assert.Equal(t, "ending", StatusEnding.String()) +} + +func TestStatus_Parse(t *testing.T) { + cases := []struct { + in string + want Status + ok bool + }{ + {"cold", StatusCold, true}, + {"active", StatusActive, true}, + {"warm-idle", StatusWarmIdle, true}, + {"ending", StatusEnding, true}, + {"garbage", "", false}, + } + for _, tc := range cases { + got, ok := ParseStatus(tc.in) + assert.Equal(t, tc.ok, ok, "parse %q", tc.in) + + if ok { + assert.Equal(t, tc.want, got) + } + } +} + +func TestNewID_IsULID(t *testing.T) { + id := NewID() + assert.Len(t, id, 26, "ULID is 26 chars in Crockford base32") +} diff --git a/internal/config/config.go b/internal/config/config.go index 79138a6c..0d1dc7a6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -125,6 +125,52 @@ type TaskSkillsConfig struct { GitRemoteURL string `yaml:"git_remote_url"` } +// ChatConfig configures the global chat panel feature. +type ChatConfig struct { + // DBPath is the SQLite file path for chat sessions and transcripts. + // Defaults to /contextmatrix/chats.db, falling back to + // ~/.local/state/contextmatrix/chats.db. + DBPath string `yaml:"db_path"` + + // IdleTTL is how long a chat container survives after the browser + // disconnects. Default: 1h. + IdleTTL time.Duration `yaml:"idle_ttl"` + + // MaxConcurrent caps the number of simultaneously-running chat + // containers. Default: 5. + MaxConcurrent int `yaml:"max_concurrent"` + + // DefaultModel is the Claude model ID used when a chat is created + // without an explicit selection. Must be a key in Models. Default: + // "claude-sonnet-4-6". + DefaultModel string `yaml:"default_model"` + + // Models is the allowlist of selectable models for new chats, keyed + // by model ID. The values carry the human label shown in the picker + // and the context-window denominator used by the UI usage indicator. + Models map[string]ChatModelConfig `yaml:"models"` + + // ResumeBudgetTokens caps the rough token estimate the transcript + // builder will fit into the rehydration payload on cold-reopen. + // Default: 40000. + ResumeBudgetTokens int `yaml:"resume_budget_tokens"` + + // RehydrationTimeout forces the per-session rehydration phase off + // after this duration, even if the agent never called + // chat_rehydration_complete and the user never typed. Default: 10m. + RehydrationTimeout time.Duration `yaml:"rehydration_timeout"` +} + +// ChatModelConfig is one entry in ChatConfig.Models. +type ChatModelConfig struct { + // Label is the human-readable name shown in the picker, e.g. "Sonnet 4.6". + Label string `yaml:"label"` + + // MaxTokens is the context-window denominator used by the UI usage + // indicator. The picker also surfaces it (e.g. "(200k context)"). + MaxTokens int64 `yaml:"max_tokens"` +} + // Config holds the application configuration. type Config struct { Port int `yaml:"port"` @@ -149,6 +195,7 @@ type Config struct { LogLevel string `yaml:"log_level"` // "debug"/"info"/"warn"/"error", default "info" AdminPort int `yaml:"admin_port"` // 0 = disabled AdminBindAddr string `yaml:"admin_bind_addr"` // listen address for admin server (pprof + /metrics); default "127.0.0.1" + Chat ChatConfig `yaml:"chat"` } // defaults returns a Config with default values. @@ -337,6 +384,48 @@ func (c *Config) Validate() error { c.AdminBindAddr = "127.0.0.1" } + // Chat: applyChatDefaults turns zero values into safe defaults during + // Load. Run it again here so callers that bypass Load (tests, embedded + // uses) still get defaults applied; the function is idempotent. Then + // reject negatives — a negative IdleTTL would have the reaper end every + // session immediately and a negative MaxConcurrent would reject every + // open. + applyChatDefaults(c) + + if c.Chat.IdleTTL <= 0 { + return fmt.Errorf("chat.idle_ttl must be positive (got %s)", c.Chat.IdleTTL) + } + + if c.Chat.MaxConcurrent < 0 { + return fmt.Errorf("chat.max_concurrent must be >= 0 (got %d)", c.Chat.MaxConcurrent) + } + + if c.Chat.ResumeBudgetTokens < 0 { + return fmt.Errorf("chat.resume_budget_tokens must be >= 0 (got %d)", c.Chat.ResumeBudgetTokens) + } + + if c.Chat.RehydrationTimeout <= 0 { + return fmt.Errorf("chat.rehydration_timeout must be positive (got %s)", c.Chat.RehydrationTimeout) + } + + if c.Chat.DefaultModel == "" { + return fmt.Errorf("chat.default_model is required") + } + + if _, ok := c.Chat.Models[c.Chat.DefaultModel]; !ok { + return fmt.Errorf("chat.default_model %q is not in chat.models", c.Chat.DefaultModel) + } + + for id, m := range c.Chat.Models { + if m.Label == "" { + return fmt.Errorf("chat.models[%q].label is required", id) + } + + if m.MaxTokens <= 0 { + return fmt.Errorf("chat.models[%q].max_tokens must be positive (got %d)", id, m.MaxTokens) + } + } + return nil } @@ -374,6 +463,7 @@ func Load(path string) (*Config, error) { data, err := os.ReadFile(path) if err != nil { if os.IsNotExist(err) { + applyChatDefaults(cfg) applyEnvOverrides(cfg) if err := resolvePaths(cfg, path); err != nil { @@ -394,6 +484,7 @@ func Load(path string) (*Config, error) { return nil, fmt.Errorf("parse config: %w", err) } + applyChatDefaults(cfg) applyEnvOverrides(cfg) if err := resolvePaths(cfg, path); err != nil { @@ -441,6 +532,47 @@ func resolvePaths(cfg *Config, configPath string) error { return nil } +// applyChatDefaults sets Chat fields that were not supplied by YAML. +func applyChatDefaults(cfg *Config) { + if cfg.Chat.IdleTTL == 0 { + cfg.Chat.IdleTTL = time.Hour + } + + if cfg.Chat.MaxConcurrent == 0 { + cfg.Chat.MaxConcurrent = 8 + } + + if cfg.Chat.DBPath == "" { + state := os.Getenv("XDG_STATE_HOME") + if state == "" { + home, _ := os.UserHomeDir() + state = filepath.Join(home, ".local", "state") + } + + cfg.Chat.DBPath = filepath.Join(state, "contextmatrix", "chats.db") + } + + if cfg.Chat.ResumeBudgetTokens == 0 { + cfg.Chat.ResumeBudgetTokens = 40000 + } + + if cfg.Chat.RehydrationTimeout == 0 { + cfg.Chat.RehydrationTimeout = 10 * time.Minute + } + + if len(cfg.Chat.Models) == 0 { + cfg.Chat.Models = map[string]ChatModelConfig{ + "claude-sonnet-4-6": {Label: "Sonnet 4.6", MaxTokens: 1000000}, + "claude-opus-4-7": {Label: "Opus 4.7", MaxTokens: 1000000}, + "claude-haiku-4-5-20251001": {Label: "Haiku 4.5", MaxTokens: 200000}, + } + } + + if cfg.Chat.DefaultModel == "" { + cfg.Chat.DefaultModel = "claude-sonnet-4-6" + } +} + // applyEnvOverrides applies environment variable overrides to the config. func applyEnvOverrides(cfg *Config) { if v := os.Getenv("CONTEXTMATRIX_PORT"); v != "" { @@ -602,6 +734,26 @@ func applyEnvOverrides(cfg *Config) { if v := os.Getenv("CONTEXTMATRIX_ADMIN_BIND_ADDR"); v != "" { cfg.AdminBindAddr = v } + + if v := os.Getenv("CONTEXTMATRIX_CHAT_DB_PATH"); v != "" { + cfg.Chat.DBPath = v + } + + if v := os.Getenv("CONTEXTMATRIX_CHAT_IDLE_TTL"); v != "" { + if d, err := time.ParseDuration(v); err == nil { + cfg.Chat.IdleTTL = d + } else { + slog.Warn("ignoring invalid CONTEXTMATRIX_CHAT_IDLE_TTL", "value", v, "error", err) + } + } + + if v := os.Getenv("CONTEXTMATRIX_CHAT_MAX_CONCURRENT"); v != "" { + if n, err := strconv.Atoi(v); err == nil { + cfg.Chat.MaxConcurrent = n + } else { + slog.Warn("ignoring invalid CONTEXTMATRIX_CHAT_MAX_CONCURRENT", "value", v, "error", err) + } + } } // HeartbeatDuration parses HeartbeatTimeout as a time.Duration. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 6348696a..9cae7880 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -422,6 +422,55 @@ func TestValidate_InvalidHeartbeatTimeout(t *testing.T) { } } +func TestValidate_RejectsNegativeChatIdleTTL(t *testing.T) { + cfg := &Config{ + Boards: BoardsConfig{Dir: "/some/path"}, + HeartbeatTimeout: "30m", + GitHub: GitHubConfig{AuthMode: "pat", PAT: GitHubPATConfig{Token: "x"}}, + Chat: ChatConfig{IdleTTL: -time.Minute, MaxConcurrent: 5}, + } + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "chat.idle_ttl") +} + +func TestValidate_AcceptsZeroChatIdleTTL(t *testing.T) { + // Zero IdleTTL means "use the default" — applyChatDefaults bumps it + // inside Validate so callers that bypass Load still get the default. + cfg := &Config{ + Boards: BoardsConfig{Dir: "/some/path"}, + HeartbeatTimeout: "30m", + GitHub: GitHubConfig{AuthMode: "pat", PAT: GitHubPATConfig{Token: "x"}}, + Chat: ChatConfig{IdleTTL: 0, MaxConcurrent: 5}, + } + require.NoError(t, cfg.Validate()) + assert.Equal(t, time.Hour, cfg.Chat.IdleTTL, "Validate must apply the default IdleTTL") +} + +func TestValidate_RejectsNegativeChatMaxConcurrent(t *testing.T) { + cfg := &Config{ + Boards: BoardsConfig{Dir: "/some/path"}, + HeartbeatTimeout: "30m", + GitHub: GitHubConfig{AuthMode: "pat", PAT: GitHubPATConfig{Token: "x"}}, + Chat: ChatConfig{IdleTTL: time.Hour, MaxConcurrent: -1}, + } + err := cfg.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "chat.max_concurrent") +} + +func TestValidate_AcceptsZeroChatMaxConcurrent(t *testing.T) { + // MaxConcurrent=0 means "unlimited" per the existing applyChatDefaults + // semantics; only negative values are rejected. + cfg := &Config{ + Boards: BoardsConfig{Dir: "/some/path"}, + HeartbeatTimeout: "30m", + GitHub: GitHubConfig{AuthMode: "pat", PAT: GitHubPATConfig{Token: "x"}}, + Chat: ChatConfig{IdleTTL: time.Hour, MaxConcurrent: 0}, + } + assert.NoError(t, cfg.Validate()) +} + func TestValidate_ValidConfig(t *testing.T) { cfg := &Config{ Boards: BoardsConfig{Dir: "/some/path"}, @@ -2150,3 +2199,107 @@ github: {auth_mode: "pat", pat: {token: "x"}} assert.Equal(t, "https://github.com/x/y.git", cfg.TaskSkills.GitRemoteURL) assert.True(t, cfg.TaskSkills.GitCloneOnEmpty) } + +// ---------- Chat config tests ---------- + +func TestLoadConfig_ChatDefaults(t *testing.T) { + t.Helper() + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.yaml") + require.NoError(t, os.WriteFile(cfgPath, []byte("boards:\n dir: /tmp/boards\ngithub:\n auth_mode: \"pat\"\n pat:\n token: \"ghp_test\"\n"), 0o644)) + + cfg, err := Load(cfgPath) + require.NoError(t, err) + assert.Equal(t, time.Hour, cfg.Chat.IdleTTL, "default idle TTL should be 1h") + assert.Equal(t, 8, cfg.Chat.MaxConcurrent, "default max concurrent should be 8 (multi-pane headroom)") + assert.NotEmpty(t, cfg.Chat.DBPath, "default db path should be derived") +} + +func TestLoadConfig_ChatEnvOverrides(t *testing.T) { + dir := t.TempDir() + boardsDir := t.TempDir() + path := writeConfigFile(t, dir, ` +boards: {dir: `+boardsDir+`} +github: {auth_mode: "pat", pat: {token: "x"}} +`) + + t.Setenv("CONTEXTMATRIX_CHAT_DB_PATH", "/var/lib/contextmatrix/chats.db") + t.Setenv("CONTEXTMATRIX_CHAT_IDLE_TTL", "30m") + t.Setenv("CONTEXTMATRIX_CHAT_MAX_CONCURRENT", "10") + + cfg, err := Load(path) + require.NoError(t, err) + + assert.Equal(t, "/var/lib/contextmatrix/chats.db", cfg.Chat.DBPath) + assert.Equal(t, 30*time.Minute, cfg.Chat.IdleTTL) + assert.Equal(t, 10, cfg.Chat.MaxConcurrent) +} + +func TestLoadConfig_ChatInvalidIdleTTL_Ignored(t *testing.T) { + dir := t.TempDir() + boardsDir := t.TempDir() + path := writeConfigFile(t, dir, ` +boards: {dir: `+boardsDir+`} +github: {auth_mode: "pat", pat: {token: "x"}} +`) + + t.Setenv("CONTEXTMATRIX_CHAT_IDLE_TTL", "notaduration") + + cfg, err := Load(path) + require.NoError(t, err) + // Should retain the default value since the env override was invalid. + assert.Equal(t, time.Hour, cfg.Chat.IdleTTL) +} + +func TestLoadConfig_ChatInvalidMaxConcurrent_Ignored(t *testing.T) { + dir := t.TempDir() + boardsDir := t.TempDir() + path := writeConfigFile(t, dir, ` +boards: {dir: `+boardsDir+`} +github: {auth_mode: "pat", pat: {token: "x"}} +`) + + t.Setenv("CONTEXTMATRIX_CHAT_MAX_CONCURRENT", "abc") + + cfg, err := Load(path) + require.NoError(t, err) + // Should retain the default value since the env override was invalid. + assert.Equal(t, 8, cfg.Chat.MaxConcurrent) +} + +func TestLoadConfig_ChatYAML(t *testing.T) { + dir := t.TempDir() + boardsDir := t.TempDir() + path := writeConfigFile(t, dir, ` +boards: {dir: `+boardsDir+`} +github: {auth_mode: "pat", pat: {token: "x"}} +chat: + db_path: /custom/chats.db + idle_ttl: 2h + max_concurrent: 3 +`) + + cfg, err := Load(path) + require.NoError(t, err) + + assert.Equal(t, "/custom/chats.db", cfg.Chat.DBPath) + assert.Equal(t, 2*time.Hour, cfg.Chat.IdleTTL) + assert.Equal(t, 3, cfg.Chat.MaxConcurrent) +} + +func TestLoadConfig_ChatDBPath_XDGStateHome(t *testing.T) { + dir := t.TempDir() + boardsDir := t.TempDir() + path := writeConfigFile(t, dir, ` +boards: {dir: `+boardsDir+`} +github: {auth_mode: "pat", pat: {token: "x"}} +`) + + stateDir := filepath.Join(dir, "state") + t.Setenv("XDG_STATE_HOME", stateDir) + + cfg, err := Load(path) + require.NoError(t, err) + + assert.Equal(t, filepath.Join(stateDir, "contextmatrix", "chats.db"), cfg.Chat.DBPath) +} From 53d235cf77859c644847236b88f7e083e6131907 Mon Sep 17 00:00:00 2001 From: Morten Hersson Date: Fri, 15 May 2026 06:49:27 +0200 Subject: [PATCH 2/6] feat(api): chat REST endpoints and MCP rehydration tool - /api/chats CRUD + messages, SSE stream, end/reopen/delete - chat_rehydration_complete MCP tool gated on X-CM-Chat-Session - Runner reconciliation on startup syncs tracker against active sessions - Chat live-data fan-out via in-process bus to SSE consumers --- cmd/contextmatrix/main.go | 214 ++++++++++- internal/api/chats.go | 469 ++++++++++++++++++++++++ internal/api/chats_test.go | 494 ++++++++++++++++++++++++++ internal/api/router.go | 25 +- internal/mcp/chat_rehydration.go | 73 ++++ internal/mcp/chat_rehydration_test.go | 149 ++++++++ internal/mcp/mcpcontext/mcpcontext.go | 24 ++ internal/mcp/server.go | 42 ++- internal/mcp/server_test.go | 2 +- internal/runner/client.go | 3 + internal/runner/endsession_test.go | 14 + internal/runner/reconcile.go | 142 +++++++- internal/runner/reconcile_test.go | 229 +++++++++++- test/integration/client_test.go | 54 +++ test/integration/scenarios_test.go | 88 +++++ 15 files changed, 1983 insertions(+), 39 deletions(-) create mode 100644 internal/api/chats.go create mode 100644 internal/api/chats_test.go create mode 100644 internal/mcp/chat_rehydration.go create mode 100644 internal/mcp/chat_rehydration_test.go create mode 100644 internal/mcp/mcpcontext/mcpcontext.go diff --git a/cmd/contextmatrix/main.go b/cmd/contextmatrix/main.go index f07cf2ec..e0cf2f9a 100644 --- a/cmd/contextmatrix/main.go +++ b/cmd/contextmatrix/main.go @@ -25,6 +25,8 @@ import ( githubauth "github.com/mhersson/contextmatrix-githubauth" "github.com/mhersson/contextmatrix/internal/api" + "github.com/mhersson/contextmatrix/internal/chat" + chatsqlite "github.com/mhersson/contextmatrix/internal/chat/sqlite" "github.com/mhersson/contextmatrix/internal/clock" "github.com/mhersson/contextmatrix/internal/config" "github.com/mhersson/contextmatrix/internal/events" @@ -282,15 +284,142 @@ func main() { runner.StartEndSessionSubscriber(ctx, bus, svc, runnerClient, slog.Default()) slog.Info("end-session subscriber started") + } + + // Chat: SQLite store + manager + SSE hub + idle reaper + warm-idle grace timer. + chatStore, err := chatsqlite.Open(cfg.Chat.DBPath) + if err != nil { + slog.Error("failed to open chat store", "path", cfg.Chat.DBPath, "error", err) + cancel() + os.Exit(1) //nolint:gocritic // cancel called explicitly above + } + defer chatStore.Close() + + slog.Info("chat store opened", "path", cfg.Chat.DBPath) + + var chatRunner chat.RunnerClient + if cfg.Runner.Enabled { + chatRunner = chat.NewRunnerClient(chat.RunnerClientConfig{ + BaseURL: cfg.Runner.URL, + HMACKey: cfg.Runner.APIKey, + MCPAPIKey: cfg.MCPAPIKey, + }) + } else { + // Nil runner causes nil-pointer panics at call sites. Use a no-op stub + // that returns an error on every operation — chat features require runner. + chatRunner = chatRunnerDisabled{} + } + + chatHub := chat.NewSSEHub(128) + + chatMgr := chat.NewManager(chat.Config{ + Store: chatStore, + Runner: chatRunner, + Clock: clock.Real(), + IdleTTL: cfg.Chat.IdleTTL, + MaxConcurrent: cfg.Chat.MaxConcurrent, + Hub: chatHub, + ResumeBudgetTokens: cfg.Chat.ResumeBudgetTokens, + RehydrationTimeout: cfg.Chat.RehydrationTimeout, + DefaultModel: cfg.Chat.DefaultModel, + ResolveRepoURL: func(rctx context.Context, project string) (string, error) { + p, err := svc.GetProject(rctx, project) + if err != nil { + return "", err + } + + if p.Repo != "" { + return p.Repo, nil + } + + repos := p.EffectiveRepos() + if len(repos) > 0 { + return repos[0].URL, nil + } + + return "", nil + }, + }) + go chat.NewIdleReaper(chatMgr, time.Minute).Run(ctx) + + // 30s grace timer: last subscriber drop → flip session to warm-idle. + // A new subscriber within 30s cancels the flip. + var graceTimers sync.Map // sessionID → *time.Timer + + chatHub.OnLastUnsubscribe = func(sessionID string) { + if existing, ok := graceTimers.LoadAndDelete(sessionID); ok { + existing.(*time.Timer).Stop() + } + + timer := time.AfterFunc(30*time.Second, func() { + // If the entry is still in the map it means no new subscriber + // arrived during the grace window — proceed with warm-idle. + if _, loaded := graceTimers.LoadAndDelete(sessionID); !loaded { + return + } + + if err := chatMgr.MarkWarmIdle(ctx, sessionID); err != nil { + slog.Warn("chat: warm-idle transition failed", "session_id", sessionID, "error", err) + } + }) + graceTimers.Store(sessionID, timer) + } + chatHub.OnSubscribe = func(sessionID string) { + if t, ok := graceTimers.LoadAndDelete(sessionID); ok { + t.(*time.Timer).Stop() + } + // A browser subscriber is a strong "I want this chat" signal. + // Reattach the runner-log consumer if one isn't already bridging + // /logs for this session — covers the case where CM restarted + // while runner containers stayed alive, stranding their consumer + // goroutines. No-op on cold/ending sessions. + if err := chatMgr.Reattach(ctx, sessionID); err != nil { + slog.Warn("chat: reattach on subscribe failed", + "session_id", sessionID, "error", err) + } + } + + slog.Info("chat manager initialized", "idle_ttl", cfg.Chat.IdleTTL, "max_concurrent", cfg.Chat.MaxConcurrent) + + // Resume runner-log consumers for sessions that survived a CM restart. + // Without this, active/warm-idle sessions stay marked alive in the DB + // while their consumer goroutines are gone (in-memory state lost), so + // the UI can't see runner output even though the container is still + // up. Reattach is idempotent and tolerant of dead containers — the + // consumer exits on first /logs error and the reconcile sweep below + // will flip orphaned sessions to cold. + go func() { + rctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + for _, status := range []chat.Status{chat.StatusActive, chat.StatusWarmIdle} { + sessions, err := chatMgr.ListSessions(rctx, chat.SessionFilter{Status: status}) + if err != nil { + slog.Warn("chat: startup reattach list failed", + "status", status, "error", err) + + continue + } + for _, s := range sessions { + if err := chatMgr.Reattach(rctx, s.ID); err != nil { + slog.Warn("chat: startup reattach failed", + "session_id", s.ID, "error", err) + } + } + } + }() + + // Card + chat reconcile sweep: a single ticker fetches /containers once + // per tick and feeds both reconcilers. Two separate tickers used to + // produce identically-signed HMAC GETs back to back; the runner's + // replay cache rejected the second as a duplicate. The chat reconciler + // flips active/warm-idle sessions whose runner container has + // disappeared (claude crash, runner restart, OOM, manual docker kill) + // to cold so the UI can reopen. + if cfg.Runner.Enabled { reconcileInterval := cfg.Runner.ReconcileIntervalDuration() - // The sweep takes the CardService (CardLookup) and the runner client - // (ReconcileClient: ListContainers + EndSession + Kill). It uses the - // runner's Docker state as the authoritative "is this container - // running?" input and the card store as the authoritative "should it - // be?" — see internal/runner/reconcile.go for why we no longer gate - // on card.runner_status. - runner.StartReconciliationSweep(ctx, svc, runnerClient, reconcileInterval, slog.Default()) + runner.StartReconciliationSweep(ctx, svc, chatReconcilerAdapter{mgr: chatMgr}, runnerClient, reconcileInterval, slog.Default()) if reconcileInterval > 0 { slog.Info("runner reconciliation sweep started", "interval", reconcileInterval) @@ -310,7 +439,7 @@ func main() { slog.Info("session log manager initialized") // Create MCP server - mcpSrv := mcpserver.NewServer(svc, cfg.WorkflowSkillsDir) + mcpSrv := mcpserver.NewServer(svc, cfg.WorkflowSkillsDir, chatMgr) mcpHandler := mcpserver.NewHandler(mcpSrv, cfg.MCPAPIKey) if cfg.MCPAPIKey != "" { @@ -344,6 +473,9 @@ func main() { Theme: cfg.Theme, Version: buildVersion(), MCPHandler: mcpHandler, + ChatManager: chatMgr, + ChatHub: chatHub, + ChatConfig: &cfg.Chat, }) slog.Info("MCP server registered", "endpoint", "/mcp") @@ -466,6 +598,15 @@ func main() { slog.Error("session manager shutdown error", "error", err) } + if chatMgr != nil { + chatCloseCtx, chatCloseCancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := chatMgr.Close(chatCloseCtx); err != nil { + slog.Warn("chat manager close failed", "error", err) + } + + chatCloseCancel() + } + // Phase 3: signal the rest of the app (timeout checker, syncers' // periodic loops, runner subscribers) to wind down. slog.Info("shutdown: phase=ctx_cancel") @@ -642,3 +783,60 @@ func newSPAHandler(apiHandler http.Handler, fsys fs.FS) http.Handler { fileServer.ServeHTTP(w, r) }) } + +// chatRunnerDisabled is a no-op RunnerClient used when the runner integration +// is disabled. Every operation returns an error so callers receive a clear +// "runner not enabled" message rather than a nil-pointer panic. +type chatRunnerDisabled struct{} + +func (chatRunnerDisabled) StartChat(_ context.Context, _ chat.StartChatOpts) (string, error) { + return "", fmt.Errorf("chat: runner not enabled") +} + +func (chatRunnerDisabled) EndChat(_ context.Context, _ string) error { + return fmt.Errorf("chat: runner not enabled") +} + +func (chatRunnerDisabled) SendChatMessage(_ context.Context, _, _, _ string) error { + return fmt.Errorf("chat: runner not enabled") +} + +func (chatRunnerDisabled) StreamLogs(ctx context.Context, _ string, _ func(chat.LogEntry)) error { + <-ctx.Done() + + return ctx.Err() +} + +// chatReconcilerAdapter adapts *chat.Manager to the runner.ChatReconciler +// surface. Keeps the chat package free of any runner-facing type while still +// letting the reconcile sweep enumerate orphan sessions and flip them cold. +type chatReconcilerAdapter struct { + mgr *chat.Manager +} + +func (a chatReconcilerAdapter) ListActiveChatSessions(ctx context.Context) ([]runner.ChatSessionRef, error) { + active, err := a.mgr.ListSessions(ctx, chat.SessionFilter{Status: chat.StatusActive}) + if err != nil { + return nil, fmt.Errorf("list active: %w", err) + } + + warm, err := a.mgr.ListSessions(ctx, chat.SessionFilter{Status: chat.StatusWarmIdle}) + if err != nil { + return nil, fmt.Errorf("list warm-idle: %w", err) + } + + out := make([]runner.ChatSessionRef, 0, len(active)+len(warm)) + for _, s := range active { + out = append(out, runner.ChatSessionRef{ID: s.ID, Status: string(s.Status)}) + } + + for _, s := range warm { + out = append(out, runner.ChatSessionRef{ID: s.ID, Status: string(s.Status)}) + } + + return out, nil +} + +func (a chatReconcilerAdapter) EndChatSession(ctx context.Context, id string) error { + return a.mgr.EndSession(ctx, id) +} diff --git a/internal/api/chats.go b/internal/api/chats.go new file mode 100644 index 00000000..426ca8ac --- /dev/null +++ b/internal/api/chats.go @@ -0,0 +1,469 @@ +package api + +import ( + "encoding/json" + "errors" + "net/http" + "sort" + "strconv" + "strings" + "time" + + "github.com/mhersson/contextmatrix/internal/chat" + "github.com/mhersson/contextmatrix/internal/config" + "github.com/mhersson/contextmatrix/internal/ctxlog" +) + +const ( + ErrCodeChatNotFound = "CHAT_NOT_FOUND" + ErrCodeTooManyChats = "TOO_MANY_CHATS" + ErrCodeInvalidModel = "INVALID_MODEL" +) + +type chatHandlers struct { + mgr *chat.Manager + hub *chat.SSEHub + chat *config.ChatConfig +} + +func newChatHandlers(mgr *chat.Manager, hub *chat.SSEHub, chatCfg *config.ChatConfig) *chatHandlers { + return &chatHandlers{mgr: mgr, hub: hub, chat: chatCfg} +} + +// agentIDForChat returns the caller identity, defaulting to "human:web" when +// the X-Agent-ID header is absent (same fallback pattern used elsewhere in api/). +func agentIDForChat(r *http.Request) string { + id := r.Header.Get("X-Agent-ID") + if id == "" { + return "human:web" + } + + return id +} + +func (h *chatHandlers) listChats(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + + f := chat.SessionFilter{ + Project: q.Get("project"), + CreatedBy: q.Get("created_by"), + Limit: listChatsDefaultLimit, + } + if st := q.Get("status"); st != "" { + s, ok := chat.ParseStatus(st) + if !ok { + writeError(w, http.StatusBadRequest, ErrCodeBadRequest, "invalid status", st) + + return + } + + f.Status = s + } + + if v := q.Get("limit"); v != "" { + n, err := strconv.Atoi(v) + if err != nil || n < 1 { + writeError(w, http.StatusBadRequest, ErrCodeBadRequest, "invalid limit", v) + + return + } + + if n > listChatsMaxLimit { + n = listChatsMaxLimit + } + + f.Limit = n + } + + sessions, err := h.mgr.ListSessions(r.Context(), f) + if err != nil { + handleChatError(w, r, err) + + return + } + // Always return a slice — never nil — so JSON serializes as [] not null. + if sessions == nil { + sessions = []chat.Session{} + } + + writeJSON(w, http.StatusOK, sessions) +} + +func (h *chatHandlers) createChat(w http.ResponseWriter, r *http.Request) { + var body struct { + Title string `json:"title"` + Project string `json:"project"` + Model string `json:"model"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, ErrCodeBadRequest, "invalid body", err.Error()) + + return + } + + model := body.Model + if model == "" && h.chat != nil { + model = h.chat.DefaultModel + } + + if model != "" && h.chat != nil { + if _, ok := h.chat.Models[model]; !ok { + writeError(w, http.StatusBadRequest, ErrCodeInvalidModel, + "model not in allowlist", model) + + return + } + } + + sess, err := h.mgr.CreateSession(r.Context(), chat.CreateInput{ + Title: body.Title, + Project: body.Project, + CreatedBy: agentIDForChat(r), + Model: model, + }) + if err != nil { + handleChatError(w, r, err) + + return + } + + writeJSON(w, http.StatusCreated, sess) +} + +// chatModelEntry is the picker-facing shape returned by listModels. +type chatModelEntry struct { + ID string `json:"id"` + Label string `json:"label"` + MaxTokens int64 `json:"max_tokens"` +} + +// listModels exposes the configured chat model allowlist + default for the +// frontend picker. Response shape mirrors what NewChatDialog consumes. +func (h *chatHandlers) listModels(w http.ResponseWriter, _ *http.Request) { + type response struct { + Models []chatModelEntry `json:"models"` + Default string `json:"default"` + } + + if h.chat == nil { + writeJSON(w, http.StatusOK, response{Models: []chatModelEntry{}, Default: ""}) + + return + } + + models := make([]chatModelEntry, 0, len(h.chat.Models)) + for id, m := range h.chat.Models { + models = append(models, chatModelEntry{ID: id, Label: m.Label, MaxTokens: m.MaxTokens}) + } + + sort.Slice(models, func(i, j int) bool { return models[i].ID < models[j].ID }) + + writeJSON(w, http.StatusOK, response{Models: models, Default: h.chat.DefaultModel}) +} + +func (h *chatHandlers) getChat(w http.ResponseWriter, r *http.Request) { + sess, err := h.mgr.GetSession(r.Context(), r.PathValue("id")) + if err != nil { + handleChatError(w, r, err) + + return + } + + writeJSON(w, http.StatusOK, sess) +} + +func (h *chatHandlers) deleteChat(w http.ResponseWriter, r *http.Request) { + if err := h.mgr.DeleteSession(r.Context(), r.PathValue("id")); err != nil { + handleChatError(w, r, err) + + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (h *chatHandlers) openChat(w http.ResponseWriter, r *http.Request) { + sess, err := h.mgr.OpenSession(r.Context(), r.PathValue("id")) + if err != nil { + handleChatError(w, r, err) + + return + } + + writeJSON(w, http.StatusOK, sess) +} + +func (h *chatHandlers) endChat(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if err := h.mgr.EndSession(r.Context(), id); err != nil { + handleChatError(w, r, err) + + return + } + + sess, err := h.mgr.GetSession(r.Context(), id) + if err != nil { + handleChatError(w, r, err) + + return + } + + writeJSON(w, http.StatusOK, sess) +} + +func (h *chatHandlers) sendMessage(w http.ResponseWriter, r *http.Request) { + var body struct { + Content string `json:"content"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, ErrCodeBadRequest, "invalid body", err.Error()) + + return + } + + if strings.TrimSpace(body.Content) == "" { + writeError(w, http.StatusBadRequest, ErrCodeBadRequest, "content is required", "") + + return + } + + if len(body.Content) > 8192 { + writeError(w, http.StatusRequestEntityTooLarge, ErrCodeContentTooLarge, "message too long", "") + + return + } + + msgID, err := h.mgr.SendUserMessage(r.Context(), r.PathValue("id"), body.Content) + if err != nil { + handleChatError(w, r, err) + + return + } + + writeJSON(w, http.StatusAccepted, map[string]any{"ok": true, "message_id": msgID}) +} + +const ( + listChatsDefaultLimit = 500 + listChatsMaxLimit = 5000 + + listMessagesDefaultLimit = 200 + listMessagesMaxLimit = 1000 +) + +func (h *chatHandlers) listMessages(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if _, err := h.mgr.GetSession(r.Context(), id); err != nil { + handleChatError(w, r, err) + + return + } + + q := r.URL.Query() + + var sinceSeq int64 + + if v := q.Get("since_seq"); v != "" { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil || n < 0 { + writeError(w, http.StatusBadRequest, ErrCodeBadRequest, "invalid since_seq", v) + + return + } + + sinceSeq = n + } + + limit := listMessagesDefaultLimit + + if v := q.Get("limit"); v != "" { + n, err := strconv.Atoi(v) + if err != nil || n <= 0 { + writeError(w, http.StatusBadRequest, ErrCodeBadRequest, "invalid limit", v) + + return + } + + if n > listMessagesMaxLimit { + n = listMessagesMaxLimit + } + + limit = n + } + + msgs, err := h.mgr.ListMessages(r.Context(), id, sinceSeq, limit) + if err != nil { + handleChatError(w, r, err) + + return + } + + if msgs == nil { + msgs = []chat.Message{} + } + + writeJSON(w, http.StatusOK, map[string]any{"messages": msgs}) +} + +func (h *chatHandlers) patchChat(w http.ResponseWriter, r *http.Request) { + var body struct { + Title *string `json:"title,omitempty"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, ErrCodeBadRequest, "invalid body", err.Error()) + + return + } + + sess, err := h.mgr.GetSession(r.Context(), r.PathValue("id")) + if err != nil { + handleChatError(w, r, err) + + return + } + + if body.Title != nil { + sess.Title = *body.Title + } + + if err := h.mgr.UpdateSessionMetadata(r.Context(), sess); err != nil { + handleChatError(w, r, err) + + return + } + + writeJSON(w, http.StatusOK, sess) +} + +func (h *chatHandlers) streamChat(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + + // Validate the session exists before subscribing — the hub lazy-creates a + // per-session ring buffer + subscriber set on first Subscribe, so an + // unguarded handler would let any GET against an unknown id permanently + // grow perSess. + if _, err := h.mgr.GetSession(r.Context(), id); err != nil { + handleChatError(w, r, err) + + return + } + + since, _ := strconv.ParseInt(r.URL.Query().Get("since_seq"), 10, 64) + + // Clear the server's WriteTimeout for this long-lived SSE connection; + // without this the connection is severed after the global write deadline. + if err := http.NewResponseController(w).SetWriteDeadline(time.Time{}); err != nil { + ctxlog.Logger(r.Context()).Warn("chat SSE: could not clear write deadline; connection will drop on WriteTimeout", + "error", err) + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + ch, replay, err := h.hub.Subscribe(id, since) + if err != nil { + writeError(w, http.StatusTooManyRequests, ErrCodeTooManyChats, err.Error(), "") + + return + } + + defer h.hub.Unsubscribe(id, ch) + + flusher, _ := w.(http.Flusher) + + // Flush a connected comment immediately so EventSource.onopen fires + // before any chat event is published. Critical for browsers behind + // proxies that buffer until the first body byte. + if _, err := w.Write([]byte(": connected\n\n")); err != nil { + return + } + + if flusher != nil { + flusher.Flush() + } + + for _, e := range replay { + writeChatSSEEvent(w, e) + } + + if flusher != nil { + flusher.Flush() + } + + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + for { + select { + case <-r.Context().Done(): + return + case <-ticker.C: + if _, err := w.Write([]byte(": keepalive\n\n")); err != nil { + return + } + + if flusher != nil { + flusher.Flush() + } + case e, ok := <-ch: + if !ok { + return + } + + writeChatSSEEvent(w, e) + + if flusher != nil { + flusher.Flush() + } + } + } +} + +// writeChatSSEEvent serialises one event onto the SSE wire. Different event +// kinds use the SSE "event:" header so the browser's EventSource can route +// to different listeners (transcript messages vs. session-state pushes). +// The default kind ("") is treated as the message wire (backwards-compatible +// with clients written before Wave 4 of the rehydration feature). +func writeChatSSEEvent(w http.ResponseWriter, e chat.SSEEvent) { + switch e.Kind { + case chat.SSEKindSessionUpdate: + _, _ = w.Write([]byte("event: session_updated\n")) + + if e.SessionUpdate != nil { + b, _ := json.Marshal(e.SessionUpdate) + _, _ = w.Write([]byte("data: ")) + _, _ = w.Write(b) + _, _ = w.Write([]byte("\n\n")) + } else { + _, _ = w.Write([]byte("data: {}\n\n")) + } + default: + // Backwards-compatible default: emit data without an event: header + // so older clients listening on the unnamed message stream keep + // working. rehydration_phase is included so the UI can group + // agent rehydration messages distinctly from normal turns. + b, _ := json.Marshal(struct { + Seq int64 `json:"seq"` + Role chat.Role `json:"role"` + Content string `json:"content"` + RehydrationPhase bool `json:"rehydration_phase,omitempty"` + }{Seq: e.Seq, Role: e.Role, Content: e.Content, RehydrationPhase: e.RehydrationPhase}) + _, _ = w.Write([]byte("data: ")) + _, _ = w.Write(b) + _, _ = w.Write([]byte("\n\n")) + } +} + +func handleChatError(w http.ResponseWriter, r *http.Request, err error) { + ctxlog.Logger(r.Context()).Error("chat error", "error", err) + + switch { + case errors.Is(err, chat.ErrSessionNotFound): + writeError(w, http.StatusNotFound, ErrCodeChatNotFound, "chat session not found", "") + case errors.Is(err, chat.ErrTooManyConcurrent): + writeError(w, http.StatusTooManyRequests, ErrCodeTooManyChats, "concurrent chat limit reached", "") + default: + writeError(w, http.StatusInternalServerError, ErrCodeInternalError, "internal error", "") + } +} diff --git a/internal/api/chats_test.go b/internal/api/chats_test.go new file mode 100644 index 00000000..5202fa33 --- /dev/null +++ b/internal/api/chats_test.go @@ -0,0 +1,494 @@ +package api + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mhersson/contextmatrix/internal/chat" + "github.com/mhersson/contextmatrix/internal/chat/sqlite" + "github.com/mhersson/contextmatrix/internal/clock" + "github.com/mhersson/contextmatrix/internal/config" +) + +type chatStubRunner struct{} + +func (chatStubRunner) StartChat(_ context.Context, opts chat.StartChatOpts) (string, error) { + return "container-" + opts.SessionID, nil +} + +func (chatStubRunner) EndChat(_ context.Context, _ string) error { return nil } +func (chatStubRunner) SendChatMessage(_ context.Context, _, _, _ string) error { return nil } +func (chatStubRunner) StreamLogs(ctx context.Context, _ string, _ func(chat.LogEntry)) error { + <-ctx.Done() + + return ctx.Err() +} + +type fixtureOpts struct { + chatConfig config.ChatConfig +} + +func defaultFixtureOpts() fixtureOpts { + return fixtureOpts{ + chatConfig: config.ChatConfig{ + DefaultModel: "claude-sonnet-4-6", + Models: map[string]config.ChatModelConfig{ + "claude-sonnet-4-6": {Label: "Sonnet 4.6", MaxTokens: 1000000}, + }, + }, + } +} + +func jsonReq(t *testing.T, method, path, body string) *http.Request { + t.Helper() + + req := httptest.NewRequest(method, path, bytes.NewBufferString(body)) + req.Header.Set("X-Agent-ID", "human:web-x") + req.Header.Set("Content-Type", "application/json") + + return req +} + +func newChatFixture(t *testing.T, opts fixtureOpts) (*http.ServeMux, *chat.Manager) { + t.Helper() + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + chatCfg := opts.chatConfig + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: chatStubRunner{}, + Clock: clock.Real(), + IdleTTL: time.Hour, + DefaultModel: chatCfg.DefaultModel, + }) + hub := chat.NewSSEHub(64) + mux := http.NewServeMux() + chh := newChatHandlers(mgr, hub, &chatCfg) + mux.HandleFunc("GET /api/chats/models", chh.listModels) + mux.HandleFunc("GET /api/chats", chh.listChats) + mux.HandleFunc("POST /api/chats", chh.createChat) + mux.HandleFunc("GET /api/chats/{id}", chh.getChat) + mux.HandleFunc("DELETE /api/chats/{id}", chh.deleteChat) + mux.HandleFunc("PATCH /api/chats/{id}", chh.patchChat) + mux.HandleFunc("POST /api/chats/{id}/open", chh.openChat) + mux.HandleFunc("POST /api/chats/{id}/end", chh.endChat) + mux.HandleFunc("POST /api/chats/{id}/messages", chh.sendMessage) + mux.HandleFunc("GET /api/chats/{id}/messages", chh.listMessages) + mux.HandleFunc("GET /api/chats/{id}/stream", chh.streamChat) + + return mux, mgr +} + +func TestCreateChat_Success(t *testing.T) { + mux, _ := newChatFixture(t, defaultFixtureOpts()) + req := httptest.NewRequest(http.MethodPost, "/api/chats", + bytes.NewBufferString(`{"title":"t","project":"alpha"}`)) + req.Header.Set("X-Agent-ID", "human:web-x") + + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + require.Equal(t, http.StatusCreated, w.Code) + + var sess chat.Session + require.NoError(t, json.NewDecoder(w.Body).Decode(&sess)) + assert.Equal(t, "t", sess.Title) + assert.Equal(t, "alpha", sess.Project) + assert.Equal(t, chat.StatusCold, sess.Status) +} + +func TestGetChat_NotFound(t *testing.T) { + mux, _ := newChatFixture(t, defaultFixtureOpts()) + req := httptest.NewRequest(http.MethodGet, "/api/chats/missing", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestListChats_EmptyReturnsArray(t *testing.T) { + mux, _ := newChatFixture(t, defaultFixtureOpts()) + req := httptest.NewRequest(http.MethodGet, "/api/chats", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "[]\n", w.Body.String()) +} + +func TestSendMessage_OpensColdSession(t *testing.T) { + mux, mgr := newChatFixture(t, defaultFixtureOpts()) + sess, err := mgr.CreateSession(context.Background(), + chat.CreateInput{Title: "", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/chats/"+sess.ID+"/messages", + bytes.NewBufferString(`{"content":"hello"}`)) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + assert.Equal(t, http.StatusAccepted, w.Code) +} + +func TestDeleteChat_Success(t *testing.T) { + mux, mgr := newChatFixture(t, defaultFixtureOpts()) + sess, err := mgr.CreateSession(context.Background(), + chat.CreateInput{Title: "to-del", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodDelete, "/api/chats/"+sess.ID, nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + assert.Equal(t, http.StatusNoContent, w.Code) +} + +// TestEndChat_ReturnsColdSession verifies that POST /api/chats/{id}/end +// returns 200 with the fresh (cold) session body. The frontend depends on +// this body to update its local state without an extra getChat call; an +// empty 2xx response would also have made the client's response.json() +// call throw and surface as "Failed to end session" in the UI. +func TestEndChat_ReturnsColdSession(t *testing.T) { + mux, mgr := newChatFixture(t, defaultFixtureOpts()) + + sess, err := mgr.CreateSession(context.Background(), + chat.CreateInput{Title: "to-end", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + // Drive the session active so EndSession has work to do. + _, err = mgr.OpenSession(context.Background(), sess.ID) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/chats/"+sess.ID+"/end", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "body=%s", w.Body.String()) + + var got chat.Session + + require.NoError(t, json.NewDecoder(w.Body).Decode(&got)) + assert.Equal(t, sess.ID, got.ID) + assert.Equal(t, chat.StatusCold, got.Status) + assert.Empty(t, got.ContainerID, "ended session must not carry a container_id") +} + +func TestPatchChat_UpdatesTitle(t *testing.T) { + mux, mgr := newChatFixture(t, defaultFixtureOpts()) + sess, err := mgr.CreateSession(context.Background(), + chat.CreateInput{Title: "old", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPatch, "/api/chats/"+sess.ID, + bytes.NewBufferString(`{"title":"new"}`)) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var got chat.Session + require.NoError(t, json.NewDecoder(w.Body).Decode(&got)) + assert.Equal(t, "new", got.Title) +} + +func TestListChats_InvalidStatusBadRequest(t *testing.T) { + mux, _ := newChatFixture(t, defaultFixtureOpts()) + req := httptest.NewRequest(http.MethodGet, "/api/chats?status=bogus", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +// TestStreamChat_ConnectedBeforeAnyEvent verifies that the SSE handler +// flushes a ": connected\n\n" comment immediately on subscribe so browsers +// (and proxies that buffer until first body byte) see onopen fire even +// when no events are pending. Without this, the chat status dot stays grey +// and the UI thinks the stream is disconnected. +func TestStreamChat_ConnectedBeforeAnyEvent(t *testing.T) { + mux, mgr := newChatFixture(t, defaultFixtureOpts()) + sess, err := mgr.CreateSession(context.Background(), + chat.CreateInput{Title: "", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + srv := httptest.NewServer(mux) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, + srv.URL+"/api/chats/"+sess.ID+"/stream", nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) + + done := make(chan string, 1) + + go func() { + reader := bufio.NewReader(resp.Body) + + line, err := reader.ReadString('\n') + if err != nil { + done <- "" + + return + } + + done <- line + }() + + select { + case line := <-done: + assert.True(t, strings.HasPrefix(line, ": connected"), + "expected first SSE line to be `: connected`, got %q", line) + case <-time.After(500 * time.Millisecond): + t.Fatal("did not receive `: connected` within 500ms — handler is not flushing before blocking") + } +} + +// TestStreamChat_UnknownSession_404 verifies that GET .../stream against a +// session that does not exist returns 404 without creating a hub entry. The +// SSE hub used to lazily create a per-session ring buffer on first reference, +// so any GET against an unknown id would grow perSess permanently. The +// handler must validate the session exists before subscribing. +func TestStreamChat_UnknownSession_404(t *testing.T) { + mux, _ := newChatFixture(t, defaultFixtureOpts()) + + // Bounded deadline so the test fails fast if the handler subscribes and + // blocks instead of returning 404 immediately. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + req := httptest.NewRequest(http.MethodGet, "/api/chats/never-existed/stream", nil).WithContext(ctx) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusNotFound, w.Code, + "streamChat must 404 on unknown session, not silently create a hub entry") +} + +func TestListMessages_EmptyEnvelope(t *testing.T) { + mux, mgr := newChatFixture(t, defaultFixtureOpts()) + + sess, err := mgr.CreateSession(context.Background(), + chat.CreateInput{Title: "t", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, "/api/chats/"+sess.ID+"/messages", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var body struct { + Messages []chat.Message `json:"messages"` + } + require.NoError(t, json.NewDecoder(w.Body).Decode(&body)) + assert.NotNil(t, body.Messages, "messages must be [] not null") + assert.Empty(t, body.Messages) +} + +func TestListMessages_FiltersSinceSeqExclusively(t *testing.T) { + mux, mgr := newChatFixture(t, defaultFixtureOpts()) + + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, + chat.CreateInput{Title: "t", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + for i := range 5 { + _, err := mgr.AppendMessage(ctx, sess.ID, chat.RoleAssistantText, + `{"text":"m`+strconv.Itoa(i)+`"}`) + require.NoError(t, err) + } + + req := httptest.NewRequest(http.MethodGet, + "/api/chats/"+sess.ID+"/messages?since_seq=2", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var body struct { + Messages []chat.Message `json:"messages"` + } + require.NoError(t, json.NewDecoder(w.Body).Decode(&body)) + require.Len(t, body.Messages, 3) + assert.Equal(t, int64(3), body.Messages[0].Seq) + assert.Equal(t, int64(5), body.Messages[2].Seq) +} + +func TestListMessages_RespectsLimit(t *testing.T) { + mux, mgr := newChatFixture(t, defaultFixtureOpts()) + + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, + chat.CreateInput{Title: "t", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + for i := range 5 { + _, err := mgr.AppendMessage(ctx, sess.ID, chat.RoleAssistantText, + `{"text":"m`+strconv.Itoa(i)+`"}`) + require.NoError(t, err) + } + + req := httptest.NewRequest(http.MethodGet, + "/api/chats/"+sess.ID+"/messages?limit=3", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var body struct { + Messages []chat.Message `json:"messages"` + } + require.NoError(t, json.NewDecoder(w.Body).Decode(&body)) + require.Len(t, body.Messages, 3) + assert.Equal(t, int64(1), body.Messages[0].Seq, "oldest-first ordering") + assert.Equal(t, int64(3), body.Messages[2].Seq) +} + +func TestListMessages_ClampsLimitToMax(t *testing.T) { + mux, mgr := newChatFixture(t, defaultFixtureOpts()) + + ctx := context.Background() + sess, err := mgr.CreateSession(ctx, + chat.CreateInput{Title: "t", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + for range 1001 { + _, err := mgr.AppendMessage(ctx, sess.ID, chat.RoleAssistantText, `{"text":"m"}`) + require.NoError(t, err) + } + + req := httptest.NewRequest(http.MethodGet, + "/api/chats/"+sess.ID+"/messages?limit=99999", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var body struct { + Messages []chat.Message `json:"messages"` + } + require.NoError(t, json.NewDecoder(w.Body).Decode(&body)) + assert.Len(t, body.Messages, 1000, "limit must be clamped to maxLimit") +} + +func TestListMessages_UnknownSessionReturns404(t *testing.T) { + mux, _ := newChatFixture(t, defaultFixtureOpts()) + + req := httptest.NewRequest(http.MethodGet, "/api/chats/no-such/messages", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestSendMessage_TooLong(t *testing.T) { + mux, mgr := newChatFixture(t, defaultFixtureOpts()) + sess, err := mgr.CreateSession(context.Background(), + chat.CreateInput{Title: "", CreatedBy: "human:web-x"}) + require.NoError(t, err) + + long := make([]byte, 8193) + for i := range long { + long[i] = 'x' + } + + body, _ := json.Marshal(map[string]string{"content": string(long)}) + req := httptest.NewRequest(http.MethodPost, "/api/chats/"+sess.ID+"/messages", + bytes.NewReader(body)) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + assert.Equal(t, http.StatusRequestEntityTooLarge, w.Code) +} + +func TestListModels(t *testing.T) { + t.Parallel() + mux, _ := newChatFixture(t, fixtureOpts{chatConfig: config.ChatConfig{ + DefaultModel: "claude-sonnet-4-6", + Models: map[string]config.ChatModelConfig{ + "claude-haiku-4-5-20251001": {Label: "Haiku 4.5", MaxTokens: 200000}, + "claude-opus-4-7": {Label: "Opus 4.7", MaxTokens: 1000000}, + "claude-sonnet-4-6": {Label: "Sonnet 4.6", MaxTokens: 1000000}, + }, + }}) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, httptest.NewRequest("GET", "/api/chats/models", nil)) + require.Equal(t, 200, rec.Code) + + var body struct { + Models []struct { + ID, Label string + MaxTokens int64 + } `json:"models"` + Default string `json:"default"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &body)) + require.Equal(t, "claude-sonnet-4-6", body.Default) + require.Len(t, body.Models, 3) + // Sorted by ID. + require.Equal(t, "claude-haiku-4-5-20251001", body.Models[0].ID) + require.Equal(t, "claude-opus-4-7", body.Models[1].ID) + require.Equal(t, "claude-sonnet-4-6", body.Models[2].ID) +} + +func TestCreateChat_Model_RoundTrip(t *testing.T) { + t.Parallel() + mux, _ := newChatFixture(t, defaultFixtureOpts()) + body := `{"title":"x","project":"p","model":"claude-sonnet-4-6"}` + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, jsonReq(t, "POST", "/api/chats", body)) + require.Equal(t, 201, rec.Code) + + var sess chat.Session + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &sess)) + require.Equal(t, "claude-sonnet-4-6", sess.Model) +} + +func TestCreateChat_InvalidModel(t *testing.T) { + t.Parallel() + mux, _ := newChatFixture(t, defaultFixtureOpts()) + body := `{"title":"x","project":"p","model":"gpt-5"}` + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, jsonReq(t, "POST", "/api/chats", body)) + require.Equal(t, 400, rec.Code) + require.Contains(t, rec.Body.String(), "INVALID_MODEL") +} + +func TestListModels_NilConfig(t *testing.T) { + t.Parallel() + // Create a fixture with nil chat config by manually building without the chatConfig. + t.Helper() + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: chatStubRunner{}, + Clock: clock.Real(), + IdleTTL: time.Hour, + DefaultModel: "claude-sonnet-4-6", + }) + hub := chat.NewSSEHub(64) + mux := http.NewServeMux() + chh := newChatHandlers(mgr, hub, nil) + mux.HandleFunc("GET /api/chats/models", chh.listModels) + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, httptest.NewRequest("GET", "/api/chats/models", nil)) + require.Equal(t, 200, rec.Code) + require.Contains(t, rec.Body.String(), `"models":[]`) + require.Contains(t, rec.Body.String(), `"default":""`) +} diff --git a/internal/api/router.go b/internal/api/router.go index 9727700f..91b06150 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -17,6 +17,7 @@ import ( githubauth "github.com/mhersson/contextmatrix-githubauth" "github.com/mhersson/contextmatrix/internal/board" + "github.com/mhersson/contextmatrix/internal/chat" "github.com/mhersson/contextmatrix/internal/config" "github.com/mhersson/contextmatrix/internal/ctxlog" "github.com/mhersson/contextmatrix/internal/events" @@ -96,6 +97,9 @@ type RouterConfig struct { Version string // build version string for display MCPHandler http.Handler // optional; registered at POST/GET/DELETE /mcp when set RefreshRegistry *refresh.Registry // optional; tracks in-flight KB refresh jobs + ChatManager *chat.Manager // optional; enables /api/chats routes + ChatHub *chat.SSEHub // optional; required when ChatManager is set + ChatConfig *config.ChatConfig // optional; carries model allowlist for /api/chats endpoints } // NewRouter creates a new HTTP router with all API routes registered. @@ -216,6 +220,22 @@ func NewRouter(cfg RouterConfig) http.Handler { mux.HandleFunc("GET /api/v1/cards/{project}/{id}/autonomous", rh.getCardAutonomous) } + // Chat routes — registered only when both the manager and hub are wired. + if cfg.ChatManager != nil && cfg.ChatHub != nil { + chh := newChatHandlers(cfg.ChatManager, cfg.ChatHub, cfg.ChatConfig) + mux.HandleFunc("GET /api/chats", chh.listChats) + mux.HandleFunc("POST /api/chats", chh.createChat) + mux.HandleFunc("GET /api/chats/{id}", chh.getChat) + mux.HandleFunc("PATCH /api/chats/{id}", chh.patchChat) + mux.HandleFunc("DELETE /api/chats/{id}", chh.deleteChat) + mux.HandleFunc("POST /api/chats/{id}/open", chh.openChat) + mux.HandleFunc("POST /api/chats/{id}/end", chh.endChat) + mux.HandleFunc("POST /api/chats/{id}/messages", chh.sendMessage) + mux.HandleFunc("GET /api/chats/{id}/messages", chh.listMessages) + mux.HandleFunc("GET /api/chats/{id}/stream", chh.streamChat) + mux.HandleFunc("GET /api/chats/models", chh.listModels) + } + // MCP server routes — registered on the inner mux so they share the // same middleware chain as every other route (recovery, requestID, // observe, bodyLimit, ...). @@ -346,8 +366,11 @@ func observe(next http.Handler) http.Handler { // SSE streams would pollute the REST latency histogram and the // path label set — skip them entirely for metrics. MCP Streamable // HTTP GET /mcp is a long-lived SSE connection for the same reason. + // Chat session SSE streams (/api/chats/{id}/stream) follow the same + // pattern — match by suffix since the id is variable. if r.URL.Path == "/api/events" || r.URL.Path == "/api/runner/logs" || - (r.Method == http.MethodGet && r.URL.Path == "/mcp") { + (r.Method == http.MethodGet && r.URL.Path == "/mcp") || + (r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/api/chats/") && strings.HasSuffix(r.URL.Path, "/stream")) { return } diff --git a/internal/mcp/chat_rehydration.go b/internal/mcp/chat_rehydration.go new file mode 100644 index 00000000..056219f3 --- /dev/null +++ b/internal/mcp/chat_rehydration.go @@ -0,0 +1,73 @@ +package mcp + +import ( + "context" + "fmt" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/mhersson/contextmatrix/internal/chat" + "github.com/mhersson/contextmatrix/internal/mcp/mcpcontext" +) + +// chatRehydrationCompleteInput is the agent-facing argument shape. +type chatRehydrationCompleteInput struct { + SessionID string `json:"session_id" jsonschema:"required,id of the chat session being resumed"` + Summary string `json:"summary" jsonschema:"required,one-paragraph summary of where the conversation left off (becomes the first visible message of the resumed chat)"` +} + +// chatRehydrationCompleteOutput acknowledges the call. +type chatRehydrationCompleteOutput struct { + OK bool `json:"ok"` +} + +// registerChatRehydrationComplete adds the chat_rehydration_complete MCP tool. +// The agent calls this once it has read /run/cm-chat/resume.jsonl and +// re-established workspace state. The call: +// +// - flips chat_sessions.rehydration_active off so the UI un-collapses the +// restoration block and stops showing the "Restoring workspace…" spinner; +// - persists `summary` as a normal (non-phase) assistant_text message so +// the operator sees an anchor point ("ready to continue") in the thread. +// +// Idempotent: a second call with the flag already off is a successful no-op. +func registerChatRehydrationComplete(server *mcp.Server, mgr *chat.Manager) { + mcp.AddTool(server, &mcp.Tool{ + Name: "chat_rehydration_complete", + Description: "Signal that the chat-mode rehydration phase is complete. " + + "Call this exactly once per resumed chat session after reading " + + "/run/cm-chat/resume.jsonl and re-establishing any workspace state " + + "(re-cloning repos, restoring branches, etc.). The summary argument " + + "becomes the first visible message of the resumed chat — keep it " + + "to one short paragraph.", + }, buildChatRehydrationCompleteTool(mgr)) +} + +// buildChatRehydrationCompleteTool returns the handler closure for the +// chat_rehydration_complete tool. Extracted for direct unit testing. +func buildChatRehydrationCompleteTool(mgr *chat.Manager) func(context.Context, *mcp.CallToolRequest, chatRehydrationCompleteInput) (*mcp.CallToolResult, chatRehydrationCompleteOutput, error) { + return func(ctx context.Context, _ *mcp.CallToolRequest, in chatRehydrationCompleteInput) (*mcp.CallToolResult, chatRehydrationCompleteOutput, error) { + if strings.TrimSpace(in.SessionID) == "" { + return nil, chatRehydrationCompleteOutput{}, fmt.Errorf("chat_rehydration_complete: session_id is required") + } + + if strings.TrimSpace(in.Summary) == "" { + return nil, chatRehydrationCompleteOutput{}, fmt.Errorf("chat_rehydration_complete: summary is required") + } + + // Gate the call to the caller's own session. Chat-container callers + // forward CM_CHAT_SESSION via X-CM-Chat-Session; the middleware stashes + // it into ctx. Empty caller means the header was absent (card-mode + // worker, human curl) so we skip the check. + if caller := mcpcontext.ChatSession(ctx); caller != "" && caller != in.SessionID { + return nil, chatRehydrationCompleteOutput{}, fmt.Errorf("chat_rehydration_complete: session mismatch: caller=%s session_id=%s", caller, in.SessionID) + } + + if err := mgr.CompleteRehydration(ctx, in.SessionID, in.Summary); err != nil { + return nil, chatRehydrationCompleteOutput{}, fmt.Errorf("chat_rehydration_complete: %w", err) + } + + return nil, chatRehydrationCompleteOutput{OK: true}, nil + } +} diff --git a/internal/mcp/chat_rehydration_test.go b/internal/mcp/chat_rehydration_test.go new file mode 100644 index 00000000..043896f2 --- /dev/null +++ b/internal/mcp/chat_rehydration_test.go @@ -0,0 +1,149 @@ +package mcp + +import ( + "context" + "path/filepath" + "strings" + "testing" + "time" + + _ "modernc.org/sqlite" + + "github.com/stretchr/testify/require" + + "github.com/mhersson/contextmatrix/internal/chat" + "github.com/mhersson/contextmatrix/internal/chat/sqlite" + "github.com/mhersson/contextmatrix/internal/clock" + "github.com/mhersson/contextmatrix/internal/mcp/mcpcontext" +) + +// chatTestDeps holds the store so tests can manipulate rehydration_active +// independently of the manager. +type chatTestDeps struct { + store chat.Store +} + +// setRehydrationActive flips the rehydration flag directly via the store, +// bypassing the manager's in-memory cache to simulate a just-opened session. +func (d *chatTestDeps) setRehydrationActive(ctx context.Context, sessionID string, active bool) error { + return d.store.SetRehydrationActive(ctx, sessionID, active) +} + +// newTestChatManager creates a chat.Manager backed by a real SQLite store and +// a no-op stub runner, suitable for unit-testing tool handlers. +func newTestChatManager(t *testing.T) (*chat.Manager, *chatTestDeps) { + t.Helper() + + store, err := sqlite.Open(filepath.Join(t.TempDir(), "chats.db")) + require.NoError(t, err) + t.Cleanup(func() { _ = store.Close() }) + + mgr := chat.NewManager(chat.Config{ + Store: store, + Runner: &chatStubRunner{}, + Clock: clock.Real(), + IdleTTL: time.Hour, + }) + + return mgr, &chatTestDeps{store: store} +} + +// chatStubRunner is the minimal RunnerClient stub needed for chat manager tests +// in the mcp package. It satisfies the chat.RunnerClient interface without any +// real behaviour — we never actually start containers in these tests. +type chatStubRunner struct{} + +func (r *chatStubRunner) StartChat(_ context.Context, _ chat.StartChatOpts) (string, error) { + return "stub-container", nil +} + +func (r *chatStubRunner) EndChat(_ context.Context, _ string) error { return nil } + +func (r *chatStubRunner) SendChatMessage(_ context.Context, _, _, _ string) error { return nil } + +func (r *chatStubRunner) StreamLogs(ctx context.Context, _ string, _ func(chat.LogEntry)) error { + <-ctx.Done() + + return ctx.Err() +} + +// --- Tests --- + +func TestChatRehydrationCompleteTool_UnknownSession(t *testing.T) { + t.Parallel() + mgr, _ := newTestChatManager(t) + tool := buildChatRehydrationCompleteTool(mgr) + _, _, err := tool(context.Background(), nil, chatRehydrationCompleteInput{ + SessionID: "01UNKNOWN", + Summary: "x", + }) + require.Error(t, err) + require.Contains(t, strings.ToLower(err.Error()), "session not found") +} + +func TestChatRehydrationCompleteTool_HappyPath(t *testing.T) { + t.Parallel() + mgr, deps := newTestChatManager(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + + // Flip the flag on via the store directly, simulating a cold reopen. + require.NoError(t, deps.setRehydrationActive(ctx, sess.ID, true)) + + tool := buildChatRehydrationCompleteTool(mgr) + _, out, err := tool(ctx, nil, chatRehydrationCompleteInput{ + SessionID: sess.ID, + Summary: "restored.", + }) + require.NoError(t, err) + require.True(t, out.OK) + + got, err := mgr.GetSession(ctx, sess.ID) + require.NoError(t, err) + require.False(t, got.RehydrationActive, "flag must flip off") +} + +func TestChatRehydrationCompleteTool_AlreadyInactive(t *testing.T) { + t.Parallel() + mgr, _ := newTestChatManager(t) + ctx := context.Background() + + sess, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + + // rehydration_active is false by default — call should be a no-op success. + tool := buildChatRehydrationCompleteTool(mgr) + _, out, err := tool(ctx, nil, chatRehydrationCompleteInput{ + SessionID: sess.ID, + Summary: "no-op", + }) + require.NoError(t, err) + require.True(t, out.OK) +} + +func TestChatRehydrationCompleteTool_CrossSessionRejected(t *testing.T) { + t.Parallel() + mgr, deps := newTestChatManager(t) + ctx := context.Background() + // Two sessions; both rehydration-active. + a, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + b, err := mgr.CreateSession(ctx, chat.CreateInput{Project: "p", CreatedBy: "human:t"}) + require.NoError(t, err) + require.NoError(t, deps.setRehydrationActive(ctx, a.ID, true)) + require.NoError(t, deps.setRehydrationActive(ctx, b.ID, true)) + // Agent in container A calls the tool with B's session_id. + ctxA := mcpcontext.WithChatSession(ctx, a.ID) + tool := buildChatRehydrationCompleteTool(mgr) + _, _, err = tool(ctxA, nil, chatRehydrationCompleteInput{ + SessionID: b.ID, + Summary: "I am evil", + }) + require.Error(t, err) + require.Contains(t, strings.ToLower(err.Error()), "session mismatch") + // B's rehydration_active must still be true. + got, _ := mgr.GetSession(ctx, b.ID) + require.True(t, got.RehydrationActive) +} diff --git a/internal/mcp/mcpcontext/mcpcontext.go b/internal/mcp/mcpcontext/mcpcontext.go new file mode 100644 index 00000000..d75ad3c4 --- /dev/null +++ b/internal/mcp/mcpcontext/mcpcontext.go @@ -0,0 +1,24 @@ +// Package mcpcontext exposes typed context helpers shared by the MCP HTTP +// transport (server.go middleware) and the tool handlers. The session +// helper is set by chatSessionHeaderMiddleware when a chat-container +// caller forwards its X-CM-Chat-Session header; tools may use it to gate +// session-scoped operations to the caller's own session. +package mcpcontext + +import "context" + +type chatSessionKey struct{} + +// WithChatSession returns a derived context carrying the chat session ID +// the caller claims to belong to. Set by middleware; read by tools. +func WithChatSession(ctx context.Context, sessionID string) context.Context { + return context.WithValue(ctx, chatSessionKey{}, sessionID) +} + +// ChatSession returns the chat session ID stashed by WithChatSession. +// Empty string when the header was absent (card-mode, human curl). +func ChatSession(ctx context.Context) string { + v, _ := ctx.Value(chatSessionKey{}).(string) + + return v +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go index a5a5b7f8..cdb35b00 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -14,12 +14,16 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/mhersson/contextmatrix/internal/chat" "github.com/mhersson/contextmatrix/internal/ctxlog" + "github.com/mhersson/contextmatrix/internal/mcp/mcpcontext" "github.com/mhersson/contextmatrix/internal/service" ) // NewServer creates a configured MCP server with all tools and prompts registered. -func NewServer(svc *service.CardService, workflowSkillsDir string) *mcp.Server { +// chatMgr may be nil when chat is disabled; chat-specific tools register only +// when it is non-nil. +func NewServer(svc *service.CardService, workflowSkillsDir string, chatMgr *chat.Manager) *mcp.Server { server := mcp.NewServer( &mcp.Implementation{ Name: "contextmatrix", @@ -31,6 +35,10 @@ func NewServer(svc *service.CardService, workflowSkillsDir string) *mcp.Server { registerTools(server, svc, workflowSkillsDir) registerPrompts(server, svc, workflowSkillsDir) + if chatMgr != nil { + registerChatRehydrationComplete(server, chatMgr) + } + return server } @@ -40,11 +48,13 @@ func NewServer(svc *service.CardService, workflowSkillsDir string) *mcp.Server { // // Middleware order (outermost → innermost): // -// mcpAuthMiddleware → clearWriteDeadlineForStreaming → mcpRequestInfoMiddleware → SDK handler +// mcpAuthMiddleware → clearWriteDeadlineForStreaming → chatSessionHeaderMiddleware → mcpRequestInfoMiddleware → SDK handler // // Unauthenticated probes are rejected before body inspection. The write-deadline -// tweak stays outermost to apply to all authenticated requests. Request info -// extraction runs just before the SDK so it can read the body once. +// tweak stays outermost to apply to all authenticated requests. The chat-session +// header is stashed into context after auth so only authenticated callers can +// set it. Request info extraction runs just before the SDK so it can read the +// body once. func NewHandler(server *mcp.Server, apiKey string) http.Handler { handler := mcp.NewStreamableHTTPHandler( func(_ *http.Request) *mcp.Server { return server }, @@ -54,8 +64,12 @@ func NewHandler(server *mcp.Server, apiKey string) http.Handler { ) // Innermost: SDK handler wrapped with request-info extraction. infoWrapped := mcpRequestInfoMiddleware(handler) + // Above info: stash X-CM-Chat-Session into the request context so tool + // handlers can gate session-scoped operations to the calling chat + // container's own session. + sessionWrapped := chatSessionHeaderMiddleware(infoWrapped) // Middle: write-deadline clearing for long-lived GET SSE streams. - wrapped := clearWriteDeadlineForStreaming(infoWrapped) + wrapped := clearWriteDeadlineForStreaming(sessionWrapped) if apiKey == "" { return wrapped } @@ -63,6 +77,24 @@ func NewHandler(server *mcp.Server, apiKey string) http.Handler { return mcpAuthMiddleware(wrapped, apiKey) } +// chatSessionHeaderMiddleware reads the X-CM-Chat-Session header (forwarded by +// chat-container entrypoints) and stashes the value into the request context +// via mcpcontext.WithChatSession. Session-scoped MCP tools +// (chat_rehydration_complete) compare this against the in-RPC session_id to +// reject cross-session calls from a compromised or malicious caller. +// +// Empty header (card-mode worker, human curl) leaves the context untouched so +// existing flows are unaffected. +func chatSessionHeaderMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if h := r.Header.Get("X-CM-Chat-Session"); h != "" { + r = r.WithContext(mcpcontext.WithChatSession(r.Context(), h)) + } + + next.ServeHTTP(w, r) + }) +} + // mcpRequestInfoMiddleware reads the JSON-RPC body to populate the MCPCall // stored in context by the outer observe middleware. It is best-effort: any // read or parse error is swallowed so logging never breaks MCP traffic. diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 3fc0481c..7a01c332 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -100,7 +100,7 @@ func setupMCP(t *testing.T) *testEnv { } // Create MCP server and connect in-memory - server := NewServer(svc, workflowSkillsDir) + server := NewServer(svc, workflowSkillsDir, nil) ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/runner/client.go b/internal/runner/client.go index 8eae73b7..ae3960ee 100644 --- a/internal/runner/client.go +++ b/internal/runner/client.go @@ -105,6 +105,7 @@ type ContainerInfo struct { ContainerID string ContainerName string CardID string + SessionID string Project string State string StartedAt time.Time @@ -118,6 +119,7 @@ type containerInfoWire struct { ContainerID string `json:"container_id"` ContainerName string `json:"container_name,omitempty"` CardID string `json:"card_id"` + SessionID string `json:"session_id,omitempty"` Project string `json:"project"` State string `json:"state"` StartedAt string `json:"started_at"` @@ -227,6 +229,7 @@ func (c *Client) ListContainers(ctx context.Context) ([]ContainerInfo, error) { ContainerID: c.ContainerID, ContainerName: c.ContainerName, CardID: c.CardID, + SessionID: c.SessionID, Project: c.Project, State: c.State, StartedAt: started, diff --git a/internal/runner/endsession_test.go b/internal/runner/endsession_test.go index f7cc3af9..80f062ef 100644 --- a/internal/runner/endsession_test.go +++ b/internal/runner/endsession_test.go @@ -55,6 +55,7 @@ type fakeClient struct { killErr error listResult []runner.ContainerInfo listErr error + listCount int } func (f *fakeClient) EndSession(_ context.Context, p runner.EndSessionPayload) error { @@ -83,6 +84,8 @@ func (f *fakeClient) ListContainers(_ context.Context) ([]runner.ContainerInfo, f.mu.Lock() defer f.mu.Unlock() + f.listCount++ + if f.listErr != nil { return nil, f.listErr } @@ -93,6 +96,17 @@ func (f *fakeClient) ListContainers(_ context.Context) ([]runner.ContainerInfo, return out, nil } +// ListCount returns the number of ListContainers calls observed since the +// fake was created. Used by tests that need to assert a single sweep tick +// makes exactly one /containers round-trip to avoid hitting the runner's +// HMAC replay cache. +func (f *fakeClient) ListCount() int { + f.mu.Lock() + defer f.mu.Unlock() + + return f.listCount +} + func (f *fakeClient) Calls() []runner.EndSessionPayload { f.mu.Lock() defer f.mu.Unlock() diff --git a/internal/runner/reconcile.go b/internal/runner/reconcile.go index 6aeca804..251c730a 100644 --- a/internal/runner/reconcile.go +++ b/internal/runner/reconcile.go @@ -63,13 +63,20 @@ type CardLookup interface { // StartReconciliationSweep launches a ticker goroutine that periodically asks // the runner for every labeled container and decides, per container, whether -// it should still be running. A container is killed if: +// it should still be running. A card container is killed if: // // 1. CM has no card matching (project, card_id) — deleted or renamed out // from under the container. // 2. The card's state is terminal (done / not_planned) — the work is over. // 3. The container is older than ContainerMaxAge — runaway cap. // +// When chatMgr is non-nil the same tick also reconciles chat sessions: any +// CM-side active or warm-idle session whose SessionID is missing from the +// runner's container list is flipped to cold. Card and chat reconcile share +// the single /containers fetch per tick — splitting them across two tickers +// would produce two identically-signed HMAC GETs back to back and the runner's +// replay cache would reject the second as a duplicate. +// // Notably: the sweep does NOT consult the card's runner_status field. That // field is a CM-side bookkeeping convenience that has repeatedly drifted // away from Docker reality (the runner's reportCompleted/reportFailure @@ -80,7 +87,7 @@ type CardLookup interface { // // Blocks only until the goroutine is scheduled; returns immediately. An // interval of 0 disables the sweep entirely. -func StartReconciliationSweep(ctx context.Context, svc CardLookup, client ReconcileClient, interval time.Duration, logger *slog.Logger) { +func StartReconciliationSweep(ctx context.Context, svc CardLookup, chatMgr ChatReconciler, client ReconcileClient, interval time.Duration, logger *slog.Logger) { if logger == nil { logger = slog.Default() } @@ -109,7 +116,7 @@ func StartReconciliationSweep(ctx context.Context, svc CardLookup, client Reconc // Run an initial sweep immediately so containers orphaned by a CM // restart (events published while CM was down are never delivered) // are cleaned up without having to wait a full interval. - runReconcileSweep(ctx, svc, client, maxAge, logger) + runReconcileSweep(ctx, svc, chatMgr, client, maxAge, logger) ticker := time.NewTicker(interval) defer ticker.Stop() @@ -120,21 +127,21 @@ func StartReconciliationSweep(ctx context.Context, svc CardLookup, client Reconc return case <-ticker.C: - runReconcileSweep(ctx, svc, client, maxAge, logger) + runReconcileSweep(ctx, svc, chatMgr, client, maxAge, logger) } } }() } -// runReconcileSweep asks the runner for its current container list and kills -// every container whose card says it should no longer be running. Safe to -// call ad-hoc from tests. +// runReconcileSweep asks the runner for its current container list once, +// then runs both the card-kill loop and (if chatMgr != nil) the chat +// orphan reconcile against that single list. Safe to call ad-hoc from +// tests. // // Every tick logs scanned/killed — including 0/0 — so "is the sweep actually -// running?" is answerable from a single grep. The old implementation only -// logged when killed > 0, which made debugging a stuck container impossible -// without adding ad-hoc instrumentation. -func runReconcileSweep(ctx context.Context, svc CardLookup, client ReconcileClient, maxAge time.Duration, logger *slog.Logger) { +// running?" is answerable from a single grep. The chat reconcile emits its +// own tick log line when chatMgr is non-nil. +func runReconcileSweep(ctx context.Context, svc CardLookup, chatMgr ChatReconciler, client ReconcileClient, maxAge time.Duration, logger *slog.Logger) { containers, err := client.ListContainers(ctx) if err != nil { logger.Warn("reconcile sweep: runner list failed (skipping tick)", "error", err) @@ -167,6 +174,10 @@ func runReconcileSweep(ctx context.Context, svc CardLookup, client ReconcileClie "scanned", len(containers), "killed", killed, ) + + if chatMgr != nil { + reconcileChatSessions(ctx, chatMgr, containers, logger) + } } // decideKill runs the three-rule authoritative check against a single @@ -179,6 +190,16 @@ func runReconcileSweep(ctx context.Context, svc CardLookup, client ReconcileClie // time, passed through explicitly instead of re-read from the package var so // a concurrent test mutation cannot race the sweep's per-tick check. func decideKill(ctx context.Context, svc CardLookup, c ContainerInfo, maxAge time.Duration, logger *slog.Logger) (string, bool) { + // Chat containers are reconciled by RunChatReconcileSweep against + // chat.Manager's session store — not by the card sweep. Routing them + // through decideKill would fire /end-session with an empty CardID, + // which the runner rejects with HTTP 400. The chat sweep's input is + // CM-authoritative (chat DB) cross-referenced with the same + // /containers list, so chat orphans are not the card sweep's problem. + if c.CardID == "" { + return "", false + } + // KB-refresh containers use a synthetic card_id and are managed by the // internal/refresh registry janitor, not by the card-state sweep — they // also have the runner's own container_timeout as a last-resort cap. The @@ -254,3 +275,102 @@ func truncate(id string) string { return id } + +// ChatSessionRef is the subset of chat.Session the reconcile sweep reads. +// Keeping the type local to the runner package means reconcile.go does not +// import the chat package — the wiring in main.go adapts chat.Manager to +// the ChatReconciler interface. +type ChatSessionRef struct { + ID string + Status string +} + +// ChatReconciler is the chat-manager surface the chat reconcile sweep needs: +// enumerate non-cold sessions and flip orphans to cold. The interface keeps +// reconcile.go decoupled from chat.Manager's full type surface and lets tests +// inject a small fake. +type ChatReconciler interface { + // ListActiveChatSessions returns every chat session whose CM-side + // status is active or warm-idle. Cold and ending sessions are excluded + // — cold has no runner container by definition, and ending is a + // transient transition the sweep should not race against. + ListActiveChatSessions(ctx context.Context) ([]ChatSessionRef, error) + // EndChatSession flips a session to cold and clears its container_id, + // matching the user-initiated End Session path. Idempotent on cold. + EndChatSession(ctx context.Context, id string) error +} + +// RunChatReconcileSweep is the standalone chat-reconcile entrypoint kept for +// unit tests that want to exercise the chat path without standing up a card +// sweep. Production wiring goes through StartReconciliationSweep, which folds +// the chat reconcile into the same tick that drives card reconcile so both +// share one /containers fetch. +// +// A failed /containers call skips the tick — better to leave live sessions +// alone than to flip every one to cold because the runner briefly couldn't +// answer. +func RunChatReconcileSweep(ctx context.Context, chatMgr ChatReconciler, client ContainerLister, logger *slog.Logger) { + if chatMgr == nil || client == nil { + return + } + + if logger == nil { + logger = slog.Default() + } + + containers, err := client.ListContainers(ctx) + if err != nil { + logger.Warn("chat reconcile sweep: runner list failed (skipping tick)", "error", err) + + return + } + + reconcileChatSessions(ctx, chatMgr, containers, logger) +} + +// reconcileChatSessions cross-references CM's active/warm-idle chat sessions +// with a pre-fetched runner container list. Any CM session whose SessionID is +// missing from the list is treated as an orphan and flipped to cold via +// EndChatSession. Silent on the happy path. Per-session EndChatSession +// errors are logged but do not abort the rest of the sweep. +func reconcileChatSessions(ctx context.Context, chatMgr ChatReconciler, containers []ContainerInfo, logger *slog.Logger) { + runnerHas := make(map[string]bool, len(containers)) + + for _, c := range containers { + if c.SessionID != "" { + runnerHas[c.SessionID] = true + } + } + + sessions, err := chatMgr.ListActiveChatSessions(ctx) + if err != nil { + logger.Warn("chat reconcile sweep: list sessions failed (skipping tick)", "error", err) + + return + } + + var orphaned int + + for _, s := range sessions { + if runnerHas[s.ID] { + continue + } + + logger.Warn("chat reconcile sweep: orphan session, flipping to cold", + "session_id", s.ID, "from_status", s.Status) + + if err := chatMgr.EndChatSession(ctx, s.ID); err != nil { + logger.Warn("chat reconcile sweep: EndChatSession failed", + "session_id", s.ID, "error", err) + + continue + } + + orphaned++ + } + + logger.Info("chat reconcile sweep tick", + "scanned", len(sessions), + "orphaned", orphaned, + ) +} diff --git a/internal/runner/reconcile_test.go b/internal/runner/reconcile_test.go index 9416a2bb..ca173e15 100644 --- a/internal/runner/reconcile_test.go +++ b/internal/runner/reconcile_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync" "testing" "time" @@ -45,7 +46,7 @@ func TestReconciliationSweep_TerminalCardKillsContainer(t *testing.T) { }, } - runner.StartReconciliationSweep(ctx, cg, fc, 30*time.Millisecond, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, fc, 30*time.Millisecond, discardLogger()) waitForKillCalls(t, fc, 1) @@ -77,7 +78,7 @@ func TestReconciliationSweep_SkipsNonTerminalCard(t *testing.T) { }, } - runner.StartReconciliationSweep(ctx, cg, fc, 30*time.Millisecond, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, fc, 30*time.Millisecond, discardLogger()) time.Sleep(150 * time.Millisecond) assert.Empty(t, fc.KillCalls(), "sweep must not kill in-progress card's container") @@ -103,7 +104,7 @@ func TestReconciliationSweep_MissingCardKillsContainer(t *testing.T) { }, } - runner.StartReconciliationSweep(ctx, cg, fc, 30*time.Millisecond, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, fc, 30*time.Millisecond, discardLogger()) waitForKillCalls(t, fc, 1) @@ -137,7 +138,7 @@ func TestReconciliationSweep_SkipsKBRefreshContainer(t *testing.T) { }, } - runner.StartReconciliationSweep(ctx, cg, fc, 30*time.Millisecond, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, fc, 30*time.Millisecond, discardLogger()) time.Sleep(150 * time.Millisecond) assert.Empty(t, fc.KillCalls(), @@ -173,7 +174,7 @@ func TestReconciliationSweep_KBRefreshSkipsAgeCap(t *testing.T) { }, } - runner.StartReconciliationSweep(ctx, cg, fc, 30*time.Millisecond, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, fc, 30*time.Millisecond, discardLogger()) time.Sleep(150 * time.Millisecond) assert.Empty(t, fc.KillCalls(), @@ -208,7 +209,7 @@ func TestReconciliationSweep_AgeCapKillsRunawayContainer(t *testing.T) { }, } - runner.StartReconciliationSweep(ctx, cg, fc, 30*time.Millisecond, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, fc, 30*time.Millisecond, discardLogger()) waitForKillCalls(t, fc, 1) } @@ -228,7 +229,7 @@ func TestReconciliationSweep_ZeroIntervalDisabled(t *testing.T) { }, } - runner.StartReconciliationSweep(ctx, cg, fc, 0, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, fc, 0, discardLogger()) time.Sleep(100 * time.Millisecond) assert.Empty(t, fc.KillCalls(), "sweep must be a no-op at interval=0") @@ -253,7 +254,7 @@ func TestReconciliationSweep_RunsImmediatelyOnStart(t *testing.T) { // Interval well above the assertion deadline — if the first sweep waits // for the ticker, waitForKillCalls will time out. - runner.StartReconciliationSweep(ctx, cg, fc, 10*time.Second, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, fc, 10*time.Second, discardLogger()) waitForKillCalls(t, fc, 1) } @@ -268,7 +269,7 @@ func TestReconciliationSweep_RunnerListFailureSkipsTick(t *testing.T) { cg := &fakeCardGetter{cards: map[string]*board.Card{}} fc := &fakeClient{listErr: errors.New("runner unreachable")} - runner.StartReconciliationSweep(ctx, cg, fc, 30*time.Millisecond, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, fc, 30*time.Millisecond, discardLogger()) time.Sleep(150 * time.Millisecond) // Not kill and not panic — the ListContainers error just skips the tick. @@ -283,7 +284,7 @@ func TestReconciliationSweep_MissingClient_NoPanic(t *testing.T) { cg := &fakeCardGetter{} - runner.StartReconciliationSweep(ctx, cg, nil, 30*time.Millisecond, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, nil, 30*time.Millisecond, discardLogger()) time.Sleep(50 * time.Millisecond) } @@ -303,7 +304,7 @@ func TestReconciliationSweep_TransientCardErrorLeavesContainerAlone(t *testing.T }, } - runner.StartReconciliationSweep(ctx, cg, fc, 30*time.Millisecond, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, fc, 30*time.Millisecond, discardLogger()) time.Sleep(150 * time.Millisecond) assert.Empty(t, fc.KillCalls(), "transient card-store error must not trigger a kill") @@ -330,7 +331,7 @@ func TestReconciliationSweep_StorageNotFoundErrorIsKill(t *testing.T) { }, } - runner.StartReconciliationSweep(ctx, cg, fc, 30*time.Millisecond, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, fc, 30*time.Millisecond, discardLogger()) waitForKillCalls(t, fc, 1) } @@ -350,7 +351,209 @@ func TestReconciliationSweep_WrappedStorageNotFoundErrorIsKill(t *testing.T) { }, } - runner.StartReconciliationSweep(ctx, cg, fc, 30*time.Millisecond, discardLogger()) + runner.StartReconciliationSweep(ctx, cg, nil, fc, 30*time.Millisecond, discardLogger()) waitForKillCalls(t, fc, 1) } + +// TestReconciliationSweep_SkipsChatContainers guards the boundary between the +// card-mode sweep and the chat-mode sweep: after Wave 2.2, /containers also +// reports chat containers (LabelSessionID, no LabelCardID). The card sweep +// must skip those rows — calling decideKill on a chat container with an empty +// CardID would route a malformed /end-session against the runner. +func TestReconciliationSweep_SkipsChatContainers(t *testing.T) { + ctx := t.Context() + + cg := &fakeCardGetter{cards: map[string]*board.Card{}} + fc := &fakeClient{ + listResult: []runner.ContainerInfo{ + { + ContainerID: "chat-ctr-1", + SessionID: "S-active", + Project: "proj", + State: "running", + StartedAt: time.Now().Add(-5 * time.Minute), + }, + }, + } + + runner.StartReconciliationSweep(ctx, cg, nil, fc, 30*time.Millisecond, discardLogger()) + + time.Sleep(150 * time.Millisecond) + assert.Empty(t, fc.KillCalls(), + "card sweep must skip chat containers (those carry SessionID, not CardID)") + assert.Empty(t, fc.Calls(), + "card sweep must not call /end-session for chat containers either") +} + +// fakeChatReconciler implements runner.ChatReconciler for tests. +type fakeChatReconciler struct { + mu sync.Mutex + active []chatSessionStub + warm []chatSessionStub + ended []string + endError error +} + +// chatSessionStub mirrors the subset of chat.Session that the reconcile +// sweep reads — using a local stub keeps the test file independent of the +// chat package's full type surface. +type chatSessionStub struct { + ID string + Status string +} + +func (f *fakeChatReconciler) ListActiveChatSessions(_ context.Context) ([]runner.ChatSessionRef, error) { + f.mu.Lock() + defer f.mu.Unlock() + + out := make([]runner.ChatSessionRef, 0, len(f.active)+len(f.warm)) + for _, s := range f.active { + out = append(out, runner.ChatSessionRef{ID: s.ID, Status: s.Status}) + } + + for _, s := range f.warm { + out = append(out, runner.ChatSessionRef{ID: s.ID, Status: s.Status}) + } + + return out, nil +} + +func (f *fakeChatReconciler) EndChatSession(_ context.Context, id string) error { + f.mu.Lock() + defer f.mu.Unlock() + + f.ended = append(f.ended, id) + + return f.endError +} + +func (f *fakeChatReconciler) endedCalls() []string { + f.mu.Lock() + defer f.mu.Unlock() + + out := make([]string, len(f.ended)) + copy(out, f.ended) + + return out +} + +// TestChatReconcileSweep_FlipsOrphanToCold is the central guarantee of Wave 2.3: +// if CM thinks a chat session is active or warm-idle but the runner has no +// container for it, the sweep flips that session to cold. Without this, +// stranded sessions persist forever after a runner restart or crash. +func TestChatReconcileSweep_FlipsOrphanToCold(t *testing.T) { + ctx := t.Context() + + fcr := &fakeChatReconciler{ + active: []chatSessionStub{ + {ID: "S-live", Status: "active"}, + {ID: "S-orphan", Status: "active"}, + }, + warm: []chatSessionStub{ + {ID: "S-warm-orphan", Status: "warm-idle"}, + }, + } + fc := &fakeClient{ + listResult: []runner.ContainerInfo{ + { + ContainerID: "card-ctr", + CardID: "C-001", + Project: "proj", + State: "running", + StartedAt: time.Now(), + }, + { + ContainerID: "chat-ctr", + SessionID: "S-live", + Project: "proj", + State: "running", + StartedAt: time.Now(), + }, + }, + } + + runner.RunChatReconcileSweep(ctx, fcr, fc, discardLogger()) + + ended := fcr.endedCalls() + require.ElementsMatch(t, []string{"S-orphan", "S-warm-orphan"}, ended, + "both orphan sessions must be ended; the one with a live runner container must be left alone") +} + +// TestChatReconcileSweep_NoOpWhenAllSessionsHaveContainers confirms that when +// every active/warm session matches a runner container, no EndChatSession +// calls fire. Reconcile must be silent on the happy path. +func TestChatReconcileSweep_NoOpWhenAllSessionsHaveContainers(t *testing.T) { + ctx := t.Context() + + fcr := &fakeChatReconciler{ + active: []chatSessionStub{ + {ID: "S-1", Status: "active"}, + {ID: "S-2", Status: "active"}, + }, + } + fc := &fakeClient{ + listResult: []runner.ContainerInfo{ + {ContainerID: "c1", SessionID: "S-1", Project: "proj", State: "running", StartedAt: time.Now()}, + {ContainerID: "c2", SessionID: "S-2", Project: "proj", State: "running", StartedAt: time.Now()}, + }, + } + + runner.RunChatReconcileSweep(ctx, fcr, fc, discardLogger()) + + assert.Empty(t, fcr.endedCalls(), + "happy path: every CM session has a runner container, no end calls expected") +} + +// TestChatReconcileSweep_RunnerListErrorSkipsTick guards the safety property: +// a transient /containers failure must NOT cause CM to flip every chat +// session to cold. Better to skip a tick than to nuke live sessions. +func TestChatReconcileSweep_RunnerListErrorSkipsTick(t *testing.T) { + ctx := t.Context() + + fcr := &fakeChatReconciler{ + active: []chatSessionStub{{ID: "S-1", Status: "active"}}, + } + fc := &fakeClient{ + listErr: errors.New("runner unreachable"), + } + + runner.RunChatReconcileSweep(ctx, fcr, fc, discardLogger()) + + assert.Empty(t, fcr.endedCalls(), + "runner-list error must skip the tick, not flip every session to cold") +} + +// TestReconciliationSweep_SingleContainersFetchPerTick is the regression guard +// for the HMAC replay-cache 409 we hit in dev: card and chat sweeps were +// firing simultaneously on two tickers, each calling /containers with the +// same signed payload — the runner's replay cache rejected the second one +// as "duplicate request". Both reconcilers must now share a single +// ListContainers round-trip per tick. +func TestReconciliationSweep_SingleContainersFetchPerTick(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + cg := &fakeCardGetter{cards: map[string]*board.Card{}} + fcr := &fakeChatReconciler{ + active: []chatSessionStub{{ID: "S-1", Status: "active"}}, + } + fc := &fakeClient{ + listResult: []runner.ContainerInfo{ + {ContainerID: "chat-ctr", SessionID: "S-1", Project: "proj", State: "running", StartedAt: time.Now()}, + }, + } + + // Long interval: only the initial tick fires within the test window, + // so a stable assertion on ListCount==1 is possible. If both sweeps + // fetched separately the count would be 2. + runner.StartReconciliationSweep(ctx, cg, fcr, fc, time.Hour, discardLogger()) + + require.Eventually(t, func() bool { + return fc.ListCount() >= 1 + }, time.Second, 5*time.Millisecond, "expected initial sweep tick") + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, fc.ListCount(), + "single tick must fetch /containers exactly once; separate fetches hit the runner's HMAC replay cache as duplicates") +} diff --git a/test/integration/client_test.go b/test/integration/client_test.go index 736577dd..162862a6 100644 --- a/test/integration/client_test.go +++ b/test/integration/client_test.go @@ -126,6 +126,60 @@ func (c *cmClient) getCard(t *testing.T, project, cardID string) cardSnapshot { return card } +func (c *cmClient) patch(t *testing.T, path string, body any, into any) (int, string) { + t.Helper() + + var buf bytes.Buffer + if body != nil { + if err := json.NewEncoder(&buf).Encode(body); err != nil { + t.Fatalf("patch encode %s: %v", path, err) + } + } + + req, err := http.NewRequest(http.MethodPatch, c.baseURL+path, &buf) + if err != nil { + t.Fatalf("patch req %s: %v", path, err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Agent-ID", "human:harness") + req.Header.Set("X-Requested-With", "contextmatrix") + + resp, err := c.hc.Do(req) + if err != nil { + t.Fatalf("patch do %s: %v", path, err) + } + defer resp.Body.Close() + + raw, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + if into != nil && resp.StatusCode < 400 && len(raw) > 0 { + if err := json.Unmarshal(raw, into); err != nil { + t.Fatalf("patch decode %s: %v body=%s", path, err, raw) + } + } + + return resp.StatusCode, string(raw) +} + +func (c *cmClient) deleteReq(t *testing.T, path string) int { + t.Helper() + + req, err := http.NewRequest(http.MethodDelete, c.baseURL+path, nil) + if err != nil { + t.Fatalf("delete req %s: %v", path, err) + } + req.Header.Set("X-Agent-ID", "human:harness") + req.Header.Set("X-Requested-With", "contextmatrix") + + resp, err := c.hc.Do(req) + if err != nil { + t.Fatalf("delete do %s: %v", path, err) + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + return resp.StatusCode +} + // pollUntil retries fn until it returns true or the deadline expires. func pollUntil(ctx context.Context, t *testing.T, label string, fn func() bool) { t.Helper() diff --git a/test/integration/scenarios_test.go b/test/integration/scenarios_test.go index 834a35ab..29269fb2 100644 --- a/test/integration/scenarios_test.go +++ b/test/integration/scenarios_test.go @@ -4,6 +4,7 @@ package integration_test import ( "context" + "net/http" "testing" "time" ) @@ -16,6 +17,7 @@ func TestIntegrationHarness(t *testing.T) { t.Run("HeartbeatTimeout", testHeartbeatTimeoutStub) t.Run("PromoteHITLToAuto", testPromoteHITLToAutoStub) t.Run("IdleWatchdog", testIdleWatchdogStub) + t.Run("Chat", testChatStub) } func testAutonomousStub(t *testing.T) { @@ -185,3 +187,89 @@ func testIdleWatchdogStub(t *testing.T) { return len(dockerListByScenario(scenarioID)) == 0 }) } + +// testChatStub validates the global-chat REST API end-to-end (create, get, +// list, patch, delete) against a real CM binary with a live SQLite store. +// +// What this covers: router wiring, handler logic, SQLite persistence, and +// correct HTTP status codes for all five chat lifecycle operations. +// +// What this defers (follow-up work): +// - Sending a message and receiving SSE events — requires the stub-worker to +// accept /chat/start and the SSE bridge to pump events through. +// - Reopen flow (cold → active → cold) — same dependency on the runner stub. +func testChatStub(t *testing.T) { + scenarioID := "chat" + project := "harness" + + s := bootScenario(t, scenarioID, project) + + // Create chat — no project field means cross-project / no container clone. + var created struct { + ID string `json:"id"` + Title string `json:"title"` + Status string `json:"status"` + } + status, body := s.client.postRaw(t, "/api/chats", map[string]any{"title": "smoke"}, &created) + if status != http.StatusCreated { + t.Fatalf("create chat: HTTP %d body=%s", status, body) + } + if created.ID == "" || created.Status != "cold" { + t.Fatalf("create chat returned unexpected payload: %+v", created) + } + + // Get chat by ID. + var got struct { + ID string `json:"id"` + Status string `json:"status"` + Title string `json:"title"` + } + statusGet := s.client.get(t, "/api/chats/"+created.ID, &got) + if statusGet != http.StatusOK { + t.Fatalf("get chat: HTTP %d", statusGet) + } + if got.ID != created.ID { + t.Fatalf("get chat id mismatch: %s vs %s", got.ID, created.ID) + } + + // List chats — the newly created one must appear. + var list []map[string]any + statusList := s.client.get(t, "/api/chats", &list) + if statusList != http.StatusOK { + t.Fatalf("list chats: HTTP %d", statusList) + } + found := false + for _, c := range list { + if c["id"] == created.ID { + found = true + break + } + } + if !found { + t.Fatalf("list chats: created id %s not in list", created.ID) + } + + // PATCH title. + var patched struct { + Title string `json:"title"` + } + statusPatch, patchBody := s.client.patch(t, "/api/chats/"+created.ID, map[string]any{"title": "renamed"}, &patched) + if statusPatch != http.StatusOK { + t.Fatalf("patch chat: HTTP %d body=%s", statusPatch, patchBody) + } + if patched.Title != "renamed" { + t.Fatalf("patch chat: title not updated: %q", patched.Title) + } + + // DELETE chat. + statusDel := s.client.deleteReq(t, "/api/chats/"+created.ID) + if statusDel != http.StatusNoContent { + t.Fatalf("delete chat: HTTP %d", statusDel) + } + + // GET after delete must return 404. + statusGone := s.client.get(t, "/api/chats/"+created.ID, nil) + if statusGone != http.StatusNotFound { + t.Fatalf("get deleted chat: expected 404 got %d", statusGone) + } +} From acebc99439b87f4bc040f119489adc8ce72f5573 Mon Sep 17 00:00:00 2001 From: Morten Hersson Date: Fri, 15 May 2026 06:49:27 +0200 Subject: [PATCH 3/6] feat(web): multi-pane chat UI with persistent layout - Up to 4 simultaneously open chats with LRU eviction toast - Drag-and-drop from sidebar onto pane drop overlays - Per-pane resize, focus, accent stripe from idColor hash - Tab title stays at ContextMatrix; last_chat_id persisted per pane - Mobile single-pane mode with the full ChatThread header --- web/CLAUDE.md | 122 +++- web/eslint.config.js | 6 + web/package-lock.json | 11 + web/package.json | 1 + web/src/App.tsx | 6 + web/src/api/client.ts | 74 +++ web/src/components/CardPanel/CardChat.tsx | 337 ++--------- web/src/components/ChatLayout/ChatLayout.tsx | 185 ++++++ .../ChatLayout/ChatLayoutContext.tsx | 27 + web/src/components/ChatLayout/ChatPane.tsx | 103 ++++ .../components/ChatLayout/EmptyPanePicker.tsx | 90 +++ .../ChatLayout/PaneAccentStripe.tsx | 12 + web/src/components/ChatLayout/PaneHeader.tsx | 186 ++++++ web/src/components/ChatLayout/index.ts | 18 + web/src/components/ChatLayout/types.ts | 26 + .../components/ChatPanel/ChatPanel.test.tsx | 45 ++ web/src/components/ChatPanel/ChatPanel.tsx | 282 +++++++++ web/src/components/ChatPanel/index.ts | 2 + web/src/components/RunnerConsole/logUtils.ts | 17 +- .../components/Sidebar/ChatSection.test.tsx | 77 +++ web/src/components/Sidebar/ChatSection.tsx | 153 +++++ web/src/components/Sidebar/Sidebar.test.tsx | 6 +- web/src/components/Sidebar/Sidebar.tsx | 6 +- web/src/hooks/useChatLayout.test.tsx | 333 +++++++++++ web/src/hooks/useChatLayout.ts | 406 +++++++++++++ web/src/hooks/useChatLiveData.ts | 56 ++ web/src/hooks/useChatSessions.test.tsx | 36 ++ web/src/hooks/useChatSessions.ts | 68 +++ web/src/hooks/useChatStream.test.tsx | 144 +++++ web/src/hooks/useChatStream.ts | 160 ++++++ web/src/index.css | 544 ++++++++++++++++++ web/src/pages/Chat/ChatHeaderInfo.tsx | 60 ++ web/src/pages/Chat/ChatPage.tsx | 229 ++++++++ web/src/pages/Chat/ChatThread.tsx | 335 +++++++++++ web/src/pages/Chat/NewChatDialog.tsx | 243 ++++++++ web/src/types/index.ts | 48 ++ web/src/utils/chatModels.ts | 61 ++ web/src/utils/colorHash.ts | 15 + 38 files changed, 4230 insertions(+), 300 deletions(-) create mode 100644 web/src/components/ChatLayout/ChatLayout.tsx create mode 100644 web/src/components/ChatLayout/ChatLayoutContext.tsx create mode 100644 web/src/components/ChatLayout/ChatPane.tsx create mode 100644 web/src/components/ChatLayout/EmptyPanePicker.tsx create mode 100644 web/src/components/ChatLayout/PaneAccentStripe.tsx create mode 100644 web/src/components/ChatLayout/PaneHeader.tsx create mode 100644 web/src/components/ChatLayout/index.ts create mode 100644 web/src/components/ChatLayout/types.ts create mode 100644 web/src/components/ChatPanel/ChatPanel.test.tsx create mode 100644 web/src/components/ChatPanel/ChatPanel.tsx create mode 100644 web/src/components/ChatPanel/index.ts create mode 100644 web/src/components/Sidebar/ChatSection.test.tsx create mode 100644 web/src/components/Sidebar/ChatSection.tsx create mode 100644 web/src/hooks/useChatLayout.test.tsx create mode 100644 web/src/hooks/useChatLayout.ts create mode 100644 web/src/hooks/useChatLiveData.ts create mode 100644 web/src/hooks/useChatSessions.test.tsx create mode 100644 web/src/hooks/useChatSessions.ts create mode 100644 web/src/hooks/useChatStream.test.tsx create mode 100644 web/src/hooks/useChatStream.ts create mode 100644 web/src/pages/Chat/ChatHeaderInfo.tsx create mode 100644 web/src/pages/Chat/ChatPage.tsx create mode 100644 web/src/pages/Chat/ChatThread.tsx create mode 100644 web/src/pages/Chat/NewChatDialog.tsx create mode 100644 web/src/utils/chatModels.ts create mode 100644 web/src/utils/colorHash.ts diff --git a/web/CLAUDE.md b/web/CLAUDE.md index 9b6be2c4..f3003841 100644 --- a/web/CLAUDE.md +++ b/web/CLAUDE.md @@ -20,8 +20,9 @@ `docs/gotchas.md`). - `vite.config.ts` must proxy `/api` → `http://localhost:8080` for dev mode. - No `localStorage` usage except: theme preference, palette preference - (`palette` key), human agent ID, last selected project, collapsed - column/card state. + (`palette` key), human agent ID, last selected project, last chat id + (`last_chat_id`), chat section collapse (`sidebar.chat_section_collapsed`), + multi-pane chat layout (`chat_layout`), collapsed column/card state. - Theme state is managed via `ThemeProvider` (in `web/src/hooks/useTheme.ts`) wrapping the app root. Components consume it with `useTheme()`. The markdown editor (`@uiw/react-md-editor`) receives `data-color-mode={theme}` so it @@ -278,6 +279,123 @@ rule, which forbids reading or writing refs during render. Mounting into an already-HITL card lands on the `chat` tab and starts with the rail expanded via the initial `useState(isHITLRunning)` — no transition needed. +## Global Chat tab + +The Chat tab (`/chat`, `/chat/:id`) hosts long-lived chat sessions distinct +from card-scoped HITL chats. `useChatStream(sessionID)` owns transcript +state for the active session via `useRingBuffer(2000)`. The hook pairs the +SSE `/api/chats/{id}/stream` subscription with a REST bootstrap from +`GET /api/chats/{id}/messages?since_seq=0&limit=1000`: + +1. On mount or `sessionID` change, the buffer is cleared and the REST + bootstrap is fetched first. +2. The last bootstrap `seq` is recorded; the SSE subscription opens with + `since_seq=` so it only delivers strictly newer events. +3. Replay overlap (SSE events whose `seq` falls inside the REST window) is + deduped on the client — the seam is gapless without double messages. + +`last_chat_id` localStorage key tracks the focused pane's chat. In the +multi-pane layout (see next section), `useChatLayout` writes the key +whenever focus moves; `ChatThread` only writes it in non-embedded (mobile +single-pane) mode. This preserves backward compat with external readers +that expect a single "current chat" pointer. + +## Multi-pane chat layout + +`/chat` renders a tiled layout of up to 4 simultaneously open chats. The +shell is `ChatLayout` (`web/src/components/ChatLayout/`), composed with +`react-resizable-panels` `PanelGroup`s for the column + row splits. State +lives in `useChatLayout` (`web/src/hooks/useChatLayout.ts`) and is exposed +to descendants via `ChatLayoutContext`. + +### Layout model + +Panes are addressed by `Slot` ('TL' | 'BL' | 'TR' | 'BR'). The hook +normalizes the layout so: + +- 1 pane → only `TL` is occupied (full-width). +- 2 panes → `TL` + `TR` (vertical split). +- 3 panes → either left or right column has a horizontal split (whichever + column held the focused pane when the 3rd pane was opened). +- 4 panes → 2×2 grid. + +Closing a pane runs `normalize()` to collapse the layout (e.g. closing +`TL` promotes `BL → TL`). Column- and row-percentages persist as +`{ col, leftRow, rightRow }` and are clamped 20–80% by the resizable- +panels library. + +### Mutations + +- `openInNewPane(id)` / `openInFocused(id)` — same implementation in v1: + sidebar clicks always auto-tile into a new pane (per the captured build + prompt — no `Cmd`-click distinction). If the chat is already open in + another pane, focus that pane instead of opening a duplicate. +- `swapPaneChat(slot, id)` — drop semantics. If the dropped chat is + already in another pane, the two panes' contents **swap** (same-chat- + twice = swap). If not, the target pane's contents are replaced. +- `splitFromPane(slot)` / `cancelEmptyPane(slot)` — manual "+ split" + button creates an empty pane with a "Pick a chat" picker; Esc cancels. +- `closePane(slot)` — removes the pane (the chat session itself is **not** + deleted; the End / Reopen / Delete actions live on `ChatThread`'s + non-embedded header and are reachable on mobile or by closing other + panes down to one). +- `focus(slot)` — stamps `lastFocusedAt[slot] = Date.now()` for LRU. + +### 5th-chat policy: LRU eviction with undo + +When `openInNewPane` is called and 4 panes are open, the hook evicts the +pane with the smallest `lastFocusedAt` stamp, calls `onLRUEvict({ +victimSlot, victimChatId, incomingChatId, snapshot })`, and `ChatPage` +shows a `chat-evict-toast` (bottom-center, 6s) with an Undo button. Undo +calls `restoreSnapshot(snapshot)` to atomically revert. + +### Persistence + +- `localStorage.chat_layout`: `{ panes, focused, sizes, lastFocusedAt }`, + debounced 300ms. On mount, `loadPersisted()` filters persisted chat IDs + against the current `availableChats` list (dropping stale ids). +- `localStorage.last_chat_id`: written by `useChatLayout` whenever focus + moves (focused pane's chat id only). +- Server-side deletes are reconciled via an effect that watches + `availableChats`: ids no longer in the list are removed from panes. + +### Drag-and-drop from the sidebar + +`ChatSection` lives outside `ChatLayoutProvider`'s subtree (the sidebar +renders above the route outlet). To let pane drop-overlays show the +incoming chat name, `ChatSection` dispatches `cm:chat-drag-start` / +`cm:chat-drag-end` custom events; `ChatPage` listens and forwards to +`layout.setDragging(...)`. Touch devices skip `draggable=true` to avoid +hijacking scroll gestures (`!isTouchDevice()` guard). + +### Routing + +- `/chat` — hydrates the layout from `chat_layout`, renders `ChatLayout`. +- `/chat/:id` — **additive** deep link. The id is opened as a new pane on + top of the hydrated layout (LRU evicts the 5th), then `ChatPage` + redirects to `/chat` so refresh doesn't re-trigger. Uses the in-render + state-marker pattern (`prevDeepLinkId !== deepLinkId`) — not `useEffect` + — so the navigate happens synchronously with the prop change. +- `/chat?new=1` — opens `NewChatDialog`. + +### Mobile (`< 768px`) + +`useMediaQuery('(min-width: 768px)')` toggles single-pane mode. The hook's +state persists across resizes; only one pane (the focused one) is +rendered. Sidebar drag is disabled on touch devices. The full +`ChatThread` (with its End / Reopen / Delete header) is rendered +*non-embedded* on mobile so all chat actions stay reachable. + +### Visual tokens + +All chat-pane CSS lives at the end of `web/src/index.css` under the +"Multi-pane chat layout" header. Mirrors CardPanel: 36px header, mono +font, `.bf-rail-tab` typography, `--bg0` body, `--bg1` header bg, +`--bg3` borders, `--aqua` focused-state glow + resize-handle hover. +Per-chat 2px accent stripe colored by `idColor(chatId)` from +`web/src/utils/colorHash.ts` (shared with RunnerConsole). Drop target +uses the **static glow** variant (no pulse animation). + ## Runner Console The Runner Console is a live log panel that streams output from diff --git a/web/eslint.config.js b/web/eslint.config.js index 5e6b472f..ea570ebe 100644 --- a/web/eslint.config.js +++ b/web/eslint.config.js @@ -19,5 +19,11 @@ export default defineConfig([ ecmaVersion: 2020, globals: globals.browser, }, + rules: { + '@typescript-eslint/no-unused-vars': [ + 'error', + { argsIgnorePattern: '^_', varsIgnorePattern: '^_', caughtErrorsIgnorePattern: '^_' }, + ], + }, }, ]) diff --git a/web/package-lock.json b/web/package-lock.json index 8f32ad24..7c808e3f 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -15,6 +15,7 @@ "@uiw/react-md-editor": "^4.1.0", "react": "^19.2.6", "react-dom": "^19.2.6", + "react-resizable-panels": "^3.0.6", "react-router-dom": "^7.15.0", "tailwindcss": "^4.2.2" }, @@ -5151,6 +5152,16 @@ "react": ">=18" } }, + "node_modules/react-resizable-panels": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/react-resizable-panels/-/react-resizable-panels-3.0.6.tgz", + "integrity": "sha512-b3qKHQ3MLqOgSS+FRYKapNkJZf5EQzuf6+RLiq1/IlTHw99YrZ2NJZLk4hQIzTnnIkRg2LUqyVinu6YWWpUYew==", + "license": "MIT", + "peerDependencies": { + "react": "^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc", + "react-dom": "^16.14.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc" + } + }, "node_modules/react-router": { "version": "7.15.0", "resolved": "https://registry.npmjs.org/react-router/-/react-router-7.15.0.tgz", diff --git a/web/package.json b/web/package.json index c9b5999d..fbdb6f60 100644 --- a/web/package.json +++ b/web/package.json @@ -37,6 +37,7 @@ "@uiw/react-md-editor": "^4.1.0", "react": "^19.2.6", "react-dom": "^19.2.6", + "react-resizable-panels": "^3.0.6", "react-router-dom": "^7.15.0", "tailwindcss": "^4.2.2" }, diff --git a/web/src/App.tsx b/web/src/App.tsx index 1d4611a8..aba90622 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -15,6 +15,9 @@ import type { ProjectConfig } from './types'; const ProjectShell = lazy(() => import('./components/ProjectShell').then((m) => ({ default: m.ProjectShell })) ); +const ChatPage = lazy(() => + import('./pages/Chat/ChatPage').then((m) => ({ default: m.ChatPage })) +); const AllProjectsDashboard = lazy(() => import('./components/AllProjectsDashboard').then((m) => ({ default: m.AllProjectsDashboard })) ); @@ -56,6 +59,7 @@ function AppInner() {
setNewProjectOpen(true)} + onNewChat={() => navigate('/chat?new=1')} mobileOpen={mobileOpen} onMobileClose={onMobileClose} /> @@ -67,6 +71,8 @@ function AppInner() { } /> } /> } /> + } /> + } /> } /> diff --git a/web/src/api/client.ts b/web/src/api/client.ts index f2fd7775..06091d55 100644 --- a/web/src/api/client.ts +++ b/web/src/api/client.ts @@ -19,6 +19,10 @@ import type { RefreshPlan, RefreshJobStatus, RefreshStatusResponse, + ChatSession, + ChatStatus, + ChatMessage, + ChatModelList, } from '../types'; const BASE_URL = '/api'; @@ -371,6 +375,76 @@ class APIClient { { method: 'GET' }, ); } + + // Chat + async listChats(filter: { project?: string; status?: ChatStatus } = {}): Promise { + const q = new URLSearchParams(); + if (filter.project) q.set('project', filter.project); + if (filter.status) q.set('status', filter.status); + const qs = q.toString(); + return this.request(`/chats${qs ? `?${qs}` : ''}`); + } + + async createChat(body: { title?: string; project?: string; model?: string }): Promise { + return this.request('/chats', { + method: 'POST', + body: JSON.stringify(body), + }); + } + + async listChatModels(): Promise { + return this.request('/chats/models'); + } + + async getChat(id: string): Promise { + return this.request(`/chats/${encodeURIComponent(id)}`); + } + + async patchChat(id: string, body: { title?: string }): Promise { + return this.request(`/chats/${encodeURIComponent(id)}`, { + method: 'PATCH', + body: JSON.stringify(body), + }); + } + + async deleteChat(id: string): Promise { + return this.request(`/chats/${encodeURIComponent(id)}`, { method: 'DELETE' }); + } + + async openChat(id: string): Promise { + return this.request(`/chats/${encodeURIComponent(id)}/open`, { + method: 'POST', + body: JSON.stringify({}), + }); + } + + async endChat(id: string): Promise { + return this.request(`/chats/${encodeURIComponent(id)}/end`, { + method: 'POST', + body: JSON.stringify({}), + }); + } + + async sendChatMessage(id: string, content: string): Promise<{ ok: boolean; message_id: string }> { + return this.request<{ ok: boolean; message_id: string }>(`/chats/${encodeURIComponent(id)}/messages`, { + method: 'POST', + body: JSON.stringify({ content }), + }); + } + + async listChatMessages( + id: string, + sinceSeq: number, + limit: number, + ): Promise<{ messages: ChatMessage[] }> { + const qs = new URLSearchParams({ + since_seq: String(sinceSeq), + limit: String(limit), + }); + return this.request<{ messages: ChatMessage[] }>( + `/chats/${encodeURIComponent(id)}/messages?${qs.toString()}`, + ); + } } export const api = new APIClient(); diff --git a/web/src/components/CardPanel/CardChat.tsx b/web/src/components/CardPanel/CardChat.tsx index b72a260e..ce414e65 100644 --- a/web/src/components/CardPanel/CardChat.tsx +++ b/web/src/components/CardPanel/CardChat.tsx @@ -1,18 +1,8 @@ -import { Suspense, lazy, useCallback, useId, useLayoutEffect, useRef, useState } from 'react'; -import { flushSync } from 'react-dom'; +import { useState } from 'react'; import type { Card, LogEntry } from '../../types'; import { api, isAPIError } from '../../api/client'; import { ConfirmModal } from '../ConfirmModal/ConfirmModal'; - -// Lazy-load the markdown previewer so the chat panel doesn't pay the -// bundle cost until the user opens an HITL session. The chat markdown -// styling is fully driven by CSS custom properties scoped to :root and -// [data-theme="light"] (see .wmde-markdown rules in index.css), so dark -// /light switches automatically without data-color-mode. -const MarkdownPreview = lazy(() => import('@uiw/react-markdown-preview')); - -const MAX_MESSAGE_LENGTH = 8000; -const NEAR_BOTTOM_THRESHOLD = 50; +import { ChatPanel } from '../ChatPanel'; interface CardChatProps { card: Card; @@ -20,211 +10,90 @@ interface CardChatProps { } /** - * Two-channel chat panel. Agent output renders as left-aligned bubbles; - * human replies render as right-aligned bubbles. Newlines are preserved - * via `white-space: pre-wrap`. The Send button only lives here — never - * duplicate it in the panel header. + * Card-bound chat wrapper. Composes the generic ChatPanel primitive with + * card-specific bits: the promote-to-autonomous confirm modal, the + * read-only footer text logic (different message for promoted vs ended), + * and the api.sendCardMessage / api.promoteCardToAutonomous calls. * - * The transcript stays visible whenever the parent mounts this component. - * When HITL is no longer active (runner stopped or card promoted to - * autonomous) the compose row and Switch-to-Autonomous button are replaced - * by a thin read-only footer so the conversation is preserved while input - * is closed. + * The transcript and filter bar remain visible even when the session is + * not active (stopped or promoted), with the compose row replaced by a + * read-only footer. */ export function CardChat({ card, cardLogs }: CardChatProps) { - const [message, setMessage] = useState(''); - const [sending, setSending] = useState(false); const [promoting, setPromoting] = useState(false); const [confirmOpen, setConfirmOpen] = useState(false); - const [error, setError] = useState(null); - const [showText, setShowText] = useState(true); - const [showToolCalls, setShowToolCalls] = useState(false); - const [showThinking, setShowThinking] = useState(false); - const messageId = useId(); - const logContainerRef = useRef(null); - const textareaRef = useRef(null); - const userScrolledUpRef = useRef(false); - - const handleLogScroll = useCallback(() => { - const el = logContainerRef.current; - if (!el) return; - const distanceFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight; - userScrolledUpRef.current = distanceFromBottom > NEAR_BOTTOM_THRESHOLD; - }, []); - - // useLayoutEffect pins the scroll before paint so the new content lands at - // the bottom on the same frame, matching VirtualLogList. - useLayoutEffect(() => { - const el = logContainerRef.current; - if (!el) return; - if (userScrolledUpRef.current) return; - el.scrollTop = el.scrollHeight; - }, [cardLogs]); + const [promoteError, setPromoteError] = useState(null); const hitlActive = card.runner_status === 'running' && !card.autonomous; - const filteredLogs = cardLogs.filter((entry) => { - if (entry.type === 'text') return showText; - if (entry.type === 'tool_call') return showToolCalls; - if (entry.type === 'thinking') return showThinking; - return true; - }); - - const isOverLimit = message.length > MAX_MESSAGE_LENGTH; - const canSend = message.trim().length > 0 && !sending && !isOverLimit; - - const handleKeyDown = (e: React.KeyboardEvent) => { - if (e.key === 'Enter' && !e.shiftKey) { - e.preventDefault(); - if (canSend) void handleSend(); - } - }; - - const handleSend = async () => { - const content = message.trim(); - if (!content || sending || isOverLimit) return; - setSending(true); + const handleSend = async (content: string) => { try { await api.sendCardMessage(card.project, card.id, content); - setMessage(''); - setError(null); } catch (err) { - setError(isAPIError(err) ? err.error : 'Failed to send message'); - } finally { - // Browsers drop focus() calls against a disabled input. setSending(false) - // only queues the flip — flushSync commits it before the imperative focus - // so the user can keep typing without re-clicking the textarea. - flushSync(() => setSending(false)); - textareaRef.current?.focus(); + // Rethrow as Error so ChatPanel's internal error display shows the + // API error message. Preserve the original via `cause` so devtools / + // future telemetry can still see the underlying APIError. + const msg = isAPIError(err) ? err.error : 'Failed to send message'; + throw new Error(msg, { cause: err }); } }; const handlePromoteConfirm = async () => { setConfirmOpen(false); setPromoting(true); - setError(null); + setPromoteError(null); try { await api.promoteCardToAutonomous(card.project, card.id); } catch (err) { - setError(isAPIError(err) ? err.error : 'Failed to promote session'); + setPromoteError(isAPIError(err) ? err.error : 'Failed to promote session'); } finally { setPromoting(false); } }; - return ( -
- {/* Filter bar */} -
- - - -
- - {/* Log column */} -
+
- - {hitlActive ? ( - <> - {/* Compose */} -
- -