fix(chat): use itinerary context for restaurant searches

This commit is contained in:
2026-03-10 17:12:29 +00:00
parent a023a9548c
commit 46d7704e4f
4 changed files with 231 additions and 30 deletions

View File

@@ -5,7 +5,7 @@ from django.contrib.auth import get_user_model
from django.test import TestCase
from rest_framework.test import APITransactionTestCase
from adventures.models import Collection, CollectionItineraryItem
from adventures.models import Collection, CollectionItineraryItem, Location
from chat.agent_tools import add_to_itinerary, get_trip_details
from chat.views import ChatViewSet
@@ -177,6 +177,32 @@ class ChatViewSetToolValidationBoundaryTests(TestCase):
)
)
def test_infer_search_places_category_detects_restaurant_intent(self):
self.assertEqual(
ChatViewSet._infer_search_places_category(
"Can you find restaurants for dinner?",
[],
),
"food",
)
def test_infer_search_places_category_detects_prior_message_dining_intent(self):
self.assertEqual(
ChatViewSet._infer_search_places_category(
"What are the top picks?",
["We want good cafes nearby"],
),
"food",
)
def test_infer_search_places_category_leaves_non_dining_intent_unchanged(self):
self.assertIsNone(
ChatViewSet._infer_search_places_category(
"Find great museums and landmarks",
["We love cultural attractions"],
)
)
class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase):
@patch("chat.views.execute_tool")
@@ -347,3 +373,97 @@ class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase):
for message in assistant_messages
)
)
@patch("chat.views.execute_tool")
@patch("chat.views.stream_chat_completion")
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
def test_collection_context_retry_uses_destination_and_food_category_for_restaurants(
self,
_mock_auto_profile,
mock_stream_chat_completion,
mock_execute_tool,
):
user = User.objects.create_user(
username="chat-context-retry-user",
email="chat-context-retry-user@example.com",
password="password123",
)
self.client.force_authenticate(user=user)
collection = Collection.objects.create(user=user, name="Rome Food Trip")
trip_stop = Location.objects.create(
user=user,
name="Trevi Fountain",
location="Rome, Italy",
)
collection.locations.add(trip_stop)
conversation_response = self.client.post(
"/api/chat/conversations/",
{"title": "Collection Context Retry Test"},
format="json",
)
self.assertEqual(conversation_response.status_code, 201)
conversation_id = conversation_response.json()["id"]
async def first_stream(*args, **kwargs):
yield 'data: {"tool_calls": [{"index": 0, "id": "call_1", "type": "function", "function": {"name": "search_places", "arguments": "{}"}}]}\n\n'
yield "data: [DONE]\n\n"
async def second_stream(*args, **kwargs):
yield 'data: {"content": "Here are restaurant options in Rome."}\n\n'
yield "data: [DONE]\n\n"
mock_stream_chat_completion.side_effect = [first_stream(), second_stream()]
mock_execute_tool.side_effect = [
{"error": "location is required"},
{"results": [{"name": "Roscioli"}]},
]
response = self.client.post(
f"/api/chat/conversations/{conversation_id}/send_message/",
{
"message": "Find great restaurants for dinner",
"collection_id": str(collection.id),
"collection_name": collection.name,
"destination": "Rome, Italy",
},
format="json",
)
self.assertEqual(response.status_code, 200)
chunks = [
chunk.decode("utf-8")
if isinstance(chunk, (bytes, bytearray))
else str(chunk)
for chunk in response.streaming_content
]
payload_lines = [
chunk.strip()[len("data: ") :]
for chunk in chunks
if chunk.strip().startswith("data: ")
]
json_payloads = [
json.loads(payload) for payload in payload_lines if payload != "[DONE]"
]
self.assertTrue(any("tool_result" in payload for payload in json_payloads))
self.assertTrue(
any(
payload.get("content") == "Here are restaurant options in Rome."
for payload in json_payloads
)
)
self.assertFalse(
any(
"specific location" in payload.get("content", "").lower()
for payload in json_payloads
)
)
self.assertEqual(mock_execute_tool.call_count, 2)
first_call_kwargs = mock_execute_tool.call_args_list[0].kwargs
second_call_kwargs = mock_execute_tool.call_args_list[1].kwargs
self.assertEqual(first_call_kwargs.get("category"), "food")
self.assertEqual(second_call_kwargs.get("category"), "food")
self.assertEqual(second_call_kwargs.get("location"), "Rome, Italy")