fix(chat): normalize itinerary search locations
This commit is contained in:
@@ -120,6 +120,15 @@ class ChatAgentToolSharedTripAccessTests(TestCase):
|
||||
|
||||
|
||||
class ChatViewSetToolValidationBoundaryTests(TestCase):
|
||||
def test_trip_context_destination_summary_normalizes_to_first_segment(self):
|
||||
self.assertEqual(
|
||||
ChatViewSet._trip_context_search_location(
|
||||
"A; B; +1 more",
|
||||
["Fallback City"],
|
||||
),
|
||||
"A",
|
||||
)
|
||||
|
||||
def test_dates_is_required_matches_required_param_short_circuit(self):
|
||||
self.assertTrue(
|
||||
ChatViewSet._is_required_param_tool_error({"error": "dates is required"})
|
||||
@@ -467,3 +476,84 @@ class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase):
|
||||
self.assertEqual(first_call_kwargs.get("category"), "food")
|
||||
self.assertEqual(second_call_kwargs.get("category"), "food")
|
||||
self.assertEqual(second_call_kwargs.get("location"), "Rome, Italy")
|
||||
|
||||
@patch("chat.views.execute_tool")
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
def test_trip_context_retry_uses_normalized_summary_destination_for_search_places(
|
||||
self,
|
||||
_mock_auto_profile,
|
||||
mock_stream_chat_completion,
|
||||
mock_execute_tool,
|
||||
):
|
||||
user = User.objects.create_user(
|
||||
username="chat-summary-retry-user",
|
||||
email="chat-summary-retry-user@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.client.force_authenticate(user=user)
|
||||
|
||||
conversation_response = self.client.post(
|
||||
"/api/chat/conversations/",
|
||||
{"title": "Summary Destination Retry Test"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(conversation_response.status_code, 201)
|
||||
conversation_id = conversation_response.json()["id"]
|
||||
|
||||
async def first_stream(*args, **kwargs):
|
||||
yield 'data: {"tool_calls": [{"index": 0, "id": "call_1", "type": "function", "function": {"name": "search_places", "arguments": "{}"}}]}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def second_stream(*args, **kwargs):
|
||||
yield 'data: {"content": "I found places in Paris."}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
mock_stream_chat_completion.side_effect = [first_stream(), second_stream()]
|
||||
mock_execute_tool.side_effect = [
|
||||
{"error": "location is required"},
|
||||
{"results": [{"name": "Louvre Museum"}]},
|
||||
]
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||
{
|
||||
"message": "Find top attractions",
|
||||
"destination": "Paris, France; Rome, Italy; +1 more",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
chunks = [
|
||||
chunk.decode("utf-8")
|
||||
if isinstance(chunk, (bytes, bytearray))
|
||||
else str(chunk)
|
||||
for chunk in response.streaming_content
|
||||
]
|
||||
payload_lines = [
|
||||
chunk.strip()[len("data: ") :]
|
||||
for chunk in chunks
|
||||
if chunk.strip().startswith("data: ")
|
||||
]
|
||||
json_payloads = [
|
||||
json.loads(payload) for payload in payload_lines if payload != "[DONE]"
|
||||
]
|
||||
|
||||
self.assertTrue(any("tool_result" in payload for payload in json_payloads))
|
||||
self.assertTrue(
|
||||
any(
|
||||
payload.get("content") == "I found places in Paris."
|
||||
for payload in json_payloads
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
any(
|
||||
"specific location" in payload.get("content", "").lower()
|
||||
for payload in json_payloads
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(mock_execute_tool.call_count, 2)
|
||||
second_call_kwargs = mock_execute_tool.call_args_list[1].kwargs
|
||||
self.assertEqual(second_call_kwargs.get("location"), "Paris, France")
|
||||
|
||||
@@ -199,8 +199,31 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _trip_context_search_location(destination, itinerary_stops):
|
||||
def _normalize_trip_context_destination(destination):
|
||||
destination_text = (destination or "").strip()
|
||||
if not destination_text:
|
||||
return ""
|
||||
|
||||
if ";" not in destination_text:
|
||||
if re.fullmatch(r"\+\d+\s+more", destination_text, re.IGNORECASE):
|
||||
return ""
|
||||
return destination_text
|
||||
|
||||
for segment in destination_text.split(";"):
|
||||
segment_text = segment.strip()
|
||||
if not segment_text:
|
||||
continue
|
||||
|
||||
if re.fullmatch(r"\+\d+\s+more", segment_text, re.IGNORECASE):
|
||||
continue
|
||||
|
||||
return segment_text
|
||||
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _trip_context_search_location(cls, destination, itinerary_stops):
|
||||
destination_text = cls._normalize_trip_context_destination(destination)
|
||||
if destination_text:
|
||||
return destination_text
|
||||
|
||||
@@ -309,8 +332,6 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
itinerary_stops = []
|
||||
if collection_name:
|
||||
context_parts.append(f"Trip: {collection_name}")
|
||||
if destination:
|
||||
context_parts.append(f"Destination: {destination}")
|
||||
if start_date and end_date:
|
||||
context_parts.append(f"Dates: {start_date} to {end_date}")
|
||||
|
||||
@@ -369,6 +390,8 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
trip_context_location = self._trip_context_search_location(
|
||||
destination, itinerary_stops
|
||||
)
|
||||
if trip_context_location:
|
||||
context_parts.append(f"Destination: {trip_context_location}")
|
||||
prior_user_messages = list(
|
||||
conversation.messages.filter(role="user")
|
||||
.order_by("-created_at")
|
||||
|
||||
Reference in New Issue
Block a user