diff --git a/apps/web/src/components/chat/ModelPickerContent.tsx b/apps/web/src/components/chat/ModelPickerContent.tsx index e0d6b9afc5e..ecd440cf860 100644 --- a/apps/web/src/components/chat/ModelPickerContent.tsx +++ b/apps/web/src/components/chat/ModelPickerContent.tsx @@ -5,12 +5,27 @@ import { } from "@t3tools/contracts"; import { resolveSelectableModel } from "@t3tools/shared/model"; import { LegendList, type LegendListRef } from "@legendapp/list/react"; -import { memo, useMemo, useState, useCallback, useEffect, useLayoutEffect, useRef } from "react"; +import { + memo, + useMemo, + useState, + useCallback, + useEffect, + useLayoutEffect, + useRef, + forwardRef, + type ChangeEvent, + type KeyboardEvent, +} from "react"; import { SearchIcon } from "lucide-react"; import { ModelListRow } from "./ModelListRow"; import { ModelPickerSidebar } from "./ModelPickerSidebar"; import { isModelPickerNewModel } from "./modelPickerModelHighlights"; -import { buildModelPickerSearchText, scoreModelPickerSearch } from "./modelPickerSearch"; +import { + buildModelPickerSearchText, + normalizeModelPickerSearchQuery, + scoreModelPickerSearch, +} from "./modelPickerSearch"; import { Combobox, ComboboxEmpty, ComboboxInput, ComboboxListVirtualized } from "../ui/combobox"; import { ModelEsque } from "./providerIconUtils"; import { @@ -43,6 +58,10 @@ type ModelPickerItem = { const EMPTY_MODEL_JUMP_LABELS = new Map(); +const SEARCH_ICON = ( + +); + // Split a `${instanceId}:${slug}` combobox key back into its pieces. Slugs // can contain colons (e.g. some vendor model ids), so we only split on the // first colon — anything after that is the slug. @@ -57,6 +76,111 @@ function splitInstanceModelKey(key: string): { instanceId: ProviderInstanceId; s }; } +const ModelPickerSearchInput = memo( + forwardRef< + HTMLInputElement, + { + onSearchQueryChange: (query: string) => void; + onRequestClose: (() => void) | undefined; + onSelectHighlightedModel: () => boolean; + } + >(function ModelPickerSearchInput( + { onSearchQueryChange, onRequestClose, onSelectHighlightedModel }, + ref, + ) { + const [rawSearchQuery, setRawSearchQuery] = useState(""); + const handleChange = useCallback( + (event: ChangeEvent) => { + const nextRawSearchQuery = event.target.value; + setRawSearchQuery(nextRawSearchQuery); + onSearchQueryChange(normalizeModelPickerSearchQuery(nextRawSearchQuery)); + }, + [onSearchQueryChange], + ); + const handleKeyDown = useCallback( + (event: KeyboardEvent) => { + if (event.key === "Escape") { + event.preventDefault(); + event.stopPropagation(); + onRequestClose?.(); + return; + } + if (event.key === "Enter" && onSelectHighlightedModel()) { + (event as typeof event & { preventBaseUIHandler?: () => void }).preventBaseUIHandler?.(); + event.preventDefault(); + event.stopPropagation(); + return; + } + event.stopPropagation(); + }, + [onRequestClose, onSelectHighlightedModel], + ); + + return ( + event.stopPropagation()} + onTouchStart={(event) => event.stopPropagation()} + size="sm" + unstyled + /> + ); + }), +); + +function useModelPickerJumpShortcuts(input: { + keybindings: ResolvedKeybindingsConfig; + modelJumpShortcutContext: { + readonly terminalFocus: false; + readonly terminalOpen: boolean; + readonly modelPickerOpen: true; + }; + modelJumpModelKeys: ReadonlyArray; + handleModelSelect: (modelSlug: string, instanceId: ProviderInstanceId) => void; +}) { + const { keybindings, modelJumpShortcutContext, modelJumpModelKeys, handleModelSelect } = input; + + useEffect(() => { + const onWindowKeyDown = (event: globalThis.KeyboardEvent) => { + if (event.defaultPrevented || event.repeat) { + return; + } + + const command = resolveShortcutCommand(event, keybindings, { + platform: navigator.platform, + context: modelJumpShortcutContext, + }); + const jumpIndex = modelPickerJumpIndexFromCommand(command ?? ""); + if (jumpIndex === null) { + return; + } + + const targetModelKey = modelJumpModelKeys[jumpIndex]; + if (!targetModelKey) { + return; + } + const { instanceId, slug } = splitInstanceModelKey(targetModelKey); + event.preventDefault(); + event.stopPropagation(); + handleModelSelect(slug, instanceId); + }; + + window.addEventListener("keydown", onWindowKeyDown, true); + + return () => { + window.removeEventListener("keydown", onWindowKeyDown, true); + }; + }, [handleModelSelect, keybindings, modelJumpModelKeys, modelJumpShortcutContext]); +} + export const ModelPickerContent = memo(function ModelPickerContent(props: { /** The instance currently selected in the composer (combobox "value"). */ activeInstanceId: ProviderInstanceId; @@ -118,6 +242,9 @@ export const ModelPickerContent = memo(function ModelPickerContent(props: { [providedKeybindings], ); const updateSettings = useUpdateClientSettings(); + const handleSearchQueryChange = useCallback((nextQuery: string) => { + setSearchQuery((currentQuery) => (currentQuery === nextQuery ? currentQuery : nextQuery)); + }, []); const focusSearchInput = useCallback(() => { searchInputRef.current?.focus({ preventScroll: true }); @@ -221,7 +348,7 @@ export const ModelPickerContent = memo(function ModelPickerContent(props: { }, [modelOptionsByInstance, entryByInstanceId, readyInstanceSet]); const isLocked = props.lockedProvider !== null; - const isSearching = searchQuery.trim().length > 0; + const isSearching = searchQuery.length > 0; const lockedDisabledInstanceIds = useMemo(() => { if (!isLocked) { return undefined; @@ -261,7 +388,7 @@ export const ModelPickerContent = memo(function ModelPickerContent(props: { let result = flatModels; // Apply tokenized fuzzy search across the combined provider/model search fields. - if (searchQuery.trim()) { + if (searchQuery) { const rankedMatches = result .map((model) => ({ model, @@ -385,6 +512,14 @@ export const ModelPickerContent = memo(function ModelPickerContent(props: { }, [entryByInstanceId, getModelDisabledReason, modelOptionsByInstance, onInstanceModelChange], ); + const selectHighlightedModel = useCallback(() => { + if (!highlightedModelKeyRef.current) { + return false; + } + const { instanceId, slug } = splitInstanceModelKey(highlightedModelKeyRef.current); + handleModelSelect(slug, instanceId); + return true; + }, [handleModelSelect]); const toggleFavorite = useCallback( (instanceId: ProviderInstanceId, model: string) => { @@ -472,37 +607,12 @@ export const ModelPickerContent = memo(function ModelPickerContent(props: { return mapping.size > 0 ? mapping : EMPTY_MODEL_JUMP_LABELS; }, [keybindings, modelJumpCommandByKey, modelJumpShortcutContext]); - useEffect(() => { - const onWindowKeyDown = (event: globalThis.KeyboardEvent) => { - if (event.defaultPrevented || event.repeat) { - return; - } - - const command = resolveShortcutCommand(event, keybindings, { - platform: navigator.platform, - context: modelJumpShortcutContext, - }); - const jumpIndex = modelPickerJumpIndexFromCommand(command ?? ""); - if (jumpIndex === null) { - return; - } - - const targetModelKey = modelJumpModelKeys[jumpIndex]; - if (!targetModelKey) { - return; - } - const { instanceId, slug } = splitInstanceModelKey(targetModelKey); - event.preventDefault(); - event.stopPropagation(); - handleModelSelect(slug, instanceId); - }; - - window.addEventListener("keydown", onWindowKeyDown, true); - - return () => { - window.removeEventListener("keydown", onWindowKeyDown, true); - }; - }, [handleModelSelect, keybindings, modelJumpModelKeys, modelJumpShortcutContext]); + useModelPickerJumpShortcuts({ + keybindings, + modelJumpShortcutContext, + modelJumpModelKeys, + handleModelSelect, + }); useLayoutEffect(() => { setShowTopScrollFade(false); @@ -577,42 +687,11 @@ export const ModelPickerContent = memo(function ModelPickerContent(props: { {/* Search bar */}
- - } - value={searchQuery} - onChange={(e) => setSearchQuery(e.target.value)} - onKeyDown={(e) => { - if (e.key === "Escape") { - e.preventDefault(); - e.stopPropagation(); - props.onRequestClose?.(); - return; - } - if (e.key === "Enter" && highlightedModelKeyRef.current) { - ( - e as typeof e & { preventBaseUIHandler?: () => void } - ).preventBaseUIHandler?.(); - e.preventDefault(); - e.stopPropagation(); - const { instanceId, slug } = splitInstanceModelKey( - highlightedModelKeyRef.current, - ); - handleModelSelect(slug, instanceId); - return; - } - e.stopPropagation(); - }} - onMouseDown={(e) => e.stopPropagation()} - onTouchStart={(e) => e.stopPropagation()} - size="sm" - unstyled + onSearchQueryChange={handleSearchQueryChange} + onRequestClose={props.onRequestClose} + onSelectHighlightedModel={selectHighlightedModel} />
diff --git a/apps/web/src/components/chat/modelPickerSearch.test.ts b/apps/web/src/components/chat/modelPickerSearch.test.ts index 57a7142c6b9..987eae7fac3 100644 --- a/apps/web/src/components/chat/modelPickerSearch.test.ts +++ b/apps/web/src/components/chat/modelPickerSearch.test.ts @@ -1,6 +1,10 @@ import { describe, expect, it } from "vite-plus/test"; -import { buildModelPickerSearchText, scoreModelPickerSearch } from "./modelPickerSearch"; +import { + buildModelPickerSearchText, + normalizeModelPickerSearchQuery, + scoreModelPickerSearch, +} from "./modelPickerSearch"; describe("buildModelPickerSearchText", () => { it("builds provider-agnostic search text from generic fields", () => { @@ -15,6 +19,12 @@ describe("buildModelPickerSearchText", () => { }); }); +describe("normalizeModelPickerSearchQuery", () => { + it("normalizes casing and redundant whitespace before publishing picker searches", () => { + expect(normalizeModelPickerSearchQuery(" GPT Codex ")).toBe("gpt codex"); + }); +}); + describe("scoreModelPickerSearch", () => { it("matches typo-tolerant multi-token queries", () => { expect( diff --git a/apps/web/src/components/chat/modelPickerSearch.ts b/apps/web/src/components/chat/modelPickerSearch.ts index ff265e4a65a..2fd96acf884 100644 --- a/apps/web/src/components/chat/modelPickerSearch.ts +++ b/apps/web/src/components/chat/modelPickerSearch.ts @@ -17,6 +17,13 @@ type ModelPickerSearchableModel = { const MODEL_PICKER_FAVORITE_SCORE_BOOST = 24; +export function normalizeModelPickerSearchQuery(query: string): string { + return normalizeSearchQuery(query) + .split(/\s+/u) + .filter((token) => token.length > 0) + .join(" "); +} + function getModelPickerSearchFields(model: ModelPickerSearchableModel): string[] { return [ normalizeSearchQuery(model.name), @@ -56,9 +63,7 @@ export function scoreModelPickerSearch( model: ModelPickerSearchableModel, query: string, ): number | null { - const tokens = normalizeSearchQuery(query) - .split(/\s+/u) - .filter((token) => token.length > 0); + const tokens = normalizeModelPickerSearchQuery(query).split(/\s+/u).filter(Boolean); if (tokens.length === 0) { return 0;