Files
voyage/backend/server/chat/llm_client.py
alex 9d5681b1ef feat(ai): implement agent-redesign plan with enhanced AI travel features
Phase 1 - Configuration Infrastructure (WS1):
- Add instance-level AI env vars (VOYAGE_AI_PROVIDER, VOYAGE_AI_MODEL, VOYAGE_AI_API_KEY)
- Implement fallback chain: user key → instance key → error
- Add UserAISettings model for per-user provider/model preferences
- Enhance provider catalog with instance_configured and user_configured flags
- Optimize provider catalog to avoid N+1 queries

Phase 1 - User Preference Learning (WS2):
- Add Travel Preferences tab to Settings page
- Improve preference formatting in system prompt with emoji headers
- Add multi-user preference aggregation for shared collections

Phase 2 - Day-Level Suggestions Modal (WS3):
- Create ItinerarySuggestionModal with 3-step flow (category → filters → results)
- Add AI suggestions button to itinerary Add dropdown
- Support restaurant, activity, event, and lodging categories
- Backend endpoint POST /api/chat/suggestions/day/ with context-aware prompts

Phase 3 - Collection-Level Chat Improvements (WS4):
- Inject collection context (destination, dates) into chat system prompt
- Add quick action buttons for common queries
- Add 'Add to itinerary' button on search_places results
- Update chat UI with travel-themed branding and improved tool result cards

Phase 3 - Web Search Capability (WS5):
- Add web_search agent tool using DuckDuckGo
- Support location_context parameter for biased results
- Handle rate limiting gracefully

Phase 4 - Extensibility Architecture (WS6):
- Implement decorator-based @agent_tool registry
- Convert existing tools to use decorators
- Add GET /api/chat/capabilities/ endpoint for tool discovery
- Refactor execute_tool() to use registry pattern
2026-03-08 23:53:14 +00:00

456 lines
15 KiB
Python

