309 lines
11 KiB
TypeScript
309 lines
11 KiB
TypeScript
import { Text } from "@mariozechner/pi-tui";
|
|
import {
|
|
listAvailableModelReferences,
|
|
normalizeAvailableModelReference,
|
|
resolveChildModel,
|
|
} from "./models.ts";
|
|
import {
|
|
SubagentParamsSchema,
|
|
type SubagentRunResult,
|
|
type SubagentToolDetails,
|
|
} from "./schema.ts";
|
|
import { createProgressFormatter } from "./progress.mjs";
|
|
|
|
const MAX_PARALLEL_TASKS = 8;
|
|
const MAX_CONCURRENCY = 4;
|
|
|
|
async function mapWithConcurrencyLimit<TIn, TOut>(
|
|
items: TIn[],
|
|
concurrency: number,
|
|
fn: (item: TIn, index: number) => Promise<TOut>,
|
|
): Promise<TOut[]> {
|
|
const limit = Math.max(1, Math.min(concurrency, items.length || 1));
|
|
const results = new Array<TOut>(items.length);
|
|
let nextIndex = 0;
|
|
|
|
await Promise.all(
|
|
Array.from({ length: limit }, async () => {
|
|
while (nextIndex < items.length) {
|
|
const index = nextIndex++;
|
|
results[index] = await fn(items[index], index);
|
|
}
|
|
}),
|
|
);
|
|
|
|
return results;
|
|
}
|
|
|
|
function isFailure(result: Pick<SubagentRunResult, "exitCode" | "stopReason">) {
|
|
return result.exitCode !== 0 || result.stopReason === "error" || result.stopReason === "aborted";
|
|
}
|
|
|
|
function makeDetails(
|
|
mode: "single" | "parallel" | "chain",
|
|
results: SubagentRunResult[],
|
|
): SubagentToolDetails {
|
|
return { mode, results };
|
|
}
|
|
|
|
function makeErrorResult(text: string, mode: "single" | "parallel" | "chain") {
|
|
return {
|
|
content: [{ type: "text" as const, text }],
|
|
details: makeDetails(mode, []),
|
|
isError: true,
|
|
};
|
|
}
|
|
|
|
import type { ModelSelection } from "./models.ts";
|
|
|
|
export function createSubagentTool(deps: {
|
|
listAvailableModelReferences?: typeof listAvailableModelReferences;
|
|
normalizeAvailableModelReference?: typeof normalizeAvailableModelReference;
|
|
parameters?: typeof SubagentParamsSchema;
|
|
// Compatibility: accept injected resolveChildModel functions with either the
|
|
// new API ({ callModel?, presetModel? }) or the older test/hooks API
|
|
// ({ taskModel?, topLevelModel? }). We adapt at callsite below.
|
|
resolveChildModel?:
|
|
| typeof resolveChildModel
|
|
| ((input: { taskModel?: string; topLevelModel?: string }) => ModelSelection);
|
|
runSingleTask?: (input: {
|
|
cwd: string;
|
|
meta: Record<string, unknown>;
|
|
onEvent?: (event: any) => void;
|
|
}) => Promise<SubagentRunResult>;
|
|
} = {}) {
|
|
return {
|
|
name: "subagent",
|
|
label: "Subagent",
|
|
description: "Delegate tasks to generic subagents running in separate child sessions.",
|
|
parameters: deps.parameters ?? SubagentParamsSchema,
|
|
async execute(_toolCallId: string, params: any, _signal: AbortSignal | undefined, onUpdate: any, ctx: any) {
|
|
const hasSingle = Boolean(params.task);
|
|
const hasParallel = Boolean(params.tasks?.length);
|
|
const hasChain = Boolean(params.chain?.length);
|
|
const modeCount = Number(hasSingle) + Number(hasParallel) + Number(hasChain);
|
|
const mode = hasParallel ? "parallel" : hasChain ? "chain" : "single";
|
|
|
|
if (modeCount !== 1) {
|
|
return makeErrorResult("Provide exactly one mode: single, parallel, or chain.", "single");
|
|
}
|
|
|
|
const availableModelReferences = (deps.listAvailableModelReferences ?? listAvailableModelReferences)(ctx.modelRegistry);
|
|
const availableModelsText = availableModelReferences.join(", ") || "(none)";
|
|
const normalizeModelReference = (requestedModel?: string) =>
|
|
(deps.normalizeAvailableModelReference ?? normalizeAvailableModelReference)(requestedModel, availableModelReferences);
|
|
|
|
if (availableModelReferences.length === 0) {
|
|
return makeErrorResult(
|
|
"No available models are configured. Configure at least one model before using subagent.",
|
|
mode,
|
|
);
|
|
}
|
|
|
|
const topLevelModel = normalizeModelReference(params.model);
|
|
if (!topLevelModel) {
|
|
const message =
|
|
typeof params.model !== "string" || params.model.trim().length === 0
|
|
? `Subagent requires a top-level model chosen from the available models: ${availableModelsText}`
|
|
: `Invalid top-level model "${params.model}". Choose one of the available models: ${availableModelsText}`;
|
|
return makeErrorResult(message, mode);
|
|
}
|
|
params.model = topLevelModel;
|
|
|
|
for (const [index, task] of (params.tasks ?? []).entries()) {
|
|
if (task.model === undefined) continue;
|
|
|
|
const normalizedTaskModel = normalizeModelReference(task.model);
|
|
if (!normalizedTaskModel) {
|
|
return makeErrorResult(
|
|
`Invalid model for parallel task ${index + 1}: "${task.model}". Choose one of the available models: ${availableModelsText}`,
|
|
mode,
|
|
);
|
|
}
|
|
task.model = normalizedTaskModel;
|
|
}
|
|
|
|
for (const [index, step] of (params.chain ?? []).entries()) {
|
|
if (step.model === undefined) continue;
|
|
|
|
const normalizedStepModel = normalizeModelReference(step.model);
|
|
if (!normalizedStepModel) {
|
|
return makeErrorResult(
|
|
`Invalid model for chain step ${index + 1}: "${step.model}". Choose one of the available models: ${availableModelsText}`,
|
|
mode,
|
|
);
|
|
}
|
|
step.model = normalizedStepModel;
|
|
}
|
|
|
|
const callResolveChildModel = (input: {
|
|
callModel?: string;
|
|
presetModel?: string;
|
|
taskModel?: string;
|
|
topLevelModel?: string;
|
|
}) => {
|
|
// If an injected resolveChildModel exists, call it with the older-shape
|
|
// keys (taskModel/topLevelModel) for compatibility. Otherwise, use the
|
|
// internal resolveChildModel which expects { callModel, presetModel }.
|
|
if (deps.resolveChildModel) {
|
|
const injected = deps.resolveChildModel as unknown as (arg: { taskModel?: string; topLevelModel?: string }) => unknown;
|
|
return injected({ taskModel: input.callModel ?? input.taskModel, topLevelModel: input.presetModel ?? input.topLevelModel }) as ModelSelection;
|
|
}
|
|
|
|
return resolveChildModel({ callModel: input.callModel ?? input.taskModel, presetModel: input.presetModel ?? input.topLevelModel });
|
|
};
|
|
|
|
const runTask = async (input: {
|
|
task: string;
|
|
cwd?: string;
|
|
taskModel?: string;
|
|
taskIndex?: number;
|
|
step?: number;
|
|
mode: "single" | "parallel" | "chain";
|
|
}) => {
|
|
const model = callResolveChildModel({
|
|
callModel: input.taskModel,
|
|
presetModel: params.model,
|
|
taskModel: input.taskModel,
|
|
topLevelModel: params.model,
|
|
});
|
|
|
|
const progressFormatter = createProgressFormatter();
|
|
|
|
return deps.runSingleTask?.({
|
|
cwd: input.cwd ?? ctx.cwd,
|
|
onEvent(event) {
|
|
const text = progressFormatter.format(event);
|
|
if (!text) return;
|
|
onUpdate?.({
|
|
content: [{ type: "text", text }],
|
|
details: makeDetails(input.mode, []),
|
|
});
|
|
},
|
|
meta: {
|
|
mode: input.mode,
|
|
taskIndex: input.taskIndex,
|
|
step: input.step,
|
|
task: input.task,
|
|
cwd: input.cwd ?? ctx.cwd,
|
|
requestedModel: model.requestedModel,
|
|
resolvedModel: model.resolvedModel,
|
|
},
|
|
}) as Promise<SubagentRunResult>;
|
|
};
|
|
|
|
if (hasSingle) {
|
|
try {
|
|
const result = await runTask({
|
|
task: params.task,
|
|
cwd: params.cwd,
|
|
mode: "single",
|
|
});
|
|
|
|
return {
|
|
content: [{ type: "text" as const, text: result.finalText }],
|
|
details: makeDetails("single", [result]),
|
|
isError: isFailure(result),
|
|
};
|
|
} catch (error) {
|
|
return {
|
|
content: [{ type: "text" as const, text: (error as Error).message }],
|
|
details: makeDetails("single", []),
|
|
isError: true,
|
|
};
|
|
}
|
|
}
|
|
|
|
if (hasParallel) {
|
|
if (params.tasks.length > MAX_PARALLEL_TASKS) {
|
|
return {
|
|
content: [
|
|
{
|
|
type: "text" as const,
|
|
text: `Too many parallel tasks (${params.tasks.length}). Max is ${MAX_PARALLEL_TASKS}.`,
|
|
},
|
|
],
|
|
details: makeDetails("parallel", []),
|
|
isError: true,
|
|
};
|
|
}
|
|
|
|
const liveResults: SubagentRunResult[] = [];
|
|
const results = await mapWithConcurrencyLimit(params.tasks, MAX_CONCURRENCY, async (task: any, index) => {
|
|
const result = await runTask({
|
|
task: task.task,
|
|
cwd: task.cwd,
|
|
taskModel: task.model,
|
|
taskIndex: index,
|
|
mode: "parallel",
|
|
});
|
|
liveResults[index] = result;
|
|
onUpdate?.({
|
|
content: [{ type: "text", text: `Parallel: ${liveResults.filter(Boolean).length}/${params.tasks.length} finished` }],
|
|
details: makeDetails("parallel", liveResults.filter(Boolean)),
|
|
});
|
|
return result;
|
|
});
|
|
|
|
const successCount = results.filter((result) => !isFailure(result)).length;
|
|
const summary = results
|
|
.map((result, index) => `[task ${index + 1}] ${isFailure(result) ? "failed" : "completed"}: ${result.finalText || "(no output)"}`)
|
|
.join("\n\n");
|
|
|
|
return {
|
|
content: [{ type: "text" as const, text: `Parallel: ${successCount}/${results.length} succeeded\n\n${summary}` }],
|
|
details: makeDetails("parallel", results),
|
|
isError: successCount !== results.length,
|
|
};
|
|
}
|
|
|
|
const results: SubagentRunResult[] = [];
|
|
let previous = "";
|
|
for (let index = 0; index < params.chain.length; index += 1) {
|
|
const item = params.chain[index];
|
|
const task = item.task.replaceAll("{previous}", previous);
|
|
const result = await runTask({
|
|
task,
|
|
cwd: item.cwd,
|
|
taskModel: item.model,
|
|
step: index + 1,
|
|
mode: "chain",
|
|
});
|
|
onUpdate?.({
|
|
content: [{ type: "text", text: `Chain: completed step ${index + 1}/${params.chain.length}` }],
|
|
details: makeDetails("chain", [...results, result]),
|
|
});
|
|
results.push(result);
|
|
if (isFailure(result)) {
|
|
return {
|
|
content: [
|
|
{
|
|
type: "text" as const,
|
|
text: `Chain stopped at step ${index + 1}: ${result.finalText || result.stopReason || "failed"}`,
|
|
},
|
|
],
|
|
details: makeDetails("chain", results),
|
|
isError: true,
|
|
};
|
|
}
|
|
previous = result.finalText;
|
|
}
|
|
|
|
const finalResult = results[results.length - 1];
|
|
return {
|
|
content: [{ type: "text" as const, text: finalResult?.finalText ?? "" }],
|
|
details: makeDetails("chain", results),
|
|
};
|
|
},
|
|
renderCall(args: any) {
|
|
if (args.tasks?.length) return new Text(`subagent parallel (${args.tasks.length} tasks)`, 0, 0);
|
|
if (args.chain?.length) return new Text(`subagent chain (${args.chain.length} steps)`, 0, 0);
|
|
return new Text("subagent", 0, 0);
|
|
},
|
|
renderResult(result: { content: Array<{ type: string; text?: string }> }) {
|
|
const first = result.content[0];
|
|
return new Text(first?.type === "text" ? first.text ?? "" : "", 0, 0);
|
|
},
|
|
};
|
|
}
|