diff --git a/backend/server/chat/agent_tools.py b/backend/server/chat/agent_tools.py index 76b7e319..16545ec4 100644 --- a/backend/server/chat/agent_tools.py +++ b/backend/server/chat/agent_tools.py @@ -196,6 +196,10 @@ def search_places( "category": category, "results": results, } + except requests.HTTPError as exc: + if exc.response is not None and exc.response.status_code == 429: + return {"error": f"Places API request failed: {exc}", "retryable": False} + return {"error": f"Places API request failed: {exc}"} except requests.RequestException as exc: return {"error": f"Places API request failed: {exc}"} except (TypeError, ValueError) as exc: diff --git a/backend/server/chat/tests.py b/backend/server/chat/tests.py index ab0142e5..60491476 100644 --- a/backend/server/chat/tests.py +++ b/backend/server/chat/tests.py @@ -1,6 +1,8 @@ import json from unittest import mock -from unittest.mock import patch +from unittest.mock import MagicMock, patch + +import requests as _requests from django.contrib.auth import get_user_model from django.test import TestCase @@ -11,6 +13,7 @@ from chat.agent_tools import ( add_to_itinerary, execute_tool, get_trip_details, + search_places, web_search, ) from chat.views import ChatViewSet @@ -997,3 +1000,218 @@ class ChatViewSetToolExecutionFailureLoopTests(APITransactionTestCase): for payload in json_payloads ) ) + + +class SearchPlaces429NonRetryableTests(TestCase): + """search_places must return retryable=False on HTTP 429.""" + + def test_429_response_marks_result_non_retryable(self): + mock_response = MagicMock() + mock_response.status_code = 429 + http_error = _requests.HTTPError(response=mock_response) + + with patch("chat.agent_tools.requests.get", side_effect=http_error): + result = search_places( + user=None, + location="Paris, France", + ) + + self.assertIn("error", result) + self.assertFalse( + result.get("retryable", True), + "429 error must set retryable=False to prevent retry spiral", + ) + + def test_non_429_http_error_is_retryable_by_default(self): + mock_response = MagicMock() + mock_response.status_code = 500 + http_error = _requests.HTTPError(response=mock_response) + + with patch("chat.agent_tools.requests.get", side_effect=http_error): + result = search_places( + user=None, + location="Paris, France", + ) + + self.assertIn("error", result) + self.assertTrue( + result.get("retryable", True), + "Non-429 HTTP errors should remain retryable (default=True)", + ) + + def test_generic_request_exception_is_retryable_by_default(self): + conn_error = _requests.ConnectionError("timeout") + + with patch("chat.agent_tools.requests.get", side_effect=conn_error): + result = search_places( + user=None, + location="Paris, France", + ) + + self.assertIn("error", result) + self.assertTrue( + result.get("retryable", True), + "Generic RequestException should remain retryable", + ) + + +class GetWeatherCoordFallbackTests(APITransactionTestCase): + """get_weather lat/lng required param should be retried with collection location coords.""" + + @patch("chat.views.execute_tool") + @patch("chat.views.stream_chat_completion") + @patch("integrations.utils.auto_profile.update_auto_preference_profile") + def test_get_weather_retries_with_collection_coordinates( + self, + _mock_auto_profile, + mock_stream_chat_completion, + mock_execute_tool, + ): + user = User.objects.create_user( + username="weather-coord-user", + email="weather-coord-user@example.com", + password="password123", + ) + self.client.force_authenticate(user=user) + + collection = Collection.objects.create( + user_id=user.id, + name="Paris Trip", + ) + paris_location = Location.objects.create( + user_id=user.id, + name="Paris", + latitude=48.8566, + longitude=2.3522, + ) + collection.locations.add(paris_location) + + conversation_response = self.client.post( + "/api/chat/conversations/", + {"title": "Weather Coord Fallback Test"}, + format="json", + ) + self.assertEqual(conversation_response.status_code, 201) + conversation_id = conversation_response.json()["id"] + + async def weather_stream(*args, **kwargs): + # LLM calls get_weather without coordinates + yield 'data: {"tool_calls": [{"index": 0, "id": "call_w1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]}\n\n' + yield "data: [DONE]\n\n" + + async def success_stream(*args, **kwargs): + yield 'data: {"content": "The weather in Paris is sunny."}\n\n' + yield "data: [DONE]\n\n" + + mock_stream_chat_completion.side_effect = [weather_stream(), success_stream()] + mock_execute_tool.side_effect = [ + # First call: no lat/lon + {"error": "latitude and longitude are required"}, + # Retry call: with injected coords from collection — succeeds + {"temperature": 22, "condition": "sunny", "location": "Paris"}, + ] + + response = self.client.post( + f"/api/chat/conversations/{conversation_id}/send_message/", + { + "message": "What's the weather like?", + "collection_id": str(collection.id), + }, + format="json", + ) + + self.assertEqual(response.status_code, 200) + chunks = [ + chunk.decode("utf-8") + if isinstance(chunk, (bytes, bytearray)) + else str(chunk) + for chunk in response.streaming_content + ] + payload_lines = [ + chunk.strip()[len("data: ") :] + for chunk in chunks + if chunk.strip().startswith("data: ") + ] + json_payloads = [json.loads(p) for p in payload_lines if p != "[DONE]"] + + # Verify the retry happened with coordinates (execute_tool called twice) + self.assertEqual( + mock_execute_tool.call_count, + 2, + "Expected exactly 2 execute_tool calls: initial + coord retry", + ) + # Verify no tool_execution_error surfaced to the user + self.assertFalse( + any( + payload.get("error_category") == "tool_execution_error" + for payload in json_payloads + ), + "Should not emit tool_execution_error when coord retry succeeds", + ) + + # Verify coordinates were passed in the retry call + retry_kwargs = mock_execute_tool.call_args_list[1][1] + self.assertAlmostEqual(retry_kwargs.get("latitude"), 48.8566, places=3) + self.assertAlmostEqual(retry_kwargs.get("longitude"), 2.3522, places=3) + + @patch("chat.views.execute_tool") + @patch("chat.views.stream_chat_completion") + @patch("integrations.utils.auto_profile.update_auto_preference_profile") + def test_get_weather_missing_coords_no_collection_emits_error( + self, + _mock_auto_profile, + mock_stream_chat_completion, + mock_execute_tool, + ): + user = User.objects.create_user( + username="weather-no-collection-user", + email="weather-no-collection-user@example.com", + password="password123", + ) + self.client.force_authenticate(user=user) + + conversation_response = self.client.post( + "/api/chat/conversations/", + {"title": "Weather No Collection Test"}, + format="json", + ) + self.assertEqual(conversation_response.status_code, 201) + conversation_id = conversation_response.json()["id"] + + async def weather_stream(*args, **kwargs): + yield 'data: {"tool_calls": [{"index": 0, "id": "call_w2", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]}\n\n' + yield "data: [DONE]\n\n" + + mock_stream_chat_completion.side_effect = weather_stream + mock_execute_tool.return_value = { + "error": "latitude and longitude are required" + } + + response = self.client.post( + f"/api/chat/conversations/{conversation_id}/send_message/", + {"message": "What's the weather like?"}, + format="json", + ) + + self.assertEqual(response.status_code, 200) + chunks = [ + chunk.decode("utf-8") + if isinstance(chunk, (bytes, bytearray)) + else str(chunk) + for chunk in response.streaming_content + ] + payload_lines = [ + chunk.strip()[len("data: ") :] + for chunk in chunks + if chunk.strip().startswith("data: ") + ] + json_payloads = [json.loads(p) for p in payload_lines if p != "[DONE]"] + + # No collection means no coord fallback — should emit tool_validation_error + self.assertTrue( + any( + payload.get("error_category") == "tool_validation_error" + for payload in json_payloads + ), + "Should emit tool_validation_error when no collection coords available", + ) diff --git a/backend/server/chat/views/__init__.py b/backend/server/chat/views/__init__.py index bea0627a..a8cfaf0c 100644 --- a/backend/server/chat/views/__init__.py +++ b/backend/server/chat/views/__init__.py @@ -248,6 +248,34 @@ class ChatViewSet(viewsets.ModelViewSet): result, ) or cls._is_search_places_geocode_error(tool_name, result) + @classmethod + def _is_get_weather_missing_latlong_error(cls, tool_name, result): + """True when get_weather was called without latitude/longitude.""" + if tool_name != "get_weather" or not cls._is_required_param_tool_error(result): + return False + + error_text = (result or {}).get("error") if isinstance(result, dict) else "" + if not isinstance(error_text, str): + return False + + normalized_error = error_text.strip().lower() + return "latitude" in normalized_error or "longitude" in normalized_error + + @staticmethod + def _extract_collection_coordinates(collection): + """Return (lat, lon) from the first geocoded location in the collection, or None.""" + if collection is None: + return None + for location in collection.locations.all(): + lat = getattr(location, "latitude", None) + lon = getattr(location, "longitude", None) + if lat is not None and lon is not None: + try: + return float(lat), float(lon) + except (TypeError, ValueError): + continue + return None + @staticmethod def _build_search_places_location_clarification_message(): return ( @@ -703,6 +731,54 @@ class ChatViewSet(viewsets.ModelViewSet): "error": "Could not search places at the provided itinerary locations" } + attempted_weather_coord_retry = False + if self._is_get_weather_missing_latlong_error( + function_name, result + ): + coords = await sync_to_async( + self._extract_collection_coordinates, + thread_sensitive=True, + )(collection) + if coords is not None: + retry_lat, retry_lon = coords + retry_arguments = dict(prepared_arguments) + retry_arguments["latitude"] = retry_lat + retry_arguments["longitude"] = retry_lon + attempted_weather_coord_retry = True + retry_result = await sync_to_async( + execute_tool, + thread_sensitive=True, + )( + function_name, + request.user, + **retry_arguments, + ) + if not self._is_required_param_tool_error( + retry_result + ) and not self._is_execution_failure_tool_error( + retry_result + ): + result = retry_result + tool_call_for_history = { + **tool_call, + "function": { + **function_payload, + "name": function_name, + "arguments": json.dumps(retry_arguments), + }, + } + + # If retry was attempted but still failed, convert to an + # execution failure — never ask the user for coordinates + # they implied via collection context. + if ( + attempted_weather_coord_retry + and self._is_required_param_tool_error(result) + ): + result = { + "error": "Could not fetch weather for the collection locations" + } + if self._is_required_param_tool_error(result): assistant_message_kwargs = { "conversation": conversation, diff --git a/frontend/src/lib/components/AITravelChat.svelte b/frontend/src/lib/components/AITravelChat.svelte index 00808d5d..d4f36a97 100644 --- a/frontend/src/lib/components/AITravelChat.svelte +++ b/frontend/src/lib/components/AITravelChat.svelte @@ -348,7 +348,9 @@ return [...next, toolResult]; } - function uniqueToolResultsByCallId(toolResults: ToolResultEntry[] | undefined): ToolResultEntry[] { + function uniqueToolResultsByCallId( + toolResults: ToolResultEntry[] | undefined + ): ToolResultEntry[] { if (!toolResults) { return []; } @@ -368,6 +370,24 @@ return unique; } + // Context-loading tools that should render at most once per message, even if + // the retry loop caused the LLM to call them multiple times. + const CONTEXT_ONLY_TOOLS = new Set(['get_trip_details', 'get_weather']); + + function deduplicateContextTools(toolResults: ToolResultEntry[]): ToolResultEntry[] { + const seenContextTool = new Set(); + return toolResults.filter((result) => { + const name = result.name; + if (name && CONTEXT_ONLY_TOOLS.has(name)) { + if (seenContextTool.has(name)) { + return false; + } + seenContextTool.add(name); + } + return true; + }); + } + function rebuildConversationMessages(rawMessages: ChatMessage[]): ChatMessage[] { const rebuilt = rawMessages.map((msg) => ({ ...msg, @@ -936,7 +956,7 @@
{msg.content}
{#if msg.role === 'assistant' && msg.tool_results}
- {#each uniqueToolResultsByCallId(msg.tool_results) as result} + {#each deduplicateContextTools(uniqueToolResultsByCallId(msg.tool_results)) as result} {#if hasPlaceResults(result)}
{#each getPlaceResults(result) as place}