import json
import logging
import litellm
from django.conf import settings
from integrations.models import UserAPIKey
logger = logging.getLogger(__name__)
PROVIDER_MODEL_PREFIX = {
"openai": "openai",
"anthropic": "anthropic",
"gemini": "gemini",
"ollama": "ollama",
"groq": "groq",
"mistral": "mistral",
"github_models": "github",
"openrouter": "openrouter",
}
CHAT_PROVIDER_CONFIG = {
"openai": {
"label": "OpenAI",
"needs_api_key": True,
"default_model": "gpt-4o",
"api_base": None,
},
"anthropic": {
"label": "Anthropic",
"needs_api_key": True,
"default_model": "anthropic/claude-sonnet-4-20250514",
"api_base": None,
},
"gemini": {
"label": "Google Gemini",
"needs_api_key": True,
"default_model": "gemini/gemini-2.0-flash",
"api_base": None,
},
"ollama": {
"label": "Ollama",
"needs_api_key": True,
"default_model": "ollama/llama3.1",
"api_base": None,
},
"groq": {
"label": "Groq",
"needs_api_key": True,
"default_model": "groq/llama-3.3-70b-versatile",
"api_base": None,
},
"mistral": {
"label": "Mistral",
"needs_api_key": True,
"default_model": "mistral/mistral-large-latest",
"api_base": None,
},
"github_models": {
"label": "GitHub Models",
"needs_api_key": True,
"default_model": "github/gpt-4o",
"api_base": None,
},
"openrouter": {
"label": "OpenRouter",
"needs_api_key": True,
"default_model": "openrouter/auto",
"api_base": None,
},
"opencode_zen": {
"label": "OpenCode Zen",
"needs_api_key": True,
# Chosen from OpenCode Zen compatible OpenAI-routed models per
# opencode_zen connection research (see .memory/research).
"default_model": "openai/gpt-5-nano",
"api_base": "https://opencode.ai/zen/v1",
},
}
def _is_model_override_compatible(provider, provider_config, model):
"""Validate model/provider compatibility when strict checks are safe.
For providers with a custom api_base gateway, skip strict prefix checks since
gateway routing may legitimately accept cross-provider prefixes.
"""
if not model or provider_config.get("api_base"):
return True
if "/" not in model:
return True
expected_prefix = PROVIDER_MODEL_PREFIX.get(provider)
if not expected_prefix:
default_model = provider_config.get("default_model") or ""
if "/" in default_model:
expected_prefix = default_model.split("/", 1)[0]
if not expected_prefix:
return True
return model.startswith(f"{expected_prefix}/")
def _safe_error_payload(exc):
exceptions = getattr(litellm, "exceptions", None)
not_found_cls = getattr(exceptions, "NotFoundError", tuple())
auth_cls = getattr(exceptions, "AuthenticationError", tuple())
rate_limit_cls = getattr(exceptions, "RateLimitError", tuple())
bad_request_cls = getattr(exceptions, "BadRequestError", tuple())
timeout_cls = getattr(exceptions, "Timeout", tuple())
api_connection_cls = getattr(exceptions, "APIConnectionError", tuple())
if isinstance(exc, not_found_cls):
return {
"error": "The selected model is unavailable for this provider. Choose a different model and try again.",
"error_category": "model_not_found",
}
if isinstance(exc, auth_cls):
return {
"error": "Authentication with the model provider failed. Verify your API key in Settings and try again.",
"error_category": "authentication_failed",
}
if isinstance(exc, rate_limit_cls):
return {
"error": "The model provider rate limit was reached. Please wait and try again.",
"error_category": "rate_limited",
}
if isinstance(exc, bad_request_cls):
return {
"error": "The model provider rejected this request. Check your selected model and try again.",
"error_category": "invalid_request",
}
if isinstance(exc, timeout_cls) or isinstance(exc, api_connection_cls):
return {
"error": "Unable to reach the model provider right now. Please try again.",
"error_category": "provider_unreachable",
}
return {
"error": "An error occurred while processing your request. Please try again.",
"error_category": "unknown_error",
}
def _safe_get(obj, key, default=None):
if obj is None:
return default
if isinstance(obj, dict):
return obj.get(key, default)
return getattr(obj, key, default)
def _normalize_provider_id(provider_id):
value = str(provider_id or "").strip()
lowered = value.lower()
if lowered.startswith("llmproviders."):
return lowered.split(".", 1)[1]
return lowered
def _default_provider_label(provider_id):
return provider_id.replace("_", " ").title()
def is_chat_provider_available(provider_id):
normalized_provider = _normalize_provider_id(provider_id)
return normalized_provider in CHAT_PROVIDER_CONFIG
def get_provider_catalog(user=None):
seen = set()
catalog = []
user_key_providers = set()
instance_provider = (
_normalize_provider_id(settings.VOYAGE_AI_PROVIDER)
if settings.VOYAGE_AI_PROVIDER
else None
)
instance_has_key = bool(settings.VOYAGE_AI_API_KEY)
if user:
user_key_providers = set(
UserAPIKey.objects.filter(user=user).values_list("provider", flat=True)
)
for provider_id in getattr(litellm, "provider_list", []):
normalized_provider = _normalize_provider_id(provider_id)
if not normalized_provider or normalized_provider in seen:
continue
seen.add(normalized_provider)
provider_config = CHAT_PROVIDER_CONFIG.get(normalized_provider)
if provider_config:
catalog.append(
{
"id": normalized_provider,
"label": provider_config["label"],
"available_for_chat": True,
"needs_api_key": provider_config["needs_api_key"],
"default_model": provider_config["default_model"],
"api_base": provider_config["api_base"],
"instance_configured": instance_has_key
and normalized_provider == instance_provider,
"user_configured": normalized_provider in user_key_providers,
}
)
continue
catalog.append(
{
"id": normalized_provider,
"label": _default_provider_label(normalized_provider),
"available_for_chat": False,
"needs_api_key": None,
"default_model": None,
"api_base": None,
"instance_configured": instance_has_key
and normalized_provider == instance_provider,
"user_configured": normalized_provider in user_key_providers,
}
)
# Include app-supported OpenAI-compatible aliases that are not part of
# LiteLLM's native provider list (for example OpenCode Zen).
for provider_id, provider_config in CHAT_PROVIDER_CONFIG.items():
normalized_provider = _normalize_provider_id(provider_id)
if not normalized_provider or normalized_provider in seen:
continue
seen.add(normalized_provider)
catalog.append(
{
"id": normalized_provider,
"label": provider_config["label"],
"available_for_chat": True,
"needs_api_key": provider_config["needs_api_key"],
"default_model": provider_config["default_model"],
"api_base": provider_config["api_base"],
"instance_configured": instance_has_key
and normalized_provider == instance_provider,
"user_configured": normalized_provider in user_key_providers,
}
)
return catalog
def get_llm_api_key(user, provider):
"""Get the user's API key for the given provider."""
normalized_provider = _normalize_provider_id(provider)
try:
key_obj = UserAPIKey.objects.get(user=user, provider=normalized_provider)
return key_obj.get_api_key()
except UserAPIKey.DoesNotExist:
if normalized_provider == _normalize_provider_id(settings.VOYAGE_AI_PROVIDER):
instance_api_key = (settings.VOYAGE_AI_API_KEY or "").strip()
if instance_api_key:
return instance_api_key
return None
def _format_interests(interests):
if isinstance(interests, list):
return ", ".join(interests)
return interests
def get_aggregated_preferences(collection):
"""Aggregate preferences from collection owner and shared users."""
from integrations.models import UserRecommendationPreferenceProfile
users = [collection.user] + list(collection.shared_with.all())
preferences = []
for member in users:
try:
profile = UserRecommendationPreferenceProfile.objects.get(user=member)
user_prefs = []
if profile.cuisines:
user_prefs.append(f"cuisines: {profile.cuisines}")
if profile.interests:
user_prefs.append(f"interests: {_format_interests(profile.interests)}")
if profile.trip_style:
user_prefs.append(f"style: {profile.trip_style}")
if profile.notes:
user_prefs.append(f"notes: {profile.notes}")
if user_prefs:
preferences.append(f"- **{member.username}**: {', '.join(user_prefs)}")
except UserRecommendationPreferenceProfile.DoesNotExist:
continue
if preferences:
return (
"\n\n## Party Preferences\n"
+ "\n".join(preferences)
+ "\n\nNote: Consider all travelers' preferences when making recommendations."
)
return ""
def get_system_prompt(user, collection=None):
"""Build the system prompt with user context."""
from integrations.models import UserRecommendationPreferenceProfile
base_prompt = """You are a helpful travel planning assistant for the Voyage travel app. You help users discover places, plan trips, and organize their itineraries.
Your capabilities:
- Search for interesting places (restaurants, tourist attractions, hotels) near any location
- View and manage the user's trip collections and itineraries
- Add new locations to trip itineraries
- Check weather/temperature data for travel dates
When suggesting places:
- Be specific with names, addresses, and why a place is worth visiting
- Consider the user's travel dates and weather conditions
- Group suggestions logically (by area, by type, by day)
When modifying itineraries:
- Always confirm with the user before adding items
- Suggest logical ordering based on geography
- Consider travel time between locations
Be conversational, helpful, and enthusiastic about travel. Keep responses concise but informative."""
if collection and collection.shared_with.count() > 0:
base_prompt += get_aggregated_preferences(collection)
else:
try:
profile = UserRecommendationPreferenceProfile.objects.get(user=user)
preference_lines = []
if profile.cuisines:
preference_lines.append(
f"🍽️ **Cuisine Preferences**: {profile.cuisines}"
)
if profile.interests:
preference_lines.append(
f"🎯 **Interests**: {_format_interests(profile.interests)}"
)
if profile.trip_style:
preference_lines.append(f"✈️ **Travel Style**: {profile.trip_style}")
if profile.notes:
preference_lines.append(f"📝 **Additional Notes**: {profile.notes}")
if preference_lines:
base_prompt += "\n\n## Traveler Preferences\n" + "\n".join(
preference_lines
)
except UserRecommendationPreferenceProfile.DoesNotExist:
pass
return base_prompt
async def stream_chat_completion(user, messages, provider, tools=None, model=None):
"""Stream a chat completion using LiteLLM.
Yields SSE-formatted strings.
"""
normalized_provider = _normalize_provider_id(provider)
provider_config = CHAT_PROVIDER_CONFIG.get(normalized_provider)
if not provider_config:
payload = {
"error": f"Provider is not available for chat: {normalized_provider}."
}
yield f"data: {json.dumps(payload)}\n\n"
return
api_key = get_llm_api_key(user, normalized_provider)
if provider_config["needs_api_key"] and not api_key:
payload = {
"error": f"No API key found for provider: {normalized_provider}. Please add one in Settings."
}
yield f"data: {json.dumps(payload)}\n\n"
return
if not _is_model_override_compatible(normalized_provider, provider_config, model):
payload = {
"error": "The selected model is incompatible with this provider. Choose a model for the selected provider and try again.",
"error_category": "invalid_model_for_provider",
}
yield f"data: {json.dumps(payload)}\n\n"
return
completion_kwargs = {
"model": model
or (
settings.VOYAGE_AI_MODEL
if normalized_provider
== _normalize_provider_id(settings.VOYAGE_AI_PROVIDER)
and settings.VOYAGE_AI_MODEL
else None
)
or provider_config["default_model"],
"messages": messages,
"stream": True,
"api_key": api_key,
}
if tools:
completion_kwargs["tools"] = tools
completion_kwargs["tool_choice"] = "auto"
if provider_config["api_base"]:
completion_kwargs["api_base"] = provider_config["api_base"]
try:
response = await litellm.acompletion(**completion_kwargs)
async for chunk in response:
choices = _safe_get(chunk, "choices", []) or []
if not choices:
continue
delta = _safe_get(choices[0], "delta")
if not delta:
continue
chunk_data = {}
content = _safe_get(delta, "content")
if content:
chunk_data["content"] = content
tool_calls = _safe_get(delta, "tool_calls") or []
if tool_calls:
serialized = []
for tool_call in tool_calls:
function = _safe_get(tool_call, "function")
serialized.append(
{
"id": _safe_get(tool_call, "id"),
"type": _safe_get(tool_call, "type"),
"function": {
"name": _safe_get(function, "name", "") or "",
"arguments": _safe_get(function, "arguments", "") or "",
},
}
)
chunk_data["tool_calls"] = serialized
if chunk_data:
yield f"data: {json.dumps(chunk_data)}\n\n"
yield "data: [DONE]\n\n"
except Exception as exc:
logger.exception("LLM streaming error")
payload = _safe_error_payload(exc)
yield f"data: {json.dumps(payload)}\n\n"