392 lines
13 KiB
TypeScript
392 lines
13 KiB
TypeScript
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
|
|
import {
|
|
getDefaultWebSearchConfigPath,
|
|
normalizeWebSearchConfig,
|
|
readRawWebSearchConfig,
|
|
writeWebSearchConfig,
|
|
WebSearchConfigError,
|
|
} from "../config.ts";
|
|
import type { WebSearchConfig, WebSearchProviderConfig } from "../schema.ts";
|
|
|
|
type ProviderPatch = {
|
|
apiKey?: string;
|
|
baseUrl?: string;
|
|
fallbackProviders?: string[];
|
|
options?: WebSearchProviderConfig["options"];
|
|
};
|
|
|
|
function validateProviderDraftOrThrow(provider: WebSearchProviderConfig) {
|
|
if (!provider.name.trim()) {
|
|
throw new Error("Provider name cannot be blank.");
|
|
}
|
|
|
|
if (provider.type === "firecrawl") {
|
|
const apiKey = provider.apiKey?.trim();
|
|
const baseUrl = provider.baseUrl?.trim();
|
|
if (!apiKey && !baseUrl) {
|
|
throw new Error("Firecrawl provider apiKey cannot be blank unless baseUrl is set.");
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (!provider.apiKey.trim()) {
|
|
throw new Error("Provider apiKey cannot be blank.");
|
|
}
|
|
}
|
|
|
|
function normalizeDraftConfigOrThrow(config: WebSearchConfig, path: string): WebSearchConfig {
|
|
const normalized = normalizeWebSearchConfig(config, path);
|
|
return {
|
|
defaultProvider: normalized.defaultProviderName,
|
|
providers: normalized.providers,
|
|
};
|
|
}
|
|
|
|
function parseFallbackProviders(value: string) {
|
|
const items = value
|
|
.split(",")
|
|
.map((item) => item.trim())
|
|
.filter(Boolean);
|
|
return items.length > 0 ? items : undefined;
|
|
}
|
|
|
|
export function createDefaultWebSearchConfig(input: { provider: WebSearchProviderConfig }): WebSearchConfig {
|
|
validateProviderDraftOrThrow(input.provider);
|
|
return {
|
|
defaultProvider: input.provider.name,
|
|
providers: [input.provider],
|
|
};
|
|
}
|
|
|
|
export function setDefaultProviderOrThrow(config: WebSearchConfig, providerName: string): WebSearchConfig {
|
|
if (!config.providers.some((provider) => provider.name === providerName)) {
|
|
throw new Error(`Unknown provider: ${providerName}`);
|
|
}
|
|
return { ...config, defaultProvider: providerName };
|
|
}
|
|
|
|
export function renameProviderOrThrow(
|
|
config: WebSearchConfig,
|
|
currentName: string,
|
|
nextName: string,
|
|
): WebSearchConfig {
|
|
if (!nextName.trim()) {
|
|
throw new Error("Provider name cannot be blank.");
|
|
}
|
|
if (config.providers.some((provider) => provider.name === nextName && provider.name !== currentName)) {
|
|
throw new Error(`Duplicate provider name: ${nextName}`);
|
|
}
|
|
|
|
return {
|
|
defaultProvider: config.defaultProvider === currentName ? nextName : config.defaultProvider,
|
|
providers: config.providers.map((provider) => ({
|
|
...provider,
|
|
name: provider.name === currentName ? nextName : provider.name,
|
|
fallbackProviders: provider.fallbackProviders?.map((name) => (name === currentName ? nextName : name)),
|
|
})),
|
|
};
|
|
}
|
|
|
|
export function updateProviderOrThrow(
|
|
config: WebSearchConfig,
|
|
providerName: string,
|
|
patch: ProviderPatch,
|
|
): WebSearchConfig {
|
|
const existing = config.providers.find((provider) => provider.name === providerName);
|
|
if (!existing) {
|
|
throw new Error(`Unknown provider: ${providerName}`);
|
|
}
|
|
|
|
let nextProvider: WebSearchProviderConfig;
|
|
if (existing.type === "firecrawl") {
|
|
const nextBaseUrl = patch.baseUrl ?? existing.baseUrl;
|
|
const nextApiKey = patch.apiKey !== undefined ? patch.apiKey.trim() || undefined : existing.apiKey;
|
|
const nextFallbackProviders = patch.fallbackProviders ?? existing.fallbackProviders;
|
|
const nextOptions = patch.options ?? existing.options;
|
|
|
|
nextProvider = {
|
|
name: existing.name,
|
|
type: existing.type,
|
|
...(nextApiKey ? { apiKey: nextApiKey } : {}),
|
|
...(nextBaseUrl ? { baseUrl: nextBaseUrl } : {}),
|
|
...(nextFallbackProviders ? { fallbackProviders: nextFallbackProviders } : {}),
|
|
...(nextOptions ? { options: nextOptions } : {}),
|
|
};
|
|
} else {
|
|
if (patch.apiKey !== undefined && !patch.apiKey.trim()) {
|
|
throw new Error("Provider apiKey cannot be blank.");
|
|
}
|
|
|
|
nextProvider = {
|
|
...existing,
|
|
apiKey: patch.apiKey ?? existing.apiKey,
|
|
fallbackProviders: patch.fallbackProviders ?? existing.fallbackProviders,
|
|
options: patch.options ?? existing.options,
|
|
};
|
|
}
|
|
|
|
validateProviderDraftOrThrow(nextProvider);
|
|
|
|
return {
|
|
...config,
|
|
providers: config.providers.map((provider) => (provider.name === providerName ? nextProvider : provider)),
|
|
};
|
|
}
|
|
|
|
export function removeProviderOrThrow(config: WebSearchConfig, providerName: string): WebSearchConfig {
|
|
if (config.providers.length === 1) {
|
|
throw new Error("Cannot remove the last provider.");
|
|
}
|
|
if (config.defaultProvider === providerName) {
|
|
throw new Error("Cannot remove the default provider before selecting a new default.");
|
|
}
|
|
return {
|
|
...config,
|
|
providers: config.providers.filter((provider) => provider.name !== providerName),
|
|
};
|
|
}
|
|
|
|
function upsertProviderOrThrow(config: WebSearchConfig, nextProvider: WebSearchProviderConfig): WebSearchConfig {
|
|
validateProviderDraftOrThrow(nextProvider);
|
|
|
|
const withoutSameName = config.providers.filter((provider) => provider.name !== nextProvider.name);
|
|
return {
|
|
...config,
|
|
providers: [...withoutSameName, nextProvider],
|
|
};
|
|
}
|
|
|
|
async function promptProviderOptions(ctx: any, provider: WebSearchProviderConfig) {
|
|
const defaultSearchLimit = await ctx.ui.input(
|
|
`Default search limit for ${provider.name}`,
|
|
provider.options?.defaultSearchLimit !== undefined ? String(provider.options.defaultSearchLimit) : "",
|
|
);
|
|
|
|
if (provider.type === "firecrawl") {
|
|
const options = {
|
|
defaultSearchLimit: defaultSearchLimit ? Number(defaultSearchLimit) : undefined,
|
|
};
|
|
return options.defaultSearchLimit !== undefined ? options : undefined;
|
|
}
|
|
|
|
const defaultFetchTextMaxCharacters = await ctx.ui.input(
|
|
`Default fetch text max characters for ${provider.name}`,
|
|
provider.options?.defaultFetchTextMaxCharacters !== undefined
|
|
? String(provider.options.defaultFetchTextMaxCharacters)
|
|
: "",
|
|
);
|
|
|
|
if (provider.type === "tavily") {
|
|
const options = {
|
|
defaultSearchLimit: defaultSearchLimit ? Number(defaultSearchLimit) : undefined,
|
|
defaultFetchTextMaxCharacters: defaultFetchTextMaxCharacters
|
|
? Number(defaultFetchTextMaxCharacters)
|
|
: undefined,
|
|
};
|
|
return Object.values(options).some((value) => value !== undefined) ? options : undefined;
|
|
}
|
|
|
|
const defaultFetchHighlightsMaxCharacters = await ctx.ui.input(
|
|
`Default fetch highlights max characters for ${provider.name}`,
|
|
provider.options?.defaultFetchHighlightsMaxCharacters !== undefined
|
|
? String(provider.options.defaultFetchHighlightsMaxCharacters)
|
|
: "",
|
|
);
|
|
|
|
const options = {
|
|
defaultSearchLimit: defaultSearchLimit ? Number(defaultSearchLimit) : undefined,
|
|
defaultFetchTextMaxCharacters: defaultFetchTextMaxCharacters
|
|
? Number(defaultFetchTextMaxCharacters)
|
|
: undefined,
|
|
defaultFetchHighlightsMaxCharacters: defaultFetchHighlightsMaxCharacters
|
|
? Number(defaultFetchHighlightsMaxCharacters)
|
|
: undefined,
|
|
};
|
|
|
|
return Object.values(options).some((value) => value !== undefined) ? options : undefined;
|
|
}
|
|
|
|
async function promptFallbackProviders(ctx: any, provider: WebSearchProviderConfig) {
|
|
const value = await ctx.ui.input(
|
|
`Fallback providers for ${provider.name} (comma-separated names)`,
|
|
(provider.fallbackProviders ?? []).join(", "),
|
|
);
|
|
return parseFallbackProviders(value ?? "");
|
|
}
|
|
|
|
async function promptNewProvider(ctx: any, type: WebSearchProviderConfig["type"]) {
|
|
const name = await ctx.ui.input(
|
|
"Provider name",
|
|
type === "tavily" ? "tavily-main" : type === "exa" ? "exa-fallback" : "firecrawl-main",
|
|
);
|
|
if (!name) {
|
|
return undefined;
|
|
}
|
|
|
|
if (type === "firecrawl") {
|
|
const baseUrl = await ctx.ui.input("Firecrawl base URL (blank uses cloud default)", "");
|
|
const apiKey = await ctx.ui.input("Firecrawl API key (blank allowed when base URL is set)", "fc-...");
|
|
const provider: WebSearchProviderConfig = {
|
|
name,
|
|
type,
|
|
...(apiKey?.trim() ? { apiKey } : {}),
|
|
...(baseUrl?.trim() ? { baseUrl } : {}),
|
|
};
|
|
const fallbackProviders = await promptFallbackProviders(ctx, provider);
|
|
const options = await promptProviderOptions(ctx, provider);
|
|
return {
|
|
...provider,
|
|
...(fallbackProviders ? { fallbackProviders } : {}),
|
|
...(options ? { options } : {}),
|
|
};
|
|
}
|
|
|
|
const apiKey = await ctx.ui.input(type === "tavily" ? "Tavily API key" : "Exa API key", type === "tavily" ? "tvly-..." : "exa_...");
|
|
if (!apiKey) {
|
|
return undefined;
|
|
}
|
|
|
|
const provider: WebSearchProviderConfig = { name, type, apiKey };
|
|
const fallbackProviders = await promptFallbackProviders(ctx, provider);
|
|
const options = await promptProviderOptions(ctx, provider);
|
|
return {
|
|
...provider,
|
|
...(fallbackProviders ? { fallbackProviders } : {}),
|
|
...(options ? { options } : {}),
|
|
};
|
|
}
|
|
|
|
export function registerWebSearchConfigCommand(pi: ExtensionAPI) {
|
|
pi.registerCommand("web-search-config", {
|
|
description: "Configure Tavily/Exa/Firecrawl providers for web_search and web_fetch",
|
|
handler: async (_args, ctx) => {
|
|
const path = getDefaultWebSearchConfigPath();
|
|
|
|
let config: WebSearchConfig;
|
|
try {
|
|
config = await readRawWebSearchConfig(path);
|
|
} catch (error) {
|
|
if (!(error instanceof WebSearchConfigError)) {
|
|
throw error;
|
|
}
|
|
|
|
const createType = await ctx.ui.select("Create initial provider", [
|
|
"Add Tavily provider",
|
|
"Add Exa provider",
|
|
"Add Firecrawl provider",
|
|
]);
|
|
if (!createType) {
|
|
return;
|
|
}
|
|
|
|
const provider = await promptNewProvider(
|
|
ctx,
|
|
createType === "Add Tavily provider"
|
|
? "tavily"
|
|
: createType === "Add Exa provider"
|
|
? "exa"
|
|
: "firecrawl",
|
|
);
|
|
if (!provider) {
|
|
return;
|
|
}
|
|
config = createDefaultWebSearchConfig({ provider });
|
|
}
|
|
|
|
const action = await ctx.ui.select("Web search config", [
|
|
"Set default provider",
|
|
"Add Tavily provider",
|
|
"Add Exa provider",
|
|
"Add Firecrawl provider",
|
|
"Edit provider",
|
|
"Remove provider",
|
|
]);
|
|
if (!action) {
|
|
return;
|
|
}
|
|
|
|
if (action === "Set default provider") {
|
|
const nextDefault = await ctx.ui.select(
|
|
"Choose default provider",
|
|
config.providers.map((provider) => provider.name),
|
|
);
|
|
if (!nextDefault) {
|
|
return;
|
|
}
|
|
config = setDefaultProviderOrThrow(config, nextDefault);
|
|
}
|
|
|
|
if (action === "Add Tavily provider" || action === "Add Exa provider" || action === "Add Firecrawl provider") {
|
|
const provider = await promptNewProvider(
|
|
ctx,
|
|
action === "Add Tavily provider" ? "tavily" : action === "Add Exa provider" ? "exa" : "firecrawl",
|
|
);
|
|
if (!provider) {
|
|
return;
|
|
}
|
|
config = upsertProviderOrThrow(config, provider);
|
|
}
|
|
|
|
if (action === "Edit provider") {
|
|
const providerName = await ctx.ui.select(
|
|
"Choose provider",
|
|
config.providers.map((provider) => provider.name),
|
|
);
|
|
if (!providerName) {
|
|
return;
|
|
}
|
|
|
|
const existing = config.providers.find((provider) => provider.name === providerName)!;
|
|
const nextName = await ctx.ui.input("Provider name", existing.name);
|
|
if (!nextName) {
|
|
return;
|
|
}
|
|
|
|
config = renameProviderOrThrow(config, existing.name, nextName);
|
|
const renamed = config.providers.find((provider) => provider.name === nextName)!;
|
|
const fallbackProviders = await promptFallbackProviders(ctx, renamed);
|
|
const nextOptions = await promptProviderOptions(ctx, renamed);
|
|
|
|
if (renamed.type === "firecrawl") {
|
|
const nextBaseUrl = await ctx.ui.input("Firecrawl base URL (blank uses cloud default)", renamed.baseUrl ?? "");
|
|
const nextApiKey = await ctx.ui.input(
|
|
`API key for ${renamed.name} (blank allowed when base URL is set)`,
|
|
renamed.apiKey ?? "",
|
|
);
|
|
config = updateProviderOrThrow(config, nextName, {
|
|
apiKey: nextApiKey,
|
|
baseUrl: nextBaseUrl,
|
|
fallbackProviders,
|
|
options: nextOptions,
|
|
});
|
|
} else {
|
|
const nextApiKey = await ctx.ui.input(`API key for ${renamed.name}`, renamed.apiKey);
|
|
if (!nextApiKey) {
|
|
return;
|
|
}
|
|
config = updateProviderOrThrow(config, nextName, {
|
|
apiKey: nextApiKey,
|
|
fallbackProviders,
|
|
options: nextOptions,
|
|
});
|
|
}
|
|
}
|
|
|
|
if (action === "Remove provider") {
|
|
const providerName = await ctx.ui.select(
|
|
"Choose provider to remove",
|
|
config.providers.map((provider) => provider.name),
|
|
);
|
|
if (!providerName) {
|
|
return;
|
|
}
|
|
config = removeProviderOrThrow(config, providerName);
|
|
}
|
|
|
|
const normalizedConfig = normalizeDraftConfigOrThrow(config, path);
|
|
await writeWebSearchConfig(path, normalizedConfig);
|
|
ctx.ui.notify(`Saved web-search config to ${path}`, "info");
|
|
},
|
|
});
|
|
}
|