fix(chat): clarify missing-location search requests

This commit is contained in:
2026-03-10 16:26:02 +00:00
parent 1ad9d20037
commit 84384df236
5 changed files with 147 additions and 172 deletions

View File

@@ -3,6 +3,7 @@ from unittest.mock import patch
from django.contrib.auth import get_user_model
from django.test import TestCase
from rest_framework.test import APITestCase
from adventures.models import Collection, CollectionItineraryItem
from chat.agent_tools import add_to_itinerary, get_trip_details
@@ -146,3 +147,98 @@ class ChatViewSetToolValidationBoundaryTests(TestCase):
json.dumps({"error": error_text})
)
)
def test_search_places_missing_location_error_detected_for_clarification(self):
self.assertTrue(
ChatViewSet._is_search_places_missing_location_required_error(
"search_places",
{"error": "location is required"},
)
)
def test_non_search_places_required_error_not_detected_for_clarification(self):
self.assertFalse(
ChatViewSet._is_search_places_missing_location_required_error(
"web_search",
{"error": "query is required"},
)
)
class ChatViewSetSearchPlacesClarificationTests(APITestCase):
@patch("chat.views.execute_tool")
@patch("chat.views.stream_chat_completion")
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
def test_missing_search_place_location_streams_clarifying_content(
self,
_mock_auto_profile,
mock_stream_chat_completion,
mock_execute_tool,
):
user = User.objects.create_user(
username="chat-clarify-user",
email="chat-clarify-user@example.com",
password="password123",
)
self.client.force_authenticate(user=user)
conversation_response = self.client.post(
"/api/chat/conversations/",
{"title": "Clarification Test"},
format="json",
)
self.assertEqual(conversation_response.status_code, 201)
conversation_id = conversation_response.json()["id"]
async def mock_stream(*args, **kwargs):
yield 'data: {"tool_calls": [{"index": 0, "id": "call_1", "type": "function", "function": {"name": "search_places", "arguments": "{}"}}]}\n\n'
yield "data: [DONE]\n\n"
mock_stream_chat_completion.side_effect = mock_stream
mock_execute_tool.return_value = {"error": "location is required"}
response = self.client.post(
f"/api/chat/conversations/{conversation_id}/send_message/",
{"message": "Find good places"},
format="json",
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "text/event-stream")
chunks = [
chunk.decode("utf-8")
if isinstance(chunk, (bytes, bytearray))
else str(chunk)
for chunk in response.streaming_content
]
payload_lines = [
chunk.strip()[len("data: ") :]
for chunk in chunks
if chunk.strip().startswith("data: ")
]
done_count = sum(1 for payload in payload_lines if payload == "[DONE]")
self.assertEqual(done_count, 1)
json_payloads = [
json.loads(payload) for payload in payload_lines if payload != "[DONE]"
]
self.assertTrue(any("content" in payload for payload in json_payloads))
self.assertFalse(
any(payload.get("error_category") for payload in json_payloads)
)
content_payload = next(
payload for payload in json_payloads if "content" in payload
)
self.assertIn("specific location", content_payload["content"].lower())
clarifying_message = (
user.chat_conversations.get(id=conversation_id)
.messages.filter(role="assistant")
.order_by("created_at")
.last()
)
self.assertIsNotNone(clarifying_message)
self.assertIn("specific location", clarifying_message.content.lower())