diff --git a/backend/server/chat/tests.py b/backend/server/chat/tests.py index b78e3852..ab3d47e8 100644 --- a/backend/server/chat/tests.py +++ b/backend/server/chat/tests.py @@ -164,6 +164,19 @@ class ChatViewSetToolValidationBoundaryTests(TestCase): ) ) + def test_likely_location_reply_heuristic_positive_case(self): + self.assertTrue(ChatViewSet._is_likely_location_reply("london")) + + def test_likely_location_reply_heuristic_negative_question(self): + self.assertFalse(ChatViewSet._is_likely_location_reply("where should I go?")) + + def test_likely_location_reply_heuristic_negative_long_sentence(self): + self.assertFalse( + ChatViewSet._is_likely_location_reply( + "I am not sure what city yet, maybe something with beaches and nice museums" + ) + ) + class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase): @patch("chat.views.execute_tool") @@ -242,3 +255,95 @@ class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase): ) self.assertIsNotNone(clarifying_message) self.assertIn("specific location", clarifying_message.content.lower()) + + @patch("chat.views.execute_tool") + @patch("chat.views.stream_chat_completion") + @patch("integrations.utils.auto_profile.update_auto_preference_profile") + def test_missing_location_retry_uses_user_reply_and_avoids_clarification_loop( + self, + _mock_auto_profile, + mock_stream_chat_completion, + mock_execute_tool, + ): + user = User.objects.create_user( + username="chat-location-retry-user", + email="chat-location-retry-user@example.com", + password="password123", + ) + self.client.force_authenticate(user=user) + + conversation_response = self.client.post( + "/api/chat/conversations/", + {"title": "Location Retry Test"}, + format="json", + ) + self.assertEqual(conversation_response.status_code, 201) + conversation_id = conversation_response.json()["id"] + + async def first_stream(*args, **kwargs): + yield 'data: {"tool_calls": [{"index": 0, "id": "call_1", "type": "function", "function": {"name": "search_places", "arguments": "{}"}}]}\n\n' + yield "data: [DONE]\n\n" + + async def second_stream(*args, **kwargs): + yield 'data: {"content": "Great, here are top spots in London."}\n\n' + yield "data: [DONE]\n\n" + + mock_stream_chat_completion.side_effect = [first_stream(), second_stream()] + mock_execute_tool.side_effect = [ + {"error": "location is required"}, + {"results": [{"name": "British Museum"}]}, + ] + + response = self.client.post( + f"/api/chat/conversations/{conversation_id}/send_message/", + {"message": "london"}, + format="json", + ) + + self.assertEqual(response.status_code, 200) + chunks = [ + chunk.decode("utf-8") + if isinstance(chunk, (bytes, bytearray)) + else str(chunk) + for chunk in response.streaming_content + ] + payload_lines = [ + chunk.strip()[len("data: ") :] + for chunk in chunks + if chunk.strip().startswith("data: ") + ] + json_payloads = [ + json.loads(payload) for payload in payload_lines if payload != "[DONE]" + ] + + self.assertTrue(any("tool_result" in payload for payload in json_payloads)) + self.assertTrue( + any( + payload.get("content") == "Great, here are top spots in London." + for payload in json_payloads + ) + ) + self.assertFalse( + any( + "specific location" in payload.get("content", "").lower() + for payload in json_payloads + ) + ) + self.assertFalse( + any(payload.get("error_category") for payload in json_payloads) + ) + + self.assertEqual(mock_execute_tool.call_count, 2) + self.assertEqual( + mock_execute_tool.call_args_list[1].kwargs.get("location"), "london" + ) + + assistant_messages = user.chat_conversations.get( + id=conversation_id + ).messages.filter(role="assistant") + self.assertFalse( + any( + "specific location" in message.content.lower() + for message in assistant_messages + ) + ) diff --git a/backend/server/chat/views/__init__.py b/backend/server/chat/views/__init__.py index 07a23791..22b7bafc 100644 --- a/backend/server/chat/views/__init__.py +++ b/backend/server/chat/views/__init__.py @@ -198,6 +198,27 @@ class ChatViewSet(viewsets.ModelViewSet): "activities, or lodging." ) + @staticmethod + def _is_likely_location_reply(user_content): + if not isinstance(user_content, str): + return False + + normalized = user_content.strip() + if not normalized: + return False + + if normalized.endswith("?"): + return False + + if len(normalized) > 80: + return False + + parts = normalized.split() + if len(parts) > 6: + return False + + return bool(re.search(r"[a-z0-9]", normalized, re.IGNORECASE)) + @action(detail=True, methods=["post"]) def send_message(self, request, pk=None): # Auto-learn preferences from user's travel history @@ -412,6 +433,34 @@ class ChatViewSet(viewsets.ModelViewSet): **arguments, ) + tool_call_for_history = tool_call + + if self._is_search_places_missing_location_required_error( + function_name, + result, + ) and self._is_likely_location_reply(user_content): + retry_arguments = dict(arguments) + retry_arguments["location"] = user_content + retry_result = await sync_to_async( + execute_tool, + thread_sensitive=True, + )( + function_name, + request.user, + **retry_arguments, + ) + + if not self._is_required_param_tool_error(retry_result): + result = retry_result + tool_call_for_history = { + **tool_call, + "function": { + **function_payload, + "name": function_name, + "arguments": json.dumps(retry_arguments), + }, + } + if self._is_required_param_tool_error(result): assistant_message_kwargs = { "conversation": conversation, @@ -480,19 +529,19 @@ class ChatViewSet(viewsets.ModelViewSet): result_content = serialize_tool_result(result) - successful_tool_calls.append(tool_call) + successful_tool_calls.append(tool_call_for_history) tool_message_payload = { "conversation": conversation, "role": "tool", "content": result_content, - "tool_call_id": tool_call.get("id"), + "tool_call_id": tool_call_for_history.get("id"), "name": function_name, } successful_tool_messages.append(tool_message_payload) successful_tool_chat_entries.append( { "role": "tool", - "tool_call_id": tool_call.get("id"), + "tool_call_id": tool_call_for_history.get("id"), "name": function_name, "content": result_content, } @@ -500,7 +549,7 @@ class ChatViewSet(viewsets.ModelViewSet): tool_event = { "tool_result": { - "tool_call_id": tool_call.get("id"), + "tool_call_id": tool_call_for_history.get("id"), "name": function_name, "result": result, }