Files
dotfiles/.pi/agent/extensions/context-manager/index.ts
2026-04-09 23:14:57 +01:00

354 lines
12 KiB
TypeScript

import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { ExtensionAPI, ExtensionContext } from "@mariozechner/pi-coding-agent";
import { adjustPolicyForZone } from "./src/config.ts";
import { deserializeLatestSnapshot, serializeSnapshot, SNAPSHOT_ENTRY_TYPE, type RuntimeSnapshot } from "./src/persist.ts";
import { createEmptyLedger } from "./src/ledger.ts";
import { pruneContextMessages } from "./src/prune.ts";
import { createContextManagerRuntime } from "./src/runtime.ts";
import { registerContextCommands } from "./src/commands.ts";
import { buildBranchSummaryFromEntries, buildCompactionSummaryFromPreparation } from "./src/summaries.ts";
type TrackedMessage = Extract<AgentMessage, { role: "user" | "assistant" | "toolResult" }>;
type BranchEntry = ReturnType<ExtensionContext["sessionManager"]["getBranch"]>[number];
function isTextPart(part: unknown): part is { type: "text"; text?: string } {
return typeof part === "object" && part !== null && "type" in part && (part as { type?: unknown }).type === "text";
}
function toText(content: unknown): string {
if (typeof content === "string") return content;
if (!Array.isArray(content)) return "";
return content
.map((part) => {
if (!isTextPart(part)) return "";
return typeof part.text === "string" ? part.text : "";
})
.join("\n")
.trim();
}
function isMessageEntry(entry: BranchEntry): entry is Extract<BranchEntry, { type: "message" }> {
return entry.type === "message";
}
function isCompactionEntry(entry: BranchEntry): entry is Extract<BranchEntry, { type: "compaction" }> {
return entry.type === "compaction";
}
function isBranchSummaryEntry(entry: BranchEntry): entry is Extract<BranchEntry, { type: "branch_summary" }> {
return entry.type === "branch_summary";
}
function isTrackedMessage(message: AgentMessage): message is TrackedMessage {
return message.role === "user" || message.role === "assistant" || message.role === "toolResult";
}
function createDefaultSnapshot(): RuntimeSnapshot {
return {
mode: "balanced",
lastZone: "green",
ledger: createEmptyLedger(),
};
}
function getMessageContent(message: AgentMessage): string {
return "content" in message ? toText(message.content) : "";
}
function getMessageToolName(message: AgentMessage): string | undefined {
return message.role === "toolResult" ? message.toolName : undefined;
}
function rewriteContextMessage(message: { role: string; content: string; original: AgentMessage; distilled?: boolean }): AgentMessage {
if (!message.distilled || message.role !== "toolResult") {
return message.original;
}
return {
...(message.original as Extract<AgentMessage, { role: "toolResult" }>),
content: [{ type: "text", text: message.content }],
} as AgentMessage;
}
function findLatestSnapshotState(branch: BranchEntry[]): { snapshot: RuntimeSnapshot; index: number } | undefined {
for (let index = branch.length - 1; index >= 0; index -= 1) {
const entry = branch[index]!;
if (entry.type !== "custom" || entry.customType !== SNAPSHOT_ENTRY_TYPE) {
continue;
}
const snapshot = deserializeLatestSnapshot([entry]);
if (snapshot) {
return { snapshot, index };
}
}
return undefined;
}
function findLatestSessionSnapshot(entries: BranchEntry[]): RuntimeSnapshot | undefined {
let latest: RuntimeSnapshot | undefined;
let latestFreshness = -Infinity;
for (const entry of entries) {
if (entry.type !== "custom" || entry.customType !== SNAPSHOT_ENTRY_TYPE) {
continue;
}
const snapshot = deserializeLatestSnapshot([entry]);
if (!snapshot) {
continue;
}
const sessionItems = snapshot.ledger.items.filter((item) => item.scope === "session");
const freshness = sessionItems.length > 0 ? Math.max(...sessionItems.map((item) => item.timestamp)) : -Infinity;
if (freshness >= latestFreshness) {
latest = snapshot;
latestFreshness = freshness;
}
}
return latest;
}
function createSessionFallbackSnapshot(source?: RuntimeSnapshot): RuntimeSnapshot {
return {
mode: source?.mode ?? "balanced",
lastZone: "green",
ledger: {
items: structuredClone((source?.ledger.items ?? []).filter((item) => item.scope === "session")),
rollingSummary: "",
},
};
}
function overlaySessionLayer(base: RuntimeSnapshot, latestSessionSnapshot?: RuntimeSnapshot): RuntimeSnapshot {
const sessionItems = latestSessionSnapshot?.ledger.items.filter((item) => item.scope === "session") ?? [];
if (sessionItems.length === 0) {
return base;
}
return {
...base,
ledger: {
...base.ledger,
items: [
...structuredClone(base.ledger.items.filter((item) => item.scope !== "session")),
...structuredClone(sessionItems),
],
},
};
}
export default function contextManager(pi: ExtensionAPI) {
const runtime = createContextManagerRuntime({
mode: "balanced",
contextWindow: 200_000,
});
let pendingResumeInjection = false;
const syncContextWindow = (ctx: Pick<ExtensionContext, "model">) => {
runtime.setContextWindow(ctx.model?.contextWindow ?? 200_000);
};
const armResumeInjection = () => {
const snapshot = runtime.getSnapshot();
pendingResumeInjection = Boolean(snapshot.lastCompactionSummary || snapshot.lastBranchSummary) && runtime.buildResumePacket().trim().length > 0;
};
const replayBranchEntry = (entry: BranchEntry) => {
if (isMessageEntry(entry) && isTrackedMessage(entry.message)) {
runtime.ingest({
entryId: entry.id,
role: entry.message.role,
text: toText(entry.message.content),
timestamp: entry.message.timestamp,
isError: entry.message.role === "toolResult" ? entry.message.isError : undefined,
});
return;
}
if (isCompactionEntry(entry)) {
runtime.recordCompactionSummary(entry.summary, entry.id, Date.parse(entry.timestamp));
return;
}
if (isBranchSummaryEntry(entry)) {
runtime.recordBranchSummary(entry.summary, entry.id, Date.parse(entry.timestamp));
}
};
const rebuildRuntimeFromBranch = (
ctx: Pick<ExtensionContext, "model" | "sessionManager" | "ui">,
fallbackSnapshot: RuntimeSnapshot,
options?: { preferRuntimeMode?: boolean },
) => {
syncContextWindow(ctx);
const branch = ctx.sessionManager.getBranch();
const latestSessionSnapshot = findLatestSessionSnapshot(ctx.sessionManager.getEntries() as BranchEntry[]);
const restored = findLatestSnapshotState(branch);
const baseSnapshot = restored
? overlaySessionLayer(restored.snapshot, latestSessionSnapshot)
: createSessionFallbackSnapshot(latestSessionSnapshot ?? fallbackSnapshot);
runtime.restore({
...baseSnapshot,
mode: options?.preferRuntimeMode ? fallbackSnapshot.mode : baseSnapshot.mode,
});
const replayEntries = restored ? branch.slice(restored.index + 1) : branch;
for (const entry of replayEntries) {
replayBranchEntry(entry);
}
const snapshot = runtime.getSnapshot();
ctx.ui.setStatus("context-manager", `ctx ${snapshot.lastZone}`);
};
registerContextCommands(pi, {
getSnapshot: runtime.getSnapshot,
buildPacket: runtime.buildPacket,
buildResumePacket: runtime.buildResumePacket,
setMode: runtime.setMode,
rebuildFromBranch: async (commandCtx) => {
rebuildRuntimeFromBranch(commandCtx, runtime.getSnapshot(), { preferRuntimeMode: true });
armResumeInjection();
},
isResumePending: () => pendingResumeInjection,
});
pi.on("session_start", async (_event, ctx) => {
rebuildRuntimeFromBranch(ctx, createDefaultSnapshot());
armResumeInjection();
});
pi.on("session_tree", async (event, ctx) => {
rebuildRuntimeFromBranch(ctx, createDefaultSnapshot());
if (
event.summaryEntry &&
!ctx.sessionManager.getBranch().some((entry) => isBranchSummaryEntry(entry) && entry.id === event.summaryEntry.id)
) {
runtime.recordBranchSummary(event.summaryEntry.summary, event.summaryEntry.id, Date.parse(event.summaryEntry.timestamp));
}
armResumeInjection();
if (event.summaryEntry) {
pi.appendEntry(SNAPSHOT_ENTRY_TYPE, serializeSnapshot(runtime.getSnapshot()));
}
});
pi.on("tool_result", async (event) => {
runtime.ingest({
entryId: event.toolCallId,
role: "toolResult",
text: toText(event.content),
timestamp: Date.now(),
});
});
pi.on("turn_end", async (_event, ctx) => {
rebuildRuntimeFromBranch(ctx, runtime.getSnapshot(), { preferRuntimeMode: true });
const usage = ctx.getContextUsage();
if (usage?.tokens !== null && usage?.tokens !== undefined) {
runtime.observeTokens(usage.tokens);
}
const snapshot = runtime.getSnapshot();
pi.appendEntry(SNAPSHOT_ENTRY_TYPE, serializeSnapshot(snapshot));
ctx.ui.setStatus("context-manager", `ctx ${snapshot.lastZone}`);
});
pi.on("context", async (event, ctx) => {
syncContextWindow(ctx);
const snapshot = runtime.getSnapshot();
const policy = adjustPolicyForZone(runtime.getPolicy(), snapshot.lastZone);
const normalized = event.messages.map((message) => ({
role: message.role,
content: getMessageContent(message),
toolName: getMessageToolName(message),
original: message,
}));
const pruned = pruneContextMessages(normalized, policy);
const nextMessages = pruned.map((message) =>
rewriteContextMessage(message as { role: string; content: string; original: AgentMessage; distilled?: boolean }),
);
const resumeText = pendingResumeInjection ? runtime.buildResumePacket() : "";
const packetText = pendingResumeInjection ? "" : runtime.buildPacket().text;
const injectedText = resumeText || packetText;
if (!injectedText) {
return { messages: nextMessages };
}
if (resumeText) {
pendingResumeInjection = false;
}
return {
messages: [
{
role: "custom",
customType: resumeText ? "context-manager.resume" : "context-manager.packet",
content: injectedText,
display: false,
timestamp: Date.now(),
} as any,
...nextMessages,
],
};
});
pi.on("session_before_compact", async (event, ctx) => {
syncContextWindow(ctx);
try {
return {
compaction: {
summary: buildCompactionSummaryFromPreparation({
messagesToSummarize: event.preparation.messagesToSummarize,
turnPrefixMessages: event.preparation.turnPrefixMessages,
previousSummary: event.preparation.previousSummary,
fileOps: event.preparation.fileOps,
customInstructions: event.customInstructions,
}),
firstKeptEntryId: event.preparation.firstKeptEntryId,
tokensBefore: event.preparation.tokensBefore,
details: event.preparation.fileOps,
},
};
} catch (error) {
ctx.ui.notify(`context-manager compaction fallback: ${error instanceof Error ? error.message : String(error)}`, "warning");
return;
}
});
pi.on("session_before_tree", async (event, ctx) => {
syncContextWindow(ctx);
if (!event.preparation.userWantsSummary) return;
return {
summary: {
summary: buildBranchSummaryFromEntries({
branchLabel: "branch handoff",
entriesToSummarize: event.preparation.entriesToSummarize,
customInstructions: event.preparation.customInstructions,
replaceInstructions: event.preparation.replaceInstructions,
commonAncestorId: event.preparation.commonAncestorId,
}),
},
};
});
pi.on("session_compact", async (event, ctx) => {
runtime.recordCompactionSummary(event.compactionEntry.summary, event.compactionEntry.id, Date.parse(event.compactionEntry.timestamp));
pi.appendEntry(SNAPSHOT_ENTRY_TYPE, serializeSnapshot(runtime.getSnapshot()));
armResumeInjection();
ctx.ui.setStatus("context-manager", `ctx ${runtime.getSnapshot().lastZone}`);
});
}