feat(ai): implement agent-redesign plan with enhanced AI travel features
Phase 1 - Configuration Infrastructure (WS1): - Add instance-level AI env vars (VOYAGE_AI_PROVIDER, VOYAGE_AI_MODEL, VOYAGE_AI_API_KEY) - Implement fallback chain: user key → instance key → error - Add UserAISettings model for per-user provider/model preferences - Enhance provider catalog with instance_configured and user_configured flags - Optimize provider catalog to avoid N+1 queries Phase 1 - User Preference Learning (WS2): - Add Travel Preferences tab to Settings page - Improve preference formatting in system prompt with emoji headers - Add multi-user preference aggregation for shared collections Phase 2 - Day-Level Suggestions Modal (WS3): - Create ItinerarySuggestionModal with 3-step flow (category → filters → results) - Add AI suggestions button to itinerary Add dropdown - Support restaurant, activity, event, and lodging categories - Backend endpoint POST /api/chat/suggestions/day/ with context-aware prompts Phase 3 - Collection-Level Chat Improvements (WS4): - Inject collection context (destination, dates) into chat system prompt - Add quick action buttons for common queries - Add 'Add to itinerary' button on search_places results - Update chat UI with travel-themed branding and improved tool result cards Phase 3 - Web Search Capability (WS5): - Add web_search agent tool using DuckDuckGo - Support location_context parameter for biased results - Handle rate limiting gracefully Phase 4 - Extensibility Architecture (WS6): - Implement decorator-based @agent_tool registry - Convert existing tools to use decorators - Add GET /api/chat/capabilities/ endpoint for tool discovery - Refactor execute_tool() to use registry pattern
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import inspect
|
||||
import logging
|
||||
from datetime import date as date_cls
|
||||
|
||||
@@ -10,117 +11,50 @@ from adventures.models import Collection, CollectionItineraryItem, Location
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AGENT_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_places",
|
||||
"description": "Search for places of interest near a location. Returns tourist attractions, restaurants, hotels, etc.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "Location name or address to search near",
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["tourism", "food", "lodging"],
|
||||
"description": "Category of places",
|
||||
},
|
||||
"radius": {
|
||||
"type": "number",
|
||||
"description": "Search radius in km (default 10)",
|
||||
},
|
||||
_REGISTERED_TOOLS = {}
|
||||
_TOOL_SCHEMAS = []
|
||||
|
||||
|
||||
def agent_tool(name: str, description: str, parameters: dict):
|
||||
"""Decorator to register a function as an agent tool."""
|
||||
|
||||
def decorator(func):
|
||||
_REGISTERED_TOOLS[name] = func
|
||||
|
||||
required = [k for k, v in parameters.items() if v.get("required", False)]
|
||||
props = {
|
||||
k: {kk: vv for kk, vv in v.items() if kk != "required"}
|
||||
for k, v in parameters.items()
|
||||
}
|
||||
|
||||
schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": props,
|
||||
"required": required,
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_trips",
|
||||
"description": "List the user's trip collections with dates and descriptions",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_trip_details",
|
||||
"description": "Get full details of a trip including all itinerary items, locations, transportation, and lodging",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"collection_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the collection/trip",
|
||||
}
|
||||
},
|
||||
"required": ["collection_id"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_to_itinerary",
|
||||
"description": "Add a new location to a trip's itinerary on a specific date",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"collection_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the collection/trip",
|
||||
},
|
||||
"name": {"type": "string", "description": "Name of the location"},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Description of why to visit",
|
||||
},
|
||||
"latitude": {
|
||||
"type": "number",
|
||||
"description": "Latitude coordinate",
|
||||
},
|
||||
"longitude": {
|
||||
"type": "number",
|
||||
"description": "Longitude coordinate",
|
||||
},
|
||||
"date": {
|
||||
"type": "string",
|
||||
"description": "Date in YYYY-MM-DD format",
|
||||
},
|
||||
"location_address": {
|
||||
"type": "string",
|
||||
"description": "Full address of the location",
|
||||
},
|
||||
},
|
||||
"required": ["collection_id", "name", "latitude", "longitude"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get temperature/weather data for a location on specific dates",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {"type": "number", "description": "Latitude"},
|
||||
"longitude": {"type": "number", "description": "Longitude"},
|
||||
"dates": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of dates in YYYY-MM-DD format",
|
||||
},
|
||||
},
|
||||
"required": ["latitude", "longitude", "dates"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
_TOOL_SCHEMAS.append(schema)
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_tool_schemas() -> list:
|
||||
"""Return all registered tool schemas for LLM."""
|
||||
return _TOOL_SCHEMAS.copy()
|
||||
|
||||
|
||||
def get_registered_tools() -> dict:
|
||||
"""Return all registered tool functions."""
|
||||
return _REGISTERED_TOOLS.copy()
|
||||
|
||||
|
||||
NOMINATIM_URL = "https://nominatim.openstreetmap.org/search"
|
||||
OVERPASS_URL = "https://overpass-api.de/api/interpreter"
|
||||
@@ -162,14 +96,39 @@ def _parse_address(tags):
|
||||
return ", ".join([p for p in parts if p])
|
||||
|
||||
|
||||
def search_places(user, **kwargs):
|
||||
@agent_tool(
|
||||
name="search_places",
|
||||
description="Search for places of interest near a location. Returns tourist attractions, restaurants, hotels, etc.",
|
||||
parameters={
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "Location name or address to search near",
|
||||
"required": True,
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["tourism", "food", "lodging"],
|
||||
"description": "Category of places",
|
||||
},
|
||||
"radius": {
|
||||
"type": "number",
|
||||
"description": "Search radius in km (default 10)",
|
||||
},
|
||||
},
|
||||
)
|
||||
def search_places(
|
||||
user,
|
||||
location: str | None = None,
|
||||
category: str = "tourism",
|
||||
radius: float = 10,
|
||||
):
|
||||
try:
|
||||
location_name = kwargs.get("location")
|
||||
location_name = location
|
||||
if not location_name:
|
||||
return {"error": "location is required"}
|
||||
|
||||
category = kwargs.get("category") or "tourism"
|
||||
radius_km = float(kwargs.get("radius") or 10)
|
||||
category = category or "tourism"
|
||||
radius_km = float(radius or 10)
|
||||
radius_meters = max(500, min(int(radius_km * 1000), 50000))
|
||||
|
||||
geocode_resp = requests.get(
|
||||
@@ -240,7 +199,12 @@ def search_places(user, **kwargs):
|
||||
return {"error": "An unexpected error occurred during place search"}
|
||||
|
||||
|
||||
def list_trips(user, **kwargs):
|
||||
@agent_tool(
|
||||
name="list_trips",
|
||||
description="List the user's trip collections with dates and descriptions",
|
||||
parameters={},
|
||||
)
|
||||
def list_trips(user):
|
||||
try:
|
||||
collections = Collection.objects.filter(user=user).prefetch_related("locations")
|
||||
trips = []
|
||||
@@ -265,9 +229,87 @@ def list_trips(user, **kwargs):
|
||||
return {"error": "An unexpected error occurred while listing trips"}
|
||||
|
||||
|
||||
def get_trip_details(user, **kwargs):
|
||||
@agent_tool(
|
||||
name="web_search",
|
||||
description="Search the web for current information about destinations, events, prices, weather, or any real-time travel information. Use this when you need up-to-date information that may not be in your training data.",
|
||||
parameters={
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query (e.g., 'best restaurants Paris 2024', 'weather Tokyo March')",
|
||||
"required": True,
|
||||
},
|
||||
"location_context": {
|
||||
"type": "string",
|
||||
"description": "Optional location to bias search results (e.g., 'Paris, France')",
|
||||
},
|
||||
},
|
||||
)
|
||||
def web_search(user, query: str, location_context: str | None = None) -> dict:
|
||||
"""
|
||||
Search the web for current information about destinations, events, prices, etc.
|
||||
|
||||
Args:
|
||||
user: The user making the request (for auth/logging)
|
||||
query: The search query
|
||||
location_context: Optional location to bias results
|
||||
|
||||
Returns:
|
||||
dict with 'results' list or 'error' string
|
||||
"""
|
||||
if not query:
|
||||
return {"error": "query is required", "results": []}
|
||||
|
||||
try:
|
||||
from duckduckgo_search import DDGS # type: ignore[import-not-found]
|
||||
|
||||
full_query = query
|
||||
if location_context:
|
||||
full_query = f"{query} {location_context}"
|
||||
|
||||
with DDGS() as ddgs:
|
||||
results = list(ddgs.text(full_query, max_results=5))
|
||||
|
||||
formatted = []
|
||||
for result in results:
|
||||
formatted.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("body", ""),
|
||||
"url": result.get("href", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {"results": formatted}
|
||||
|
||||
except ImportError:
|
||||
return {
|
||||
"error": "Web search is not available (duckduckgo-search not installed)",
|
||||
"results": [],
|
||||
}
|
||||
except Exception as exc:
|
||||
error_str = str(exc).lower()
|
||||
if "rate" in error_str or "limit" in error_str:
|
||||
return {
|
||||
"error": "Search rate limit reached. Please wait a moment and try again.",
|
||||
"results": [],
|
||||
}
|
||||
logger.error("Web search error: %s", exc)
|
||||
return {"error": "Web search failed. Please try again.", "results": []}
|
||||
|
||||
|
||||
@agent_tool(
|
||||
name="get_trip_details",
|
||||
description="Get full details of a trip including all itinerary items, locations, transportation, and lodging",
|
||||
parameters={
|
||||
"collection_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the collection/trip",
|
||||
"required": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
def get_trip_details(user, collection_id: str | None = None):
|
||||
try:
|
||||
collection_id = kwargs.get("collection_id")
|
||||
if not collection_id:
|
||||
return {"error": "collection_id is required"}
|
||||
|
||||
@@ -354,16 +396,55 @@ def get_trip_details(user, **kwargs):
|
||||
return {"error": "An unexpected error occurred while fetching trip details"}
|
||||
|
||||
|
||||
def add_to_itinerary(user, **kwargs):
|
||||
@agent_tool(
|
||||
name="add_to_itinerary",
|
||||
description="Add a new location to a trip's itinerary on a specific date",
|
||||
parameters={
|
||||
"collection_id": {
|
||||
"type": "string",
|
||||
"description": "UUID of the collection/trip",
|
||||
"required": True,
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name of the location",
|
||||
"required": True,
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Description of why to visit",
|
||||
},
|
||||
"latitude": {
|
||||
"type": "number",
|
||||
"description": "Latitude coordinate",
|
||||
"required": True,
|
||||
},
|
||||
"longitude": {
|
||||
"type": "number",
|
||||
"description": "Longitude coordinate",
|
||||
"required": True,
|
||||
},
|
||||
"date": {
|
||||
"type": "string",
|
||||
"description": "Date in YYYY-MM-DD format",
|
||||
},
|
||||
"location_address": {
|
||||
"type": "string",
|
||||
"description": "Full address of the location",
|
||||
},
|
||||
},
|
||||
)
|
||||
def add_to_itinerary(
|
||||
user,
|
||||
collection_id: str | None = None,
|
||||
name: str | None = None,
|
||||
latitude: float | None = None,
|
||||
longitude: float | None = None,
|
||||
description: str | None = None,
|
||||
date: str | None = None,
|
||||
location_address: str | None = None,
|
||||
):
|
||||
try:
|
||||
collection_id = kwargs.get("collection_id")
|
||||
name = kwargs.get("name")
|
||||
latitude = kwargs.get("latitude")
|
||||
longitude = kwargs.get("longitude")
|
||||
description = kwargs.get("description")
|
||||
location_address = kwargs.get("location_address")
|
||||
date = kwargs.get("date")
|
||||
|
||||
if not collection_id or not name or latitude is None or longitude is None:
|
||||
return {
|
||||
"error": "collection_id, name, latitude, and longitude are required"
|
||||
@@ -479,16 +560,34 @@ def _fetch_temperature_for_date(latitude, longitude, date_value):
|
||||
}
|
||||
|
||||
|
||||
def get_weather(user, **kwargs):
|
||||
@agent_tool(
|
||||
name="get_weather",
|
||||
description="Get temperature/weather data for a location on specific dates",
|
||||
parameters={
|
||||
"latitude": {"type": "number", "description": "Latitude", "required": True},
|
||||
"longitude": {
|
||||
"type": "number",
|
||||
"description": "Longitude",
|
||||
"required": True,
|
||||
},
|
||||
"dates": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of dates in YYYY-MM-DD format",
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
def get_weather(user, latitude=None, longitude=None, dates=None):
|
||||
try:
|
||||
raw_latitude = kwargs.get("latitude")
|
||||
raw_longitude = kwargs.get("longitude")
|
||||
raw_latitude = latitude
|
||||
raw_longitude = longitude
|
||||
if raw_latitude is None or raw_longitude is None:
|
||||
return {"error": "latitude and longitude are required"}
|
||||
|
||||
latitude = float(raw_latitude)
|
||||
longitude = float(raw_longitude)
|
||||
dates = kwargs.get("dates") or []
|
||||
dates = dates or []
|
||||
|
||||
if not isinstance(dates, list) or not dates:
|
||||
return {"error": "dates must be a non-empty list"}
|
||||
@@ -509,44 +608,24 @@ def get_weather(user, **kwargs):
|
||||
return {"error": "An unexpected error occurred while fetching weather data"}
|
||||
|
||||
|
||||
ALLOWED_KWARGS = {
|
||||
"search_places": {"location", "category", "radius"},
|
||||
"list_trips": set(),
|
||||
"get_trip_details": {"collection_id"},
|
||||
"add_to_itinerary": {
|
||||
"collection_id",
|
||||
"name",
|
||||
"description",
|
||||
"latitude",
|
||||
"longitude",
|
||||
"date",
|
||||
"location_address",
|
||||
},
|
||||
"get_weather": {"latitude", "longitude", "dates"},
|
||||
}
|
||||
|
||||
|
||||
def execute_tool(tool_name, user, **kwargs):
|
||||
tool_map = {
|
||||
"search_places": search_places,
|
||||
"list_trips": list_trips,
|
||||
"get_trip_details": get_trip_details,
|
||||
"add_to_itinerary": add_to_itinerary,
|
||||
"get_weather": get_weather,
|
||||
}
|
||||
|
||||
tool_fn = tool_map.get(tool_name)
|
||||
if not tool_fn:
|
||||
if tool_name not in _REGISTERED_TOOLS:
|
||||
return {"error": f"Unknown tool: {tool_name}"}
|
||||
|
||||
allowed = ALLOWED_KWARGS.get(tool_name, set())
|
||||
tool_fn = _REGISTERED_TOOLS[tool_name]
|
||||
|
||||
sig = inspect.signature(tool_fn)
|
||||
allowed = set(sig.parameters.keys()) - {"user"}
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
|
||||
|
||||
try:
|
||||
return tool_fn(user, **filtered_kwargs)
|
||||
except Exception:
|
||||
logger.exception("Tool execution failed: %s", tool_name)
|
||||
return {"error": "An unexpected error occurred while executing the tool"}
|
||||
return tool_fn(user=user, **filtered_kwargs)
|
||||
except Exception as exc:
|
||||
logger.exception("Tool %s failed", tool_name)
|
||||
return {"error": str(exc)}
|
||||
|
||||
|
||||
AGENT_TOOLS = get_tool_schemas()
|
||||
|
||||
|
||||
def serialize_tool_result(result):
|
||||
|
||||
@@ -2,11 +2,23 @@ import json
|
||||
import logging
|
||||
|
||||
import litellm
|
||||
from django.conf import settings
|
||||
|
||||
from integrations.models import UserAPIKey
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROVIDER_MODEL_PREFIX = {
|
||||
"openai": "openai",
|
||||
"anthropic": "anthropic",
|
||||
"gemini": "gemini",
|
||||
"ollama": "ollama",
|
||||
"groq": "groq",
|
||||
"mistral": "mistral",
|
||||
"github_models": "github",
|
||||
"openrouter": "openrouter",
|
||||
}
|
||||
|
||||
CHAT_PROVIDER_CONFIG = {
|
||||
"openai": {
|
||||
"label": "OpenAI",
|
||||
@@ -59,12 +71,83 @@ CHAT_PROVIDER_CONFIG = {
|
||||
"opencode_zen": {
|
||||
"label": "OpenCode Zen",
|
||||
"needs_api_key": True,
|
||||
"default_model": "openai/gpt-4o-mini",
|
||||
# Chosen from OpenCode Zen compatible OpenAI-routed models per
|
||||
# opencode_zen connection research (see .memory/research).
|
||||
"default_model": "openai/gpt-5-nano",
|
||||
"api_base": "https://opencode.ai/zen/v1",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _is_model_override_compatible(provider, provider_config, model):
|
||||
"""Validate model/provider compatibility when strict checks are safe.
|
||||
|
||||
For providers with a custom api_base gateway, skip strict prefix checks since
|
||||
gateway routing may legitimately accept cross-provider prefixes.
|
||||
"""
|
||||
if not model or provider_config.get("api_base"):
|
||||
return True
|
||||
|
||||
if "/" not in model:
|
||||
return True
|
||||
|
||||
expected_prefix = PROVIDER_MODEL_PREFIX.get(provider)
|
||||
if not expected_prefix:
|
||||
default_model = provider_config.get("default_model") or ""
|
||||
if "/" in default_model:
|
||||
expected_prefix = default_model.split("/", 1)[0]
|
||||
|
||||
if not expected_prefix:
|
||||
return True
|
||||
|
||||
return model.startswith(f"{expected_prefix}/")
|
||||
|
||||
|
||||
def _safe_error_payload(exc):
|
||||
exceptions = getattr(litellm, "exceptions", None)
|
||||
not_found_cls = getattr(exceptions, "NotFoundError", tuple())
|
||||
auth_cls = getattr(exceptions, "AuthenticationError", tuple())
|
||||
rate_limit_cls = getattr(exceptions, "RateLimitError", tuple())
|
||||
bad_request_cls = getattr(exceptions, "BadRequestError", tuple())
|
||||
timeout_cls = getattr(exceptions, "Timeout", tuple())
|
||||
api_connection_cls = getattr(exceptions, "APIConnectionError", tuple())
|
||||
|
||||
if isinstance(exc, not_found_cls):
|
||||
return {
|
||||
"error": "The selected model is unavailable for this provider. Choose a different model and try again.",
|
||||
"error_category": "model_not_found",
|
||||
}
|
||||
|
||||
if isinstance(exc, auth_cls):
|
||||
return {
|
||||
"error": "Authentication with the model provider failed. Verify your API key in Settings and try again.",
|
||||
"error_category": "authentication_failed",
|
||||
}
|
||||
|
||||
if isinstance(exc, rate_limit_cls):
|
||||
return {
|
||||
"error": "The model provider rate limit was reached. Please wait and try again.",
|
||||
"error_category": "rate_limited",
|
||||
}
|
||||
|
||||
if isinstance(exc, bad_request_cls):
|
||||
return {
|
||||
"error": "The model provider rejected this request. Check your selected model and try again.",
|
||||
"error_category": "invalid_request",
|
||||
}
|
||||
|
||||
if isinstance(exc, timeout_cls) or isinstance(exc, api_connection_cls):
|
||||
return {
|
||||
"error": "Unable to reach the model provider right now. Please try again.",
|
||||
"error_category": "provider_unreachable",
|
||||
}
|
||||
|
||||
return {
|
||||
"error": "An error occurred while processing your request. Please try again.",
|
||||
"error_category": "unknown_error",
|
||||
}
|
||||
|
||||
|
||||
def _safe_get(obj, key, default=None):
|
||||
if obj is None:
|
||||
return default
|
||||
@@ -90,9 +173,20 @@ def is_chat_provider_available(provider_id):
|
||||
return normalized_provider in CHAT_PROVIDER_CONFIG
|
||||
|
||||
|
||||
def get_provider_catalog():
|
||||
def get_provider_catalog(user=None):
|
||||
seen = set()
|
||||
catalog = []
|
||||
user_key_providers = set()
|
||||
instance_provider = (
|
||||
_normalize_provider_id(settings.VOYAGE_AI_PROVIDER)
|
||||
if settings.VOYAGE_AI_PROVIDER
|
||||
else None
|
||||
)
|
||||
instance_has_key = bool(settings.VOYAGE_AI_API_KEY)
|
||||
if user:
|
||||
user_key_providers = set(
|
||||
UserAPIKey.objects.filter(user=user).values_list("provider", flat=True)
|
||||
)
|
||||
|
||||
for provider_id in getattr(litellm, "provider_list", []):
|
||||
normalized_provider = _normalize_provider_id(provider_id)
|
||||
@@ -110,6 +204,9 @@ def get_provider_catalog():
|
||||
"needs_api_key": provider_config["needs_api_key"],
|
||||
"default_model": provider_config["default_model"],
|
||||
"api_base": provider_config["api_base"],
|
||||
"instance_configured": instance_has_key
|
||||
and normalized_provider == instance_provider,
|
||||
"user_configured": normalized_provider in user_key_providers,
|
||||
}
|
||||
)
|
||||
continue
|
||||
@@ -122,6 +219,9 @@ def get_provider_catalog():
|
||||
"needs_api_key": None,
|
||||
"default_model": None,
|
||||
"api_base": None,
|
||||
"instance_configured": instance_has_key
|
||||
and normalized_provider == instance_provider,
|
||||
"user_configured": normalized_provider in user_key_providers,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -141,6 +241,9 @@ def get_provider_catalog():
|
||||
"needs_api_key": provider_config["needs_api_key"],
|
||||
"default_model": provider_config["default_model"],
|
||||
"api_base": provider_config["api_base"],
|
||||
"instance_configured": instance_has_key
|
||||
and normalized_provider == instance_provider,
|
||||
"user_configured": normalized_provider in user_key_providers,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -154,9 +257,55 @@ def get_llm_api_key(user, provider):
|
||||
key_obj = UserAPIKey.objects.get(user=user, provider=normalized_provider)
|
||||
return key_obj.get_api_key()
|
||||
except UserAPIKey.DoesNotExist:
|
||||
if normalized_provider == _normalize_provider_id(settings.VOYAGE_AI_PROVIDER):
|
||||
instance_api_key = (settings.VOYAGE_AI_API_KEY or "").strip()
|
||||
if instance_api_key:
|
||||
return instance_api_key
|
||||
return None
|
||||
|
||||
|
||||
def _format_interests(interests):
|
||||
if isinstance(interests, list):
|
||||
return ", ".join(interests)
|
||||
return interests
|
||||
|
||||
|
||||
def get_aggregated_preferences(collection):
|
||||
"""Aggregate preferences from collection owner and shared users."""
|
||||
from integrations.models import UserRecommendationPreferenceProfile
|
||||
|
||||
users = [collection.user] + list(collection.shared_with.all())
|
||||
preferences = []
|
||||
|
||||
for member in users:
|
||||
try:
|
||||
profile = UserRecommendationPreferenceProfile.objects.get(user=member)
|
||||
user_prefs = []
|
||||
|
||||
if profile.cuisines:
|
||||
user_prefs.append(f"cuisines: {profile.cuisines}")
|
||||
if profile.interests:
|
||||
user_prefs.append(f"interests: {_format_interests(profile.interests)}")
|
||||
if profile.trip_style:
|
||||
user_prefs.append(f"style: {profile.trip_style}")
|
||||
if profile.notes:
|
||||
user_prefs.append(f"notes: {profile.notes}")
|
||||
|
||||
if user_prefs:
|
||||
preferences.append(f"- **{member.username}**: {', '.join(user_prefs)}")
|
||||
except UserRecommendationPreferenceProfile.DoesNotExist:
|
||||
continue
|
||||
|
||||
if preferences:
|
||||
return (
|
||||
"\n\n## Party Preferences\n"
|
||||
+ "\n".join(preferences)
|
||||
+ "\n\nNote: Consider all travelers' preferences when making recommendations."
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def get_system_prompt(user, collection=None):
|
||||
"""Build the system prompt with user context."""
|
||||
from integrations.models import UserRecommendationPreferenceProfile
|
||||
@@ -181,26 +330,37 @@ When modifying itineraries:
|
||||
|
||||
Be conversational, helpful, and enthusiastic about travel. Keep responses concise but informative."""
|
||||
|
||||
try:
|
||||
profile = UserRecommendationPreferenceProfile.objects.get(user=user)
|
||||
prefs = []
|
||||
if profile.cuisines:
|
||||
prefs.append(f"Cuisine preferences: {profile.cuisines}")
|
||||
if profile.interests:
|
||||
prefs.append(f"Interests: {profile.interests}")
|
||||
if profile.trip_style:
|
||||
prefs.append(f"Travel style: {profile.trip_style}")
|
||||
if profile.notes:
|
||||
prefs.append(f"Additional notes: {profile.notes}")
|
||||
if prefs:
|
||||
base_prompt += "\n\nUser preferences:\n" + "\n".join(prefs)
|
||||
except UserRecommendationPreferenceProfile.DoesNotExist:
|
||||
pass
|
||||
if collection and collection.shared_with.count() > 0:
|
||||
base_prompt += get_aggregated_preferences(collection)
|
||||
else:
|
||||
try:
|
||||
profile = UserRecommendationPreferenceProfile.objects.get(user=user)
|
||||
preference_lines = []
|
||||
|
||||
if profile.cuisines:
|
||||
preference_lines.append(
|
||||
f"🍽️ **Cuisine Preferences**: {profile.cuisines}"
|
||||
)
|
||||
if profile.interests:
|
||||
preference_lines.append(
|
||||
f"🎯 **Interests**: {_format_interests(profile.interests)}"
|
||||
)
|
||||
if profile.trip_style:
|
||||
preference_lines.append(f"✈️ **Travel Style**: {profile.trip_style}")
|
||||
if profile.notes:
|
||||
preference_lines.append(f"📝 **Additional Notes**: {profile.notes}")
|
||||
|
||||
if preference_lines:
|
||||
base_prompt += "\n\n## Traveler Preferences\n" + "\n".join(
|
||||
preference_lines
|
||||
)
|
||||
except UserRecommendationPreferenceProfile.DoesNotExist:
|
||||
pass
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
async def stream_chat_completion(user, messages, provider, tools=None):
|
||||
async def stream_chat_completion(user, messages, provider, tools=None, model=None):
|
||||
"""Stream a chat completion using LiteLLM.
|
||||
|
||||
Yields SSE-formatted strings.
|
||||
@@ -215,6 +375,7 @@ async def stream_chat_completion(user, messages, provider, tools=None):
|
||||
return
|
||||
|
||||
api_key = get_llm_api_key(user, normalized_provider)
|
||||
|
||||
if provider_config["needs_api_key"] and not api_key:
|
||||
payload = {
|
||||
"error": f"No API key found for provider: {normalized_provider}. Please add one in Settings."
|
||||
@@ -222,14 +383,31 @@ async def stream_chat_completion(user, messages, provider, tools=None):
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
return
|
||||
|
||||
if not _is_model_override_compatible(normalized_provider, provider_config, model):
|
||||
payload = {
|
||||
"error": "The selected model is incompatible with this provider. Choose a model for the selected provider and try again.",
|
||||
"error_category": "invalid_model_for_provider",
|
||||
}
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
return
|
||||
|
||||
completion_kwargs = {
|
||||
"model": provider_config["default_model"],
|
||||
"model": model
|
||||
or (
|
||||
settings.VOYAGE_AI_MODEL
|
||||
if normalized_provider
|
||||
== _normalize_provider_id(settings.VOYAGE_AI_PROVIDER)
|
||||
and settings.VOYAGE_AI_MODEL
|
||||
else None
|
||||
)
|
||||
or provider_config["default_model"],
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto" if tools else None,
|
||||
"stream": True,
|
||||
"api_key": api_key,
|
||||
}
|
||||
if tools:
|
||||
completion_kwargs["tools"] = tools
|
||||
completion_kwargs["tool_choice"] = "auto"
|
||||
if provider_config["api_base"]:
|
||||
completion_kwargs["api_base"] = provider_config["api_base"]
|
||||
|
||||
@@ -271,6 +449,7 @@ async def stream_chat_completion(user, messages, provider, tools=None):
|
||||
yield f"data: {json.dumps(chunk_data)}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
logger.exception("LLM streaming error")
|
||||
yield f"data: {json.dumps({'error': 'An error occurred while processing your request. Please try again.'})}\n\n"
|
||||
payload = _safe_error_payload(exc)
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from django.urls import include, path
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
from .views import ChatProviderCatalogViewSet, ChatViewSet
|
||||
from .views import (
|
||||
CapabilitiesView,
|
||||
ChatProviderCatalogViewSet,
|
||||
ChatViewSet,
|
||||
DaySuggestionsView,
|
||||
)
|
||||
|
||||
router = DefaultRouter()
|
||||
router.register(r"conversations", ChatViewSet, basename="chat-conversation")
|
||||
@@ -11,4 +16,6 @@ router.register(
|
||||
|
||||
urlpatterns = [
|
||||
path("", include(router.urls)),
|
||||
path("capabilities/", CapabilitiesView.as_view(), name="chat-capabilities"),
|
||||
path("suggestions/day/", DaySuggestionsView.as_view(), name="chat-day-suggestions"),
|
||||
]
|
||||
|
||||
328
backend/server/chat/views/__init__.py
Normal file
328
backend/server/chat/views/__init__.py
Normal file
@@ -0,0 +1,328 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from adventures.models import Collection
|
||||
from django.http import StreamingHttpResponse
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from ..agent_tools import AGENT_TOOLS, execute_tool, serialize_tool_result
|
||||
from ..llm_client import (
|
||||
get_provider_catalog,
|
||||
get_system_prompt,
|
||||
is_chat_provider_available,
|
||||
stream_chat_completion,
|
||||
)
|
||||
from ..models import ChatConversation, ChatMessage
|
||||
from ..serializers import ChatConversationSerializer
|
||||
|
||||
|
||||
class ChatViewSet(viewsets.ModelViewSet):
|
||||
serializer_class = ChatConversationSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get_queryset(self):
|
||||
return ChatConversation.objects.filter(user=self.request.user).prefetch_related(
|
||||
"messages"
|
||||
)
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
conversations = self.get_queryset().only("id", "title", "updated_at")
|
||||
data = [
|
||||
{
|
||||
"id": str(conversation.id),
|
||||
"title": conversation.title,
|
||||
"updated_at": conversation.updated_at,
|
||||
}
|
||||
for conversation in conversations
|
||||
]
|
||||
return Response(data)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
conversation = ChatConversation.objects.create(
|
||||
user=request.user,
|
||||
title=(request.data.get("title") or "").strip(),
|
||||
)
|
||||
serializer = self.get_serializer(conversation)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
||||
|
||||
def _build_llm_messages(self, conversation, user, system_prompt=None):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt or get_system_prompt(user),
|
||||
}
|
||||
]
|
||||
for message in conversation.messages.all().order_by("created_at"):
|
||||
payload = {
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
}
|
||||
if message.role == "assistant" and message.tool_calls:
|
||||
payload["tool_calls"] = message.tool_calls
|
||||
if message.role == "tool":
|
||||
payload["tool_call_id"] = message.tool_call_id
|
||||
payload["name"] = message.name
|
||||
messages.append(payload)
|
||||
return messages
|
||||
|
||||
def _async_to_sync_generator(self, async_gen):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
yield loop.run_until_complete(async_gen.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
loop.close()
|
||||
|
||||
@staticmethod
|
||||
def _merge_tool_call_delta(accumulator, tool_calls_delta):
|
||||
for idx, tool_call in enumerate(tool_calls_delta or []):
|
||||
idx = tool_call.get("index", idx)
|
||||
while len(accumulator) <= idx:
|
||||
accumulator.append(
|
||||
{
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
)
|
||||
|
||||
current = accumulator[idx]
|
||||
if tool_call.get("id"):
|
||||
current["id"] = tool_call.get("id")
|
||||
if tool_call.get("type"):
|
||||
current["type"] = tool_call.get("type")
|
||||
|
||||
function_data = tool_call.get("function") or {}
|
||||
if function_data.get("name"):
|
||||
current["function"]["name"] = function_data.get("name")
|
||||
if function_data.get("arguments"):
|
||||
current["function"]["arguments"] += function_data.get("arguments")
|
||||
|
||||
@action(detail=True, methods=["post"])
|
||||
def send_message(self, request, pk=None):
|
||||
conversation = self.get_object()
|
||||
user_content = (request.data.get("message") or "").strip()
|
||||
if not user_content:
|
||||
return Response(
|
||||
{"error": "message is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
provider = (request.data.get("provider") or "openai").strip().lower()
|
||||
model = (request.data.get("model") or "").strip() or None
|
||||
collection_id = request.data.get("collection_id")
|
||||
collection_name = request.data.get("collection_name")
|
||||
start_date = request.data.get("start_date")
|
||||
end_date = request.data.get("end_date")
|
||||
destination = request.data.get("destination")
|
||||
if not is_chat_provider_available(provider):
|
||||
return Response(
|
||||
{"error": f"Provider is not available for chat: {provider}."},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
context_parts = []
|
||||
if collection_name:
|
||||
context_parts.append(f"Trip: {collection_name}")
|
||||
if destination:
|
||||
context_parts.append(f"Destination: {destination}")
|
||||
if start_date and end_date:
|
||||
context_parts.append(f"Dates: {start_date} to {end_date}")
|
||||
|
||||
collection = None
|
||||
if collection_id:
|
||||
try:
|
||||
requested_collection = Collection.objects.get(id=collection_id)
|
||||
if (
|
||||
requested_collection.user == request.user
|
||||
or requested_collection.shared_with.filter(
|
||||
id=request.user.id
|
||||
).exists()
|
||||
):
|
||||
collection = requested_collection
|
||||
except Collection.DoesNotExist:
|
||||
pass
|
||||
|
||||
system_prompt = get_system_prompt(request.user, collection)
|
||||
if context_parts:
|
||||
system_prompt += "\n\n## Trip Context\n" + "\n".join(context_parts)
|
||||
|
||||
ChatMessage.objects.create(
|
||||
conversation=conversation,
|
||||
role="user",
|
||||
content=user_content,
|
||||
)
|
||||
conversation.save(update_fields=["updated_at"])
|
||||
|
||||
if not conversation.title:
|
||||
conversation.title = user_content[:120]
|
||||
conversation.save(update_fields=["title", "updated_at"])
|
||||
|
||||
llm_messages = self._build_llm_messages(
|
||||
conversation,
|
||||
request.user,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
async def event_stream():
|
||||
current_messages = list(llm_messages)
|
||||
encountered_error = False
|
||||
tool_iterations = 0
|
||||
|
||||
while tool_iterations < MAX_TOOL_ITERATIONS:
|
||||
content_chunks = []
|
||||
tool_calls_accumulator = []
|
||||
|
||||
async for chunk in stream_chat_completion(
|
||||
request.user,
|
||||
current_messages,
|
||||
provider,
|
||||
tools=AGENT_TOOLS,
|
||||
model=model,
|
||||
):
|
||||
if not chunk.startswith("data: "):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
payload = chunk[len("data: ") :].strip()
|
||||
if payload == "[DONE]":
|
||||
continue
|
||||
|
||||
yield chunk
|
||||
|
||||
try:
|
||||
data = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if data.get("error"):
|
||||
encountered_error = True
|
||||
break
|
||||
|
||||
if data.get("content"):
|
||||
content_chunks.append(data["content"])
|
||||
|
||||
if data.get("tool_calls"):
|
||||
self._merge_tool_call_delta(
|
||||
tool_calls_accumulator,
|
||||
data["tool_calls"],
|
||||
)
|
||||
|
||||
if encountered_error:
|
||||
break
|
||||
|
||||
assistant_content = "".join(content_chunks)
|
||||
|
||||
if tool_calls_accumulator:
|
||||
assistant_with_tools = {
|
||||
"role": "assistant",
|
||||
"content": assistant_content,
|
||||
"tool_calls": tool_calls_accumulator,
|
||||
}
|
||||
current_messages.append(assistant_with_tools)
|
||||
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create, thread_sensitive=True
|
||||
)(
|
||||
conversation=conversation,
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
tool_calls=tool_calls_accumulator,
|
||||
)
|
||||
await sync_to_async(conversation.save, thread_sensitive=True)(
|
||||
update_fields=["updated_at"]
|
||||
)
|
||||
|
||||
for tool_call in tool_calls_accumulator:
|
||||
function_payload = tool_call.get("function") or {}
|
||||
function_name = function_payload.get("name") or ""
|
||||
raw_arguments = function_payload.get("arguments") or "{}"
|
||||
|
||||
try:
|
||||
arguments = json.loads(raw_arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if not isinstance(arguments, dict):
|
||||
arguments = {}
|
||||
|
||||
result = await sync_to_async(
|
||||
execute_tool, thread_sensitive=True
|
||||
)(
|
||||
function_name,
|
||||
request.user,
|
||||
**arguments,
|
||||
)
|
||||
result_content = serialize_tool_result(result)
|
||||
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
"name": function_name,
|
||||
"content": result_content,
|
||||
}
|
||||
)
|
||||
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create, thread_sensitive=True
|
||||
)(
|
||||
conversation=conversation,
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=tool_call.get("id"),
|
||||
name=function_name,
|
||||
)
|
||||
await sync_to_async(conversation.save, thread_sensitive=True)(
|
||||
update_fields=["updated_at"]
|
||||
)
|
||||
|
||||
tool_event = {
|
||||
"tool_result": {
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
"name": function_name,
|
||||
"result": result,
|
||||
}
|
||||
}
|
||||
yield f"data: {json.dumps(tool_event)}\n\n"
|
||||
|
||||
continue
|
||||
|
||||
await sync_to_async(ChatMessage.objects.create, thread_sensitive=True)(
|
||||
conversation=conversation,
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
)
|
||||
await sync_to_async(conversation.save, thread_sensitive=True)(
|
||||
update_fields=["updated_at"]
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
streaming_content=self._async_to_sync_generator(event_stream()),
|
||||
content_type="text/event-stream",
|
||||
)
|
||||
response["Cache-Control"] = "no-cache"
|
||||
response["X-Accel-Buffering"] = "no"
|
||||
return response
|
||||
|
||||
|
||||
class ChatProviderCatalogViewSet(viewsets.ViewSet):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def list(self, request):
|
||||
return Response(get_provider_catalog(user=request.user))
|
||||
|
||||
|
||||
from .capabilities import CapabilitiesView
|
||||
from .day_suggestions import DaySuggestionsView
|
||||
24
backend/server/chat/views/capabilities.py
Normal file
24
backend/server/chat/views/capabilities.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from chat.agent_tools import get_tool_schemas
|
||||
|
||||
|
||||
class CapabilitiesView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get(self, request):
|
||||
"""Return available AI capabilities/tools."""
|
||||
tools = get_tool_schemas()
|
||||
return Response(
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"name": tool["function"]["name"],
|
||||
"description": tool["function"]["description"],
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
}
|
||||
)
|
||||
215
backend/server/chat/views/day_suggestions.py
Normal file
215
backend/server/chat/views/day_suggestions.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
import litellm
|
||||
from django.shortcuts import get_object_or_404
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from adventures.models import Collection
|
||||
from chat.agent_tools import search_places
|
||||
from chat.llm_client import (
|
||||
get_llm_api_key,
|
||||
get_system_prompt,
|
||||
is_chat_provider_available,
|
||||
)
|
||||
|
||||
|
||||
class DaySuggestionsView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def post(self, request):
|
||||
collection_id = request.data.get("collection_id")
|
||||
date = request.data.get("date")
|
||||
category = request.data.get("category")
|
||||
filters = request.data.get("filters", {}) or {}
|
||||
location_context = request.data.get("location_context", "")
|
||||
|
||||
if not all([collection_id, date, category]):
|
||||
return Response(
|
||||
{"error": "collection_id, date, and category are required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
valid_categories = ["restaurant", "activity", "event", "lodging"]
|
||||
if category not in valid_categories:
|
||||
return Response(
|
||||
{"error": f"category must be one of: {', '.join(valid_categories)}"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
collection = get_object_or_404(Collection, id=collection_id)
|
||||
if (
|
||||
collection.user != request.user
|
||||
and not collection.shared_with.filter(id=request.user.id).exists()
|
||||
):
|
||||
return Response(
|
||||
{"error": "You don't have access to this collection"},
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
location = location_context or self._get_collection_location(collection)
|
||||
system_prompt = get_system_prompt(request.user, collection)
|
||||
provider = "openai"
|
||||
|
||||
if not is_chat_provider_available(provider):
|
||||
return Response(
|
||||
{
|
||||
"error": "AI suggestions are not available. Please configure an API key."
|
||||
},
|
||||
status=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
try:
|
||||
places_context = self._get_places_context(request.user, category, location)
|
||||
prompt = self._build_prompt(
|
||||
category=category,
|
||||
filters=filters,
|
||||
location=location,
|
||||
date=date,
|
||||
collection=collection,
|
||||
places_context=places_context,
|
||||
)
|
||||
|
||||
suggestions = self._get_suggestions_from_llm(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=prompt,
|
||||
user=request.user,
|
||||
provider=provider,
|
||||
)
|
||||
return Response({"suggestions": suggestions}, status=status.HTTP_200_OK)
|
||||
except Exception:
|
||||
return Response(
|
||||
{"error": "Failed to generate suggestions. Please try again."},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
def _get_collection_location(self, collection):
|
||||
for loc in collection.locations.select_related("city", "country").all():
|
||||
if loc.city:
|
||||
city_name = getattr(loc.city, "name", str(loc.city))
|
||||
country_name = getattr(loc.country, "name", "") if loc.country else ""
|
||||
return ", ".join([x for x in [city_name, country_name] if x])
|
||||
if loc.location:
|
||||
return loc.location
|
||||
if loc.name:
|
||||
return loc.name
|
||||
return "Unknown location"
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
category,
|
||||
filters,
|
||||
location,
|
||||
date,
|
||||
collection,
|
||||
places_context="",
|
||||
):
|
||||
category_prompts = {
|
||||
"restaurant": f"Find restaurant recommendations for {location}",
|
||||
"activity": f"Find activity recommendations for {location}",
|
||||
"event": f"Find event recommendations for {location} around {date}",
|
||||
"lodging": f"Find lodging recommendations for {location}",
|
||||
}
|
||||
|
||||
prompt = category_prompts.get(
|
||||
category, f"Find {category} recommendations for {location}"
|
||||
)
|
||||
|
||||
if filters:
|
||||
filter_parts = []
|
||||
if filters.get("cuisine_type"):
|
||||
filter_parts.append(f"cuisine type: {filters['cuisine_type']}")
|
||||
if filters.get("price_range"):
|
||||
filter_parts.append(f"price range: {filters['price_range']}")
|
||||
if filters.get("dietary"):
|
||||
filter_parts.append(f"dietary restrictions: {filters['dietary']}")
|
||||
if filters.get("activity_type"):
|
||||
filter_parts.append(f"type: {filters['activity_type']}")
|
||||
if filters.get("duration"):
|
||||
filter_parts.append(f"duration: {filters['duration']}")
|
||||
if filters.get("event_type"):
|
||||
filter_parts.append(f"event type: {filters['event_type']}")
|
||||
if filters.get("lodging_type"):
|
||||
filter_parts.append(f"lodging type: {filters['lodging_type']}")
|
||||
amenities = filters.get("amenities")
|
||||
if isinstance(amenities, list) and amenities:
|
||||
filter_parts.append(
|
||||
f"amenities: {', '.join(str(x) for x in amenities)}"
|
||||
)
|
||||
|
||||
if filter_parts:
|
||||
prompt += f" with these preferences: {', '.join(filter_parts)}"
|
||||
|
||||
prompt += f". The trip date is {date}."
|
||||
|
||||
if collection.start_date or collection.end_date:
|
||||
prompt += (
|
||||
" Collection trip window: "
|
||||
f"{collection.start_date or 'unknown'} to {collection.end_date or 'unknown'}."
|
||||
)
|
||||
|
||||
if places_context:
|
||||
prompt += f" Nearby place context: {places_context}."
|
||||
|
||||
prompt += (
|
||||
" Return 3-5 specific suggestions as a JSON array."
|
||||
" Each suggestion should have: name, description, why_fits, category, location, rating, price_level."
|
||||
" Return ONLY valid JSON, no markdown, no surrounding text."
|
||||
)
|
||||
return prompt
|
||||
|
||||
def _get_places_context(self, user, category, location):
|
||||
tool_category_map = {
|
||||
"restaurant": "food",
|
||||
"activity": "tourism",
|
||||
"event": "tourism",
|
||||
"lodging": "lodging",
|
||||
}
|
||||
result = search_places(
|
||||
user,
|
||||
location=location,
|
||||
category=tool_category_map.get(category, "tourism"),
|
||||
radius=8,
|
||||
)
|
||||
if result.get("error"):
|
||||
return ""
|
||||
|
||||
entries = []
|
||||
for place in result.get("results", [])[:5]:
|
||||
name = place.get("name")
|
||||
address = place.get("address") or ""
|
||||
if name:
|
||||
entries.append(f"{name} ({address})" if address else name)
|
||||
return "; ".join(entries)
|
||||
|
||||
def _get_suggestions_from_llm(self, system_prompt, user_prompt, user, provider):
|
||||
api_key = get_llm_api_key(user, provider)
|
||||
if not api_key:
|
||||
raise ValueError("No API key available")
|
||||
|
||||
response = litellm.completion(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
api_key=api_key,
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
content = (response.choices[0].message.content or "").strip()
|
||||
try:
|
||||
json_match = re.search(r"\[.*\]", content, re.DOTALL)
|
||||
parsed = (
|
||||
json.loads(json_match.group())
|
||||
if json_match
|
||||
else json.loads(content or "[]")
|
||||
)
|
||||
suggestions = parsed if isinstance(parsed, list) else [parsed]
|
||||
return suggestions[:5]
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
@@ -0,0 +1,52 @@
|
||||
# Generated by Django 5.2.12 on 2026-03-08
|
||||
|
||||
import django.db.models.deletion
|
||||
import uuid
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("integrations", "0007_userapikey_userrecommendationpreferenceprofile"),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="UserAISettings",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4,
|
||||
editable=False,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
),
|
||||
),
|
||||
(
|
||||
"preferred_provider",
|
||||
models.CharField(blank=True, max_length=100, null=True),
|
||||
),
|
||||
(
|
||||
"preferred_model",
|
||||
models.CharField(blank=True, max_length=100, null=True),
|
||||
),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"user",
|
||||
models.OneToOneField(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="ai_settings",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "User AI Settings",
|
||||
"verbose_name_plural": "User AI Settings",
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -124,3 +124,23 @@ class UserRecommendationPreferenceProfile(models.Model):
|
||||
notes = models.TextField(blank=True, null=True)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
|
||||
class UserAISettings(models.Model):
|
||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
||||
user = models.OneToOneField(
|
||||
settings.AUTH_USER_MODEL,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="ai_settings",
|
||||
)
|
||||
preferred_provider = models.CharField(max_length=100, blank=True, null=True)
|
||||
preferred_model = models.CharField(max_length=100, blank=True, null=True)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
class Meta:
|
||||
verbose_name = "User AI Settings"
|
||||
verbose_name_plural = "User AI Settings"
|
||||
|
||||
def __str__(self):
|
||||
return f"AI Settings for {self.user.username}"
|
||||
|
||||
@@ -3,6 +3,7 @@ from django.db import IntegrityError
|
||||
from .models import (
|
||||
EncryptionConfigurationError,
|
||||
ImmichIntegration,
|
||||
UserAISettings,
|
||||
UserAPIKey,
|
||||
UserRecommendationPreferenceProfile,
|
||||
)
|
||||
@@ -98,3 +99,16 @@ class UserRecommendationPreferenceProfileSerializer(serializers.ModelSerializer)
|
||||
"updated_at",
|
||||
]
|
||||
read_only_fields = ["id", "created_at", "updated_at"]
|
||||
|
||||
|
||||
class UserAISettingsSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = UserAISettings
|
||||
fields = [
|
||||
"id",
|
||||
"preferred_provider",
|
||||
"preferred_model",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
read_only_fields = ["id", "created_at", "updated_at"]
|
||||
|
||||
@@ -6,6 +6,7 @@ from integrations.views import (
|
||||
StravaIntegrationView,
|
||||
WandererIntegrationViewSet,
|
||||
UserAPIKeyViewSet,
|
||||
UserAISettingsViewSet,
|
||||
UserRecommendationPreferenceProfileViewSet,
|
||||
)
|
||||
|
||||
@@ -22,6 +23,7 @@ router.register(
|
||||
UserRecommendationPreferenceProfileViewSet,
|
||||
basename="user-recommendation-preferences",
|
||||
)
|
||||
router.register(r"ai-settings", UserAISettingsViewSet, basename="user-ai-settings")
|
||||
|
||||
# Include the router URLs
|
||||
urlpatterns = [
|
||||
|
||||
@@ -4,3 +4,4 @@ from .strava_view import StravaIntegrationView
|
||||
from .wanderer_view import WandererIntegrationViewSet
|
||||
from .user_api_key_view import UserAPIKeyViewSet
|
||||
from .recommendation_profile_view import UserRecommendationPreferenceProfileViewSet
|
||||
from .ai_settings_view import UserAISettingsViewSet
|
||||
|
||||
39
backend/server/integrations/views/ai_settings_view.py
Normal file
39
backend/server/integrations/views/ai_settings_view.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from integrations.models import UserAISettings
|
||||
from integrations.serializers import UserAISettingsSerializer
|
||||
|
||||
|
||||
class UserAISettingsViewSet(viewsets.ModelViewSet):
|
||||
serializer_class = UserAISettingsSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get_queryset(self):
|
||||
return UserAISettings.objects.filter(user=self.request.user)
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
instance = self.get_queryset().first()
|
||||
if not instance:
|
||||
return Response([], status=status.HTTP_200_OK)
|
||||
serializer = self.get_serializer(instance)
|
||||
return Response([serializer.data], status=status.HTTP_200_OK)
|
||||
|
||||
def perform_create(self, serializer):
|
||||
existing = UserAISettings.objects.filter(user=self.request.user).first()
|
||||
if existing:
|
||||
for field, value in serializer.validated_data.items():
|
||||
setattr(existing, field, value)
|
||||
existing.save()
|
||||
self._upserted_instance = existing
|
||||
return
|
||||
|
||||
self._upserted_instance = serializer.save(user=self.request.user)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_create(serializer)
|
||||
output = self.get_serializer(self._upserted_instance)
|
||||
return Response(output.data, status=status.HTTP_200_OK)
|
||||
@@ -403,6 +403,11 @@ OSRM_BASE_URL = getenv("OSRM_BASE_URL", "https://router.project-osrm.org")
|
||||
|
||||
FIELD_ENCRYPTION_KEY = getenv("FIELD_ENCRYPTION_KEY", "")
|
||||
|
||||
# Voyage AI Configuration
|
||||
VOYAGE_AI_PROVIDER = getenv("VOYAGE_AI_PROVIDER", "openai")
|
||||
VOYAGE_AI_MODEL = getenv("VOYAGE_AI_MODEL", "gpt-4o-mini")
|
||||
VOYAGE_AI_API_KEY = getenv("VOYAGE_AI_API_KEY", "")
|
||||
|
||||
DJANGO_MCP_ENDPOINT = getenv("DJANGO_MCP_ENDPOINT", "api/mcp")
|
||||
DJANGO_MCP_AUTHENTICATION_CLASSES = [
|
||||
"rest_framework.authentication.TokenAuthentication",
|
||||
|
||||
@@ -34,3 +34,4 @@ requests>=2.32.5
|
||||
cryptography>=46.0.5
|
||||
django-mcp-server>=0.5.7
|
||||
litellm>=1.72.3
|
||||
duckduckgo-search>=4.0.0
|
||||
|
||||
Reference in New Issue
Block a user