fix(chat): use itinerary context for restaurant searches
This commit is contained in:
@@ -67,7 +67,7 @@ Run in this order:
|
||||
- Security: handle CSRF tokens via `/auth/csrf/` and `X-CSRFToken`
|
||||
- Chat providers: dynamic catalog from `GET /api/chat/providers/`; configured in `CHAT_PROVIDER_CONFIG`
|
||||
- Chat model override: dropdown selector fed by `GET /api/chat/providers/{provider}/models/`; persisted in `localStorage` key `voyage_chat_model_prefs`; backend accepts optional `model` param in `send_message`
|
||||
- Chat context: collection chats inject collection UUID + multi-stop itinerary context; system prompt guides `get_trip_details`-first reasoning and confirms only before first `add_to_itinerary`; `search_places` prompt guard requires the LLM to have a concrete location string before calling the tool (asks clarifying question otherwise)
|
||||
- Chat context: collection chats inject collection UUID + multi-stop itinerary context; system prompt guides `get_trip_details`-first reasoning and confirms only before first `add_to_itinerary`; `search_places` has a deterministic context-retry fallback — when the LLM omits `location`, the backend retries using the trip destination or first itinerary stop before asking the user for clarification; a dining-intent heuristic infers `category="food"` from user messages when the LLM omits category for restaurant/dining requests
|
||||
- Chat tool output: `role=tool` messages hidden from display; tool outputs render as concise summaries; persisted tool rows reconstructed on reload via `rebuildConversationMessages()`
|
||||
- Chat error surfacing: `_safe_error_payload()` maps LiteLLM exceptions to sanitized user-safe categories (never forwards raw `exc.message`)
|
||||
- Invalid tool calls (missing required args) are detected and short-circuited with a user-visible error — not replayed into history
|
||||
|
||||
@@ -102,7 +102,8 @@ def _parse_address(tags):
|
||||
description=(
|
||||
"Search for places of interest near a location. "
|
||||
"Required: provide a non-empty 'location' string (city, neighborhood, or address). "
|
||||
"Returns tourist attractions, restaurants, hotels, etc."
|
||||
"Use category='food' for restaurants/dining, category='tourism' for attractions, "
|
||||
"and category='lodging' for hotels/stays."
|
||||
),
|
||||
parameters={
|
||||
"location": {
|
||||
@@ -113,7 +114,7 @@ def _parse_address(tags):
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["tourism", "food", "lodging"],
|
||||
"description": "Category of places",
|
||||
"description": "Place type: food (restaurants/dining), tourism (attractions), lodging (hotels/stays)",
|
||||
},
|
||||
"radius": {
|
||||
"type": "number",
|
||||
|
||||
@@ -5,7 +5,7 @@ from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase
|
||||
from rest_framework.test import APITransactionTestCase
|
||||
|
||||
from adventures.models import Collection, CollectionItineraryItem
|
||||
from adventures.models import Collection, CollectionItineraryItem, Location
|
||||
from chat.agent_tools import add_to_itinerary, get_trip_details
|
||||
from chat.views import ChatViewSet
|
||||
|
||||
@@ -177,6 +177,32 @@ class ChatViewSetToolValidationBoundaryTests(TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_infer_search_places_category_detects_restaurant_intent(self):
|
||||
self.assertEqual(
|
||||
ChatViewSet._infer_search_places_category(
|
||||
"Can you find restaurants for dinner?",
|
||||
[],
|
||||
),
|
||||
"food",
|
||||
)
|
||||
|
||||
def test_infer_search_places_category_detects_prior_message_dining_intent(self):
|
||||
self.assertEqual(
|
||||
ChatViewSet._infer_search_places_category(
|
||||
"What are the top picks?",
|
||||
["We want good cafes nearby"],
|
||||
),
|
||||
"food",
|
||||
)
|
||||
|
||||
def test_infer_search_places_category_leaves_non_dining_intent_unchanged(self):
|
||||
self.assertIsNone(
|
||||
ChatViewSet._infer_search_places_category(
|
||||
"Find great museums and landmarks",
|
||||
["We love cultural attractions"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase):
|
||||
@patch("chat.views.execute_tool")
|
||||
@@ -347,3 +373,97 @@ class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase):
|
||||
for message in assistant_messages
|
||||
)
|
||||
)
|
||||
|
||||
@patch("chat.views.execute_tool")
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
def test_collection_context_retry_uses_destination_and_food_category_for_restaurants(
|
||||
self,
|
||||
_mock_auto_profile,
|
||||
mock_stream_chat_completion,
|
||||
mock_execute_tool,
|
||||
):
|
||||
user = User.objects.create_user(
|
||||
username="chat-context-retry-user",
|
||||
email="chat-context-retry-user@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.client.force_authenticate(user=user)
|
||||
|
||||
collection = Collection.objects.create(user=user, name="Rome Food Trip")
|
||||
trip_stop = Location.objects.create(
|
||||
user=user,
|
||||
name="Trevi Fountain",
|
||||
location="Rome, Italy",
|
||||
)
|
||||
collection.locations.add(trip_stop)
|
||||
|
||||
conversation_response = self.client.post(
|
||||
"/api/chat/conversations/",
|
||||
{"title": "Collection Context Retry Test"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(conversation_response.status_code, 201)
|
||||
conversation_id = conversation_response.json()["id"]
|
||||
|
||||
async def first_stream(*args, **kwargs):
|
||||
yield 'data: {"tool_calls": [{"index": 0, "id": "call_1", "type": "function", "function": {"name": "search_places", "arguments": "{}"}}]}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def second_stream(*args, **kwargs):
|
||||
yield 'data: {"content": "Here are restaurant options in Rome."}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
mock_stream_chat_completion.side_effect = [first_stream(), second_stream()]
|
||||
mock_execute_tool.side_effect = [
|
||||
{"error": "location is required"},
|
||||
{"results": [{"name": "Roscioli"}]},
|
||||
]
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||
{
|
||||
"message": "Find great restaurants for dinner",
|
||||
"collection_id": str(collection.id),
|
||||
"collection_name": collection.name,
|
||||
"destination": "Rome, Italy",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
chunks = [
|
||||
chunk.decode("utf-8")
|
||||
if isinstance(chunk, (bytes, bytearray))
|
||||
else str(chunk)
|
||||
for chunk in response.streaming_content
|
||||
]
|
||||
payload_lines = [
|
||||
chunk.strip()[len("data: ") :]
|
||||
for chunk in chunks
|
||||
if chunk.strip().startswith("data: ")
|
||||
]
|
||||
json_payloads = [
|
||||
json.loads(payload) for payload in payload_lines if payload != "[DONE]"
|
||||
]
|
||||
|
||||
self.assertTrue(any("tool_result" in payload for payload in json_payloads))
|
||||
self.assertTrue(
|
||||
any(
|
||||
payload.get("content") == "Here are restaurant options in Rome."
|
||||
for payload in json_payloads
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
any(
|
||||
"specific location" in payload.get("content", "").lower()
|
||||
for payload in json_payloads
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(mock_execute_tool.call_count, 2)
|
||||
first_call_kwargs = mock_execute_tool.call_args_list[0].kwargs
|
||||
second_call_kwargs = mock_execute_tool.call_args_list[1].kwargs
|
||||
self.assertEqual(first_call_kwargs.get("category"), "food")
|
||||
self.assertEqual(second_call_kwargs.get("category"), "food")
|
||||
self.assertEqual(second_call_kwargs.get("location"), "Rome, Italy")
|
||||
|
||||
@@ -198,6 +198,38 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
"activities, or lodging."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _trip_context_search_location(destination, itinerary_stops):
|
||||
destination_text = (destination or "").strip()
|
||||
if destination_text:
|
||||
return destination_text
|
||||
|
||||
for stop in itinerary_stops or []:
|
||||
stop_text = (stop or "").strip()
|
||||
if stop_text:
|
||||
return stop_text
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _infer_search_places_category(user_content, prior_user_messages):
|
||||
message_parts = [(user_content or "").strip()]
|
||||
message_parts.extend(
|
||||
(content or "").strip() for content in prior_user_messages or []
|
||||
)
|
||||
normalized = " ".join(part for part in message_parts if part).lower()
|
||||
if not normalized:
|
||||
return None
|
||||
|
||||
dining_intent_pattern = (
|
||||
r"\b(restaurant|restaurants|dining|dinner|lunch|breakfast|brunch|"
|
||||
r"cafe|cafes|food|eat|eating|cuisine|meal|meals|bistro|bar|bars)\b"
|
||||
)
|
||||
if re.search(dining_intent_pattern, normalized):
|
||||
return "food"
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_likely_location_reply(user_content):
|
||||
if not isinstance(user_content, str):
|
||||
@@ -274,6 +306,7 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
)
|
||||
|
||||
context_parts = []
|
||||
itinerary_stops = []
|
||||
if collection_name:
|
||||
context_parts.append(f"Trip: {collection_name}")
|
||||
if destination:
|
||||
@@ -300,7 +333,6 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
"Collection UUID (use this exact collection_id for get_trip_details and add_to_itinerary): "
|
||||
f"{collection.id}"
|
||||
)
|
||||
itinerary_stops = []
|
||||
seen_stops = set()
|
||||
for location in collection.locations.select_related(
|
||||
"city", "country"
|
||||
@@ -334,6 +366,15 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
if itinerary_stops:
|
||||
context_parts.append(f"Itinerary stops: {'; '.join(itinerary_stops)}")
|
||||
|
||||
trip_context_location = self._trip_context_search_location(
|
||||
destination, itinerary_stops
|
||||
)
|
||||
prior_user_messages = list(
|
||||
conversation.messages.filter(role="user")
|
||||
.order_by("-created_at")
|
||||
.values_list("content", flat=True)[:3]
|
||||
)
|
||||
|
||||
system_prompt = get_system_prompt(request.user, collection)
|
||||
if context_parts:
|
||||
system_prompt += "\n\n## Trip Context\n" + "\n".join(context_parts)
|
||||
@@ -425,42 +466,81 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
if not isinstance(arguments, dict):
|
||||
arguments = {}
|
||||
|
||||
result = await sync_to_async(
|
||||
execute_tool, thread_sensitive=True
|
||||
)(
|
||||
function_name,
|
||||
request.user,
|
||||
**arguments,
|
||||
)
|
||||
|
||||
prepared_arguments = dict(arguments)
|
||||
tool_call_for_history = tool_call
|
||||
|
||||
if self._is_search_places_missing_location_required_error(
|
||||
function_name,
|
||||
result,
|
||||
) and self._is_likely_location_reply(user_content):
|
||||
retry_arguments = dict(arguments)
|
||||
retry_arguments["location"] = user_content
|
||||
retry_result = await sync_to_async(
|
||||
execute_tool,
|
||||
thread_sensitive=True,
|
||||
)(
|
||||
function_name,
|
||||
request.user,
|
||||
**retry_arguments,
|
||||
)
|
||||
if function_name == "search_places":
|
||||
if not (prepared_arguments.get("category") or "").strip():
|
||||
inferred_category = self._infer_search_places_category(
|
||||
user_content,
|
||||
prior_user_messages,
|
||||
)
|
||||
if inferred_category:
|
||||
prepared_arguments["category"] = inferred_category
|
||||
|
||||
if not self._is_required_param_tool_error(retry_result):
|
||||
result = retry_result
|
||||
if prepared_arguments != arguments:
|
||||
tool_call_for_history = {
|
||||
**tool_call,
|
||||
"function": {
|
||||
**function_payload,
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(retry_arguments),
|
||||
"arguments": json.dumps(prepared_arguments),
|
||||
},
|
||||
}
|
||||
|
||||
result = await sync_to_async(
|
||||
execute_tool, thread_sensitive=True
|
||||
)(
|
||||
function_name,
|
||||
request.user,
|
||||
**prepared_arguments,
|
||||
)
|
||||
|
||||
if self._is_search_places_missing_location_required_error(
|
||||
function_name,
|
||||
result,
|
||||
):
|
||||
retry_locations = []
|
||||
if trip_context_location:
|
||||
retry_locations.append(trip_context_location)
|
||||
if self._is_likely_location_reply(user_content):
|
||||
retry_locations.append(user_content)
|
||||
|
||||
seen_retry_locations = set()
|
||||
for retry_location in retry_locations:
|
||||
normalized_retry_location = (
|
||||
retry_location.strip().lower()
|
||||
)
|
||||
if (
|
||||
not normalized_retry_location
|
||||
or normalized_retry_location in seen_retry_locations
|
||||
):
|
||||
continue
|
||||
seen_retry_locations.add(normalized_retry_location)
|
||||
|
||||
retry_arguments = dict(prepared_arguments)
|
||||
retry_arguments["location"] = retry_location
|
||||
retry_result = await sync_to_async(
|
||||
execute_tool,
|
||||
thread_sensitive=True,
|
||||
)(
|
||||
function_name,
|
||||
request.user,
|
||||
**retry_arguments,
|
||||
)
|
||||
|
||||
if not self._is_required_param_tool_error(retry_result):
|
||||
result = retry_result
|
||||
tool_call_for_history = {
|
||||
**tool_call,
|
||||
"function": {
|
||||
**function_payload,
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(retry_arguments),
|
||||
},
|
||||
}
|
||||
break
|
||||
|
||||
if self._is_required_param_tool_error(result):
|
||||
assistant_message_kwargs = {
|
||||
"conversation": conversation,
|
||||
|
||||
Reference in New Issue
Block a user