From a760e21a193fedec2dd2aa33a816067c6cda1569 Mon Sep 17 00:00:00 2001 From: David Meijer Date: Fri, 12 Dec 2025 16:39:22 +0100 Subject: [PATCH 01/34] ENH: add download buttons to query and enrichment results --- src/client/src/components/WorkspaceQuery.tsx | 71 ++++++++++++++++++-- 1 file changed, 65 insertions(+), 6 deletions(-) diff --git a/src/client/src/components/WorkspaceQuery.tsx b/src/client/src/components/WorkspaceQuery.tsx index 8198c97..b4fdd31 100644 --- a/src/client/src/components/WorkspaceQuery.tsx +++ b/src/client/src/components/WorkspaceQuery.tsx @@ -21,6 +21,7 @@ import { DataGrid, GridColDef, GridPaginationModel } from "@mui/x-data-grid"; import NotificationsRoundedIcon from "@mui/icons-material/NotificationsRounded"; import UploadFileIcon from "@mui/icons-material/UploadFile"; import SettingsIcon from "@mui/icons-material/Settings"; +import DownloadIcon from "@mui/icons-material/Download"; import { useTheme } from "@mui/material/styles"; import { Link as RouterLink } from "react-router-dom"; import { useNotifications } from "../components/NotificationProvider"; @@ -74,6 +75,42 @@ export const WorkspaceQuery: React.FC = ({ session, setSess const [enrichmentError, setEnrichmentError] = React.useState(null); const [enrichmentResult, setEnrichmentResult] = React.useState(null); + const handleDownloadQueryResults = () => { + if (!queryResult || queryResult.rows.length === 0) return; + + const tsvHeader = queryResult.columns.join("\t"); + const tsvRows = queryResult.rows.map((row) => + queryResult.columns.map((col) => row[col] ?? "").join("\t") + ); + const tsvContent = [tsvHeader, ...tsvRows].join("\n"); + + const blob = new Blob([tsvContent], { type: "text/tab-separated-values" }); + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = `query_results_session_${session.sessionId}.tsv`; + a.click(); + URL.revokeObjectURL(url); + } + + const handleDownloadEnrichmentResults = () => { + if (!enrichmentResult || enrichmentResult.items.length === 0) return; + + const tsvHeader = ["id", "schema", "key", "value", "adjusted_p_value"].join("\t"); + const tsvRows = enrichmentResult.items.map((item) => + [item.id, item.schema, item.key, item.value, item.adjusted_p_value].join("\t") + ); + const tsvContent = [tsvHeader, ...tsvRows].join("\n"); + + const blob = new Blob([tsvContent], { type: "text/tab-separated-values" }); + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = `enrichment_results_session_${session.sessionId}.tsv`; + a.click(); + URL.revokeObjectURL(url); + } + const [paginationModelEnrichment, setPaginationModelEnrichment] = React.useState({ pageSize: 10, page: 0, @@ -612,9 +649,20 @@ export const WorkspaceQuery: React.FC = ({ session, setSess - - Query results - + + + Query results + + + 0 ? "pointer" : "not-allowed", + }} + /> + + {rows.length === 0 && !queryLoading ? ( {queryResult ? "No results returned." : "Run a query to see results."} @@ -671,9 +719,20 @@ export const WorkspaceQuery: React.FC = ({ session, setSess - - Enrichment results - + + + Enrichment results + + + 0 ? "pointer" : "not-allowed", + }} + /> + + {enrichmentLoading ? ( From 83dfef2a66157ea7173b984188812d30413e702f Mon Sep 17 00:00:00 2001 From: David Meijer Date: Fri, 12 Dec 2025 16:59:22 +0100 Subject: [PATCH 02/34] ENH: allow downloading of scatter data embedding space --- .../src/components/ViewEmbeddingSpace.tsx | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/client/src/components/ViewEmbeddingSpace.tsx b/src/client/src/components/ViewEmbeddingSpace.tsx index 15acf10..1a2a5a9 100644 --- a/src/client/src/components/ViewEmbeddingSpace.tsx +++ b/src/client/src/components/ViewEmbeddingSpace.tsx @@ -1,8 +1,10 @@ import React from "react"; import Box from "@mui/material/Box"; import Alert from "@mui/material/Alert"; +import Tooltip from "@mui/material/Tooltip"; import Typography from "@mui/material/Typography"; import CircularProgress from "@mui/material/CircularProgress"; +import DownloadIcon from "@mui/icons-material/Download"; import { ScatterChart } from "@mui/x-charts/ScatterChart"; import Stack from "@mui/material/Stack"; import ToggleButton from "@mui/material/ToggleButton"; @@ -90,6 +92,38 @@ export const ViewEmbeddingSpace: React.FC = ({ return Array.from(byKind.values()); }, [points, parentById]); + // Download scatter series as TSV + const handleDownloadEmbeddingData = () => { + if (!points || points.length === 0) return; + + const header = ["id", "kind", "nane", "x", "y"]; + const rows = points.map((p) => { + const parent = parentById.get(p.parent_id) ?? null; + const parentName = parent ? parent.name : "unknown"; + const childIds = parent?.retrofingerprints.map((fp) => fp.id) || []; + const childIdx = childIds.indexOf(p.child_id); + return [ + p.child_id, + p.kind, + `Readout ${childIdx >= 0 ? `${childIdx + 1}` : ""} ${parentName} `, + p.x.toString(), + p.y.toString(), + ]; + }); + const tsvContent = [header, ...rows] + .map((row) => row.join("\t")) + .join("\n"); + const blob = new Blob([tsvContent], { type: "text/tab-separated-values" }); + const url = URL.createObjectURL(blob); + const link = document.createElement("a"); + link.href = url; + link.download = `embedding_space_${session.sessionId}.tsv`; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + URL.revokeObjectURL(url); + }; + // Changes only when the contents of items changes, not just because of polling const itemsKey = React.useMemo(() => { if (!session.items || session.items.length === 0) return ""; @@ -253,6 +287,15 @@ export const ViewEmbeddingSpace: React.FC = ({ }} > + + + Date: Fri, 12 Dec 2025 18:08:06 +0100 Subject: [PATCH 03/34] UPD: convert msa to svg for download --- src/client/src/components/ViewMsa.tsx | 141 ++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) diff --git a/src/client/src/components/ViewMsa.tsx b/src/client/src/components/ViewMsa.tsx index 7b42934..36ee8ed 100644 --- a/src/client/src/components/ViewMsa.tsx +++ b/src/client/src/components/ViewMsa.tsx @@ -11,6 +11,7 @@ import RefreshIcon from "@mui/icons-material/Refresh"; import PaletteIcon from "@mui/icons-material/Palette"; import VisibilityOffIcon from "@mui/icons-material/VisibilityOff"; import SettingsIcon from "@mui/icons-material/Settings"; +import DownloadIcon from "@mui/icons-material/Download"; import { useNotifications } from "./NotificationProvider"; import { Session, MsaSettings, MsaState, MsaSequence, PrimarySequence } from "../features/session/types"; import { runMsa } from "../features/views/api"; @@ -35,6 +36,69 @@ const toDisplayName = (name: string | null): string | null => { return alphanumeric.slice(0, 3).toUpperCase(); }; +const toHex = (v: number) => v.toString(16).padStart(2, "0"); + +const hslToRgb = (h: number, s: number, l: number) => { + // h: 0–360, s/l: 0–1 + const c = (1 - Math.abs(2 * l - 1)) * s; + const x = c * (1 - Math.abs(((h / 60) % 2) - 1)); + const m = l - c / 2; + const pick = (hp: number) => + hp < 60 ? [c, x, 0] : + hp < 120 ? [x, c, 0] : + hp < 180 ? [0, c, x] : + hp < 240 ? [0, x, c] : + hp < 300 ? [x, 0, c] : + [c, 0, x]; + const [r1, g1, b1] = pick(h); + return [ + Math.round((r1 + m) * 255), + Math.round((g1 + m) * 255), + Math.round((b1 + m) * 255), + ]; +}; + +const normalizeColor = (raw: string | undefined | null) => { + if (!raw) return "#f5f5f5"; + const c = raw.trim(); + + // Already hex (#rgb, #rrggbb, #rrggbbaa) + if (/^#([0-9a-f]{3}|[0-9a-f]{6}|[0-9a-f]{8})$/i.test(c)) return c; + + // rgba()/rgb() + const rgba = c.match(/^rgba?\(\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)(?:\s*,\s*([\d.]+))?\s*\)$/i); + if (rgba) { + const [r, g, b, aRaw] = rgba.slice(1).map(Number); + const a = isNaN(aRaw) ? 1 : Math.max(0, Math.min(1, aRaw)); + // composite over white to avoid transparency issues + const blend = (v: number) => Math.round((1 - a) * 255 + a * v); + return `#${toHex(blend(r))}${toHex(blend(g))}${toHex(blend(b))}`; + } + + // hsla()/hsl() + const hsla = c.match(/^hsla?\(\s*([\d.]+)(?:deg)?\s*,\s*([\d.]+)%\s*,\s*([\d.]+)%(?:\s*,\s*([\d.]+))?\s*\)$/i); + if (hsla) { + const h = Number(hsla[1]); + const s = Number(hsla[2]) / 100; + const l = Number(hsla[3]) / 100; + const a = hsla[4] === undefined ? 1 : Math.max(0, Math.min(1, Number(hsla[4]))); + const [r, g, b] = hslToRgb(h, s, l); + const blend = (v: number) => Math.round((1 - a) * 255 + a * v); + return `#${toHex(blend(r))}${toHex(blend(g))}${toHex(blend(b))}`; + } + + // Fallback + return "#f5f5f5"; +}; + +const escapeSvgText = (value: string) => + value + .replace(/&/g, "&") + .replace(//g, ">") + .replace(/"/g, """) + .replace(/'/g, "'"); + interface SortableRowProps { row: MsaSequence; labelWidth: number; @@ -618,6 +682,77 @@ export const ViewMsa: React.FC = ({ } } + const buildMsaSvg = () => { + const visible = msa.filter(row => !hiddenIds.has(row.id as string)); + if (visible.length === 0) return ""; + + const motifPx = 19.1955; + const rowHeight = 13.8131; + const labelW = 80; + const padding = 10; + const textPadding = 6; + + const svgWidth = padding * 2 + labelW + motifPx * msaLength; + const svgHeight = padding * 2 + rowHeight * visible.length; + + const rowsSvg = visible + .map((row, rIdx) => { + const y = padding + rIdx * rowHeight; + const labelText = escapeSvgText(row.name || row.id || "row"); + const cells = row.sequence + .slice(0, msaLength) + .map((motif, cIdx) => { + const x = padding + labelW + cIdx * motifPx; + const isPad = (motif.id ?? "").startsWith("pad-"); + if (isPad) return ""; + const fill = normalizeColor(session.settings.motifColorPalette[motif.name || ""]); + const text = escapeSvgText( + toDisplayName(motif.displayName || motif.name || null) || "UNK" + ); + return ` + + + ${text} + `; + }) + .join(""); + return ` + + ${labelText} + ${cells} + `; + }) + .join(""); + + return ` + + + ${rowsSvg} + `; + }; + + const handleDownloadMsaSvg = () => { + const svg = buildMsaSvg(); + if (!svg) { + pushNotification("No visible sequences to download.", "warning"); + return; + } + const blob = new Blob([svg], { type: "image/svg+xml;charset=utf-8" }); + const url = URL.createObjectURL(blob); + const link = document.createElement("a"); + link.href = url; + link.download = `msa_${session.sessionId}.svg`; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + URL.revokeObjectURL(url); + }; + return ( @@ -718,6 +853,12 @@ export const ViewMsa: React.FC = ({ sx={{ cursor: "pointer" }} /> + + + From 5a9d403bbdae4451597dec4069c9b4089438350e Mon Sep 17 00:00:00 2001 From: David Meijer Date: Fri, 12 Dec 2025 20:06:09 +0100 Subject: [PATCH 04/34] UPD: protected name space for proteinogenic amino acids --- src/client/src/components/ViewMsa.tsx | 90 ++++++++++++++++++++++++--- 1 file changed, 81 insertions(+), 9 deletions(-) diff --git a/src/client/src/components/ViewMsa.tsx b/src/client/src/components/ViewMsa.tsx index 36ee8ed..8e1f0d5 100644 --- a/src/client/src/components/ViewMsa.tsx +++ b/src/client/src/components/ViewMsa.tsx @@ -29,11 +29,80 @@ import { import { CSS } from "@dnd-kit/utilities"; import DragIndicatorIcon from "@mui/icons-material/DragIndicator"; -// Helper: turn a name into a display name by keeping only alphanumerics, capitalizing, and taking up to 3 characters -const toDisplayName = (name: string | null): string | null => { - if (!name) return null; - const alphanumeric = name.replace(/[^a-z0-9]/gi, ""); - return alphanumeric.slice(0, 3).toUpperCase(); +export const PROTECTED_NAME_TO_CODE: Record = { + ALANINE: "ALA", + CYSTEINE: "CYS", + ASPARTICACID: "ASP", + GLUTAMICACID: "GLU", + PHENYLALANINE: "PHE", + GLYCINE: "GLY", + HISTIDINE: "HIS", + ISOLEUCINE: "ILE", + LYSINE: "LYS", + LEUCINE: "LEU", + METHIONINE: "MET", + ASPARAGINE: "ASN", + PROLINE: "PRO", + GLUTAMINE: "GLN", + ARGININE: "ARG", + SERINE: "SER", + THREONINE: "THR", + VALINE: "VAL", + TRYPTOPHAN: "TRP", + TYROSINE: "TYR", +}; + + +export const makeToDisplayName = (protectedNameToCode: Record) => { + const norm = (s: string) => s.replace(/[^a-z0-9]/gi, "").toUpperCase(); + + // normalize protected names + reserve protected codes + const prot = new Map( + Object.entries(protectedNameToCode).map(([k, v]) => [norm(k), norm(v)]) + ); + const reserved = new Set(Array.from(prot.values())); // e.g. ALA, GLY + const used = new Set(reserved); // block others from taking them + const cache = new Map(); // per-name stability + + const candidates = (s: string) => { + const out: string[] = []; + if (s.length >= 3) { + out.push(s.slice(0, 3)); // ABC + for (let i = 3; i < s.length; i++) out.push(s[0] + s[1] + s[i]); // AB? + for (let i = 2; i < s.length; i++) out.push(s[0] + s[i - 1] + s[i]); // A?? + } + if (s.length >= 2) out.push(s.slice(0, 2)); // AB + if (s.length >= 1) out.push(s[0]); // A + // de-dupe in order + const seen = new Set(); + return out.filter(c => c.length <= 3 && !seen.has(c) && (seen.add(c), true)); + }; + + return (name: string | null): string | null => { + if (!name) return null; + const s = norm(name); + if (!s) return null; + + const hit = cache.get(s); + if (hit) return hit; + + // ONLY protected full names get protected 3-letter codes + const canonical = prot.get(s); + if (canonical) { + cache.set(s, canonical); + return canonical; + } + + // don’t let non-protected names steal reserved AA codes + for (const c of candidates(s)) { + if (!used.has(c)) { + used.add(c); + cache.set(s, c); + return c; + } + } + return null; + }; }; const toHex = (v: number) => v.toString(16).padStart(2, "0"); @@ -280,6 +349,11 @@ export const ViewMsa: React.FC = ({ const [colorPaletteDialogOpen, setColorPaletteDialogOpen] = React.useState(false); const [msaSettingsDialogOpen, setMsaSettingsDialogOpen] = React.useState(false); + const toDisplayName = React.useMemo( + () => makeToDisplayName(PROTECTED_NAME_TO_CODE), + [session.sessionId] + ) + const handleColorPaletteSave = (newMap: Record) => { setSession(prev => ({ ...prev, @@ -706,9 +780,7 @@ export const ViewMsa: React.FC = ({ const isPad = (motif.id ?? "").startsWith("pad-"); if (isPad) return ""; const fill = normalizeColor(session.settings.motifColorPalette[motif.name || ""]); - const text = escapeSvgText( - toDisplayName(motif.displayName || motif.name || null) || "UNK" - ); + const text = escapeSvgText(toDisplayName(motif.name || null) || "UNK"); return ` = ({ arrow > Date: Fri, 12 Dec 2025 20:23:42 +0100 Subject: [PATCH 05/34] UPD: add line per frow for msa output as svg --- src/client/src/components/ViewMsa.tsx | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/client/src/components/ViewMsa.tsx b/src/client/src/components/ViewMsa.tsx index 8e1f0d5..515c3fc 100644 --- a/src/client/src/components/ViewMsa.tsx +++ b/src/client/src/components/ViewMsa.tsx @@ -768,11 +768,15 @@ export const ViewMsa: React.FC = ({ const svgWidth = padding * 2 + labelW + motifPx * msaLength; const svgHeight = padding * 2 + rowHeight * visible.length; + const lineSpan = Math.max(0, msaLength - 1) * motifPx; const rowsSvg = visible .map((row, rIdx) => { const y = padding + rIdx * rowHeight; const labelText = escapeSvgText(row.name || row.id || "row"); + const lineX1 = padding + labelW; + const lineX2 = lineX1 + lineSpan + motifPx; + const lineY = y + rowHeight / 2; const cells = row.sequence .slice(0, msaLength) .map((motif, cIdx) => { @@ -796,6 +800,10 @@ export const ViewMsa: React.FC = ({ ${labelText} + ${lineSpan > 0 + ? `` + : ""} ${cells} `; }) From f26bdc0be8a52953fb84d52225c53dfe7b02c5be Mon Sep 17 00:00:00 2001 From: David Meijer Date: Mon, 15 Dec 2025 14:00:12 +0100 Subject: [PATCH 06/34] ENH: add record level parsing switch --- .../components/DialogImportGeneCluster.tsx | 30 +++++++++++++++++++ src/client/src/components/WorkspaceUpload.tsx | 6 +++- src/client/src/features/jobs/api.ts | 10 +++++-- src/client/src/features/jobs/types.ts | 1 + src/server/routes/jobs.py | 23 ++++++++++---- 5 files changed, 61 insertions(+), 9 deletions(-) diff --git a/src/client/src/components/DialogImportGeneCluster.tsx b/src/client/src/components/DialogImportGeneCluster.tsx index 3d97c4b..af7b03a 100644 --- a/src/client/src/components/DialogImportGeneCluster.tsx +++ b/src/client/src/components/DialogImportGeneCluster.tsx @@ -9,12 +9,16 @@ type DialogImportGeneClusterProps = { open: boolean; onClose: () => void; onImport: (files: File[]) => void; + readoutLevel: "rec" | "gene"; + setReadoutLevel: (level: "rec" | "gene") => void; } export const DialogImportGeneCluster: React.FC = ({ open, onClose, onImport, + readoutLevel, + setReadoutLevel, }) => { const [gbkFiles, setGbkFiles] = React.useState([]); const canImport = gbkFiles.length > 0; @@ -47,6 +51,32 @@ export const DialogImportGeneCluster: React.FC = (  output files for best compatibility. + + Choose readout level:  + setReadoutLevel("rec")} + sx={{ + fontWeight: readoutLevel === "rec" ? "bold" : "normal", + color: readoutLevel === "rec" ? 'primary.main' : 'inherit', + }} + > + record (record level) + +  or  + setReadoutLevel("gene")} + sx={{ + fontWeight: readoutLevel === "gene" ? "bold" : "normal", + color: readoutLevel === "gene" ? 'primary.main' : 'inherit', + }} + > + gene (gene level) + + - - - - - - {localEntries.map(({ id, key, color }) => ( - - handleKeyChange(id, e.target.value)} - sx={{ flexGrow: 1 }} // fill up available space - /> - handleColorChange(id, e.target.value)} - inputProps={{ maxLength: 7, placeholder: "#RRGGBB" }} - /> - - handleRemoveEntry(id)} - /> - - ))} - - - ) -} diff --git a/src/client/src/components/DialogImportGeneCluster.tsx b/src/client/src/components/DialogImportGeneCluster.tsx deleted file mode 100644 index af7b03a..0000000 --- a/src/client/src/components/DialogImportGeneCluster.tsx +++ /dev/null @@ -1,98 +0,0 @@ -import React from "react"; -import Stack from "@mui/material/Stack"; -import Typography from "@mui/material/Typography"; -import Button from "@mui/material/Button"; -import MuiLink from "@mui/material/Link"; -import { DialogWindow } from "../components/DialogWindow"; - -type DialogImportGeneClusterProps = { - open: boolean; - onClose: () => void; - onImport: (files: File[]) => void; - readoutLevel: "rec" | "gene"; - setReadoutLevel: (level: "rec" | "gene") => void; -} - -export const DialogImportGeneCluster: React.FC = ({ - open, - onClose, - onImport, - readoutLevel, - setReadoutLevel, -}) => { - const [gbkFiles, setGbkFiles] = React.useState([]); - const canImport = gbkFiles.length > 0; - - const reset = () => setGbkFiles([]); - - const handleImport = () => { - onImport(gbkFiles); - reset(); - onClose(); - } - - return ( - - - - Select one or more GenBank files (.gbk, .gb, .genbank) containing gene cluster data to import into your workspace. Make sure the files are  - - antiSMASH - -  output files for best compatibility. - - - Choose readout level:  - setReadoutLevel("rec")} - sx={{ - fontWeight: readoutLevel === "rec" ? "bold" : "normal", - color: readoutLevel === "rec" ? 'primary.main' : 'inherit', - }} - > - record (record level) - -  or  - setReadoutLevel("gene")} - sx={{ - fontWeight: readoutLevel === "gene" ? "bold" : "normal", - color: readoutLevel === "gene" ? 'primary.main' : 'inherit', - }} - > - gene (gene level) - - - - {gbkFiles.length > 0 && ( - - {gbkFiles.length} file(s) selected - - )} - - - ) -} diff --git a/src/client/src/components/DialogMsaSettings.tsx b/src/client/src/components/DialogMsaSettings.tsx deleted file mode 100644 index f21da26..0000000 --- a/src/client/src/components/DialogMsaSettings.tsx +++ /dev/null @@ -1,109 +0,0 @@ -import React from "react"; -import Box from "@mui/material/Box"; -import FormControl from "@mui/material/FormControl"; -import Stack from "@mui/material/Stack"; -import Typography from "@mui/material/Typography"; -import Select, { SelectChangeEvent } from "@mui/material/Select"; -import MenuItem from "@mui/material/MenuItem"; -import { DialogWindow } from "../components/DialogWindow"; -import { MsaSettings, AlignmentType } from "../features/session/types"; - -type DialogMsaSettingsProps = { - open: boolean; - onClose: () => void; - settings: MsaSettings; - onSave: (newSettings: MsaSettings) => void; -} - -export const DialogMsaSettings: React.FC = ({ - open, - onClose, - settings, - onSave, -}) => { - const [localSettings, setLocalSettings] = React.useState(settings); - const [dirty, setDirty] = React.useState(false); - - // Sync local state with remote session updates when not actively editing - React.useEffect(() => { - // If dialog is closed, always sync to the latest saved settings - if (!open) { - setLocalSettings(settings); - setDirty(false); - return; - } - - // When open but not dirty, accept incoming updates (e.g., from polling) - if (!dirty) { - setLocalSettings(settings); - } - }, [settings, open, dirty]) - - // Whenever the dialog opens, start from the latest saved settings - React.useEffect(() => { - if (open) { - setLocalSettings(settings); - setDirty(false); - } - }, [open, settings]) - - // Handle alignment type change - const handleAlignmentTypeChange = (event: SelectChangeEvent) => { - const newType = event.target.value as AlignmentType; - setLocalSettings((prev) => ({ - ...prev, - alignmentType: newType, - })); - setDirty(true); - } - - const handleCancel = () => { - setLocalSettings(settings); - setDirty(false); - onClose(); - } - - const handleSave = () => { - onSave(localSettings); - setDirty(false); - onClose(); - } - - return ( - - - - - Alignment type - - - Choose between global (Needleman-Wunsch) and local (Smith-Waterman) alignment algorithms. - - - - - - - - ) -} diff --git a/src/client/src/components/DialogQuerySettings.tsx b/src/client/src/components/DialogQuerySettings.tsx deleted file mode 100644 index 5bd1845..0000000 --- a/src/client/src/components/DialogQuerySettings.tsx +++ /dev/null @@ -1,81 +0,0 @@ -import React from "react"; -import Box from "@mui/material/Box"; -import FormControl from "@mui/material/FormControl"; -import FormControlLabel from "@mui/material/FormControlLabel"; -import Radio from "@mui/material/Radio"; -import RadioGroup from "@mui/material/RadioGroup"; -import Stack from "@mui/material/Stack"; -import Typography from "@mui/material/Typography"; -import { DialogWindow } from "../components/DialogWindow"; -import { QuerySearchSpace, QuerySettings } from "../features/session/types"; - -type DialogQuerySettingsProps = { - open: boolean; - onClose: () => void; - settings: QuerySettings; - onSave: (newSettings: QuerySettings) => void; -} - -export const DialogQuerySettings: React.FC = ({ - open, - onClose, - settings, - onSave, -}) => { - const [localSettings, setLocalSettings] = React.useState(settings); - const [dirty, setDirty] = React.useState(false); - - React.useEffect(() => { - setLocalSettings(settings); - }, [settings]) - - const handleSave = () => { - onSave(localSettings); - onClose(); - } - - const handleSearchSpaceChange = (value: QuerySearchSpace) => { - setLocalSettings((prev) => ({ - ...prev, - searchSpace: value, - })); - setDirty(true); - } - - return ( - - - - - Search space - - - Choose whether to search against both compounds and biosynthetic gene clusters (BGCs), or limit the search to only one of these categories. - Search space impacts both the results returned and the enrichment calculations. Some annotations are only available for one type of item and - will appear as significant only when that item type is included in the search space (e.g., chemical classes for compounds). - - - handleSearchSpaceChange(e.target.value as QuerySearchSpace)} - > - } label="Full (compounds & BGCs)" /> - } label="Compounds only" /> - } label="BGCs only" /> - - - - - - ) -} diff --git a/src/client/src/components/DialogViewItem.tsx b/src/client/src/components/DialogViewItem.tsx deleted file mode 100644 index ccdbfa2..0000000 --- a/src/client/src/components/DialogViewItem.tsx +++ /dev/null @@ -1,234 +0,0 @@ -import React from "react"; -import Alert from "@mui/material/Alert"; -import Stack from "@mui/material/Stack"; -import Typography from "@mui/material/Typography"; -import FormControl from "@mui/material/FormControl"; -import Select from "@mui/material/Select"; -import MenuItem from "@mui/material/MenuItem"; -import MuiLink from "@mui/material/Link"; -import CircularProgress from "@mui/material/CircularProgress"; -import { useNotifications } from "./NotificationProvider"; -import { DialogWindow } from "../components/DialogWindow"; -import { SessionItem } from "../features/session/types"; -import { SvgViewer } from "../components/SvgViewer"; -import { drawCompoundItem, drawGeneClusterItem } from "../features/drawing/api"; - -type DialogViewItemProps = { - open: boolean; - item?: SessionItem | null; - onClose: () => void; -} - -export const DialogViewItem: React.FC = ({ - open, - item, - onClose, -}) => { - const { pushNotification } = useNotifications(); - const [loading, setLoading] = React.useState(false); - const [errorMsg, setErrorMsg] = React.useState(null); - - const [initializedItemId, setInitializedItemId] = React.useState(null); - const [selectPrimarySequenceId, setSelectPrimarySequenceId] = React.useState(""); - const [svg, setSvg] = React.useState(null); - - const generateSvg = React.useCallback(async (primarySequenceId: string) => { - if (!item) { - setSvg(null); - setErrorMsg(null); - return; - } - - // Get Item primarySequence for primarySequenceId - const primarySequence = item.primarySequences?.find(seq => seq.id === primarySequenceId); - if (!primarySequence) { - pushNotification("Selected primary sequence not found in item.", "error"); - setSvg(null); - setErrorMsg(null); - return; - } - - setLoading(true); - - try { - if (item.kind === "compound") { - const taggedParentSmiles = item.taggedSmiles; - - // Check if taggedParentSmiles is available - if (!taggedParentSmiles) { - pushNotification("No tagged SMILES available for this compound item.", "error"); - setSvg(null); - return; - } - - // Call drawing API - const drawingSvg = await drawCompoundItem( - taggedParentSmiles, - primarySequence - ); - setSvg(drawingSvg); - setErrorMsg(null); - } else if (item.kind === "gene_cluster") { - // Call drawing API - const drawingSvg = await drawGeneClusterItem(item.fileContent || ""); - setSvg(drawingSvg); - setErrorMsg(null); - } else { - const errorMsg = "SVG drawing not supported for this item type."; - pushNotification(errorMsg, "error"); - setSvg(null); - setErrorMsg(errorMsg); - return; - } - } catch (error) { - let errorMsg = "Error generating SVG drawing"; - const errorBody = (error as any)?.body as string | undefined - if (errorBody) { - try { - const parsed = JSON.parse(errorBody); - if (typeof parsed?.error === "string") errorMsg = `${errorMsg}: ${parsed.error}`; - } catch (e) { - errorMsg = `${errorMsg}.` - } - } - pushNotification(errorMsg, "error"); - setSvg(null); - setErrorMsg(errorMsg); - } finally { - setLoading(false); - } - }, [item]); - - const initializePrimarySequence = React.useCallback(() => { - if (item && item.primarySequences && item.primarySequences.length > 0) { - const firstSeqId = item.primarySequences[0].id; - setSelectPrimarySequenceId(firstSeqId); - generateSvg(firstSeqId); - } else { - setSelectPrimarySequenceId(""); - setSvg(null); - setErrorMsg(null); - } - }, [item, generateSvg]); - - // Initialize primary sequence selection when item changes - // Also avoid re-initializing if the same item is passed again - React.useEffect(() => { - if (!item) { - setInitializedItemId(null); - setSelectPrimarySequenceId(""); - setSvg(null); - setErrorMsg(null); - return; - } - - // No redraw if same item - if (initializedItemId === item.id) { - return; - } - - // New item is passed - setInitializedItemId(item.id); - initializePrimarySequence(); - }, [item, initializedItemId, initializePrimarySequence]); - - const handlePrimarySequenceChange = (sequenceId: string) => { - setSelectPrimarySequenceId(sequenceId); - generateSvg(sequenceId); - } - - return ( - - { loading ? ( - - - - ) : errorMsg ? ( - {errorMsg} - ) : !item ? ( - - No item selected. - - ) : svg === null || svg.length === 0 ? ( - - No SVG drawing available for this item. - - ) : item.kind === "compound" ? ( - - - To get started, select a primary sequence to map onto the input compound structure. - A downloadable SVG will be generated showing the mapping below the selector upon successful mapping. - All structures in this view are drawn using  - - PIKAChU - -  . Please cite PIKAChU if you use these drawings in your work. - - - - - {svg && ( - {}} - onElementClick={() => {}} - height={600} - /> - )} - - ) : item.kind === "gene_cluster" ? ( - - - - RAIChU - -  is used to generate the gene cluster visualization below. - The SVG rendered below is included for informative purposes and does not reflect the exact encoding mechanism of the gene clustering by RetroMol. - This viewer serves as a wrapper around the RAIChU SVG generation API. For more information on RAIChU, please refer to the link provided. - Substrate predictions for non-ribosomal peptide (NRP) A-domains and polyketide synthase (PKS) acyltransferase (AT)-domains are taken directly from the  - - antiSMASH - -  output. - This wrapper viewer around RAIChU currently does not use any of the PARAS substrate specificity predictions provided and used by RetroMol for similarity searches. - Additionally, this wrapper view around RAIChU only provided a visualization of the full region readout, not of individual candidate clusters found within the region. - Please cite RAIChU if you use these drawings in your work. - - {svg && ( - {}} - onElementClick={() => {}} - height={600} - /> - )} - - ) : ( - - No preview available for this item type. - - )} - - - ) -} diff --git a/src/client/src/components/ItemKindChip.tsx b/src/client/src/components/ItemKindChip.tsx deleted file mode 100644 index 19379a7..0000000 --- a/src/client/src/components/ItemKindChip.tsx +++ /dev/null @@ -1,27 +0,0 @@ -import React from "react"; -import Chip from "@mui/material/Chip"; - -interface ItemKindChipProps { - itemKind: string; -} - -export const ItemKindChip: React.FC = ({ itemKind }) => { - const labelMap: Record = { - "compound": "Compound", - "gene_cluster": "BGC", - }; - - const colorMap: Record = { - "compound": "primary", - "gene_cluster": "secondary", - }; - - return ( - - ); -} diff --git a/src/client/src/components/ScoreBar.tsx b/src/client/src/components/ScoreBar.tsx deleted file mode 100644 index 6471fa0..0000000 --- a/src/client/src/components/ScoreBar.tsx +++ /dev/null @@ -1,64 +0,0 @@ -import React from "react"; -import Box from "@mui/material/Box"; -import Stack from "@mui/material/Stack"; -import Tooltip from "@mui/material/Tooltip"; -import Typography from "@mui/material/Typography"; -import { alpha } from "@mui/material/styles"; - -type ScoreBarProps = { - score: number; // value between 0 and 1 - getTooltipTitle: (value: number) => string; - getScoreColor: (theme: any, value: number) => string; - tooltipPosition?: "top" | "bottom" | "left" | "right"; - width?: number | string; - height?: number | string; -} - -export const ScoreBar: React.FC = ({ - score, - getTooltipTitle, - getScoreColor, - tooltipPosition = "top", - width = "100%", - height = 16 -}) => { - const value = Math.max(0, Math.min(1, score)); - - return ( - - - - { - const t = theme.vars || theme; - return { - height: height, - width: width, - flexShrink: 0, - borderRadius: 999, - overflow: "hidden", - backgroundColor: alpha("#000000", 0.1), - ...theme.applyStyles("dark", { backgroundColor: alpha("#ffffff", 0.1) }) - }}} - > - { - const t = theme.vars || theme; - const barColor = getScoreColor(t, value); - return { - width: `${value * 100}%`, - height: "100%", - backgroundColor: barColor, - transition: "width 200ms ease-out", - }; - }} - /> - - - {(value * 100).toFixed(1)}% - - - - - ) -} diff --git a/src/client/src/components/SvgViewer.tsx b/src/client/src/components/SvgViewer.tsx deleted file mode 100644 index 9f9f696..0000000 --- a/src/client/src/components/SvgViewer.tsx +++ /dev/null @@ -1,230 +0,0 @@ -import React from "react"; -import Box from "@mui/material/Box"; -import Paper from "@mui/material/Paper"; -import Slider from "@mui/material/Slider"; -import Stack from "@mui/material/Stack"; -import Tooltip from "@mui/material/Tooltip"; -import Typography from "@mui/material/Typography"; -import AddIcon from "@mui/icons-material/Add"; -import RemoveIcon from "@mui/icons-material/Remove"; -import FitScreenIcon from "@mui/icons-material/FitScreen"; -import DownloadIcon from "@mui/icons-material/Download"; - -export interface SvgViewerProps { - svg: string; // raw SVG string - initialZoom?: number; // initial zoom level (default is 1) - minZoom?: number; - maxZoom?: number; - zoomStep?: number; - onZoomChange?: (zoom: number) => void; - onElementClick?: (info: { - targetId: string | null; - clientX: number; - clientY: number; - }) => void; - height?: number | string; // optional fixed height - downloadFileName?: string; -} - -export const SvgViewer: React.FC = ({ - svg, - initialZoom = 1, - minZoom = 0.25, - maxZoom = 4, - zoomStep = 0.25, - onZoomChange, - onElementClick, - height = 400, - downloadFileName = "image.svg", -}) => { - const [zoom, setZoom] = React.useState(initialZoom); - const [isPanning, setIsPanning] = React.useState(false); - const [pan, setPan] = React.useState({ x: 0, y: 0 }); - const [panStart, setPanStart] = React.useState<{ x: number; y: number } | null>(null); - const [panOrigin, setPanOrigin] = React.useState<{ x: number; y: number }>({ x: 0, y: 0 }); - - const containerRef = React.useRef(null); - - const handleFit = () => { - // naive fit: reset zoom and pan - setZoom(initialZoom); - setPan({ x: 0, y: 0 }); - setPanOrigin({ x: 0, y: 0 }); - } - - // Reset zoom and pan whenever SVG changes - React.useEffect(() => { - handleFit(); - }, [svg, initialZoom]); - - // Notify parent about zoom changes - React.useEffect(() => { - onZoomChange?.(zoom); - }, [zoom, onZoomChange]) - - const clampZoom = (z: number) => Math.min(maxZoom, Math.max(minZoom, z)); - - const handleZoomIn = () => { - setZoom((z) => clampZoom(z + zoomStep)); - } - - const handleZoomOut = () => { - setZoom((z) => clampZoom(z - zoomStep)); - } - - const handleZoomSliderChange = (event: Event, value: number | number[]) => { - if (typeof value === "number") { - setZoom(value); - } - } - - const handleMouseDown = (e: React.MouseEvent) => { - if (e.button !== 0) return; - setIsPanning(true); - setPanStart({ x: e.clientX, y: e.clientY }); - setPanOrigin({ x: pan.x, y: pan.y }); - } - - const handleMouseMove = (e: React.MouseEvent) => { - if (!isPanning || !panStart) return; - const dx = e.clientX - panStart.x; - const dy = e.clientY - panStart.y; - setPan({ x: panOrigin.x + dx, y: panOrigin.y + dy }); - } - - const endPan = () => { - setIsPanning(false); - setPanStart(null); - } - - const handleClick = (e: React.MouseEvent) => { - if (!onElementClick) return; - - // Do not treat clicks that were actual pans as element clicks - if (isPanning || panStart !== null) return; - - const target = e.target as HTMLElement | null; - const targetId = target?.id ?? null; - - onElementClick({ - targetId, - clientX: e.clientX, - clientY: e.clientY, - }); - } - - const handleDownload = () => { - const blob = new Blob([svg], { type: "image/svg+xml;charset=utf-8" }); - const url = URL.createObjectURL(blob); - const link = document.createElement("a"); - link.href = url; - link.download = downloadFileName; - document.body.appendChild(link); - link.click(); - document.body.removeChild(link); - URL.revokeObjectURL(url); - } - - return ( - - - {/* Toolbar */} - `1px solid ${theme.palette.divider}`, - }} - > - - - { - // prevent zooming out beyond minZoom - if (zoom <= minZoom) { return; } - handleZoomOut(); - }} - /> - - - - { - // prevent zooming in beyond maxZoom - if (zoom >= maxZoom) { return; } - handleZoomIn(); - }} - /> - - - - - - - - - - {Math.round(zoom * 100)}% - - - - - {/* Viewport */} - theme.palette.background.default, - userSelect: isPanning ? "none" : "auto", - }} - > - - - - - ) -} diff --git a/src/client/src/components/ViewEmbeddingSpace.tsx b/src/client/src/components/ViewEmbeddingSpace.tsx deleted file mode 100644 index 1a2a5a9..0000000 --- a/src/client/src/components/ViewEmbeddingSpace.tsx +++ /dev/null @@ -1,392 +0,0 @@ -import React from "react"; -import Box from "@mui/material/Box"; -import Alert from "@mui/material/Alert"; -import Tooltip from "@mui/material/Tooltip"; -import Typography from "@mui/material/Typography"; -import CircularProgress from "@mui/material/CircularProgress"; -import DownloadIcon from "@mui/icons-material/Download"; -import { ScatterChart } from "@mui/x-charts/ScatterChart"; -import Stack from "@mui/material/Stack"; -import ToggleButton from "@mui/material/ToggleButton"; -import ToggleButtonGroup from "@mui/material/ToggleButtonGroup"; -import Divider from "@mui/material/Divider"; -import { useNotifications } from "../components/NotificationProvider"; -import { EmbeddingVisualizationType, Session, SessionItem } from "../features/session/types"; -import { EmbeddingPoint } from "../features/views/types"; -import { getEmbeddingSpace } from "../features/views/api"; - -const labelMap: Record = { - "compound": "Compound", - "gene_cluster": "BGC", -} - -interface ViewEmbeddingSpaceProps { - session: Session; - setSession: (updated: (prev: Session) => Session) => void; -} - -export const ViewEmbeddingSpace: React.FC = ({ - session, - setSession, -}) => { - const { pushNotification } = useNotifications(); - - const [points, setPoints] = React.useState(null); - const [loading, setLoading] = React.useState(false); - const [error, setError] = React.useState(null); - - // Container ref + size for square embedding space - const embeddingSpaceContainerRef = React.useRef(null); - const [embeddingSpaceSize, setEmbeddingSpaceSize] = React.useState(null); - - // Map item ID -> name for tooltips/labels - const parentById = React.useMemo(() => { - const map = new Map(); - for (const item of session.items) { - map.set(item.id, item); - } - return map; - }, [session.items]); - - const embeddingMethod = session.settings.embeddingVisualizationType; - - // Map child ID -> score for tooltips - const scoreByChildId = React.useMemo(() => { - const map = new Map(); - for (const item of session.items) { - for (const fp of item.retrofingerprints || []) { - map.set(fp.id, fp.score); - } - } - return map; - }, [session.items]); - - // Construct scatter series from points - const scatterSeries = React.useMemo(() => { - if (!points) return []; - - // kind -> array of points for that kin - const byKind = new Map(); - - for (const p of points) { - const parent = parentById.get(p.parent_id) ?? null; - const parentName = parent ? parent.name : "unknown"; - const childIds = parent?.retrofingerprints.map((fp) => fp.id) || []; - const childIdx = childIds.indexOf(p.child_id); - const key = p.kind || "unknown"; - - // Initialize kind entry if not present - if (!byKind.has(key)) { - byKind.set(key, { label: key, data: [] }); - } - - // Add point to appropriate kind series - byKind.get(key)!.data.push({ - x: p.x, - y: p.y, - id: p.child_id, - name: `Readout ${childIdx >= 0 ? `${childIdx + 1}` : ""} ${parentName} `, - }); - } - - return Array.from(byKind.values()); - }, [points, parentById]); - - // Download scatter series as TSV - const handleDownloadEmbeddingData = () => { - if (!points || points.length === 0) return; - - const header = ["id", "kind", "nane", "x", "y"]; - const rows = points.map((p) => { - const parent = parentById.get(p.parent_id) ?? null; - const parentName = parent ? parent.name : "unknown"; - const childIds = parent?.retrofingerprints.map((fp) => fp.id) || []; - const childIdx = childIds.indexOf(p.child_id); - return [ - p.child_id, - p.kind, - `Readout ${childIdx >= 0 ? `${childIdx + 1}` : ""} ${parentName} `, - p.x.toString(), - p.y.toString(), - ]; - }); - const tsvContent = [header, ...rows] - .map((row) => row.join("\t")) - .join("\n"); - const blob = new Blob([tsvContent], { type: "text/tab-separated-values" }); - const url = URL.createObjectURL(blob); - const link = document.createElement("a"); - link.href = url; - link.download = `embedding_space_${session.sessionId}.tsv`; - document.body.appendChild(link); - link.click(); - document.body.removeChild(link); - URL.revokeObjectURL(url); - }; - - // Changes only when the contents of items changes, not just because of polling - const itemsKey = React.useMemo(() => { - if (!session.items || session.items.length === 0) return ""; - - return session.items - .map((item) => { - const fpIds = (item.retrofingerprints || []).map((fp) => fp.id).join(","); - return `${item.id}:${item.status}:${item.updatedAt}:${fpIds}`; - }) - .join("|") - }, [session.items]); - - // Filtered items with retrofingerprints - const itemsWithFingerprints = React.useMemo(() => { - return session.items?.filter((item) => item.retrofingerprints && item.retrofingerprints.length > 0) || []; - }, [session.items]); - - // Fetch embedding space points when session changes - React.useEffect(() => { - // If less than 2 items, no embedding space - if (!itemsWithFingerprints || itemsWithFingerprints.length < 2) { - setPoints(null); - setLoading(false); - setError(`At least three readouts are required to view the embedding space.`); - return; - } - - let cancelled = false; - - setLoading(true); - - getEmbeddingSpace(session.sessionId, itemsWithFingerprints, embeddingMethod) - .then((pts) => { - if (cancelled) return; - setPoints(pts); - }) - .catch((err) => { - if (cancelled) return; - pushNotification(`Failed to load embedding space: ${err.message}`, "error"); - setError(err.message); - }) - .finally(() => { - if (!cancelled) { - setLoading(false); - } - }); - - return () => { - cancelled = true; - } - }, [itemsKey, embeddingMethod]); // rerun when items change or method toggles - - // Adjust embedding space size on container resize - React.useLayoutEffect(() => { - if (typeof ResizeObserver === "undefined") return; - - const element = embeddingSpaceContainerRef.current; - if (!element) return; - - const observer = new ResizeObserver((entries) => { - for (const entry of entries) { - const width = entry.contentRect.width; - setEmbeddingSpaceSize((prev) => (prev === width ? prev : width)); - } - }); - - observer.observe(element); - return () => observer.disconnect(); - }, [points]) - - // No items at all - if (!itemsWithFingerprints || itemsWithFingerprints.length === 0) { - return ( - - - No items in this session yet. Add compounds or gene clusters to see the embedding space. - - - ) - } - - // Calculate axes limits and add some padding - const axisLimits = React.useMemo(() => { - if (!points || points.length === 0) { - return { xMin: undefined, xMax: undefined, yMin: undefined, yMax: undefined }; - } - - let minX = points[0].x; - let maxX = points[0].x; - let minY = points[0].y; - let maxY = points[0].y; - for (const p of points) { - if (p.x < minX) minX = p.x; - if (p.x > maxX) maxX = p.x; - if (p.y < minY) minY = p.y; - if (p.y > maxY) maxY = p.y; - } - - const xRange = maxX - minX; - const yRange = maxY - minY; - - // Use the larger range for both axes - const baseRange = Math.max(xRange, yRange) || 1; // avoid 0 - const paddingFactor = 0.2; // 20% padding - const fullRange = baseRange * (1 + paddingFactor); - - const xCenter = (minX + maxX) / 2; - const yCenter = (minY + maxY) / 2; - - const halfRange = fullRange / 2; - - return { - xMin: xCenter - halfRange, - xMax: xCenter + halfRange, - yMin: yCenter - halfRange, - yMax: yCenter + halfRange, - }; - }, [points]); - - const handleEmbeddingToggle = (_: React.MouseEvent, next: EmbeddingVisualizationType) => { - if (!next || next === embeddingMethod) return; - setSession((prev) => ({ - ...prev, - settings: { - ...prev.settings, - embeddingVisualizationType: next, - }, - })); - }; - - return ( - - {loading ? ( - - - - ) : error ? ( - - - {error} - - - ) : !loading && !error && points && points.length > 0 ? ( - - - - - - - - PCA - UMAP - - - - {embeddingSpaceSize && ( - ({ - label: labelMap[group.label] || group.label, - data: group.data, - markerSize: 5, - valueFormatter: (_value, context) => { - const idx = context.dataIndex; - const point = group.data[idx]; - const score = scoreByChildId.get(point.id); - return score != null - ? `${point.name} (score ${(score * 100).toFixed(1)}%)` - : point.name; - }, - }))} - xAxis={[{ - label: "Dimension 1", - disableLine: false, - disableTicks: true, - min: axisLimits.xMin, - max: axisLimits.xMax, - }]} - yAxis={[{ - label: "Dimension 2", - disableLine: false, - disableTicks: true, - min: axisLimits.yMin, - max: axisLimits.yMax, - }]} - grid={{ horizontal: false, vertical: false }} - sx={{ - // Hide any remaining tick labels just in case - "& .MuiChartsAxis-tickLabel": { - display: "none", - }, - }} - /> - )} - - - - - - - - How to read this plot - - - Each point represents a biosynthetic fingerprint readout. Distances reflect structural similarity in the chosen - embedding space (PCA or UMAP). Switch methods to explore how local neighborhoods change with a linear vs. - non-linear reduction. - - - Hover a point to see which import it belongs to and its score. Points from compounds and BGCs are grouped - separately so mixed datasets stay interpretable. - - - - ) : ( - - No embedding points to display. - - )} - - ) -} diff --git a/src/client/src/components/ViewMsa.tsx b/src/client/src/components/ViewMsa.tsx deleted file mode 100644 index 00c45f5..0000000 --- a/src/client/src/components/ViewMsa.tsx +++ /dev/null @@ -1,1102 +0,0 @@ -import React from "react"; -import Box from "@mui/material/Box"; -import Chip from "@mui/material/Chip"; -import Stack from "@mui/material/Stack"; -import Tooltip from "@mui/material/Tooltip"; -import Typography from "@mui/material/Typography"; -import Button from "@mui/material/Button"; -import ZoomOutIcon from "@mui/icons-material/ZoomOut"; -import ZoomInIcon from "@mui/icons-material/ZoomIn"; -import RefreshIcon from "@mui/icons-material/Refresh"; -import PaletteIcon from "@mui/icons-material/Palette"; -import VisibilityOffIcon from "@mui/icons-material/VisibilityOff"; -import SettingsIcon from "@mui/icons-material/Settings"; -import DownloadIcon from "@mui/icons-material/Download"; -import { useNotifications } from "./NotificationProvider"; -import { Session, MsaSettings, MsaState, MsaSequence, PrimarySequence } from "../features/session/types"; -import { runMsa } from "../features/views/api"; -import { DialogColorPalette } from "./DialogColorPalette"; -import { DialogMsaSettings } from "./DialogMsaSettings"; - -// Imports for dragging and dropping rows and motifs -import { DndContext, DragEndEvent } from "@dnd-kit/core"; -import { - SortableContext, - useSortable, - verticalListSortingStrategy, - horizontalListSortingStrategy, -} from "@dnd-kit/sortable"; -import { CSS } from "@dnd-kit/utilities"; -import DragIndicatorIcon from "@mui/icons-material/DragIndicator"; - -export const PROTECTED_NAME_TO_CODE: Record = { - ALANINE: "ALA", - CYSTEINE: "CYS", - ASPARTICACID: "ASP", - GLUTAMICACID: "GLU", - PHENYLALANINE: "PHE", - GLYCINE: "GLY", - HISTIDINE: "HIS", - ISOLEUCINE: "ILE", - LYSINE: "LYS", - LEUCINE: "LEU", - METHIONINE: "MET", - ASPARAGINE: "ASN", - PROLINE: "PRO", - GLUTAMINE: "GLN", - ARGININE: "ARG", - SERINE: "SER", - THREONINE: "THR", - VALINE: "VAL", - TRYPTOPHAN: "TRP", - TYROSINE: "TYR", -}; - - -export const makeToDisplayName = (protectedNameToCode: Record) => { - const norm = (s: string) => s.replace(/[^a-z0-9]/gi, "").toUpperCase(); - - // normalize protected names + reserve protected codes - const prot = new Map( - Object.entries(protectedNameToCode).map(([k, v]) => [norm(k), norm(v)]) - ); - const reserved = new Set(Array.from(prot.values())); // e.g. ALA, GLY - const used = new Set(reserved); // block others from taking them - const cache = new Map(); // per-name stability - - const candidates = (s: string) => { - const out: string[] = []; - if (s.length >= 3) { - out.push(s.slice(0, 3)); // ABC - for (let i = 3; i < s.length; i++) out.push(s[0] + s[1] + s[i]); // AB? - for (let i = 2; i < s.length; i++) out.push(s[0] + s[i - 1] + s[i]); // A?? - } - if (s.length >= 2) out.push(s.slice(0, 2)); // AB - if (s.length >= 1) out.push(s[0]); // A - // de-dupe in order - const seen = new Set(); - return out.filter(c => c.length <= 3 && !seen.has(c) && (seen.add(c), true)); - }; - - return (name: string | null): string | null => { - if (!name) return null; - const s = norm(name); - if (!s) return null; - - const hit = cache.get(s); - if (hit) return hit; - - // ONLY protected full names get protected 3-letter codes - const canonical = prot.get(s); - if (canonical) { - cache.set(s, canonical); - return canonical; - } - - // don’t let non-protected names steal reserved AA codes - for (const c of candidates(s)) { - if (!used.has(c)) { - used.add(c); - cache.set(s, c); - return c; - } - } - return null; - }; -}; - -const toHex = (v: number) => v.toString(16).padStart(2, "0"); - -const hslToRgb = (h: number, s: number, l: number) => { - // h: 0–360, s/l: 0–1 - const c = (1 - Math.abs(2 * l - 1)) * s; - const x = c * (1 - Math.abs(((h / 60) % 2) - 1)); - const m = l - c / 2; - const pick = (hp: number) => - hp < 60 ? [c, x, 0] : - hp < 120 ? [x, c, 0] : - hp < 180 ? [0, c, x] : - hp < 240 ? [0, x, c] : - hp < 300 ? [x, 0, c] : - [c, 0, x]; - const [r1, g1, b1] = pick(h); - return [ - Math.round((r1 + m) * 255), - Math.round((g1 + m) * 255), - Math.round((b1 + m) * 255), - ]; -}; - -const normalizeColor = (raw: string | undefined | null) => { - if (!raw) return "#f5f5f5"; - const c = raw.trim(); - - // Already hex (#rgb, #rrggbb, #rrggbbaa) - if (/^#([0-9a-f]{3}|[0-9a-f]{6}|[0-9a-f]{8})$/i.test(c)) return c; - - // rgba()/rgb() - const rgba = c.match(/^rgba?\(\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)(?:\s*,\s*([\d.]+))?\s*\)$/i); - if (rgba) { - const [r, g, b, aRaw] = rgba.slice(1).map(Number); - const a = isNaN(aRaw) ? 1 : Math.max(0, Math.min(1, aRaw)); - // composite over white to avoid transparency issues - const blend = (v: number) => Math.round((1 - a) * 255 + a * v); - return `#${toHex(blend(r))}${toHex(blend(g))}${toHex(blend(b))}`; - } - - // hsla()/hsl() - const hsla = c.match(/^hsla?\(\s*([\d.]+)(?:deg)?\s*,\s*([\d.]+)%\s*,\s*([\d.]+)%(?:\s*,\s*([\d.]+))?\s*\)$/i); - if (hsla) { - const h = Number(hsla[1]); - const s = Number(hsla[2]) / 100; - const l = Number(hsla[3]) / 100; - const a = hsla[4] === undefined ? 1 : Math.max(0, Math.min(1, Number(hsla[4]))); - const [r, g, b] = hslToRgb(h, s, l); - const blend = (v: number) => Math.round((1 - a) * 255 + a * v); - return `#${toHex(blend(r))}${toHex(blend(g))}${toHex(blend(b))}`; - } - - // Fallback - return "#f5f5f5"; -}; - -const escapeSvgText = (value: string) => - value - .replace(/&/g, "&") - .replace(//g, ">") - .replace(/"/g, """) - .replace(/'/g, "'"); - -interface SortableRowProps { - row: MsaSequence; - labelWidth: number; - motifWidth: number; - zoom: number; - centerId: string | null; - session: Session; - onSetCenter: (id: string) => void; - onHideRow: (id: string) => void; - children: React.ReactNode; // the motifs -} - -const SortableRow: React.FC = ({ - row, - labelWidth, - centerId, - motifWidth, - zoom, - session, - onSetCenter, - onHideRow, - children, -}) => { - const rowSortableId = `row:${row.id}`; - const { attributes, listeners, setNodeRef, transform, transition } = useSortable({ id: rowSortableId }); - - const isCenter = centerId === row.id; - - return ( - <> - onSetCenter(row.id as string)} - sx={{ - transform: transform ? CSS.Transform.toString(transform) : undefined, - transition, - m: 0, - p: 0, - height: 20, - fontWeight: 600, - zIndex: 100, - cursor: "pointer", - display: "flex", - alignItems: "center", - justifyContent: "space-between", - px: 1, - borderRadius: 1, - backgroundColor: isCenter ? "warning.light" : "background.paper", - border: isCenter ? "1px solid" : "1px solid transparent", - borderColor: isCenter ? "warning.main" : "transparent", - }} - > - e.stopPropagation()} // don't trigger center selection on drag - > - - - - - - {row.name || row.id} - - - - {isCenter && ( - - )} - { - e.stopPropagation(); - onHideRow(row.id as string); - }} - /> - - - - - {/* Motifs */} - {children} - - {/* Row line */} - - - ) -} - -interface SortableMotifCellProps { - id: string; - children: React.ReactNode; -} - -const SortableMotifCell: React.FC = ({ id, children }) => { - const { attributes, listeners, setNodeRef, transform, transition } = useSortable({ - id, - animateLayoutChanges: () => false, - }) - - return ( - - {children} - - ) -} - -interface ViewMsaProps { - session: Session; - setSession: (updated: (prev: Session) => Session) => void; -} - -export const ViewMsa: React.FC = ({ - session, - setSession, -}) => { - const { pushNotification } = useNotifications(); - - const [zoom, setZoom] = React.useState(1); - const handleZoomIn = () => setZoom(z => Math.min(z + 0.1, 3)); - const handleZoomOut = () => setZoom(z => Math.max(z - 0.1, 0.5)); - const [colorPaletteDialogOpen, setColorPaletteDialogOpen] = React.useState(false); - const [msaSettingsDialogOpen, setMsaSettingsDialogOpen] = React.useState(false); - - const toDisplayName = React.useMemo( - () => makeToDisplayName(PROTECTED_NAME_TO_CODE), - [session.sessionId] - ) - - const handleColorPaletteSave = (newMap: Record) => { - setSession(prev => ({ - ...prev, - settings: { - ...prev.settings, - motifColorPalette: newMap, - }, - })); - pushNotification("Palette saved", "success"); - }; - - const handleMsaSettingsSave = (newSettings: MsaSettings) => { - setSession(prev => ({ - ...prev, - settings: { - ...prev.settings, - msaSettings: newSettings, - }, - })); - pushNotification("MSA settings saved", "success"); - } - - const emptyMsaState = React.useCallback((): MsaState => ({ - aligned: false, - centerId: null, - sequences: [], - }), []); - - const padSequence = React.useCallback( - (sequence: PrimarySequence["sequence"], targetLength: number, baseId: string) => { - // Always clone, so callers never share references - const seq = sequence.map((m, idx ) => ({ - ...m, - // keep existing cellKey if present, otherwise initialize - cellKey: (m as any).cellKey ?? `${baseId}-motif-${m.id ?? "idx"}-${idx}`, - })); - - if (targetLength <= 0) return []; - if (seq.length > targetLength) { - return seq.slice(0, targetLength); - } - if (seq.length >= targetLength) return seq; - - const paddingNeeded = targetLength - seq.length; - const paddingMotifs = Array.from({ length: paddingNeeded }, (_, i) => ({ - id: `pad-${baseId}-${i}`, - name: null, - displayName: null, - tags: [], - smiles: null, - morganfingerprint2048r2: null, - // Give pads a unique cellKey as well - cellKey: `pad-${baseId}-${i}`, - })); - - return [...seq, ...paddingMotifs]; - }, - [] - ); - - const normalizeMsaState = React.useCallback(() => { - setSession(prev => { - const prevState = prev.msaState ?? emptyMsaState(); - const order = new Map(prevState.sequences.map((seq, idx) => [seq.id, idx])); - - const collected: MsaSequence[] = []; - prev.items.forEach(item => { - item.primarySequences.forEach(ps => { - if (ps.sequence.length === 0) return; - const seqId = `${item.id}::${ps.id}`; - const existing = prevState.sequences.find(s => s.id === seqId); - const baseSeq = existing?.sequence ?? ps.sequence; - collected.push({ - id: seqId, - itemId: item.id, - primarySequenceId: ps.id, - name: item.name, - // sequence: existing?.sequence ?? ps.sequence, - sequence: baseSeq.map((m, idx) => ({ - ...m, - // keep existing cellKey if present, otherwise initialize - cellKey: (m as any).cellKey ?? `${seqId}-motif-${m.id ?? "idx"}-${idx}`, - })), - hidden: existing?.hidden ?? false, - }); - }); - }); - - // Preserve prior order; new sequences append - collected.sort((a, b) => { - const idxA = order.has(a.id) ? order.get(a.id)! : Number.MAX_SAFE_INTEGER; - const idxB = order.has(b.id) ? order.get(b.id)! : Number.MAX_SAFE_INTEGER; - return idxA - idxB; - }); - - if (collected.length === 0) { - const nextState = emptyMsaState(); - if (JSON.stringify(prevState) === JSON.stringify(nextState)) { - return prev; - } - return { ...prev, msaState: nextState }; - } - - const maxLength = Math.max(...collected.map(seq => seq.sequence.length), 0); - const padded = collected.map(seq => ({ - ...seq, - sequence: padSequence(seq.sequence, maxLength, seq.id), - })); - - const nextCenter = prevState.centerId && padded.some(s => s.id === prevState.centerId) - ? prevState.centerId - : null; - - const nextState: MsaState = { - ...prevState, - sequences: padded, - centerId: nextCenter, - }; - - // Lightweight equality check - const sameLength = - prevState.sequences.length === nextState.sequences.length && - prevState.centerId === nextState.centerId && - prevState.aligned === nextState.aligned; - if (sameLength) { - let identical = true; - for (let i = 0; i < nextState.sequences.length; i++) { - const prevSeq = prevState.sequences[i]; - const nextSeq = nextState.sequences[i]; - if (!prevSeq || prevSeq.id !== nextSeq.id || prevSeq.hidden !== nextSeq.hidden || prevSeq.name !== nextSeq.name) { - identical = false; - break; - } - if (prevSeq.sequence.length !== nextSeq.sequence.length) { - identical = false; - break; - } - if (JSON.stringify(prevSeq.sequence) !== JSON.stringify(nextSeq.sequence)) { - identical = false; - break; - } - } - if (identical) return prev; - } - - return { ...prev, msaState: nextState }; - }); - }, [emptyMsaState, padSequence, setSession]); - - const itemsSignature = React.useMemo(() => { - // Only include fields that matter for which sequences exist / their raw sequences. - // Avoids firing when backend just updates job status, timestamps, etc. - return session.items - .map(item => { - const seqSig = item.primarySequences - .map(ps => `${ps.id}:${ps.sequence.length}`) - .join("|"); - return `${item.id}:${seqSig}`; - }) - .join(";"); - }, [session.items]); - - React.useEffect(() => { - normalizeMsaState(); - }, [normalizeMsaState, itemsSignature]); // itemsSignature instead of session.items - - const msaState = session.msaState ?? emptyMsaState(); - const msa = msaState.sequences; - const hiddenIds = new Set(msa.filter(seq => seq.hidden).map(seq => seq.id)); - const centerId = msaState.centerId; - const [aligning, setAligning] = React.useState(false); - - const updateMsaState = React.useCallback( - (updater: (state: MsaState) => MsaState) => { - setSession(prev => { - const nextState = updater(prev.msaState ?? emptyMsaState()); - return { ...prev, msaState: nextState }; - }); - }, - [emptyMsaState, setSession] - ); - - const handleDragEnd = (event: DragEndEvent) => { - const { active, over } = event; - if (!over) return; - - const activeId = String(active.id); - const overId = String(over.id); - if (activeId === overId) return; - - // Row-level drag: ids like "row:" - if (activeId.startsWith("row:") && overId.startsWith("row:")) { - const fromRowId = activeId.slice("row:".length); - const toRowId = overId.slice("row:".length); - - updateMsaState(state => { - const sequences = [...state.sequences]; - const fromIndex = sequences.findIndex(s => s.id === fromRowId); - const toIndex = sequences.findIndex(s => s.id === toRowId); - if (fromIndex === -1 || toIndex === -1) return state; - - const [moved] = sequences.splice(fromIndex, 1); - sequences.splice(toIndex, 0, moved); - - return { - ...state, - aligned: false, // manual reorder breaks alignment - sequences, - } - }); - - return; - } - - // Motif level drag depends on motif cellKey ids - // Find which row cells belong to via search - updateMsaState(state => { - let rowIndex = -1; - let fromCol = -1; - let toCol = -1; - - state.sequences.forEach((row, rIdx) => { - const seq = row.sequence as any[]; - const aIdx = seq.findIndex(m => m.cellKey === activeId); - const oIdx = seq.findIndex(m => m.cellKey === overId); - if (aIdx !== -1 && oIdx !== -1) { - rowIndex = rIdx; - fromCol = aIdx; - toCol = oIdx; - } - }) - - if (rowIndex === -1 || fromCol === -1 || toCol === -1) return state; - - const row = state.sequences[rowIndex]; - const seq = [...row.sequence]; - - const [movedMotif] = seq.splice(fromCol, 1); - seq.splice(toCol, 0, movedMotif); - - const newSequences = state.sequences.map((s, idx) => - idx === rowIndex ? { ...s, sequence: seq } : s - ); - - return { - ...state, - aligned: false, // manual reorder breaks alignment - sequences: newSequences, - }; - }) - - } - - const handleHideRow = (id: string) => { - updateMsaState(state => ({ - ...state, - centerId: state.centerId === id ? null : state.centerId, - sequences: state.sequences.map(seq => - seq.id === id ? { ...seq, hidden: true } : seq - ), - })); - }; - - const handleResetHidden = () => { - updateMsaState(state => ({ - ...state, - sequences: state.sequences.map(seq => ({ ...seq, hidden: false })), - })); - }; - - const handleSetCenter = (id: string) => { - updateMsaState(state => ({ - ...state, - centerId: id, - })); - }; - - if (msa.length === 0) { - return ( - - No primary sequences available in session items to display MSA. - - ); - } - - const visibleRows = msa.filter(row => !hiddenIds.has(row.id as string)); - const msaLength = visibleRows.length > 0 - ? Math.max(...visibleRows.map(row => row.sequence.length)) - : 0; - const motifWidth = 50 * zoom; - const labelWidth = 250; - const colTemplate = `${labelWidth}px repeat(${msaLength}, ${motifWidth}px) 1fr`; - - const stripPads = (seq: PrimarySequence["sequence"]) => - seq.filter(motif => !(motif.id ?? "").startsWith("pad-")); - - const handleAlign = async () => { - if (!centerId) { - pushNotification("Please select a center sequence for alignment.", "warning"); - return; - } - - const currentVisible = msa.filter(row => !hiddenIds.has(row.id as string)); - - if (!currentVisible.some(r => r.id === centerId)) { - pushNotification("Center sequence is hidden. Please unhide it before aligning.", "warning"); - return; - } - - if (currentVisible.length < 2) { - pushNotification("At least two sequences must be visible to perform alignment.", "warning"); - return; - } - - // Build a canonical, *ungapped* base sequence per id, - // using msaState if present, falling back to session.items. - const baseById = new Map< - string, - { baseSeq: PrimarySequence["sequence"]; name: string | null; itemId?: string; primarySequenceId?: string } - >(); - - session.items.forEach(item => { - item.primarySequences.forEach(ps => { - const seqId = `${item.id}::${ps.id}`; - const existing = msaState.sequences.find(s => s.id === seqId); - - // Prefer the current msaState sequence (stripped of our pads), - // otherwise fall back to the original primary sequence. - const rawSeq = existing ? stripPads(existing.sequence) : stripPads(ps.sequence); - - baseById.set(seqId, { - // Deep clone: runMsa cannot mutate our state - baseSeq: rawSeq.map(m => ({ ...m })), - name: item.name, - itemId: item.id, - primarySequenceId: ps.id, - }); - }); - }); - - setAligning(true); - try { - const result = await runMsa({ - primarySequences: currentVisible.map(seq => { - const base = baseById.get(seq.id); - if (!base) { - // Shouldn't happen, but be defensive - return { - id: seq.id, - name: seq.name, - sequence: stripPads(seq.sequence).map(m => ({ ...m })), - }; - } - return { - id: seq.id, - name: base.name ?? seq.name, - sequence: base.baseSeq.map(m => ({ ...m })), // clone again for safety - }; - }), - centerId, - msaSettings: session.settings.msaSettings, - }); - - const alignedSequences = result.alignedSequences; - const alignedLength = Math.max(...alignedSequences.map(seq => seq.sequence.length), 0); - - updateMsaState(state => { - const alignedIds = new Set(alignedSequences.map(seq => seq.id)); - const nextSequences: MsaSequence[] = []; - - alignedSequences.forEach(seq => { - const base = baseById.get(seq.id); - const existing = state.sequences.find(s => s.id === seq.id); - - nextSequences.push({ - ...(existing ?? {}), - id: seq.id, - itemId: base?.itemId ?? (existing as MsaSequence | undefined)?.itemId ?? seq.id, - primarySequenceId: base?.primarySequenceId ?? (existing as MsaSequence | undefined)?.primarySequenceId ?? seq.id, - name: base?.name ?? seq.name, - // padSequence now clones internally, so this is a fresh array - sequence: padSequence(seq.sequence, alignedLength, seq.id), - hidden: false, - }); - }); - - // Keep previously hidden sequences (not part of this alignment) - state.sequences - .filter(seq => seq.hidden && !alignedIds.has(seq.id)) - .forEach(seq => { - const base = baseById.get(seq.id); - const rawSeq = base ? base.baseSeq : stripPads(seq.sequence); - nextSequences.push({ - ...seq, - sequence: padSequence(rawSeq, alignedLength, seq.id), - hidden: true, - }); - }); - - return { - ...state, - aligned: true, - centerId, - sequences: nextSequences, - }; - }); - - pushNotification("Alignment completed successfully.", "success"); - } catch (error) { - pushNotification(`An error occurred during alignment: ${error}`, "error"); - } finally { - setAligning(false); - } - } - - const buildMsaSvg = () => { - const visible = msa.filter(row => !hiddenIds.has(row.id as string)); - if (visible.length === 0) return ""; - - const motifPx = 19.1955; - const rowHeight = 13.8131; - const labelW = 80; - const padding = 10; - const textPadding = 6; - - const svgWidth = padding * 2 + labelW + motifPx * msaLength; - const svgHeight = padding * 2 + rowHeight * visible.length; - const lineSpan = Math.max(0, msaLength - 1) * motifPx; - - const rowsSvg = visible - .map((row, rIdx) => { - const y = padding + rIdx * rowHeight; - const labelText = escapeSvgText(row.name || row.id || "row"); - const lineX1 = padding + labelW; - const lineX2 = lineX1 + lineSpan + motifPx; - const lineY = y + rowHeight / 2; - const cells = row.sequence - .slice(0, msaLength) - .map((motif, cIdx) => { - const x = padding + labelW + cIdx * motifPx; - const isPad = (motif.id ?? "").startsWith("pad-"); - if (isPad) return ""; - const fill = normalizeColor(session.settings.motifColorPalette[motif.name || ""]); - const text = escapeSvgText(toDisplayName(motif.name || null) || "UNK"); - return ` - - - ${text} - `; - }) - .join(""); - return ` - - ${labelText} - ${lineSpan > 0 - ? `` - : ""} - ${cells} - `; - }) - .join(""); - - return ` - - - ${rowsSvg} - `; - }; - - const handleDownloadMsaSvg = () => { - const svg = buildMsaSvg(); - if (!svg) { - pushNotification("No visible sequences to download.", "warning"); - return; - } - const blob = new Blob([svg], { type: "image/svg+xml;charset=utf-8" }); - const url = URL.createObjectURL(blob); - const link = document.createElement("a"); - link.href = url; - link.download = `msa_${session.sessionId}.svg`; - document.body.appendChild(link); - link.click(); - document.body.removeChild(link); - URL.revokeObjectURL(url); - }; - - return ( - - - - Click a row label to choose the center sequence to align all other sequences against. Press the Align button to perform the alignment. - Hidden rows are excluded from alignment, but can be reset using the Reset hidden button. - The state of the alignment is saved in the session and can be revisited later. - Any manual changes made to the alignment are local and do not affect the readouts for querying. - - - - - - - setMsaSettingsDialogOpen(true)} - sx={{ - cursor: "pointer", - color: "text.primary", - transform: msaSettingsDialogOpen ? "rotate(180deg)" : "rotate(0deg)", - transition: "transform 0.4s ease", - }} - /> - - {msaState.aligned ? ( - - ) : ( - - )} - - {/* Toolbar */} - - - - - - - - - - setZoom(1)} - sx={{ cursor: "pointer" }} - /> - - - setColorPaletteDialogOpen(true)} - sx={{ cursor: "pointer" }} - /> - - - - - - - - {visibleRows.length > 0 ? ( - - {/* MSA display */} - - - {/* Row-level sortable context (vertical) */} - `row:${row.id}`)} - strategy={verticalListSortingStrategy} - > - - {visibleRows.map((row, rowIndex) => ( - - - (m as any).cellKey)} - strategy={horizontalListSortingStrategy} - > - {row.sequence.map((motif, colIndex) => { - const cellKey = (motif as any).cellKey as string; - const isPad = (motif.id ?? "").startsWith("pad-"); - - return ( - - {!isPad ? ( - - - - - - ) : ( - - - - )} - - ) - })} - - - - ))} - - - - - - ) : ( - - All sequences are hidden. Reset hidden to show them again. - - )} - - - {/* Color palette dialog */} - setColorPaletteDialogOpen(false)} - colorMap={session.settings.motifColorPalette} - onSave={handleColorPaletteSave} - /> - - {/* MSA settings dialog */} - setMsaSettingsDialogOpen(false)} - settings={session.settings.msaSettings} - onSave={handleMsaSettingsSave} - /> - - ); -} diff --git a/src/client/src/components/Workspace.tsx b/src/client/src/components/Workspace.tsx deleted file mode 100644 index 1e578fc..0000000 --- a/src/client/src/components/Workspace.tsx +++ /dev/null @@ -1,203 +0,0 @@ -import React from "react"; -import Box from "@mui/material/Box"; -import Stack from "@mui/material/Stack"; -import Fade from "@mui/material/Fade"; -import { alpha } from "@mui/material/styles"; -import { Routes, Route, useNavigate } from "react-router-dom"; -import { useNotifications } from "../components/NotificationProvider"; -import { useOverlay } from "../components/OverlayProvider"; -import { Session } from "../features/session/types"; -import { getSession, refreshSession, saveSession } from "../features/session/api"; -import { WorkspaceNavbar } from "./WorkspaceNavbar"; -import { WorkspaceSideMenu } from "./WorkspaceSideMenu"; -import { WorkspaceHeader } from "./WorkspaceHeader"; -import { WorkspaceDashboard } from "./WorkspaceDashboard"; -import { WorkspaceUpload } from "./WorkspaceUpload"; -import { WorkspaceExplore } from "./WorkspaceExplore"; -import { WorkspaceQuery } from "./WorkspaceQuery"; - -export const Workspace: React.FC = () => { - const { showOverlay, hideOverlay } = useOverlay(); - const { pushNotification } = useNotifications(); - const navigate = useNavigate(); - const [loading, setLoading] = React.useState(true); - const [session, setSession] = React.useState(null); - - // Track source of last update to session; prevents race conditions - const lastUpdatedSourceRef = React.useRef<"local" | "remote" | null>(null); - - const setSessionLocal = React.useCallback( - (updater: (prev: Session) => Session) => { - lastUpdatedSourceRef.current = "local"; - setSession((prev) => (prev ? updater(prev) : prev)); - }, - [] - ); - - const setSessionRemote = React.useCallback( - (next: Session) => { - lastUpdatedSourceRef.current = "remote"; - setSession(next); - }, - [] - ); - - // Retrieve session from the server on component mount - React.useEffect(() => { - let alive = true; - setLoading(true); - getSession() - .then(sess => { - if (!alive) return; - setSessionRemote(sess); - }) - .catch(err => { - console.error("Error loading session:", err); - navigate("/notfound"); - }) - .finally(() => { if (alive) setLoading(false); }) - return () => { alive = false; } - }, [navigate]) - - // Subscribe to server-sent events - React.useEffect(() => { - if (!session?.sessionId) return; - - let alive = true; - - // Debounce refreshes to avoid spamming if multiple events arrive quickly - let refreshTimer: number | null = null; - const scheduleRefresh = () => { - if (!alive) return; - if (refreshTimer !== null) return; // already scheduled - refreshTimer = window.setTimeout(() => { - refreshTimer = null; - refreshSession(session.sessionId) - .then((fresh) => { - if (!alive) return; - setSessionRemote(fresh); - }) - .catch((err) => { - const msg = err instanceof Error ? err.message : String(err); - pushNotification(`Failed to refresh session: ${msg}`, "error"); - }); - }, 250); // based on chattiness of events - }; - - // Create SSE connection - const SSE_BASE = process.env.REACT_APP_SSE_BASE ?? ""; - console.log(`SSE connecting to ${SSE_BASE}/api/sessionEvents?sessionId=${session.sessionId}`); - const es = new EventSource( - `${SSE_BASE}/api/sessionEvents?sessionId=${encodeURIComponent(session.sessionId)}` - ); - - es.addEventListener("hello", () => { - // Refresh once on connect so we're in sync - scheduleRefresh(); - }); - - // Backend should emit this when job/item status/results changed - es.addEventListener("item_updated", () => { - scheduleRefresh(); - }); - - // Backend should emit this when session is merged/saved - es.addEventListener("session_merged", () => { - scheduleRefresh(); - }); - - es.addEventListener("keepalive", () => { - // No-op; just to keep connection alive - }); - - es.onopen = () => { - console.log("SSE connection opened for session events"); - }; - - es.onmessage = (e) => { - // Generic message handler; shouldn't be called if specific events are handled - console.debug("SSE message received:", e.data); - }; - - es.onerror = (err) => { - // Fires on transient disconnects too; EventSource will retry automatically - console.debug("SSE transient error; will retry", { - readyState: es.readyState, // 0=CONNECTING, 1=OPEN, 2=CLOSED - }); - // Don't close here; don't spam notifications - }; - - return () => { - alive = false; - if (refreshTimer !== null) window.clearTimeout(refreshTimer); - es.close(); - }; - // }, [session?.sessionId, setSessionRemote, pushNotification]); - }, [session?.sessionId, setSessionRemote]); - - // Overlay follows loading state - React.useEffect(() => { - if (loading) showOverlay(); else hideOverlay(); - }, [loading, showOverlay, hideOverlay]) - - // Auto-save session whenever session changes - React.useEffect(() => { - if (!session) return; - if (lastUpdatedSourceRef.current !== "local") return; // only auto-save if local changes - - // Reset source so we don't double-save - lastUpdatedSourceRef.current = null; - - saveSession(session).catch(err => { - const msg = err instanceof Error ? err.message : String(err); - pushNotification(`Failed to save session: ${msg}`, "error"); - }) - }, [session]); - - if (!session && !loading) { - // Hard failure; couldn't load session at all - return null; - } - - // Determine what to show - const showContent = !!session && !loading; - - return ( - - - - ({ - flexGrow: 1, - backgroundColor: theme.vars - ? `rgba(${theme.vars.palette.background.defaultChannel} / 1)` - : alpha(theme.palette.background.default, 1), - overflow: "auto", - })} - > - - - - {/* Actual workspace content fades in once session is ready */} - {session && ( - - - - } /> - } /> - } /> - } /> - - - - )} - - - - - ) -} diff --git a/src/client/src/components/WorkspaceExplore.tsx b/src/client/src/components/WorkspaceExplore.tsx deleted file mode 100644 index 0688b43..0000000 --- a/src/client/src/components/WorkspaceExplore.tsx +++ /dev/null @@ -1,133 +0,0 @@ -import React from "react"; -import Box from "@mui/material/Box"; -import Card from "@mui/material/Card"; -import CardContent from "@mui/material/CardContent"; -import MuiLink from "@mui/material/Link"; -import Stack from "@mui/material/Stack"; -import Tab from "@mui/material/Tab"; -import Tabs from "@mui/material/Tabs"; -import Typography from "@mui/material/Typography"; -import NotificationsRoundedIcon from "@mui/icons-material/NotificationsRounded"; -import { useTheme } from "@mui/material/styles"; -import { Link as RouterLink } from "react-router-dom"; -import { Session } from "../features/session/types"; -import { ViewEmbeddingSpace } from "./ViewEmbeddingSpace"; -import { ViewMsa } from "./ViewMsa"; - -type WorkspaceExploreProps = { - session: Session; - setSession: (updated: (prev: Session) => Session) => void; -} - -type ExploreView = "msa" | "embedding"; - -export const WorkspaceExplore: React.FC = ({ session, setSession }) => { - const theme = useTheme(); - const [view, setView] = React.useState("msa"); - // Helper to switch views - const handleViewChange = (_event: React.SyntheticEvent, newValue: ExploreView) => { - setView(newValue); - } - - return ( - - - - - Getting started - - - Items from the  - - Upload tab -   - marked as Ready can used here for analysis. Biosynthetic fingerprints parsed from your imports can be - visualized in a reduced dimensional space and individual clusters can be enriched for annotations. You can - enrich your data further by querying any item against the BioNexus database in the  - - Query tab -  and importing additional hits into your workspace. Keep an eye on  - for updates on your - imports. - - - - - - - - - Explore imports - - - Visualize multiple sequence alignments and embedding spaces of biosynthetic fingerprints for items in your - workspace. You can switch between the different views using the tabs below. - - - - - - - - - {view === "embedding" && ( - - )} - {view === "msa" && ( - - )} - - - - ) -} diff --git a/src/client/src/components/WorkspaceItemCard.tsx b/src/client/src/components/WorkspaceItemCard.tsx deleted file mode 100644 index 08ee430..0000000 --- a/src/client/src/components/WorkspaceItemCard.tsx +++ /dev/null @@ -1,324 +0,0 @@ -import React from "react"; -import Box from "@mui/material/Box"; -import Stack from "@mui/material/Stack"; -import Typography from "@mui/material/Typography"; -import Checkbox from "@mui/material/Checkbox"; -import Chip from "@mui/material/Chip"; -import IconButton from "@mui/material/IconButton"; -import TextField from "@mui/material/TextField"; -import Tooltip from "@mui/material/Tooltip"; -import ScienceIcon from "@mui/icons-material/Science"; -import BiotechIcon from "@mui/icons-material/Biotech"; -import DeleteIcon from "@mui/icons-material/Delete"; -import CircularProgress from "@mui/material/CircularProgress"; -import EditIcon from "@mui/icons-material/Edit"; -import VisibilityIcon from "@mui/icons-material/Visibility"; -import { SessionItem } from "../features/session/types"; -import { alpha } from "@mui/material/styles"; -import type { Theme } from "@mui/material/styles"; -import { ScoreBar } from "./ScoreBar"; -import { ItemKindChip } from "./ItemKindChip"; - -type WorkspaceItemCardProps = { - item: SessionItem; - selected: boolean; - onToggleSelect: (id: string) => void; - onView: (id: string) => void; - onDelete: (id: string) => void; - onRename: (id: string, newName: string) => void; -} - -// Helper to format "X ago" -function formatUpdatedAgo(updatedAt?: number): string { - if (!updatedAt) return "Never updated"; - const now = Date.now(); - const diffMs = now - updatedAt; - if (diffMs < 0) return "just now"; - - const diffSec = Math.floor(diffMs / 1000); - if (diffSec < 5) return "just now"; - if (diffSec < 60) return `${diffSec}s ago`; - - const diffMin = Math.floor(diffSec / 60); - if (diffMin < 60) return `${diffMin}m ago`; - - const diffHours = Math.floor(diffMin / 60); - if (diffHours < 24) return `${diffHours}h ago`; - - const diffDays = Math.floor(diffHours / 24); - return `${diffDays}d ago`; -} - -function getScoreColor(theme: Theme, value: number): string { - const t = theme.vars || theme; - if (value < 0.5) { return t.palette.error.main }; - if (value < 0.9) { return t.palette.warning.main }; - return t.palette.success.main; -} - -function getScoreTooltip(value: number): string { - if (value < 0.5) { return "Low score: results may be incomplete or noisey" }; - if (value < 0.9) { return "Moderate score: results should be fairly reliable" }; - return "High score: results are likely very reliable"; -} - -export const WorkspaceItemCard: React.FC = ({ - item, - selected, - onToggleSelect, - onView, - onDelete, - onRename, -}) => { - const isCompound = item.kind === "compound"; // there are only two types: "compound" and "gene_cluster" - - // Tick every 15s so "X ago" updates - const [, forceTick] = React.useState(0); - React.useEffect(() => { - const id = window.setInterval(() => { - forceTick(n => n + 1); - }, 5000); - return () => { window.clearInterval(id); } - }, []) - - const isQueued = item.status === "queued"; - const showSpinner = item.status === "processing"; - const isError = item.status === "error"; - const isDone = item.status === "done"; - - // Rename state - const [isEditing, setIsEditing] = React.useState(false); - const [draftName, setDraftName] = React.useState(item.name); - - // Keep draft in sync when item.name changes from server - React.useEffect(() => { - if (!isEditing) { - setDraftName(item.name); - } - }, [item.name, isEditing]); - - const startEditing = (e: React.MouseEvent) => { - e.stopPropagation(); - setDraftName(item.name); - setIsEditing(true); - } - - const cancelEditing = () => { - setDraftName(item.name); - setIsEditing(false); - } - - const commitEditing = () => { - const trimmed = draftName.trim(); - // No-op if unchanged or empty - if (!trimmed || trimmed === item.name) { - setIsEditing(false); - setDraftName(item.name); - return; - } - onRename(item.id, trimmed); - setIsEditing(false); - } - - const handleNameKeyDown = (e: React.KeyboardEvent) => { - if (e.key === "Enter") { - e.preventDefault(); - commitEditing(); - } else if (e.key === "Escape") { - e.preventDefault(); - cancelEditing(); - } - } - - return ( - onToggleSelect(item.id)} - direction="column" - sx={(theme) => { - const t = theme.vars || theme; - return { - borderRadius: 1, - border: `1px solid ${selected ? t.palette.primary.main : "transparent"}`, - p: 1.5, - display: "flex", - gap: 1.5, - cursor: "pointer", - "&:hover": { boxShadow: 10 }, - backgroundColor: selected ? alpha("#000000", 0.04) : alpha("#000000", 0.02), - ...theme.applyStyles("dark", { backgroundColor: selected ? alpha("#ffffff", 0.06) : alpha("#ffffff", 0.03) }), - } - }} - > - { - const t = theme.vars || theme; - return { - display: "flex", - alignItems: "flex-start", - justifyContent: "space-between", - gap: 1.5, - } - }} - > - - { - e.stopPropagation(); - onToggleSelect(item.id); - }} - /> - - {isCompound ? ( - - ) : ( - - )} - - - - - - {isEditing ? ( - e.stopPropagation()} - onChange={(e) => setDraftName(e.target.value)} - onKeyDown={handleNameKeyDown} - onBlur={commitEditing} - inputProps={{ - style: { fontSize: "0.875rem", fontWeight: 500 }, - }} - /> - ) : ( - <> - - {item.name} - - {isDone && ( - - )} - - )} - - - - - Status updated {formatUpdatedAgo(item.updatedAt)} - - - - - - {isQueued && ( - - )} - - {showSpinner && ()} - - {isDone && ( - - )} - - {isError && ( - - - - )} - - { - e.stopPropagation(); - onView(item.id); - }} - disabled={!isDone} - > - - - - { - e.stopPropagation(); - onDelete(item.id); - }} - > - - - - - - - {isDone && ( - - {item.retrofingerprints.map((fp, idx) => ( - - - {`Readout ${idx + 1}`} - - - - ))} - - )} - - - ) -} diff --git a/src/client/src/components/WorkspaceQuery.tsx b/src/client/src/components/WorkspaceQuery.tsx deleted file mode 100644 index b4fdd31..0000000 --- a/src/client/src/components/WorkspaceQuery.tsx +++ /dev/null @@ -1,834 +0,0 @@ -import React from "react"; -import Box from "@mui/material/Box"; -import Card from "@mui/material/Card"; -import CardContent from "@mui/material/CardContent"; -import Chip from "@mui/material/Chip"; -import Grid from "@mui/material/Grid"; -import MuiLink from "@mui/material/Link"; -import Tooltip from "@mui/material/Tooltip"; -import Typography from "@mui/material/Typography"; -import Button from "@mui/material/Button"; -import MenuItem from "@mui/material/MenuItem"; -import Select from "@mui/material/Select"; -import FormControl from "@mui/material/FormControl"; -import InputLabel from "@mui/material/InputLabel"; -import TextField from "@mui/material/TextField"; -import InputAdornment from "@mui/material/InputAdornment"; -import Stack from "@mui/material/Stack"; -import CircularProgress from "@mui/material/CircularProgress"; -import Alert from "@mui/material/Alert"; -import { DataGrid, GridColDef, GridPaginationModel } from "@mui/x-data-grid"; -import NotificationsRoundedIcon from "@mui/icons-material/NotificationsRounded"; -import UploadFileIcon from "@mui/icons-material/UploadFile"; -import SettingsIcon from "@mui/icons-material/Settings"; -import DownloadIcon from "@mui/icons-material/Download"; -import { useTheme } from "@mui/material/styles"; -import { Link as RouterLink } from "react-router-dom"; -import { useNotifications } from "../components/NotificationProvider"; -import { Session, SessionItem, QuerySettings } from "../features/session/types"; -import { runQuery } from "../features/query/api"; -import type { QueryResult } from "../features/query/types"; -import { runEnrichment } from "../features/views/api"; -import type { EnrichmentResult } from "../features/views/types"; -import { ItemKindChip } from "./ItemKindChip"; -import { importCompoundById } from "../features/jobs/api"; -import { DialogQuerySettings } from "./DialogQuerySettings"; - -interface WorkspaceQueryProps { - session: Session; - setSession: (updated: (prev: Session) => Session) => void; -} - -const formatSource = (value: string | null | undefined) => { - if (!value) return "Unknown"; - if (value.toLowerCase() === "mibig") return "MIBiG"; - if (value.toLowerCase() === "npatlas") return "NPAtlas"; - return value; -} - -const buildExtIdUrl = (source: string | undefined, extId: string) => { - if (source?.toLowerCase() === "mibig") return `https://mibig.secondarymetabolites.org/repository/${extId}/index.html#r1c1`; - if (source?.toLowerCase() === "npatlas") return `https://www.npatlas.org/explore/compounds/${extId}`; - return null; -} - -const columnNameMap: Record = { - identifier: "Readout ID", - type: "Type", - source: "Source", - ext_id: "External ID", - name: "Name", - score: "Similarity", -}; - -export const WorkspaceQuery: React.FC = ({ session, setSession }) => { - const theme = useTheme(); - const { pushNotification } = useNotifications(); - - const [settingsDialogOpen, setSettingsDialogOpen] = React.useState(false); - - const [queryLoading, setQueryLoading] = React.useState(false); - const [queryError, setQueryError] = React.useState(null); - const [queryResult, setQueryResult] = React.useState(null); - - const [enrichmentLoading, setEnrichmentLoading] = React.useState(false); - const [enrichmentError, setEnrichmentError] = React.useState(null); - const [enrichmentResult, setEnrichmentResult] = React.useState(null); - - const handleDownloadQueryResults = () => { - if (!queryResult || queryResult.rows.length === 0) return; - - const tsvHeader = queryResult.columns.join("\t"); - const tsvRows = queryResult.rows.map((row) => - queryResult.columns.map((col) => row[col] ?? "").join("\t") - ); - const tsvContent = [tsvHeader, ...tsvRows].join("\n"); - - const blob = new Blob([tsvContent], { type: "text/tab-separated-values" }); - const url = URL.createObjectURL(blob); - const a = document.createElement("a"); - a.href = url; - a.download = `query_results_session_${session.sessionId}.tsv`; - a.click(); - URL.revokeObjectURL(url); - } - - const handleDownloadEnrichmentResults = () => { - if (!enrichmentResult || enrichmentResult.items.length === 0) return; - - const tsvHeader = ["id", "schema", "key", "value", "adjusted_p_value"].join("\t"); - const tsvRows = enrichmentResult.items.map((item) => - [item.id, item.schema, item.key, item.value, item.adjusted_p_value].join("\t") - ); - const tsvContent = [tsvHeader, ...tsvRows].join("\n"); - - const blob = new Blob([tsvContent], { type: "text/tab-separated-values" }); - const url = URL.createObjectURL(blob); - const a = document.createElement("a"); - a.href = url; - a.download = `enrichment_results_session_${session.sessionId}.tsv`; - a.click(); - URL.revokeObjectURL(url); - } - - const [paginationModelEnrichment, setPaginationModelEnrichment] = React.useState({ - pageSize: 10, - page: 0, - }) - - const [rowCount, setRowCount] = React.useState(0); - const [paginationModelQuery, setPaginationModelQuery] = React.useState({ - pageSize: 10, - page: 0, - }) - - const [importCooldown, setImportCooldown] = React.useState(false); - const IMPORT_COOLDOWN_MS = 3000; - - // Helper to build deps for import service - const deps = React.useMemo( - () => ({ - setSession, - pushNotification, - sessionId: session.sessionId, - }), - [setSession, session.sessionId] - ) - - // Items that have at least one fingerprint/readout - const itemsWithFingerprints = React.useMemo(() => { - return (session.items || []).filter((item) => item.retrofingerprints && item.retrofingerprints.length > 0); - }, [session.items]); - - const similarityThreshold = session.settings.querySettings?.similarityThreshold ?? 0.7; - const searchSpace = session.settings.querySettings?.searchSpace ?? "only_compounds"; - const [thresholdInput, setThresholdInput] = React.useState(() => Math.round(similarityThreshold * 100).toString()); - - React.useEffect(() => { - setThresholdInput(Math.round(similarityThreshold * 100).toString()); - }, [similarityThreshold]); - - const querySettings = React.useMemo( - () => (session.settings.querySettings || { - similarityThreshold: 0.7, - searchSpace: "only_compounds", - annotationFilters: [], - }), - [similarityThreshold, searchSpace] - ); - - // Flatten retrofingerprints for selection list - const fingerprintOptions = React.useMemo(() => { - const options: { itemId: string; fpId: string; label: string }[] = []; - for (const item of itemsWithFingerprints) { - item.retrofingerprints!.forEach((fp, idx) => { - options.push({ - itemId: item.id, - fpId: fp.id, - label: `${item.name} (readout ${idx + 1})`, - }); - }); - } - return options; - }, [itemsWithFingerprints]); - - const [selectedItemId, setSelectedItemId] = React.useState(() => fingerprintOptions[0]?.itemId || ""); - const [selectedFingerprintId, setSelectedFingerprintId] = React.useState(() => fingerprintOptions[0]?.fpId || ""); - - // Keep selection in sync when options change - React.useEffect(() => { - if (!fingerprintOptions.length) { - setSelectedItemId(""); - setSelectedFingerprintId(""); - return; - } - if (!fingerprintOptions.find((opt) => opt.fpId === selectedFingerprintId)) { - setSelectedItemId(fingerprintOptions[0].itemId); - setSelectedFingerprintId(fingerprintOptions[0].fpId); - } - }, [fingerprintOptions, selectedFingerprintId]); - - const handleItemChange = (value: string) => { - setSelectedItemId(value); - const firstFp = fingerprintOptions.find((opt) => opt.itemId === value); - setSelectedFingerprintId(firstFp?.fpId || ""); - }; - - const handleFingerprintChange = (value: string) => { - const match = fingerprintOptions.find((opt) => opt.fpId === value); - if (match) { - setSelectedItemId(match.itemId); - setSelectedFingerprintId(match.fpId); - } - }; - - const handleThresholdChange = (value: string) => { - setThresholdInput(value); - const numeric = Number(value); - if (!Number.isFinite(numeric)) return; - const clamped = Math.min(100, Math.max(0, numeric)); - setSession((prev) => ({ - ...prev, - settings: { - ...prev.settings, - querySettings: { - ...prev.settings.querySettings, - similarityThreshold: clamped / 100, - }, - }, - })); - }; - - const handleQuerySettingsSave = (newSettings: QuerySettings) => { - setSession((prev) => ({ - ...prev, - settings: { - ...prev.settings, - querySettings: newSettings, - }, - })); - setSettingsDialogOpen(false); - }; - - // Second fetch for enrichment analysis for results - const fetchEnrichmentForResults = React.useCallback( - async (result: QueryResult) => { - if (!result.rows || result.rows.length === 0) return; - - try { - setEnrichmentLoading(true); - setEnrichmentError(null); - - // Get selected fingerprint data - const selectedItem = session.items?.find((item) => item.id === selectedItemId); - const selectedFingerprint = selectedItem?.retrofingerprints?.find((fp) => fp.id === selectedFingerprintId); - if (!selectedFingerprint || !selectedFingerprint) { - setEnrichmentError("Selected readout data not found."); - return; - } - - const retrofingerprint512 = selectedFingerprint.retrofingerprint512; - - const data = await runEnrichment({ - retrofingerprint512, - querySettings: querySettings, - }); - setEnrichmentResult(data); - } catch (err: any) { - setEnrichmentError(err.message || "Failed to run enrichment analysis"); - pushNotification(`Enrichment analysis failed: ${err.message}`, "error"); - } finally { - setEnrichmentLoading(false); - } - }, - [session.items, selectedItemId, selectedFingerprintId, querySettings, pushNotification] - ) - - // Core fetch function that wires page + pageSize -> limit + offset - const fetchQueryResults = React.useCallback( - async (model: GridPaginationModel): Promise => { - if (!selectedItemId || !selectedFingerprintId) { - pushNotification("Select an item and readout to query.", "warning"); - return null; - } - - // Get the selected fingerprint data - const selectedItem = session.items?.find((item) => item.id === selectedItemId); - const selectedFingerprint = selectedItem?.retrofingerprints?.find((fp) => fp.id === selectedFingerprintId); - if (!selectedFingerprint || !selectedFingerprint) { - pushNotification("Selected readout data not found.", "error"); - return null; - } - - setQueryLoading(true); - setQueryError(null); - - try { - const limit = model.pageSize; - const offset = model.page * model.pageSize; - - const data = await runQuery({ - name: "cross_modal_retrieval", - params: { - retrofingerprint512: selectedFingerprint.retrofingerprint512, - querySettings, - }, - paging: { limit, offset }, - order: { column: "score", dir: "desc" }, - }); - - setQueryResult(data); - - const total = - (data as any).totalCount ?? - offset + data.rows.length + (data.rows.length === limit ? 1 : 0); - - setRowCount(total); - - return data; - } catch (err: any) { - setQueryError(err.message || "Failed to run query."); - pushNotification(`Query failed: ${err.message}`, "error"); - return null; - } finally { - setQueryLoading(false); - } - }, - [selectedItemId, selectedFingerprintId, session.items] - ) - - // USer clicks "Run query" -> reset to first page and fetch - const handleRunQuery = async () => { - const firstPageModel: GridPaginationModel = { - page: 0, - pageSize: paginationModelQuery.pageSize, - }; - setPaginationModelQuery(firstPageModel); - - const data = await fetchQueryResults(firstPageModel); - - // Only run enrichment when button is clicked - if (data && data.rows && data.rows.length > 0) { - void fetchEnrichmentForResults(data); - } else { - setEnrichmentResult(null); - } - }; - - // DataGrid pagination event -> fetch that page from backend - const handlePaginationModelChange = (newModel: GridPaginationModel) => { - setPaginationModelQuery(newModel); - void fetchQueryResults(newModel); - } - - // Row action handler - const handleRowAction = React.useCallback( - (row: any) => { - if (importCooldown) { - pushNotification("Please wait a moment before importing another item.", "info"); - return; - } - - // Start global cooldown - setImportCooldown(true); - setTimeout(() => setImportCooldown(false), IMPORT_COOLDOWN_MS); - - // Extract necessary info from the row - const type = row.type as string | undefined; - const identifier = row.identifier as number | undefined; - - // Only allow type "compound" for now - if (type !== "compound") { - pushNotification(`Import for item type "${type}" is not supported yet.`, "warning"); - return; - } - - // Check if both are present - if (!type || !identifier) { - pushNotification("Cannot import item: missing type or identifier.", "error"); - return; - } - - // Use importCompoundById to import the compound - importCompoundById(deps, identifier).then((item) => { - if (item) { - pushNotification(`Imported item "${item.name}" into your workspace`, "success"); - } else { - pushNotification("Failed to import item", "error"); - } - }).catch((err) => { - const msg = err instanceof Error ? err.message : String(err); - pushNotification(`Import failed: ${msg}`, "error"); - }); - }, - [importCooldown, deps] - ) - - // Define columns based on result columns - const columns: GridColDef[] = React.useMemo(() => { - if (!queryResult) return []; - - // return result.columns.map((col) => { - const baseColumns: GridColDef[] = queryResult.columns.map((col) => { - const base: GridColDef = { - field: col, - headerName: columnNameMap[col] || col, - flex: 1, - sortable: true, - resizable: true, - } - - // For "type" column, render ItemKindChip - if (col === "type") { - base.renderCell = (params) => { - const itemKind = params.value as string | undefined; - if (!itemKind) return null; - return ; - } - } - - // Format source values - if (col === "source") { - base.valueGetter = (value) => { - return formatSource(value as string | undefined); - } - } - - // Render ext_id as clickable outlink - if (col === "ext_id") { - base.renderCell = (params) => { - const extId = params.value as string | undefined; - if (!extId) return null; - - const source = params.row.source as string | undefined; - const url = buildExtIdUrl(source, extId); - - return ( - - {extId} - - ) - } - } - - // Format score as percentage with 2 decimals - if (col === "score") { - base.valueFormatter = (v: number) => { - return (v * 100).toFixed(2) + "%"; - } - } - - return base; - }) - - // Add extra column with Import action button - const actionColumn: GridColDef = { - field: "actions", - headerName: "", - sortable: false, - filterable: false, - resizable: false, - disableColumnMenu: true, - align: "center", - headerAlign: "center", - width: 50, - renderCell: (params) => { - const isDisabled = importCooldown || params.row.type !== "compound"; - - // Only show import button for compounds for now - return ( - params.row.type === "compound" ? ( - - - { if (!isDisabled) handleRowAction(params.row) }} - /> - - - ) : null - ) - }, - } - - return [...baseColumns, actionColumn]; - }, [queryResult, handleRowAction, importCooldown]); - - // Map result rows to DataGrid rows with unique IDs - const rows = React.useMemo(() => { - if (!queryResult) return []; - return queryResult.rows.map((row, idx) => ({ - id: idx, - ...row, - })); - }, [queryResult]); - - return ( - - - - - Getting started - - - Pick one of your imported items from the  - - Upload tab - -  and query it against the BioNexus database for compounds and BGCs (biosynthetic gene clusters). The results - data frame shows nearest neighbors based on fingerprint similarity. The enrichment results show overrepresented - annotations among the hits compared to the background database. - You can - insert query hits into your workspace for further analysis in the  - - Explore tab - . Keep an eye on - for updates on your queries. - - - - - - Item - - - - - - - Readout - - - - - handleThresholdChange(e.target.value)} - inputProps={{ min: 0, max: 100, step: 1 }} - InputProps={{ - endAdornment: %, - }} - helperText="Minimum similarity for hits (0-100%)" - disabled={!fingerprintOptions.length} - /> - - - - - setSettingsDialogOpen(true)} - sx={{ - cursor: "pointer", - color: "text.primary", - transform: settingsDialogOpen ? "rotate(180deg)" : "rotate(0deg)", - transition: "transform 0.4s ease", - }} - /> - - {queryLoading && } - - - {!fingerprintOptions.length && ( - - No items with readouts found. Import items in Upload, wait for readouts, then try again. - - )} - - {queryError && {queryError}} - - - - - theme.spacing(2) }}> - - - - - - Query results - - - 0 ? "pointer" : "not-allowed", - }} - /> - - - {rows.length === 0 && !queryLoading ? ( - - {queryResult ? "No results returned." : "Run a query to see results."} - - ) : ( - row.id} - density="compact" - localeText={{ - noRowsLabel: queryResult ? "No rows returned." : "Run a query to see results.", - }} - sx={{ - borderRadius: 0, - "& .MuiIconButton-root": { - backgroundColor: "transparent !important", - border: "none !important", - boxShadow: "none", - }, - "& .MuiIconButton-root:hover": { - backgroundColor: "transparent !important", - }, - "& .MuiButtonBase-root": { - boxShadow: "none", - }, - }} - slotProps={{ - basePagination: { - material: { labelDisplayedRows: ({ from, to, count }) => `${from}-${to}` } - }, - filterPanel: { - sx: { - "& .MuiInputLabel-root": { - backgroundColor: "background.paper", - paddingLeft: 0.5, - paddingRight: 0.5, - }, - }, - }, - }} - /> - )} - - - - - - - - - Enrichment results - - - 0 ? "pointer" : "not-allowed", - }} - /> - - - {enrichmentLoading ? ( - - - Running enrichment analysis... - - ) : enrichmentError ? ( - {enrichmentError} - ) : enrichmentResult ? ( - enrichmentResult.items.length > 0 ? ( - ({ - id: item.id, - significant: item.adjusted_p_value < 0.05, - schema: item.schema, - key: item.key, - value: item.value, - adjusted_p_value: item.adjusted_p_value, - }))} - pagination - paginationModel={paginationModelEnrichment} - onPaginationModelChange={setPaginationModelEnrichment} - pageSizeOptions={[10, 25, 50]} - columns={[ - { field: "significant", headerName: "Significant", flex: 1, sortable: true, resizable: true, - valueFormatter: (v: boolean) => v ? "Yes" : "No", - renderCell: (params) => ( - params.value ? - : - - ) - }, - { field: "schema", headerName: "Schema", flex: 1, sortable: true, resizable: true, - valueFormatter: (v: string) => v.toUpperCase(), - }, - { field: "key", headerName: "Key", flex: 1, sortable: true, resizable: true, - valueFormatter: (v: string) => v.toUpperCase(), - }, - { field: "value", headerName: "Value", flex: 1, sortable: true, resizable: true, - valueFormatter: (v: string) => v.toUpperCase(), - }, - { field: "adjusted_p_value", headerName: "Adjusted P-value", flex: 1, sortable: true, resizable: true, - valueFormatter: (v: number) => v.toExponential(3), - }, - ]} - sx={{ - borderRadius: 0, - "& .MuiIconButton-root": { - backgroundColor: "transparent !important", - border: "none !important", - boxShadow: "none", - }, - "& .MuiIconButton-root:hover": { - backgroundColor: "transparent !important", - }, - "& .MuiButtonBase-root": { - boxShadow: "none", - }, - }} - slotProps={{ - filterPanel: { - sx: { - "& .MuiInputLabel-root": { - backgroundColor: "background.paper", - paddingLeft: 0.5, - paddingRight: 0.5, - }, - }, - }, - }} - localeText={{ - noRowsLabel: "No enrichment results found.", - }} - /> - ) : ( - - No significant enrichment found among the query results. - - ) - ) : ( - - Enrichment results will appear here after running a query. - - )} - - - - - - setSettingsDialogOpen(false)} - settings={session.settings.querySettings} - onSave={handleQuerySettingsSave} - /> - - ) -} diff --git a/src/client/src/components/Hero.tsx b/src/client/src/components/home/Hero.tsx similarity index 100% rename from src/client/src/components/Hero.tsx rename to src/client/src/components/home/Hero.tsx diff --git a/src/client/src/components/HomeAppBar.tsx b/src/client/src/components/home/HomeAppBar.tsx similarity index 97% rename from src/client/src/components/HomeAppBar.tsx rename to src/client/src/components/home/HomeAppBar.tsx index 5b39c75..0b3bf07 100644 --- a/src/client/src/components/HomeAppBar.tsx +++ b/src/client/src/components/home/HomeAppBar.tsx @@ -15,10 +15,10 @@ import Toolbar from "@mui/material/Toolbar"; import MenuIcon from "@mui/icons-material/Menu"; import CloseRoundedIcon from "@mui/icons-material/CloseRounded"; import visuallyHidden from "@mui/utils/visuallyHidden"; -import ColorModeIconDropdown from "../theme/ColorModeIconDropdown"; +import ColorModeIconDropdown from "../../theme/ColorModeIconDropdown"; import { useNavigate } from "react-router-dom"; -import { createSession, getSession } from "../features/session/api"; -import { createCookie } from "../features/session/utils"; +import { createSession, getSession } from "../../features/session/api"; +import { createCookie } from "../../features/session/utils"; // Custom styling for the toolbar const StyledToolbar = styled(Toolbar)(({ theme }) => ({ diff --git a/src/client/src/components/DialogWindow.tsx b/src/client/src/components/shared/DialogWindow.tsx similarity index 98% rename from src/client/src/components/DialogWindow.tsx rename to src/client/src/components/shared/DialogWindow.tsx index f5b0f2b..5cad010 100644 --- a/src/client/src/components/DialogWindow.tsx +++ b/src/client/src/components/shared/DialogWindow.tsx @@ -4,7 +4,6 @@ import Dialog from "@mui/material/Dialog"; import DialogTitle from "@mui/material/DialogTitle"; import DialogContent from "@mui/material/DialogContent"; import DialogActions from "@mui/material/DialogActions"; -import IconButton from "@mui/material/IconButton"; import Chip from "@mui/material/Chip"; import Button, { ButtonProps } from "@mui/material/Button"; import CloseIcon from "@mui/icons-material/Close"; diff --git a/src/client/src/components/Footer.tsx b/src/client/src/components/shared/Footer.tsx similarity index 100% rename from src/client/src/components/Footer.tsx rename to src/client/src/components/shared/Footer.tsx diff --git a/src/client/src/components/BuildVersion.tsx b/src/client/src/components/workspace/BuildVersion.tsx similarity index 89% rename from src/client/src/components/BuildVersion.tsx rename to src/client/src/components/workspace/BuildVersion.tsx index 76fa923..82fb56d 100644 --- a/src/client/src/components/BuildVersion.tsx +++ b/src/client/src/components/workspace/BuildVersion.tsx @@ -1,6 +1,6 @@ import React from "react"; import Box from "@mui/material/Box"; -import packageJson from "../../package.json"; +import packageJson from "../../../package.json"; export const BuildVersion: React.FC = () => { return ( diff --git a/src/client/src/components/MenuButton.tsx b/src/client/src/components/workspace/MenuButton.tsx similarity index 100% rename from src/client/src/components/MenuButton.tsx rename to src/client/src/components/workspace/MenuButton.tsx diff --git a/src/client/src/components/MenuContent.tsx b/src/client/src/components/workspace/MenuContent.tsx similarity index 90% rename from src/client/src/components/MenuContent.tsx rename to src/client/src/components/workspace/MenuContent.tsx index d8acf8e..730f0ea 100644 --- a/src/client/src/components/MenuContent.tsx +++ b/src/client/src/components/workspace/MenuContent.tsx @@ -8,7 +8,6 @@ import Stack from "@mui/material/Stack"; import ExploreIcon from "@mui/icons-material/Explore"; import HomeRoundedIcon from "@mui/icons-material/HomeRounded"; import UploadFileIcon from "@mui/icons-material/UploadFile"; -import QueryStatsIcon from "@mui/icons-material/QueryStats"; import { useNavigate, useLocation } from "react-router-dom"; const mainListItems = [ @@ -23,14 +22,9 @@ const mainListItems = [ to: `/dashboard/upload` }, { - text: "Explore", + text: "Discovery", icon: , - to: `/dashboard/explore` - }, - { - text: "Query", - icon: , - to: `/dashboard/query` + to: `/dashboard/discovery` }, ] diff --git a/src/client/src/components/NavbarBreadcrumbs.tsx b/src/client/src/components/workspace/NavbarBreadcrumbs.tsx similarity index 100% rename from src/client/src/components/NavbarBreadcrumbs.tsx rename to src/client/src/components/workspace/NavbarBreadcrumbs.tsx diff --git a/src/client/src/components/NotificationDrawer.tsx b/src/client/src/components/workspace/NotificationDrawer.tsx similarity index 100% rename from src/client/src/components/NotificationDrawer.tsx rename to src/client/src/components/workspace/NotificationDrawer.tsx diff --git a/src/client/src/components/NotificationProvider.tsx b/src/client/src/components/workspace/NotificationProvider.tsx similarity index 96% rename from src/client/src/components/NotificationProvider.tsx rename to src/client/src/components/workspace/NotificationProvider.tsx index 3f76b36..77cf328 100644 --- a/src/client/src/components/NotificationProvider.tsx +++ b/src/client/src/components/workspace/NotificationProvider.tsx @@ -1,5 +1,5 @@ import React from "react"; -import type { NotificationSeverity } from "../features/notifications/types"; +import type { NotificationSeverity } from "../../features/notifications/types"; // Extend the Notification interface with a "level" property export interface Notification { diff --git a/src/client/src/components/OverlayProvider.tsx b/src/client/src/components/workspace/OverlayProvider.tsx similarity index 100% rename from src/client/src/components/OverlayProvider.tsx rename to src/client/src/components/workspace/OverlayProvider.tsx diff --git a/src/client/src/components/UserIconDropdown.tsx b/src/client/src/components/workspace/UserIconDropdown.tsx similarity index 97% rename from src/client/src/components/UserIconDropdown.tsx rename to src/client/src/components/workspace/UserIconDropdown.tsx index a4fda3b..a8c91c1 100644 --- a/src/client/src/components/UserIconDropdown.tsx +++ b/src/client/src/components/workspace/UserIconDropdown.tsx @@ -12,8 +12,8 @@ import MenuItem from "@mui/material/MenuItem"; import Tooltip from "@mui/material/Tooltip"; import AccountCircleIcon from "@mui/icons-material/AccountCircle"; import { useNotifications } from "./NotificationProvider"; -import { deleteSession } from "../features/session/api"; -import { deleteCookie, getCookie } from "../features/session/utils"; +import { deleteSession } from "../../features/session/api"; +import { deleteCookie, getCookie } from "../../features/session/utils"; export const UserIconDropdown: React.FC = (props) => { const { pushNotification } = useNotifications(); diff --git a/src/client/src/components/workspace/Workspace.tsx b/src/client/src/components/workspace/Workspace.tsx new file mode 100644 index 0000000..73c3345 --- /dev/null +++ b/src/client/src/components/workspace/Workspace.tsx @@ -0,0 +1,143 @@ +import React from "react"; +import Box from "@mui/material/Box"; +import Stack from "@mui/material/Stack"; +import Fade from "@mui/material/Fade"; +import { alpha } from "@mui/material/styles"; +import { Routes, Route, useNavigate } from "react-router-dom"; +import { useNotifications } from "./NotificationProvider"; +import { useOverlay } from "./OverlayProvider"; +import { Session } from "../../features/session/types"; +import { getSession, refreshSession } from "../../features/session/api"; +import { WorkspaceNavbar } from "./WorkspaceNavbar"; +import { WorkspaceSideMenu } from "./WorkspaceSideMenu"; +import { WorkspaceHeader } from "./WorkspaceHeader"; +import { WorkspaceHome } from "./tabs/home/WorkspaceHome"; +import { WorkspaceUpload } from "./tabs/upload/WorkspaceUpload"; +import { WorkspaceDiscovery } from "./tabs/discovery/WorkspaceDiscovery"; + +export const Workspace: React.FC = () => { + const { showOverlay, hideOverlay } = useOverlay(); + const { pushNotification } = useNotifications(); + const navigate = useNavigate(); + + const [loading, setLoading] = React.useState(true); + const [session, setSession] = React.useState(null); + + // Load session on mount + React.useEffect(() => { + let alive = true; + setLoading(true); + + getSession() + .then((sess) => { + if (!alive) return; + setSession(sess); + }) + .catch((err) => { + console.error("Error loading session:", err); + navigate("/notfound"); + }) + .finally(() => { + if (!alive) return; + setLoading(false); + }); + + return () => { alive = false; }; + }, [navigate]); + + // Overlay follows loading state + React.useEffect(() => { + if (loading) showOverlay(); + else hideOverlay(); + }, [loading, showOverlay, hideOverlay]); + + // SSE: refresh session when server says something changed + React.useEffect(() => { + if (!session?.sessionId) return; + + let alive = true; + let refreshTimer: number | null = null; + + const scheduleRefresh = () => { + if (!alive) return; + if (refreshTimer !== null) return; + + refreshTimer = window.setTimeout(() => { + refreshTimer = null; + + refreshSession(session.sessionId) + .then((fresh) => { + if (!alive) return; + setSession(fresh); + }) + .catch((err) => { + const msg = err instanceof Error ? err.message : String(err); + pushNotification(`Failed to refresh session: ${msg}`, "error"); + }); + }, 250); + }; + + const SSE_BASE = process.env.REACT_APP_SSE_BASE ?? ""; + const es = new EventSource( + `${SSE_BASE}/api/sessionEvents?sessionId=${encodeURIComponent(session.sessionId)}` + ); + + // Attach scheduleRefresh to all relevant events + es.addEventListener("hello", scheduleRefresh); + es.addEventListener("item_updated", scheduleRefresh); + es.addEventListener("session_merged", scheduleRefresh); + + es.onopen = () => { + // EventSource retries automatically; avoid spamming notifications + }; + + return () => { + alive = false; + if (refreshTimer !== null) window.clearTimeout(refreshTimer); + es.close(); + }; + }, [session?.sessionId, pushNotification]); + + // Determine what to show + const showContent = !!session && !loading; + + if (!session && !loading) return null; + + return ( + + + + ({ + flexGrow: 1, + backgroundColor: theme.vars + ? `rgba(${theme.vars.palette.background.defaultChannel} / 1)` + : alpha(theme.palette.background.default, 1), + overflow: "auto", + })} + > + + + + {/* Actual workspace content fades in once session is ready */} + {session && ( + + + + } /> + } /> + } /> + + + + )} + + + + + ) +}; diff --git a/src/client/src/components/WorkspaceControls.tsx b/src/client/src/components/workspace/WorkspaceControls.tsx similarity index 93% rename from src/client/src/components/WorkspaceControls.tsx rename to src/client/src/components/workspace/WorkspaceControls.tsx index 54d2959..9e9db80 100644 --- a/src/client/src/components/WorkspaceControls.tsx +++ b/src/client/src/components/workspace/WorkspaceControls.tsx @@ -4,7 +4,7 @@ import NotificationsRoundedIcon from "@mui/icons-material/NotificationsRounded"; import { MenuButton } from "./MenuButton"; import { useNotifications } from "./NotificationProvider"; import { UserIconDropdown } from "./UserIconDropdown"; -import ColorModeIconDropdown from "../theme/ColorModeIconDropdown"; +import ColorModeIconDropdown from "../../theme/ColorModeIconDropdown"; interface WorksspaceControlsProps { handleDrawerOpen: () => void; diff --git a/src/client/src/components/WorkspaceHeader.tsx b/src/client/src/components/workspace/WorkspaceHeader.tsx similarity index 100% rename from src/client/src/components/WorkspaceHeader.tsx rename to src/client/src/components/workspace/WorkspaceHeader.tsx diff --git a/src/client/src/components/WorkspaceNavbar.tsx b/src/client/src/components/workspace/WorkspaceNavbar.tsx similarity index 100% rename from src/client/src/components/WorkspaceNavbar.tsx rename to src/client/src/components/workspace/WorkspaceNavbar.tsx diff --git a/src/client/src/components/WorkspaceSideMenu.tsx b/src/client/src/components/workspace/WorkspaceSideMenu.tsx similarity index 100% rename from src/client/src/components/WorkspaceSideMenu.tsx rename to src/client/src/components/workspace/WorkspaceSideMenu.tsx diff --git a/src/client/src/components/WorkspaceSideMenuMobile.tsx b/src/client/src/components/workspace/WorkspaceSideMenuMobile.tsx similarity index 100% rename from src/client/src/components/WorkspaceSideMenuMobile.tsx rename to src/client/src/components/workspace/WorkspaceSideMenuMobile.tsx diff --git a/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx b/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx new file mode 100644 index 0000000..57651d1 --- /dev/null +++ b/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx @@ -0,0 +1,81 @@ +import React from "react"; +import Box from "@mui/material/Box"; +import Card from "@mui/material/Card"; +import CardContent from "@mui/material/CardContent"; +import Typography from "@mui/material/Typography"; +import MuiLink from "@mui/material/Link"; +import NotificationsRoundedIcon from "@mui/icons-material/NotificationsRounded"; +import { useTheme } from "@mui/material/styles"; +import { useNotifications } from "../../NotificationProvider"; +import { Link as RouterLink } from "react-router-dom"; +import { Session } from "../../../../features/session/types"; + +type WorkspaceDiscoveryProps = { + session: Session; + setSession: React.Dispatch>; +}; + +export const WorkspaceDiscovery: React.FC = ({ session, setSession }) => { + const theme = useTheme(); + const { pushNotification } = useNotifications(); + + // Wrap parent setter (Session | null) into the deps shape (Session-only functional updater) + const setSessionSafe = React.useCallback( + (updater: (prev: Session) => Session) => { + setSession((prev) => (prev ? updater(prev) : prev)); + }, + [setSession] + ); + + // Helper to build deps for import service + const deps = React.useMemo( + () => ({ + setSession: setSessionSafe, + pushNotification, + sessionId: session.sessionId, + }), + [setSessionSafe, pushNotification, session.sessionId] + ); + + return ( + + + + + Getting started + + + Here you can use uploaded items from the  + + Upload tab + +  for cross-modal retrieval against the BioNexus database. + + + + + + ); +}; diff --git a/src/client/src/components/WorkspaceDashboard.tsx b/src/client/src/components/workspace/tabs/home/WorkspaceHome.tsx similarity index 94% rename from src/client/src/components/WorkspaceDashboard.tsx rename to src/client/src/components/workspace/tabs/home/WorkspaceHome.tsx index b938ece..c4fd6c1 100644 --- a/src/client/src/components/WorkspaceDashboard.tsx +++ b/src/client/src/components/workspace/tabs/home/WorkspaceHome.tsx @@ -7,7 +7,7 @@ import Typography from "@mui/material/Typography"; import { useTheme } from "@mui/material/styles"; import { Link as RouterLink } from "react-router-dom"; -export const WorkspaceDashboard: React.FC = () => { +export const WorkspaceHome: React.FC = () => { const theme = useTheme(); return ( @@ -49,12 +49,12 @@ export const WorkspaceDashboard: React.FC = () => { You can use the biosynthetic fingerprints parsed from you data for exploratory data analysis in the  - Explore tab + Discovery tab . Exploratory data analysis allows you to retrieve similar biosynthetic fingerprints and their associated metadata from the database using your uploaded data. diff --git a/src/client/src/components/DialogImportCompound.tsx b/src/client/src/components/workspace/tabs/upload/DialogImportCompound.tsx similarity index 69% rename from src/client/src/components/DialogImportCompound.tsx rename to src/client/src/components/workspace/tabs/upload/DialogImportCompound.tsx index 33341a6..37746d8 100644 --- a/src/client/src/components/DialogImportCompound.tsx +++ b/src/client/src/components/workspace/tabs/upload/DialogImportCompound.tsx @@ -1,4 +1,5 @@ import React from "react"; +import Chip from "@mui/material/Chip"; import Box from "@mui/material/Box"; import Button from "@mui/material/Button"; import FormControlLabel from "@mui/material/FormControlLabel"; @@ -7,20 +8,21 @@ import Switch from "@mui/material/Switch"; import TextField from "@mui/material/TextField"; import Typography from "@mui/material/Typography"; import Autocomplete from "@mui/material/Autocomplete"; -import { useNotifications } from "../components/NotificationProvider"; -import { DialogWindow } from "./DialogWindow"; -import { runQuery } from "../features/query/api"; +import { useNotifications } from "../../NotificationProvider"; +import { DialogWindow } from "../../../shared/DialogWindow"; type CompoundOption = { name: string; smiles: string; + databaseName: string; + databaseIdentifier: string; } type DialogImportCompoundProps = { open: boolean; onClose: () => void; - onImportSingle: (compound: { name: string; smiles: string }) => void; - onImportBatch: (file: File) => void; + onImportSingle: (compound: { name: string; smiles: string; matchStereochemistry: boolean }) => void; + onImportBatch: (file: File, matchStereochemistry: boolean) => void; } export const DialogImportCompound: React.FC = ({ @@ -36,6 +38,9 @@ export const DialogImportCompound: React.FC = ({ const [compoundSmiles, setCompoundSmiles] = React.useState(""); const [batchFile, setBatchFile] = React.useState(null); + // Stereochemistry toggle + const [matchStereochemistry, setMatchStereochemistry] = React.useState(true); + // Autocomplete state const [options, setOptions] = React.useState([]); const [loading, setLoading] = React.useState(false); @@ -52,20 +57,36 @@ export const DialogImportCompound: React.FC = ({ setCompoundSmiles(""); setBatchFile(null); setOptions([]); - } + }; const handleImport = () => { if (mode === "single") { onImportSingle({ name: compoundName.trim(), smiles: compoundSmiles.trim(), + matchStereochemistry, }); } else if (batchFile) { - onImportBatch(batchFile); + onImportBatch(batchFile, matchStereochemistry); } reset(); onClose(); - } + }; + + async function searchCompoundByName(q: string) { + const params = new URLSearchParams({ + q, + limit: "10", + }); + + const res = await fetch(`/api/searchCompound?${params.toString()}`); + + if (!res.ok) { + throw new Error(`Search failed: ${res.status}`); + }; + + return await res.json(); + }; // Debounced search when user types a compound name React.useEffect(() => { @@ -74,16 +95,12 @@ export const DialogImportCompound: React.FC = ({ if (!q) { setOptions([]); return; - } + }; const handle = setTimeout(async () => { setLoading(true); try { - const res = await runQuery({ - name: "search_compound_by_name", - params: { q }, - paging: { limit: 10 }, // size of suggestion list - }); + const res = await searchCompoundByName(q); const rows = (res.rows || []) as CompoundOption[]; setOptions(rows); @@ -115,16 +132,28 @@ export const DialogImportCompound: React.FC = ({ Enter a single compound identifier & SMILES, or upload a CSV/TSV that contains a column called "name" and a column called "smiles". - - setMode(e.target.checked ? "batch" : "single")} - /> - } - label={mode === "batch" ? "Batch import from file" : "Single compound import"} - /> + + + setMatchStereochemistry(e.target.checked)} + /> + } + label={matchStereochemistry ? "Match stereochemistry" : "Ignore stereochemistry"} + /> + + setMode(e.target.checked ? "batch" : "single")} + /> + } + label={mode === "batch" ? "Batch import from file" : "Single compound import"} + /> + {/* Single compound input */} {mode === "single" && ( @@ -155,6 +184,26 @@ export const DialogImportCompound: React.FC = ({ setCompoundSmiles(value.smiles); // auto-fill from DB } }} + renderOption={(props, option) => { + if (typeof option == "string") { + return
  • {option}
  • ; + } + + return ( +
  • + + + + {option.name} + + +
  • + ) + }} renderInput={(params) => ( = ({ - ) -} + ); +}; diff --git a/src/client/src/components/workspace/tabs/upload/WorkspaceItemCard.tsx b/src/client/src/components/workspace/tabs/upload/WorkspaceItemCard.tsx new file mode 100644 index 0000000..e9d23d3 --- /dev/null +++ b/src/client/src/components/workspace/tabs/upload/WorkspaceItemCard.tsx @@ -0,0 +1,238 @@ +import React from "react"; +import Box from "@mui/material/Box"; +import Stack from "@mui/material/Stack"; +import Typography from "@mui/material/Typography"; +import Checkbox from "@mui/material/Checkbox"; +import Chip from "@mui/material/Chip"; +import IconButton from "@mui/material/IconButton"; +import Tooltip from "@mui/material/Tooltip"; +import DeleteIcon from "@mui/icons-material/Delete"; +import CircularProgress from "@mui/material/CircularProgress"; +import { Gauge } from "@mui/x-charts/Gauge"; +import { SessionItem } from "../../../../features/session/types"; +import { alpha } from "@mui/material/styles"; +import type { Theme } from "@mui/material/styles"; + +function getScoreColor(theme: Theme, value: number): string { + const t = theme.vars || theme; + if (value < 0.5) { return t.palette.error.main }; + if (value < 0.9) { return t.palette.warning.main }; + return t.palette.success.main; +}; + +type WorkspaceItemCardProps = { + item: SessionItem; + selected: boolean; + disabled?: boolean; + onToggleSelect: (id: string) => void; + onDelete: (id: string) => void; +}; + +// Helper to format "X ago" +function formatUpdatedAgo(updatedAt?: number): string { + if (!updatedAt) return "Never updated"; + const now = Date.now(); + const diffMs = now - updatedAt; + if (diffMs < 0) return "just now"; + + const diffSec = Math.floor(diffMs / 1000); + if (diffSec < 5) return "just now"; + if (diffSec < 60) return `${diffSec}s ago`; + + const diffMin = Math.floor(diffSec / 60); + if (diffMin < 60) return `${diffMin}m ago`; + + const diffHours = Math.floor(diffMin / 60); + if (diffHours < 24) return `${diffHours}h ago`; + + const diffDays = Math.floor(diffHours / 24); + return `${diffDays}d ago`; +}; + +export const WorkspaceItemCard: React.FC = ({ + item, + selected, + disabled = false, + onToggleSelect, + onDelete, +}) => { + const isCompound = item.kind === "compound"; // there are only two types: "compound" and "cluster" + const itemScore = typeof item.score === "number" ? item.score : 0.0; + + // Tick every 15s so "X ago" updates + const [, forceTick] = React.useState(0); + React.useEffect(() => { + const id = window.setInterval(() => forceTick(n => n + 1), 5000); + return () => { window.clearInterval(id); } + }, []) + + const isQueued = item.status === "queued"; + const showSpinner = item.status === "processing"; + const isError = item.status === "error"; + const isDone = item.status === "done"; + + const handleToggle = (e?: React.SyntheticEvent) => { + if (e) e.stopPropagation(); + if (disabled) return; + onToggleSelect(item.id); + }; + + const handleDelete = (e?: React.SyntheticEvent) => { + if (e) e.stopPropagation(); + if (disabled) return; + onDelete(item.id); + }; + + return ( + { + const t = theme.vars || theme; + return { + borderRadius: 1, + border: `1px solid ${selected ? t.palette.primary.main : "transparent"}`, + p: 1.5, + display: "flex", + gap: 1.5, + cursor: "pointer", + "&:hover": { boxShadow: 10 }, + backgroundColor: selected ? alpha("#000000", 0.04) : alpha("#000000", 0.02), + ...theme.applyStyles("dark", { backgroundColor: selected ? alpha("#ffffff", 0.06) : alpha("#ffffff", 0.03) }), + } + }} + > + + + { + e.stopPropagation(); + onToggleSelect(item.id); + }} + /> + + getScoreColor(theme, item.score!), + transition: "stroke-dashoffset 0.3s ease", + }, + }} + text={({ value }) => `${value}%`} + /> + + + + + + {item.name} + + + + + + Status updated {formatUpdatedAgo(item.updatedAt)} + + + + + + + {disabled && ( + <> + + + + )} + + {isCompound && ( + + )} + + {isQueued && ( + + )} + + {showSpinner && ()} + + {isDone && ( + + )} + + {isError && ( + + + + )} + + { + e.stopPropagation(); + if (disabled) return; + onDelete(item.id); + }} + > + + + + + + ); +}; diff --git a/src/client/src/components/WorkspaceUpload.tsx b/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx similarity index 58% rename from src/client/src/components/WorkspaceUpload.tsx rename to src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx index f11f307..34d0c6a 100644 --- a/src/client/src/components/WorkspaceUpload.tsx +++ b/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx @@ -1,33 +1,26 @@ import React from "react"; -import Box from '@mui/material/Box'; -import Button from '@mui/material/Button'; -import Card from '@mui/material/Card'; -import CardContent from '@mui/material/CardContent'; -import Stack from '@mui/material/Stack'; -import Typography from '@mui/material/Typography'; +import Box from "@mui/material/Box"; +import Button from "@mui/material/Button"; +import Card from "@mui/material/Card"; +import CardContent from "@mui/material/CardContent"; +import Stack from "@mui/material/Stack"; +import Typography from "@mui/material/Typography"; import MuiLink from "@mui/material/Link"; -import NotificationsRoundedIcon from '@mui/icons-material/NotificationsRounded'; +import NotificationsRoundedIcon from "@mui/icons-material/NotificationsRounded"; import { useTheme } from "@mui/material/styles"; -import { useNotifications } from "../components/NotificationProvider"; +import { useNotifications } from "../../NotificationProvider"; import { Link as RouterLink } from "react-router-dom"; import { DialogImportCompound } from "./DialogImportCompound"; -import { DialogImportGeneCluster } from "./DialogImportGeneCluster"; -import { DialogViewItem } from "./DialogViewItem"; import { WorkspaceItemCard } from "./WorkspaceItemCard"; -import { Session } from "../features/session/types"; -import { NewCompoundJob } from "../features/jobs/types"; -import { - MAX_ITEMS, - importCompound, - importCompoundsBatch, - importGeneClustersBatch, -} from "../features/jobs/api"; - -// const MAX_ITEMS = 200; +import { Session } from "../../../../features/session/types"; +import { deleteSessionItem } from "../../../../features/session/api"; +import { NewCompoundJob } from "../../../../features/jobs/types"; +import { MAX_ITEMS, importCompound, importCompoundsBatch } from "../../../../features/jobs/api"; + const MAX_FILE_SIZE_MB = 2; const MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024; -async function parseCompoundFile(file: File): Promise { +async function parseCompoundFile(file: File, matchStereochemistry: boolean): Promise { const text = await file.text(); // Normalize newlines and split lines @@ -42,7 +35,7 @@ async function parseCompoundFile(file: File): Promise { delimiter = "\t"; } else if (file.name.endsWith(".csv")) { delimiter = ","; - } + }; const headers = headerLine.split(delimiter).map(h => h.trim().toLowerCase()); const nameIdx = headers.indexOf("name"); @@ -50,7 +43,7 @@ async function parseCompoundFile(file: File): Promise { if (nameIdx === -1 || smilesIdx === -1) { throw new Error("File must contain 'name' and 'smiles' columns in the header."); - } + }; const compounds: NewCompoundJob[] = []; @@ -64,152 +57,134 @@ async function parseCompoundFile(file: File): Promise { if (!name || !smiles) continue; - compounds.push({ name, smiles }); - } + compounds.push({ name, smiles, matchStereochemistry }); + }; return compounds; -} +}; type WorkspaceUploadProps = { session: Session; - setSession: (updated: (prev: Session) => Session) => void; -} + setSession: React.Dispatch>; +}; export const WorkspaceUpload: React.FC = ({ session, setSession }) => { const theme = useTheme(); const { pushNotification } = useNotifications(); const [openCompounds, setOpenCompounds] = React.useState(false); - const [openGeneClusters, setOpenGeneClusters] = React.useState(false); - const [openView, setOpenView] = React.useState(false); - const [viewingItemId, setViewingItemId] = React.useState(null); - const [selectedIds, setSelectedIds] = React.useState>(new Set()); + const [deletingIds, setDeletingIds] = React.useState>(new Set()); + + // Clean up deletingIds when session items change + React.useEffect(() => { + setDeletingIds((prev) => { + const liveIds = new Set(session.items.map(it => it.id)); + const next = new Set(); + + prev.forEach((id) => { + if (liveIds.has(id)) next.add(id); + }); + + return next; + }); + }, [session.items]); - const [readoutLevel, setReadoutLevel] = React.useState<"rec" | "gene">("gene"); + // Wrap parent setter (Session | null) into the deps shape (Session-only functional updater) + const setSessionSafe = React.useCallback( + (updater: (prev: Session) => Session) => { + setSession((prev) => (prev ? updater(prev) : prev)); + }, + [setSession] + ); // Helper to build deps for import service const deps = React.useMemo( () => ({ - setSession, + setSession: setSessionSafe, pushNotification, sessionId: session.sessionId, }), - [setSession, session.sessionId] - ) - - // Renaming helper - const handleRenameItem = (id: string, newName: string) => { - setSession((prev) => ({ - ...prev, - items: prev.items.map((item) => - item.id === id - ? { ...item, name: newName } - : item - ), - })) - } - - // Viewing helper - const handleViewItem = (id: string) => { - setViewingItemId(id); - setOpenView(true); - } + [setSessionSafe, pushNotification, session.sessionId] + ); // Selection helpers const toggleSelectItem = (id: string) => { + // Prevent toggling if deleting + if (deletingIds.has(id)) return; + setSelectedIds(prev => { const next = new Set(prev); if (next.has(id)) next.delete(id); else next.add(id); return next; }) - } - - const handleDeleteItem = (id: string) => { - setSession(prev => ({ - ...prev, - items: prev.items.filter(item => item.id !== id), - })) - setSelectedIds(prev => { - const next = new Set(prev); - next.delete(id); - return next; - }) - } + }; const handleSelectAll = () => { if (!session.items.length) return; setSelectedIds(new Set(session.items.map(item => item.id))); - } + }; const handleClearSelection = () => { setSelectedIds(new Set()); - } - - const handleDeleteSelected = () => { - if (selectedIds.size === 0) return; - const toDelete = new Set(selectedIds); - - setSession(prev => ({ - ...prev, - items: prev.items.filter(item => !toDelete.has(item.id)), - })) - setSelectedIds(new Set()); - } - - const handleImportSingleCompound = async({ name, smiles}: { name: string; smiles: string }) => { - await importCompound(deps, { name, smiles }); - } + }; - const handleImportBatchCompounds = async (file: File) => { - if (file.size > MAX_FILE_SIZE_BYTES) { - pushNotification(`The file "${file.name}" exceeds the maximum size of ${MAX_FILE_SIZE_MB} MB and was not imported.`, "error"); - return; - } + const handleDeleteItem = async (id: string) => { + // Mark as deleting (UI only) + setDeletingIds(prev => new Set(prev).add(id)); try { - const compounds = await parseCompoundFile(file); - await importCompoundsBatch(deps, compounds); + await deleteSessionItem(session.sessionId, id); + // DO NOTHING ELSE + // SSE refresh will remove item from session } catch (err) { const msg = err instanceof Error ? err.message : String(err); - pushNotification(`Failed to parse compound file: ${msg}`, "error"); + pushNotification(`Failed to delete item: ${msg}`, "error"); + setDeletingIds(prev => { + const n = new Set(prev); + n.delete(id); + return n; + }); } - } + }; - const handleImportGeneClusters = async (files: File[]) => { - if (!files.length) return; + const handleDeleteSelected = async () => { + if (selectedIds.size === 0) return; + + const ids = Array.from(selectedIds); + setSelectedIds(new Set()); // UI-only - const oversized = files.filter(f => f.size > MAX_FILE_SIZE_BYTES); - if (oversized.length > 0) { - pushNotification(`The following files exceed the maximum size of ${MAX_FILE_SIZE_MB} MB and were not imported: ${oversized.map(f => f.name).join(", ")}`, "error"); - - // Keep only files within size limit - files = files.filter(f => f.size <= MAX_FILE_SIZE_BYTES); - } + try { + for (const id of ids) { + await deleteSessionItem(session.sessionId, id); + } + // SSE will update the session state + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + pushNotification(`Failed to delete selected items from session: ${msg}`, "error"); + }; + }; - // Check if any files remain after filtering on file size - if (files.length === 0) { - pushNotification("No gene cluster files to import after size check.", "warning"); + // Import compound handlers + const handleImportSingleCompound = async({ name, smiles, matchStereochemistry}: { name: string; smiles: string; matchStereochemistry: boolean }) => { + await importCompound(deps, { name, smiles, matchStereochemistry }); + }; + + const handleImportBatchCompounds = async (file: File, matchStereochemistry: boolean) => { + if (file.size > MAX_FILE_SIZE_BYTES) { + pushNotification(`The file "${file.name}" exceeds the maximum size of ${MAX_FILE_SIZE_MB} MB and was not imported.`, "error"); return; - } + }; - let payloads: { name: string; fileContent: string }[] = []; try { - payloads = await Promise.all( - files.map(async (file) => ({ - name: file.name, - fileContent: await file.text(), - })) - ) + const compounds = await parseCompoundFile(file, matchStereochemistry); + await importCompoundsBatch(deps, compounds); } catch (err) { const msg = err instanceof Error ? err.message : String(err); - pushNotification(`Failed to read gene cluster files: ${msg}`, "error"); - return; - } - - await importGeneClustersBatch(deps, payloads, readoutLevel); - } + pushNotification(`Failed to parse compound file: ${msg}`, "error"); + }; + }; // Selection states const anySelected = selectedIds.size > 0; @@ -239,15 +214,15 @@ export const WorkspaceUpload: React.FC = ({ session, setSe Getting started - To get started, you can import compounds and gene clusters into your workspace. Use the buttons below to upload your data files. After importing, you can visualize and analyze your data within the  + In this tab you can import compounds and biosynthetic gene clusters (BGCs) into your workspace. Use the buttons below to upload your data files. After importing, you can visualize and analyze your data within the  - Explore tab + Discovery tab . A maximum of {MAX_ITEMS} items can be imported into the workspace. Keep an eye on for updates on your queries. @@ -256,8 +231,9 @@ export const WorkspaceUpload: React.FC = ({ session, setSe -
    @@ -270,23 +246,6 @@ export const WorkspaceUpload: React.FC = ({ session, setSe onImportBatch={handleImportBatchCompounds} /> - setOpenGeneClusters(false)} - onImport={handleImportGeneClusters} - readoutLevel={readoutLevel} - setReadoutLevel={setReadoutLevel} - /> - - item.id === viewingItemId) ?? null} - onClose={() => { - setOpenView(false); - setViewingItemId(null); - }} - /> - {session.items.length > 0 && ( @@ -334,10 +293,9 @@ export const WorkspaceUpload: React.FC = ({ session, setSe key={item.id} item={item} selected={selectedIds.has(item.id)} + disabled={deletingIds.has(item.id)} onToggleSelect={toggleSelectItem} onDelete={handleDeleteItem} - onView={handleViewItem} - onRename={handleRenameItem} /> ))} @@ -346,5 +304,5 @@ export const WorkspaceUpload: React.FC = ({ session, setSe )} - ) -} + ); +}; diff --git a/src/client/src/features/drawing/api.ts b/src/client/src/features/drawing/api.ts deleted file mode 100644 index edec7db..0000000 --- a/src/client/src/features/drawing/api.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { postJson } from "../http"; -import { PrimarySequence } from "../session/types"; -import { ItemDrawingResultSchema } from "./types"; - -export async function drawCompoundItem( - taggedParentSmiles: string, - primarySequence: PrimarySequence -): Promise { - const data = await postJson( - "/api/drawCompoundItem", - { - taggedParentSmiles, - primarySequence - }, - ItemDrawingResultSchema - ) - return data.svg; -} - -export async function drawGeneClusterItem( - fileContent: string -): Promise { - const data = await postJson( - "/api/drawGeneClusterItem", - { - fileContent - }, - ItemDrawingResultSchema - ) - return data.svg; -} diff --git a/src/client/src/features/drawing/types.ts b/src/client/src/features/drawing/types.ts deleted file mode 100644 index 42ffa20..0000000 --- a/src/client/src/features/drawing/types.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { z } from "zod"; - -export const ItemDrawingResultSchema = z.object({ - svg: z.string(), -}); - -export type ItemDrawingResult = z.infer; diff --git a/src/client/src/features/jobs/api.ts b/src/client/src/features/jobs/api.ts index d03a5ff..26efe24 100644 --- a/src/client/src/features/jobs/api.ts +++ b/src/client/src/features/jobs/api.ts @@ -1,17 +1,16 @@ import { postJson } from "../http"; -import type { WorkspaceImportDeps, NewCompoundJob, NewGeneClusterJob } from "./types"; -import type { Session, SessionItem, CompoundItem, GeneClusterItem } from "../session/types"; +import type { WorkspaceImportDeps, NewCompoundJob } from "./types"; +import type { Session, SessionItem, CompoundItem } from "../session/types"; import { saveSession } from "../session/api"; -import { runQuery } from "../query/api"; import { z } from "zod"; -export const MAX_ITEMS = 200; +export const MAX_ITEMS = 20; const SubmitJobRespSchema = z.object({ ok: z.boolean(), elapsed_ms: z.number().int().nonnegative(), status: z.string().optional(), -}).partial() // we don't actually use the response body here +}).partial(); export async function submitCompoundJob( sessionId: string, @@ -24,144 +23,94 @@ export async function submitCompoundJob( itemId: item.id, name: item.name, smiles: item.smiles, + matchStereochemistry: item.matchStereochemistry, }, SubmitJobRespSchema - ) -} + ); +}; -export async function submitGeneClusterJob( - sessionId: string, - item: GeneClusterItem, - readoutLevel: "rec" | "gene", -): Promise { - await postJson( - "/api/submitGeneCluster", - { - sessionId, - itemId: item.id, - name: item.name, - fileContent: item.fileContent, - readoutLevel: readoutLevel, - }, - SubmitJobRespSchema - ) -} - -// General helper to add items to the session with capacity checks -async function addItemsToSession( +export async function importCompoundsBatch( deps: WorkspaceImportDeps, - buildItems: (prev: Session, remainingSlots: number) => { updated: Session; newItems: SessionItem[] } -): Promise<{ nextSession: Session | null; newItems: SessionItem[] }> { - // Load the current session - const { setSession, pushNotification } = deps; + compounds: NewCompoundJob[], +): Promise { + const { pushNotification, setSession, sessionId } = deps; + + if (!compounds.length) { + pushNotification("No compounds to import", "warning"); + return []; + }; let nextSession: Session | null = null; let newItems: SessionItem[] = []; - setSession(prev => { + // Update local session (queued items) + setSession((prev) => { const existingCount = prev.items.length; const remainingSlots = MAX_ITEMS - existingCount; if (remainingSlots <= 0) { - pushNotification(`Workspace is full. Maximum of ${MAX_ITEMS} items reached.`, "warning"); + pushNotification(`Session already has maximum of ${MAX_ITEMS} items`, "warning"); nextSession = prev; newItems = []; return prev; - } - - const { updated, newItems: createdItems } = buildItems(prev, remainingSlots); - - if (createdItems.length === 0) { - nextSession = prev; - newItems = []; - return prev; - } - - nextSession = updated; - newItems = createdItems; - return updated; - }); - - return { nextSession, newItems }; -} - -// Submit compounds -export async function importCompoundsBatch( - deps: WorkspaceImportDeps, - compounds: NewCompoundJob[] -): Promise { - const { pushNotification, setSession, sessionId } = deps; - - if (!compounds.length) { - pushNotification("No valid compounds to import", "warning"); - return []; - } + }; - const { nextSession, newItems } = await addItemsToSession(deps, (prev, remainingSlots) => { - const limitedCompounds = - compounds.length > remainingSlots ? compounds.slice(0, remainingSlots) : compounds; + const limited = compounds.length > remainingSlots ? compounds.slice(0, remainingSlots) : compounds; - if (limitedCompounds.length < compounds.length) { - pushNotification( - `Only ${remainingSlots} compounds were imported due to workspace limit of ${MAX_ITEMS} items.`, - "warning" - ); - } + if (limited.length < compounds.length) { + pushNotification(`Only importing ${limited.length} compounds to avoid exceeding maximum of ${MAX_ITEMS} items`, "warning"); + }; - const createdItems: SessionItem[] = limitedCompounds.map(({ name, smiles }) => ({ + const createdItems: SessionItem[] = limited.map(({ name, smiles, matchStereochemistry }) => ({ id: crypto.randomUUID(), kind: "compound", name, smiles, - taggedSmiles: null, - retrofingerprints: [], - primarySequences: [], + matchStereochemistry, status: "queued", errorMessage: null, updatedAt: Date.now(), + // optional fields + score: null, + payload: null, })); - return { - updated: { - ...prev, - items: [...prev.items, ...createdItems], - }, - newItems: createdItems, - }; - }); + const updated: Session = { ...prev, items: [...prev.items, ...createdItems] }; - if (!nextSession || newItems.length === 0) { - // Nothing to do - return []; - } + nextSession = updated; + newItems = createdItems; + return updated; + }); + + if (!nextSession || newItems.length === 0) return []; - // Save session before submitting jobs + // Persist session BEFORE submitting jobs try { await saveSession(nextSession); } catch (err) { const msg = err instanceof Error ? err.message : String(err); - pushNotification(`Failed to save session: ${msg}`, "error"); - - const newIds = new Set(newItems.map(ni => ni.id)); + pushNotification(`Failed to save session before importing compounds: ${msg}`, "error"); + + const newIds = new Set(newItems.map((it) => it.id)); - setSession(prev => ({ + setSession((prev) => ({ ...prev, - items: prev.items.map(it => + items: prev.items.map((it) => newIds.has(it.id) ? { ...it, status: "error", - errorMessage: `Failed to save session: ${msg}`, + errorMessage: "Failed to save session before importing compound", updatedAt: Date.now(), } : it - ), + ) })); return []; } - // Submit jobs for each new compound (sequential) + // Submit jobs sequentially for (const item of newItems) { try { await submitCompoundJob(sessionId, item as CompoundItem); @@ -169,9 +118,10 @@ export async function importCompoundsBatch( const msg = err instanceof Error ? err.message : String(err); pushNotification(`Failed to submit job for compound "${item.name}": ${msg}`, "error"); - setSession(prev => ({ + // Mark item as error + setSession((prev) => ({ ...prev, - items: prev.items.map(it => + items: prev.items.map((it) => it.id === item.id ? { ...it, @@ -180,13 +130,13 @@ export async function importCompoundsBatch( updatedAt: Date.now(), } : it - ), + ) })); - } - } + }; + }; return newItems; -} +}; // Single compound import wrapper export async function importCompound( @@ -195,155 +145,4 @@ export async function importCompound( ): Promise { const items = await importCompoundsBatch(deps, [payload]); return items[0] ?? null; -} - -// Wrapper around importCompound; construct payload by first retrieving compound job info from server -export async function importCompoundById( - deps: WorkspaceImportDeps, - compoundId: number, -): Promise { - const { pushNotification } = deps; - - // Retrieve compound info from server - let compoundInfo: { name: string; smiles: string }; - try { - compoundInfo = await runQuery({ - name: "compound_info_by_id", - params: { "compound_id": compoundId }, - paramSchema: z.object({ - compound_id: z.number().int().positive(), - }), - }).then(res => { - const rows = (res.rows || []) as { name: string; smiles: string }[]; - if (rows.length === 0) { - throw new Error("No compound found with the given ID"); - } - return { name: rows[0].name, smiles: rows[0].smiles }; - }); - } catch (err) { - const msg = err instanceof Error ? err.message : String(err); - pushNotification(`Failed to retrieve compound info: ${msg}`, "error"); - return null; - } - - // Import compound - const item = await importCompound(deps, { - name: compoundInfo.name, - smiles: compoundInfo.smiles, - }); - - return item; -} - -// Submit gene clusters -export async function importGeneClustersBatch( - deps: WorkspaceImportDeps, - clusters: NewGeneClusterJob[], - readoutLevel: "rec" | "gene", -): Promise { - const { pushNotification, setSession, sessionId } = deps; - - if (!clusters.length) { - pushNotification("No gene clusters to import", "warning"); - return []; - } - - const { nextSession, newItems } = await addItemsToSession(deps, (prev, remainingSlots) => { - const limited = - clusters.length > remainingSlots ? clusters.slice(0, remainingSlots) : clusters; - - if (limited.length < clusters.length) { - pushNotification( - `Only ${remainingSlots} gene clusters were imported due to workspace limit of ${MAX_ITEMS} items.`, - "warning" - ); - } - - const createdItems: SessionItem[] = limited.map(({ name, fileContent }) => ({ - id: crypto.randomUUID(), - kind: "gene_cluster", - name, - fileContent, - retrofingerprints: [], - primarySequences: [], - status: "queued", - errorMessage: null, - updatedAt: Date.now(), - })); - - return { - updated: { - ...prev, - items: [...prev.items, ...createdItems], - }, - newItems: createdItems, - }; - }); - - if (!nextSession || newItems.length === 0) { - return []; - } - - // Save session before submitting jobs - try { - await saveSession(nextSession); - } catch (err) { - const msg = err instanceof Error ? err.message : String(err); - pushNotification(`Failed to save session: ${msg}`, "error"); - - const newIds = new Set(newItems.map(ni => ni.id)); - - setSession(prev => ({ - ...prev, - items: prev.items.map(it => - newIds.has(it.id) - ? { - ...it, - status: "error", - errorMessage: `Failed to save session: ${msg}`, - updatedAt: Date.now(), - } - : it - ), - })); - - return []; - } - - // Submit jobs for each new gene cluster (sequential) - for (const item of newItems) { - try { - await submitGeneClusterJob(sessionId, item as GeneClusterItem, readoutLevel); - } catch (err) { - const msg = err instanceof Error ? err.message : String(err); - pushNotification(`Failed to submit job for gene cluster "${item.name}": ${msg}`, "error"); - - setSession(prev => ({ - ...prev, - items: prev.items.map(it => - it.id === item.id - ? { - ...it, - status: "error", - errorMessage: `Failed to submit job: ${msg}`, - updatedAt: Date.now(), - } - : it - ), - })); - } - } - - return newItems; -} - -// Single gene cluster import wrapper -export async function importGeneCluster( - deps: WorkspaceImportDeps, - payload: NewGeneClusterJob, - readoutLevel: "rec" | "gene", -): Promise { - const items = await importGeneClustersBatch(deps, [payload], readoutLevel); - return items[0] ?? null; -} - +}; diff --git a/src/client/src/features/jobs/types.ts b/src/client/src/features/jobs/types.ts index b84043a..1ea6c76 100644 --- a/src/client/src/features/jobs/types.ts +++ b/src/client/src/features/jobs/types.ts @@ -2,19 +2,11 @@ import type { Session } from "../session/types"; import type { NotificationSeverity } from "../../features/notifications/types"; import { z } from "zod"; +export type SetSession = (updater: (prev: Session) => Session) => void; + export const WorkspaceImportDepsSchema = z.object({ - setSession: z.function().args( - z.function() - .args(z.custom()) - .returns(z.custom()) - ) - .returns(z.void()), - pushNotification: z.function() - .args( - z.string(), - z.custom() - ) - .returns(z.void()), + setSession: z.custom(), + pushNotification: z.function().args(z.string(), z.custom()).returns(z.void()), sessionId: z.string().min(1, "Session ID cannot be empty"), }) @@ -26,12 +18,7 @@ export const BaseNewJobSchema = z.object({ export const NewCompoundJobSchema = BaseNewJobSchema.extend({ smiles: z.string().min(1, "SMILES cannot be empty"), -}) - -export const NewGeneClusterJobSchema = BaseNewJobSchema.extend({ - fileContent: z.string().min(1, "File content cannot be empty"), - readoutLevel: z.enum(["rec", "gene"]), + matchStereochemistry: z.boolean().default(false), }) export type NewCompoundJob = z.infer; -export type NewGeneClusterJob = z.infer; diff --git a/src/client/src/features/query/api.ts b/src/client/src/features/query/api.ts deleted file mode 100644 index 076ab70..0000000 --- a/src/client/src/features/query/api.ts +++ /dev/null @@ -1,57 +0,0 @@ -import { postJson } from "../http"; -import { QueryResult, QueryResultSchema, CreateQueryResultRespSchema, GetQueryResultRespSchema } from "./types"; -import { z } from "zod"; - -// Shared option types -export const OrderDirSchema = z.enum(["asc", "desc"]).default("desc"); -export type OrderDir = z.infer; - -export const OrderSchema = z.object({ - column: z.string(), // server whitelists per-query - dir: OrderDirSchema.optional(), -}) -export type Order = z.infer; - -export const PagingSchema = z.object({ - limit: z.number().int().nonnegative().optional(), - offset: z.number().int().nonnegative().optional(), -}); -export type Paging = z.infer; - -// Generic payload schema -const QueryPayloadSchema = z.object({ - name: z.string(), - params: z.record(z.unknown()), - paging: PagingSchema.optional(), - order: OrderSchema.optional(), -}); -export type QueryPayload = z.infer; - -/** - * Generic query runner. If a `paramSchema` is provided, function will validate - * the params before sending. Otherwise, it will send params as-is (server will - * coerce/validate). - */ -export async function runQuery

    >(args: { - name: string; - params: z.input

    | Record; - paramSchema?: P; - paging?: Paging; - order?: Order; -}): Promise { - const safeParams = args.paramSchema - ? args.paramSchema.parse(args.params) - : (args.params as Record); - - const payload: QueryPayload = QueryPayloadSchema.parse({ - name: args.name, - params: safeParams, - paging: args.paging, - order: args.order, - }) - - // Server response is exactly same as QueryResultSchema - const data = await postJson("/api/query", payload, QueryResultSchema) - - return data; -} diff --git a/src/client/src/features/query/types.ts b/src/client/src/features/query/types.ts deleted file mode 100644 index 9fa28b8..0000000 --- a/src/client/src/features/query/types.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { z } from "zod"; - -export const QueryResultSchema = z.object({ - name: z.string(), - columns: z.array(z.string()), - rows: z.array(z.record(z.unknown())), - rowCount: z.number().int().nonnegative(), - limit: z.number().int().nonnegative(), - offset: z.number().int().nonnegative(), - elapsed_ms: z.number().nonnegative(), -}); - -export type QueryResult = z.infer; - -// Simple response wrapper -export const CreateQueryResultRespSchema = z.object({ result: QueryResultSchema }); -export const GetQueryResultRespSchema = z.object({ result: QueryResultSchema }); \ No newline at end of file diff --git a/src/client/src/features/session/api.ts b/src/client/src/features/session/api.ts index 30c75d0..f584d04 100644 --- a/src/client/src/features/session/api.ts +++ b/src/client/src/features/session/api.ts @@ -3,31 +3,37 @@ import { Session, initSession, SessionSchema, CreateSessionRespSchema, GetSessio import { getCookie } from "./utils"; import { z } from "zod"; +const OkRespSchema = z.object({ ok: z.boolean() }).partial(); + export async function createSession(): Promise { const newSession = initSession(); const data = await postJson("/api/createSession", { session: newSession }, CreateSessionRespSchema); return data.sessionId; -} +}; export async function getSession(sessionIdArg?: string): Promise { const sessionId = sessionIdArg ?? getCookie("sessionId"); if (!sessionId) throw new Error("No sessionId provided or found in cookies"); const data = await postJson("/api/getSession", { sessionId }, GetSessionRespSchema); return data.session; -} +}; export async function refreshSession(sessionId: string): Promise { return getSession(sessionId); -} +}; export async function saveSession(session: Session): Promise { // Runtime validate before sending (especially useful because session is user-mutated in UI) SessionSchema.parse(session); await postJson("/api/saveSession", { session }, z.unknown()); -} +}; export async function deleteSession(): Promise { const sessionId = getCookie("sessionId"); if (!sessionId) throw new Error("No sessionId cookie found"); await postJson("/api/deleteSession", { sessionId }, z.unknown()); -} +}; + +export async function deleteSessionItem(sessionId: string, itemId: string): Promise { + await postJson("/api/deleteSessionItem", { sessionId, itemId }, OkRespSchema); +}; diff --git a/src/client/src/features/session/types.ts b/src/client/src/features/session/types.ts index c793a2a..c87cc95 100644 --- a/src/client/src/features/session/types.ts +++ b/src/client/src/features/session/types.ts @@ -1,130 +1,37 @@ -import { defaultMotifColorMap } from "./utils"; import { z } from "zod"; export const BaseItemSchema = z.object({ id: z.string(), name: z.string(), // display name + score: z.number().min(0).max(1).nullable().optional(), + payload: z.record(z.any()).nullable().optional(), status: z.enum(["queued", "processing", "done", "error"]).default("queued"), errorMessage: z.string().nullable().optional(), updatedAt: z.number().nonnegative().default(() => Date.now()), -}) - -export const PrimarySequenceMotifSchema = z.object({ - id: z.string(), - name: z.string().nullable().optional(), - displayName: z.string().nullable().optional(), - - // Structural information - tags: z.array(z.number().int().nonnegative()).default([]), - smiles: z.string().nullable().optional(), - morganfingerprint2048r2: z.string().length(512).nullable().optional(), -}).refine( - ({ smiles, morganfingerprint2048r2 }) => - (smiles == null && morganfingerprint2048r2 == null) || - (smiles != null && morganfingerprint2048r2 != null), - { message: "Both 'smiles' and 'morganfingerprint2048r2' must be provided together or both be null",} -) - -export type PrimarySequenceMotif = z.output; - -export const PrimarySequenceSchema = z.object({ - id: z.string(), - name: z.string().nullable().optional(), - parentSmilesTagged: z.string().nullable().optional(), - sequence: z.array(PrimarySequenceMotifSchema).min(1), -}) - -export type PrimarySequence = z.output; - -export const MsaSequenceSchema = PrimarySequenceSchema.extend({ - itemId: z.string(), - primarySequenceId: z.string(), - hidden: z.boolean().default(false), }); -export type MsaSequence = z.output; - -export const MsaStateSchema = z.object({ - aligned: z.boolean().default(false), - centerId: z.string().nullable().optional(), - sequences: z.array(MsaSequenceSchema).default([]), -}); - -export type MsaState = z.output; - -export const BaseFingerprintSchema = z.object({ - id: z.string(), - retrofingerprint512: z.string().length(128), - score: z.number().min(0).max(1), -}) - -export const CompoundRetrofingerprintSchema = BaseFingerprintSchema.extend({}); -export const GeneClusterRetrofingerprintSchema = BaseFingerprintSchema.extend({}); - export const CompoundItemSchema = BaseItemSchema.extend({ kind: z.literal("compound"), smiles: z.string(), - taggedSmiles: z.string().nullable().optional(), - retrofingerprints: z.array(CompoundRetrofingerprintSchema).default([]), - primarySequences: z.array(PrimarySequenceSchema).default([]), -}) + matchStereochemistry: z.boolean(), +}); -export const GeneClusterSchema = BaseItemSchema.extend({ - kind: z.literal("gene_cluster"), +export const ClusterSchema = BaseItemSchema.extend({ + kind: z.literal("cluster"), fileContent: z.string(), - retrofingerprints: z.array(GeneClusterRetrofingerprintSchema).default([]), - primarySequences: z.array(PrimarySequenceSchema).default([]), -}) +}); -export const SessionItemSchema = z.discriminatedUnion("kind", [ - CompoundItemSchema, - GeneClusterSchema, -]) +export const SessionItemSchema = z.discriminatedUnion("kind", [CompoundItemSchema,ClusterSchema]); export type CompoundItem = z.output; -export type GeneClusterItem = z.output; +export type ClusterItem = z.output; export type SessionItem = z.output; -export const AlignmentTypeSchema = z.enum(["global", "local"]); -export const EmbeddingVisualizationTypeSchema = z.enum(["pca", "umap"]); -export const QuerySearchSpaceSchema = z.enum(["only_compounds", "only_gene_clusters", "both"]); -export const AnnotationFilterSchema = z.object({scheme: z.string(), key: z.string(), value: z.string()}); - -export type AlignmentType = z.output; -export type EmbeddingVisualizationType = z.output; -export type QuerySearchSpace = z.output; -export type AnnotationFilter = z.output; - -export const MsaSettingsSchema = z.object({ - alignmentType: AlignmentTypeSchema.default("global"), -}); - -export type MsaSettings = z.output; - -export const QuerySettingsSchema = z.object({ - similarityThreshold: z.number().min(0).max(1).default(0.7), - searchSpace: QuerySearchSpaceSchema.default("both"), - annotationFilters: z.array(AnnotationFilterSchema).default([]), -}); - -export type QuerySettings = z.output; - -export const SessionSettingsSchema = z.object({ - motifColorPalette: z.record(z.string()).default(() => defaultMotifColorMap()), - embeddingVisualizationType: EmbeddingVisualizationTypeSchema.default("pca"), - msaSettings: MsaSettingsSchema.default(() => ({})), - querySettings: QuerySettingsSchema.default(() => ({})), -}); - -export type SessionSettings = z.output; - export const SessionSchema = z.object({ sessionId: z.string().default(() => crypto.randomUUID()), created: z.number().nonnegative().default(() => Date.now()), items: z.array(SessionItemSchema).default([]), - settings: SessionSettingsSchema.default(() => ({})), - msaState: MsaStateSchema.default(() => ({})), -}) +}); export type Session = z.output; @@ -135,4 +42,4 @@ export const GetSessionRespSchema = z.object({ session: SessionSchema }); export function initSession(): Session { const newSession = SessionSchema.parse({}); return newSession; -} +}; diff --git a/src/client/src/features/session/utils.ts b/src/client/src/features/session/utils.ts index c203276..d3daf18 100644 --- a/src/client/src/features/session/utils.ts +++ b/src/client/src/features/session/utils.ts @@ -8,7 +8,7 @@ export function createCookie(name: string, value: string, days?: number, path: s } cookieStr += `; path=${path}`; document.cookie = cookieStr; -} +}; export function getCookie(name: string): string | null { const match = document.cookie @@ -16,66 +16,11 @@ export function getCookie(name: string): string | null { .map(pair => pair.split("=")) .find(([key]) => key === name); return match ? match[1] : null; -} +}; export function deleteCookie(name: string, path: string = "/"): void { document.cookie = `${encodeURIComponent(name)}=; ` + `expires=Thu, 01 Jan 1970 00:00:00 GMT; ` + `path=${path};`; -} - -function parseColor(color: string, alpha: number): string { - // HEX case: "#RGB" or "#RRGGBB" - if (color.startsWith("#")) { - let hex = color.replace(/^#/, ""); - // expand shorthand (#abc → aabbcc) - if (hex.length === 3) { - hex = hex.split("").map(c => c + c).join(""); - } - // parse r, g, b - const r = parseInt(hex.slice(0, 2), 16); - const g = parseInt(hex.slice(2, 4), 16); - const b = parseInt(hex.slice(4, 6), 16); - return `rgba(${r}, ${g}, ${b}, ${alpha})`; - }; - - // HSL case: "hsl(h, s%, l%)" - const hsl = color.match( - /hsl\(\s*([\d.]+)(?:deg)?\s*,\s*([\d.]+)%\s*,\s*([\d.]+)%\s*\)/ - ); - if (hsl) { - const h = hsl[1]; - const s = hsl[2]; - const l = hsl[3]; - return `hsla(${h}, ${s}%, ${l}%, ${alpha})`; - }; - - throw new Error(`Unsupported color format: ${color}`); -}; - -export const defaultMotifColorMap = (): Record => { - const newColorMap: Record = {}; - - const baseColors: Record<"A"|"B"|"C"|"D", string> = { - A: "#e74c3c ", // red - B: "#27ae60", // green - C: "#2980b9", // blue - D: "#f39c12 ", // orange - }; - - for (const key of Object.keys(baseColors) as Array) { - const color = baseColors[key]; - // plain (opaque) base - newColorMap[key] = color; - - // numbered variants 1→15 → alpha = 1/15…15/15 - for (let i = 1; i <= 15; i++) { - const alpha = 1 - (i / 15); - const alphaRounded = Math.round(alpha * 1000) / 1000; - newColorMap[`${key}${i}`] = parseColor(color, alphaRounded); - } - } - - return newColorMap; }; diff --git a/src/client/src/features/views/api.ts b/src/client/src/features/views/api.ts deleted file mode 100644 index 0247286..0000000 --- a/src/client/src/features/views/api.ts +++ /dev/null @@ -1,61 +0,0 @@ -import { postJson } from "../http"; -import { MsaSettings, SessionItem, SessionSettings } from "../session/types"; -import { - EmbeddingPoint, - GetEmbeddingSpaceRespSchema, - GetEnrichmentResultRespSchema, - GetMsaResultRespSchema, - MsaResult -} from "./types"; -import { EnrichmentResult } from "./types"; -import { QuerySettings } from "../session/types"; -import { PrimarySequence } from "../session/types"; - -export async function getEmbeddingSpace( - sessionId: string, - sessionItems: SessionItem[], - method: SessionSettings["embeddingVisualizationType"] = "pca" -): Promise { - const data = await postJson( - "/api/getEmbeddingSpace", - { - sessionId: sessionId, - items: sessionItems, - method, - }, - GetEmbeddingSpaceRespSchema - ) - return data.points; -} - -export async function runEnrichment({ retrofingerprint512, querySettings }: { - retrofingerprint512: any; - querySettings: QuerySettings; -}): Promise { - const data = await postJson( - "/api/enrich", - { - retrofingerprint512, - querySettings, - }, - GetEnrichmentResultRespSchema - ) - return data.result; -} - -export async function runMsa({ primarySequences, centerId, msaSettings }: { - primarySequences: PrimarySequence[]; - centerId?: string; - msaSettings?: MsaSettings; -}): Promise { - const data = await postJson( - "/api/runMsa", - { - primarySequences, - centerId, - msaSettings, - }, - GetMsaResultRespSchema - ) - return data.result; -} diff --git a/src/client/src/features/views/types.ts b/src/client/src/features/views/types.ts deleted file mode 100644 index 02b574f..0000000 --- a/src/client/src/features/views/types.ts +++ /dev/null @@ -1,43 +0,0 @@ -import { PrimarySequenceSchema } from "../session/types"; -import { z } from "zod"; - -export const EmbeddingPointSchema = z.object({ - parent_id: z.string(), // corresponds to SessionItem id - child_id: z.string(), // corresponds to fingerprint item id - kind: z.enum(["compound", "gene_cluster"]), // corresponds to SessionItem kind - x: z.number(), - y: z.number(), -}) - -export const GetEmbeddingSpaceRespSchema = z.object({ points: z.array(EmbeddingPointSchema) }) - -export type EmbeddingPoint = z.output; - -const EnrichmentItemSchema = z.object({ - id: z.string(), - schema: z.string(), - key: z.string(), - value: z.string(), - p_value: z.number(), - adjusted_p_value: z.number(), -}) - -export type EnrichmentItem = z.infer; - -export const EnrichmentResultSchema = z.object({ - items: z.array(EnrichmentItemSchema), -}) - -export type EnrichmentResult = z.infer; - -// Simple response wrapper -export const CreateEnrichmentResultRespSchema = z.object({ result: EnrichmentResultSchema }); -export const GetEnrichmentResultRespSchema = z.object({ result: EnrichmentResultSchema }) - -export const MsaResultSchema = z.object({ - alignedSequences: z.array(PrimarySequenceSchema), -}) - -export type MsaResult = z.infer; - -export const GetMsaResultRespSchema = z.object({ result: MsaResultSchema }) \ No newline at end of file diff --git a/src/client/src/pages/Dashboard.tsx b/src/client/src/pages/Dashboard.tsx index 69f395d..4edee00 100644 --- a/src/client/src/pages/Dashboard.tsx +++ b/src/client/src/pages/Dashboard.tsx @@ -1,12 +1,12 @@ import CssBaseline from "@mui/material/CssBaseline"; import AppTheme from "../theme/AppTheme"; -import { OverlayProvider } from "../components/OverlayProvider"; -import { NotificationProvider } from "../components/NotificationProvider"; -import { Workspace } from "../components/Workspace"; +import { OverlayProvider } from "../components/workspace/OverlayProvider"; +import { NotificationProvider } from "../components/workspace/NotificationProvider"; +import { Workspace } from "../components/workspace/Workspace"; interface DashboardProps { disableCustomTheme?: boolean; -} +}; export default function Dashboard(props: DashboardProps) { return ( @@ -18,5 +18,5 @@ export default function Dashboard(props: DashboardProps) { - ) -} + ); +}; diff --git a/src/client/src/pages/Home.tsx b/src/client/src/pages/Home.tsx index 452cbfc..8b48b58 100644 --- a/src/client/src/pages/Home.tsx +++ b/src/client/src/pages/Home.tsx @@ -1,9 +1,9 @@ import Box from "@mui/material/Box"; import CssBaseline from "@mui/material/CssBaseline"; import AppTheme from "../theme/AppTheme"; -import HomeAppBar from "../components/HomeAppBar"; -import Hero from "../components/Hero"; -import Footer from "../components/Footer"; +import HomeAppBar from "../components/home/HomeAppBar"; +import Hero from "../components/home/Hero"; +import Footer from "../components/shared/Footer"; export default function Home(props: { disableCustomTheme?: boolean }) { return ( diff --git a/src/client/src/pages/NotFound.tsx b/src/client/src/pages/NotFound.tsx index 8f5951b..24fe563 100644 --- a/src/client/src/pages/NotFound.tsx +++ b/src/client/src/pages/NotFound.tsx @@ -6,7 +6,7 @@ import CssBaseline from "@mui/material/CssBaseline"; import Stack from "@mui/material/Stack"; import Typography from "@mui/material/Typography"; import AppTheme from "../theme/AppTheme"; -import Footer from "../components/Footer"; +import Footer from "../components/shared/Footer"; // Helper function to pad numbers with a leading zero if needed const pad = (num: number): string => num.toString().padStart(2, "0") diff --git a/src/server/app.py b/src/server/app.py index b97a9c0..e465f90 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -12,23 +12,12 @@ blp_delete_session, blp_get_session, blp_save_session, + blp_delete_item, ) from routes.session_store import get_or_init_app_start_epoch -from routes.query import dsn_from_env, blp as query_blp -from routes.jobs import ( - blp_submit_compound, - blp_submit_gene_cluster, -) -from routes.views import ( - blp_get_embedding_space, - blp_enrich, - blp_run_msa, -) -from routes.drawing import ( - blp_draw_compound_item, - blp_draw_gene_cluster_item, -) from routes.events import blp_events +from routes.database import dsn_from_env +from routes.compound import blp_search_compound, blp_submit_compound # Initialize the Flask app @@ -159,12 +148,7 @@ def ready() -> tuple[dict[str, str], int]: app.register_blueprint(blp_delete_session) app.register_blueprint(blp_get_session) app.register_blueprint(blp_save_session) -app.register_blueprint(query_blp) +app.register_blueprint(blp_delete_item) +app.register_blueprint(blp_search_compound) app.register_blueprint(blp_submit_compound) -app.register_blueprint(blp_submit_gene_cluster) -app.register_blueprint(blp_get_embedding_space) -app.register_blueprint(blp_enrich) -app.register_blueprint(blp_run_msa) -app.register_blueprint(blp_draw_compound_item) -app.register_blueprint(blp_draw_gene_cluster_item) app.register_blueprint(blp_events) diff --git a/src/server/routes/_retromol.py b/src/server/routes/_retromol.py deleted file mode 100644 index 18b85ce..0000000 --- a/src/server/routes/_retromol.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Custom implementation for RetroMol linear backbone readouts.""" - -from typing import Any - -from networkx import connected_components -from retromol.io import Result -from retromol.readout import ( - _graphs_with_metadata, - _is_path_component, - _longest_path_approx, - _monomer_nodes_at_level, - _order_nodes_along_path, - _payload_from_order, -) - - -def retromol_linear_readout( - result: Result, - require_identified: bool = True, - mode: str = "all", # "all" | "best_per_level" | "global_best" - nesting_depth: int | None = None, -) -> dict[str, Any]: - """ - Linear backbone readouts. - - :param result: RetroMol Result object - :param require_identified: if ``True``, only consider monomer nodes with an - assigned identity. If ``False``, consider all monomer nodes - :param nesting_depth: maximum nesting level to analyze. If ``None``, all depths are included - - when ``nesting_depth`` is ``None``: iterate all graphs in DFS order and return - a structure identical to the previous version (keys and shapes), but each - entry now also includes a ``depth`` field for clarity. - - set ``nesting_depth = k`` to restrict analysis to graphs at that **true** - nesting level (root = 0, its children = 1, etc.). - :param mode: determines the aggregation mode of readouts - Supported values: - - ``"all"``: return all depth levels and paths. - - :returns: - Depending on the selected ``mode``: - - **mode = "all"** - Returns: - ``{"levels": [ - {"dfs_index": int, "depth": int, - "strict_paths": [payload, ...], - "fallback": payload_or_None}, - ... - ]}`` - """ - metas = _graphs_with_metadata(result.graph) - if nesting_depth is not None: - metas = [m for m in metas if m["depth"] == nesting_depth] - if not metas: - msg = f"No graphs at nesting_depth={nesting_depth}." - if mode == "global_best": - return { - "dfs_index": -1, - "depth": nesting_depth, - "strict_path": False, - "backbone": {"n_monomers": 0, "ordered_monomers": []}, - "notes": msg, - } - else: - return {"levels": [], "notes": msg} - - # Per-graph analysis - entries: list[dict[str, Any]] = [] - for m in metas: - G = m["graph"] - dfs_idx = m["dfs_index"] - depth = m["depth"] - - parent_smiles_tagged = G.graph["smiles"] - - monomer_nodes = _monomer_nodes_at_level(G, require_identified) - if not monomer_nodes: - entries.append( - { - "dfs_index": dfs_idx, - "depth": depth, - "strict_paths": [], - "fallback": None, - } - ) - continue - - MG = G.subgraph(monomer_nodes).copy() - comps = list(connected_components(MG)) - - strict_payloads: list[dict[str, Any]] = [] - for comp in comps: - nodes = list(comp) - if _is_path_component(MG, nodes): - order = _order_nodes_along_path(MG, nodes) - strict_payloads.append(_payload_from_order(G, order)) - - fallback_payload = None - if not strict_payloads and comps: - largest = max(comps, key=len) - approx_order = _longest_path_approx(MG, list(largest)) - fallback_payload = _payload_from_order(G, approx_order) - - entries.append( - { - "parent_smiles_tagged": parent_smiles_tagged, - "dfs_index": dfs_idx, - "depth": depth, - "strict_paths": strict_payloads, - "fallback": fallback_payload, - } - ) - - entries.sort(key=lambda e: e["dfs_index"]) - return {"levels": entries} \ No newline at end of file diff --git a/src/server/routes/compound.py b/src/server/routes/compound.py new file mode 100644 index 0000000..918d895 --- /dev/null +++ b/src/server/routes/compound.py @@ -0,0 +1,190 @@ +"""Blueprints for compound-related API endpoints.""" + +from __future__ import annotations + +import time + +from flask import Blueprint, current_app, jsonify, request +from sqlalchemy import select +from bionexus.db.models import Compound, Reference +from retromol.model.rules import RuleSet +from retromol.model.submission import Submission +from retromol.model.result import Result +from retromol.pipelines.parsing import run_retromol + +from routes.session_store import load_session_with_items, update_item +from routes.database import SessionLocal + +blp_search_compound = Blueprint("search_compound", __name__) +blp_submit_compound = Blueprint("submit_compound", __name__) + +DEFAULT_LIMIT = 10 +MAX_LIMIT = 50 + + +@blp_search_compound.get("/api/searchCompound") +def search_compound_by_name(): + """ + Autocomplete endpoint for compounds by name-like query. + """ + q = (request.args.get("q") or "").strip() + if not q: + return jsonify({"rows": [], "rowCount": 0}), 200 + + try: + limit = int(request.args.get("limit", DEFAULT_LIMIT)) + except ValueError: + limit = DEFAULT_LIMIT + limit = max(1, min(MAX_LIMIT, limit)) + + like = f"%{q}%" + + stmt = ( + select( + Reference.name, + Reference.database_name, + Reference.database_identifier, + Compound.smiles, + ) + .join(Reference.compounds) + .where(Reference.name.ilike(like)) + .order_by(Reference.name.asc()) + .limit(limit) + ) + + with SessionLocal() as session: + rows = session.execute(stmt).all() + + out = [ + { + "name": name, + "smiles": smiles, + "databaseName": database_name, + "databaseIdentifier": database_identifier, + } + for (name, database_name, database_identifier, smiles) in rows + if name and smiles and database_name and database_identifier + ] + + return jsonify({"rows": out, "rowCount": len(out)}), 200 + + +def _set_item_status_inplace(item: dict, status: str, error_message: str | None = None) -> None: + """ + Update the status and error message of an item in place. + + :param item: the item dictionary to update + :param status: the new status string + :param error_message: optional error message string + """ + item["status"] = status + item["updatedAt"] = int(time.time() * 1000) + + if error_message is not None: + item["errorMessage"] = error_message + else: + if "errorMessage" in item: + item["errorMessage"] = None + + +@blp_submit_compound.post("/api/submitCompound") +def submit_compound(): + """ + payload = request.get_json(force=True) or {} + + session_id = payload.get("sessionId") + item_id = payload.get("itemId") + name = payload.get("name") + smiles = payload.get("smiles") + + Endpoint to submit a compound by SMILES string. + """ + payload = request.get_json(force=True) or {} + session_id = payload.get("sessionId") + item_id = payload.get("itemId") + name = payload.get("name") + smiles = payload.get("smiles") + match_stereochemistry = payload.get("matchStereochemistry", False) + + current_app.logger.info(f"submit_compound called: session_id={session_id} item_id={item_id}") + + if not session_id or not item_id: + current_app.logger.warning("submit_compound: missing sessionId or itemId") + return jsonify({"error": "Missing sessionId or itemId"}), 400 + + # Validate session + item exists and kind is correct + full_sess = load_session_with_items(session_id) + if full_sess is None: + current_app.logger.warning(f"submit_compound: session not found: {session_id}") + return jsonify({"error": "Session not found"}), 404 + + item = next((it for it in full_sess.get("items", []) if it.get("id") == item_id), None) + if item is None: + current_app.logger.warning(f"submit_compound: item not found: {item_id}") + return jsonify({"error": "Item not found"}), 404 + + if item.get("kind") != "compound": + current_app.logger.warning(f"submit_compound: wrong kind={item.get('kind')}") + return jsonify({"error": "Item is not a compound"}), 400 + + t0 = time.time() + + # Set status=processing early on this item only + def mark_processing(it: dict) -> None: + """ + Update item details and mark as processing. + + :param it: the item dictionary to update + """ + it["name"] = name or it.get("name") + it["smiles"] = smiles or it.get("smiles") + _set_item_status_inplace(it, "processing") + + ok = update_item(session_id, item_id, mark_processing) + if not ok: + current_app.logger.warning(f"submit_compound: failed to mark item as processing: {item_id}") + return jsonify({"error": "Item not found during update"}), 404 + + try: + # Heavy work + ruleset = RuleSet.load_default(match_stereochemistry=match_stereochemistry) + submission = Submission(smiles, props={}) + result: Result = run_retromol(submission, ruleset) + coverage = result.calculate_coverage() + result_as_dict = result.to_dict() + + # Set final status=done and store results on this item only + def mark_done(it: dict) -> None: + it["name"] = name or it.get("name") + it["smiles"] = smiles or it.get("smiles") + it["matchStereochemistry"] = match_stereochemistry + it["score"] = coverage + it["payload"] = result_as_dict + _set_item_status_inplace(it, "done") + + update_item(session_id, item_id, mark_done) + + except Exception as e: + current_app.logger.exception(f"submit_compound: error for item_id={item_id}") + + def mark_error(it: dict) -> None: + _set_item_status_inplace(it, "error", error_message=str(e)) + + update_item(session_id, item_id, mark_error) + + elapsed = int((time.time() - t0) * 1000) + return jsonify({ + "ok": False, + "status": "error", + "elapsed_ms": elapsed, + "error": str(e), + }), 500 + + elapsed = int((time.time() - t0) * 1000) + current_app.logger.info(f"submit_compound: finished item_id={item_id} elapsed_ms={elapsed}") + + return jsonify({ + "ok": True, + "status": "done", + "elapsed_ms": elapsed, + }), 200 diff --git a/src/server/routes/database.py b/src/server/routes/database.py new file mode 100644 index 0000000..435cec7 --- /dev/null +++ b/src/server/routes/database.py @@ -0,0 +1,35 @@ +"""Database connection setup using SQLAlchemy.""" + +import os + +from sqlalchemy.orm import sessionmaker +from sqlalchemy import create_engine + + +def dsn_from_env() -> str: + """ + Construct the Postgres DSN from environment variables. + + :return: the Postgres DSN string + """ + dsn = os.getenv("DATABASE_URL") + if dsn: + # If plain postgresql:// URL was provided, force psycopg v3 driver + if dsn.startswith("postgresql://"): + dsn = dsn.replace("postgresql://", "postgresql+psycopg://", 1) + # When explicitly using psycopg2, upgrade it too (optional but helpful) + if dsn.startswith("postgresql+psycopg2://"): + dsn = dsn.replace("postgresql+psycopg2://", "postgresql+psycopg://", 1) + return dsn + + host = os.getenv("DB_HOST", "db") + port = os.getenv("DB_PORT", "5432") + name = os.getenv("DB_NAME", "bionexus") + user = os.getenv("DB_USER", "app_ro") + pwd = os.getenv("DB_PASS") or os.getenv("DB_PASSWORD", "apppass_ro") + + return f"postgresql+psycopg://{user}:{pwd}@{host}:{port}/{name}" + + +engine = create_engine(dsn_from_env(), pool_pre_ping=True) +SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False) diff --git a/src/server/routes/drawing.py b/src/server/routes/drawing.py deleted file mode 100644 index c12dc7d..0000000 --- a/src/server/routes/drawing.py +++ /dev/null @@ -1,532 +0,0 @@ -"""Module for handling drawing-related routes.""" - -import re -from enum import Enum -from dataclasses import dataclass - -from flask import Blueprint, request, jsonify -from retromol.chem import smiles_to_mol -from raichu.run_raichu import draw_cluster -from raichu.antismash import get_nrps_pks_modules -from pikachu.general import read_smiles -from pikachu.drawing.drawing import Drawer, Options - - -blp_draw_compound_item = Blueprint("draw_compound_item", __name__) -blp_draw_gene_cluster_item = Blueprint("draw_gene_cluster_item", __name__) - - -class Palette(Enum): - Orange = (230, 159, 0) - SkyBlue = (86, 180, 233) - Green = (3, 158, 115) - Yellow = (240, 228, 66) - Blue = (0, 114, 178) - Red = (213, 95, 0) - Pink = (204, 121, 167) - - def hex(self, alpha: float) -> str: - """ - Get hex representation of the color with specified alpha transparency. - - :param alpha: alpha transparency (0.0 to 1.0) - :return: hex color string with alpha - """ - return f"#{self.value[0]:02x}{self.value[1]:02x}{self.value[2]:02x}{int(alpha * 255):02x}" - - def normalize(self, min_val: float = 0.0, max_val: float = 255.0) -> tuple[float, float, float]: - """ - Get normalized RGB tuple of the color. - - :param min_val: minimum value for normalization - :param max_val: maximum value for normalization - :return: normalized RGB tuple - """ - r, g, b = self.value - return ( - (r - min_val) / (max_val - min_val), - (g - min_val) / (max_val - min_val), - (b - min_val) / (max_val - min_val), - ) - - -def hex_to_rgb_tuple(hex_str: str) -> tuple[float, float, float]: - """ - Convert hex color string to normalized RGB tuple. - - :param hex_str: hex color string (e.g. "#ff5733" or "#ff5733ff") - :return: normalized RGB tuple - """ - hex_str = hex_str.lstrip("#") - if len(hex_str) == 6: - r, g, b = int(hex_str[0:2], 16), int(hex_str[2:4], 16), int(hex_str[4:6], 16) - elif len(hex_str) == 8: - r, g, b = int(hex_str[0:2], 16), int(hex_str[2:4], 16), int(hex_str[4:6], 16) - # alpha = int(hex_str[6:8], 16) # Alpha is ignored in this function - else: - raise ValueError(f"Invalid hex color string: {hex_str}") - return (r / 255.0, g / 255.0, b / 255.0) - - -@dataclass -class Highlight: - """ - Class representing a highlight for drawing. - - :param tags: set of integer tags to highlight - :param color: RGBA color tuple for the highlight - """ - - display_name: str - tags: set[int] - color: tuple[float, float, float] - - -def rgba_to_hex(rgba: tuple[float, float, float]) -> str: - """ - Convert RGBA tuple to hex color string. - - :param rgba: RGBA color tuple - :return: tuple of hex color string - """ - r, g, b = rgba - r_i = int(r * 255) - g_i = int(g * 255) - b_i = int(b * 255) - return f"#{r_i:02x}{g_i:02x}{b_i:02x}" - - -def extract_svg_body(svg_str: str) -> tuple[float, float, str]: - """ - Extract width, height and inner content from an SVG string. - - :param svg_str: SVG string - :return: tuple of (width, height, inner SVG content) - :raises ValueError: if SVG tag or dimensions cannot be parsed - """ - # Find opening tag - m = re.search(r"]*>", svg_str) - if not m: - raise ValueError("No tag found in RDKit/PIKAChU SVG") - - svg_open_tag = m.group(0) - inner = svg_str[m.end():] - # Remove closing tag - inner = inner.replace("", "") - - width_match = re.search(r'width=(["\'])([\d.]+)(?:px)?\1', svg_open_tag) - height_match = re.search(r'height=(["\'])([\d.]+)(?:px)?\1', svg_open_tag) - - if not width_match or not height_match: - raise ValueError("Could not parse width/height from SVG header") - - width = float(width_match.group(2)) - height = float(height_match.group(2)) - - return width, height, inner - - -def build_compound_scheme_svg( - mol1_width: float, - mol1_height: float, - mol1_inner_svg: str, - mol2_width: float, - mol2_height: float, - mol2_inner_svg: str, - highlights: list[Highlight], - arrow_labels: list[str] | None = None, -) -> str: - """ - Build a full SVG with: [structure] --arrow--> [structure] --arrow--> [primary sequence] - - :param mol1_width: width of single molecule drawing - :param mol1_height: height of single molecule drawing - :param mol1_inner_svg: inner SVG content of molecule (no outer tags) - :param mol2_width: width of second molecule drawing - :param mol2_height: height of second molecule drawing - :param mol2_inner_svg: inner SVG content of second molecule (no outer tags) - :param highlights: motif info for the primary sequence panel - :param arrow_labels: optional labels above the arrows, e.g. ["step 1", "step 2"] - """ - arrow_labels = arrow_labels or ["", ""] - - # Layout constants - PADDING = 20.0 - H_GAP = 25.0 # gap between elements - ARROW_LEN = 120.0 - AVG_CHAR_WIDTH = 8.0 - MIN_BOX_WIDTH = 40.0 - H_TEXT_PADDING = 0.0 - SEQ_BOX_HEIGHT = 24.0 - SEQ_BOX_GAP = 4.0 - - n_motifs = len(highlights) - if n_motifs > 0: - seq_panel_height = ( - 2 * PADDING + - n_motifs * SEQ_BOX_HEIGHT + - (n_motifs - 1) * SEQ_BOX_GAP - ) - else: - seq_panel_height = 2 * PADDING + SEQ_BOX_HEIGHT - - max_mol_height = max(mol1_height, mol2_height) - content_height = max(max_mol_height, seq_panel_height) - total_height = content_height + 2 * PADDING - - # X positions - mol1_x = PADDING - mol1_y = (total_height - mol1_height) / 2.0 - - arrow1_x1 = mol1_x + mol1_width + H_GAP - arrow1_x2 = arrow1_x1 + ARROW_LEN - - mol2_x = arrow1_x2 + H_GAP - mol2_y = (total_height - mol2_height) / 2.0 - - arrow2_x1 = mol2_x + mol2_width + H_GAP - arrow2_x2 = arrow2_x1 + ARROW_LEN - - seq_x = arrow2_x2 + H_GAP - seq_y = (total_height - seq_panel_height) / 2.0 - - text_widths = [ - len(h.display_name) * AVG_CHAR_WIDTH + H_TEXT_PADDING - for h in highlights - ] if highlights else [MIN_BOX_WIDTH] - - widest_box = max(max(text_widths), MIN_BOX_WIDTH) - - seq_panel_width = widest_box + 2 * PADDING - - total_width = seq_x + seq_panel_width + PADDING - - arrow_y = total_height / 2.0 - arrow_label_offset = 12.0 # distance above arrow for text - - # Assemble SVG - svg_parts: list[str] = [] - - svg_parts.append( - f'' - ) - - # First molecule - svg_parts.append( - f'' - f'{mol1_inner_svg}' - '' - ) - - # Second molecule - svg_parts.append( - f'' - f'{mol2_inner_svg}' - '' - ) - - # Arrow 1 - head_len = 8.0 - head_width = 6.0 - line1_x2 = arrow1_x2 - head_len # line end before arrowhead - - svg_parts.append( - f'' - ) - - tip_x1 = arrow1_x2 - tip_y1 = arrow_y - base_x1 = line1_x2 - base_y1a = arrow_y - head_width / 2.0 - base_y1b = arrow_y + head_width / 2.0 - - svg_parts.append( - f'' - ) - - if arrow_labels[0]: - mid_x1 = (arrow1_x1 + arrow1_x2) / 2.0 - svg_parts.append( - f'{arrow_labels[0]}' - ) - - # Arrow 2 - line2_x2 = arrow2_x2 - head_len # line end before arrowhead - - svg_parts.append( - f'' - ) - - tip_x2 = arrow2_x2 - tip_y2 = arrow_y - base_x2 = line2_x2 - base_y2a = arrow_y - head_width / 2.0 - base_y2b = arrow_y + head_width / 2.0 - - svg_parts.append( - f'' - ) - - if arrow_labels[1]: - mid_x2 = (arrow2_x1 + arrow2_x2) / 2.0 - svg_parts.append( - f'{arrow_labels[1]}' - ) - - # Primary sequence panel on the right (vertical legend) - # Draw background panel (optional, can remove if you don’t want it) - title_text = "primary sequence" - title_font_size = 14.0 - - title_x = seq_x + seq_panel_width / 2.0 - title_y = seq_y - 6.0 - - svg_parts.append( - f'{title_text}' - ) - - svg_parts.append( - f'' - ) - - current_y = seq_y + PADDING - for h in highlights: - fill_hex = rgba_to_hex(h.color) - # Box - box_width = widest_box - svg_parts.append( - f'' - ) - # Text in middle - text_x = seq_x + PADDING + 4 - text_y = current_y + SEQ_BOX_HEIGHT / 2.0 - svg_parts.append( - f'{h.display_name}' - ) - - current_y += SEQ_BOX_HEIGHT + SEQ_BOX_GAP - - svg_parts.append("") - - return "".join(svg_parts) - - -def draw_structure_with_pikachu(drawer: Drawer) -> str: - """ - Draw a molecular structure using Pikachu and return the SVG string. - - :param drawer: Drawer object with the molecular structure - :return: SVG string of the drawn structure - """ - drawer.flip_y_axis() - drawer.move_to_positive_coords() - drawer.convert_to_int() - - min_x = 1e9 - max_x = -1e9 - min_y = 1e9 - max_y = -1e9 - - # First pass: get bounds - for atom in drawer.structure.graph: - if atom.draw.positioned: - x = atom.draw.position.x - y = atom.draw.position.y - if x < min_x: min_x = x - if x > max_x: max_x = x - if y < min_y: min_y = y - if y > max_y: max_y = y - - padding = drawer.options.padding - width = max_x - min_x + 2 * padding - height = max_y - min_y + 2 * padding - - # Second pass: shift coordinates into viewbox - shift_x = padding - min_x - shift_y = padding - min_y - - for atom in drawer.structure.graph: - if atom.draw.positioned: - atom.draw.position.x += shift_x - atom.draw.position.y += shift_y - - svg_string = f"""""" - svg_string += drawer.svg_style - svg_string += drawer.draw_svg(annotation=None, numbered_atoms=None) - svg_string += "" - - return svg_string - - -def highlight_pikachu_atoms(tagged_smiles: str, highlights: list[Highlight]) -> str: - """ - Highlight atoms in a molecule using Pikachu based on provided tags and colors. - - :param tagged_smiles: SMILES string of the tagged molecule - :param highlights: list of Highlight objects specifying tags and colors - :return: SVG string of the drawn molecule with highlights - """ - # We have to translate isotope-stored tags from the SMILES to atom.nr in PIKAChU - # PIKAChU also labels hydrogen atoms, which we have to ignore for the highlights - - # Create mapping: isotope tag to idx (read order in SMILES) - tag_to_idx = {} - tagged_parent = smiles_to_mol(tagged_smiles) - for atom in tagged_parent.GetAtoms(): - idx = atom.GetIdx() - tag = atom.GetIsotope() - if tag > 0: - tag_to_idx[tag] = idx - - # First map every tag to its corresponding Highlight's color - color_map = {} - for highlight in highlights: - for tag in highlight.tags: - if tag in tag_to_idx: - atom_idx = tag_to_idx[tag] - color_map[atom_idx + 1] = rgba_to_hex(highlight.color) - - # Now create a lookup of every non-hydogen atom - # PIKAChU reads the SMILES string from left-to-right so we can just count - structure = read_smiles(tagged_smiles) - non_h_count = 0 - for atom in structure.get_atoms(): - if atom.type == "H": - continue - - non_h_count += 1 - atom.draw.colour = color_map.get(non_h_count, "black") - - options = Options() - drawer = Drawer(structure, options=options, coords_only=True, kekulise=True) - mol_svg_str = draw_structure_with_pikachu(drawer) - - return mol_svg_str - - -def draw_highlights( - tagged_parent_smiles: str, - tagged_subparent_smiles: str, - highlights: list[Highlight], -) -> None: - """ - Draw highlights on a molecule given its tagged SMILES representation. - - :param tagged_parent_smiles: SMILES string of the tagged parent molecule - :param tagged_subparent_smiles: SMILES string of the tagged subparent molecule - :param highlights: list of Highlight objects specifying tags and alpha values - :return: SVG string of the drawn molecule with highlights - :raises ValueError: if an unknown drawing engine is specified - """ - mol1_svg_str = highlight_pikachu_atoms(tagged_parent_smiles, highlights) - mol2_svg_str = highlight_pikachu_atoms(tagged_subparent_smiles, highlights) - - # Draw other SVG elements and inject structures - mol1_w, mol1_h, mol1_inner = extract_svg_body(mol1_svg_str) - mol2_w, mol2_h, mol2_inner = extract_svg_body(mol2_svg_str) - - arrow_labels = ["preprocess", "sequence"] - svg_str = build_compound_scheme_svg( - mol1_width=mol1_w, - mol1_height=mol1_h, - mol1_inner_svg=mol1_inner, - mol2_width=mol2_w, - mol2_height=mol2_h, - mol2_inner_svg=mol2_inner, - highlights=highlights, - arrow_labels=arrow_labels, - ) - - return svg_str - - -@blp_draw_compound_item.post("/api/drawCompoundItem") -def draw_compound_item(): - """ - Endpoint to handle drawing of compound items. - - :return: JSON response with query results or error message - """ - payload = request.get_json(force=True) or {} - - # Check required fields - tagged_parent_smiles = payload.get("taggedParentSmiles", None) - primary_sequence = payload.get("primarySequence", None) - tagged_subparent_smiles = primary_sequence.get("parentSmilesTagged", None) - - if ( - tagged_parent_smiles is None - or primary_sequence is None - or tagged_subparent_smiles is None - ): - return jsonify({"svg": ""}), 500 - - # Parse out highlights - highlights: list[Highlight] = [] - palette = [c.normalize() for c in Palette] - for motif_idx, motif in enumerate(primary_sequence.get("sequence", [])): - display_name = motif.get("name", None) - if not display_name: - display_name = f"motif {motif_idx + 1}" - - color = palette[motif_idx % len(palette)] - tags = motif.get("tags", []) - highlights.append(Highlight( - display_name=display_name, - tags=set(tags), - color=color) - ) - - # Draw highlights - svg_str = draw_highlights( - tagged_parent_smiles=tagged_parent_smiles, - tagged_subparent_smiles=tagged_subparent_smiles, - highlights=highlights, - ) - - return jsonify({"svg": svg_str}), 200 - - -@blp_draw_gene_cluster_item.post("/api/drawGeneClusterItem") -def draw_gene_cluster_item(): - """ - Endpoint to handle drawing of gene cluster items. - - :return: JSON response with query results or error message - """ - payload = request.get_json(force=True) or {} - - fileContent = payload.get("fileContent", None) - if fileContent is None: - return jsonify({"svg": "", "error": "Missing fileContent"}), 500 - - try: - modules = get_nrps_pks_modules(fileContent, file_mode="file_content") - cluster_repr = modules.make_raichu_cluster() - svg_str = draw_cluster(cluster_repr, out_file=None, colour_by_module=False) - except Exception as e: - return jsonify({"svg": "", "error": str(e)}), 500 - - return jsonify({"svg": svg_str}), 200 diff --git a/src/server/routes/events.py b/src/server/routes/events.py index f60eddf..6ef0b41 100644 --- a/src/server/routes/events.py +++ b/src/server/routes/events.py @@ -5,7 +5,7 @@ import redis from flask import Blueprint, Response, request, stream_with_context -from routes.session_store import REDIS_URL, publish_session_event, load_session_meta +from routes.session_store import REDIS_URL, load_session_meta blp_events = Blueprint("events", __name__) diff --git a/src/server/routes/helpers.py b/src/server/routes/helpers.py deleted file mode 100644 index 6a8e66d..0000000 --- a/src/server/routes/helpers.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Module providing helper functions for endpoints.""" - -import uuid -from typing import Any - -import numpy as np - - -BITS_PER_HEX_DIGIT = 4 - - -def get_unique_identifier() -> str: - """ - Generate a unique identifier string. - - :return: unique identifier as a string - """ - return str(uuid.uuid4()) - - -def bits_to_hex(bits: np.ndarray, n_bits: int = 512) -> str: - """ - Convert a numpy array of bits (0/1 ints) into a hexadecimal string representation. - - :param bits: numpy array of shape (n_bits,) or (1, n_bits) with values 0 or 1 - :param n_bits: expected number of bits (default is 512) - :return: hexadecimal string representation - :raises ValueError: if input array shape is incorrect or contains invalid values - """ - arr = np.asarray(bits, dtype=np.int8).reshape(-1) - - if arr.size != n_bits: - raise ValueError(f"Input array must have shape ({n_bits},) or (1, {n_bits})") - - if n_bits % BITS_PER_HEX_DIGIT != 0: - raise ValueError("Number of bits must be a multiple of 4 in order to convert to hexadecimal") - - # Guard that values are actually 0/1 - if not np.isin(arr, (0, 1)).all(): - raise ValueError("Input array must only contain 0 and 1 values") - - bitstring = "".join("1" if b else "0" for b in arr) # n_bits characters - hex_len = n_bits // BITS_PER_HEX_DIGIT - hexstr = format(int(bitstring, 2), f"0{hex_len}x") - - return hexstr - - -def hex_to_bits(hexstr: str, n_bits: int = 512) -> np.ndarray: - """ - Convert a hexadecimal string representation into a numpy array of bits (0/1 ints). - - :param hexstr: hexadecimal string representation - :param n_bits: expected number of bits (default is 512) - :return: numpy array of shape (n_bits,) with values 0 or 1 - :raises ValueError: if input string is invalid or does not match expected bit length - """ - hexstr = hexstr.strip().lower() - - # Basic hex validation - if not hexstr: - raise ValueError("Input hexadecimal string is empty") - if any(c not in "0123456789abcdef" for c in hexstr): - raise ValueError("Input string contains non-hexadecimal characters") - - inferred_bits = len(hexstr) * BITS_PER_HEX_DIGIT - if inferred_bits != n_bits: - raise ValueError(f"Input hexadecimal string must represent {n_bits} bits (length {n_bits // BITS_PER_HEX_DIGIT})") - - bit_int = int(hexstr, 16) - bitstring = bin(bit_int)[2:].zfill(n_bits) # binary string of length n_bits - - return np.fromiter((1 if b == "1" else 0 for b in bitstring), dtype=np.int8, count=n_bits) - - -def kmerize_sequence(sequence: list[Any], k: int) -> list[list[Any]]: - """ - Generate k-mers from a given sequence (forward and backward). - - :param sequence: list of elements (e.g., amino acids) - :param k: length of each k-mer - :return: list of k-mer strings - """ - kmers = [] - seq_length = len(sequence) - - # Forward k-mers - for i in range(seq_length - k + 1): - kmer = sequence[i:i + k] - kmers.append(kmer) - - # Backward k-mers - for i in range(seq_length - k, -1, -1): - kmer = sequence[i:i + k] - kmers.append(kmer) - - return kmers diff --git a/src/server/routes/jobs.py b/src/server/routes/jobs.py deleted file mode 100644 index 61f8188..0000000 --- a/src/server/routes/jobs.py +++ /dev/null @@ -1,644 +0,0 @@ -"""Module for defining job endpoints.""" - -import tempfile -import time -import dataclasses -from typing import Literal - -import numpy as np -from flask import Blueprint, current_app, request, jsonify -from retromol.api import run_retromol -from retromol.chem import ( - smiles_to_mol, - get_tags_mol, - mol_to_smiles, - mol_to_fpr, -) -from retromol.fingerprint import ( - FingerprintGenerator, - NameSimilarityConfig, - polyketide_family_of, - polyketide_ancestors_of, -) -from retromol.io import Input as RetroMolInput -from retromol.rules import get_path_default_matching_rules -from retromol.helpers import blake64_hex -from routes._retromol import retromol_linear_readout -from biocracker.antismash import parse_region_gbk_file -from biocracker.readout import NRPSModuleReadout, PKSModuleReadout, linear_readouts as biocracker_linear_readouts -from biocracker.text_mining import get_default_tokenspecs, mine_virtual_tokens - -from routes.helpers import bits_to_hex, get_unique_identifier, kmerize_sequence -from routes.models_registry import get_cache_dir, get_paras_model -from routes.session_store import load_session_with_items, update_item - -blp_submit_compound = Blueprint("submit_compound", __name__) -blp_submit_gene_cluster = Blueprint("submit_gene_cluster", __name__) - - -COLLAPSE_BY_NAME = { - "glycosylation": ["glycosyltransferase"], - "methylation": ["methyltransferase"], - "siderophore": ["siderophore"], -} - - -def _set_item_status_inplace(item: dict, status: str, error_message: str | None = None) -> None: - """ - Update the status and error message of an item in place. - - :param item: the item dictionary to update - :param status: the new status string - :param error_message: optional error message string - """ - item["status"] = status - item["updatedAt"] = int(time.time() * 1000) - - if error_message is not None: - item["errorMessage"] = error_message - else: - if "errorMessage" in item: - item["errorMessage"] = None - - -def _stable_name_token(nm: str) -> str: - # Deterministic, case-insensitive token for raw names / name-groups - return f"NM:{blake64_hex('NAMEGROUP:' + (nm or '').lower())}" - - -def add_name_group(generator: FingerprintGenerator, name: str) -> None: - """ - Ensure there is a name-based group for `name` in the generator. - - If such a group already exists, this is a no-op (except ensuring - the name is in collapse_by_name). If it doesn't exist, we clone - an existing name-group as a template and add a new one. - """ - # 1) Does a name-group for this name already exist? - for g in generator.groups: - if getattr(g, "kind", None) == "name" and getattr(g, "name_key", None) == name: - # Make sure it's in collapse_by_name - if name not in (generator.collapse_by_name or []): - generator.collapse_by_name.append(name) - return - - # 2) Find a template name-group to clone - try: - template = next(g for g in generator.groups if getattr(g, "kind", None) == "name") - except StopIteration: - raise RuntimeError("No existing name-based Group found to clone as a template") - - # 3) Build a new token for this name-group - token_fine = _stable_name_token(name) - - # 4) Clone the template and modify fields; adjust to your actual Group schema if needed - new_group = dataclasses.replace( - template, - name_key=name, - token_fine=token_fine, - ) - - # 5) Append to generator state - generator.groups.append(new_group) - if name not in (generator.collapse_by_name or []): - generator.collapse_by_name.append(name) - - # 6) Invalidate caches, since group set changed - generator._assign_cache.clear() - generator._token_bytes_cache.clear() - - -def _setup_fingerprint_generator() -> FingerprintGenerator: - """ - Setup and return a FingerprintGenerator instance. - - :return: FingerprintGenerator instance - """ - path_default_matching_rules = get_path_default_matching_rules() - collapse_by_name: list[str] = list(COLLAPSE_BY_NAME.keys()) - cfg = NameSimilarityConfig( - # family_of=polyketide_family_of, - # family_repeat_scale=1, - ancestors_of=polyketide_ancestors_of, - ancestor_repeat_scale=1, - symmetric=True, - ) - generator = FingerprintGenerator( - matching_rules_yaml=path_default_matching_rules, - collapse_by_name=collapse_by_name, - name_similarity=cfg - ) - add_name_group(generator, "siderophore") - return generator - - -def _compute_compound(generator: FingerprintGenerator, smiles: str) -> tuple[str, list[float], list[str], list[dict]]: - """ - Compute 512-bit fingerprint for a compound given its SMILES. - - :param generator: the fingerprint generator instance - :param smiles: the SMILES string of the compound - :return: tuple of (tagged smiles string, list of coverage values, list of fingerprint hex strings, list of linear readouts) - """ - # Parse compound with RetroMol - input_data = RetroMolInput(cid="compound", repr=smiles) - result = run_retromol(input_data) - - # Retrieve tagged SMILES from result - tagged_smiles_input = result.get_input_smiles(remove_tags=False) - - # Calculate coverage - cov = result.best_total_coverage() - - # Calculate linear readouts - readout = retromol_linear_readout(result, require_identified=False) # TODO: change to "global_best" mode - linear_readouts = [] - for level_idx, level in enumerate(readout["levels"]): - - # We need the tagged parent SMILES for visualization - parent_smiles_tagged = level.get("parent_smiles_tagged", None) - if not parent_smiles_tagged: - raise ValueError("Missing parent_smiles_tagged in level data") - - for path_idx, path in enumerate(level["strict_paths"]): - ms = path["ordered_monomers"] - if len(ms) <= 2: continue # skip too short - - ms_fwd = [] - for m in ms: - m_id = get_unique_identifier() - m_name = m.get("identity", "unknown") - m_display_name = None - - # Process motif SMILES, if any - tagged_smiles = m.get("smiles", None) - if tagged_smiles: - mol = smiles_to_mol(tagged_smiles) - tags: list[int] = get_tags_mol(mol) - clean_smiles = mol_to_smiles(mol, remove_tags=True) - clean_mol = smiles_to_mol(clean_smiles) - fp_bits = mol_to_fpr(clean_mol, rad=2, nbs=2048).reshape(-1).astype(np.uint8) - morgan_fp_hex = bits_to_hex(fp_bits, n_bits=2048) - else: - tags = [] - clean_smiles = None - morgan_fp_hex = None - - ms_fwd.append({ - "id": m_id, - "name": m_name, - "displayName": m_display_name, - "tags": tags, - "smiles": clean_smiles, - "morganfingerprint2048r2": morgan_fp_hex, - }) - - # Get other direction - ms_rev = list(reversed(ms_fwd)) - - linear_readouts.append({ - "id": get_unique_identifier(), - "name": f"level{level_idx}_path{path_idx}_fwd", - "parentSmilesTagged": parent_smiles_tagged, - "sequence": ms_fwd, - }) - linear_readouts.append({ - "id": get_unique_identifier(), - "name": f"level{level_idx}_path{path_idx}_rev", - "parentSmilesTagged": parent_smiles_tagged, - "sequence": ms_rev, - }) - - # Generate retrofingerprints - fps: np.ndarray = generator.fingerprint_from_result(result, num_bits=512, counted=False) # shape [N, 512] where N>=1 - - # Check linear readouts for any family tokens to add - family_tokens = set() - for lr in linear_readouts: - for monomer in lr["sequence"]: - monomer_id = monomer.get("name", None) - siderophore_related = [ - "N-(5-aminopentyl)hydroxylamine" - ] - if monomer_id in siderophore_related: - family_tokens.add("siderophore") - - # Create fingerprint for just family tokens - if family_tokens: - kmers = [[(token_name, None)] for token_name in family_tokens] - family_fp: np.ndarray = generator.fingerprint_from_kmers(kmers, num_bits=512, counted=False) - - # If bits of family token fingerprints are not yet flipped in fps, flip them now - for fp in fps: - # if there are bit sin family_fp not set in fp, set them - for bit_idx in np.where(family_fp)[0]: - if not fp[bit_idx]: - fp[bit_idx] = True - - # Convert retrofingerprints to hex strings - fp_hex_strings = [bits_to_hex(fp) for fp in fps] if len(fps) > 0 else [np.zeros((512,), dtype=bool)] - - return ( - tagged_smiles_input, - [cov for _ in range(len(fp_hex_strings))], - fp_hex_strings, - linear_readouts - ) - - -def _compute_gene_cluster( - generator: FingerprintGenerator, - itemId: str, - gbk_str: str, - readout_level: Literal["rec", "gene"] = "gene", -) -> tuple[list[float], list[str], list[dict]]: - """ - Dummy function to compute a 512-bit fingerprint as a hex string (128 chars). - - :param generator: the fingerprint generator instance - :param itemId: the ID of the gene cluster item - :param gbk_str: the GenBank file content as a string - :param readout_level: the readout level, either "rec" or "gene" - :return: tuple of (list of average prediction values, list of fingerprint hex strings, list of linear readouts) - """ - # Write gbk_str to a temporary file - with tempfile.NamedTemporaryFile(delete=True, suffix=".gbk") as temp_gbk_file: - temp_gbk_file.write(gbk_str.encode("utf-8")) - temp_gbk_file.flush() - gbk_path = temp_gbk_file.name - - # Parse gene cluster file - targets = parse_region_gbk_file(gbk_path, top_level="cand_cluster") # 'region' or 'cand_cluster' top level - - # Configure tokenspecs - tokenspecs = get_default_tokenspecs() - - # Generate readouts - level = readout_level # "rec" or "gene" - avg_pred_vals, fps, linear_readouts = [], [], [] - for target_idx, target in enumerate(targets): - pred_vals = [] - raw_kmers = [] - - # Mine for tokenspecs (i.e., family tokens) - for mined_tokenspec in mine_virtual_tokens(target, tokenspecs): - if token_spec := mined_tokenspec.get("token"): - print(f" Found token_spec: {token_spec}") - for token_name, values in COLLAPSE_BY_NAME.items(): - if token_spec in values: - raw_kmers.append([(token_name, None)]) - - # Optionally load PARAS model - paras_model = get_paras_model() - - # Extract module kmers - for readout in biocracker_linear_readouts( - target, - model=paras_model, - cache_dir_override=get_cache_dir(), - level=level, - pred_threshold=0.1 - ): - kmer, linear_readout = [], [] - for module in readout["readout"]: - match module: - case PKSModuleReadout(module_type="PKS_A") as m: - kmer.append(("A", None)) - pred_vals.append(1.0) - linear_readout.append({ - "id": get_unique_identifier(), - "name": "A", - "displayName": None, - "tags": [], - "smiles": None, - "morganfingerprint2048r2": None, - }) - case PKSModuleReadout(module_type="PKS_B") as m: - kmer.append(("B", None)) - pred_vals.append(1.0) - linear_readout.append({ - "id": get_unique_identifier(), - "name": "B", - "displayName": None, - "tags": [], - "smiles": None, - "morganfingerprint2048r2": None, - }) - case PKSModuleReadout(module_type="PKS_C") as m: - kmer.append(("C", None)) - pred_vals.append(1.0) - linear_readout.append({ - "id": get_unique_identifier(), - "name": "C", - "displayName": None, - "tags": [], - "smiles": None, - "morganfingerprint2048r2": None, - }) - case PKSModuleReadout(module_type="PKS_D") as m: - kmer.append(("D", None)) - pred_vals.append(1.0) - linear_readout.append({ - "id": get_unique_identifier(), - "name": "D", - "displayName": None, - "tags": [], - "smiles": None, - "morganfingerprint2048r2": None, - }) - case PKSModuleReadout(module_type="UNCLASSIFIED") as m: - kmer.append(("A", None)) - pred_vals.append(1.0) - linear_readout.append({ - "id": get_unique_identifier(), - "name": "A", - "displayName": None, - "tags": [], - "smiles": None, - "morganfingerprint2048r2": None, - }) - case NRPSModuleReadout() as m: - substrate_name = m.get("substrate_name", None) - substrate_smiles = m.get("substrate_smiles", None) - substrate_score = m.get("score", 0.0) - if substrate_score is None: - substrate_score = 0.0 - kmer.append((substrate_name, substrate_smiles)) - pred_vals.append(substrate_score) - - # Calculate fingerprint for SMILES if present - if substrate_smiles: - clean_mol = smiles_to_mol(substrate_smiles) - fp_bits = mol_to_fpr(clean_mol, rad=2, nbs=2048).reshape(-1).astype(np.uint8) - morgan_fp_hex = bits_to_hex(fp_bits, n_bits=2048) - else: - morgan_fp_hex = None - - linear_readout.append({ - "id": get_unique_identifier(), - "name": substrate_name or "unknown", - "displayName": None, - "tags": [], - "smiles": substrate_smiles, - "morganfingerprint2048r2": morgan_fp_hex, - }) - case _: - print(module) - raise ValueError("Unknown module readout type") - - if len(kmer) > 0: - raw_kmers.append(kmer) - - # if len(linear_readout) >= 2: # skip too short - linear_readouts.append({ - "id": get_unique_identifier(), - "name": f"{itemId}_readout_{len(linear_readouts)+1}_{level}_{target_idx}", - "parentSmilesTagged": None, - "sequence": linear_readout, - }) - - # Mine for kmers of lengths 1 to 3 - kmers = [] - kmer_lengths = [1, 2, 3] - for k in kmer_lengths: - for raw_kmer in raw_kmers: - kmers.extend(kmerize_sequence(raw_kmer, k)) - - # Generate fingerprint - fp: np.ndarray = generator.fingerprint_from_kmers(kmers, num_bits=512, counted=False) - - # Convert to hex string - fp_hex_string = bits_to_hex(fp) - - # Calculate average prediction value - avg_pred_val = float(np.mean(pred_vals)) if len(pred_vals) > 0 else 0.0 - - avg_pred_vals.append(avg_pred_val) - fps.append(fp_hex_string) - - return avg_pred_vals, fps, linear_readouts - - -@blp_submit_compound.post("/api/submitCompound") -def submit_compound() -> tuple[dict[str, str], int]: - """ - Endpoint to submit a compound for processing. - - Expected JSON body: - - sessionId: str - - itemId: str - - name: str - - smiles: str - - :return: a tuple containing a JSON response and HTTP status code - """ - payload = request.get_json(force=True) or {} - - session_id = payload.get("sessionId") - item_id = payload.get("itemId") - name = payload.get("name") - smiles = payload.get("smiles") - - current_app.logger.info(f"submit_compound called: session_id={session_id} item_id={item_id}") - - if not session_id or not item_id: - current_app.logger.warning("submit_compound: missing sessionId or itemId") - return jsonify({"error": "Missing sessionId or itemId"}), 400 - - # Validate session + item exists and kind is correct - full_sess = load_session_with_items(session_id) - if full_sess is None: - current_app.logger.warning(f"submit_compound: session not found: {session_id}") - return jsonify({"error": "Session not found"}), 404 - - item = next((it for it in full_sess.get("items", []) if it.get("id") == item_id), None) - if item is None: - current_app.logger.warning(f"submit_compound: item not found: {item_id}") - return jsonify({"error": "Item not found"}), 404 - - if item.get("kind") != "compound": - current_app.logger.warning(f"submit_compound: wrong kind={item.get('kind')}") - return jsonify({"error": "Item is not a compound"}), 400 - - t0 = time.time() - - # Set status=processing early on this item only - def mark_processing(it: dict) -> None: - """ - Update item details and mark as processing. - - :param it: the item dictionary to update - """ - it["name"] = name or it.get("name") - it["smiles"] = smiles or it.get("smiles") - _set_item_status_inplace(it, "processing") - - ok = update_item(session_id, item_id, mark_processing) - if not ok: - current_app.logger.warning(f"submit_compound: failed to mark item as processing: {item_id}") - return jsonify({"error": "Item not found during update"}), 404 - - try: - # Heavy work - generator = _setup_fingerprint_generator() - ( - tagged_smiles, - coverages, - fp_hex_strings, - linear_readout - ) = _compute_compound(generator, smiles) - - # Set final status=done and store results on this item only - def mark_done(it: dict) -> None: - it["name"] = name or it.get("name") - it["smiles"] = smiles or it.get("smiles") - it["taggedSmiles"] = tagged_smiles - it["retrofingerprints"] = [ - { - "id": get_unique_identifier(), - "retrofingerprint512": fp_hex, - "score": cov, - } - for cov, fp_hex in zip(coverages, fp_hex_strings, strict=True) - ] - it["primarySequences"] = linear_readout - _set_item_status_inplace(it, "done") - - update_item(session_id, item_id, mark_done) - - except Exception as e: - current_app.logger.exception(f"submit_compound: error for item_id={item_id}") - - def mark_error(it: dict) -> None: - _set_item_status_inplace(it, "error", error_message=str(e)) - - update_item(session_id, item_id, mark_error) - - elapsed = int((time.time() - t0) * 1000) - return jsonify({ - "ok": False, - "status": "error", - "elapsed_ms": elapsed, - "error": str(e), - }), 500 - - elapsed = int((time.time() - t0) * 1000) - current_app.logger.info(f"submit_compound: finished item_id={item_id} elapsed_ms={elapsed}") - - return jsonify({ - "ok": True, - "status": "done", - "elapsed_ms": elapsed, - }), 200 - - -@blp_submit_gene_cluster.post("/api/submitGeneCluster") -def submit_gene_cluster() -> tuple[dict[str, str], int]: - """ - Endpoint to submit a gene cluster for processing. - - Expected JSON body: - - sessionId: str - - itemId: str - - name: str - - fileContent: str - - :return: a tuple containing a JSON response and HTTP status code - """ - payload = request.get_json(force=True) or {} - - session_id = payload.get("sessionId") - item_id = payload.get("itemId") - name = payload.get("name") - file_content = payload.get("fileContent") - readout_level = payload.get("readoutLevel", None) - - if readout_level not in ("rec", "gene"): - current_app.logger.error(f"submit_gene_cluster: invalid readoutLevel={readout_level}") - return jsonify({"error": "Invalid readoutLevel; must be 'rec' or 'gene'"}), 400 - - current_app.logger.info(f"submit_gene_cluster called: session_id={session_id} item_id={item_id}") - - if not session_id or not item_id: - current_app.logger.warning("submit_gene_cluster: missing sessionId or itemId") - return jsonify({"error": "Missing sessionId or itemId"}), 400 - - # Validate session + item exists and kind is correct - full_sess = load_session_with_items(session_id) - if full_sess is None: - current_app.logger.warning(f"submit_gene_cluster: session not found: {session_id}") - return jsonify({"error": "Session not found"}), 404 - - item = next((it for it in full_sess.get("items", []) if it.get("id") == item_id), None) - if item is None: - current_app.logger.warning(f"submit_gene_cluster: item not found: {item_id}") - return jsonify({"error": "Item not found"}), 404 - - if item.get("kind") != "gene_cluster": - current_app.logger.warning(f"submit_gene_cluster: wrong kind={item.get('kind')}") - return jsonify({"error": "Item is not a gene cluster"}), 400 - - t0 = time.time() - - # Set status=processing early on this item only - def mark_processing(it: dict) -> None: - """ - Update item details and mark as processing. - - :param it: the item dictionary to update - """ - it["name"] = name or it.get("name") - it["fileContent"] = file_content or it.get("fileContent") - _set_item_status_inplace(it, "processing") - - ok = update_item(session_id, item_id, mark_processing) - if not ok: - current_app.logger.warning(f"submit_gene_cluster: failed to mark item as processing: {item_id}") - return jsonify({"error": "Item not found during update"}), 404 - - try: - # Heavy work - generator = _setup_fingerprint_generator() - scores, fp_hex_strings, readout = _compute_gene_cluster(generator, item_id, file_content, readout_level) - - # Set final status=done and store results on this item only - def mark_done(it: dict) -> None: - it["name"] = name or it.get("name") - it["fileContent"] = file_content or it.get("fileContent") - it["retrofingerprints"] = [ - { - "id": get_unique_identifier(), - "retrofingerprint512": fp_hex, - "score": score, - } - for idx, (score, fp_hex) in enumerate(zip(scores, fp_hex_strings, strict=True), start=1) - ] - it["primarySequences"] = readout - _set_item_status_inplace(it, "done") - - update_item(session_id, item_id, mark_done) - - except Exception as e: - current_app.logger.exception(f"submit_gene_cluster: error for item_id={item_id}") - - def mark_error(it: dict) -> None: - _set_item_status_inplace(it, "error", error_message=str(e)) - - update_item(session_id, item_id, mark_error) - - elapsed = int((time.time() - t0) * 1000) - return jsonify({ - "ok": False, - "status": "error", - "elapsed_ms": elapsed, - "error": str(e), - }), 500 - - elapsed = int((time.time() - t0) * 1000) - current_app.logger.info(f"submit_gene_cluster: finished item_id={item_id} elapsed_ms={elapsed}") - - return jsonify({ - "ok": True, - "status": "done", - "elapsed_ms": elapsed, - }), 200 diff --git a/src/server/routes/models_registry.py b/src/server/routes/models_registry.py deleted file mode 100644 index 364dda5..0000000 --- a/src/server/routes/models_registry.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Module for loading and caching machine learning models used in the application.""" - -from pathlib import Path -import os - -from flask import current_app -import joblib - - -CACHE_DIR = os.environ.get("CACHE_DIR", "/app/cache") -PARAS_MODEL_PATH = os.environ.get("PARAS_MODEL_PATH", None) -_model_cache: dict[str, object | None] = {} - - -# Make sure cache directory exists -os.makedirs(CACHE_DIR, exist_ok=True) - - -def get_cache_dir() -> str: - """ - Get the cache directory path. - - :return: the cache directory path - """ - return CACHE_DIR - - -def get_paras_model() -> object | None: - """ - Load and return the PARAS model from disk, caching it in memory. - - :return: the loaded PARAS model, or None if not found - """ - # Check if model is already cached - if "paras" in _model_cache: - return _model_cache["paras"] - - # Check if model path is defined - if PARAS_MODEL_PATH: - # Model path is defined; attempt to load the model - path = Path(PARAS_MODEL_PATH) - if path.is_file(): - current_app.logger.info(f"Loading PARAS model from {path}") - _model_cache["paras"] = joblib.load(path) - else: - current_app.logger.warning(f"PARAS model not found at {path}; letting BioCracker download into {CACHE_DIR}") - _model_cache["paras"] = None - return _model_cache["paras"] - else: - # Model path is not defined - current_app.logger.warning("PARAS_MODEL_PATH not set; letting BioCracker download into CACHE_DIR") - return None diff --git a/src/server/routes/query.py b/src/server/routes/query.py deleted file mode 100644 index 0248904..0000000 --- a/src/server/routes/query.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Module for defining database query endpoints.""" - -import os -import time - -import psycopg -from flask import Blueprint, request, jsonify -from psycopg import sql -from pgvector.psycopg import register_vector - -from routes.query_registry import QUERIES - - -blp = Blueprint("query", __name__) - - -# Knobs -DEFAULT_LIMIT = 500 -MAX_OFFSET = 50_000 -STATEMENT_TIMEOUT_MS = 3000 # 3 seconds - - -def dsn_from_env() -> str: - """ - Construct the Postgres DSN from environment variables. - - :return: the Postgres DSN string - """ - dsn = os.getenv("DATABASE_URL") - if dsn: - return dsn - host = os.getenv("DB_HOST", "db") - port = os.getenv("DB_PORT", "5432") - name = os.getenv("DB_NAME", "bionexus") - user = os.getenv("DB_USER", "app_ro") - pwd = os.getenv("DB_PASS") or os.getenv("DB_PASSWORD", "apppass_ro") - return f"postgresql://{user}:{pwd}@{host}:{port}/{name}" - - -def coerce_params(spec: dict[str, type], data: dict) -> tuple[dict, str | None]: - """ - Coerce and validate parameters from the request data according to the spec. - - :param spec: a dictionary mapping parameter names to expected types - :param data: the input data dictionary - :return: a tuple containing the coerced parameters dictionary and an error message (or None - """ - out = {} - for k, typ in spec.items(): - if k not in data: - return {}, f"Missing param: {k}" - v = data[k] - try: - if typ is float: out[k] = float(v) - elif typ is int: out[k] = int(v) - elif typ is str: out[k] = str(v) - else: out[k] = v - except Exception: - return {}, f"Invalid type for {k}" - return out, None - - -def execute_named_query( - name: str, - params: dict | None = None, - paging: dict | None = None, - order: dict | None = None, -) -> dict: - """ - Execute a predefined database query with parameters, paging, and ordering. - - :param name: the name of the predefined query - :param params: a dictionary of query parameters - :param paging: a dictionary with 'limit' and 'offset' for paging - :param order: a dictionary with 'column' and 'dir' for ordering - :return: a dictionary with query results - :raises ValueError: if there is a parameter validation error - :raises TimeoutError: if the query times out - :raises RuntimeError: if there is a database error - """ - params = params or {} - paging = paging or {} - order = order or {} - - if not name or name not in QUERIES: - raise ValueError("Invalid or missing query name") - - qinfo = QUERIES[name] - - # Validate required/optional params - req_spec = qinfo.get("required", {}) - opt_spec = qinfo.get("optional", {}) - typed, err = coerce_params(req_spec, params) - if err: - raise ValueError(err) - # Optional params (coerce if present) - for k, typ in opt_spec.items(): - if k in params: - try: - if typ is float: typed[k] = float(params[k]) - elif typ is int: typed[k] = int(params[k]) - elif typ is str: typed[k] = str(params[k]) - else: typed[k] = params[k] - except Exception: - raise ValueError(f"Invalid type for {k}") - - # Preprocess params for specific queries - preprocess = qinfo.get("preprocess_params") - if preprocess: - try: - typed = preprocess(typed) or typed - except Exception as e: - raise ValueError(f"Parameter preprocessing error: {str(e)}") - - # Paging - limit = int(paging.get("limit", DEFAULT_LIMIT)) - offset = int(paging.get("offset", 0)) - offset = max(0, min(MAX_OFFSET, offset)) - typed["limit"] = limit - typed["offset"] = offset - - # Order-by (whitelisted) - allowed_cols = qinfo.get("allowed_order_cols", set()) - order_col = (order.get("column") if isinstance(order, dict) else None) or qinfo.get("default_order_col") - if order_col not in allowed_cols: - order_col = qinfo.get("default_order_col") - order_dir = (order.get("dir", qinfo.get("default_order_dir", "ASC")) if isinstance(order, dict) else qinfo.get("default_order_dir", "ASC")) - order_dir = "DESC" if str(order_dir).upper().startswith("D") else "ASC" - - # Render final SQL safely for the ORDER BY identifier - base_sql = qinfo["sql"] - rendered = ( - base_sql.format( - order_col=sql.Identifier(order_col).as_string(psycopg.connect(dsn_from_env())), - order_dir=order_dir, - ).rstrip().rstrip(";") - + f" LIMIT %(limit)s OFFSET %(offset)s" - ) - - # Exec (read-only, short timeout, public schema) - dsn = dsn_from_env() - t0 = time.time() - try: - with psycopg.connect( - dsn, - options=f"-c statement_timeout={STATEMENT_TIMEOUT_MS} " - f"-c idle_in_transaction_session_timeout={STATEMENT_TIMEOUT_MS} " - f"-c search_path=public", - ) as conn: - register_vector(conn) - with conn.cursor() as cur: - cur.execute(rendered, typed) - rows = cur.fetchall() - cols = [d.name for d in cur.description] - except psycopg.errors.QueryCanceled: - raise TimeoutError(f"Query timeout (>{STATEMENT_TIMEOUT_MS} ms)") - except Exception as e: - raise RuntimeError(f"Database error: {str(e)}") - - elapsed = int((time.time() - t0) * 1000) - out_rows = [dict(zip(cols, r)) for r in rows] - - return { - "name": name, - "columns": cols, - "rows": out_rows, - "rowCount": len(out_rows), - "limit": limit, - "offset": offset, - "elapsed_ms": elapsed, - } - - -@blp.post("/api/query") -def run_query(): - """ - Run a predefined database query with parameters, paging, and ordering. - - :return: JSON response with query results or error message - """ - payload = request.get_json(force=True) or {} - name = payload.get("name") - params = payload.get("params", {}) - paging = payload.get("paging", {}) - order = payload.get("order", {}) - - try: - result = execute_named_query(name, params=params, paging=paging, order=order) - except ValueError as e: - return jsonify({"error": str(e)}), 400 - except TimeoutError as e: - return jsonify({"error": str(e)}), 408 - except RuntimeError as e: - return jsonify({"error": str(e)}), 400 - - return jsonify(result), 200 diff --git a/src/server/routes/query_registry.py b/src/server/routes/query_registry.py deleted file mode 100644 index 111bbe9..0000000 --- a/src/server/routes/query_registry.py +++ /dev/null @@ -1,217 +0,0 @@ -"""Module defining the available SQL queries for the query registry.""" - - -from pgvector import Vector - -from routes.helpers import hex_to_bits - - -def preprocess_cross_modal_params(typed: dict) -> dict: - """ - Preprocess parameters for the cross-modal retrieval query. - - :param typed: the typed parameters dictionary - :return: the preprocessed parameters dictionary - """ - fp_hex_string = typed["retrofingerprint512"] - fp = hex_to_bits(fp_hex_string) - fp = [float(x) for x in fp] - typed["qv"] = Vector(fp) - - query_settings = typed.get("querySettings", {}) - similarity_threshold = query_settings.get("similarityThreshold", 0.0) - typed["similarity_threshold"] = similarity_threshold - - search_space = query_settings.get("searchSpace", "only_compounds") - typed["search_space"] = search_space - - return typed - - -QUERIES = { - "binned_coverage": { - "sql": """ - SELECT - width_bucket(LEAST(coverage, 1 - 1e-12), 0.0, 1.0, 20) AS bin_id, - 0.05 * (width_bucket(LEAST(coverage, 1 - 1e-12), 0.0, 1.0, 20) - 1) AS bin_start, - 0.05 * width_bucket(LEAST(coverage, 1 - 1e-12), 0.0, 1.0, 20) AS bin_end, - COUNT(*) AS count - FROM retromol_compound - WHERE coverage IS NOT NULL - GROUP BY bin_id - ORDER BY {order_col} {order_dir} - """, - "allowed_order_cols": {"bin_id", "bin_start", "bin_end", "count"}, - "default_order_col": "bin_start", - "default_order_dir": "ASC", - "required": {}, - "optional": {}, - }, - "fingerprint_source_counts": { - "sql": """ - SELECT cpr.source, COUNT(*) AS count_per_source - FROM retrofingerprint AS rfp - JOIN retromol_compound AS rcp ON rfp.retromol_compound_id = rcp.id - JOIN compound_record AS cpr ON rfp.retromol_compound_id = cpr.compound_id - WHERE rcp.coverage >= 0.95 - GROUP BY cpr.source - ORDER BY {order_col} {order_dir} - """, - "allowed_order_cols": {"source", "count_per_source"}, - "default_order_col": "count_per_source", - "default_order_dir": "DESC", - "required": {}, - "optional": {}, - }, - "search_compound_by_name": { - "sql": """ - SELECT cpr.name, MIN(cp.smiles) as smiles - FROM compound as cp - JOIN compound_record as cpr ON cp.id = cpr.compound_id - WHERE cpr.name ILIKE %(q)s || '%%' - GROUP BY cpr.name - ORDER BY {order_col} {order_dir} - """, - "allowed_order_cols": {"name", "smiles"}, - "default_order_col": "name", - "default_order_dir": "ASC", - "required": {"q": str}, - "optional": {}, - }, - "cross_modal_retrieval": { - "sql": """ - SELECT - rf.id AS identifier, - CASE - WHEN rf.retromol_compound_id IS NOT NULL AND rf.biocracker_genbank_id IS NULL THEN 'compound' - WHEN rf.biocracker_genbank_id IS NOT NULL AND rf.retromol_compound_id IS NULL THEN 'gene_cluster' - ELSE 'unknown' - END AS type, - COALESCE(cr.source, gr.source) AS source, - COALESCE(cr.ext_id, gr.ext_id) AS ext_id, - COALESCE(cr.name, CONCAT('Region ', gr.ext_id::text)) AS name, - (1.0 - (rf.fp_retro_b512_vec_binary <=> %(qv)s)) AS score - FROM retrofingerprint AS rf - LEFT JOIN retromol_compound rmc ON rmc.id = rf.retromol_compound_id - LEFT JOIN compound c ON c.id = rmc.compound_id - LEFT join compound_record cr ON cr.compound_id = c.id - LEFT JOIN biocracker_genbank bg ON bg.id = rf.biocracker_genbank_id - LEFT JOIN genbank_region gr ON gr.id = bg.genbank_region_id - WHERE vector_norm(rf.fp_retro_b512_vec_binary) > 0 - AND vector_norm(%(qv)s) > 0 - AND (1.0 - (rf.fp_retro_b512_vec_binary <=> %(qv)s)) >= %(similarity_threshold)s - AND ( - %(search_space)s = 'both' - OR (%(search_space)s = 'only_compounds' AND rf.retromol_compound_id IS NOT NULL AND rf.biocracker_genbank_id IS NULL) - OR (%(search_space)s = 'only_gene_clusters' AND rf.biocracker_genbank_id IS NOT NULL AND rf.retromol_compound_id IS NULL) - ) - ORDER BY {order_col} {order_dir} - """, - "allowed_order_cols": {"identifier", "name", "source", "ext_id", "score"}, - "default_order_col": "score", - "default_order_dir": "DESC", - "required": { "retrofingerprint512": str, "querySettings": dict }, - "optional": {}, - "preprocess_params": preprocess_cross_modal_params, - }, - "compound_info_by_id": { - "sql": """ - SELECT - cr.name, - c.smiles - FROM retrofingerprint AS rfp - JOIN retromol_compound AS rmc - ON rfp.retromol_compound_id = rmc.id - JOIN compound AS c - ON rmc.compound_id = c.id - JOIN compound_record AS cr - ON c.id = cr.compound_id - WHERE rfp.id = %(compound_id)s; - """, - "allowed_order_cols": set(), - "default_order_col": "", - "default_order_dir": "ASC", - "required": { "compound_id": int }, - "optional": {}, - }, - "annotation_counts_full": { - "sql": """ - SELECT - scheme, - key, - value, - COUNT(*) AS annotation_count, - COUNT(DISTINCT compound_id) AS n_compounds, - COUNT(DISTINCT genbank_region_id) AS n_genbank_regions - FROM annotation - GROUP BY scheme, key, value - ORDER BY {order_col} {order_dir} - """, - "allowed_order_cols": { - "scheme", "key", "value", - "annotation_count", "n_compounds", "n_genbank_regions" - }, - "default_order_col": "annotation_count", - "default_order_dir": "DESC", - "required": {}, - "optional": {}, - }, - "annotation_counts_subset": { - "sql": """ - SELECT - scheme, - key, - value, - COUNT(*) AS annotation_count, - COUNT(DISTINCT compound_id) AS n_compounds, - COUNT(DISTINCT genbank_region_id) AS n_genbank_regions - FROM annotation - WHERE - compound_id = ANY(%(compound_ids)s::bigint[]) - OR genbank_region_id = ANY(%(genbank_region_ids)s::bigint[]) - GROUP BY scheme, key, value - ORDER BY {order_col} {order_dir} - """, - "allowed_order_cols": { - "scheme", "key", "value", - "annotation_count", "n_compounds", "n_genbank_regions" - }, - "default_order_col": "annotation_count", - "default_order_dir": "DESC", - "required": { - "compound_ids": list, - "genbank_region_ids": list, - }, - "optional": {}, - }, - "retrieve_items_by_fingerprint_ids": { - "sql": """ - SELECT DISTINCT - rmc.compound_id, - bkg.genbank_region_id - FROM retrofingerprint AS rf - LEFT JOIN retromol_compound AS rmc - ON rf.retromol_compound_id = rmc.id - LEFT JOIN biocracker_genbank AS bkg - ON rf.biocracker_genbank_id = bkg.id - WHERE rf.id = ANY(%(rf_ids)s::bigint[]); - """, - "allowed_order_cols": set(), - "default_order_col": "", - "default_order_dir": "ASC", - "required": { "rf_ids": list }, - "optional": {}, - }, - "target_counts": { - "sql": """ - SELECT - (SELECT COUNT(DISTINCT compound_id) FROM annotation WHERE compound_id IS NOT NULL) AS n_compounds, - (SELECT COUNT(DISTINCT genbank_region_id) FROM annotation WHERE genbank_region_id IS NOT NULL) AS n_genbank_regions - """, - "allowed_order_cols": set(), - "default_order_col": "", - "default_order_dir": "ASC", - "required": {}, - "optional": {}, - }, -} \ No newline at end of file diff --git a/src/server/routes/session.py b/src/server/routes/session.py index 0c76e2b..6248650 100644 --- a/src/server/routes/session.py +++ b/src/server/routes/session.py @@ -9,6 +9,7 @@ load_session_with_items, merge_session_from_client, count_sessions, + delete_item as redis_delete_item, ) @@ -16,6 +17,7 @@ blp_delete_session = Blueprint("delete_session", __name__) blp_get_session = Blueprint("get_session", __name__) blp_save_session = Blueprint("save_session", __name__) +blp_delete_item = Blueprint("delete_item", __name__) @blp_create_session.post("/api/createSession") @@ -131,3 +133,27 @@ def save_session() -> tuple[dict[str, str], int]: session_id = new_session.get("sessionId") return jsonify({"sessionId": session_id}), 200 + + +@blp_delete_item.post("/api/deleteSessionItem") +def delete_item() -> tuple[dict[str, str], int]: + """ + Delete a single item from a session. + + :return: a tuple containing a dictionary with the operation status and an HTTP status code. + """ + payload = request.get_json(force=True) or {} + session_id = payload.get("sessionId") + item_id = payload.get("itemId") + + if not isinstance(session_id, str) or not session_id: + return {"error": "Missing or invalid sessionId"}, 400 + + if not isinstance(item_id, str) or not item_id: + return {"error": "Missing or invalid itemId"}, 400 + + ok = redis_delete_item(session_id, item_id) + if not ok: + return {"error": "Session or item not found"}, 404 + + return jsonify({"ok": True}), 200 diff --git a/src/server/routes/session_store.py b/src/server/routes/session_store.py index feab83e..e1a1957 100644 --- a/src/server/routes/session_store.py +++ b/src/server/routes/session_store.py @@ -21,6 +21,19 @@ EVENTS_CHANNEL_PREFIX = "session_events:" +# Fields that are owned by the server and should not be overwritten by client data +# Item 'name' and 'updatedAt' are client-editable +SERVER_OWNED_FIELDS = { + "name", + "score", + "payload", + "status", + "errorMessage", + "smiles", + "fileContent", +} + + def _event_channel(session_id: str) -> str: """ Get the Redis Pub/Sub channel name for session events. @@ -343,21 +356,47 @@ def update_item(session_id: str, item_id: str, mutator: Callable[[dict[str, Any] }) return True + + +def delete_item(session_id: str, item_id: str) -> None: + """ + Delete a single item from a session (both item blob and its id in session meta). + + :return: True if deleted, False if session/item not found + .. note:: publishes a session_merged event so clients refresh via SSE + """ + meta = load_session_meta(session_id) + if meta is None: + return False + item_ids = meta.get("items", []) or [] + if not isinstance(item_ids, list): + item_ids = [] -# Fields that are owned by the server and should not be overwritten by client data -SERVER_OWNED_FIELDS = { - "status", - "errorMessage", - "retrofingerprints", - "retrofingerprint512", - "morganfingerprint2048r2", - "primarySequences", - "smiles", - "taggedSmiles", - "coverage", - "updatedAt", -} + if item_id not in item_ids: + return False + + # Remove id from meta list + item_ids = [x for x in item_ids if x != item_id] + meta["items"] = item_ids + + # Delete item blob + redis_client.delete(_item_key(session_id, item_id)) + + # Save updated session meta + redis_client.set( + _session_key(session_id), + json.dumps(meta), + ex=SESSION_TTL_SECONDS, + ) + + # Tell SSE clients to refresh + publish_session_event(session_id, { + "type": "session_merged", + "deletedItemId": item_id, + }) + + return True def merge_session_from_client(new_session: dict[str, Any]) -> None: diff --git a/src/server/routes/views.py b/src/server/routes/views.py deleted file mode 100644 index 40d38a1..0000000 --- a/src/server/routes/views.py +++ /dev/null @@ -1,624 +0,0 @@ -"""Module for handling view requests.""" - -import math -import re -import time - -import numpy as np -import umap -from sklearn.decomposition import PCA -from flask import Blueprint, current_app, request, jsonify - -from versalign.aligner import setup_aligner -from versalign.msa import calc_msa -from versalign.printing import format_alignment -from versalign.scoring import create_substituion_matrix_dynamically - -from retromol.chem import calc_tanimoto_similarity - -from routes.helpers import hex_to_bits, get_unique_identifier -from routes.query import execute_named_query - - -blp_get_embedding_space = Blueprint("get_embedding_space", __name__) -blp_enrich = Blueprint("enrich", __name__) -blp_run_msa = Blueprint("run_msa", __name__) - - -def _log_hypergeom_probability(a: int, b: int, c: int, d: int) -> float: - """ - Compute log probability of a 2x2 table under the hypergeometric model. - - :param a: count in cell (1,1) - :param b: count in cell (1,2) - :param c: count in cell (2,1) - :param d: count in cell (2,2) - :return: log probability - """ - total = a + b + c + d - return ( - math.lgamma(a + b + 1) - - math.lgamma(a + 1) - - math.lgamma(b + 1) - + math.lgamma(c + d + 1) - - math.lgamma(c + 1) - - math.lgamma(d + 1) - - math.lgamma(total + 1) - + math.lgamma(a + c + 1) - + math.lgamma(b + d + 1) - ) - - -def _fisher_exact_two_sided(a: int, b: int, c: int, d: int) -> float: - """ - Return two-sided Fisher's exact test p-value for a 2x2 table. - - :param a: count in cell (1,1) - :param b: count in cell (1,2) - :param c: count in cell (2,1) - :param d: count in cell (2,2) - :return: two-sided p-value - """ - if min(a, b, c, d) < 0: - raise ValueError("Fisher's exact test counts must be non-negative") - - r1 = a + b - r2 = c + d - c1 = a + c - - min_a = max(0, c1 - r2) - max_a = min(r1, c1) - - obs_log_prob = _log_hypergeom_probability(a, b, c, d) - p_sum = 0.0 - for x in range(min_a, max_a + 1): - y = r1 - x - z = c1 - x - w = r2 - z - if y < 0 or z < 0 or w < 0: - continue - log_prob = _log_hypergeom_probability(x, y, z, w) - if log_prob <= obs_log_prob + 1e-12: - p_sum += math.exp(log_prob) - - return min(p_sum, 1.0) - - -@blp_get_embedding_space.post("/api/getEmbeddingSpace") -def get_embedding_space() -> tuple[dict[str, str], int]: - """ - Handle POST requests to retrieve embedding space information. - - :return: a tuple containing an empty dictionary and HTTP status code 200 - """ - payload = request.get_json(force=True) or {} - - session_id = payload.get("sessionId") - items = payload.get("items", []) - method = (payload.get("method") or "umap").lower() - - current_app.logger.info(f"get_embedding_space called: session_id={session_id} items_count={len(items)}") - - if not session_id or not items: - current_app.logger.warning("get_embedding_space: missing sessionId or items") - return jsonify({"error": "Missing sessionId or items"}), 400 - - # Filter out any item that does not have required fields for item - required_fields_item = {"id", "kind", "retrofingerprints"} - items = [item for item in items if required_fields_item.issubset(item.keys())] - - # Filter out any item that does not have required fields for fingerprint item - required_fields_fp = {"id", "retrofingerprint512", "score"} - for item in items: - item["retrofingerprints"] = [fp_item for fp_item in item["retrofingerprints"] if required_fields_fp.issubset(fp_item.keys())] - - t0 = time.time() - - # Gather all "kind" types; if both "compound" and "gene_cluster" are present, set reduce_fp to True - reduce_fp = False - kinds = set(item["kind"] for item in items) - if "compound" in kinds and "gene_cluster" in kinds: - reduce_fp = True - - try: - # Decode retrofingerprints - kinds, parent_ids, child_ids, fps = [], [], [], [] - for item in items: - for fp_item in item["retrofingerprints"]: - kinds.append(item["kind"]) - parent_ids.append(item["id"]) - child_ids.append(fp_item["id"]) - fps.append(hex_to_bits(fp_item["retrofingerprint512"])) - - # Handle case with no retrofingerprints - if len(fps) == 0: - # Return empty points - points = [] - - else: - # Convert to numpy array - fps = np.array(fps) - - # Reduce dimensionality if needed - if reduce_fp: - # Remove every bit that is not set in "gene_cluster" retrofingerprints - gene_cluster_fps = fps[[i for i, kind in enumerate(kinds) if kind == "gene_cluster"]] - bits_to_keep = np.any(gene_cluster_fps, axis=0) - fps = fps[:, bits_to_keep] - - # Reduce dimensionality using selected method - n_samples = fps.shape[0] - if n_samples == 1: - # Single point: put it at the origin (jitter will be applied later so might not be exactly at origin) - reduced = np.zeros((1, 2)) - elif n_samples <= 3: - # UMAP's spectral step is fragile for very small N - # Put points on unit circle evenly spaced - angles = np.linspace(0, 2 * np.pi, n_samples, endpoint=False) - reduced = np.stack([np.cos(angles), np.sin(angles)], axis=1) - else: - if method == "pca": - # PCA - pca = PCA(n_components=2, random_state=42) - reduced = pca.fit_transform(fps) - else: - # Default to UMAP - n_neighbors = min(15, n_samples - 1) - reducer = umap.UMAP( - n_components=2, - n_neighbors=n_neighbors, - random_state=42, - metric="cosine" - ) - reduced = reducer.fit_transform(fps) - - points = [ - { - "parent_id": parent_id, - "child_id": child_id, - "kind": kind, - "x": float(reduced[i, 0]), - "y": float(reduced[i, 1]), - } for i, (kind, parent_id, child_id) in enumerate(zip(kinds, parent_ids, child_ids)) - ] - except Exception as e: - current_app.logger.error(f"get_embedding_space: error processing items: {e}") - return jsonify({"error": "Error processing items"}), 500 - - elapsed = int((time.time() - t0) * 1000) - current_app.logger.info(f"get_embedding_space: finished session_id={session_id} elapsed_ms={elapsed}") - - return jsonify({ - "ok": True, - "status": "done", - "elapsed_ms": elapsed, - "points": points - }), 200 - - -@blp_enrich.post("/api/enrich") -def run_enrichment() -> tuple[dict[str, str], int]: - """ - Handle POST requests to run enrichment analysis. - - :return: a tuple containing an empty dictionary and HTTP status code 200 - """ - payload = request.get_json(force=True) or {} - - fp_hex_string = payload.get("retrofingerprint512") - query_settings = payload.get("querySettings", {}) - search_space = query_settings.get("searchSpace", "only_compounds") - - # Guard against missing fingerprint - if not fp_hex_string: - return jsonify({"error": "Missing retrofingerprint512"}), 400 - - t0 = time.time() - - result = execute_named_query( - name="cross_modal_retrieval", - params={ - "retrofingerprint512": fp_hex_string, - "querySettings": query_settings, - }, - paging={}, - order={}, - ) - - # If num_rows is equal to max_limit_in_group, we know there are more results - # Throw error and ask user to up the score threshold - max_limit_in_group = 1000 - num_rows = len(result["rows"]) - if num_rows >= max_limit_in_group: - return jsonify({ - "error": f"Too many items in in-group (>={max_limit_in_group}). Please increase the score threshold in query settings and try again." - }), 400 - - # Get unique IDs - readout_ids = set([item["identifier"] for item in result["rows"]]) - - # Map those to compounds and genbank regions - item_ids = execute_named_query( - name="retrieve_items_by_fingerprint_ids", - params={"rf_ids": list(readout_ids)}, - paging={ "limit": 1_000_000_000 }, # use high limit otherwise default limit of 1000 applies - order={}, - ) - compound_ids: set[int] = set() - genbank_region_ids: set[int] = set() - for row in item_ids["rows"]: - compound_id = row["compound_id"] - genbank_region_id = row["genbank_region_id"] - if compound_id and not genbank_region_id: compound_ids.add(compound_id) - elif genbank_region_id and not compound_id: genbank_region_ids.add(genbank_region_id) - # Ignore rows that have both or neither - - # Filter for search space - if search_space == "only_compounds": - genbank_region_ids = set() - elif search_space == "only_gene_clusters": - compound_ids = set() - - # If somehow no targets, return empty result - if not compound_ids and not genbank_region_ids: - return jsonify({ - "ok": True, - "status": "done", - "elapsed_ms": int((time.time() - t0) * 1000), - "result": { - "querySettings": query_settings, - "items": [], - } - }), 200 - - subset_total_targets = len(compound_ids) + len(genbank_region_ids) - - # Get all annotation counts - ann_full = execute_named_query( - name="annotation_counts_full", - params={}, - paging={ "limit": 1_000_000_000 }, # use high limit otherwise default limit of 1000 applies - order={}, - ) - - # Get annotation counts for subset - ann_subset = execute_named_query( - name="annotation_counts_subset", - params={ - "compound_ids": list(compound_ids), - "genbank_region_ids": list(genbank_region_ids), - }, - paging={ "limit": 1_000_000_000 }, # use high limit otherwise default limit of 1000 applies - order={}, - ) - - # Total number of targets in universe (all compounds + all genbank regions) - bg_counts = execute_named_query( - name="target_counts", - params={}, - paging={}, - order={}, - ) - bg_row = bg_counts["rows"][0] - - if search_space == "only_compounds": - background_total_targets = int(bg_row["n_compounds"]) - elif search_space == "only_gene_clusters": - background_total_targets = int(bg_row["n_genbank_regions"]) - else: - background_total_targets = int(bg_row["n_compounds"]) + int(bg_row["n_genbank_regions"]) - - # Do statistical enrichment analysis (Fisher's exact test) - full_rows = ann_full.get("rows", []) - subset_rows = ann_subset.get("rows", []) - enrichment_candidates: list[dict] = [] - - if ( - subset_rows - and full_rows - and subset_total_targets > 0 - and background_total_targets > subset_total_targets - ): - def _ann_key(row: dict) -> tuple: - return (row.get("scheme"), row.get("key"), row.get("value")) - - full_lookup = {_ann_key(row): row for row in full_rows} - - for row in subset_rows: - key = _ann_key(row) - base_row = full_lookup.get(key) - if not base_row: - continue - - # Determine subset counts based on search space - row_compounds = int(row.get("n_compounds", 0)) - row_regions = int(row.get("n_genbank_regions", 0)) - if search_space == "only_compounds": - subset_with = row_compounds - elif search_space == "only_gene_clusters": - subset_with = row_regions - else: - subset_with = row_compounds + row_regions - if subset_with <= 0: - continue - - # Determine background counts based on search space - base_compounds = int(base_row.get("n_compounds", 0)) - base_regions = int(base_row.get("n_genbank_regions", 0)) - if search_space == "only_compounds": - background_with = base_compounds - elif search_space == "only_gene_clusters": - background_with = base_regions - else: - background_with = base_compounds + base_regions - if background_with <= 0: - continue - - # 2x2 per target: - # a = subset tarets WITH this annotation - # b = subset targets WITHOUT this annotation - # c = background-only targets WITH this annotation - # d = background-only targets WITHOUT this annotation - a = subset_with - b = subset_total_targets - a - - background_only_total = background_total_targets - subset_total_targets - c = background_with - a - d = background_only_total - c - - if min(a, b, c, d) < 0: - continue - - p_value = _fisher_exact_two_sided(a, b, c, d) - - enrichment_candidates.append({ - "id": f"{row['scheme']}::{row['key']}::{row['value']}", - "schema": row["scheme"], - "key": row["key"], - "value": row["value"], - "subset_count": a, - "background_count": background_with, - "p_value": p_value, - }) - - # Multiple hypothesis correction (Benjamini-Hochberg) - # Need to do this to take into account the number of tests performed) - candidate_count = len(enrichment_candidates) - if candidate_count > 0: - # Sort by raw p-value - sorted_indices = sorted( - range(candidate_count), - key=lambda idx: enrichment_candidates[idx]["p_value"] - ) - - # Calculate adjusted p-values using Benjamini-Hochberg procedure - cumulative_min = 1.0 - adjusted_values = [1.0] * candidate_count - for order_idx in range(candidate_count - 1, -1, -1): - candidate_idx = sorted_indices[order_idx] - rank = order_idx + 1 - raw_p = enrichment_candidates[candidate_idx]["p_value"] - adjusted = (raw_p * candidate_count) / rank - if adjusted < cumulative_min: - cumulative_min = adjusted - adjusted_values[candidate_idx] = cumulative_min if cumulative_min < 1.0 else 1.0 - for idx, adj_p in enumerate(adjusted_values): - enrichment_candidates[idx]["adjusted_p_value"] = adj_p - - elapsed = int((time.time() - t0) * 1000) - - # Sort results by adjusted p-value and p-value - items = [] - if enrichment_candidates: - items = sorted( - ( - { - "id": candidate["id"], - "schema": candidate["schema"], - "key": candidate["key"], - "value": candidate["value"], - "p_value": candidate["p_value"], - "adjusted_p_value": candidate.get("adjusted_p_value", candidate["p_value"]), - } - for candidate in enrichment_candidates - ), - key=lambda entry: (entry["adjusted_p_value"], entry["p_value"]), - ) - - return jsonify({ - "ok": True, - "status": "done", - "elapsed_ms": elapsed, - "result": { - "querySettings": query_settings, - "items": items, - } - }), 200 - - -def motif_compare(a: dict, b: dict) -> float: - """ - Score function for sequence alignment. - - :param a: first item - :param b: second item - :return: similarity score in [0, 1] - """ - max_score = 1.0 - - def _base_unit(name: str | None) -> str | None: - if name is None: - return None - stripped = str(name).strip() - if not stripped: - return None - # Remove trailing digits (e.g., "A1" -> "A") - return re.sub(r"\d+$", "", stripped).upper() - - def _unit_similarity(unit_a: str | None, unit_b: str | None) -> float | None: - if unit_a is None or unit_b is None: - return None - if unit_a == unit_b: - return max_score - high_pairs = {("A", "B"), ("B", "A"), ("B", "C"), ("C", "B"), ("C", "D"), ("D", "C")} - if (unit_a, unit_b) in high_pairs: - return (2.0 / 3.0) * max_score - mid_pairs = {("A", "C"), ("C", "A"), ("B", "D"), ("D", "B")} - if (unit_a, unit_b) in mid_pairs: - return (1.0 / 3.0) * max_score - return None - - # Handle string inputs (e.g., gap "-") first to avoid attribute errors - if isinstance(a, str) or isinstance(b, str): - return 1.0 if a == b else 0.0 - - if not (isinstance(a, dict) and isinstance(b, dict)): - return 0.0 - - unit_a = _base_unit(a.get("name")) - unit_b = _base_unit(b.get("name")) - unit_score = _unit_similarity(unit_a, unit_b) - if unit_score is not None: - return unit_score - - fp_a_hex = a.get("morganfingerprint2048r2") - fp_b_hex = b.get("morganfingerprint2048r2") - if fp_a_hex and fp_b_hex: - try: - fp_a = hex_to_bits(fp_a_hex, n_bits=2048) - fp_b = hex_to_bits(fp_b_hex, n_bits=2048) - except Exception: - return 0.0 - tanimoto = float(calc_tanimoto_similarity(fp_a, fp_b)) - return max(0.0, min(max_score, tanimoto * max_score)) - - return 0.0 - - -def label_motif (r: dict | str) -> str: - """ - Label function for sequence alignment. - - :param r: item - :return: name of the item - """ - if isinstance(r, str): # for gap as "-" - return r - return r.get("name") - - -@blp_run_msa.post("/api/runMsa") -def run_msa() -> tuple[dict[str, str], int]: - """ - Handle POST requests to run multiple sequence alignment (MSA). - - :return: a tuple containing an empty dictionary and HTTP status code 200 - """ - payload = request.get_json(force=True) or {} - - primary_sequences = payload.get("primarySequences", []) - center_id = payload.get("centerId", None) - settings = payload.get("msaSettings", {}) - - # Determine index of center sequence if provided - center_id_index = None - if center_id is not None: - for i, seq in enumerate(primary_sequences): - if seq.get("id") == center_id: - center_id_index = i - break - - current_app.logger.info(f"run_msa called: primary_sequences_count={len(primary_sequences)} center_id_index={center_id_index}") - - if not primary_sequences: - current_app.logger.warning("run_msa: missing primarySequences") - return jsonify({"error": "Missing primarySequences"}), 400 - - t0 = time.time() - - try: - gap_symbol = "-" # representation of gap in alignment - - # Remove current gaps from sequences; in-place modification - for seq in primary_sequences: - seq["sequence"] = [x for x in seq["sequence"] if not x["id"].startswith("pad-")] - - # Gather unique set of motifs - curr_motif_names = set() - motifs = [gap_symbol] # gap symbol - motif_by_name: dict[str | None, dict] = {} - for seq in primary_sequences: - for motif in seq["sequence"]: - name = motif.get("name") if isinstance(motif, dict) else motif - if name not in curr_motif_names: - curr_motif_names.add(name) - motifs.append(motif) - # Keep richest motif seen for this name - if isinstance(motif, dict): - prev = motif_by_name.get(name) - if prev is None or (prev.get("morganfingerprint2048r2") is None and motif.get("morganfingerprint2048r2") is not None): - motif_by_name[name] = motif - - # Construct substitution matrix - sm, _ = create_substituion_matrix_dynamically(motifs, compare=motif_compare, label_fn=label_motif) - - # Setup aligner - alignment_type = settings.get("alignmentType", "global").lower() - aligner = setup_aligner(sm, alignment_type, label_fn=label_motif) - - # Multiple sequence alignment - seqs = [seq["sequence"] for seq in primary_sequences] - lbls = [seq["id"] for seq in primary_sequences] - msa, order = calc_msa(aligner, seqs, gap_repr=gap_symbol, center_star=center_id_index) - - # Replace old sequences with aligned sequences; in-place modification - aligned_sequences = [] - for i, aligned_seq in zip(order, msa): - seq = primary_sequences[i] - seq["sequence"] = [] - def _normalize_motif(raw_motif: dict | str, name: str | None) -> dict: - if isinstance(raw_motif, dict): - base = raw_motif.copy() - else: - base = motif_by_name.get(name, {}) if name is not None else {} - return { - **base, - "id": base.get("id") or get_unique_identifier(), - "name": name, - "displayName": base.get("displayName"), - "tags": base.get("tags", []), - "smiles": base.get("smiles"), - "morganfingerprint2048r2": base.get("morganfingerprint2048r2"), - } - - for motif in aligned_seq: - if motif == gap_symbol: - seq["sequence"].append({ - "id": f"pad-{get_unique_identifier()}", - "name": None, - "displayName": None, - "tags": [], - "smiles": None, - "morganfingerprint2048r2": None, - }) - else: - name = motif.get("name") if isinstance(motif, dict) else motif - seq["sequence"].append(_normalize_motif(motif, name)) - aligned_sequences.append(seq) - - except Exception as e: - current_app.logger.error(f"run_msa: error preparing substitution matrix: {e}") - return jsonify({"error": "Error preparing substitution matrix"}), 500 - - elapsed = int((time.time() - t0) * 1000) - - current_app.logger.info(f"run_msa: finished primary_sequences_count={len(primary_sequences)} elapsed_ms={elapsed} center_id_index={center_id_index}") - - return jsonify({ - "ok": True, - "status": "done", - "elapsed_ms": elapsed, - "result": { - "alignedSequences": aligned_sequences, - } - }), 200 From 5fdcfe0853592d7d3384a79095edbd46e6d95272 Mon Sep 17 00:00:00 2001 From: David Meijer Date: Tue, 13 Jan 2026 00:06:46 +0100 Subject: [PATCH 11/34] WIP: refactoring MSA --- .../workspace/NotificationDrawer.tsx | 149 +++--- .../tabs/discovery/QueryResultView.tsx | 441 ++++++++++++++++ .../workspace/tabs/discovery/SortableItem.tsx | 37 ++ .../workspace/tabs/discovery/SortableRow.tsx | 185 +++++++ .../tabs/discovery/WorkspaceDiscovery.tsx | 172 +++++- .../tabs/upload/DialogImportCompound.tsx | 14 +- src/server/app.py | 2 + src/server/helpers/__init__.py | 0 src/server/helpers/guid.py | 12 + src/server/helpers/ncbi.py | 97 ++++ src/server/routes/compound.py | 9 +- src/server/routes/query.py | 494 ++++++++++++++++++ 12 files changed, 1517 insertions(+), 95 deletions(-) create mode 100644 src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx create mode 100644 src/client/src/components/workspace/tabs/discovery/SortableItem.tsx create mode 100644 src/client/src/components/workspace/tabs/discovery/SortableRow.tsx create mode 100644 src/server/helpers/__init__.py create mode 100644 src/server/helpers/guid.py create mode 100644 src/server/helpers/ncbi.py create mode 100644 src/server/routes/query.py diff --git a/src/client/src/components/workspace/NotificationDrawer.tsx b/src/client/src/components/workspace/NotificationDrawer.tsx index 631ad0d..4ddab11 100644 --- a/src/client/src/components/workspace/NotificationDrawer.tsx +++ b/src/client/src/components/workspace/NotificationDrawer.tsx @@ -55,86 +55,85 @@ export const NotificationDrawer: React.FC = ({ open, ha display: "flex", }} > - - - - Notifications - - - - - - - {notifications.length === 0 ? ( - - No notifications to show. - - ) : ( - + - {notifications.slice(0).reverse().map((notification) => ( - Notifications + + - - {notification.content} - - } - secondaryTypographyProps={{ component: "div" }} // render wrapper as div instead of p - /> - - ))} - - )} - + /> + + + + + {notifications.length === 0 ? ( + + No notifications to show. + + ) : ( + + {notifications.slice(0).reverse().map((notification) => ( + + + {notification.content} + + } + secondaryTypographyProps={{ component: "div" }} // render wrapper as div instead of p + /> + + ))} + + )} + ) diff --git a/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx b/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx new file mode 100644 index 0000000..22ca9b9 --- /dev/null +++ b/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx @@ -0,0 +1,441 @@ +import React from "react"; +import Typography from "@mui/material/Typography"; +import Box from "@mui/material/Box"; +import Chip from "@mui/material/Chip"; +import Stack from "@mui/material/Stack"; +import Tooltip from "@mui/material/Tooltip"; +import ZoomOutIcon from "@mui/icons-material/ZoomOut"; +import ZoomInIcon from "@mui/icons-material/ZoomIn"; +import RefreshIcon from "@mui/icons-material/Refresh"; +import DownloadIcon from "@mui/icons-material/Download"; +import { SortableRow } from "./SortableRow"; +import { SortableItem } from "./SortableItem"; + +// Imports for dragging and dropping rows and motifs +import { DndContext, DragEndEvent } from "@dnd-kit/core"; +import { + SortableContext, + arrayMove, + verticalListSortingStrategy, + horizontalListSortingStrategy, +} from "@dnd-kit/sortable"; + +export type SequenceItem = { + id: string; + isGap: boolean; + name: string | null; + smiles: string | null; +}; + +export type Reference = { + name: string; + database_name: string; + database_identifier: string; +}; + +export type MsaItem = { + id: string; + name?: string; + alignment_score: number | null; + cosine_score: number | null; + sequence: SequenceItem[]; + references: Reference[]; +}; + +export type QueryResult = { + msa: MsaItem[]; +}; + +type QueryResultViewProps = { + result: QueryResult; +}; + +export const PROTECTED_NAME_TO_CODE: Record = { + ALANINE: "ALA", + CYSTEINE: "CYS", + ASPARTICACID: "ASP", + GLUTAMICACID: "GLU", + PHENYLALANINE: "PHE", + GLYCINE: "GLY", + HISTIDINE: "HIS", + ISOLEUCINE: "ILE", + LYSINE: "LYS", + LEUCINE: "LEU", + METHIONINE: "MET", + ASPARAGINE: "ASN", + PROLINE: "PRO", + GLUTAMINE: "GLN", + ARGININE: "ARG", + SERINE: "SER", + THREONINE: "THR", + VALINE: "VAL", + TRYPTOPHAN: "TRP", + TYROSINE: "TYR", +}; + + +export const renderChiralSuperscripts = (label: string) => { + // Split into normal text + ^R/^S tokens + const parts = label.split(/(\^[RS])/g).filter(Boolean); + + return ( + <> + {parts.map((p, i) => { + if (p === "^R" || p === "^S") { + return ( + + {p.slice(1)} + + ); + } + return {p}; + })} + + ); +}; + +const isPolyketideMotif = (s: string | null | undefined) => { + if (!s) return false; + return /^[A-D](\^[RS])*(\d+)?(\^[RS])*$/i.test(s.trim()); +}; + +export const makeToDisplayName = (protectedNameToCode: Record) => { + const norm = (s: string) => s.replace(/[^a-z0-9]/gi, "").toUpperCase(); + + // normalize protected names + reserve protected codes + const prot = new Map( + Object.entries(protectedNameToCode).map(([k, v]) => [norm(k), norm(v)]) + ); + const reserved = new Set(Array.from(prot.values())); // e.g. ALA, GLY + const used = new Set(reserved); // block others from taking them + const cache = new Map(); // per-name stability + + const candidates = (s: string) => { + const out: string[] = []; + if (s.length >= 3) { + out.push(s.slice(0, 3)); // ABC + for (let i = 3; i < s.length; i++) out.push(s[0] + s[1] + s[i]); // AB? + for (let i = 2; i < s.length; i++) out.push(s[0] + s[i - 1] + s[i]); // A?? + } + if (s.length >= 2) out.push(s.slice(0, 2)); // AB + if (s.length >= 1) out.push(s[0]); // A + // de-dupe in order + const seen = new Set(); + return out.filter(c => c.length <= 3 && !seen.has(c) && (seen.add(c), true)); + }; + + return (name: string | null): string | null => { + if (!name) return null; + const s = norm(name); + if (!s) return null; + + const hit = cache.get(s); + if (hit) return hit; + + // ONLY protected full names get protected 3-letter codes + const canonical = prot.get(s); + if (canonical) { + cache.set(s, canonical); + return canonical; + } + + // don’t let non-protected names steal reserved AA codes + for (const c of candidates(s)) { + if (!used.has(c)) { + used.add(c); + cache.set(s, c); + return c; + } + } + return null; + }; +}; + +function parseColor(color: string, alpha: number): string { + // HEX case: "#RGB" or "#RRGGBB" + if (color.startsWith("#")) { + let hex = color.replace(/^#/, ""); + // expand shorthand (#abc → aabbcc) + if (hex.length === 3) { + hex = hex.split("").map(c => c + c).join(""); + } + // parse r, g, b + const r = parseInt(hex.slice(0, 2), 16); + const g = parseInt(hex.slice(2, 4), 16); + const b = parseInt(hex.slice(4, 6), 16); + return `rgba(${r}, ${g}, ${b}, ${alpha})`; + }; + + // HSL case: "hsl(h, s%, l%)" + const hsl = color.match( + /hsl\(\s*([\d.]+)(?:deg)?\s*,\s*([\d.]+)%\s*,\s*([\d.]+)%\s*\)/ + ); + if (hsl) { + const h = hsl[1]; + const s = hsl[2]; + const l = hsl[3]; + return `hsla(${h}, ${s}%, ${l}%, ${alpha})`; + }; + + throw new Error(`Unsupported color format: ${color}`); +}; + +export const canonicalMotifKey = (s: string): string => { + return s.trim().replace(/\s+/g, "").replace(/\^[RS]/g, ""); +}; + +export const defaultMotifColorMap = (): Record => { + const newColorMap: Record = {}; + + const baseColors: Record<"A"|"B"|"C"|"D", string> = { + A: "#e74c3c", // red + B: "#27ae60", // green + C: "#2980b9", // blue + D: "#f39c12", // orange + }; + + for (const key of Object.keys(baseColors) as Array) { + const color = baseColors[key]; + // plain (opaque) base + newColorMap[key] = color; + + // numbered variants 1->15 -> alpha = 1/15...15/15 + for (let i = 1; i <= 15; i++) { + const alpha = 1 - (i / 15); + const alphaRounded = Math.round(alpha * 1000) / 1000; + newColorMap[`${key}${i}`] = parseColor(color, alphaRounded); + }; + }; + + return newColorMap; +}; + +const getMotifColor = (name: string): string | null => { + const colorMap = defaultMotifColorMap(); + const key = canonicalMotifKey(name); + return colorMap[key] || null; +}; + +const renderChipLabel = ( + rawName: string | null, + toDisplayName: (name: string | null) => string | null +): React.ReactNode => { + const raw = rawName || ""; + const displayLabel = isPolyketideMotif(raw) + ? raw + : (toDisplayName(raw) || "X"); + return renderChiralSuperscripts(displayLabel); +}; + +const renderTooltipLabel = ( + rawName: string | null, + toDisplayName: (name: string | null) => string | null +): React.ReactNode => { + const raw = rawName || ""; + + // Polyketide: show the short display code in tooltip + if (isPolyketideMotif(raw)) { + return toDisplayName(raw) || raw; // fallback to raw if code not available + } + + // Non-polyketide: show full original name in tooltip + return renderChiralSuperscripts(raw || "Unknown motif"); +}; + + +export const QueryResultView: React.FC = ({ result }) => { + // Keep order locally + const [msa, setMsa] = React.useState(result.msa); + + // Zoom + const [zoom, setZoom] = React.useState(1.0); + const handleZoomIn = () => setZoom(z => Math.min(z + 0.1, 3.0)); + const handleZoomOut = () => setZoom(z => Math.max(z - 0.1, 0.5)); + const handleZoomReset = () => setZoom(1.0); + + const msaLength = result.msa.length > 0 ? Math.max(...result.msa.map((r) => r.sequence.length)) : 0; + const motifWidth = 50 * zoom; + const labelWidth = 250; + const colTemplate = `${labelWidth}px repeat(${msaLength}, ${motifWidth}px) 1fr`; + + const toDisplayName = React.useMemo(() => makeToDisplayName(PROTECTED_NAME_TO_CODE), []); + + // If a new result comes in, refresh local state + React.useEffect(() => { + setMsa(result.msa); + }, [result]); + + // Handle drag end + const handleDragEnd = React.useCallback((event: DragEndEvent) => { + const { active, over } = event; + if (!over) return; + + const activeId = active.id as string; + const overId = over.id as string; + if (activeId === overId) return; + + setMsa((prev) => { + // Row-level drag: both IDs must be row IDs + const rowIds = new Set(prev.map((r) => r.id)); + const isRowDrag = rowIds.has(activeId) && rowIds.has(overId); + + if (isRowDrag) { + const fromIndex = prev.findIndex((r) => r.id === activeId); + const toIndex = prev.findIndex((r) => r.id === overId); + if (fromIndex === -1 || toIndex === -1) return prev; + return arrayMove(prev, fromIndex, toIndex); + }; + + // Item-level drag: both IDs must be in the SAME row + let rowIndex = -1; + let fromCol = -1 + let toCol = -1; + + for (let r = 0; r < prev.length; r++) { + const seq = prev[r].sequence; + const aIdx = seq.findIndex((s) => s.id === activeId); + const oIdx = seq.findIndex((s) => s.id === overId); + + // Only reorder if both items are within the SAME row + if (aIdx !== -1 && oIdx !== -1) { + rowIndex = r; + fromCol = aIdx; + toCol = oIdx; + break; + }; + }; + + if (rowIndex === -1) return prev; + + const row = prev[rowIndex]; + const newSeq = arrayMove([...row.sequence], fromCol, toCol); + + return prev.map((r, idx) => idx === rowIndex ? { ...r, sequence: newSeq } : r); + }); + }, []); + + return ( +

    + + Query results + + + {/* Toolbar */} + + + + + + {}} sx={{ cursor: "not-allowed" }} /> + + + + + + + item.id)} + strategy={verticalListSortingStrategy} + > + + {msa.map((row) => ( + + + item.id)} + strategy={horizontalListSortingStrategy} + > + {row.sequence.map(item => ( + + {item.isGap ? ( + + + + ) : ( + + + + + + )} + + ))} + + + + ))} + + + + + +
    + ); +}; diff --git a/src/client/src/components/workspace/tabs/discovery/SortableItem.tsx b/src/client/src/components/workspace/tabs/discovery/SortableItem.tsx new file mode 100644 index 0000000..133ecf0 --- /dev/null +++ b/src/client/src/components/workspace/tabs/discovery/SortableItem.tsx @@ -0,0 +1,37 @@ +import React from "react"; +import Box from "@mui/material/Box"; +import { CSS } from "@dnd-kit/utilities"; +import { useSortable } from "@dnd-kit/sortable"; + +interface SortableItemProps { + id: string; + children: React.ReactNode; + disabled?: boolean; // if true, cannot dragt THIS item +}; + +export const SortableItem: React.FC = ({ id, children, disabled }) => { + const { attributes, listeners, setNodeRef, transform, transition } = useSortable({ + id, + disabled, + animateLayoutChanges: () => false, + }) + + return ( + + {children} + + ); +}; \ No newline at end of file diff --git a/src/client/src/components/workspace/tabs/discovery/SortableRow.tsx b/src/client/src/components/workspace/tabs/discovery/SortableRow.tsx new file mode 100644 index 0000000..adc8148 --- /dev/null +++ b/src/client/src/components/workspace/tabs/discovery/SortableRow.tsx @@ -0,0 +1,185 @@ +import React from "react"; +import Box from "@mui/material/Box"; +import Stack from "@mui/material/Stack"; +import Typography from "@mui/material/Typography"; +import Tooltip from "@mui/material/Tooltip"; +import { CSS } from "@dnd-kit/utilities"; +import { useSortable } from "@dnd-kit/sortable"; +import DragIndicatorIcon from "@mui/icons-material/DragIndicator"; +import { MsaItem, Reference } from "./QueryResultView"; + +interface SortableRowProps { + row: MsaItem; + labelWidth: number; + children: React.ReactNode; +}; + +const fmt = (v: number | null, digits: number) => + v == null || Number.isNaN(v) ? "" : v.toFixed(digits); + +const referenceToUrl = (ref: Reference): string | null => { + switch (ref.database_name.toLowerCase()) { + case "npatlas": + return `https://www.npatlas.org/explore/compounds/${ref.database_identifier}`; + default: + return null; + }; +}; + +export const SortableRow: React.FC = ({ + row, + labelWidth, + children, +}) => { + const { attributes, listeners, setNodeRef, transform, transition } = useSortable({ id: row.id }); + + // Setup outlink to reference for row + const [ref, setRef] = React.useState(null); + const url = React.useMemo(() => (ref ? referenceToUrl(ref) : null), [ref]); + + // Set reference on mount + React.useEffect(() => { + if (row.references && row.references.length > 0) { + setRef(row.references[0]); + } else { + setRef(null); + } + }, [row.references]); + + const alignText = fmt(row.alignment_score, 2); + const cosineText = fmt(row.cosine_score, 2); + const scoreBlockWidth = 40; + + return ( + <> + + e.stopPropagation()} // don't trigger center selection on drag + > + + + + + { + if (url) e.stopPropagation(); // prevent row drag / selection + }} + sx={{ + fontWeight: 600, + maxWidth: labelWidth - 100, + lineHeight: "20px", + overflow: "hidden", + textOverflow: "ellipsis", + whiteSpace: "nowrap", + zIndex: 101, + userSelect: "none", + + // link-only styling + textDecoration: url ? "underline" : "none", + cursor: url ? "pointer" : "default", + color: "inherit", + + "&:hover": url + ? { color: "primary.main" } + : undefined, + }} + > + {row.name || row.id} + + + + {/* Scores */} + + + + {alignText} + + + + + {cosineText} + + + + + + + {/* Motifs */} + {children} + + {/* Row line */} + + + ); +}; \ No newline at end of file diff --git a/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx b/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx index 57651d1..f3d0dac 100644 --- a/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx +++ b/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx @@ -4,11 +4,22 @@ import Card from "@mui/material/Card"; import CardContent from "@mui/material/CardContent"; import Typography from "@mui/material/Typography"; import MuiLink from "@mui/material/Link"; -import NotificationsRoundedIcon from "@mui/icons-material/NotificationsRounded"; +import Stack from "@mui/material/Stack"; +import InputLabel from "@mui/material/InputLabel"; +import FormControl from "@mui/material/FormControl"; +import MenuItem from "@mui/material/MenuItem"; +import Button from "@mui/material/Button"; +import CircularProgress from '@mui/material/CircularProgress'; +import Alert from "@mui/material/Alert"; +import Checkbox from "@mui/material/Checkbox"; +import FormGroup from "@mui/material/FormGroup"; +import FormControlLabel from "@mui/material/FormControlLabel"; import { useTheme } from "@mui/material/styles"; import { useNotifications } from "../../NotificationProvider"; import { Link as RouterLink } from "react-router-dom"; import { Session } from "../../../../features/session/types"; +import { Select } from "@mui/material"; +import { QueryResult, QueryResultView } from "./QueryResultView"; type WorkspaceDiscoveryProps = { session: Session; @@ -19,6 +30,16 @@ export const WorkspaceDiscovery: React.FC = ({ session, const theme = useTheme(); const { pushNotification } = useNotifications(); + // Query state + const [selectedItemId, setSelectedItemId] = React.useState(""); + const [queryLoading, setQueryLoading] = React.useState(false); + const [queryError, setQueryError] = React.useState(null); + const [queryResult, setQueryResult] = React.useState(null); + + // Query settings + const [queryAgainstCompounds, setQueryAgainstCompounds] = React.useState(true); + const [queryAgainstClusters, setQueryAgainstClusters] = React.useState(true); + // Wrap parent setter (Session | null) into the deps shape (Session-only functional updater) const setSessionSafe = React.useCallback( (updater: (prev: Session) => Session) => { @@ -37,6 +58,53 @@ export const WorkspaceDiscovery: React.FC = ({ session, [setSessionSafe, pushNotification, session.sessionId] ); + // Memoized alert based on query state + const alert = React.useMemo(() => { + if (queryError) { + return { severity: "error" as const, text: queryError }; + } + if (queryResult) { + return { severity: "success" as const, text: "Query complete! See results below." }; + } + return { + severity: "info" as const, + text: "Select an item and click “Run query” to see results here.", + }; + }, [queryError, queryResult]); + + // Post query + async function queryItem(itemId: string): Promise { + const params = new URLSearchParams({ + sessionId: session.sessionId, + itemId, + queryAgainstCompounds: String(queryAgainstCompounds), + queryAgainstClusters: String(queryAgainstClusters) + }); + const res = await fetch(`/api/queryItem?${params.toString()}`); + if (!res.ok) { throw new Error(`Query failed: ${res.status}`); }; + return await res.json(); + }; + + // Handler to run query (dummy implementation) + const handleRunQuery = async () => { + if (!selectedItemId) return; + + setQueryLoading(true); + setQueryError(null); + setQueryResult(null); + + try { + const result = await queryItem(selectedItemId); + setQueryResult(result); + pushNotification("Query completed successfully!", "success"); + } catch (error) { + setQueryError("Failed to run query. Please try again."); + pushNotification("Failed to run query.", "error"); + } finally { + setQueryLoading(false); + } + }; + return ( = ({ session,  for cross-modal retrieval against the BioNexus database. + + + + {selectedItemId ? "Item to use for querying" : "Select an item to use for querying"} + + + + + + + setQueryAgainstClusters(e.target.checked)} + /> + } + label="Query against clusters" + /> + setQueryAgainstCompounds(e.target.checked)} + /> + } + label="Query against compounds" + /> + + + + + + {queryLoading && } + + + {alert.text} + +
    + + {queryResult && ( + + + + + + )} ); }; diff --git a/src/client/src/components/workspace/tabs/upload/DialogImportCompound.tsx b/src/client/src/components/workspace/tabs/upload/DialogImportCompound.tsx index 37746d8..46606b2 100644 --- a/src/client/src/components/workspace/tabs/upload/DialogImportCompound.tsx +++ b/src/client/src/components/workspace/tabs/upload/DialogImportCompound.tsx @@ -74,17 +74,9 @@ export const DialogImportCompound: React.FC = ({ }; async function searchCompoundByName(q: string) { - const params = new URLSearchParams({ - q, - limit: "10", - }); - + const params = new URLSearchParams({q, limit: "10"}); const res = await fetch(`/api/searchCompound?${params.toString()}`); - - if (!res.ok) { - throw new Error(`Search failed: ${res.status}`); - }; - + if (!res.ok) { throw new Error(`Search failed: ${res.status}`); }; return await res.json(); }; @@ -101,7 +93,6 @@ export const DialogImportCompound: React.FC = ({ setLoading(true); try { const res = await searchCompoundByName(q); - const rows = (res.rows || []) as CompoundOption[]; setOptions(rows); } catch (err) { @@ -131,6 +122,7 @@ export const DialogImportCompound: React.FC = ({ Enter a single compound identifier & SMILES, or upload a CSV/TSV that contains a column called "name" and a column called "smiles". + Begin typing a compound name to see autocomplete suggestions from our database. Selecting one will automatically fill in a valid SMILES. diff --git a/src/server/app.py b/src/server/app.py index e465f90..dd55dc1 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -18,6 +18,7 @@ from routes.events import blp_events from routes.database import dsn_from_env from routes.compound import blp_search_compound, blp_submit_compound +from routes.query import blp_query_item # Initialize the Flask app @@ -151,4 +152,5 @@ def ready() -> tuple[dict[str, str], int]: app.register_blueprint(blp_delete_item) app.register_blueprint(blp_search_compound) app.register_blueprint(blp_submit_compound) +app.register_blueprint(blp_query_item) app.register_blueprint(blp_events) diff --git a/src/server/helpers/__init__.py b/src/server/helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/server/helpers/guid.py b/src/server/helpers/guid.py new file mode 100644 index 0000000..a23f810 --- /dev/null +++ b/src/server/helpers/guid.py @@ -0,0 +1,12 @@ +"""Helpers for generating GUIDs.""" + +import uuid + + +def generate_guid() -> str: + """ + Generate a new GUID. + + :return: a string representation of a new GUID + """ + return str(uuid.uuid4()) diff --git a/src/server/helpers/ncbi.py b/src/server/helpers/ncbi.py new file mode 100644 index 0000000..9961430 --- /dev/null +++ b/src/server/helpers/ncbi.py @@ -0,0 +1,97 @@ +"""Helpers to interact with NCBI APIs.""" + +import time +import requests +from typing import Any + + +def nuccore_to_gcf( + nuccore_acc: str, + *, + api_key: str | None = None, + email: str | None = None, + tool: str = "bionexus", + timeout: float = 15.0, + retries: int = 3, + sleep_between: float = 0.34, # ~3 requests/sec (NCBI-safe) + +) -> str | None: + """ + Resolve a nuccore accession to a RefSeq assembly accession (GCF_*). + + :param nuccore_acc: nuccore accession + :param api_key: NCBI API key (optional) + :param email: contact email (optional) + :param tool: tool name for NCBI eutils (default: "bionexus") + :param timeout: request timeout in seconds (default: 15.0) + :param retries: number of retries for requests (default: 3) + :param sleep_between: sleep time between requests in seconds (default: 0.34) + :return: GCF_XXXXXXXX.X assembly accession, or None if not found + """ + session = requests.Session() + + base_params = {"retmode": "json", "tool": tool} + if api_key: + base_params["api_key"] = api_key + if email: + base_params["email"] = email + + def _get(url: str, params: dict) -> dict[str, Any]: + """ + Helper to perform GET request with retries. + + :param url: request URL + :param params: request parameters + :return: JSON response as dictionary + """ + last_err = None + for _ in range(retries): + try: + r = session.get(url, params=params, timeout=timeout) + r.raise_for_status() + return r.json() + except Exception as e: + last_err = e + time.sleep(sleep_between) + raise last_err + + # elink: nuccore -> assembly UID + elink_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/elink.fcgi" + elink_params = { + **base_params, + "dbfrom": "nuccore", + "db": "assembly", + "id": nuccore_acc, + } + + data = _get(elink_url, elink_params) + + linksets = data.get("linksets") or [] + linksetdbs = (linksets[0].get("linksetdbs") if linksets else []) or [] + + assembly_uid = None + for db in linksetdbs: + if db.get("dbto") == "assembly" and db.get("links"): + assembly_uid = db["links"][0] + break + + if not assembly_uid: + return None + + time.sleep(sleep_between) + + # esummary: assembly UID -> GCF accession + esum_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi" + esum_params = { + **base_params, + "db": "assembly", + "id": assembly_uid, + } + + summary = _get(esum_url, esum_params) + doc = summary.get("result", {}).get(str(assembly_uid)) + + if not doc: + return None + + return doc.get("assemblyaccession") diff --git a/src/server/routes/compound.py b/src/server/routes/compound.py index 918d895..9c802d6 100644 --- a/src/server/routes/compound.py +++ b/src/server/routes/compound.py @@ -90,14 +90,7 @@ def _set_item_status_inplace(item: dict, status: str, error_message: str | None @blp_submit_compound.post("/api/submitCompound") def submit_compound(): """ - payload = request.get_json(force=True) or {} - - session_id = payload.get("sessionId") - item_id = payload.get("itemId") - name = payload.get("name") - smiles = payload.get("smiles") - - Endpoint to submit a compound by SMILES string. + Submit a compound for processing. """ payload = request.get_json(force=True) or {} session_id = payload.get("sessionId") diff --git a/src/server/routes/query.py b/src/server/routes/query.py new file mode 100644 index 0000000..c06d2d5 --- /dev/null +++ b/src/server/routes/query.py @@ -0,0 +1,494 @@ +"""Query endpoint routes.""" + +import warnings +import uuid +from dataclasses import dataclass + +import sqlalchemy as sa +from flask import Blueprint, current_app, jsonify, request +from Bio import BiopythonDeprecationWarning +from rdkit.DataStructs.cDataStructs import ExplicitBitVect + +from retromol.model.result import Result +from retromol.model.rules import RuleSet +from retromol.model.reaction_graph import MolNode +from retromol.chem.mol import smiles_to_mol +from retromol.chem.fingerprint import mol_to_morgan_fingerprint, calculate_tanimoto_similarity +from retromol.fingerprint.fingerprint import FingerprintGenerator + +from biocracker.query.modules import LinearReadout, PKSModule, NRPSModule, PKSExtenderUnit + +from bionexus.db.models import CandidateCluster, Compound, Reference + +from versalign.aligner import setup_aligner +from versalign.scoring import create_substitution_matrix_dynamically +from versalign.docking import dock_against_target + +from routes.session_store import load_item +from routes.database import SessionLocal +from helpers.ncbi import nuccore_to_gcf + +warnings.filterwarnings("ignore", category=BiopythonDeprecationWarning) + + +blp_query_item = Blueprint("query_item", __name__) + + +RULESET = RuleSet.load_default() +GENERATOR = FingerprintGenerator(RULESET.matching_rules) + + +def get_compound_references(s, compound_id: int) -> list[Reference]: + """ + Return Reference rows linked to a compound via compound_reference. + """ + stmt = ( + sa.select(Reference) + .join(Reference.compounds) + .where(Compound.id == compound_id) + .order_by(Reference.database_name.asc(), Reference.database_identifier.asc()) + ) + return list(s.scalars(stmt).all()) + + +def get_cluster_references(s, cluster_id: int) -> list[Reference]: + """ + Return Reference rows linked to a cluster via reference_candidate_cluster. + """ + stmt = ( + sa.select(Reference) + .join(Reference.candidate_clusters) + .where(CandidateCluster.id == cluster_id) + .order_by(Reference.database_name.asc(), Reference.database_identifier.asc()) + ) + return list(s.scalars(stmt).all()) + + +@dataclass(frozen=True) +class SequenceItem: + """ + Represents an item in a biosynthetic sequence (e.g., NRPS/PKS module or monomer). + + :var name: name of the item + :var morgan_fp: Morgan fingerprint of the item's structure (if applicable) + """ + + name: str + morgan_fp: ExplicitBitVect | None = None + + def __hash__(self) -> int: + return hash((self.name, self.morgan_fp.ToBitString() if self.morgan_fp else None)) + + @classmethod + def from_nrps_module(cls, mod: NRPSModule) -> "SequenceItem": + """ + Create a SequenceItem from an NRPS module. + """ + if mod.substrate.smiles is not None: + name = mod.substrate.name + smiles = mod.substrate.smiles + if smiles == "O=NN(O)CCC[C@H](N)(C(=O)O": # graminine fix (fixed in >=2.0.1 versions of BioCracker) + smiles = "O=NN(O)CCC[C@H](N)(C(=O)O)" + mol = smiles_to_mol(smiles) + morgan_fp = mol_to_morgan_fingerprint(mol, radius=2, num_bits=2048, use_chirality=False) + return cls(name, morgan_fp) + else: + return cls("Unknown") + + @classmethod + def from_pks_module(cls, mod: PKSModule) -> "SequenceItem": + """ + Create a SequenceItem from a PKS module. + """ + match mod.substrate.extender_unit: + case PKSExtenderUnit.PKS_A: name = "PKS_A" + case PKSExtenderUnit.PKS_B: name = "PKS_B" + case PKSExtenderUnit.PKS_C: name = "PKS_C" + case PKSExtenderUnit.PKS_D: name = "PKS_D" + case _: name = "PKS_A" + return cls(name) + + @classmethod + def from_molnode(cls, node: MolNode) -> "SequenceItem": + """ + Create a SequenceItem from a MolNode. + """ + if node.is_identified: + rule = node.identity.matched_rule + name = rule.name + mol = smiles_to_mol(rule.smiles) + morgan_fp = mol_to_morgan_fingerprint(mol, radius=2, num_bits=2048, use_chirality=False) + return cls(name, morgan_fp) + else: + return cls("Unknown") + + +def item_compare(a: SequenceItem | str, b: SequenceItem | str) -> float: + """ + Compare two SequenceItems or gap representations. + """ + if a == "-" or b == "-": + return 0.0 # gap penalty + + elif isinstance(a, SequenceItem) and isinstance(b, SequenceItem): + pks_a = {'PKS_A', 'A2'} + pks_b = {'PKS_B', 'B2', 'B6'} + pks_d = {'PKS_D', 'D6'} + pks_mod_names = {"PKS_A", "PKS_B", "PKS_C", "PKS_D", "B2", "D6", "A2", "B6"} + if a.name in pks_a and b.name in pks_a: + return 1.0 + elif a.name in pks_b and b.name in pks_b: + return 1.0 + elif a.name in pks_d and b.name in pks_d: + return 1.0 + elif a.name in pks_mod_names or b.name in pks_mod_names: + # Could be correct, but we have no info + return 0.5 + + elif a.name == "Unknown" or b.name == "Unknown": + # Could be correct, but we have no info + return 0.5 + + elif a.morgan_fp is not None and b.morgan_fp is not None: + return calculate_tanimoto_similarity(a.morgan_fp, b.morgan_fp) + + return -2.0 + +def label_fn (r: SequenceItem | str) -> str: + """ + Label function for sequence items. + """ + return str(hash(r)) if isinstance(r, SequenceItem) else r + + +@blp_query_item.get("/api/queryItem") +def query_item(): + """ + Query endpoint for compounds by name-like query. + """ + session_id = request.args.get("sessionId", "").strip() + item_id = request.args.get("itemId", "").strip() + if not session_id: + return jsonify({"error": "Missing sessionId"}), 400 + if not item_id: + return jsonify({"error": "Missing itemId"}), 400 + + query_against_compounds = request.args.get("queryAgainstCompounds", "true").lower() == "true" + query_against_clusters = request.args.get("queryAgainstClusters", "true").lower() == "true" + # if both set to false, return error + if not query_against_compounds and not query_against_clusters: + return jsonify({"error": "At least one of queryAgainstCompounds or queryAgainstClusters must be true"}), 400 + + # Retrieve item from session store + item = load_item(session_id, item_id) + if item is None: + return jsonify({"error": "Item not found"}), 404 + + # Load Result + payload_as_dict = item.get("payload", None) + if payload_as_dict is None: + return jsonify({"error": "No payload found in item"}), 404 + payload: Result = Result.from_dict(payload_as_dict) + + # Create fingerprints for query + retromol_fp_counted = GENERATOR.fingerprint_from_result(payload, num_bits=1024, counted=True) + retromol_fp_counted = retromol_fp_counted.astype(float).tolist() + # retromol_fp_binary = [float(int(x > 0)) for x in retromol_fp_counted] # currently not used; default is counted fingerprints + + # Retrieve primary sequence from payload + linear_readouts = payload.linear_readout.paths + linear_readout = max(linear_readouts, key=lambda x: len(x)) + current_app.logger.debug(f"best linear readout has {len(linear_readout)} module(s)") + + # Parse linear_readout into sequence of SequenceItems + seq1: list[SequenceItem] = [SequenceItem.from_molnode(n) for n in linear_readout] + + # ANN query against compounds and/or clusters + keep_top = 1000 + + with SessionLocal() as s: + # Works for pgvector 0.8.0+ + s.execute(sa.text("SET LOCAL hnsw.iterative_scan = strict_order")) + # increase how far it is allowed to scan + s.execute(sa.text("SET LOCAL hnsw.max_scan_tuples = 1000000")) + # optional: allow more memory for scanning + s.execute(sa.text("SET LOCAL hnsw.scan_mem_multiplier = 2")) + # increase ef_search for better accuracy + s.execute(sa.text("SET LOCAL hnsw.ef_search = 1000")) + + if query_against_clusters: + dist = CandidateCluster.retromol_fp_counted_by_region.cosine_distance(retromol_fp_counted).label("dist") + stmt = ( + sa.select(CandidateCluster, dist) + .where( + CandidateCluster.retromol_fp_counted_by_region.is_not(None), + # CandidateCluster.file_name.ilike("BGC%"), + ) + .order_by(dist.asc()) + .limit(keep_top if (query_against_clusters and not query_against_compounds) else keep_top//2) + ) + cluster_rows = s.execute(stmt).all() + else: + cluster_rows = [] + + if query_against_compounds: + dist = Compound.retromol_fp_counted.cosine_distance(retromol_fp_counted).label("dist") + stmt = ( + sa.select(Compound, dist) + .where( + Compound.retromol_fp_counted.is_not(None), + ) + .order_by(dist.asc()) + .limit(keep_top if (query_against_compounds and not query_against_clusters) else keep_top//2) + ) + compound_rows = s.execute(stmt).all() + else: + compound_rows = [] + + # Rerank cluster rows through docking alignment + best_clusters = [] + for cluster, cosine_dist in cluster_rows: + rec = LinearReadout.from_dict(cluster.biocracker) + + # Assembly seq2 from rec + seq2: list[list[SequenceItem]] = [] + by_orf = False + if not by_orf: subs = [("seq", rec.biosynthetic_order(by_orf=by_orf))] + else: subs = rec.biosynthetic_order(by_orf=by_orf) + for _, mods in subs: + seq2_sub = [] + for mod in mods: + if isinstance(mod, NRPSModule): seq2_sub.append(SequenceItem.from_nrps_module(mod)) + elif isinstance(mod, PKSModule): seq2_sub.append(SequenceItem.from_pks_module(mod)) + else: raise ValueError(f"unknown module type: {type(mod)}") + seq2.append(seq2_sub) + + if len(seq2): + # Dynamically create scoring matrix + items = ["-"] + items.extend(seq1) + for seq2_sub in seq2: + items.extend(seq2_sub) + unique_items = list(set(items)) + sm, _ = create_substitution_matrix_dynamically(unique_items, compare=item_compare, label_fn=label_fn) + aligner = setup_aligner( + sm, + "global", + target_internal_open_gap_score=-5.0, + target_left_open_gap_score=-5.0, + target_right_open_gap_score=-5.0, + query_internal_open_gap_score=-5.0, + query_left_open_gap_score=-5.0, + query_right_open_gap_score=-5.0, + label_fn=label_fn, + ) + aln = dock_against_target( + aligner=aligner, + target=seq1, + candidates=seq2, + gap_repr="-", + allow_block_reverse=True, + strategy="nonoverlap", + ) + alignment_score = aln.total_score # the higher the score the bigger and stronge the match between the two; favors long matches + + # Penalize unmatched parts; if we are using shorter blocks + # TODO + + if len(best_clusters) < keep_top or alignment_score > best_clusters[-1][0]: + best_clusters.append((alignment_score, 1.0 - cosine_dist, cluster, aln, seq2)) + best_clusters.sort(key=lambda x: x[0], reverse=True) + if len(best_clusters) > keep_top: + best_clusters.pop() + + + # Rerank compound rows through docking alignment + best_compounds = [] + for compound, cosine_dist in compound_rows: + rec = Result.from_dict(compound.retromol) + + # Assembly seq2 from rec + # Retrieve primary sequence from payload + compound_readouts = rec.linear_readout.paths + compound_readout = max(compound_readouts, key=lambda x: len(x)) + seq2: list[SequenceItem] = [[SequenceItem.from_molnode(n) for n in compound_readout]] + + if len(seq2): + # Dynamically create scoring matrix + items = ["-"] + items.extend(seq1) + for seq2_sub in seq2: + items.extend(seq2_sub) + unique_items = list(set(items)) + sm, _ = create_substitution_matrix_dynamically(unique_items, compare=item_compare, label_fn=label_fn) + aligner = setup_aligner( + sm, + "global", + target_internal_open_gap_score=-5.0, + target_left_open_gap_score=-5.0, + target_right_open_gap_score=-5.0, + query_internal_open_gap_score=-5.0, + query_left_open_gap_score=-5.0, + query_right_open_gap_score=-5.0, + label_fn=label_fn, + ) + aln = dock_against_target( + aligner=aligner, + target=seq1, + candidates=seq2, + gap_repr="-", + allow_block_reverse=True, + strategy="nonoverlap", + ) + alignment_score = aln.total_score # the higher the score the bigger and stronge the match between the two; favors long matches + + # Penalize unmatched parts; if we are using shorter blocks + # TODO + + if len(best_compounds) < keep_top or alignment_score > best_compounds[-1][0]: + best_compounds.append((alignment_score, 1.0 - cosine_dist, compound, aln, seq2)) + best_compounds.sort(key=lambda x: x[0], reverse=True) + if len(best_compounds) > keep_top: + best_compounds.pop() + + + # Sort first on alignment score, then on cosine score + best_clusters.sort(key=lambda x: (x[0], x[1]), reverse=True) + best_compounds.sort(key=lambda x: (x[0], x[1]), reverse=True) + + # combine lists and resort, add little tag to indicate source + combined = [] + for alignment_score, cosine_score, cluster, aln, blocks in best_clusters: + combined.append((alignment_score, cosine_score, cluster, aln, blocks, "cluster")) + for alignment_score, cosine_score, compound, aln, blocks in best_compounds: + combined.append((alignment_score, cosine_score, compound, aln, blocks, "compound")) + combined.sort(key=lambda x: (x[0], x[1]), reverse=True) + + # Format response + msa: list[list[dict]] = [] + + # print(linear_readout) # MolNodes + msa_item = { + "id": str(uuid.uuid4()), + "name": "Query", + "alignment_score": None, + "cosine_score": None, + "sequence": [], + "references": [], + } + for x in linear_readout: + msa_item["sequence"].append({ + "id": str(uuid.uuid4()), + "isGap": False, + "name": x.identity.matched_rule.name if x.is_identified else None, + "smiles": x.identity.matched_rule.smiles if x.is_identified else None, + }) + msa.append(msa_item) + + for i, (alignment_score, cosine_score, cluster, aln, blocks, source) in enumerate(combined[:20], 1): + + if source == "compound": + + # get compound references + with SessionLocal() as s: + refs = get_compound_references(s, cluster.id) + + if refs: + name = refs[0].name + else: + name = "Unnamed compound" + + msa_item = { + "id": str(uuid.uuid4()), + "name": name, + "alignment_score": round(alignment_score, 3), + "cosine_score": round(cosine_score, 3), + "sequence": [ + { + "id": str(uuid.uuid4()), + "isGap": True, + "name": None, + "smiles": None, + } + for _ in range(len(linear_readout)) + ], + "references": [{ + "name": ref.name, + "database_name": ref.database_name, + "database_identifier": ref.database_identifier, + } for ref in refs], + } + + + if source == "cluster": + + # get cluster references + with SessionLocal() as s: + refs = get_cluster_references(s, cluster.id) + + msa_item = { + "id": str(uuid.uuid4()), + "name": cluster.file_name, + "alignment_score": round(alignment_score, 3), + "cosine_score": round(cosine_score, 3), + "sequence": [ + { + "id": str(uuid.uuid4()), + "isGap": True, + "name": None, + "smiles": None, + } + for _ in range(len(linear_readout)) + ], + "references": [{ + "name": ref.name, + "database_name": ref.database_name, + "database_identifier": ref.database_identifier, + } for ref in refs], + } + + try: + # sort placements by start position + placements = sorted(aln.placements, key=lambda p: p.start) + for placement in placements: + + is_reversed = placement.reversed + + # get real identities instead of hashes of aligned blocks + block_idx = placement.block_idx + block = blocks[block_idx] + if is_reversed: + block = list(reversed(block)) + # gap_inds = [i for i, x in enumerate(placement.block_aln) if x == "-"] + # print(len(placement.block_aln), len(block), gap_inds) + + placement_items = {} + placement_count = 0 + for x in placement.block_aln: + if x == "-": + continue + placement_items[placement_count] = block[placement_count] + placement_count += 1 + + start = placement.start + end = placement.end + it = 0 + for idx in range(start, end + 1): + name = placement_items[it].name + if name.startswith("PKS_"): + name = name.strip("PKS_") + msa_item["sequence"][idx] = { + "id": str(uuid.uuid4()), + "isGap": False, + # "name": None, # could be filled in if needed + "name": name, + "smiles": None, # could be filled in if needed + } + it += 1 + except Exception as e: + pass + + msa.append(msa_item) + + # For now just return error + return jsonify({"msa": msa}), 200 From aafd3f0f8beb1d4d70e20fedba8bd7bf68892d9f Mon Sep 17 00:00:00 2001 From: David Meijer Date: Thu, 15 Jan 2026 03:51:41 +0100 Subject: [PATCH 12/34] UPD: correct visualization block alignment --- .../tabs/discovery/QueryResultView.tsx | 241 +++++++--- .../workspace/tabs/discovery/SortableItem.tsx | 37 -- .../workspace/tabs/discovery/SortableRow.tsx | 74 +-- .../tabs/upload/DialogImportCompound.tsx | 15 +- .../tabs/upload/WorkspaceItemCard.tsx | 54 ++- .../workspace/tabs/upload/WorkspaceUpload.tsx | 15 +- src/client/src/features/jobs/api.ts | 1 - src/client/src/features/session/api.ts | 5 +- src/client/src/features/session/types.ts | 1 - src/server/app.py | 2 +- src/server/routes/query/__init__.py | 0 .../routes/{query.py => query/_query.py} | 239 ++++++---- src/server/routes/query/align.py | 133 ++++++ src/server/routes/query/featurize.py | 145 ++++++ src/server/routes/query/pipeline.py | 442 ++++++++++++++++++ src/server/routes/query/retrieve.py | 122 +++++ src/server/routes/query/seq.py | 185 ++++++++ src/server/routes/query_service.py | 54 +++ src/server/routes/session.py | 4 + src/server/routes/session_store.py | 18 +- 20 files changed, 1535 insertions(+), 252 deletions(-) delete mode 100644 src/client/src/components/workspace/tabs/discovery/SortableItem.tsx create mode 100644 src/server/routes/query/__init__.py rename src/server/routes/{query.py => query/_query.py} (70%) create mode 100644 src/server/routes/query/align.py create mode 100644 src/server/routes/query/featurize.py create mode 100644 src/server/routes/query/pipeline.py create mode 100644 src/server/routes/query/retrieve.py create mode 100644 src/server/routes/query/seq.py create mode 100644 src/server/routes/query_service.py diff --git a/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx b/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx index 22ca9b9..e56694b 100644 --- a/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx +++ b/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx @@ -8,8 +8,8 @@ import ZoomOutIcon from "@mui/icons-material/ZoomOut"; import ZoomInIcon from "@mui/icons-material/ZoomIn"; import RefreshIcon from "@mui/icons-material/Refresh"; import DownloadIcon from "@mui/icons-material/Download"; +import ExchangeIcon from "@mui/icons-material/SwapHoriz"; import { SortableRow } from "./SortableRow"; -import { SortableItem } from "./SortableItem"; // Imports for dragging and dropping rows and motifs import { DndContext, DragEndEvent } from "@dnd-kit/core"; @@ -17,7 +17,6 @@ import { SortableContext, arrayMove, verticalListSortingStrategy, - horizontalListSortingStrategy, } from "@dnd-kit/sortable"; export type SequenceItem = { @@ -27,6 +26,12 @@ export type SequenceItem = { smiles: string | null; }; +export type Sequence = { + id: string; + name: string | null; + sequence: SequenceItem[]; +}; + export type Reference = { name: string; database_name: string; @@ -38,7 +43,7 @@ export type MsaItem = { name?: string; alignment_score: number | null; cosine_score: number | null; - sequence: SequenceItem[]; + sequence: Sequence[]; references: Reference[]; }; @@ -242,18 +247,36 @@ const renderTooltipLabel = ( return renderChiralSuperscripts(raw || "Unknown motif"); }; - export const QueryResultView: React.FC = ({ result }) => { // Keep order locally const [msa, setMsa] = React.useState(result.msa); + // Invert order of motifs in msa + const invertMsaMotifOrder = () => { + setMsa((prev) => + prev.map((row) => ({ + ...row, + sequence: [...row.sequence] + .reverse() + .map((seq) => ({ ...seq, sequence: [...seq.sequence].reverse() })), + }))) + }; + // Zoom const [zoom, setZoom] = React.useState(1.0); const handleZoomIn = () => setZoom(z => Math.min(z + 0.1, 3.0)); const handleZoomOut = () => setZoom(z => Math.max(z - 0.1, 0.5)); const handleZoomReset = () => setZoom(1.0); - const msaLength = result.msa.length > 0 ? Math.max(...result.msa.map((r) => r.sequence.length)) : 0; + const sequenceLength = (seqs: Sequence[]) => + seqs.reduce((sum, seq) => sum + seq.sequence.length, 0); + + const msaLength = + result.msa.length > 0 + ? Math.max(...result.msa.map((r) => sequenceLength(r.sequence))) + : 0; + + // const msaLength = result.msa.length > 0 ? Math.max(...result.msa.map((r) => r.sequence.length)) : 0; const motifWidth = 50 * zoom; const labelWidth = 250; const colTemplate = `${labelWidth}px repeat(${msaLength}, ${motifWidth}px) 1fr`; @@ -335,6 +358,7 @@ export const QueryResultView: React.FC = ({ result }) => { + {}} sx={{ cursor: "not-allowed" }} /> @@ -345,91 +369,170 @@ export const QueryResultView: React.FC = ({ result }) => { width: "100%", overflowX: "auto", overflowY: "hidden", - pb: 2, + py: 2, }} > item.id)} + items={msa.map((item) => item.id)} strategy={verticalListSortingStrategy} > {msa.map((row) => ( - - - item.id)} - strategy={horizontalListSortingStrategy} + + + {row.sequence.map((subseq) => { + const allGaps = subseq.sequence.every((it) => it.isGap); + + return ( - {row.sequence.map(item => ( - - {item.isGap ? ( + + {subseq.name || subseq.id} + + + {subseq.sequence.map((item) => + item.isGap ? ( + + + ) : ( + + + Motif name is{" "} + {isPolyketideMotif(item.name) + ? renderChipLabel(item.name, toDisplayName) + : renderTooltipLabel(item.name, toDisplayName)} + + } + arrow > - - - ) : ( - - - - - - )} - - ))} - - - + + + ) + )} + + + )})} + + ))} diff --git a/src/client/src/components/workspace/tabs/discovery/SortableItem.tsx b/src/client/src/components/workspace/tabs/discovery/SortableItem.tsx deleted file mode 100644 index 133ecf0..0000000 --- a/src/client/src/components/workspace/tabs/discovery/SortableItem.tsx +++ /dev/null @@ -1,37 +0,0 @@ -import React from "react"; -import Box from "@mui/material/Box"; -import { CSS } from "@dnd-kit/utilities"; -import { useSortable } from "@dnd-kit/sortable"; - -interface SortableItemProps { - id: string; - children: React.ReactNode; - disabled?: boolean; // if true, cannot dragt THIS item -}; - -export const SortableItem: React.FC = ({ id, children, disabled }) => { - const { attributes, listeners, setNodeRef, transform, transition } = useSortable({ - id, - disabled, - animateLayoutChanges: () => false, - }) - - return ( - - {children} - - ); -}; \ No newline at end of file diff --git a/src/client/src/components/workspace/tabs/discovery/SortableRow.tsx b/src/client/src/components/workspace/tabs/discovery/SortableRow.tsx index adc8148..d505b6c 100644 --- a/src/client/src/components/workspace/tabs/discovery/SortableRow.tsx +++ b/src/client/src/components/workspace/tabs/discovery/SortableRow.tsx @@ -11,11 +11,12 @@ import { MsaItem, Reference } from "./QueryResultView"; interface SortableRowProps { row: MsaItem; labelWidth: number; + columnTemplate: string; children: React.ReactNode; }; -const fmt = (v: number | null, digits: number) => - v == null || Number.isNaN(v) ? "" : v.toFixed(digits); +const fmt = (v: number | null, digits: number) => + v == null || Number.isNaN(v) ? "" : v.toFixed(digits); const referenceToUrl = (ref: Reference): string | null => { switch (ref.database_name.toLowerCase()) { @@ -23,20 +24,28 @@ const referenceToUrl = (ref: Reference): string | null => { return `https://www.npatlas.org/explore/compounds/${ref.database_identifier}`; default: return null; - }; + } }; export const SortableRow: React.FC = ({ row, labelWidth, + columnTemplate, children, }) => { - const { attributes, listeners, setNodeRef, transform, transition } = useSortable({ id: row.id }); + const { + attributes, + listeners, + setNodeRef, + setActivatorNodeRef, + transform, + transition, + } = useSortable({ id: row.id }); // Setup outlink to reference for row const [ref, setRef] = React.useState(null); const url = React.useMemo(() => (ref ? referenceToUrl(ref) : null), [ref]); - + // Set reference on mount React.useEffect(() => { if (row.references && row.references.length > 0) { @@ -51,12 +60,19 @@ export const SortableRow: React.FC = ({ const scoreBlockWidth = 40; return ( - <> + = ({ }} > = ({ alignItems: "center", cursor: "grab", }} - onClick={e => e.stopPropagation()} // don't trigger center selection on drag + onClick={(e) => e.stopPropagation()} // don't trigger center selection on drag > - + = ({ href={url ?? undefined} target={url ? "_blank" : undefined} rel={url ? "noopener noreferrer" : undefined} - onClick={(e) => { + onClick={(e: React.MouseEvent) => { if (url) e.stopPropagation(); // prevent row drag / selection }} sx={{ @@ -110,9 +132,7 @@ export const SortableRow: React.FC = ({ cursor: url ? "pointer" : "default", color: "inherit", - "&:hover": url - ? { color: "primary.main" } - : undefined, + "&:hover": url ? { color: "primary.main" } : undefined, }} > {row.name || row.id} @@ -130,11 +150,7 @@ export const SortableRow: React.FC = ({ lineHeight: 1, }} > - + = ({ {alignText} - + = ({ - + {/* Motifs */} {children} {/* Row line */} - - + /> */} + ); -}; \ No newline at end of file +}; diff --git a/src/client/src/components/workspace/tabs/upload/DialogImportCompound.tsx b/src/client/src/components/workspace/tabs/upload/DialogImportCompound.tsx index 46606b2..5da46e2 100644 --- a/src/client/src/components/workspace/tabs/upload/DialogImportCompound.tsx +++ b/src/client/src/components/workspace/tabs/upload/DialogImportCompound.tsx @@ -10,6 +10,7 @@ import Typography from "@mui/material/Typography"; import Autocomplete from "@mui/material/Autocomplete"; import { useNotifications } from "../../NotificationProvider"; import { DialogWindow } from "../../../shared/DialogWindow"; +import { o } from "framer-motion/dist/types.d-DagZKalS"; type CompoundOption = { name: string; @@ -177,12 +178,18 @@ export const DialogImportCompound: React.FC = ({ } }} renderOption={(props, option) => { - if (typeof option == "string") { - return
  • {option}
  • ; - } + const { key, ...optionProps } = props as React.HTMLAttributes & { key: React.Key }; + + if (typeof option === "string") { + return ( +
  • + {option} +
  • + ); + }; return ( -
  • +
  • = ({ - + = ({ height={70} innerRadius="70%" outerRadius="100%" - sx={{ - "& text": { - fontSize: "0.65rem", - fontWeight: 600, - }, - "& .MuiGauge-valueArc": { - fill: (theme) => getScoreColor(theme, item.score!), - transition: "stroke-dashoffset 0.3s ease", - }, - }} + sx={{ + minWidth: 70, + "& text": { + fontSize: "0.65rem", + fontWeight: 600, + }, + "& .MuiGauge-valueArc": { + fill: (theme) => getScoreColor(theme, item.score!), + transition: "stroke-dashoffset 0.3s ease", + }, + }} text={({ value }) => `${value}%`} /> - - - + + + {item.name} @@ -164,7 +171,18 @@ export const WorkspaceItemCard: React.FC = ({ - + {disabled && ( <> diff --git a/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx b/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx index 34d0c6a..d98bc51 100644 --- a/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx +++ b/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx @@ -166,6 +166,17 @@ export const WorkspaceUpload: React.FC = ({ session, setSe }; }; + // Open dailogs + const handleOpenCompounds = (event: React.MouseEvent) => { + event.currentTarget.blur(); // prevents 'Blocked aria-hidden on an element' warning + setOpenCompounds(true); + }; + + const handleOpenBGCs = (event: React.MouseEvent) => { + // event.currentTarget.blur(); // prevents 'Blocked aria-hidden on an element' warning + // Not implemented yet + }; + // Import compound handlers const handleImportSingleCompound = async({ name, smiles, matchStereochemistry}: { name: string; smiles: string; matchStereochemistry: boolean }) => { await importCompound(deps, { name, smiles, matchStereochemistry }); @@ -228,11 +239,11 @@ export const WorkspaceUpload: React.FC = ({ session, setSe - - diff --git a/src/client/src/features/jobs/api.ts b/src/client/src/features/jobs/api.ts index 26efe24..23a5232 100644 --- a/src/client/src/features/jobs/api.ts +++ b/src/client/src/features/jobs/api.ts @@ -72,7 +72,6 @@ export async function importCompoundsBatch( updatedAt: Date.now(), // optional fields score: null, - payload: null, })); const updated: Session = { ...prev, items: [...prev.items, ...createdItems] }; diff --git a/src/client/src/features/session/api.ts b/src/client/src/features/session/api.ts index f584d04..0ceb12e 100644 --- a/src/client/src/features/session/api.ts +++ b/src/client/src/features/session/api.ts @@ -15,6 +15,7 @@ export async function getSession(sessionIdArg?: string): Promise { const sessionId = sessionIdArg ?? getCookie("sessionId"); if (!sessionId) throw new Error("No sessionId provided or found in cookies"); const data = await postJson("/api/getSession", { sessionId }, GetSessionRespSchema); + // Dont have to sanitize session: postJson already validates with GetSessionRespSchema return data.session; }; @@ -24,8 +25,8 @@ export async function refreshSession(sessionId: string): Promise { export async function saveSession(session: Session): Promise { // Runtime validate before sending (especially useful because session is user-mutated in UI) - SessionSchema.parse(session); - await postJson("/api/saveSession", { session }, z.unknown()); + const sanitized = SessionSchema.parse(session); + await postJson("/api/saveSession", { session: sanitized }, z.unknown()); }; export async function deleteSession(): Promise { diff --git a/src/client/src/features/session/types.ts b/src/client/src/features/session/types.ts index c87cc95..03abd94 100644 --- a/src/client/src/features/session/types.ts +++ b/src/client/src/features/session/types.ts @@ -4,7 +4,6 @@ export const BaseItemSchema = z.object({ id: z.string(), name: z.string(), // display name score: z.number().min(0).max(1).nullable().optional(), - payload: z.record(z.any()).nullable().optional(), status: z.enum(["queued", "processing", "done", "error"]).default("queued"), errorMessage: z.string().nullable().optional(), updatedAt: z.number().nonnegative().default(() => Date.now()), diff --git a/src/server/app.py b/src/server/app.py index dd55dc1..8e87a62 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -18,7 +18,7 @@ from routes.events import blp_events from routes.database import dsn_from_env from routes.compound import blp_search_compound, blp_submit_compound -from routes.query import blp_query_item +from routes.query_service import blp_query_item # Initialize the Flask app diff --git a/src/server/routes/query/__init__.py b/src/server/routes/query/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/server/routes/query.py b/src/server/routes/query/_query.py similarity index 70% rename from src/server/routes/query.py rename to src/server/routes/query/_query.py index c06d2d5..be539c7 100644 --- a/src/server/routes/query.py +++ b/src/server/routes/query/_query.py @@ -197,11 +197,17 @@ def query_item(): # Retrieve primary sequence from payload linear_readouts = payload.linear_readout.paths - linear_readout = max(linear_readouts, key=lambda x: len(x)) - current_app.logger.debug(f"best linear readout has {len(linear_readout)} module(s)") + linear_readouts.sort(key=lambda x: len(x), reverse=True) + seq1_blocks: list[list[SequenceItem]] = [] + for readout in linear_readouts: + seq1_blocks.append([SequenceItem.from_molnode(n) for n in readout]) + # Flatten seq1 for alignment + seq1: list[SequenceItem] = [] + for block in seq1_blocks: + seq1.extend(block) - # Parse linear_readout into sequence of SequenceItems - seq1: list[SequenceItem] = [SequenceItem.from_molnode(n) for n in linear_readout] + + # NOTE: are we going to retrieve clusters/compounds based on a every block or combined? # ANN query against compounds and/or clusters keep_top = 1000 @@ -223,6 +229,7 @@ def query_item(): .where( CandidateCluster.retromol_fp_counted_by_region.is_not(None), # CandidateCluster.file_name.ilike("BGC%"), + # CandidateCluster.file_name.ilike("BGC0000336"), ) .order_by(dist.asc()) .limit(keep_top if (query_against_clusters and not query_against_compounds) else keep_top//2) @@ -252,7 +259,7 @@ def query_item(): # Assembly seq2 from rec seq2: list[list[SequenceItem]] = [] - by_orf = False + by_orf = True if not by_orf: subs = [("seq", rec.biosynthetic_order(by_orf=by_orf))] else: subs = rec.biosynthetic_order(by_orf=by_orf) for _, mods in subs: @@ -263,7 +270,7 @@ def query_item(): else: raise ValueError(f"unknown module type: {type(mod)}") seq2.append(seq2_sub) - if len(seq2): + if any(len(seq2_sub) for seq2_sub in seq2): # Dynamically create scoring matrix items = ["-"] items.extend(seq1) @@ -275,11 +282,11 @@ def query_item(): sm, "global", target_internal_open_gap_score=-5.0, - target_left_open_gap_score=-5.0, - target_right_open_gap_score=-5.0, + target_left_open_gap_score=-2.5, + target_right_open_gap_score=-2.5, query_internal_open_gap_score=-5.0, - query_left_open_gap_score=-5.0, - query_right_open_gap_score=-5.0, + query_left_open_gap_score=-2.5, + query_right_open_gap_score=-2.5, label_fn=label_fn, ) aln = dock_against_target( @@ -301,7 +308,6 @@ def query_item(): if len(best_clusters) > keep_top: best_clusters.pop() - # Rerank compound rows through docking alignment best_compounds = [] for compound, cosine_dist in compound_rows: @@ -310,14 +316,16 @@ def query_item(): # Assembly seq2 from rec # Retrieve primary sequence from payload compound_readouts = rec.linear_readout.paths - compound_readout = max(compound_readouts, key=lambda x: len(x)) - seq2: list[SequenceItem] = [[SequenceItem.from_molnode(n) for n in compound_readout]] + compound_readouts.sort(key=lambda x: len(x), reverse=True) + seq2_blocks: list[list[SequenceItem]] = [] + for readout in compound_readouts: + seq2_blocks.append([SequenceItem.from_molnode(n) for n in readout]) - if len(seq2): + if any(len(seq2_sub) for seq2_sub in seq2_blocks): # Dynamically create scoring matrix items = ["-"] items.extend(seq1) - for seq2_sub in seq2: + for seq2_sub in seq2_blocks: items.extend(seq2_sub) unique_items = list(set(items)) sm, _ = create_substitution_matrix_dynamically(unique_items, compare=item_compare, label_fn=label_fn) @@ -325,17 +333,17 @@ def query_item(): sm, "global", target_internal_open_gap_score=-5.0, - target_left_open_gap_score=-5.0, - target_right_open_gap_score=-5.0, + target_left_open_gap_score=-2.5, + target_right_open_gap_score=-2.5, query_internal_open_gap_score=-5.0, - query_left_open_gap_score=-5.0, - query_right_open_gap_score=-5.0, + query_left_open_gap_score=-2.5, + query_right_open_gap_score=-2.5, label_fn=label_fn, ) aln = dock_against_target( aligner=aligner, target=seq1, - candidates=seq2, + candidates=seq2_blocks, gap_repr="-", allow_block_reverse=True, strategy="nonoverlap", @@ -346,7 +354,7 @@ def query_item(): # TODO if len(best_compounds) < keep_top or alignment_score > best_compounds[-1][0]: - best_compounds.append((alignment_score, 1.0 - cosine_dist, compound, aln, seq2)) + best_compounds.append((alignment_score, 1.0 - cosine_dist, compound, aln, seq2_blocks)) best_compounds.sort(key=lambda x: x[0], reverse=True) if len(best_compounds) > keep_top: best_compounds.pop() @@ -376,17 +384,35 @@ def query_item(): "sequence": [], "references": [], } - for x in linear_readout: + # for x in linear_readout: + # subseq.append({ + # "id": str(uuid.uuid4()), + # "isGap": False, + # "name": x.identity.matched_rule.name if x.is_identified else None, + # "smiles": x.identity.matched_rule.smiles if x.is_identified else None, + # }) + for i, block in enumerate(seq1_blocks): + subseq = [] + for x in block: + subseq.append({ + "id": str(uuid.uuid4()), + "isGap": False, + "name": x.name, + "smiles": None, + }) msa_item["sequence"].append({ "id": str(uuid.uuid4()), - "isGap": False, - "name": x.identity.matched_rule.name if x.is_identified else None, - "smiles": x.identity.matched_rule.smiles if x.is_identified else None, + "name": f"primary sequence {i + 1}", + "sequence": subseq, }) msa.append(msa_item) + # NOTE: if nothing from seq2 block alignst to a block original alignemtn(target) the sequence should remain emptty, not full of gaps + for i, (alignment_score, cosine_score, cluster, aln, blocks, source) in enumerate(combined[:20], 1): + msa_item = None + if source == "compound": # get compound references @@ -403,15 +429,7 @@ def query_item(): "name": name, "alignment_score": round(alignment_score, 3), "cosine_score": round(cosine_score, 3), - "sequence": [ - { - "id": str(uuid.uuid4()), - "isGap": True, - "name": None, - "smiles": None, - } - for _ in range(len(linear_readout)) - ], + "sequence": [], "references": [{ "name": ref.name, "database_name": ref.database_name, @@ -419,6 +437,21 @@ def query_item(): } for ref in refs], } + for i, block in enumerate(seq1_blocks): + msa_item["sequence"].append({ + "id": str(uuid.uuid4()), + "name": f"primary sequence {i + 1}", + "sequence": [ + { + "id": str(uuid.uuid4()), + "isGap": True, + "name": None, + "smiles": None, + } + for _ in range(len(block)) + ], + }) + if source == "cluster": @@ -431,64 +464,100 @@ def query_item(): "name": cluster.file_name, "alignment_score": round(alignment_score, 3), "cosine_score": round(cosine_score, 3), - "sequence": [ - { - "id": str(uuid.uuid4()), - "isGap": True, - "name": None, - "smiles": None, - } - for _ in range(len(linear_readout)) - ], + "sequence": [], "references": [{ "name": ref.name, "database_name": ref.database_name, "database_identifier": ref.database_identifier, } for ref in refs], } - - try: - # sort placements by start position - placements = sorted(aln.placements, key=lambda p: p.start) - for placement in placements: - - is_reversed = placement.reversed - - # get real identities instead of hashes of aligned blocks - block_idx = placement.block_idx - block = blocks[block_idx] - if is_reversed: - block = list(reversed(block)) - # gap_inds = [i for i, x in enumerate(placement.block_aln) if x == "-"] - # print(len(placement.block_aln), len(block), gap_inds) - - placement_items = {} - placement_count = 0 - for x in placement.block_aln: - if x == "-": - continue - placement_items[placement_count] = block[placement_count] - placement_count += 1 - - start = placement.start - end = placement.end - it = 0 - for idx in range(start, end + 1): - name = placement_items[it].name - if name.startswith("PKS_"): - name = name.strip("PKS_") - msa_item["sequence"][idx] = { - "id": str(uuid.uuid4()), - "isGap": False, - # "name": None, # could be filled in if needed - "name": name, - "smiles": None, # could be filled in if needed - } - it += 1 - except Exception as e: - pass - msa.append(msa_item) + for i, block in enumerate(seq1_blocks): + msa_item["sequence"].append({ + "id": str(uuid.uuid4()), + "name": f"gene {i + 1}", + "sequence": [ + { + "id": str(uuid.uuid4()), + "isGap": True, + "name": None, + "smiles": None, + } + for _ in range(len(block)) + ], + }) + + if msa_item: + try: + # sort placements by start position + print(aln) + placements = sorted(aln.placements, key=lambda p: p.start) + for placement in placements: + + is_reversed = placement.reversed + + # get real identities instead of hashes of aligned blocks + block_idx = placement.block_idx # THIS IS THE INDEX OF THE BLOCK IN THE CANDIDATE SEQUENCE + block = blocks[block_idx] + if is_reversed: + block = list(reversed(block)) + # gap_inds = [i for i, x in enumerate(placement.block_aln) if x == "-"] + # print(len(placement.block_aln), len(block), gap_inds) + + print("block_idx", block_idx) + + placement_items = {} + placement_count = 0 + for x in placement.block_aln: + if x == "-": + continue + placement_items[placement_count] = block[placement_count] + placement_count += 1 + print("placement count", placement_count) + + # NOTE: GAPS COULD BE INTRODUCED IN BOTH QUERY AND TARGET!!!!! DURING ALIGNMENT + # NOTE: WHY ARENT THE SUGARS MATCHING/ALIGNING WITH ERYTHROMYCIN + # NOTE: APPEND UNMATCHED PARTS TO THE END OF THE ALIGNMENT AS EXTRA BLOCKS, NEED TO CHECK PADDING AFTERWARDS + + start = placement.start + end = placement.end + print("start-end", start, end) + it = 0 + for idx in range(start, end + 1): + name = placement_items[it].name + if name.startswith("PKS_"): + name = name.strip("PKS_") + offset = start + msa_item["sequence"][0]["sequence"][idx] = { + "id": str(uuid.uuid4()), + "isGap": False, + # "name": None, # could be filled in if needed + "name": name, + "smiles": None, # could be filled in if needed + } + it += 1 + + for block_idx in aln.unused_blocks: + unused_block = blocks[block_idx] + msa_item["sequence"].append({ + "id": str(uuid.uuid4()), + "name": f"additional sequence {block_idx + 1}", + "sequence": [], + }) + for x in unused_block: + msa_item["sequence"][-1]["sequence"].append({ + "id": str(uuid.uuid4()), + "isGap": False, + "name": x.name, + "smiles": None, + }) + + except Exception as e: + pass + + msa.append(msa_item) # For now just return error + # return jsonify({"msa": msa}), 200 return jsonify({"msa": msa}), 200 + diff --git a/src/server/routes/query/align.py b/src/server/routes/query/align.py new file mode 100644 index 0000000..9d1c37c --- /dev/null +++ b/src/server/routes/query/align.py @@ -0,0 +1,133 @@ +"""Module for aligning sequence items and creating MSA.""" + +from dataclasses import dataclass +from typing import Any + +from retromol.chem.fingerprint import calculate_tanimoto_similarity + +from versalign.aligner import Aligner, setup_aligner +from versalign.scoring import create_substitution_matrix_dynamically +from versalign.docking import DockingResult, dock_against_target + +from routes.query.seq import ( + DISPLAY_NAME_UNIDENTIFIED, + SequenceItem, + Gap, + NonGap, + SequenceItemReadout, +) + + +@dataclass(frozen=True) +class MSAResult: + """ + Data structure representing the MSA result of a query. + """ + def to_dict(self) -> dict[str, Any]: + """ + Convert the MSAResult to a dictionary. + + :return: a dictionary representation of the MSAResult + """ + return { + "msa": [], + } + + +def item_compare_fn(a: SequenceItem, b: SequenceItem) -> float: + """ + Compare two SequenceItems or Gaps for sorting. + """ + # Deal with gaps first + if isinstance(a, Gap) or isinstance(b, Gap): + return 0.0 + + # Both items are non-gaps at this point + if isinstance(a, NonGap) and isinstance(b, NonGap): + if a.morgan_fp is not None and b.morgan_fp is not None: + return calculate_tanimoto_similarity(a.morgan_fp, b.morgan_fp) + + if a.display_name == DISPLAY_NAME_UNIDENTIFIED or b.display_name == DISPLAY_NAME_UNIDENTIFIED: + return 0.0 # could be correct, but we don't know + + if a.display_name == b.display_name: # NOTE: this is a display name, not unique + return 1.0 + + return -2.0 + + +def item_label_fn(item: SequenceItem) -> str: + """ + Generate a label for a SequenceItem or Gap. + + :param item: SequenceItem or Gap + :return: label string + """ + return str(hash(item)) + + +def _setup_aligner( + readout1: SequenceItemReadout, + readout2: SequenceItemReadout, +) -> Aligner: + """ + Setup an Aligner for two SequenceItemReadouts. + + :param readout1: first SequenceItemReadout + :param readout2: second SequenceItemReadout + :return: configured Aligner + """ + readout1_items = readout1.flatten_items() + readout2_items = readout2.flatten_items() + unique_items = list(set(readout1_items + readout2_items + [Gap()])) + sm, _ = create_substitution_matrix_dynamically( + unique_items, + compare=item_compare_fn, + label_fn=item_label_fn + ) + + aligner = setup_aligner( + sm, + "global", + target_internal_open_gap_score=-5.0, + target_left_open_gap_score=-2.5, + target_right_open_gap_score=-2.5, + query_internal_open_gap_score=-5.0, + query_left_open_gap_score=-2.5, + query_right_open_gap_score=-2.5, + label_fn=item_label_fn, + ) + + return aligner + + +def score_by_alignment( + query: SequenceItemReadout, + items: list[SequenceItemReadout] +) -> tuple[list[DockingResult], list[float]]: + """ + Rerank nearest neighbors based on more accurate scoring. + + :param query: SequenceItemReadout of the query item + :param items: list of SequenceItemReadouts to be scored against the query + :return: tuple of list of DockingResults and their corresponding scores + """ + aln_results = [] + aln_scores: list[float] = [] + + for item in items: + aligner = _setup_aligner(query, item) + + aln: DockingResult = dock_against_target( + aligner=aligner, + target=query.flatten_items(), + candidates=item.blocks, + gap_repr=Gap.alignment_representation(), + allow_block_reverse=True, + strategy="nonoverlap", + ) + + aln_results.append(aln) + aln_scores.append(aln.total_score) + + return aln_results, aln_scores diff --git a/src/server/routes/query/featurize.py b/src/server/routes/query/featurize.py new file mode 100644 index 0000000..b24eee3 --- /dev/null +++ b/src/server/routes/query/featurize.py @@ -0,0 +1,145 @@ +"""Featurization utilities for query items.""" + +from typing import Any, Literal + +from retromol.model.result import Result +from retromol.model.reaction_graph import MolNode +from retromol.model.rules import RuleSet +from retromol.fingerprint.fingerprint import FingerprintGenerator + +from biocracker.query.modules import LinearReadout + +from routes.query.seq import NonGap, SequenceItemReadout + + +RULESET = RuleSet.load_default() +GENERATOR = FingerprintGenerator(RULESET.matching_rules) + +FP_COUNTED = True +FP_SIZE = 1024 + + +def calculate_payload_fingerprint( + payload_type: Literal["cluster", "compound"], + payload: Result | LinearReadout, +) -> list[float]: + """ + Calculate the fingerprint for the given payload based on its type. + + :param payload_type: the type of the payload ("cluster" or "compound") + :param payload: the payload object (Result or LinearReadout) + :return: the calculated fingerprint as a sequence of floats + :raises ValueError: if the payload_type is unsupported + :raises AssertionError: if the payload type does not match the expected class + """ + match payload_type: + case "cluster": + assert isinstance(payload, LinearReadout), f"expected LinearReadout payload, got {type(payload)}" + fp = GENERATOR.fingerprint_from_biocracker_readout(payload, by_orf=False, num_bits=FP_SIZE, counted=FP_COUNTED) + case "compound": + assert isinstance(payload, Result), f"expected Result payload, got {type(payload)}" + fp = GENERATOR.fingerprint_from_result(payload, num_bits=FP_SIZE, counted=FP_COUNTED) + case _: + raise ValueError(f"unsupported payload_type: {payload_type}") + + return fp + + +def _format_readout_compound(payload: Result) -> SequenceItemReadout: + """ + Format the readout for a compound payload. + + :param payload: the compound payload object (Result) + :return: the formatted readout as a sequence of sequences of SequenceItem + """ + linear_readouts: list[list[MolNode]] = payload.linear_readout.paths + + formatted_blocks = [] + for path in linear_readouts: + formatted_block = [NonGap.from_retromol_molnode(n) for n in path] + formatted_blocks.append(formatted_block) + + return SequenceItemReadout(blocks=formatted_blocks) + + +def _format_readout_cluster(payload: LinearReadout) -> SequenceItemReadout: + """ + Format the readout for a cluster payload. + + :param payload: the cluster payload object (LinearReadout) + :return: the formatted readout as a sequence of sequences of SequenceItem + """ + formatted_blocks = [] + for orf_name, orf in payload.biosynthetic_order(by_orf=True): + formatted_block = [NonGap.from_biocracker_module(m) for m in orf] + formatted_blocks.append(formatted_block) + + return SequenceItemReadout(blocks=formatted_blocks) + + +def format_payload_readout( + payload_type: Literal["cluster", "compound"], + payload: Result | LinearReadout, +) -> SequenceItemReadout: + """ + Format the readout for the given payload based on its type. + + :param payload_type: the type of the payload ("cluster" or "compound") + :param payload: the payload object (Result or LinearReadout) + :return: the formatted readout as a sequence of sequences of SequenceItem + :raises ValueError: if the payload_type is unsupported + :raises AssertionError: if the payload type does not match the expected class + """ + match payload_type: + case "cluster": + assert isinstance(payload, LinearReadout), f"expected LinearReadout payload, got {type(payload)}" + query_seq = _format_readout_cluster(payload) + case "compound": + assert isinstance(payload, Result), f"expected Result payload, got {type(payload)}" + query_seq = _format_readout_compound(payload) + case _: + raise ValueError(f"unsupported payload_type: {payload_type}") + + return query_seq + + +def load_payload( + payload_type: Literal["cluster", "compound"], + payload_blob: dict[str, Any], +) -> Result | LinearReadout: + """ + Load the payload object from its blob representation based on its type. + + :param payload_type: the type of the payload ("cluster" or "compound") + :param payload_blob: the payload data as a dictionary + :return: the loaded payload object (Result or LinearReadout) + :raises ValueError: if the payload_type is unsupported + """ + match payload_type: + case "cluster": + payload = LinearReadout.from_dict(payload_blob) + case "compound": + payload = Result.from_dict(payload_blob) + case _: + raise ValueError(f"unsupported payload_type: {payload_type}") + + return payload + + +def featurize_item( + payload_type: Literal["cluster", "compound"], + payload_blob: dict[str, Any], +) -> tuple[list[float], SequenceItemReadout]: + """ + Featurize the given payload based on its type. + + :param payload_type: the type of the payload ("cluster" or "compound") + :param payload_blob: the payload data as a dictionary + :return: a tuple containing the feature vector and query blocks + :raises ValueError: if the payload_type is unsupported + """ + payload = load_payload(payload_type, payload_blob) + query_vec = calculate_payload_fingerprint(payload_type, payload) + query_seq = format_payload_readout(payload_type, payload) + + return query_vec, query_seq diff --git a/src/server/routes/query/pipeline.py b/src/server/routes/query/pipeline.py new file mode 100644 index 0000000..45058d5 --- /dev/null +++ b/src/server/routes/query/pipeline.py @@ -0,0 +1,442 @@ +"""Pipeline for cross-modal retrieval.""" + +import uuid +from dataclasses import dataclass +from typing import Any + +from flask import current_app + +from routes.query.align import MSAResult, score_by_alignment, item_label_fn +from routes.query.featurize import featurize_item +from routes.query.retrieve import ann_search +from routes.query.seq import DISPLAY_NAME_UNIDENTIFIED, Gap, SequenceItemReadout +from routes.query.featurize import load_payload, format_payload_readout + +from versalign.docking import DockingResult, DockPlacement + +from bionexus.db.models import CandidateCluster, Compound + + +@dataclass(frozen=True) +class _InsKey: + """ + Key to uniquely identify an insertion column in docking results. + + :var result_idx: index of the docking result + :var placement_idx: index of the placement within the docking result + :var col_in_region: column index within the insertion region + :var anchor: target position anchor for the insertion + """ + + result_idx: int + placement_idx: int + col_in_region: int + anchor: int # insertion occurs AFTER this target position; -1 means before target[0] + + +def _slice_alignment_to_target_region( + center_aln: list[str], + block_aln: list[str], + start: int, + end: int, + gap_repr: str, +) -> tuple[list[str], list[str]]: + """ + Slice (center_aln, block_aln) down to the columns that map to target coordinates + [start, end] inclusive, while keeping insertion columns (center token == gap_repr) + that occur while inside the region. + + :param center_aln: list of hashes representing the center alignment SequenceItems + :param block_aln: list of hashes representing the block alignment SequenceItems + :param start: start target position (inclusive) + :param end: end target position (inclusive) + :param gap_repr: hash string representing of a gap + """ + if len(center_aln) != len(block_aln): + raise ValueError("alignment lenght mismatch") + + out_c: list[str] = [] + out_b: list[str] = [] + + target_pos = -1 + in_region = False + + for c_tok, b_tok in zip(center_aln, block_aln): + if c_tok != gap_repr: + target_pos += 1 + in_region = (start <= target_pos <= end) + + if start <= target_pos <= end: + out_c.append(c_tok) + out_b.append(b_tok) + + if target_pos > end: + break + + else: + if in_region: + out_c.append(c_tok) + out_b.append(b_tok) + + return out_c, out_b + + +def merge_dockings_into_global_alignment( + target: list[str], + dockings: list[DockingResult], + gap_repr: str, +) -> tuple[list[list[str]], list[list[int | None]]]: + """ + Merge multiple docking results into a global MSAResult. + + :param target: the tokenized target SequenceItemReadout + :param dockings: list of DockingResult to merge + :param gap_repr: string representation for gaps + :return: tuple of (aligned rows, block maps) + :raises ValueError: if target blocks are empty + """ + if not target: + raise ValueError("target blocks are empty") + + n = len(target) + + # Collect ALL insertion columns across all dockings, anchored to a target boundary + # anchor = j means "insertion column occurs after target position j" + # anchor = -1 means "before target[0]" + insertions_by_anchor: dict[int, list[_InsKey]] = {a: [] for a in range(-1, n)} + # We also need to later map each insertion column identity -> global column index + inskey_to_global_col: dict[_InsKey, int] = {} + + # Also precompute mapping of target positions -> global columns (once built) + targetpos_to_global_col: dict[int, int] = {} + + # To place insertions deterministically, we'll sort them by: + # (result_idx, placement start, placement_idx, col_in_region) + # We need placement start; capture it in a side map + placement_start: dict[tuple[int, int], int] = {} + + for ri, dr in enumerate(dockings): + placements = sorted(dr.placements, key=lambda p: (p.start, p.end)) + for pi, p in enumerate(placements): + placement_start[(ri, pi)] = p.start + + reg_center, reg_block = _slice_alignment_to_target_region( + center_aln=p.center_aln, + block_aln=p.block_aln, + start=p.start, + end=p.end, + gap_repr=gap_repr, + ) + + # Walk region columns and anchor insertions. + # We anchor insertion columns to "after the last consumed target position" + # Initialize target_pos to p.start - 1 so that the first consumed target sets it to p.start + tpos = p.start - 1 + for ci, c_tok in enumerate(reg_center): + if c_tok != gap_repr: + tpos += 1 + else: + # Insertion after tpos (which is in [p.start-1, .. p.end-1]) + # If tpos == p.start-1, that's an insertion before the first consumed symbol in the region + anchor = tpos + if anchor < -1: anchor = -1 + if anchor > n - 1: anchor = n - 1 + insertions_by_anchor[anchor].append(_InsKey(ri, pi, ci, anchor=anchor)) + + # Sort insertions at each anchor deterministically + for anchor, keys in insertions_by_anchor.items(): + keys.sort( + key=lambda k: ( + k.result_idx, + placement_start.get((k.result_idx, k.placement_idx), 10**9), + k.placement_idx, + k.col_in_region, + ) + ) + + # Build the global aligned center, assigning global column indices + aligned_center: list[str] = [] + + # Insertions before target[0] (anchor -1) + for k in insertions_by_anchor[-1]: + inskey_to_global_col[k] = len(aligned_center) + aligned_center.append(gap_repr) + + # For each target pos j: emit target[j], then insertion anchored at j + for j in range(n): + targetpos_to_global_col[j] = len(aligned_center) + aligned_center.append(target[j]) + + for k in insertions_by_anchor[j]: + inskey_to_global_col[k] = len(aligned_center) + aligned_center.append(gap_repr) + + aligned_target = aligned_center[:] # same content; separate name for readability + L = len(aligned_center) + + # Helper to write a placement into a row, with "max score wins" on collisions + def _project_one_placement_into_row( + row: list[str], + score_row: list[float], + block_map: list[int | None], + ri: int, + pi: int, + p: DockPlacement, + ) -> None: + """ + Project one placement into the given row, updating in-place. + + :param row: list of strings representing the row to update + :param score_row: list of floats representing the scores for each column in the row + :param block_map: mapping from block indices to global column indices + :param ri: index of the docking result + :param pi: index of the placement within the docking result + :param p: DockPlacement to project + """ + reg_center, reg_block = _slice_alignment_to_target_region( + center_aln=p.center_aln, + block_aln=p.block_aln, + start=p.start, + end=p.end, + gap_repr=gap_repr, + ) + + tpos = p.start - 1 + for ci, (c_tok, b_tok) in enumerate(zip(reg_center, reg_block)): + if c_tok != gap_repr: + tpos += 1 + gcol = targetpos_to_global_col[tpos] + else: + # tpos should naturally be in [p.start-1, .. p.end-1], so we don't need clamping here + if not (-1 <= tpos <= n - 1): + raise ValueError("unexpected target position for insertion column") + key = _InsKey(ri, pi, ci, anchor=tpos) + # Because we sorted+assigned by identity, this must exist + gcol = inskey_to_global_col.get(key, None) + if gcol is None: + # Extremely defensive fallback: skip if we somehow didn't register it + continue + + if b_tok == gap_repr: + continue + + # Resolve collisions within the same row + if score_row[gcol] < float(p.score): + row[gcol] = b_tok + score_row[gcol] = float(p.score) + # Mark block owernship + block_map[gcol] = p.block_idx + + # Build rows\ + row_ids = list(range(len(dockings))) + rows: list[list[str]] = [] + block_maps: list[list[int | None]] = [] + + for ri, dr in enumerate(dockings): + placements = sorted(dr.placements, key=lambda p: (p.start, p.end)) + row = [gap_repr] * L + score_row = [float("-inf")] * L + block_map = [None] * L + for pi, p in enumerate(placements): + _project_one_placement_into_row( + row=row, + score_row=score_row, + block_map=block_map, + ri=ri, + pi=pi, + p=p, + ) + + rows.append(row) + block_maps.append(block_map) + + # First row should be the target + rows = [aligned_target] + rows + block_maps = [[None] * L] + block_maps + + return rows, block_maps + + +def cross_modal_retrieval( + payload_type: str, + payload_blob: dict[str, Any], + query_against_clusters: bool, + query_against_compounds: bool, + top_k: int = 20, +) -> MSAResult: + """ + Perform cross-modal retrieval given an item payload. + + :param payload_type: type of the payload ("cluster" or "compound") + :param payload_blob: the actual payload data + :param query_against_clusters: whether to query against clusters + :param query_against_compounds: whether to query against compounds + :param top_k: number of top results to return + :return: MSAResult containing the retrieval results + :raises ValueError: if no nearest neighbors found or alignment fails + """ + # Featurize query + featurized_item: tuple[list[float], SequenceItemReadout] = featurize_item(payload_type, payload_blob) + query_vec, query_blocks = featurized_item + + # ANN with query_vec; return nearest neighbors with cosine DISTANCE + nns: list[tuple[CandidateCluster | Compound, float]] = ann_search( + query_vec, + query_against_clusters=query_against_clusters, + query_against_compounds=query_against_compounds, + ) + current_app.logger.debug(f"found {len(nns)} nearest neighbors") + + # Featurize nearest neighbors as SequenceItemReadout with cosine SCORE (1 - distance) + nns_featurized: list[SequenceItemReadout] = [] + nns_cosine_scores: list[float] = [] + for item, distance in nns: + assert isinstance(item, (CandidateCluster, Compound)), "expected item to be CandidateCluster or Compound" + item_type = "cluster" if isinstance(item, CandidateCluster) else "compound" + item_blob = item.biocracker if item_type == "cluster" else item.retromol + item_payload = load_payload(item_type, item_blob) + item_readout = format_payload_readout(item_type, item_payload) + nns_featurized.append(item_readout) + nns_cosine_scores.append(1.0 - distance) + + if not nns_featurized or not query_blocks: + raise ValueError("no nearest neighbors found or query blocks are empty") + + # Rerank nearest neighbors by alignment + alignment_results: tuple[list[DockingResult], list[float]] = score_by_alignment(query_blocks, nns_featurized) + aln_results, aln_scores = alignment_results + + if not aln_results or not aln_scores: + raise ValueError("alignment scoring failed; no results or scores obtained") + + # Get top K nns_featurized and aln_results; first sorted on aln_scores, then on nns_cosine_scores + top_k_indices = sorted(range(len(aln_scores)), key=lambda i: (aln_scores[i], nns_cosine_scores[i]), reverse=True)[:top_k] + current_app.logger.debug(f"top k indices: {top_k_indices}") + + top_k_nns_featurized = [nns_featurized[i] for i in top_k_indices] + top_k_aln_results = [aln_results[i] for i in top_k_indices] + top_k_aln_scores = [aln_scores[i] for i in top_k_indices] + top_k_cosine_scores = [nns_cosine_scores[i] for i in top_k_indices] + + current_app.logger.debug(f"found {len(top_k_nns_featurized)} top-k nearest neighbors after reranking") + current_app.logger.debug(f"top scores for first nearest neighbor: aln {top_k_aln_scores[0]}, cosine {top_k_cosine_scores[0]}") + + # Merge top K dockings into global alignment + rows, block_maps = merge_dockings_into_global_alignment( + target=[item_label_fn(item) for item in query_blocks.flatten_items()], + dockings=top_k_aln_results, + gap_repr=Gap.alignment_representation(), + ) + + msa_result = {"msa": []} + + # Format row[0] as target + mapping = {item_label_fn(item): item for item in query_blocks.flatten_items()} + msa_item = { + "id": str(uuid.uuid4()), + "name": "Query", + "alignment_score": None, + "cosine_score": None, + "sequence": [ + { + "id": str(uuid.uuid4()), + "name": "query primary sequence", + "sequence": [ + { + "id": str(uuid.uuid4()), + "isGap": (tok == Gap.alignment_representation()), + "name": mapping.get(tok).display_name if tok in mapping else "unknown", + "smiles": None, + } + for tok in rows[0] + ] + } + ], + "references": [] + } + msa_result["msa"].append(msa_item) + + max_len = len(rows[0]) + + # We now need to translate the tokenized SequenceItems back to their original SequenceItems + for i in range(1, len(rows[1:])): # skip target row at index 0 + row = rows[i] + block_map = block_maps[i] + + # Create map + readout = nns_featurized[top_k_indices[i]] + mapping = {item_label_fn(item): item for item in readout.flatten_items()} + + msa_item = { + "id": str(uuid.uuid4()), + "name": f"Result {i+1}", + "alignment_score": top_k_aln_scores[i], + "cosine_score": top_k_cosine_scores[i], + "sequence": [], + "references": [], + } + # Create subseq per block idx, in order of appearance in block_map + block_idx_to_subseq: dict[int, list[dict[str, Any]]] = {} + block_order: list[int] = [] + + for col_idx, tok in enumerate(row): + block_idx = block_map[col_idx] + if block_idx is None: + continue # skip columns not owned by a block + + if block_idx not in block_idx_to_subseq: + block_idx_to_subseq[block_idx] = [] + block_order.append(block_idx) + + if tok == Gap.alignment_representation(): + seq_item = { + "id": str(uuid.uuid4()), + "isGap": True, + "name": Gap().display_name, + "smiles": None, + } + else: + obj = mapping.get(tok) + seq_item = { + "id": str(uuid.uuid4()), + "isGap": False, + "name": obj.display_name if obj is not None else DISPLAY_NAME_UNIDENTIFIED, + "smiles": None, + } + + block_idx_to_subseq[block_idx].append(seq_item) + + msa_item["sequence"] = [ + { + "id": str(uuid.uuid4()), + "name": f"retrieved primary sequence block {block_idx}", + "sequence": block_idx_to_subseq[block_idx], + } + for block_idx in block_order + ] + + # Pad with gaps if needed + cum_len = 0 + for subseq in msa_item["sequence"]: + cum_len += len(subseq["sequence"]) + if cum_len < max_len: + len_diff = max_len - cum_len + msa_item["sequence"].append( + { + "id": str(uuid.uuid4()), + "name": "padding gap", + "sequence": [ + { + "id": str(uuid.uuid4()), + "isGap": True, + "name": Gap().display_name, + "smiles": None, + } + for _ in range(len_diff) + ] + } + ) + + msa_result["msa"].append(msa_item) + + return msa_result diff --git a/src/server/routes/query/retrieve.py b/src/server/routes/query/retrieve.py new file mode 100644 index 0000000..6d3bf94 --- /dev/null +++ b/src/server/routes/query/retrieve.py @@ -0,0 +1,122 @@ +"""Approximate nearest neighbor search using HNSW index in Postgres.""" + +from typing import Any, Sequence + +import sqlalchemy as sa +from sqlalchemy.orm import Session + +from bionexus.db.models import CandidateCluster, Compound + +from routes.database import SessionLocal + + +ANN_SEARCH_RADIUS = 1000 + +HNSW_SETTINGS = { + "hnsw.iterative_scan": "strict_order", + "hnsw.max_scan_tuples": 1_000_000, + "hnsw.scan_mem_multiplier": 2, + "hnsw.ef_search": 1000, +} + + +def _set_local(session: Session, settings: dict[str, Any]) -> None: + """ + Set local session settings for Postgres. + + :param session: the SQLAlchemy session + :param settings: a dictionary of settings to apply + :raises ValueError: if an unsupported setting value type is provided + """ + for k, v in settings.items(): + # Numbers must not be quoted; strings should be quoted + if isinstance(v, str): + session.execute(sa.text(f"SET LOCAL {k} = '{v}'")) + elif isinstance(v, (int, float)): + session.execute(sa.text(f"SET LOCAL {k} = {v}")) + else: + raise ValueError(f"unsupported setting value type: {type(v)} for key {k}") + + +def _ann_query( + session: Session, + model: type[CandidateCluster] | type[Compound], + vector_col: Any, + query_vec: list[float], + where: Sequence[Any], + limit: int, +) -> list[tuple[CandidateCluster | Compound, float]]: + """ + Perform an approximate nearest neighbor search using the HNSW index. + + :param session: the SQLAlchemy session + :param model: the SQLAlchemy model to query + :param vector_col: the vector column to search against + :param query_vec: the query vector + :param where: additional filtering conditions + :param limit: the maximum number of results to return + :return: a list of tuples of (model instance, distance) + """ + dist = vector_col.cosine_distance(query_vec).label("dist") + stmt = ( + sa.select(model, dist) + .where(*where) + .order_by(dist.asc()) + .limit(limit) + ) + return session.execute(stmt).all() + + +def ann_search( + query_vec: list[float], + query_against_clusters: bool, + query_against_compounds: bool, +) -> list[tuple[CandidateCluster | Compound, float]]: + """ + Perform an approximate nearest neighbor search against clusters and/or compounds. + + :param query_vec: the query vector + :param query_against_clusters: whether to query against candidate clusters + :param query_against_compounds: whether to query against compounds + :return: a list of tuples of (model instance, distance) + """ + if not query_against_clusters and not query_against_compounds: + return [] + + only_one = query_against_clusters ^ query_against_compounds + per_type_limit = ANN_SEARCH_RADIUS if only_one else ANN_SEARCH_RADIUS // 2 + + with SessionLocal() as session: + _set_local(session, HNSW_SETTINGS) + + cluster_rows = ( + _ann_query( + session=session, + model=CandidateCluster, + vector_col=CandidateCluster.retromol_fp_counted_by_region, + query_vec=query_vec, + where=[CandidateCluster.retromol_fp_counted_by_region.is_not(None)], + limit=per_type_limit, + ) + if query_against_clusters + else [] + ) + + compound_rows = ( + _ann_query( + session=session, + model=Compound, + vector_col=Compound.retromol_fp_counted, + query_vec=query_vec, + where=[Compound.retromol_fp_counted.is_not(None)], + limit=per_type_limit, + ) + if query_against_compounds + else [] + ) + + combined = cluster_rows + compound_rows + combined.sort(key=lambda x: x[1]) # sort by distance + + # Should not have more than ANN_SEARCH_RADIUS items now, but just in case + return combined[:ANN_SEARCH_RADIUS] diff --git a/src/server/routes/query/seq.py b/src/server/routes/query/seq.py new file mode 100644 index 0000000..1f8a26f --- /dev/null +++ b/src/server/routes/query/seq.py @@ -0,0 +1,185 @@ +"""Module defining sequence item data structures for query results.""" + +from dataclasses import dataclass, field + +from rdkit.DataStructs.cDataStructs import ExplicitBitVect + +from retromol.model.reaction_graph import MolNode +from retromol.chem.mol import smiles_to_mol +from retromol.chem.fingerprint import mol_to_morgan_fingerprint + +from biocracker.query.modules import ( + Module, + PKSModule, + NRPSModule, + PKSSubstrate, + PKSExtenderUnit, + NRPSSubstrate, +) + +DEFAULT_GAP_REPR = "-" + +MORGAN_RADIUS = 2 +MORGAN_SIZE = 2048 + +DISPLAY_NAME_UNIDENTIFIED = "unknown" + + +@dataclass(frozen=True) +class SequenceItem: + """ + Base class for sequence items in query results. + """ + + ... + + +@dataclass(frozen=True) +class Gap(SequenceItem): + """ + Gap sequence item representing an unknown or missing module. + """ + + display_name: str = DEFAULT_GAP_REPR + + def __str__(self) -> str: + """ + String representation of the SequenceItem. + + :return: string representation + """ + return self.display_name + + def __hash__(self) -> int: + """ + Hash function for SequenceItem. + + :return: hash value + """ + return hash(self.display_name) + + @classmethod + def alignment_representation(cls) -> str: + """ + Representation used in alignments. + + :return: alignment representation string + """ + return str(hash(cls(DEFAULT_GAP_REPR))) + + +@dataclass(frozen=True) +class NonGap(SequenceItem): + """ + Non-gap sequence item representing a module or identified molecule. + + :var display_name: Name to display for the item. + :var morgan_fp: Morgan fingerprint of the molecule. + """ + + display_name: str + morgan_fp: ExplicitBitVect | None = None + family_tokens: list[str] | None = field(default_factory=list) + ancestor_tokens: list[str] | None = field(default_factory=list) + + def __hash__(self) -> int: + """ + Hash based on display name and Morgan fingerprint. + + :return: hash value + """ + return hash(( + self.display_name, + self.morgan_fp.ToBitString() if self.morgan_fp else None, + "|".join(self.family_tokens) if self.family_tokens else None, + "|".join(self.ancestor_tokens) if self.ancestor_tokens else None, + )) + + @classmethod + def from_retromol_molnode(cls, n: MolNode) -> "SequenceItem": + """ + Create a SequenceItem from a RetroMol MolNode. + + :param n: RetroMol MolNode + :return: SequenceItem + """ + assert isinstance(n, MolNode), f"expected RetroMol MolNode, got {type(molnode)}" + + if n.is_identified: + matched_rule = n.identity.matched_rule + display_name = matched_rule.name + mol = smiles_to_mol(matched_rule.smiles) + morgan_fp = mol_to_morgan_fingerprint(mol, radius=MORGAN_RADIUS, num_bits=MORGAN_SIZE) + family_tokens = list(matched_rule.family_tokens) + ancestor_tokens = list(matched_rule.ancestor_tokens) + + return cls( + display_name=display_name, + morgan_fp=morgan_fp, + family_tokens=family_tokens, + ancestor_tokens=ancestor_tokens, + ) + else: + display_name = DISPLAY_NAME_UNIDENTIFIED + mol = smiles_to_mol(n.smiles) + morgan_fp = mol_to_morgan_fingerprint(mol, radius=MORGAN_RADIUS, num_bits=MORGAN_SIZE) + + return cls(display_name=display_name, morgan_fp=morgan_fp) + + @classmethod + def from_biocracker_module(cls, m: Module) -> "SequenceItem": + """ + Create a SequenceItem from a BioCracker module. + + :param m: BioCracker Module + :return: SequenceItem + """ + assert isinstance(m, Module), f"expected BioCracker Module, got {type(module)}" + + match m: + case PKSModule(substrate=PKSSubstrate(extender_unit=PKSExtenderUnit.PKS_A)): + return cls(display_name="A", ancestor_tokens=["PKS", "A"]) + case PKSModule(substrate=PKSSubstrate(extender_unit=PKSExtenderUnit.PKS_B)): + return cls(display_name="B", ancestor_tokens=["PKS", "B"]) + case PKSModule(substrate=PKSSubstrate(extender_unit=PKSExtenderUnit.PKS_C)): + return cls(display_name="C", ancestor_tokens=["PKS", "C"]) + case PKSModule(substrate=PKSSubstrate(extender_unit=PKSExtenderUnit.PKS_D)): + return cls(display_name="D", ancestor_tokens=["PKS", "D"]) + case PKSModule(substrate=PKSSubstrate(extender_unit=PKSExtenderUnit.UNCLASSIFIED)): + return cls(display_name="A", ancestor_tokens=["PKS", "A"]) + case NRPSModule(substrate=NRPSSubstrate(smiles=None)): + return cls(display_name=DISPLAY_NAME_UNIDENTIFIED, ancestor_tokens=["NRPS"]) + case NRPSModule(substrate=NRPSSubstrate(name=name, smiles=smiles)): + display_name = name if name is not None else DISPLAY_NAME_UNIDENTIFIED + + # Graminine SMILES fix (fixed in >=2.0.1 versions of BioCracker) + if smiles == "O=NN(O)CCC[C@H](N)(C(=O)O": + smiles = "O=NN(O)CCC[C@H](N)C(=O)O" + + mol = smiles_to_mol(smiles) + morgan_fp = mol_to_morgan_fingerprint(mol, radius=MORGAN_RADIUS, num_bits=MORGAN_SIZE) + return cls(display_name=display_name, morgan_fp=morgan_fp, ancestor_tokens=["NRPS"]) + case _: + raise NotImplementedError(f"BioCracker module type {type(m)} not supported yet") + + +@dataclass(frozen=True) +class SequenceItemReadout: + """ + Readout of sequence items in query results. + + :var blocks: list of blocks, where each block is a list of SequenceItems + """ + + blocks: list[list[SequenceItem]] + + def flatten_items(self) -> list[SequenceItem]: + """ + Flatten the blocks into a single list of SequenceItems. + + :return: flattened list of SequenceItems + """ + # Sort blocks on size; longer blocks first + blocks_sorted = sorted(self.blocks, key=lambda b: len(b), reverse=True) + + return [item for block in blocks_sorted for item in block] diff --git a/src/server/routes/query_service.py b/src/server/routes/query_service.py new file mode 100644 index 0000000..1698dde --- /dev/null +++ b/src/server/routes/query_service.py @@ -0,0 +1,54 @@ +"""Query service API endpoint.""" + +from flask import Blueprint, current_app, request, jsonify + +from routes.session_store import load_item +from routes.query.pipeline import cross_modal_retrieval + + +blp_query_item = Blueprint("query_item", __name__) + + +@blp_query_item.get("/api/queryItem") +def query_item(): + """ + Endpoint to query a specific item (cluster or compound). + """ + session_id = request.args.get("sessionId", "").strip() + item_id = request.args.get("itemId", "").strip() + if not session_id: + return jsonify({"error": "Missing sessionId"}), 400 + if not item_id: + return jsonify({"error": "Missing itemId"}), 400 + + query_against_clusters = request.args.get("queryAgainstClusters", "true").lower() == "true" + query_against_compounds = request.args.get("queryAgainstCompounds", "true").lower() == "true" + current_app.logger.debug(f"query_against_compounds: {query_against_compounds}") + current_app.logger.debug(f"query_against_clusters: {query_against_clusters}") + if not query_against_clusters and not query_against_compounds: + return jsonify({"error": "At least one of queryAgainstClusters or queryAgainstCompounds must be true"}), 400 + + # Retrieve item from session store + item = load_item(session_id, item_id) + if item is None: + return jsonify({"error": "Item not found"}), 404 + + payload_type = item.get("kind", None) + payload_blob = item.get("payload", None) + if not payload_type or payload_type not in ["cluster", "compound"]: + return jsonify({"error": "Invalid item kind"}), 400 + if not payload_blob: + return jsonify({"error": "Missing item payload"}), 400 + + try: + msa_result = cross_modal_retrieval( + payload_type=payload_type, + payload_blob=payload_blob, + query_against_clusters=query_against_clusters, + query_against_compounds=query_against_compounds, + ) + except ValueError as e: + current_app.logger.error(f"error during cross-modal retrieval: {e}") + return jsonify({"error": str(e)}), 500 + + return jsonify(msa_result), 200 diff --git a/src/server/routes/session.py b/src/server/routes/session.py index 6248650..65757f3 100644 --- a/src/server/routes/session.py +++ b/src/server/routes/session.py @@ -10,6 +10,7 @@ merge_session_from_client, count_sessions, delete_item as redis_delete_item, + strip_property_from_dict, ) @@ -97,6 +98,9 @@ def get_session() -> tuple[dict[str, str], int]: items = [] full["items"] = items + # We don't want to send payloads back to the client + full["items"] = [strip_property_from_dict(item, "payload") for item in items] + return jsonify({"sessionId": full["sessionId"], "session": full}), 200 diff --git a/src/server/routes/session_store.py b/src/server/routes/session_store.py index e1a1957..72dc287 100644 --- a/src/server/routes/session_store.py +++ b/src/server/routes/session_store.py @@ -297,6 +297,21 @@ def load_item(session_id: str, item_id: str) -> dict[str, Any] | None: return json.loads(data) +def strip_property_from_dict(d: dict[str, Any], prop: str) -> dict[str, Any]: + """ + Recursively strip a property from a nested dictionary. + + :param d: the dictionary to process + :param prop: the property name to strip + :return: a new dictionary with the property stripped + """ + out = dict(d) + if prop in out: + out.pop(prop, None) + + return out + + def save_item(session_id: str, item: dict[str, Any]) -> None: """ Save a specific item to a session. @@ -437,7 +452,8 @@ def merge_session_from_client(new_session: dict[str, Any]) -> None: if old_item is None: # New item: accept as-is (client owns everything initially) - merged_items.append(new_item) + item = strip_property_from_dict(new_item, "payload") # remove payload if present + merged_items.append(item) new_item_ids.append(item_id) else: # Existing item: merge client fields into old item, presevering server-owned fields From 9f080d67476b035c775577aebe8ae9988e9b077a Mon Sep 17 00:00:00 2001 From: David Meijer Date: Thu, 15 Jan 2026 11:59:49 +0100 Subject: [PATCH 13/34] UPD: correct docking alignment --- .../tabs/discovery/QueryResultView.tsx | 6 +- .../workspace/tabs/upload/WorkspaceUpload.tsx | 45 +- src/server/routes/query/_query.py | 563 ------------------ src/server/routes/query/align.py | 327 +++++++++- src/server/routes/query/featurize.py | 10 +- src/server/routes/query/pipeline.py | 134 +---- src/server/routes/query/retrieve.py | 25 +- src/server/routes/query/seq.py | 2 + src/server/routes/query_service.py | 5 +- 9 files changed, 422 insertions(+), 695 deletions(-) delete mode 100644 src/server/routes/query/_query.py diff --git a/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx b/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx index e56694b..2f74ad4 100644 --- a/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx +++ b/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx @@ -204,9 +204,9 @@ export const defaultMotifColorMap = (): Record => { // plain (opaque) base newColorMap[key] = color; - // numbered variants 1->15 -> alpha = 1/15...15/15 - for (let i = 1; i <= 15; i++) { - const alpha = 1 - (i / 15); + // numbered variants 1->20 -> alpha = 1/20...20/20 + for (let i = 1; i <= 20; i++) { + const alpha = 1 - (i / 20); const alphaRounded = Math.round(alpha * 1000) / 1000; newColorMap[`${key}${i}`] = parseColor(color, alphaRounded); }; diff --git a/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx b/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx index d98bc51..5ba7d4e 100644 --- a/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx +++ b/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx @@ -6,14 +6,16 @@ import CardContent from "@mui/material/CardContent"; import Stack from "@mui/material/Stack"; import Typography from "@mui/material/Typography"; import MuiLink from "@mui/material/Link"; +import Tooltip from "@mui/material/Tooltip"; import NotificationsRoundedIcon from "@mui/icons-material/NotificationsRounded"; +import RefreshIcon from "@mui/icons-material/Refresh"; import { useTheme } from "@mui/material/styles"; import { useNotifications } from "../../NotificationProvider"; import { Link as RouterLink } from "react-router-dom"; import { DialogImportCompound } from "./DialogImportCompound"; import { WorkspaceItemCard } from "./WorkspaceItemCard"; import { Session } from "../../../../features/session/types"; -import { deleteSessionItem } from "../../../../features/session/api"; +import { deleteSessionItem, refreshSession } from "../../../../features/session/api"; import { NewCompoundJob } from "../../../../features/jobs/types"; import { MAX_ITEMS, importCompound, importCompoundsBatch } from "../../../../features/jobs/api"; @@ -75,6 +77,7 @@ export const WorkspaceUpload: React.FC = ({ session, setSe const [openCompounds, setOpenCompounds] = React.useState(false); const [selectedIds, setSelectedIds] = React.useState>(new Set()); const [deletingIds, setDeletingIds] = React.useState>(new Set()); + const [refreshSpinKey, setRefreshSpinKey] = React.useState(0); // to force re-mount of refresh icon // Clean up deletingIds when session items change React.useEffect(() => { @@ -108,6 +111,19 @@ export const WorkspaceUpload: React.FC = ({ session, setSe [setSessionSafe, pushNotification, session.sessionId] ); + // Helper for manual refresh + const handleManualRefresh = async () => { + setRefreshSpinKey((k) => k + 1); // trigger re-mount of icon to restart animation + try { + const fresh = await refreshSession(session.sessionId); + setSession(() => fresh); + pushNotification("Workspace session refreshed successfully!", "success"); + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + pushNotification(`Failed to refresh session: ${msg}`, "error"); + }; + }; + // Selection helpers const toggleSelectItem = (id: string) => { // Prevent toggling if deleting @@ -266,9 +282,30 @@ export const WorkspaceUpload: React.FC = ({ session, setSe justifyContent="space-between" sx={{ mb: 1.5 }} > - - Workspace items ({session.items.length}/{MAX_ITEMS}) - + + + Workspace items ({session.items.length}/{MAX_ITEMS}) + + + 0 ? "refresh-spin 0.6s linear" : "none", + cursor: "pointer", + color: (theme.vars || theme).palette.text.secondary, + "&:hover": { + color: (theme.vars || theme).palette.text.primary, + }, + }} + onClick={handleManualRefresh} + /> + + + {gbkFiles.length > 0 && ( + + {gbkFiles.length} file(s) selected + + )} + + + ); +}; diff --git a/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx b/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx index 669a1e7..c699139 100644 --- a/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx +++ b/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx @@ -13,11 +13,12 @@ import { useTheme } from "@mui/material/styles"; import { useNotifications } from "../../NotificationProvider"; import { Link as RouterLink } from "react-router-dom"; import { DialogImportCompound } from "./DialogImportCompound"; +import { DialogImportCluster } from "./DialogImportCluster"; import { WorkspaceItemCard } from "./WorkspaceItemCard"; import { Session } from "../../../../features/session/types"; import { deleteSessionItem, refreshSession } from "../../../../features/session/api"; import { NewCompoundJob } from "../../../../features/jobs/types"; -import { MAX_ITEMS, importCompound, importCompoundsBatch } from "../../../../features/jobs/api"; +import { MAX_ITEMS, importCompound, importCompoundsBatch, importClustersBatch } from "../../../../features/jobs/api"; const MAX_FILE_SIZE_MB = 2; const MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024; @@ -75,6 +76,8 @@ export const WorkspaceUpload: React.FC = ({ session, setSe const { pushNotification } = useNotifications(); const [openCompounds, setOpenCompounds] = React.useState(false); + const [openClusters, setOpenClusters] = React.useState(false); + const [selectedIds, setSelectedIds] = React.useState>(new Set()); const [deletingIds, setDeletingIds] = React.useState>(new Set()); const [refreshSpinKey, setRefreshSpinKey] = React.useState(0); // to force re-mount of refresh icon @@ -189,8 +192,8 @@ export const WorkspaceUpload: React.FC = ({ session, setSe }; const handleOpenBGCs = (event: React.MouseEvent) => { - // event.currentTarget.blur(); // prevents 'Blocked aria-hidden on an element' warning - console.log("dialog for BGC import not implemented yet"); + event.currentTarget.blur(); // prevents 'Blocked aria-hidden on an element' warning + setOpenClusters(true); }; const handleViewItem = (itemId: string) => { @@ -218,6 +221,40 @@ export const WorkspaceUpload: React.FC = ({ session, setSe }; }; + // Import cluster handler + const handleImportClusters = async (files: File[]) => { + if (!files.length) return; + + const oversized = files.filter(f => f.size > MAX_FILE_SIZE_BYTES); + if (oversized.length > 0) { + pushNotification(`Some files exceed the maximum size of ${MAX_FILE_SIZE_MB} MB and were not imported: ${oversized.map(f => f.name).join(", ")}`, "error"); + + // Keep only files within size limit + files = files.filter(f => f.size <= MAX_FILE_SIZE_BYTES); + }; + + // Check if any files remain after filtering on file size + if (files.length === 0) { + pushNotification("No valid files to import after size filtering.", "warning"); + return; + }; + + let payloads: { name: string; fileContent: string }[] = []; + + try { + payloads = await Promise.all( + files.map(async (file) => ({ + name: file.name, + fileContent: await file.text(), + })) + ) + await importClustersBatch(deps, payloads); + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + pushNotification(`Failed to import BGC files: ${msg}`, "error"); + }; + }; + // Selection states const anySelected = selectedIds.size > 0; const allSelected = session.items.length > 0 && selectedIds.size === session.items.length; @@ -264,7 +301,7 @@ export const WorkspaceUpload: React.FC = ({ session, setSe Import compounds - @@ -278,6 +315,12 @@ export const WorkspaceUpload: React.FC = ({ session, setSe onImportBatch={handleImportBatchCompounds} /> + setOpenClusters(false)} + onImport={handleImportClusters} + /> + {session.items.length > 0 && ( diff --git a/src/client/src/features/jobs/api.ts b/src/client/src/features/jobs/api.ts index 23a5232..666a74f 100644 --- a/src/client/src/features/jobs/api.ts +++ b/src/client/src/features/jobs/api.ts @@ -1,6 +1,6 @@ import { postJson } from "../http"; import type { WorkspaceImportDeps, NewCompoundJob } from "./types"; -import type { Session, SessionItem, CompoundItem } from "../session/types"; +import type { Session, SessionItem, CompoundItem, ClusterItem } from "../session/types"; import { saveSession } from "../session/api"; import { z } from "zod"; @@ -29,6 +29,23 @@ export async function submitCompoundJob( ); }; +export async function submitClusterJob( + sessionId: string, + item: ClusterItem, +): Promise { + await postJson( + "/api/submitCluster", + { + sessionId, + itemId: item.id, + name: item.name, + fileContent: item.fileContent, + }, + SubmitJobRespSchema + ); +}; + +// Batch compound import export async function importCompoundsBatch( deps: WorkspaceImportDeps, compounds: NewCompoundJob[], @@ -107,7 +124,7 @@ export async function importCompoundsBatch( })); return []; - } + }; // Submit jobs sequentially for (const item of newItems) { @@ -145,3 +162,109 @@ export async function importCompound( const items = await importCompoundsBatch(deps, [payload]); return items[0] ?? null; }; + +// Batch cluster import +export async function importClustersBatch( + deps: WorkspaceImportDeps, + clusters: { name: string; fileContent: string }[], +): Promise { + const { pushNotification, setSession, sessionId } = deps; + + if (!clusters.length) { + pushNotification("No clusters to import", "warning"); + return []; + }; + + let nextSession: Session | null = null; + let newItems: SessionItem[] = []; + + // Update local session (queued items) + setSession((prev) => { + const existingCount = prev.items.length; + const remainingSlots = MAX_ITEMS - existingCount; + + if (remainingSlots <= 0) { + pushNotification(`Session already has maximum of ${MAX_ITEMS} items`, "warning"); + nextSession = prev; + newItems = []; + return prev; + }; + + const limited = clusters.length > remainingSlots ? clusters.slice(0, remainingSlots) : clusters; + + if (limited.length < clusters.length) { + pushNotification(`Only importing ${limited.length} clusters to avoid exceeding maximum of ${MAX_ITEMS} items`, "warning"); + }; + + const createdItems: SessionItem[] = limited.map(({ name, fileContent }) => ({ + id: crypto.randomUUID(), + kind: "cluster", + name, + fileContent, + status: "queued", + errorMessage: null, + updatedAt: Date.now(), + })); + + const updated: Session = { ...prev, items: [...prev.items, ...createdItems] }; + + nextSession = updated; + newItems = createdItems; + return updated; + }); + + if (!nextSession || newItems.length === 0) return []; + + // Persist session BEFORE submitting jobs + try { + await saveSession(nextSession); + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + pushNotification(`Failed to save session before importing clusters: ${msg}`, "error"); + + const newIds = new Set(newItems.map((it) => it.id)); + + setSession((prev) => ({ + ...prev, + items: prev.items.map((it) => + newIds.has(it.id) + ? { + ...it, + status: "error", + errorMessage: "Failed to save session before importing cluster", + updatedAt: Date.now(), + } + : it + ) + })); + + return []; + }; + + // Submit jobs sequentially + for (const item of newItems) { + try { + await submitClusterJob(sessionId, item as ClusterItem); + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + pushNotification(`Failed to submit job for cluster "${item.name}": ${msg}`, "error"); + + // Mark item as error + setSession((prev) => ({ + ...prev, + items: prev.items.map((it) => + it.id === item.id + ? { + ...it, + status: "error", + errorMessage: `Failed to submit job: ${msg}`, + updatedAt: Date.now(), + } + : it + ) + })); + }; + }; + + return newItems; +}; diff --git a/src/client/src/features/session/types.ts b/src/client/src/features/session/types.ts index 03abd94..67a24e4 100644 --- a/src/client/src/features/session/types.ts +++ b/src/client/src/features/session/types.ts @@ -15,15 +15,15 @@ export const CompoundItemSchema = BaseItemSchema.extend({ matchStereochemistry: z.boolean(), }); -export const ClusterSchema = BaseItemSchema.extend({ +export const ClusterItemSchema = BaseItemSchema.extend({ kind: z.literal("cluster"), fileContent: z.string(), }); -export const SessionItemSchema = z.discriminatedUnion("kind", [CompoundItemSchema,ClusterSchema]); +export const SessionItemSchema = z.discriminatedUnion("kind", [CompoundItemSchema, ClusterItemSchema]); export type CompoundItem = z.output; -export type ClusterItem = z.output; +export type ClusterItem = z.output; export type SessionItem = z.output; export const SessionSchema = z.object({ diff --git a/src/server/app.py b/src/server/app.py index 8e87a62..be78205 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -18,6 +18,7 @@ from routes.events import blp_events from routes.database import dsn_from_env from routes.compound import blp_search_compound, blp_submit_compound +from routes.cluster import blp_submit_cluster from routes.query_service import blp_query_item @@ -152,5 +153,6 @@ def ready() -> tuple[dict[str, str], int]: app.register_blueprint(blp_delete_item) app.register_blueprint(blp_search_compound) app.register_blueprint(blp_submit_compound) +app.register_blueprint(blp_submit_cluster) app.register_blueprint(blp_query_item) app.register_blueprint(blp_events) diff --git a/src/server/routes/cluster.py b/src/server/routes/cluster.py new file mode 100644 index 0000000..1f8b056 --- /dev/null +++ b/src/server/routes/cluster.py @@ -0,0 +1,165 @@ +"""Blueprints for cluster-related API endpoints.""" + +import time +import tempfile +import os + +from flask import Blueprint, current_app, jsonify, request + +from biocracker.io.readers import load_regions +from biocracker.io.options import AntiSmashOptions +from biocracker.inference.registry import register_domain_model, register_gene_model +from biocracker.pipelines.annotate_region import annotate_region +from biocracker.query.modules import NRPSModule, PKSModule, linear_readout as biocracker_linear_readout + +from routes.session_store import load_session_with_items, update_item +from routes.models_registry import get_paras_model, get_pfam_models + +blp_submit_cluster = Blueprint("submit_cluster", __name__) + + +def _set_item_status_inplace(item: dict, status: str, error_message: str | None = None) -> None: + """ + Update the status and error message of an item in place. + + :param item: the item dictionary to update + :param status: the new status string + :param error_message: optional error message string + """ + item["status"] = status + item["updatedAt"] = int(time.time() * 1000) + + if error_message is not None: + item["errorMessage"] = error_message + else: + if "errorMessage" in item: + item["errorMessage"] = None + + +@blp_submit_cluster.post("/api/submitCluster") +def submit_cluster(): + """ + Submit a cluster for processing. + """ + payload = request.get_json(force=True) or {} + session_id = payload.get("sessionId") + item_id = payload.get("itemId") + name = payload.get("name") + file_content = payload.get("fileContent") + + current_app.logger.info(f"submit_cluster called: session_id={session_id} item_id={item_id}") + + if not session_id or not item_id: + current_app.logger.warning("submit_cluster: missing sessionId or itemId") + return jsonify({"error": "Missing sessionId or itemId"}), 400 + + # Validate session + item exists and kind is correct + full_sess = load_session_with_items(session_id) + if full_sess is None: + current_app.logger.warning(f"submit_cluster: session not found: {session_id}") + return jsonify({"error": "Session not found"}), 404 + + item = next((it for it in full_sess.get("items", []) if it.get("id") == item_id), None) + if item is None: + current_app.logger.warning(f"submit_cluster: item not found: {item_id}") + return jsonify({"error": "Item not found"}), 404 + + if item.get("kind") != "cluster": + current_app.logger.warning(f"submit_cluster: wrong kind={item.get('kind')}") + return jsonify({"error": "Item is not a cluster"}), 400 + + t0 = time.time() + + # Set status=processing early on this item only + def mark_processing(it: dict) -> None: + """ + Update item details and mark as processing. + + :param it: the item dictionary to update + """ + it["name"] = name or it.get("name") + it["fileContent"] = file_content or it.get("fileContent") + _set_item_status_inplace(it, "processing") + + ok = update_item(session_id, item_id, mark_processing) + if not ok: + current_app.logger.warning(f"submit_cluster: failed to mark item as processing: {item_id}") + return jsonify({"error": "Item not found during update"}), 404 + + tmp_path = None + try: + options = AntiSmashOptions(readout_level="cand_cluster") + + paras_model = get_paras_model() + if paras_model: + register_domain_model(paras_model) + + pfam_models = get_pfam_models() + print(f"PFAM models loaded: {pfam_models}") + for pfam_model in pfam_models or []: + register_gene_model(pfam_model) + + # Heavy work + # Write file content to a temporary file + with tempfile.NamedTemporaryFile(mode="w", suffix=".gbk", delete=True) as tmp: + tmp.write(file_content or "") + tmp.flush() + tmp_path = tmp.name + regions = load_regions(tmp_path, options=options) + + # TODO: only saving readout of last candidate cluster found... need to keep all of them and create individual items for each + for region in regions: + annotate_region(region) + readout = biocracker_linear_readout(region) + + module_scores: list[float] = [] + for m in readout.modules: + if isinstance(m, NRPSModule): + if s := m.substrate: + module_scores.append(s.score) + else: + module_scores.append(0.0) + elif isinstance(m, PKSModule): + # Readout from antiSMASH GBK is always confident + module_scores.append(1.0) + else: + current_app.logger.warning(f"submit_cluster: unknown module type: {type(m)}") + module_scores.append(0.0) + + score: float = sum(module_scores) / len(module_scores) if module_scores else 0.0 + result_as_dict: dict = readout.to_dict() + + # Set final status=done and store results on this item only + def mark_done(it: dict) -> None: + it["name"] = name or it.get("name") + it["fileContent"] = file_content or it.get("fileContent") + it["score"] = score + it["payload"] = result_as_dict + _set_item_status_inplace(it, "done") + + update_item(session_id, item_id, mark_done) + + except Exception as e: + current_app.logger.exception(f"submit_cluster: error for item_id={item_id}") + + def mark_error(it: dict) -> None: + _set_item_status_inplace(it, "error", error_message=str(e)) + + update_item(session_id, item_id, mark_error) + + elapsed = int((time.time() - t0) * 1000) + return jsonify({ + "ok": False, + "status": "error", + "elapsed_ms": elapsed, + "error": str(e), + }), 500 + + elapsed = int((time.time() - t0) * 1000) + current_app.logger.info(f"submit_cluster: finished item_id={item_id} elapsed_ms={elapsed}") + + return jsonify({ + "ok": True, + "status": "done", + "elapsed_ms": elapsed, + }), 200 diff --git a/src/server/routes/compound.py b/src/server/routes/compound.py index 9c802d6..3ff0c6c 100644 --- a/src/server/routes/compound.py +++ b/src/server/routes/compound.py @@ -1,7 +1,5 @@ """Blueprints for compound-related API endpoints.""" -from __future__ import annotations - import time from flask import Blueprint, current_app, jsonify, request diff --git a/src/server/routes/models_registry.py b/src/server/routes/models_registry.py new file mode 100644 index 0000000..b841530 --- /dev/null +++ b/src/server/routes/models_registry.py @@ -0,0 +1,88 @@ +"""Module for loading and caching machine learning models used in the application.""" + +from pathlib import Path +import os + +from flask import current_app + +from biocracker.inference.model_paras import ParasModel +from biocracker.inference.model_pfam import PfamModel + + +CACHE_DIR = os.environ.get("CACHE_DIR", "/app/cache") +_model_cache: dict[str, object | None] = {} + +PARAS_MODEL_PATH = os.environ.get("PARAS_MODEL_PATH", None) +PFAM_HMM_DIR_PATH = os.environ.get("PFAM_HMM_DIR_PATH", None) + + +# Make sure cache directory exists +os.makedirs(CACHE_DIR, exist_ok=True) + + +def get_cache_dir() -> str: + """ + Get the cache directory path. + + :return: the cache directory path + """ + return CACHE_DIR + + +def get_paras_model() -> ParasModel | None: + """ + Load and return the PARAS model from disk, caching it in memory. + + :return: the loaded PARAS model, or None if not found + """ + # Check if model is already cached + if "paras" in _model_cache: + return _model_cache["paras"] + + # Check if model path is defined + if PARAS_MODEL_PATH: + # Model path is defined; attempt to load the model + path = Path(PARAS_MODEL_PATH) + if path.is_file(): + current_app.logger.info(f"Loading PARAS model from {path}") + _model_cache["paras"] = ParasModel(threshold=0.1, keep_top=3, cache_dir=get_cache_dir(), model_path=path) + else: + current_app.logger.warning(f"PARAS model not found at {path}; letting BioCracker download into {CACHE_DIR}") + _model_cache["paras"] = None + return _model_cache["paras"] + else: + # Model path is not defined + current_app.logger.warning("PARAS_MODEL_PATH not set; letting BioCracker download into CACHE_DIR") + return None + + +def get_pfam_models() -> list[PfamModel] | None: + """ + Load and return Pfam models from disk, caching them in memory. + + :return: the loaded Pfam models, or None if not found + """ + # Check if model is already cached + if "pfam" in _model_cache: + return _model_cache["pfam"] + + # Check if HMM directory path is defined + if PFAM_HMM_DIR_PATH: + path = Path(PFAM_HMM_DIR_PATH) + if path.is_dir(): + current_app.logger.info(f"Loading Pfam models from {path}") + hmm_paths = [f for f in path.glob("*.hmm") if f.is_file()] + pfam_models = [] + for hmm_path in hmm_paths: + current_app.logger.info(f"Loading Pfam model from {hmm_path}") + pfam_model = PfamModel(hmm_path=hmm_path, label=hmm_path.stem) + pfam_models.append(pfam_model) + _model_cache["pfam"] = pfam_models + else: + current_app.logger.warning(f"Pfam HMM directory not found at {path}; letting BioCracker download into {CACHE_DIR}") + _model_cache["pfam"] = None + return _model_cache["pfam"] + else: + # HMM directory path is not defined + current_app.logger.warning("PFAM_HMM_DIR_PATH not set") + return None diff --git a/src/server/routes/query/align.py b/src/server/routes/query/align.py index 56f941c..e29df5f 100644 --- a/src/server/routes/query/align.py +++ b/src/server/routes/query/align.py @@ -214,12 +214,18 @@ def from_alignment( # Query row: split into original blocks for display q_map = {label_fn(it): it for it in query_readout.flatten_items()} + + if query_readout.kind == "compound": + # We sort on length descending for compounds + block_order = sorted( + range(len(query_readout.blocks)), + key=lambda i: len(query_readout.blocks[i]), + reverse=True, + ) + else: + # Keep original order for clusters + block_order = list(range(len(query_readout.blocks))) - block_order = sorted( - range(len(query_readout.blocks)), - key=lambda i: len(query_readout.blocks[i]), - reverse=True, - ) target_block_indices = [ bidx for bidx in block_order for _ in query_readout.blocks[bidx] ] @@ -390,15 +396,20 @@ def item_compare_fn(a: SequenceItem, b: SequenceItem) -> float: a_fam_toks = set(a.family_tokens) b_fam_toks = set(b.family_tokens) fam_tok_overlap = a_fam_toks.intersection(b_fam_toks) + fam_tok_differs = a_fam_toks.symmetric_difference(b_fam_toks) # Compare ancestor tokens a_anc_toks = set(a.ancestor_tokens) b_anc_toks = set(b.ancestor_tokens) anc_tok_overlap = a_anc_toks.intersection(b_anc_toks) + anc_tok_differs = a_anc_toks.symmetric_difference(b_anc_toks) tok_overlap = fam_tok_overlap.union(anc_tok_overlap) score += 0.5 * len(tok_overlap) + tok_differs = fam_tok_differs.union(anc_tok_differs) + score -= 0.5 * len(tok_differs) + if a.morgan_fp is not None and b.morgan_fp is not None: score += calculate_tanimoto_similarity(a.morgan_fp, b.morgan_fp) return score diff --git a/src/server/routes/query/featurize.py b/src/server/routes/query/featurize.py index 4e2e762..071984d 100644 --- a/src/server/routes/query/featurize.py +++ b/src/server/routes/query/featurize.py @@ -45,7 +45,7 @@ def calculate_payload_fingerprint( return fp -def _format_readout_compound(payload: Result) -> SequenceItemReadout: +def _format_readout_compound(payload: Result, fragment: bool = False) -> SequenceItemReadout: """ Format the readout for a compound payload. @@ -54,12 +54,17 @@ def _format_readout_compound(payload: Result) -> SequenceItemReadout: """ linear_readouts: list[list[MolNode]] = payload.linear_readout.paths + if fragment: + # Flatten linear_readouts + linear_readouts = [[node] for path in linear_readouts for node in path] + formatted_blocks = [] for path in linear_readouts: formatted_block = [NonGap.from_retromol_molnode(n) for n in path] formatted_blocks.append(formatted_block) return SequenceItemReadout( + kind="compound", block_ids=[f"structural readout {i+1}" for i in range(len(formatted_blocks))], blocks=formatted_blocks, ) @@ -86,6 +91,7 @@ def _format_readout_cluster(payload: LinearReadout) -> SequenceItemReadout: formatted_blocks.append(formatted_block) return SequenceItemReadout( + kind="cluster", block_ids=block_ids, blocks=formatted_blocks, ) diff --git a/src/server/routes/query/pipeline.py b/src/server/routes/query/pipeline.py index ae08f3b..b686654 100644 --- a/src/server/routes/query/pipeline.py +++ b/src/server/routes/query/pipeline.py @@ -17,6 +17,12 @@ from bionexus.db.models import CandidateCluster, Compound +# Turn off BiopythonDeprecationWarning warnings +import warnings +from Bio import BiopythonDeprecationWarning +warnings.simplefilter("ignore", BiopythonDeprecationWarning) + + @dataclass(frozen=True) class _InsKey: """ diff --git a/src/server/routes/query/retrieve.py b/src/server/routes/query/retrieve.py index 8848180..af1d25a 100644 --- a/src/server/routes/query/retrieve.py +++ b/src/server/routes/query/retrieve.py @@ -10,7 +10,7 @@ from routes.database import SessionLocal -ANN_SEARCH_RADIUS = 1000 +ANN_SEARCH_RADIUS = 2000 # increasing this will increase latency, but also improve results HNSW_SETTINGS = { "hnsw.iterative_scan": "strict_order", diff --git a/src/server/routes/query/seq.py b/src/server/routes/query/seq.py index 028f438..49923b6 100644 --- a/src/server/routes/query/seq.py +++ b/src/server/routes/query/seq.py @@ -1,6 +1,7 @@ """Module defining sequence item data structures for query results.""" from dataclasses import dataclass, field +from typing import Literal from rdkit.DataStructs.cDataStructs import ExplicitBitVect @@ -195,6 +196,8 @@ def from_biocracker_module(cls, m: Module) -> "SequenceItem": mol = smiles_to_mol(smiles) morgan_fp = mol_to_morgan_fingerprint(mol, radius=MORGAN_RADIUS, num_bits=MORGAN_SIZE) return cls(display_name=display_name, morgan_fp=morgan_fp, ancestor_tokens=["NRPS"]) + case NRPSModule(substrate=None): + return cls(display_name=DISPLAY_NAME_UNIDENTIFIED, ancestor_tokens=["NRPS"]) case _: raise NotImplementedError(f"BioCracker module type {type(m)} not supported yet") @@ -214,10 +217,12 @@ class SequenceItemReadout: """ Readout of sequence items in query results. + :var kind: either "compound" or "cluster" :var block_ids: list of block identifiers for display purposes :var blocks: list of blocks, where each block is a list of SequenceItems """ + kind: Literal["compound", "cluster"] block_ids: list[str] # only for display purposes blocks: list[list[SequenceItem]] @@ -227,7 +232,9 @@ def flatten_items(self) -> list[SequenceItem]: :return: flattened list of SequenceItems """ - # Sort blocks on size; longer blocks first - blocks_sorted = sorted(self.blocks, key=lambda b: len(b), reverse=True) + blocks = self.blocks + if self.kind == "compound": + # Only for compounds: sort blocks on size; longer blocks first + blocks = sorted(blocks, key=lambda b: len(b), reverse=True) - return [item for block in blocks_sorted for item in block] + return [item for block in blocks for item in block] \ No newline at end of file diff --git a/src/server/routes/query_service.py b/src/server/routes/query_service.py index d8b9e22..0d6a381 100644 --- a/src/server/routes/query_service.py +++ b/src/server/routes/query_service.py @@ -30,18 +30,23 @@ def query_item(): return jsonify({"error": "At least one of queryAgainstClusters or queryAgainstCompounds must be true"}), 400 # Retrieve item from session store + current_app.logger.info(f"Retrieving query item: session_id={session_id} item_id={item_id}") item = load_item(session_id, item_id) if item is None: return jsonify({"error": "Item not found"}), 404 + current_app.logger.info(f"Loaded item for querying: session_id={session_id} item_id={item_id} kind={item.get('kind')}") payload_type = item.get("kind", None) payload_blob = item.get("payload", None) if not payload_type or payload_type not in ["cluster", "compound"]: + current_app.logger.error(f"Invalid item kind for querying: {payload_type}") return jsonify({"error": "Invalid item kind"}), 400 if not payload_blob: + current_app.logger.error("Missing item payload for querying") return jsonify({"error": "Missing item payload"}), 400 try: + current_app.logger.info(f"Starting cross-modal retrieval for item_id={item_id}") msa_result: MSAResult = cross_modal_retrieval( payload_type=payload_type, payload_blob=payload_blob, @@ -49,7 +54,7 @@ def query_item(): query_against_compounds=query_against_compounds, ) except ValueError as e: - current_app.logger.error(f"error during cross-modal retrieval: {e}") + current_app.logger.error(f"Error during cross-modal retrieval: {e}") return jsonify({"error": str(e)}), 500 return jsonify(msa_result.to_dict()), 200 From b736c2b7e399bf3e78d7f470ddb192045f412ed4 Mon Sep 17 00:00:00 2001 From: David Meijer Date: Mon, 19 Jan 2026 22:28:56 +0100 Subject: [PATCH 20/34] UPD: per candidate cluster saving --- src/server/routes/cluster.py | 60 +++++++++++++++++++++++++++++++----- 1 file changed, 53 insertions(+), 7 deletions(-) diff --git a/src/server/routes/cluster.py b/src/server/routes/cluster.py index 1f8b056..ac6c789 100644 --- a/src/server/routes/cluster.py +++ b/src/server/routes/cluster.py @@ -12,11 +12,14 @@ from biocracker.pipelines.annotate_region import annotate_region from biocracker.query.modules import NRPSModule, PKSModule, linear_readout as biocracker_linear_readout -from routes.session_store import load_session_with_items, update_item +from routes.session_store import load_session_with_items, update_item, save_item, publish_session_event from routes.models_registry import get_paras_model, get_pfam_models +from helpers.guid import generate_guid blp_submit_cluster = Blueprint("submit_cluster", __name__) +MAX_ITEMS = int(os.getenv("MAX_ITEMS", "20")) + def _set_item_status_inplace(item: dict, status: str, error_message: str | None = None) -> None: """ @@ -107,8 +110,26 @@ def mark_processing(it: dict) -> None: tmp_path = tmp.name regions = load_regions(tmp_path, options=options) - # TODO: only saving readout of last candidate cluster found... need to keep all of them and create individual items for each - for region in regions: + if not regions: + raise ValueError("No candidate clusters found") + + existing_count = len(full_sess.get("items", []) or []) + remaining_slots = MAX_ITEMS - existing_count + if remaining_slots < 0: + remaining_slots = 0 + + max_clusters = 1 + remaining_slots + if len(regions) > max_clusters: + current_app.logger.warning( + "submit_cluster: truncating candidate clusters to %s due to max items limit", + max_clusters, + ) + + base_name = name or item.get("name") or "Cluster" + file_blob = file_content or item.get("fileContent") + + results: list[dict] = [] + for region in regions[:max_clusters]: annotate_region(region) readout = biocracker_linear_readout(region) @@ -128,17 +149,42 @@ def mark_processing(it: dict) -> None: score: float = sum(module_scores) / len(module_scores) if module_scores else 0.0 result_as_dict: dict = readout.to_dict() + results.append({"score": score, "payload": result_as_dict}) + + def _candidate_name(idx: int, total: int) -> str: + if total <= 1: + return base_name + return f"{base_name} (candidate cluster {idx})" # Set final status=done and store results on this item only def mark_done(it: dict) -> None: - it["name"] = name or it.get("name") - it["fileContent"] = file_content or it.get("fileContent") - it["score"] = score - it["payload"] = result_as_dict + it["name"] = _candidate_name(1, len(results)) + it["fileContent"] = file_blob + it["score"] = results[0]["score"] + it["payload"] = results[0]["payload"] _set_item_status_inplace(it, "done") update_item(session_id, item_id, mark_done) + extra_results = results[1:] + if extra_results: + now_ms = int(time.time() * 1000) + for idx, result in enumerate(extra_results, start=2): + new_item = { + "id": generate_guid(), + "kind": "cluster", + "name": _candidate_name(idx, len(results)), + "fileContent": file_blob, + "status": "done", + "errorMessage": None, + "updatedAt": now_ms, + "score": result["score"], + "payload": result["payload"], + } + save_item(session_id, new_item) + + publish_session_event(session_id, {"type": "session_merged"}) + except Exception as e: current_app.logger.exception(f"submit_cluster: error for item_id={item_id}") From bdbacb43db6e3e74249e7c07f1a43d181409c21c Mon Sep 17 00:00:00 2001 From: David Meijer Date: Mon, 19 Jan 2026 22:39:41 +0100 Subject: [PATCH 21/34] ENH: add compound/cluster labels to workspace items --- .../tabs/upload/WorkspaceItemCard.tsx | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/client/src/components/workspace/tabs/upload/WorkspaceItemCard.tsx b/src/client/src/components/workspace/tabs/upload/WorkspaceItemCard.tsx index 0cb5da4..8fea346 100644 --- a/src/client/src/components/workspace/tabs/upload/WorkspaceItemCard.tsx +++ b/src/client/src/components/workspace/tabs/upload/WorkspaceItemCard.tsx @@ -156,12 +156,18 @@ export const WorkspaceItemCard: React.FC = ({ - + {item.name} @@ -198,6 +204,20 @@ export const WorkspaceItemCard: React.FC = ({ )} + {isCompound ? ( + + ) : ( + + )} + {isCompound && ( Date: Wed, 21 Jan 2026 16:18:11 +0100 Subject: [PATCH 22/34] UPD: row info dialog popup --- .../tabs/discovery/DialogRowInfo.tsx | 225 +++++++++++++++ .../tabs/discovery/QueryResultView.tsx | 18 +- .../workspace/tabs/discovery/SortableRow.tsx | 273 +++++++++--------- .../tabs/discovery/WorkspaceDiscovery.tsx | 13 +- .../tabs/upload/DialogImportCompound.tsx | 1 - src/server/app.py | 6 +- src/server/routes/compound.py | 2 + src/server/routes/info.py | 52 ++++ src/server/routes/query/align.py | 55 +--- src/server/routes/query/featurize.py | 22 +- src/server/routes/query/pipeline.py | 9 +- src/server/routes/query/seq.py | 3 + 12 files changed, 469 insertions(+), 210 deletions(-) create mode 100644 src/client/src/components/workspace/tabs/discovery/DialogRowInfo.tsx create mode 100644 src/server/routes/info.py diff --git a/src/client/src/components/workspace/tabs/discovery/DialogRowInfo.tsx b/src/client/src/components/workspace/tabs/discovery/DialogRowInfo.tsx new file mode 100644 index 0000000..57fa9f6 --- /dev/null +++ b/src/client/src/components/workspace/tabs/discovery/DialogRowInfo.tsx @@ -0,0 +1,225 @@ +import React from "react"; +import Typography from "@mui/material/Typography"; +import Chip from "@mui/material/Chip"; +import Alert from "@mui/material/Alert"; +import CircularProgress from "@mui/material/CircularProgress"; +import { DialogWindow } from "../../../shared/DialogWindow"; +import { MsaRow } from "./QueryResultView"; + +type DialogRowInfoProps = { + open: boolean; + onClose: () => void; + msaRow: MsaRow; +}; + +type Annotation = { + scheme: string; + key: string; + value: string; +}; + +type Reference = { + name: string; + database_name: string; + database_identifier: string; +}; + +type RowInfoResponse = { + annotations: Annotation[]; + references: Reference[]; +}; + +function referenceToUrl(ref: Reference): string | null { + const { database_name, database_identifier } = ref; + + switch (database_name.toLowerCase()) { + case "npatlas": + return `https://www.npatlas.org/explore/compounds/${database_identifier}` + default: + return null; + }; +}; + +function referenceLabel(ref: Reference): string { + return `${ref.name} (${ref.database_name}: ${ref.database_identifier})`; +}; + +export const DialogRowInfo: React.FC = ({ + open, + onClose, + msaRow, +}) => { + const kind = msaRow.kind ?? null; + const dbId = msaRow.db_id ?? null; + + const canFetch = open && kind != null && dbId != null; + + const [loading, setLoading] = React.useState(false); + const [error, setError] = React.useState(null); + const [data, setData] = React.useState(null); + + const cacheRef = React.useRef(new Map()); + + const cacheKey = kind && dbId != null ? `${kind}:${dbId}` : null; + + const sortedAnnotations = React.useMemo(() => { + if (!data?.annotations) return []; + return [...data.annotations].sort((a, b) => { + const s = a.scheme.localeCompare(b.scheme); + if (s !== 0) return s; + + const k = a.key.localeCompare(b.key); + if (k !== 0) return k; + + return a.value.localeCompare(b.value); + }); + }, [data?.annotations]); + + const sortedReferences = React.useMemo(() => { + if (!data?.references) return []; + return [...data.references].sort((a, b) => { + const d = a.database_name.localeCompare(b.database_name); + if (d !== 0) return d; + + const i = a.database_identifier.localeCompare(b.database_identifier); + if (i !== 0) return i; + + return a.name.localeCompare(b.name); + }); + }, [data?.references]); + + React.useEffect(() => { + if (!canFetch || !cacheKey) { + setLoading(false); + setError(null); + setData(null); + return; + } + + // Serve from cache immediately if present + const cached = cacheRef.current.get(cacheKey); + if (cached) { + setLoading(false); + setError(null); + setData(cached); + return; + } + + // Otherwise fetch + const controller = new AbortController(); + let alive = true; + + (async () => { + try { + setLoading(true); + setError(null); + + const res = await fetch( + `/api/itemInfo?kind=${encodeURIComponent(kind!)}&db_id=${dbId}`, + { + method: "GET", + signal: controller.signal, + headers: { Accept: "application/json" }, + } + ); + + if (!res.ok) { + throw new Error(`Error fetching data: ${res.status} ${res.statusText}`); + } + + const json = (await res.json()) as RowInfoResponse; + if (!alive) return; + + // Store in cache + cacheRef.current.set(cacheKey, json); + + setData(json); + } catch (e: any) { + if (e?.name === "AbortError") return; + if (!alive) return; + setError(e?.message ?? "Unknown error"); + } finally { + if (!alive) return; + setLoading(false); + } + })(); + + return () => { + alive = false; + controller.abort(); + }; + }, [canFetch, cacheKey, kind, dbId]); + + return ( + + {!canFetch || (!data?.annotations.length && !data?.references.length) && ( + + No additional information available for this row. + + )} + + {data?.references.length ? ( + <> + + References + + {sortedReferences.map((reference) => { + const url = referenceToUrl(reference); + + return ( + + ); + })} + + ) : null} + + {data?.annotations.length ? ( + <> + + Annotations + + {sortedAnnotations.map((annotation, index) => ( + + ))} + + ) : null} + + {loading && ( + + )} + + {error && ( + + {error} + + )} + + ); +}; diff --git a/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx b/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx index 7c8f039..00e17cd 100644 --- a/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx +++ b/src/client/src/components/workspace/tabs/discovery/QueryResultView.tsx @@ -32,24 +32,19 @@ export type Sequence = { sequence: SequenceItem[]; }; -export type Reference = { - name: string; - database_name: string; - database_identifier: string; -}; - -export type MsaItem = { +export type MsaRow= { id: string; name?: string; + kind?: "compound" | "cluster" | null; + db_id?: number | null; alignment_score: number | null; cosine_score: number | null; match_score: number | null; sequence: Sequence[]; - references: Reference[]; }; export type QueryResult = { - msa: MsaItem[]; + msa: MsaRow[]; }; type QueryResultViewProps = { @@ -250,7 +245,7 @@ const renderTooltipLabel = ( export const QueryResultView: React.FC = ({ result }) => { // Keep order locally - const [msa, setMsa] = React.useState(result.msa); + const [msa, setMsa] = React.useState(result.msa); // Invert order of motifs in msa const invertMsaMotifOrder = () => { @@ -387,7 +382,7 @@ export const QueryResultView: React.FC = ({ result }) => { py: 1, }} > - {msa.map((row) => ( + {msa.map((row, rowIndex) => ( = ({ result }) => { row={row} labelWidth={labelWidth} columnTemplate={colTemplate} + hasRowInfo={rowIndex > 0} > {row.sequence.map((subseq) => { const allGaps = subseq.sequence.every((it) => it.isGap); diff --git a/src/client/src/components/workspace/tabs/discovery/SortableRow.tsx b/src/client/src/components/workspace/tabs/discovery/SortableRow.tsx index 5813e8f..3693c40 100644 --- a/src/client/src/components/workspace/tabs/discovery/SortableRow.tsx +++ b/src/client/src/components/workspace/tabs/discovery/SortableRow.tsx @@ -6,32 +6,27 @@ import Tooltip from "@mui/material/Tooltip"; import { CSS } from "@dnd-kit/utilities"; import { useSortable } from "@dnd-kit/sortable"; import DragIndicatorIcon from "@mui/icons-material/DragIndicator"; -import { MsaItem, Reference } from "./QueryResultView"; +import InfoOutlineIcon from "@mui/icons-material/InfoOutline"; +import { MsaRow } from "./QueryResultView"; +import { DialogRowInfo } from "./DialogRowInfo"; interface SortableRowProps { - row: MsaItem; + row: MsaRow; labelWidth: number; columnTemplate: string; children: React.ReactNode; + hasRowInfo: boolean; }; const fmt = (v: number | null, digits: number) => v == null || Number.isNaN(v) ? "" : v.toFixed(digits); -const referenceToUrl = (ref: Reference): string | null => { - switch (ref.database_name.toLowerCase()) { - case "npatlas": - return `https://www.npatlas.org/explore/compounds/${ref.database_identifier}`; - default: - return null; - } -}; - export const SortableRow: React.FC = ({ row, labelWidth, columnTemplate, children, + hasRowInfo, }) => { const { attributes, @@ -42,169 +37,159 @@ export const SortableRow: React.FC = ({ transition, } = useSortable({ id: row.id }); - // Setup outlink to reference for row - const [ref, setRef] = React.useState(null); - const url = React.useMemo(() => (ref ? referenceToUrl(ref) : null), [ref]); - - // Set reference on mount - React.useEffect(() => { - if (row.references && row.references.length > 0) { - setRef(row.references[0]); - } else { - setRef(null); - } - }, [row.references]); - const alignText = fmt(row.alignment_score, 2); const cosineText = fmt(row.cosine_score, 2); const matchText = fmt(row.match_score, 2); const scoreBlockWidth = 40; + const [openRowInfo, setOpenRowInfo] = React.useState(false); + + const handleRowInfo = (event: React.MouseEvent) => { + setOpenRowInfo(true); + }; + return ( - + <> e.stopPropagation()} // don't trigger center selection on drag - > - - - - - ) => { - if (url) e.stopPropagation(); // prevent row drag / selection - }} - sx={{ - fontWeight: 600, - maxWidth: labelWidth - 100, - lineHeight: "20px", - overflow: "hidden", - textOverflow: "ellipsis", - whiteSpace: "nowrap", - zIndex: 101, - userSelect: "none", - - // link-only styling - textDecoration: url ? "underline" : "none", - cursor: url ? "pointer" : "default", - color: "inherit", - - "&:hover": url ? { color: "primary.main" } : undefined, - }} - > - {row.name || row.id} - - - - {/* Scores */} - e.stopPropagation()} // don't trigger center selection on drag > - - + + {hasRowInfo && ( + + - {alignText} - - - - - {cosineText} - - - + /> + + )} + + - {matchText} + {row.name || row.id} + + {/* Scores */} + + + + {alignText} + + + + + {cosineText} + + + + + {matchText} + + + - - + - {/* Motifs */} - {children} + {/* Motifs */} + {children} + - {/* Row line */} - {/* */} - + setOpenRowInfo(false)} + msaRow={row} + /> + ); }; diff --git a/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx b/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx index f3d0dac..f82b7d6 100644 --- a/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx +++ b/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx @@ -30,6 +30,9 @@ export const WorkspaceDiscovery: React.FC = ({ session, const theme = useTheme(); const { pushNotification } = useNotifications(); + // Check if session has items + const hasItems = session.items.length > 0; + // Query state const [selectedItemId, setSelectedItemId] = React.useState(""); const [queryLoading, setQueryLoading] = React.useState(false); @@ -143,7 +146,7 @@ export const WorkspaceDiscovery: React.FC = ({ session, - + = ({ session, "&.MuiInputLabel-shrink": { transform: "translate(14px, -9px) scale(0.75)" }, }} > - {selectedItemId ? "Item to use for querying" : "Select an item to use for querying"} + {!hasItems + ? "No items available to select" + : selectedItemId + ? "Item to use for querying" + : "Select an item to use for querying"} setSelectedItemId(e.target.value)} + disabled={!hasItems || queryLoading} + MenuProps={{ + PaperProps: { + sx: { + "& .MuiMenuItem-root": { + userSelect: "none", + borderRadius: 0, + }, + }, + }, + }} + sx={{ + "& .MuiSelect-select": { userSelect: "none" }, + "& .MuiSelect-select:focus": { backgroundColor: "transparent" }, + "& .MuiSelect-select:focus-visible": { outline: "none" }, + "&.MuiInputBase-root": { height: 44 }, + }} + > + {session.items.map((item) => ( + + {item.name} + + ))} + + + + { + const next = Number(e.target.value); + if (!Number.isNaN(next)) { + const clamped = Math.min(100, Math.max(0, next)); + setThresholdPct(clamped); + } + }} + disabled={queryLoading} + inputProps={{ min: 0, max: 100, step: 1 }} + helperText="Percentage of the query self-alignment score used as cutoff." + InputProps={{ + endAdornment: %, + }} + /> + + + + setQueryAgainstClusters(e.target.checked)} + /> + } + label="Query against clusters" + /> + setQueryAgainstCompounds(e.target.checked)} + /> + } + label="Query against compounds" + /> + + + + + + {queryLoading && } + + + {alert.text} + + + + + {queryResult && ( + + + + + Enrichment results + + + + + + + + + + + + + + Self-alignment score: {queryResult.summary.self_alignment_score.toFixed(2)}. Cutoff:{" "} + {queryResult.summary.alignment_threshold.toFixed(2)}. Background counts use all + database items for the selected types. + + + {queryResult.warnings.map((warning) => ( + + {warning} + + ))} + + {queryResult.results.length === 0 ? ( + + No enriched labels detected for this threshold. + + ) : ( + + + + + + Label + In-group + Background + P-value + Adj. P + + + + {queryResult.results.map((result) => { + const label = labelText(result); + const isSignificant = result.p_adjusted <= SIGNIFICANCE_ALPHA; + return ( + + + 0.05)" + } + arrow + > + + {isSignificant ? ( + + ) : ( + + )} + + + + + + + {label} + + + + + {formatCount(result.in_group_count, queryResult.summary.in_group)} + + + {formatCount(result.background_count, queryResult.summary.population_total)} + + {formatPValue(result.p_value)} + {formatPValue(result.p_adjusted)} + + ); + })} + +
    +
    + )} +
    +
    + )} + + ); +}; diff --git a/src/client/src/components/workspace/tabs/enrichment/types.ts b/src/client/src/components/workspace/tabs/enrichment/types.ts new file mode 100644 index 0000000..735d898 --- /dev/null +++ b/src/client/src/components/workspace/tabs/enrichment/types.ts @@ -0,0 +1,32 @@ +export type EnrichmentLabel = { + scheme: string; + key: string; + value: string; +}; + +export type EnrichmentResult = { + label: EnrichmentLabel; + p_value: number; + p_adjusted: number; + in_group_count: number; + background_count: number; + in_group_fraction: number; + background_fraction: number; +}; + +export type EnrichmentSummary = { + neighbors_requested: number; + total_neighbors: number; + population_total: number; + in_group: number; + out_group: number; + threshold_pct: number; + self_alignment_score: number; + alignment_threshold: number; +}; + +export type EnrichmentResponse = { + summary: EnrichmentSummary; + warnings: string[]; + results: EnrichmentResult[]; +}; diff --git a/src/server/routes/query/enrichment.py b/src/server/routes/query/enrichment.py new file mode 100644 index 0000000..0734b0b --- /dev/null +++ b/src/server/routes/query/enrichment.py @@ -0,0 +1,296 @@ +"""Enrichment analysis for query items using annotation labels.""" + +from __future__ import annotations + +from collections import defaultdict +import math +from typing import Any + +import sqlalchemy as sa +from sqlalchemy.orm import selectinload +from flask import current_app + +from bionexus.db.models import CandidateCluster, Compound + +from routes.database import SessionLocal +from routes.query.align import score_by_alignment +from routes.query.featurize import featurize_item, format_payload_readout, load_payload +from routes.query.retrieve import ann_search, ANN_SEARCH_RADIUS + + +LabelKey = tuple[str, str, str] + + +def _log_comb(n: int, k: int) -> float: + if k < 0 or k > n: + return float("-inf") + return math.lgamma(n + 1) - math.lgamma(k + 1) - math.lgamma(n - k + 1) + + +def _log_hypergeom_pmf(k: int, K: int, N: int, n: int) -> float: + if k < 0 or k > K or k > n or n > N: + return float("-inf") + return _log_comb(K, k) + _log_comb(N - K, n - k) - _log_comb(N, n) + + +def _hypergeom_sf(k: int, K: int, N: int, n: int) -> float: + max_i = min(K, n) + if k > max_i: + return 0.0 + log_terms = [_log_hypergeom_pmf(i, K, N, n) for i in range(k, max_i + 1)] + max_log = max(log_terms) + total = sum(math.exp(term - max_log) for term in log_terms) + return math.exp(max_log) * total + + +def _benjamini_hochberg(pvals: list[float]) -> list[float]: + m = len(pvals) + if m == 0: + return [] + order = sorted(range(m), key=lambda i: pvals[i]) + adjusted = [0.0] * m + min_adj = 1.0 + for rank in range(m, 0, -1): + idx = order[rank - 1] + adj = pvals[idx] * m / rank + if adj < min_adj: + min_adj = adj + adjusted[idx] = min(min_adj, 1.0) + return adjusted + + +def _annotation_labels(obj: Any) -> set[LabelKey]: + labels: set[LabelKey] = set() + annotations = getattr(obj, "annotations", None) or [] + for ann in annotations: + scheme = getattr(ann, "scheme", None) + key = getattr(ann, "key", None) + value = getattr(ann, "value", None) + if scheme is None or key is None or value is None: + continue + labels.add((str(scheme), str(key), str(value))) + return labels + + +def _label_counts_for_model( + session: Any, + model: type[CandidateCluster] | type[Compound], +) -> dict[LabelKey, int]: + rel = sa.inspect(model).relationships.get("annotations") + ann_attr = getattr(model, "annotations", None) + if rel is None or ann_attr is None: + return {} + + ann_cls = rel.mapper.class_ + stmt = ( + sa.select( + ann_cls.scheme, + ann_cls.key, + ann_cls.value, + sa.func.count(sa.distinct(model.id)), + ) + .select_from(model) + .join(ann_attr) + .group_by(ann_cls.scheme, ann_cls.key, ann_cls.value) + ) + rows = session.execute(stmt).all() + return { + (str(scheme), str(key), str(value)): int(count) + for scheme, key, value, count in rows + } + + +def _population_label_counts( + query_against_clusters: bool, + query_against_compounds: bool, +) -> tuple[dict[LabelKey, int], int]: + counts: dict[LabelKey, int] = defaultdict(int) + total = 0 + + with SessionLocal() as session: + if query_against_clusters: + total += int(session.execute( + sa.select(sa.func.count(CandidateCluster.id)) + ).scalar_one() or 0) + for label, count in _label_counts_for_model(session, CandidateCluster).items(): + counts[label] += count + + if query_against_compounds: + total += int(session.execute( + sa.select(sa.func.count(Compound.id)) + ).scalar_one() or 0) + for label, count in _label_counts_for_model(session, Compound).items(): + counts[label] += count + + return counts, total + + +def _load_annotation_map( + items: list[CandidateCluster | Compound], +) -> dict[tuple[str, int], set[LabelKey]]: + cluster_ids = [i.id for i in items if isinstance(i, CandidateCluster)] + compound_ids = [i.id for i in items if isinstance(i, Compound)] + + labels_by_key: dict[tuple[str, int], set[LabelKey]] = {} + + with SessionLocal() as session: + if cluster_ids: + clusters = session.execute( + sa.select(CandidateCluster) + .where(CandidateCluster.id.in_(cluster_ids)) + .options(selectinload(CandidateCluster.annotations)) + ).scalars().all() + for cluster in clusters: + labels_by_key[("cluster", cluster.id)] = _annotation_labels(cluster) + + if compound_ids: + compounds = session.execute( + sa.select(Compound) + .where(Compound.id.in_(compound_ids)) + .options(selectinload(Compound.annotations)) + ).scalars().all() + for compound in compounds: + labels_by_key[("compound", compound.id)] = _annotation_labels(compound) + + return labels_by_key + + +def enrichment_study( + payload_type: str, + payload_blob: dict[str, Any], + query_against_clusters: bool, + query_against_compounds: bool, + threshold_pct: float, +) -> dict[str, Any]: + if threshold_pct < 0.0 or threshold_pct > 100.0: + raise ValueError("threshold_pct must be between 0 and 100") + + query_vec, query_blocks = featurize_item(payload_type, payload_blob) + + nns: list[tuple[CandidateCluster | Compound, float]] = ann_search( + query_vec, + query_against_clusters=query_against_clusters, + query_against_compounds=query_against_compounds, + ) + current_app.logger.debug("enrichment: found %s nearest neighbors", len(nns)) + + if not nns or not query_blocks: + raise ValueError("no nearest neighbors found or query blocks are empty") + + nns_featurized = [] + retrieved_items: list[CandidateCluster | Compound] = [] + retrieved_keys: list[tuple[str, int] | None] = [] + + for item, _distance in nns: + assert isinstance(item, (CandidateCluster, Compound)) + item_type = "cluster" if isinstance(item, CandidateCluster) else "compound" + item_blob = item.biocracker if item_type == "cluster" else item.retromol + item_payload = load_payload(item_type, item_blob) + item_db_id = getattr(item, "id", None) + item_readout = format_payload_readout(item_type, item_payload, item_db_id) + nns_featurized.append(item_readout) + retrieved_items.append(item) + retrieved_keys.append((item_type, item_db_id) if item_db_id is not None else None) + + if not nns_featurized: + raise ValueError("failed to featurize nearest neighbors") + + _, self_scores, _ = score_by_alignment(query_blocks, [query_blocks]) + if not self_scores: + raise ValueError("failed to compute self alignment score") + + self_score = float(self_scores[0]) + if self_score <= 0.0: + raise ValueError("self alignment score is non-positive; threshold invalid") + + alignment_results = score_by_alignment(query_blocks, nns_featurized) + _aln_results, aln_scores, _match_scores = alignment_results + + threshold_score = self_score * (threshold_pct / 100.0) + + in_group_indices = [i for i, score in enumerate(aln_scores) if score >= threshold_score] + in_group_set = set(in_group_indices) + + total_neighbors = len(retrieved_items) + in_group_count = len(in_group_set) + + warnings: list[str] = [] + if in_group_count == 0: + warnings.append("No candidates meet the alignment threshold; enrichment cannot be computed.") + if in_group_count == total_neighbors and total_neighbors > 0: + warnings.append( + "All candidates are in the in-group; enrichment results may be unreliable." + ) + + labels_by_key = _load_annotation_map(retrieved_items) + population_counts, population_total = _population_label_counts( + query_against_clusters=query_against_clusters, + query_against_compounds=query_against_compounds, + ) + out_group_count = population_total - in_group_count + if out_group_count < 0: + warnings.append("In-group exceeds population size; check background query scope.") + out_group_count = 0 + if population_total == 0: + warnings.append("No items available in the database for the selected types.") + + in_group_counts: dict[LabelKey, int] = defaultdict(int) + + for idx, key in enumerate(retrieved_keys): + if key is None: + continue + labels = labels_by_key.get(key, set()) + for label in labels: + if idx in in_group_set: + in_group_counts[label] += 1 + + if not in_group_counts: + warnings.append("No annotations found in the in-group.") + if not population_counts: + warnings.append("No annotations found in the database background.") + + results: list[dict[str, Any]] = [] + + if in_group_count > 0 and population_counts: + pvals: list[float] = [] + labels: list[LabelKey] = [] + for label, hits_in_group in in_group_counts.items(): + total_hits = population_counts.get(label, 0) + if total_hits <= 0: + continue + if hits_in_group <= 0: + continue + pval = _hypergeom_sf(hits_in_group, total_hits, population_total, in_group_count) + pvals.append(pval) + labels.append(label) + + adjusted = _benjamini_hochberg(pvals) + for (scheme, key, value), pval, padj in zip(labels, pvals, adjusted): + in_hits = in_group_counts.get((scheme, key, value), 0) + total_hits = population_counts.get((scheme, key, value), 0) + results.append({ + "label": {"scheme": scheme, "key": key, "value": value}, + "p_value": pval, + "p_adjusted": padj, + "in_group_count": in_hits, + "background_count": total_hits, + "in_group_fraction": in_hits / in_group_count if in_group_count else 0.0, + "background_fraction": total_hits / population_total if population_total else 0.0, + }) + + results.sort(key=lambda r: (r["p_adjusted"], r["p_value"])) + + return { + "summary": { + "neighbors_requested": ANN_SEARCH_RADIUS, + "total_neighbors": total_neighbors, + "population_total": population_total, + "in_group": in_group_count, + "out_group": out_group_count, + "threshold_pct": threshold_pct, + "self_alignment_score": self_score, + "alignment_threshold": threshold_score, + }, + "warnings": warnings, + "results": results, + } diff --git a/src/server/routes/query_service.py b/src/server/routes/query_service.py index 0d6a381..d94e4a3 100644 --- a/src/server/routes/query_service.py +++ b/src/server/routes/query_service.py @@ -4,6 +4,7 @@ from routes.session_store import load_item from routes.query.pipeline import cross_modal_retrieval +from routes.query.enrichment import enrichment_study from routes.query.align import MSAResult @@ -58,3 +59,62 @@ def query_item(): return jsonify({"error": str(e)}), 500 return jsonify(msa_result.to_dict()), 200 + + +@blp_query_item.get("/api/enrichment") +def enrichment(): + """ + Endpoint to run annotation enrichment on nearest neighbors. + """ + session_id = request.args.get("sessionId", "").strip() + item_id = request.args.get("itemId", "").strip() + if not session_id: + return jsonify({"error": "Missing sessionId"}), 400 + if not item_id: + return jsonify({"error": "Missing itemId"}), 400 + + threshold_raw = request.args.get("thresholdPct", "80").strip() + try: + threshold_pct = float(threshold_raw) + except ValueError: + return jsonify({"error": "thresholdPct must be a number"}), 400 + if threshold_pct < 0 or threshold_pct > 100: + return jsonify({"error": "thresholdPct must be between 0 and 100"}), 400 + + query_against_clusters = request.args.get("queryAgainstClusters", "true").lower() == "true" + query_against_compounds = request.args.get("queryAgainstCompounds", "true").lower() == "true" + if not query_against_clusters and not query_against_compounds: + return jsonify({"error": "At least one of queryAgainstClusters or queryAgainstCompounds must be true"}), 400 + + # Retrieve item from session store + current_app.logger.info(f"Retrieving enrichment item: session_id={session_id} item_id={item_id}") + item = load_item(session_id, item_id) + if item is None: + return jsonify({"error": "Item not found"}), 404 + current_app.logger.info( + f"Loaded item for enrichment: session_id={session_id} item_id={item_id} kind={item.get('kind')}" + ) + + payload_type = item.get("kind", None) + payload_blob = item.get("payload", None) + if not payload_type or payload_type not in ["cluster", "compound"]: + current_app.logger.error(f"Invalid item kind for enrichment: {payload_type}") + return jsonify({"error": "Invalid item kind"}), 400 + if not payload_blob: + current_app.logger.error("Missing item payload for enrichment") + return jsonify({"error": "Missing item payload"}), 400 + + try: + current_app.logger.info(f"Starting enrichment study for item_id={item_id}") + result = enrichment_study( + payload_type=payload_type, + payload_blob=payload_blob, + query_against_clusters=query_against_clusters, + query_against_compounds=query_against_compounds, + threshold_pct=threshold_pct, + ) + except ValueError as e: + current_app.logger.error(f"Error during enrichment study: {e}") + return jsonify({"error": str(e)}), 500 + + return jsonify(result), 200 From f013b9af869fa44e5899a15edf5130ee79038794 Mon Sep 17 00:00:00 2001 From: David Meijer Date: Thu, 22 Jan 2026 00:38:38 +0100 Subject: [PATCH 28/34] DOC: update documentation upload tab --- .../workspace/tabs/upload/WorkspaceUpload.tsx | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx b/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx index 77f9a5e..887b630 100644 --- a/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx +++ b/src/client/src/components/workspace/tabs/upload/WorkspaceUpload.tsx @@ -286,9 +286,19 @@ export const WorkspaceUpload: React.FC = ({ session, setSe color={(theme.vars || theme).palette.primary.main} sx={{ fontWeight: "500" }} > - Discovery tab + Discovery - . A maximum of {MAX_ITEMS} items can be imported into the workspace. Keep an eye on for updates on your queries. +  and  + + Enrichment + +  tabs. A maximum of {MAX_ITEMS} items can be imported into the workspace. Keep an eye on for updates on your queries. From 2eca65dc96646a01f0f8774c18609b1d301dfc74 Mon Sep 17 00:00:00 2001 From: David Meijer Date: Mon, 26 Jan 2026 11:44:21 +0100 Subject: [PATCH 29/34] FIX: correct import of type --- .../tabs/discovery/WorkspaceDiscovery.tsx | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx b/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx index f82b7d6..6b78871 100644 --- a/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx +++ b/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx @@ -19,7 +19,8 @@ import { useNotifications } from "../../NotificationProvider"; import { Link as RouterLink } from "react-router-dom"; import { Session } from "../../../../features/session/types"; import { Select } from "@mui/material"; -import { QueryResult, QueryResultView } from "./QueryResultView"; +import { QueryResultView } from "./QueryResultView"; +import { QueryResult } from "./types"; type WorkspaceDiscoveryProps = { session: Session; @@ -51,16 +52,6 @@ export const WorkspaceDiscovery: React.FC = ({ session, [setSession] ); - // Helper to build deps for import service - const deps = React.useMemo( - () => ({ - setSession: setSessionSafe, - pushNotification, - sessionId: session.sessionId, - }), - [setSessionSafe, pushNotification, session.sessionId] - ); - // Memoized alert based on query state const alert = React.useMemo(() => { if (queryError) { @@ -80,6 +71,7 @@ export const WorkspaceDiscovery: React.FC = ({ session, const params = new URLSearchParams({ sessionId: session.sessionId, itemId, + queryAgainstUserUploads: String(queryAgainstUserUploads), queryAgainstCompounds: String(queryAgainstCompounds), queryAgainstClusters: String(queryAgainstClusters) }); @@ -206,7 +198,7 @@ export const WorkspaceDiscovery: React.FC = ({ session, control={ setQueryAgainstClusters(e.target.checked)} /> } @@ -216,7 +208,7 @@ export const WorkspaceDiscovery: React.FC = ({ session, control={ setQueryAgainstCompounds(e.target.checked)} /> } From aaf3e4c264b8b074004b38f6e6e5c41606b21157 Mon Sep 17 00:00:00 2001 From: David Meijer Date: Mon, 26 Jan 2026 12:31:56 +0100 Subject: [PATCH 30/34] UPD: include user uploads in alignment; refactor look --- .../tabs/discovery/WorkspaceDiscovery.tsx | 26 +++- src/server/routes/query/align.py | 112 +++++++++++++++--- src/server/routes/query/pipeline.py | 76 ++++++++++-- src/server/routes/query_service.py | 28 ++++- 4 files changed, 210 insertions(+), 32 deletions(-) diff --git a/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx b/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx index 6b78871..b87a4d0 100644 --- a/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx +++ b/src/client/src/components/workspace/tabs/discovery/WorkspaceDiscovery.tsx @@ -43,6 +43,7 @@ export const WorkspaceDiscovery: React.FC = ({ session, // Query settings const [queryAgainstCompounds, setQueryAgainstCompounds] = React.useState(true); const [queryAgainstClusters, setQueryAgainstClusters] = React.useState(true); + const [queryAgainstUserUploads, setQueryAgainstUserUploads] = React.useState(false); // Wrap parent setter (Session | null) into the deps shape (Session-only functional updater) const setSessionSafe = React.useCallback( @@ -134,7 +135,7 @@ export const WorkspaceDiscovery: React.FC = ({ session, > Upload tab -  for cross-modal retrieval against the BioNexus database. +  for cross-modal retrieval against the BioNexus database or other items in your workspace. @@ -198,7 +199,10 @@ export const WorkspaceDiscovery: React.FC = ({ session, control={ setQueryAgainstClusters(e.target.checked)} /> } @@ -208,12 +212,28 @@ export const WorkspaceDiscovery: React.FC = ({ session, control={ setQueryAgainstCompounds(e.target.checked)} /> } label="Query against compounds" /> + setQueryAgainstUserUploads(e.target.checked)} + /> + } + label="Query against user uploads" + /> diff --git a/src/server/routes/query/align.py b/src/server/routes/query/align.py index 890852f..8ad0e98 100644 --- a/src/server/routes/query/align.py +++ b/src/server/routes/query/align.py @@ -87,7 +87,7 @@ class MSARow: :var sequence: list of MSASequenceBlocks in the row :var alignment_score: optional alignment score :var cosine_score: optional cosine similarity score - :var match_score: ratio of items aligned against target + :var match_score: ratio of item tokens visible in the alignment :var id: unique identifier """ @@ -146,7 +146,7 @@ def from_alignment( # rows[0] is query_readout, rows[1:] correspond to retrieved_readouts query_readout: SequenceItemReadout, retrieved_readouts: list[SequenceItemReadout], - retrieved_items: list[Compound | CandidateCluster], + retrieved_items: list[Compound | CandidateCluster | None], retrieved_alignment_scores: list[float], retrieved_cosine_scores: list[float], retrieved_match_scores: list[float], @@ -155,6 +155,8 @@ def from_alignment( gap_repr: str, display_name_unidentified: str, gap_display_name: str, + retrieved_row_names: list[str | None] | None = None, + query_name: str | None = None, ) -> "MSAResult": """ Create an MSAResult from alignment data. @@ -163,7 +165,7 @@ def from_alignment( :param block_maps: per-row ownership of each global col -> block_idx or None :param query_readout: SequenceItemReadout for the query row :param retrieved_readouts: list of SequenceItemReadouts for retrieved rows - :param retrieved_items: list of retrieved Compound or CandidateCluster items + :param retrieved_items: list of retrieved items (db models or None for uploads) :param retrieved_alignment_scores: list of alignment scores for retrieved rows :param retrieved_cosine_scores: list of cosine similarity scores for retrieved rows :param retrieved_match_scores: list of match scores for retrieved rows @@ -171,6 +173,8 @@ def from_alignment( :param gap_repr: string representation used for gaps in the alignment :param display_name_unidentified: display name for unidentified items :param gap_display_name: display name for gaps + :param retrieved_row_names: optional list of display names for retrieved rows + :param query_name: optional display name for the query row :return: constructed MSAResult """ if not rows: @@ -238,8 +242,9 @@ def from_alignment( sequence=current_tokens, )) + display_query_name = (query_name or "").strip() query_row = MSARow( - name="Query", + name=f"Query: {display_query_name}" if display_query_name else "Query", kind=None, db_id=None, alignment_score=None, @@ -267,6 +272,9 @@ def from_alignment( if len(retrieved_match_scores) != len(rows) - 1: raise ValueError("retrieved_match_scores/rows length mismatch") + + if retrieved_row_names is not None and len(retrieved_row_names) != len(rows) - 1: + raise ValueError("retrieved_row_names/rows length mismatch") for ridx in range(1, len(rows)): row_tokens = rows[ridx] @@ -311,21 +319,29 @@ def from_alignment( sequence=current_tokens, )) - # Retrieve references - with SessionLocal() as session: - item_type = Compound if isinstance(item, Compound) else CandidateCluster - refs = get_references(session, item_type, item.id) + row_name = None + if retrieved_row_names is not None: + row_name = retrieved_row_names[ridx - 1] - if refs: - name = refs[0].name - else: - if isinstance(item, Compound): - name = f"Compound {item.id}" + if row_name is None and item is not None: + # Retrieve references + with SessionLocal() as session: + item_type = Compound if isinstance(item, Compound) else CandidateCluster + refs = get_references(session, item_type, item.id) + + if refs: + row_name = refs[0].name else: - name = f"Cluster {item.file_name}" + if isinstance(item, Compound): + row_name = f"Compound {item.id}" + else: + row_name = f"Cluster {item.file_name}" + + if row_name is None: + row_name = "Uploaded item" result.msa.append(MSARow( - name=name, + name=row_name, kind=readout.kind, db_id=readout.db_id, alignment_score=retrieved_alignment_scores[ridx - 1], @@ -459,6 +475,8 @@ def score_by_alignment( aln_scores: list[float] = [] match_scores: list[float] = [] + gap_repr = Gap.alignment_representation() + for item in items: aligner = _setup_aligner(query, item) @@ -466,15 +484,15 @@ def score_by_alignment( aligner=aligner, target=query.flatten_items(), candidates=item.blocks, - gap_repr=Gap.alignment_representation(), + gap_repr=gap_repr, mask_repr=Mask.alignment_representation(), allow_block_reverse=True, ) # Get full length of item cum_len = len(item.flatten_items()) - aligned_items = cum_len - sum(len(item.blocks[block_idx]) for block_idx in aln.unused_blocks) - match_score = aligned_items / cum_len if cum_len > 0 else 0.0 + visible_items = _count_visible_tokens_in_docking(aln, gap_repr) + match_score = visible_items / cum_len if cum_len > 0 else 0.0 # Penalize unaligned regions unaligned_items = 0 @@ -491,3 +509,61 @@ def score_by_alignment( match_scores.append(match_score) return aln_results, aln_scores, match_scores + + +def _count_visible_tokens_in_docking(docking: DockingResult, gap_repr: str) -> int: + """ + Count item tokens that will be visible in the alignment for a docking result. + + This mirrors the region slicing logic used in the MSA merge and respects + collision resolution within a single row (max score wins on target columns). + """ + visible = 0 + placements = sorted(docking.placements, key=lambda p: (p.start, p.end)) + + # Count insertion tokens (unique columns, so no collision handling needed). + for placement in placements: + target_pos = -1 + in_region = False + prefix_anchor = placement.start - 1 + + for c_tok, b_tok in zip(placement.center_aln, placement.block_aln): + if c_tok != gap_repr: + target_pos += 1 + if target_pos > placement.end: + break + in_region = placement.start <= target_pos <= placement.end + else: + if (in_region or target_pos == prefix_anchor) and b_tok != gap_repr: + visible += 1 + + # Count target-column tokens with collision handling (max score wins). + score_by_pos: dict[int, float] = {} + has_token_by_pos: dict[int, bool] = {} + + for placement in placements: + target_pos = -1 + in_region = False + score = float(placement.score) + + for c_tok, b_tok in zip(placement.center_aln, placement.block_aln): + if c_tok != gap_repr: + target_pos += 1 + if target_pos > placement.end: + break + in_region = placement.start <= target_pos <= placement.end + if not in_region: + continue + + current_score = score_by_pos.get(target_pos, float("-inf")) + if score > current_score: + score_by_pos[target_pos] = score + has_token_by_pos[target_pos] = (b_tok != gap_repr) + else: + continue + + for has_token in has_token_by_pos.values(): + if has_token: + visible += 1 + + return visible diff --git a/src/server/routes/query/pipeline.py b/src/server/routes/query/pipeline.py index bd928b2..8589e7a 100644 --- a/src/server/routes/query/pipeline.py +++ b/src/server/routes/query/pipeline.py @@ -1,6 +1,7 @@ """Pipeline for cross-modal retrieval.""" import uuid +import math from dataclasses import dataclass from typing import Any @@ -50,7 +51,7 @@ def _slice_alignment_to_target_region( """ Slice (center_aln, block_aln) down to the columns that map to target coordinates [start, end] inclusive, while keeping insertion columns (center token == gap_repr) - that occur while inside the region. + that occur while inside the region and immediately before the region start. :param center_aln: list of hashes representing the center alignment SequenceItems :param block_aln: list of hashes representing the block alignment SequenceItems @@ -66,21 +67,21 @@ def _slice_alignment_to_target_region( target_pos = -1 in_region = False + prefix_anchor = start - 1 for c_tok, b_tok in zip(center_aln, block_aln): if c_tok != gap_repr: target_pos += 1 - in_region = (start <= target_pos <= end) + if target_pos > end: + break - if start <= target_pos <= end: + in_region = (start <= target_pos <= end) + if in_region: out_c.append(c_tok) out_b.append(b_tok) - if target_pos > end: - break - else: - if in_region: + if in_region or target_pos == prefix_anchor: out_c.append(c_tok) out_b.append(b_tok) @@ -267,11 +268,42 @@ def _project_one_placement_into_row( return rows, block_maps +def _cosine_similarity(a: list[float], b: list[float]) -> float: + """ + Compute cosine similarity between two vectors. + + :param a: first vector + :param b: second vector + :return: cosine similarity in [-1, 1] + """ + if a is None or b is None: + return 0.0 + try: + if len(a) == 0 or len(b) == 0: + return 0.0 + except TypeError: + return 0.0 + if len(a) != len(b): + current_app.logger.warning("cosine similarity length mismatch: %s vs %s", len(a), len(b)) + dot = 0.0 + norm_a = 0.0 + norm_b = 0.0 + for x, y in zip(a, b): + dot += x * y + norm_a += x * x + norm_b += y * y + if norm_a <= 0.0 or norm_b <= 0.0: + return 0.0 + return dot / (math.sqrt(norm_a) * math.sqrt(norm_b)) + + def cross_modal_retrieval( payload_type: str, payload_blob: dict[str, Any], query_against_clusters: bool, query_against_compounds: bool, + user_uploads: list[dict[str, Any]] | None = None, + query_name: str | None = None, top_k: int = 20, ) -> MSAResult: """ @@ -281,6 +313,8 @@ def cross_modal_retrieval( :param payload_blob: the actual payload data :param query_against_clusters: whether to query against clusters :param query_against_compounds: whether to query against compounds + :param user_uploads: optional list of session items to include in retrieval + :param query_name: optional display name for the query row :param top_k: number of top results to return :return: MSAResult containing the retrieval results :raises ValueError: if no nearest neighbors found or alignment fails @@ -300,7 +334,8 @@ def cross_modal_retrieval( # Featurize nearest neighbors as SequenceItemReadout with cosine SCORE (1 - distance) nns_featurized: list[SequenceItemReadout] = [] nns_cosine_scores: list[float] = [] - retrieved_items: list[CandidateCluster | Compound] = [] + retrieved_items: list[CandidateCluster | Compound | None] = [] + retrieved_names: list[str | None] = [] for item, distance in nns: assert isinstance(item, (CandidateCluster, Compound)), "expected item to be CandidateCluster or Compound" item_type = "cluster" if isinstance(item, CandidateCluster) else "compound" @@ -317,6 +352,28 @@ def cross_modal_retrieval( nns_featurized.append(item_readout) nns_cosine_scores.append(1.0 - distance) retrieved_items.append(item) + retrieved_names.append(None) + + # Include user uploads (session items) if provided + for upload in user_uploads or []: + if not isinstance(upload, dict): + continue + upload_kind = upload.get("kind") + upload_payload = upload.get("payload") + if upload_kind not in ("cluster", "compound"): + continue + if not upload_payload: + continue + try: + upload_vec, upload_readout = featurize_item(upload_kind, upload_payload) + except Exception as exc: + current_app.logger.warning("failed to featurize user upload: %s", exc) + continue + cosine_score = _cosine_similarity(query_vec, upload_vec) + nns_featurized.append(upload_readout) + nns_cosine_scores.append(cosine_score) + retrieved_items.append(None) + retrieved_names.append(upload.get("name") or "Uploaded item") if not nns_featurized or not query_blocks: raise ValueError("no nearest neighbors found or query blocks are empty") @@ -334,6 +391,7 @@ def cross_modal_retrieval( top_k_nns_featurized = [nns_featurized[i] for i in top_k_indices] top_k_retrieved_items = [retrieved_items[i] for i in top_k_indices] + top_k_retrieved_names = [retrieved_names[i] for i in top_k_indices] top_k_aln_results = [aln_results[i] for i in top_k_indices] top_k_aln_scores = [aln_scores[i] for i in top_k_indices] top_k_cosine_scores = [nns_cosine_scores[i] for i in top_k_indices] @@ -358,10 +416,12 @@ def cross_modal_retrieval( retrieved_alignment_scores=top_k_aln_scores, retrieved_cosine_scores=top_k_cosine_scores, retrieved_match_scores=top_k_match_scores, + retrieved_row_names=top_k_retrieved_names, label_fn=item_label_fn, gap_repr=Gap.alignment_representation(), display_name_unidentified=DISPLAY_NAME_UNIDENTIFIED, gap_display_name=Gap().display_name, + query_name=query_name, ) return msa_result diff --git a/src/server/routes/query_service.py b/src/server/routes/query_service.py index d94e4a3..db4159d 100644 --- a/src/server/routes/query_service.py +++ b/src/server/routes/query_service.py @@ -2,7 +2,7 @@ from flask import Blueprint, current_app, request, jsonify -from routes.session_store import load_item +from routes.session_store import load_item, load_session_with_items from routes.query.pipeline import cross_modal_retrieval from routes.query.enrichment import enrichment_study from routes.query.align import MSAResult @@ -25,10 +25,14 @@ def query_item(): query_against_clusters = request.args.get("queryAgainstClusters", "true").lower() == "true" query_against_compounds = request.args.get("queryAgainstCompounds", "true").lower() == "true" + query_against_user_uploads = request.args.get("queryAgainstUserUploads", "false").lower() == "true" current_app.logger.debug(f"query_against_compounds: {query_against_compounds}") current_app.logger.debug(f"query_against_clusters: {query_against_clusters}") - if not query_against_clusters and not query_against_compounds: - return jsonify({"error": "At least one of queryAgainstClusters or queryAgainstCompounds must be true"}), 400 + current_app.logger.debug(f"query_against_user_uploads: {query_against_user_uploads}") + if not query_against_clusters and not query_against_compounds and not query_against_user_uploads: + return jsonify({ + "error": "At least one of queryAgainstClusters, queryAgainstCompounds, or queryAgainstUserUploads must be true" + }), 400 # Retrieve item from session store current_app.logger.info(f"Retrieving query item: session_id={session_id} item_id={item_id}") @@ -46,6 +50,22 @@ def query_item(): current_app.logger.error("Missing item payload for querying") return jsonify({"error": "Missing item payload"}), 400 + user_uploads: list[dict] = [] + if query_against_user_uploads: + session_blob = load_session_with_items(session_id) + for candidate in (session_blob or {}).get("items", []) or []: + if not isinstance(candidate, dict): + continue + if candidate.get("id") == item_id: + continue + if candidate.get("status") != "done": + continue + if candidate.get("kind") not in ("cluster", "compound"): + continue + if not candidate.get("payload"): + continue + user_uploads.append(candidate) + try: current_app.logger.info(f"Starting cross-modal retrieval for item_id={item_id}") msa_result: MSAResult = cross_modal_retrieval( @@ -53,6 +73,8 @@ def query_item(): payload_blob=payload_blob, query_against_clusters=query_against_clusters, query_against_compounds=query_against_compounds, + user_uploads=user_uploads, + query_name=item.get("name"), ) except ValueError as e: current_app.logger.error(f"Error during cross-modal retrieval: {e}") From ffa7dc42c862523b98ddf089b3ee14612e0294df Mon Sep 17 00:00:00 2001 From: David Meijer Date: Mon, 26 Jan 2026 12:39:07 +0100 Subject: [PATCH 31/34] ENH: increased match score for overlapping tokens --- src/server/routes/query/align.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/routes/query/align.py b/src/server/routes/query/align.py index 8ad0e98..0405021 100644 --- a/src/server/routes/query/align.py +++ b/src/server/routes/query/align.py @@ -388,7 +388,7 @@ def item_compare_fn(a: SequenceItem, b: SequenceItem) -> float: anc_tok_differs = a_anc_toks.symmetric_difference(b_anc_toks) tok_overlap = fam_tok_overlap.union(anc_tok_overlap) - score += 0.5 * len(tok_overlap) + score += 1.0 * len(tok_overlap) tok_differs = fam_tok_differs.union(anc_tok_differs) score -= 0.5 * len(tok_differs) From 803b13b644acecf4331872d213c4e700dcfdbac4 Mon Sep 17 00:00:00 2001 From: David Meijer Date: Tue, 27 Jan 2026 11:28:24 +0100 Subject: [PATCH 32/34] UPD: add outlinks for mibig --- .../src/components/workspace/tabs/discovery/DialogRowInfo.tsx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/client/src/components/workspace/tabs/discovery/DialogRowInfo.tsx b/src/client/src/components/workspace/tabs/discovery/DialogRowInfo.tsx index f7304f4..4feca8f 100644 --- a/src/client/src/components/workspace/tabs/discovery/DialogRowInfo.tsx +++ b/src/client/src/components/workspace/tabs/discovery/DialogRowInfo.tsx @@ -33,6 +33,8 @@ function referenceToUrl(ref: Reference): string | null { const { database_name, database_identifier } = ref; switch (database_name.toLowerCase()) { + case "mibig": + return `https://mibig.secondarymetabolites.org/repository/${database_identifier}`; case "npatlas": return `https://www.npatlas.org/explore/compounds/${database_identifier}` default: From 0f55fc6d411edece81c4f4340a2fc6d5983dc8ab Mon Sep 17 00:00:00 2001 From: David Meijer Date: Sat, 31 Jan 2026 07:44:09 +0100 Subject: [PATCH 33/34] WIP --- src/client/src/features/jobs/api.ts | 2 +- src/server/requirements.backend.txt | 3 +- src/server/routes/cluster.py | 2 +- src/server/routes/query/align.py | 36 ++++++++++- src/server/routes/query/pipeline.py | 96 ++++++++++++----------------- src/server/routes/query/retrieve.py | 15 ++++- 6 files changed, 89 insertions(+), 65 deletions(-) diff --git a/src/client/src/features/jobs/api.ts b/src/client/src/features/jobs/api.ts index 666a74f..4dab15f 100644 --- a/src/client/src/features/jobs/api.ts +++ b/src/client/src/features/jobs/api.ts @@ -4,7 +4,7 @@ import type { Session, SessionItem, CompoundItem, ClusterItem } from "../session import { saveSession } from "../session/api"; import { z } from "zod"; -export const MAX_ITEMS = 20; +export const MAX_ITEMS = 50; const SubmitJobRespSchema = z.object({ ok: z.boolean(), diff --git a/src/server/requirements.backend.txt b/src/server/requirements.backend.txt index 12971a2..11e8e37 100644 --- a/src/server/requirements.backend.txt +++ b/src/server/requirements.backend.txt @@ -12,4 +12,5 @@ biocracker bionexus versalign scikit-learn -umap-learn \ No newline at end of file +umap-learn +tqdm diff --git a/src/server/routes/cluster.py b/src/server/routes/cluster.py index ac6c789..0d2df50 100644 --- a/src/server/routes/cluster.py +++ b/src/server/routes/cluster.py @@ -18,7 +18,7 @@ blp_submit_cluster = Blueprint("submit_cluster", __name__) -MAX_ITEMS = int(os.getenv("MAX_ITEMS", "20")) +MAX_ITEMS = int(os.getenv("MAX_ITEMS", "50")) def _set_item_status_inplace(item: dict, status: str, error_message: str | None = None) -> None: diff --git a/src/server/routes/query/align.py b/src/server/routes/query/align.py index 0405021..1c0db77 100644 --- a/src/server/routes/query/align.py +++ b/src/server/routes/query/align.py @@ -119,6 +119,40 @@ def to_dict(self) -> dict[str, Any]: } +def _merge_gap_columns_between_same_block( + block_map: list[int | None], +) -> list[int | None]: + """ + Treat gap-only columns between identical block indices as belonging to that block. + + This prevents one block from being split into multiple subseqs when only gaps + separate its aligned segments. + """ + if not block_map: + return block_map + + next_non_none: list[int | None] = [None] * len(block_map) + next_bidx: int | None = None + for idx in range(len(block_map) - 1, -1, -1): + bidx = block_map[idx] + if bidx is not None: + next_bidx = bidx + next_non_none[idx] = next_bidx + + merged: list[int | None] = [None] * len(block_map) + current_bidx: int | None = None + for idx, bidx in enumerate(block_map): + if bidx is None: + if current_bidx is not None and next_non_none[idx] == current_bidx: + merged[idx] = current_bidx + else: + merged[idx] = None + else: + current_bidx = bidx + merged[idx] = bidx + + return merged + @dataclass(frozen=True) class MSAResult: """ @@ -278,7 +312,7 @@ def from_alignment( for ridx in range(1, len(rows)): row_tokens = rows[ridx] - block_map = block_maps[ridx] + block_map = _merge_gap_columns_between_same_block(block_maps[ridx]) readout = retrieved_readouts[ridx - 1] item = retrieved_items[ridx - 1] diff --git a/src/server/routes/query/pipeline.py b/src/server/routes/query/pipeline.py index 8589e7a..110b023 100644 --- a/src/server/routes/query/pipeline.py +++ b/src/server/routes/query/pipeline.py @@ -2,8 +2,7 @@ import uuid import math -from dataclasses import dataclass -from typing import Any +from typing import Any, Sequence from flask import current_app @@ -24,23 +23,6 @@ warnings.simplefilter("ignore", BiopythonDeprecationWarning) -@dataclass(frozen=True) -class _InsKey: - """ - Key to uniquely identify an insertion column in docking results. - - :var result_idx: index of the docking result - :var placement_idx: index of the placement within the docking result - :var col_in_region: column index within the insertion region - :var anchor: target position anchor for the insertion - """ - - result_idx: int - placement_idx: int - col_in_region: int - anchor: int # insertion occurs AFTER this target position; -1 means before target[0] - - def _slice_alignment_to_target_region( center_aln: list[str], block_aln: list[str], @@ -107,26 +89,17 @@ def merge_dockings_into_global_alignment( n = len(target) - # Collect ALL insertion columns across all dockings, anchored to a target boundary - # anchor = j means "insertion column occurs after target position j" + # Collect insertion lengths across all dockings, anchored to a target boundary + # anchor = j means "insertion occurs after target position j" # anchor = -1 means "before target[0]" - insertions_by_anchor: dict[int, list[_InsKey]] = {a: [] for a in range(-1, n)} - # We also need to later map each insertion column identity -> global column index - inskey_to_global_col: dict[_InsKey, int] = {} + insertion_lengths: dict[int, int] = {a: 0 for a in range(-1, n)} - # Also precompute mapping of target positions -> global columns (once built) + # Precompute mapping of target positions -> global columns (once built) targetpos_to_global_col: dict[int, int] = {} - # To place insertions deterministically, we'll sort them by: - # (result_idx, placement start, placement_idx, col_in_region) - # We need placement start; capture it in a side map - placement_start: dict[tuple[int, int], int] = {} - - for ri, dr in enumerate(dockings): + for dr in dockings: placements = sorted(dr.placements, key=lambda p: (p.start, p.end)) - for pi, p in enumerate(placements): - placement_start[(ri, pi)] = p.start - + for p in placements: reg_center, reg_block = _slice_alignment_to_target_region( center_aln=p.center_aln, block_aln=p.block_aln, @@ -139,34 +112,27 @@ def merge_dockings_into_global_alignment( # We anchor insertion columns to "after the last consumed target position" # Initialize target_pos to p.start - 1 so that the first consumed target sets it to p.start tpos = p.start - 1 - for ci, c_tok in enumerate(reg_center): + insertion_offset = 0 + for c_tok in reg_center: if c_tok != gap_repr: tpos += 1 + insertion_offset = 0 else: # Insertion after tpos (which is in [p.start-1, .. p.end-1]) # If tpos == p.start-1, that's an insertion before the first consumed symbol in the region anchor = tpos if anchor < -1: anchor = -1 if anchor > n - 1: anchor = n - 1 - insertions_by_anchor[anchor].append(_InsKey(ri, pi, ci, anchor=anchor)) - - # Sort insertions at each anchor deterministically - for anchor, keys in insertions_by_anchor.items(): - keys.sort( - key=lambda k: ( - k.result_idx, - placement_start.get((k.result_idx, k.placement_idx), 10**9), - k.placement_idx, - k.col_in_region, - ) - ) + insertion_lengths[anchor] = max(insertion_lengths[anchor], insertion_offset + 1) + insertion_offset += 1 # Build the global aligned center, assigning global column indices aligned_center: list[str] = [] + insertion_cols_by_anchor: dict[int, list[int]] = {a: [] for a in range(-1, n)} # Insertions before target[0] (anchor -1) - for k in insertions_by_anchor[-1]: - inskey_to_global_col[k] = len(aligned_center) + for _ in range(insertion_lengths[-1]): + insertion_cols_by_anchor[-1].append(len(aligned_center)) aligned_center.append(gap_repr) # For each target pos j: emit target[j], then insertion anchored at j @@ -174,8 +140,8 @@ def merge_dockings_into_global_alignment( targetpos_to_global_col[j] = len(aligned_center) aligned_center.append(target[j]) - for k in insertions_by_anchor[j]: - inskey_to_global_col[k] = len(aligned_center) + for _ in range(insertion_lengths[j]): + insertion_cols_by_anchor[j].append(len(aligned_center)) aligned_center.append(gap_repr) aligned_target = aligned_center[:] # same content; separate name for readability @@ -209,20 +175,25 @@ def _project_one_placement_into_row( ) tpos = p.start - 1 - for ci, (c_tok, b_tok) in enumerate(zip(reg_center, reg_block)): + insertion_offset = 0 + for c_tok, b_tok in zip(reg_center, reg_block): if c_tok != gap_repr: tpos += 1 gcol = targetpos_to_global_col[tpos] + insertion_offset = 0 else: # tpos should naturally be in [p.start-1, .. p.end-1], so we don't need clamping here if not (-1 <= tpos <= n - 1): raise ValueError("unexpected target position for insertion column") - key = _InsKey(ri, pi, ci, anchor=tpos) - # Because we sorted+assigned by identity, this must exist - gcol = inskey_to_global_col.get(key, None) - if gcol is None: - # Extremely defensive fallback: skip if we somehow didn't register it + anchor = tpos + if anchor < -1: anchor = -1 + if anchor > n - 1: anchor = n - 1 + cols = insertion_cols_by_anchor[anchor] + if insertion_offset >= len(cols): + insertion_offset += 1 continue + gcol = cols[insertion_offset] + insertion_offset += 1 if b_tok == gap_repr: # Keep block ownership for gap columns, but never override a real token @@ -304,7 +275,10 @@ def cross_modal_retrieval( query_against_compounds: bool, user_uploads: list[dict[str, Any]] | None = None, query_name: str | None = None, - top_k: int = 20, + top_k: int = 18, + ann_search_limit: int | None = None, + cluster_where: Sequence[Any] | None = None, + compound_where: Sequence[Any] | None = None, ) -> MSAResult: """ Perform cross-modal retrieval given an item payload. @@ -316,6 +290,9 @@ def cross_modal_retrieval( :param user_uploads: optional list of session items to include in retrieval :param query_name: optional display name for the query row :param top_k: number of top results to return + :param ann_search_limit: optional override for ANN search radius + :param cluster_where: optional extra filters for cluster ANN query + :param compound_where: optional extra filters for compound ANN query :return: MSAResult containing the retrieval results :raises ValueError: if no nearest neighbors found or alignment fails """ @@ -328,6 +305,9 @@ def cross_modal_retrieval( query_vec, query_against_clusters=query_against_clusters, query_against_compounds=query_against_compounds, + limit=ann_search_limit, + cluster_where=cluster_where, + compound_where=compound_where, ) current_app.logger.debug(f"found {len(nns)} nearest neighbors") diff --git a/src/server/routes/query/retrieve.py b/src/server/routes/query/retrieve.py index af1d25a..f0842d3 100644 --- a/src/server/routes/query/retrieve.py +++ b/src/server/routes/query/retrieve.py @@ -71,6 +71,9 @@ def ann_search( query_vec: list[float], query_against_clusters: bool, query_against_compounds: bool, + cluster_where: Sequence[Any] | None = None, + compound_where: Sequence[Any] | None = None, + limit: int | None = None, ) -> list[tuple[CandidateCluster | Compound, float]]: """ Perform an approximate nearest neighbor search against clusters and/or compounds. @@ -78,13 +81,19 @@ def ann_search( :param query_vec: the query vector :param query_against_clusters: whether to query against candidate clusters :param query_against_compounds: whether to query against compounds + :param cluster_where: optional extra filters for cluster query + :param compound_where: optional extra filters for compound query + :param limit: optional override for total ANN search radius :return: a list of tuples of (model instance, distance) """ if not query_against_clusters and not query_against_compounds: return [] only_one = query_against_clusters ^ query_against_compounds - per_type_limit = ANN_SEARCH_RADIUS if only_one else ANN_SEARCH_RADIUS // 2 + effective_limit = ANN_SEARCH_RADIUS if limit is None else max(1, int(limit)) + per_type_limit = effective_limit if only_one else max(1, effective_limit // 2) + cluster_filters = list(cluster_where or []) + compound_filters = list(compound_where or []) with SessionLocal() as session: _set_local(session, HNSW_SETTINGS) @@ -95,7 +104,7 @@ def ann_search( model=CandidateCluster, vector_col=CandidateCluster.retromol_fp_counted_by_region, query_vec=query_vec, - where=[CandidateCluster.retromol_fp_counted_by_region.is_not(None)], + where=[CandidateCluster.retromol_fp_counted_by_region.is_not(None), *cluster_filters], limit=per_type_limit, ) if query_against_clusters @@ -108,7 +117,7 @@ def ann_search( model=Compound, vector_col=Compound.retromol_fp_counted, query_vec=query_vec, - where=[Compound.retromol_fp_counted.is_not(None)], + where=[Compound.retromol_fp_counted.is_not(None), *compound_filters], limit=per_type_limit, ) if query_against_compounds From e0b6ff4634fe31085c957390241a4713cf48dfee Mon Sep 17 00:00:00 2001 From: David Meijer Date: Thu, 5 Feb 2026 10:26:23 -0500 Subject: [PATCH 34/34] FIX: install RDKit X11 runtime libs in backend image --- docker/backend.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/backend.Dockerfile b/docker/backend.Dockerfile index 1a00f4b..eb2ec63 100644 --- a/docker/backend.Dockerfile +++ b/docker/backend.Dockerfile @@ -15,7 +15,7 @@ RUN groupadd --gid $USER_GID $USERNAME \ WORKDIR /app # System deps (psycopg binary + git) -RUN apt-get update && apt-get install -y --no-install-recommends build-essential libpq-dev git && rm -rf /var/lib/apt/lists/* +RUN apt-get update && apt-get install -y --no-install-recommends build-essential libpq-dev git libxrender1 libxext6 libsm6 && rm -rf /var/lib/apt/lists/* # Copy env + requirements before env creation for caching COPY src/server/environment.backend.yml /app/