diff --git a/src/tool.ts b/src/tool.ts index dd68f74..cf7dbce 100644 --- a/src/tool.ts +++ b/src/tool.ts @@ -54,11 +54,18 @@ function makeErrorResult(text: string, mode: "single" | "parallel" | "chain") { }; } +import type { ModelSelection } from "./models.ts"; + export function createSubagentTool(deps: { listAvailableModelReferences?: typeof listAvailableModelReferences; normalizeAvailableModelReference?: typeof normalizeAvailableModelReference; parameters?: typeof SubagentParamsSchema; - resolveChildModel?: typeof resolveChildModel; + // Compatibility: accept injected resolveChildModel functions with either the + // new API ({ callModel?, presetModel? }) or the older test/hooks API + // ({ taskModel?, topLevelModel? }). We adapt at callsite below. + resolveChildModel?: + | typeof resolveChildModel + | ((input: { taskModel?: string; topLevelModel?: string }) => ModelSelection); runSingleTask?: (input: { cwd: string; meta: Record; @@ -129,6 +136,23 @@ export function createSubagentTool(deps: { step.model = normalizedStepModel; } + const callResolveChildModel = (input: { + callModel?: string; + presetModel?: string; + taskModel?: string; + topLevelModel?: string; + }) => { + // If an injected resolveChildModel exists, call it with the older-shape + // keys (taskModel/topLevelModel) for compatibility. Otherwise, use the + // internal resolveChildModel which expects { callModel, presetModel }. + if (deps.resolveChildModel) { + const injected = deps.resolveChildModel as unknown as (arg: { taskModel?: string; topLevelModel?: string }) => unknown; + return injected({ taskModel: input.callModel ?? input.taskModel, topLevelModel: input.presetModel ?? input.topLevelModel }) as ModelSelection; + } + + return resolveChildModel({ callModel: input.callModel ?? input.taskModel, presetModel: input.presetModel ?? input.topLevelModel }); + }; + const runTask = async (input: { task: string; cwd?: string; @@ -137,9 +161,7 @@ export function createSubagentTool(deps: { step?: number; mode: "single" | "parallel" | "chain"; }) => { - const model = (deps.resolveChildModel ?? resolveChildModel)({ - // compatibility: newer resolveChildModel expects { callModel, presetModel } - // older hooks/tests used { taskModel, topLevelModel } + const model = callResolveChildModel({ callModel: input.taskModel, presetModel: params.model, taskModel: input.taskModel,