import json import logging import litellm from asgiref.sync import sync_to_async from django.conf import settings from integrations.models import UserAPIKey logger = logging.getLogger(__name__) PROVIDER_MODEL_PREFIX = { "openai": "openai", "anthropic": "anthropic", "gemini": "gemini", "ollama": "ollama", "groq": "groq", "mistral": "mistral", "github_models": "github", "openrouter": "openrouter", } CHAT_PROVIDER_CONFIG = { "openai": { "label": "OpenAI", "needs_api_key": True, "default_model": "gpt-4o", "api_base": None, }, "anthropic": { "label": "Anthropic", "needs_api_key": True, "default_model": "anthropic/claude-sonnet-4-20250514", "api_base": None, }, "gemini": { "label": "Google Gemini", "needs_api_key": True, "default_model": "gemini/gemini-2.0-flash", "api_base": None, }, "ollama": { "label": "Ollama", "needs_api_key": True, "default_model": "ollama/llama3.1", "api_base": None, }, "groq": { "label": "Groq", "needs_api_key": True, "default_model": "groq/llama-3.3-70b-versatile", "api_base": None, }, "mistral": { "label": "Mistral", "needs_api_key": True, "default_model": "mistral/mistral-large-latest", "api_base": None, }, "github_models": { "label": "GitHub Models", "needs_api_key": True, "default_model": "github/gpt-4o", "api_base": None, }, "openrouter": { "label": "OpenRouter", "needs_api_key": True, "default_model": "openrouter/auto", "api_base": None, }, "opencode_zen": { "label": "OpenCode Zen", "needs_api_key": True, # Chosen from OpenCode Zen compatible OpenAI-routed models per # opencode_zen connection research (see .memory/research). "default_model": "openai/gpt-5-nano", "api_base": "https://opencode.ai/zen/v1", }, } def _is_model_override_compatible(provider, provider_config, model): """Validate model/provider compatibility when strict checks are safe. For providers with a custom api_base gateway, skip strict prefix checks since gateway routing may legitimately accept cross-provider prefixes. """ if not model or provider_config.get("api_base"): return True if "/" not in model: return True expected_prefix = PROVIDER_MODEL_PREFIX.get(provider) if not expected_prefix: default_model = provider_config.get("default_model") or "" if "/" in default_model: expected_prefix = default_model.split("/", 1)[0] if not expected_prefix: return True return model.startswith(f"{expected_prefix}/") def _safe_error_payload(exc): exceptions = getattr(litellm, "exceptions", None) not_found_cls = getattr(exceptions, "NotFoundError", tuple()) auth_cls = getattr(exceptions, "AuthenticationError", tuple()) rate_limit_cls = getattr(exceptions, "RateLimitError", tuple()) bad_request_cls = getattr(exceptions, "BadRequestError", tuple()) timeout_cls = getattr(exceptions, "Timeout", tuple()) api_connection_cls = getattr(exceptions, "APIConnectionError", tuple()) if isinstance(exc, not_found_cls): return { "error": "The selected model is unavailable for this provider. Choose a different model and try again.", "error_category": "model_not_found", } if isinstance(exc, auth_cls): return { "error": "Authentication with the model provider failed. Verify your API key in Settings and try again.", "error_category": "authentication_failed", } if isinstance(exc, rate_limit_cls): return { "error": "The model provider rate limit was reached. Please wait and try again.", "error_category": "rate_limited", } if isinstance(exc, bad_request_cls): return { "error": "The model provider rejected this request. Check your selected model and try again.", "error_category": "invalid_request", } if isinstance(exc, timeout_cls) or isinstance(exc, api_connection_cls): return { "error": "Unable to reach the model provider right now. Please try again.", "error_category": "provider_unreachable", } return { "error": "An error occurred while processing your request. Please try again.", "error_category": "unknown_error", } def _safe_get(obj, key, default=None): if obj is None: return default if isinstance(obj, dict): return obj.get(key, default) return getattr(obj, key, default) def _normalize_provider_id(provider_id): value = str(provider_id or "").strip() lowered = value.lower() if lowered.startswith("llmproviders."): return lowered.split(".", 1)[1] return lowered def normalize_gateway_model(provider_id, model): normalized_provider = _normalize_provider_id(provider_id) normalized_model = str(model or "").strip() if not normalized_model: return None if normalized_provider == "opencode_zen" and "/" not in normalized_model: return f"openai/{normalized_model}" return normalized_model def _default_provider_label(provider_id): return provider_id.replace("_", " ").title() def is_chat_provider_available(provider_id): normalized_provider = _normalize_provider_id(provider_id) return normalized_provider in CHAT_PROVIDER_CONFIG def get_provider_catalog(user=None): seen = set() catalog = [] user_key_providers = set() instance_provider = ( _normalize_provider_id(settings.VOYAGE_AI_PROVIDER) if settings.VOYAGE_AI_PROVIDER else None ) instance_has_key = bool(settings.VOYAGE_AI_API_KEY) if user: user_key_providers = set( UserAPIKey.objects.filter(user=user).values_list("provider", flat=True) ) for provider_id in getattr(litellm, "provider_list", []): normalized_provider = _normalize_provider_id(provider_id) if not normalized_provider or normalized_provider in seen: continue seen.add(normalized_provider) provider_config = CHAT_PROVIDER_CONFIG.get(normalized_provider) if provider_config: catalog.append( { "id": normalized_provider, "label": provider_config["label"], "available_for_chat": True, "needs_api_key": provider_config["needs_api_key"], "default_model": provider_config["default_model"], "api_base": provider_config["api_base"], "instance_configured": instance_has_key and normalized_provider == instance_provider, "user_configured": normalized_provider in user_key_providers, } ) continue catalog.append( { "id": normalized_provider, "label": _default_provider_label(normalized_provider), "available_for_chat": False, "needs_api_key": None, "default_model": None, "api_base": None, "instance_configured": instance_has_key and normalized_provider == instance_provider, "user_configured": normalized_provider in user_key_providers, } ) # Include app-supported OpenAI-compatible aliases that are not part of # LiteLLM's native provider list (for example OpenCode Zen). for provider_id, provider_config in CHAT_PROVIDER_CONFIG.items(): normalized_provider = _normalize_provider_id(provider_id) if not normalized_provider or normalized_provider in seen: continue seen.add(normalized_provider) catalog.append( { "id": normalized_provider, "label": provider_config["label"], "available_for_chat": True, "needs_api_key": provider_config["needs_api_key"], "default_model": provider_config["default_model"], "api_base": provider_config["api_base"], "instance_configured": instance_has_key and normalized_provider == instance_provider, "user_configured": normalized_provider in user_key_providers, } ) return catalog def get_llm_api_key(user, provider): """Get the user's API key for the given provider.""" normalized_provider = _normalize_provider_id(provider) try: key_obj = UserAPIKey.objects.get(user=user, provider=normalized_provider) return key_obj.get_api_key() except UserAPIKey.DoesNotExist: if normalized_provider == _normalize_provider_id(settings.VOYAGE_AI_PROVIDER): instance_api_key = (settings.VOYAGE_AI_API_KEY or "").strip() if instance_api_key: return instance_api_key return None def _format_interests(interests): if isinstance(interests, list): return ", ".join(interests) return interests def get_aggregated_preferences(collection): """Aggregate preferences from collection owner and shared users.""" from integrations.models import UserRecommendationPreferenceProfile users = [collection.user] + list(collection.shared_with.all()) preferences = [] for member in users: try: profile = UserRecommendationPreferenceProfile.objects.get(user=member) user_prefs = [] if profile.cuisines: user_prefs.append(f"cuisines: {profile.cuisines}") if profile.interests: user_prefs.append(f"interests: {_format_interests(profile.interests)}") if profile.trip_style: user_prefs.append(f"style: {profile.trip_style}") if profile.notes: user_prefs.append(f"notes: {profile.notes}") if user_prefs: preferences.append(f"- **{member.username}**: {', '.join(user_prefs)}") except UserRecommendationPreferenceProfile.DoesNotExist: continue if preferences: return ( "\n\n## Party Preferences\n" + "\n".join(preferences) + "\n\nNote: Consider all travelers' preferences when making recommendations." ) return "" def get_system_prompt(user, collection=None): """Build the system prompt with user context.""" from integrations.models import UserRecommendationPreferenceProfile base_prompt = """You are a helpful travel planning assistant for the Voyage travel app. You help users discover places, plan trips, and organize their itineraries. Your capabilities: - Search for interesting places (restaurants, tourist attractions, hotels) near any location - View and manage the user's trip collections and itineraries - Add new locations to trip itineraries - Check weather/temperature data for travel dates When suggesting places: - Be specific with names, addresses, and why a place is worth visiting - Consider the user's travel dates and weather conditions - Group suggestions logically (by area, by type, by day) When modifying itineraries: - Always confirm with the user before adding items - Suggest logical ordering based on geography - Consider travel time between locations When chat context includes a trip collection: - Treat context as itinerary-wide (potentially multiple stops), not a single destination - Use get_trip_details first when you need complete collection context before searching for places - Ground place searches in trip stops and dates from the provided trip context Be conversational, helpful, and enthusiastic about travel. Keep responses concise but informative.""" if collection and collection.shared_with.count() > 0: base_prompt += get_aggregated_preferences(collection) else: try: profile = UserRecommendationPreferenceProfile.objects.get(user=user) if profile.interests or profile.trip_style or profile.notes: base_prompt += "\n\n## Traveler Preferences\n" base_prompt += "*(Automatically inferred from travel history)*\n\n" if profile.interests: interests_str = ( ", ".join(profile.interests) if isinstance(profile.interests, list) else str(profile.interests) ) base_prompt += f"🎯 **Interests**: {interests_str}\n" if profile.trip_style: base_prompt += f"✈️ **Travel Style**: {profile.trip_style}\n" if profile.notes: base_prompt += f"📍 **Patterns**: {profile.notes}\n" except UserRecommendationPreferenceProfile.DoesNotExist: pass return base_prompt async def stream_chat_completion(user, messages, provider, tools=None, model=None): """Stream a chat completion using LiteLLM. Yields SSE-formatted strings. """ normalized_provider = _normalize_provider_id(provider) provider_config = CHAT_PROVIDER_CONFIG.get(normalized_provider) if not provider_config: payload = { "error": f"Provider is not available for chat: {normalized_provider}." } yield f"data: {json.dumps(payload)}\n\n" return api_key = await sync_to_async(get_llm_api_key)(user, normalized_provider) if provider_config["needs_api_key"] and not api_key: payload = { "error": f"No API key found for provider: {normalized_provider}. Please add one in Settings." } yield f"data: {json.dumps(payload)}\n\n" return if not _is_model_override_compatible(normalized_provider, provider_config, model): payload = { "error": "The selected model is incompatible with this provider. Choose a model for the selected provider and try again.", "error_category": "invalid_model_for_provider", } yield f"data: {json.dumps(payload)}\n\n" return resolved_model = ( model or ( settings.VOYAGE_AI_MODEL if normalized_provider == _normalize_provider_id(settings.VOYAGE_AI_PROVIDER) and settings.VOYAGE_AI_MODEL else None ) or provider_config["default_model"] ) resolved_model = normalize_gateway_model(normalized_provider, resolved_model) if tools and not litellm.supports_function_calling(model=resolved_model): logger.warning( "Model %s does not support function calling, disabling tools", resolved_model, ) tools = None logger.info( "Chat request: provider=%s, model=%s, has_tools=%s", normalized_provider, resolved_model, bool(tools), ) logger.debug( "API base: %s, messages count: %s", provider_config.get("api_base"), len(messages), ) completion_kwargs = { "model": resolved_model, "messages": messages, "stream": True, "api_key": api_key, "num_retries": 2, } if tools: completion_kwargs["tools"] = tools completion_kwargs["tool_choice"] = "auto" if provider_config["api_base"]: completion_kwargs["api_base"] = provider_config["api_base"] try: response = await litellm.acompletion(**completion_kwargs) async for chunk in response: choices = _safe_get(chunk, "choices", []) or [] if not choices: continue delta = _safe_get(choices[0], "delta") if not delta: continue chunk_data = {} content = _safe_get(delta, "content") if content: chunk_data["content"] = content tool_calls = _safe_get(delta, "tool_calls") or [] if tool_calls: serialized = [] for tool_call in tool_calls: function = _safe_get(tool_call, "function") serialized.append( { "id": _safe_get(tool_call, "id"), "type": _safe_get(tool_call, "type"), "function": { "name": _safe_get(function, "name", "") or "", "arguments": _safe_get(function, "arguments", "") or "", }, } ) chunk_data["tool_calls"] = serialized if chunk_data: yield f"data: {json.dumps(chunk_data)}\n\n" yield "data: [DONE]\n\n" except Exception as exc: logger.error("LiteLLM error: %s: %s", type(exc).__name__, str(exc)[:200]) logger.exception("LLM streaming error") payload = _safe_error_payload(exc) yield f"data: {json.dumps(payload)}\n\n"