From 9d5681b1ef62b756b968ca19b0b563423c37f4b0 Mon Sep 17 00:00:00 2001 From: alex Date: Sun, 8 Mar 2026 23:53:14 +0000 Subject: [PATCH] feat(ai): implement agent-redesign plan with enhanced AI travel features MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 1 - Configuration Infrastructure (WS1): - Add instance-level AI env vars (VOYAGE_AI_PROVIDER, VOYAGE_AI_MODEL, VOYAGE_AI_API_KEY) - Implement fallback chain: user key β†’ instance key β†’ error - Add UserAISettings model for per-user provider/model preferences - Enhance provider catalog with instance_configured and user_configured flags - Optimize provider catalog to avoid N+1 queries Phase 1 - User Preference Learning (WS2): - Add Travel Preferences tab to Settings page - Improve preference formatting in system prompt with emoji headers - Add multi-user preference aggregation for shared collections Phase 2 - Day-Level Suggestions Modal (WS3): - Create ItinerarySuggestionModal with 3-step flow (category β†’ filters β†’ results) - Add AI suggestions button to itinerary Add dropdown - Support restaurant, activity, event, and lodging categories - Backend endpoint POST /api/chat/suggestions/day/ with context-aware prompts Phase 3 - Collection-Level Chat Improvements (WS4): - Inject collection context (destination, dates) into chat system prompt - Add quick action buttons for common queries - Add 'Add to itinerary' button on search_places results - Update chat UI with travel-themed branding and improved tool result cards Phase 3 - Web Search Capability (WS5): - Add web_search agent tool using DuckDuckGo - Support location_context parameter for biased results - Handle rate limiting gracefully Phase 4 - Extensibility Architecture (WS6): - Implement decorator-based @agent_tool registry - Convert existing tools to use decorators - Add GET /api/chat/capabilities/ endpoint for tool discovery - Refactor execute_tool() to use registry pattern --- backend/server/chat/agent_tools.py | 401 +++++++++------ backend/server/chat/llm_client.py | 225 +++++++- backend/server/chat/urls.py | 9 +- backend/server/chat/views/__init__.py | 328 ++++++++++++ backend/server/chat/views/capabilities.py | 24 + backend/server/chat/views/day_suggestions.py | 215 ++++++++ .../migrations/0008_useraisettings.py | 52 ++ backend/server/integrations/models.py | 20 + backend/server/integrations/serializers.py | 14 + backend/server/integrations/urls.py | 2 + backend/server/integrations/views/__init__.py | 1 + .../integrations/views/ai_settings_view.py | 39 ++ backend/server/main/settings.py | 5 + backend/server/requirements.txt | 1 + .../src/lib/components/AITravelChat.svelte | 486 ++++++++++++++++-- .../CollectionItineraryPlanner.svelte | 76 ++- .../ItinerarySuggestionModal.svelte | 442 ++++++++++++++++ frontend/src/lib/types.ts | 10 + frontend/src/locales/en.json | 50 +- .../src/routes/collections/[id]/+page.svelte | 32 +- frontend/src/routes/settings/+page.server.ts | 27 +- frontend/src/routes/settings/+page.svelte | 154 +++++- 22 files changed, 2358 insertions(+), 255 deletions(-) create mode 100644 backend/server/chat/views/__init__.py create mode 100644 backend/server/chat/views/capabilities.py create mode 100644 backend/server/chat/views/day_suggestions.py create mode 100644 backend/server/integrations/migrations/0008_useraisettings.py create mode 100644 backend/server/integrations/views/ai_settings_view.py create mode 100644 frontend/src/lib/components/collections/ItinerarySuggestionModal.svelte diff --git a/backend/server/chat/agent_tools.py b/backend/server/chat/agent_tools.py index 8dbbeb24..0767e391 100644 --- a/backend/server/chat/agent_tools.py +++ b/backend/server/chat/agent_tools.py @@ -1,4 +1,5 @@ import json +import inspect import logging from datetime import date as date_cls @@ -10,117 +11,50 @@ from adventures.models import Collection, CollectionItineraryItem, Location logger = logging.getLogger(__name__) -AGENT_TOOLS = [ - { - "type": "function", - "function": { - "name": "search_places", - "description": "Search for places of interest near a location. Returns tourist attractions, restaurants, hotels, etc.", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "Location name or address to search near", - }, - "category": { - "type": "string", - "enum": ["tourism", "food", "lodging"], - "description": "Category of places", - }, - "radius": { - "type": "number", - "description": "Search radius in km (default 10)", - }, +_REGISTERED_TOOLS = {} +_TOOL_SCHEMAS = [] + + +def agent_tool(name: str, description: str, parameters: dict): + """Decorator to register a function as an agent tool.""" + + def decorator(func): + _REGISTERED_TOOLS[name] = func + + required = [k for k, v in parameters.items() if v.get("required", False)] + props = { + k: {kk: vv for kk, vv in v.items() if kk != "required"} + for k, v in parameters.items() + } + + schema = { + "type": "function", + "function": { + "name": name, + "description": description, + "parameters": { + "type": "object", + "properties": props, + "required": required, }, - "required": ["location"], }, - }, - }, - { - "type": "function", - "function": { - "name": "list_trips", - "description": "List the user's trip collections with dates and descriptions", - "parameters": {"type": "object", "properties": {}}, - }, - }, - { - "type": "function", - "function": { - "name": "get_trip_details", - "description": "Get full details of a trip including all itinerary items, locations, transportation, and lodging", - "parameters": { - "type": "object", - "properties": { - "collection_id": { - "type": "string", - "description": "UUID of the collection/trip", - } - }, - "required": ["collection_id"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "add_to_itinerary", - "description": "Add a new location to a trip's itinerary on a specific date", - "parameters": { - "type": "object", - "properties": { - "collection_id": { - "type": "string", - "description": "UUID of the collection/trip", - }, - "name": {"type": "string", "description": "Name of the location"}, - "description": { - "type": "string", - "description": "Description of why to visit", - }, - "latitude": { - "type": "number", - "description": "Latitude coordinate", - }, - "longitude": { - "type": "number", - "description": "Longitude coordinate", - }, - "date": { - "type": "string", - "description": "Date in YYYY-MM-DD format", - }, - "location_address": { - "type": "string", - "description": "Full address of the location", - }, - }, - "required": ["collection_id", "name", "latitude", "longitude"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get temperature/weather data for a location on specific dates", - "parameters": { - "type": "object", - "properties": { - "latitude": {"type": "number", "description": "Latitude"}, - "longitude": {"type": "number", "description": "Longitude"}, - "dates": { - "type": "array", - "items": {"type": "string"}, - "description": "List of dates in YYYY-MM-DD format", - }, - }, - "required": ["latitude", "longitude", "dates"], - }, - }, - }, -] + } + _TOOL_SCHEMAS.append(schema) + + return func + + return decorator + + +def get_tool_schemas() -> list: + """Return all registered tool schemas for LLM.""" + return _TOOL_SCHEMAS.copy() + + +def get_registered_tools() -> dict: + """Return all registered tool functions.""" + return _REGISTERED_TOOLS.copy() + NOMINATIM_URL = "https://nominatim.openstreetmap.org/search" OVERPASS_URL = "https://overpass-api.de/api/interpreter" @@ -162,14 +96,39 @@ def _parse_address(tags): return ", ".join([p for p in parts if p]) -def search_places(user, **kwargs): +@agent_tool( + name="search_places", + description="Search for places of interest near a location. Returns tourist attractions, restaurants, hotels, etc.", + parameters={ + "location": { + "type": "string", + "description": "Location name or address to search near", + "required": True, + }, + "category": { + "type": "string", + "enum": ["tourism", "food", "lodging"], + "description": "Category of places", + }, + "radius": { + "type": "number", + "description": "Search radius in km (default 10)", + }, + }, +) +def search_places( + user, + location: str | None = None, + category: str = "tourism", + radius: float = 10, +): try: - location_name = kwargs.get("location") + location_name = location if not location_name: return {"error": "location is required"} - category = kwargs.get("category") or "tourism" - radius_km = float(kwargs.get("radius") or 10) + category = category or "tourism" + radius_km = float(radius or 10) radius_meters = max(500, min(int(radius_km * 1000), 50000)) geocode_resp = requests.get( @@ -240,7 +199,12 @@ def search_places(user, **kwargs): return {"error": "An unexpected error occurred during place search"} -def list_trips(user, **kwargs): +@agent_tool( + name="list_trips", + description="List the user's trip collections with dates and descriptions", + parameters={}, +) +def list_trips(user): try: collections = Collection.objects.filter(user=user).prefetch_related("locations") trips = [] @@ -265,9 +229,87 @@ def list_trips(user, **kwargs): return {"error": "An unexpected error occurred while listing trips"} -def get_trip_details(user, **kwargs): +@agent_tool( + name="web_search", + description="Search the web for current information about destinations, events, prices, weather, or any real-time travel information. Use this when you need up-to-date information that may not be in your training data.", + parameters={ + "query": { + "type": "string", + "description": "The search query (e.g., 'best restaurants Paris 2024', 'weather Tokyo March')", + "required": True, + }, + "location_context": { + "type": "string", + "description": "Optional location to bias search results (e.g., 'Paris, France')", + }, + }, +) +def web_search(user, query: str, location_context: str | None = None) -> dict: + """ + Search the web for current information about destinations, events, prices, etc. + + Args: + user: The user making the request (for auth/logging) + query: The search query + location_context: Optional location to bias results + + Returns: + dict with 'results' list or 'error' string + """ + if not query: + return {"error": "query is required", "results": []} + + try: + from duckduckgo_search import DDGS # type: ignore[import-not-found] + + full_query = query + if location_context: + full_query = f"{query} {location_context}" + + with DDGS() as ddgs: + results = list(ddgs.text(full_query, max_results=5)) + + formatted = [] + for result in results: + formatted.append( + { + "title": result.get("title", ""), + "snippet": result.get("body", ""), + "url": result.get("href", ""), + } + ) + + return {"results": formatted} + + except ImportError: + return { + "error": "Web search is not available (duckduckgo-search not installed)", + "results": [], + } + except Exception as exc: + error_str = str(exc).lower() + if "rate" in error_str or "limit" in error_str: + return { + "error": "Search rate limit reached. Please wait a moment and try again.", + "results": [], + } + logger.error("Web search error: %s", exc) + return {"error": "Web search failed. Please try again.", "results": []} + + +@agent_tool( + name="get_trip_details", + description="Get full details of a trip including all itinerary items, locations, transportation, and lodging", + parameters={ + "collection_id": { + "type": "string", + "description": "UUID of the collection/trip", + "required": True, + } + }, +) +def get_trip_details(user, collection_id: str | None = None): try: - collection_id = kwargs.get("collection_id") if not collection_id: return {"error": "collection_id is required"} @@ -354,16 +396,55 @@ def get_trip_details(user, **kwargs): return {"error": "An unexpected error occurred while fetching trip details"} -def add_to_itinerary(user, **kwargs): +@agent_tool( + name="add_to_itinerary", + description="Add a new location to a trip's itinerary on a specific date", + parameters={ + "collection_id": { + "type": "string", + "description": "UUID of the collection/trip", + "required": True, + }, + "name": { + "type": "string", + "description": "Name of the location", + "required": True, + }, + "description": { + "type": "string", + "description": "Description of why to visit", + }, + "latitude": { + "type": "number", + "description": "Latitude coordinate", + "required": True, + }, + "longitude": { + "type": "number", + "description": "Longitude coordinate", + "required": True, + }, + "date": { + "type": "string", + "description": "Date in YYYY-MM-DD format", + }, + "location_address": { + "type": "string", + "description": "Full address of the location", + }, + }, +) +def add_to_itinerary( + user, + collection_id: str | None = None, + name: str | None = None, + latitude: float | None = None, + longitude: float | None = None, + description: str | None = None, + date: str | None = None, + location_address: str | None = None, +): try: - collection_id = kwargs.get("collection_id") - name = kwargs.get("name") - latitude = kwargs.get("latitude") - longitude = kwargs.get("longitude") - description = kwargs.get("description") - location_address = kwargs.get("location_address") - date = kwargs.get("date") - if not collection_id or not name or latitude is None or longitude is None: return { "error": "collection_id, name, latitude, and longitude are required" @@ -479,16 +560,34 @@ def _fetch_temperature_for_date(latitude, longitude, date_value): } -def get_weather(user, **kwargs): +@agent_tool( + name="get_weather", + description="Get temperature/weather data for a location on specific dates", + parameters={ + "latitude": {"type": "number", "description": "Latitude", "required": True}, + "longitude": { + "type": "number", + "description": "Longitude", + "required": True, + }, + "dates": { + "type": "array", + "items": {"type": "string"}, + "description": "List of dates in YYYY-MM-DD format", + "required": True, + }, + }, +) +def get_weather(user, latitude=None, longitude=None, dates=None): try: - raw_latitude = kwargs.get("latitude") - raw_longitude = kwargs.get("longitude") + raw_latitude = latitude + raw_longitude = longitude if raw_latitude is None or raw_longitude is None: return {"error": "latitude and longitude are required"} latitude = float(raw_latitude) longitude = float(raw_longitude) - dates = kwargs.get("dates") or [] + dates = dates or [] if not isinstance(dates, list) or not dates: return {"error": "dates must be a non-empty list"} @@ -509,44 +608,24 @@ def get_weather(user, **kwargs): return {"error": "An unexpected error occurred while fetching weather data"} -ALLOWED_KWARGS = { - "search_places": {"location", "category", "radius"}, - "list_trips": set(), - "get_trip_details": {"collection_id"}, - "add_to_itinerary": { - "collection_id", - "name", - "description", - "latitude", - "longitude", - "date", - "location_address", - }, - "get_weather": {"latitude", "longitude", "dates"}, -} - - def execute_tool(tool_name, user, **kwargs): - tool_map = { - "search_places": search_places, - "list_trips": list_trips, - "get_trip_details": get_trip_details, - "add_to_itinerary": add_to_itinerary, - "get_weather": get_weather, - } - - tool_fn = tool_map.get(tool_name) - if not tool_fn: + if tool_name not in _REGISTERED_TOOLS: return {"error": f"Unknown tool: {tool_name}"} - allowed = ALLOWED_KWARGS.get(tool_name, set()) + tool_fn = _REGISTERED_TOOLS[tool_name] + + sig = inspect.signature(tool_fn) + allowed = set(sig.parameters.keys()) - {"user"} filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed} try: - return tool_fn(user, **filtered_kwargs) - except Exception: - logger.exception("Tool execution failed: %s", tool_name) - return {"error": "An unexpected error occurred while executing the tool"} + return tool_fn(user=user, **filtered_kwargs) + except Exception as exc: + logger.exception("Tool %s failed", tool_name) + return {"error": str(exc)} + + +AGENT_TOOLS = get_tool_schemas() def serialize_tool_result(result): diff --git a/backend/server/chat/llm_client.py b/backend/server/chat/llm_client.py index 5042ba9a..2549df58 100644 --- a/backend/server/chat/llm_client.py +++ b/backend/server/chat/llm_client.py @@ -2,11 +2,23 @@ import json import logging import litellm +from django.conf import settings from integrations.models import UserAPIKey logger = logging.getLogger(__name__) +PROVIDER_MODEL_PREFIX = { + "openai": "openai", + "anthropic": "anthropic", + "gemini": "gemini", + "ollama": "ollama", + "groq": "groq", + "mistral": "mistral", + "github_models": "github", + "openrouter": "openrouter", +} + CHAT_PROVIDER_CONFIG = { "openai": { "label": "OpenAI", @@ -59,12 +71,83 @@ CHAT_PROVIDER_CONFIG = { "opencode_zen": { "label": "OpenCode Zen", "needs_api_key": True, - "default_model": "openai/gpt-4o-mini", + # Chosen from OpenCode Zen compatible OpenAI-routed models per + # opencode_zen connection research (see .memory/research). + "default_model": "openai/gpt-5-nano", "api_base": "https://opencode.ai/zen/v1", }, } +def _is_model_override_compatible(provider, provider_config, model): + """Validate model/provider compatibility when strict checks are safe. + + For providers with a custom api_base gateway, skip strict prefix checks since + gateway routing may legitimately accept cross-provider prefixes. + """ + if not model or provider_config.get("api_base"): + return True + + if "/" not in model: + return True + + expected_prefix = PROVIDER_MODEL_PREFIX.get(provider) + if not expected_prefix: + default_model = provider_config.get("default_model") or "" + if "/" in default_model: + expected_prefix = default_model.split("/", 1)[0] + + if not expected_prefix: + return True + + return model.startswith(f"{expected_prefix}/") + + +def _safe_error_payload(exc): + exceptions = getattr(litellm, "exceptions", None) + not_found_cls = getattr(exceptions, "NotFoundError", tuple()) + auth_cls = getattr(exceptions, "AuthenticationError", tuple()) + rate_limit_cls = getattr(exceptions, "RateLimitError", tuple()) + bad_request_cls = getattr(exceptions, "BadRequestError", tuple()) + timeout_cls = getattr(exceptions, "Timeout", tuple()) + api_connection_cls = getattr(exceptions, "APIConnectionError", tuple()) + + if isinstance(exc, not_found_cls): + return { + "error": "The selected model is unavailable for this provider. Choose a different model and try again.", + "error_category": "model_not_found", + } + + if isinstance(exc, auth_cls): + return { + "error": "Authentication with the model provider failed. Verify your API key in Settings and try again.", + "error_category": "authentication_failed", + } + + if isinstance(exc, rate_limit_cls): + return { + "error": "The model provider rate limit was reached. Please wait and try again.", + "error_category": "rate_limited", + } + + if isinstance(exc, bad_request_cls): + return { + "error": "The model provider rejected this request. Check your selected model and try again.", + "error_category": "invalid_request", + } + + if isinstance(exc, timeout_cls) or isinstance(exc, api_connection_cls): + return { + "error": "Unable to reach the model provider right now. Please try again.", + "error_category": "provider_unreachable", + } + + return { + "error": "An error occurred while processing your request. Please try again.", + "error_category": "unknown_error", + } + + def _safe_get(obj, key, default=None): if obj is None: return default @@ -90,9 +173,20 @@ def is_chat_provider_available(provider_id): return normalized_provider in CHAT_PROVIDER_CONFIG -def get_provider_catalog(): +def get_provider_catalog(user=None): seen = set() catalog = [] + user_key_providers = set() + instance_provider = ( + _normalize_provider_id(settings.VOYAGE_AI_PROVIDER) + if settings.VOYAGE_AI_PROVIDER + else None + ) + instance_has_key = bool(settings.VOYAGE_AI_API_KEY) + if user: + user_key_providers = set( + UserAPIKey.objects.filter(user=user).values_list("provider", flat=True) + ) for provider_id in getattr(litellm, "provider_list", []): normalized_provider = _normalize_provider_id(provider_id) @@ -110,6 +204,9 @@ def get_provider_catalog(): "needs_api_key": provider_config["needs_api_key"], "default_model": provider_config["default_model"], "api_base": provider_config["api_base"], + "instance_configured": instance_has_key + and normalized_provider == instance_provider, + "user_configured": normalized_provider in user_key_providers, } ) continue @@ -122,6 +219,9 @@ def get_provider_catalog(): "needs_api_key": None, "default_model": None, "api_base": None, + "instance_configured": instance_has_key + and normalized_provider == instance_provider, + "user_configured": normalized_provider in user_key_providers, } ) @@ -141,6 +241,9 @@ def get_provider_catalog(): "needs_api_key": provider_config["needs_api_key"], "default_model": provider_config["default_model"], "api_base": provider_config["api_base"], + "instance_configured": instance_has_key + and normalized_provider == instance_provider, + "user_configured": normalized_provider in user_key_providers, } ) @@ -154,9 +257,55 @@ def get_llm_api_key(user, provider): key_obj = UserAPIKey.objects.get(user=user, provider=normalized_provider) return key_obj.get_api_key() except UserAPIKey.DoesNotExist: + if normalized_provider == _normalize_provider_id(settings.VOYAGE_AI_PROVIDER): + instance_api_key = (settings.VOYAGE_AI_API_KEY or "").strip() + if instance_api_key: + return instance_api_key return None +def _format_interests(interests): + if isinstance(interests, list): + return ", ".join(interests) + return interests + + +def get_aggregated_preferences(collection): + """Aggregate preferences from collection owner and shared users.""" + from integrations.models import UserRecommendationPreferenceProfile + + users = [collection.user] + list(collection.shared_with.all()) + preferences = [] + + for member in users: + try: + profile = UserRecommendationPreferenceProfile.objects.get(user=member) + user_prefs = [] + + if profile.cuisines: + user_prefs.append(f"cuisines: {profile.cuisines}") + if profile.interests: + user_prefs.append(f"interests: {_format_interests(profile.interests)}") + if profile.trip_style: + user_prefs.append(f"style: {profile.trip_style}") + if profile.notes: + user_prefs.append(f"notes: {profile.notes}") + + if user_prefs: + preferences.append(f"- **{member.username}**: {', '.join(user_prefs)}") + except UserRecommendationPreferenceProfile.DoesNotExist: + continue + + if preferences: + return ( + "\n\n## Party Preferences\n" + + "\n".join(preferences) + + "\n\nNote: Consider all travelers' preferences when making recommendations." + ) + + return "" + + def get_system_prompt(user, collection=None): """Build the system prompt with user context.""" from integrations.models import UserRecommendationPreferenceProfile @@ -181,26 +330,37 @@ When modifying itineraries: Be conversational, helpful, and enthusiastic about travel. Keep responses concise but informative.""" - try: - profile = UserRecommendationPreferenceProfile.objects.get(user=user) - prefs = [] - if profile.cuisines: - prefs.append(f"Cuisine preferences: {profile.cuisines}") - if profile.interests: - prefs.append(f"Interests: {profile.interests}") - if profile.trip_style: - prefs.append(f"Travel style: {profile.trip_style}") - if profile.notes: - prefs.append(f"Additional notes: {profile.notes}") - if prefs: - base_prompt += "\n\nUser preferences:\n" + "\n".join(prefs) - except UserRecommendationPreferenceProfile.DoesNotExist: - pass + if collection and collection.shared_with.count() > 0: + base_prompt += get_aggregated_preferences(collection) + else: + try: + profile = UserRecommendationPreferenceProfile.objects.get(user=user) + preference_lines = [] + + if profile.cuisines: + preference_lines.append( + f"🍽️ **Cuisine Preferences**: {profile.cuisines}" + ) + if profile.interests: + preference_lines.append( + f"🎯 **Interests**: {_format_interests(profile.interests)}" + ) + if profile.trip_style: + preference_lines.append(f"✈️ **Travel Style**: {profile.trip_style}") + if profile.notes: + preference_lines.append(f"πŸ“ **Additional Notes**: {profile.notes}") + + if preference_lines: + base_prompt += "\n\n## Traveler Preferences\n" + "\n".join( + preference_lines + ) + except UserRecommendationPreferenceProfile.DoesNotExist: + pass return base_prompt -async def stream_chat_completion(user, messages, provider, tools=None): +async def stream_chat_completion(user, messages, provider, tools=None, model=None): """Stream a chat completion using LiteLLM. Yields SSE-formatted strings. @@ -215,6 +375,7 @@ async def stream_chat_completion(user, messages, provider, tools=None): return api_key = get_llm_api_key(user, normalized_provider) + if provider_config["needs_api_key"] and not api_key: payload = { "error": f"No API key found for provider: {normalized_provider}. Please add one in Settings." @@ -222,14 +383,31 @@ async def stream_chat_completion(user, messages, provider, tools=None): yield f"data: {json.dumps(payload)}\n\n" return + if not _is_model_override_compatible(normalized_provider, provider_config, model): + payload = { + "error": "The selected model is incompatible with this provider. Choose a model for the selected provider and try again.", + "error_category": "invalid_model_for_provider", + } + yield f"data: {json.dumps(payload)}\n\n" + return + completion_kwargs = { - "model": provider_config["default_model"], + "model": model + or ( + settings.VOYAGE_AI_MODEL + if normalized_provider + == _normalize_provider_id(settings.VOYAGE_AI_PROVIDER) + and settings.VOYAGE_AI_MODEL + else None + ) + or provider_config["default_model"], "messages": messages, - "tools": tools, - "tool_choice": "auto" if tools else None, "stream": True, "api_key": api_key, } + if tools: + completion_kwargs["tools"] = tools + completion_kwargs["tool_choice"] = "auto" if provider_config["api_base"]: completion_kwargs["api_base"] = provider_config["api_base"] @@ -271,6 +449,7 @@ async def stream_chat_completion(user, messages, provider, tools=None): yield f"data: {json.dumps(chunk_data)}\n\n" yield "data: [DONE]\n\n" - except Exception: + except Exception as exc: logger.exception("LLM streaming error") - yield f"data: {json.dumps({'error': 'An error occurred while processing your request. Please try again.'})}\n\n" + payload = _safe_error_payload(exc) + yield f"data: {json.dumps(payload)}\n\n" diff --git a/backend/server/chat/urls.py b/backend/server/chat/urls.py index 220fec94..dd0c0522 100644 --- a/backend/server/chat/urls.py +++ b/backend/server/chat/urls.py @@ -1,7 +1,12 @@ from django.urls import include, path from rest_framework.routers import DefaultRouter -from .views import ChatProviderCatalogViewSet, ChatViewSet +from .views import ( + CapabilitiesView, + ChatProviderCatalogViewSet, + ChatViewSet, + DaySuggestionsView, +) router = DefaultRouter() router.register(r"conversations", ChatViewSet, basename="chat-conversation") @@ -11,4 +16,6 @@ router.register( urlpatterns = [ path("", include(router.urls)), + path("capabilities/", CapabilitiesView.as_view(), name="chat-capabilities"), + path("suggestions/day/", DaySuggestionsView.as_view(), name="chat-day-suggestions"), ] diff --git a/backend/server/chat/views/__init__.py b/backend/server/chat/views/__init__.py new file mode 100644 index 00000000..cde13be9 --- /dev/null +++ b/backend/server/chat/views/__init__.py @@ -0,0 +1,328 @@ +import asyncio +import json + +from asgiref.sync import sync_to_async +from adventures.models import Collection +from django.http import StreamingHttpResponse +from rest_framework import status, viewsets +from rest_framework.decorators import action +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response + +from ..agent_tools import AGENT_TOOLS, execute_tool, serialize_tool_result +from ..llm_client import ( + get_provider_catalog, + get_system_prompt, + is_chat_provider_available, + stream_chat_completion, +) +from ..models import ChatConversation, ChatMessage +from ..serializers import ChatConversationSerializer + + +class ChatViewSet(viewsets.ModelViewSet): + serializer_class = ChatConversationSerializer + permission_classes = [IsAuthenticated] + + def get_queryset(self): + return ChatConversation.objects.filter(user=self.request.user).prefetch_related( + "messages" + ) + + def list(self, request, *args, **kwargs): + conversations = self.get_queryset().only("id", "title", "updated_at") + data = [ + { + "id": str(conversation.id), + "title": conversation.title, + "updated_at": conversation.updated_at, + } + for conversation in conversations + ] + return Response(data) + + def create(self, request, *args, **kwargs): + conversation = ChatConversation.objects.create( + user=request.user, + title=(request.data.get("title") or "").strip(), + ) + serializer = self.get_serializer(conversation) + return Response(serializer.data, status=status.HTTP_201_CREATED) + + def _build_llm_messages(self, conversation, user, system_prompt=None): + messages = [ + { + "role": "system", + "content": system_prompt or get_system_prompt(user), + } + ] + for message in conversation.messages.all().order_by("created_at"): + payload = { + "role": message.role, + "content": message.content, + } + if message.role == "assistant" and message.tool_calls: + payload["tool_calls"] = message.tool_calls + if message.role == "tool": + payload["tool_call_id"] = message.tool_call_id + payload["name"] = message.name + messages.append(payload) + return messages + + def _async_to_sync_generator(self, async_gen): + loop = asyncio.new_event_loop() + try: + while True: + try: + yield loop.run_until_complete(async_gen.__anext__()) + except StopAsyncIteration: + break + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() + + @staticmethod + def _merge_tool_call_delta(accumulator, tool_calls_delta): + for idx, tool_call in enumerate(tool_calls_delta or []): + idx = tool_call.get("index", idx) + while len(accumulator) <= idx: + accumulator.append( + { + "id": None, + "type": "function", + "function": {"name": "", "arguments": ""}, + } + ) + + current = accumulator[idx] + if tool_call.get("id"): + current["id"] = tool_call.get("id") + if tool_call.get("type"): + current["type"] = tool_call.get("type") + + function_data = tool_call.get("function") or {} + if function_data.get("name"): + current["function"]["name"] = function_data.get("name") + if function_data.get("arguments"): + current["function"]["arguments"] += function_data.get("arguments") + + @action(detail=True, methods=["post"]) + def send_message(self, request, pk=None): + conversation = self.get_object() + user_content = (request.data.get("message") or "").strip() + if not user_content: + return Response( + {"error": "message is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + provider = (request.data.get("provider") or "openai").strip().lower() + model = (request.data.get("model") or "").strip() or None + collection_id = request.data.get("collection_id") + collection_name = request.data.get("collection_name") + start_date = request.data.get("start_date") + end_date = request.data.get("end_date") + destination = request.data.get("destination") + if not is_chat_provider_available(provider): + return Response( + {"error": f"Provider is not available for chat: {provider}."}, + status=status.HTTP_400_BAD_REQUEST, + ) + + context_parts = [] + if collection_name: + context_parts.append(f"Trip: {collection_name}") + if destination: + context_parts.append(f"Destination: {destination}") + if start_date and end_date: + context_parts.append(f"Dates: {start_date} to {end_date}") + + collection = None + if collection_id: + try: + requested_collection = Collection.objects.get(id=collection_id) + if ( + requested_collection.user == request.user + or requested_collection.shared_with.filter( + id=request.user.id + ).exists() + ): + collection = requested_collection + except Collection.DoesNotExist: + pass + + system_prompt = get_system_prompt(request.user, collection) + if context_parts: + system_prompt += "\n\n## Trip Context\n" + "\n".join(context_parts) + + ChatMessage.objects.create( + conversation=conversation, + role="user", + content=user_content, + ) + conversation.save(update_fields=["updated_at"]) + + if not conversation.title: + conversation.title = user_content[:120] + conversation.save(update_fields=["title", "updated_at"]) + + llm_messages = self._build_llm_messages( + conversation, + request.user, + system_prompt=system_prompt, + ) + + MAX_TOOL_ITERATIONS = 10 + + async def event_stream(): + current_messages = list(llm_messages) + encountered_error = False + tool_iterations = 0 + + while tool_iterations < MAX_TOOL_ITERATIONS: + content_chunks = [] + tool_calls_accumulator = [] + + async for chunk in stream_chat_completion( + request.user, + current_messages, + provider, + tools=AGENT_TOOLS, + model=model, + ): + if not chunk.startswith("data: "): + yield chunk + continue + + payload = chunk[len("data: ") :].strip() + if payload == "[DONE]": + continue + + yield chunk + + try: + data = json.loads(payload) + except json.JSONDecodeError: + continue + + if data.get("error"): + encountered_error = True + break + + if data.get("content"): + content_chunks.append(data["content"]) + + if data.get("tool_calls"): + self._merge_tool_call_delta( + tool_calls_accumulator, + data["tool_calls"], + ) + + if encountered_error: + break + + assistant_content = "".join(content_chunks) + + if tool_calls_accumulator: + assistant_with_tools = { + "role": "assistant", + "content": assistant_content, + "tool_calls": tool_calls_accumulator, + } + current_messages.append(assistant_with_tools) + + await sync_to_async( + ChatMessage.objects.create, thread_sensitive=True + )( + conversation=conversation, + role="assistant", + content=assistant_content, + tool_calls=tool_calls_accumulator, + ) + await sync_to_async(conversation.save, thread_sensitive=True)( + update_fields=["updated_at"] + ) + + for tool_call in tool_calls_accumulator: + function_payload = tool_call.get("function") or {} + function_name = function_payload.get("name") or "" + raw_arguments = function_payload.get("arguments") or "{}" + + try: + arguments = json.loads(raw_arguments) + except json.JSONDecodeError: + arguments = {} + if not isinstance(arguments, dict): + arguments = {} + + result = await sync_to_async( + execute_tool, thread_sensitive=True + )( + function_name, + request.user, + **arguments, + ) + result_content = serialize_tool_result(result) + + current_messages.append( + { + "role": "tool", + "tool_call_id": tool_call.get("id"), + "name": function_name, + "content": result_content, + } + ) + + await sync_to_async( + ChatMessage.objects.create, thread_sensitive=True + )( + conversation=conversation, + role="tool", + content=result_content, + tool_call_id=tool_call.get("id"), + name=function_name, + ) + await sync_to_async(conversation.save, thread_sensitive=True)( + update_fields=["updated_at"] + ) + + tool_event = { + "tool_result": { + "tool_call_id": tool_call.get("id"), + "name": function_name, + "result": result, + } + } + yield f"data: {json.dumps(tool_event)}\n\n" + + continue + + await sync_to_async(ChatMessage.objects.create, thread_sensitive=True)( + conversation=conversation, + role="assistant", + content=assistant_content, + ) + await sync_to_async(conversation.save, thread_sensitive=True)( + update_fields=["updated_at"] + ) + yield "data: [DONE]\n\n" + break + + response = StreamingHttpResponse( + streaming_content=self._async_to_sync_generator(event_stream()), + content_type="text/event-stream", + ) + response["Cache-Control"] = "no-cache" + response["X-Accel-Buffering"] = "no" + return response + + +class ChatProviderCatalogViewSet(viewsets.ViewSet): + permission_classes = [IsAuthenticated] + + def list(self, request): + return Response(get_provider_catalog(user=request.user)) + + +from .capabilities import CapabilitiesView +from .day_suggestions import DaySuggestionsView diff --git a/backend/server/chat/views/capabilities.py b/backend/server/chat/views/capabilities.py new file mode 100644 index 00000000..75cabb22 --- /dev/null +++ b/backend/server/chat/views/capabilities.py @@ -0,0 +1,24 @@ +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework.views import APIView + +from chat.agent_tools import get_tool_schemas + + +class CapabilitiesView(APIView): + permission_classes = [IsAuthenticated] + + def get(self, request): + """Return available AI capabilities/tools.""" + tools = get_tool_schemas() + return Response( + { + "tools": [ + { + "name": tool["function"]["name"], + "description": tool["function"]["description"], + } + for tool in tools + ] + } + ) diff --git a/backend/server/chat/views/day_suggestions.py b/backend/server/chat/views/day_suggestions.py new file mode 100644 index 00000000..67711836 --- /dev/null +++ b/backend/server/chat/views/day_suggestions.py @@ -0,0 +1,215 @@ +import json +import re + +import litellm +from django.shortcuts import get_object_or_404 +from rest_framework import status +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework.views import APIView + +from adventures.models import Collection +from chat.agent_tools import search_places +from chat.llm_client import ( + get_llm_api_key, + get_system_prompt, + is_chat_provider_available, +) + + +class DaySuggestionsView(APIView): + permission_classes = [IsAuthenticated] + + def post(self, request): + collection_id = request.data.get("collection_id") + date = request.data.get("date") + category = request.data.get("category") + filters = request.data.get("filters", {}) or {} + location_context = request.data.get("location_context", "") + + if not all([collection_id, date, category]): + return Response( + {"error": "collection_id, date, and category are required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + valid_categories = ["restaurant", "activity", "event", "lodging"] + if category not in valid_categories: + return Response( + {"error": f"category must be one of: {', '.join(valid_categories)}"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + collection = get_object_or_404(Collection, id=collection_id) + if ( + collection.user != request.user + and not collection.shared_with.filter(id=request.user.id).exists() + ): + return Response( + {"error": "You don't have access to this collection"}, + status=status.HTTP_403_FORBIDDEN, + ) + + location = location_context or self._get_collection_location(collection) + system_prompt = get_system_prompt(request.user, collection) + provider = "openai" + + if not is_chat_provider_available(provider): + return Response( + { + "error": "AI suggestions are not available. Please configure an API key." + }, + status=status.HTTP_503_SERVICE_UNAVAILABLE, + ) + + try: + places_context = self._get_places_context(request.user, category, location) + prompt = self._build_prompt( + category=category, + filters=filters, + location=location, + date=date, + collection=collection, + places_context=places_context, + ) + + suggestions = self._get_suggestions_from_llm( + system_prompt=system_prompt, + user_prompt=prompt, + user=request.user, + provider=provider, + ) + return Response({"suggestions": suggestions}, status=status.HTTP_200_OK) + except Exception: + return Response( + {"error": "Failed to generate suggestions. Please try again."}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + def _get_collection_location(self, collection): + for loc in collection.locations.select_related("city", "country").all(): + if loc.city: + city_name = getattr(loc.city, "name", str(loc.city)) + country_name = getattr(loc.country, "name", "") if loc.country else "" + return ", ".join([x for x in [city_name, country_name] if x]) + if loc.location: + return loc.location + if loc.name: + return loc.name + return "Unknown location" + + def _build_prompt( + self, + category, + filters, + location, + date, + collection, + places_context="", + ): + category_prompts = { + "restaurant": f"Find restaurant recommendations for {location}", + "activity": f"Find activity recommendations for {location}", + "event": f"Find event recommendations for {location} around {date}", + "lodging": f"Find lodging recommendations for {location}", + } + + prompt = category_prompts.get( + category, f"Find {category} recommendations for {location}" + ) + + if filters: + filter_parts = [] + if filters.get("cuisine_type"): + filter_parts.append(f"cuisine type: {filters['cuisine_type']}") + if filters.get("price_range"): + filter_parts.append(f"price range: {filters['price_range']}") + if filters.get("dietary"): + filter_parts.append(f"dietary restrictions: {filters['dietary']}") + if filters.get("activity_type"): + filter_parts.append(f"type: {filters['activity_type']}") + if filters.get("duration"): + filter_parts.append(f"duration: {filters['duration']}") + if filters.get("event_type"): + filter_parts.append(f"event type: {filters['event_type']}") + if filters.get("lodging_type"): + filter_parts.append(f"lodging type: {filters['lodging_type']}") + amenities = filters.get("amenities") + if isinstance(amenities, list) and amenities: + filter_parts.append( + f"amenities: {', '.join(str(x) for x in amenities)}" + ) + + if filter_parts: + prompt += f" with these preferences: {', '.join(filter_parts)}" + + prompt += f". The trip date is {date}." + + if collection.start_date or collection.end_date: + prompt += ( + " Collection trip window: " + f"{collection.start_date or 'unknown'} to {collection.end_date or 'unknown'}." + ) + + if places_context: + prompt += f" Nearby place context: {places_context}." + + prompt += ( + " Return 3-5 specific suggestions as a JSON array." + " Each suggestion should have: name, description, why_fits, category, location, rating, price_level." + " Return ONLY valid JSON, no markdown, no surrounding text." + ) + return prompt + + def _get_places_context(self, user, category, location): + tool_category_map = { + "restaurant": "food", + "activity": "tourism", + "event": "tourism", + "lodging": "lodging", + } + result = search_places( + user, + location=location, + category=tool_category_map.get(category, "tourism"), + radius=8, + ) + if result.get("error"): + return "" + + entries = [] + for place in result.get("results", [])[:5]: + name = place.get("name") + address = place.get("address") or "" + if name: + entries.append(f"{name} ({address})" if address else name) + return "; ".join(entries) + + def _get_suggestions_from_llm(self, system_prompt, user_prompt, user, provider): + api_key = get_llm_api_key(user, provider) + if not api_key: + raise ValueError("No API key available") + + response = litellm.completion( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + api_key=api_key, + temperature=0.7, + max_tokens=1000, + ) + + content = (response.choices[0].message.content or "").strip() + try: + json_match = re.search(r"\[.*\]", content, re.DOTALL) + parsed = ( + json.loads(json_match.group()) + if json_match + else json.loads(content or "[]") + ) + suggestions = parsed if isinstance(parsed, list) else [parsed] + return suggestions[:5] + except json.JSONDecodeError: + return [] diff --git a/backend/server/integrations/migrations/0008_useraisettings.py b/backend/server/integrations/migrations/0008_useraisettings.py new file mode 100644 index 00000000..f3c18f9e --- /dev/null +++ b/backend/server/integrations/migrations/0008_useraisettings.py @@ -0,0 +1,52 @@ +# Generated by Django 5.2.12 on 2026-03-08 + +import django.db.models.deletion +import uuid +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("integrations", "0007_userapikey_userrecommendationpreferenceprofile"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="UserAISettings", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ( + "preferred_provider", + models.CharField(blank=True, max_length=100, null=True), + ), + ( + "preferred_model", + models.CharField(blank=True, max_length=100, null=True), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "user", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + related_name="ai_settings", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "verbose_name": "User AI Settings", + "verbose_name_plural": "User AI Settings", + }, + ), + ] diff --git a/backend/server/integrations/models.py b/backend/server/integrations/models.py index 0029308d..45754903 100644 --- a/backend/server/integrations/models.py +++ b/backend/server/integrations/models.py @@ -124,3 +124,23 @@ class UserRecommendationPreferenceProfile(models.Model): notes = models.TextField(blank=True, null=True) created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) + + +class UserAISettings(models.Model): + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + user = models.OneToOneField( + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + related_name="ai_settings", + ) + preferred_provider = models.CharField(max_length=100, blank=True, null=True) + preferred_model = models.CharField(max_length=100, blank=True, null=True) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + verbose_name = "User AI Settings" + verbose_name_plural = "User AI Settings" + + def __str__(self): + return f"AI Settings for {self.user.username}" diff --git a/backend/server/integrations/serializers.py b/backend/server/integrations/serializers.py index 7b893ad2..49b56d1c 100644 --- a/backend/server/integrations/serializers.py +++ b/backend/server/integrations/serializers.py @@ -3,6 +3,7 @@ from django.db import IntegrityError from .models import ( EncryptionConfigurationError, ImmichIntegration, + UserAISettings, UserAPIKey, UserRecommendationPreferenceProfile, ) @@ -98,3 +99,16 @@ class UserRecommendationPreferenceProfileSerializer(serializers.ModelSerializer) "updated_at", ] read_only_fields = ["id", "created_at", "updated_at"] + + +class UserAISettingsSerializer(serializers.ModelSerializer): + class Meta: + model = UserAISettings + fields = [ + "id", + "preferred_provider", + "preferred_model", + "created_at", + "updated_at", + ] + read_only_fields = ["id", "created_at", "updated_at"] diff --git a/backend/server/integrations/urls.py b/backend/server/integrations/urls.py index 181bebca..7cfc33f9 100644 --- a/backend/server/integrations/urls.py +++ b/backend/server/integrations/urls.py @@ -6,6 +6,7 @@ from integrations.views import ( StravaIntegrationView, WandererIntegrationViewSet, UserAPIKeyViewSet, + UserAISettingsViewSet, UserRecommendationPreferenceProfileViewSet, ) @@ -22,6 +23,7 @@ router.register( UserRecommendationPreferenceProfileViewSet, basename="user-recommendation-preferences", ) +router.register(r"ai-settings", UserAISettingsViewSet, basename="user-ai-settings") # Include the router URLs urlpatterns = [ diff --git a/backend/server/integrations/views/__init__.py b/backend/server/integrations/views/__init__.py index f757b06d..36a08ebf 100644 --- a/backend/server/integrations/views/__init__.py +++ b/backend/server/integrations/views/__init__.py @@ -4,3 +4,4 @@ from .strava_view import StravaIntegrationView from .wanderer_view import WandererIntegrationViewSet from .user_api_key_view import UserAPIKeyViewSet from .recommendation_profile_view import UserRecommendationPreferenceProfileViewSet +from .ai_settings_view import UserAISettingsViewSet diff --git a/backend/server/integrations/views/ai_settings_view.py b/backend/server/integrations/views/ai_settings_view.py new file mode 100644 index 00000000..10e1afab --- /dev/null +++ b/backend/server/integrations/views/ai_settings_view.py @@ -0,0 +1,39 @@ +from rest_framework import status, viewsets +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response + +from integrations.models import UserAISettings +from integrations.serializers import UserAISettingsSerializer + + +class UserAISettingsViewSet(viewsets.ModelViewSet): + serializer_class = UserAISettingsSerializer + permission_classes = [IsAuthenticated] + + def get_queryset(self): + return UserAISettings.objects.filter(user=self.request.user) + + def list(self, request, *args, **kwargs): + instance = self.get_queryset().first() + if not instance: + return Response([], status=status.HTTP_200_OK) + serializer = self.get_serializer(instance) + return Response([serializer.data], status=status.HTTP_200_OK) + + def perform_create(self, serializer): + existing = UserAISettings.objects.filter(user=self.request.user).first() + if existing: + for field, value in serializer.validated_data.items(): + setattr(existing, field, value) + existing.save() + self._upserted_instance = existing + return + + self._upserted_instance = serializer.save(user=self.request.user) + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + self.perform_create(serializer) + output = self.get_serializer(self._upserted_instance) + return Response(output.data, status=status.HTTP_200_OK) diff --git a/backend/server/main/settings.py b/backend/server/main/settings.py index 40554746..0a7d656c 100644 --- a/backend/server/main/settings.py +++ b/backend/server/main/settings.py @@ -403,6 +403,11 @@ OSRM_BASE_URL = getenv("OSRM_BASE_URL", "https://router.project-osrm.org") FIELD_ENCRYPTION_KEY = getenv("FIELD_ENCRYPTION_KEY", "") +# Voyage AI Configuration +VOYAGE_AI_PROVIDER = getenv("VOYAGE_AI_PROVIDER", "openai") +VOYAGE_AI_MODEL = getenv("VOYAGE_AI_MODEL", "gpt-4o-mini") +VOYAGE_AI_API_KEY = getenv("VOYAGE_AI_API_KEY", "") + DJANGO_MCP_ENDPOINT = getenv("DJANGO_MCP_ENDPOINT", "api/mcp") DJANGO_MCP_AUTHENTICATION_CLASSES = [ "rest_framework.authentication.TokenAuthentication", diff --git a/backend/server/requirements.txt b/backend/server/requirements.txt index 034f4916..b736fe16 100644 --- a/backend/server/requirements.txt +++ b/backend/server/requirements.txt @@ -34,3 +34,4 @@ requests>=2.32.5 cryptography>=46.0.5 django-mcp-server>=0.5.7 litellm>=1.72.3 +duckduckgo-search>=4.0.0 diff --git a/frontend/src/lib/components/AITravelChat.svelte b/frontend/src/lib/components/AITravelChat.svelte index a1b8a0d3..e20c4fd7 100644 --- a/frontend/src/lib/components/AITravelChat.svelte +++ b/frontend/src/lib/components/AITravelChat.svelte @@ -1,8 +1,22 @@ + + + + + diff --git a/frontend/src/lib/types.ts b/frontend/src/lib/types.ts index 3b02346a..c364098b 100644 --- a/frontend/src/lib/types.ts +++ b/frontend/src/lib/types.ts @@ -575,6 +575,16 @@ export type ChatProviderCatalogEntry = { api_base: string | null; }; +export type UserRecommendationPreferenceProfile = { + id: string; + cuisines: string | null; + interests: string[]; + trip_style: string | null; + notes: string | null; + created_at: string; + updated_at: string; +}; + export type CollectionItineraryDay = { id: string; collection: string; // UUID of the collection diff --git a/frontend/src/locales/en.json b/frontend/src/locales/en.json index 5d3c05d1..db2632e9 100644 --- a/frontend/src/locales/en.json +++ b/frontend/src/locales/en.json @@ -44,7 +44,41 @@ "send": "Send", "delete_conversation": "Delete Conversation", "connection_error": "Connection error. Please try again.", - "no_api_key": "No API key found. Please add one in Settings." + "no_api_key": "No API key found. Please add one in Settings.", + "model_label": "Model", + "model_placeholder": "Default model" + }, + "travel_assistant": "Travel Assistant", + "quick_actions": "Quick actions", + "add_to_itinerary": "Add to Itinerary", + "add_to_which_day": "Add \"{placeName}\" to which day?", + "added_successfully": "Added to itinerary!", + "suggestions": { + "title": "AI Suggestions", + "for_date": "for {date}", + "select_category": "What would you like suggestions for?", + "category_restaurant": "Restaurant", + "category_activity": "Activity", + "category_event": "Event", + "category_lodging": "Lodging", + "surprise_me": "Surprise me!", + "refine_filters": "Refine your preferences", + "cuisine_type": "Cuisine type", + "price_range": "Price range", + "dietary": "Dietary restrictions", + "activity_type": "Activity type", + "duration": "Duration", + "event_type": "Event type", + "time_preference": "Time preference", + "lodging_type": "Lodging type", + "amenities": "Amenities", + "get_suggestions": "Get Suggestions", + "loading": "Finding great options...", + "no_results": "No suggestions found. Try adjusting your filters.", + "try_again": "Try different filters", + "add_to_day": "Add to this day", + "why_fits": "Why it's a great fit", + "error": "Failed to get suggestions. Please try again." }, "about": { "about": "About", @@ -782,7 +816,19 @@ "travel_agent_help_title": "How to use the travel agent", "travel_agent_help_body": "Open a collection and switch to Recommendations to interact with the travel agent for place suggestions.", "travel_agent_help_open_collections": "Open Collections", - "travel_agent_help_setup_guide": "Travel agent setup guide" + "travel_agent_help_setup_guide": "Travel agent setup guide", + "travel_preferences": "Travel Preferences", + "travel_preferences_desc": "Customize your travel preferences for better AI recommendations", + "cuisines": "Favorite Cuisines", + "cuisines_placeholder": "e.g., Italian, Japanese, Mexican...", + "interests": "Travel Interests", + "interests_placeholder": "e.g., hiking, museums, beaches, nightlife...", + "trip_style": "Travel Style", + "trip_style_placeholder": "e.g., adventure, luxury, budget, cultural", + "notes": "Additional Notes", + "notes_placeholder": "Any other preferences or considerations for your trips...", + "preferences_saved": "Preferences saved successfully!", + "preferences_save_error": "Failed to save preferences" }, "collection": { "collection_created": "Collection created successfully!", diff --git a/frontend/src/routes/collections/[id]/+page.svelte b/frontend/src/routes/collections/[id]/+page.svelte index 38425cf1..4fccd48d 100644 --- a/frontend/src/routes/collections/[id]/+page.svelte +++ b/frontend/src/routes/collections/[id]/+page.svelte @@ -256,6 +256,29 @@ // Enforce recommendations visibility only for owner/shared users $: availableViews.recommendations = !!canModifyCollection; + function deriveCollectionDestination(current: Collection | null): string | undefined { + if (!current?.locations?.length) { + return undefined; + } + + const firstLocation = current.locations.find((loc) => + Boolean(loc.city?.name || loc.country?.name || loc.location || loc.name) + ); + if (!firstLocation) { + return undefined; + } + + const cityName = firstLocation.city?.name; + const countryName = firstLocation.country?.name; + if (cityName && countryName) { + return `${cityName}, ${countryName}`; + } + + return cityName || countryName || firstLocation.location || firstLocation.name || undefined; + } + + $: collectionDestination = deriveCollectionDestination(collection); + // Build calendar events from collection visits type TimezoneMode = 'event' | 'local'; @@ -1261,7 +1284,14 @@ {#if currentView === 'recommendations'}
- +
{/if} diff --git a/frontend/src/routes/settings/+page.server.ts b/frontend/src/routes/settings/+page.server.ts index 2f6249de..58e54339 100644 --- a/frontend/src/routes/settings/+page.server.ts +++ b/frontend/src/routes/settings/+page.server.ts @@ -1,7 +1,7 @@ import { fail, redirect, type Actions } from '@sveltejs/kit'; import type { PageServerLoad } from '../$types'; const PUBLIC_SERVER_URL = process.env['PUBLIC_SERVER_URL']; -import type { ImmichIntegration, User } from '$lib/types'; +import type { ImmichIntegration, User, UserRecommendationPreferenceProfile } from '$lib/types'; import { fetchCSRFToken } from '$lib/index.server'; const endpoint = PUBLIC_SERVER_URL || 'http://localhost:8000'; @@ -95,11 +95,25 @@ export const load: PageServerLoad = async (event) => { let apiKeys: UserAPIKey[] = []; let apiKeysConfigError: string | null = null; - let apiKeysFetch = await fetch(`${endpoint}/api/integrations/api-keys/`, { - headers: { - Cookie: `sessionid=${sessionId}` - } - }); + let [apiKeysFetch, recommendationPreferencesFetch] = await Promise.all([ + fetch(`${endpoint}/api/integrations/api-keys/`, { + headers: { + Cookie: `sessionid=${sessionId}` + } + }), + fetch(`${endpoint}/api/integrations/recommendation-preferences/`, { + headers: { + Cookie: `sessionid=${sessionId}` + } + }) + ]); + + let recommendationProfile: UserRecommendationPreferenceProfile | null = null; + if (recommendationPreferencesFetch.ok) { + const recommendationProfiles = + (await recommendationPreferencesFetch.json()) as UserRecommendationPreferenceProfile[]; + recommendationProfile = recommendationProfiles[0] ?? null; + } if (apiKeysFetch.ok) { apiKeys = (await apiKeysFetch.json()) as UserAPIKey[]; @@ -131,6 +145,7 @@ export const load: PageServerLoad = async (event) => { stravaUserEnabled, apiKeys, apiKeysConfigError, + recommendationProfile, wandererEnabled, wandererExpired } diff --git a/frontend/src/routes/settings/+page.svelte b/frontend/src/routes/settings/+page.svelte index b71d7275..338e0302 100644 --- a/frontend/src/routes/settings/+page.svelte +++ b/frontend/src/routes/settings/+page.svelte @@ -28,6 +28,16 @@ usage_required: boolean; }; + type UserRecommendationPreferenceProfile = { + id: string; + cuisines: string | null; + interests: string[]; + trip_style: string | null; + notes: string | null; + created_at: string; + updated_at: string; + }; + let new_email: string = ''; let public_url: string = data.props.publicUrl; let immichIntegration = data.props.immichIntegration; @@ -50,6 +60,13 @@ let newApiKeyValue = ''; let isSavingApiKey = false; let deletingApiKeyId: string | null = null; + let recommendationProfile: UserRecommendationPreferenceProfile | null = null; + let cuisinesValue = ''; + let interestsValue = ''; + let tripStyleValue = ''; + let notesValue = ''; + let isSavingPreferences = false; + let savePreferencesError = ''; let mcpToken: string | null = null; let isLoadingMcpToken = false; let activeSection: string = 'profile'; @@ -127,12 +144,23 @@ { id: 'emails', icon: 'πŸ“§', label: () => $t('settings.emails') }, { id: 'integrations', icon: 'πŸ”—', label: () => $t('settings.integrations') }, { id: 'ai_api_keys', icon: 'πŸ€–', label: () => $t('settings.ai_api_keys') }, + { id: 'travel_preferences', icon: '🧭', label: () => $t('settings.travel_preferences') }, { id: 'import_export', icon: 'πŸ“¦', label: () => $t('settings.backup_restore') }, { id: 'admin', icon: 'βš™οΈ', label: () => $t('settings.admin') }, { id: 'advanced', icon: 'πŸ› οΈ', label: () => $t('settings.advanced') } ]; onMount(async () => { + recommendationProfile = + (data.props as { recommendationProfile?: UserRecommendationPreferenceProfile | null }) + .recommendationProfile ?? null; + if (recommendationProfile) { + cuisinesValue = recommendationProfile.cuisines ?? ''; + interestsValue = (recommendationProfile.interests || []).join(', '); + tripStyleValue = recommendationProfile.trip_style ?? ''; + notesValue = recommendationProfile.notes ?? ''; + } + void loadProviderCatalog(); if (browser) { @@ -555,6 +583,45 @@ } } + async function savePreferences(event: SubmitEvent) { + event.preventDefault(); + savePreferencesError = ''; + isSavingPreferences = true; + + try { + const res = await fetch('/api/integrations/recommendation-preferences/', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + cuisines: cuisinesValue.trim() || null, + interests: interestsValue + .split(',') + .map((s) => s.trim()) + .filter(Boolean), + trip_style: tripStyleValue.trim() || null, + notes: notesValue.trim() || null + }) + }); + + if (!res.ok) { + savePreferencesError = $t('settings.preferences_save_error'); + addToast('error', $t('settings.preferences_save_error')); + return; + } + + recommendationProfile = (await res.json()) as UserRecommendationPreferenceProfile; + interestsValue = (recommendationProfile.interests || []).join(', '); + addToast('success', $t('settings.preferences_saved')); + } catch { + savePreferencesError = $t('settings.preferences_save_error'); + addToast('error', $t('settings.preferences_save_error')); + } finally { + isSavingPreferences = false; + } + } + function getMaskedMcpToken(token: string): string { if (token.length <= 8) { return 'β€’β€’β€’β€’β€’β€’β€’β€’'; @@ -1642,9 +1709,9 @@

{$t('settings.add_api_key')}

- + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ + {#if savePreferencesError} +
+ {savePreferencesError} +
+ {/if} + + +
+ + {/if} + {#if activeSection === 'import_export'}