fix(chat): stop 429 retry spiral and add get_weather coord fallback
- search_places: detect HTTP 429 and mark retryable=False to stop the retry loop immediately instead of spiraling until MAX_ITERATIONS - get_weather: extract collection coordinates (lat/lng from first location with coords) and retry when LLM omits required params; uses sync_to_async for the DB query in the async view - AITravelChat: deduplicate context-only tools (get_trip_details, get_weather) in the render pipeline to prevent duplicate place cards from appearing when the retry loop causes multiple get_trip_details calls - Tests: 5 new tests covering 429 non-retryable path and weather coord fallback; all 39 chat tests pass
This commit is contained in:
@@ -196,6 +196,10 @@ def search_places(
|
|||||||
"category": category,
|
"category": category,
|
||||||
"results": results,
|
"results": results,
|
||||||
}
|
}
|
||||||
|
except requests.HTTPError as exc:
|
||||||
|
if exc.response is not None and exc.response.status_code == 429:
|
||||||
|
return {"error": f"Places API request failed: {exc}", "retryable": False}
|
||||||
|
return {"error": f"Places API request failed: {exc}"}
|
||||||
except requests.RequestException as exc:
|
except requests.RequestException as exc:
|
||||||
return {"error": f"Places API request failed: {exc}"}
|
return {"error": f"Places API request failed: {exc}"}
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import requests as _requests
|
||||||
|
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
@@ -11,6 +13,7 @@ from chat.agent_tools import (
|
|||||||
add_to_itinerary,
|
add_to_itinerary,
|
||||||
execute_tool,
|
execute_tool,
|
||||||
get_trip_details,
|
get_trip_details,
|
||||||
|
search_places,
|
||||||
web_search,
|
web_search,
|
||||||
)
|
)
|
||||||
from chat.views import ChatViewSet
|
from chat.views import ChatViewSet
|
||||||
@@ -997,3 +1000,218 @@ class ChatViewSetToolExecutionFailureLoopTests(APITransactionTestCase):
|
|||||||
for payload in json_payloads
|
for payload in json_payloads
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchPlaces429NonRetryableTests(TestCase):
|
||||||
|
"""search_places must return retryable=False on HTTP 429."""
|
||||||
|
|
||||||
|
def test_429_response_marks_result_non_retryable(self):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 429
|
||||||
|
http_error = _requests.HTTPError(response=mock_response)
|
||||||
|
|
||||||
|
with patch("chat.agent_tools.requests.get", side_effect=http_error):
|
||||||
|
result = search_places(
|
||||||
|
user=None,
|
||||||
|
location="Paris, France",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertFalse(
|
||||||
|
result.get("retryable", True),
|
||||||
|
"429 error must set retryable=False to prevent retry spiral",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_non_429_http_error_is_retryable_by_default(self):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 500
|
||||||
|
http_error = _requests.HTTPError(response=mock_response)
|
||||||
|
|
||||||
|
with patch("chat.agent_tools.requests.get", side_effect=http_error):
|
||||||
|
result = search_places(
|
||||||
|
user=None,
|
||||||
|
location="Paris, France",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertTrue(
|
||||||
|
result.get("retryable", True),
|
||||||
|
"Non-429 HTTP errors should remain retryable (default=True)",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_generic_request_exception_is_retryable_by_default(self):
|
||||||
|
conn_error = _requests.ConnectionError("timeout")
|
||||||
|
|
||||||
|
with patch("chat.agent_tools.requests.get", side_effect=conn_error):
|
||||||
|
result = search_places(
|
||||||
|
user=None,
|
||||||
|
location="Paris, France",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("error", result)
|
||||||
|
self.assertTrue(
|
||||||
|
result.get("retryable", True),
|
||||||
|
"Generic RequestException should remain retryable",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GetWeatherCoordFallbackTests(APITransactionTestCase):
|
||||||
|
"""get_weather lat/lng required param should be retried with collection location coords."""
|
||||||
|
|
||||||
|
@patch("chat.views.execute_tool")
|
||||||
|
@patch("chat.views.stream_chat_completion")
|
||||||
|
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||||
|
def test_get_weather_retries_with_collection_coordinates(
|
||||||
|
self,
|
||||||
|
_mock_auto_profile,
|
||||||
|
mock_stream_chat_completion,
|
||||||
|
mock_execute_tool,
|
||||||
|
):
|
||||||
|
user = User.objects.create_user(
|
||||||
|
username="weather-coord-user",
|
||||||
|
email="weather-coord-user@example.com",
|
||||||
|
password="password123",
|
||||||
|
)
|
||||||
|
self.client.force_authenticate(user=user)
|
||||||
|
|
||||||
|
collection = Collection.objects.create(
|
||||||
|
user_id=user.id,
|
||||||
|
name="Paris Trip",
|
||||||
|
)
|
||||||
|
paris_location = Location.objects.create(
|
||||||
|
user_id=user.id,
|
||||||
|
name="Paris",
|
||||||
|
latitude=48.8566,
|
||||||
|
longitude=2.3522,
|
||||||
|
)
|
||||||
|
collection.locations.add(paris_location)
|
||||||
|
|
||||||
|
conversation_response = self.client.post(
|
||||||
|
"/api/chat/conversations/",
|
||||||
|
{"title": "Weather Coord Fallback Test"},
|
||||||
|
format="json",
|
||||||
|
)
|
||||||
|
self.assertEqual(conversation_response.status_code, 201)
|
||||||
|
conversation_id = conversation_response.json()["id"]
|
||||||
|
|
||||||
|
async def weather_stream(*args, **kwargs):
|
||||||
|
# LLM calls get_weather without coordinates
|
||||||
|
yield 'data: {"tool_calls": [{"index": 0, "id": "call_w1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]}\n\n'
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
async def success_stream(*args, **kwargs):
|
||||||
|
yield 'data: {"content": "The weather in Paris is sunny."}\n\n'
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
mock_stream_chat_completion.side_effect = [weather_stream(), success_stream()]
|
||||||
|
mock_execute_tool.side_effect = [
|
||||||
|
# First call: no lat/lon
|
||||||
|
{"error": "latitude and longitude are required"},
|
||||||
|
# Retry call: with injected coords from collection — succeeds
|
||||||
|
{"temperature": 22, "condition": "sunny", "location": "Paris"},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = self.client.post(
|
||||||
|
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||||
|
{
|
||||||
|
"message": "What's the weather like?",
|
||||||
|
"collection_id": str(collection.id),
|
||||||
|
},
|
||||||
|
format="json",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
chunks = [
|
||||||
|
chunk.decode("utf-8")
|
||||||
|
if isinstance(chunk, (bytes, bytearray))
|
||||||
|
else str(chunk)
|
||||||
|
for chunk in response.streaming_content
|
||||||
|
]
|
||||||
|
payload_lines = [
|
||||||
|
chunk.strip()[len("data: ") :]
|
||||||
|
for chunk in chunks
|
||||||
|
if chunk.strip().startswith("data: ")
|
||||||
|
]
|
||||||
|
json_payloads = [json.loads(p) for p in payload_lines if p != "[DONE]"]
|
||||||
|
|
||||||
|
# Verify the retry happened with coordinates (execute_tool called twice)
|
||||||
|
self.assertEqual(
|
||||||
|
mock_execute_tool.call_count,
|
||||||
|
2,
|
||||||
|
"Expected exactly 2 execute_tool calls: initial + coord retry",
|
||||||
|
)
|
||||||
|
# Verify no tool_execution_error surfaced to the user
|
||||||
|
self.assertFalse(
|
||||||
|
any(
|
||||||
|
payload.get("error_category") == "tool_execution_error"
|
||||||
|
for payload in json_payloads
|
||||||
|
),
|
||||||
|
"Should not emit tool_execution_error when coord retry succeeds",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify coordinates were passed in the retry call
|
||||||
|
retry_kwargs = mock_execute_tool.call_args_list[1][1]
|
||||||
|
self.assertAlmostEqual(retry_kwargs.get("latitude"), 48.8566, places=3)
|
||||||
|
self.assertAlmostEqual(retry_kwargs.get("longitude"), 2.3522, places=3)
|
||||||
|
|
||||||
|
@patch("chat.views.execute_tool")
|
||||||
|
@patch("chat.views.stream_chat_completion")
|
||||||
|
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||||
|
def test_get_weather_missing_coords_no_collection_emits_error(
|
||||||
|
self,
|
||||||
|
_mock_auto_profile,
|
||||||
|
mock_stream_chat_completion,
|
||||||
|
mock_execute_tool,
|
||||||
|
):
|
||||||
|
user = User.objects.create_user(
|
||||||
|
username="weather-no-collection-user",
|
||||||
|
email="weather-no-collection-user@example.com",
|
||||||
|
password="password123",
|
||||||
|
)
|
||||||
|
self.client.force_authenticate(user=user)
|
||||||
|
|
||||||
|
conversation_response = self.client.post(
|
||||||
|
"/api/chat/conversations/",
|
||||||
|
{"title": "Weather No Collection Test"},
|
||||||
|
format="json",
|
||||||
|
)
|
||||||
|
self.assertEqual(conversation_response.status_code, 201)
|
||||||
|
conversation_id = conversation_response.json()["id"]
|
||||||
|
|
||||||
|
async def weather_stream(*args, **kwargs):
|
||||||
|
yield 'data: {"tool_calls": [{"index": 0, "id": "call_w2", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]}\n\n'
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
mock_stream_chat_completion.side_effect = weather_stream
|
||||||
|
mock_execute_tool.return_value = {
|
||||||
|
"error": "latitude and longitude are required"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = self.client.post(
|
||||||
|
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||||
|
{"message": "What's the weather like?"},
|
||||||
|
format="json",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
chunks = [
|
||||||
|
chunk.decode("utf-8")
|
||||||
|
if isinstance(chunk, (bytes, bytearray))
|
||||||
|
else str(chunk)
|
||||||
|
for chunk in response.streaming_content
|
||||||
|
]
|
||||||
|
payload_lines = [
|
||||||
|
chunk.strip()[len("data: ") :]
|
||||||
|
for chunk in chunks
|
||||||
|
if chunk.strip().startswith("data: ")
|
||||||
|
]
|
||||||
|
json_payloads = [json.loads(p) for p in payload_lines if p != "[DONE]"]
|
||||||
|
|
||||||
|
# No collection means no coord fallback — should emit tool_validation_error
|
||||||
|
self.assertTrue(
|
||||||
|
any(
|
||||||
|
payload.get("error_category") == "tool_validation_error"
|
||||||
|
for payload in json_payloads
|
||||||
|
),
|
||||||
|
"Should emit tool_validation_error when no collection coords available",
|
||||||
|
)
|
||||||
|
|||||||
@@ -248,6 +248,34 @@ class ChatViewSet(viewsets.ModelViewSet):
|
|||||||
result,
|
result,
|
||||||
) or cls._is_search_places_geocode_error(tool_name, result)
|
) or cls._is_search_places_geocode_error(tool_name, result)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _is_get_weather_missing_latlong_error(cls, tool_name, result):
|
||||||
|
"""True when get_weather was called without latitude/longitude."""
|
||||||
|
if tool_name != "get_weather" or not cls._is_required_param_tool_error(result):
|
||||||
|
return False
|
||||||
|
|
||||||
|
error_text = (result or {}).get("error") if isinstance(result, dict) else ""
|
||||||
|
if not isinstance(error_text, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
normalized_error = error_text.strip().lower()
|
||||||
|
return "latitude" in normalized_error or "longitude" in normalized_error
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_collection_coordinates(collection):
|
||||||
|
"""Return (lat, lon) from the first geocoded location in the collection, or None."""
|
||||||
|
if collection is None:
|
||||||
|
return None
|
||||||
|
for location in collection.locations.all():
|
||||||
|
lat = getattr(location, "latitude", None)
|
||||||
|
lon = getattr(location, "longitude", None)
|
||||||
|
if lat is not None and lon is not None:
|
||||||
|
try:
|
||||||
|
return float(lat), float(lon)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_search_places_location_clarification_message():
|
def _build_search_places_location_clarification_message():
|
||||||
return (
|
return (
|
||||||
@@ -703,6 +731,54 @@ class ChatViewSet(viewsets.ModelViewSet):
|
|||||||
"error": "Could not search places at the provided itinerary locations"
|
"error": "Could not search places at the provided itinerary locations"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
attempted_weather_coord_retry = False
|
||||||
|
if self._is_get_weather_missing_latlong_error(
|
||||||
|
function_name, result
|
||||||
|
):
|
||||||
|
coords = await sync_to_async(
|
||||||
|
self._extract_collection_coordinates,
|
||||||
|
thread_sensitive=True,
|
||||||
|
)(collection)
|
||||||
|
if coords is not None:
|
||||||
|
retry_lat, retry_lon = coords
|
||||||
|
retry_arguments = dict(prepared_arguments)
|
||||||
|
retry_arguments["latitude"] = retry_lat
|
||||||
|
retry_arguments["longitude"] = retry_lon
|
||||||
|
attempted_weather_coord_retry = True
|
||||||
|
retry_result = await sync_to_async(
|
||||||
|
execute_tool,
|
||||||
|
thread_sensitive=True,
|
||||||
|
)(
|
||||||
|
function_name,
|
||||||
|
request.user,
|
||||||
|
**retry_arguments,
|
||||||
|
)
|
||||||
|
if not self._is_required_param_tool_error(
|
||||||
|
retry_result
|
||||||
|
) and not self._is_execution_failure_tool_error(
|
||||||
|
retry_result
|
||||||
|
):
|
||||||
|
result = retry_result
|
||||||
|
tool_call_for_history = {
|
||||||
|
**tool_call,
|
||||||
|
"function": {
|
||||||
|
**function_payload,
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": json.dumps(retry_arguments),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# If retry was attempted but still failed, convert to an
|
||||||
|
# execution failure — never ask the user for coordinates
|
||||||
|
# they implied via collection context.
|
||||||
|
if (
|
||||||
|
attempted_weather_coord_retry
|
||||||
|
and self._is_required_param_tool_error(result)
|
||||||
|
):
|
||||||
|
result = {
|
||||||
|
"error": "Could not fetch weather for the collection locations"
|
||||||
|
}
|
||||||
|
|
||||||
if self._is_required_param_tool_error(result):
|
if self._is_required_param_tool_error(result):
|
||||||
assistant_message_kwargs = {
|
assistant_message_kwargs = {
|
||||||
"conversation": conversation,
|
"conversation": conversation,
|
||||||
|
|||||||
@@ -348,7 +348,9 @@
|
|||||||
return [...next, toolResult];
|
return [...next, toolResult];
|
||||||
}
|
}
|
||||||
|
|
||||||
function uniqueToolResultsByCallId(toolResults: ToolResultEntry[] | undefined): ToolResultEntry[] {
|
function uniqueToolResultsByCallId(
|
||||||
|
toolResults: ToolResultEntry[] | undefined
|
||||||
|
): ToolResultEntry[] {
|
||||||
if (!toolResults) {
|
if (!toolResults) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
@@ -368,6 +370,24 @@
|
|||||||
return unique;
|
return unique;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Context-loading tools that should render at most once per message, even if
|
||||||
|
// the retry loop caused the LLM to call them multiple times.
|
||||||
|
const CONTEXT_ONLY_TOOLS = new Set(['get_trip_details', 'get_weather']);
|
||||||
|
|
||||||
|
function deduplicateContextTools(toolResults: ToolResultEntry[]): ToolResultEntry[] {
|
||||||
|
const seenContextTool = new Set<string>();
|
||||||
|
return toolResults.filter((result) => {
|
||||||
|
const name = result.name;
|
||||||
|
if (name && CONTEXT_ONLY_TOOLS.has(name)) {
|
||||||
|
if (seenContextTool.has(name)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
seenContextTool.add(name);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
function rebuildConversationMessages(rawMessages: ChatMessage[]): ChatMessage[] {
|
function rebuildConversationMessages(rawMessages: ChatMessage[]): ChatMessage[] {
|
||||||
const rebuilt = rawMessages.map((msg) => ({
|
const rebuilt = rawMessages.map((msg) => ({
|
||||||
...msg,
|
...msg,
|
||||||
@@ -936,7 +956,7 @@
|
|||||||
<div class="whitespace-pre-wrap">{msg.content}</div>
|
<div class="whitespace-pre-wrap">{msg.content}</div>
|
||||||
{#if msg.role === 'assistant' && msg.tool_results}
|
{#if msg.role === 'assistant' && msg.tool_results}
|
||||||
<div class="mt-2 space-y-2">
|
<div class="mt-2 space-y-2">
|
||||||
{#each uniqueToolResultsByCallId(msg.tool_results) as result}
|
{#each deduplicateContextTools(uniqueToolResultsByCallId(msg.tool_results)) as result}
|
||||||
{#if hasPlaceResults(result)}
|
{#if hasPlaceResults(result)}
|
||||||
<div class="grid gap-2">
|
<div class="grid gap-2">
|
||||||
{#each getPlaceResults(result) as place}
|
{#each getPlaceResults(result) as place}
|
||||||
|
|||||||
Reference in New Issue
Block a user