fix: persist assistant tool errors and markdown
This commit is contained in:
@@ -8,6 +8,7 @@ import requests as _requests
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.test import TestCase
|
||||
from rest_framework.exceptions import PermissionDenied, ValidationError
|
||||
from rest_framework.test import APITransactionTestCase
|
||||
|
||||
from adventures.models import Collection, CollectionItineraryItem, Location, Visit
|
||||
@@ -459,6 +460,63 @@ class ChatAgentToolItineraryManagementTests(TestCase):
|
||||
self.assertEqual(remove_result, {"error": "Trip not found"})
|
||||
self.assertEqual(update_result, {"error": "Trip not found"})
|
||||
|
||||
@patch("chat.agent_tools.reorder_itinerary_items")
|
||||
def test_move_itinerary_item_returns_validation_error_from_reorder(
|
||||
self,
|
||||
mock_reorder_itinerary_items,
|
||||
):
|
||||
mock_reorder_itinerary_items.side_effect = ValidationError(
|
||||
{
|
||||
"items": (
|
||||
"Item abc date 2026-05-30 is before collection start date "
|
||||
"2026-06-01."
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
self.collection.start_date = date(2026, 6, 1)
|
||||
self.collection.save(update_fields=["start_date"])
|
||||
|
||||
result = move_itinerary_item(
|
||||
self.owner,
|
||||
collection_id=str(self.collection.id),
|
||||
itinerary_item_id=str(self.day1_item.id),
|
||||
date="2026-06-02",
|
||||
order=0,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result,
|
||||
{
|
||||
"error": (
|
||||
"Item abc date 2026-05-30 is before collection start date "
|
||||
"2026-06-01."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@patch("chat.agent_tools.reorder_itinerary_items")
|
||||
def test_move_itinerary_item_returns_permission_error_from_reorder(
|
||||
self,
|
||||
mock_reorder_itinerary_items,
|
||||
):
|
||||
mock_reorder_itinerary_items.side_effect = PermissionDenied(
|
||||
"You do not have permission to modify items in this collection."
|
||||
)
|
||||
|
||||
result = move_itinerary_item(
|
||||
self.owner,
|
||||
collection_id=str(self.collection.id),
|
||||
itinerary_item_id=str(self.day1_item.id),
|
||||
date="2026-06-02",
|
||||
order=0,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result,
|
||||
{"error": "You do not have permission to modify items in this collection."},
|
||||
)
|
||||
|
||||
|
||||
class ChatViewSetToolValidationBoundaryTests(TestCase):
|
||||
def test_trip_context_destination_summary_normalizes_to_first_segment(self):
|
||||
@@ -589,6 +647,70 @@ class ChatViewSetToolValidationBoundaryTests(TestCase):
|
||||
|
||||
|
||||
class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase):
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
def test_collection_context_includes_uuid_guidance_for_existing_item_tools(
|
||||
self,
|
||||
_mock_auto_profile,
|
||||
mock_stream_chat_completion,
|
||||
):
|
||||
user = User.objects.create_user(
|
||||
username="chat-context-tool-guidance-user",
|
||||
email="chat-context-tool-guidance-user@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.client.force_authenticate(user=user)
|
||||
|
||||
collection = Collection.objects.create(user=user, name="Paris Trip")
|
||||
|
||||
conversation_response = self.client.post(
|
||||
"/api/chat/conversations/",
|
||||
{"title": "Context Tool Guidance Test"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(conversation_response.status_code, 201)
|
||||
conversation_id = conversation_response.json()["id"]
|
||||
|
||||
async def stream_noop(*args, **kwargs):
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
mock_stream_chat_completion.side_effect = stream_noop
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||
{
|
||||
"message": "Move my first stop to day 2",
|
||||
"collection_id": str(collection.id),
|
||||
"collection_name": collection.name,
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
list(response.streaming_content)
|
||||
|
||||
stream_call_messages = mock_stream_chat_completion.call_args.args[1]
|
||||
system_message = next(
|
||||
message for message in stream_call_messages if message["role"] == "system"
|
||||
)
|
||||
normalized_system_content = system_message["content"].lower()
|
||||
self.assertIn(
|
||||
(
|
||||
"use this exact collection_id for get_trip_details, "
|
||||
"add_to_itinerary, move_itinerary_item, remove_itinerary_item, "
|
||||
"and update_location_details"
|
||||
),
|
||||
normalized_system_content,
|
||||
)
|
||||
self.assertIn(
|
||||
(
|
||||
"call get_trip_details first before move_itinerary_item, "
|
||||
"remove_itinerary_item, or update_location_details when exact IDs "
|
||||
"are needed"
|
||||
).lower(),
|
||||
normalized_system_content,
|
||||
)
|
||||
|
||||
@patch("chat.views.execute_tool")
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
@@ -1073,6 +1195,56 @@ class ChatViewSetSearchPlacesClarificationTests(APITransactionTestCase):
|
||||
|
||||
|
||||
class ChatViewSetToolExecutionFailureLoopTests(APITransactionTestCase):
|
||||
@patch("chat.views.execute_tool")
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
def test_required_param_terminal_error_is_persisted_as_assistant_message(
|
||||
self,
|
||||
_mock_auto_profile,
|
||||
mock_stream_chat_completion,
|
||||
mock_execute_tool,
|
||||
):
|
||||
user = User.objects.create_user(
|
||||
username="chat-required-error-persist-user",
|
||||
email="chat-required-error-persist-user@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.client.force_authenticate(user=user)
|
||||
|
||||
conversation_response = self.client.post(
|
||||
"/api/chat/conversations/",
|
||||
{"title": "Required Error Persistence Test"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(conversation_response.status_code, 201)
|
||||
conversation_id = conversation_response.json()["id"]
|
||||
|
||||
async def validation_error_stream(*args, **kwargs):
|
||||
yield 'data: {"tool_calls": [{"index": 0, "id": "call_move", "type": "function", "function": {"name": "move_itinerary_item", "arguments": "{}"}}]}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
mock_stream_chat_completion.side_effect = validation_error_stream
|
||||
mock_execute_tool.return_value = {
|
||||
"error": "collection_id, itinerary_item_id, and date are required"
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||
{"message": "Move that stop"},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
list(response.streaming_content)
|
||||
|
||||
conversation = user.chat_conversations.get(id=conversation_id)
|
||||
self.assertTrue(
|
||||
conversation.messages.filter(
|
||||
role="assistant",
|
||||
content__contains="attempted to call 'move_itinerary_item' without required arguments",
|
||||
).exists()
|
||||
)
|
||||
|
||||
@patch("chat.views.execute_tool")
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
@@ -1145,6 +1317,14 @@ class ChatViewSetToolExecutionFailureLoopTests(APITransactionTestCase):
|
||||
self.assertEqual(mock_stream_chat_completion.call_count, 3)
|
||||
self.assertEqual(mock_execute_tool.call_count, 3)
|
||||
|
||||
conversation = user.chat_conversations.get(id=conversation_id)
|
||||
self.assertTrue(
|
||||
conversation.messages.filter(
|
||||
role="assistant",
|
||||
content__contains="could not complete 'web_search'",
|
||||
).exists()
|
||||
)
|
||||
|
||||
@patch("chat.views.execute_tool")
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
|
||||
Reference in New Issue
Block a user