fix(chat): stop retry spirals on tool failures

This commit is contained in:
2026-03-10 18:05:34 +00:00
parent 212ce33e36
commit dbabbdf9f0
6 changed files with 490 additions and 14 deletions

View File

@@ -295,6 +295,7 @@ def web_search(user, query: str, location_context: str | None = None) -> dict:
return {
"error": "Web search is not available (duckduckgo-search not installed)",
"results": [],
"retryable": False,
}
except Exception as exc:
error_str = str(exc).lower()
@@ -637,9 +638,9 @@ def execute_tool(tool_name, user, **kwargs):
try:
return tool_fn(user=user, **filtered_kwargs)
except Exception as exc:
except Exception:
logger.exception("Tool %s failed", tool_name)
return {"error": str(exc)}
return {"error": "Tool execution failed"}
AGENT_TOOLS = get_tool_schemas()

View File

@@ -1,4 +1,5 @@
import json
from unittest import mock
from unittest.mock import patch
from django.contrib.auth import get_user_model
@@ -6,13 +7,64 @@ from django.test import TestCase
from rest_framework.test import APITransactionTestCase
from adventures.models import Collection, CollectionItineraryItem, Location
from chat.agent_tools import add_to_itinerary, get_trip_details
from chat.agent_tools import (
add_to_itinerary,
execute_tool,
get_trip_details,
web_search,
)
from chat.views import ChatViewSet
User = get_user_model()
class WebSearchToolFailureClassificationTests(TestCase):
def test_web_search_import_error_sets_retryable_false(self):
import builtins
original_import = builtins.__import__
def controlled_import(name, *args, **kwargs):
if name == "duckduckgo_search":
raise ImportError("missing dependency")
return original_import(name, *args, **kwargs)
user = User.objects.create_user(
username="chat-web-search-user",
email="chat-web-search-user@example.com",
password="password123",
)
with mock.patch("builtins.__import__", side_effect=controlled_import):
result = web_search(user, query="best restaurants")
self.assertEqual(
result.get("error"),
"Web search is not available (duckduckgo-search not installed)",
)
self.assertEqual(result.get("retryable"), False)
class ExecuteToolErrorSanitizationTests(TestCase):
def test_execute_tool_catch_all_returns_sanitized_error_message(self):
def raising_tool(user):
raise RuntimeError("sensitive backend detail")
user = User.objects.create_user(
username="chat-tool-sanitize-user",
email="chat-tool-sanitize-user@example.com",
password="password123",
)
with mock.patch.dict(
"chat.agent_tools._REGISTERED_TOOLS", {"boom": raising_tool}
):
result = execute_tool("boom", user)
self.assertEqual(result, {"error": "Tool execution failed"})
class ChatAgentToolSharedTripAccessTests(TestCase):
def setUp(self):
self.owner = User.objects.create_user(
@@ -173,6 +225,40 @@ class ChatViewSetToolValidationBoundaryTests(TestCase):
)
)
def test_execution_failure_error_detected(self):
self.assertTrue(
ChatViewSet._is_execution_failure_tool_error(
{"error": "Web search failed. Please try again."}
)
)
def test_required_error_not_treated_as_execution_failure(self):
self.assertFalse(
ChatViewSet._is_execution_failure_tool_error(
{"error": "location is required"}
)
)
def test_search_places_geocode_error_detected_for_location_retry(self):
self.assertTrue(
ChatViewSet._is_search_places_location_retry_candidate_error(
"search_places",
{"error": "Could not geocode location: ???"},
)
)
def test_retryable_execution_failure_defaults_true(self):
self.assertTrue(
ChatViewSet._is_retryable_execution_failure({"error": "Temporary outage"})
)
def test_retryable_execution_failure_honors_false_flag(self):
self.assertFalse(
ChatViewSet._is_retryable_execution_failure(
{"error": "Not installed", "retryable": False}
)
)
def test_likely_location_reply_heuristic_positive_case(self):
self.assertTrue(ChatViewSet._is_likely_location_reply("london"))
@@ -557,3 +643,217 @@ class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase):
self.assertEqual(mock_execute_tool.call_count, 2)
second_call_kwargs = mock_execute_tool.call_args_list[1].kwargs
self.assertEqual(second_call_kwargs.get("location"), "Paris, France")
@patch("chat.views.execute_tool")
@patch("chat.views.stream_chat_completion")
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
def test_geocode_failure_retries_with_trip_context_location(
self,
_mock_auto_profile,
mock_stream_chat_completion,
mock_execute_tool,
):
user = User.objects.create_user(
username="chat-geocode-retry-user",
email="chat-geocode-retry-user@example.com",
password="password123",
)
self.client.force_authenticate(user=user)
conversation_response = self.client.post(
"/api/chat/conversations/",
{"title": "Geocode Retry Test"},
format="json",
)
self.assertEqual(conversation_response.status_code, 201)
conversation_id = conversation_response.json()["id"]
async def first_stream(*args, **kwargs):
yield 'data: {"tool_calls": [{"index": 0, "id": "call_1", "type": "function", "function": {"name": "search_places", "arguments": "{}"}}]}\n\n'
yield "data: [DONE]\n\n"
async def second_stream(*args, **kwargs):
yield 'data: {"content": "Here are options in Lisbon."}\n\n'
yield "data: [DONE]\n\n"
mock_stream_chat_completion.side_effect = [first_stream(), second_stream()]
mock_execute_tool.side_effect = [
{"error": "Could not geocode location: invalid"},
{"results": [{"name": "Time Out Market"}]},
]
response = self.client.post(
f"/api/chat/conversations/{conversation_id}/send_message/",
{
"message": "Find restaurants",
"destination": "Lisbon, Portugal",
},
format="json",
)
self.assertEqual(response.status_code, 200)
chunks = [
chunk.decode("utf-8")
if isinstance(chunk, (bytes, bytearray))
else str(chunk)
for chunk in response.streaming_content
]
payload_lines = [
chunk.strip()[len("data: ") :]
for chunk in chunks
if chunk.strip().startswith("data: ")
]
json_payloads = [
json.loads(payload) for payload in payload_lines if payload != "[DONE]"
]
self.assertTrue(any("tool_result" in payload for payload in json_payloads))
self.assertTrue(
any(
payload.get("content") == "Here are options in Lisbon."
for payload in json_payloads
)
)
self.assertEqual(mock_execute_tool.call_count, 2)
second_call_kwargs = mock_execute_tool.call_args_list[1].kwargs
self.assertEqual(second_call_kwargs.get("location"), "Lisbon, Portugal")
class ChatViewSetToolExecutionFailureLoopTests(APITransactionTestCase):
@patch("chat.views.execute_tool")
@patch("chat.views.stream_chat_completion")
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
def test_all_failure_rounds_stop_with_execution_error_before_tool_cap(
self,
_mock_auto_profile,
mock_stream_chat_completion,
mock_execute_tool,
):
user = User.objects.create_user(
username="chat-loop-failure-user",
email="chat-loop-failure-user@example.com",
password="password123",
)
self.client.force_authenticate(user=user)
conversation_response = self.client.post(
"/api/chat/conversations/",
{"title": "Failure Loop Test"},
format="json",
)
self.assertEqual(conversation_response.status_code, 201)
conversation_id = conversation_response.json()["id"]
async def failing_stream(*args, **kwargs):
yield 'data: {"tool_calls": [{"index": 0, "id": "call_w", "type": "function", "function": {"name": "web_search", "arguments": "{\\"query\\":\\"restaurants\\"}"}}]}\n\n'
yield "data: [DONE]\n\n"
mock_stream_chat_completion.side_effect = failing_stream
mock_execute_tool.return_value = {
"error": "Web search failed. Please try again.",
"results": [],
}
response = self.client.post(
f"/api/chat/conversations/{conversation_id}/send_message/",
{"message": "Find restaurants near me"},
format="json",
)
self.assertEqual(response.status_code, 200)
chunks = [
chunk.decode("utf-8")
if isinstance(chunk, (bytes, bytearray))
else str(chunk)
for chunk in response.streaming_content
]
payload_lines = [
chunk.strip()[len("data: ") :]
for chunk in chunks
if chunk.strip().startswith("data: ")
]
json_payloads = [
json.loads(payload) for payload in payload_lines if payload != "[DONE]"
]
self.assertTrue(
any(
payload.get("error_category") == "tool_execution_error"
for payload in json_payloads
)
)
self.assertFalse(
any(
payload.get("error_category") == "tool_loop_limit"
for payload in json_payloads
)
)
self.assertFalse(any("tool_result" in payload for payload in json_payloads))
self.assertEqual(mock_stream_chat_completion.call_count, 3)
self.assertEqual(mock_execute_tool.call_count, 3)
@patch("chat.views.execute_tool")
@patch("chat.views.stream_chat_completion")
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
def test_permanent_execution_failure_stops_immediately(
self,
_mock_auto_profile,
mock_stream_chat_completion,
mock_execute_tool,
):
user = User.objects.create_user(
username="chat-permanent-failure-user",
email="chat-permanent-failure-user@example.com",
password="password123",
)
self.client.force_authenticate(user=user)
conversation_response = self.client.post(
"/api/chat/conversations/",
{"title": "Permanent Failure Test"},
format="json",
)
self.assertEqual(conversation_response.status_code, 201)
conversation_id = conversation_response.json()["id"]
async def failing_stream(*args, **kwargs):
yield 'data: {"tool_calls": [{"index": 0, "id": "call_w", "type": "function", "function": {"name": "web_search", "arguments": "{\\"query\\":\\"restaurants\\"}"}}]}\n\n'
yield "data: [DONE]\n\n"
mock_stream_chat_completion.side_effect = failing_stream
mock_execute_tool.return_value = {
"error": "Web search is not available (duckduckgo-search not installed)",
"results": [],
"retryable": False,
}
response = self.client.post(
f"/api/chat/conversations/{conversation_id}/send_message/",
{"message": "Find restaurants near me"},
format="json",
)
self.assertEqual(response.status_code, 200)
chunks = [
chunk.decode("utf-8")
if isinstance(chunk, (bytes, bytearray))
else str(chunk)
for chunk in response.streaming_content
]
payload_lines = [
chunk.strip()[len("data: ") :]
for chunk in chunks
if chunk.strip().startswith("data: ")
]
json_payloads = [
json.loads(payload) for payload in payload_lines if payload != "[DONE]"
]
self.assertTrue(
any(
payload.get("error_category") == "tool_execution_error"
for payload in json_payloads
)
)
self.assertEqual(mock_stream_chat_completion.call_count, 1)
self.assertEqual(mock_execute_tool.call_count, 1)

