fix(chat): clarify missing-location search requests
This commit is contained in:
@@ -3,6 +3,7 @@ from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from adventures.models import Collection, CollectionItineraryItem
|
||||
from chat.agent_tools import add_to_itinerary, get_trip_details
|
||||
@@ -146,3 +147,98 @@ class ChatViewSetToolValidationBoundaryTests(TestCase):
|
||||
json.dumps({"error": error_text})
|
||||
)
|
||||
)
|
||||
|
||||
def test_search_places_missing_location_error_detected_for_clarification(self):
|
||||
self.assertTrue(
|
||||
ChatViewSet._is_search_places_missing_location_required_error(
|
||||
"search_places",
|
||||
{"error": "location is required"},
|
||||
)
|
||||
)
|
||||
|
||||
def test_non_search_places_required_error_not_detected_for_clarification(self):
|
||||
self.assertFalse(
|
||||
ChatViewSet._is_search_places_missing_location_required_error(
|
||||
"web_search",
|
||||
{"error": "query is required"},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ChatViewSetSearchPlacesClarificationTests(APITestCase):
|
||||
@patch("chat.views.execute_tool")
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
def test_missing_search_place_location_streams_clarifying_content(
|
||||
self,
|
||||
_mock_auto_profile,
|
||||
mock_stream_chat_completion,
|
||||
mock_execute_tool,
|
||||
):
|
||||
user = User.objects.create_user(
|
||||
username="chat-clarify-user",
|
||||
email="chat-clarify-user@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.client.force_authenticate(user=user)
|
||||
|
||||
conversation_response = self.client.post(
|
||||
"/api/chat/conversations/",
|
||||
{"title": "Clarification Test"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(conversation_response.status_code, 201)
|
||||
conversation_id = conversation_response.json()["id"]
|
||||
|
||||
async def mock_stream(*args, **kwargs):
|
||||
yield 'data: {"tool_calls": [{"index": 0, "id": "call_1", "type": "function", "function": {"name": "search_places", "arguments": "{}"}}]}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
mock_stream_chat_completion.side_effect = mock_stream
|
||||
mock_execute_tool.return_value = {"error": "location is required"}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||
{"message": "Find good places"},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response["Content-Type"], "text/event-stream")
|
||||
|
||||
chunks = [
|
||||
chunk.decode("utf-8")
|
||||
if isinstance(chunk, (bytes, bytearray))
|
||||
else str(chunk)
|
||||
for chunk in response.streaming_content
|
||||
]
|
||||
payload_lines = [
|
||||
chunk.strip()[len("data: ") :]
|
||||
for chunk in chunks
|
||||
if chunk.strip().startswith("data: ")
|
||||
]
|
||||
|
||||
done_count = sum(1 for payload in payload_lines if payload == "[DONE]")
|
||||
self.assertEqual(done_count, 1)
|
||||
|
||||
json_payloads = [
|
||||
json.loads(payload) for payload in payload_lines if payload != "[DONE]"
|
||||
]
|
||||
self.assertTrue(any("content" in payload for payload in json_payloads))
|
||||
self.assertFalse(
|
||||
any(payload.get("error_category") for payload in json_payloads)
|
||||
)
|
||||
|
||||
content_payload = next(
|
||||
payload for payload in json_payloads if "content" in payload
|
||||
)
|
||||
self.assertIn("specific location", content_payload["content"].lower())
|
||||
|
||||
clarifying_message = (
|
||||
user.chat_conversations.get(id=conversation_id)
|
||||
.messages.filter(role="assistant")
|
||||
.order_by("created_at")
|
||||
.last()
|
||||
)
|
||||
self.assertIsNotNone(clarifying_message)
|
||||
self.assertIn("specific location", clarifying_message.content.lower())
|
||||
|
||||
@@ -176,6 +176,28 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
"error_category": "tool_validation_error",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _is_search_places_missing_location_required_error(cls, tool_name, result):
|
||||
if tool_name != "search_places" or not cls._is_required_param_tool_error(
|
||||
result
|
||||
):
|
||||
return False
|
||||
|
||||
error_text = (result or {}).get("error") if isinstance(result, dict) else ""
|
||||
if not isinstance(error_text, str):
|
||||
return False
|
||||
|
||||
normalized_error = error_text.strip().lower()
|
||||
return "location" in normalized_error
|
||||
|
||||
@staticmethod
|
||||
def _build_search_places_location_clarification_message():
|
||||
return (
|
||||
"Could you share the specific location you'd like me to search near "
|
||||
"(city, neighborhood, or address)? I can also focus on food, "
|
||||
"activities, or lodging."
|
||||
)
|
||||
|
||||
@action(detail=True, methods=["post"])
|
||||
def send_message(self, request, pk=None):
|
||||
# Auto-learn preferences from user's travel history
|
||||
@@ -411,6 +433,33 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
thread_sensitive=True,
|
||||
)(**tool_message)
|
||||
|
||||
if self._is_search_places_missing_location_required_error(
|
||||
function_name,
|
||||
result,
|
||||
):
|
||||
clarification_content = self._build_search_places_location_clarification_message()
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create,
|
||||
thread_sensitive=True,
|
||||
)(
|
||||
conversation=conversation,
|
||||
role="assistant",
|
||||
content=clarification_content,
|
||||
)
|
||||
|
||||
await sync_to_async(
|
||||
conversation.save,
|
||||
thread_sensitive=True,
|
||||
)(update_fields=["updated_at"])
|
||||
|
||||
yield (
|
||||
"data: "
|
||||
f"{json.dumps({'content': clarification_content})}"
|
||||
"\n\n"
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
await sync_to_async(
|
||||
conversation.save,
|
||||
thread_sensitive=True,
|
||||
|
||||
Reference in New Issue
Block a user