|
| 1 | +import type { WorkspaceClient } from "@databricks/sdk-experimental"; |
| 2 | +import { Time, TimeUnits } from "@databricks/sdk-experimental"; |
| 3 | +import type { GenieMessage } from "@databricks/sdk-experimental/dist/apis/dashboards"; |
| 4 | +import type { Waiter } from "@databricks/sdk-experimental/dist/wait"; |
| 5 | +import { createLogger } from "../../logging"; |
| 6 | +import { genieConnectorDefaults } from "./defaults"; |
| 7 | +import { pollWaiter } from "./poll-waiter"; |
| 8 | +import type { |
| 9 | + GenieAttachmentResponse, |
| 10 | + GenieConversationHistoryResponse, |
| 11 | + GenieMessageResponse, |
| 12 | + GenieStreamEvent, |
| 13 | +} from "./types"; |
| 14 | + |
| 15 | +const logger = createLogger("connectors:genie"); |
| 16 | + |
| 17 | +type CreateMessageWaiter = Waiter<GenieMessage, GenieMessage>; |
| 18 | + |
| 19 | +export interface GenieConnectorConfig { |
| 20 | + timeout?: number; |
| 21 | + maxMessages?: number; |
| 22 | +} |
| 23 | + |
| 24 | +function mapAttachments(message: GenieMessage): GenieAttachmentResponse[] { |
| 25 | + return ( |
| 26 | + message.attachments?.map((att) => ({ |
| 27 | + attachmentId: att.attachment_id, |
| 28 | + query: att.query |
| 29 | + ? { |
| 30 | + title: att.query.title, |
| 31 | + description: att.query.description, |
| 32 | + query: att.query.query, |
| 33 | + statementId: att.query.statement_id, |
| 34 | + } |
| 35 | + : undefined, |
| 36 | + text: att.text ? { content: att.text.content } : undefined, |
| 37 | + suggestedQuestions: att.suggested_questions?.questions, |
| 38 | + })) ?? [] |
| 39 | + ); |
| 40 | +} |
| 41 | + |
| 42 | +function toMessageResponse(message: GenieMessage): GenieMessageResponse { |
| 43 | + return { |
| 44 | + messageId: message.message_id, |
| 45 | + conversationId: message.conversation_id, |
| 46 | + spaceId: message.space_id, |
| 47 | + status: message.status ?? "COMPLETED", |
| 48 | + content: message.content, |
| 49 | + attachments: mapAttachments(message), |
| 50 | + error: message.error?.error, |
| 51 | + }; |
| 52 | +} |
| 53 | + |
| 54 | +export class GenieConnector { |
| 55 | + private readonly config: Required<GenieConnectorConfig>; |
| 56 | + |
| 57 | + constructor(config: GenieConnectorConfig = {}) { |
| 58 | + this.config = { |
| 59 | + timeout: config.timeout ?? genieConnectorDefaults.timeout, |
| 60 | + maxMessages: config.maxMessages ?? genieConnectorDefaults.maxMessages, |
| 61 | + }; |
| 62 | + } |
| 63 | + |
| 64 | + async startMessage( |
| 65 | + workspaceClient: WorkspaceClient, |
| 66 | + spaceId: string, |
| 67 | + content: string, |
| 68 | + conversationId: string | undefined, |
| 69 | + ): Promise<{ |
| 70 | + messageWaiter: CreateMessageWaiter; |
| 71 | + conversationId: string; |
| 72 | + messageId: string; |
| 73 | + }> { |
| 74 | + if (conversationId) { |
| 75 | + const waiter = await workspaceClient.genie.createMessage({ |
| 76 | + space_id: spaceId, |
| 77 | + conversation_id: conversationId, |
| 78 | + content, |
| 79 | + }); |
| 80 | + return { |
| 81 | + messageWaiter: waiter, |
| 82 | + conversationId, |
| 83 | + messageId: waiter.message_id ?? "", |
| 84 | + }; |
| 85 | + } |
| 86 | + const start = await workspaceClient.genie.startConversation({ |
| 87 | + space_id: spaceId, |
| 88 | + content, |
| 89 | + }); |
| 90 | + return { |
| 91 | + messageWaiter: start as unknown as CreateMessageWaiter, |
| 92 | + conversationId: start.conversation_id, |
| 93 | + messageId: start.message_id, |
| 94 | + }; |
| 95 | + } |
| 96 | + |
| 97 | + async waitForMessage( |
| 98 | + messageWaiter: CreateMessageWaiter, |
| 99 | + options?: { timeout?: number }, |
| 100 | + ): Promise<GenieMessage> { |
| 101 | + const timeout = options?.timeout ?? this.config.timeout; |
| 102 | + const waitOptions = |
| 103 | + timeout > 0 ? { timeout: new Time(timeout, TimeUnits.milliseconds) } : {}; |
| 104 | + return messageWaiter.wait(waitOptions); |
| 105 | + } |
| 106 | + |
| 107 | + async listConversationMessages( |
| 108 | + workspaceClient: WorkspaceClient, |
| 109 | + spaceId: string, |
| 110 | + conversationId: string, |
| 111 | + options?: { maxMessages?: number }, |
| 112 | + ): Promise<GenieMessageResponse[]> { |
| 113 | + const maxMessages = options?.maxMessages ?? this.config.maxMessages; |
| 114 | + const allMessages: GenieMessage[] = []; |
| 115 | + let pageToken: string | undefined; |
| 116 | + |
| 117 | + do { |
| 118 | + const response = await workspaceClient.genie.listConversationMessages({ |
| 119 | + space_id: spaceId, |
| 120 | + conversation_id: conversationId, |
| 121 | + page_size: genieConnectorDefaults.pageSize, |
| 122 | + ...(pageToken ? { page_token: pageToken } : {}), |
| 123 | + }); |
| 124 | + |
| 125 | + if (response.messages) { |
| 126 | + allMessages.push(...response.messages); |
| 127 | + } |
| 128 | + |
| 129 | + pageToken = response.next_page_token; |
| 130 | + } while (pageToken && allMessages.length < maxMessages); |
| 131 | + |
| 132 | + return allMessages.slice(0, maxMessages).reverse().map(toMessageResponse); |
| 133 | + } |
| 134 | + |
| 135 | + async getMessageAttachmentQueryResult( |
| 136 | + workspaceClient: WorkspaceClient, |
| 137 | + spaceId: string, |
| 138 | + conversationId: string, |
| 139 | + messageId: string, |
| 140 | + attachmentId: string, |
| 141 | + _signal?: AbortSignal, |
| 142 | + ): Promise<unknown> { |
| 143 | + const response = |
| 144 | + await workspaceClient.genie.getMessageAttachmentQueryResult({ |
| 145 | + space_id: spaceId, |
| 146 | + conversation_id: conversationId, |
| 147 | + message_id: messageId, |
| 148 | + attachment_id: attachmentId, |
| 149 | + }); |
| 150 | + return response.statement_response; |
| 151 | + } |
| 152 | + |
| 153 | + async *streamSendMessage( |
| 154 | + workspaceClient: WorkspaceClient, |
| 155 | + spaceId: string, |
| 156 | + content: string, |
| 157 | + conversationId: string | undefined, |
| 158 | + options?: { timeout?: number }, |
| 159 | + ): AsyncGenerator<GenieStreamEvent> { |
| 160 | + try { |
| 161 | + const { |
| 162 | + messageWaiter, |
| 163 | + conversationId: resultConversationId, |
| 164 | + messageId: resultMessageId, |
| 165 | + } = await this.startMessage( |
| 166 | + workspaceClient, |
| 167 | + spaceId, |
| 168 | + content, |
| 169 | + conversationId, |
| 170 | + ); |
| 171 | + |
| 172 | + yield { |
| 173 | + type: "message_start", |
| 174 | + conversationId: resultConversationId, |
| 175 | + messageId: resultMessageId, |
| 176 | + spaceId, |
| 177 | + }; |
| 178 | + |
| 179 | + const timeout = |
| 180 | + options?.timeout != null ? options.timeout : this.config.timeout; |
| 181 | + const waitOptions = |
| 182 | + timeout > 0 |
| 183 | + ? { timeout: new Time(timeout, TimeUnits.milliseconds) } |
| 184 | + : {}; |
| 185 | + |
| 186 | + let completedMessage!: GenieMessage; |
| 187 | + for await (const event of pollWaiter(messageWaiter, waitOptions)) { |
| 188 | + if (event.type === "progress" && event.value.status) { |
| 189 | + yield { type: "status", status: event.value.status }; |
| 190 | + } else if (event.type === "completed") { |
| 191 | + completedMessage = event.value; |
| 192 | + } |
| 193 | + } |
| 194 | + |
| 195 | + const messageResponse = toMessageResponse(completedMessage); |
| 196 | + yield { type: "message_result", message: messageResponse }; |
| 197 | + |
| 198 | + yield* this.emitQueryResults( |
| 199 | + workspaceClient, |
| 200 | + spaceId, |
| 201 | + resultConversationId, |
| 202 | + messageResponse.messageId, |
| 203 | + messageResponse, |
| 204 | + ); |
| 205 | + } catch (error) { |
| 206 | + logger.error("Genie message error: %O", error); |
| 207 | + yield { |
| 208 | + type: "error", |
| 209 | + error: error instanceof Error ? error.message : "Genie request failed", |
| 210 | + }; |
| 211 | + } |
| 212 | + } |
| 213 | + |
| 214 | + private async *emitQueryResults( |
| 215 | + workspaceClient: WorkspaceClient, |
| 216 | + spaceId: string, |
| 217 | + conversationId: string, |
| 218 | + messageId: string, |
| 219 | + messageResponse: GenieMessageResponse, |
| 220 | + ): AsyncGenerator< |
| 221 | + Extract<GenieStreamEvent, { type: "query_result" } | { type: "error" }> |
| 222 | + > { |
| 223 | + const attachments = messageResponse.attachments ?? []; |
| 224 | + for (const att of attachments) { |
| 225 | + if (!att.query?.statementId || !att.attachmentId) continue; |
| 226 | + try { |
| 227 | + const data = await this.getMessageAttachmentQueryResult( |
| 228 | + workspaceClient, |
| 229 | + spaceId, |
| 230 | + conversationId, |
| 231 | + messageId, |
| 232 | + att.attachmentId, |
| 233 | + ); |
| 234 | + yield { |
| 235 | + type: "query_result", |
| 236 | + attachmentId: att.attachmentId, |
| 237 | + statementId: att.query.statementId, |
| 238 | + data, |
| 239 | + }; |
| 240 | + } catch (error) { |
| 241 | + logger.error( |
| 242 | + "Failed to fetch query result for attachment %s: %O", |
| 243 | + att.attachmentId, |
| 244 | + error, |
| 245 | + ); |
| 246 | + yield { |
| 247 | + type: "error", |
| 248 | + error: `Failed to fetch query result for attachment ${att.attachmentId}`, |
| 249 | + }; |
| 250 | + } |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + async *streamConversation( |
| 255 | + workspaceClient: WorkspaceClient, |
| 256 | + spaceId: string, |
| 257 | + conversationId: string, |
| 258 | + options?: { includeQueryResults?: boolean }, |
| 259 | + ): AsyncGenerator<GenieStreamEvent> { |
| 260 | + const includeQueryResults = options?.includeQueryResults !== false; |
| 261 | + |
| 262 | + try { |
| 263 | + const messageResponses = await this.listConversationMessages( |
| 264 | + workspaceClient, |
| 265 | + spaceId, |
| 266 | + conversationId, |
| 267 | + ); |
| 268 | + |
| 269 | + for (const messageResponse of messageResponses) { |
| 270 | + yield { type: "message_result", message: messageResponse }; |
| 271 | + } |
| 272 | + |
| 273 | + if (includeQueryResults) { |
| 274 | + const queryAttachments: Array<{ |
| 275 | + messageId: string; |
| 276 | + attachmentId: string; |
| 277 | + statementId: string; |
| 278 | + }> = []; |
| 279 | + |
| 280 | + for (const msg of messageResponses) { |
| 281 | + for (const att of msg.attachments ?? []) { |
| 282 | + if (att.query?.statementId && att.attachmentId) { |
| 283 | + queryAttachments.push({ |
| 284 | + messageId: msg.messageId, |
| 285 | + attachmentId: att.attachmentId, |
| 286 | + statementId: att.query.statementId, |
| 287 | + }); |
| 288 | + } |
| 289 | + } |
| 290 | + } |
| 291 | + |
| 292 | + const results = await Promise.allSettled( |
| 293 | + queryAttachments.map(async (att) => { |
| 294 | + const data = await this.getMessageAttachmentQueryResult( |
| 295 | + workspaceClient, |
| 296 | + spaceId, |
| 297 | + conversationId, |
| 298 | + att.messageId, |
| 299 | + att.attachmentId, |
| 300 | + ); |
| 301 | + return { |
| 302 | + attachmentId: att.attachmentId, |
| 303 | + statementId: att.statementId, |
| 304 | + data, |
| 305 | + }; |
| 306 | + }), |
| 307 | + ); |
| 308 | + |
| 309 | + for (const result of results) { |
| 310 | + if (result.status === "fulfilled") { |
| 311 | + yield { |
| 312 | + type: "query_result", |
| 313 | + attachmentId: result.value.attachmentId, |
| 314 | + statementId: result.value.statementId, |
| 315 | + data: result.value.data, |
| 316 | + }; |
| 317 | + } else { |
| 318 | + logger.error("Failed to fetch query result: %O", result.reason); |
| 319 | + yield { |
| 320 | + type: "error", |
| 321 | + error: |
| 322 | + result.reason instanceof Error |
| 323 | + ? result.reason.message |
| 324 | + : "Failed to fetch query result", |
| 325 | + }; |
| 326 | + } |
| 327 | + } |
| 328 | + } |
| 329 | + } catch (error) { |
| 330 | + logger.error("Genie getConversation error: %O", error); |
| 331 | + yield { |
| 332 | + type: "error", |
| 333 | + error: |
| 334 | + error instanceof Error |
| 335 | + ? error.message |
| 336 | + : "Failed to fetch conversation", |
| 337 | + }; |
| 338 | + } |
| 339 | + } |
| 340 | + |
| 341 | + async sendMessage( |
| 342 | + workspaceClient: WorkspaceClient, |
| 343 | + spaceId: string, |
| 344 | + content: string, |
| 345 | + conversationId: string | undefined, |
| 346 | + ): Promise<GenieMessageResponse> { |
| 347 | + const { messageWaiter, conversationId: resultConversationId } = |
| 348 | + await this.startMessage( |
| 349 | + workspaceClient, |
| 350 | + spaceId, |
| 351 | + content, |
| 352 | + conversationId, |
| 353 | + ); |
| 354 | + const completedMessage = await this.waitForMessage(messageWaiter); |
| 355 | + const messageResponse = toMessageResponse(completedMessage); |
| 356 | + return { |
| 357 | + ...messageResponse, |
| 358 | + conversationId: resultConversationId, |
| 359 | + }; |
| 360 | + } |
| 361 | + |
| 362 | + async getConversation( |
| 363 | + workspaceClient: WorkspaceClient, |
| 364 | + spaceId: string, |
| 365 | + conversationId: string, |
| 366 | + ): Promise<GenieConversationHistoryResponse> { |
| 367 | + const messages = await this.listConversationMessages( |
| 368 | + workspaceClient, |
| 369 | + spaceId, |
| 370 | + conversationId, |
| 371 | + ); |
| 372 | + return { |
| 373 | + conversationId, |
| 374 | + spaceId, |
| 375 | + messages, |
| 376 | + }; |
| 377 | + } |
| 378 | +} |
0 commit comments