diff --git a/backend/server/chat/tests.py b/backend/server/chat/tests.py index f59b03ae..7b3aa43a 100644 --- a/backend/server/chat/tests.py +++ b/backend/server/chat/tests.py @@ -120,6 +120,15 @@ class ChatAgentToolSharedTripAccessTests(TestCase): class ChatViewSetToolValidationBoundaryTests(TestCase): + def test_trip_context_destination_summary_normalizes_to_first_segment(self): + self.assertEqual( + ChatViewSet._trip_context_search_location( + "A; B; +1 more", + ["Fallback City"], + ), + "A", + ) + def test_dates_is_required_matches_required_param_short_circuit(self): self.assertTrue( ChatViewSet._is_required_param_tool_error({"error": "dates is required"}) @@ -467,3 +476,84 @@ class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase): self.assertEqual(first_call_kwargs.get("category"), "food") self.assertEqual(second_call_kwargs.get("category"), "food") self.assertEqual(second_call_kwargs.get("location"), "Rome, Italy") + + @patch("chat.views.execute_tool") + @patch("chat.views.stream_chat_completion") + @patch("integrations.utils.auto_profile.update_auto_preference_profile") + def test_trip_context_retry_uses_normalized_summary_destination_for_search_places( + self, + _mock_auto_profile, + mock_stream_chat_completion, + mock_execute_tool, + ): + user = User.objects.create_user( + username="chat-summary-retry-user", + email="chat-summary-retry-user@example.com", + password="password123", + ) + self.client.force_authenticate(user=user) + + conversation_response = self.client.post( + "/api/chat/conversations/", + {"title": "Summary Destination Retry Test"}, + format="json", + ) + self.assertEqual(conversation_response.status_code, 201) + conversation_id = conversation_response.json()["id"] + + async def first_stream(*args, **kwargs): + yield 'data: {"tool_calls": [{"index": 0, "id": "call_1", "type": "function", "function": {"name": "search_places", "arguments": "{}"}}]}\n\n' + yield "data: [DONE]\n\n" + + async def second_stream(*args, **kwargs): + yield 'data: {"content": "I found places in Paris."}\n\n' + yield "data: [DONE]\n\n" + + mock_stream_chat_completion.side_effect = [first_stream(), second_stream()] + mock_execute_tool.side_effect = [ + {"error": "location is required"}, + {"results": [{"name": "Louvre Museum"}]}, + ] + + response = self.client.post( + f"/api/chat/conversations/{conversation_id}/send_message/", + { + "message": "Find top attractions", + "destination": "Paris, France; Rome, Italy; +1 more", + }, + format="json", + ) + + self.assertEqual(response.status_code, 200) + chunks = [ + chunk.decode("utf-8") + if isinstance(chunk, (bytes, bytearray)) + else str(chunk) + for chunk in response.streaming_content + ] + payload_lines = [ + chunk.strip()[len("data: ") :] + for chunk in chunks + if chunk.strip().startswith("data: ") + ] + json_payloads = [ + json.loads(payload) for payload in payload_lines if payload != "[DONE]" + ] + + self.assertTrue(any("tool_result" in payload for payload in json_payloads)) + self.assertTrue( + any( + payload.get("content") == "I found places in Paris." + for payload in json_payloads + ) + ) + self.assertFalse( + any( + "specific location" in payload.get("content", "").lower() + for payload in json_payloads + ) + ) + + self.assertEqual(mock_execute_tool.call_count, 2) + second_call_kwargs = mock_execute_tool.call_args_list[1].kwargs + self.assertEqual(second_call_kwargs.get("location"), "Paris, France") diff --git a/backend/server/chat/views/__init__.py b/backend/server/chat/views/__init__.py index 948af457..92186082 100644 --- a/backend/server/chat/views/__init__.py +++ b/backend/server/chat/views/__init__.py @@ -199,8 +199,31 @@ class ChatViewSet(viewsets.ModelViewSet): ) @staticmethod - def _trip_context_search_location(destination, itinerary_stops): + def _normalize_trip_context_destination(destination): destination_text = (destination or "").strip() + if not destination_text: + return "" + + if ";" not in destination_text: + if re.fullmatch(r"\+\d+\s+more", destination_text, re.IGNORECASE): + return "" + return destination_text + + for segment in destination_text.split(";"): + segment_text = segment.strip() + if not segment_text: + continue + + if re.fullmatch(r"\+\d+\s+more", segment_text, re.IGNORECASE): + continue + + return segment_text + + return "" + + @classmethod + def _trip_context_search_location(cls, destination, itinerary_stops): + destination_text = cls._normalize_trip_context_destination(destination) if destination_text: return destination_text @@ -309,8 +332,6 @@ class ChatViewSet(viewsets.ModelViewSet): itinerary_stops = [] if collection_name: context_parts.append(f"Trip: {collection_name}") - if destination: - context_parts.append(f"Destination: {destination}") if start_date and end_date: context_parts.append(f"Dates: {start_date} to {end_date}") @@ -369,6 +390,8 @@ class ChatViewSet(viewsets.ModelViewSet): trip_context_location = self._trip_context_search_location( destination, itinerary_stops ) + if trip_context_location: + context_parts.append(f"Destination: {trip_context_location}") prior_user_messages = list( conversation.messages.filter(role="user") .order_by("-created_at")