import { Text } from "@mariozechner/pi-tui"; import { listAvailableModelReferences, normalizeAvailableModelReference, resolveChildModel, } from "./models.ts"; import { SubagentParamsSchema, type SubagentRunResult, type SubagentToolDetails, } from "./schema.ts"; import { createProgressFormatter } from "./progress.mjs"; const MAX_PARALLEL_TASKS = 8; const MAX_CONCURRENCY = 4; async function mapWithConcurrencyLimit( items: TIn[], concurrency: number, fn: (item: TIn, index: number) => Promise, ): Promise { const limit = Math.max(1, Math.min(concurrency, items.length || 1)); const results = new Array(items.length); let nextIndex = 0; await Promise.all( Array.from({ length: limit }, async () => { while (nextIndex < items.length) { const index = nextIndex++; results[index] = await fn(items[index], index); } }), ); return results; } function isFailure(result: Pick) { return result.exitCode !== 0 || result.stopReason === "error" || result.stopReason === "aborted"; } function makeDetails( mode: "single" | "parallel" | "chain", results: SubagentRunResult[], ): SubagentToolDetails { return { mode, results }; } function makeErrorResult(text: string, mode: "single" | "parallel" | "chain") { return { content: [{ type: "text" as const, text }], details: makeDetails(mode, []), isError: true, }; } import type { ModelSelection } from "./models.ts"; export function createSubagentTool(deps: { listAvailableModelReferences?: typeof listAvailableModelReferences; normalizeAvailableModelReference?: typeof normalizeAvailableModelReference; parameters?: typeof SubagentParamsSchema; // Compatibility: accept injected resolveChildModel functions with either the // new API ({ callModel?, presetModel? }) or the older test/hooks API // ({ taskModel?, topLevelModel? }). We adapt at callsite below. resolveChildModel?: | typeof resolveChildModel | ((input: { taskModel?: string; topLevelModel?: string }) => ModelSelection); runSingleTask?: (input: { cwd: string; meta: Record; onEvent?: (event: any) => void; }) => Promise; } = {}) { return { name: "subagent", label: "Subagent", description: "Delegate tasks to generic subagents running in separate child sessions.", parameters: deps.parameters ?? SubagentParamsSchema, async execute(_toolCallId: string, params: any, _signal: AbortSignal | undefined, onUpdate: any, ctx: any) { const hasSingle = Boolean(params.task); const hasParallel = Boolean(params.tasks?.length); const hasChain = Boolean(params.chain?.length); const modeCount = Number(hasSingle) + Number(hasParallel) + Number(hasChain); const mode = hasParallel ? "parallel" : hasChain ? "chain" : "single"; if (modeCount !== 1) { return makeErrorResult("Provide exactly one mode: single, parallel, or chain.", "single"); } const availableModelReferences = (deps.listAvailableModelReferences ?? listAvailableModelReferences)(ctx.modelRegistry); const availableModelsText = availableModelReferences.join(", ") || "(none)"; const normalizeModelReference = (requestedModel?: string) => (deps.normalizeAvailableModelReference ?? normalizeAvailableModelReference)(requestedModel, availableModelReferences); if (availableModelReferences.length === 0) { return makeErrorResult( "No available models are configured. Configure at least one model before using subagent.", mode, ); } const topLevelModel = normalizeModelReference(params.model); if (!topLevelModel) { const message = typeof params.model !== "string" || params.model.trim().length === 0 ? `Subagent requires a top-level model chosen from the available models: ${availableModelsText}` : `Invalid top-level model "${params.model}". Choose one of the available models: ${availableModelsText}`; return makeErrorResult(message, mode); } params.model = topLevelModel; for (const [index, task] of (params.tasks ?? []).entries()) { if (task.model === undefined) continue; const normalizedTaskModel = normalizeModelReference(task.model); if (!normalizedTaskModel) { return makeErrorResult( `Invalid model for parallel task ${index + 1}: "${task.model}". Choose one of the available models: ${availableModelsText}`, mode, ); } task.model = normalizedTaskModel; } for (const [index, step] of (params.chain ?? []).entries()) { if (step.model === undefined) continue; const normalizedStepModel = normalizeModelReference(step.model); if (!normalizedStepModel) { return makeErrorResult( `Invalid model for chain step ${index + 1}: "${step.model}". Choose one of the available models: ${availableModelsText}`, mode, ); } step.model = normalizedStepModel; } const callResolveChildModel = (input: { callModel?: string; presetModel?: string; taskModel?: string; topLevelModel?: string; }) => { // If an injected resolveChildModel exists, call it with the older-shape // keys (taskModel/topLevelModel) for compatibility. Otherwise, use the // internal resolveChildModel which expects { callModel, presetModel }. if (deps.resolveChildModel) { const injected = deps.resolveChildModel as unknown as (arg: { taskModel?: string; topLevelModel?: string }) => unknown; return injected({ taskModel: input.callModel ?? input.taskModel, topLevelModel: input.presetModel ?? input.topLevelModel }) as ModelSelection; } return resolveChildModel({ callModel: input.callModel ?? input.taskModel, presetModel: input.presetModel ?? input.topLevelModel }); }; const runTask = async (input: { task: string; cwd?: string; taskModel?: string; taskIndex?: number; step?: number; mode: "single" | "parallel" | "chain"; }) => { const model = callResolveChildModel({ callModel: input.taskModel, presetModel: params.model, taskModel: input.taskModel, topLevelModel: params.model, }); const progressFormatter = createProgressFormatter(); return deps.runSingleTask?.({ cwd: input.cwd ?? ctx.cwd, onEvent(event) { const text = progressFormatter.format(event); if (!text) return; onUpdate?.({ content: [{ type: "text", text }], details: makeDetails(input.mode, []), }); }, meta: { mode: input.mode, taskIndex: input.taskIndex, step: input.step, task: input.task, cwd: input.cwd ?? ctx.cwd, requestedModel: model.requestedModel, resolvedModel: model.resolvedModel, }, }) as Promise; }; if (hasSingle) { try { const result = await runTask({ task: params.task, cwd: params.cwd, mode: "single", }); return { content: [{ type: "text" as const, text: result.finalText }], details: makeDetails("single", [result]), isError: isFailure(result), }; } catch (error) { return { content: [{ type: "text" as const, text: (error as Error).message }], details: makeDetails("single", []), isError: true, }; } } if (hasParallel) { if (params.tasks.length > MAX_PARALLEL_TASKS) { return { content: [ { type: "text" as const, text: `Too many parallel tasks (${params.tasks.length}). Max is ${MAX_PARALLEL_TASKS}.`, }, ], details: makeDetails("parallel", []), isError: true, }; } const liveResults: SubagentRunResult[] = []; const results = await mapWithConcurrencyLimit(params.tasks, MAX_CONCURRENCY, async (task: any, index) => { const result = await runTask({ task: task.task, cwd: task.cwd, taskModel: task.model, taskIndex: index, mode: "parallel", }); liveResults[index] = result; onUpdate?.({ content: [{ type: "text", text: `Parallel: ${liveResults.filter(Boolean).length}/${params.tasks.length} finished` }], details: makeDetails("parallel", liveResults.filter(Boolean)), }); return result; }); const successCount = results.filter((result) => !isFailure(result)).length; const summary = results .map((result, index) => `[task ${index + 1}] ${isFailure(result) ? "failed" : "completed"}: ${result.finalText || "(no output)"}`) .join("\n\n"); return { content: [{ type: "text" as const, text: `Parallel: ${successCount}/${results.length} succeeded\n\n${summary}` }], details: makeDetails("parallel", results), isError: successCount !== results.length, }; } const results: SubagentRunResult[] = []; let previous = ""; for (let index = 0; index < params.chain.length; index += 1) { const item = params.chain[index]; const task = item.task.replaceAll("{previous}", previous); const result = await runTask({ task, cwd: item.cwd, taskModel: item.model, step: index + 1, mode: "chain", }); onUpdate?.({ content: [{ type: "text", text: `Chain: completed step ${index + 1}/${params.chain.length}` }], details: makeDetails("chain", [...results, result]), }); results.push(result); if (isFailure(result)) { return { content: [ { type: "text" as const, text: `Chain stopped at step ${index + 1}: ${result.finalText || result.stopReason || "failed"}`, }, ], details: makeDetails("chain", results), isError: true, }; } previous = result.finalText; } const finalResult = results[results.length - 1]; return { content: [{ type: "text" as const, text: finalResult?.finalText ?? "" }], details: makeDetails("chain", results), }; }, renderCall(args: any) { if (args.tasks?.length) return new Text(`subagent parallel (${args.tasks.length} tasks)`, 0, 0); if (args.chain?.length) return new Text(`subagent chain (${args.chain.length} steps)`, 0, 0); return new Text("subagent", 0, 0); }, renderResult(result: { content: Array<{ type: string; text?: string }> }) { const first = result.content[0]; return new Text(first?.type === "text" ? first.text ?? "" : "", 0, 0); }, }; }