This commit is contained in:
alex wiesner
2026-03-13 20:15:22 +00:00
parent e68c95b2dd
commit c4d39f2812
33 changed files with 2383 additions and 162 deletions

View File

@@ -1,4 +1,5 @@
import os
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
from django.db.models import Q
from .models import (
@@ -345,6 +346,8 @@ class CalendarLocationSerializer(serializers.ModelSerializer):
class LocationSerializer(CustomModelSerializer):
name = serializers.CharField(required=True)
location = serializers.CharField(required=False, allow_blank=True, allow_null=True)
images = serializers.SerializerMethodField()
visits = VisitSerializer(many=True, read_only=False, required=False)
attachments = AttachmentSerializer(many=True, read_only=True)
@@ -426,6 +429,19 @@ class LocationSerializer(CustomModelSerializer):
# Filter out None values from the serialized data
return [image for image in serializer.data if image is not None]
@staticmethod
def _truncate_to_model_max_length(value, field_name):
if value is None:
return value
max_length = Location._meta.get_field(field_name).max_length
return value[:max_length]
def validate_name(self, value):
return self._truncate_to_model_max_length(value, "name")
def validate_location(self, value):
return self._truncate_to_model_max_length(value, "location")
def validate_collections(self, collections):
"""Validate that collections are compatible with the location being created/updated"""
@@ -511,6 +527,33 @@ class LocationSerializer(CustomModelSerializer):
category_data["name"] = name
return category_data
@staticmethod
def _normalize_coordinate_input(value):
if value in (None, ""):
return value
try:
coordinate = Decimal(str(value))
except (InvalidOperation, TypeError, ValueError):
return value
return coordinate.quantize(Decimal("0.000001"), rounding=ROUND_HALF_UP)
def to_internal_value(self, data):
if self.instance is None:
normalized_data = data.copy()
for field_name in ("latitude", "longitude"):
if field_name not in normalized_data:
continue
normalized_data[field_name] = self._normalize_coordinate_input(
normalized_data.get(field_name)
)
data = normalized_data
return super().to_internal_value(data)
def get_or_create_category(self, category_data):
user = self.context["request"].user

View File

@@ -23,6 +23,7 @@ from adventures.models import (
Note,
Transportation,
)
from adventures.utils.weather import fetch_daily_temperature
User = get_user_model()
@@ -61,7 +62,11 @@ class WeatherViewTests(APITestCase):
mock_fetch_temperature.return_value = {
"date": future_date,
"available": True,
"temperature_low_c": 19.0,
"temperature_high_c": 26.0,
"temperature_c": 22.5,
"is_estimate": False,
"source": "forecast",
}
response = self.client.post(
@@ -73,18 +78,44 @@ class WeatherViewTests(APITestCase):
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["results"][0]["date"], future_date)
self.assertTrue(response.json()["results"][0]["available"])
self.assertEqual(response.json()["results"][0]["temperature_low_c"], 19.0)
self.assertEqual(response.json()["results"][0]["temperature_high_c"], 26.0)
self.assertFalse(response.json()["results"][0]["is_estimate"])
self.assertEqual(response.json()["results"][0]["source"], "forecast")
self.assertEqual(response.json()["results"][0]["temperature_c"], 22.5)
mock_fetch_temperature.assert_called_once_with(future_date, 12.34, 56.78)
@patch("adventures.views.weather_view.requests.get")
def test_daily_temperatures_far_future_returns_unavailable_when_upstream_has_no_data(
@patch("adventures.utils.weather.requests.get")
def test_daily_temperatures_far_future_uses_historical_estimate(
self, mock_requests_get
):
future_date = (timezone.now().date() + timedelta(days=3650)).isoformat()
mocked_response = Mock()
mocked_response.raise_for_status.return_value = None
mocked_response.json.return_value = {"daily": {}}
mock_requests_get.return_value = mocked_response
archive_no_data = Mock()
archive_no_data.raise_for_status.return_value = None
archive_no_data.json.return_value = {"daily": {}}
forecast_no_data = Mock()
forecast_no_data.raise_for_status.return_value = None
forecast_no_data.json.return_value = {"daily": {}}
historical_data = Mock()
historical_data.raise_for_status.return_value = None
historical_data.json.return_value = {
"daily": {
"temperature_2m_max": [15.0, 18.0, 20.0],
"temperature_2m_min": [7.0, 9.0, 11.0],
}
}
call_sequence = [archive_no_data, forecast_no_data, historical_data]
def mock_get(*args, **kwargs):
if call_sequence:
return call_sequence.pop(0)
return historical_data
mock_requests_get.side_effect = mock_get
response = self.client.post(
"/api/weather/daily-temperatures/",
@@ -93,13 +124,17 @@ class WeatherViewTests(APITestCase):
)
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json()["results"][0],
{"date": future_date, "available": False, "temperature_c": None},
)
self.assertEqual(mock_requests_get.call_count, 2)
result = response.json()["results"][0]
self.assertTrue(result["available"])
self.assertEqual(result["date"], future_date)
self.assertEqual(result["temperature_low_c"], 9.0)
self.assertEqual(result["temperature_high_c"], 17.7)
self.assertEqual(result["temperature_c"], 13.3)
self.assertTrue(result["is_estimate"])
self.assertEqual(result["source"], "historical_estimate")
self.assertGreaterEqual(mock_requests_get.call_count, 3)
@patch("adventures.views.weather_view.requests.get")
@patch("adventures.utils.weather.requests.get")
def test_daily_temperatures_accepts_zero_lat_lon(self, mock_requests_get):
today = timezone.now().date().isoformat()
mocked_response = Mock()
@@ -121,9 +156,43 @@ class WeatherViewTests(APITestCase):
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["results"][0]["date"], today)
self.assertTrue(response.json()["results"][0]["available"])
self.assertEqual(response.json()["results"][0]["temperature_low_c"], 10.0)
self.assertEqual(response.json()["results"][0]["temperature_high_c"], 20.0)
self.assertFalse(response.json()["results"][0]["is_estimate"])
self.assertEqual(response.json()["results"][0]["source"], "archive")
self.assertEqual(response.json()["results"][0]["temperature_c"], 15.0)
class WeatherHelperTests(TestCase):
@patch("adventures.utils.weather.requests.get")
def test_fetch_daily_temperature_returns_unavailable_when_all_sources_fail(
self, mock_requests_get
):
mocked_response = Mock()
mocked_response.raise_for_status.return_value = None
mocked_response.json.return_value = {"daily": {}}
mock_requests_get.return_value = mocked_response
result = fetch_daily_temperature(
date=(timezone.now().date() + timedelta(days=6000)).isoformat(),
latitude=40.7128,
longitude=-74.0060,
)
self.assertEqual(
result,
{
"date": result["date"],
"available": False,
"temperature_low_c": None,
"temperature_high_c": None,
"temperature_c": None,
"is_estimate": False,
"source": None,
},
)
class MCPAuthTests(APITestCase):
def test_mcp_unauthenticated_access_is_rejected(self):
unauthenticated_client = APIClient()
@@ -131,6 +200,52 @@ class MCPAuthTests(APITestCase):
self.assertIn(response.status_code, [401, 403])
class LocationPayloadHardeningTests(APITestCase):
def setUp(self):
self.user = User.objects.create_user(
username="location-hardening-user",
email="location-hardening@example.com",
password="password123",
)
self.client.force_authenticate(user=self.user)
def test_create_location_truncates_overlong_name_and_location(self):
overlong_name = "N" * 250
overlong_location = "L" * 250
response = self.client.post(
"/api/locations/",
{
"name": overlong_name,
"location": overlong_location,
"is_public": False,
},
format="json",
)
self.assertEqual(response.status_code, 201)
self.assertEqual(len(response.data["name"]), 200)
self.assertEqual(len(response.data["location"]), 200)
self.assertEqual(response.data["name"], overlong_name[:200])
self.assertEqual(response.data["location"], overlong_location[:200])
def test_create_location_accepts_high_precision_coordinates(self):
response = self.client.post(
"/api/locations/",
{
"name": "Precision test",
"is_public": False,
"latitude": 51.5007292,
"longitude": -0.1246254,
},
format="json",
)
self.assertEqual(response.status_code, 201)
self.assertEqual(response.data["latitude"], "51.500729")
self.assertEqual(response.data["longitude"], "-0.124625")
class CollectionViewSetTests(APITestCase):
def setUp(self):
self.owner = User.objects.create_user(

View File

@@ -0,0 +1,172 @@
import logging
from datetime import date as date_cls
import requests
logger = logging.getLogger(__name__)
OPEN_METEO_ARCHIVE_URL = "https://archive-api.open-meteo.com/v1/archive"
OPEN_METEO_FORECAST_URL = "https://api.open-meteo.com/v1/forecast"
HISTORICAL_YEARS_BACK = 5
HISTORICAL_WINDOW_DAYS = 7
def _base_payload(date: str) -> dict:
return {
"date": date,
"available": False,
"temperature_low_c": None,
"temperature_high_c": None,
"temperature_c": None,
"is_estimate": False,
"source": None,
}
def _coerce_temperature(max_values, min_values):
if not max_values or not min_values:
return None
try:
low = float(min_values[0])
high = float(max_values[0])
except (TypeError, ValueError, IndexError):
return None
avg = (low + high) / 2
return {
"temperature_low_c": round(low, 1),
"temperature_high_c": round(high, 1),
"temperature_c": round(avg, 1),
}
def _request_daily_range(
url: str, latitude: float, longitude: float, start_date: str, end_date: str
):
try:
response = requests.get(
url,
params={
"latitude": latitude,
"longitude": longitude,
"start_date": start_date,
"end_date": end_date,
"daily": "temperature_2m_max,temperature_2m_min",
"timezone": "UTC",
},
timeout=8,
)
response.raise_for_status()
return response.json()
except requests.RequestException:
return None
except ValueError:
return None
def _fetch_direct_temperature(date: str, latitude: float, longitude: float):
for source, url in (
("archive", OPEN_METEO_ARCHIVE_URL),
("forecast", OPEN_METEO_FORECAST_URL),
):
data = _request_daily_range(url, latitude, longitude, date, date)
if not data:
continue
daily = data.get("daily") or {}
temperatures = _coerce_temperature(
daily.get("temperature_2m_max") or [],
daily.get("temperature_2m_min") or [],
)
if not temperatures:
continue
return {
**temperatures,
"available": True,
"is_estimate": False,
"source": source,
}
return None
def _fetch_historical_estimate(date: str, latitude: float, longitude: float):
try:
target_date = date_cls.fromisoformat(date)
except ValueError:
return None
all_max: list[float] = []
all_min: list[float] = []
for years_back in range(1, HISTORICAL_YEARS_BACK + 1):
year = target_date.year - years_back
try:
same_day = target_date.replace(year=year)
except ValueError:
# Leap-day fallback: use Feb 28 for non-leap years
same_day = target_date.replace(year=year, day=28)
start = same_day.fromordinal(same_day.toordinal() - HISTORICAL_WINDOW_DAYS)
end = same_day.fromordinal(same_day.toordinal() + HISTORICAL_WINDOW_DAYS)
data = _request_daily_range(
OPEN_METEO_ARCHIVE_URL,
latitude,
longitude,
start.isoformat(),
end.isoformat(),
)
if not data:
continue
daily = data.get("daily") or {}
max_values = daily.get("temperature_2m_max") or []
min_values = daily.get("temperature_2m_min") or []
pair_count = min(len(max_values), len(min_values))
for index in range(pair_count):
try:
all_max.append(float(max_values[index]))
all_min.append(float(min_values[index]))
except (TypeError, ValueError):
continue
if not all_max or not all_min:
return None
avg_max = sum(all_max) / len(all_max)
avg_min = sum(all_min) / len(all_min)
avg = (avg_max + avg_min) / 2
return {
"available": True,
"temperature_low_c": round(avg_min, 1),
"temperature_high_c": round(avg_max, 1),
"temperature_c": round(avg, 1),
"is_estimate": True,
"source": "historical_estimate",
}
def fetch_daily_temperature(date: str, latitude: float, longitude: float):
payload = _base_payload(date)
direct = _fetch_direct_temperature(date, latitude, longitude)
if direct:
return {**payload, **direct}
historical_estimate = _fetch_historical_estimate(date, latitude, longitude)
if historical_estimate:
return {**payload, **historical_estimate}
logger.info(
"No weather data available for date=%s lat=%s lon=%s",
date,
latitude,
longitude,
)
return payload

View File

@@ -1,22 +1,17 @@
import hashlib
import logging
from datetime import date as date_cls
import requests
from django.core.cache import cache
from rest_framework import status, viewsets
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
logger = logging.getLogger(__name__)
from adventures.utils.weather import fetch_daily_temperature
class WeatherViewSet(viewsets.ViewSet):
permission_classes = [IsAuthenticated]
OPEN_METEO_ARCHIVE_URL = "https://archive-api.open-meteo.com/v1/archive"
OPEN_METEO_FORECAST_URL = "https://api.open-meteo.com/v1/forecast"
CACHE_TIMEOUT_SECONDS = 60 * 60 * 6
MAX_DAYS_PER_REQUEST = 60
@@ -39,7 +34,15 @@ class WeatherViewSet(viewsets.ViewSet):
for entry in days:
if not isinstance(entry, dict):
results.append(
{"date": None, "available": False, "temperature_c": None}
{
"date": None,
"available": False,
"temperature_low_c": None,
"temperature_high_c": None,
"temperature_c": None,
"is_estimate": False,
"source": None,
}
)
continue
@@ -49,14 +52,30 @@ class WeatherViewSet(viewsets.ViewSet):
if not date or latitude is None or longitude is None:
results.append(
{"date": date, "available": False, "temperature_c": None}
{
"date": date,
"available": False,
"temperature_low_c": None,
"temperature_high_c": None,
"temperature_c": None,
"is_estimate": False,
"source": None,
}
)
continue
parsed_date = self._parse_date(date)
if parsed_date is None:
results.append(
{"date": date, "available": False, "temperature_c": None}
{
"date": date,
"available": False,
"temperature_low_c": None,
"temperature_high_c": None,
"temperature_c": None,
"is_estimate": False,
"source": None,
}
)
continue
@@ -65,7 +84,15 @@ class WeatherViewSet(viewsets.ViewSet):
lon = float(longitude)
except (TypeError, ValueError):
results.append(
{"date": date, "available": False, "temperature_c": None}
{
"date": date,
"available": False,
"temperature_low_c": None,
"temperature_high_c": None,
"temperature_c": None,
"is_estimate": False,
"source": None,
}
)
continue
@@ -82,57 +109,9 @@ class WeatherViewSet(viewsets.ViewSet):
return Response({"results": results}, status=status.HTTP_200_OK)
def _fetch_daily_temperature(self, date: str, latitude: float, longitude: float):
base_payload = {
"date": date,
"available": False,
"temperature_c": None,
}
for url in (self.OPEN_METEO_ARCHIVE_URL, self.OPEN_METEO_FORECAST_URL):
try:
response = requests.get(
url,
params={
"latitude": latitude,
"longitude": longitude,
"start_date": date,
"end_date": date,
"daily": "temperature_2m_max,temperature_2m_min",
"timezone": "UTC",
},
timeout=8,
)
response.raise_for_status()
data = response.json()
except requests.RequestException:
continue
except ValueError:
continue
daily = data.get("daily") or {}
max_values = daily.get("temperature_2m_max") or []
min_values = daily.get("temperature_2m_min") or []
if not max_values or not min_values:
continue
try:
avg = (float(max_values[0]) + float(min_values[0])) / 2
except (TypeError, ValueError, IndexError):
continue
return {
"date": date,
"available": True,
"temperature_c": round(avg, 1),
}
logger.info(
"No weather data available for date=%s lat=%s lon=%s",
date,
latitude,
longitude,
return fetch_daily_temperature(
date=date, latitude=latitude, longitude=longitude
)
return base_payload
def _cache_key(self, date: str, latitude: float, longitude: float) -> str:
rounded_lat = round(latitude, 3)

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,9 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful travel planning assistant for the V
Your capabilities:
- Search for interesting places (restaurants, tourist attractions, hotels) near any location
- View and manage the user's trip collections and itineraries
- Add new locations to trip itineraries
- Add, move, and remove itinerary items
- Update itinerary location details
- Add/manage lodging and transportation entries in the trip
- Check weather/temperature data for travel dates
When suggesting places:
@@ -23,8 +25,8 @@ When suggesting places:
- Group suggestions logically (by area, by type, by day)
When modifying itineraries:
- Confirm with the user before the first add_to_itinerary action in a conversation
- After the user clearly approves adding items (for example: "yes", "go ahead", "add them", "just add things there"), stop re-confirming and call add_to_itinerary directly for subsequent additions in that conversation
- Confirm with the user before the first mutating itinerary action in a conversation (add, move, remove, or update).
- After the user clearly approves itinerary changes (for example: "yes", "go ahead", "add them", "just add things there"), stop re-confirming and proceed directly for subsequent itinerary changes in that conversation.
- Suggest logical ordering based on geography
- Consider travel time between locations

View File

@@ -1,22 +1,34 @@
import json
from datetime import date
from unittest import mock
from unittest.mock import MagicMock, patch
import requests as _requests
from django.contrib.auth import get_user_model
from django.contrib.contenttypes.models import ContentType
from django.test import TestCase
from rest_framework.test import APITransactionTestCase
from adventures.models import Collection, CollectionItineraryItem, Location
from adventures.models import Collection, CollectionItineraryItem, Location, Visit
from chat.agent_tools import (
add_to_itinerary,
execute_tool,
add_lodging,
add_transportation,
get_trip_details,
move_itinerary_item,
remove_itinerary_item,
remove_lodging,
remove_transportation,
search_places,
update_location_details,
update_lodging,
update_transportation,
web_search,
)
from chat.views import ChatViewSet
from chat.views.day_suggestions import DaySuggestionsView
User = get_user_model()
@@ -245,6 +257,209 @@ class ChatAgentToolSharedTripAccessTests(TestCase):
self.assertEqual(itinerary_result, {"error": "Trip not found"})
class ChatAgentToolItineraryManagementTests(TestCase):
def setUp(self):
self.owner = User.objects.create_user(
username="chat-itinerary-owner",
email="chat-itinerary-owner@example.com",
password="password123",
)
self.shared_user = User.objects.create_user(
username="chat-itinerary-shared",
email="chat-itinerary-shared@example.com",
password="password123",
)
self.non_member = User.objects.create_user(
username="chat-itinerary-non-member",
email="chat-itinerary-non-member@example.com",
password="password123",
)
self.collection = Collection.objects.create(
user=self.owner,
name="Assistant Managed Trip",
)
self.collection.shared_with.add(self.shared_user)
self.location = Location.objects.create(
user=self.owner,
name="Existing Stop",
latitude=48.8566,
longitude=2.3522,
)
self.collection.locations.add(self.location)
self.location_content_type = ContentType.objects.get_for_model(Location)
self.day1_item = CollectionItineraryItem.objects.create(
collection=self.collection,
content_type=self.location_content_type,
object_id=self.location.id,
date="2026-06-01",
order=0,
is_global=False,
)
def test_move_itinerary_item_allows_shared_user(self):
result = move_itinerary_item(
self.shared_user,
collection_id=str(self.collection.id),
itinerary_item_id=str(self.day1_item.id),
date="2026-06-02",
order=0,
)
self.assertTrue(result.get("success"))
self.day1_item.refresh_from_db()
self.assertEqual(self.day1_item.date.isoformat(), "2026-06-02")
self.assertEqual(self.day1_item.order, 0)
def test_remove_itinerary_item_removes_matching_visit_for_locations(self):
Visit.objects.create(
location=self.location,
start_date="2026-06-01T10:00:00Z",
end_date="2026-06-01T12:00:00Z",
)
result = remove_itinerary_item(
self.owner,
collection_id=str(self.collection.id),
itinerary_item_id=str(self.day1_item.id),
)
self.assertTrue(result.get("success"))
self.assertEqual(result.get("deleted_visit_count"), 1)
self.assertFalse(
CollectionItineraryItem.objects.filter(id=self.day1_item.id).exists()
)
self.assertEqual(Visit.objects.filter(location=self.location).count(), 0)
def test_update_location_details_scoped_to_collection(self):
outsider_location = Location.objects.create(
user=self.owner,
name="Outside",
latitude=1.0,
longitude=1.0,
)
denied = update_location_details(
self.shared_user,
collection_id=str(self.collection.id),
location_id=str(outsider_location.id),
name="Should fail",
)
self.assertEqual(denied, {"error": "Location not found in this trip"})
allowed = update_location_details(
self.shared_user,
collection_id=str(self.collection.id),
location_id=str(self.location.id),
name="Updated Stop",
latitude=40.7128,
longitude=-74.0060,
)
self.assertTrue(allowed.get("success"))
self.location.refresh_from_db()
self.assertEqual(self.location.name, "Updated Stop")
self.assertAlmostEqual(float(self.location.latitude), 40.7128, places=4)
self.assertAlmostEqual(float(self.location.longitude), -74.0060, places=4)
def test_lodging_management_tools_allow_shared_user(self):
created = add_lodging(
self.shared_user,
collection_id=str(self.collection.id),
name="River Hotel",
type="hotel",
location="Paris",
check_in="2026-06-02T15:00:00Z",
check_out="2026-06-04T11:00:00Z",
latitude=48.85,
longitude=2.35,
itinerary_date="2026-06-02",
)
self.assertTrue(created.get("success"))
lodging_id = created["lodging"]["id"]
self.assertIsNotNone(created.get("itinerary_item"))
updated = update_lodging(
self.shared_user,
collection_id=str(self.collection.id),
lodging_id=lodging_id,
name="River Hotel Updated",
location="Paris Center",
)
self.assertTrue(updated.get("success"))
self.assertEqual(updated["lodging"]["name"], "River Hotel Updated")
removed = remove_lodging(
self.shared_user,
collection_id=str(self.collection.id),
lodging_id=lodging_id,
)
self.assertTrue(removed.get("success"))
self.assertGreaterEqual(removed.get("removed_itinerary_items", 0), 1)
def test_transportation_management_tools_allow_shared_user(self):
created = add_transportation(
self.shared_user,
collection_id=str(self.collection.id),
name="Train to Lyon",
type="train",
date="2026-06-03T09:00:00Z",
end_date="2026-06-03T11:00:00Z",
from_location="Paris",
to_location="Lyon",
origin_latitude=48.8566,
origin_longitude=2.3522,
destination_latitude=45.7640,
destination_longitude=4.8357,
itinerary_date="2026-06-03",
)
self.assertTrue(created.get("success"))
transportation_id = created["transportation"]["id"]
updated = update_transportation(
self.shared_user,
collection_id=str(self.collection.id),
transportation_id=transportation_id,
to_location="Lyon Part-Dieu",
)
self.assertTrue(updated.get("success"))
self.assertEqual(updated["transportation"]["to_location"], "Lyon Part-Dieu")
removed = remove_transportation(
self.shared_user,
collection_id=str(self.collection.id),
transportation_id=transportation_id,
)
self.assertTrue(removed.get("success"))
self.assertGreaterEqual(removed.get("removed_itinerary_items", 0), 1)
def test_management_tools_deny_non_member(self):
move_result = move_itinerary_item(
self.non_member,
collection_id=str(self.collection.id),
itinerary_item_id=str(self.day1_item.id),
date="2026-06-02",
)
remove_result = remove_itinerary_item(
self.non_member,
collection_id=str(self.collection.id),
itinerary_item_id=str(self.day1_item.id),
)
update_result = update_location_details(
self.non_member,
collection_id=str(self.collection.id),
location_id=str(self.location.id),
name="No access",
)
self.assertEqual(move_result, {"error": "Trip not found"})
self.assertEqual(remove_result, {"error": "Trip not found"})
self.assertEqual(update_result, {"error": "Trip not found"})
class ChatViewSetToolValidationBoundaryTests(TestCase):
def test_trip_context_destination_summary_normalizes_to_first_segment(self):
self.assertEqual(
@@ -1286,3 +1501,337 @@ class GetWeatherCoordFallbackTests(APITransactionTestCase):
),
"Should emit tool_validation_error when no collection coords available",
)
@patch("chat.views.execute_tool")
@patch("chat.views.stream_chat_completion")
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
def test_get_weather_missing_coords_and_dates_retries_with_collection_dates(
self,
_mock_auto_profile,
mock_stream_chat_completion,
mock_execute_tool,
):
user = User.objects.create_user(
username="weather-coord-dates-user",
email="weather-coord-dates-user@example.com",
password="password123",
)
self.client.force_authenticate(user=user)
collection = Collection.objects.create(
user_id=user.id,
name="Vienna Trip",
start_date=date(2026, 6, 10),
end_date=date(2026, 6, 12),
)
vienna_location = Location.objects.create(
user_id=user.id,
name="Vienna",
latitude=48.2082,
longitude=16.3738,
)
collection.locations.add(vienna_location)
conversation_response = self.client.post(
"/api/chat/conversations/",
{"title": "Weather Coord+Dates Fallback Test"},
format="json",
)
self.assertEqual(conversation_response.status_code, 201)
conversation_id = conversation_response.json()["id"]
async def weather_stream(*args, **kwargs):
yield 'data: {"tool_calls": [{"index": 0, "id": "call_w3", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]}\n\n'
yield "data: [DONE]\n\n"
async def success_stream(*args, **kwargs):
yield 'data: {"content": "Vienna forecast loaded."}\n\n'
yield "data: [DONE]\n\n"
mock_stream_chat_completion.side_effect = [weather_stream(), success_stream()]
mock_execute_tool.side_effect = [
{"error": "latitude and longitude are required"},
{
"location": "Vienna",
"forecast": [
{"date": "2026-06-10", "temperature": 24, "condition": "sunny"}
],
},
]
response = self.client.post(
f"/api/chat/conversations/{conversation_id}/send_message/",
{
"message": "What's the weather for my trip?",
"collection_id": str(collection.id),
},
format="json",
)
self.assertEqual(response.status_code, 200)
chunks = [
chunk.decode("utf-8")
if isinstance(chunk, (bytes, bytearray))
else str(chunk)
for chunk in response.streaming_content
]
payload_lines = [
chunk.strip()[len("data: ") :]
for chunk in chunks
if chunk.strip().startswith("data: ")
]
json_payloads = [json.loads(p) for p in payload_lines if p != "[DONE]"]
self.assertEqual(mock_execute_tool.call_count, 2)
retry_kwargs = mock_execute_tool.call_args_list[1][1]
self.assertAlmostEqual(retry_kwargs.get("latitude"), 48.2082, places=3)
self.assertAlmostEqual(retry_kwargs.get("longitude"), 16.3738, places=3)
self.assertEqual(
retry_kwargs.get("dates"),
["2026-06-10", "2026-06-11", "2026-06-12"],
)
self.assertFalse(
any(
payload.get("error_category") == "tool_execution_error"
for payload in json_payloads
)
)
@patch("chat.views.execute_tool")
@patch("chat.views.stream_chat_completion")
@patch("integrations.utils.auto_profile.update_auto_preference_profile")
def test_get_weather_missing_dates_only_remains_validation_error(
self,
_mock_auto_profile,
mock_stream_chat_completion,
mock_execute_tool,
):
user = User.objects.create_user(
username="weather-missing-dates-user",
email="weather-missing-dates-user@example.com",
password="password123",
)
self.client.force_authenticate(user=user)
collection = Collection.objects.create(
user_id=user.id,
name="Berlin Trip",
)
berlin_location = Location.objects.create(
user_id=user.id,
name="Berlin",
latitude=52.52,
longitude=13.405,
)
collection.locations.add(berlin_location)
conversation_response = self.client.post(
"/api/chat/conversations/",
{"title": "Weather Missing Dates Validation Test"},
format="json",
)
self.assertEqual(conversation_response.status_code, 201)
conversation_id = conversation_response.json()["id"]
async def weather_stream(*args, **kwargs):
yield (
'data: {"tool_calls": [{"index": 0, "id": "call_w4", "type": '
'"function", "function": {"name": "get_weather", "arguments": '
'"{\\"latitude\\":52.52,\\"longitude\\":13.405}"}}]}\n\n'
)
yield "data: [DONE]\n\n"
mock_stream_chat_completion.side_effect = weather_stream
mock_execute_tool.return_value = {"error": "dates is required"}
response = self.client.post(
f"/api/chat/conversations/{conversation_id}/send_message/",
{
"message": "What's the weather there?",
"collection_id": str(collection.id),
},
format="json",
)
self.assertEqual(response.status_code, 200)
payload_lines = [
(
chunk.decode("utf-8")
if isinstance(chunk, (bytes, bytearray))
else str(chunk)
).strip()[len("data: ") :]
for chunk in response.streaming_content
if (
(
chunk.decode("utf-8")
if isinstance(chunk, (bytes, bytearray))
else str(chunk)
)
.strip()
.startswith("data: ")
)
]
json_payloads = [json.loads(p) for p in payload_lines if p != "[DONE]"]
self.assertEqual(mock_execute_tool.call_count, 1)
self.assertTrue(
any(
payload.get("error_category") == "tool_validation_error"
for payload in json_payloads
)
)
self.assertFalse(
any(
payload.get("error_category") == "tool_execution_error"
for payload in json_payloads
)
)
self.assertFalse(
any(
"Could not fetch weather for the collection locations"
in payload.get("error", "")
for payload in json_payloads
)
)
class DaySuggestionsCoordinateEnrichmentTests(TestCase):
def setUp(self):
self.view = DaySuggestionsView()
def test_enriches_suggestion_with_coordinates_from_place_context(self):
suggestions = [
{
"name": "Roscioli",
"location": "Via dei Giubbonari, Rome",
"description": "Classic Roman spot",
}
]
place_candidates = [
{
"name": "Roscioli",
"address": "Via dei Giubbonari, Rome",
"latitude": 41.8933,
"longitude": 12.4722,
}
]
enriched = self.view._enrich_suggestions_with_coordinates(
suggestions,
place_candidates,
)
self.assertEqual(len(enriched), 1)
self.assertEqual(enriched[0]["latitude"], 41.8933)
self.assertEqual(enriched[0]["longitude"], 12.4722)
def test_preserves_existing_suggestion_coordinates(self):
suggestions = [
{
"name": "Known Place",
"location": "Somewhere",
"latitude": 10.5,
"longitude": 20.5,
}
]
place_candidates = [
{
"name": "Known Place",
"address": "Somewhere",
"latitude": 1.0,
"longitude": 2.0,
}
]
enriched = self.view._enrich_suggestions_with_coordinates(
suggestions,
place_candidates,
)
self.assertEqual(enriched[0]["latitude"], 10.5)
self.assertEqual(enriched[0]["longitude"], 20.5)
def test_enriches_coordinates_with_token_based_name_matching(self):
suggestions = [
{
"name": "Borough food market",
"location": "South Bank",
"description": "Popular food destination",
}
]
place_candidates = [
{
"name": "Borough Market",
"address": "8 Southwark St, London SE1 1TL",
"latitude": 51.5055,
"longitude": -0.0904,
}
]
enriched = self.view._enrich_suggestions_with_coordinates(
suggestions,
place_candidates,
)
self.assertEqual(len(enriched), 1)
self.assertEqual(enriched[0]["latitude"], 51.5055)
self.assertEqual(enriched[0]["longitude"], -0.0904)
def test_falls_back_to_coordinate_match_when_best_text_match_has_no_coordinates(
self,
):
suggestions = [
{
"name": "Sunset Bar",
"location": "Pier 7",
"description": "Cocktail spot",
}
]
place_candidates = [
{
"name": "Sunset Bar",
"address": "Pier 7",
"latitude": None,
"longitude": None,
},
{
"name": "Harbor Walk",
"address": "Pier 7, Lisbon",
"latitude": 38.7072,
"longitude": -9.1366,
},
]
enriched = self.view._enrich_suggestions_with_coordinates(
suggestions,
place_candidates,
)
self.assertEqual(len(enriched), 1)
self.assertEqual(enriched[0]["latitude"], 38.7072)
self.assertEqual(enriched[0]["longitude"], -9.1366)
def test_does_not_inject_null_coordinates_when_no_coordinate_match_exists(self):
suggestions = [
{
"name": "Skyline View",
"location": "Hilltop",
}
]
place_candidates = [
{
"name": "Skyline View",
"address": "Hilltop",
"latitude": None,
"longitude": None,
}
]
enriched = self.view._enrich_suggestions_with_coordinates(
suggestions,
place_candidates,
)
self.assertEqual(len(enriched), 1)
self.assertNotIn("latitude", enriched[0])
self.assertNotIn("longitude", enriched[0])

View File

@@ -2,10 +2,12 @@ import asyncio
import json
import logging
import re
from datetime import timedelta
from asgiref.sync import sync_to_async
from adventures.models import Collection
from django.http import StreamingHttpResponse
from django.utils import timezone
from integrations.models import UserAISettings
from rest_framework import status, viewsets
from rest_framework.decorators import action
@@ -276,6 +278,33 @@ class ChatViewSet(viewsets.ModelViewSet):
continue
return None
@staticmethod
def _derive_weather_dates_from_collection(collection, max_days=7):
"""Derive a bounded weather date list from collection dates, or fallback to today."""
today = timezone.localdate()
if collection is None:
return [today.isoformat()]
start_date = getattr(collection, "start_date", None)
end_date = getattr(collection, "end_date", None)
if start_date and end_date:
range_start = min(start_date, end_date)
range_end = max(start_date, end_date)
day_count = min((range_end - range_start).days + 1, max_days)
return [
(range_start + timedelta(days=offset)).isoformat()
for offset in range(day_count)
]
if start_date:
return [start_date.isoformat()]
if end_date:
return [end_date.isoformat()]
return [today.isoformat()]
@staticmethod
def _build_search_places_location_clarification_message():
return (
@@ -744,6 +773,12 @@ class ChatViewSet(viewsets.ModelViewSet):
retry_arguments = dict(prepared_arguments)
retry_arguments["latitude"] = retry_lat
retry_arguments["longitude"] = retry_lon
if not retry_arguments.get("dates"):
retry_arguments["dates"] = (
self._derive_weather_dates_from_collection(
collection
)
)
attempted_weather_coord_retry = True
retry_result = await sync_to_async(
execute_tool,
@@ -774,6 +809,10 @@ class ChatViewSet(viewsets.ModelViewSet):
if (
attempted_weather_coord_retry
and self._is_required_param_tool_error(result)
and self._is_get_weather_missing_latlong_error(
function_name,
result,
)
):
result = {
"error": "Could not fetch weather for the collection locations"

View File

@@ -72,7 +72,12 @@ class DaySuggestionsView(APIView):
)
try:
places_context = self._get_places_context(request.user, category, location)
place_candidates = self._fetch_place_candidates(
request.user,
category,
location,
)
places_context = self._build_places_context(place_candidates)
prompt = self._build_prompt(
category=category,
filters=filters,
@@ -89,17 +94,30 @@ class DaySuggestionsView(APIView):
provider=provider,
model=model,
)
suggestions = self._enrich_suggestions_with_coordinates(
suggestions,
place_candidates,
)
return Response({"suggestions": suggestions}, status=status.HTTP_200_OK)
except Exception as exc:
logger.exception("Failed to generate day suggestions")
payload = _safe_error_payload(exc)
status_code = {
error_category = (
payload.get("error_category") if isinstance(payload, dict) else None
)
status_code_map = {
"model_not_found": status.HTTP_400_BAD_REQUEST,
"authentication_failed": status.HTTP_401_UNAUTHORIZED,
"rate_limited": status.HTTP_429_TOO_MANY_REQUESTS,
"invalid_request": status.HTTP_400_BAD_REQUEST,
"provider_unreachable": status.HTTP_503_SERVICE_UNAVAILABLE,
}.get(payload.get("error_category"), status.HTTP_500_INTERNAL_SERVER_ERROR)
}
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
if isinstance(error_category, str):
status_code = status_code_map.get(
error_category,
status.HTTP_500_INTERNAL_SERVER_ERROR,
)
return Response(
payload,
status=status_code,
@@ -176,11 +194,12 @@ class DaySuggestionsView(APIView):
prompt += (
" Return 3-5 specific suggestions as a JSON array."
" Each suggestion should have: name, description, why_fits, category, location, rating, price_level."
" Include latitude and longitude when known from nearby-place context."
" Return ONLY valid JSON, no markdown, no surrounding text."
)
return prompt
def _get_places_context(self, user, category, location):
def _fetch_place_candidates(self, user, category, location):
tool_category_map = {
"restaurant": "food",
"activity": "tourism",
@@ -194,24 +213,190 @@ class DaySuggestionsView(APIView):
radius=8,
)
if not isinstance(result, dict):
return ""
return []
if result.get("error"):
return ""
return []
raw_results = result.get("results")
if not isinstance(raw_results, list):
return []
return [entry for entry in raw_results if isinstance(entry, dict)]
def _build_places_context(self, place_candidates):
if not isinstance(place_candidates, list):
return ""
entries = []
for place in raw_results[:5]:
if not isinstance(place, dict):
continue
for place in place_candidates[:5]:
name = place.get("name")
address = place.get("address") or ""
if name:
entries.append(f"{name} ({address})" if address else name)
latitude = place.get("latitude")
longitude = place.get("longitude")
if not name:
continue
details = [name]
if address:
details.append(address)
if latitude is not None and longitude is not None:
details.append(f"lat={latitude}")
details.append(f"lon={longitude}")
entries.append(" | ".join(details))
return "; ".join(entries)
def _tokenize_text(self, value):
normalized = self._normalize_text(value)
if not normalized:
return set()
return set(re.findall(r"[a-z0-9]+", normalized))
def _normalize_text(self, value):
if not isinstance(value, str):
return ""
return value.strip().lower()
def _extract_suggestion_identity(self, suggestion):
if not isinstance(suggestion, dict):
return "", ""
name = self._normalize_text(
suggestion.get("name")
or suggestion.get("title")
or suggestion.get("place_name")
or suggestion.get("venue")
)
location_text = self._normalize_text(
suggestion.get("location")
or suggestion.get("address")
or suggestion.get("neighborhood")
)
return name, location_text
def _best_place_match(self, suggestion, place_candidates):
suggestion_name, suggestion_location = self._extract_suggestion_identity(
suggestion
)
if not suggestion_name and not suggestion_location:
return None
suggestion_name_tokens = self._tokenize_text(suggestion_name)
suggestion_location_tokens = self._tokenize_text(suggestion_location)
def has_coordinates(candidate):
return (
candidate.get("latitude") is not None
and candidate.get("longitude") is not None
)
best_candidate = None
best_score = -1
best_coordinate_candidate = None
best_coordinate_score = -1
for candidate in place_candidates:
candidate_name = self._normalize_text(candidate.get("name"))
candidate_address = self._normalize_text(candidate.get("address"))
candidate_name_tokens = self._tokenize_text(candidate_name)
candidate_address_tokens = self._tokenize_text(candidate_address)
score = 0
if suggestion_name and candidate_name:
if suggestion_name == candidate_name:
score += 4
elif (
suggestion_name in candidate_name
or candidate_name in suggestion_name
):
score += 2
shared_name_tokens = suggestion_name_tokens & candidate_name_tokens
if len(shared_name_tokens) >= 2:
score += 3
elif len(shared_name_tokens) == 1:
score += 1
if suggestion_location and candidate_address:
if suggestion_location == candidate_address:
score += 2
elif (
suggestion_location in candidate_address
or candidate_address in suggestion_location
):
score += 1
shared_location_tokens = (
suggestion_location_tokens & candidate_address_tokens
)
if len(shared_location_tokens) >= 2:
score += 2
elif len(shared_location_tokens) == 1:
score += 1
if score > best_score:
best_score = score
best_candidate = candidate
elif (
score == best_score
and best_candidate is not None
and not has_coordinates(best_candidate)
and has_coordinates(candidate)
):
best_candidate = candidate
if has_coordinates(candidate) and score > best_coordinate_score:
best_coordinate_score = score
best_coordinate_candidate = candidate
if best_score <= 0:
return None
if has_coordinates(best_candidate):
return best_candidate
# Bounded fallback: if the strongest text match has no coordinates,
# accept the best coordinate-bearing candidate only with a
# reasonably strong lexical overlap score.
if best_coordinate_score >= 2:
return best_coordinate_candidate
return best_candidate
def _enrich_suggestions_with_coordinates(self, suggestions, place_candidates):
if not isinstance(suggestions, list) or not isinstance(place_candidates, list):
return suggestions
enriched = []
for suggestion in suggestions:
if not isinstance(suggestion, dict):
continue
if (
suggestion.get("latitude") is not None
and suggestion.get("longitude") is not None
):
enriched.append(suggestion)
continue
matched_place = self._best_place_match(suggestion, place_candidates)
if not matched_place:
enriched.append(suggestion)
continue
if (
matched_place.get("latitude") is None
or matched_place.get("longitude") is None
):
enriched.append(suggestion)
continue
merged = dict(suggestion)
merged["latitude"] = matched_place.get("latitude")
merged["longitude"] = matched_place.get("longitude")
merged["location"] = merged.get("location") or matched_place.get("address")
enriched.append(merged)
return enriched
def _resolve_provider_and_model(self, request):
request_provider = (request.data.get("provider") or "").strip().lower() or None
request_model = (request.data.get("model") or "").strip() or None
@@ -262,7 +447,7 @@ class DaySuggestionsView(APIView):
if not api_key:
raise ValueError("No API key available")
provider_config = CHAT_PROVIDER_CONFIG.get(provider, {})
provider_config = CHAT_PROVIDER_CONFIG.get(provider or "", {})
resolved_model = normalize_gateway_model(
provider,
model or provider_config.get("default_model"),