feat(ai): implement agent-redesign plan with enhanced AI travel features
Phase 1 - Configuration Infrastructure (WS1): - Add instance-level AI env vars (VOYAGE_AI_PROVIDER, VOYAGE_AI_MODEL, VOYAGE_AI_API_KEY) - Implement fallback chain: user key → instance key → error - Add UserAISettings model for per-user provider/model preferences - Enhance provider catalog with instance_configured and user_configured flags - Optimize provider catalog to avoid N+1 queries Phase 1 - User Preference Learning (WS2): - Add Travel Preferences tab to Settings page - Improve preference formatting in system prompt with emoji headers - Add multi-user preference aggregation for shared collections Phase 2 - Day-Level Suggestions Modal (WS3): - Create ItinerarySuggestionModal with 3-step flow (category → filters → results) - Add AI suggestions button to itinerary Add dropdown - Support restaurant, activity, event, and lodging categories - Backend endpoint POST /api/chat/suggestions/day/ with context-aware prompts Phase 3 - Collection-Level Chat Improvements (WS4): - Inject collection context (destination, dates) into chat system prompt - Add quick action buttons for common queries - Add 'Add to itinerary' button on search_places results - Update chat UI with travel-themed branding and improved tool result cards Phase 3 - Web Search Capability (WS5): - Add web_search agent tool using DuckDuckGo - Support location_context parameter for biased results - Handle rate limiting gracefully Phase 4 - Extensibility Architecture (WS6): - Implement decorator-based @agent_tool registry - Convert existing tools to use decorators - Add GET /api/chat/capabilities/ endpoint for tool discovery - Refactor execute_tool() to use registry pattern
This commit is contained in:
@@ -2,11 +2,23 @@ import json
|
||||
import logging
|
||||
|
||||
import litellm
|
||||
from django.conf import settings
|
||||
|
||||
from integrations.models import UserAPIKey
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROVIDER_MODEL_PREFIX = {
|
||||
"openai": "openai",
|
||||
"anthropic": "anthropic",
|
||||
"gemini": "gemini",
|
||||
"ollama": "ollama",
|
||||
"groq": "groq",
|
||||
"mistral": "mistral",
|
||||
"github_models": "github",
|
||||
"openrouter": "openrouter",
|
||||
}
|
||||
|
||||
CHAT_PROVIDER_CONFIG = {
|
||||
"openai": {
|
||||
"label": "OpenAI",
|
||||
@@ -59,12 +71,83 @@ CHAT_PROVIDER_CONFIG = {
|
||||
"opencode_zen": {
|
||||
"label": "OpenCode Zen",
|
||||
"needs_api_key": True,
|
||||
"default_model": "openai/gpt-4o-mini",
|
||||
# Chosen from OpenCode Zen compatible OpenAI-routed models per
|
||||
# opencode_zen connection research (see .memory/research).
|
||||
"default_model": "openai/gpt-5-nano",
|
||||
"api_base": "https://opencode.ai/zen/v1",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _is_model_override_compatible(provider, provider_config, model):
|
||||
"""Validate model/provider compatibility when strict checks are safe.
|
||||
|
||||
For providers with a custom api_base gateway, skip strict prefix checks since
|
||||
gateway routing may legitimately accept cross-provider prefixes.
|
||||
"""
|
||||
if not model or provider_config.get("api_base"):
|
||||
return True
|
||||
|
||||
if "/" not in model:
|
||||
return True
|
||||
|
||||
expected_prefix = PROVIDER_MODEL_PREFIX.get(provider)
|
||||
if not expected_prefix:
|
||||
default_model = provider_config.get("default_model") or ""
|
||||
if "/" in default_model:
|
||||
expected_prefix = default_model.split("/", 1)[0]
|
||||
|
||||
if not expected_prefix:
|
||||
return True
|
||||
|
||||
return model.startswith(f"{expected_prefix}/")
|
||||
|
||||
|
||||
def _safe_error_payload(exc):
|
||||
exceptions = getattr(litellm, "exceptions", None)
|
||||
not_found_cls = getattr(exceptions, "NotFoundError", tuple())
|
||||
auth_cls = getattr(exceptions, "AuthenticationError", tuple())
|
||||
rate_limit_cls = getattr(exceptions, "RateLimitError", tuple())
|
||||
bad_request_cls = getattr(exceptions, "BadRequestError", tuple())
|
||||
timeout_cls = getattr(exceptions, "Timeout", tuple())
|
||||
api_connection_cls = getattr(exceptions, "APIConnectionError", tuple())
|
||||
|
||||
if isinstance(exc, not_found_cls):
|
||||
return {
|
||||
"error": "The selected model is unavailable for this provider. Choose a different model and try again.",
|
||||
"error_category": "model_not_found",
|
||||
}
|
||||
|
||||
if isinstance(exc, auth_cls):
|
||||
return {
|
||||
"error": "Authentication with the model provider failed. Verify your API key in Settings and try again.",
|
||||
"error_category": "authentication_failed",
|
||||
}
|
||||
|
||||
if isinstance(exc, rate_limit_cls):
|
||||
return {
|
||||
"error": "The model provider rate limit was reached. Please wait and try again.",
|
||||
"error_category": "rate_limited",
|
||||
}
|
||||
|
||||
if isinstance(exc, bad_request_cls):
|
||||
return {
|
||||
"error": "The model provider rejected this request. Check your selected model and try again.",
|
||||
"error_category": "invalid_request",
|
||||
}
|
||||
|
||||
if isinstance(exc, timeout_cls) or isinstance(exc, api_connection_cls):
|
||||
return {
|
||||
"error": "Unable to reach the model provider right now. Please try again.",
|
||||
"error_category": "provider_unreachable",
|
||||
}
|
||||
|
||||
return {
|
||||
"error": "An error occurred while processing your request. Please try again.",
|
||||
"error_category": "unknown_error",
|
||||
}
|
||||
|
||||
|
||||
def _safe_get(obj, key, default=None):
|
||||
if obj is None:
|
||||
return default
|
||||
@@ -90,9 +173,20 @@ def is_chat_provider_available(provider_id):
|
||||
return normalized_provider in CHAT_PROVIDER_CONFIG
|
||||
|
||||
|
||||
def get_provider_catalog():
|
||||
def get_provider_catalog(user=None):
|
||||
seen = set()
|
||||
catalog = []
|
||||
user_key_providers = set()
|
||||
instance_provider = (
|
||||
_normalize_provider_id(settings.VOYAGE_AI_PROVIDER)
|
||||
if settings.VOYAGE_AI_PROVIDER
|
||||
else None
|
||||
)
|
||||
instance_has_key = bool(settings.VOYAGE_AI_API_KEY)
|
||||
if user:
|
||||
user_key_providers = set(
|
||||
UserAPIKey.objects.filter(user=user).values_list("provider", flat=True)
|
||||
)
|
||||
|
||||
for provider_id in getattr(litellm, "provider_list", []):
|
||||
normalized_provider = _normalize_provider_id(provider_id)
|
||||
@@ -110,6 +204,9 @@ def get_provider_catalog():
|
||||
"needs_api_key": provider_config["needs_api_key"],
|
||||
"default_model": provider_config["default_model"],
|
||||
"api_base": provider_config["api_base"],
|
||||
"instance_configured": instance_has_key
|
||||
and normalized_provider == instance_provider,
|
||||
"user_configured": normalized_provider in user_key_providers,
|
||||
}
|
||||
)
|
||||
continue
|
||||
@@ -122,6 +219,9 @@ def get_provider_catalog():
|
||||
"needs_api_key": None,
|
||||
"default_model": None,
|
||||
"api_base": None,
|
||||
"instance_configured": instance_has_key
|
||||
and normalized_provider == instance_provider,
|
||||
"user_configured": normalized_provider in user_key_providers,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -141,6 +241,9 @@ def get_provider_catalog():
|
||||
"needs_api_key": provider_config["needs_api_key"],
|
||||
"default_model": provider_config["default_model"],
|
||||
"api_base": provider_config["api_base"],
|
||||
"instance_configured": instance_has_key
|
||||
and normalized_provider == instance_provider,
|
||||
"user_configured": normalized_provider in user_key_providers,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -154,9 +257,55 @@ def get_llm_api_key(user, provider):
|
||||
key_obj = UserAPIKey.objects.get(user=user, provider=normalized_provider)
|
||||
return key_obj.get_api_key()
|
||||
except UserAPIKey.DoesNotExist:
|
||||
if normalized_provider == _normalize_provider_id(settings.VOYAGE_AI_PROVIDER):
|
||||
instance_api_key = (settings.VOYAGE_AI_API_KEY or "").strip()
|
||||
if instance_api_key:
|
||||
return instance_api_key
|
||||
return None
|
||||
|
||||
|
||||
def _format_interests(interests):
|
||||
if isinstance(interests, list):
|
||||
return ", ".join(interests)
|
||||
return interests
|
||||
|
||||
|
||||
def get_aggregated_preferences(collection):
|
||||
"""Aggregate preferences from collection owner and shared users."""
|
||||
from integrations.models import UserRecommendationPreferenceProfile
|
||||
|
||||
users = [collection.user] + list(collection.shared_with.all())
|
||||
preferences = []
|
||||
|
||||
for member in users:
|
||||
try:
|
||||
profile = UserRecommendationPreferenceProfile.objects.get(user=member)
|
||||
user_prefs = []
|
||||
|
||||
if profile.cuisines:
|
||||
user_prefs.append(f"cuisines: {profile.cuisines}")
|
||||
if profile.interests:
|
||||
user_prefs.append(f"interests: {_format_interests(profile.interests)}")
|
||||
if profile.trip_style:
|
||||
user_prefs.append(f"style: {profile.trip_style}")
|
||||
if profile.notes:
|
||||
user_prefs.append(f"notes: {profile.notes}")
|
||||
|
||||
if user_prefs:
|
||||
preferences.append(f"- **{member.username}**: {', '.join(user_prefs)}")
|
||||
except UserRecommendationPreferenceProfile.DoesNotExist:
|
||||
continue
|
||||
|
||||
if preferences:
|
||||
return (
|
||||
"\n\n## Party Preferences\n"
|
||||
+ "\n".join(preferences)
|
||||
+ "\n\nNote: Consider all travelers' preferences when making recommendations."
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def get_system_prompt(user, collection=None):
|
||||
"""Build the system prompt with user context."""
|
||||
from integrations.models import UserRecommendationPreferenceProfile
|
||||
@@ -181,26 +330,37 @@ When modifying itineraries:
|
||||
|
||||
Be conversational, helpful, and enthusiastic about travel. Keep responses concise but informative."""
|
||||
|
||||
try:
|
||||
profile = UserRecommendationPreferenceProfile.objects.get(user=user)
|
||||
prefs = []
|
||||
if profile.cuisines:
|
||||
prefs.append(f"Cuisine preferences: {profile.cuisines}")
|
||||
if profile.interests:
|
||||
prefs.append(f"Interests: {profile.interests}")
|
||||
if profile.trip_style:
|
||||
prefs.append(f"Travel style: {profile.trip_style}")
|
||||
if profile.notes:
|
||||
prefs.append(f"Additional notes: {profile.notes}")
|
||||
if prefs:
|
||||
base_prompt += "\n\nUser preferences:\n" + "\n".join(prefs)
|
||||
except UserRecommendationPreferenceProfile.DoesNotExist:
|
||||
pass
|
||||
if collection and collection.shared_with.count() > 0:
|
||||
base_prompt += get_aggregated_preferences(collection)
|
||||
else:
|
||||
try:
|
||||
profile = UserRecommendationPreferenceProfile.objects.get(user=user)
|
||||
preference_lines = []
|
||||
|
||||
if profile.cuisines:
|
||||
preference_lines.append(
|
||||
f"🍽️ **Cuisine Preferences**: {profile.cuisines}"
|
||||
)
|
||||
if profile.interests:
|
||||
preference_lines.append(
|
||||
f"🎯 **Interests**: {_format_interests(profile.interests)}"
|
||||
)
|
||||
if profile.trip_style:
|
||||
preference_lines.append(f"✈️ **Travel Style**: {profile.trip_style}")
|
||||
if profile.notes:
|
||||
preference_lines.append(f"📝 **Additional Notes**: {profile.notes}")
|
||||
|
||||
if preference_lines:
|
||||
base_prompt += "\n\n## Traveler Preferences\n" + "\n".join(
|
||||
preference_lines
|
||||
)
|
||||
except UserRecommendationPreferenceProfile.DoesNotExist:
|
||||
pass
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
async def stream_chat_completion(user, messages, provider, tools=None):
|
||||
async def stream_chat_completion(user, messages, provider, tools=None, model=None):
|
||||
"""Stream a chat completion using LiteLLM.
|
||||
|
||||
Yields SSE-formatted strings.
|
||||
@@ -215,6 +375,7 @@ async def stream_chat_completion(user, messages, provider, tools=None):
|
||||
return
|
||||
|
||||
api_key = get_llm_api_key(user, normalized_provider)
|
||||
|
||||
if provider_config["needs_api_key"] and not api_key:
|
||||
payload = {
|
||||
"error": f"No API key found for provider: {normalized_provider}. Please add one in Settings."
|
||||
@@ -222,14 +383,31 @@ async def stream_chat_completion(user, messages, provider, tools=None):
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
return
|
||||
|
||||
if not _is_model_override_compatible(normalized_provider, provider_config, model):
|
||||
payload = {
|
||||
"error": "The selected model is incompatible with this provider. Choose a model for the selected provider and try again.",
|
||||
"error_category": "invalid_model_for_provider",
|
||||
}
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
return
|
||||
|
||||
completion_kwargs = {
|
||||
"model": provider_config["default_model"],
|
||||
"model": model
|
||||
or (
|
||||
settings.VOYAGE_AI_MODEL
|
||||
if normalized_provider
|
||||
== _normalize_provider_id(settings.VOYAGE_AI_PROVIDER)
|
||||
and settings.VOYAGE_AI_MODEL
|
||||
else None
|
||||
)
|
||||
or provider_config["default_model"],
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto" if tools else None,
|
||||
"stream": True,
|
||||
"api_key": api_key,
|
||||
}
|
||||
if tools:
|
||||
completion_kwargs["tools"] = tools
|
||||
completion_kwargs["tool_choice"] = "auto"
|
||||
if provider_config["api_base"]:
|
||||
completion_kwargs["api_base"] = provider_config["api_base"]
|
||||
|
||||
@@ -271,6 +449,7 @@ async def stream_chat_completion(user, messages, provider, tools=None):
|
||||
yield f"data: {json.dumps(chunk_data)}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
logger.exception("LLM streaming error")
|
||||
yield f"data: {json.dumps({'error': 'An error occurred while processing your request. Please try again.'})}\n\n"
|
||||
payload = _safe_error_payload(exc)
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
Reference in New Issue
Block a user