import type { ExtensionAPI } from "@mariozechner/pi-coding-agent"; import { getDefaultWebSearchConfigPath, normalizeWebSearchConfig, readRawWebSearchConfig, writeWebSearchConfig, WebSearchConfigError, } from "../config.ts"; import type { WebSearchConfig, WebSearchProviderConfig } from "../schema.ts"; type ProviderPatch = { apiKey?: string; baseUrl?: string; fallbackProviders?: string[]; options?: WebSearchProviderConfig["options"]; }; function validateProviderDraftOrThrow(provider: WebSearchProviderConfig) { if (!provider.name.trim()) { throw new Error("Provider name cannot be blank."); } if (provider.type === "firecrawl") { const apiKey = provider.apiKey?.trim(); const baseUrl = provider.baseUrl?.trim(); if (!apiKey && !baseUrl) { throw new Error("Firecrawl provider apiKey cannot be blank unless baseUrl is set."); } return; } if (!provider.apiKey.trim()) { throw new Error("Provider apiKey cannot be blank."); } } function normalizeDraftConfigOrThrow(config: WebSearchConfig, path: string): WebSearchConfig { const normalized = normalizeWebSearchConfig(config, path); return { defaultProvider: normalized.defaultProviderName, providers: normalized.providers, }; } function parseFallbackProviders(value: string) { const items = value .split(",") .map((item) => item.trim()) .filter(Boolean); return items.length > 0 ? items : undefined; } export function createDefaultWebSearchConfig(input: { provider: WebSearchProviderConfig }): WebSearchConfig { validateProviderDraftOrThrow(input.provider); return { defaultProvider: input.provider.name, providers: [input.provider], }; } export function setDefaultProviderOrThrow(config: WebSearchConfig, providerName: string): WebSearchConfig { if (!config.providers.some((provider) => provider.name === providerName)) { throw new Error(`Unknown provider: ${providerName}`); } return { ...config, defaultProvider: providerName }; } export function renameProviderOrThrow( config: WebSearchConfig, currentName: string, nextName: string, ): WebSearchConfig { if (!nextName.trim()) { throw new Error("Provider name cannot be blank."); } if (config.providers.some((provider) => provider.name === nextName && provider.name !== currentName)) { throw new Error(`Duplicate provider name: ${nextName}`); } return { defaultProvider: config.defaultProvider === currentName ? nextName : config.defaultProvider, providers: config.providers.map((provider) => ({ ...provider, name: provider.name === currentName ? nextName : provider.name, fallbackProviders: provider.fallbackProviders?.map((name) => (name === currentName ? nextName : name)), })), }; } export function updateProviderOrThrow( config: WebSearchConfig, providerName: string, patch: ProviderPatch, ): WebSearchConfig { const existing = config.providers.find((provider) => provider.name === providerName); if (!existing) { throw new Error(`Unknown provider: ${providerName}`); } let nextProvider: WebSearchProviderConfig; if (existing.type === "firecrawl") { const nextBaseUrl = patch.baseUrl ?? existing.baseUrl; const nextApiKey = patch.apiKey !== undefined ? patch.apiKey.trim() || undefined : existing.apiKey; const nextFallbackProviders = patch.fallbackProviders ?? existing.fallbackProviders; const nextOptions = patch.options ?? existing.options; nextProvider = { name: existing.name, type: existing.type, ...(nextApiKey ? { apiKey: nextApiKey } : {}), ...(nextBaseUrl ? { baseUrl: nextBaseUrl } : {}), ...(nextFallbackProviders ? { fallbackProviders: nextFallbackProviders } : {}), ...(nextOptions ? { options: nextOptions } : {}), }; } else { if (patch.apiKey !== undefined && !patch.apiKey.trim()) { throw new Error("Provider apiKey cannot be blank."); } nextProvider = { ...existing, apiKey: patch.apiKey ?? existing.apiKey, fallbackProviders: patch.fallbackProviders ?? existing.fallbackProviders, options: patch.options ?? existing.options, }; } validateProviderDraftOrThrow(nextProvider); return { ...config, providers: config.providers.map((provider) => (provider.name === providerName ? nextProvider : provider)), }; } export function removeProviderOrThrow(config: WebSearchConfig, providerName: string): WebSearchConfig { if (config.providers.length === 1) { throw new Error("Cannot remove the last provider."); } if (config.defaultProvider === providerName) { throw new Error("Cannot remove the default provider before selecting a new default."); } return { ...config, providers: config.providers.filter((provider) => provider.name !== providerName), }; } function upsertProviderOrThrow(config: WebSearchConfig, nextProvider: WebSearchProviderConfig): WebSearchConfig { validateProviderDraftOrThrow(nextProvider); const withoutSameName = config.providers.filter((provider) => provider.name !== nextProvider.name); return { ...config, providers: [...withoutSameName, nextProvider], }; } async function promptProviderOptions(ctx: any, provider: WebSearchProviderConfig) { const defaultSearchLimit = await ctx.ui.input( `Default search limit for ${provider.name}`, provider.options?.defaultSearchLimit !== undefined ? String(provider.options.defaultSearchLimit) : "", ); if (provider.type === "firecrawl") { const options = { defaultSearchLimit: defaultSearchLimit ? Number(defaultSearchLimit) : undefined, }; return options.defaultSearchLimit !== undefined ? options : undefined; } const defaultFetchTextMaxCharacters = await ctx.ui.input( `Default fetch text max characters for ${provider.name}`, provider.options?.defaultFetchTextMaxCharacters !== undefined ? String(provider.options.defaultFetchTextMaxCharacters) : "", ); if (provider.type === "tavily") { const options = { defaultSearchLimit: defaultSearchLimit ? Number(defaultSearchLimit) : undefined, defaultFetchTextMaxCharacters: defaultFetchTextMaxCharacters ? Number(defaultFetchTextMaxCharacters) : undefined, }; return Object.values(options).some((value) => value !== undefined) ? options : undefined; } const defaultFetchHighlightsMaxCharacters = await ctx.ui.input( `Default fetch highlights max characters for ${provider.name}`, provider.options?.defaultFetchHighlightsMaxCharacters !== undefined ? String(provider.options.defaultFetchHighlightsMaxCharacters) : "", ); const options = { defaultSearchLimit: defaultSearchLimit ? Number(defaultSearchLimit) : undefined, defaultFetchTextMaxCharacters: defaultFetchTextMaxCharacters ? Number(defaultFetchTextMaxCharacters) : undefined, defaultFetchHighlightsMaxCharacters: defaultFetchHighlightsMaxCharacters ? Number(defaultFetchHighlightsMaxCharacters) : undefined, }; return Object.values(options).some((value) => value !== undefined) ? options : undefined; } async function promptFallbackProviders(ctx: any, provider: WebSearchProviderConfig) { const value = await ctx.ui.input( `Fallback providers for ${provider.name} (comma-separated names)`, (provider.fallbackProviders ?? []).join(", "), ); return parseFallbackProviders(value ?? ""); } async function promptNewProvider(ctx: any, type: WebSearchProviderConfig["type"]) { const name = await ctx.ui.input( "Provider name", type === "tavily" ? "tavily-main" : type === "exa" ? "exa-fallback" : "firecrawl-main", ); if (!name) { return undefined; } if (type === "firecrawl") { const baseUrl = await ctx.ui.input("Firecrawl base URL (blank uses cloud default)", ""); const apiKey = await ctx.ui.input("Firecrawl API key (blank allowed when base URL is set)", "fc-..."); const provider: WebSearchProviderConfig = { name, type, ...(apiKey?.trim() ? { apiKey } : {}), ...(baseUrl?.trim() ? { baseUrl } : {}), }; const fallbackProviders = await promptFallbackProviders(ctx, provider); const options = await promptProviderOptions(ctx, provider); return { ...provider, ...(fallbackProviders ? { fallbackProviders } : {}), ...(options ? { options } : {}), }; } const apiKey = await ctx.ui.input(type === "tavily" ? "Tavily API key" : "Exa API key", type === "tavily" ? "tvly-..." : "exa_..."); if (!apiKey) { return undefined; } const provider: WebSearchProviderConfig = { name, type, apiKey }; const fallbackProviders = await promptFallbackProviders(ctx, provider); const options = await promptProviderOptions(ctx, provider); return { ...provider, ...(fallbackProviders ? { fallbackProviders } : {}), ...(options ? { options } : {}), }; } export function registerWebSearchConfigCommand(pi: ExtensionAPI) { pi.registerCommand("web-search-config", { description: "Configure Tavily/Exa/Firecrawl providers for web_search and web_fetch", handler: async (_args, ctx) => { const path = getDefaultWebSearchConfigPath(); let config: WebSearchConfig; try { config = await readRawWebSearchConfig(path); } catch (error) { if (!(error instanceof WebSearchConfigError)) { throw error; } const createType = await ctx.ui.select("Create initial provider", [ "Add Tavily provider", "Add Exa provider", "Add Firecrawl provider", ]); if (!createType) { return; } const provider = await promptNewProvider( ctx, createType === "Add Tavily provider" ? "tavily" : createType === "Add Exa provider" ? "exa" : "firecrawl", ); if (!provider) { return; } config = createDefaultWebSearchConfig({ provider }); } const action = await ctx.ui.select("Web search config", [ "Set default provider", "Add Tavily provider", "Add Exa provider", "Add Firecrawl provider", "Edit provider", "Remove provider", ]); if (!action) { return; } if (action === "Set default provider") { const nextDefault = await ctx.ui.select( "Choose default provider", config.providers.map((provider) => provider.name), ); if (!nextDefault) { return; } config = setDefaultProviderOrThrow(config, nextDefault); } if (action === "Add Tavily provider" || action === "Add Exa provider" || action === "Add Firecrawl provider") { const provider = await promptNewProvider( ctx, action === "Add Tavily provider" ? "tavily" : action === "Add Exa provider" ? "exa" : "firecrawl", ); if (!provider) { return; } config = upsertProviderOrThrow(config, provider); } if (action === "Edit provider") { const providerName = await ctx.ui.select( "Choose provider", config.providers.map((provider) => provider.name), ); if (!providerName) { return; } const existing = config.providers.find((provider) => provider.name === providerName)!; const nextName = await ctx.ui.input("Provider name", existing.name); if (!nextName) { return; } config = renameProviderOrThrow(config, existing.name, nextName); const renamed = config.providers.find((provider) => provider.name === nextName)!; const fallbackProviders = await promptFallbackProviders(ctx, renamed); const nextOptions = await promptProviderOptions(ctx, renamed); if (renamed.type === "firecrawl") { const nextBaseUrl = await ctx.ui.input("Firecrawl base URL (blank uses cloud default)", renamed.baseUrl ?? ""); const nextApiKey = await ctx.ui.input( `API key for ${renamed.name} (blank allowed when base URL is set)`, renamed.apiKey ?? "", ); config = updateProviderOrThrow(config, nextName, { apiKey: nextApiKey, baseUrl: nextBaseUrl, fallbackProviders, options: nextOptions, }); } else { const nextApiKey = await ctx.ui.input(`API key for ${renamed.name}`, renamed.apiKey); if (!nextApiKey) { return; } config = updateProviderOrThrow(config, nextName, { apiKey: nextApiKey, fallbackProviders, options: nextOptions, }); } } if (action === "Remove provider") { const providerName = await ctx.ui.select( "Choose provider to remove", config.providers.map((provider) => provider.name), ); if (!providerName) { return; } config = removeProviderOrThrow(config, providerName); } const normalizedConfig = normalizeDraftConfigOrThrow(config, path); await writeWebSearchConfig(path, normalizedConfig); ctx.ui.notify(`Saved web-search config to ${path}`, "info"); }, }); }