feat(chat): add dynamic provider catalog and zen support
This commit is contained in:
@@ -7,15 +7,61 @@ from integrations.models import UserAPIKey
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROVIDER_MODELS = {
|
||||
"openai": "gpt-4o",
|
||||
"anthropic": "anthropic/claude-sonnet-4-20250514",
|
||||
"gemini": "gemini/gemini-2.0-flash",
|
||||
"ollama": "ollama/llama3.1",
|
||||
"groq": "groq/llama-3.3-70b-versatile",
|
||||
"mistral": "mistral/mistral-large-latest",
|
||||
"github_models": "github/gpt-4o",
|
||||
"openrouter": "openrouter/auto",
|
||||
CHAT_PROVIDER_CONFIG = {
|
||||
"openai": {
|
||||
"label": "OpenAI",
|
||||
"needs_api_key": True,
|
||||
"default_model": "gpt-4o",
|
||||
"api_base": None,
|
||||
},
|
||||
"anthropic": {
|
||||
"label": "Anthropic",
|
||||
"needs_api_key": True,
|
||||
"default_model": "anthropic/claude-sonnet-4-20250514",
|
||||
"api_base": None,
|
||||
},
|
||||
"gemini": {
|
||||
"label": "Google Gemini",
|
||||
"needs_api_key": True,
|
||||
"default_model": "gemini/gemini-2.0-flash",
|
||||
"api_base": None,
|
||||
},
|
||||
"ollama": {
|
||||
"label": "Ollama",
|
||||
"needs_api_key": True,
|
||||
"default_model": "ollama/llama3.1",
|
||||
"api_base": None,
|
||||
},
|
||||
"groq": {
|
||||
"label": "Groq",
|
||||
"needs_api_key": True,
|
||||
"default_model": "groq/llama-3.3-70b-versatile",
|
||||
"api_base": None,
|
||||
},
|
||||
"mistral": {
|
||||
"label": "Mistral",
|
||||
"needs_api_key": True,
|
||||
"default_model": "mistral/mistral-large-latest",
|
||||
"api_base": None,
|
||||
},
|
||||
"github_models": {
|
||||
"label": "GitHub Models",
|
||||
"needs_api_key": True,
|
||||
"default_model": "github/gpt-4o",
|
||||
"api_base": None,
|
||||
},
|
||||
"openrouter": {
|
||||
"label": "OpenRouter",
|
||||
"needs_api_key": True,
|
||||
"default_model": "openrouter/auto",
|
||||
"api_base": None,
|
||||
},
|
||||
"opencode_zen": {
|
||||
"label": "OpenCode Zen",
|
||||
"needs_api_key": True,
|
||||
"default_model": "openai/gpt-4o-mini",
|
||||
"api_base": "https://opencode.ai/zen/v1",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -27,9 +73,82 @@ def _safe_get(obj, key, default=None):
|
||||
return getattr(obj, key, default)
|
||||
|
||||
|
||||
def _normalize_provider_id(provider_id):
|
||||
value = str(provider_id or "").strip()
|
||||
if value.startswith("LlmProviders."):
|
||||
value = value.split(".", 1)[1]
|
||||
return value.lower()
|
||||
|
||||
|
||||
def _default_provider_label(provider_id):
|
||||
return provider_id.replace("_", " ").title()
|
||||
|
||||
|
||||
def is_chat_provider_available(provider_id):
|
||||
normalized_provider = _normalize_provider_id(provider_id)
|
||||
return normalized_provider in CHAT_PROVIDER_CONFIG
|
||||
|
||||
|
||||
def get_provider_catalog():
|
||||
seen = set()
|
||||
catalog = []
|
||||
|
||||
for provider_id in getattr(litellm, "provider_list", []):
|
||||
normalized_provider = _normalize_provider_id(provider_id)
|
||||
if not normalized_provider or normalized_provider in seen:
|
||||
continue
|
||||
|
||||
seen.add(normalized_provider)
|
||||
provider_config = CHAT_PROVIDER_CONFIG.get(normalized_provider)
|
||||
if provider_config:
|
||||
catalog.append(
|
||||
{
|
||||
"id": normalized_provider,
|
||||
"label": provider_config["label"],
|
||||
"available_for_chat": True,
|
||||
"needs_api_key": provider_config["needs_api_key"],
|
||||
"default_model": provider_config["default_model"],
|
||||
"api_base": provider_config["api_base"],
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
catalog.append(
|
||||
{
|
||||
"id": normalized_provider,
|
||||
"label": _default_provider_label(normalized_provider),
|
||||
"available_for_chat": False,
|
||||
"needs_api_key": None,
|
||||
"default_model": None,
|
||||
"api_base": None,
|
||||
}
|
||||
)
|
||||
|
||||
# Include app-supported OpenAI-compatible aliases that are not part of
|
||||
# LiteLLM's native provider list (for example OpenCode Zen).
|
||||
for provider_id, provider_config in CHAT_PROVIDER_CONFIG.items():
|
||||
normalized_provider = _normalize_provider_id(provider_id)
|
||||
if not normalized_provider or normalized_provider in seen:
|
||||
continue
|
||||
|
||||
seen.add(normalized_provider)
|
||||
catalog.append(
|
||||
{
|
||||
"id": normalized_provider,
|
||||
"label": provider_config["label"],
|
||||
"available_for_chat": True,
|
||||
"needs_api_key": provider_config["needs_api_key"],
|
||||
"default_model": provider_config["default_model"],
|
||||
"api_base": provider_config["api_base"],
|
||||
}
|
||||
)
|
||||
|
||||
return catalog
|
||||
|
||||
|
||||
def get_llm_api_key(user, provider):
|
||||
"""Get the user's API key for the given provider."""
|
||||
normalized_provider = (provider or "").strip().lower()
|
||||
normalized_provider = _normalize_provider_id(provider)
|
||||
try:
|
||||
key_obj = UserAPIKey.objects.get(user=user, provider=normalized_provider)
|
||||
return key_obj.get_api_key()
|
||||
@@ -85,26 +204,36 @@ async def stream_chat_completion(user, messages, provider, tools=None):
|
||||
|
||||
Yields SSE-formatted strings.
|
||||
"""
|
||||
normalized_provider = (provider or "").strip().lower()
|
||||
normalized_provider = _normalize_provider_id(provider)
|
||||
provider_config = CHAT_PROVIDER_CONFIG.get(normalized_provider)
|
||||
if not provider_config:
|
||||
payload = {
|
||||
"error": f"Provider is not available for chat: {normalized_provider}."
|
||||
}
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
return
|
||||
|
||||
api_key = get_llm_api_key(user, normalized_provider)
|
||||
if not api_key:
|
||||
if provider_config["needs_api_key"] and not api_key:
|
||||
payload = {
|
||||
"error": f"No API key found for provider: {normalized_provider}. Please add one in Settings."
|
||||
}
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
return
|
||||
|
||||
model = PROVIDER_MODELS.get(normalized_provider, "gpt-4o")
|
||||
completion_kwargs = {
|
||||
"model": provider_config["default_model"],
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto" if tools else None,
|
||||
"stream": True,
|
||||
"api_key": api_key,
|
||||
}
|
||||
if provider_config["api_base"]:
|
||||
completion_kwargs["api_base"] = provider_config["api_base"]
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto" if tools else None,
|
||||
stream=True,
|
||||
api_key=api_key,
|
||||
)
|
||||
response = await litellm.acompletion(**completion_kwargs)
|
||||
|
||||
async for chunk in response:
|
||||
choices = _safe_get(chunk, "choices", []) or []
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from django.urls import include, path
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
from .views import ChatViewSet
|
||||
from .views import ChatProviderCatalogViewSet, ChatViewSet
|
||||
|
||||
router = DefaultRouter()
|
||||
router.register(r"conversations", ChatViewSet, basename="chat-conversation")
|
||||
router.register(
|
||||
r"providers", ChatProviderCatalogViewSet, basename="chat-provider-catalog"
|
||||
)
|
||||
|
||||
urlpatterns = [
|
||||
path("", include(router.urls)),
|
||||
|
||||
@@ -9,7 +9,12 @@ from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from .agent_tools import AGENT_TOOLS, execute_tool, serialize_tool_result
|
||||
from .llm_client import get_system_prompt, stream_chat_completion
|
||||
from .llm_client import (
|
||||
get_provider_catalog,
|
||||
get_system_prompt,
|
||||
is_chat_provider_available,
|
||||
stream_chat_completion,
|
||||
)
|
||||
from .models import ChatConversation, ChatMessage
|
||||
from .serializers import ChatConversationSerializer
|
||||
|
||||
@@ -106,6 +111,11 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
)
|
||||
|
||||
provider = (request.data.get("provider") or "openai").strip().lower()
|
||||
if not is_chat_provider_available(provider):
|
||||
return Response(
|
||||
{"error": f"Provider is not available for chat: {provider}."},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
ChatMessage.objects.create(
|
||||
conversation=conversation,
|
||||
@@ -262,3 +272,10 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
response["Cache-Control"] = "no-cache"
|
||||
response["X-Accel-Buffering"] = "no"
|
||||
return response
|
||||
|
||||
|
||||
class ChatProviderCatalogViewSet(viewsets.ViewSet):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def list(self, request):
|
||||
return Response(get_provider_catalog())
|
||||
|
||||
Reference in New Issue
Block a user