354 lines
12 KiB
TypeScript
354 lines
12 KiB
TypeScript
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
|
import type { ExtensionAPI, ExtensionContext } from "@mariozechner/pi-coding-agent";
|
|
import { adjustPolicyForZone } from "./src/config.ts";
|
|
import { deserializeLatestSnapshot, serializeSnapshot, SNAPSHOT_ENTRY_TYPE, type RuntimeSnapshot } from "./src/persist.ts";
|
|
import { createEmptyLedger } from "./src/ledger.ts";
|
|
import { pruneContextMessages } from "./src/prune.ts";
|
|
import { createContextManagerRuntime } from "./src/runtime.ts";
|
|
import { registerContextCommands } from "./src/commands.ts";
|
|
import { buildBranchSummaryFromEntries, buildCompactionSummaryFromPreparation } from "./src/summaries.ts";
|
|
|
|
type TrackedMessage = Extract<AgentMessage, { role: "user" | "assistant" | "toolResult" }>;
|
|
type BranchEntry = ReturnType<ExtensionContext["sessionManager"]["getBranch"]>[number];
|
|
|
|
function isTextPart(part: unknown): part is { type: "text"; text?: string } {
|
|
return typeof part === "object" && part !== null && "type" in part && (part as { type?: unknown }).type === "text";
|
|
}
|
|
|
|
function toText(content: unknown): string {
|
|
if (typeof content === "string") return content;
|
|
if (!Array.isArray(content)) return "";
|
|
|
|
return content
|
|
.map((part) => {
|
|
if (!isTextPart(part)) return "";
|
|
return typeof part.text === "string" ? part.text : "";
|
|
})
|
|
.join("\n")
|
|
.trim();
|
|
}
|
|
|
|
function isMessageEntry(entry: BranchEntry): entry is Extract<BranchEntry, { type: "message" }> {
|
|
return entry.type === "message";
|
|
}
|
|
|
|
function isCompactionEntry(entry: BranchEntry): entry is Extract<BranchEntry, { type: "compaction" }> {
|
|
return entry.type === "compaction";
|
|
}
|
|
|
|
function isBranchSummaryEntry(entry: BranchEntry): entry is Extract<BranchEntry, { type: "branch_summary" }> {
|
|
return entry.type === "branch_summary";
|
|
}
|
|
|
|
function isTrackedMessage(message: AgentMessage): message is TrackedMessage {
|
|
return message.role === "user" || message.role === "assistant" || message.role === "toolResult";
|
|
}
|
|
|
|
function createDefaultSnapshot(): RuntimeSnapshot {
|
|
return {
|
|
mode: "balanced",
|
|
lastZone: "green",
|
|
ledger: createEmptyLedger(),
|
|
};
|
|
}
|
|
|
|
function getMessageContent(message: AgentMessage): string {
|
|
return "content" in message ? toText(message.content) : "";
|
|
}
|
|
|
|
function getMessageToolName(message: AgentMessage): string | undefined {
|
|
return message.role === "toolResult" ? message.toolName : undefined;
|
|
}
|
|
|
|
function rewriteContextMessage(message: { role: string; content: string; original: AgentMessage; distilled?: boolean }): AgentMessage {
|
|
if (!message.distilled || message.role !== "toolResult") {
|
|
return message.original;
|
|
}
|
|
|
|
return {
|
|
...(message.original as Extract<AgentMessage, { role: "toolResult" }>),
|
|
content: [{ type: "text", text: message.content }],
|
|
} as AgentMessage;
|
|
}
|
|
|
|
function findLatestSnapshotState(branch: BranchEntry[]): { snapshot: RuntimeSnapshot; index: number } | undefined {
|
|
for (let index = branch.length - 1; index >= 0; index -= 1) {
|
|
const entry = branch[index]!;
|
|
if (entry.type !== "custom" || entry.customType !== SNAPSHOT_ENTRY_TYPE) {
|
|
continue;
|
|
}
|
|
|
|
const snapshot = deserializeLatestSnapshot([entry]);
|
|
if (snapshot) {
|
|
return { snapshot, index };
|
|
}
|
|
}
|
|
|
|
return undefined;
|
|
}
|
|
|
|
function findLatestSessionSnapshot(entries: BranchEntry[]): RuntimeSnapshot | undefined {
|
|
let latest: RuntimeSnapshot | undefined;
|
|
let latestFreshness = -Infinity;
|
|
|
|
for (const entry of entries) {
|
|
if (entry.type !== "custom" || entry.customType !== SNAPSHOT_ENTRY_TYPE) {
|
|
continue;
|
|
}
|
|
|
|
const snapshot = deserializeLatestSnapshot([entry]);
|
|
if (!snapshot) {
|
|
continue;
|
|
}
|
|
|
|
const sessionItems = snapshot.ledger.items.filter((item) => item.scope === "session");
|
|
const freshness = sessionItems.length > 0 ? Math.max(...sessionItems.map((item) => item.timestamp)) : -Infinity;
|
|
if (freshness >= latestFreshness) {
|
|
latest = snapshot;
|
|
latestFreshness = freshness;
|
|
}
|
|
}
|
|
|
|
return latest;
|
|
}
|
|
|
|
function createSessionFallbackSnapshot(source?: RuntimeSnapshot): RuntimeSnapshot {
|
|
return {
|
|
mode: source?.mode ?? "balanced",
|
|
lastZone: "green",
|
|
ledger: {
|
|
items: structuredClone((source?.ledger.items ?? []).filter((item) => item.scope === "session")),
|
|
rollingSummary: "",
|
|
},
|
|
};
|
|
}
|
|
|
|
function overlaySessionLayer(base: RuntimeSnapshot, latestSessionSnapshot?: RuntimeSnapshot): RuntimeSnapshot {
|
|
const sessionItems = latestSessionSnapshot?.ledger.items.filter((item) => item.scope === "session") ?? [];
|
|
if (sessionItems.length === 0) {
|
|
return base;
|
|
}
|
|
|
|
return {
|
|
...base,
|
|
ledger: {
|
|
...base.ledger,
|
|
items: [
|
|
...structuredClone(base.ledger.items.filter((item) => item.scope !== "session")),
|
|
...structuredClone(sessionItems),
|
|
],
|
|
},
|
|
};
|
|
}
|
|
|
|
export default function contextManager(pi: ExtensionAPI) {
|
|
const runtime = createContextManagerRuntime({
|
|
mode: "balanced",
|
|
contextWindow: 200_000,
|
|
});
|
|
let pendingResumeInjection = false;
|
|
|
|
const syncContextWindow = (ctx: Pick<ExtensionContext, "model">) => {
|
|
runtime.setContextWindow(ctx.model?.contextWindow ?? 200_000);
|
|
};
|
|
|
|
const armResumeInjection = () => {
|
|
const snapshot = runtime.getSnapshot();
|
|
pendingResumeInjection = Boolean(snapshot.lastCompactionSummary || snapshot.lastBranchSummary) && runtime.buildResumePacket().trim().length > 0;
|
|
};
|
|
|
|
const replayBranchEntry = (entry: BranchEntry) => {
|
|
if (isMessageEntry(entry) && isTrackedMessage(entry.message)) {
|
|
runtime.ingest({
|
|
entryId: entry.id,
|
|
role: entry.message.role,
|
|
text: toText(entry.message.content),
|
|
timestamp: entry.message.timestamp,
|
|
isError: entry.message.role === "toolResult" ? entry.message.isError : undefined,
|
|
});
|
|
return;
|
|
}
|
|
|
|
if (isCompactionEntry(entry)) {
|
|
runtime.recordCompactionSummary(entry.summary, entry.id, Date.parse(entry.timestamp));
|
|
return;
|
|
}
|
|
|
|
if (isBranchSummaryEntry(entry)) {
|
|
runtime.recordBranchSummary(entry.summary, entry.id, Date.parse(entry.timestamp));
|
|
}
|
|
};
|
|
|
|
const rebuildRuntimeFromBranch = (
|
|
ctx: Pick<ExtensionContext, "model" | "sessionManager" | "ui">,
|
|
fallbackSnapshot: RuntimeSnapshot,
|
|
options?: { preferRuntimeMode?: boolean },
|
|
) => {
|
|
syncContextWindow(ctx);
|
|
|
|
const branch = ctx.sessionManager.getBranch();
|
|
const latestSessionSnapshot = findLatestSessionSnapshot(ctx.sessionManager.getEntries() as BranchEntry[]);
|
|
const restored = findLatestSnapshotState(branch);
|
|
const baseSnapshot = restored
|
|
? overlaySessionLayer(restored.snapshot, latestSessionSnapshot)
|
|
: createSessionFallbackSnapshot(latestSessionSnapshot ?? fallbackSnapshot);
|
|
|
|
runtime.restore({
|
|
...baseSnapshot,
|
|
mode: options?.preferRuntimeMode ? fallbackSnapshot.mode : baseSnapshot.mode,
|
|
});
|
|
|
|
const replayEntries = restored ? branch.slice(restored.index + 1) : branch;
|
|
for (const entry of replayEntries) {
|
|
replayBranchEntry(entry);
|
|
}
|
|
|
|
const snapshot = runtime.getSnapshot();
|
|
ctx.ui.setStatus("context-manager", `ctx ${snapshot.lastZone}`);
|
|
};
|
|
|
|
registerContextCommands(pi, {
|
|
getSnapshot: runtime.getSnapshot,
|
|
buildPacket: runtime.buildPacket,
|
|
buildResumePacket: runtime.buildResumePacket,
|
|
setMode: runtime.setMode,
|
|
rebuildFromBranch: async (commandCtx) => {
|
|
rebuildRuntimeFromBranch(commandCtx, runtime.getSnapshot(), { preferRuntimeMode: true });
|
|
armResumeInjection();
|
|
},
|
|
isResumePending: () => pendingResumeInjection,
|
|
});
|
|
|
|
pi.on("session_start", async (_event, ctx) => {
|
|
rebuildRuntimeFromBranch(ctx, createDefaultSnapshot());
|
|
armResumeInjection();
|
|
});
|
|
|
|
pi.on("session_tree", async (event, ctx) => {
|
|
rebuildRuntimeFromBranch(ctx, createDefaultSnapshot());
|
|
|
|
if (
|
|
event.summaryEntry &&
|
|
!ctx.sessionManager.getBranch().some((entry) => isBranchSummaryEntry(entry) && entry.id === event.summaryEntry.id)
|
|
) {
|
|
runtime.recordBranchSummary(event.summaryEntry.summary, event.summaryEntry.id, Date.parse(event.summaryEntry.timestamp));
|
|
}
|
|
|
|
armResumeInjection();
|
|
|
|
if (event.summaryEntry) {
|
|
pi.appendEntry(SNAPSHOT_ENTRY_TYPE, serializeSnapshot(runtime.getSnapshot()));
|
|
}
|
|
});
|
|
|
|
pi.on("tool_result", async (event) => {
|
|
runtime.ingest({
|
|
entryId: event.toolCallId,
|
|
role: "toolResult",
|
|
text: toText(event.content),
|
|
timestamp: Date.now(),
|
|
});
|
|
});
|
|
|
|
pi.on("turn_end", async (_event, ctx) => {
|
|
rebuildRuntimeFromBranch(ctx, runtime.getSnapshot(), { preferRuntimeMode: true });
|
|
|
|
const usage = ctx.getContextUsage();
|
|
if (usage?.tokens !== null && usage?.tokens !== undefined) {
|
|
runtime.observeTokens(usage.tokens);
|
|
}
|
|
|
|
const snapshot = runtime.getSnapshot();
|
|
pi.appendEntry(SNAPSHOT_ENTRY_TYPE, serializeSnapshot(snapshot));
|
|
ctx.ui.setStatus("context-manager", `ctx ${snapshot.lastZone}`);
|
|
});
|
|
|
|
pi.on("context", async (event, ctx) => {
|
|
syncContextWindow(ctx);
|
|
const snapshot = runtime.getSnapshot();
|
|
const policy = adjustPolicyForZone(runtime.getPolicy(), snapshot.lastZone);
|
|
const normalized = event.messages.map((message) => ({
|
|
role: message.role,
|
|
content: getMessageContent(message),
|
|
toolName: getMessageToolName(message),
|
|
original: message,
|
|
}));
|
|
|
|
const pruned = pruneContextMessages(normalized, policy);
|
|
const nextMessages = pruned.map((message) =>
|
|
rewriteContextMessage(message as { role: string; content: string; original: AgentMessage; distilled?: boolean }),
|
|
);
|
|
const resumeText = pendingResumeInjection ? runtime.buildResumePacket() : "";
|
|
const packetText = pendingResumeInjection ? "" : runtime.buildPacket().text;
|
|
const injectedText = resumeText || packetText;
|
|
|
|
if (!injectedText) {
|
|
return { messages: nextMessages };
|
|
}
|
|
|
|
if (resumeText) {
|
|
pendingResumeInjection = false;
|
|
}
|
|
|
|
return {
|
|
messages: [
|
|
{
|
|
role: "custom",
|
|
customType: resumeText ? "context-manager.resume" : "context-manager.packet",
|
|
content: injectedText,
|
|
display: false,
|
|
timestamp: Date.now(),
|
|
} as any,
|
|
...nextMessages,
|
|
],
|
|
};
|
|
});
|
|
|
|
pi.on("session_before_compact", async (event, ctx) => {
|
|
syncContextWindow(ctx);
|
|
|
|
try {
|
|
return {
|
|
compaction: {
|
|
summary: buildCompactionSummaryFromPreparation({
|
|
messagesToSummarize: event.preparation.messagesToSummarize,
|
|
turnPrefixMessages: event.preparation.turnPrefixMessages,
|
|
previousSummary: event.preparation.previousSummary,
|
|
fileOps: event.preparation.fileOps,
|
|
customInstructions: event.customInstructions,
|
|
}),
|
|
firstKeptEntryId: event.preparation.firstKeptEntryId,
|
|
tokensBefore: event.preparation.tokensBefore,
|
|
details: event.preparation.fileOps,
|
|
},
|
|
};
|
|
} catch (error) {
|
|
ctx.ui.notify(`context-manager compaction fallback: ${error instanceof Error ? error.message : String(error)}`, "warning");
|
|
return;
|
|
}
|
|
});
|
|
|
|
pi.on("session_before_tree", async (event, ctx) => {
|
|
syncContextWindow(ctx);
|
|
if (!event.preparation.userWantsSummary) return;
|
|
return {
|
|
summary: {
|
|
summary: buildBranchSummaryFromEntries({
|
|
branchLabel: "branch handoff",
|
|
entriesToSummarize: event.preparation.entriesToSummarize,
|
|
customInstructions: event.preparation.customInstructions,
|
|
replaceInstructions: event.preparation.replaceInstructions,
|
|
commonAncestorId: event.preparation.commonAncestorId,
|
|
}),
|
|
},
|
|
};
|
|
});
|
|
|
|
pi.on("session_compact", async (event, ctx) => {
|
|
runtime.recordCompactionSummary(event.compactionEntry.summary, event.compactionEntry.id, Date.parse(event.compactionEntry.timestamp));
|
|
pi.appendEntry(SNAPSHOT_ENTRY_TYPE, serializeSnapshot(runtime.getSnapshot()));
|
|
armResumeInjection();
|
|
ctx.ui.setStatus("context-manager", `ctx ${runtime.getSnapshot().lastZone}`);
|
|
});
|
|
}
|