Files
voyage/backend/server/chat/views/__init__.py
alex wiesner c4d39f2812 changes
2026-03-13 20:15:22 +00:00

1150 lines
44 KiB
Python

import asyncio
import json
import logging
import re
from datetime import timedelta
from asgiref.sync import sync_to_async
from adventures.models import Collection
from django.http import StreamingHttpResponse
from django.utils import timezone
from integrations.models import UserAISettings
from rest_framework import status, viewsets
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from ..agent_tools import AGENT_TOOLS, execute_tool, serialize_tool_result
from ..llm_client import (
get_provider_catalog,
get_system_prompt,
is_chat_provider_available,
stream_chat_completion,
)
from ..models import ChatConversation, ChatMessage
from ..serializers import ChatConversationSerializer
logger = logging.getLogger(__name__)
class ChatViewSet(viewsets.ModelViewSet):
serializer_class = ChatConversationSerializer
permission_classes = [IsAuthenticated]
def get_queryset(self):
return ChatConversation.objects.filter(user=self.request.user).prefetch_related(
"messages"
)
def list(self, request, *args, **kwargs):
conversations = self.get_queryset().only("id", "title", "updated_at")
data = [
{
"id": str(conversation.id),
"title": conversation.title,
"updated_at": conversation.updated_at,
}
for conversation in conversations
]
return Response(data)
def create(self, request, *args, **kwargs):
conversation = ChatConversation.objects.create(
user=request.user,
title=(request.data.get("title") or "").strip(),
)
serializer = self.get_serializer(conversation)
return Response(serializer.data, status=status.HTTP_201_CREATED)
def _build_llm_messages(self, conversation, user, system_prompt=None):
ordered_messages = list(conversation.messages.all().order_by("created_at"))
valid_tool_call_ids = {
message.tool_call_id
for message in ordered_messages
if message.role == "tool"
and message.tool_call_id
and not self._is_required_param_tool_error_message_content(message.content)
and not self._is_execution_failure_tool_error_message_content(
message.content
)
}
messages = [
{
"role": "system",
"content": system_prompt or get_system_prompt(user),
}
]
for message in ordered_messages:
if (
message.role == "tool"
and self._is_required_param_tool_error_message_content(message.content)
):
continue
if (
message.role == "tool"
and self._is_execution_failure_tool_error_message_content(
message.content
)
):
continue
payload = {
"role": message.role,
"content": message.content,
}
if message.role == "assistant" and message.tool_calls:
filtered_tool_calls = [
tool_call
for tool_call in message.tool_calls
if (tool_call or {}).get("id") in valid_tool_call_ids
]
if filtered_tool_calls:
payload["tool_calls"] = filtered_tool_calls
if message.role == "tool":
payload["tool_call_id"] = message.tool_call_id
payload["name"] = message.name
messages.append(payload)
return messages
def _async_to_sync_generator(self, async_gen):
loop = asyncio.new_event_loop()
try:
while True:
try:
yield loop.run_until_complete(async_gen.__anext__())
except StopAsyncIteration:
break
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
@staticmethod
def _merge_tool_call_delta(accumulator, tool_calls_delta):
for idx, tool_call in enumerate(tool_calls_delta or []):
idx = tool_call.get("index", idx)
while len(accumulator) <= idx:
accumulator.append(
{
"id": None,
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
current = accumulator[idx]
if tool_call.get("id"):
current["id"] = tool_call.get("id")
if tool_call.get("type"):
current["type"] = tool_call.get("type")
function_data = tool_call.get("function") or {}
if function_data.get("name"):
current["function"]["name"] = function_data.get("name")
if function_data.get("arguments"):
current["function"]["arguments"] += function_data.get("arguments")
@staticmethod
def _is_required_param_tool_error(result):
if not isinstance(result, dict):
return False
error_text = result.get("error")
if not isinstance(error_text, str):
return False
normalized_error = error_text.strip().lower()
if normalized_error in {"location is required", "query is required"}:
return True
return bool(
re.fullmatch(
r"[a-z0-9_,\-\s]+ (is|are) required",
normalized_error,
)
)
@classmethod
def _is_execution_failure_tool_error(cls, result):
if not isinstance(result, dict):
return False
error_text = result.get("error")
if not isinstance(error_text, str) or not error_text.strip():
return False
return not cls._is_required_param_tool_error(result)
@staticmethod
def _is_retryable_execution_failure(result):
if not isinstance(result, dict):
return False
return result.get("retryable", True) is not False
@classmethod
def _is_required_param_tool_error_message_content(cls, content):
if not isinstance(content, str):
return False
try:
parsed = json.loads(content)
except json.JSONDecodeError:
return False
return cls._is_required_param_tool_error(parsed)
@classmethod
def _is_execution_failure_tool_error_message_content(cls, content):
if not isinstance(content, str):
return False
try:
parsed = json.loads(content)
except json.JSONDecodeError:
return False
return cls._is_execution_failure_tool_error(parsed)
@staticmethod
def _build_required_param_error_event(tool_name, result):
tool_error = result.get("error") if isinstance(result, dict) else ""
return {
"error": (
"The assistant attempted to call "
f"'{tool_name}' without required arguments ({tool_error}). "
"Please try your message again with more specific details."
),
"error_category": "tool_validation_error",
}
@classmethod
def _is_search_places_missing_location_required_error(cls, tool_name, result):
if tool_name != "search_places" or not cls._is_required_param_tool_error(
result
):
return False
error_text = (result or {}).get("error") if isinstance(result, dict) else ""
if not isinstance(error_text, str):
return False
normalized_error = error_text.strip().lower()
return "location" in normalized_error
@staticmethod
def _is_search_places_geocode_error(tool_name, result):
if tool_name != "search_places" or not isinstance(result, dict):
return False
error_text = result.get("error")
if not isinstance(error_text, str):
return False
return error_text.strip().lower().startswith("could not geocode location")
@classmethod
def _is_search_places_location_retry_candidate_error(cls, tool_name, result):
return cls._is_search_places_missing_location_required_error(
tool_name,
result,
) or cls._is_search_places_geocode_error(tool_name, result)
@classmethod
def _is_get_weather_missing_latlong_error(cls, tool_name, result):
"""True when get_weather was called without latitude/longitude."""
if tool_name != "get_weather" or not cls._is_required_param_tool_error(result):
return False
error_text = (result or {}).get("error") if isinstance(result, dict) else ""
if not isinstance(error_text, str):
return False
normalized_error = error_text.strip().lower()
return "latitude" in normalized_error or "longitude" in normalized_error
@staticmethod
def _extract_collection_coordinates(collection):
"""Return (lat, lon) from the first geocoded location in the collection, or None."""
if collection is None:
return None
for location in collection.locations.all():
lat = getattr(location, "latitude", None)
lon = getattr(location, "longitude", None)
if lat is not None and lon is not None:
try:
return float(lat), float(lon)
except (TypeError, ValueError):
continue
return None
@staticmethod
def _derive_weather_dates_from_collection(collection, max_days=7):
"""Derive a bounded weather date list from collection dates, or fallback to today."""
today = timezone.localdate()
if collection is None:
return [today.isoformat()]
start_date = getattr(collection, "start_date", None)
end_date = getattr(collection, "end_date", None)
if start_date and end_date:
range_start = min(start_date, end_date)
range_end = max(start_date, end_date)
day_count = min((range_end - range_start).days + 1, max_days)
return [
(range_start + timedelta(days=offset)).isoformat()
for offset in range(day_count)
]
if start_date:
return [start_date.isoformat()]
if end_date:
return [end_date.isoformat()]
return [today.isoformat()]
@staticmethod
def _build_search_places_location_clarification_message():
return (
"Could you share the specific location you'd like me to search near "
"(city, neighborhood, or address)? I can also focus on food, "
"activities, or lodging."
)
@staticmethod
def _build_tool_execution_error_event(tool_name, result):
tool_error = (
(result or {}).get("error")
if isinstance(result, dict)
else "Tool execution failed"
)
return {
"error": (
f"The assistant could not complete '{tool_name}' ({tool_error}). "
"Please try again in a moment or adjust your request."
),
"error_category": "tool_execution_error",
}
@staticmethod
def _normalize_trip_context_destination(destination):
destination_text = (destination or "").strip()
if not destination_text:
return ""
if ";" not in destination_text:
if re.fullmatch(r"\+\d+\s+more", destination_text, re.IGNORECASE):
return ""
return destination_text
for segment in destination_text.split(";"):
segment_text = segment.strip()
if not segment_text:
continue
if re.fullmatch(r"\+\d+\s+more", segment_text, re.IGNORECASE):
continue
return segment_text
return ""
@classmethod
def _trip_context_search_location(cls, destination, itinerary_stops):
destination_text = cls._normalize_trip_context_destination(destination)
if destination_text:
return destination_text
for stop in itinerary_stops or []:
stop_text = (stop or "").strip()
if stop_text:
return stop_text
return ""
@staticmethod
def _infer_search_places_category(user_content, prior_user_messages):
message_parts = [(user_content or "").strip()]
message_parts.extend(
(content or "").strip() for content in prior_user_messages or []
)
normalized = " ".join(part for part in message_parts if part).lower()
if not normalized:
return None
dining_intent_pattern = (
r"\b(restaurant|restaurants|dining|dinner|lunch|breakfast|brunch|"
r"cafe|cafes|food|eat|eating|cuisine|meal|meals|bistro|bar|bars)\b"
)
if re.search(dining_intent_pattern, normalized):
return "food"
return None
# Verbs that indicate a command/request rather than a location reply.
_COMMAND_VERBS = frozenset(
[
"find",
"search",
"show",
"get",
"look",
"give",
"tell",
"help",
"recommend",
"suggest",
"list",
"fetch",
"what",
"where",
"which",
"who",
"how",
"can",
"could",
"would",
"please",
]
)
@classmethod
def _is_likely_location_reply(cls, user_content):
if not isinstance(user_content, str):
return False
normalized = user_content.strip()
if not normalized:
return False
if normalized.endswith("?"):
return False
if len(normalized) > 80:
return False
parts = normalized.split()
if len(parts) > 6:
return False
# Exclude messages that start with a command/request verb — those are
# original requests, not replies to a "where?" clarification prompt.
first_word = parts[0].lower().rstrip(".,!;:")
if first_word in cls._COMMAND_VERBS:
return False
return bool(re.search(r"[a-z0-9]", normalized, re.IGNORECASE))
@action(detail=True, methods=["post"])
def send_message(self, request, pk=None):
# Auto-learn preferences from user's travel history
from integrations.utils.auto_profile import update_auto_preference_profile
try:
update_auto_preference_profile(request.user)
except Exception as exc:
logger.warning("Auto-profile update failed: %s", exc)
# Continue anyway - not critical
conversation = self.get_object()
user_content = (request.data.get("message") or "").strip()
if not user_content:
return Response(
{"error": "message is required"},
status=status.HTTP_400_BAD_REQUEST,
)
requested_provider = (request.data.get("provider") or "").strip().lower()
requested_model = (request.data.get("model") or "").strip() or None
ai_settings = UserAISettings.objects.filter(user=request.user).first()
preferred_provider = (
(ai_settings.preferred_provider or "").strip().lower()
if ai_settings
else ""
)
preferred_model = (
(ai_settings.preferred_model or "").strip() if ai_settings else ""
)
provider = requested_provider
if not provider and preferred_provider:
if preferred_provider and is_chat_provider_available(preferred_provider):
provider = preferred_provider
if not provider:
provider = "openai"
model = requested_model
if model is None and preferred_model and provider == preferred_provider:
model = preferred_model
collection_id = request.data.get("collection_id")
collection_name = request.data.get("collection_name")
start_date = request.data.get("start_date")
end_date = request.data.get("end_date")
destination = request.data.get("destination")
if not is_chat_provider_available(provider):
return Response(
{"error": f"Provider is not available for chat: {provider}."},
status=status.HTTP_400_BAD_REQUEST,
)
context_parts = []
itinerary_stops = []
if collection_name:
context_parts.append(f"Trip: {collection_name}")
if start_date and end_date:
context_parts.append(f"Dates: {start_date} to {end_date}")
collection = None
if collection_id:
try:
requested_collection = Collection.objects.get(id=collection_id)
if (
requested_collection.user == request.user
or requested_collection.shared_with.filter(
id=request.user.id
).exists()
):
collection = requested_collection
except Collection.DoesNotExist:
pass
if collection:
context_parts.append(
"Collection UUID (use this exact collection_id for get_trip_details and add_to_itinerary): "
f"{collection.id}"
)
seen_stops = set()
for location in collection.locations.select_related(
"city", "country"
).all():
city_name = (getattr(location.city, "name", "") or "").strip()
country_name = (getattr(location.country, "name", "") or "").strip()
if city_name or country_name:
stop_label = (
f"{city_name}, {country_name}"
if city_name and country_name
else city_name or country_name
)
stop_key = f"geo:{city_name.lower()}|{country_name.lower()}"
else:
fallback_name = (location.location or location.name or "").strip()
if not fallback_name:
continue
# When city/country FKs are not set, try to extract a geocodable
# city name from a comma-separated address string.
# e.g. "Little Turnstile 6, London" → "London"
# e.g. "Kingsway 58, London" → "London"
if "," in fallback_name:
parts = [p.strip() for p in fallback_name.split(",")]
# Last non-empty, non-purely-numeric segment is typically the city
city_hint = next(
(p for p in reversed(parts) if p and not p.isdigit()),
None,
)
stop_label = city_hint if city_hint else fallback_name
else:
stop_label = fallback_name
stop_key = f"name:{stop_label.lower()}"
if stop_key in seen_stops:
continue
seen_stops.add(stop_key)
itinerary_stops.append(stop_label)
if len(itinerary_stops) >= 8:
break
if itinerary_stops:
context_parts.append(f"Itinerary stops: {'; '.join(itinerary_stops)}")
trip_context_location = self._trip_context_search_location(
destination, itinerary_stops
)
if trip_context_location:
context_parts.append(f"Destination: {trip_context_location}")
prior_user_messages = list(
conversation.messages.filter(role="user")
.order_by("-created_at")
.values_list("content", flat=True)[:3]
)
system_prompt = get_system_prompt(request.user, collection)
if context_parts:
system_prompt += "\n\n## Trip Context\n" + "\n".join(context_parts)
ChatMessage.objects.create(
conversation=conversation,
role="user",
content=user_content,
)
conversation.save(update_fields=["updated_at"])
if not conversation.title:
conversation.title = user_content[:120]
conversation.save(update_fields=["title", "updated_at"])
llm_messages = self._build_llm_messages(
conversation,
request.user,
system_prompt=system_prompt,
)
MAX_TOOL_ITERATIONS = 10
MAX_ALL_FAILURE_ROUNDS = 3
async def event_stream():
current_messages = list(llm_messages)
encountered_error = False
tool_iterations = 0
all_failure_rounds = 0
while tool_iterations < MAX_TOOL_ITERATIONS:
content_chunks = []
tool_calls_accumulator = []
async for chunk in stream_chat_completion(
request.user,
current_messages,
provider,
tools=AGENT_TOOLS,
model=model,
):
if not chunk.startswith("data: "):
yield chunk
continue
payload = chunk[len("data: ") :].strip()
if payload == "[DONE]":
continue
yield chunk
try:
data = json.loads(payload)
except json.JSONDecodeError:
continue
if data.get("error"):
encountered_error = True
break
if data.get("content"):
content_chunks.append(data["content"])
if data.get("tool_calls"):
self._merge_tool_call_delta(
tool_calls_accumulator,
data["tool_calls"],
)
if encountered_error:
yield "data: [DONE]\n\n"
break
assistant_content = "".join(content_chunks)
if tool_calls_accumulator:
successful_tool_calls = []
successful_tool_messages = []
successful_tool_chat_entries = []
first_execution_failure = None
encountered_permanent_failure = False
for tool_call in tool_calls_accumulator:
function_payload = tool_call.get("function") or {}
function_name = function_payload.get("name") or ""
raw_arguments = function_payload.get("arguments") or "{}"
try:
arguments = json.loads(raw_arguments)
except json.JSONDecodeError:
arguments = {}
if not isinstance(arguments, dict):
arguments = {}
prepared_arguments = dict(arguments)
tool_call_for_history = tool_call
if function_name == "search_places":
if not (prepared_arguments.get("category") or "").strip():
inferred_category = self._infer_search_places_category(
user_content,
prior_user_messages,
)
if inferred_category:
prepared_arguments["category"] = inferred_category
if prepared_arguments != arguments:
tool_call_for_history = {
**tool_call,
"function": {
**function_payload,
"name": function_name,
"arguments": json.dumps(prepared_arguments),
},
}
result = await sync_to_async(
execute_tool, thread_sensitive=True
)(
function_name,
request.user,
**prepared_arguments,
)
attempted_location_retry = False
if self._is_search_places_location_retry_candidate_error(
function_name,
result,
):
retry_locations = []
if trip_context_location:
retry_locations.append(trip_context_location)
if self._is_likely_location_reply(user_content):
retry_locations.append(user_content)
seen_retry_locations = set()
for retry_location in retry_locations:
normalized_retry_location = (
retry_location.strip().lower()
)
if (
not normalized_retry_location
or normalized_retry_location in seen_retry_locations
):
continue
seen_retry_locations.add(normalized_retry_location)
attempted_location_retry = True
retry_arguments = dict(prepared_arguments)
retry_arguments["location"] = retry_location
retry_result = await sync_to_async(
execute_tool,
thread_sensitive=True,
)(
function_name,
request.user,
**retry_arguments,
)
if not self._is_required_param_tool_error(
retry_result
) and not self._is_execution_failure_tool_error(
retry_result
):
result = retry_result
tool_call_for_history = {
**tool_call,
"function": {
**function_payload,
"name": function_name,
"arguments": json.dumps(retry_arguments),
},
}
break
# If we attempted context retries but all failed, convert
# to an execution failure so we never ask the user for a
# location they already implied via itinerary context.
if (
attempted_location_retry
and self._is_required_param_tool_error(result)
):
result = {
"error": "Could not search places at the provided itinerary locations"
}
attempted_weather_coord_retry = False
if self._is_get_weather_missing_latlong_error(
function_name, result
):
coords = await sync_to_async(
self._extract_collection_coordinates,
thread_sensitive=True,
)(collection)
if coords is not None:
retry_lat, retry_lon = coords
retry_arguments = dict(prepared_arguments)
retry_arguments["latitude"] = retry_lat
retry_arguments["longitude"] = retry_lon
if not retry_arguments.get("dates"):
retry_arguments["dates"] = (
self._derive_weather_dates_from_collection(
collection
)
)
attempted_weather_coord_retry = True
retry_result = await sync_to_async(
execute_tool,
thread_sensitive=True,
)(
function_name,
request.user,
**retry_arguments,
)
if not self._is_required_param_tool_error(
retry_result
) and not self._is_execution_failure_tool_error(
retry_result
):
result = retry_result
tool_call_for_history = {
**tool_call,
"function": {
**function_payload,
"name": function_name,
"arguments": json.dumps(retry_arguments),
},
}
# If retry was attempted but still failed, convert to an
# execution failure — never ask the user for coordinates
# they implied via collection context.
if (
attempted_weather_coord_retry
and self._is_required_param_tool_error(result)
and self._is_get_weather_missing_latlong_error(
function_name,
result,
)
):
result = {
"error": "Could not fetch weather for the collection locations"
}
if self._is_required_param_tool_error(result):
assistant_message_kwargs = {
"conversation": conversation,
"role": "assistant",
"content": assistant_content,
}
if successful_tool_calls:
assistant_message_kwargs["tool_calls"] = (
successful_tool_calls
)
await sync_to_async(
ChatMessage.objects.create, thread_sensitive=True
)(**assistant_message_kwargs)
for tool_message in successful_tool_messages:
await sync_to_async(
ChatMessage.objects.create,
thread_sensitive=True,
)(**tool_message)
if (
not attempted_location_retry
and self._is_search_places_missing_location_required_error(
function_name,
result,
)
):
clarification_content = self._build_search_places_location_clarification_message()
await sync_to_async(
ChatMessage.objects.create,
thread_sensitive=True,
)(
conversation=conversation,
role="assistant",
content=clarification_content,
)
await sync_to_async(
conversation.save,
thread_sensitive=True,
)(update_fields=["updated_at"])
yield (
"data: "
f"{json.dumps({'content': clarification_content})}"
"\n\n"
)
yield "data: [DONE]\n\n"
return
await sync_to_async(
conversation.save,
thread_sensitive=True,
)(update_fields=["updated_at"])
logger.info(
"Stopping chat tool loop due to required-arg tool validation error: %s (%s)",
function_name,
result.get("error"),
)
error_event = self._build_required_param_error_event(
function_name,
result,
)
yield f"data: {json.dumps(error_event)}\n\n"
yield "data: [DONE]\n\n"
return
if self._is_execution_failure_tool_error(result):
if first_execution_failure is None:
first_execution_failure = (function_name, result)
if not self._is_retryable_execution_failure(result):
encountered_permanent_failure = True
continue
result_content = serialize_tool_result(result)
successful_tool_calls.append(tool_call_for_history)
tool_message_payload = {
"conversation": conversation,
"role": "tool",
"content": result_content,
"tool_call_id": tool_call_for_history.get("id"),
"name": function_name,
}
successful_tool_messages.append(tool_message_payload)
successful_tool_chat_entries.append(
{
"role": "tool",
"tool_call_id": tool_call_for_history.get("id"),
"name": function_name,
"content": result_content,
}
)
tool_event = {
"tool_result": {
"tool_call_id": tool_call_for_history.get("id"),
"name": function_name,
"result": result,
}
}
yield f"data: {json.dumps(tool_event)}\n\n"
if not successful_tool_calls and first_execution_failure:
if encountered_permanent_failure:
all_failure_rounds = MAX_ALL_FAILURE_ROUNDS
else:
all_failure_rounds += 1
if all_failure_rounds >= MAX_ALL_FAILURE_ROUNDS:
failed_tool_name, failed_tool_result = (
first_execution_failure
)
error_event = self._build_tool_execution_error_event(
failed_tool_name,
failed_tool_result,
)
await sync_to_async(
ChatMessage.objects.create,
thread_sensitive=True,
)(
conversation=conversation,
role="assistant",
content=error_event["error"],
)
await sync_to_async(
conversation.save,
thread_sensitive=True,
)(update_fields=["updated_at"])
yield f"data: {json.dumps(error_event)}\n\n"
yield "data: [DONE]\n\n"
return
continue
all_failure_rounds = 0
tool_iterations += 1
assistant_with_tools = {
"role": "assistant",
"content": assistant_content,
"tool_calls": successful_tool_calls,
}
current_messages.append(assistant_with_tools)
current_messages.extend(successful_tool_chat_entries)
await sync_to_async(
ChatMessage.objects.create, thread_sensitive=True
)(
conversation=conversation,
role="assistant",
content=assistant_content,
tool_calls=successful_tool_calls,
)
for tool_message in successful_tool_messages:
await sync_to_async(
ChatMessage.objects.create,
thread_sensitive=True,
)(**tool_message)
await sync_to_async(conversation.save, thread_sensitive=True)(
update_fields=["updated_at"]
)
continue
await sync_to_async(ChatMessage.objects.create, thread_sensitive=True)(
conversation=conversation,
role="assistant",
content=assistant_content,
)
await sync_to_async(conversation.save, thread_sensitive=True)(
update_fields=["updated_at"]
)
yield "data: [DONE]\n\n"
break
if tool_iterations >= MAX_TOOL_ITERATIONS:
logger.warning(
"Stopping chat tool loop after max iterations (%s)",
MAX_TOOL_ITERATIONS,
)
payload = {
"error": "The assistant stopped after too many tool calls. Please try again with a more specific request.",
"error_category": "tool_loop_limit",
}
yield f"data: {json.dumps(payload)}\n\n"
yield "data: [DONE]\n\n"
response = StreamingHttpResponse(
streaming_content=self._async_to_sync_generator(event_stream()),
content_type="text/event-stream",
)
response["Cache-Control"] = "no-cache"
response["X-Accel-Buffering"] = "no"
return response
class ChatProviderCatalogViewSet(viewsets.ViewSet):
permission_classes = [IsAuthenticated]
def list(self, request):
return Response(get_provider_catalog(user=request.user))
@action(detail=True, methods=["get"])
def models(self, request, pk=None):
"""Fetch available models from a provider's API."""
from chat.llm_client import CHAT_PROVIDER_CONFIG, get_llm_api_key
provider = (pk or "").lower()
api_key = get_llm_api_key(request.user, provider)
if not api_key:
return Response(
{"error": "No API key configured for this provider"},
status=status.HTTP_403_FORBIDDEN,
)
try:
if provider == "openai":
import openai
client = openai.OpenAI(api_key=api_key)
models = client.models.list()
chat_models = [
model.id
for model in models
if any(prefix in model.id for prefix in ["gpt-", "o1-", "chatgpt"])
]
return Response({"models": sorted(set(chat_models), reverse=True)})
if provider in ["anthropic", "claude"]:
return Response(
{
"models": [
"claude-sonnet-4-20250514",
"claude-opus-4-20250514",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-haiku-20240307",
]
}
)
if provider in ["gemini", "google"]:
return Response(
{
"models": [
"gemini-2.0-flash",
"gemini-1.5-pro",
"gemini-1.5-flash",
"gemini-1.5-flash-8b",
]
}
)
if provider in ["groq"]:
return Response(
{
"models": [
"llama-3.3-70b-versatile",
"llama-3.1-70b-versatile",
"llama-3.1-8b-instant",
"mixtral-8x7b-32768",
]
}
)
if provider in ["ollama"]:
import requests
try:
response = requests.get(
"http://localhost:11434/api/tags", timeout=5
)
if response.ok:
data = response.json()
models = [item["name"] for item in data.get("models", [])]
return Response({"models": models})
except Exception:
pass
return Response({"models": []})
if provider == "opencode_zen":
import requests
config = CHAT_PROVIDER_CONFIG.get("opencode_zen", {})
api_base = config.get("api_base", "https://opencode.ai/zen/v1")
response = requests.get(
f"{api_base}/models",
headers={"Authorization": f"Bearer {api_key}"},
timeout=10,
)
if response.ok:
data = response.json()
raw_models = (
data.get("data", data) if isinstance(data, dict) else data
)
model_ids = []
for model_entry in raw_models:
if not isinstance(model_entry, dict):
continue
model_id = model_entry.get("id") or model_entry.get("model_id")
if model_id:
model_ids.append(model_id)
return Response({"models": sorted(set(model_ids))})
logger.warning(
"OpenCode Zen models fetch failed with status %s",
response.status_code,
)
return Response({"models": []})
return Response({"models": []})
except Exception as exc:
logger.error("Failed to fetch models for %s: %s", provider, exc)
return Response(
{"error": f"Failed to fetch models: {str(exc)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
from .capabilities import CapabilitiesView
from .day_suggestions import DaySuggestionsView