From 46d7704e4fac57934ca8cfebd4e7f9898fc146bf Mon Sep 17 00:00:00 2001 From: alex Date: Tue, 10 Mar 2026 17:12:29 +0000 Subject: [PATCH] fix(chat): use itinerary context for restaurant searches --- AGENTS.md | 2 +- backend/server/chat/agent_tools.py | 5 +- backend/server/chat/tests.py | 122 +++++++++++++++++++++++- backend/server/chat/views/__init__.py | 132 +++++++++++++++++++++----- 4 files changed, 231 insertions(+), 30 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index f0822b25..48994c4c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -67,7 +67,7 @@ Run in this order: - Security: handle CSRF tokens via `/auth/csrf/` and `X-CSRFToken` - Chat providers: dynamic catalog from `GET /api/chat/providers/`; configured in `CHAT_PROVIDER_CONFIG` - Chat model override: dropdown selector fed by `GET /api/chat/providers/{provider}/models/`; persisted in `localStorage` key `voyage_chat_model_prefs`; backend accepts optional `model` param in `send_message` -- Chat context: collection chats inject collection UUID + multi-stop itinerary context; system prompt guides `get_trip_details`-first reasoning and confirms only before first `add_to_itinerary`; `search_places` prompt guard requires the LLM to have a concrete location string before calling the tool (asks clarifying question otherwise) +- Chat context: collection chats inject collection UUID + multi-stop itinerary context; system prompt guides `get_trip_details`-first reasoning and confirms only before first `add_to_itinerary`; `search_places` has a deterministic context-retry fallback — when the LLM omits `location`, the backend retries using the trip destination or first itinerary stop before asking the user for clarification; a dining-intent heuristic infers `category="food"` from user messages when the LLM omits category for restaurant/dining requests - Chat tool output: `role=tool` messages hidden from display; tool outputs render as concise summaries; persisted tool rows reconstructed on reload via `rebuildConversationMessages()` - Chat error surfacing: `_safe_error_payload()` maps LiteLLM exceptions to sanitized user-safe categories (never forwards raw `exc.message`) - Invalid tool calls (missing required args) are detected and short-circuited with a user-visible error — not replayed into history diff --git a/backend/server/chat/agent_tools.py b/backend/server/chat/agent_tools.py index 964e5bf8..881d6395 100644 --- a/backend/server/chat/agent_tools.py +++ b/backend/server/chat/agent_tools.py @@ -102,7 +102,8 @@ def _parse_address(tags): description=( "Search for places of interest near a location. " "Required: provide a non-empty 'location' string (city, neighborhood, or address). " - "Returns tourist attractions, restaurants, hotels, etc." + "Use category='food' for restaurants/dining, category='tourism' for attractions, " + "and category='lodging' for hotels/stays." ), parameters={ "location": { @@ -113,7 +114,7 @@ def _parse_address(tags): "category": { "type": "string", "enum": ["tourism", "food", "lodging"], - "description": "Category of places", + "description": "Place type: food (restaurants/dining), tourism (attractions), lodging (hotels/stays)", }, "radius": { "type": "number", diff --git a/backend/server/chat/tests.py b/backend/server/chat/tests.py index ab3d47e8..f59b03ae 100644 --- a/backend/server/chat/tests.py +++ b/backend/server/chat/tests.py @@ -5,7 +5,7 @@ from django.contrib.auth import get_user_model from django.test import TestCase from rest_framework.test import APITransactionTestCase -from adventures.models import Collection, CollectionItineraryItem +from adventures.models import Collection, CollectionItineraryItem, Location from chat.agent_tools import add_to_itinerary, get_trip_details from chat.views import ChatViewSet @@ -177,6 +177,32 @@ class ChatViewSetToolValidationBoundaryTests(TestCase): ) ) + def test_infer_search_places_category_detects_restaurant_intent(self): + self.assertEqual( + ChatViewSet._infer_search_places_category( + "Can you find restaurants for dinner?", + [], + ), + "food", + ) + + def test_infer_search_places_category_detects_prior_message_dining_intent(self): + self.assertEqual( + ChatViewSet._infer_search_places_category( + "What are the top picks?", + ["We want good cafes nearby"], + ), + "food", + ) + + def test_infer_search_places_category_leaves_non_dining_intent_unchanged(self): + self.assertIsNone( + ChatViewSet._infer_search_places_category( + "Find great museums and landmarks", + ["We love cultural attractions"], + ) + ) + class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase): @patch("chat.views.execute_tool") @@ -347,3 +373,97 @@ class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase): for message in assistant_messages ) ) + + @patch("chat.views.execute_tool") + @patch("chat.views.stream_chat_completion") + @patch("integrations.utils.auto_profile.update_auto_preference_profile") + def test_collection_context_retry_uses_destination_and_food_category_for_restaurants( + self, + _mock_auto_profile, + mock_stream_chat_completion, + mock_execute_tool, + ): + user = User.objects.create_user( + username="chat-context-retry-user", + email="chat-context-retry-user@example.com", + password="password123", + ) + self.client.force_authenticate(user=user) + + collection = Collection.objects.create(user=user, name="Rome Food Trip") + trip_stop = Location.objects.create( + user=user, + name="Trevi Fountain", + location="Rome, Italy", + ) + collection.locations.add(trip_stop) + + conversation_response = self.client.post( + "/api/chat/conversations/", + {"title": "Collection Context Retry Test"}, + format="json", + ) + self.assertEqual(conversation_response.status_code, 201) + conversation_id = conversation_response.json()["id"] + + async def first_stream(*args, **kwargs): + yield 'data: {"tool_calls": [{"index": 0, "id": "call_1", "type": "function", "function": {"name": "search_places", "arguments": "{}"}}]}\n\n' + yield "data: [DONE]\n\n" + + async def second_stream(*args, **kwargs): + yield 'data: {"content": "Here are restaurant options in Rome."}\n\n' + yield "data: [DONE]\n\n" + + mock_stream_chat_completion.side_effect = [first_stream(), second_stream()] + mock_execute_tool.side_effect = [ + {"error": "location is required"}, + {"results": [{"name": "Roscioli"}]}, + ] + + response = self.client.post( + f"/api/chat/conversations/{conversation_id}/send_message/", + { + "message": "Find great restaurants for dinner", + "collection_id": str(collection.id), + "collection_name": collection.name, + "destination": "Rome, Italy", + }, + format="json", + ) + + self.assertEqual(response.status_code, 200) + chunks = [ + chunk.decode("utf-8") + if isinstance(chunk, (bytes, bytearray)) + else str(chunk) + for chunk in response.streaming_content + ] + payload_lines = [ + chunk.strip()[len("data: ") :] + for chunk in chunks + if chunk.strip().startswith("data: ") + ] + json_payloads = [ + json.loads(payload) for payload in payload_lines if payload != "[DONE]" + ] + + self.assertTrue(any("tool_result" in payload for payload in json_payloads)) + self.assertTrue( + any( + payload.get("content") == "Here are restaurant options in Rome." + for payload in json_payloads + ) + ) + self.assertFalse( + any( + "specific location" in payload.get("content", "").lower() + for payload in json_payloads + ) + ) + + self.assertEqual(mock_execute_tool.call_count, 2) + first_call_kwargs = mock_execute_tool.call_args_list[0].kwargs + second_call_kwargs = mock_execute_tool.call_args_list[1].kwargs + self.assertEqual(first_call_kwargs.get("category"), "food") + self.assertEqual(second_call_kwargs.get("category"), "food") + self.assertEqual(second_call_kwargs.get("location"), "Rome, Italy") diff --git a/backend/server/chat/views/__init__.py b/backend/server/chat/views/__init__.py index 22b7bafc..948af457 100644 --- a/backend/server/chat/views/__init__.py +++ b/backend/server/chat/views/__init__.py @@ -198,6 +198,38 @@ class ChatViewSet(viewsets.ModelViewSet): "activities, or lodging." ) + @staticmethod + def _trip_context_search_location(destination, itinerary_stops): + destination_text = (destination or "").strip() + if destination_text: + return destination_text + + for stop in itinerary_stops or []: + stop_text = (stop or "").strip() + if stop_text: + return stop_text + + return "" + + @staticmethod + def _infer_search_places_category(user_content, prior_user_messages): + message_parts = [(user_content or "").strip()] + message_parts.extend( + (content or "").strip() for content in prior_user_messages or [] + ) + normalized = " ".join(part for part in message_parts if part).lower() + if not normalized: + return None + + dining_intent_pattern = ( + r"\b(restaurant|restaurants|dining|dinner|lunch|breakfast|brunch|" + r"cafe|cafes|food|eat|eating|cuisine|meal|meals|bistro|bar|bars)\b" + ) + if re.search(dining_intent_pattern, normalized): + return "food" + + return None + @staticmethod def _is_likely_location_reply(user_content): if not isinstance(user_content, str): @@ -274,6 +306,7 @@ class ChatViewSet(viewsets.ModelViewSet): ) context_parts = [] + itinerary_stops = [] if collection_name: context_parts.append(f"Trip: {collection_name}") if destination: @@ -300,7 +333,6 @@ class ChatViewSet(viewsets.ModelViewSet): "Collection UUID (use this exact collection_id for get_trip_details and add_to_itinerary): " f"{collection.id}" ) - itinerary_stops = [] seen_stops = set() for location in collection.locations.select_related( "city", "country" @@ -334,6 +366,15 @@ class ChatViewSet(viewsets.ModelViewSet): if itinerary_stops: context_parts.append(f"Itinerary stops: {'; '.join(itinerary_stops)}") + trip_context_location = self._trip_context_search_location( + destination, itinerary_stops + ) + prior_user_messages = list( + conversation.messages.filter(role="user") + .order_by("-created_at") + .values_list("content", flat=True)[:3] + ) + system_prompt = get_system_prompt(request.user, collection) if context_parts: system_prompt += "\n\n## Trip Context\n" + "\n".join(context_parts) @@ -425,42 +466,81 @@ class ChatViewSet(viewsets.ModelViewSet): if not isinstance(arguments, dict): arguments = {} - result = await sync_to_async( - execute_tool, thread_sensitive=True - )( - function_name, - request.user, - **arguments, - ) - + prepared_arguments = dict(arguments) tool_call_for_history = tool_call - if self._is_search_places_missing_location_required_error( - function_name, - result, - ) and self._is_likely_location_reply(user_content): - retry_arguments = dict(arguments) - retry_arguments["location"] = user_content - retry_result = await sync_to_async( - execute_tool, - thread_sensitive=True, - )( - function_name, - request.user, - **retry_arguments, - ) + if function_name == "search_places": + if not (prepared_arguments.get("category") or "").strip(): + inferred_category = self._infer_search_places_category( + user_content, + prior_user_messages, + ) + if inferred_category: + prepared_arguments["category"] = inferred_category - if not self._is_required_param_tool_error(retry_result): - result = retry_result + if prepared_arguments != arguments: tool_call_for_history = { **tool_call, "function": { **function_payload, "name": function_name, - "arguments": json.dumps(retry_arguments), + "arguments": json.dumps(prepared_arguments), }, } + result = await sync_to_async( + execute_tool, thread_sensitive=True + )( + function_name, + request.user, + **prepared_arguments, + ) + + if self._is_search_places_missing_location_required_error( + function_name, + result, + ): + retry_locations = [] + if trip_context_location: + retry_locations.append(trip_context_location) + if self._is_likely_location_reply(user_content): + retry_locations.append(user_content) + + seen_retry_locations = set() + for retry_location in retry_locations: + normalized_retry_location = ( + retry_location.strip().lower() + ) + if ( + not normalized_retry_location + or normalized_retry_location in seen_retry_locations + ): + continue + seen_retry_locations.add(normalized_retry_location) + + retry_arguments = dict(prepared_arguments) + retry_arguments["location"] = retry_location + retry_result = await sync_to_async( + execute_tool, + thread_sensitive=True, + )( + function_name, + request.user, + **retry_arguments, + ) + + if not self._is_required_param_tool_error(retry_result): + result = retry_result + tool_call_for_history = { + **tool_call, + "function": { + **function_payload, + "name": function_name, + "arguments": json.dumps(retry_arguments), + }, + } + break + if self._is_required_param_tool_error(result): assistant_message_kwargs = { "conversation": conversation,