fix(chat): retry search_places using user location reply
This commit is contained in:
@@ -164,6 +164,19 @@ class ChatViewSetToolValidationBoundaryTests(TestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_likely_location_reply_heuristic_positive_case(self):
|
||||||
|
self.assertTrue(ChatViewSet._is_likely_location_reply("london"))
|
||||||
|
|
||||||
|
def test_likely_location_reply_heuristic_negative_question(self):
|
||||||
|
self.assertFalse(ChatViewSet._is_likely_location_reply("where should I go?"))
|
||||||
|
|
||||||
|
def test_likely_location_reply_heuristic_negative_long_sentence(self):
|
||||||
|
self.assertFalse(
|
||||||
|
ChatViewSet._is_likely_location_reply(
|
||||||
|
"I am not sure what city yet, maybe something with beaches and nice museums"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase):
|
class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase):
|
||||||
@patch("chat.views.execute_tool")
|
@patch("chat.views.execute_tool")
|
||||||
@@ -242,3 +255,95 @@ class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase):
|
|||||||
)
|
)
|
||||||
self.assertIsNotNone(clarifying_message)
|
self.assertIsNotNone(clarifying_message)
|
||||||
self.assertIn("specific location", clarifying_message.content.lower())
|
self.assertIn("specific location", clarifying_message.content.lower())
|
||||||
|
|
||||||
|
@patch("chat.views.execute_tool")
|
||||||
|
@patch("chat.views.stream_chat_completion")
|
||||||
|
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||||
|
def test_missing_location_retry_uses_user_reply_and_avoids_clarification_loop(
|
||||||
|
self,
|
||||||
|
_mock_auto_profile,
|
||||||
|
mock_stream_chat_completion,
|
||||||
|
mock_execute_tool,
|
||||||
|
):
|
||||||
|
user = User.objects.create_user(
|
||||||
|
username="chat-location-retry-user",
|
||||||
|
email="chat-location-retry-user@example.com",
|
||||||
|
password="password123",
|
||||||
|
)
|
||||||
|
self.client.force_authenticate(user=user)
|
||||||
|
|
||||||
|
conversation_response = self.client.post(
|
||||||
|
"/api/chat/conversations/",
|
||||||
|
{"title": "Location Retry Test"},
|
||||||
|
format="json",
|
||||||
|
)
|
||||||
|
self.assertEqual(conversation_response.status_code, 201)
|
||||||
|
conversation_id = conversation_response.json()["id"]
|
||||||
|
|
||||||
|
async def first_stream(*args, **kwargs):
|
||||||
|
yield 'data: {"tool_calls": [{"index": 0, "id": "call_1", "type": "function", "function": {"name": "search_places", "arguments": "{}"}}]}\n\n'
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
async def second_stream(*args, **kwargs):
|
||||||
|
yield 'data: {"content": "Great, here are top spots in London."}\n\n'
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
mock_stream_chat_completion.side_effect = [first_stream(), second_stream()]
|
||||||
|
mock_execute_tool.side_effect = [
|
||||||
|
{"error": "location is required"},
|
||||||
|
{"results": [{"name": "British Museum"}]},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = self.client.post(
|
||||||
|
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||||
|
{"message": "london"},
|
||||||
|
format="json",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
chunks = [
|
||||||
|
chunk.decode("utf-8")
|
||||||
|
if isinstance(chunk, (bytes, bytearray))
|
||||||
|
else str(chunk)
|
||||||
|
for chunk in response.streaming_content
|
||||||
|
]
|
||||||
|
payload_lines = [
|
||||||
|
chunk.strip()[len("data: ") :]
|
||||||
|
for chunk in chunks
|
||||||
|
if chunk.strip().startswith("data: ")
|
||||||
|
]
|
||||||
|
json_payloads = [
|
||||||
|
json.loads(payload) for payload in payload_lines if payload != "[DONE]"
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertTrue(any("tool_result" in payload for payload in json_payloads))
|
||||||
|
self.assertTrue(
|
||||||
|
any(
|
||||||
|
payload.get("content") == "Great, here are top spots in London."
|
||||||
|
for payload in json_payloads
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
any(
|
||||||
|
"specific location" in payload.get("content", "").lower()
|
||||||
|
for payload in json_payloads
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
any(payload.get("error_category") for payload in json_payloads)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(mock_execute_tool.call_count, 2)
|
||||||
|
self.assertEqual(
|
||||||
|
mock_execute_tool.call_args_list[1].kwargs.get("location"), "london"
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant_messages = user.chat_conversations.get(
|
||||||
|
id=conversation_id
|
||||||
|
).messages.filter(role="assistant")
|
||||||
|
self.assertFalse(
|
||||||
|
any(
|
||||||
|
"specific location" in message.content.lower()
|
||||||
|
for message in assistant_messages
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -198,6 +198,27 @@ class ChatViewSet(viewsets.ModelViewSet):
|
|||||||
"activities, or lodging."
|
"activities, or lodging."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_likely_location_reply(user_content):
|
||||||
|
if not isinstance(user_content, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
normalized = user_content.strip()
|
||||||
|
if not normalized:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if normalized.endswith("?"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if len(normalized) > 80:
|
||||||
|
return False
|
||||||
|
|
||||||
|
parts = normalized.split()
|
||||||
|
if len(parts) > 6:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return bool(re.search(r"[a-z0-9]", normalized, re.IGNORECASE))
|
||||||
|
|
||||||
@action(detail=True, methods=["post"])
|
@action(detail=True, methods=["post"])
|
||||||
def send_message(self, request, pk=None):
|
def send_message(self, request, pk=None):
|
||||||
# Auto-learn preferences from user's travel history
|
# Auto-learn preferences from user's travel history
|
||||||
@@ -412,6 +433,34 @@ class ChatViewSet(viewsets.ModelViewSet):
|
|||||||
**arguments,
|
**arguments,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool_call_for_history = tool_call
|
||||||
|
|
||||||
|
if self._is_search_places_missing_location_required_error(
|
||||||
|
function_name,
|
||||||
|
result,
|
||||||
|
) and self._is_likely_location_reply(user_content):
|
||||||
|
retry_arguments = dict(arguments)
|
||||||
|
retry_arguments["location"] = user_content
|
||||||
|
retry_result = await sync_to_async(
|
||||||
|
execute_tool,
|
||||||
|
thread_sensitive=True,
|
||||||
|
)(
|
||||||
|
function_name,
|
||||||
|
request.user,
|
||||||
|
**retry_arguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._is_required_param_tool_error(retry_result):
|
||||||
|
result = retry_result
|
||||||
|
tool_call_for_history = {
|
||||||
|
**tool_call,
|
||||||
|
"function": {
|
||||||
|
**function_payload,
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": json.dumps(retry_arguments),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
if self._is_required_param_tool_error(result):
|
if self._is_required_param_tool_error(result):
|
||||||
assistant_message_kwargs = {
|
assistant_message_kwargs = {
|
||||||
"conversation": conversation,
|
"conversation": conversation,
|
||||||
@@ -480,19 +529,19 @@ class ChatViewSet(viewsets.ModelViewSet):
|
|||||||
|
|
||||||
result_content = serialize_tool_result(result)
|
result_content = serialize_tool_result(result)
|
||||||
|
|
||||||
successful_tool_calls.append(tool_call)
|
successful_tool_calls.append(tool_call_for_history)
|
||||||
tool_message_payload = {
|
tool_message_payload = {
|
||||||
"conversation": conversation,
|
"conversation": conversation,
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"content": result_content,
|
"content": result_content,
|
||||||
"tool_call_id": tool_call.get("id"),
|
"tool_call_id": tool_call_for_history.get("id"),
|
||||||
"name": function_name,
|
"name": function_name,
|
||||||
}
|
}
|
||||||
successful_tool_messages.append(tool_message_payload)
|
successful_tool_messages.append(tool_message_payload)
|
||||||
successful_tool_chat_entries.append(
|
successful_tool_chat_entries.append(
|
||||||
{
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": tool_call.get("id"),
|
"tool_call_id": tool_call_for_history.get("id"),
|
||||||
"name": function_name,
|
"name": function_name,
|
||||||
"content": result_content,
|
"content": result_content,
|
||||||
}
|
}
|
||||||
@@ -500,7 +549,7 @@ class ChatViewSet(viewsets.ModelViewSet):
|
|||||||
|
|
||||||
tool_event = {
|
tool_event = {
|
||||||
"tool_result": {
|
"tool_result": {
|
||||||
"tool_call_id": tool_call.get("id"),
|
"tool_call_id": tool_call_for_history.get("id"),
|
||||||
"name": function_name,
|
"name": function_name,
|
||||||
"result": result,
|
"result": result,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user