feat(ai): implement agent-redesign plan with enhanced AI travel features

Phase 1 - Configuration Infrastructure (WS1):
- Add instance-level AI env vars (VOYAGE_AI_PROVIDER, VOYAGE_AI_MODEL, VOYAGE_AI_API_KEY)
- Implement fallback chain: user key → instance key → error
- Add UserAISettings model for per-user provider/model preferences
- Enhance provider catalog with instance_configured and user_configured flags
- Optimize provider catalog to avoid N+1 queries

Phase 1 - User Preference Learning (WS2):
- Add Travel Preferences tab to Settings page
- Improve preference formatting in system prompt with emoji headers
- Add multi-user preference aggregation for shared collections

Phase 2 - Day-Level Suggestions Modal (WS3):
- Create ItinerarySuggestionModal with 3-step flow (category → filters → results)
- Add AI suggestions button to itinerary Add dropdown
- Support restaurant, activity, event, and lodging categories
- Backend endpoint POST /api/chat/suggestions/day/ with context-aware prompts

Phase 3 - Collection-Level Chat Improvements (WS4):
- Inject collection context (destination, dates) into chat system prompt
- Add quick action buttons for common queries
- Add 'Add to itinerary' button on search_places results
- Update chat UI with travel-themed branding and improved tool result cards

Phase 3 - Web Search Capability (WS5):
- Add web_search agent tool using DuckDuckGo
- Support location_context parameter for biased results
- Handle rate limiting gracefully

Phase 4 - Extensibility Architecture (WS6):
- Implement decorator-based @agent_tool registry
- Convert existing tools to use decorators
- Add GET /api/chat/capabilities/ endpoint for tool discovery
- Refactor execute_tool() to use registry pattern
This commit is contained in:
2026-03-08 23:53:14 +00:00
parent 246b081d97
commit 9d5681b1ef
22 changed files with 2358 additions and 255 deletions

View File