View File

@@ -62,6 +62,9 @@ class ChatViewSet(viewsets.ModelViewSet):
if message.role == "tool"
and message.tool_call_id
and not self._is_required_param_tool_error_message_content(message.content)
and not self._is_execution_failure_tool_error_message_content(
message.content
)
}
messages = [
@@ -76,6 +79,13 @@ class ChatViewSet(viewsets.ModelViewSet):
and self._is_required_param_tool_error_message_content(message.content)
):
continue
if (
message.role == "tool"
and self._is_execution_failure_tool_error_message_content(
message.content
)
):
continue
payload = {
"role": message.role,
@@ -152,6 +162,24 @@ class ChatViewSet(viewsets.ModelViewSet):
)
)
@classmethod
def _is_execution_failure_tool_error(cls, result):
if not isinstance(result, dict):
return False
error_text = result.get("error")
if not isinstance(error_text, str) or not error_text.strip():
return False
return not cls._is_required_param_tool_error(result)
@staticmethod
def _is_retryable_execution_failure(result):
if not isinstance(result, dict):
return False
return result.get("retryable", True) is not False
@classmethod
def _is_required_param_tool_error_message_content(cls, content):
if not isinstance(content, str):
@@ -164,6 +192,18 @@ class ChatViewSet(viewsets.ModelViewSet):
return cls._is_required_param_tool_error(parsed)
@classmethod
def _is_execution_failure_tool_error_message_content(cls, content):
if not isinstance(content, str):
return False
try:
parsed = json.loads(content)
except json.JSONDecodeError:
return False
return cls._is_execution_failure_tool_error(parsed)
@staticmethod
def _build_required_param_error_event(tool_name, result):
tool_error = result.get("error") if isinstance(result, dict) else ""
@@ -190,6 +230,24 @@ class ChatViewSet(viewsets.ModelViewSet):
normalized_error = error_text.strip().lower()
return "location" in normalized_error
@staticmethod
def _is_search_places_geocode_error(tool_name, result):
if tool_name != "search_places" or not isinstance(result, dict):
return False
error_text = result.get("error")
if not isinstance(error_text, str):
return False
return error_text.strip().lower().startswith("could not geocode location")
@classmethod
def _is_search_places_location_retry_candidate_error(cls, tool_name, result):
return cls._is_search_places_missing_location_required_error(
tool_name,
result,
) or cls._is_search_places_geocode_error(tool_name, result)
@staticmethod
def _build_search_places_location_clarification_message():
return (
@@ -198,6 +256,21 @@ class ChatViewSet(viewsets.ModelViewSet):
"activities, or lodging."
)
@staticmethod
def _build_tool_execution_error_event(tool_name, result):
tool_error = (
(result or {}).get("error")
if isinstance(result, dict)
else "Tool execution failed"
)
return {
"error": (
f"The assistant could not complete '{tool_name}' ({tool_error}). "
"Please try again in a moment or adjust your request."
),
"error_category": "tool_execution_error",
}
@staticmethod
def _normalize_trip_context_destination(destination):
destination_text = (destination or "").strip()
@@ -420,11 +493,13 @@ class ChatViewSet(viewsets.ModelViewSet):
)
MAX_TOOL_ITERATIONS = 10
MAX_ALL_FAILURE_ROUNDS = 3
async def event_stream():
current_messages = list(llm_messages)
encountered_error = False
tool_iterations = 0
all_failure_rounds = 0
while tool_iterations < MAX_TOOL_ITERATIONS:
content_chunks = []
@@ -472,10 +547,11 @@ class ChatViewSet(viewsets.ModelViewSet):
assistant_content = "".join(content_chunks)
if tool_calls_accumulator:
tool_iterations += 1
successful_tool_calls = []
successful_tool_messages = []
successful_tool_chat_entries = []
first_execution_failure = None
encountered_permanent_failure = False
for tool_call in tool_calls_accumulator:
function_payload = tool_call.get("function") or {}
@@ -519,7 +595,7 @@ class ChatViewSet(viewsets.ModelViewSet):
**prepared_arguments,
)
if self._is_search_places_missing_location_required_error(
if self._is_search_places_location_retry_candidate_error(
function_name,
result,
):
@@ -552,7 +628,11 @@ class ChatViewSet(viewsets.ModelViewSet):
**retry_arguments,
)
if not self._is_required_param_tool_error(retry_result):
if not self._is_required_param_tool_error(
retry_result
) and not self._is_execution_failure_tool_error(
retry_result
):
result = retry_result
tool_call_for_history = {
**tool_call,
@@ -630,6 +710,13 @@ class ChatViewSet(viewsets.ModelViewSet):
yield "data: [DONE]\n\n"
return
if self._is_execution_failure_tool_error(result):
if first_execution_failure is None:
first_execution_failure = (function_name, result)
if not self._is_retryable_execution_failure(result):
encountered_permanent_failure = True
continue
result_content = serialize_tool_result(result)
successful_tool_calls.append(tool_call_for_history)
@@ -659,6 +746,41 @@ class ChatViewSet(viewsets.ModelViewSet):
}
yield f"data: {json.dumps(tool_event)}\n\n"
if not successful_tool_calls and first_execution_failure:
if encountered_permanent_failure:
all_failure_rounds = MAX_ALL_FAILURE_ROUNDS
else:
all_failure_rounds += 1
if all_failure_rounds >= MAX_ALL_FAILURE_ROUNDS:
failed_tool_name, failed_tool_result = (
first_execution_failure
)
error_event = self._build_tool_execution_error_event(
failed_tool_name,
failed_tool_result,
)
await sync_to_async(
ChatMessage.objects.create,
thread_sensitive=True,
)(
conversation=conversation,
role="assistant",
content=error_event["error"],
)
await sync_to_async(
conversation.save,
thread_sensitive=True,
)(update_fields=["updated_at"])
yield f"data: {json.dumps(error_event)}\n\n"
yield "data: [DONE]\n\n"
return
continue
all_failure_rounds = 0
tool_iterations += 1
assistant_with_tools = {
"role": "assistant",
"content": assistant_content,