diff --git a/Cargo.lock b/Cargo.lock index 8a062d75a5..1f23a1de55 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5141,6 +5141,7 @@ dependencies = [ "serde-big-array", "serde_json", "serde_yaml", + "sha1", "sha2 0.10.9", "shellexpand", "socketioxide", diff --git a/Cargo.toml b/Cargo.toml index df843a7e2d..b1c4e5644c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,9 @@ argon2 = "0.5" rand = "0.10" dirs = "5" sha2 = "0.10" +# Legacy SHA-1 only used for Tencent COS HMAC-SHA1 signing (yuanbao +# channel media upload). Not used for any new security-sensitive work. +sha1 = "0.10" hmac = "0.12" # Archive extraction for the Node.js runtime bootstrap. Unix Node # distributions ship as .tar.xz, Windows as .zip. `xz2` with `static` diff --git a/app/src-tauri/Cargo.lock b/app/src-tauri/Cargo.lock index 74ff66efb9..18ec6d2660 100644 --- a/app/src-tauri/Cargo.lock +++ b/app/src-tauri/Cargo.lock @@ -5286,6 +5286,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "sha1", "sha2 0.10.9", "shellexpand", "socketioxide", diff --git a/app/src/components/channels/ChannelSelector.tsx b/app/src/components/channels/ChannelSelector.tsx index 6e41f598a0..770587146b 100644 --- a/app/src/components/channels/ChannelSelector.tsx +++ b/app/src/components/channels/ChannelSelector.tsx @@ -5,6 +5,7 @@ import { useT } from '../../lib/i18n/I18nContext'; import { useAppSelector } from '../../store/hooks'; import type { ChannelConnectionStatus, ChannelDefinition, ChannelType } from '../../types/channels'; import ChannelStatusBadge from './ChannelStatusBadge'; +import YuanbaoIcon from './YuanbaoIcon'; interface ChannelSelectorProps { definitions: ChannelDefinition[]; @@ -12,17 +13,28 @@ interface ChannelSelectorProps { onSelectChannel: (channel: ChannelType) => void; } +// Emoji icons for channels rendered as plain text. `yuanbao` is handled +// separately with a branded SVG (see `YuanbaoIcon`). const CHANNEL_ICONS: Record = { telegram: '✈️', discord: '🎮', web: '🌐', + yuanbao: '🟡', mcp: '🔌', }; +const renderChannelIcon = (icon: string) => + icon === 'yuanbao' ? ( + + ) : ( + {CHANNEL_ICONS[icon] ?? ''} + ); + /** Virtual (static) tabs that are not backed by a ChannelDefinition from the core. */ const VIRTUAL_TABS: { id: ChannelType; display_name: string }[] = [ { id: 'mcp', display_name: 'MCP Servers' }, ]; + const CHANNEL_STATUS_PRIORITY: ChannelConnectionStatus[] = [ 'connected', 'connecting', @@ -84,7 +96,7 @@ const ChannelSelector = ({ : 'border-stone-200 dark:border-neutral-800 bg-stone-50 dark:bg-neutral-800/60 text-stone-600 dark:text-neutral-300 hover:border-stone-300 dark:hover:border-neutral-700' }`}> - {CHANNEL_ICONS[def.icon] ?? ''} + {renderChannelIcon(def.icon)} {def.display_name} @@ -105,7 +117,7 @@ const ChannelSelector = ({ ? 'border-primary-500/60 bg-primary-50 dark:bg-primary-500/15 text-primary-600 dark:text-primary-300' : 'border-stone-200 dark:border-neutral-800 bg-stone-50 dark:bg-neutral-800/60 text-stone-600 dark:text-neutral-300 hover:border-stone-300 dark:hover:border-neutral-700' }`}> - {CHANNEL_ICONS[tab.id] ?? ''} + {renderChannelIcon(tab.id)} {tab.display_name} ); diff --git a/app/src/components/channels/ChannelSetupModal.tsx b/app/src/components/channels/ChannelSetupModal.tsx index cf81b7c7e2..38c58baafe 100644 --- a/app/src/components/channels/ChannelSetupModal.tsx +++ b/app/src/components/channels/ChannelSetupModal.tsx @@ -9,7 +9,12 @@ import { useT } from '../../lib/i18n/I18nContext'; import type { ChannelDefinition, ChannelType } from '../../types/channels'; import DiscordConfig from './DiscordConfig'; import TelegramConfig from './TelegramConfig'; +import YuanbaoConfig from './YuanbaoConfig'; +import YuanbaoIcon from './YuanbaoIcon'; +// Emoji icons for channels rendered as plain text. `yuanbao` is handled +// separately with a branded SVG (see `YuanbaoIcon`) — matches the +// rendering used in `ChannelSelector`. const CHANNEL_ICONS: Record = { telegram: '\u2708\uFE0F', discord: '\uD83C\uDFAE', @@ -29,6 +34,8 @@ function ChannelConfigContent({ definition }: { definition: ChannelDefinition }) return ; case 'discord': return ; + case 'yuanbao': + return ; default: return (

@@ -62,7 +69,8 @@ export default function ChannelSetupModal({ definition, onClose }: ChannelSetupM if (e.target === e.currentTarget) onClose(); }; - const icon = CHANNEL_ICONS[definition.icon] ?? ''; + const emojiIcon = CHANNEL_ICONS[definition.icon] ?? ''; + const isYuanbao = definition.icon === 'yuanbao'; const modalContent = (

- {icon && {icon}} + {isYuanbao ? ( + + ) : ( + emojiIcon && {emojiIcon} + )}

diff --git a/app/src/components/channels/YuanbaoConfig.tsx b/app/src/components/channels/YuanbaoConfig.tsx new file mode 100644 index 0000000000..3ba1f4198f --- /dev/null +++ b/app/src/components/channels/YuanbaoConfig.tsx @@ -0,0 +1,316 @@ +import debug from 'debug'; +import { useCallback, useEffect, useState } from 'react'; + +import { AUTH_MODE_LABELS } from '../../lib/channels/definitions'; +import { useT } from '../../lib/i18n/I18nContext'; +import { channelConnectionsApi } from '../../services/api/channelConnectionsApi'; +import { + disconnectChannelConnection, + setChannelConnectionStatus, + upsertChannelConnection, +} from '../../store/channelConnectionsSlice'; +import { useAppDispatch, useAppSelector } from '../../store/hooks'; +import type { ChannelConnectionStatus, ChannelDefinition } from '../../types/channels'; +import { restartCoreProcess } from '../../utils/tauriCommands/core'; +import ChannelFieldInput from './ChannelFieldInput'; +import ChannelStatusBadge from './ChannelStatusBadge'; + +const log = debug('channels:yuanbao'); + +interface YuanbaoConfigProps { + definition: ChannelDefinition; +} + +const YuanbaoConfig = ({ definition }: YuanbaoConfigProps) => { + const { t } = useT(); + const dispatch = useAppDispatch(); + const channelConnections = useAppSelector(state => state.channelConnections); + + const [busy, setBusy] = useState(false); + const [fieldValues, setFieldValues] = useState>({ + app_key: '', + app_secret: '', + }); + // Per-field inline validation errors, keyed by field.key. + const [fieldErrors, setFieldErrors] = useState>({}); + + const updateField = useCallback((fieldKey: string, value: string) => { + setFieldValues(prev => ({ ...prev, [fieldKey]: value })); + // Clear the error for this field as the user types. + setFieldErrors(prev => { + if (!prev[fieldKey]) return prev; + const next = { ...prev }; + delete next[fieldKey]; + return next; + }); + }, []); + + const spec = definition.auth_modes[0]; + + // On mount, reset any stale 'connecting' state persisted from a previous session. + useEffect(() => { + if (!spec) return; + const conn = channelConnections.connections.yuanbao?.[spec.mode]; + if (conn?.status === 'connecting') { + dispatch( + setChannelConnectionStatus({ + channel: 'yuanbao', + authMode: spec.mode, + status: 'disconnected', + }) + ); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + // All useCallback hooks must be called unconditionally. + const handleConnect = useCallback(() => { + console.log('[YuanbaoConfig] handleConnect: 1.entry, spec=', spec); + if (!spec) { + console.warn('[YuanbaoConfig] handleConnect: aborted — spec is null'); + return; + } + + const errors: Record = {}; + for (const field of spec.fields) { + const empty = !fieldValues[field.key]?.trim(); + if (field.required && empty) { + errors[field.key] = t('channels.yuanbao.fieldRequired').replace('{field}', field.label); + } + } + if (Object.keys(errors).length > 0) { + console.warn('[YuanbaoConfig] handleConnect: 2.validation FAILED', errors); + setFieldErrors(errors); + return; + } + console.log('[YuanbaoConfig] handleConnect: 2.validation passed'); + + setFieldErrors({}); + setBusy(true); + + dispatch( + setChannelConnectionStatus({ channel: 'yuanbao', authMode: spec.mode, status: 'connecting' }) + ); + + const credentials: Record = {}; + for (const field of spec.fields) { + const val = fieldValues[field.key]?.trim() ?? ''; + if (val) credentials[field.key] = val; + } + console.log( + '[YuanbaoConfig] handleConnect: 3.dispatched connecting, credential keys=', + Object.keys(credentials) + ); + + void (async () => { + try { + console.log( + '[YuanbaoConfig] handleConnect: 4.before channels_connect RPC, authMode=', + spec.mode + ); + log('connecting yuanbao via %s', spec.mode); + const result = await channelConnectionsApi.connectChannel('yuanbao', { + authMode: spec.mode, + credentials, + }); + console.log('[YuanbaoConfig] handleConnect: 5.RPC returned', result); + log('connect result: %o', result); + + // Only treat explicit "connected" as success. Any other status + // (e.g. "pending_auth" if a future auth flow gets added) must + // surface as an error instead of silently dispatching connected. + if (result.status !== 'connected') { + const msg = t('channels.yuanbao.unexpectedStatus').replace( + '{status}', + result.status ?? '' + ); + console.warn('[YuanbaoConfig] handleConnect: 6.unexpected status', result.status); + dispatch( + setChannelConnectionStatus({ + channel: 'yuanbao', + authMode: spec.mode, + status: 'error', + lastError: msg, + }) + ); + return; + } + + if (result.restart_required) { + console.log( + '[YuanbaoConfig] handleConnect: 6.restart_required=true, calling restartCoreProcess' + ); + log('restart required after connect — restarting core process'); + try { + await restartCoreProcess(); + console.log( + '[YuanbaoConfig] handleConnect: 7.restartCoreProcess resolved, dispatching connected' + ); + dispatch( + upsertChannelConnection({ + channel: 'yuanbao', + authMode: spec.mode, + patch: { + status: 'connected', + lastError: undefined, + capabilities: ['read', 'write'], + }, + }) + ); + } catch (restartErr) { + const msg = restartErr instanceof Error ? restartErr.message : String(restartErr); + console.error('[YuanbaoConfig] handleConnect: 7.restartCoreProcess FAILED', restartErr); + log('core restart failed: %s', msg); + dispatch( + setChannelConnectionStatus({ + channel: 'yuanbao', + authMode: spec.mode, + status: 'error', + lastError: t('channels.telegram.savedRestartRequired'), + }) + ); + } + } else { + console.log( + '[YuanbaoConfig] handleConnect: 6.restart_required=false, dispatching connected' + ); + dispatch( + upsertChannelConnection({ + channel: 'yuanbao', + authMode: spec.mode, + patch: { status: 'connected', lastError: undefined, capabilities: ['read', 'write'] }, + }) + ); + } + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + console.error('[YuanbaoConfig] handleConnect: X.caught error', e); + dispatch( + setChannelConnectionStatus({ + channel: 'yuanbao', + authMode: spec.mode, + status: 'error', + lastError: msg, + }) + ); + } finally { + console.log('[YuanbaoConfig] handleConnect: 8.finally, setBusy(false)'); + setBusy(false); + } + })(); + }, [dispatch, fieldValues, spec, t]); + + const handleDisconnect = useCallback(() => { + if (!spec) return; + setBusy(true); + void (async () => { + try { + log('disconnecting yuanbao'); + await channelConnectionsApi.disconnectChannel('yuanbao', spec.mode); + dispatch(disconnectChannelConnection({ channel: 'yuanbao', authMode: spec.mode })); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + dispatch( + setChannelConnectionStatus({ + channel: 'yuanbao', + authMode: spec.mode, + status: 'error', + lastError: msg, + }) + ); + } finally { + setBusy(false); + } + })(); + }, [dispatch, spec]); + + if (!spec) return null; + + const connection = channelConnections.connections.yuanbao?.[spec.mode]; + const status: ChannelConnectionStatus = connection?.status ?? 'disconnected'; + + return ( +
+
+
+
+

+ {AUTH_MODE_LABELS[spec.mode] ?? spec.mode} +

+

{spec.description}

+ {connection?.lastError && ( +

{connection.lastError}

+ )} +
+ +
+ + {spec.fields.length > 0 && ( +
+ {spec.fields.map(field => { + return ( +
+ updateField(field.key, val)} + disabled={busy} + /> + {fieldErrors[field.key] && ( +

+ {fieldErrors[field.key]} +

+ )} +
+ ); + })} +
+ )} + +
+ + +
+
+
+ ); +}; + +export default YuanbaoConfig; diff --git a/app/src/components/channels/YuanbaoIcon.tsx b/app/src/components/channels/YuanbaoIcon.tsx new file mode 100644 index 0000000000..5a1261b9b9 --- /dev/null +++ b/app/src/components/channels/YuanbaoIcon.tsx @@ -0,0 +1,56 @@ +import { useId } from 'react'; + +interface YuanbaoIconProps { + /** + * Tailwind size + color overrides. Defaults to a 20px box, matching + * the visual weight of the channel-row emojis it sits next to. + */ + className?: string; +} + +/** + * Brand mark for the Yuanbao channel. Inlined as an SVG component so it + * can be tinted / sized via Tailwind without round-tripping through an + * `` element. `clipPath` ids are generated with `useId` so multiple + * instances on the same page (channel selector + setup modal) don't + * collide in the DOM. + */ +const YuanbaoIcon = ({ className = 'w-5 h-5' }: YuanbaoIconProps) => { + const clipId = useId(); + return ( + + ); +}; + +export default YuanbaoIcon; diff --git a/app/src/components/channels/__tests__/ChannelSetupModal.test.tsx b/app/src/components/channels/__tests__/ChannelSetupModal.test.tsx new file mode 100644 index 0000000000..ae521baa31 --- /dev/null +++ b/app/src/components/channels/__tests__/ChannelSetupModal.test.tsx @@ -0,0 +1,72 @@ +import { screen } from '@testing-library/react'; +import { describe, expect, it, vi } from 'vitest'; + +import { FALLBACK_DEFINITIONS } from '../../../lib/channels/definitions'; +import { renderWithProviders } from '../../../test/test-utils'; +import type { ChannelDefinition } from '../../../types/channels'; +import ChannelSetupModal from '../ChannelSetupModal'; + +// YuanbaoConfig pulls in API + Tauri helpers we don't need for the routing +// branches under test — stub it so we only assert ChannelSetupModal's own +// behavior (icon branch + yuanbao switch case). +vi.mock('../YuanbaoConfig', () => ({ + default: () =>
Yuanbao Config
, +})); + +vi.mock('../TelegramConfig', () => ({ + default: () =>
Telegram Config
, +})); + +vi.mock('../DiscordConfig', () => ({ + default: () =>
Discord Config
, +})); + +const yuanbaoDef: ChannelDefinition = { + id: 'yuanbao', + display_name: '元宝', + description: '通过元宝(Yuanbao)机器人收发消息。', + icon: 'yuanbao', + auth_modes: [ + { + mode: 'api_key', + description: '提供元宝开放平台的 AppID 和 AppSecret。', + fields: [], + auth_action: undefined, + }, + ], + capabilities: ['send_text', 'receive_text'], +}; + +describe('ChannelSetupModal', () => { + it('renders the YuanbaoConfig body and brand SVG icon for the yuanbao channel', () => { + renderWithProviders( {}} />); + // Header title + body routing both exercised. + expect(screen.getByText('元宝')).toBeInTheDocument(); + expect(screen.getByTestId('yuanbao-config')).toBeInTheDocument(); + // YuanbaoIcon emits an aria-hidden SVG in the header; the emoji-based + // fallback should NOT also render for yuanbao. + const dialog = screen.getByRole('dialog'); + expect(dialog.querySelector('svg[aria-hidden="true"]')).not.toBeNull(); + }); + + it('renders the emoji icon and TelegramConfig body for the telegram channel', () => { + const telegramDef = FALLBACK_DEFINITIONS.find(d => d.id === 'telegram')!; + renderWithProviders( {}} />); + expect(screen.getByTestId('telegram-config')).toBeInTheDocument(); + // Emoji branch produces a span sibling to the title. + expect(screen.getByText('\u2708\uFE0F')).toBeInTheDocument(); + }); + + it('falls back to the unavailable-channel message for an unknown channel id', () => { + const unknown: ChannelDefinition = { ...yuanbaoDef, id: 'unknown', display_name: 'Unknown' }; + renderWithProviders( {}} />); + expect(screen.getByText(/Configuration for/i)).toBeInTheDocument(); + }); + + it('invokes onClose when the Escape key is pressed', () => { + const onClose = vi.fn(); + renderWithProviders(); + document.dispatchEvent(new KeyboardEvent('keydown', { key: 'Escape' })); + expect(onClose).toHaveBeenCalledTimes(1); + }); +}); diff --git a/app/src/components/channels/__tests__/YuanbaoConfig.test.tsx b/app/src/components/channels/__tests__/YuanbaoConfig.test.tsx new file mode 100644 index 0000000000..404f44d6cf --- /dev/null +++ b/app/src/components/channels/__tests__/YuanbaoConfig.test.tsx @@ -0,0 +1,287 @@ +import { fireEvent, screen, waitFor } from '@testing-library/react'; +import { afterEach, describe, expect, it, vi } from 'vitest'; + +import { channelConnectionsApi } from '../../../services/api/channelConnectionsApi'; +import { setChannelConnectionStatus } from '../../../store/channelConnectionsSlice'; +import { createTestStore, renderWithProviders } from '../../../test/test-utils'; +import type { ChannelDefinition } from '../../../types/channels'; +import { restartCoreProcess } from '../../../utils/tauriCommands/core'; +import YuanbaoConfig from '../YuanbaoConfig'; + +vi.mock('../../../services/api/channelConnectionsApi', () => ({ + channelConnectionsApi: { connectChannel: vi.fn(), disconnectChannel: vi.fn() }, +})); + +vi.mock('../../../utils/tauriCommands/core', () => ({ restartCoreProcess: vi.fn() })); + +// Mirrors the backend yuanbao_definition() in +// src/openhuman/channels/controllers/definitions.rs — kept inline because +// the frontend fallback definitions list does not (yet) include yuanbao. +const yuanbaoDef: ChannelDefinition = { + id: 'yuanbao', + display_name: '元宝', + description: '通过元宝(Yuanbao)机器人收发消息。', + icon: 'yuanbao', + auth_modes: [ + { + mode: 'api_key', + description: '提供元宝开放平台的 AppID 和 AppSecret。', + fields: [ + { + key: 'app_key', + label: 'AppID', + field_type: 'string', + required: true, + placeholder: '元宝开放平台 AppID', + }, + { + key: 'app_secret', + label: 'AppSecret', + field_type: 'secret', + required: true, + placeholder: '元宝开放平台 AppSecret', + }, + ], + auth_action: undefined, + }, + ], + capabilities: ['send_text', 'receive_text', 'typing'], +}; + +afterEach(() => { + vi.clearAllMocks(); +}); + +describe('YuanbaoConfig', () => { + it('renders the api_key mode label, description, and credential fields', () => { + renderWithProviders(); + expect(screen.getByText('Use your own API Key')).toBeInTheDocument(); + expect(screen.getByText(/AppID 和 AppSecret/)).toBeInTheDocument(); + expect(screen.getByPlaceholderText('元宝开放平台 AppID')).toBeInTheDocument(); + expect(screen.getByPlaceholderText('元宝开放平台 AppSecret')).toBeInTheDocument(); + }); + + it('shows a Connect and a (disabled) Disconnect button by default', () => { + renderWithProviders(); + expect(screen.getByText('Connect')).toBeInTheDocument(); + const disconnect = screen.getByText('Disconnect'); + expect(disconnect).toBeDisabled(); + }); + + it('returns null when the definition has no auth modes', () => { + const empty: ChannelDefinition = { ...yuanbaoDef, auth_modes: [] }; + const { container } = renderWithProviders(); + expect(container.firstChild).toBeNull(); + }); + + it('shows inline validation errors when required fields are empty and clears them on input', () => { + renderWithProviders(); + fireEvent.click(screen.getByText('Connect')); + + // Two required fields → two inline error messages. + const appKeyError = screen + .getAllByText(/AppID/) + .filter(node => node.className.includes('text-coral')); + expect(appKeyError.length).toBeGreaterThan(0); + expect(channelConnectionsApi.connectChannel).not.toHaveBeenCalled(); + + // Typing into a field clears that field's error (covers updateField + // branch that mutates fieldErrors). + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppID'), { + target: { value: 'app-key-123' }, + }); + expect( + screen.queryAllByText(/AppID/).filter(node => node.className.includes('text-coral')).length + ).toBe(0); + }); + + it('connects successfully and dispatches connected when restart is not required', async () => { + vi.mocked(channelConnectionsApi.connectChannel).mockResolvedValue({ + status: 'connected', + restart_required: false, + }); + + const { store } = renderWithProviders(); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppID'), { + target: { value: 'app-key-123' }, + }); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppSecret'), { + target: { value: 'app-secret-xyz' }, + }); + fireEvent.click(screen.getByText('Connect')); + + await waitFor(() => { + expect(channelConnectionsApi.connectChannel).toHaveBeenCalledWith('yuanbao', { + authMode: 'api_key', + credentials: { app_key: 'app-key-123', app_secret: 'app-secret-xyz' }, + }); + }); + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('connected'); + expect(conn?.capabilities).toEqual(['read', 'write']); + }); + expect(restartCoreProcess).not.toHaveBeenCalled(); + }); + + it('calls restartCoreProcess and dispatches connected when restart_required=true', async () => { + vi.mocked(channelConnectionsApi.connectChannel).mockResolvedValue({ + status: 'connected', + restart_required: true, + }); + vi.mocked(restartCoreProcess).mockResolvedValue(); + + const { store } = renderWithProviders(); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppID'), { + target: { value: 'app-key-123' }, + }); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppSecret'), { + target: { value: 'app-secret-xyz' }, + }); + fireEvent.click(screen.getByText('Connect')); + + await waitFor(() => { + expect(restartCoreProcess).toHaveBeenCalledTimes(1); + }); + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('connected'); + }); + }); + + it('marks the channel as error when restartCoreProcess throws after a successful connect', async () => { + vi.mocked(channelConnectionsApi.connectChannel).mockResolvedValue({ + status: 'connected', + restart_required: true, + }); + vi.mocked(restartCoreProcess).mockRejectedValue(new Error('core restart failed')); + + const { store } = renderWithProviders(); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppID'), { + target: { value: 'app-key-123' }, + }); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppSecret'), { + target: { value: 'app-secret-xyz' }, + }); + fireEvent.click(screen.getByText('Connect')); + + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('error'); + expect(conn?.lastError).toBeTruthy(); + }); + }); + + it('surfaces an error when the backend returns a non-connected status', async () => { + vi.mocked(channelConnectionsApi.connectChannel).mockResolvedValue({ + status: 'pending_auth', + restart_required: false, + }); + + const { store } = renderWithProviders(); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppID'), { + target: { value: 'app-key-123' }, + }); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppSecret'), { + target: { value: 'app-secret-xyz' }, + }); + fireEvent.click(screen.getByText('Connect')); + + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('error'); + expect(conn?.lastError).toContain('pending_auth'); + }); + }); + + it('captures connect failures from the API and dispatches an error status', async () => { + vi.mocked(channelConnectionsApi.connectChannel).mockRejectedValue( + new Error('invalid credentials') + ); + + const { store } = renderWithProviders(); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppID'), { + target: { value: 'app-key-123' }, + }); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppSecret'), { + target: { value: 'app-secret-xyz' }, + }); + fireEvent.click(screen.getByText('Connect')); + + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('error'); + expect(conn?.lastError).toBe('invalid credentials'); + }); + }); + + it('disconnects an active channel via the API and clears the connection', async () => { + const store = createTestStore(); + store.dispatch( + setChannelConnectionStatus({ channel: 'yuanbao', authMode: 'api_key', status: 'connected' }) + ); + vi.mocked(channelConnectionsApi.disconnectChannel).mockResolvedValue(); + + renderWithProviders(, { store }); + + // Status is connected → Reconnect label appears on the primary button. + expect(screen.getByText('Reconnect')).toBeInTheDocument(); + const disconnect = screen.getByText('Disconnect'); + expect(disconnect).not.toBeDisabled(); + fireEvent.click(disconnect); + + await waitFor(() => { + expect(channelConnectionsApi.disconnectChannel).toHaveBeenCalledWith('yuanbao', 'api_key'); + }); + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('disconnected'); + }); + }); + + it('reports an error status when the disconnect API call fails', async () => { + const store = createTestStore(); + store.dispatch( + setChannelConnectionStatus({ channel: 'yuanbao', authMode: 'api_key', status: 'connected' }) + ); + vi.mocked(channelConnectionsApi.disconnectChannel).mockRejectedValue( + new Error('rpc unreachable') + ); + + renderWithProviders(, { store }); + fireEvent.click(screen.getByText('Disconnect')); + + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('error'); + expect(conn?.lastError).toBe('rpc unreachable'); + }); + }); + + it('resets a stale "connecting" status from a previous session on mount', () => { + const store = createTestStore(); + store.dispatch( + setChannelConnectionStatus({ channel: 'yuanbao', authMode: 'api_key', status: 'connecting' }) + ); + + renderWithProviders(, { store }); + + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('disconnected'); + }); + + it('renders the last error message when the connection is in an error state', () => { + const store = createTestStore(); + store.dispatch( + setChannelConnectionStatus({ + channel: 'yuanbao', + authMode: 'api_key', + status: 'error', + lastError: 'sign verification failed', + }) + ); + + renderWithProviders(, { store }); + expect(screen.getByText('sign verification failed')).toBeInTheDocument(); + }); +}); diff --git a/app/src/components/channels/__tests__/YuanbaoIcon.test.tsx b/app/src/components/channels/__tests__/YuanbaoIcon.test.tsx new file mode 100644 index 0000000000..067535c224 --- /dev/null +++ b/app/src/components/channels/__tests__/YuanbaoIcon.test.tsx @@ -0,0 +1,37 @@ +import { render } from '@testing-library/react'; +import { describe, expect, it } from 'vitest'; + +import YuanbaoIcon from '../YuanbaoIcon'; + +describe('YuanbaoIcon', () => { + it('renders an inline SVG with the default size class', () => { + const { container } = render(); + const svg = container.querySelector('svg'); + expect(svg).not.toBeNull(); + expect(svg).toHaveAttribute('aria-hidden', 'true'); + expect(svg?.getAttribute('class')).toContain('w-5'); + expect(svg?.getAttribute('class')).toContain('h-5'); + }); + + it('applies a custom className override', () => { + const { container } = render(); + const svg = container.querySelector('svg'); + expect(svg?.getAttribute('class')).toBe('w-10 h-10 text-amber-500'); + }); + + it('generates a unique clipPath id per instance so duplicate icons do not collide', () => { + const { container } = render( + <> + + + + ); + const clips = container.querySelectorAll('clipPath'); + expect(clips.length).toBe(2); + const id1 = clips[0].getAttribute('id'); + const id2 = clips[1].getAttribute('id'); + expect(id1).toBeTruthy(); + expect(id2).toBeTruthy(); + expect(id1).not.toBe(id2); + }); +}); diff --git a/app/src/components/skills/skillIcons.tsx b/app/src/components/skills/skillIcons.tsx index b683fb1a23..1c173308e9 100644 --- a/app/src/components/skills/skillIcons.tsx +++ b/app/src/components/skills/skillIcons.tsx @@ -2,6 +2,8 @@ import type { ReactNode } from 'react'; import type { IconType } from 'react-icons'; import { FaDiscord, FaGlobe, FaTelegramPlane } from 'react-icons/fa'; import { IoChatbubble } from 'react-icons/io5'; + +import YuanbaoIcon from '../channels/YuanbaoIcon'; import { LuBlocks, LuBot, @@ -81,6 +83,14 @@ export const CHANNEL_ICONS: Record = { iconClassName="text-[#34C759]" /> ), + yuanbao: ( + + + + ), }; const CATEGORY_META: Record< diff --git a/app/src/lib/i18n/chunks/ar-3.ts b/app/src/lib/i18n/chunks/ar-3.ts index d2462092dc..b669127999 100644 --- a/app/src/lib/i18n/chunks/ar-3.ts +++ b/app/src/lib/i18n/chunks/ar-3.ts @@ -398,6 +398,9 @@ const ar3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default ar3; diff --git a/app/src/lib/i18n/chunks/bn-3.ts b/app/src/lib/i18n/chunks/bn-3.ts index 0c9a8b5a25..7b205b18fa 100644 --- a/app/src/lib/i18n/chunks/bn-3.ts +++ b/app/src/lib/i18n/chunks/bn-3.ts @@ -401,6 +401,9 @@ const bn3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default bn3; diff --git a/app/src/lib/i18n/chunks/de-3.ts b/app/src/lib/i18n/chunks/de-3.ts index 42a08fa11a..f5a1b94626 100644 --- a/app/src/lib/i18n/chunks/de-3.ts +++ b/app/src/lib/i18n/chunks/de-3.ts @@ -413,6 +413,9 @@ const de3: TranslationMap = { 'channels.web.description': 'Chatte über die integrierte Web-Oberfläche.', 'channels.web.authMode.managed_dm.description': 'Nutze den eingebetteten Web-Chat — keine Einrichtung erforderlich.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default de3; diff --git a/app/src/lib/i18n/chunks/en-3.ts b/app/src/lib/i18n/chunks/en-3.ts index 1f9291f86a..a91e09aed2 100644 --- a/app/src/lib/i18n/chunks/en-3.ts +++ b/app/src/lib/i18n/chunks/en-3.ts @@ -401,6 +401,9 @@ const en3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default en3; diff --git a/app/src/lib/i18n/chunks/es-3.ts b/app/src/lib/i18n/chunks/es-3.ts index 14b8efa238..cf1529a587 100644 --- a/app/src/lib/i18n/chunks/es-3.ts +++ b/app/src/lib/i18n/chunks/es-3.ts @@ -406,6 +406,9 @@ const es3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default es3; diff --git a/app/src/lib/i18n/chunks/fr-3.ts b/app/src/lib/i18n/chunks/fr-3.ts index 9435715197..43c164b289 100644 --- a/app/src/lib/i18n/chunks/fr-3.ts +++ b/app/src/lib/i18n/chunks/fr-3.ts @@ -407,6 +407,9 @@ const fr3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default fr3; diff --git a/app/src/lib/i18n/chunks/hi-3.ts b/app/src/lib/i18n/chunks/hi-3.ts index 0ab30ff9e4..c70c3ac249 100644 --- a/app/src/lib/i18n/chunks/hi-3.ts +++ b/app/src/lib/i18n/chunks/hi-3.ts @@ -403,6 +403,9 @@ const hi3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default hi3; diff --git a/app/src/lib/i18n/chunks/id-3.ts b/app/src/lib/i18n/chunks/id-3.ts index 72b9d5c9fc..7d2afd70ad 100644 --- a/app/src/lib/i18n/chunks/id-3.ts +++ b/app/src/lib/i18n/chunks/id-3.ts @@ -406,6 +406,9 @@ const id3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default id3; diff --git a/app/src/lib/i18n/chunks/it-3.ts b/app/src/lib/i18n/chunks/it-3.ts index 2a918283f3..758d48122f 100644 --- a/app/src/lib/i18n/chunks/it-3.ts +++ b/app/src/lib/i18n/chunks/it-3.ts @@ -406,6 +406,9 @@ const it3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default it3; diff --git a/app/src/lib/i18n/chunks/ko-3.ts b/app/src/lib/i18n/chunks/ko-3.ts index 7543d247db..e96e89d7b1 100644 --- a/app/src/lib/i18n/chunks/ko-3.ts +++ b/app/src/lib/i18n/chunks/ko-3.ts @@ -403,5 +403,8 @@ const ko3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default ko3; diff --git a/app/src/lib/i18n/chunks/pt-3.ts b/app/src/lib/i18n/chunks/pt-3.ts index b43cf61bd1..bcf4ef442b 100644 --- a/app/src/lib/i18n/chunks/pt-3.ts +++ b/app/src/lib/i18n/chunks/pt-3.ts @@ -405,6 +405,9 @@ const pt3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default pt3; diff --git a/app/src/lib/i18n/chunks/ru-3.ts b/app/src/lib/i18n/chunks/ru-3.ts index d4c47e3ee7..a94c12d08b 100644 --- a/app/src/lib/i18n/chunks/ru-3.ts +++ b/app/src/lib/i18n/chunks/ru-3.ts @@ -402,6 +402,9 @@ const ru3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default ru3; diff --git a/app/src/lib/i18n/chunks/zh-CN-3.ts b/app/src/lib/i18n/chunks/zh-CN-3.ts index 0e8a794db8..b288b3db50 100644 --- a/app/src/lib/i18n/chunks/zh-CN-3.ts +++ b/app/src/lib/i18n/chunks/zh-CN-3.ts @@ -396,6 +396,9 @@ const zhCN3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': '通过内置的 Web UI 聊天。', 'channels.web.authMode.managed_dm.description': '使用嵌入式 Web 聊天 — 无需设置。', + 'channels.yuanbao.connecting': '连接中…', + 'channels.yuanbao.fieldRequired': '{field} 不能为空', + 'channels.yuanbao.unexpectedStatus': '意外的连接状态:{status}', }; export default zhCN3; diff --git a/app/src/lib/i18n/en.ts b/app/src/lib/i18n/en.ts index ede7298b3c..8ab165f2d1 100644 --- a/app/src/lib/i18n/en.ts +++ b/app/src/lib/i18n/en.ts @@ -1473,6 +1473,9 @@ const en: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', 'chat.unsubscribeApproval.approve': 'Approve & Unsubscribe', 'chat.unsubscribeApproval.approved': '✓ Successfully unsubscribed.', 'chat.unsubscribeApproval.denied': '✕ Request denied.', diff --git a/app/src/lib/i18n/ko.ts b/app/src/lib/i18n/ko.ts index 3d21380ba3..fd371bad0b 100644 --- a/app/src/lib/i18n/ko.ts +++ b/app/src/lib/i18n/ko.ts @@ -1301,6 +1301,9 @@ const ko: TranslationMap = { 'channels.telegram.savedRestartRequired': '채널이 저장되었습니다. 활성화하려면 앱을 다시 시작하세요.', 'channels.web.alwaysAvailable': '항상 사용 가능', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', 'chat.unsubscribeApproval.approve': '승인 및 구독 취소', 'chat.unsubscribeApproval.approved': '✓ 구독 취소가 완료되었습니다.', 'chat.unsubscribeApproval.denied': '✕ 요청이 거부되었습니다.', diff --git a/app/src/store/__tests__/channelConnectionsSlice.test.ts b/app/src/store/__tests__/channelConnectionsSlice.test.ts index 3474bb43de..1c545d4956 100644 --- a/app/src/store/__tests__/channelConnectionsSlice.test.ts +++ b/app/src/store/__tests__/channelConnectionsSlice.test.ts @@ -177,6 +177,31 @@ describe('channelConnectionsSlice', () => { }); }); + it('lazily initialises a channel modes bucket when persisted state is missing the key', () => { + // Simulates a rehydrated state from before yuanbao existed: the channel + // key is absent so `state.connections.yuanbao` is undefined. Without + // `ensureChannelModes()` the first upsert would crash on + // `state.connections[channel][authMode]`. See `ensureChannelModes` in + // channelConnectionsSlice.ts. + const migrated = reducer(undefined, completeBreakingMigration()); + const partial = { + ...migrated, + connections: { ...migrated.connections, yuanbao: undefined as never }, + }; + + const next = reducer( + partial, + upsertChannelConnection({ + channel: 'yuanbao', + authMode: 'api_key', + patch: { status: 'connected' }, + }) + ); + + expect(next.connections.yuanbao).toBeDefined(); + expect(next.connections.yuanbao.api_key?.status).toBe('connected'); + }); + it('clears stale lastError when patch explicitly sets undefined', () => { const withError = reducer( undefined, diff --git a/app/src/store/channelConnectionsSlice.ts b/app/src/store/channelConnectionsSlice.ts index faee8db9c5..b785894590 100644 --- a/app/src/store/channelConnectionsSlice.ts +++ b/app/src/store/channelConnectionsSlice.ts @@ -34,9 +34,21 @@ const initialState: ChannelConnectionsState = { // MCP Servers tab is a virtual channel — no auth-mode connections, // but must be present to satisfy Record. mcp: makeEmptyChannelModes(), + yuanbao: makeEmptyChannelModes(), }, }; +// Lazy-init a channel's mode bucket on first write. Protects writes against +// rehydrated state from older app versions (or any channel added after a user +// first persisted state) where `state.connections[channel]` is `undefined` +// because redux-persist's default `autoMergeLevel1` reconciler does not deep- +// merge into `connections`. +function ensureChannelModes(state: ChannelConnectionsState, channel: ChannelType): void { + if (!state.connections[channel]) { + state.connections[channel] = makeEmptyChannelModes(); + } +} + function touchConnection( existing: ChannelConnection | undefined, patch: Partial & { channel: ChannelType; authMode: ChannelAuthMode } @@ -68,11 +80,13 @@ const channelConnectionsSlice = createSlice({ // explicit initialisation here, the first `upsertChannelConnection` // for either channel would crash on `state.connections[channel]` // being undefined. Pin them by default so the migration is total. + state.connections.yuanbao = makeEmptyChannelModes(); state.connections.lark = makeEmptyChannelModes(); state.connections.dingtalk = makeEmptyChannelModes(); // MCP virtual channel must be present in persisted states migrated from // before PR #2276 or the Record shape is incomplete. state.connections.mcp = makeEmptyChannelModes(); + state.connections.yuanbao = makeEmptyChannelModes(); state.defaultMessagingChannel = 'telegram'; state.migrationCompleted = true; state.schemaVersion = SCHEMA_VERSION; @@ -91,6 +105,7 @@ const channelConnectionsSlice = createSlice({ }> ) { const { channel, authMode, patch } = action.payload; + ensureChannelModes(state, channel); const existing = state.connections[channel][authMode]; state.connections[channel][authMode] = touchConnection(existing, { channel, @@ -109,6 +124,7 @@ const channelConnectionsSlice = createSlice({ }> ) { const { channel, authMode, status, lastError } = action.payload; + ensureChannelModes(state, channel); const existing = state.connections[channel][authMode]; state.connections[channel][authMode] = touchConnection(existing, { channel, @@ -123,6 +139,7 @@ const channelConnectionsSlice = createSlice({ action: PayloadAction<{ channel: ChannelType; authMode: ChannelAuthMode }> ) { const { channel, authMode } = action.payload; + ensureChannelModes(state, channel); state.connections[channel][authMode] = touchConnection(state.connections[channel][authMode], { channel, authMode, diff --git a/app/src/types/channels.ts b/app/src/types/channels.ts index ad3be4cf16..05ca949fa3 100644 --- a/app/src/types/channels.ts +++ b/app/src/types/channels.ts @@ -1,4 +1,4 @@ -export type ChannelType = 'telegram' | 'discord' | 'web' | 'lark' | 'dingtalk' | 'mcp'; +export type ChannelType = 'telegram' | 'discord' | 'web' | 'lark' | 'dingtalk' | 'mcp' | 'yuanbao'; export type ChannelAuthMode = 'managed_dm' | 'oauth' | 'bot_token' | 'api_key'; diff --git a/docs/TEST-COVERAGE-MATRIX.md b/docs/TEST-COVERAGE-MATRIX.md index 9674728600..1b4f995c16 100644 --- a/docs/TEST-COVERAGE-MATRIX.md +++ b/docs/TEST-COVERAGE-MATRIX.md @@ -344,6 +344,7 @@ Canonical mapping of every product feature to its test source(s). Drives gap-fil | 10.1.2 | WhatsApp Connection | WD | `app/test/e2e/specs/whatsapp-flow.spec.ts` (this PR) | ✅ | Was ❌ | | 10.1.3 | Gmail Connection | WD | `gmail-flow.spec.ts` | ✅ | | | 10.1.4 | Slack Connection | WD | `app/test/e2e/specs/slack-flow.spec.ts` (this PR) | ✅ | Was ❌ | +| 10.1.5 | Yuanbao Connection | RU | `src/openhuman/channels/providers/yuanbao/` (this PR), `src/openhuman/channels/controllers/ops.rs::tests::connect_yuanbao_*` (this PR), `src/openhuman/channels/runtime/startup.rs::yuanbao_secret_tests` (this PR) | 🟡 | New API-key channel for Tencent Yuanbao. RU covers sign-token preflight (valid/invalid creds, env-override cluster routing), credentials store hydration (incl. stale app_key guard), and WS reconnect/shutdown. No WDIO spec yet — connect-flow UI is rendered via the generic `ChannelSetupModal` already exercised by other channel flow specs. | ### 10.2 Authentication & Authorization diff --git a/src/openhuman/channels/commands.rs b/src/openhuman/channels/commands.rs index bd12e19343..105f40e119 100644 --- a/src/openhuman/channels/commands.rs +++ b/src/openhuman/channels/commands.rs @@ -17,6 +17,7 @@ use super::telegram::TelegramChannel; use super::whatsapp::WhatsAppChannel; #[cfg(feature = "whatsapp-web")] use super::whatsapp_web::WhatsAppWebChannel; +use super::yuanbao::YuanbaoChannel; use super::Channel; use crate::openhuman::config::Config; use anyhow::Result; @@ -235,6 +236,13 @@ pub async fn doctor_channels(config: Config) -> Result<()> { )); } + if let Some(ref yb) = config.channels_config.yuanbao { + match YuanbaoChannel::new(yb.clone()) { + Ok(ch) => channels.push(("Yuanbao", Arc::new(ch))), + Err(e) => tracing::warn!("Yuanbao config invalid, skipping: {}", e), + } + } + if channels.is_empty() { println!("No real-time channels configured. Configure channels in the web UI."); return Ok(()); diff --git a/src/openhuman/channels/controllers/definitions.rs b/src/openhuman/channels/controllers/definitions.rs index 445437e193..ccb7baad9f 100644 --- a/src/openhuman/channels/controllers/definitions.rs +++ b/src/openhuman/channels/controllers/definitions.rs @@ -160,6 +160,7 @@ pub fn all_channel_definitions() -> Vec { imessage_definition(), lark_definition(), dingtalk_definition(), + yuanbao_definition(), ] } @@ -444,6 +445,44 @@ fn dingtalk_definition() -> ChannelDefinition { } } +fn yuanbao_definition() -> ChannelDefinition { + // Endpoint URLs (api_domain / ws_domain) are not user-facing — the + // channel derives them from the `env` field of `YuanbaoConfig` + // (default: production). Advanced users can override via TOML. + ChannelDefinition { + id: "yuanbao", + display_name: "元宝", + description: "通过元宝(Yuanbao)机器人收发消息。", + icon: "yuanbao", + auth_modes: vec![AuthModeSpec { + mode: ChannelAuthMode::ApiKey, + description: "提供元宝开放平台的 AppID 和 AppSecret。", + fields: vec![ + FieldRequirement { + key: "app_key", + label: "AppID", + field_type: "string", + required: true, + placeholder: "元宝开放平台 AppID", + }, + FieldRequirement { + key: "app_secret", + label: "AppSecret", + field_type: "secret", + required: true, + placeholder: "元宝开放平台 AppSecret", + }, + ], + auth_action: None, + }], + capabilities: vec![ + ChannelCapability::SendText, + ChannelCapability::ReceiveText, + ChannelCapability::Typing, + ], + } +} + #[cfg(test)] #[path = "definitions_tests.rs"] mod tests; diff --git a/src/openhuman/channels/controllers/ops.rs b/src/openhuman/channels/controllers/ops.rs index cf5a3a3296..1df9722463 100644 --- a/src/openhuman/channels/controllers/ops.rs +++ b/src/openhuman/channels/controllers/ops.rs @@ -6,6 +6,8 @@ use serde_json::{json, Value}; use crate::api::config::{app_env_from_env, effective_backend_api_url, is_staging_app_env}; use crate::api::jwt::get_session_token; use crate::api::rest::BackendOAuthClient; +use crate::openhuman::channels::providers::yuanbao::sign::SignManager; +use crate::openhuman::channels::providers::yuanbao::YuanbaoConfig; use crate::openhuman::config::{Config, DiscordConfig, IMessageConfig, TelegramConfig}; use crate::openhuman::credentials; use crate::rpc::RpcOutcome; @@ -108,6 +110,89 @@ fn parse_optional_bool(value: Option<&Value>) -> Option { } } +/// Read a required non-empty Yuanbao credential field from the connect-channel +/// payload. Returns the trimmed value or an error naming the missing field. +fn require_yuanbao_field( + creds_map: &serde_json::Map, + key: &str, +) -> Result { + creds_map + .get(key) + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .ok_or_else(|| format!("missing required {key}")) +} + +/// Build the **effective** Yuanbao config that will be used for both +/// preflight verification and persistence. +/// +/// Starts from the existing TOML (so manually-installed deployments keep +/// any custom routes), overlays the client-supplied endpoint overrides +/// (`env` / `api_domain` / `ws_domain` / `route_env`), then calls +/// `apply_env_defaults` so the verifier hits the correct cluster — e.g. a +/// user submitting `env = "pre"` is verified against the pre-release +/// sign-token endpoint instead of the default prod one. +/// +/// `app_secret` is intentionally left empty: the runtime loads it from +/// the encrypted credentials store at startup, never from `config.toml`. +fn build_effective_yuanbao_config( + base: YuanbaoConfig, + creds_map: &serde_json::Map, + app_key: String, +) -> YuanbaoConfig { + let opt_string = |key: &str| -> Option { + creds_map + .get(key) + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + }; + + let mut cfg = base; + cfg.app_key = app_key; + cfg.app_secret = String::new(); + if let Some(env) = opt_string("env") { + cfg.env = env; + } + if let Some(api_domain) = opt_string("api_domain") { + cfg.api_domain = api_domain; + } + if let Some(ws_domain) = opt_string("ws_domain") { + cfg.ws_domain = ws_domain; + } + if let Some(route_env) = opt_string("route_env") { + cfg.route_env = route_env; + } + cfg.apply_env_defaults(); + cfg +} + +/// Verify Yuanbao credentials against the `sign-token` endpoint before any +/// persistence so invalid `app_key` / `app_secret` surface the upstream API +/// error to the user instead of silently succeeding. +/// +/// Takes the **effective** `YuanbaoConfig` already built from the client's +/// overrides + TOML defaults, so the verifier targets whatever cluster the +/// runtime will use after restart. +async fn verify_yuanbao_credentials( + yb_cfg: &YuanbaoConfig, + app_secret: &str, +) -> Result<(), String> { + SignManager::new(reqwest::Client::new()) + .get_token( + &yb_cfg.app_key, + app_secret, + &yb_cfg.api_domain, + &yb_cfg.route_env, + ) + .await + .map_err(|e| format!("yuanbao credential verification failed: {e}"))?; + Ok(()) +} + /// List all available channel definitions. pub async fn list_channels() -> Result>, String> { Ok(RpcOutcome::new(all_channel_definitions(), vec![])) @@ -160,6 +245,21 @@ pub async fn connect_channel( def.validate_credentials(auth_mode, creds_map)?; + // Yuanbao: build the effective config (with any client-supplied + // endpoint overrides applied) once, verify against THAT cluster, and + // reuse the same config for persistence below. This prevents the + // verifier from validating against prod while the runtime then + // reconnects to a pre-release cluster after restart. + let mut prebuilt_yuanbao_config: Option = None; + if channel_id == "yuanbao" && auth_mode == ChannelAuthMode::ApiKey { + let app_key = require_yuanbao_field(creds_map, "app_key")?; + let app_secret = require_yuanbao_field(creds_map, "app_secret")?; + let base = config.channels_config.yuanbao.clone().unwrap_or_default(); + let effective = build_effective_yuanbao_config(base, creds_map, app_key); + verify_yuanbao_credentials(&effective, &app_secret).await?; + prebuilt_yuanbao_config = Some(effective); + } + // iMessage is local-only (no credentials): persist channels_config + return connected. if channel_id == "imessage" && auth_mode == ChannelAuthMode::ManagedDm { let allowed_contacts = parse_allowed_users(creds_map.get("allowed_contacts")); @@ -332,6 +432,27 @@ pub async fn connect_channel( mention_only, "[discord] connect_channel: wrote channels_config.discord; restart core for listener to load token" ); + } else if channel_id == "yuanbao" && auth_mode == ChannelAuthMode::ApiKey { + // Reuse the effective config built above (with `env` / `api_domain` + // / `ws_domain` / `route_env` overrides already applied and + // `app_secret` already cleared) so persistence and verification + // can never diverge. + let yb_config = prebuilt_yuanbao_config + .take() + .expect("yuanbao verify branch must run before persistence"); + + let mut persisted = config.clone(); + persisted.channels_config.yuanbao = Some(yb_config); + + persisted + .save() + .await + .map_err(|e| format!("failed to persist yuanbao config.toml: {e}"))?; + + tracing::info!( + target: "openhuman::channels", + "[yuanbao] connect_channel: wrote channels_config.yuanbao (secret stored in credentials); restart core for WS listener" + ); } Ok(RpcOutcome::single_log( @@ -402,6 +523,18 @@ pub async fn disconnect_channel( "[imessage] disconnect_channel: cleared channels_config.imessage" ); } + } else if channel_id == "yuanbao" && auth_mode == ChannelAuthMode::ApiKey { + let mut persisted = config.clone(); + if persisted.channels_config.yuanbao.take().is_some() { + persisted + .save() + .await + .map_err(|e| format!("failed to clear yuanbao config.toml: {e}"))?; + tracing::info!( + target: "openhuman::channels", + "[yuanbao] disconnect_channel: cleared channels_config.yuanbao" + ); + } } Ok(RpcOutcome::single_log( @@ -507,6 +640,9 @@ pub async fn connected_channel_slugs(config: &Config) -> Result, Str if cc.imessage.is_some() { slugs.insert("imessage".to_string()); } + if cc.yuanbao.is_some() { + slugs.insert("yuanbao".to_string()); + } if cc.irc.is_some() { slugs.insert("irc".to_string()); } diff --git a/src/openhuman/channels/controllers/ops_tests.rs b/src/openhuman/channels/controllers/ops_tests.rs index aa19ab4d97..446ae5d43a 100644 --- a/src/openhuman/channels/controllers/ops_tests.rs +++ b/src/openhuman/channels/controllers/ops_tests.rs @@ -1,4 +1,5 @@ use super::*; +use crate::openhuman::channels::providers::yuanbao::YuanbaoConfig; use tempfile::tempdir; fn isolated_test_config() -> (tempfile::TempDir, Config) { @@ -482,3 +483,271 @@ async fn connected_channel_slugs_empty_when_nothing_configured() { "fresh config should yield no channels: {slugs:?}" ); } + +// ── Yuanbao channel credential verification ──────────────────── +// Issue: connect_channel for yuanbao previously stored creds and returned +// "connected" without ever calling the upstream sign-token endpoint, so +// random input (e.g. app_key=12) showed as Connected in the UI. The fix +// calls `/api/v5/robotLogic/sign-token` and propagates the API error. + +/// Build a Config pre-pointed at a mock `api_domain` so the verification +/// step hits the wiremock server instead of the live prod URL. +fn yuanbao_test_config(mock_uri: &str) -> (tempfile::TempDir, Config) { + let (tmp, mut config) = isolated_test_config(); + config.channels_config.yuanbao = Some(YuanbaoConfig { + api_domain: mock_uri.to_string(), + ..Default::default() + }); + (tmp, config) +} + +#[tokio::test] +async fn connect_yuanbao_rejects_invalid_credentials() { + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v5/robotLogic/sign-token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 40001, + "msg": "invalid signature", + }))) + .mount(&server) + .await; + + let (_tmp, config) = yuanbao_test_config(&server.uri()); + let err = connect_channel( + &config, + "yuanbao", + ChannelAuthMode::ApiKey, + serde_json::json!({ "app_key": "12", "app_secret": "12" }), + ) + .await + .expect_err("invalid yuanbao credentials should fail"); + + assert!( + err.contains("yuanbao credential verification failed") && err.contains("invalid signature"), + "expected upstream API msg in error, got: {err}" + ); + + // Nothing should be persisted on failure: no TOML write, no credential row. + let raw = tokio::fs::read_to_string(&config.config_path).await.ok(); + if let Some(text) = raw { + let parsed: toml::Value = toml::from_str(&text).expect("config parses"); + // The mock api_domain we pre-loaded is allowed to be present, but + // app_key / app_secret must NOT have been written. + if let Some(yb) = parsed + .get("channels_config") + .and_then(|v| v.get("yuanbao")) + .and_then(toml::Value::as_table) + { + assert_ne!( + yb.get("app_key").and_then(toml::Value::as_str), + Some("12"), + "app_key must not be persisted when verification fails" + ); + } + } +} + +#[tokio::test] +async fn connect_yuanbao_persists_when_credentials_valid() { + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v5/robotLogic/sign-token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "data": { + "token": "tok-abc", + "bot_id": "bot-123", + "product": "yuanbao", + "source": "openhuman", + "duration": 3600, + } + }))) + .mount(&server) + .await; + + let (_tmp, config) = yuanbao_test_config(&server.uri()); + let result = connect_channel( + &config, + "yuanbao", + ChannelAuthMode::ApiKey, + serde_json::json!({ "app_key": "real-key", "app_secret": "real-secret" }), + ) + .await + .expect("valid yuanbao credentials should succeed"); + + assert_eq!(result.value.status, "connected"); + assert!(result.value.restart_required); + + let raw = tokio::fs::read_to_string(&config.config_path) + .await + .expect("config should be persisted"); + let parsed: toml::Value = toml::from_str(&raw).expect("config parses"); + let yb = parsed + .get("channels_config") + .and_then(|v| v.get("yuanbao")) + .and_then(toml::Value::as_table) + .expect("channels_config.yuanbao persisted"); + assert_eq!( + yb.get("app_key").and_then(toml::Value::as_str), + Some("real-key") + ); + // The plaintext `app_secret` must NOT be persisted in TOML — the + // runtime loads it from the encrypted credentials store instead. + let toml_secret = yb.get("app_secret").and_then(toml::Value::as_str); + assert!( + toml_secret.is_none() || toml_secret == Some(""), + "app_secret must not be persisted in plaintext TOML, got {toml_secret:?}" + ); + + // The credentials store should contain the secret so startup can recover it. + let auth = crate::openhuman::credentials::AuthService::from_config(&config); + let profile = auth + .get_profile("channel:yuanbao:api_key", None) + .expect("credentials lookup succeeds") + .expect("yuanbao credentials stored"); + assert_eq!( + profile.metadata.get("app_secret").map(String::as_str), + Some("real-secret") + ); + assert_eq!( + profile.metadata.get("app_key").map(String::as_str), + Some("real-key") + ); +} + +#[tokio::test] +async fn connect_yuanbao_verifies_against_overridden_api_domain() { + // Regression: previously, `verify_yuanbao_credentials` rebuilt the + // YuanbaoConfig from `config.channels_config.yuanbao` alone and + // ignored the `api_domain` / `env` / `route_env` overrides on the + // connect-channel payload. A user submitting `env = "pre"` could + // pass verification against PROD and then fail after restart when + // the persisted override took effect. + // + // Here the base TOML's `api_domain` deliberately points at an + // unreachable URL — verification only succeeds if the override + // supplied in `creds_map` is what actually gets used. + use wiremock::matchers::{header, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v5/robotLogic/sign-token")) + .and(header("X-Route-Env", "canary")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "data": { + "token": "tok-override", + "bot_id": "bot-1", + "product": "yuanbao", + "source": "openhuman", + "duration": 3600, + } + }))) + .mount(&server) + .await; + + let (_tmp, mut config) = isolated_test_config(); + // Base TOML points to a black hole so the test fails immediately if + // the verifier ignores the override. + config.channels_config.yuanbao = Some(YuanbaoConfig { + api_domain: "http://127.0.0.1:1".to_string(), + ..Default::default() + }); + + let mock_uri = server.uri(); + let result = connect_channel( + &config, + "yuanbao", + ChannelAuthMode::ApiKey, + serde_json::json!({ + "app_key": "k", + "app_secret": "s", + "api_domain": mock_uri.clone(), + "route_env": "canary", + }), + ) + .await + .expect("override should be applied before verify"); + + assert_eq!(result.value.status, "connected"); + + // The override should also have been persisted (single source of + // truth between verify and persist). + let raw = tokio::fs::read_to_string(&config.config_path) + .await + .expect("config should be persisted"); + let parsed: toml::Value = toml::from_str(&raw).expect("config parses"); + let yb = parsed + .get("channels_config") + .and_then(|v| v.get("yuanbao")) + .and_then(toml::Value::as_table) + .expect("channels_config.yuanbao persisted"); + assert_eq!( + yb.get("api_domain").and_then(toml::Value::as_str), + Some(mock_uri.as_str()), + ); + assert_eq!( + yb.get("route_env").and_then(toml::Value::as_str), + Some("canary"), + ); +} + +#[tokio::test] +async fn connect_yuanbao_persists_env_override() { + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v5/robotLogic/sign-token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "data": { + "token": "tok-pre", + "bot_id": "bot-456", + "product": "yuanbao", + "source": "openhuman", + "duration": 3600, + } + }))) + .mount(&server) + .await; + + let (_tmp, config) = yuanbao_test_config(&server.uri()); + connect_channel( + &config, + "yuanbao", + ChannelAuthMode::ApiKey, + serde_json::json!({ + "app_key": "k", + "app_secret": "s", + "env": "pre", + "route_env": "canary", + }), + ) + .await + .expect("valid yuanbao credentials should succeed"); + + let raw = tokio::fs::read_to_string(&config.config_path) + .await + .expect("config should be persisted"); + let parsed: toml::Value = toml::from_str(&raw).expect("config parses"); + let yb = parsed + .get("channels_config") + .and_then(|v| v.get("yuanbao")) + .and_then(toml::Value::as_table) + .expect("channels_config.yuanbao persisted"); + assert_eq!(yb.get("env").and_then(toml::Value::as_str), Some("pre")); + assert_eq!( + yb.get("route_env").and_then(toml::Value::as_str), + Some("canary") + ); +} diff --git a/src/openhuman/channels/mod.rs b/src/openhuman/channels/mod.rs index 2101542973..41f9679504 100644 --- a/src/openhuman/channels/mod.rs +++ b/src/openhuman/channels/mod.rs @@ -34,6 +34,7 @@ pub use providers::web; pub use providers::whatsapp; #[cfg(feature = "whatsapp-web")] pub use providers::whatsapp_web; +pub use providers::yuanbao; pub use cli::CliChannel; pub use dingtalk::DingTalkChannel; @@ -54,6 +55,7 @@ pub use traits::{Channel, SendMessage}; pub use whatsapp::WhatsAppChannel; #[cfg(feature = "whatsapp-web")] pub use whatsapp_web::WhatsAppWebChannel; +pub use yuanbao::YuanbaoChannel; pub use commands::doctor_channels; pub use controllers::{ChannelAuthMode, ChannelDefinition}; diff --git a/src/openhuman/channels/providers/mod.rs b/src/openhuman/channels/providers/mod.rs index 34bae4714a..d6844be06d 100644 --- a/src/openhuman/channels/providers/mod.rs +++ b/src/openhuman/channels/providers/mod.rs @@ -19,3 +19,4 @@ pub mod web; pub mod whatsapp; #[cfg(feature = "whatsapp-web")] pub mod whatsapp_web; +pub mod yuanbao; diff --git a/src/openhuman/channels/providers/yuanbao/channel.rs b/src/openhuman/channels/providers/yuanbao/channel.rs new file mode 100644 index 0000000000..b597dfe79b --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/channel.rs @@ -0,0 +1,705 @@ +//! Channel facade for the Yuanbao provider. +//! +//! This module owns the OpenHuman [`Channel`] implementation and keeps +//! provider wiring out of `mod.rs`. Protocol decoding, transport, inbound +//! filtering, and outbound sending remain delegated to sibling modules. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use tokio::sync::{mpsc, watch, Mutex as TokioMutex}; +use tokio::task::JoinHandle; +use tracing::{info, warn}; + +use crate::openhuman::channels::traits::{Channel, ChannelMessage, SendMessage}; + +use super::config::YuanbaoConfig; +use super::connection::{InboundEvent, YuanbaoConnection}; +use super::ids::{shorten_account_id, shorten_reply_target}; +use super::inbound::{InboundPipeline, PipelineOutcome, PipelineState}; +use super::outbound::OutboundSender; +use super::proto::decode_push_msg; +use super::sign::SignManager; +use super::{splitter, types}; + +/// Reply Heartbeat keepalive interval. The yuanbao gateway expects the +/// bot to ping (`SendPrivateHeartbeat RUNNING`) at this cadence so the +/// "正在输入" indicator stays alive for long-running responses. +const REPLY_HEARTBEAT_INTERVAL_SECS: u64 = 2; + +/// Hard ceiling on the in-memory shortened-recipient → original-recipient +/// map. Each entry is two short strings (~80 B), so 4096 distinct senders +/// give ~320 KB — plenty for any realistic chat load and small enough +/// that we can blow the whole map away when we hit the cap instead of +/// dragging in an LRU dependency. See `register_recipient_alias`. +const RECIPIENT_ALIAS_CAP: usize = 4096; + +/// The yuanbao channel — owns one WebSocket and one inbound pipeline. +pub struct YuanbaoChannel { + config: YuanbaoConfig, + connection: Arc, + outbound: Arc, + pipeline: Arc, + shutdown_tx: watch::Sender, + /// Holds the inbound receiver between `new()` and the first `listen()` call. + /// + /// `Channel::listen` takes `&self`, so we can't move the receiver out of + /// a field. Use a `Mutex>` so the first listener takes ownership + /// and subsequent calls fail cleanly. + inbound_rx: parking_lot::Mutex>>, + /// Per-recipient Reply Heartbeat keepalive tasks (started on `start_typing`). + heartbeat_tasks: TokioMutex>>, + /// Reverse lookup table from shortened recipient ids (the ones we + /// emit on `ChannelMessage.sender` / `reply_target`) back to the + /// original server-recognized ids that outbound `send_c2c_message` + /// / `send_group_message` must use as `to_account` / `group_code`. + /// + /// Why this exists: yuanbao uids are ~64-char hashes, and + /// `super::ids::shorten_account_id` rewrites them as + /// `_` so the conversation store's per-thread + /// JSONL filenames stay under filesystem `NAME_MAX`. Without this + /// table the agent loop sends replies addressed to the shortened + /// hash, which the yuanbao gateway silently drops because no such + /// user exists. See `register_recipient_alias` / `resolve_recipient`. + recipient_aliases: TokioMutex>, +} + +impl YuanbaoChannel { + /// Build a channel from a validated config. Returns an error if the + /// config is missing required fields (so misconfiguration surfaces + /// at startup, not on the first inbound message). + pub fn new(mut config: YuanbaoConfig) -> anyhow::Result { + config.apply_env_defaults(); + config.validate()?; + let (shutdown_tx, _shutdown_rx) = watch::channel(false); + let (inbound_tx, inbound_rx) = mpsc::unbounded_channel::(); + + // SignManager is only useful when we have an app_secret — without + // it we'd never call the sign endpoint anyway. + let sign_manager: Option> = if !config.app_secret.is_empty() { + Some(SignManager::new(reqwest::Client::new())) + } else { + None + }; + + let connection = YuanbaoConnection::new(config.clone(), inbound_tx, sign_manager.clone()); + let outbound = Arc::new(OutboundSender::new( + Arc::clone(&connection), + sign_manager.clone(), + config.app_key.clone(), + config.bot_id.clone(), + )); + // PipelineState's `from_account` is used by the echo-guard stage to + // drop self-sent messages. We feed it the static config value here + // (which may be empty); the canonical server-issued bot_id only + // becomes known after sign-token, so this is a known minor gap — + // echo guard will simply not fire when bot_id isn't statically set. + let pipeline_state = PipelineState::new(&config, config.bot_id.clone()); + let pipeline = Arc::new(InboundPipeline::new(pipeline_state)); + + Ok(Self { + config, + connection, + outbound, + pipeline, + shutdown_tx, + inbound_rx: parking_lot::Mutex::new(Some(inbound_rx)), + heartbeat_tasks: TokioMutex::new(HashMap::new()), + recipient_aliases: TokioMutex::new(HashMap::new()), + }) + } + + /// Record a `shortened → original` recipient mapping so the outbound + /// side can recover the server-recognized id when the agent loop + /// addresses a reply with the shortened sender / reply_target it + /// received on `ChannelMessage`. + /// + /// No-op when the two are equal (uid is short enough to skip + /// shortening, or this is the `g:` group-target case where the + /// inner code is short). When the map crosses `RECIPIENT_ALIAS_CAP` + /// we clear it — the next inbound message from each active sender + /// re-populates the entry it needs, and stale entries from idle + /// conversations are fine to lose. + async fn register_recipient_alias(&self, shortened: &str, original: &str) { + if shortened == original { + return; + } + let mut m = self.recipient_aliases.lock().await; + if m.len() >= RECIPIENT_ALIAS_CAP { + warn!( + "[yuanbao] recipient alias map hit cap ({}), clearing", + RECIPIENT_ALIAS_CAP + ); + m.clear(); + } + m.insert(shortened.to_string(), original.to_string()); + } + + /// Look up the server-recognized recipient for a (possibly + /// shortened) inbound id. Falls back to the input unchanged when + /// nothing is registered — which keeps the previous behavior for + /// recipients that don't go through `shorten_account_id` (short + /// uids, group codes, `imessage`-style ids). + async fn resolve_recipient(&self, recipient: &str) -> String { + let m = self.recipient_aliases.lock().await; + m.get(recipient) + .cloned() + .unwrap_or_else(|| recipient.to_string()) + } + + fn split_message(&self, text: &str) -> Vec { + splitter::split_markdown(text, self.config.max_message_length) + } + + async fn start_heartbeat_task(&self, recipient: &str) { + let mut tasks = self.heartbeat_tasks.lock().await; + if tasks.contains_key(recipient) { + return; + } + let outbound = Arc::clone(&self.outbound); + let target = recipient.to_string(); + let handle = tokio::spawn(async move { + let mut interval = + tokio::time::interval(Duration::from_secs(REPLY_HEARTBEAT_INTERVAL_SECS)); + interval.tick().await; // skip first tick (start_typing already sent RUNNING) + loop { + interval.tick().await; + if let Err(e) = outbound.start_heartbeat(&target).await { + // Connection bouncing — bail out of this loop; the + // next start_typing call will spawn a new one. + warn!( + "[yuanbao] reply heartbeat send failed: {} — stopping loop", + e + ); + return; + } + } + }); + tasks.insert(recipient.to_string(), handle); + } + + async fn stop_heartbeat_task(&self, recipient: &str) { + let mut tasks = self.heartbeat_tasks.lock().await; + if let Some(handle) = tasks.remove(recipient) { + handle.abort(); + } + } +} + +#[async_trait] +impl Channel for YuanbaoChannel { + fn name(&self) -> &str { + "yuanbao" + } + + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + let chunks = self.split_message(&message.content); + let ref_msg_id = message.thread_ts.as_deref(); + let recipient = self.resolve_recipient(&message.recipient).await; + for chunk in &chunks { + self.outbound + .send_text(&recipient, chunk, ref_msg_id) + .await?; + } + Ok(()) + } + + fn supports_draft_updates(&self) -> bool { + // Routes turns through the streaming code path even though Yuanbao + // itself has no edit-message capability. We accept the UX cost (no + // progressive rendering — the reply appears all at once in + // `finalize_draft`) in exchange for streaming's tolerance of + // malformed `usage` chunks; the non-streaming parser fails the + // whole turn when an upstream LLM returns string-typed token counts. + true + } + + async fn send_draft(&self, message: &SendMessage) -> anyhow::Result> { + // Marker id so dispatch spins up the progress consumer task; + // nothing is sent to the user here. Real content goes out in + // `finalize_draft`. See `supports_draft_updates` for rationale. + Ok(Some(format!("yb-draft:{}", message.recipient))) + } + + async fn update_draft( + &self, + _recipient: &str, + _message_id: &str, + _text: &str, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn finalize_draft( + &self, + recipient: &str, + _message_id: &str, + text: &str, + thread_ts: Option<&str>, + ) -> anyhow::Result<()> { + let chunks = self.split_message(text); + let recipient = self.resolve_recipient(recipient).await; + for chunk in &chunks { + self.outbound + .send_text(&recipient, chunk, thread_ts) + .await?; + } + Ok(()) + } + + async fn listen(&self, tx: mpsc::Sender) -> anyhow::Result<()> { + // Take the inbound receiver. A second listener would just exit early. + let mut inbound_rx = match self.inbound_rx.lock().take() { + Some(rx) => rx, + None => { + warn!("[yuanbao] listen() called twice — second call exits"); + return Ok(()); + } + }; + + let conn = Arc::clone(&self.connection); + let shutdown_rx = self.shutdown_tx.subscribe(); + let conn_task = tokio::spawn(async move { + conn.run(shutdown_rx).await; + }); + + info!("[yuanbao] channel listening — pipeline ready"); + let mut shutdown_rx2 = self.shutdown_tx.subscribe(); + loop { + tokio::select! { + _ = shutdown_rx2.changed() => { + info!("[yuanbao] listen loop received shutdown"); + break; + } + event = inbound_rx.recv() => { + match event { + Some(InboundEvent::Push(frame)) => { + self.dispatch_push(frame, &tx).await; + } + Some(InboundEvent::Kickout(reason)) => { + warn!("[yuanbao] kickout: {} — stopping listen loop", reason); + break; + } + None => { + warn!("[yuanbao] inbound channel closed"); + break; + } + } + } + } + } + + let _ = self.shutdown_tx.send(true); + conn_task.abort(); + Ok(()) + } + + async fn health_check(&self) -> bool { + self.connection.is_connected() + } + + async fn start_typing(&self, recipient: &str) -> anyhow::Result<()> { + // Send RUNNING immediately, then spawn a 2s keepalive so the + // indicator doesn't expire while we generate. + let recipient = self.resolve_recipient(recipient).await; + self.outbound.start_heartbeat(&recipient).await?; + self.start_heartbeat_task(&recipient).await; + Ok(()) + } + + async fn stop_typing(&self, recipient: &str) -> anyhow::Result<()> { + let recipient = self.resolve_recipient(recipient).await; + self.stop_heartbeat_task(&recipient).await; + self.outbound.stop_heartbeat(&recipient).await?; + Ok(()) + } + + fn supports_reactions(&self) -> bool { + false + } +} + +impl YuanbaoChannel { + async fn dispatch_push(&self, frame: types::ConnFrame, tx: &mpsc::Sender) { + // The Yuanbao gateway pushes inbound messages with `cmd_type=Push`; + // the actual `cmd` word is decided server-side and varies (mirrors + // hermes-agent `yuanbao.py::_handle_received_frame` which routes + // purely on cmd_type). The connection layer has already filtered + // out non-push frames before we get here, so every frame we see + // should be a candidate for the inbound pipeline. + if frame.data.is_empty() { + tracing::trace!("[yuanbao] empty push body cmd={} — skipping", frame.cmd); + return; + } + // Some push frames wrap the biz body in an extra + // `PushMsg { cmd, module, msg_id, data }` envelope; others (e.g. + // cmd="inbound_message", module="yuanbao_openclaw_proxy") put the + // InboundMessagePush bytes directly in `ConnMsg.data` with the + // ConnMsg.head already carrying cmd/module. Mirrors plugin + // client.ts::onPush (l. 813): try PushMsg first, but only accept + // it when it has a non-empty cmd or module; otherwise treat the + // raw frame.data as the biz body. + let unwrapped: Option> = match decode_push_msg(&frame.data) { + Ok(p) if (!p.cmd.is_empty() || !p.module.is_empty()) && !p.data.is_empty() => { + info!( + "[yuanbao] push envelope decoded: cmd={} module={} msg_id={} biz_len={}", + p.cmd, + p.module, + p.msg_id, + p.data.len() + ); + Some(p.data) + } + _ => { + info!( + "[yuanbao] push has no PushMsg envelope — treating ConnMsg.data as biz body (conn_cmd={} module={} len={})", + frame.cmd, + frame.module, + frame.data.len() + ); + None + } + }; + let biz_body: &[u8] = unwrapped.as_deref().unwrap_or(&frame.data); + let outcome = self.pipeline.process(biz_body).await; + match outcome { + PipelineOutcome::Dispatch(ctx) => { + // Shorten ids at the channel boundary so the composite thread_id + // derived downstream (channel:yuanbao__) + // stays under filesystem NAME_MAX once hex-encoded for the + // per-thread JSONL filename. Yuanbao internals (echo guard, + // access control, owner-command check) keep the original + // `from_account` — see `super::ids` for the format and rationale. + let original_from = ctx.msg.from_account.clone(); + let original_reply_target = ctx.source.reply_target(); + let short_sender = shorten_account_id(&original_from); + let short_reply_target = shorten_reply_target(&original_reply_target); + // Remember the original ids so the outbound side can + // recover them when the agent loop addresses a reply + // with the shortened values it sees here. + self.register_recipient_alias(&short_sender, &original_from) + .await; + self.register_recipient_alias(&short_reply_target, &original_reply_target) + .await; + let msg = ChannelMessage { + id: ctx.msg.msg_id.clone(), + sender: short_sender, + reply_target: short_reply_target, + content: if ctx.text.is_empty() && !ctx.image_urls.is_empty() { + // Surface image URLs as content so downstream tools have something to work with. + ctx.image_urls.join("\n") + } else { + ctx.text.clone() + }, + channel: "yuanbao".into(), + timestamp: ctx.msg.msg_time as u64, + thread_ts: None, + }; + if tx.send(msg).await.is_err() { + warn!("[yuanbao] dispatch receiver gone — dropping message"); + } + } + PipelineOutcome::Filtered(reason) => { + tracing::trace!("[yuanbao] filtered at {reason}"); + } + PipelineOutcome::Failed(err) => { + // Intentionally omit the raw biz payload — it can carry + // user content / PII. The decoder error already encodes + // the structural reason; only the length is safe to log. + warn!( + "[yuanbao] pipeline error: {err} | biz_len={}", + biz_body.len() + ); + } + } + } +} + +impl Drop for YuanbaoChannel { + fn drop(&mut self) { + let _ = self.shutdown_tx.send(true); + } +} + +#[cfg(test)] +mod tests { + use crate::openhuman::channels::traits::Channel; + + use super::*; + + fn good_cfg() -> YuanbaoConfig { + let mut c = YuanbaoConfig::default(); + c.app_key = "ak".into(); + c.ws_domain = "wss://example".into(); + c.token = "tok".into(); + c + } + + #[test] + fn channel_construction_validates() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + assert_eq!(ch.name(), "yuanbao"); + } + + #[test] + fn invalid_config_rejected() { + let mut c = YuanbaoConfig::default(); + c.app_key = "ak".into(); + // missing ws_domain + assert!(YuanbaoChannel::new(c).is_err()); + } + + #[test] + fn split_short_message_returns_one() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + let chunks = ch.split_message("hello"); + assert_eq!(chunks, vec!["hello"]); + } + + #[test] + fn split_respects_newlines() { + let mut c = good_cfg(); + c.max_message_length = 12; + let ch = YuanbaoChannel::new(c).unwrap(); + let chunks = ch.split_message("line one\nline two\nline three"); + assert!(chunks.len() >= 2); + // No chunk exceeds the limit. + for chunk in &chunks { + assert!(chunk.len() <= 12, "chunk too long: {chunk:?}"); + } + } + + #[tokio::test] + async fn resolve_recipient_returns_input_when_no_alias_registered() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + assert_eq!(ch.resolve_recipient("short_uid").await, "short_uid"); + } + + #[tokio::test] + async fn register_and_resolve_dm_alias_recovers_original_uid() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + let original = "x".repeat(64); + let shortened = shorten_account_id(&original); + assert_ne!(shortened, original, "test premise: should actually shorten"); + ch.register_recipient_alias(&shortened, &original).await; + assert_eq!(ch.resolve_recipient(&shortened).await, original); + } + + #[tokio::test] + async fn register_recipient_alias_is_noop_for_equal_pair() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + // Short uid that wouldn't be shortened — caller still hands us + // (s, s); we should silently skip and not eat a map slot. + ch.register_recipient_alias("short", "short").await; + let m = ch.recipient_aliases.lock().await; + assert!(m.is_empty()); + } + + #[tokio::test] + async fn resolve_recipient_preserves_group_prefix_via_alias() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + let long_group_code = "g".repeat(64); + let original = format!("g:{long_group_code}"); + let shortened = shorten_reply_target(&original); + assert_ne!(shortened, original); + assert!(shortened.starts_with("g:")); + ch.register_recipient_alias(&shortened, &original).await; + assert_eq!(ch.resolve_recipient(&shortened).await, original); + } + + #[tokio::test] + async fn alias_map_clears_when_cap_is_hit() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + // Pre-fill up to the cap with distinct entries. + for i in 0..RECIPIENT_ALIAS_CAP { + ch.register_recipient_alias(&format!("s{i}"), &format!("o{i}")) + .await; + } + assert_eq!(ch.recipient_aliases.lock().await.len(), RECIPIENT_ALIAS_CAP); + // One more entry must trigger a clear, then insert the new entry. + ch.register_recipient_alias("new_short", "new_original") + .await; + let m = ch.recipient_aliases.lock().await; + assert_eq!(m.len(), 1); + assert_eq!(m.get("new_short").map(String::as_str), Some("new_original")); + } + + // ─── trivial trait methods ───────────────────────────────────── + + #[test] + fn supports_draft_updates_is_true() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + assert!(ch.supports_draft_updates()); + } + + #[test] + fn supports_reactions_is_false() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + assert!(!ch.supports_reactions()); + } + + #[tokio::test] + async fn send_draft_returns_marker_id() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + let msg = SendMessage::new("ignored", "user-42"); + let id = ch.send_draft(&msg).await.unwrap(); + assert_eq!(id.as_deref(), Some("yb-draft:user-42")); + } + + #[tokio::test] + async fn update_draft_is_a_noop_ok() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + assert!(ch.update_draft("user-42", "any-id", "text").await.is_ok()); + } + + #[tokio::test] + async fn health_check_is_false_when_socket_not_connected() { + // Real connect requires a WebSocket; we only verify the + // disconnected default here. The connected branch is exercised + // by `connection::tests::set_state_connected_flips_is_connected_flag`. + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + assert!(!ch.health_check().await); + } + + // ─── dispatch_push branches ──────────────────────────────────── + + fn make_push_frame(cmd: &str, data: Vec) -> types::ConnFrame { + types::ConnFrame { + cmd_type: super::super::proto_constants::cmd_type::PUSH, + cmd: cmd.into(), + module: "yuanbao_openclaw_proxy".into(), + seq_no: 0, + msg_id: String::new(), + need_ack: false, + status: 0, + data, + } + } + + #[tokio::test] + async fn dispatch_push_empty_body_is_skipped() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + let (tx, mut rx) = mpsc::channel::(4); + let frame = make_push_frame("noop", Vec::new()); + ch.dispatch_push(frame, &tx).await; + // No message should reach the sender. + assert!(rx.try_recv().is_err()); + } + + #[tokio::test] + async fn dispatch_push_garbage_body_does_not_dispatch() { + // Body is not a valid protobuf push *and* not valid JSON → Failed. + // dispatch_push should log + swallow, not propagate panic. + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + let (tx, mut rx) = mpsc::channel::(4); + let frame = make_push_frame("inbound_message", vec![0xFF, 0xFF, 0xFF, 0xFF]); + ch.dispatch_push(frame, &tx).await; + assert!(rx.try_recv().is_err()); + } + + #[tokio::test] + async fn dispatch_push_dm_text_reaches_listener() { + // Build a minimal `InboundMessagePush` directly in ConnFrame.data + // (no PushMsg envelope), with a single TIMTextElem so the pipeline + // dispatches. + use super::super::proto::{encode_msg_body_element, encode_varint}; + let elem = types::MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: types::MsgContent { + text: Some("hello".into()), + ..Default::default() + }, + }; + let elem_bytes = encode_msg_body_element(&elem); + + // Hand-roll an InboundMessagePush so we don't depend on a helper: + // field 2 = from_account, field 3 = to_account, field 12 = msg_id, + // field 13 = repeated MsgBodyElement. + let mut biz = Vec::new(); + let put_string = |fnum: u32, s: &str, b: &mut Vec| { + encode_varint(((fnum as u64) << 3) | 2, b); + encode_varint(s.len() as u64, b); + b.extend_from_slice(s.as_bytes()); + }; + put_string(2, "alice", &mut biz); + put_string(3, "bot1", &mut biz); + put_string(12, "mid-x", &mut biz); + encode_varint(((13u64) << 3) | 2, &mut biz); + encode_varint(elem_bytes.len() as u64, &mut biz); + biz.extend_from_slice(&elem_bytes); + + // Disable group_at_required and use open dm_access so the + // pipeline passes all stages for this DM. + let mut cfg = good_cfg(); + cfg.dm_access = "open".into(); + cfg.bot_id = "bot1".into(); + let ch = YuanbaoChannel::new(cfg).unwrap(); + + let frame = make_push_frame("inbound_message", biz); + let (tx, mut rx) = mpsc::channel::(4); + ch.dispatch_push(frame, &tx).await; + let msg = rx.try_recv().expect("dispatch should produce one message"); + assert_eq!(msg.id, "mid-x"); + assert_eq!(msg.content, "hello"); + assert_eq!(msg.channel, "yuanbao"); + } + + #[tokio::test] + async fn dispatch_push_filtered_by_dedup_does_not_double_dispatch() { + use super::super::proto::{encode_msg_body_element, encode_varint}; + let elem = types::MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: types::MsgContent { + text: Some("dup".into()), + ..Default::default() + }, + }; + let elem_bytes = encode_msg_body_element(&elem); + let mut biz = Vec::new(); + let put_string = |fnum: u32, s: &str, b: &mut Vec| { + encode_varint(((fnum as u64) << 3) | 2, b); + encode_varint(s.len() as u64, b); + b.extend_from_slice(s.as_bytes()); + }; + put_string(2, "alice", &mut biz); + put_string(3, "bot1", &mut biz); + put_string(12, "dup-id", &mut biz); + encode_varint(((13u64) << 3) | 2, &mut biz); + encode_varint(elem_bytes.len() as u64, &mut biz); + biz.extend_from_slice(&elem_bytes); + + let mut cfg = good_cfg(); + cfg.dm_access = "open".into(); + cfg.bot_id = "bot1".into(); + let ch = YuanbaoChannel::new(cfg).unwrap(); + let (tx, mut rx) = mpsc::channel::(4); + ch.dispatch_push(make_push_frame("inbound_message", biz.clone()), &tx) + .await; + assert!(rx.try_recv().is_ok(), "first should dispatch"); + ch.dispatch_push(make_push_frame("inbound_message", biz), &tx) + .await; + assert!(rx.try_recv().is_err(), "second (same id) should dedup"); + } + + // ─── heartbeat task lifecycle ────────────────────────────────── + + #[tokio::test] + async fn start_heartbeat_task_inserts_and_stop_removes() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + ch.start_heartbeat_task("recipient-1").await; + assert!( + ch.heartbeat_tasks.lock().await.contains_key("recipient-1"), + "should have spawned a task for recipient-1" + ); + // Second start for same recipient is a no-op (does not double-spawn). + ch.start_heartbeat_task("recipient-1").await; + assert_eq!(ch.heartbeat_tasks.lock().await.len(), 1); + + ch.stop_heartbeat_task("recipient-1").await; + assert!(ch.heartbeat_tasks.lock().await.is_empty()); + // Stopping a recipient with no task is also a no-op. + ch.stop_heartbeat_task("never-started").await; + } +} diff --git a/src/openhuman/channels/providers/yuanbao/config.rs b/src/openhuman/channels/providers/yuanbao/config.rs new file mode 100644 index 0000000000..3f79888d67 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/config.rs @@ -0,0 +1,250 @@ +//! Yuanbao channel configuration. +//! +//! Loaded from `ChannelsConfig.yuanbao` (TOML) and validated before the +//! channel is started. Mirrors the Python `YuanbaoAdapter` configuration +//! surface (hermes-agent `gateway/platforms/yuanbao.py`). + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use super::errors::YuanbaoError; + +/// Production environment endpoints (default). +const PROD_API_DOMAIN: &str = "https://bot.yuanbao.tencent.com"; +const PROD_WS_URL: &str = "wss://bot-wss.yuanbao.tencent.com/wss/connection"; +/// Pre-release environment endpoints. Opt in via `env = "pre"` in TOML. +const PRE_API_DOMAIN: &str = "https://bot-pre.yuanbao.tencent.com"; +const PRE_WS_URL: &str = "wss://bot-wss-pre.yuanbao.tencent.com/wss/connection"; + +/// User-facing config for the Yuanbao channel. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct YuanbaoConfig { + /// Application key (`X-ID` header / AuthBind biz_id). + pub app_key: String, + /// Application secret — used by the token-sign endpoint. + pub app_secret: String, + /// Bot account ID (uid for AuthBind). Optional — when empty, derived + /// from the AuthBindRsp payload after the first handshake. + #[serde(default)] + pub bot_id: String, + /// Environment selector for endpoint defaults: `"prod"` (default) or `"pre"`. + /// Only consulted when `api_domain` / `ws_domain` are empty. + #[serde(default = "default_env")] + pub env: String, + /// API base URL. Empty by default — derived from `env` at channel start. + /// Set explicitly in TOML to point at a custom deployment. + #[serde(default)] + pub api_domain: String, + /// WebSocket base URL. Empty by default — derived from `env` at channel + /// start. Set explicitly in TOML to point at a custom deployment. + #[serde(default)] + pub ws_domain: String, + /// Optional `route_env` header (canary routing). + #[serde(default)] + pub route_env: String, + /// Optional pre-provisioned token. When empty, the channel calls + /// `api_domain/api/token/sign` with `(app_key, app_secret)` to fetch one. + #[serde(default)] + pub token: String, + /// Plugin/bot version reported in `AuthBindReq.DeviceInfo.bot_version`. + #[serde(default = "default_bot_version")] + pub bot_version: String, + /// Optional bot display name — used by the `@bot` mention guard. + #[serde(default)] + pub bot_name: String, + + /// DM access policy: `open` / `allowlist` / `closed`. + #[serde(default = "default_dm_policy")] + pub dm_access: String, + /// Group access policy: `open` / `allowlist` / `closed`. + #[serde(default = "default_group_policy")] + pub group_access: String, + /// When `dm_access = "allowlist"`, only these UIDs may DM the bot. + #[serde(default)] + pub allowed_users: Vec, + /// When `group_access = "allowlist"`, only these group codes are allowed. + #[serde(default)] + pub allowed_groups: Vec, + /// Owner UID — receives elevated `/admin` commands. + #[serde(default)] + pub owner_id: String, + + /// Group messages must `@bot` to be processed (recommended). + #[serde(default = "default_true")] + pub group_at_required: bool, + + /// Maximum WS heartbeat interval override (seconds). 0 = use server-driven default. + #[serde(default)] + pub heartbeat_interval_secs: u64, + /// Reconnect retry budget — 0 means use the default cap (100). + #[serde(default)] + pub max_reconnect_attempts: u32, + + /// Per-message body length cap before splitting (UTF-8 bytes). + #[serde(default = "default_max_msg_len")] + pub max_message_length: usize, + /// Maximum inbound media file size in MiB. + #[serde(default = "default_max_media_mb")] + pub max_media_mb: u32, +} + +impl Default for YuanbaoConfig { + fn default() -> Self { + Self { + app_key: String::new(), + app_secret: String::new(), + bot_id: String::new(), + env: default_env(), + api_domain: String::new(), + ws_domain: String::new(), + route_env: String::new(), + token: String::new(), + bot_version: default_bot_version(), + bot_name: String::new(), + dm_access: default_dm_policy(), + group_access: default_group_policy(), + allowed_users: Vec::new(), + allowed_groups: Vec::new(), + owner_id: String::new(), + group_at_required: true, + heartbeat_interval_secs: 0, + max_reconnect_attempts: 0, + max_message_length: default_max_msg_len(), + max_media_mb: default_max_media_mb(), + } + } +} + +impl YuanbaoConfig { + /// Fill empty `api_domain` / `ws_domain` from the configured `env`. The + /// UI only collects `app_key` + `app_secret`; endpoints are derived + /// here so the renderer never needs to know about them. TOML values + /// take precedence (when non-empty), so existing deployments and + /// custom routes keep working. + pub fn apply_env_defaults(&mut self) { + let env = self.env.as_str(); + if self.api_domain.is_empty() { + self.api_domain = match env { + "pre" => PRE_API_DOMAIN.into(), + _ => PROD_API_DOMAIN.into(), + }; + } + if self.ws_domain.is_empty() { + self.ws_domain = match env { + "pre" => PRE_WS_URL.into(), + _ => PROD_WS_URL.into(), + }; + } + } + + /// Validate required fields. Called at channel construction time so + /// misconfiguration surfaces early in `start_channels`, not after a + /// failed WebSocket handshake. + pub fn validate(&self) -> Result<(), YuanbaoError> { + if self.app_key.is_empty() { + return Err(YuanbaoError::Config("`app_key` is required".into())); + } + if self.ws_domain.is_empty() { + return Err(YuanbaoError::Config("`ws_domain` is required".into())); + } + if self.token.is_empty() && self.app_secret.is_empty() { + return Err(YuanbaoError::Config( + "either `token` or `app_secret` must be set".into(), + )); + } + if self.api_domain.is_empty() && self.token.is_empty() { + return Err(YuanbaoError::Config( + "`api_domain` is required when `token` is not pre-provisioned".into(), + )); + } + Ok(()) + } +} + +fn default_bot_version() -> String { + "openhuman/0.1.0".into() +} + +fn default_env() -> String { + "prod".into() +} + +fn default_dm_policy() -> String { + "open".into() +} + +fn default_group_policy() -> String { + "allowlist".into() +} + +fn default_true() -> bool { + true +} + +fn default_max_msg_len() -> usize { + 4500 +} + +fn default_max_media_mb() -> u32 { + 50 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_config_is_invalid() { + let cfg = YuanbaoConfig::default(); + assert!(cfg.validate().is_err()); + } + + #[test] + fn validate_requires_app_key() { + let mut cfg = YuanbaoConfig::default(); + cfg.ws_domain = "wss://example".into(); + cfg.token = "tok".into(); + assert!(cfg.validate().is_err()); + cfg.app_key = "ak".into(); + assert!(cfg.validate().is_ok()); + } + + #[test] + fn validate_requires_token_or_secret() { + let mut cfg = YuanbaoConfig::default(); + cfg.app_key = "ak".into(); + cfg.ws_domain = "wss://example".into(); + cfg.api_domain = "https://api".into(); + assert!(cfg.validate().is_err()); + cfg.app_secret = "secret".into(); + assert!(cfg.validate().is_ok()); + } + + #[test] + fn apply_env_defaults_fills_prod_when_empty() { + let mut cfg = YuanbaoConfig::default(); + assert_eq!(cfg.env, "prod"); + cfg.apply_env_defaults(); + assert_eq!(cfg.api_domain, PROD_API_DOMAIN); + assert_eq!(cfg.ws_domain, PROD_WS_URL); + } + + #[test] + fn apply_env_defaults_respects_pre_env() { + let mut cfg = YuanbaoConfig::default(); + cfg.env = "pre".into(); + cfg.apply_env_defaults(); + assert_eq!(cfg.api_domain, PRE_API_DOMAIN); + assert_eq!(cfg.ws_domain, PRE_WS_URL); + } + + #[test] + fn apply_env_defaults_preserves_explicit_overrides() { + let mut cfg = YuanbaoConfig::default(); + cfg.api_domain = "https://custom.example".into(); + cfg.ws_domain = "wss://custom.example".into(); + cfg.apply_env_defaults(); + assert_eq!(cfg.api_domain, "https://custom.example"); + assert_eq!(cfg.ws_domain, "wss://custom.example"); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/connection.rs b/src/openhuman/channels/providers/yuanbao/connection.rs new file mode 100644 index 0000000000..c906533656 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/connection.rs @@ -0,0 +1,772 @@ +//! Yuanbao WebSocket connection manager. +//! +//! Owns one WebSocket to the gateway and runs: +//! 1. token sign-fetch (via [`SignManager`]) → `auth-bind` handshake +//! 2. periodic `ping` heartbeats +//! 3. inbound frame fan-out (decoded `ConnFrame` → mpsc) +//! 4. outbound request/response correlation via per-`msg_id` oneshot +//! 5. exponential-backoff reconnect with a no-retry close-code allowlist +//! +//! All public APIs are `&self` so the connection can be wrapped in +//! `Arc<…>` and shared between the listen loop and outbound senders. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use futures_util::{SinkExt, StreamExt}; +use parking_lot::Mutex as ParkingMutex; +use tokio::net::TcpStream; +use tokio::sync::{mpsc, oneshot, watch, Mutex}; +use tokio::time; +use tokio_tungstenite::{ + connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream, +}; +use tracing::{error, info, warn}; +use uuid::Uuid; + +use super::config::YuanbaoConfig; +use super::errors::{YuanbaoError, NO_RECONNECT_CLOSE_CODES}; +use super::proto::{ + decode_auth_bind_rsp, decode_conn_msg, encode_auth_bind, encode_ping, encode_push_ack, +}; +use super::proto_constants::*; +use super::sign::SignManager; +use super::types::{Account, ConnFrame, ConnectionState}; + +type WsSender = + futures_util::stream::SplitSink>, Message>; + +/// One inbound event delivered to the listen loop. +pub enum InboundEvent { + /// A regular biz push. + Push(ConnFrame), + /// Server told us we were kicked off. + Kickout(String), +} + +/// In-flight outbound request awaiting a matching `Response` frame. +type PendingMap = HashMap>; + +/// Long-lived connection manager. +pub struct YuanbaoConnection { + config: YuanbaoConfig, + state: ParkingMutex, + is_connected: AtomicBool, + msg_id_seq: AtomicU64, + sender: Mutex>, + inbound_tx: mpsc::UnboundedSender, + account: ParkingMutex, + sign_manager: Option>, + pending: ParkingMutex, +} + +impl YuanbaoConnection { + pub fn new( + config: YuanbaoConfig, + inbound_tx: mpsc::UnboundedSender, + sign_manager: Option>, + ) -> Arc { + let initial_account = Account { + uid: config.bot_id.clone(), + ..Default::default() + }; + Arc::new(Self { + config, + state: ParkingMutex::new(ConnectionState::Disconnected), + is_connected: AtomicBool::new(false), + msg_id_seq: AtomicU64::new(1), + sender: Mutex::new(None), + inbound_tx, + account: ParkingMutex::new(initial_account), + sign_manager, + pending: ParkingMutex::new(HashMap::new()), + }) + } + + pub fn is_connected(&self) -> bool { + self.is_connected.load(Ordering::Relaxed) + } + + pub fn state(&self) -> ConnectionState { + *self.state.lock() + } + + fn set_state(&self, new: ConnectionState) { + *self.state.lock() = new; + self.is_connected + .store(matches!(new, ConnectionState::Connected), Ordering::Relaxed); + } + + /// Current account info (best-effort — empty fields until auth-bind succeeds). + pub fn account(&self) -> Account { + self.account.lock().clone() + } + + fn update_account(&self, f: impl FnOnce(&mut Account)) { + let mut g = self.account.lock(); + f(&mut g); + } + + /// Per-process monotonic application msg_id. + pub fn next_msg_id(&self, prefix: &str) -> String { + let n = self.msg_id_seq.fetch_add(1, Ordering::Relaxed); + format!("{prefix}_{n}") + } + + /// Send a raw binary frame. Returns `NotConnected` if the connection + /// isn't currently up. + pub async fn send_frame(&self, data: Vec) -> Result<(), YuanbaoError> { + let mut guard = self.sender.lock().await; + match guard.as_mut() { + Some(s) => s + .send(Message::Binary(data)) + .await + .map_err(|e| YuanbaoError::WebSocket(e.to_string())), + None => Err(YuanbaoError::NotConnected), + } + } + + /// Send an already-encoded `ConnMsg` (alias of `send_frame`). + pub async fn send_conn_msg(&self, frame_bytes: Vec) -> Result<(), YuanbaoError> { + self.send_frame(frame_bytes).await + } + + /// Send a request and wait for the matching `Response` (correlated by + /// `msg_id`). Times out after `timeout` and removes the pending entry. + pub async fn send_and_wait( + &self, + msg_id: &str, + frame_bytes: Vec, + timeout: Duration, + ) -> Result { + let (tx, rx) = oneshot::channel(); + { + let mut p = self.pending.lock(); + p.insert(msg_id.to_string(), tx); + } + if let Err(e) = self.send_frame(frame_bytes).await { + self.pending.lock().remove(msg_id); + return Err(e); + } + match tokio::time::timeout(timeout, rx).await { + Ok(Ok(frame)) => Ok(frame), + Ok(Err(_)) => { + self.pending.lock().remove(msg_id); + Err(YuanbaoError::SendFailed(format!( + "correlator channel closed for msg_id={msg_id}" + ))) + } + Err(_) => { + self.pending.lock().remove(msg_id); + Err(YuanbaoError::Timeout(format!("msg_id={msg_id}"))) + } + } + } + + /// Trigger a graceful shutdown. + pub async fn shutdown(&self) { + let mut guard = self.sender.lock().await; + if let Some(mut s) = guard.take() { + let _ = s.send(Message::Close(None)).await; + let _ = s.close().await; + } + // Drop all pending waiters so callers stop hanging. + let mut pending = self.pending.lock(); + pending.clear(); + self.set_state(ConnectionState::Disconnected); + } + + /// Main reconnection loop. Returns when `shutdown` flips to `true`. + pub async fn run(self: Arc, mut shutdown: watch::Receiver) { + let max_attempts = if self.config.max_reconnect_attempts > 0 { + self.config.max_reconnect_attempts + } else { + MAX_RECONNECT_ATTEMPTS + }; + let mut attempt: u32 = 0; + + loop { + if *shutdown.borrow() { + info!("[yuanbao] shutdown signaled, stopping connection loop"); + self.shutdown().await; + return; + } + if attempt >= max_attempts { + error!("[yuanbao] giving up after {} reconnect attempts", attempt); + return; + } + + self.set_state(if attempt == 0 { + ConnectionState::Connecting + } else { + ConnectionState::Reconnecting + }); + + let outcome = self.connect_once(&mut shutdown).await; + match outcome { + Ok(Some(code)) if NO_RECONNECT_CLOSE_CODES.contains(&code) => { + error!("[yuanbao] no-reconnect close code {} — stopping", code); + return; + } + Ok(close_code) => info!("[yuanbao] connection closed (code={:?})", close_code), + Err(e) => warn!("[yuanbao] connection error: {}", e), + } + + self.set_state(ConnectionState::Disconnected); + *self.sender.lock().await = None; + self.pending.lock().clear(); + + // `connect_once` may have returned because shutdown fired inside + // its read loop. In that case we must not sleep through the + // reconnect backoff — exit immediately so stop is responsive. + if *shutdown.borrow() { + info!("[yuanbao] shutdown signaled, stopping connection loop"); + self.shutdown().await; + return; + } + + attempt += 1; + let delay = backoff_seconds(attempt); + info!( + "[yuanbao] reconnecting in {}s (attempt {}/{})", + delay, attempt, max_attempts + ); + tokio::select! { + _ = time::sleep(Duration::from_secs(delay)) => {} + _ = shutdown.changed() => { + info!("[yuanbao] shutdown received during backoff"); + self.shutdown().await; + return; + } + } + } + } + + async fn connect_once( + &self, + shutdown: &mut watch::Receiver, + ) -> Result, YuanbaoError> { + // Resolve token (may hit the sign endpoint). + let (token, bot_id, source) = self.resolve_token().await?; + if !bot_id.is_empty() { + self.update_account(|a| { + if a.uid.is_empty() { + a.uid = bot_id.clone(); + } + }); + } + + let url = &self.config.ws_domain; + info!("[yuanbao] connecting to {}", url); + let (ws_stream, _resp) = connect_async(url) + .await + .map_err(|e| YuanbaoError::WebSocket(e.to_string()))?; + + let (sender, mut receiver) = ws_stream.split(); + *self.sender.lock().await = Some(sender); + info!("[yuanbao] WebSocket connected — sending auth-bind"); + + self.set_state(ConnectionState::Authenticating); + self.send_auth_bind(&token, &bot_id, &source).await?; + + // Wait for auth-bind response. + let auth_timeout = Duration::from_secs(AUTH_TIMEOUT_SECS); + let auth_msg = tokio::time::timeout(auth_timeout, receiver.next()) + .await + .map_err(|_| YuanbaoError::AuthTimeout)? + .ok_or_else(|| YuanbaoError::WebSocket("closed during auth-bind".into()))? + .map_err(|e| YuanbaoError::WebSocket(e.to_string()))?; + + self.handle_auth_response(&auth_msg)?; + self.set_state(ConnectionState::Connected); + info!("[yuanbao] auth-bind successful, entering read loop"); + + let ping_secs = if self.config.heartbeat_interval_secs > 0 { + self.config.heartbeat_interval_secs + } else { + PING_INTERVAL_SECS + }; + let mut ping_interval = time::interval(Duration::from_secs(ping_secs)); + ping_interval.tick().await; // skip first tick + + let mut close_code: Option = None; + let mut consecutive_ping_failures: u32 = 0; + + loop { + tokio::select! { + _ = shutdown.changed() => { + info!("[yuanbao] shutdown received in read loop"); + return Ok(None); + } + _ = ping_interval.tick() => { + let msg_id = self.next_msg_id("ping"); + let frame = encode_ping(&msg_id); + if let Err(e) = self.send_frame(frame).await { + warn!("[yuanbao] ping send failed: {}", e); + consecutive_ping_failures += 1; + if consecutive_ping_failures >= HEARTBEAT_TIMEOUT_THRESHOLD { + warn!( + "[yuanbao] {} consecutive ping failures — dropping", + consecutive_ping_failures + ); + break; + } + } else { + consecutive_ping_failures = 0; + } + } + msg = receiver.next() => { + match msg { + Some(Ok(Message::Binary(data))) => self.handle_binary(data).await, + Some(Ok(Message::Close(frame))) => { + close_code = frame.map(|f| u16::from(f.code)); + info!("[yuanbao] received close frame: {:?}", close_code); + break; + } + Some(Ok(Message::Ping(payload))) => { + let mut guard = self.sender.lock().await; + if let Some(s) = guard.as_mut() { + let _ = s.send(Message::Pong(payload)).await; + } + } + Some(Ok(_)) => {} + Some(Err(e)) => { + warn!("[yuanbao] websocket read error: {}", e); + break; + } + None => { + info!("[yuanbao] websocket stream ended"); + break; + } + } + } + } + } + + Ok(close_code) + } + + async fn resolve_token(&self) -> Result<(String, String, String), YuanbaoError> { + let cfg = &self.config; + if !cfg.token.is_empty() { + // Pre-signed token: no source returned by the sign endpoint. + // Mirrors yuanbao-openclaw-plugin's static-token branch, which + // returns source="bot". + return Ok((cfg.token.clone(), cfg.bot_id.clone(), String::new())); + } + let mgr = self + .sign_manager + .as_ref() + .ok_or_else(|| YuanbaoError::AuthFailed("no token and no SignManager".into()))?; + if cfg.app_secret.is_empty() { + return Err(YuanbaoError::AuthFailed( + "app_secret required to sign".into(), + )); + } + let entry = mgr + .get_token( + &cfg.app_key, + &cfg.app_secret, + &cfg.api_domain, + &cfg.route_env, + ) + .await?; + Ok((entry.token, entry.bot_id, entry.source)) + } + + async fn send_auth_bind( + &self, + token: &str, + bot_id: &str, + source: &str, + ) -> Result<(), YuanbaoError> { + let cfg = &self.config; + let uid = if bot_id.is_empty() { + self.account.lock().uid.clone() + } else { + bot_id.to_string() + }; + let msg_id = format!("auth_{}", Uuid::new_v4()); + // Auth-bind payload aligned with yuanbao-openclaw-plugin: + // biz_id = "ybBot" (server rejects raw app_key with 40011). + // source comes from the sign endpoint response; fall back to + // "bot" when missing (matches the plugin's static-token branch + // and `data.source || "bot"` resolution). + let resolved_source = if source.is_empty() { "bot" } else { source }; + let frame = encode_auth_bind( + "ybBot", + &uid, + resolved_source, + token, + &msg_id, + env!("CARGO_PKG_VERSION"), + std::env::consts::OS, + &cfg.bot_version, + &cfg.route_env, + ); + self.send_frame(frame).await + } + + fn handle_auth_response(&self, msg: &Message) -> Result<(), YuanbaoError> { + let data = match msg { + Message::Binary(b) => b, + _ => { + return Err(YuanbaoError::AuthFailed( + "expected binary auth-bind response".into(), + )) + } + }; + let frame = decode_conn_msg(data)?; + if frame.cmd != cmd::AUTH_BIND { + return Err(YuanbaoError::AuthFailed(format!( + "unexpected cmd in auth response: {:?}", + frame.cmd + ))); + } + if frame.status != 0 { + return Err(YuanbaoError::AuthFailed(format!( + "auth rejected: status={}", + frame.status + ))); + } + // Body carries code/message/connect_id — back-fill the account. + if !frame.data.is_empty() { + let rsp = decode_auth_bind_rsp(&frame.data)?; + if rsp.code != 0 { + return Err(YuanbaoError::AuthFailed(format!( + "auth-bind code={} message={}", + rsp.code, rsp.message + ))); + } + if !rsp.connect_id.is_empty() { + self.update_account(|a| a.connect_id = rsp.connect_id.clone()); + info!("[yuanbao] auth-bind connect_id={}", rsp.connect_id); + } + } + Ok(()) + } + + async fn handle_binary(&self, data: Vec) { + let frame = match decode_conn_msg(&data) { + Ok(f) => f, + Err(e) => { + warn!("[yuanbao] failed to decode binary frame: {}", e); + return; + } + }; + + info!( + "[yuanbao] rx cmd={} module={} cmd_type={} seq={} msg_id={} data_len={}", + frame.cmd, + frame.module, + frame.cmd_type, + frame.seq_no, + frame.msg_id, + frame.data.len() + ); + + // Responses → match against pending requests via msg_id. + if frame.cmd_type == cmd_type::RESPONSE { + if !frame.msg_id.is_empty() { + if let Some(tx) = self.pending.lock().remove(&frame.msg_id) { + let _ = tx.send(frame); + return; + } + } + info!( + "[yuanbao] response with no waiter cmd={} msg_id={}", + frame.cmd, frame.msg_id + ); + return; + } + + // For server-driven pushes, ACK first when the head asks for it. + if frame.cmd_type == cmd_type::PUSH && frame.need_ack { + let ack = encode_push_ack(&frame); + if let Err(e) = self.send_frame(ack).await { + warn!("[yuanbao] failed to send PushAck: {}", e); + } + } + + // Handle conn-level builtin pushes inline. + if frame.cmd == cmd::KICKOUT { + let reason = String::from_utf8_lossy(&frame.data).into_owned(); + warn!("[yuanbao] kickout received: {}", reason); + let _ = self.inbound_tx.send(InboundEvent::Kickout(reason)); + return; + } + if frame.cmd == cmd::UPDATE_META { + return; + } + + if frame.cmd_type != cmd_type::PUSH { + info!( + "[yuanbao] dropping non-push frame cmd_type={} cmd={}", + frame.cmd_type, frame.cmd + ); + return; + } + + info!( + "[yuanbao] push forwarded to listener cmd={} module={} seq={}", + frame.cmd, frame.module, frame.seq_no + ); + if self.inbound_tx.send(InboundEvent::Push(frame)).is_err() { + error!("[yuanbao] inbound channel closed — listener gone"); + } + } +} + +/// Backoff schedule used by `run()`. After the configured table is +/// exhausted we cap at the last entry forever (until the attempt budget +/// trips). Indexing is 1-based so attempt=1 → table[0]. +fn backoff_seconds(attempt: u32) -> u64 { + let idx = attempt.saturating_sub(1) as usize; + if idx < RECONNECT_DELAYS.len() { + RECONNECT_DELAYS[idx] + } else { + *RECONNECT_DELAYS.last().unwrap_or(&60) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn backoff_follows_schedule() { + assert_eq!(backoff_seconds(1), 1); + assert_eq!(backoff_seconds(2), 2); + assert_eq!(backoff_seconds(3), 5); + assert_eq!(backoff_seconds(6), 60); + assert_eq!(backoff_seconds(100), 60); + assert_eq!(backoff_seconds(0), 1); + } + + fn cfg() -> YuanbaoConfig { + let mut c = YuanbaoConfig::default(); + c.app_key = "ak".into(); + c.ws_domain = "wss://example".into(); + c.token = "tok".into(); + c.bot_id = "bot1".into(); + c + } + + #[tokio::test] + async fn pending_correlator_times_out() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + let err = conn + .send_and_wait("missing_id", vec![1, 2, 3], Duration::from_millis(20)) + .await + .unwrap_err(); + // Without a connected socket, send_frame fails first → SendFailed/NotConnected. + assert!(matches!( + err, + YuanbaoError::NotConnected | YuanbaoError::Timeout(_) | YuanbaoError::SendFailed(_) + )); + } + + #[tokio::test] + async fn account_back_fill_picks_up_uid() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + assert_eq!(conn.account().uid, "bot1"); + conn.update_account(|a| a.connect_id = "cid_xyz".into()); + assert_eq!(conn.account().connect_id, "cid_xyz"); + } + + #[tokio::test] + async fn next_msg_id_is_monotonic_and_prefixed() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + let a = conn.next_msg_id("pfx"); + let b = conn.next_msg_id("pfx"); + assert!(a.starts_with("pfx_")); + assert!(b.starts_with("pfx_")); + // Suffix is monotonically increasing. + let na: u64 = a.strip_prefix("pfx_").unwrap().parse().unwrap(); + let nb: u64 = b.strip_prefix("pfx_").unwrap().parse().unwrap(); + assert!(nb > na); + } + + #[tokio::test] + async fn initial_state_is_disconnected() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + assert_eq!(conn.state(), ConnectionState::Disconnected); + assert!(!conn.is_connected()); + } + + #[tokio::test] + async fn set_state_connected_flips_is_connected_flag() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + conn.set_state(ConnectionState::Connected); + assert_eq!(conn.state(), ConnectionState::Connected); + assert!(conn.is_connected()); + conn.set_state(ConnectionState::Reconnecting); + assert_eq!(conn.state(), ConnectionState::Reconnecting); + assert!(!conn.is_connected()); + } + + #[tokio::test] + async fn send_frame_without_socket_returns_not_connected() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + let err = conn.send_frame(vec![1, 2, 3]).await.unwrap_err(); + assert!(matches!(err, YuanbaoError::NotConnected)); + let err2 = conn.send_conn_msg(vec![4]).await.unwrap_err(); + assert!(matches!(err2, YuanbaoError::NotConnected)); + } + + #[tokio::test] + async fn shutdown_clears_pending_and_sets_disconnected() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + conn.set_state(ConnectionState::Connected); + // Drop a phantom pending entry then shutdown. + let (phantom_tx, _phantom_rx) = oneshot::channel(); + conn.pending.lock().insert("ghost".into(), phantom_tx); + conn.shutdown().await; + assert_eq!(conn.state(), ConnectionState::Disconnected); + assert!(!conn.is_connected()); + assert!(conn.pending.lock().is_empty()); + } + + #[test] + fn backoff_caps_at_last_entry_for_huge_attempts() { + let last = *RECONNECT_DELAYS.last().unwrap(); + assert_eq!(backoff_seconds(RECONNECT_DELAYS.len() as u32 + 5), last); + } + + #[tokio::test] + async fn resolve_token_uses_static_token_when_present() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + let (token, bot_id, source) = conn.resolve_token().await.unwrap(); + assert_eq!(token, "tok"); + assert_eq!(bot_id, "bot1"); + assert_eq!(source, ""); + } + + #[tokio::test] + async fn resolve_token_without_token_and_without_sign_manager_errors() { + let (tx, _rx) = mpsc::unbounded_channel(); + let mut c = cfg(); + c.token = String::new(); + let conn = YuanbaoConnection::new(c, tx, None); + match conn.resolve_token().await.unwrap_err() { + YuanbaoError::AuthFailed(m) => assert!(m.contains("no token"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[tokio::test] + async fn resolve_token_with_sign_manager_but_no_app_secret_errors() { + let (tx, _rx) = mpsc::unbounded_channel(); + let mut c = cfg(); + c.token = String::new(); + c.app_secret = String::new(); + let mgr = SignManager::new(reqwest::Client::new()); + let conn = YuanbaoConnection::new(c, tx, Some(mgr)); + match conn.resolve_token().await.unwrap_err() { + YuanbaoError::AuthFailed(m) => assert!(m.contains("app_secret"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[tokio::test] + async fn send_auth_bind_without_socket_returns_not_connected() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + let err = conn.send_auth_bind("tok", "bot1", "bot").await.unwrap_err(); + assert!(matches!(err, YuanbaoError::NotConnected)); + } + + #[tokio::test] + async fn send_auth_bind_falls_back_to_account_uid_when_bot_id_empty() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + // bot_id="" → reads from account.uid (which was seeded from cfg.bot_id="bot1") + let err = conn.send_auth_bind("tok", "", "").await.unwrap_err(); + assert!(matches!(err, YuanbaoError::NotConnected)); + // Account uid still in place. + assert_eq!(conn.account().uid, "bot1"); + } + + #[test] + fn handle_auth_response_rejects_non_binary_message() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + let msg = Message::Text("nope".into()); + match conn.handle_auth_response(&msg).unwrap_err() { + YuanbaoError::AuthFailed(m) => { + assert!(m.contains("binary"), "got {m}") + } + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn handle_auth_response_rejects_undecodable_binary() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + // Wholly invalid wire data — decode_conn_msg fails. + let msg = Message::Binary(vec![0xFF, 0xFF, 0xFF, 0xFF]); + let err = conn.handle_auth_response(&msg).unwrap_err(); + // Either Proto decode error or some other surface — must not be Ok. + assert!( + !matches!(err, YuanbaoError::AuthFailed(_) if format!("{err:?}").contains("binary")) + ); + } + + /// Regression guard for the post-`connect_once` shutdown short-circuit: + /// once shutdown is signaled, `run()` must not block on the reconnect + /// backoff. We force connect_once to fail synchronously (invalid WS URL), + /// then signal shutdown — total runtime must be well under the first + /// backoff slot (`backoff_seconds(1) == 1s`). + #[tokio::test] + async fn run_exits_promptly_after_shutdown_signal() { + use std::time::Instant; + let (tx, _rx) = mpsc::unbounded_channel(); + let mut c = cfg(); + // tokio-tungstenite rejects the URL synchronously — connect_once + // returns Err in microseconds, putting `run()` on the post-connect + // cleanup path that the fix targets. + c.ws_domain = "not-a-valid-ws-url".to_string(); + c.max_reconnect_attempts = 100; + let conn = YuanbaoConnection::new(c, tx, None); + let (sd_tx, sd_rx) = watch::channel(false); + + let handle = tokio::spawn(conn.clone().run(sd_rx)); + // Let `run()` enter the loop and attempt connect_once at least once. + time::sleep(Duration::from_millis(20)).await; + + let started = Instant::now(); + sd_tx.send(true).unwrap(); + + // The first reconnect backoff slot is 1s. Without responsive + // shutdown handling, run() would sleep through it before checking + // the flag. 500ms gives us comfortable headroom while staying + // far enough below the backoff to detect a regression. + let res = time::timeout(Duration::from_millis(500), handle).await; + res.expect("run() did not exit within 500ms of shutdown signal") + .expect("run() task panicked"); + assert!( + started.elapsed() < Duration::from_millis(500), + "run() took {:?} to exit after shutdown — backoff was not skipped", + started.elapsed() + ); + } + + #[tokio::test] + async fn handle_binary_with_garbage_does_not_panic() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + // Should silently log + return — no panic. + conn.handle_binary(vec![0xFF, 0xFF, 0xFF, 0xFF]).await; + } +} diff --git a/src/openhuman/channels/providers/yuanbao/cos.rs b/src/openhuman/channels/providers/yuanbao/cos.rs new file mode 100644 index 0000000000..e7aa40fec0 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/cos.rs @@ -0,0 +1,567 @@ +//! Tencent COS upload — HMAC-SHA1 signing and `genUploadInfo` flow. +//! +//! Split out of `media.rs` to stay under the 500-line per-file ceiling. +//! Reference: . + +use std::time::{SystemTime, UNIX_EPOCH}; + +use hmac::{Hmac, Mac}; +use sha1::{Digest, Sha1}; +use tracing::{debug, info}; + +use super::errors::YuanbaoError; +use super::media::{guess_mime_type, is_image, parse_image_size, ImageDims}; + +const UPLOAD_INFO_PATH: &str = "/api/resource/genUploadInfo"; +const COS_USE_ACCELERATE: bool = true; + +type HmacSha1 = Hmac; + +fn hmac_sha1_hex(key: &[u8], msg: &[u8]) -> String { + let mut mac = HmacSha1::new_from_slice(key).expect("HMAC accepts any key length"); + mac.update(msg); + hex::encode(mac.finalize().into_bytes()) +} + +fn sha1_hex(msg: &[u8]) -> String { + let mut hasher = Sha1::new(); + hasher.update(msg); + hex::encode(hasher.finalize()) +} + +#[derive(Debug, Clone)] +pub struct CosSignInput<'a> { + pub method: &'a str, + /// URL-encoded path with leading `/`. + pub path: &'a str, + pub params: &'a [(&'a str, &'a str)], + pub headers: &'a [(&'a str, &'a str)], + pub secret_id: &'a str, + pub secret_key: &'a str, + pub start_time: u64, + pub expire_seconds: u64, +} + +/// Build the COS `Authorization` header value. +pub fn cos_sign(input: &CosSignInput<'_>) -> String { + let q_sign_time = format!( + "{};{}", + input.start_time, + input.start_time + input.expire_seconds + ); + + // Step 1 — SignKey = HMAC-SHA1(SecretKey, q-sign-time). + let sign_key = hmac_sha1_hex(input.secret_key.as_bytes(), q_sign_time.as_bytes()); + + // Step 2 — HttpString. Names lower-cased, values URL-encoded. + let mut params: Vec<(String, String)> = input + .params + .iter() + .map(|(k, v)| (k.to_ascii_lowercase(), url_encode(v))) + .collect(); + params.sort(); + let mut headers: Vec<(String, String)> = input + .headers + .iter() + .map(|(k, v)| (k.to_ascii_lowercase(), url_encode(v))) + .collect(); + headers.sort(); + + let url_param_list = params + .iter() + .map(|(k, _)| k.as_str()) + .collect::>() + .join(";"); + let url_params = params + .iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join("&"); + let header_list = headers + .iter() + .map(|(k, _)| k.as_str()) + .collect::>() + .join(";"); + let header_str = headers + .iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join("&"); + + let http_string = format!( + "{}\n{}\n{}\n{}\n", + input.method.to_ascii_lowercase(), + input.path, + url_params, + header_str + ); + + // Step 3 — StringToSign. + let sha1_of_http = sha1_hex(http_string.as_bytes()); + let string_to_sign = format!("sha1\n{q_sign_time}\n{sha1_of_http}\n"); + + // Step 4 — Signature. + let signature = hmac_sha1_hex(sign_key.as_bytes(), string_to_sign.as_bytes()); + + format!( + "q-sign-algorithm=sha1&q-ak={sid}&q-sign-time={t}&q-key-time={t}\ + &q-header-list={hl}&q-url-param-list={pl}&q-signature={sig}", + sid = input.secret_id, + t = q_sign_time, + hl = header_list, + pl = url_param_list, + sig = signature + ) +} + +fn url_encode(s: &str) -> String { + urlencoding::encode(s).into_owned() +} + +fn encode_cos_key(key: &str) -> String { + key.split('/') + .map(|seg| urlencoding::encode(seg).into_owned()) + .collect::>() + .join("/") +} + +#[derive(Debug, Clone, Default)] +pub struct CosCredentials { + pub bucket: String, + pub region: String, + pub location: String, + pub secret_id: String, + pub secret_key: String, + pub session_token: String, + pub start_time: u64, + pub expired_time: u64, + pub resource_url: String, +} + +#[derive(Debug, Clone)] +pub struct UploadResult { + pub url: String, + pub uuid: String, + pub size: u64, + pub width: u32, + pub height: u32, +} + +/// Fetch COS upload credentials from the yuanbao gateway. +pub async fn get_cos_credentials( + http: &reqwest::Client, + api_domain: &str, + app_key: &str, + bot_id: &str, + token: &str, + route_env: &str, + filename: &str, +) -> Result { + let upload_url = format!( + "{}/{}", + api_domain.trim_end_matches('/'), + UPLOAD_INFO_PATH.trim_start_matches('/') + ); + let body = serde_json::json!({ + "fileName": filename, + "fileId": uuid::Uuid::new_v4().simple().to_string(), + "docFrom": "localDoc", + "docOpenId": "", + }); + let mut req = http + .post(&upload_url) + .header("Content-Type", "application/json") + .header("X-Token", token) + .header("X-ID", if bot_id.is_empty() { app_key } else { bot_id }) + .header("X-Source", "web"); + if !route_env.is_empty() { + req = req.header("X-Route-Env", route_env); + } + let resp = req + .json(&body) + .send() + .await + .map_err(|e| YuanbaoError::Connection(format!("genUploadInfo: {e}")))?; + if !resp.status().is_success() { + return Err(YuanbaoError::Media(format!( + "genUploadInfo HTTP {}", + resp.status() + ))); + } + let payload: serde_json::Value = resp + .json() + .await + .map_err(|e| YuanbaoError::Media(format!("genUploadInfo body parse: {e}")))?; + if let Some(code) = payload.get("code").and_then(|c| c.as_i64()) { + if code != 0 { + return Err(YuanbaoError::Media(format!( + "genUploadInfo code={code}, msg={}", + payload.get("msg").and_then(|m| m.as_str()).unwrap_or("") + ))); + } + } + let data = payload.get("data").unwrap_or(&payload); + let get_str = |k: &str| -> String { + data.get(k) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string() + }; + let get_u64 = |k: &str| -> u64 { data.get(k).and_then(|v| v.as_u64()).unwrap_or(0) }; + + Ok(CosCredentials { + bucket: get_str("bucketName"), + region: get_str("region"), + location: get_str("location"), + secret_id: get_str("encryptTmpSecretId"), + secret_key: get_str("encryptTmpSecretKey"), + session_token: get_str("encryptToken"), + start_time: get_u64("startTime"), + expired_time: get_u64("expiredTime"), + resource_url: get_str("resourceUrl"), + }) +} + +/// PUT a file to COS using credentials returned by `get_cos_credentials`. +pub async fn upload_to_cos( + http: &reqwest::Client, + creds: &CosCredentials, + data: &[u8], + filename: &str, + mut content_type: String, +) -> Result { + if creds.secret_id.is_empty() || creds.secret_key.is_empty() || creds.location.is_empty() { + return Err(YuanbaoError::Media( + "COS credentials missing secret_id / secret_key / location".into(), + )); + } + if content_type.is_empty() || content_type == "application/octet-stream" { + content_type = if is_image(filename, "") { + guess_mime_type(filename).to_string() + } else { + "application/octet-stream".into() + }; + } + + let cos_host = if COS_USE_ACCELERATE { + format!("{}.cos.accelerate.myqcloud.com", creds.bucket) + } else { + format!("{}.cos.{}.myqcloud.com", creds.bucket, creds.region) + }; + let encoded_key = encode_cos_key(&creds.location); + let cos_url = format!("https://{cos_host}/{}", encoded_key.trim_start_matches('/')); + + let now = unix_now(); + let start = if creds.start_time != 0 { + creds.start_time + } else { + now + }; + let expire = if creds.expired_time > now { + creds.expired_time - now + } else { + 3600 + }; + + let headers_for_sign: Vec<(&str, &str)> = vec![ + ("host", cos_host.as_str()), + ("content-type", content_type.as_str()), + ("x-cos-security-token", creds.session_token.as_str()), + ]; + let path = format!("/{}", encoded_key.trim_start_matches('/')); + let sig = cos_sign(&CosSignInput { + method: "put", + path: &path, + params: &[], + headers: &headers_for_sign, + secret_id: &creds.secret_id, + secret_key: &creds.secret_key, + start_time: start, + expire_seconds: expire, + }); + + info!( + "[yuanbao] COS PUT bucket={} key={} size={}", + creds.bucket, + creds.location, + data.len() + ); + let resp = http + .put(&cos_url) + .header("Authorization", sig) + .header("Content-Type", content_type.as_str()) + .header("x-cos-security-token", &creds.session_token) + .body(data.to_vec()) + .send() + .await + .map_err(|e| YuanbaoError::Connection(format!("COS PUT: {e}")))?; + if !resp.status().is_success() { + return Err(YuanbaoError::Media(format!( + "COS PUT HTTP {}", + resp.status() + ))); + } + + let dims = if content_type.starts_with("image/") { + parse_image_size(data).unwrap_or(ImageDims { + width: 0, + height: 0, + }) + } else { + ImageDims { + width: 0, + height: 0, + } + }; + + let uuid = { + let mut h = Sha1::new(); + h.update(data); + hex::encode(h.finalize()) + }; + let url = if creds.resource_url.is_empty() { + cos_url + } else { + creds.resource_url.clone() + }; + debug!("[yuanbao] COS upload ok url={url}"); + Ok(UploadResult { + url, + uuid, + size: data.len() as u64, + width: dims.width, + height: dims.height, + }) +} + +fn unix_now() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cos_sign_is_deterministic() { + let s = cos_sign(&CosSignInput { + method: "put", + path: "/test/file.bin", + params: &[], + headers: &[("host", "bucket.cos.example.com")], + secret_id: "AKID", + secret_key: "SK", + start_time: 1_700_000_000, + expire_seconds: 3600, + }); + let s2 = cos_sign(&CosSignInput { + method: "put", + path: "/test/file.bin", + params: &[], + headers: &[("host", "bucket.cos.example.com")], + secret_id: "AKID", + secret_key: "SK", + start_time: 1_700_000_000, + expire_seconds: 3600, + }); + assert_eq!(s, s2); + assert!(s.starts_with("q-sign-algorithm=sha1")); + assert!(s.contains("q-ak=AKID")); + assert!(s.contains("q-sign-time=1700000000;1700003600")); + } + + #[test] + fn cos_sign_changes_with_path() { + let base = CosSignInput { + method: "put", + path: "/a", + params: &[], + headers: &[("host", "h")], + secret_id: "AKID", + secret_key: "SK", + start_time: 1_700_000_000, + expire_seconds: 3600, + }; + let s1 = cos_sign(&base); + let s2 = cos_sign(&CosSignInput { path: "/b", ..base }); + assert_ne!(s1, s2); + } + + #[test] + fn cos_sign_lowercases_method_and_includes_url_params() { + let s = cos_sign(&CosSignInput { + method: "PUT", // mixed case → should be lowercased into sig + path: "/k", + params: &[("Foo", "Bar Baz")], // url-encoded value + headers: &[("Host", "h")], + secret_id: "AKID", + secret_key: "SK", + start_time: 1_700_000_000, + expire_seconds: 600, + }); + assert!(s.contains("q-url-param-list=foo")); + // header list also lowercased + assert!(s.contains("q-header-list=host")); + } + + fn ok_credentials_body(bucket: &str, location: &str) -> serde_json::Value { + serde_json::json!({ + "code": 0, + "data": { + "bucketName": bucket, + "region": "ap-shanghai", + "location": location, + "encryptTmpSecretId": "AKID", + "encryptTmpSecretKey": "SECRET", + "encryptToken": "session-token", + "startTime": 1_700_000_000u64, + "expiredTime": 1_700_003_600u64, + "resourceUrl": "https://cdn.example/r", + } + }) + } + + #[tokio::test] + async fn get_cos_credentials_parses_data_block() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(UPLOAD_INFO_PATH)) + .and(wiremock::matchers::header("X-Token", "tok")) + .and(wiremock::matchers::header("X-Source", "web")) + .respond_with( + wiremock::ResponseTemplate::new(200) + .set_body_json(ok_credentials_body("bkt-1", "k/v/file.png")), + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let creds = get_cos_credentials(&http, &server.uri(), "appk", "bot", "tok", "", "file.png") + .await + .unwrap(); + assert_eq!(creds.bucket, "bkt-1"); + assert_eq!(creds.region, "ap-shanghai"); + assert_eq!(creds.location, "k/v/file.png"); + assert_eq!(creds.secret_id, "AKID"); + assert_eq!(creds.secret_key, "SECRET"); + assert_eq!(creds.session_token, "session-token"); + assert_eq!(creds.resource_url, "https://cdn.example/r"); + assert_eq!(creds.start_time, 1_700_000_000); + assert_eq!(creds.expired_time, 1_700_003_600); + } + + #[tokio::test] + async fn get_cos_credentials_falls_back_to_app_key_for_xid_when_bot_id_empty() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(UPLOAD_INFO_PATH)) + .and(wiremock::matchers::header("X-ID", "appk")) + .respond_with( + wiremock::ResponseTemplate::new(200) + .set_body_json(ok_credentials_body("bkt", "loc")), + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let creds = get_cos_credentials(&http, &server.uri(), "appk", "", "tok", "", "f") + .await + .unwrap(); + assert_eq!(creds.bucket, "bkt"); + } + + #[tokio::test] + async fn get_cos_credentials_sends_route_env_header_when_non_empty() { + let server = wiremock::MockServer::start().await; + // Bind the matcher to both the upload-info path AND the header so + // this test fails if a future refactor routes the call elsewhere + // but happens to still attach `X-Route-Env: canary` somewhere. + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(UPLOAD_INFO_PATH)) + .and(wiremock::matchers::header("X-Route-Env", "canary")) + .respond_with( + wiremock::ResponseTemplate::new(200) + .set_body_json(ok_credentials_body("bkt", "loc")), + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + get_cos_credentials(&http, &server.uri(), "appk", "bot", "tok", "canary", "f") + .await + .expect("should send canary header"); + } + + #[tokio::test] + async fn get_cos_credentials_surfaces_http_error() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .respond_with(wiremock::ResponseTemplate::new(500).set_body_string("boom")) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let err = get_cos_credentials(&http, &server.uri(), "appk", "bot", "tok", "", "f") + .await + .unwrap_err(); + match err { + YuanbaoError::Media(m) => assert!(m.contains("HTTP 500"), "got {m}"), + other => panic!("expected Media error, got {other:?}"), + } + } + + #[tokio::test] + async fn get_cos_credentials_surfaces_non_zero_business_code() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 4001, + "msg": "quota", + })), + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let err = get_cos_credentials(&http, &server.uri(), "appk", "bot", "tok", "", "f") + .await + .unwrap_err(); + match err { + YuanbaoError::Media(m) => { + assert!(m.contains("code=4001"), "got {m}"); + assert!(m.contains("quota"), "got {m}"); + } + other => panic!("expected Media error, got {other:?}"), + } + } + + #[tokio::test] + async fn upload_to_cos_rejects_missing_credentials() { + let http = reqwest::Client::new(); + // empty credentials → fail without making any HTTP call + let bad = CosCredentials::default(); + let err = upload_to_cos( + &http, + &bad, + b"data", + "f.bin", + "application/octet-stream".into(), + ) + .await + .unwrap_err(); + match err { + YuanbaoError::Media(m) => assert!(m.contains("credentials missing"), "got {m}"), + other => panic!("expected Media error, got {other:?}"), + } + } + + // NOTE: upload_to_cos always targets `.cos.accelerate.myqcloud.com` + // which we cannot redirect at the reqwest layer without DNS hacks, so we + // only cover the guard branch (missing creds) above. The PUT body itself + // is exercised by integration tests, not unit tests. + + #[test] + fn encode_cos_key_keeps_slashes_but_escapes_segments() { + assert_eq!(encode_cos_key("plain/file.png"), "plain/file.png"); + assert_eq!(encode_cos_key("a b/c d.png"), "a%20b/c%20d.png"); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/errors.rs b/src/openhuman/channels/providers/yuanbao/errors.rs new file mode 100644 index 0000000000..de0d3d2335 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/errors.rs @@ -0,0 +1,61 @@ +//! Yuanbao channel error types. + +use thiserror::Error; + +/// Close codes from the yuanbao gateway that indicate the connection +/// must **not** be retried (auth failure, kicked off, etc.). +/// +/// Mirrors `NO_RECONNECT_CLOSE_CODES` in hermes-agent `yuanbao.py`. +pub const NO_RECONNECT_CLOSE_CODES: &[u16] = &[4012, 4013, 4014, 4018, 4019, 4021]; + +/// Auth-related response codes that mean "credentials are bad" — surface +/// to the user, don't auto-retry. +pub const AUTH_FAILED_CODES: &[u32] = &[40001, 40002, 40003]; + +/// Auth-related codes that are transient — retry with backoff. +pub const AUTH_RETRYABLE_CODES: &[u32] = &[40010, 40011]; + +#[derive(Debug, Error)] +pub enum YuanbaoError { + #[error("protocol encode error: {0}")] + ProtoEncode(String), + + #[error("protocol decode error: {0}")] + ProtoDecode(String), + + #[error("not connected")] + NotConnected, + + #[error("connection closed: code={code}, reason={reason}")] + ConnectionClosed { code: u16, reason: String }, + + #[error("WebSocket error: {0}")] + WebSocket(String), + + #[error("HTTP/connection error: {0}")] + Connection(String), + + #[error("auth-bind failed: {0}")] + AuthFailed(String), + + #[error("auth-bind timeout")] + AuthTimeout, + + #[error("login timeout")] + LoginTimeout, + + #[error("request timeout: {0}")] + Timeout(String), + + #[error("send-message failed: {0}")] + SendFailed(String), + + #[error("media error: {0}")] + Media(String), + + #[error("invalid message: {0}")] + InvalidMessage(String), + + #[error("config error: {0}")] + Config(String), +} diff --git a/src/openhuman/channels/providers/yuanbao/ids.rs b/src/openhuman/channels/providers/yuanbao/ids.rs new file mode 100644 index 0000000000..4375e561b6 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/ids.rs @@ -0,0 +1,115 @@ +//! Account-id shortening for yuanbao. +//! +//! Yuanbao uids (`from_account`) are 64-char hashes assigned by the platform. +//! The composite `ChannelMessage` thread_id that downstream consumers derive +//! from `sender` and `reply_target` (`channel:yuanbao__`) +//! becomes ~145 chars. After the conversation store hex-encodes that for the +//! per-thread JSONL filename it grows to ~296 chars, exceeding `NAME_MAX` +//! (255 bytes) on ext4/HFS+/APFS/NTFS — writes fail with `ENAMETOOLONG` and +//! channel history is lost. +//! +//! Rather than push the filesystem limit into shared `ConversationStore` code, +//! we shorten yuanbao-specific ids at the channel boundary. Internal yuanbao +//! state (echo guard, access control, owner-command check) keeps the original +//! `from_account` — only the value emitted on `ChannelMessage.sender` / +//! `ChannelMessage.reply_target` is shortened. +//! +//! Format: `_`. +//! The 8-char prefix keeps logs roughly groupable for the same user; the +//! sha256 suffix guarantees uniqueness across uids that share a prefix. + +use sha2::{Digest, Sha256}; + +/// Max raw account-id length before the shortening kicks in. +/// +/// Anything shorter is passed through unchanged so short upstream-style ids +/// (e.g. numeric ids, future protocol changes) keep their natural form. +const ACCOUNT_ID_PASSTHROUGH_MAX: usize = 24; + +/// Shorten a yuanbao account id for use in `ChannelMessage.sender` / +/// `ChannelMessage.reply_target`. See module docs for rationale. +pub(super) fn shorten_account_id(uid: &str) -> String { + if uid.len() <= ACCOUNT_ID_PASSTHROUGH_MAX { + return uid.to_string(); + } + let prefix: String = uid.chars().take(8).collect(); + let digest = Sha256::digest(uid.as_bytes()); + format!("{prefix}_{:.16x}", digest) +} + +/// Shorten a yuanbao `reply_target`, preserving the `g:` shape +/// used for group chats. The `g:` discriminator is required by outbound +/// routing (see [`super::types::InboundContext::reply_target`]). +pub(super) fn shorten_reply_target(reply_target: &str) -> String { + if let Some(group_code) = reply_target.strip_prefix("g:") { + format!("g:{}", shorten_account_id(group_code)) + } else { + shorten_account_id(reply_target) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn passes_short_ids_through_unchanged() { + assert_eq!(shorten_account_id("123456"), "123456"); + assert_eq!(shorten_account_id(""), ""); + let exactly_max = "a".repeat(ACCOUNT_ID_PASSTHROUGH_MAX); + assert_eq!(shorten_account_id(&exactly_max), exactly_max); + } + + #[test] + fn shortens_long_ids_to_prefix_plus_hash() { + let long_uid = "a".repeat(64); + let shortened = shorten_account_id(&long_uid); + assert_eq!(shortened.len(), 8 + 1 + 16, "8 prefix + '_' + 16 hex"); + assert!(shortened.starts_with("aaaaaaaa_")); + } + + #[test] + fn shortening_is_deterministic_and_collision_resistant() { + let a = "f".repeat(64); + let mut b = a.clone(); + b.replace_range(63..64, "e"); // differ in last char only + let sa = shorten_account_id(&a); + let sb = shorten_account_id(&b); + assert_eq!(sa, shorten_account_id(&a), "deterministic"); + assert_ne!(sa, sb, "different uids hash to different ids"); + } + + #[test] + fn group_reply_target_preserves_g_prefix() { + let short_group = shorten_reply_target("g:short_group"); + assert_eq!(short_group, "g:short_group"); + + let long_code = "a".repeat(64); + let long_group = format!("g:{long_code}"); + let shortened = shorten_reply_target(&long_group); + assert!(shortened.starts_with("g:aaaaaaaa_")); + assert_eq!(shortened.len(), 2 + 8 + 1 + 16); + } + + #[test] + fn dm_reply_target_shortens_like_account_id() { + let uid = "z".repeat(64); + assert_eq!(shorten_reply_target(&uid), shorten_account_id(&uid)); + } + + #[test] + fn shortened_thread_id_fits_under_name_max() { + // Simulate the worst case: long uid for sender + reply_target. + let uid = "f".repeat(64); + let sender = shorten_account_id(&uid); + let reply_target = shorten_account_id(&uid); + let thread_id = format!("channel:yuanbao_{sender}_{reply_target}"); + // hex-encoded filename used by ConversationStore (`.jsonl`). + let hex_name_len = thread_id.len() * 2 + ".jsonl".len(); + // NAME_MAX on common filesystems is 255 bytes. + assert!( + hex_name_len <= 255, + "shortened thread_id hex filename ({hex_name_len} bytes) must fit under NAME_MAX (255)" + ); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/inbound.rs b/src/openhuman/channels/providers/yuanbao/inbound.rs new file mode 100644 index 0000000000..3491a24cc6 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/inbound.rs @@ -0,0 +1,623 @@ +//! Inbound message pipeline (17 stages). +//! +//! Mirrors `InboundPipeline` in hermes-agent `gateway/platforms/yuanbao.py`. +//! Each stage runs in order; any of them can short-circuit by +//! returning `Skip(reason)`. `Abort(_)` propagates an error. + +use std::collections::VecDeque; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use async_trait::async_trait; +use tokio::sync::RwLock; +use tracing::{debug, trace}; + +use super::config::YuanbaoConfig; +use super::errors::YuanbaoError; +use super::proto::{decode_inbound_json, decode_inbound_push}; +use super::proto_constants::*; +use super::types::*; + +/// Shared per-channel state that survives across messages. +pub struct PipelineState { + pub bot_id: String, + pub bot_name: String, + pub owner_id: String, + pub dm_access: AccessPolicy, + pub group_access: AccessPolicy, + pub allowed_users: Vec, + pub allowed_groups: Vec, + pub group_at_required: bool, + pub home_chat: RwLock>, + pub dedup: RwLock, +} + +impl PipelineState { + pub fn new(cfg: &YuanbaoConfig, bot_id: String) -> Arc { + Arc::new(Self { + bot_id, + bot_name: cfg.bot_name.clone(), + owner_id: cfg.owner_id.clone(), + dm_access: AccessPolicy::parse(&cfg.dm_access), + group_access: AccessPolicy::parse(&cfg.group_access), + allowed_users: cfg.allowed_users.clone(), + allowed_groups: cfg.allowed_groups.clone(), + group_at_required: cfg.group_at_required, + home_chat: RwLock::new(None), + dedup: RwLock::new(DedupCache::new(DEDUP_CAPACITY, DEDUP_TTL_SECS)), + }) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AccessPolicy { + Open, + Allowlist, + Closed, +} + +impl AccessPolicy { + fn parse(s: &str) -> Self { + match s.to_ascii_lowercase().as_str() { + "open" => Self::Open, + "closed" | "disabled" | "none" => Self::Closed, + _ => Self::Allowlist, + } + } +} + +/// Mutable context passed through every inbound stage. +#[derive(Debug, Clone)] +pub struct PipelineCtx { + pub msg: InboundMessage, + pub source: Source, + pub text: String, + pub image_urls: Vec, + pub is_at_bot: bool, + pub is_owner_command: bool, + pub kind: MessageKind, +} + +/// Outcome of a single inbound stage invocation. +#[derive(Debug)] +pub enum MwResult { + Continue, + Skip(&'static str), + Abort(YuanbaoError), +} + +/// Final outcome of the whole pipeline. +#[derive(Debug)] +pub enum PipelineOutcome { + Dispatch(PipelineCtx), + Filtered(&'static str), + Failed(YuanbaoError), +} + +#[async_trait] +pub trait Middleware: Send + Sync { + fn name(&self) -> &'static str; + async fn process(&self, state: &PipelineState, ctx: &mut PipelineCtx) -> MwResult; +} + +/// LRU-like dedup cache with TTL. +pub struct DedupCache { + capacity: usize, + ttl: Duration, + order: VecDeque<(String, Instant)>, + index: std::collections::HashSet, +} + +impl DedupCache { + pub fn new(capacity: usize, ttl_secs: u64) -> Self { + Self { + capacity, + ttl: Duration::from_secs(ttl_secs), + order: VecDeque::with_capacity(capacity), + index: std::collections::HashSet::with_capacity(capacity), + } + } + + /// Returns `true` if `id` has been seen within the TTL window. Inserts it otherwise. + pub fn check_and_insert(&mut self, id: &str) -> bool { + self.evict_expired(); + if self.index.contains(id) { + return true; + } + if self.order.len() >= self.capacity { + if let Some((old, _)) = self.order.pop_front() { + self.index.remove(&old); + } + } + self.order.push_back((id.to_string(), Instant::now())); + self.index.insert(id.to_string()); + false + } + + fn evict_expired(&mut self) { + let now = Instant::now(); + while let Some((_, ts)) = self.order.front() { + if now.duration_since(*ts) > self.ttl { + if let Some((old, _)) = self.order.pop_front() { + self.index.remove(&old); + } + } else { + break; + } + } + } +} + +// ───── Individual inbound stages ──────────────────────────────────── + +struct DecodeMw; +struct ExtractFieldsMw; +struct RecallGuardMw; +struct DedupMw; +struct SkipSelfMw; +struct ChatRoutingMw; +struct AccessGuardMw; +struct AutoSetHomeMw; +struct ExtractContentMw; +struct PlaceholderFilterMw; +struct OwnerCommandMw; +struct BuildSourceMw; +struct GroupAtGuardMw; +struct GroupAttributionMw; +struct ClassifyMsgTypeMw; +struct QuoteContextMw; +struct MediaResolveMw; + +#[async_trait] +impl Middleware for DecodeMw { + fn name(&self) -> &'static str { + "decode" + } + async fn process(&self, _s: &PipelineState, _c: &mut PipelineCtx) -> MwResult { + // Decoding happens before we build a PipelineCtx — this MW is a placeholder + // so the stage list still has 17 entries (mirrors hermes-agent). + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for ExtractFieldsMw { + fn name(&self) -> &'static str { + "extract_fields" + } + async fn process(&self, _s: &PipelineState, _c: &mut PipelineCtx) -> MwResult { + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for RecallGuardMw { + fn name(&self) -> &'static str { + "recall_guard" + } + async fn process(&self, _s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + if c.msg.is_recall() { + c.kind = MessageKind::Recall; + return MwResult::Skip("recall_guard"); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for DedupMw { + fn name(&self) -> &'static str { + "dedup" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + if c.msg.msg_id.is_empty() { + return MwResult::Continue; // nothing to dedup on + } + let mut cache = s.dedup.write().await; + if cache.check_and_insert(&c.msg.msg_id) { + return MwResult::Skip("dedup"); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for SkipSelfMw { + fn name(&self) -> &'static str { + "skip_self" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + if !s.bot_id.is_empty() && c.msg.from_account == s.bot_id { + return MwResult::Skip("skip_self"); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for ChatRoutingMw { + fn name(&self) -> &'static str { + "chat_routing" + } + async fn process(&self, _s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + c.source.is_group = c.msg.is_group(); + c.source.group_code = c.msg.group_code.clone(); + c.source.from_account = c.msg.from_account.clone(); + c.source.sender_nickname = c.msg.sender_nickname.clone(); + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for AccessGuardMw { + fn name(&self) -> &'static str { + "access_guard" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + let (policy, allow_list, key) = if c.source.is_group { + (s.group_access, &s.allowed_groups, &c.source.group_code) + } else { + (s.dm_access, &s.allowed_users, &c.source.from_account) + }; + let pass = match policy { + AccessPolicy::Open => true, + AccessPolicy::Closed => false, + AccessPolicy::Allowlist => allow_list.iter().any(|u| u == "*" || u == key), + }; + if !pass { + return MwResult::Skip("access_guard"); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for AutoSetHomeMw { + fn name(&self) -> &'static str { + "auto_set_home" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + if !c.source.is_group { + let mut home = s.home_chat.write().await; + if home.is_none() { + *home = Some(c.source.reply_target()); + } + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for ExtractContentMw { + fn name(&self) -> &'static str { + "extract_content" + } + async fn process(&self, _s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + c.text = c.msg.extract_text(); + c.image_urls = c.msg.extract_image_urls(); + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for PlaceholderFilterMw { + fn name(&self) -> &'static str { + "placeholder_filter" + } + async fn process(&self, _s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + let trimmed = c.text.trim(); + let is_placeholder = trimmed == "[image]" || trimmed == "[file]" || trimmed == "[图片]"; + if (trimmed.is_empty() || is_placeholder) && c.image_urls.is_empty() { + return MwResult::Skip("placeholder_filter"); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for OwnerCommandMw { + fn name(&self) -> &'static str { + "owner_command" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + if !s.owner_id.is_empty() + && c.msg.from_account == s.owner_id + && c.text.trim_start().starts_with('/') + { + c.is_owner_command = true; + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for BuildSourceMw { + fn name(&self) -> &'static str { + "build_source" + } + async fn process(&self, _s: &PipelineState, _c: &mut PipelineCtx) -> MwResult { + // Source already populated by ChatRoutingMw. + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for GroupAtGuardMw { + fn name(&self) -> &'static str { + "group_at_guard" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + if !c.source.is_group || !s.group_at_required { + return MwResult::Continue; + } + let by_name = !s.bot_name.is_empty() && c.text.contains(&format!("@{}", s.bot_name)); + let by_id = !s.bot_id.is_empty() && c.text.contains(&format!("@{}", s.bot_id)); + let by_mention = + !s.bot_id.is_empty() && c.text.contains(&format!("[at|userId:{}]", s.bot_id)); + c.is_at_bot = by_name || by_id || by_mention; + if !c.is_at_bot && !c.is_owner_command { + return MwResult::Skip("group_at_guard"); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for GroupAttributionMw { + fn name(&self) -> &'static str { + "group_attribution" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + // Strip `@bot` from text and the TIM `[at|userId:…]` markup. + if c.source.is_group && c.is_at_bot { + if !s.bot_name.is_empty() { + c.text = c.text.replace(&format!("@{}", s.bot_name), ""); + } + if !s.bot_id.is_empty() { + c.text = c.text.replace(&format!("@{}", s.bot_id), ""); + c.text = c.text.replace(&format!("[at|userId:{}]", s.bot_id), ""); + } + c.text = c.text.trim().to_string(); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for ClassifyMsgTypeMw { + fn name(&self) -> &'static str { + "classify_msg_type" + } + async fn process(&self, _s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + let has_text = !c.text.is_empty(); + let has_image = !c.image_urls.is_empty(); + let has_file = c.msg.msg_body.iter().any(|el| el.msg_type == tim::FILE); + let has_sound = c.msg.msg_body.iter().any(|el| el.msg_type == tim::SOUND); + c.kind = match (has_text, has_image, has_file, has_sound) { + (_, true, _, _) if has_text => MessageKind::Mixed, + (_, true, _, _) => MessageKind::Image, + (_, _, true, _) => MessageKind::File, + (_, _, _, true) => MessageKind::Voice, + _ => MessageKind::Text, + }; + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for QuoteContextMw { + fn name(&self) -> &'static str { + "quote_context" + } + async fn process(&self, _s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + // The cloud_custom_data field carries a JSON quote envelope; for + // now we just leave the raw payload accessible via `msg.cloud_custom_data` + // for downstream tools. Full parsing is intentionally deferred — + // hermes-agent does it lazily too. + let _ = c; + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for MediaResolveMw { + fn name(&self) -> &'static str { + "media_resolve" + } + async fn process(&self, _s: &PipelineState, _c: &mut PipelineCtx) -> MwResult { + // ybres:// resource URLs would be resolved here. Currently URLs + // arrive pre-resolved from the server; expand later if needed. + MwResult::Continue + } +} + +/// Composite pipeline = ordered Vec of inbound stages. +pub struct InboundPipeline { + state: Arc, + stages: Vec>, +} + +impl InboundPipeline { + pub fn new(state: Arc) -> Self { + let stages: Vec> = vec![ + Box::new(DecodeMw), + Box::new(ExtractFieldsMw), + Box::new(RecallGuardMw), + Box::new(DedupMw), + Box::new(SkipSelfMw), + Box::new(ChatRoutingMw), + Box::new(AccessGuardMw), + Box::new(AutoSetHomeMw), + Box::new(ExtractContentMw), + Box::new(PlaceholderFilterMw), + Box::new(OwnerCommandMw), + Box::new(BuildSourceMw), + Box::new(GroupAtGuardMw), + Box::new(GroupAttributionMw), + Box::new(ClassifyMsgTypeMw), + Box::new(QuoteContextMw), + Box::new(MediaResolveMw), + ]; + Self { state, stages } + } + + /// Decode a biz push body, run it through every stage, return the outcome. + /// + /// The yuanbao gateway may push the biz body as either protobuf + /// (`InboundMessagePush`) or a JSON string with the same field shape + /// (snake_case + `log_ext.trace_id`). We sniff the first non-whitespace + /// byte to pick the decoder — `{` means JSON, anything else is treated + /// as protobuf. Mirrors plugin gateway.ts::wsPushToInboundMessage + /// (l. 288), which tries protobuf first and falls back to JSON. + pub async fn process(&self, biz_body: &[u8]) -> PipelineOutcome { + let is_json = biz_body + .iter() + .find(|b| !b.is_ascii_whitespace()) + .map(|b| *b == b'{') + .unwrap_or(false); + + let msg = if is_json { + match decode_inbound_json(biz_body) { + Ok(m) => m, + Err(e) => return PipelineOutcome::Failed(e), + } + } else { + match decode_inbound_push(biz_body) { + Ok(m) => m, + Err(e) => return PipelineOutcome::Failed(e), + } + }; + let mut ctx = PipelineCtx { + msg, + source: Source::default(), + text: String::new(), + image_urls: Vec::new(), + is_at_bot: false, + is_owner_command: false, + kind: MessageKind::Text, + }; + for stage in &self.stages { + match stage.process(&self.state, &mut ctx).await { + MwResult::Continue => { + trace!("[yuanbao:inbound] {} pass", stage.name()); + } + MwResult::Skip(reason) => { + debug!("[yuanbao:inbound] {} filtered ({})", stage.name(), reason); + return PipelineOutcome::Filtered(reason); + } + MwResult::Abort(err) => return PipelineOutcome::Failed(err), + } + } + PipelineOutcome::Dispatch(ctx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn cfg(bot_id: &str) -> YuanbaoConfig { + let mut c = YuanbaoConfig::default(); + c.app_key = "ak".into(); + c.ws_domain = "wss://x".into(); + c.token = "tok".into(); + c.bot_id = bot_id.into(); + c.bot_name = "bot".into(); + c.dm_access = "open".into(); + c.group_access = "open".into(); + c + } + + fn ctx_with(msg: InboundMessage) -> PipelineCtx { + PipelineCtx { + msg, + source: Source::default(), + text: String::new(), + image_urls: Vec::new(), + is_at_bot: false, + is_owner_command: false, + kind: MessageKind::Text, + } + } + + #[tokio::test] + async fn dedup_skips_repeat() { + let state = PipelineState::new(&cfg("bot1"), "bot1".into()); + let mw = DedupMw; + let msg = InboundMessage { + msg_id: "m1".into(), + ..Default::default() + }; + let mut c1 = ctx_with(msg.clone()); + assert!(matches!( + mw.process(&state, &mut c1).await, + MwResult::Continue + )); + let mut c2 = ctx_with(msg); + assert!(matches!( + mw.process(&state, &mut c2).await, + MwResult::Skip(_) + )); + } + + #[tokio::test] + async fn access_guard_open() { + let state = PipelineState::new(&cfg("bot1"), "bot1".into()); + let mw = AccessGuardMw; + let mut c = ctx_with(InboundMessage { + from_account: "alice".into(), + ..Default::default() + }); + c.source.is_group = false; + c.source.from_account = "alice".into(); + assert!(matches!( + mw.process(&state, &mut c).await, + MwResult::Continue + )); + } + + #[tokio::test] + async fn full_dm_dispatch() { + let mut config = cfg("bot1"); + config.group_at_required = false; + let state = PipelineState::new(&config, "bot1".into()); + let pipeline = InboundPipeline::new(state); + let msg = InboundMessage { + from_account: "alice".into(), + to_account: "bot1".into(), + msg_id: "hi".into(), + msg_body: vec![MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: MsgContent { + text: Some("hello".into()), + ..Default::default() + }, + }], + ..Default::default() + }; + let body = crate::openhuman::channels::providers::yuanbao::proto::encode_msg_body_element( + &msg.msg_body[0], + ); + // Synthesize an InboundMessagePush from scratch: + use crate::openhuman::channels::providers::yuanbao::proto; + let mut buf = Vec::new(); + let mut put_str = |fnum: u32, s: &str, b: &mut Vec| { + proto::encode_varint(((fnum as u64) << 3) | 2, b); + proto::encode_varint(s.len() as u64, b); + b.extend_from_slice(s.as_bytes()); + }; + put_str(2, &msg.from_account, &mut buf); + put_str(3, &msg.to_account, &mut buf); + put_str(12, &msg.msg_id, &mut buf); + proto::encode_varint(((13u64) << 3) | 2, &mut buf); + proto::encode_varint(body.len() as u64, &mut buf); + buf.extend_from_slice(&body); + + let outcome = pipeline.process(&buf).await; + assert!( + matches!(outcome, PipelineOutcome::Dispatch(_)), + "got {:?}", + outcome + ); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/media.rs b/src/openhuman/channels/providers/yuanbao/media.rs new file mode 100644 index 0000000000..271be114ba --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/media.rs @@ -0,0 +1,619 @@ +//! Media helpers — MIME mapping, byte-level image dimension parsing, +//! download with size cap, and TIM `msg_body` builders. +//! +//! Tencent COS upload lives in [`super::cos`] to keep both files under +//! the 500-line ceiling. + +use super::errors::YuanbaoError; +use super::types::MsgBodyElement; + +// ─── MIME / image-format mapping ─────────────────────────────────── + +pub fn guess_mime_type(filename: &str) -> &'static str { + let ext = filename + .rsplit_once('.') + .map(|(_, e)| e.to_ascii_lowercase()) + .unwrap_or_default(); + match ext.as_str() { + "jpg" | "jpeg" => "image/jpeg", + "png" => "image/png", + "gif" => "image/gif", + "webp" => "image/webp", + "bmp" => "image/bmp", + "heic" => "image/heic", + "tiff" => "image/tiff", + "ico" => "image/x-icon", + "pdf" => "application/pdf", + "doc" => "application/msword", + "docx" => "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "xls" => "application/vnd.ms-excel", + "xlsx" => "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "ppt" => "application/vnd.ms-powerpoint", + "pptx" => "application/vnd.openxmlformats-officedocument.presentationml.presentation", + "txt" => "text/plain", + "zip" => "application/zip", + "tar" => "application/x-tar", + "gz" => "application/gzip", + "mp3" => "audio/mpeg", + "mp4" => "video/mp4", + "wav" => "audio/wav", + "ogg" => "audio/ogg", + "webm" => "video/webm", + _ => "application/octet-stream", + } +} + +pub fn is_image(filename: &str, mime_type: &str) -> bool { + if mime_type.starts_with("image/") { + return true; + } + guess_mime_type(filename).starts_with("image/") +} + +/// Map a MIME type to the TIM `image_format` enum. +pub fn image_format_code(mime: &str) -> u32 { + match mime { + "image/jpeg" | "image/jpg" => 1, + "image/gif" => 2, + "image/png" => 3, + "image/bmp" => 4, + "image/webp" | "image/heic" | "image/tiff" => 255, + _ => 255, + } +} + +// ─── Pure-bytes image dimension parsing ───────────────────────────── + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ImageDims { + pub width: u32, + pub height: u32, +} + +pub fn parse_image_size(data: &[u8]) -> Option { + parse_png(data) + .or_else(|| parse_jpeg(data)) + .or_else(|| parse_gif(data)) + .or_else(|| parse_webp(data)) +} + +fn parse_png(buf: &[u8]) -> Option { + if buf.len() < 24 || &buf[..4] != b"\x89PNG" { + return None; + } + let w = u32::from_be_bytes(buf[16..20].try_into().ok()?); + let h = u32::from_be_bytes(buf[20..24].try_into().ok()?); + Some(ImageDims { + width: w, + height: h, + }) +} + +fn parse_jpeg(buf: &[u8]) -> Option { + if buf.len() < 4 || buf[0] != 0xFF || buf[1] != 0xD8 { + return None; + } + let mut i = 2usize; + while i + 9 < buf.len() { + if buf[i] != 0xFF { + i += 1; + continue; + } + let marker = buf[i + 1]; + if marker == 0xC0 || marker == 0xC2 { + let h = u16::from_be_bytes(buf[i + 5..i + 7].try_into().ok()?); + let w = u16::from_be_bytes(buf[i + 7..i + 9].try_into().ok()?); + return Some(ImageDims { + width: w as u32, + height: h as u32, + }); + } + if i + 3 >= buf.len() { + break; + } + let seg_len = u16::from_be_bytes(buf[i + 2..i + 4].try_into().ok()?) as usize; + i += 2 + seg_len; + } + None +} + +fn parse_gif(buf: &[u8]) -> Option { + if buf.len() < 10 { + return None; + } + let sig = &buf[..6]; + if sig != b"GIF87a" && sig != b"GIF89a" { + return None; + } + let w = u16::from_le_bytes(buf[6..8].try_into().ok()?); + let h = u16::from_le_bytes(buf[8..10].try_into().ok()?); + Some(ImageDims { + width: w as u32, + height: h as u32, + }) +} + +fn parse_webp(buf: &[u8]) -> Option { + if buf.len() < 16 || &buf[..4] != b"RIFF" || &buf[8..12] != b"WEBP" { + return None; + } + let chunk = &buf[12..16]; + if chunk == b"VP8 " { + if buf.len() >= 30 && buf[23] == 0x9D && buf[24] == 0x01 && buf[25] == 0x2A { + let w = u16::from_le_bytes(buf[26..28].try_into().ok()?) & 0x3FFF; + let h = u16::from_le_bytes(buf[28..30].try_into().ok()?) & 0x3FFF; + return Some(ImageDims { + width: w as u32, + height: h as u32, + }); + } + } else if chunk == b"VP8L" { + if buf.len() >= 25 && buf[20] == 0x2F { + let bits = u32::from_le_bytes(buf[21..25].try_into().ok()?); + let w = (bits & 0x3FFF) + 1; + let h = ((bits >> 14) & 0x3FFF) + 1; + return Some(ImageDims { + width: w, + height: h, + }); + } + } else if chunk == b"VP8X" && buf.len() >= 30 { + let w = (buf[24] as u32 | ((buf[25] as u32) << 8) | ((buf[26] as u32) << 16)) + 1; + let h = (buf[27] as u32 | ((buf[28] as u32) << 8) | ((buf[29] as u32) << 16)) + 1; + return Some(ImageDims { + width: w, + height: h, + }); + } + None +} + +// ─── HTTP download with size cap ──────────────────────────────────── + +pub async fn download_url( + http: &reqwest::Client, + url: &str, + max_size_mb: u64, +) -> Result<(Vec, String), YuanbaoError> { + let limit = max_size_mb.saturating_mul(1024 * 1024); + + if let Ok(head) = http.head(url).send().await { + if let Some(len) = head + .headers() + .get(reqwest::header::CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + { + if len > limit { + return Err(YuanbaoError::Media(format!( + "remote file too large: {len} > limit {limit}" + ))); + } + } + } + + let resp = http + .get(url) + .send() + .await + .map_err(|e| YuanbaoError::Connection(format!("download {url}: {e}")))?; + if !resp.status().is_success() { + return Err(YuanbaoError::Media(format!( + "download HTTP {} for {url}", + resp.status() + ))); + } + let ct = resp + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .split(';') + .next() + .unwrap_or("") + .trim() + .to_string(); + + let bytes = resp + .bytes() + .await + .map_err(|e| YuanbaoError::Media(format!("read body: {e}")))?; + if bytes.len() as u64 > limit { + return Err(YuanbaoError::Media(format!( + "downloaded file exceeds limit: {} > {}", + bytes.len(), + limit + ))); + } + Ok((bytes.to_vec(), ct)) +} + +// ─── TIM msg_body builders ────────────────────────────────────────── + +/// Build a TIM `TIMImageElem` `msg_body` ready to send. +pub fn build_image_msg_body( + url: &str, + uuid: Option<&str>, + filename: Option<&str>, + size: u32, + width: u32, + height: u32, + mime_type: &str, +) -> Vec { + use super::types::{ImageInfo, MsgContent}; + let uuid_str = uuid + .map(|s| s.to_string()) + .or_else(|| filename.map(|s| s.to_string())) + .unwrap_or_else(|| "image".to_string()); + let format = if mime_type.is_empty() { + 255 + } else { + image_format_code(mime_type) + }; + vec![MsgBodyElement { + msg_type: "TIMImageElem".into(), + msg_content: MsgContent { + uuid: Some(uuid_str), + image_format: Some(format), + image_info_array: vec![ImageInfo { + image_type: 1, + size, + width, + height, + url: url.to_string(), + }], + ..Default::default() + }, + }] +} + +/// Build a TIM `TIMFileElem` `msg_body` ready to send. +pub fn build_file_msg_body( + url: &str, + filename: &str, + uuid: Option<&str>, + size: u32, +) -> Vec { + use super::types::MsgContent; + let uuid_str = uuid + .map(|s| s.to_string()) + .unwrap_or_else(|| filename.to_string()); + vec![MsgBodyElement { + msg_type: "TIMFileElem".into(), + msg_content: MsgContent { + uuid: Some(uuid_str), + file_name: Some(filename.to_string()), + file_size: Some(size), + url: Some(url.to_string()), + ..Default::default() + }, + }] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn png_dims_parse() { + let png = [ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, + 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x06, + ]; + let d = parse_image_size(&png).expect("png parse"); + assert_eq!(d.width, 1); + assert_eq!(d.height, 1); + } + + #[test] + fn gif_dims_parse() { + let gif = b"GIF89a\x40\x01\xF0\x00rest"; + let d = parse_image_size(gif).expect("gif parse"); + assert_eq!(d.width, 320); + assert_eq!(d.height, 240); + } + + #[test] + fn guess_mime_basic() { + assert_eq!(guess_mime_type("foo.png"), "image/png"); + assert_eq!(guess_mime_type("doc.pdf"), "application/pdf"); + assert_eq!(guess_mime_type("blob"), "application/octet-stream"); + } + + #[test] + fn is_image_works() { + assert!(is_image("a.jpeg", "")); + assert!(is_image("noext", "image/png")); + assert!(!is_image("a.pdf", "")); + } + + #[test] + fn image_format_code_matrix() { + assert_eq!(image_format_code("image/png"), 3); + assert_eq!(image_format_code("image/jpeg"), 1); + assert_eq!(image_format_code("image/gif"), 2); + assert_eq!(image_format_code("image/bmp"), 4); + assert_eq!(image_format_code("image/webp"), 255); + assert_eq!(image_format_code("application/pdf"), 255); + } + + // ─── extended MIME / image-format tests ───────────────────── + + #[test] + fn guess_mime_handles_uppercase_extension() { + assert_eq!(guess_mime_type("PHOTO.JPG"), "image/jpeg"); + assert_eq!(guess_mime_type("Doc.PDF"), "application/pdf"); + } + + #[test] + fn guess_mime_covers_office_audio_video_archive_types() { + assert_eq!(guess_mime_type("file.jpg"), "image/jpeg"); + assert_eq!(guess_mime_type("file.jpeg"), "image/jpeg"); + assert_eq!(guess_mime_type("file.gif"), "image/gif"); + assert_eq!(guess_mime_type("file.webp"), "image/webp"); + assert_eq!(guess_mime_type("file.bmp"), "image/bmp"); + assert_eq!(guess_mime_type("file.heic"), "image/heic"); + assert_eq!(guess_mime_type("file.tiff"), "image/tiff"); + assert_eq!(guess_mime_type("file.ico"), "image/x-icon"); + assert_eq!(guess_mime_type("file.doc"), "application/msword"); + assert_eq!( + guess_mime_type("file.docx"), + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ); + assert_eq!(guess_mime_type("file.xls"), "application/vnd.ms-excel"); + assert_eq!( + guess_mime_type("file.xlsx"), + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + ); + assert_eq!(guess_mime_type("file.ppt"), "application/vnd.ms-powerpoint"); + assert_eq!( + guess_mime_type("file.pptx"), + "application/vnd.openxmlformats-officedocument.presentationml.presentation" + ); + assert_eq!(guess_mime_type("file.txt"), "text/plain"); + assert_eq!(guess_mime_type("file.zip"), "application/zip"); + assert_eq!(guess_mime_type("file.tar"), "application/x-tar"); + assert_eq!(guess_mime_type("file.gz"), "application/gzip"); + assert_eq!(guess_mime_type("file.mp3"), "audio/mpeg"); + assert_eq!(guess_mime_type("file.mp4"), "video/mp4"); + assert_eq!(guess_mime_type("file.wav"), "audio/wav"); + assert_eq!(guess_mime_type("file.ogg"), "audio/ogg"); + assert_eq!(guess_mime_type("file.webm"), "video/webm"); + } + + #[test] + fn image_format_code_jpg_alias_and_heic_tiff() { + assert_eq!(image_format_code("image/jpg"), 1); + assert_eq!(image_format_code("image/heic"), 255); + assert_eq!(image_format_code("image/tiff"), 255); + assert_eq!(image_format_code(""), 255); + } + + // ─── parse_image_size — JPEG / WEBP / negative paths ──────── + + #[test] + fn jpeg_dims_from_sof0_marker() { + // SOI + filler APP0 segment + SOF0 marker carrying h=2, w=3. + // parse_jpeg's read uses buf[i+5..i+9] and gates on `i + 9 < buf.len()`, + // so trailing pad bytes are required (one tail byte makes 17 < 18 true). + let mut jpeg = vec![0xFF, 0xD8]; + // APP0 (0xFFE0) with len=4 → 2 bytes payload + jpeg.extend_from_slice(&[0xFF, 0xE0, 0x00, 0x04, 0x00, 0x00]); + // SOF0: 0xFF 0xC0 len(11) precision(8) height(2 BE) width(3 BE) + jpeg.extend_from_slice(&[0xFF, 0xC0, 0x00, 0x0B, 0x08, 0x00, 0x02, 0x00, 0x03]); + // trailing pad so the `i + 9 < buf.len()` loop guard accepts the SOF0 + // entry on the second iteration. + jpeg.push(0xFF); + let d = parse_image_size(&jpeg).expect("jpeg parse"); + assert_eq!(d.width, 3); + assert_eq!(d.height, 2); + } + + #[test] + fn jpeg_too_short_returns_none() { + assert!(parse_image_size(&[0xFF, 0xD8]).is_none()); + } + + #[test] + fn jpeg_wrong_magic_returns_none() { + let buf = [0xCA, 0xFE, 0xBA, 0xBE, 0, 0, 0, 0, 0, 0]; + assert!(parse_image_size(&buf).is_none()); + } + + #[test] + fn webp_vp8x_dims_parse() { + // RIFF size WEBP VP8X flags(4) padding(3) w-1(LE24) h-1(LE24) + // w=320 → 319 little-endian = [0x3F, 0x01, 0x00]; h=240 → 239 = [0xEF, 0x00, 0x00] + let mut buf = b"RIFF\x00\x00\x00\x00WEBPVP8X".to_vec(); + buf.extend_from_slice(&[0u8; 8]); // flags + reserved + buf.extend_from_slice(&[0x3F, 0x01, 0x00, 0xEF, 0x00, 0x00]); + let d = parse_image_size(&buf).expect("webp vp8x parse"); + assert_eq!(d.width, 320); + assert_eq!(d.height, 240); + } + + #[test] + fn webp_too_short_returns_none() { + let buf = b"RIFF\0\0\0\0WEBPVP8"; + assert!(parse_image_size(buf).is_none()); + } + + #[test] + fn webp_unsupported_chunk_returns_none() { + let mut buf = b"RIFF\0\0\0\0WEBPXXXX".to_vec(); + buf.extend_from_slice(&[0u8; 30]); + assert!(parse_image_size(&buf).is_none()); + } + + #[test] + fn png_short_or_wrong_magic_returns_none() { + assert!(parse_image_size(&[0x89, 0x50, 0x4E, 0x47]).is_none()); // too short + let mut buf = vec![0xFF; 24]; + buf[..4].copy_from_slice(&[0x89, 0x50, 0x4F, 0x47]); // wrong magic + assert!(parse_image_size(&buf).is_none()); + } + + #[test] + fn gif_too_short_or_wrong_sig_returns_none() { + assert!(parse_image_size(b"GIF87").is_none()); + assert!(parse_image_size(b"NOTGIFEXT").is_none()); + } + + #[test] + fn parse_image_size_empty_returns_none() { + assert!(parse_image_size(&[]).is_none()); + } + + // ─── msg_body builders ────────────────────────────────────── + + #[test] + fn build_image_msg_body_uses_uuid_when_present() { + let body = build_image_msg_body( + "https://x/cat.png", + Some("uuid-1"), + Some("cat.png"), + 1024, + 800, + 600, + "image/png", + ); + assert_eq!(body.len(), 1); + let el = &body[0]; + assert_eq!(el.msg_type, "TIMImageElem"); + assert_eq!(el.msg_content.uuid.as_deref(), Some("uuid-1")); + assert_eq!(el.msg_content.image_format, Some(3)); // png + assert_eq!(el.msg_content.image_info_array.len(), 1); + let info = &el.msg_content.image_info_array[0]; + assert_eq!(info.image_type, 1); + assert_eq!(info.size, 1024); + assert_eq!(info.width, 800); + assert_eq!(info.height, 600); + assert_eq!(info.url, "https://x/cat.png"); + } + + #[test] + fn build_image_msg_body_falls_back_to_filename_then_default_uuid() { + let with_filename = build_image_msg_body( + "https://x/", + None, + Some("only-name.png"), + 0, + 0, + 0, + "image/png", + ); + assert_eq!( + with_filename[0].msg_content.uuid.as_deref(), + Some("only-name.png") + ); + + let default_id = build_image_msg_body("https://x/", None, None, 0, 0, 0, "image/png"); + assert_eq!(default_id[0].msg_content.uuid.as_deref(), Some("image")); + } + + #[test] + fn build_image_msg_body_treats_empty_mime_as_format_255() { + let body = build_image_msg_body("https://x/cat.jpg", None, None, 0, 0, 0, ""); + assert_eq!(body[0].msg_content.image_format, Some(255)); + } + + #[test] + fn build_file_msg_body_uses_filename_when_uuid_missing() { + let body = build_file_msg_body("https://x/doc.pdf", "doc.pdf", None, 2048); + assert_eq!(body.len(), 1); + let el = &body[0]; + assert_eq!(el.msg_type, "TIMFileElem"); + assert_eq!(el.msg_content.uuid.as_deref(), Some("doc.pdf")); + assert_eq!(el.msg_content.file_name.as_deref(), Some("doc.pdf")); + assert_eq!(el.msg_content.file_size, Some(2048)); + assert_eq!(el.msg_content.url.as_deref(), Some("https://x/doc.pdf")); + } + + #[test] + fn build_file_msg_body_prefers_explicit_uuid() { + let body = build_file_msg_body("https://x/y.pdf", "y.pdf", Some("uuid-y"), 0); + assert_eq!(body[0].msg_content.uuid.as_deref(), Some("uuid-y")); + } + + // ─── download_url (wiremock) ──────────────────────────────── + + #[tokio::test] + async fn download_url_returns_bytes_and_content_type() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("HEAD")) + .respond_with(wiremock::ResponseTemplate::new(200).insert_header("Content-Length", "3")) + .mount(&server) + .await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .respond_with( + wiremock::ResponseTemplate::new(200) + .insert_header("Content-Type", "image/png; charset=binary") + .set_body_bytes(vec![1u8, 2, 3]), + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let (bytes, mime) = download_url(&http, &server.uri(), 10).await.unwrap(); + assert_eq!(bytes, vec![1, 2, 3]); + assert_eq!(mime, "image/png"); + } + + #[tokio::test] + async fn download_url_rejects_oversize_from_head_content_length() { + let server = wiremock::MockServer::start().await; + // HEAD reports a very large file → reject BEFORE GET. + wiremock::Mock::given(wiremock::matchers::method("HEAD")) + .respond_with(wiremock::ResponseTemplate::new(200).insert_header( + "Content-Length", + (10u64 * 1024 * 1024 + 1).to_string().as_str(), + )) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let err = download_url(&http, &server.uri(), 10).await.unwrap_err(); + match err { + YuanbaoError::Media(m) => assert!(m.contains("too large"), "got {m}"), + other => panic!("expected Media error, got {other:?}"), + } + } + + #[tokio::test] + async fn download_url_rejects_when_body_exceeds_limit() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("HEAD")) + .respond_with(wiremock::ResponseTemplate::new(200)) + .mount(&server) + .await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_bytes(vec![0u8; 2 * 1024 * 1024]), // 2 MiB + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let err = download_url(&http, &server.uri(), 1).await.unwrap_err(); + match err { + YuanbaoError::Media(m) => assert!(m.contains("exceeds limit"), "got {m}"), + other => panic!("expected Media error, got {other:?}"), + } + } + + #[tokio::test] + async fn download_url_surfaces_http_error_status() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("HEAD")) + .respond_with(wiremock::ResponseTemplate::new(404)) + .mount(&server) + .await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .respond_with(wiremock::ResponseTemplate::new(404)) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let err = download_url(&http, &server.uri(), 10).await.unwrap_err(); + match err { + YuanbaoError::Media(m) => assert!(m.contains("HTTP 404"), "got {m}"), + other => panic!("expected Media error, got {other:?}"), + } + } +} diff --git a/src/openhuman/channels/providers/yuanbao/mod.rs b/src/openhuman/channels/providers/yuanbao/mod.rs new file mode 100644 index 0000000000..23e8ad99d0 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/mod.rs @@ -0,0 +1,29 @@ +//! Yuanbao (元宝) channel provider. +//! +//! This module is intentionally export-focused. Operational code lives in +//! sibling modules: +//! - [`channel`] wires the provider into the generic OpenHuman `Channel` trait. +//! - [`connection`] owns the WebSocket transport and request correlator. +//! - [`inbound`] owns inbound filtering/extraction. +//! - [`outbound`] owns Yuanbao send/query calls. +//! - [`proto`] / [`proto_biz`] / [`wire`] own hand-written protobuf codecs. + +pub mod channel; +pub mod config; +pub mod connection; +pub mod cos; +pub mod errors; +pub mod ids; +pub mod inbound; +pub mod media; +pub mod outbound; +pub mod proto; +pub mod proto_biz; +pub mod proto_constants; +pub mod sign; +pub mod splitter; +pub mod types; +pub mod wire; + +pub use channel::YuanbaoChannel; +pub use config::YuanbaoConfig; diff --git a/src/openhuman/channels/providers/yuanbao/outbound.rs b/src/openhuman/channels/providers/yuanbao/outbound.rs new file mode 100644 index 0000000000..6f7ec1cd85 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/outbound.rs @@ -0,0 +1,449 @@ +//! Outbound message sender. +//! +//! Translates high-level `send_text` / `send_image` / heartbeat calls +//! into encoded ConnMsg frames and pushes them through the shared +//! `YuanbaoConnection`. The recipient string uses the convention +//! `g:` for groups, raw `` for DMs. +//! +//! For the few request kinds where we care about the response body +//! (notably `QueryGroupInfo`, `GetGroupMemberList`, and `SendXxxMessage`'s +//! `code/msg_id` echo) we use the connection-level pending-acks +//! correlator and return parsed results to the caller. Heartbeats are +//! fire-and-forget — the response is never inspected. + +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use tracing::debug; + +use super::connection::YuanbaoConnection; +use super::cos::{get_cos_credentials, upload_to_cos}; +use super::errors::YuanbaoError; +use super::media::{build_file_msg_body, build_image_msg_body, download_url, parse_image_size}; +use super::proto_biz::{ + decode_get_group_member_list_rsp, decode_query_group_info_rsp, encode_get_group_member_list, + encode_query_group_info, encode_send_c2c_message, encode_send_group_heartbeat, + encode_send_group_message, encode_send_private_heartbeat, +}; +use super::proto_constants::{ws_heartbeat, DEFAULT_SEND_TIMEOUT_SECS}; +use super::sign::SignManager; +use super::types::{GroupInfo, GroupMemberListPage, MsgBodyElement}; + +const GROUP_PREFIX: &str = "g:"; +/// Wait-for-response timeout on queries like `QueryGroupInfo`. +const QUERY_TIMEOUT_SECS: u64 = 10; + +/// Parsed addressing target. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Target<'a> { + Dm(&'a str), + Group(&'a str), +} + +impl<'a> Target<'a> { + pub fn parse(recipient: &'a str) -> Self { + if let Some(rest) = recipient.strip_prefix(GROUP_PREFIX) { + Self::Group(rest) + } else { + Self::Dm(recipient) + } + } +} + +pub struct OutboundSender { + conn: Arc, + /// Sign-token cache holding the server-issued `bot_id`. Populated as a + /// side effect of `connection`'s sign+auth-bind flow. The `bot_id` here + /// is what `yuanbao_openclaw_proxy` expects in the outbound + /// `from_account` field — config-only fallbacks like `app_key` get + /// silently accepted (status=0) but never routed to a real conv id. + sign_manager: Option>, + /// Lookup key for `sign_manager.cached(app_key)`. + app_key: String, + /// User-supplied bot id override; empty when not set. Only used when + /// the sign cache hasn't been primed yet (e.g. send-before-auth races). + config_bot_id: String, + http: reqwest::Client, +} + +impl OutboundSender { + pub fn new( + conn: Arc, + sign_manager: Option>, + app_key: String, + config_bot_id: String, + ) -> Self { + Self { + conn, + sign_manager, + app_key, + config_bot_id, + http: reqwest::Client::new(), + } + } + + /// Resolve the `from_account` to put on the next outbound frame. + /// Prefers the server-issued `bot_id` cached after sign-token / auth-bind + /// — matches hermes-agent `_bot_id = token_data["bot_id"]` (yuanbao.py:400). + async fn resolve_from_account(&self) -> String { + if let Some(sign) = &self.sign_manager { + if let Some(entry) = sign.cached(&self.app_key).await { + if !entry.bot_id.is_empty() { + return entry.bot_id; + } + } + } + self.config_bot_id.clone() + } + + /// Send a plain-text message. Returns the client-side `msg_id`. + pub async fn send_text( + &self, + recipient: &str, + text: &str, + ref_msg_id: Option<&str>, + ) -> Result { + let body = vec![MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: super::types::MsgContent { + text: Some(text.to_string()), + ..Default::default() + }, + }]; + self.send_body(recipient, body, ref_msg_id).await + } + + /// Send an image by an already-uploaded (COS or other) URL. + #[allow(clippy::too_many_arguments)] + pub async fn send_image_url( + &self, + recipient: &str, + url: &str, + size: u32, + width: u32, + height: u32, + mime_type: &str, + ) -> Result { + let body = build_image_msg_body(url, None, None, size, width, height, mime_type); + self.send_body(recipient, body, None).await + } + + /// End-to-end image send: download from URL → upload to COS → send + /// as a `TIMImageElem`. Returns the outbound `msg_id`. + /// + /// `app_key` / `bot_id` / `token` / `api_domain` / `route_env` come + /// from the channel config; pass them in rather than reaching back + /// through the conn to keep this fn easy to unit-test. + #[allow(clippy::too_many_arguments)] + pub async fn send_image_from_url( + &self, + recipient: &str, + source_url: &str, + app_key: &str, + bot_id: &str, + token: &str, + api_domain: &str, + route_env: &str, + max_size_mb: u64, + ) -> Result { + let (bytes, mime) = download_url(&self.http, source_url, max_size_mb).await?; + let dims = parse_image_size(&bytes); + let width = dims.as_ref().map(|d| d.width).unwrap_or(0); + let height = dims.as_ref().map(|d| d.height).unwrap_or(0); + + let filename = extract_filename(source_url); + let creds = get_cos_credentials( + &self.http, api_domain, app_key, bot_id, token, route_env, &filename, + ) + .await?; + let upload = upload_to_cos(&self.http, &creds, &bytes, &filename, mime.clone()).await?; + + let final_width = if upload.width > 0 { + upload.width + } else { + width + }; + let final_height = if upload.height > 0 { + upload.height + } else { + height + }; + let body = build_image_msg_body( + &upload.url, + Some(&upload.uuid), + Some(&filename), + upload.size as u32, + final_width, + final_height, + &mime, + ); + self.send_body(recipient, body, None).await + } + + /// Send a file by URL. + pub async fn send_file_url( + &self, + recipient: &str, + url: &str, + file_name: &str, + size: u32, + ) -> Result { + let body = build_file_msg_body(url, file_name, None, size); + self.send_body(recipient, body, None).await + } + + /// Send a pre-built `msg_body`. Waits up to `DEFAULT_SEND_TIMEOUT_SECS` + /// for the server response so the caller learns about delivery + /// failures (rate-limit, banned content, etc.) instead of getting a + /// silent drop. + pub async fn send_body( + &self, + recipient: &str, + msg_body: Vec, + ref_msg_id: Option<&str>, + ) -> Result { + let msg_id = self.next_msg_id(); + let target = Target::parse(recipient); + let from_account = self.resolve_from_account().await; + let frame = match target { + Target::Dm(uid) => encode_send_c2c_message( + uid, + &from_account, + &msg_body, + &msg_id, + random_u32(), + "", + "", + ), + Target::Group(group_code) => { + let random = format!("{}", random_u32()); + encode_send_group_message( + group_code, + &from_account, + &msg_body, + &msg_id, + "", + &random, + ref_msg_id.unwrap_or(""), + "", + ) + } + }; + + let timeout = Duration::from_secs(DEFAULT_SEND_TIMEOUT_SECS); + match self.conn.send_and_wait(&msg_id, frame, timeout).await { + Ok(resp) => { + if resp.status != 0 { + return Err(YuanbaoError::SendFailed(format!( + "server status={} cmd={}", + resp.status, resp.cmd + ))); + } + debug!("[outbound] ack msg_id={msg_id} target={:?}", target); + Ok(msg_id) + } + // If the correlator isn't usable yet (NotConnected etc.) bubble up. + Err(e) => Err(e), + } + } + + /// Send a "thinking" heartbeat (RUNNING) — fire-and-forget. + pub async fn start_heartbeat(&self, recipient: &str) -> Result<(), YuanbaoError> { + self.send_heartbeat(recipient, ws_heartbeat::RUNNING).await + } + + /// Send a "done" heartbeat (FINISH) — fire-and-forget. + pub async fn stop_heartbeat(&self, recipient: &str) -> Result<(), YuanbaoError> { + self.send_heartbeat(recipient, ws_heartbeat::FINISH).await + } + + async fn send_heartbeat(&self, recipient: &str, heartbeat: u32) -> Result<(), YuanbaoError> { + let req_id = self.conn.next_msg_id("hb"); + let from_account = self.resolve_from_account().await; + let frame = match Target::parse(recipient) { + Target::Dm(uid) => { + encode_send_private_heartbeat(&req_id, &from_account, uid, heartbeat) + } + Target::Group(group_code) => { + let now_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + encode_send_group_heartbeat(&req_id, &from_account, group_code, heartbeat, now_ms) + } + }; + // Fire-and-forget — we don't care about the heartbeat ack. + self.conn.send_conn_msg(frame).await + } + + /// Query group info and wait for the server's reply. + pub async fn query_group_info(&self, group_code: &str) -> Result { + let req_id = self.conn.next_msg_id("qgi"); + let frame = encode_query_group_info(&req_id, group_code); + let resp = self + .conn + .send_and_wait(&req_id, frame, Duration::from_secs(QUERY_TIMEOUT_SECS)) + .await?; + decode_query_group_info_rsp(&resp.data) + } + + /// Fetch one page of group members. Use `offset=0, limit=100` for the + /// first page; the response carries `next_offset` for pagination. + pub async fn query_group_members( + &self, + group_code: &str, + offset: u32, + limit: u32, + ) -> Result { + let req_id = self.conn.next_msg_id("qgm"); + let frame = encode_get_group_member_list(&req_id, group_code, offset, limit); + let resp = self + .conn + .send_and_wait(&req_id, frame, Duration::from_secs(QUERY_TIMEOUT_SECS)) + .await?; + decode_get_group_member_list_rsp(&resp.data) + } + + fn next_msg_id(&self) -> String { + // Use a stable prefix so logs can be grepped across send paths. + self.conn.next_msg_id("om") + } +} + +fn random_u32() -> u32 { + rand::random::() +} + +/// Best-effort file name extraction from a URL — uses the URL's path +/// component (so the host is never picked up as a filename) and falls +/// back to "file" if there's nothing usable. +fn extract_filename(url_str: &str) -> String { + if let Ok(parsed) = url::Url::parse(url_str) { + if let Some(segments) = parsed.path_segments() { + if let Some(last) = segments.filter(|s| !s.is_empty()).last() { + return last.to_string(); + } + } + return "file".to_string(); + } + // Non-URL input (relative path, raw filename, etc.) — fall back to + // last non-empty `/`-delimited segment. + let without_query = url_str.split('?').next().unwrap_or(url_str); + let last = without_query + .rsplit('/') + .find(|s| !s.is_empty()) + .unwrap_or(""); + if last.is_empty() { + "file".to_string() + } else { + last.to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn target_parse_dm() { + match Target::parse("user_42") { + Target::Dm(uid) => assert_eq!(uid, "user_42"), + _ => panic!("should be DM"), + } + } + + #[test] + fn target_parse_group() { + match Target::parse("g:room_99") { + Target::Group(c) => assert_eq!(c, "room_99"), + _ => panic!("should be group"), + } + } + + #[test] + fn target_parse_empty_dm() { + assert!(matches!(Target::parse(""), Target::Dm(""))); + } + + #[test] + fn extract_filename_strips_query() { + assert_eq!(extract_filename("https://x.com/a/b/cat.png"), "cat.png"); + assert_eq!( + extract_filename("https://x.com/a/b/cat.png?sig=abc"), + "cat.png" + ); + assert_eq!(extract_filename("https://x.com/"), "file"); + assert_eq!(extract_filename(""), "file"); + } + + #[test] + fn extract_filename_from_bare_path() { + // Not a valid URL → fall back to last non-empty `/`-segment. + assert_eq!(extract_filename("/var/log/foo.bin"), "foo.bin"); + // Trailing slash gets skipped; last non-empty segment wins. + assert_eq!(extract_filename("/var/log/"), "log"); + // Plain filename with no slashes. + assert_eq!(extract_filename("plain.txt"), "plain.txt"); + } + + fn make_conn(cfg: super::super::config::YuanbaoConfig) -> Arc { + let (tx, _rx) = tokio::sync::mpsc::unbounded_channel(); + YuanbaoConnection::new(cfg, tx, None) + } + + fn base_cfg() -> super::super::config::YuanbaoConfig { + let mut c = super::super::config::YuanbaoConfig::default(); + c.app_key = "ak".into(); + c.ws_domain = "wss://x".into(); + c.token = "tok".into(); + c.bot_id = "cfg-bot".into(); + c + } + + #[tokio::test] + async fn resolve_from_account_uses_config_bot_id_when_no_sign_manager() { + let conn = make_conn(base_cfg()); + let sender = OutboundSender::new(conn, None, "ak".into(), "cfg-bot".into()); + assert_eq!(sender.resolve_from_account().await, "cfg-bot"); + } + + #[tokio::test] + async fn resolve_from_account_uses_sign_cache_when_bot_id_present() { + let conn = make_conn(base_cfg()); + let mgr = super::super::sign::SignManager::new(reqwest::Client::new()); + // Seed the cache with a bot_id keyed on the same app_key. + mgr.set_cached_for_test( + "ak", + super::super::sign::TokenEntry { + token: "tok".into(), + bot_id: "server-bot".into(), + product: String::new(), + source: "bot".into(), + expire_ts: u64::MAX / 2, + }, + ) + .await; + let sender = OutboundSender::new(conn, Some(mgr), "ak".into(), "fallback-bot".into()); + // Sign cache hit → use server bot_id, not the fallback. + assert_eq!(sender.resolve_from_account().await, "server-bot"); + } + + #[tokio::test] + async fn resolve_from_account_falls_back_when_sign_cache_bot_id_empty() { + let conn = make_conn(base_cfg()); + let mgr = super::super::sign::SignManager::new(reqwest::Client::new()); + mgr.set_cached_for_test( + "ak", + super::super::sign::TokenEntry { + token: "tok".into(), + bot_id: String::new(), + product: String::new(), + source: String::new(), + expire_ts: u64::MAX / 2, + }, + ) + .await; + let sender = OutboundSender::new(conn, Some(mgr), "ak".into(), "fallback-bot".into()); + assert_eq!(sender.resolve_from_account().await, "fallback-bot"); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/proto.rs b/src/openhuman/channels/providers/yuanbao/proto.rs new file mode 100644 index 0000000000..e678f74227 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/proto.rs @@ -0,0 +1,914 @@ +//! Yuanbao WebSocket ConnMsg envelope + built-in protocol commands +//! (auth-bind, ping, push-ack) + TIM `MsgBodyElement` codecs. +//! +//! Each WebSocket binary frame carries one full `ConnMsg` protobuf +//! message; **no extra length prefix is needed** (the WS frame boundary +//! delimits one message). Verified against the hermes-agent Python +//! reference (yuanbao_proto.py) and the TypeScript openclaw plugin. +//! +//! Business-layer codecs (send-message / heartbeat / group query) live +//! in [`super::proto_biz`]. Wire-format primitives (varint, FieldValue, +//! parse_fields) live in [`super::wire`]. + +use super::errors::YuanbaoError; +use super::proto_constants::*; +use super::types::*; +use super::wire::{ + encode_field_bytes, encode_field_string, encode_field_varint, get_bytes, get_repeated_bytes, + get_string, get_varint, next_seq_no, parse_fields, FieldValue, +}; + +// Re-export wire primitives for downstream callers (tests, tools). +pub use super::wire::{decode_varint, encode_varint}; + +// ─── ConnMsg envelope ────────────────────────────────────────────── +// +// message Head { +// uint32 cmd_type = 1; +// string cmd = 2; +// uint32 seq_no = 3; +// string msg_id = 4; +// string module = 5; +// bool need_ack = 6; +// int32 status = 10; +// } +// message ConnMsg { +// Head head = 1; +// bytes data = 2; +// } + +fn encode_head( + cmd_type: u32, + cmd: &str, + seq_no: u32, + msg_id: &str, + module: &str, + need_ack: bool, + status: u32, +) -> Vec { + let mut buf = Vec::with_capacity(64); + if cmd_type != 0 { + encode_field_varint(1, cmd_type as u64, &mut buf); + } + if !cmd.is_empty() { + encode_field_string(2, cmd, &mut buf); + } + if seq_no != 0 { + encode_field_varint(3, seq_no as u64, &mut buf); + } + if !msg_id.is_empty() { + encode_field_string(4, msg_id, &mut buf); + } + if !module.is_empty() { + encode_field_string(5, module, &mut buf); + } + if need_ack { + encode_field_varint(6, 1, &mut buf); + } + if status != 0 { + encode_field_varint(10, status as u64, &mut buf); + } + buf +} + +/// Encode a full `ConnMsg` frame (ready to send as a binary WS frame). +pub fn encode_conn_msg( + cmd_type: u32, + cmd: &str, + seq_no: u32, + msg_id: &str, + module: &str, + data: &[u8], +) -> Vec { + let head = encode_head(cmd_type, cmd, seq_no, msg_id, module, false, 0); + let mut buf = Vec::with_capacity(head.len() + data.len() + 16); + encode_field_bytes(1, &head, &mut buf); + if !data.is_empty() { + encode_field_bytes(2, data, &mut buf); + } + buf +} + +/// Decode a `ConnMsg` frame received from the gateway. +pub fn decode_conn_msg(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + let head_bytes = get_bytes(&fields, 1); + let payload = get_bytes(&fields, 2); + let head_fields = if head_bytes.is_empty() { + Vec::new() + } else { + parse_fields(&head_bytes)? + }; + Ok(ConnFrame { + cmd_type: get_varint(&head_fields, 1) as u32, + cmd: get_string(&head_fields, 2), + seq_no: get_varint(&head_fields, 3) as u32, + msg_id: get_string(&head_fields, 4), + module: get_string(&head_fields, 5), + need_ack: get_varint(&head_fields, 6) != 0, + status: get_varint(&head_fields, 10) as u32, + data: payload, + }) +} + +// ─── Built-in protocol commands ──────────────────────────────────── + +/// `AuthBindReq` — first request after the WebSocket opens. +#[allow(clippy::too_many_arguments)] +pub fn encode_auth_bind( + biz_id: &str, + uid: &str, + source: &str, + token: &str, + msg_id: &str, + app_version: &str, + operation_system: &str, + bot_version: &str, + route_env: &str, +) -> Vec { + let mut auth_buf = Vec::with_capacity(uid.len() + source.len() + token.len() + 16); + encode_field_string(1, uid, &mut auth_buf); + encode_field_string(2, source, &mut auth_buf); + encode_field_string(3, token, &mut auth_buf); + + let mut dev_buf = Vec::with_capacity(64); + if !app_version.is_empty() { + encode_field_string(1, app_version, &mut dev_buf); + } + if !operation_system.is_empty() { + encode_field_string(2, operation_system, &mut dev_buf); + } + encode_field_string(10, OPENHUMAN_INSTANCE_ID, &mut dev_buf); + if !bot_version.is_empty() { + encode_field_string(24, bot_version, &mut dev_buf); + } + + let mut req_buf = Vec::with_capacity(auth_buf.len() + dev_buf.len() + biz_id.len() + 16); + encode_field_string(1, biz_id, &mut req_buf); + encode_field_bytes(2, &auth_buf, &mut req_buf); + encode_field_bytes(3, &dev_buf, &mut req_buf); + if !route_env.is_empty() { + encode_field_string(5, route_env, &mut req_buf); + } + + encode_conn_msg( + cmd_type::REQUEST, + cmd::AUTH_BIND, + next_seq_no(), + msg_id, + module::CONN_ACCESS, + &req_buf, + ) +} + +pub fn encode_ping(msg_id: &str) -> Vec { + encode_conn_msg( + cmd_type::REQUEST, + cmd::PING, + next_seq_no(), + msg_id, + module::CONN_ACCESS, + &[], + ) +} + +/// Decoded `AuthBindRsp` body. +/// +/// message AuthBindRsp { +/// int32 code = 1; +/// string message = 2; +/// string connect_id = 3; +/// } +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct AuthBindRsp { + pub code: i32, + pub message: String, + pub connect_id: String, +} + +/// Parse an `AuthBindRsp` from the biz payload (`ConnMsg.data`). +pub fn decode_auth_bind_rsp(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + Ok(AuthBindRsp { + code: get_varint(&fields, 1) as i32, + message: get_string(&fields, 2), + connect_id: get_string(&fields, 3), + }) +} + +pub fn encode_push_ack(original: &ConnFrame) -> Vec { + encode_conn_msg( + cmd_type::PUSH_ACK, + &original.cmd, + next_seq_no(), + &original.msg_id, + &original.module, + &[], + ) +} + +// ─── MsgBodyElement (TIM) encoding ───────────────────────────────── + +pub fn encode_msg_content(c: &MsgContent) -> Vec { + let mut buf = Vec::with_capacity(64); + if let Some(ref v) = c.text { + if !v.is_empty() { + encode_field_string(1, v, &mut buf); + } + } + if let Some(ref v) = c.uuid { + if !v.is_empty() { + encode_field_string(2, v, &mut buf); + } + } + if let Some(v) = c.image_format { + if v != 0 { + encode_field_varint(3, v as u64, &mut buf); + } + } + if let Some(ref v) = c.data { + if !v.is_empty() { + encode_field_string(4, v, &mut buf); + } + } + if let Some(ref v) = c.desc { + if !v.is_empty() { + encode_field_string(5, v, &mut buf); + } + } + if let Some(ref v) = c.ext { + if !v.is_empty() { + encode_field_string(6, v, &mut buf); + } + } + if let Some(ref v) = c.sound { + if !v.is_empty() { + encode_field_string(7, v, &mut buf); + } + } + for img in &c.image_info_array { + let mut ib = Vec::with_capacity(48); + if img.image_type != 0 { + encode_field_varint(1, img.image_type as u64, &mut ib); + } + if img.size != 0 { + encode_field_varint(2, img.size as u64, &mut ib); + } + if img.width != 0 { + encode_field_varint(3, img.width as u64, &mut ib); + } + if img.height != 0 { + encode_field_varint(4, img.height as u64, &mut ib); + } + if !img.url.is_empty() { + encode_field_string(5, &img.url, &mut ib); + } + encode_field_bytes(8, &ib, &mut buf); + } + if let Some(v) = c.index { + if v != 0 { + encode_field_varint(9, v as u64, &mut buf); + } + } + if let Some(ref v) = c.url { + if !v.is_empty() { + encode_field_string(10, v, &mut buf); + } + } + if let Some(v) = c.file_size { + if v != 0 { + encode_field_varint(11, v as u64, &mut buf); + } + } + if let Some(ref v) = c.file_name { + if !v.is_empty() { + encode_field_string(12, v, &mut buf); + } + } + buf +} + +fn decode_msg_content(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + let mut c = MsgContent::default(); + for (n, v) in &fields { + match (*n, v) { + (1, FieldValue::Bytes(b)) => c.text = Some(String::from_utf8_lossy(b).into_owned()), + (2, FieldValue::Bytes(b)) => c.uuid = Some(String::from_utf8_lossy(b).into_owned()), + (3, FieldValue::Varint(x)) => c.image_format = Some(*x as u32), + (4, FieldValue::Bytes(b)) => c.data = Some(String::from_utf8_lossy(b).into_owned()), + (5, FieldValue::Bytes(b)) => c.desc = Some(String::from_utf8_lossy(b).into_owned()), + (6, FieldValue::Bytes(b)) => c.ext = Some(String::from_utf8_lossy(b).into_owned()), + (7, FieldValue::Bytes(b)) => c.sound = Some(String::from_utf8_lossy(b).into_owned()), + (8, FieldValue::Bytes(b)) => { + let ifields = parse_fields(b)?; + let mut info = ImageInfo { + image_type: get_varint(&ifields, 1) as u32, + size: get_varint(&ifields, 2) as u32, + width: get_varint(&ifields, 3) as u32, + height: get_varint(&ifields, 4) as u32, + url: get_string(&ifields, 5), + }; + if info.image_type != 0 || !info.url.is_empty() { + if info.image_type == 0 { + info.image_type = 1; + } + c.image_info_array.push(info); + } + } + (9, FieldValue::Varint(x)) => c.index = Some(*x as u32), + (10, FieldValue::Bytes(b)) => c.url = Some(String::from_utf8_lossy(b).into_owned()), + (11, FieldValue::Varint(x)) => c.file_size = Some(*x as u32), + (12, FieldValue::Bytes(b)) => { + c.file_name = Some(String::from_utf8_lossy(b).into_owned()) + } + _ => {} + } + } + Ok(c) +} + +pub fn encode_msg_body_element(el: &MsgBodyElement) -> Vec { + let mut buf = Vec::with_capacity(64); + if !el.msg_type.is_empty() { + encode_field_string(1, &el.msg_type, &mut buf); + } + let content = encode_msg_content(&el.msg_content); + if !content.is_empty() { + encode_field_bytes(2, &content, &mut buf); + } + buf +} + +fn decode_msg_body_element(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + let content_bytes = get_bytes(&fields, 2); + let content = if content_bytes.is_empty() { + MsgContent::default() + } else { + decode_msg_content(&content_bytes)? + }; + Ok(MsgBodyElement { + msg_type: get_string(&fields, 1), + msg_content: content, + }) +} + +// ─── PushMsg envelope (cmd_type=Push inner wrapper) ──────────────── +// +// message PushMsg { +// string cmd = 1; +// string module = 2; +// string msg_id = 3; +// bytes data = 4; // ← actual biz body (e.g. InboundMessagePush) +// } +// +// The yuanbao gateway wraps every downstream push in this envelope +// *inside* `ConnMsg.data`. Mirrors plugin client.ts::onPush which +// decodes PushMsg before handing `data` to the business decoder. + +#[derive(Debug, Default)] +pub struct PushMsg { + pub cmd: String, + pub module: String, + pub msg_id: String, + pub data: Vec, +} + +pub fn decode_push_msg(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + Ok(PushMsg { + cmd: get_string(&fields, 1), + module: get_string(&fields, 2), + msg_id: get_string(&fields, 3), + data: get_bytes(&fields, 4), + }) +} + +// ─── InboundMessagePush decode ───────────────────────────────────── + +pub fn decode_inbound_push(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + + let mut msg_body = Vec::new(); + for b in get_repeated_bytes(&fields, 13) { + msg_body.push(decode_msg_body_element(&b)?); + } + + let mut recalls = Vec::new(); + for b in get_repeated_bytes(&fields, 17) { + let f = parse_fields(&b)?; + recalls.push(ImMsgSeq { + msg_seq: get_varint(&f, 1) as u32, + msg_id: get_string(&f, 2), + }); + } + + let log_ext_bytes = get_bytes(&fields, 20); + let trace_id = if log_ext_bytes.is_empty() { + String::new() + } else { + get_string(&parse_fields(&log_ext_bytes)?, 1) + }; + + Ok(InboundMessage { + callback_command: get_string(&fields, 1), + from_account: get_string(&fields, 2), + to_account: get_string(&fields, 3), + sender_nickname: get_string(&fields, 4), + group_id: get_string(&fields, 5), + group_code: get_string(&fields, 6), + group_name: get_string(&fields, 7), + msg_seq: get_varint(&fields, 8) as u32, + msg_random: get_varint(&fields, 9) as u32, + msg_time: get_varint(&fields, 10) as u32, + msg_key: get_string(&fields, 11), + msg_id: get_string(&fields, 12), + msg_body, + cloud_custom_data: get_string(&fields, 14), + event_time: get_varint(&fields, 15) as u32, + bot_owner_id: get_string(&fields, 16), + recall_msg_seq_list: recalls, + claw_msg_type: get_varint(&fields, 18) as u32, + private_from_group_code: get_string(&fields, 19), + trace_id, + }) +} + +// ─── InboundMessagePush JSON decode ──────────────────────────────── +// +// The yuanbao gateway sometimes (depending on backend account config / +// source channel) pushes `inbound_message` as a JSON string instead of +// protobuf. The shape matches `InboundMessagePush` field-for-field +// (snake_case), with `log_ext.trace_id` nested. Mirrors plugin +// gateway.ts::decodeFromRawDataJson (l. 238). + +pub fn decode_inbound_json(data: &[u8]) -> Result { + let v: serde_json::Value = serde_json::from_slice(data) + .map_err(|e| YuanbaoError::ProtoDecode(format!("json parse failed: {e}")))?; + + let obj = v + .as_object() + .ok_or_else(|| YuanbaoError::ProtoDecode("json root is not an object".into()))?; + + let get_str = |k: &str| -> String { + obj.get(k) + .and_then(|x| x.as_str()) + .unwrap_or("") + .to_string() + }; + let get_u32 = |k: &str| -> u32 { obj.get(k).and_then(|x| x.as_u64()).unwrap_or(0) as u32 }; + + let msg_body = obj + .get("msg_body") + .and_then(|v| v.as_array()) + .map(|arr| arr.iter().map(decode_msg_body_element_json).collect()) + .unwrap_or_default(); + + let recall_msg_seq_list = obj + .get("recall_msg_seq_list") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .map(|e| ImMsgSeq { + msg_seq: e.get("msg_seq").and_then(|v| v.as_u64()).unwrap_or(0) as u32, + msg_id: e + .get("msg_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + }) + .collect() + }) + .unwrap_or_default(); + + let trace_id = obj + .get("log_ext") + .and_then(|v| v.get("trace_id")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + Ok(InboundMessage { + callback_command: get_str("callback_command"), + from_account: get_str("from_account"), + to_account: get_str("to_account"), + sender_nickname: get_str("sender_nickname"), + group_id: get_str("group_id"), + group_code: get_str("group_code"), + group_name: get_str("group_name"), + msg_seq: get_u32("msg_seq"), + msg_random: get_u32("msg_random"), + msg_time: get_u32("msg_time"), + msg_key: get_str("msg_key"), + msg_id: get_str("msg_id"), + msg_body, + cloud_custom_data: get_str("cloud_custom_data"), + event_time: get_u32("event_time"), + bot_owner_id: get_str("bot_owner_id"), + recall_msg_seq_list, + claw_msg_type: get_u32("claw_msg_type"), + private_from_group_code: get_str("private_from_group_code"), + trace_id, + }) +} + +fn decode_msg_body_element_json(v: &serde_json::Value) -> MsgBodyElement { + let msg_type = v + .get("msg_type") + .and_then(|x| x.as_str()) + .unwrap_or("") + .to_string(); + let mc = v.get("msg_content").and_then(|x| x.as_object()); + + let str_field = |k: &str| -> Option { + mc.and_then(|m| m.get(k)) + .and_then(|x| x.as_str()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + }; + let u32_field = |k: &str| -> Option { + mc.and_then(|m| m.get(k)) + .and_then(|x| x.as_u64()) + .map(|n| n as u32) + }; + + let image_info_array = mc + .and_then(|m| m.get("image_info_array")) + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .map(|e| ImageInfo { + image_type: e + .get("type") + .or_else(|| e.get("image_type")) + .and_then(|x| x.as_u64()) + .unwrap_or(0) as u32, + size: e.get("size").and_then(|x| x.as_u64()).unwrap_or(0) as u32, + width: e.get("width").and_then(|x| x.as_u64()).unwrap_or(0) as u32, + height: e.get("height").and_then(|x| x.as_u64()).unwrap_or(0) as u32, + url: e + .get("url") + .and_then(|x| x.as_str()) + .unwrap_or("") + .to_string(), + }) + .collect() + }) + .unwrap_or_default(); + + MsgBodyElement { + msg_type, + msg_content: MsgContent { + text: str_field("text"), + uuid: str_field("uuid"), + image_format: u32_field("image_format"), + data: str_field("data"), + desc: str_field("desc"), + ext: str_field("ext"), + sound: str_field("sound"), + image_info_array, + index: u32_field("index"), + url: str_field("url"), + file_size: u32_field("file_size"), + file_name: str_field("file_name"), + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn conn_msg_roundtrip() { + let buf = encode_conn_msg( + cmd_type::REQUEST, + cmd::PING, + 42, + "mid-1", + module::CONN_ACCESS, + b"payload", + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd_type, cmd_type::REQUEST); + assert_eq!(frame.cmd, cmd::PING); + assert_eq!(frame.seq_no, 42); + assert_eq!(frame.msg_id, "mid-1"); + assert_eq!(frame.module, module::CONN_ACCESS); + assert_eq!(frame.data, b"payload"); + } + + #[test] + fn auth_bind_smoke() { + let buf = encode_auth_bind( + "biz", + "uid", + "openclaw", + "tok", + "mid", + "1.0", + "linux", + "openhuman/0.1.0", + "", + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, cmd::AUTH_BIND); + assert_eq!(frame.module, module::CONN_ACCESS); + assert!(!frame.data.is_empty()); + } + + #[test] + fn push_ack_mirrors_original() { + let original = ConnFrame { + cmd_type: cmd_type::PUSH, + cmd: "some_push".into(), + module: "yuanbao_openclaw_proxy".into(), + seq_no: 99, + msg_id: "mid-abc".into(), + need_ack: true, + status: 0, + data: vec![], + }; + let buf = encode_push_ack(&original); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd_type, cmd_type::PUSH_ACK); + assert_eq!(frame.cmd, original.cmd); + assert_eq!(frame.module, original.module); + assert_eq!(frame.msg_id, original.msg_id); + } + + #[test] + fn msg_body_element_roundtrip() { + let el = MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: MsgContent { + text: Some("hello 元宝".into()), + ..Default::default() + }, + }; + let buf = encode_msg_body_element(&el); + let got = decode_msg_body_element(&buf).unwrap(); + assert_eq!(got, el); + } + + #[test] + fn image_element_roundtrip() { + let el = MsgBodyElement { + msg_type: "TIMImageElem".into(), + msg_content: MsgContent { + uuid: Some("abc123".into()), + image_format: Some(3), + image_info_array: vec![ImageInfo { + image_type: 1, + size: 1024, + width: 800, + height: 600, + url: "https://example/img.png".into(), + }], + ..Default::default() + }, + }; + let buf = encode_msg_body_element(&el); + let got = decode_msg_body_element(&buf).unwrap(); + assert_eq!(got, el); + } + + // ─── decode_auth_bind_rsp ───────────────────────────────────── + + fn build_auth_bind_rsp_bytes(code: u64, message: &str, connect_id: &str) -> Vec { + let mut buf = Vec::new(); + if code != 0 { + encode_field_varint(1, code, &mut buf); + } + if !message.is_empty() { + encode_field_string(2, message, &mut buf); + } + if !connect_id.is_empty() { + encode_field_string(3, connect_id, &mut buf); + } + buf + } + + #[test] + fn decode_auth_bind_rsp_happy_path() { + let body = build_auth_bind_rsp_bytes(0, "ok", "conn-42"); + let r = decode_auth_bind_rsp(&body).unwrap(); + assert_eq!(r.code, 0); + assert_eq!(r.message, "ok"); + assert_eq!(r.connect_id, "conn-42"); + } + + #[test] + fn decode_auth_bind_rsp_with_error_code() { + let body = build_auth_bind_rsp_bytes(40011, "rejected", ""); + let r = decode_auth_bind_rsp(&body).unwrap(); + assert_eq!(r.code, 40011); + assert_eq!(r.message, "rejected"); + assert!(r.connect_id.is_empty()); + } + + #[test] + fn decode_auth_bind_rsp_on_empty_returns_default() { + let r = decode_auth_bind_rsp(&[]).unwrap(); + assert_eq!(r, AuthBindRsp::default()); + } + + // ─── decode_push_msg ────────────────────────────────────────── + + #[test] + fn decode_push_msg_extracts_all_fields() { + let inner_payload = vec![0xCA, 0xFE, 0xBA, 0xBE]; + let mut buf = Vec::new(); + encode_field_string(1, "inbound_message", &mut buf); + encode_field_string(2, "yuanbao_openclaw_proxy", &mut buf); + encode_field_string(3, "pm-1", &mut buf); + encode_field_bytes(4, &inner_payload, &mut buf); + + let pm = decode_push_msg(&buf).unwrap(); + assert_eq!(pm.cmd, "inbound_message"); + assert_eq!(pm.module, "yuanbao_openclaw_proxy"); + assert_eq!(pm.msg_id, "pm-1"); + assert_eq!(pm.data, inner_payload); + } + + #[test] + fn decode_push_msg_on_empty_returns_defaults() { + let pm = decode_push_msg(&[]).unwrap(); + assert!(pm.cmd.is_empty()); + assert!(pm.module.is_empty()); + assert!(pm.msg_id.is_empty()); + assert!(pm.data.is_empty()); + } + + // ─── decode_inbound_push (protobuf) ─────────────────────────── + + #[test] + fn decode_inbound_push_dm_with_text_body() { + // Build a minimal DM push: from/to/sender_nickname + one TIMTextElem. + let text_elem = MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: MsgContent { + text: Some("hello".into()), + ..Default::default() + }, + }; + let elem_bytes = encode_msg_body_element(&text_elem); + + let mut log_ext = Vec::new(); + encode_field_string(1, "trace-123", &mut log_ext); + + let mut buf = Vec::new(); + encode_field_string(1, "C2CMsg", &mut buf); + encode_field_string(2, "user_42", &mut buf); + encode_field_string(3, "bot_1", &mut buf); + encode_field_string(4, "Alice", &mut buf); + encode_field_varint(8, 7, &mut buf); + encode_field_varint(9, 123, &mut buf); + encode_field_varint(10, 1_700_000_000, &mut buf); + encode_field_string(12, "mid-abc", &mut buf); + encode_field_bytes(13, &elem_bytes, &mut buf); + encode_field_varint(15, 1_700_000_001, &mut buf); + encode_field_bytes(20, &log_ext, &mut buf); + + let m = decode_inbound_push(&buf).unwrap(); + assert_eq!(m.callback_command, "C2CMsg"); + assert_eq!(m.from_account, "user_42"); + assert_eq!(m.to_account, "bot_1"); + assert_eq!(m.sender_nickname, "Alice"); + assert_eq!(m.msg_seq, 7); + assert_eq!(m.msg_random, 123); + assert_eq!(m.msg_time, 1_700_000_000); + assert_eq!(m.msg_id, "mid-abc"); + assert_eq!(m.event_time, 1_700_000_001); + assert_eq!(m.trace_id, "trace-123"); + assert_eq!(m.msg_body.len(), 1); + assert_eq!(m.msg_body[0].msg_content.text.as_deref(), Some("hello")); + assert!(m.recall_msg_seq_list.is_empty()); + } + + #[test] + fn decode_inbound_push_group_with_recall_list() { + let mut recall_entry = Vec::new(); + encode_field_varint(1, 99, &mut recall_entry); + encode_field_string(2, "old-msg-id", &mut recall_entry); + + let mut buf = Vec::new(); + encode_field_string(1, "GroupSysMsg", &mut buf); + encode_field_string(5, "gid-x", &mut buf); + encode_field_string(6, "gcode-y", &mut buf); + encode_field_string(7, "Room", &mut buf); + encode_field_bytes(17, &recall_entry, &mut buf); + encode_field_string(19, "g-private-code", &mut buf); + + let m = decode_inbound_push(&buf).unwrap(); + assert_eq!(m.callback_command, "GroupSysMsg"); + assert_eq!(m.group_id, "gid-x"); + assert_eq!(m.group_code, "gcode-y"); + assert_eq!(m.group_name, "Room"); + assert_eq!(m.private_from_group_code, "g-private-code"); + assert_eq!(m.recall_msg_seq_list.len(), 1); + assert_eq!(m.recall_msg_seq_list[0].msg_seq, 99); + assert_eq!(m.recall_msg_seq_list[0].msg_id, "old-msg-id"); + assert!(m.trace_id.is_empty(), "no log_ext => empty trace_id"); + } + + // ─── decode_inbound_json ────────────────────────────────────── + + #[test] + fn decode_inbound_json_full_dm_shape() { + let json = serde_json::json!({ + "callback_command": "C2CMsg", + "from_account": "user_42", + "to_account": "bot_1", + "sender_nickname": "Alice", + "msg_seq": 7, + "msg_random": 123, + "msg_time": 1_700_000_000, + "msg_id": "mid-1", + "msg_body": [ + { + "msg_type": "TIMTextElem", + "msg_content": { "text": "hi" } + }, + { + "msg_type": "TIMImageElem", + "msg_content": { + "uuid": "u-1", + "image_format": 1, + "image_info_array": [ + { "type": 1, "size": 100, "width": 10, "height": 20, "url": "https://x/i.png" } + ] + } + } + ], + "recall_msg_seq_list": [{ "msg_seq": 9, "msg_id": "old" }], + "log_ext": { "trace_id": "trace-json" } + }); + let m = decode_inbound_json(json.to_string().as_bytes()).unwrap(); + assert_eq!(m.callback_command, "C2CMsg"); + assert_eq!(m.from_account, "user_42"); + assert_eq!(m.msg_id, "mid-1"); + assert_eq!(m.msg_body.len(), 2); + assert_eq!(m.msg_body[0].msg_content.text.as_deref(), Some("hi")); + let img = &m.msg_body[1].msg_content; + assert_eq!(img.uuid.as_deref(), Some("u-1")); + assert_eq!(img.image_info_array.len(), 1); + assert_eq!(img.image_info_array[0].url, "https://x/i.png"); + assert_eq!(m.recall_msg_seq_list.len(), 1); + assert_eq!(m.recall_msg_seq_list[0].msg_seq, 9); + assert_eq!(m.trace_id, "trace-json"); + } + + #[test] + fn decode_inbound_json_rejects_non_object_root() { + let err = decode_inbound_json(b"[1,2,3]").unwrap_err(); + match err { + YuanbaoError::ProtoDecode(m) => assert!(m.contains("not an object"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn decode_inbound_json_rejects_invalid_json() { + let err = decode_inbound_json(b"not json").unwrap_err(); + match err { + YuanbaoError::ProtoDecode(m) => assert!(m.contains("json parse"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn decode_msg_body_element_json_handles_image_type_alias() { + // Some payloads use `image_type` (snake_case) instead of `type`. + let v = serde_json::json!({ + "msg_type": "TIMImageElem", + "msg_content": { + "image_info_array": [ + { "image_type": 2, "size": 50, "width": 5, "height": 6, "url": "u" } + ] + } + }); + let el = decode_msg_body_element_json(&v); + assert_eq!(el.msg_type, "TIMImageElem"); + assert_eq!(el.msg_content.image_info_array.len(), 1); + assert_eq!(el.msg_content.image_info_array[0].image_type, 2); + } + + #[test] + fn decode_msg_content_image_info_with_only_image_type_zero_defaults_to_one() { + // When `image_type` is 0 but url is present, decoder bumps to 1. + let mut ib = Vec::new(); + encode_field_varint(2, 64, &mut ib); + encode_field_string(5, "https://x/y.png", &mut ib); + let mut content = Vec::new(); + encode_field_bytes(8, &ib, &mut content); + let mut elem = Vec::new(); + encode_field_string(1, "TIMImageElem", &mut elem); + encode_field_bytes(2, &content, &mut elem); + let got = decode_msg_body_element(&elem).unwrap(); + assert_eq!(got.msg_content.image_info_array.len(), 1); + assert_eq!(got.msg_content.image_info_array[0].image_type, 1); + assert_eq!(got.msg_content.image_info_array[0].url, "https://x/y.png"); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/proto_biz.rs b/src/openhuman/channels/providers/yuanbao/proto_biz.rs new file mode 100644 index 0000000000..daf91f517d --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/proto_biz.rs @@ -0,0 +1,638 @@ +//! Business-layer protobuf codecs (biz payloads inside `ConnMsg.data`). +//! +//! Kept separate from `proto.rs` to stay under the 500-line ceiling and +//! to isolate the "openclaw biz protocol" surface from the lower-level +//! ConnMsg envelope. + +use super::errors::YuanbaoError; +use super::proto::{decode_conn_msg, encode_conn_msg, encode_msg_body_element}; +use super::proto_constants::*; +use super::types::*; +use super::wire::{ + encode_field_bytes as put_bytes_field, encode_field_string as put_string_field, + encode_field_varint as put_varint_field, get_bytes, get_repeated_bytes, get_string, get_varint, + next_seq_no, parse_fields, +}; + +// ─── SendC2CMessageReq ──────────────────────────────────────────── +// +// 1: msg_id (string) 5: msg_body (repeated MsgBodyElement) +// 2: to_account 6: group_code (DM-from-group) +// 3: from_account 7: msg_seq +// 4: msg_random 8: log_ext + +#[allow(clippy::too_many_arguments)] +fn encode_send_c2c_req( + msg_id: &str, + to_account: &str, + from_account: &str, + msg_random: u32, + msg_body: &[MsgBodyElement], + group_code: &str, + msg_seq: Option, + trace_id: &str, +) -> Vec { + let mut buf = Vec::with_capacity(128); + if !msg_id.is_empty() { + put_string_field(1, msg_id, &mut buf); + } + put_string_field(2, to_account, &mut buf); + if !from_account.is_empty() { + put_string_field(3, from_account, &mut buf); + } + if msg_random != 0 { + put_varint_field(4, msg_random as u64, &mut buf); + } + for el in msg_body { + let el_bytes = encode_msg_body_element(el); + put_bytes_field(5, &el_bytes, &mut buf); + } + if !group_code.is_empty() { + put_string_field(6, group_code, &mut buf); + } + if let Some(seq) = msg_seq { + put_varint_field(7, seq, &mut buf); + } + if !trace_id.is_empty() { + // log_ext is field 8 with a nested {1: trace_id} + let mut log = Vec::new(); + put_string_field(1, trace_id, &mut log); + put_bytes_field(8, &log, &mut buf); + } + buf +} + +/// Encode a full C2C send request as a `ConnMsg` ready to send over WS. +pub fn encode_send_c2c_message( + to_account: &str, + from_account: &str, + msg_body: &[MsgBodyElement], + msg_id: &str, + msg_random: u32, + group_code: &str, + trace_id: &str, +) -> Vec { + let body = encode_send_c2c_req( + msg_id, + to_account, + from_account, + msg_random, + msg_body, + group_code, + None, + trace_id, + ); + let req_id = if msg_id.is_empty() { + format!("c2c_{}", next_seq_no()) + } else { + msg_id.to_string() + }; + encode_conn_msg( + cmd_type::REQUEST, + biz_cmd::SEND_C2C_MESSAGE, + next_seq_no(), + &req_id, + module::BIZ_PKG, + &body, + ) +} + +// ─── SendGroupMessageReq ─────────────────────────────────────────── +// +// 1: msg_id 5: random (string) +// 2: group_code 6: msg_body (repeated) +// 3: from_account 7: ref_msg_id +// 4: to_account 8: msg_seq +// 9: log_ext + +#[allow(clippy::too_many_arguments)] +fn encode_send_group_req( + msg_id: &str, + group_code: &str, + from_account: &str, + to_account: &str, + random: &str, + msg_body: &[MsgBodyElement], + ref_msg_id: &str, + trace_id: &str, +) -> Vec { + let mut buf = Vec::with_capacity(128); + if !msg_id.is_empty() { + put_string_field(1, msg_id, &mut buf); + } + put_string_field(2, group_code, &mut buf); + if !from_account.is_empty() { + put_string_field(3, from_account, &mut buf); + } + if !to_account.is_empty() { + put_string_field(4, to_account, &mut buf); + } + if !random.is_empty() { + put_string_field(5, random, &mut buf); + } + for el in msg_body { + let el_bytes = encode_msg_body_element(el); + put_bytes_field(6, &el_bytes, &mut buf); + } + if !ref_msg_id.is_empty() { + put_string_field(7, ref_msg_id, &mut buf); + } + if !trace_id.is_empty() { + let mut log = Vec::new(); + put_string_field(1, trace_id, &mut log); + put_bytes_field(9, &log, &mut buf); + } + buf +} + +#[allow(clippy::too_many_arguments)] +pub fn encode_send_group_message( + group_code: &str, + from_account: &str, + msg_body: &[MsgBodyElement], + msg_id: &str, + to_account: &str, + random: &str, + ref_msg_id: &str, + trace_id: &str, +) -> Vec { + let body = encode_send_group_req( + msg_id, + group_code, + from_account, + to_account, + random, + msg_body, + ref_msg_id, + trace_id, + ); + let req_id = if msg_id.is_empty() { + format!("grp_{}", next_seq_no()) + } else { + msg_id.to_string() + }; + encode_conn_msg( + cmd_type::REQUEST, + biz_cmd::SEND_GROUP_MESSAGE, + next_seq_no(), + &req_id, + module::BIZ_PKG, + &body, + ) +} + +// ─── Heartbeats ──────────────────────────────────────────────────── + +pub fn encode_send_private_heartbeat( + req_id: &str, + from_account: &str, + to_account: &str, + heartbeat: u32, +) -> Vec { + let mut body = Vec::with_capacity(48); + put_string_field(1, from_account, &mut body); + put_string_field(2, to_account, &mut body); + put_varint_field(3, heartbeat as u64, &mut body); + encode_conn_msg( + cmd_type::REQUEST, + biz_cmd::SEND_PRIVATE_HEARTBEAT, + next_seq_no(), + req_id, + module::BIZ_PKG, + &body, + ) +} + +pub fn encode_send_group_heartbeat( + req_id: &str, + from_account: &str, + group_code: &str, + heartbeat: u32, + send_time_ms: u64, +) -> Vec { + let mut body = Vec::with_capacity(64); + put_string_field(1, from_account, &mut body); + put_string_field(2, "", &mut body); // to_account empty for group + put_string_field(3, group_code, &mut body); + put_varint_field(4, send_time_ms, &mut body); + put_varint_field(5, heartbeat as u64, &mut body); + encode_conn_msg( + cmd_type::REQUEST, + biz_cmd::SEND_GROUP_HEARTBEAT, + next_seq_no(), + req_id, + module::BIZ_PKG, + &body, + ) +} + +// ─── QueryGroupInfo ──────────────────────────────────────────────── + +pub fn encode_query_group_info(req_id: &str, group_code: &str) -> Vec { + let mut body = Vec::with_capacity(16 + group_code.len()); + put_string_field(1, group_code, &mut body); + encode_conn_msg( + cmd_type::REQUEST, + biz_cmd::QUERY_GROUP_INFO, + next_seq_no(), + req_id, + module::BIZ_PKG, + &body, + ) +} + +/// Try to narrow a varint into a smaller integer type, returning +/// `YuanbaoError::ProtoDecode` (instead of silently truncating) when +/// the upstream value is out of range. Used to harden response decoders +/// against malformed / adversarial input. +fn varint_to_i32(value: u64, field_label: &str) -> Result { + i32::try_from(value) + .map_err(|_| YuanbaoError::ProtoDecode(format!("{field_label} out of i32 range: {value}"))) +} + +fn varint_to_u32(value: u64, field_label: &str) -> Result { + u32::try_from(value) + .map_err(|_| YuanbaoError::ProtoDecode(format!("{field_label} out of u32 range: {value}"))) +} + +pub fn decode_query_group_info_rsp(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + let mut info = GroupInfo { + code: varint_to_i32(get_varint(&fields, 1), "GroupInfoRsp.code")?, + message: get_string(&fields, 2), + ..Default::default() + }; + let gi_bytes = get_bytes(&fields, 3); + if !gi_bytes.is_empty() { + let gi = parse_fields(&gi_bytes)?; + info.group_name = get_string(&gi, 1); + info.owner_id = get_string(&gi, 2); + info.owner_nickname = get_string(&gi, 3); + info.member_count = varint_to_u32(get_varint(&gi, 4), "GroupInfo.member_count")?; + } + Ok(info) +} + +// ─── GetGroupMemberList ──────────────────────────────────────────── + +pub fn encode_get_group_member_list( + req_id: &str, + group_code: &str, + offset: u32, + limit: u32, +) -> Vec { + let mut body = Vec::with_capacity(32 + group_code.len()); + put_string_field(1, group_code, &mut body); + if offset != 0 { + put_varint_field(2, offset as u64, &mut body); + } + put_varint_field(3, limit as u64, &mut body); + encode_conn_msg( + cmd_type::REQUEST, + biz_cmd::GET_GROUP_MEMBER_LIST, + next_seq_no(), + req_id, + module::BIZ_PKG, + &body, + ) +} + +pub fn decode_get_group_member_list_rsp(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + let mut members = Vec::new(); + for b in get_repeated_bytes(&fields, 3) { + let m = parse_fields(&b)?; + members.push(GroupMember { + user_id: get_string(&m, 1), + nickname: get_string(&m, 2), + role: varint_to_u32(get_varint(&m, 3), "GroupMember.role")?, + join_time: varint_to_u32(get_varint(&m, 4), "GroupMember.join_time")?, + name_card: get_string(&m, 5), + }); + } + Ok(GroupMemberListPage { + code: varint_to_i32(get_varint(&fields, 1), "GroupMemberListRsp.code")?, + message: get_string(&fields, 2), + members, + next_offset: varint_to_u32(get_varint(&fields, 4), "GroupMemberListRsp.next_offset")?, + is_complete: get_varint(&fields, 5) != 0, + }) +} + +// ─── Generic biz response code helper ────────────────────────────── + +/// Decode the `code` and `message` from a biz response. +/// +/// All biz responses share the convention: field 1 = code, field 2 = message. +pub fn decode_biz_rsp_code(data: &[u8]) -> Result<(i32, String), YuanbaoError> { + let fields = parse_fields(data)?; + Ok(( + varint_to_i32(get_varint(&fields, 1), "BizRsp.code")?, + get_string(&fields, 2), + )) +} + +/// Decode a `ConnMsg` and return the typed biz response code + frame for +/// the request/response correlator. +pub fn decode_response_envelope(frame_bytes: &[u8]) -> Result { + decode_conn_msg(frame_bytes) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn text_body(s: &str) -> Vec { + vec![MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: MsgContent { + text: Some(s.into()), + ..Default::default() + }, + }] + } + + #[test] + fn c2c_encode_smoke() { + let buf = encode_send_c2c_message("uid_alice", "uid_bot", &text_body("hi"), "", 0, "", ""); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_C2C_MESSAGE); + assert_eq!(frame.module, module::BIZ_PKG); + assert!(!frame.data.is_empty()); + } + + #[test] + fn group_encode_smoke() { + let buf = encode_send_group_message( + "group_42", + "uid_bot", + &text_body("hello"), + "", + "", + "rand", + "", + "", + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_GROUP_MESSAGE); + } + + #[test] + fn private_heartbeat_smoke() { + let buf = + encode_send_private_heartbeat("hb_1", "uid_bot", "uid_user", ws_heartbeat::RUNNING); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_PRIVATE_HEARTBEAT); + assert_eq!(frame.msg_id, "hb_1"); + } + + #[test] + fn query_group_info_roundtrip() { + let buf = encode_query_group_info("qgi_1", "group_99"); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::QUERY_GROUP_INFO); + assert_eq!(frame.msg_id, "qgi_1"); + + // Simulate response payload: code=0, message="ok", group_name="g", owner=… + let mut gi = Vec::new(); + put_string_field(1, "TestGroup", &mut gi); + put_string_field(2, "owner_uid", &mut gi); + put_string_field(3, "OwnerNick", &mut gi); + put_varint_field(4, 42, &mut gi); + let mut rsp = Vec::new(); + put_varint_field(1, 0, &mut rsp); + put_string_field(2, "ok", &mut rsp); + put_bytes_field(3, &gi, &mut rsp); + + let parsed = decode_query_group_info_rsp(&rsp).unwrap(); + assert_eq!(parsed.code, 0); + assert_eq!(parsed.group_name, "TestGroup"); + assert_eq!(parsed.owner_id, "owner_uid"); + assert_eq!(parsed.member_count, 42); + } + + // ─── encode_send_c2c branches ────────────────────────────────── + + #[test] + fn c2c_encode_with_msg_id_msg_random_group_code_trace_id() { + // Hit the branches: msg_id non-empty, msg_random != 0, group_code + // non-empty, trace_id non-empty. + let buf = encode_send_c2c_message( + "uid_alice", + "uid_bot", + &text_body("hi"), + "mid-1", + 42, + "gcode-x", + "trace-1", + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_C2C_MESSAGE); + assert_eq!(frame.msg_id, "mid-1"); + // Re-parse the biz body and check the fields we encoded show up. + let f = parse_fields(&frame.data).unwrap(); + assert_eq!(get_string(&f, 1), "mid-1"); + assert_eq!(get_string(&f, 2), "uid_alice"); + assert_eq!(get_string(&f, 3), "uid_bot"); + assert_eq!(get_varint(&f, 4), 42); + assert_eq!(get_string(&f, 6), "gcode-x"); + // log_ext (field 8) carries nested {1: trace_id} + let log_ext = get_bytes(&f, 8); + assert!(!log_ext.is_empty()); + let inner = parse_fields(&log_ext).unwrap(); + assert_eq!(get_string(&inner, 1), "trace-1"); + } + + #[test] + fn c2c_encode_generates_synthetic_req_id_when_msg_id_empty() { + // msg_id empty branch — req_id falls back to `c2c_`. + let buf = encode_send_c2c_message("uid_alice", "uid_bot", &text_body("hi"), "", 0, "", ""); + let frame = decode_conn_msg(&buf).unwrap(); + assert!( + frame.msg_id.starts_with("c2c_"), + "expected synthetic req_id starting with c2c_, got {}", + frame.msg_id + ); + } + + // ─── encode_send_group branches ──────────────────────────────── + + #[test] + fn group_encode_with_all_optional_fields() { + let buf = encode_send_group_message( + "group_42", + "uid_bot", + &text_body("hello"), + "mid-g", + "uid_to", + "rand_x", + "ref-msg-99", + "trace-g", + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_GROUP_MESSAGE); + assert_eq!(frame.msg_id, "mid-g"); + let f = parse_fields(&frame.data).unwrap(); + assert_eq!(get_string(&f, 1), "mid-g"); + assert_eq!(get_string(&f, 2), "group_42"); + assert_eq!(get_string(&f, 3), "uid_bot"); + assert_eq!(get_string(&f, 4), "uid_to"); + assert_eq!(get_string(&f, 5), "rand_x"); + assert_eq!(get_string(&f, 7), "ref-msg-99"); + let log_ext = get_bytes(&f, 9); + let inner = parse_fields(&log_ext).unwrap(); + assert_eq!(get_string(&inner, 1), "trace-g"); + } + + #[test] + fn group_encode_generates_synthetic_req_id_when_msg_id_empty() { + let buf = + encode_send_group_message("group_x", "uid_bot", &text_body("hi"), "", "", "", "", ""); + let frame = decode_conn_msg(&buf).unwrap(); + assert!( + frame.msg_id.starts_with("grp_"), + "expected synthetic req_id starting with grp_, got {}", + frame.msg_id + ); + } + + // ─── encode_send_group_heartbeat ─────────────────────────────── + + #[test] + fn group_heartbeat_encodes_send_time_and_heartbeat() { + let buf = encode_send_group_heartbeat( + "hb_g_1", + "uid_bot", + "group_42", + ws_heartbeat::RUNNING, + 1_700_000_123, + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_GROUP_HEARTBEAT); + assert_eq!(frame.msg_id, "hb_g_1"); + let f = parse_fields(&frame.data).unwrap(); + assert_eq!(get_string(&f, 1), "uid_bot"); + assert_eq!(get_string(&f, 2), ""); // to_account empty for group + assert_eq!(get_string(&f, 3), "group_42"); + assert_eq!(get_varint(&f, 4), 1_700_000_123); + assert_eq!(get_varint(&f, 5), ws_heartbeat::RUNNING as u64); + } + + // ─── encode_get_group_member_list ────────────────────────────── + + #[test] + fn get_group_member_list_omits_offset_when_zero() { + let buf = encode_get_group_member_list("qgm_1", "group_42", 0, 100); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::GET_GROUP_MEMBER_LIST); + let f = parse_fields(&frame.data).unwrap(); + assert_eq!(get_string(&f, 1), "group_42"); + // offset (field 2) skipped when 0 + assert_eq!(get_varint(&f, 2), 0); + assert_eq!(get_varint(&f, 3), 100); + } + + #[test] + fn get_group_member_list_includes_offset_when_nonzero() { + let buf = encode_get_group_member_list("qgm_2", "group_42", 200, 50); + let frame = decode_conn_msg(&buf).unwrap(); + let f = parse_fields(&frame.data).unwrap(); + assert_eq!(get_varint(&f, 2), 200); + assert_eq!(get_varint(&f, 3), 50); + } + + // ─── decode_biz_rsp_code + decode_response_envelope ──────────── + + #[test] + fn decode_biz_rsp_code_reads_code_and_message() { + let mut buf = Vec::new(); + put_varint_field(1, 4002, &mut buf); + put_string_field(2, "rate limited", &mut buf); + let (code, msg) = decode_biz_rsp_code(&buf).unwrap(); + assert_eq!(code, 4002); + assert_eq!(msg, "rate limited"); + } + + #[test] + fn decode_biz_rsp_code_on_empty_returns_defaults() { + let (code, msg) = decode_biz_rsp_code(&[]).unwrap(); + assert_eq!(code, 0); + assert!(msg.is_empty()); + } + + #[test] + fn decode_response_envelope_extracts_frame() { + let original = encode_conn_msg( + cmd_type::RESPONSE, + biz_cmd::SEND_C2C_MESSAGE, + 1, + "mid-r", + module::BIZ_PKG, + &[0xAA, 0xBB], + ); + let frame = decode_response_envelope(&original).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_C2C_MESSAGE); + assert_eq!(frame.msg_id, "mid-r"); + assert_eq!(frame.data, vec![0xAA, 0xBB]); + } + + #[test] + fn group_member_list_decode() { + let mut m1 = Vec::new(); + put_string_field(1, "uid_a", &mut m1); + put_string_field(2, "Alice", &mut m1); + put_varint_field(3, 2, &mut m1); + let mut rsp = Vec::new(); + put_varint_field(1, 0, &mut rsp); + put_string_field(2, "ok", &mut rsp); + put_bytes_field(3, &m1, &mut rsp); + put_varint_field(4, 100, &mut rsp); + put_varint_field(5, 1, &mut rsp); + + let page = decode_get_group_member_list_rsp(&rsp).unwrap(); + assert_eq!(page.members.len(), 1); + assert_eq!(page.members[0].user_id, "uid_a"); + assert_eq!(page.members[0].role, 2); + assert_eq!(page.next_offset, 100); + assert!(page.is_complete); + } + + /// Adversarial input: a varint that overflows i32. The decoder must + /// surface `YuanbaoError::ProtoDecode` instead of silently truncating + /// (which would corrupt the `code` field returned to callers). + #[test] + fn decode_biz_rsp_code_rejects_varint_out_of_i32_range() { + let mut buf = Vec::new(); + put_varint_field(1, u64::MAX, &mut buf); + put_string_field(2, "ok", &mut buf); + match decode_biz_rsp_code(&buf).unwrap_err() { + YuanbaoError::ProtoDecode(m) => { + assert!( + m.contains("out of i32 range"), + "expected i32 overflow message, got: {m}" + ); + } + other => panic!("expected ProtoDecode, got {other:?}"), + } + } + + /// Same guard applied to the group-member-list `next_offset` field — + /// an oversized varint must produce a structured decode error, not a + /// silent `as u32` wrap that would mis-paginate subsequent fetches. + #[test] + fn decode_group_member_list_rejects_varint_out_of_u32_range() { + let mut rsp = Vec::new(); + put_varint_field(1, 0, &mut rsp); + put_string_field(2, "ok", &mut rsp); + put_varint_field(4, u64::from(u32::MAX) + 1, &mut rsp); + match decode_get_group_member_list_rsp(&rsp).unwrap_err() { + YuanbaoError::ProtoDecode(m) => { + assert!( + m.contains("out of u32 range"), + "expected u32 overflow message, got: {m}" + ); + } + other => panic!("expected ProtoDecode, got {other:?}"), + } + } +} diff --git a/src/openhuman/channels/providers/yuanbao/proto_constants.rs b/src/openhuman/channels/providers/yuanbao/proto_constants.rs new file mode 100644 index 0000000000..bde9fdb58e --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/proto_constants.rs @@ -0,0 +1,90 @@ +//! Yuanbao WebSocket protocol constants. +//! +//! Values mirror `gateway/platforms/yuanbao_proto.py` in hermes-agent +//! (the authoritative reference implementation). + +/// `ConnMsg.Head.cmd_type` enum. +pub mod cmd_type { + /// Upstream request. + pub const REQUEST: u32 = 0; + /// Response to a previous upstream request. + pub const RESPONSE: u32 = 1; + /// Downstream push from server. + pub const PUSH: u32 = 2; + /// ACK reply to a downstream push. + pub const PUSH_ACK: u32 = 3; +} + +/// Built-in command words used in `ConnMsg.Head.cmd`. +pub mod cmd { + pub const AUTH_BIND: &str = "auth-bind"; + pub const PING: &str = "ping"; + pub const KICKOUT: &str = "kickout"; + pub const UPDATE_META: &str = "update-meta"; +} + +/// Module / service names used in `ConnMsg.Head.module`. +pub mod module { + pub const CONN_ACCESS: &str = "conn_access"; + /// Short name of the openclaw biz module (matches TS client). + pub const BIZ_PKG: &str = "yuanbao_openclaw_proxy"; +} + +/// Business command words (`ConnMsg.Head.cmd` when module=BIZ_PKG). +/// +/// Note: there is intentionally no constant for the inbound push cmd — +/// the yuanbao gateway uses several cmd words for inbound messages and +/// the routing is purely by `cmd_type=Push` (see `connection.rs` / +/// `mod.rs::dispatch_push`). +pub mod biz_cmd { + pub const SEND_C2C_MESSAGE: &str = "send_c2c_message"; + pub const SEND_GROUP_MESSAGE: &str = "send_group_message"; + pub const SEND_PRIVATE_HEARTBEAT: &str = "send_private_heartbeat"; + pub const SEND_GROUP_HEARTBEAT: &str = "send_group_heartbeat"; + pub const QUERY_GROUP_INFO: &str = "query_group_info"; + pub const GET_GROUP_MEMBER_LIST: &str = "get_group_member_list"; +} + +/// Reply Heartbeat status enum (`heartbeat` field of `Send*HeartbeatReq`). +pub mod ws_heartbeat { + /// Bot is currently producing output. + pub const RUNNING: u32 = 1; + /// Bot has finished its turn. + pub const FINISH: u32 = 2; +} + +/// TIM `msg_type` string constants for `MsgBodyElement.msg_type`. +pub mod tim { + pub const TEXT: &str = "TIMTextElem"; + pub const IMAGE: &str = "TIMImageElem"; + pub const FILE: &str = "TIMFileElem"; + pub const SOUND: &str = "TIMSoundElem"; + pub const VIDEO: &str = "TIMVideoFileElem"; + pub const FACE: &str = "TIMFaceElem"; + pub const CUSTOM: &str = "TIMCustomElem"; +} + +/// Fixed instance id reported in `AuthBindReq.DeviceInfo.instance_id` and +/// the `X-Instance-Id` HTTP header. Mirrors `OPENCLAW_ID = 16` used by +/// `yuanbao-openclaw-plugin` (`src/access/ws/conn-codec.ts`) — the server +/// keys some checks off this value, so it must match the value the sign +/// endpoint sees when the token is minted. +pub const OPENHUMAN_INSTANCE_ID: &str = "16"; + +/// Reconnect backoff schedule (seconds). Mirrors hermes-agent. +pub const RECONNECT_DELAYS: &[u64] = &[1, 2, 5, 10, 30, 60]; +pub const MAX_RECONNECT_ATTEMPTS: u32 = 100; + +/// Ping interval (seconds). Server-driven; this is the upper bound. +pub const PING_INTERVAL_SECS: u64 = 30; +/// Number of consecutive ping timeouts before the connection is dropped. +pub const HEARTBEAT_TIMEOUT_THRESHOLD: u32 = 2; +/// Per-request biz timeout (seconds). +pub const DEFAULT_SEND_TIMEOUT_SECS: u64 = 30; +/// Auth-bind handshake timeout (seconds). +pub const AUTH_TIMEOUT_SECS: u64 = 15; + +/// Inbound dedup TTL — drop a `msg_id` we've already seen within this window. +pub const DEDUP_TTL_SECS: u64 = 300; +/// LRU-style cap on the dedup table. +pub const DEDUP_CAPACITY: usize = 10_000; diff --git a/src/openhuman/channels/providers/yuanbao/sign.rs b/src/openhuman/channels/providers/yuanbao/sign.rs new file mode 100644 index 0000000000..03791c6b14 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/sign.rs @@ -0,0 +1,629 @@ +//! Token sign manager — talks to `/api/v5/robotLogic/sign-token` to +//! exchange `(app_key, app_secret)` for a short-lived WS token + bot_id. +//! +//! Mirrors hermes-agent `SignManager` (yuanbao.py 641-881). Implements: +//! - per-app_key tokio `Mutex` to coalesce concurrent refresh attempts +//! - 60-second early-refresh margin to avoid using a token that's +//! about to expire mid-handshake +//! - retry on `code=10099` up to 3 times +//! +//! Signature scheme (TS plugin compatible): +//! plain = nonce + timestamp + app_key + app_secret +//! signature = HMAC-SHA256(key = app_secret, msg = plain) as lower-hex + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use chrono::FixedOffset; +use hmac::{Hmac, Mac}; +use sha2::Sha256; +use tokio::sync::Mutex; +use tracing::{info, warn}; + +use super::errors::YuanbaoError; + +const SIGN_PATH: &str = "/api/v5/robotLogic/sign-token"; +const RETRYABLE_CODE: i64 = 10099; +const MAX_RETRIES: usize = 3; +const RETRY_DELAY_MS: u64 = 1_000; +/// Treat as expiring this many seconds before actual expiry so a fresh +/// token is fetched before the running one dies mid-request. +const CACHE_REFRESH_MARGIN_SECS: u64 = 60; +const HTTP_TIMEOUT_SECS: u64 = 10; +const DEFAULT_DURATION_SECS: u64 = 3600; + +/// One cached token entry. +#[derive(Debug, Clone)] +pub struct TokenEntry { + pub token: String, + pub bot_id: String, + pub product: String, + pub source: String, + /// Seconds-since-epoch when this token expires server-side. + pub expire_ts: u64, +} + +impl TokenEntry { + pub fn is_valid(&self) -> bool { + let now = unix_now(); + self.expire_ts > now + CACHE_REFRESH_MARGIN_SECS + } + + pub fn seconds_remaining(&self) -> i64 { + self.expire_ts as i64 - unix_now() as i64 + } +} + +type HmacSha256 = Hmac; + +/// Compute the `signature` field for the sign-token API. +pub fn compute_signature(nonce: &str, timestamp: &str, app_key: &str, app_secret: &str) -> String { + let plain = format!("{nonce}{timestamp}{app_key}{app_secret}"); + let mut mac = + HmacSha256::new_from_slice(app_secret.as_bytes()).expect("HMAC accepts any key length"); + mac.update(plain.as_bytes()); + hex::encode(mac.finalize().into_bytes()) +} + +/// Build Beijing-time ISO-8601 timestamp without milliseconds. +/// Format: `2006-01-02T15:04:05+08:00`. +pub fn build_timestamp() -> String { + let bj_offset = FixedOffset::east_opt(8 * 3600).expect("valid offset"); + let now = chrono::Utc::now().with_timezone(&bj_offset); + now.format("%Y-%m-%dT%H:%M:%S+08:00").to_string() +} + +/// Generate a 32-char hex nonce. +pub fn generate_nonce() -> String { + let mut bytes = [0u8; 16]; + for b in &mut bytes { + *b = rand::random::(); + } + hex::encode(bytes) +} + +/// Process-wide token manager. One instance is built per `YuanbaoChannel` +/// and shared with the connection layer; the per-app_key Mutex makes it +/// safe to have multiple connections sharing this manager. +pub struct SignManager { + http: reqwest::Client, + /// Per-app_key refresh mutexes — coalesce concurrent refresh attempts. + locks: Mutex>>>, + /// Token cache keyed by app_key. + cache: Mutex>, +} + +impl SignManager { + pub fn new(http: reqwest::Client) -> Arc { + Arc::new(Self { + http, + locks: Mutex::new(HashMap::new()), + cache: Mutex::new(HashMap::new()), + }) + } + + /// Look up a cached token without touching the network. + pub async fn cached(&self, app_key: &str) -> Option { + let cache = self.cache.lock().await; + cache.get(app_key).cloned().filter(|e| e.is_valid()) + } + + /// Test-only: inject a cache entry without touching the sign endpoint. + #[cfg(test)] + pub(crate) async fn set_cached_for_test(&self, app_key: &str, entry: TokenEntry) { + self.cache.lock().await.insert(app_key.to_string(), entry); + } + + /// Get a valid token, fetching one if the cache is empty or stale. + pub async fn get_token( + &self, + app_key: &str, + app_secret: &str, + api_domain: &str, + route_env: &str, + ) -> Result { + if let Some(entry) = self.cached(app_key).await { + info!( + "[yuanbao/sign] using cached token ({}s remaining)", + entry.seconds_remaining() + ); + return Ok(entry); + } + self.refresh(app_key, app_secret, api_domain, route_env) + .await + } + + /// Force-refresh: drop the cache entry and re-fetch. + pub async fn force_refresh( + &self, + app_key: &str, + app_secret: &str, + api_domain: &str, + route_env: &str, + ) -> Result { + { + let mut cache = self.cache.lock().await; + cache.remove(app_key); + } + warn!( + "[yuanbao/sign] force-refresh app_key=****{}", + suffix(app_key) + ); + self.refresh(app_key, app_secret, api_domain, route_env) + .await + } + + async fn refresh( + &self, + app_key: &str, + app_secret: &str, + api_domain: &str, + route_env: &str, + ) -> Result { + let lock = self.get_refresh_lock(app_key).await; + let _g = lock.lock().await; + + // Double-checked locking: another task may have refreshed while we waited. + if let Some(entry) = self.cached(app_key).await { + return Ok(entry); + } + + let entry = self + .fetch_with_retry(app_key, app_secret, api_domain, route_env) + .await?; + let mut cache = self.cache.lock().await; + cache.insert(app_key.to_string(), entry.clone()); + Ok(entry) + } + + async fn get_refresh_lock(&self, app_key: &str) -> Arc> { + let mut locks = self.locks.lock().await; + locks + .entry(app_key.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone() + } + + async fn fetch_with_retry( + &self, + app_key: &str, + app_secret: &str, + api_domain: &str, + route_env: &str, + ) -> Result { + let url = format!("{}{}", api_domain.trim_end_matches('/'), SIGN_PATH); + let mut last_err: Option = None; + + for attempt in 0..=MAX_RETRIES { + let nonce = generate_nonce(); + let timestamp = build_timestamp(); + let signature = compute_signature(&nonce, ×tamp, app_key, app_secret); + let payload = serde_json::json!({ + "app_key": app_key, + "nonce": nonce, + "signature": signature, + "timestamp": timestamp, + }); + + let mut req = self + .http + .post(&url) + .timeout(Duration::from_secs(HTTP_TIMEOUT_SECS)) + .header("Content-Type", "application/json") + .header("X-AppVersion", "openhuman/0.1.0") + .header("X-OperationSystem", "linux") + .header( + "X-Instance-Id", + super::proto_constants::OPENHUMAN_INSTANCE_ID, + ) + .header("X-Bot-Version", "openhuman/0.1.0"); + if !route_env.is_empty() { + req = req.header("X-Route-Env", route_env); + } + + info!( + "[yuanbao/sign] POST {}{}", + url, + if attempt > 0 { + format!(" (retry {attempt}/{MAX_RETRIES})") + } else { + String::new() + } + ); + + let resp = match req.json(&payload).send().await { + Ok(r) => r, + Err(e) => { + last_err = Some(YuanbaoError::Connection(format!("sign-token: {e}"))); + if attempt < MAX_RETRIES { + tokio::time::sleep(Duration::from_millis(RETRY_DELAY_MS)).await; + continue; + } + break; + } + }; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(YuanbaoError::AuthFailed(format!( + "sign-token HTTP {status}: {}", + &body.chars().take(200).collect::() + ))); + } + + let json: serde_json::Value = resp + .json() + .await + .map_err(|e| YuanbaoError::AuthFailed(format!("sign-token body: {e}")))?; + + let code = json.get("code").and_then(|c| c.as_i64()).unwrap_or(0); + if code == 0 { + let data = match json.get("data") { + Some(d) if d.is_object() => d, + _ => { + return Err(YuanbaoError::AuthFailed( + "sign-token response missing 'data'".into(), + )) + } + }; + let duration = data + .get("duration") + .and_then(|v| v.as_u64()) + .unwrap_or(DEFAULT_DURATION_SECS); + let entry = TokenEntry { + token: data + .get("token") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + bot_id: data + .get("bot_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + product: data + .get("product") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + source: data + .get("source") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + expire_ts: unix_now() + duration, + }; + info!( + "[yuanbao/sign] success: bot_id={} duration={}s", + entry.bot_id, duration + ); + return Ok(entry); + } + + if code == RETRYABLE_CODE && attempt < MAX_RETRIES { + warn!( + "[yuanbao/sign] retryable code={code}, retrying in {RETRY_DELAY_MS}ms (attempt {})", + attempt + 1 + ); + tokio::time::sleep(Duration::from_millis(RETRY_DELAY_MS)).await; + continue; + } + + let msg = json + .get("msg") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + return Err(YuanbaoError::AuthFailed(format!( + "sign-token code={code} msg={msg}" + ))); + } + + Err(last_err.unwrap_or(YuanbaoError::AuthFailed( + "sign-token max retries exceeded".into(), + ))) + } + + /// Drop all per-app_key locks. Called on channel shutdown to avoid + /// leaking entries across reconnects within the same process. + pub async fn clear_locks(&self) { + let mut locks = self.locks.lock().await; + locks.clear(); + } +} + +fn unix_now() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0) +} + +fn suffix(s: &str) -> &str { + if s.len() <= 4 { + s + } else { + &s[s.len() - 4..] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn signature_matches_python_reference() { + // Reproducible vector — hand-computed: + // plain = "n123" + "2026-05-19T22:00:00+08:00" + "app_k" + "secret" + // sig = HMAC-SHA256(key="secret", msg=plain) as lower hex + let sig = compute_signature("n123", "2026-05-19T22:00:00+08:00", "app_k", "secret"); + // We don't pin the exact bytes (would require running Python to confirm) — + // instead verify the contract: same inputs → same output, 64-char hex. + assert_eq!(sig.len(), 64); + assert!(sig.chars().all(|c| c.is_ascii_hexdigit())); + let sig2 = compute_signature("n123", "2026-05-19T22:00:00+08:00", "app_k", "secret"); + assert_eq!(sig, sig2); + } + + #[test] + fn signature_varies_with_inputs() { + let s1 = compute_signature("n1", "t", "ak", "sk"); + let s2 = compute_signature("n2", "t", "ak", "sk"); + let s3 = compute_signature("n1", "t2", "ak", "sk"); + let s4 = compute_signature("n1", "t", "ak2", "sk"); + let s5 = compute_signature("n1", "t", "ak", "sk2"); + let all = [&s1, &s2, &s3, &s4, &s5]; + for (i, a) in all.iter().enumerate() { + for (j, b) in all.iter().enumerate() { + if i != j { + assert_ne!(a, b, "inputs {i} vs {j} should differ"); + } + } + } + } + + #[test] + fn nonce_is_32_char_hex() { + let n = generate_nonce(); + assert_eq!(n.len(), 32); + assert!(n.chars().all(|c| c.is_ascii_hexdigit())); + } + + #[test] + fn timestamp_matches_beijing_format() { + let t = build_timestamp(); + // 2006-01-02T15:04:05+08:00 → length 25 + assert_eq!(t.len(), 25); + assert!(t.ends_with("+08:00")); + assert_eq!(&t[4..5], "-"); + assert_eq!(&t[7..8], "-"); + assert_eq!(&t[10..11], "T"); + assert_eq!(&t[13..14], ":"); + } + + #[test] + fn token_entry_is_valid_only_with_margin() { + let mut e = TokenEntry { + token: "t".into(), + bot_id: "b".into(), + product: String::new(), + source: String::new(), + expire_ts: unix_now() + 120, + }; + assert!(e.is_valid()); + e.expire_ts = unix_now() + 30; // less than 60s margin + assert!(!e.is_valid()); + e.expire_ts = unix_now().saturating_sub(10); + assert!(!e.is_valid()); + } + + #[tokio::test] + async fn cache_returns_entry_when_valid() { + let mgr = SignManager::new(reqwest::Client::new()); + let entry = TokenEntry { + token: "tok".into(), + bot_id: "bot".into(), + product: String::new(), + source: String::new(), + expire_ts: unix_now() + 600, + }; + mgr.cache.lock().await.insert("ak".into(), entry.clone()); + let got = mgr.cached("ak").await.expect("cache hit"); + assert_eq!(got.token, "tok"); + } + + #[tokio::test] + async fn cache_drops_expired_entry() { + let mgr = SignManager::new(reqwest::Client::new()); + mgr.cache.lock().await.insert( + "ak".into(), + TokenEntry { + token: "tok".into(), + bot_id: "bot".into(), + product: String::new(), + source: String::new(), + expire_ts: unix_now() + 10, // under margin + }, + ); + assert!(mgr.cached("ak").await.is_none()); + } + + #[test] + fn token_entry_seconds_remaining_is_signed() { + let e_future = TokenEntry { + token: "t".into(), + bot_id: "b".into(), + product: String::new(), + source: String::new(), + expire_ts: unix_now() + 300, + }; + assert!(e_future.seconds_remaining() >= 290); + let e_past = TokenEntry { + expire_ts: unix_now().saturating_sub(60), + ..e_future + }; + assert!(e_past.seconds_remaining() <= 0); + } + + #[test] + fn suffix_redacts_to_last_4_chars() { + assert_eq!(suffix(""), ""); + assert_eq!(suffix("a"), "a"); + assert_eq!(suffix("abcd"), "abcd"); + assert_eq!(suffix("abcdef"), "cdef"); + assert_eq!(suffix("0123456789"), "6789"); + } + + // ─── refresh / fetch_with_retry via wiremock ──────────────── + + fn ok_body(token: &str, bot_id: &str, duration_secs: u64) -> serde_json::Value { + serde_json::json!({ + "code": 0, + "msg": "ok", + "data": { + "token": token, + "bot_id": bot_id, + "product": "prod1", + "source": "src1", + "duration": duration_secs, + } + }) + } + + #[tokio::test] + async fn get_token_fetches_and_caches_on_first_call() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(SIGN_PATH)) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(ok_body("tok-1", "bot-1", 7200)), + ) + .mount(&server) + .await; + let mgr = SignManager::new(reqwest::Client::new()); + let e = mgr + .get_token("ak", "sk", &server.uri(), "") + .await + .expect("token"); + assert_eq!(e.token, "tok-1"); + assert_eq!(e.bot_id, "bot-1"); + assert!(e.expire_ts > unix_now() + 7000); + + // Second call should hit the cache (still works even if server stops). + let cached = mgr.cached("ak").await.expect("cached"); + assert_eq!(cached.token, "tok-1"); + } + + #[tokio::test] + async fn get_token_retries_on_code_10099_then_succeeds() { + let server = wiremock::MockServer::start().await; + // First two requests return code=10099, third returns code=0. + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(SIGN_PATH)) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 10099, + "msg": "try again", + })), + ) + .up_to_n_times(2) + .mount(&server) + .await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(SIGN_PATH)) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(ok_body("tok-r", "bot-r", 600)), + ) + .mount(&server) + .await; + let mgr = SignManager::new(reqwest::Client::new()); + let e = mgr.refresh("ak", "sk", &server.uri(), "").await.unwrap(); + assert_eq!(e.token, "tok-r"); + assert_eq!(e.bot_id, "bot-r"); + } + + #[tokio::test] + async fn get_token_surfaces_http_error_as_auth_failed() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .respond_with(wiremock::ResponseTemplate::new(401).set_body_string("Unauthorized")) + .mount(&server) + .await; + let mgr = SignManager::new(reqwest::Client::new()); + let err = mgr + .get_token("ak", "sk", &server.uri(), "") + .await + .unwrap_err(); + match err { + YuanbaoError::AuthFailed(m) => assert!(m.contains("HTTP 401"), "got {m}"), + other => panic!("expected AuthFailed, got {other:?}"), + } + } + + #[tokio::test] + async fn get_token_fails_on_non_zero_business_code() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 40001, + "msg": "bad secret", + })), + ) + .mount(&server) + .await; + let mgr = SignManager::new(reqwest::Client::new()); + let err = mgr + .get_token("ak", "sk", &server.uri(), "") + .await + .unwrap_err(); + match err { + YuanbaoError::AuthFailed(m) => { + assert!(m.contains("code=40001"), "got {m}"); + assert!(m.contains("bad secret"), "got {m}"); + } + other => panic!("expected AuthFailed, got {other:?}"), + } + } + + #[tokio::test] + async fn force_refresh_evicts_cache_and_refetches() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(SIGN_PATH)) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(ok_body("tok-a", "bot", 600)), + ) + .up_to_n_times(1) + .mount(&server) + .await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(SIGN_PATH)) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(ok_body("tok-b", "bot", 600)), + ) + .mount(&server) + .await; + let mgr = SignManager::new(reqwest::Client::new()); + let first = mgr.get_token("ak", "sk", &server.uri(), "").await.unwrap(); + assert_eq!(first.token, "tok-a"); + let second = mgr + .force_refresh("ak", "sk", &server.uri(), "to_env") + .await + .unwrap(); + assert_eq!(second.token, "tok-b"); + } + + #[tokio::test] + async fn clear_locks_drops_all_per_app_key_mutexes() { + let mgr = SignManager::new(reqwest::Client::new()); + // Prime the locks map. + let _ = mgr.get_refresh_lock("ak-1").await; + let _ = mgr.get_refresh_lock("ak-2").await; + assert_eq!(mgr.locks.lock().await.len(), 2); + mgr.clear_locks().await; + assert!(mgr.locks.lock().await.is_empty()); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/splitter.rs b/src/openhuman/channels/providers/yuanbao/splitter.rs new file mode 100644 index 0000000000..88784a7814 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/splitter.rs @@ -0,0 +1,209 @@ +//! Fence-aware Markdown splitter. +//! +//! When a long AI response is split into N chunks for the Yuanbao +//! `max_message_length` cap, we must not break inside: +//! - a fenced code block (``` … ``` or ~~~ … ~~~) +//! - a Markdown table row (lines starting with `|`) +//! - a list-continuation block +//! +//! Strategy: walk the input by line, tracking fence/table state, and +//! emit a chunk every time adding the next line would push the buffer +//! past the cap **and** the cap boundary is safe (not inside a fence, +//! not in the middle of a table). If a single line is itself longer +//! than the cap, hard-split at a char boundary. + +/// Split `text` into chunks no larger than `cap_bytes` (utf-8 byte count), +/// preserving fenced code blocks and table rows where possible. +pub fn split_markdown(text: &str, cap_bytes: usize) -> Vec { + if text.len() <= cap_bytes { + return vec![text.to_string()]; + } + let cap = cap_bytes.max(1); + // Reserve a small headroom so the trailing newline / final char fits + // when we flush. For very small caps fall back to no margin so callers + // testing tight bounds (cap=20) still get chunks under the cap. + let safe_cap = if cap >= 32 { + cap.saturating_sub(16) + } else { + cap + }; + + let mut chunks: Vec = Vec::new(); + let mut buf = String::with_capacity(cap); + let mut in_fence = false; + let mut fence_marker: Option = None; + + for line in text.split_inclusive('\n') { + let trimmed = line.trim_end_matches('\n'); + let starts_fence = is_fence(trimmed); + + // If this single line is wider than the cap, we must hard-split it. + if line.len() > cap { + flush(&mut chunks, &mut buf); + for piece in hard_split(line, cap) { + chunks.push(piece); + } + continue; + } + + let candidate_len = buf.len() + line.len(); + if candidate_len > safe_cap && !buf.is_empty() && safe_to_break(in_fence) { + flush(&mut chunks, &mut buf); + } + buf.push_str(line); + + if let Some(marker) = starts_fence { + if let Some(open) = &fence_marker { + if marker == *open { + in_fence = false; + fence_marker = None; + } + } else { + in_fence = true; + fence_marker = Some(marker); + } + } + } + flush(&mut chunks, &mut buf); + + // Drop empty trailing chunks (can happen if input ends on newline). + chunks.retain(|c| !c.trim().is_empty()); + chunks +} + +fn flush(chunks: &mut Vec, buf: &mut String) { + if !buf.is_empty() { + chunks.push(buf.trim_end().to_string()); + buf.clear(); + } +} + +fn safe_to_break(in_fence: bool) -> bool { + !in_fence +} + +/// If `line` opens or closes a fenced code block, return the marker text +/// (e.g. "```" or "~~~"). A line that contains a fence in the middle is +/// NOT a fence marker; only lines that *start* with three or more +/// backticks/tildes count. +fn is_fence(line: &str) -> Option { + let trimmed = line.trim_start(); + if let Some(rest) = trimmed.strip_prefix("```") { + // Allow optional language tag after the fence. + let _ = rest; + return Some("```".into()); + } + if let Some(rest) = trimmed.strip_prefix("~~~") { + let _ = rest; + return Some("~~~".into()); + } + None +} + +/// Last-resort splitter for a line that's wider than the cap. +fn hard_split(line: &str, cap: usize) -> Vec { + let mut out = Vec::new(); + let mut remaining = line; + while !remaining.is_empty() { + if remaining.len() <= cap { + out.push(remaining.to_string()); + break; + } + let mut idx = cap; + while idx > 0 && !remaining.is_char_boundary(idx) { + idx -= 1; + } + if idx == 0 { + // pathological — emit one char at a time + let take = remaining + .char_indices() + .nth(1) + .map(|(i, _)| i) + .unwrap_or(remaining.len()); + let (chunk, rest) = remaining.split_at(take); + out.push(chunk.to_string()); + remaining = rest; + } else { + let (chunk, rest) = remaining.split_at(idx); + out.push(chunk.to_string()); + remaining = rest; + } + } + out +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn short_input_returns_one_chunk() { + let r = split_markdown("hello", 100); + assert_eq!(r, vec!["hello"]); + } + + #[test] + fn splits_on_newlines_respecting_cap() { + let input = "a\n".repeat(100); + let r = split_markdown(&input, 20); + assert!(r.len() > 1); + for c in &r { + assert!(c.len() <= 20, "chunk too long: {c:?}"); + } + } + + #[test] + fn preserves_fenced_code_block() { + let input = "intro line\n\ + ```rust\n\ + fn long_function_a() -> u32 { 42 }\n\ + fn long_function_b() -> u32 { 43 }\n\ + fn long_function_c() -> u32 { 44 }\n\ + ```\n\ + trailing text"; + let chunks = split_markdown(input, 80); + // Find the chunk(s) containing the fence — they must not split mid-fence. + let mut open = 0; + for c in &chunks { + for line in c.lines() { + if is_fence(line).is_some() { + open += 1; + } + } + } + // The fence must appear as balanced pairs. + assert_eq!(open % 2, 0, "unbalanced fences after split: {chunks:#?}"); + } + + #[test] + fn hard_split_very_long_line() { + let line = "x".repeat(500); + let r = split_markdown(&line, 100); + for c in &r { + assert!(c.len() <= 100, "chunk too long: {}", c.len()); + } + assert_eq!(r.join("").len(), 500); + } + + #[test] + fn unicode_safe_hard_split() { + let line = "中".repeat(200); // each char is 3 bytes → 600 total + let r = split_markdown(&line, 50); + for c in &r { + assert!(c.len() <= 50, "chunk too long: {}", c.len()); + // verify it's valid utf-8 by reading it + for ch in c.chars() { + assert!(ch == '中'); + } + } + } + + #[test] + fn is_fence_detects_backticks() { + assert_eq!(is_fence("```").as_deref(), Some("```")); + assert_eq!(is_fence("```rust").as_deref(), Some("```")); + assert_eq!(is_fence("~~~").as_deref(), Some("~~~")); + assert_eq!(is_fence("text").as_deref(), None); + assert_eq!(is_fence("``").as_deref(), None); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/types.rs b/src/openhuman/channels/providers/yuanbao/types.rs new file mode 100644 index 0000000000..b2215dc2e7 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/types.rs @@ -0,0 +1,441 @@ +//! Shared domain types for the Yuanbao channel. +//! +//! Field naming follows the upstream Yuanbao protocol (`from_account`, +//! `group_code`, `msg_id`, etc.) so that the protobuf decoders, the +//! inbound pipeline, and the outbound encoders can all share the +//! same `InboundMessage` / `MsgBodyElement` shapes without re-mapping. +//! +//! Source of truth: `gateway/platforms/yuanbao_proto.py` in +//! hermes-agent (lines 415-705). + +use serde::{Deserialize, Serialize}; + +/// Decoded ConnMsg envelope (head + payload). +#[derive(Debug, Clone)] +pub struct ConnFrame { + /// CmdType (`CMD_TYPE`): Request=0, Response=1, Push=2, PushAck=3. + pub cmd_type: u32, + /// Command word, e.g. `"auth-bind"`, `"ping"`, `"send_c2c_message"`. + pub cmd: String, + /// Module / service name, e.g. `"conn_access"` or `"yuanbao_openclaw_proxy"`. + pub module: String, + /// Per-message sequence number. + pub seq_no: u32, + /// Application-level message id. + pub msg_id: String, + /// Whether the server expects an ACK. + pub need_ack: bool, + /// Status code (head.status, field 10). + pub status: u32, + /// Biz payload bytes (ConnMsg.data, field 2). + pub data: Vec, +} + +/// One element of the TIM-style `msg_body` array (e.g. text, image, file). +#[derive(Debug, Clone, Default, PartialEq)] +pub struct MsgBodyElement { + /// `"TIMTextElem"`, `"TIMImageElem"`, `"TIMFileElem"`, `"TIMSoundElem"`, … + pub msg_type: String, + pub msg_content: MsgContent, +} + +/// Generic union of all TIM `msg_content` shapes (text/image/file/sound). +/// +/// Only the fields relevant to the active `msg_type` are populated; the +/// rest stay at their `Default`. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct MsgContent { + /// Field 1 — text payload. + pub text: Option, + /// Field 2 — file uuid (MD5 for images/files). + pub uuid: Option, + /// Field 3 — image format code (1=JPEG, 2=GIF, 3=PNG, 4=BMP, 255=WEBP). + pub image_format: Option, + /// Field 4 — raw inline data (rarely used; usually `url` is set instead). + pub data: Option, + /// Field 5 — element description. + pub desc: Option, + /// Field 6 — extension JSON / blob. + pub ext: Option, + /// Field 7 — voice payload identifier. + pub sound: Option, + /// Field 8 — repeated `ImageInfo` for the image element. + pub image_info_array: Vec, + /// Field 9 — element index within a multi-image message. + pub index: Option, + /// Field 10 — resource URL. + pub url: Option, + /// Field 11 — file size in bytes. + pub file_size: Option, + /// Field 12 — file name. + pub file_name: Option, +} + +/// Per-resolution image variant. `type` is 1=original, 2=large, 3=thumb. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct ImageInfo { + pub image_type: u32, + pub size: u32, + pub width: u32, + pub height: u32, + pub url: String, +} + +/// A single recall entry in `recall_msg_seq_list` (InboundMessagePush field 17). +#[derive(Debug, Clone, Default, PartialEq)] +pub struct ImMsgSeq { + pub msg_seq: u32, + pub msg_id: String, +} + +/// A decoded `InboundMessagePush` biz payload — what the yuanbao gateway +/// pushes down to us for every incoming message. +#[derive(Debug, Clone, Default)] +pub struct InboundMessage { + pub callback_command: String, + pub from_account: String, + pub to_account: String, + pub sender_nickname: String, + /// Empty string for DMs, group ID for group messages. + pub group_id: String, + /// Empty string for DMs, group code (canonical group ref) for group messages. + pub group_code: String, + pub group_name: String, + pub msg_seq: u32, + pub msg_random: u32, + /// Server-side message timestamp (seconds since epoch). + pub msg_time: u32, + pub msg_key: String, + /// Stable application-level message ID. + pub msg_id: String, + pub msg_body: Vec, + pub cloud_custom_data: String, + pub event_time: u32, + pub bot_owner_id: String, + pub recall_msg_seq_list: Vec, + pub claw_msg_type: u32, + pub private_from_group_code: String, + pub trace_id: String, +} + +impl InboundMessage { + /// Whether this is a group message. + pub fn is_group(&self) -> bool { + !self.group_code.is_empty() + } + + /// Whether the message looks like a recall notification. + pub fn is_recall(&self) -> bool { + !self.recall_msg_seq_list.is_empty() + } + + /// Routing key — group_code for groups, sender uid for DMs. + pub fn chat_id(&self) -> &str { + if self.is_group() { + &self.group_code + } else { + &self.from_account + } + } + + /// Concatenated text content (joins all `TIMTextElem`s). + pub fn extract_text(&self) -> String { + let mut out = String::new(); + for el in &self.msg_body { + if el.msg_type == "TIMTextElem" { + if let Some(ref t) = el.msg_content.text { + if !out.is_empty() { + out.push('\n'); + } + out.push_str(t); + } + } + } + out + } + + /// All image URLs in the message (from `TIMImageElem` elements). + pub fn extract_image_urls(&self) -> Vec { + let mut urls = Vec::new(); + for el in &self.msg_body { + if el.msg_type == "TIMImageElem" { + for info in &el.msg_content.image_info_array { + if !info.url.is_empty() { + urls.push(info.url.clone()); + } + } + if let Some(ref url) = el.msg_content.url { + if !url.is_empty() && !urls.contains(url) { + urls.push(url.clone()); + } + } + } + } + urls + } +} + +/// High-level classification produced by the inbound pipeline. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum MessageKind { + #[default] + Text, + Image, + File, + Voice, + Mixed, + /// Recall notification — handled by `RecallGuard`, never dispatched. + Recall, +} + +/// Where the message came from — used by the outbound side to address replies. +#[derive(Debug, Clone, Default)] +pub struct Source { + pub from_account: String, + pub sender_nickname: String, + pub group_code: String, + /// `true` for group chats, `false` for DMs. + pub is_group: bool, +} + +impl Source { + /// Stable string for `ChannelMessage.sender` / `reply_target` — + /// `g:` for groups, raw uid for DMs. This format also + /// round-trips through `parse_recipient` in `outbound.rs`. + pub fn reply_target(&self) -> String { + if self.is_group { + format!("g:{}", self.group_code) + } else { + self.from_account.clone() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn text_elem(s: &str) -> MsgBodyElement { + MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: MsgContent { + text: Some(s.into()), + ..Default::default() + }, + } + } + + fn image_elem(info_urls: &[&str], inline_url: Option<&str>) -> MsgBodyElement { + MsgBodyElement { + msg_type: "TIMImageElem".into(), + msg_content: MsgContent { + image_info_array: info_urls + .iter() + .map(|u| ImageInfo { + image_type: 1, + url: (*u).into(), + ..Default::default() + }) + .collect(), + url: inline_url.map(String::from), + ..Default::default() + }, + } + } + + #[test] + fn dm_is_not_group() { + let m = InboundMessage { + from_account: "alice".into(), + ..Default::default() + }; + assert!(!m.is_group()); + assert_eq!(m.chat_id(), "alice"); + } + + #[test] + fn group_is_group_and_chat_id_is_group_code() { + let m = InboundMessage { + group_code: "grp_42".into(), + from_account: "alice".into(), + ..Default::default() + }; + assert!(m.is_group()); + assert_eq!(m.chat_id(), "grp_42"); + } + + #[test] + fn is_recall_iff_recall_list_non_empty() { + let mut m = InboundMessage::default(); + assert!(!m.is_recall()); + m.recall_msg_seq_list.push(ImMsgSeq { + msg_seq: 7, + msg_id: "x".into(), + }); + assert!(m.is_recall()); + } + + #[test] + fn extract_text_concatenates_text_elements() { + let m = InboundMessage { + msg_body: vec![ + text_elem("hello"), + text_elem("world"), + image_elem(&[], None), + ], + ..Default::default() + }; + assert_eq!(m.extract_text(), "hello\nworld"); + } + + #[test] + fn extract_text_ignores_text_none_and_non_text() { + let m = InboundMessage { + msg_body: vec![ + MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: MsgContent::default(), // text: None + }, + image_elem(&["https://x/y.png"], None), + ], + ..Default::default() + }; + assert_eq!(m.extract_text(), ""); + } + + #[test] + fn extract_text_on_empty_msg_body_returns_empty() { + let m = InboundMessage::default(); + assert_eq!(m.extract_text(), ""); + } + + #[test] + fn extract_image_urls_from_image_info_array() { + let m = InboundMessage { + msg_body: vec![image_elem(&["https://a/1.png", "https://a/2.png"], None)], + ..Default::default() + }; + assert_eq!( + m.extract_image_urls(), + vec!["https://a/1.png".to_string(), "https://a/2.png".into()] + ); + } + + #[test] + fn extract_image_urls_falls_back_to_inline_url_field() { + let m = InboundMessage { + msg_body: vec![image_elem(&[], Some("https://a/inline.png"))], + ..Default::default() + }; + assert_eq!( + m.extract_image_urls(), + vec!["https://a/inline.png".to_string()] + ); + } + + #[test] + fn extract_image_urls_dedups_inline_when_already_in_info_array() { + let dup = "https://a/dup.png"; + let m = InboundMessage { + msg_body: vec![image_elem(&[dup], Some(dup))], + ..Default::default() + }; + assert_eq!(m.extract_image_urls(), vec![dup.to_string()]); + } + + #[test] + fn extract_image_urls_skips_empty_url_in_info_array() { + let m = InboundMessage { + msg_body: vec![image_elem(&[""], None)], + ..Default::default() + }; + assert!(m.extract_image_urls().is_empty()); + } + + #[test] + fn extract_image_urls_ignores_text_elements() { + let m = InboundMessage { + msg_body: vec![text_elem("hi"), image_elem(&["https://a/1.png"], None)], + ..Default::default() + }; + assert_eq!(m.extract_image_urls(), vec!["https://a/1.png".to_string()]); + } + + #[test] + fn source_reply_target_dm_is_raw_uid() { + let s = Source { + from_account: "uid_alice".into(), + is_group: false, + ..Default::default() + }; + assert_eq!(s.reply_target(), "uid_alice"); + } + + #[test] + fn source_reply_target_group_uses_g_prefix() { + let s = Source { + group_code: "grp_42".into(), + is_group: true, + ..Default::default() + }; + assert_eq!(s.reply_target(), "g:grp_42"); + } + + #[test] + fn message_kind_default_is_text() { + assert_eq!(MessageKind::default(), MessageKind::Text); + } +} + +/// Group metadata returned by `QueryGroupInfo`. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct GroupInfo { + pub code: i32, + pub message: String, + pub group_name: String, + pub owner_id: String, + pub owner_nickname: String, + pub member_count: u32, +} + +/// One member returned by `GetGroupMemberList`. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct GroupMember { + pub user_id: String, + pub nickname: String, + /// 0=member, 1=admin, 2=owner. + pub role: u32, + pub join_time: u32, + pub name_card: String, +} + +/// Paginated result of `GetGroupMemberList`. +#[derive(Debug, Clone, Default)] +pub struct GroupMemberListPage { + pub code: i32, + pub message: String, + pub members: Vec, + pub next_offset: u32, + pub is_complete: bool, +} + +/// Cached account info — populated after `auth-bind` succeeds. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct Account { + /// Bot user-id (used as `from_account` in outbound messages). + pub uid: String, + /// Display name (best-effort; may be empty until first inbound message). + pub nickname: String, + /// Server-assigned connection id (`AuthBindRsp.connect_id`, field 3). + pub connect_id: String, +} + +/// Connection state machine (matches task list spec). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnectionState { + Disconnected, + Connecting, + Authenticating, + Connected, + Reconnecting, +} diff --git a/src/openhuman/channels/providers/yuanbao/wire.rs b/src/openhuman/channels/providers/yuanbao/wire.rs new file mode 100644 index 0000000000..6c5582c8c6 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/wire.rs @@ -0,0 +1,387 @@ +//! Hand-rolled protobuf wire-format primitives. +//! +//! Only varints, length-delimited bytes, and the two fixed-width forms +//! are supported — that's everything the yuanbao protocol uses. Kept +//! separate from `proto.rs` so the latter stays under 500 lines and +//! reads as a "schema" file. + +use std::sync::atomic::{AtomicU32, Ordering}; + +use super::errors::YuanbaoError; + +/// Global per-process sequence number for ConnMsg head.seq_no. +static SEQ: AtomicU32 = AtomicU32::new(1); + +pub fn next_seq_no() -> u32 { + SEQ.fetch_add(1, Ordering::Relaxed) +} + +pub const WT_VARINT: u8 = 0; +pub const WT_LEN: u8 = 2; + +// ─── Varint ───────────────────────────────────────────────────────── + +pub fn encode_varint(mut value: u64, buf: &mut Vec) { + loop { + let mut byte = (value & 0x7F) as u8; + value >>= 7; + if value != 0 { + byte |= 0x80; + } + buf.push(byte); + if value == 0 { + break; + } + } +} + +pub fn decode_varint(data: &[u8], pos: usize) -> Result<(u64, usize), YuanbaoError> { + let mut value: u64 = 0; + let mut shift: u32 = 0; + let mut i = pos; + loop { + if i >= data.len() { + return Err(YuanbaoError::ProtoDecode("truncated varint".into())); + } + let byte = data[i]; + value |= ((byte & 0x7F) as u64) << shift; + i += 1; + if byte & 0x80 == 0 { + return Ok((value, i - pos)); + } + shift += 7; + if shift >= 64 { + return Err(YuanbaoError::ProtoDecode("varint too long".into())); + } + } +} + +// ─── Field encoders ──────────────────────────────────────────────── + +pub fn encode_field_varint(field: u32, value: u64, buf: &mut Vec) { + encode_varint(((field as u64) << 3) | WT_VARINT as u64, buf); + encode_varint(value, buf); +} + +pub fn encode_field_bytes(field: u32, data: &[u8], buf: &mut Vec) { + encode_varint(((field as u64) << 3) | WT_LEN as u64, buf); + encode_varint(data.len() as u64, buf); + buf.extend_from_slice(data); +} + +pub fn encode_field_string(field: u32, s: &str, buf: &mut Vec) { + encode_field_bytes(field, s.as_bytes(), buf); +} + +// ─── Field parsing ────────────────────────────────────────────────── + +#[derive(Debug)] +pub enum FieldValue { + Varint(u64), + Bytes(Vec), + Fixed32(u32), + Fixed64(u64), +} + +pub fn parse_fields(data: &[u8]) -> Result, YuanbaoError> { + let mut out = Vec::new(); + let mut pos = 0; + while pos < data.len() { + let (tag, n) = decode_varint(data, pos)?; + pos += n; + let field = (tag >> 3) as u32; + let wire = (tag & 0x07) as u8; + match wire { + WT_VARINT => { + let (v, n) = decode_varint(data, pos)?; + pos += n; + out.push((field, FieldValue::Varint(v))); + } + WT_LEN => { + let (len, n) = decode_varint(data, pos)?; + pos += n; + // Use checked conversions / arithmetic — a crafted oversize + // varint length would otherwise overflow `usize` on 32-bit + // targets and panic during slicing. + let len_usize = usize::try_from(len).map_err(|_| { + YuanbaoError::ProtoDecode(format!( + "len field {field} too large for platform: {len}" + )) + })?; + let end = pos.checked_add(len_usize).ok_or_else(|| { + YuanbaoError::ProtoDecode(format!( + "len field {field} overflows position: pos={pos} len={len}" + )) + })?; + if end > data.len() { + return Err(YuanbaoError::ProtoDecode(format!( + "truncated len field {field}: need {len} have {}", + data.len() - pos + ))); + } + out.push((field, FieldValue::Bytes(data[pos..end].to_vec()))); + pos = end; + } + 1 => { + if pos + 8 > data.len() { + return Err(YuanbaoError::ProtoDecode("truncated fixed64".into())); + } + let v = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap()); + pos += 8; + out.push((field, FieldValue::Fixed64(v))); + } + 5 => { + if pos + 4 > data.len() { + return Err(YuanbaoError::ProtoDecode("truncated fixed32".into())); + } + let v = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()); + pos += 4; + out.push((field, FieldValue::Fixed32(v))); + } + other => { + return Err(YuanbaoError::ProtoDecode(format!( + "unsupported wire type {other} at field {field}" + ))); + } + } + } + Ok(out) +} + +pub fn get_string(fields: &[(u32, FieldValue)], num: u32) -> String { + for (n, v) in fields { + if *n == num { + if let FieldValue::Bytes(b) = v { + return String::from_utf8_lossy(b).into_owned(); + } + } + } + String::new() +} + +pub fn get_varint(fields: &[(u32, FieldValue)], num: u32) -> u64 { + for (n, v) in fields { + if *n == num { + if let FieldValue::Varint(x) = v { + return *x; + } + } + } + 0 +} + +pub fn get_bytes(fields: &[(u32, FieldValue)], num: u32) -> Vec { + for (n, v) in fields { + if *n == num { + if let FieldValue::Bytes(b) = v { + return b.clone(); + } + } + } + Vec::new() +} + +pub fn get_repeated_bytes(fields: &[(u32, FieldValue)], num: u32) -> Vec> { + fields + .iter() + .filter_map(|(n, v)| match v { + FieldValue::Bytes(b) if *n == num => Some(b.clone()), + _ => None, + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn varint_roundtrip() { + for &v in &[0u64, 1, 127, 128, 300, 16384, u32::MAX as u64, u64::MAX] { + let mut buf = Vec::new(); + encode_varint(v, &mut buf); + let (got, n) = decode_varint(&buf, 0).unwrap(); + assert_eq!(got, v, "varint roundtrip failed for {v}"); + assert_eq!(n, buf.len()); + } + } + + #[test] + fn varint_truncated_errors() { + let buf = vec![0x80, 0x80]; // continuation bit set but no end + assert!(decode_varint(&buf, 0).is_err()); + } + + #[test] + fn field_roundtrip() { + let mut buf = Vec::new(); + encode_field_varint(1, 42, &mut buf); + encode_field_string(2, "hello", &mut buf); + encode_field_bytes(3, b"\x00\x01\x02", &mut buf); + + let fields = parse_fields(&buf).unwrap(); + assert_eq!(get_varint(&fields, 1), 42); + assert_eq!(get_string(&fields, 2), "hello"); + assert_eq!(get_bytes(&fields, 3), vec![0, 1, 2]); + } + + #[test] + fn unknown_field_skipped_gracefully() { + let mut buf = Vec::new(); + encode_field_varint(99, 123, &mut buf); + encode_field_string(1, "wanted", &mut buf); + let fields = parse_fields(&buf).unwrap(); + assert_eq!(get_string(&fields, 1), "wanted"); + assert_eq!(get_string(&fields, 2), ""); // missing field returns default + } + + #[test] + fn seq_numbers_are_monotonic() { + let a = next_seq_no(); + let b = next_seq_no(); + assert!(b > a); + } + + #[test] + fn varint_too_long_errors() { + // 11 continuation bytes overflows the 64-bit shift guard. + let buf = vec![0x80; 11]; + match decode_varint(&buf, 0).unwrap_err() { + YuanbaoError::ProtoDecode(m) => assert!(m.contains("too long"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn parse_fields_truncated_bytes_field_errors() { + // Field 1 (wire type 2) declaring length 5 but only 1 byte of payload. + let buf = vec![ + (1 << 3) | 2, // tag: field=1, wire=2 + 5, // claimed len + b'a', + ]; + match parse_fields(&buf).unwrap_err() { + YuanbaoError::ProtoDecode(m) => assert!(m.contains("truncated"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn parse_fields_oversize_len_field_errors_without_panic() { + // Field 1 (wire type 2) with a varint length encoding `u64::MAX` — + // previously this would attempt `pos + len as usize`, overflowing + // on 32-bit and slicing past the buffer on 64-bit. Now it must + // return a structured decode error. + let mut buf = Vec::new(); + buf.push((1 << 3) | 2); // tag: field=1, wire=2 + encode_varint(u64::MAX, &mut buf); // adversarial length + match parse_fields(&buf).unwrap_err() { + YuanbaoError::ProtoDecode(m) => { + assert!( + m.contains("too large") || m.contains("overflows") || m.contains("truncated"), + "expected overflow/truncation error, got {m}" + ); + } + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn parse_fields_reads_fixed64() { + let mut buf = Vec::new(); + buf.push((1 << 3) | 1); // tag: field=1, wire=1 (fixed64) + buf.extend_from_slice(&0x1122_3344_5566_7788u64.to_le_bytes()); + let f = parse_fields(&buf).unwrap(); + match f[0].1 { + FieldValue::Fixed64(v) => assert_eq!(v, 0x1122_3344_5566_7788), + ref other => panic!("expected Fixed64 got {other:?}"), + } + } + + #[test] + fn parse_fields_truncated_fixed64_errors() { + let mut buf = Vec::new(); + buf.push((1 << 3) | 1); + buf.extend_from_slice(&[0u8; 4]); // only 4/8 bytes + match parse_fields(&buf).unwrap_err() { + YuanbaoError::ProtoDecode(m) => assert!(m.contains("fixed64"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn parse_fields_reads_fixed32() { + let mut buf = Vec::new(); + buf.push((1 << 3) | 5); // tag: field=1, wire=5 (fixed32) + buf.extend_from_slice(&0xCAFEBABEu32.to_le_bytes()); + let f = parse_fields(&buf).unwrap(); + match f[0].1 { + FieldValue::Fixed32(v) => assert_eq!(v, 0xCAFEBABE), + ref other => panic!("expected Fixed32 got {other:?}"), + } + } + + #[test] + fn parse_fields_truncated_fixed32_errors() { + let mut buf = Vec::new(); + buf.push((1 << 3) | 5); + buf.extend_from_slice(&[0u8; 2]); // only 2/4 bytes + match parse_fields(&buf).unwrap_err() { + YuanbaoError::ProtoDecode(m) => assert!(m.contains("fixed32"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn parse_fields_unsupported_wire_type_errors() { + // wire type 3 (start group) is not supported. + let buf = vec![(1 << 3) | 3]; + match parse_fields(&buf).unwrap_err() { + YuanbaoError::ProtoDecode(m) => { + assert!(m.contains("unsupported wire type 3"), "got {m}") + } + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn get_string_returns_empty_when_field_is_varint() { + // Field 1 exists but encoded as varint, not bytes — get_string must + // skip past it and return the default. + let mut buf = Vec::new(); + encode_field_varint(1, 7, &mut buf); + let fields = parse_fields(&buf).unwrap(); + assert_eq!(get_string(&fields, 1), ""); + } + + #[test] + fn get_varint_returns_zero_when_field_is_bytes() { + let mut buf = Vec::new(); + encode_field_string(1, "not a varint", &mut buf); + let fields = parse_fields(&buf).unwrap(); + assert_eq!(get_varint(&fields, 1), 0); + } + + #[test] + fn get_bytes_returns_empty_when_field_is_varint() { + let mut buf = Vec::new(); + encode_field_varint(1, 7, &mut buf); + let fields = parse_fields(&buf).unwrap(); + assert!(get_bytes(&fields, 1).is_empty()); + } + + #[test] + fn get_repeated_bytes_collects_multiple_same_field() { + let mut buf = Vec::new(); + encode_field_string(1, "a", &mut buf); + encode_field_string(1, "bb", &mut buf); + encode_field_string(2, "c", &mut buf); // different field — should be skipped + encode_field_string(1, "ddd", &mut buf); + let fields = parse_fields(&buf).unwrap(); + let got = get_repeated_bytes(&fields, 1); + assert_eq!(got.len(), 3); + assert_eq!(got[0], b"a"); + assert_eq!(got[1], b"bb"); + assert_eq!(got[2], b"ddd"); + } +} diff --git a/src/openhuman/channels/runtime/startup.rs b/src/openhuman/channels/runtime/startup.rs index 9138f9a5b9..2ef3064212 100644 --- a/src/openhuman/channels/runtime/startup.rs +++ b/src/openhuman/channels/runtime/startup.rs @@ -28,6 +28,7 @@ use crate::openhuman::channels::traits; use crate::openhuman::channels::whatsapp::WhatsAppChannel; #[cfg(feature = "whatsapp-web")] use crate::openhuman::channels::whatsapp_web::WhatsAppWebChannel; +use crate::openhuman::channels::yuanbao::YuanbaoChannel; use crate::openhuman::channels::Channel; use crate::openhuman::config::Config; use crate::openhuman::context::channels_prompt::build_system_prompt; @@ -498,6 +499,14 @@ pub async fn start_channels(config: Config) -> Result<()> { ))); } + if let Some(ref yb) = config.channels_config.yuanbao { + let yb_cfg = resolve_yuanbao_app_secret(yb.clone(), &config); + match YuanbaoChannel::new(yb_cfg) { + Ok(ch) => channels.push(Arc::new(ch)), + Err(e) => tracing::warn!("[channels] yuanbao config invalid: {e}"), + } + } + if channels.is_empty() { println!("No channels configured. Set up channels in the web UI."); return Ok(()); @@ -633,3 +642,145 @@ pub async fn start_channels(config: Config) -> Result<()> { Ok(()) } + +/// Best-effort fill of `yb_cfg.app_secret` from the encrypted credentials +/// store when TOML doesn't already carry one. +/// +/// `app_secret` is intentionally not persisted in `config.toml` (see the +/// `yuanbao` branch in `controllers/ops.rs`). Existing TOML values still +/// win so manually-installed deployments don't break. Returns the +/// (possibly-modified) config; logging is the only side effect on failure. +/// +/// The stored secret is **only** copied when the stored profile's +/// `app_key` matches `yb_cfg.app_key`. Without that guard, editing +/// `app_key` in `config.toml` would silently pair a fresh key with a +/// stale secret on next startup, and the channel would fail auth until +/// the user reconnected or cleared credentials manually. +fn resolve_yuanbao_app_secret( + mut yb_cfg: crate::openhuman::channels::providers::yuanbao::YuanbaoConfig, + config: &Config, +) -> crate::openhuman::channels::providers::yuanbao::YuanbaoConfig { + if !yb_cfg.app_secret.is_empty() { + return yb_cfg; + } + let auth = crate::openhuman::credentials::AuthService::from_config(config); + match auth.get_profile("channel:yuanbao:api_key", None) { + Ok(Some(profile)) => { + let stored_app_key = profile.metadata.get("app_key").map(String::as_str); + if stored_app_key != Some(yb_cfg.app_key.as_str()) { + tracing::warn!( + "[channels] yuanbao stored credentials are for a different app_key (toml={:?}, store={:?}); reconnect the channel to refresh the secret", + yb_cfg.app_key, + stored_app_key, + ); + } else if let Some(secret) = profile.metadata.get("app_secret") { + yb_cfg.app_secret = secret.clone(); + } + } + Ok(None) => { + tracing::warn!( + "[channels] yuanbao credentials missing — connect the channel again from the UI" + ); + } + Err(e) => { + tracing::warn!("[channels] failed to load yuanbao credentials: {e}"); + } + } + yb_cfg +} + +#[cfg(test)] +mod yuanbao_secret_tests { + use super::*; + use crate::openhuman::channels::providers::yuanbao::YuanbaoConfig; + use crate::openhuman::credentials::AuthService; + use std::collections::HashMap; + use tempfile::tempdir; + + fn isolated_config() -> (tempfile::TempDir, Config) { + let tmp = tempdir().expect("tempdir"); + let mut config = Config::default(); + config.workspace_dir = tmp.path().join("workspace"); + config.config_path = tmp.path().join("config.toml"); + std::fs::create_dir_all(&config.workspace_dir).expect("workspace dir"); + (tmp, config) + } + + #[test] + fn loads_app_secret_from_credentials_when_toml_empty() { + let (_tmp, config) = isolated_config(); + // Pre-write the credentials the same way `connect_channel` does: + // metadata under the `channel:yuanbao:api_key` provider key. + let auth = AuthService::from_config(&config); + let mut metadata = HashMap::new(); + metadata.insert("app_key".to_string(), "ak".to_string()); + metadata.insert("app_secret".to_string(), "from-credentials".to_string()); + auth.store_provider_token("channel:yuanbao:api_key", "default", "", metadata, true) + .expect("store credentials"); + + let yb = YuanbaoConfig { + app_key: "ak".into(), + app_secret: String::new(), + ..Default::default() + }; + let resolved = resolve_yuanbao_app_secret(yb, &config); + assert_eq!(resolved.app_secret, "from-credentials"); + } + + #[test] + fn preserves_existing_toml_secret_without_consulting_store() { + // No credentials in the store at all — resolver must still leave + // the TOML-supplied secret untouched. + let (_tmp, config) = isolated_config(); + let yb = YuanbaoConfig { + app_key: "ak".into(), + app_secret: "from-toml".into(), + ..Default::default() + }; + let resolved = resolve_yuanbao_app_secret(yb, &config); + assert_eq!(resolved.app_secret, "from-toml"); + } + + #[test] + fn returns_empty_secret_when_neither_toml_nor_credentials_have_one() { + let (_tmp, config) = isolated_config(); + let yb = YuanbaoConfig { + app_key: "ak".into(), + app_secret: String::new(), + ..Default::default() + }; + let resolved = resolve_yuanbao_app_secret(yb, &config); + // Surfaces empty so the downstream `YuanbaoChannel::new` validate() + // step can fail clearly, instead of attempting auth with a stale value. + assert_eq!(resolved.app_secret, ""); + } + + #[test] + fn skips_hydration_when_stored_profile_has_different_app_key() { + // Reproduces the stale-secret hazard: user changed `app_key` in + // `config.toml` (e.g. swapped to a different bot) but the + // credentials store still has the old key's profile. The resolver + // must NOT graft the old secret onto the new key. + let (_tmp, config) = isolated_config(); + let auth = AuthService::from_config(&config); + let mut metadata = HashMap::new(); + metadata.insert("app_key".to_string(), "OLD-KEY".to_string()); + metadata.insert( + "app_secret".to_string(), + "old-key-secret-do-not-use".to_string(), + ); + auth.store_provider_token("channel:yuanbao:api_key", "default", "", metadata, true) + .expect("store credentials"); + + let yb = YuanbaoConfig { + app_key: "NEW-KEY".into(), + app_secret: String::new(), + ..Default::default() + }; + let resolved = resolve_yuanbao_app_secret(yb, &config); + assert_eq!( + resolved.app_secret, "", + "stale profile keyed to OLD-KEY must not hydrate NEW-KEY's secret", + ); + } +} diff --git a/src/openhuman/config/schema/channels.rs b/src/openhuman/config/schema/channels.rs index 7489da8d7a..60d2c17318 100644 --- a/src/openhuman/config/schema/channels.rs +++ b/src/openhuman/config/schema/channels.rs @@ -23,6 +23,7 @@ pub struct ChannelsConfig { pub lark: Option, pub dingtalk: Option, pub qq: Option, + pub yuanbao: Option, #[serde(default = "default_channel_message_timeout_secs")] pub message_timeout_secs: u64, /// The user's preferred *external* channel for proactive messages @@ -61,6 +62,7 @@ impl ChannelsConfig { || self.lark.is_some() || self.dingtalk.is_some() || self.qq.is_some() + || self.yuanbao.is_some() || self.matrix.is_some() || self.whatsapp.is_some() } @@ -85,6 +87,7 @@ impl Default for ChannelsConfig { lark: None, dingtalk: None, qq: None, + yuanbao: None, message_timeout_secs: default_channel_message_timeout_secs(), active_channel: None, }