fix(chat): stop retry spirals on tool failures
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
@@ -6,13 +7,64 @@ from django.test import TestCase
|
||||
from rest_framework.test import APITransactionTestCase
|
||||
|
||||
from adventures.models import Collection, CollectionItineraryItem, Location
|
||||
from chat.agent_tools import add_to_itinerary, get_trip_details
|
||||
from chat.agent_tools import (
|
||||
add_to_itinerary,
|
||||
execute_tool,
|
||||
get_trip_details,
|
||||
web_search,
|
||||
)
|
||||
from chat.views import ChatViewSet
|
||||
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class WebSearchToolFailureClassificationTests(TestCase):
|
||||
def test_web_search_import_error_sets_retryable_false(self):
|
||||
import builtins
|
||||
|
||||
original_import = builtins.__import__
|
||||
|
||||
def controlled_import(name, *args, **kwargs):
|
||||
if name == "duckduckgo_search":
|
||||
raise ImportError("missing dependency")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
user = User.objects.create_user(
|
||||
username="chat-web-search-user",
|
||||
email="chat-web-search-user@example.com",
|
||||
password="password123",
|
||||
)
|
||||
|
||||
with mock.patch("builtins.__import__", side_effect=controlled_import):
|
||||
result = web_search(user, query="best restaurants")
|
||||
|
||||
self.assertEqual(
|
||||
result.get("error"),
|
||||
"Web search is not available (duckduckgo-search not installed)",
|
||||
)
|
||||
self.assertEqual(result.get("retryable"), False)
|
||||
|
||||
|
||||
class ExecuteToolErrorSanitizationTests(TestCase):
|
||||
def test_execute_tool_catch_all_returns_sanitized_error_message(self):
|
||||
def raising_tool(user):
|
||||
raise RuntimeError("sensitive backend detail")
|
||||
|
||||
user = User.objects.create_user(
|
||||
username="chat-tool-sanitize-user",
|
||||
email="chat-tool-sanitize-user@example.com",
|
||||
password="password123",
|
||||
)
|
||||
|
||||
with mock.patch.dict(
|
||||
"chat.agent_tools._REGISTERED_TOOLS", {"boom": raising_tool}
|
||||
):
|
||||
result = execute_tool("boom", user)
|
||||
|
||||
self.assertEqual(result, {"error": "Tool execution failed"})
|
||||
|
||||
|
||||
class ChatAgentToolSharedTripAccessTests(TestCase):
|
||||
def setUp(self):
|
||||
self.owner = User.objects.create_user(
|
||||
@@ -173,6 +225,40 @@ class ChatViewSetToolValidationBoundaryTests(TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_execution_failure_error_detected(self):
|
||||
self.assertTrue(
|
||||
ChatViewSet._is_execution_failure_tool_error(
|
||||
{"error": "Web search failed. Please try again."}
|
||||
)
|
||||
)
|
||||
|
||||
def test_required_error_not_treated_as_execution_failure(self):
|
||||
self.assertFalse(
|
||||
ChatViewSet._is_execution_failure_tool_error(
|
||||
{"error": "location is required"}
|
||||
)
|
||||
)
|
||||
|
||||
def test_search_places_geocode_error_detected_for_location_retry(self):
|
||||
self.assertTrue(
|
||||
ChatViewSet._is_search_places_location_retry_candidate_error(
|
||||
"search_places",
|
||||
{"error": "Could not geocode location: ???"},
|
||||
)
|
||||
)
|
||||
|
||||
def test_retryable_execution_failure_defaults_true(self):
|
||||
self.assertTrue(
|
||||
ChatViewSet._is_retryable_execution_failure({"error": "Temporary outage"})
|
||||
)
|
||||
|
||||
def test_retryable_execution_failure_honors_false_flag(self):
|
||||
self.assertFalse(
|
||||
ChatViewSet._is_retryable_execution_failure(
|
||||
{"error": "Not installed", "retryable": False}
|
||||
)
|
||||
)
|
||||
|
||||
def test_likely_location_reply_heuristic_positive_case(self):
|
||||
self.assertTrue(ChatViewSet._is_likely_location_reply("london"))
|
||||
|
||||
@@ -557,3 +643,217 @@ class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase):
|
||||
self.assertEqual(mock_execute_tool.call_count, 2)
|
||||
second_call_kwargs = mock_execute_tool.call_args_list[1].kwargs
|
||||
self.assertEqual(second_call_kwargs.get("location"), "Paris, France")
|
||||
|
||||
@patch("chat.views.execute_tool")
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
def test_geocode_failure_retries_with_trip_context_location(
|
||||
self,
|
||||
_mock_auto_profile,
|
||||
mock_stream_chat_completion,
|
||||
mock_execute_tool,
|
||||
):
|
||||
user = User.objects.create_user(
|
||||
username="chat-geocode-retry-user",
|
||||
email="chat-geocode-retry-user@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.client.force_authenticate(user=user)
|
||||
|
||||
conversation_response = self.client.post(
|
||||
"/api/chat/conversations/",
|
||||
{"title": "Geocode Retry Test"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(conversation_response.status_code, 201)
|
||||
conversation_id = conversation_response.json()["id"]
|
||||
|
||||
async def first_stream(*args, **kwargs):
|
||||
yield 'data: {"tool_calls": [{"index": 0, "id": "call_1", "type": "function", "function": {"name": "search_places", "arguments": "{}"}}]}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def second_stream(*args, **kwargs):
|
||||
yield 'data: {"content": "Here are options in Lisbon."}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
mock_stream_chat_completion.side_effect = [first_stream(), second_stream()]
|
||||
mock_execute_tool.side_effect = [
|
||||
{"error": "Could not geocode location: invalid"},
|
||||
{"results": [{"name": "Time Out Market"}]},
|
||||
]
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||
{
|
||||
"message": "Find restaurants",
|
||||
"destination": "Lisbon, Portugal",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
chunks = [
|
||||
chunk.decode("utf-8")
|
||||
if isinstance(chunk, (bytes, bytearray))
|
||||
else str(chunk)
|
||||
for chunk in response.streaming_content
|
||||
]
|
||||
payload_lines = [
|
||||
chunk.strip()[len("data: ") :]
|
||||
for chunk in chunks
|
||||
if chunk.strip().startswith("data: ")
|
||||
]
|
||||
json_payloads = [
|
||||
json.loads(payload) for payload in payload_lines if payload != "[DONE]"
|
||||
]
|
||||
|
||||
self.assertTrue(any("tool_result" in payload for payload in json_payloads))
|
||||
self.assertTrue(
|
||||
any(
|
||||
payload.get("content") == "Here are options in Lisbon."
|
||||
for payload in json_payloads
|
||||
)
|
||||
)
|
||||
self.assertEqual(mock_execute_tool.call_count, 2)
|
||||
second_call_kwargs = mock_execute_tool.call_args_list[1].kwargs
|
||||
self.assertEqual(second_call_kwargs.get("location"), "Lisbon, Portugal")
|
||||
|
||||
|
||||
class ChatViewSetToolExecutionFailureLoopTests(APITransactionTestCase):
|
||||
@patch("chat.views.execute_tool")
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
def test_all_failure_rounds_stop_with_execution_error_before_tool_cap(
|
||||
self,
|
||||
_mock_auto_profile,
|
||||
mock_stream_chat_completion,
|
||||
mock_execute_tool,
|
||||
):
|
||||
user = User.objects.create_user(
|
||||
username="chat-loop-failure-user",
|
||||
email="chat-loop-failure-user@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.client.force_authenticate(user=user)
|
||||
|
||||
conversation_response = self.client.post(
|
||||
"/api/chat/conversations/",
|
||||
{"title": "Failure Loop Test"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(conversation_response.status_code, 201)
|
||||
conversation_id = conversation_response.json()["id"]
|
||||
|
||||
async def failing_stream(*args, **kwargs):
|
||||
yield 'data: {"tool_calls": [{"index": 0, "id": "call_w", "type": "function", "function": {"name": "web_search", "arguments": "{\\"query\\":\\"restaurants\\"}"}}]}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
mock_stream_chat_completion.side_effect = failing_stream
|
||||
mock_execute_tool.return_value = {
|
||||
"error": "Web search failed. Please try again.",
|
||||
"results": [],
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||
{"message": "Find restaurants near me"},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
chunks = [
|
||||
chunk.decode("utf-8")
|
||||
if isinstance(chunk, (bytes, bytearray))
|
||||
else str(chunk)
|
||||
for chunk in response.streaming_content
|
||||
]
|
||||
payload_lines = [
|
||||
chunk.strip()[len("data: ") :]
|
||||
for chunk in chunks
|
||||
if chunk.strip().startswith("data: ")
|
||||
]
|
||||
json_payloads = [
|
||||
json.loads(payload) for payload in payload_lines if payload != "[DONE]"
|
||||
]
|
||||
|
||||
self.assertTrue(
|
||||
any(
|
||||
payload.get("error_category") == "tool_execution_error"
|
||||
for payload in json_payloads
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
any(
|
||||
payload.get("error_category") == "tool_loop_limit"
|
||||
for payload in json_payloads
|
||||
)
|
||||
)
|
||||
self.assertFalse(any("tool_result" in payload for payload in json_payloads))
|
||||
self.assertEqual(mock_stream_chat_completion.call_count, 3)
|
||||
self.assertEqual(mock_execute_tool.call_count, 3)
|
||||
|
||||
@patch("chat.views.execute_tool")
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
def test_permanent_execution_failure_stops_immediately(
|
||||
self,
|
||||
_mock_auto_profile,
|
||||
mock_stream_chat_completion,
|
||||
mock_execute_tool,
|
||||
):
|
||||
user = User.objects.create_user(
|
||||
username="chat-permanent-failure-user",
|
||||
email="chat-permanent-failure-user@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.client.force_authenticate(user=user)
|
||||
|
||||
conversation_response = self.client.post(
|
||||
"/api/chat/conversations/",
|
||||
{"title": "Permanent Failure Test"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(conversation_response.status_code, 201)
|
||||
conversation_id = conversation_response.json()["id"]
|
||||
|
||||
async def failing_stream(*args, **kwargs):
|
||||
yield 'data: {"tool_calls": [{"index": 0, "id": "call_w", "type": "function", "function": {"name": "web_search", "arguments": "{\\"query\\":\\"restaurants\\"}"}}]}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
mock_stream_chat_completion.side_effect = failing_stream
|
||||
mock_execute_tool.return_value = {
|
||||
"error": "Web search is not available (duckduckgo-search not installed)",
|
||||
"results": [],
|
||||
"retryable": False,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||
{"message": "Find restaurants near me"},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
chunks = [
|
||||
chunk.decode("utf-8")
|
||||
if isinstance(chunk, (bytes, bytearray))
|
||||
else str(chunk)
|
||||
for chunk in response.streaming_content
|
||||
]
|
||||
payload_lines = [
|
||||
chunk.strip()[len("data: ") :]
|
||||
for chunk in chunks
|
||||
if chunk.strip().startswith("data: ")
|
||||
]
|
||||
json_payloads = [
|
||||
json.loads(payload) for payload in payload_lines if payload != "[DONE]"
|
||||
]
|
||||
|
||||
self.assertTrue(
|
||||
any(
|
||||
payload.get("error_category") == "tool_execution_error"
|
||||
for payload in json_payloads
|
||||
)
|
||||
)
|
||||
self.assertEqual(mock_stream_chat_completion.call_count, 1)
|
||||
self.assertEqual(mock_execute_tool.call_count, 1)
|
||||
|
||||
Reference in New Issue
Block a user