Files
voyage/backend/server/chat/views/__init__.py
alex 21954df3ee fix(chat): improve OpenCode Zen integration and error handling
- Fetch models dynamically from OpenCode Zen API (36 models vs 5 hardcoded)
- Add function calling support check before using tools
- Add retry logic (num_retries=2) for transient failures
- Improve logging for debugging API calls and errors
- Update system prompt for multi-stop itinerary context
- Clean up unused imports in frontend components
- Remove deleted views.py (moved to views/__init__.py)
2026-03-09 16:11:14 +00:00

496 lines
18 KiB
Python

import asyncio
import json
import logging
from asgiref.sync import sync_to_async
from adventures.models import Collection
from django.http import StreamingHttpResponse
from rest_framework import status, viewsets
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from ..agent_tools import AGENT_TOOLS, execute_tool, serialize_tool_result
from ..llm_client import (
get_provider_catalog,
get_system_prompt,
is_chat_provider_available,
stream_chat_completion,
)
from ..models import ChatConversation, ChatMessage
from ..serializers import ChatConversationSerializer
logger = logging.getLogger(__name__)
class ChatViewSet(viewsets.ModelViewSet):
serializer_class = ChatConversationSerializer
permission_classes = [IsAuthenticated]
def get_queryset(self):
return ChatConversation.objects.filter(user=self.request.user).prefetch_related(
"messages"
)
def list(self, request, *args, **kwargs):
conversations = self.get_queryset().only("id", "title", "updated_at")
data = [
{
"id": str(conversation.id),
"title": conversation.title,
"updated_at": conversation.updated_at,
}
for conversation in conversations
]
return Response(data)
def create(self, request, *args, **kwargs):
conversation = ChatConversation.objects.create(
user=request.user,
title=(request.data.get("title") or "").strip(),
)
serializer = self.get_serializer(conversation)
return Response(serializer.data, status=status.HTTP_201_CREATED)
def _build_llm_messages(self, conversation, user, system_prompt=None):
messages = [
{
"role": "system",
"content": system_prompt or get_system_prompt(user),
}
]
for message in conversation.messages.all().order_by("created_at"):
payload = {
"role": message.role,
"content": message.content,
}
if message.role == "assistant" and message.tool_calls:
payload["tool_calls"] = message.tool_calls
if message.role == "tool":
payload["tool_call_id"] = message.tool_call_id
payload["name"] = message.name
messages.append(payload)
return messages
def _async_to_sync_generator(self, async_gen):
loop = asyncio.new_event_loop()
try:
while True:
try:
yield loop.run_until_complete(async_gen.__anext__())
except StopAsyncIteration:
break
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
@staticmethod
def _merge_tool_call_delta(accumulator, tool_calls_delta):
for idx, tool_call in enumerate(tool_calls_delta or []):
idx = tool_call.get("index", idx)
while len(accumulator) <= idx:
accumulator.append(
{
"id": None,
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
current = accumulator[idx]
if tool_call.get("id"):
current["id"] = tool_call.get("id")
if tool_call.get("type"):
current["type"] = tool_call.get("type")
function_data = tool_call.get("function") or {}
if function_data.get("name"):
current["function"]["name"] = function_data.get("name")
if function_data.get("arguments"):
current["function"]["arguments"] += function_data.get("arguments")
@action(detail=True, methods=["post"])
def send_message(self, request, pk=None):
# Auto-learn preferences from user's travel history
from integrations.utils.auto_profile import update_auto_preference_profile
try:
update_auto_preference_profile(request.user)
except Exception as exc:
logger.warning("Auto-profile update failed: %s", exc)
# Continue anyway - not critical
conversation = self.get_object()
user_content = (request.data.get("message") or "").strip()
if not user_content:
return Response(
{"error": "message is required"},
status=status.HTTP_400_BAD_REQUEST,
)
provider = (request.data.get("provider") or "openai").strip().lower()
model = (request.data.get("model") or "").strip() or None
collection_id = request.data.get("collection_id")
collection_name = request.data.get("collection_name")
start_date = request.data.get("start_date")
end_date = request.data.get("end_date")
destination = request.data.get("destination")
if not is_chat_provider_available(provider):
return Response(
{"error": f"Provider is not available for chat: {provider}."},
status=status.HTTP_400_BAD_REQUEST,
)
context_parts = []
if collection_name:
context_parts.append(f"Trip: {collection_name}")
if destination:
context_parts.append(f"Destination: {destination}")
if start_date and end_date:
context_parts.append(f"Dates: {start_date} to {end_date}")
collection = None
if collection_id:
try:
requested_collection = Collection.objects.get(id=collection_id)
if (
requested_collection.user == request.user
or requested_collection.shared_with.filter(
id=request.user.id
).exists()
):
collection = requested_collection
except Collection.DoesNotExist:
pass
if collection:
itinerary_stops = []
seen_stops = set()
for location in collection.locations.select_related(
"city", "country"
).all():
city_name = (getattr(location.city, "name", "") or "").strip()
country_name = (getattr(location.country, "name", "") or "").strip()
if city_name or country_name:
stop_label = (
f"{city_name}, {country_name}"
if city_name and country_name
else city_name or country_name
)
stop_key = f"geo:{city_name.lower()}|{country_name.lower()}"
else:
fallback_name = (location.location or location.name or "").strip()
if not fallback_name:
continue
stop_label = fallback_name
stop_key = f"name:{fallback_name.lower()}"
if stop_key in seen_stops:
continue
seen_stops.add(stop_key)
itinerary_stops.append(stop_label)
if len(itinerary_stops) >= 8:
break
if itinerary_stops:
context_parts.append(f"Itinerary stops: {'; '.join(itinerary_stops)}")
system_prompt = get_system_prompt(request.user, collection)
if context_parts:
system_prompt += "\n\n## Trip Context\n" + "\n".join(context_parts)
ChatMessage.objects.create(
conversation=conversation,
role="user",
content=user_content,
)
conversation.save(update_fields=["updated_at"])
if not conversation.title:
conversation.title = user_content[:120]
conversation.save(update_fields=["title", "updated_at"])
llm_messages = self._build_llm_messages(
conversation,
request.user,
system_prompt=system_prompt,
)
MAX_TOOL_ITERATIONS = 10
async def event_stream():
current_messages = list(llm_messages)
encountered_error = False
tool_iterations = 0
while tool_iterations < MAX_TOOL_ITERATIONS:
content_chunks = []
tool_calls_accumulator = []
async for chunk in stream_chat_completion(
request.user,
current_messages,
provider,
tools=AGENT_TOOLS,
model=model,
):
if not chunk.startswith("data: "):
yield chunk
continue
payload = chunk[len("data: ") :].strip()
if payload == "[DONE]":
continue
yield chunk
try:
data = json.loads(payload)
except json.JSONDecodeError:
continue
if data.get("error"):
encountered_error = True
break
if data.get("content"):
content_chunks.append(data["content"])
if data.get("tool_calls"):
self._merge_tool_call_delta(
tool_calls_accumulator,
data["tool_calls"],
)
if encountered_error:
break
assistant_content = "".join(content_chunks)
if tool_calls_accumulator:
assistant_with_tools = {
"role": "assistant",
"content": assistant_content,
"tool_calls": tool_calls_accumulator,
}
current_messages.append(assistant_with_tools)
await sync_to_async(
ChatMessage.objects.create, thread_sensitive=True
)(
conversation=conversation,
role="assistant",
content=assistant_content,
tool_calls=tool_calls_accumulator,
)
await sync_to_async(conversation.save, thread_sensitive=True)(
update_fields=["updated_at"]
)
for tool_call in tool_calls_accumulator:
function_payload = tool_call.get("function") or {}
function_name = function_payload.get("name") or ""
raw_arguments = function_payload.get("arguments") or "{}"
try:
arguments = json.loads(raw_arguments)
except json.JSONDecodeError:
arguments = {}
if not isinstance(arguments, dict):
arguments = {}
result = await sync_to_async(
execute_tool, thread_sensitive=True
)(
function_name,
request.user,
**arguments,
)
result_content = serialize_tool_result(result)
current_messages.append(
{
"role": "tool",
"tool_call_id": tool_call.get("id"),
"name": function_name,
"content": result_content,
}
)
await sync_to_async(
ChatMessage.objects.create, thread_sensitive=True
)(
conversation=conversation,
role="tool",
content=result_content,
tool_call_id=tool_call.get("id"),
name=function_name,
)
await sync_to_async(conversation.save, thread_sensitive=True)(
update_fields=["updated_at"]
)
tool_event = {
"tool_result": {
"tool_call_id": tool_call.get("id"),
"name": function_name,
"result": result,
}
}
yield f"data: {json.dumps(tool_event)}\n\n"
continue
await sync_to_async(ChatMessage.objects.create, thread_sensitive=True)(
conversation=conversation,
role="assistant",
content=assistant_content,
)
await sync_to_async(conversation.save, thread_sensitive=True)(
update_fields=["updated_at"]
)
yield "data: [DONE]\n\n"
break
response = StreamingHttpResponse(
streaming_content=self._async_to_sync_generator(event_stream()),
content_type="text/event-stream",
)
response["Cache-Control"] = "no-cache"
response["X-Accel-Buffering"] = "no"
return response
class ChatProviderCatalogViewSet(viewsets.ViewSet):
permission_classes = [IsAuthenticated]
def list(self, request):
return Response(get_provider_catalog(user=request.user))
@action(detail=True, methods=["get"])
def models(self, request, pk=None):
"""Fetch available models from a provider's API."""
from chat.llm_client import CHAT_PROVIDER_CONFIG, get_llm_api_key
provider = (pk or "").lower()
api_key = get_llm_api_key(request.user, provider)
if not api_key:
return Response(
{"error": "No API key configured for this provider"},
status=status.HTTP_403_FORBIDDEN,
)
try:
if provider == "openai":
import openai
client = openai.OpenAI(api_key=api_key)
models = client.models.list()
chat_models = [
model.id
for model in models
if any(prefix in model.id for prefix in ["gpt-", "o1-", "chatgpt"])
]
return Response({"models": sorted(set(chat_models), reverse=True)})
if provider in ["anthropic", "claude"]:
return Response(
{
"models": [
"claude-sonnet-4-20250514",
"claude-opus-4-20250514",
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-haiku-20240307",
]
}
)
if provider in ["gemini", "google"]:
return Response(
{
"models": [
"gemini-2.0-flash",
"gemini-1.5-pro",
"gemini-1.5-flash",
"gemini-1.5-flash-8b",
]
}
)
if provider in ["groq"]:
return Response(
{
"models": [
"llama-3.3-70b-versatile",
"llama-3.1-70b-versatile",
"llama-3.1-8b-instant",
"mixtral-8x7b-32768",
]
}
)
if provider in ["ollama"]:
import requests
try:
response = requests.get(
"http://localhost:11434/api/tags", timeout=5
)
if response.ok:
data = response.json()
models = [item["name"] for item in data.get("models", [])]
return Response({"models": models})
except Exception:
pass
return Response({"models": []})
if provider == "opencode_zen":
import requests
config = CHAT_PROVIDER_CONFIG.get("opencode_zen", {})
api_base = config.get("api_base", "https://opencode.ai/zen/v1")
response = requests.get(
f"{api_base}/models",
headers={"Authorization": f"Bearer {api_key}"},
timeout=10,
)
if response.ok:
data = response.json()
raw_models = (
data.get("data", data) if isinstance(data, dict) else data
)
model_ids = []
for model_entry in raw_models:
if not isinstance(model_entry, dict):
continue
model_id = model_entry.get("id") or model_entry.get("model_id")
if model_id:
model_ids.append(model_id)
return Response({"models": sorted(set(model_ids))})
logger.warning(
"OpenCode Zen models fetch failed with status %s",
response.status_code,
)
return Response({"models": []})
return Response({"models": []})
except Exception as exc:
logger.error("Failed to fetch models for %s: %s", provider, exc)
return Response(
{"error": f"Failed to fetch models: {str(exc)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
from .capabilities import CapabilitiesView
from .day_suggestions import DaySuggestionsView