@@ -1,4 +1,5 @@
import json
import inspect
import logging
from datetime import date as date_cls
@@ -10,117 +11,50 @@ from adventures.models import Collection, CollectionItineraryItem, Location
logger = logging.getLogger(__name__)
AGENT_TOOLS = [
{
"type": "function",
"function": {
"name": "search_places",
"description": "Search for places of interest near a location. Returns tourist attractions, restaurants, hotels, etc.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "Location name or address to search near",
},
"category": {
"type": "string",
"enum": ["tourism", "food", "lodging"],
"description": "Category of places",
},
"radius": {
"type": "number",
"description": "Search radius in km (default 10)",
},
_REGISTERED_TOOLS = {}
_TOOL_SCHEMAS = []
def agent_tool(name: str, description: str, parameters: dict):
"""Decorator to register a function as an agent tool."""
def decorator(func):
_REGISTERED_TOOLS[name] = func
required = [k for k, v in parameters.items() if v.get("required", False)]
props = {
k: {kk: vv for kk, vv in v.items() if kk != "required"}
for k, v in parameters.items()
}
schema = {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": {
"type": "object",
"properties": props,
"required": required,
},
"required": ["location"],
},
},
},
{
"type": "function",
"function": {
"name": "list_trips",
"description": "List the user's trip collections with dates and descriptions",
"parameters": {"type": "object", "properties": {}},
},
},
{
"type": "function",
"function": {
"name": "get_trip_details",
"description": "Get full details of a trip including all itinerary items, locations, transportation, and lodging",
"parameters": {
"type": "object",
"properties": {
"collection_id": {
"type": "string",
"description": "UUID of the collection/trip",
}
},
"required": ["collection_id"],
},
},
},
{
"type": "function",
"function": {
"name": "add_to_itinerary",
"description": "Add a new location to a trip's itinerary on a specific date",
"parameters": {
"type": "object",
"properties": {
"collection_id": {
"type": "string",
"description": "UUID of the collection/trip",
},
"name": {"type": "string", "description": "Name of the location"},
"description": {
"type": "string",
"description": "Description of why to visit",
},
"latitude": {
"type": "number",
"description": "Latitude coordinate",
},
"longitude": {
"type": "number",
"description": "Longitude coordinate",
},
"date": {
"type": "string",
"description": "Date in YYYY-MM-DD format",
},
"location_address": {
"type": "string",
"description": "Full address of the location",
},
},
"required": ["collection_id", "name", "latitude", "longitude"],
},
},
},
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get temperature/weather data for a location on specific dates",
"parameters": {
"type": "object",
"properties": {
"latitude": {"type": "number", "description": "Latitude"},
"longitude": {"type": "number", "description": "Longitude"},
"dates": {
"type": "array",
"items": {"type": "string"},
"description": "List of dates in YYYY-MM-DD format",
},
},
"required": ["latitude", "longitude", "dates"],
},
},
},
]
}
_TOOL_SCHEMAS.append(schema)
return func
return decorator
def get_tool_schemas() -> list:
"""Return all registered tool schemas for LLM."""
return _TOOL_SCHEMAS.copy()
def get_registered_tools() -> dict:
"""Return all registered tool functions."""
return _REGISTERED_TOOLS.copy()
NOMINATIM_URL = "https://nominatim.openstreetmap.org/search"
OVERPASS_URL = "https://overpass-api.de/api/interpreter"
@@ -162,14 +96,39 @@ def _parse_address(tags):
return ", ".join([p for p in parts if p])
def search_places(user, **kwargs):
@agent_tool(
name="search_places",
description="Search for places of interest near a location. Returns tourist attractions, restaurants, hotels, etc.",
parameters={
"location": {
"type": "string",
"description": "Location name or address to search near",
"required": True,
},
"category": {
"type": "string",
"enum": ["tourism", "food", "lodging"],
"description": "Category of places",
},
"radius": {
"type": "number",
"description": "Search radius in km (default 10)",
},
},
)
def search_places(
user,
location: str | None = None,
category: str = "tourism",
radius: float = 10,
):
try:
location_name = kwargs.get("location")
location_name = location
if not location_name:
return {"error": "location is required"}
category = kwargs.get("category") or "tourism"
radius_km = float(kwargs.get("radius") or 10)
category = category or "tourism"
radius_km = float(radius or 10)
radius_meters = max(500, min(int(radius_km * 1000), 50000))
geocode_resp = requests.get(
@@ -240,7 +199,12 @@ def search_places(user, **kwargs):
return {"error": "An unexpected error occurred during place search"}
def list_trips(user, **kwargs):
@agent_tool(
name="list_trips",
description="List the user's trip collections with dates and descriptions",
parameters={},
)
def list_trips(user):
try:
collections = Collection.objects.filter(user=user).prefetch_related("locations")
trips = []
@@ -265,9 +229,87 @@ def list_trips(user, **kwargs):
return {"error": "An unexpected error occurred while listing trips"}
def get_trip_details(user, **kwargs):
@agent_tool(
name="web_search",
description="Search the web for current information about destinations, events, prices, weather, or any real-time travel information. Use this when you need up-to-date information that may not be in your training data.",
parameters={
"query": {
"type": "string",
"description": "The search query (e.g., 'best restaurants Paris 2024', 'weather Tokyo March')",
"required": True,
},
"location_context": {
"type": "string",
"description": "Optional location to bias search results (e.g., 'Paris, France')",
},
},
)
def web_search(user, query: str, location_context: str | None = None) -> dict:
"""
Search the web for current information about destinations, events, prices, etc.
Args:
user: The user making the request (for auth/logging)
query: The search query
location_context: Optional location to bias results
Returns:
dict with 'results' list or 'error' string
"""
if not query:
return {"error": "query is required", "results": []}
try:
from duckduckgo_search import DDGS # type: ignore[import-not-found]
full_query = query
if location_context:
full_query = f"{query} {location_context}"
with DDGS() as ddgs:
results = list(ddgs.text(full_query, max_results=5))
formatted = []
for result in results:
formatted.append(
{
"title": result.get("title", ""),
"snippet": result.get("body", ""),
"url": result.get("href", ""),
}
)
return {"results": formatted}
except ImportError:
return {
"error": "Web search is not available (duckduckgo-search not installed)",
"results": [],
}
except Exception as exc:
error_str = str(exc).lower()
if "rate" in error_str or "limit" in error_str:
return {
"error": "Search rate limit reached. Please wait a moment and try again.",
"results": [],
}
logger.error("Web search error: %s", exc)
return {"error": "Web search failed. Please try again.", "results": []}
@agent_tool(
name="get_trip_details",
description="Get full details of a trip including all itinerary items, locations, transportation, and lodging",
parameters={
"collection_id": {
"type": "string",
"description": "UUID of the collection/trip",
"required": True,
}
},
)
def get_trip_details(user, collection_id: str | None = None):
try:
collection_id = kwargs.get("collection_id")
if not collection_id:
return {"error": "collection_id is required"}
@@ -354,16 +396,55 @@ def get_trip_details(user, **kwargs):
return {"error": "An unexpected error occurred while fetching trip details"}
def add_to_itinerary(user, **kwargs):
@agent_tool(
name="add_to_itinerary",
description="Add a new location to a trip's itinerary on a specific date",
parameters={
"collection_id": {
"type": "string",
"description": "UUID of the collection/trip",
"required": True,
},
"name": {
"type": "string",
"description": "Name of the location",
"required": True,
},
"description": {
"type": "string",
"description": "Description of why to visit",
},
"latitude": {
"type": "number",
"description": "Latitude coordinate",
"required": True,
},
"longitude": {
"type": "number",
"description": "Longitude coordinate",
"required": True,
},
"date": {
"type": "string",
"description": "Date in YYYY-MM-DD format",
},
"location_address": {
"type": "string",
"description": "Full address of the location",
},
},
)
def add_to_itinerary(
user,
collection_id: str | None = None,
name: str | None = None,
latitude: float | None = None,
longitude: float | None = None,
description: str | None = None,
date: str | None = None,
location_address: str | None = None,
):
try:
collection_id = kwargs.get("collection_id")
name = kwargs.get("name")
latitude = kwargs.get("latitude")
longitude = kwargs.get("longitude")
description = kwargs.get("description")
location_address = kwargs.get("location_address")
date = kwargs.get("date")
if not collection_id or not name or latitude is None or longitude is None:
return {
"error": "collection_id, name, latitude, and longitude are required"
@@ -479,16 +560,34 @@ def _fetch_temperature_for_date(latitude, longitude, date_value):
}
def get_weather(user, **kwargs):
@agent_tool(
name="get_weather",
description="Get temperature/weather data for a location on specific dates",
parameters={
"latitude": {"type": "number", "description": "Latitude", "required": True},
"longitude": {
"type": "number",
"description": "Longitude",
"required": True,
},
"dates": {
"type": "array",
"items": {"type": "string"},
"description": "List of dates in YYYY-MM-DD format",
"required": True,
},
},
)
def get_weather(user, latitude=None, longitude=None, dates=None):
try:
raw_latitude = kwargs.get("latitude")
raw_longitude = kwargs.get("longitude")
raw_latitude = latitude
raw_longitude = longitude
if raw_latitude is None or raw_longitude is None:
return {"error": "latitude and longitude are required"}
latitude = float(raw_latitude)
longitude = float(raw_longitude)
dates = kwargs.get("dates") or []
dates = dates or []
if not isinstance(dates, list) or not dates:
return {"error": "dates must be a non-empty list"}
@@ -509,44 +608,24 @@ def get_weather(user, **kwargs):
return {"error": "An unexpected error occurred while fetching weather data"}
ALLOWED_KWARGS = {
"search_places": {"location", "category", "radius"},
"list_trips": set(),
"get_trip_details": {"collection_id"},
"add_to_itinerary": {
"collection_id",
"name",
"description",
"latitude",
"longitude",
"date",
"location_address",
},
"get_weather": {"latitude", "longitude", "dates"},
}
def execute_tool(tool_name, user, **kwargs):
tool_map = {
"search_places": search_places,
"list_trips": list_trips,
"get_trip_details": get_trip_details,
"add_to_itinerary": add_to_itinerary,
"get_weather": get_weather,
}
tool_fn = tool_map.get(tool_name)
if not tool_fn:
if tool_name not in _REGISTERED_TOOLS:
return {"error": f"Unknown tool: {tool_name}"}
allowed = ALLOWED_KWARGS.get(tool_name, set())
tool_fn = _REGISTERED_TOOLS[tool_name]
sig = inspect.signature(tool_fn)
allowed = set(sig.parameters.keys()) - {"user"}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
try:
return tool_fn(user, **filtered_kwargs)
except Exception:
logger.exception("Tool execution failed: %s", tool_name)
return {"error": "An unexpected error occurred while executing the tool"}
return tool_fn(user=user, **filtered_kwargs)
except Exception as exc:
logger.exception("Tool %s failed", tool_name)
return {"error": str(exc)}
AGENT_TOOLS = get_tool_schemas()
def serialize_tool_result(result):

View File

@@ -2,11 +2,23 @@ import json
import logging
import litellm
from django.conf import settings
from integrations.models import UserAPIKey
logger = logging.getLogger(__name__)
PROVIDER_MODEL_PREFIX = {
"openai": "openai",
"anthropic": "anthropic",
"gemini": "gemini",
"ollama": "ollama",
"groq": "groq",
"mistral": "mistral",
"github_models": "github",
"openrouter": "openrouter",
}
CHAT_PROVIDER_CONFIG = {
"openai": {
"label": "OpenAI",
@@ -59,12 +71,83 @@ CHAT_PROVIDER_CONFIG = {
"opencode_zen": {
"label": "OpenCode Zen",
"needs_api_key": True,
"default_model": "openai/gpt-4o-mini",
# Chosen from OpenCode Zen compatible OpenAI-routed models per
# opencode_zen connection research (see .memory/research).
"default_model": "openai/gpt-5-nano",
"api_base": "https://opencode.ai/zen/v1",
},
}
def _is_model_override_compatible(provider, provider_config, model):
"""Validate model/provider compatibility when strict checks are safe.
For providers with a custom api_base gateway, skip strict prefix checks since
gateway routing may legitimately accept cross-provider prefixes.
"""
if not model or provider_config.get("api_base"):
return True
if "/" not in model:
return True
expected_prefix = PROVIDER_MODEL_PREFIX.get(provider)
if not expected_prefix:
default_model = provider_config.get("default_model") or ""
if "/" in default_model:
expected_prefix = default_model.split("/", 1)[0]
if not expected_prefix:
return True
return model.startswith(f"{expected_prefix}/")
def _safe_error_payload(exc):
exceptions = getattr(litellm, "exceptions", None)
not_found_cls = getattr(exceptions, "NotFoundError", tuple())
auth_cls = getattr(exceptions, "AuthenticationError", tuple())
rate_limit_cls = getattr(exceptions, "RateLimitError", tuple())
bad_request_cls = getattr(exceptions, "BadRequestError", tuple())
timeout_cls = getattr(exceptions, "Timeout", tuple())
api_connection_cls = getattr(exceptions, "APIConnectionError", tuple())
if isinstance(exc, not_found_cls):
return {
"error": "The selected model is unavailable for this provider. Choose a different model and try again.",
"error_category": "model_not_found",
}
if isinstance(exc, auth_cls):
return {
"error": "Authentication with the model provider failed. Verify your API key in Settings and try again.",
"error_category": "authentication_failed",
}
if isinstance(exc, rate_limit_cls):
return {
"error": "The model provider rate limit was reached. Please wait and try again.",
"error_category": "rate_limited",
}
if isinstance(exc, bad_request_cls):
return {
"error": "The model provider rejected this request. Check your selected model and try again.",
"error_category": "invalid_request",
}
if isinstance(exc, timeout_cls) or isinstance(exc, api_connection_cls):
return {
"error": "Unable to reach the model provider right now. Please try again.",
"error_category": "provider_unreachable",
}
return {
"error": "An error occurred while processing your request. Please try again.",
"error_category": "unknown_error",
}
def _safe_get(obj, key, default=None):
if obj is None:
return default
@@ -90,9 +173,20 @@ def is_chat_provider_available(provider_id):
return normalized_provider in CHAT_PROVIDER_CONFIG
def get_provider_catalog():
def get_provider_catalog(user=None):
seen = set()
catalog = []
user_key_providers = set()
instance_provider = (
_normalize_provider_id(settings.VOYAGE_AI_PROVIDER)
if settings.VOYAGE_AI_PROVIDER
else None
)
instance_has_key = bool(settings.VOYAGE_AI_API_KEY)
if user:
user_key_providers = set(
UserAPIKey.objects.filter(user=user).values_list("provider", flat=True)
)
for provider_id in getattr(litellm, "provider_list", []):
normalized_provider = _normalize_provider_id(provider_id)
@@ -110,6 +204,9 @@ def get_provider_catalog():
"needs_api_key": provider_config["needs_api_key"],
"default_model": provider_config["default_model"],
"api_base": provider_config["api_base"],
"instance_configured": instance_has_key
and normalized_provider == instance_provider,
"user_configured": normalized_provider in user_key_providers,
}
)
continue
@@ -122,6 +219,9 @@ def get_provider_catalog():
"needs_api_key": None,
"default_model": None,
"api_base": None,
"instance_configured": instance_has_key
and normalized_provider == instance_provider,
"user_configured": normalized_provider in user_key_providers,
}
)
@@ -141,6 +241,9 @@ def get_provider_catalog():
"needs_api_key": provider_config["needs_api_key"],
"default_model": provider_config["default_model"],
"api_base": provider_config["api_base"],
"instance_configured": instance_has_key
and normalized_provider == instance_provider,
"user_configured": normalized_provider in user_key_providers,
}
)
@@ -154,9 +257,55 @@ def get_llm_api_key(user, provider):
key_obj = UserAPIKey.objects.get(user=user, provider=normalized_provider)
return key_obj.get_api_key()
except UserAPIKey.DoesNotExist:
if normalized_provider == _normalize_provider_id(settings.VOYAGE_AI_PROVIDER):
instance_api_key = (settings.VOYAGE_AI_API_KEY or "").strip()
if instance_api_key:
return instance_api_key
return None
def _format_interests(interests):
if isinstance(interests, list):
return ", ".join(interests)
return interests
def get_aggregated_preferences(collection):
"""Aggregate preferences from collection owner and shared users."""
from integrations.models import UserRecommendationPreferenceProfile
users = [collection.user] + list(collection.shared_with.all())
preferences = []
for member in users:
try:
profile = UserRecommendationPreferenceProfile.objects.get(user=member)
user_prefs = []
if profile.cuisines:
user_prefs.append(f"cuisines: {profile.cuisines}")
if profile.interests:
user_prefs.append(f"interests: {_format_interests(profile.interests)}")
if profile.trip_style:
user_prefs.append(f"style: {profile.trip_style}")
if profile.notes:
user_prefs.append(f"notes: {profile.notes}")
if user_prefs:
preferences.append(f"- **{member.username}**: {', '.join(user_prefs)}")
except UserRecommendationPreferenceProfile.DoesNotExist:
continue
if preferences:
return (
"\n\n## Party Preferences\n"
+ "\n".join(preferences)
+ "\n\nNote: Consider all travelers' preferences when making recommendations."
)
return ""
def get_system_prompt(user, collection=None):
"""Build the system prompt with user context."""
from integrations.models import UserRecommendationPreferenceProfile
@@ -181,26 +330,37 @@ When modifying itineraries:
Be conversational, helpful, and enthusiastic about travel. Keep responses concise but informative."""
try:
profile = UserRecommendationPreferenceProfile.objects.get(user=user)
prefs = []
if profile.cuisines:
prefs.append(f"Cuisine preferences: {profile.cuisines}")
if profile.interests:
prefs.append(f"Interests: {profile.interests}")
if profile.trip_style:
prefs.append(f"Travel style: {profile.trip_style}")
if profile.notes:
prefs.append(f"Additional notes: {profile.notes}")
if prefs:
base_prompt += "\n\nUser preferences:\n" + "\n".join(prefs)
except UserRecommendationPreferenceProfile.DoesNotExist:
pass
if collection and collection.shared_with.count() > 0:
base_prompt += get_aggregated_preferences(collection)
else:
try:
profile = UserRecommendationPreferenceProfile.objects.get(user=user)
preference_lines = []
if profile.cuisines:
preference_lines.append(
f"🍽️ **Cuisine Preferences**: {profile.cuisines}"
)
if profile.interests:
preference_lines.append(
f"🎯 **Interests**: {_format_interests(profile.interests)}"
)
if profile.trip_style:
preference_lines.append(f"✈️ **Travel Style**: {profile.trip_style}")
if profile.notes:
preference_lines.append(f"📝 **Additional Notes**: {profile.notes}")
if preference_lines:
base_prompt += "\n\n## Traveler Preferences\n" + "\n".join(
preference_lines
)
except UserRecommendationPreferenceProfile.DoesNotExist:
pass
return base_prompt
async def stream_chat_completion(user, messages, provider, tools=None):
async def stream_chat_completion(user, messages, provider, tools=None, model=None):
"""Stream a chat completion using LiteLLM.
Yields SSE-formatted strings.
@@ -215,6 +375,7 @@ async def stream_chat_completion(user, messages, provider, tools=None):
return
api_key = get_llm_api_key(user, normalized_provider)
if provider_config["needs_api_key"] and not api_key:
payload = {
"error": f"No API key found for provider: {normalized_provider}. Please add one in Settings."
@@ -222,14 +383,31 @@ async def stream_chat_completion(user, messages, provider, tools=None):
yield f"data: {json.dumps(payload)}\n\n"
return
if not _is_model_override_compatible(normalized_provider, provider_config, model):
payload = {
"error": "The selected model is incompatible with this provider. Choose a model for the selected provider and try again.",
"error_category": "invalid_model_for_provider",
}
yield f"data: {json.dumps(payload)}\n\n"
return
completion_kwargs = {
"model": provider_config["default_model"],
"model": model
or (
settings.VOYAGE_AI_MODEL
if normalized_provider
== _normalize_provider_id(settings.VOYAGE_AI_PROVIDER)
and settings.VOYAGE_AI_MODEL
else None
)
or provider_config["default_model"],
"messages": messages,
"tools": tools,
"tool_choice": "auto" if tools else None,
"stream": True,
"api_key": api_key,
}
if tools:
completion_kwargs["tools"] = tools
completion_kwargs["tool_choice"] = "auto"
if provider_config["api_base"]:
completion_kwargs["api_base"] = provider_config["api_base"]
@@ -271,6 +449,7 @@ async def stream_chat_completion(user, messages, provider, tools=None):
yield f"data: {json.dumps(chunk_data)}\n\n"
yield "data: [DONE]\n\n"
except Exception:
except Exception as exc:
logger.exception("LLM streaming error")
yield f"data: {json.dumps({'error': 'An error occurred while processing your request. Please try again.'})}\n\n"
payload = _safe_error_payload(exc)
yield f"data: {json.dumps(payload)}\n\n"

View File

@@ -1,7 +1,12 @@
from django.urls import include, path
from rest_framework.routers import DefaultRouter
from .views import ChatProviderCatalogViewSet, ChatViewSet
from .views import (
CapabilitiesView,
ChatProviderCatalogViewSet,
ChatViewSet,
DaySuggestionsView,
)
router = DefaultRouter()
router.register(r"conversations", ChatViewSet, basename="chat-conversation")
@@ -11,4 +16,6 @@ router.register(
urlpatterns = [
path("", include(router.urls)),
path("capabilities/", CapabilitiesView.as_view(), name="chat-capabilities"),
path("suggestions/day/", DaySuggestionsView.as_view(), name="chat-day-suggestions"),
]

View File

@@ -0,0 +1,328 @@
import asyncio
import json
from asgiref.sync import sync_to_async
from adventures.models import Collection
from django.http import StreamingHttpResponse
from rest_framework import status, viewsets
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from ..agent_tools import AGENT_TOOLS, execute_tool, serialize_tool_result
from ..llm_client import (
get_provider_catalog,
get_system_prompt,
is_chat_provider_available,
stream_chat_completion,
)
from ..models import ChatConversation, ChatMessage
from ..serializers import ChatConversationSerializer
class ChatViewSet(viewsets.ModelViewSet):
serializer_class = ChatConversationSerializer
permission_classes = [IsAuthenticated]
def get_queryset(self):
return ChatConversation.objects.filter(user=self.request.user).prefetch_related(
"messages"
)
def list(self, request, *args, **kwargs):
conversations = self.get_queryset().only("id", "title", "updated_at")
data = [
{
"id": str(conversation.id),
"title": conversation.title,
"updated_at": conversation.updated_at,
}
for conversation in conversations
]
return Response(data)
def create(self, request, *args, **kwargs):
conversation = ChatConversation.objects.create(
user=request.user,
title=(request.data.get("title") or "").strip(),
)
serializer = self.get_serializer(conversation)
return Response(serializer.data, status=status.HTTP_201_CREATED)
def _build_llm_messages(self, conversation, user, system_prompt=None):
messages = [
{
"role": "system",
"content": system_prompt or get_system_prompt(user),
}
]
for message in conversation.messages.all().order_by("created_at"):
payload = {
"role": message.role,
"content": message.content,
}
if message.role == "assistant" and message.tool_calls:
payload["tool_calls"] = message.tool_calls
if message.role == "tool":
payload["tool_call_id"] = message.tool_call_id
payload["name"] = message.name
messages.append(payload)
return messages
def _async_to_sync_generator(self, async_gen):
loop = asyncio.new_event_loop()
try:
while True:
try:
yield loop.run_until_complete(async_gen.__anext__())
except StopAsyncIteration:
break
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
@staticmethod
def _merge_tool_call_delta(accumulator, tool_calls_delta):
for idx, tool_call in enumerate(tool_calls_delta or []):
idx = tool_call.get("index", idx)
while len(accumulator) <= idx:
accumulator.append(
{
"id": None,
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
current = accumulator[idx]
if tool_call.get("id"):
current["id"] = tool_call.get("id")
if tool_call.get("type"):
current["type"] = tool_call.get("type")
function_data = tool_call.get("function") or {}
if function_data.get("name"):
current["function"]["name"] = function_data.get("name")
if function_data.get("arguments"):
current["function"]["arguments"] += function_data.get("arguments")
@action(detail=True, methods=["post"])
def send_message(self, request, pk=None):
conversation = self.get_object()
user_content = (request.data.get("message") or "").strip()
if not user_content:
return Response(
{"error": "message is required"},
status=status.HTTP_400_BAD_REQUEST,
)
provider = (request.data.get("provider") or "openai").strip().lower()
model = (request.data.get("model") or "").strip() or None
collection_id = request.data.get("collection_id")
collection_name = request.data.get("collection_name")
start_date = request.data.get("start_date")
end_date = request.data.get("end_date")
destination = request.data.get("destination")
if not is_chat_provider_available(provider):
return Response(
{"error": f"Provider is not available for chat: {provider}."},
status=status.HTTP_400_BAD_REQUEST,
)
context_parts = []
if collection_name:
context_parts.append(f"Trip: {collection_name}")
if destination:
context_parts.append(f"Destination: {destination}")
if start_date and end_date:
context_parts.append(f"Dates: {start_date} to {end_date}")
collection = None
if collection_id:
try:
requested_collection = Collection.objects.get(id=collection_id)
if (
requested_collection.user == request.user
or requested_collection.shared_with.filter(
id=request.user.id
).exists()
):
collection = requested_collection
except Collection.DoesNotExist:
pass
system_prompt = get_system_prompt(request.user, collection)
if context_parts:
system_prompt += "\n\n## Trip Context\n" + "\n".join(context_parts)
ChatMessage.objects.create(
conversation=conversation,
role="user",
content=user_content,
)
conversation.save(update_fields=["updated_at"])
if not conversation.title:
conversation.title = user_content[:120]
conversation.save(update_fields=["title", "updated_at"])
llm_messages = self._build_llm_messages(
conversation,
request.user,
system_prompt=system_prompt,
)
MAX_TOOL_ITERATIONS = 10
async def event_stream():
current_messages = list(llm_messages)
encountered_error = False
tool_iterations = 0
while tool_iterations < MAX_TOOL_ITERATIONS:
content_chunks = []
tool_calls_accumulator = []
async for chunk in stream_chat_completion(
request.user,
current_messages,
provider,
tools=AGENT_TOOLS,
model=model,
):
if not chunk.startswith("data: "):
yield chunk
continue
payload = chunk[len("data: ") :].strip()
if payload == "[DONE]":
continue
yield chunk
try:
data = json.loads(payload)
except json.JSONDecodeError:
continue
if data.get("error"):
encountered_error = True
break
if data.get("content"):
content_chunks.append(data["content"])
if data.get("tool_calls"):
self._merge_tool_call_delta(
tool_calls_accumulator,
data["tool_calls"],
)
if encountered_error:
break
assistant_content = "".join(content_chunks)
if tool_calls_accumulator:
assistant_with_tools = {
"role": "assistant",
"content": assistant_content,
"tool_calls": tool_calls_accumulator,
}
current_messages.append(assistant_with_tools)
await sync_to_async(
ChatMessage.objects.create, thread_sensitive=True
)(
conversation=conversation,
role="assistant",
content=assistant_content,
tool_calls=tool_calls_accumulator,
)
await sync_to_async(conversation.save, thread_sensitive=True)(
update_fields=["updated_at"]
)
for tool_call in tool_calls_accumulator:
function_payload = tool_call.get("function") or {}
function_name = function_payload.get("name") or ""
raw_arguments = function_payload.get("arguments") or "{}"
try:
arguments = json.loads(raw_arguments)
except json.JSONDecodeError:
arguments = {}
if not isinstance(arguments, dict):
arguments = {}
result = await sync_to_async(
execute_tool, thread_sensitive=True
)(
function_name,
request.user,
**arguments,
)
result_content = serialize_tool_result(result)
current_messages.append(
{
"role": "tool",
"tool_call_id": tool_call.get("id"),
"name": function_name,
"content": result_content,
}
)
await sync_to_async(
ChatMessage.objects.create, thread_sensitive=True
)(
conversation=conversation,
role="tool",
content=result_content,
tool_call_id=tool_call.get("id"),
name=function_name,
)
await sync_to_async(conversation.save, thread_sensitive=True)(
update_fields=["updated_at"]
)
tool_event = {
"tool_result": {
"tool_call_id": tool_call.get("id"),
"name": function_name,
"result": result,
}
}
yield f"data: {json.dumps(tool_event)}\n\n"
continue
await sync_to_async(ChatMessage.objects.create, thread_sensitive=True)(
conversation=conversation,
role="assistant",
content=assistant_content,
)
await sync_to_async(conversation.save, thread_sensitive=True)(
update_fields=["updated_at"]
)
yield "data: [DONE]\n\n"
break
response = StreamingHttpResponse(
streaming_content=self._async_to_sync_generator(event_stream()),
content_type="text/event-stream",
)
response["Cache-Control"] = "no-cache"
response["X-Accel-Buffering"] = "no"
return response
class ChatProviderCatalogViewSet(viewsets.ViewSet):
permission_classes = [IsAuthenticated]
def list(self, request):
return Response(get_provider_catalog(user=request.user))
from .capabilities import CapabilitiesView
from .day_suggestions import DaySuggestionsView

View File

@@ -0,0 +1,24 @@
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
from chat.agent_tools import get_tool_schemas
class CapabilitiesView(APIView):
permission_classes = [IsAuthenticated]
def get(self, request):
"""Return available AI capabilities/tools."""
tools = get_tool_schemas()
return Response(
{
"tools": [
{
"name": tool["function"]["name"],
"description": tool["function"]["description"],
}
for tool in tools
]
}
)

View File

@@ -0,0 +1,215 @@
import json
import re
import litellm
from django.shortcuts import get_object_or_404
from rest_framework import status
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
from adventures.models import Collection
from chat.agent_tools import search_places
from chat.llm_client import (
get_llm_api_key,
get_system_prompt,
is_chat_provider_available,
)
class DaySuggestionsView(APIView):
permission_classes = [IsAuthenticated]
def post(self, request):
collection_id = request.data.get("collection_id")
date = request.data.get("date")
category = request.data.get("category")
filters = request.data.get("filters", {}) or {}
location_context = request.data.get("location_context", "")
if not all([collection_id, date, category]):
return Response(
{"error": "collection_id, date, and category are required"},
status=status.HTTP_400_BAD_REQUEST,
)
valid_categories = ["restaurant", "activity", "event", "lodging"]
if category not in valid_categories:
return Response(
{"error": f"category must be one of: {', '.join(valid_categories)}"},
status=status.HTTP_400_BAD_REQUEST,
)
collection = get_object_or_404(Collection, id=collection_id)
if (
collection.user != request.user
and not collection.shared_with.filter(id=request.user.id).exists()
):
return Response(
{"error": "You don't have access to this collection"},
status=status.HTTP_403_FORBIDDEN,
)
location = location_context or self._get_collection_location(collection)
system_prompt = get_system_prompt(request.user, collection)
provider = "openai"
if not is_chat_provider_available(provider):
return Response(
{
"error": "AI suggestions are not available. Please configure an API key."
},
status=status.HTTP_503_SERVICE_UNAVAILABLE,
)
try:
places_context = self._get_places_context(request.user, category, location)
prompt = self._build_prompt(
category=category,
filters=filters,
location=location,
date=date,
collection=collection,
places_context=places_context,
)
suggestions = self._get_suggestions_from_llm(
system_prompt=system_prompt,
user_prompt=prompt,
user=request.user,
provider=provider,
)
return Response({"suggestions": suggestions}, status=status.HTTP_200_OK)
except Exception:
return Response(
{"error": "Failed to generate suggestions. Please try again."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def _get_collection_location(self, collection):
for loc in collection.locations.select_related("city", "country").all():
if loc.city:
city_name = getattr(loc.city, "name", str(loc.city))
country_name = getattr(loc.country, "name", "") if loc.country else ""
return ", ".join([x for x in [city_name, country_name] if x])
if loc.location:
return loc.location
if loc.name:
return loc.name
return "Unknown location"
def _build_prompt(
self,
category,
filters,
location,
date,
collection,
places_context="",
):
category_prompts = {
"restaurant": f"Find restaurant recommendations for {location}",
"activity": f"Find activity recommendations for {location}",
"event": f"Find event recommendations for {location} around {date}",
"lodging": f"Find lodging recommendations for {location}",
}
prompt = category_prompts.get(
category, f"Find {category} recommendations for {location}"
)
if filters:
filter_parts = []
if filters.get("cuisine_type"):
filter_parts.append(f"cuisine type: {filters['cuisine_type']}")
if filters.get("price_range"):
filter_parts.append(f"price range: {filters['price_range']}")
if filters.get("dietary"):
filter_parts.append(f"dietary restrictions: {filters['dietary']}")
if filters.get("activity_type"):
filter_parts.append(f"type: {filters['activity_type']}")
if filters.get("duration"):
filter_parts.append(f"duration: {filters['duration']}")
if filters.get("event_type"):
filter_parts.append(f"event type: {filters['event_type']}")
if filters.get("lodging_type"):
filter_parts.append(f"lodging type: {filters['lodging_type']}")
amenities = filters.get("amenities")
if isinstance(amenities, list) and amenities:
filter_parts.append(
f"amenities: {', '.join(str(x) for x in amenities)}"
)
if filter_parts:
prompt += f" with these preferences: {', '.join(filter_parts)}"
prompt += f". The trip date is {date}."
if collection.start_date or collection.end_date:
prompt += (
" Collection trip window: "
f"{collection.start_date or 'unknown'} to {collection.end_date or 'unknown'}."
)
if places_context:
prompt += f" Nearby place context: {places_context}."
prompt += (
" Return 3-5 specific suggestions as a JSON array."
" Each suggestion should have: name, description, why_fits, category, location, rating, price_level."
" Return ONLY valid JSON, no markdown, no surrounding text."
)
return prompt
def _get_places_context(self, user, category, location):
tool_category_map = {
"restaurant": "food",
"activity": "tourism",
"event": "tourism",
"lodging": "lodging",
}
result = search_places(
user,
location=location,
category=tool_category_map.get(category, "tourism"),
radius=8,
)
if result.get("error"):
return ""
entries = []
for place in result.get("results", [])[:5]:
name = place.get("name")
address = place.get("address") or ""
if name:
entries.append(f"{name} ({address})" if address else name)
return "; ".join(entries)
def _get_suggestions_from_llm(self, system_prompt, user_prompt, user, provider):
api_key = get_llm_api_key(user, provider)
if not api_key:
raise ValueError("No API key available")
response = litellm.completion(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
api_key=api_key,
temperature=0.7,
max_tokens=1000,
)
content = (response.choices[0].message.content or "").strip()
try:
json_match = re.search(r"\[.*\]", content, re.DOTALL)
parsed = (
json.loads(json_match.group())
if json_match
else json.loads(content or "[]")
)
suggestions = parsed if isinstance(parsed, list) else [parsed]
return suggestions[:5]
except json.JSONDecodeError:
return []

View File

@@ -0,0 +1,52 @@
# Generated by Django 5.2.12 on 2026-03-08
import django.db.models.deletion
import uuid
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("integrations", "0007_userapikey_userrecommendationpreferenceprofile"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name="UserAISettings",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
(
"preferred_provider",
models.CharField(blank=True, max_length=100, null=True),
),
(
"preferred_model",
models.CharField(blank=True, max_length=100, null=True),
),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"user",
models.OneToOneField(
on_delete=django.db.models.deletion.CASCADE,
related_name="ai_settings",
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"verbose_name": "User AI Settings",
"verbose_name_plural": "User AI Settings",
},
),
]

View File

@@ -124,3 +124,23 @@ class UserRecommendationPreferenceProfile(models.Model):
notes = models.TextField(blank=True, null=True)
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class UserAISettings(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
user = models.OneToOneField(
settings.AUTH_USER_MODEL,
on_delete=models.CASCADE,
related_name="ai_settings",
)
preferred_provider = models.CharField(max_length=100, blank=True, null=True)
preferred_model = models.CharField(max_length=100, blank=True, null=True)
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
verbose_name = "User AI Settings"
verbose_name_plural = "User AI Settings"
def __str__(self):
return f"AI Settings for {self.user.username}"

View File

@@ -3,6 +3,7 @@ from django.db import IntegrityError
from .models import (
EncryptionConfigurationError,
ImmichIntegration,
UserAISettings,
UserAPIKey,
UserRecommendationPreferenceProfile,
)
@@ -98,3 +99,16 @@ class UserRecommendationPreferenceProfileSerializer(serializers.ModelSerializer)
"updated_at",
]
read_only_fields = ["id", "created_at", "updated_at"]
class UserAISettingsSerializer(serializers.ModelSerializer):
class Meta:
model = UserAISettings
fields = [
"id",
"preferred_provider",
"preferred_model",
"created_at",
"updated_at",
]
read_only_fields = ["id", "created_at", "updated_at"]

View File

@@ -6,6 +6,7 @@ from integrations.views import (
StravaIntegrationView,
WandererIntegrationViewSet,
UserAPIKeyViewSet,
UserAISettingsViewSet,
UserRecommendationPreferenceProfileViewSet,
)
@@ -22,6 +23,7 @@ router.register(
UserRecommendationPreferenceProfileViewSet,
basename="user-recommendation-preferences",
)
router.register(r"ai-settings", UserAISettingsViewSet, basename="user-ai-settings")
# Include the router URLs
urlpatterns = [

View File

@@ -4,3 +4,4 @@ from .strava_view import StravaIntegrationView
from .wanderer_view import WandererIntegrationViewSet
from .user_api_key_view import UserAPIKeyViewSet
from .recommendation_profile_view import UserRecommendationPreferenceProfileViewSet
from .ai_settings_view import UserAISettingsViewSet

View File

@@ -0,0 +1,39 @@
from rest_framework import status, viewsets
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from integrations.models import UserAISettings
from integrations.serializers import UserAISettingsSerializer
class UserAISettingsViewSet(viewsets.ModelViewSet):
serializer_class = UserAISettingsSerializer
permission_classes = [IsAuthenticated]
def get_queryset(self):
return UserAISettings.objects.filter(user=self.request.user)
def list(self, request, *args, **kwargs):
instance = self.get_queryset().first()
if not instance:
return Response([], status=status.HTTP_200_OK)
serializer = self.get_serializer(instance)
return Response([serializer.data], status=status.HTTP_200_OK)
def perform_create(self, serializer):
existing = UserAISettings.objects.filter(user=self.request.user).first()
if existing:
for field, value in serializer.validated_data.items():
setattr(existing, field, value)
existing.save()
self._upserted_instance = existing
return
self._upserted_instance = serializer.save(user=self.request.user)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
output = self.get_serializer(self._upserted_instance)
return Response(output.data, status=status.HTTP_200_OK)

View File

@@ -403,6 +403,11 @@ OSRM_BASE_URL = getenv("OSRM_BASE_URL", "https://router.project-osrm.org")
FIELD_ENCRYPTION_KEY = getenv("FIELD_ENCRYPTION_KEY", "")
# Voyage AI Configuration
VOYAGE_AI_PROVIDER = getenv("VOYAGE_AI_PROVIDER", "openai")
VOYAGE_AI_MODEL = getenv("VOYAGE_AI_MODEL", "gpt-4o-mini")
VOYAGE_AI_API_KEY = getenv("VOYAGE_AI_API_KEY", "")
DJANGO_MCP_ENDPOINT = getenv("DJANGO_MCP_ENDPOINT", "api/mcp")
DJANGO_MCP_AUTHENTICATION_CLASSES = [
"rest_framework.authentication.TokenAuthentication",

View File

@@ -34,3 +34,4 @@ requests>=2.32.5
cryptography>=46.0.5
django-mcp-server>=0.5.7
litellm>=1.72.3
duckduckgo-search>=4.0.0