changes
This commit is contained in:
@@ -1,22 +1,34 @@
|
||||
import json
|
||||
from datetime import date
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import requests as _requests
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.test import TestCase
|
||||
from rest_framework.test import APITransactionTestCase
|
||||
|
||||
from adventures.models import Collection, CollectionItineraryItem, Location
|
||||
from adventures.models import Collection, CollectionItineraryItem, Location, Visit
|
||||
from chat.agent_tools import (
|
||||
add_to_itinerary,
|
||||
execute_tool,
|
||||
add_lodging,
|
||||
add_transportation,
|
||||
get_trip_details,
|
||||
move_itinerary_item,
|
||||
remove_itinerary_item,
|
||||
remove_lodging,
|
||||
remove_transportation,
|
||||
search_places,
|
||||
update_location_details,
|
||||
update_lodging,
|
||||
update_transportation,
|
||||
web_search,
|
||||
)
|
||||
from chat.views import ChatViewSet
|
||||
from chat.views.day_suggestions import DaySuggestionsView
|
||||
|
||||
|
||||
User = get_user_model()
|
||||
@@ -245,6 +257,209 @@ class ChatAgentToolSharedTripAccessTests(TestCase):
|
||||
self.assertEqual(itinerary_result, {"error": "Trip not found"})
|
||||
|
||||
|
||||
class ChatAgentToolItineraryManagementTests(TestCase):
|
||||
def setUp(self):
|
||||
self.owner = User.objects.create_user(
|
||||
username="chat-itinerary-owner",
|
||||
email="chat-itinerary-owner@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.shared_user = User.objects.create_user(
|
||||
username="chat-itinerary-shared",
|
||||
email="chat-itinerary-shared@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.non_member = User.objects.create_user(
|
||||
username="chat-itinerary-non-member",
|
||||
email="chat-itinerary-non-member@example.com",
|
||||
password="password123",
|
||||
)
|
||||
|
||||
self.collection = Collection.objects.create(
|
||||
user=self.owner,
|
||||
name="Assistant Managed Trip",
|
||||
)
|
||||
self.collection.shared_with.add(self.shared_user)
|
||||
|
||||
self.location = Location.objects.create(
|
||||
user=self.owner,
|
||||
name="Existing Stop",
|
||||
latitude=48.8566,
|
||||
longitude=2.3522,
|
||||
)
|
||||
self.collection.locations.add(self.location)
|
||||
|
||||
self.location_content_type = ContentType.objects.get_for_model(Location)
|
||||
self.day1_item = CollectionItineraryItem.objects.create(
|
||||
collection=self.collection,
|
||||
content_type=self.location_content_type,
|
||||
object_id=self.location.id,
|
||||
date="2026-06-01",
|
||||
order=0,
|
||||
is_global=False,
|
||||
)
|
||||
|
||||
def test_move_itinerary_item_allows_shared_user(self):
|
||||
result = move_itinerary_item(
|
||||
self.shared_user,
|
||||
collection_id=str(self.collection.id),
|
||||
itinerary_item_id=str(self.day1_item.id),
|
||||
date="2026-06-02",
|
||||
order=0,
|
||||
)
|
||||
|
||||
self.assertTrue(result.get("success"))
|
||||
self.day1_item.refresh_from_db()
|
||||
self.assertEqual(self.day1_item.date.isoformat(), "2026-06-02")
|
||||
self.assertEqual(self.day1_item.order, 0)
|
||||
|
||||
def test_remove_itinerary_item_removes_matching_visit_for_locations(self):
|
||||
Visit.objects.create(
|
||||
location=self.location,
|
||||
start_date="2026-06-01T10:00:00Z",
|
||||
end_date="2026-06-01T12:00:00Z",
|
||||
)
|
||||
|
||||
result = remove_itinerary_item(
|
||||
self.owner,
|
||||
collection_id=str(self.collection.id),
|
||||
itinerary_item_id=str(self.day1_item.id),
|
||||
)
|
||||
|
||||
self.assertTrue(result.get("success"))
|
||||
self.assertEqual(result.get("deleted_visit_count"), 1)
|
||||
self.assertFalse(
|
||||
CollectionItineraryItem.objects.filter(id=self.day1_item.id).exists()
|
||||
)
|
||||
self.assertEqual(Visit.objects.filter(location=self.location).count(), 0)
|
||||
|
||||
def test_update_location_details_scoped_to_collection(self):
|
||||
outsider_location = Location.objects.create(
|
||||
user=self.owner,
|
||||
name="Outside",
|
||||
latitude=1.0,
|
||||
longitude=1.0,
|
||||
)
|
||||
|
||||
denied = update_location_details(
|
||||
self.shared_user,
|
||||
collection_id=str(self.collection.id),
|
||||
location_id=str(outsider_location.id),
|
||||
name="Should fail",
|
||||
)
|
||||
self.assertEqual(denied, {"error": "Location not found in this trip"})
|
||||
|
||||
allowed = update_location_details(
|
||||
self.shared_user,
|
||||
collection_id=str(self.collection.id),
|
||||
location_id=str(self.location.id),
|
||||
name="Updated Stop",
|
||||
latitude=40.7128,
|
||||
longitude=-74.0060,
|
||||
)
|
||||
self.assertTrue(allowed.get("success"))
|
||||
self.location.refresh_from_db()
|
||||
self.assertEqual(self.location.name, "Updated Stop")
|
||||
self.assertAlmostEqual(float(self.location.latitude), 40.7128, places=4)
|
||||
self.assertAlmostEqual(float(self.location.longitude), -74.0060, places=4)
|
||||
|
||||
def test_lodging_management_tools_allow_shared_user(self):
|
||||
created = add_lodging(
|
||||
self.shared_user,
|
||||
collection_id=str(self.collection.id),
|
||||
name="River Hotel",
|
||||
type="hotel",
|
||||
location="Paris",
|
||||
check_in="2026-06-02T15:00:00Z",
|
||||
check_out="2026-06-04T11:00:00Z",
|
||||
latitude=48.85,
|
||||
longitude=2.35,
|
||||
itinerary_date="2026-06-02",
|
||||
)
|
||||
|
||||
self.assertTrue(created.get("success"))
|
||||
lodging_id = created["lodging"]["id"]
|
||||
self.assertIsNotNone(created.get("itinerary_item"))
|
||||
|
||||
updated = update_lodging(
|
||||
self.shared_user,
|
||||
collection_id=str(self.collection.id),
|
||||
lodging_id=lodging_id,
|
||||
name="River Hotel Updated",
|
||||
location="Paris Center",
|
||||
)
|
||||
self.assertTrue(updated.get("success"))
|
||||
self.assertEqual(updated["lodging"]["name"], "River Hotel Updated")
|
||||
|
||||
removed = remove_lodging(
|
||||
self.shared_user,
|
||||
collection_id=str(self.collection.id),
|
||||
lodging_id=lodging_id,
|
||||
)
|
||||
self.assertTrue(removed.get("success"))
|
||||
self.assertGreaterEqual(removed.get("removed_itinerary_items", 0), 1)
|
||||
|
||||
def test_transportation_management_tools_allow_shared_user(self):
|
||||
created = add_transportation(
|
||||
self.shared_user,
|
||||
collection_id=str(self.collection.id),
|
||||
name="Train to Lyon",
|
||||
type="train",
|
||||
date="2026-06-03T09:00:00Z",
|
||||
end_date="2026-06-03T11:00:00Z",
|
||||
from_location="Paris",
|
||||
to_location="Lyon",
|
||||
origin_latitude=48.8566,
|
||||
origin_longitude=2.3522,
|
||||
destination_latitude=45.7640,
|
||||
destination_longitude=4.8357,
|
||||
itinerary_date="2026-06-03",
|
||||
)
|
||||
|
||||
self.assertTrue(created.get("success"))
|
||||
transportation_id = created["transportation"]["id"]
|
||||
|
||||
updated = update_transportation(
|
||||
self.shared_user,
|
||||
collection_id=str(self.collection.id),
|
||||
transportation_id=transportation_id,
|
||||
to_location="Lyon Part-Dieu",
|
||||
)
|
||||
self.assertTrue(updated.get("success"))
|
||||
self.assertEqual(updated["transportation"]["to_location"], "Lyon Part-Dieu")
|
||||
|
||||
removed = remove_transportation(
|
||||
self.shared_user,
|
||||
collection_id=str(self.collection.id),
|
||||
transportation_id=transportation_id,
|
||||
)
|
||||
self.assertTrue(removed.get("success"))
|
||||
self.assertGreaterEqual(removed.get("removed_itinerary_items", 0), 1)
|
||||
|
||||
def test_management_tools_deny_non_member(self):
|
||||
move_result = move_itinerary_item(
|
||||
self.non_member,
|
||||
collection_id=str(self.collection.id),
|
||||
itinerary_item_id=str(self.day1_item.id),
|
||||
date="2026-06-02",
|
||||
)
|
||||
remove_result = remove_itinerary_item(
|
||||
self.non_member,
|
||||
collection_id=str(self.collection.id),
|
||||
itinerary_item_id=str(self.day1_item.id),
|
||||
)
|
||||
update_result = update_location_details(
|
||||
self.non_member,
|
||||
collection_id=str(self.collection.id),
|
||||
location_id=str(self.location.id),
|
||||
name="No access",
|
||||
)
|
||||
|
||||
self.assertEqual(move_result, {"error": "Trip not found"})
|
||||
self.assertEqual(remove_result, {"error": "Trip not found"})
|
||||
self.assertEqual(update_result, {"error": "Trip not found"})
|
||||
|
||||
|
||||
class ChatViewSetToolValidationBoundaryTests(TestCase):
|
||||
def test_trip_context_destination_summary_normalizes_to_first_segment(self):
|
||||
self.assertEqual(
|
||||
@@ -1286,3 +1501,337 @@ class GetWeatherCoordFallbackTests(APITransactionTestCase):
|
||||
),
|
||||
"Should emit tool_validation_error when no collection coords available",
|
||||
)
|
||||
|
||||
@patch("chat.views.execute_tool")
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
def test_get_weather_missing_coords_and_dates_retries_with_collection_dates(
|
||||
self,
|
||||
_mock_auto_profile,
|
||||
mock_stream_chat_completion,
|
||||
mock_execute_tool,
|
||||
):
|
||||
user = User.objects.create_user(
|
||||
username="weather-coord-dates-user",
|
||||
email="weather-coord-dates-user@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.client.force_authenticate(user=user)
|
||||
|
||||
collection = Collection.objects.create(
|
||||
user_id=user.id,
|
||||
name="Vienna Trip",
|
||||
start_date=date(2026, 6, 10),
|
||||
end_date=date(2026, 6, 12),
|
||||
)
|
||||
vienna_location = Location.objects.create(
|
||||
user_id=user.id,
|
||||
name="Vienna",
|
||||
latitude=48.2082,
|
||||
longitude=16.3738,
|
||||
)
|
||||
collection.locations.add(vienna_location)
|
||||
|
||||
conversation_response = self.client.post(
|
||||
"/api/chat/conversations/",
|
||||
{"title": "Weather Coord+Dates Fallback Test"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(conversation_response.status_code, 201)
|
||||
conversation_id = conversation_response.json()["id"]
|
||||
|
||||
async def weather_stream(*args, **kwargs):
|
||||
yield 'data: {"tool_calls": [{"index": 0, "id": "call_w3", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def success_stream(*args, **kwargs):
|
||||
yield 'data: {"content": "Vienna forecast loaded."}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
mock_stream_chat_completion.side_effect = [weather_stream(), success_stream()]
|
||||
mock_execute_tool.side_effect = [
|
||||
{"error": "latitude and longitude are required"},
|
||||
{
|
||||
"location": "Vienna",
|
||||
"forecast": [
|
||||
{"date": "2026-06-10", "temperature": 24, "condition": "sunny"}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||
{
|
||||
"message": "What's the weather for my trip?",
|
||||
"collection_id": str(collection.id),
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
chunks = [
|
||||
chunk.decode("utf-8")
|
||||
if isinstance(chunk, (bytes, bytearray))
|
||||
else str(chunk)
|
||||
for chunk in response.streaming_content
|
||||
]
|
||||
payload_lines = [
|
||||
chunk.strip()[len("data: ") :]
|
||||
for chunk in chunks
|
||||
if chunk.strip().startswith("data: ")
|
||||
]
|
||||
json_payloads = [json.loads(p) for p in payload_lines if p != "[DONE]"]
|
||||
|
||||
self.assertEqual(mock_execute_tool.call_count, 2)
|
||||
retry_kwargs = mock_execute_tool.call_args_list[1][1]
|
||||
self.assertAlmostEqual(retry_kwargs.get("latitude"), 48.2082, places=3)
|
||||
self.assertAlmostEqual(retry_kwargs.get("longitude"), 16.3738, places=3)
|
||||
self.assertEqual(
|
||||
retry_kwargs.get("dates"),
|
||||
["2026-06-10", "2026-06-11", "2026-06-12"],
|
||||
)
|
||||
self.assertFalse(
|
||||
any(
|
||||
payload.get("error_category") == "tool_execution_error"
|
||||
for payload in json_payloads
|
||||
)
|
||||
)
|
||||
|
||||
@patch("chat.views.execute_tool")
|
||||
@patch("chat.views.stream_chat_completion")
|
||||
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
|
||||
def test_get_weather_missing_dates_only_remains_validation_error(
|
||||
self,
|
||||
_mock_auto_profile,
|
||||
mock_stream_chat_completion,
|
||||
mock_execute_tool,
|
||||
):
|
||||
user = User.objects.create_user(
|
||||
username="weather-missing-dates-user",
|
||||
email="weather-missing-dates-user@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.client.force_authenticate(user=user)
|
||||
|
||||
collection = Collection.objects.create(
|
||||
user_id=user.id,
|
||||
name="Berlin Trip",
|
||||
)
|
||||
berlin_location = Location.objects.create(
|
||||
user_id=user.id,
|
||||
name="Berlin",
|
||||
latitude=52.52,
|
||||
longitude=13.405,
|
||||
)
|
||||
collection.locations.add(berlin_location)
|
||||
|
||||
conversation_response = self.client.post(
|
||||
"/api/chat/conversations/",
|
||||
{"title": "Weather Missing Dates Validation Test"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(conversation_response.status_code, 201)
|
||||
conversation_id = conversation_response.json()["id"]
|
||||
|
||||
async def weather_stream(*args, **kwargs):
|
||||
yield (
|
||||
'data: {"tool_calls": [{"index": 0, "id": "call_w4", "type": '
|
||||
'"function", "function": {"name": "get_weather", "arguments": '
|
||||
'"{\\"latitude\\":52.52,\\"longitude\\":13.405}"}}]}\n\n'
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
mock_stream_chat_completion.side_effect = weather_stream
|
||||
mock_execute_tool.return_value = {"error": "dates is required"}
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/chat/conversations/{conversation_id}/send_message/",
|
||||
{
|
||||
"message": "What's the weather there?",
|
||||
"collection_id": str(collection.id),
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
payload_lines = [
|
||||
(
|
||||
chunk.decode("utf-8")
|
||||
if isinstance(chunk, (bytes, bytearray))
|
||||
else str(chunk)
|
||||
).strip()[len("data: ") :]
|
||||
for chunk in response.streaming_content
|
||||
if (
|
||||
(
|
||||
chunk.decode("utf-8")
|
||||
if isinstance(chunk, (bytes, bytearray))
|
||||
else str(chunk)
|
||||
)
|
||||
.strip()
|
||||
.startswith("data: ")
|
||||
)
|
||||
]
|
||||
json_payloads = [json.loads(p) for p in payload_lines if p != "[DONE]"]
|
||||
|
||||
self.assertEqual(mock_execute_tool.call_count, 1)
|
||||
self.assertTrue(
|
||||
any(
|
||||
payload.get("error_category") == "tool_validation_error"
|
||||
for payload in json_payloads
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
any(
|
||||
payload.get("error_category") == "tool_execution_error"
|
||||
for payload in json_payloads
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
any(
|
||||
"Could not fetch weather for the collection locations"
|
||||
in payload.get("error", "")
|
||||
for payload in json_payloads
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class DaySuggestionsCoordinateEnrichmentTests(TestCase):
|
||||
def setUp(self):
|
||||
self.view = DaySuggestionsView()
|
||||
|
||||
def test_enriches_suggestion_with_coordinates_from_place_context(self):
|
||||
suggestions = [
|
||||
{
|
||||
"name": "Roscioli",
|
||||
"location": "Via dei Giubbonari, Rome",
|
||||
"description": "Classic Roman spot",
|
||||
}
|
||||
]
|
||||
place_candidates = [
|
||||
{
|
||||
"name": "Roscioli",
|
||||
"address": "Via dei Giubbonari, Rome",
|
||||
"latitude": 41.8933,
|
||||
"longitude": 12.4722,
|
||||
}
|
||||
]
|
||||
|
||||
enriched = self.view._enrich_suggestions_with_coordinates(
|
||||
suggestions,
|
||||
place_candidates,
|
||||
)
|
||||
|
||||
self.assertEqual(len(enriched), 1)
|
||||
self.assertEqual(enriched[0]["latitude"], 41.8933)
|
||||
self.assertEqual(enriched[0]["longitude"], 12.4722)
|
||||
|
||||
def test_preserves_existing_suggestion_coordinates(self):
|
||||
suggestions = [
|
||||
{
|
||||
"name": "Known Place",
|
||||
"location": "Somewhere",
|
||||
"latitude": 10.5,
|
||||
"longitude": 20.5,
|
||||
}
|
||||
]
|
||||
place_candidates = [
|
||||
{
|
||||
"name": "Known Place",
|
||||
"address": "Somewhere",
|
||||
"latitude": 1.0,
|
||||
"longitude": 2.0,
|
||||
}
|
||||
]
|
||||
|
||||
enriched = self.view._enrich_suggestions_with_coordinates(
|
||||
suggestions,
|
||||
place_candidates,
|
||||
)
|
||||
|
||||
self.assertEqual(enriched[0]["latitude"], 10.5)
|
||||
self.assertEqual(enriched[0]["longitude"], 20.5)
|
||||
|
||||
def test_enriches_coordinates_with_token_based_name_matching(self):
|
||||
suggestions = [
|
||||
{
|
||||
"name": "Borough food market",
|
||||
"location": "South Bank",
|
||||
"description": "Popular food destination",
|
||||
}
|
||||
]
|
||||
place_candidates = [
|
||||
{
|
||||
"name": "Borough Market",
|
||||
"address": "8 Southwark St, London SE1 1TL",
|
||||
"latitude": 51.5055,
|
||||
"longitude": -0.0904,
|
||||
}
|
||||
]
|
||||
|
||||
enriched = self.view._enrich_suggestions_with_coordinates(
|
||||
suggestions,
|
||||
place_candidates,
|
||||
)
|
||||
|
||||
self.assertEqual(len(enriched), 1)
|
||||
self.assertEqual(enriched[0]["latitude"], 51.5055)
|
||||
self.assertEqual(enriched[0]["longitude"], -0.0904)
|
||||
|
||||
def test_falls_back_to_coordinate_match_when_best_text_match_has_no_coordinates(
|
||||
self,
|
||||
):
|
||||
suggestions = [
|
||||
{
|
||||
"name": "Sunset Bar",
|
||||
"location": "Pier 7",
|
||||
"description": "Cocktail spot",
|
||||
}
|
||||
]
|
||||
place_candidates = [
|
||||
{
|
||||
"name": "Sunset Bar",
|
||||
"address": "Pier 7",
|
||||
"latitude": None,
|
||||
"longitude": None,
|
||||
},
|
||||
{
|
||||
"name": "Harbor Walk",
|
||||
"address": "Pier 7, Lisbon",
|
||||
"latitude": 38.7072,
|
||||
"longitude": -9.1366,
|
||||
},
|
||||
]
|
||||
|
||||
enriched = self.view._enrich_suggestions_with_coordinates(
|
||||
suggestions,
|
||||
place_candidates,
|
||||
)
|
||||
|
||||
self.assertEqual(len(enriched), 1)
|
||||
self.assertEqual(enriched[0]["latitude"], 38.7072)
|
||||
self.assertEqual(enriched[0]["longitude"], -9.1366)
|
||||
|
||||
def test_does_not_inject_null_coordinates_when_no_coordinate_match_exists(self):
|
||||
suggestions = [
|
||||
{
|
||||
"name": "Skyline View",
|
||||
"location": "Hilltop",
|
||||
}
|
||||
]
|
||||
place_candidates = [
|
||||
{
|
||||
"name": "Skyline View",
|
||||
"address": "Hilltop",
|
||||
"latitude": None,
|
||||
"longitude": None,
|
||||
}
|
||||
]
|
||||
|
||||
enriched = self.view._enrich_suggestions_with_coordinates(
|
||||
suggestions,
|
||||
place_candidates,
|
||||
)
|
||||
|
||||
self.assertEqual(len(enriched), 1)
|
||||
self.assertNotIn("latitude", enriched[0])
|
||||
self.assertNotIn("longitude", enriched[0])
|
||||
|
||||
Reference in New Issue
Block a user