fix(chat): stop retry spirals on tool failures
This commit is contained in:
@@ -62,6 +62,9 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
if message.role == "tool"
|
||||
and message.tool_call_id
|
||||
and not self._is_required_param_tool_error_message_content(message.content)
|
||||
and not self._is_execution_failure_tool_error_message_content(
|
||||
message.content
|
||||
)
|
||||
}
|
||||
|
||||
messages = [
|
||||
@@ -76,6 +79,13 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
and self._is_required_param_tool_error_message_content(message.content)
|
||||
):
|
||||
continue
|
||||
if (
|
||||
message.role == "tool"
|
||||
and self._is_execution_failure_tool_error_message_content(
|
||||
message.content
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
payload = {
|
||||
"role": message.role,
|
||||
@@ -152,6 +162,24 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _is_execution_failure_tool_error(cls, result):
|
||||
if not isinstance(result, dict):
|
||||
return False
|
||||
|
||||
error_text = result.get("error")
|
||||
if not isinstance(error_text, str) or not error_text.strip():
|
||||
return False
|
||||
|
||||
return not cls._is_required_param_tool_error(result)
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable_execution_failure(result):
|
||||
if not isinstance(result, dict):
|
||||
return False
|
||||
|
||||
return result.get("retryable", True) is not False
|
||||
|
||||
@classmethod
|
||||
def _is_required_param_tool_error_message_content(cls, content):
|
||||
if not isinstance(content, str):
|
||||
@@ -164,6 +192,18 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
|
||||
return cls._is_required_param_tool_error(parsed)
|
||||
|
||||
@classmethod
|
||||
def _is_execution_failure_tool_error_message_content(cls, content):
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
|
||||
return cls._is_execution_failure_tool_error(parsed)
|
||||
|
||||
@staticmethod
|
||||
def _build_required_param_error_event(tool_name, result):
|
||||
tool_error = result.get("error") if isinstance(result, dict) else ""
|
||||
@@ -190,6 +230,24 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
normalized_error = error_text.strip().lower()
|
||||
return "location" in normalized_error
|
||||
|
||||
@staticmethod
|
||||
def _is_search_places_geocode_error(tool_name, result):
|
||||
if tool_name != "search_places" or not isinstance(result, dict):
|
||||
return False
|
||||
|
||||
error_text = result.get("error")
|
||||
if not isinstance(error_text, str):
|
||||
return False
|
||||
|
||||
return error_text.strip().lower().startswith("could not geocode location")
|
||||
|
||||
@classmethod
|
||||
def _is_search_places_location_retry_candidate_error(cls, tool_name, result):
|
||||
return cls._is_search_places_missing_location_required_error(
|
||||
tool_name,
|
||||
result,
|
||||
) or cls._is_search_places_geocode_error(tool_name, result)
|
||||
|
||||
@staticmethod
|
||||
def _build_search_places_location_clarification_message():
|
||||
return (
|
||||
@@ -198,6 +256,21 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
"activities, or lodging."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_execution_error_event(tool_name, result):
|
||||
tool_error = (
|
||||
(result or {}).get("error")
|
||||
if isinstance(result, dict)
|
||||
else "Tool execution failed"
|
||||
)
|
||||
return {
|
||||
"error": (
|
||||
f"The assistant could not complete '{tool_name}' ({tool_error}). "
|
||||
"Please try again in a moment or adjust your request."
|
||||
),
|
||||
"error_category": "tool_execution_error",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_trip_context_destination(destination):
|
||||
destination_text = (destination or "").strip()
|
||||
@@ -420,11 +493,13 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
)
|
||||
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
MAX_ALL_FAILURE_ROUNDS = 3
|
||||
|
||||
async def event_stream():
|
||||
current_messages = list(llm_messages)
|
||||
encountered_error = False
|
||||
tool_iterations = 0
|
||||
all_failure_rounds = 0
|
||||
|
||||
while tool_iterations < MAX_TOOL_ITERATIONS:
|
||||
content_chunks = []
|
||||
@@ -472,10 +547,11 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
assistant_content = "".join(content_chunks)
|
||||
|
||||
if tool_calls_accumulator:
|
||||
tool_iterations += 1
|
||||
successful_tool_calls = []
|
||||
successful_tool_messages = []
|
||||
successful_tool_chat_entries = []
|
||||
first_execution_failure = None
|
||||
encountered_permanent_failure = False
|
||||
|
||||
for tool_call in tool_calls_accumulator:
|
||||
function_payload = tool_call.get("function") or {}
|
||||
@@ -519,7 +595,7 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
**prepared_arguments,
|
||||
)
|
||||
|
||||
if self._is_search_places_missing_location_required_error(
|
||||
if self._is_search_places_location_retry_candidate_error(
|
||||
function_name,
|
||||
result,
|
||||
):
|
||||
@@ -552,7 +628,11 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
**retry_arguments,
|
||||
)
|
||||
|
||||
if not self._is_required_param_tool_error(retry_result):
|
||||
if not self._is_required_param_tool_error(
|
||||
retry_result
|
||||
) and not self._is_execution_failure_tool_error(
|
||||
retry_result
|
||||
):
|
||||
result = retry_result
|
||||
tool_call_for_history = {
|
||||
**tool_call,
|
||||
@@ -630,6 +710,13 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
if self._is_execution_failure_tool_error(result):
|
||||
if first_execution_failure is None:
|
||||
first_execution_failure = (function_name, result)
|
||||
if not self._is_retryable_execution_failure(result):
|
||||
encountered_permanent_failure = True
|
||||
continue
|
||||
|
||||
result_content = serialize_tool_result(result)
|
||||
|
||||
successful_tool_calls.append(tool_call_for_history)
|
||||
@@ -659,6 +746,41 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
}
|
||||
yield f"data: {json.dumps(tool_event)}\n\n"
|
||||
|
||||
if not successful_tool_calls and first_execution_failure:
|
||||
if encountered_permanent_failure:
|
||||
all_failure_rounds = MAX_ALL_FAILURE_ROUNDS
|
||||
else:
|
||||
all_failure_rounds += 1
|
||||
|
||||
if all_failure_rounds >= MAX_ALL_FAILURE_ROUNDS:
|
||||
failed_tool_name, failed_tool_result = (
|
||||
first_execution_failure
|
||||
)
|
||||
error_event = self._build_tool_execution_error_event(
|
||||
failed_tool_name,
|
||||
failed_tool_result,
|
||||
)
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create,
|
||||
thread_sensitive=True,
|
||||
)(
|
||||
conversation=conversation,
|
||||
role="assistant",
|
||||
content=error_event["error"],
|
||||
)
|
||||
await sync_to_async(
|
||||
conversation.save,
|
||||
thread_sensitive=True,
|
||||
)(update_fields=["updated_at"])
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
continue
|
||||
|
||||
all_failure_rounds = 0
|
||||
tool_iterations += 1
|
||||
|
||||
assistant_with_tools = {
|
||||
"role": "assistant",
|
||||
"content": assistant_content,
|
||||
|
||||
Reference in New Issue
Block a user