fix(chat): improve OpenCode Zen integration and error handling
- Fetch models dynamically from OpenCode Zen API (36 models vs 5 hardcoded) - Add function calling support check before using tools - Add retry logic (num_retries=2) for transient failures - Improve logging for debugging API calls and errors - Update system prompt for multi-stop itinerary context - Clean up unused imports in frontend components - Remove deleted views.py (moved to views/__init__.py)
This commit is contained in:
@@ -329,6 +329,11 @@ When modifying itineraries:
|
||||
- Suggest logical ordering based on geography
|
||||
- Consider travel time between locations
|
||||
|
||||
When chat context includes a trip collection:
|
||||
- Treat context as itinerary-wide (potentially multiple stops), not a single destination
|
||||
- Use get_trip_details first when you need complete collection context before searching for places
|
||||
- Ground place searches in trip stops and dates from the provided trip context
|
||||
|
||||
Be conversational, helpful, and enthusiastic about travel. Keep responses concise but informative."""
|
||||
|
||||
if collection and collection.shared_with.count() > 0:
|
||||
@@ -389,8 +394,8 @@ async def stream_chat_completion(user, messages, provider, tools=None, model=Non
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
return
|
||||
|
||||
completion_kwargs = {
|
||||
"model": model
|
||||
resolved_model = (
|
||||
model
|
||||
or (
|
||||
settings.VOYAGE_AI_MODEL
|
||||
if normalized_provider
|
||||
@@ -398,10 +403,34 @@ async def stream_chat_completion(user, messages, provider, tools=None, model=Non
|
||||
and settings.VOYAGE_AI_MODEL
|
||||
else None
|
||||
)
|
||||
or provider_config["default_model"],
|
||||
or provider_config["default_model"]
|
||||
)
|
||||
|
||||
if tools and not litellm.supports_function_calling(model=resolved_model):
|
||||
logger.warning(
|
||||
"Model %s does not support function calling, disabling tools",
|
||||
resolved_model,
|
||||
)
|
||||
tools = None
|
||||
|
||||
logger.info(
|
||||
"Chat request: provider=%s, model=%s, has_tools=%s",
|
||||
normalized_provider,
|
||||
resolved_model,
|
||||
bool(tools),
|
||||
)
|
||||
logger.debug(
|
||||
"API base: %s, messages count: %s",
|
||||
provider_config.get("api_base"),
|
||||
len(messages),
|
||||
)
|
||||
|
||||
completion_kwargs = {
|
||||
"model": resolved_model,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"api_key": api_key,
|
||||
"num_retries": 2,
|
||||
}
|
||||
if tools:
|
||||
completion_kwargs["tools"] = tools
|
||||
@@ -448,6 +477,7 @@ async def stream_chat_completion(user, messages, provider, tools=None, model=Non
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as exc:
|
||||
logger.error("LiteLLM error: %s: %s", type(exc).__name__, str(exc)[:200])
|
||||
logger.exception("LLM streaming error")
|
||||
payload = _safe_error_payload(exc)
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.http import StreamingHttpResponse
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from .agent_tools import AGENT_TOOLS, execute_tool, serialize_tool_result
|
||||
from .llm_client import (
|
||||
get_provider_catalog,
|
||||
get_system_prompt,
|
||||
is_chat_provider_available,
|
||||
stream_chat_completion,
|
||||
)
|
||||
from .models import ChatConversation, ChatMessage
|
||||
from .serializers import ChatConversationSerializer
|
||||
|
||||
|
||||
class ChatViewSet(viewsets.ModelViewSet):
|
||||
serializer_class = ChatConversationSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get_queryset(self):
|
||||
return ChatConversation.objects.filter(user=self.request.user).prefetch_related(
|
||||
"messages"
|
||||
)
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
conversations = self.get_queryset().only("id", "title", "updated_at")
|
||||
data = [
|
||||
{
|
||||
"id": str(conversation.id),
|
||||
"title": conversation.title,
|
||||
"updated_at": conversation.updated_at,
|
||||
}
|
||||
for conversation in conversations
|
||||
]
|
||||
return Response(data)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
conversation = ChatConversation.objects.create(
|
||||
user=request.user,
|
||||
title=(request.data.get("title") or "").strip(),
|
||||
)
|
||||
serializer = self.get_serializer(conversation)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
||||
|
||||
def _build_llm_messages(self, conversation, user):
|
||||
messages = [{"role": "system", "content": get_system_prompt(user)}]
|
||||
for message in conversation.messages.all().order_by("created_at"):
|
||||
payload = {
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
}
|
||||
if message.role == "assistant" and message.tool_calls:
|
||||
payload["tool_calls"] = message.tool_calls
|
||||
if message.role == "tool":
|
||||
payload["tool_call_id"] = message.tool_call_id
|
||||
payload["name"] = message.name
|
||||
messages.append(payload)
|
||||
return messages
|
||||
|
||||
def _async_to_sync_generator(self, async_gen):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
yield loop.run_until_complete(async_gen.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
loop.close()
|
||||
|
||||
@staticmethod
|
||||
def _merge_tool_call_delta(accumulator, tool_calls_delta):
|
||||
for idx, tool_call in enumerate(tool_calls_delta or []):
|
||||
idx = tool_call.get("index", idx)
|
||||
while len(accumulator) <= idx:
|
||||
accumulator.append(
|
||||
{
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
)
|
||||
|
||||
current = accumulator[idx]
|
||||
if tool_call.get("id"):
|
||||
current["id"] = tool_call.get("id")
|
||||
if tool_call.get("type"):
|
||||
current["type"] = tool_call.get("type")
|
||||
|
||||
function_data = tool_call.get("function") or {}
|
||||
if function_data.get("name"):
|
||||
current["function"]["name"] = function_data.get("name")
|
||||
if function_data.get("arguments"):
|
||||
current["function"]["arguments"] += function_data.get("arguments")
|
||||
|
||||
@action(detail=True, methods=["post"])
|
||||
def send_message(self, request, pk=None):
|
||||
conversation = self.get_object()
|
||||
user_content = (request.data.get("message") or "").strip()
|
||||
if not user_content:
|
||||
return Response(
|
||||
{"error": "message is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
provider = (request.data.get("provider") or "openai").strip().lower()
|
||||
if not is_chat_provider_available(provider):
|
||||
return Response(
|
||||
{"error": f"Provider is not available for chat: {provider}."},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
ChatMessage.objects.create(
|
||||
conversation=conversation,
|
||||
role="user",
|
||||
content=user_content,
|
||||
)
|
||||
conversation.save(update_fields=["updated_at"])
|
||||
|
||||
if not conversation.title:
|
||||
conversation.title = user_content[:120]
|
||||
conversation.save(update_fields=["title", "updated_at"])
|
||||
|
||||
llm_messages = self._build_llm_messages(conversation, request.user)
|
||||
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
async def event_stream():
|
||||
current_messages = list(llm_messages)
|
||||
encountered_error = False
|
||||
tool_iterations = 0
|
||||
|
||||
while tool_iterations < MAX_TOOL_ITERATIONS:
|
||||
content_chunks = []
|
||||
tool_calls_accumulator = []
|
||||
|
||||
async for chunk in stream_chat_completion(
|
||||
request.user,
|
||||
current_messages,
|
||||
provider,
|
||||
tools=AGENT_TOOLS,
|
||||
):
|
||||
if not chunk.startswith("data: "):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
payload = chunk[len("data: ") :].strip()
|
||||
if payload == "[DONE]":
|
||||
continue
|
||||
|
||||
yield chunk
|
||||
|
||||
try:
|
||||
data = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if data.get("error"):
|
||||
encountered_error = True
|
||||
break
|
||||
|
||||
if data.get("content"):
|
||||
content_chunks.append(data["content"])
|
||||
|
||||
if data.get("tool_calls"):
|
||||
self._merge_tool_call_delta(
|
||||
tool_calls_accumulator,
|
||||
data["tool_calls"],
|
||||
)
|
||||
|
||||
if encountered_error:
|
||||
break
|
||||
|
||||
assistant_content = "".join(content_chunks)
|
||||
|
||||
if tool_calls_accumulator:
|
||||
assistant_with_tools = {
|
||||
"role": "assistant",
|
||||
"content": assistant_content,
|
||||
"tool_calls": tool_calls_accumulator,
|
||||
}
|
||||
current_messages.append(assistant_with_tools)
|
||||
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create, thread_sensitive=True
|
||||
)(
|
||||
conversation=conversation,
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
tool_calls=tool_calls_accumulator,
|
||||
)
|
||||
await sync_to_async(conversation.save, thread_sensitive=True)(
|
||||
update_fields=["updated_at"]
|
||||
)
|
||||
|
||||
for tool_call in tool_calls_accumulator:
|
||||
function_payload = tool_call.get("function") or {}
|
||||
function_name = function_payload.get("name") or ""
|
||||
raw_arguments = function_payload.get("arguments") or "{}"
|
||||
|
||||
try:
|
||||
arguments = json.loads(raw_arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if not isinstance(arguments, dict):
|
||||
arguments = {}
|
||||
|
||||
result = await sync_to_async(
|
||||
execute_tool, thread_sensitive=True
|
||||
)(
|
||||
function_name,
|
||||
request.user,
|
||||
**arguments,
|
||||
)
|
||||
result_content = serialize_tool_result(result)
|
||||
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
"name": function_name,
|
||||
"content": result_content,
|
||||
}
|
||||
)
|
||||
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create, thread_sensitive=True
|
||||
)(
|
||||
conversation=conversation,
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=tool_call.get("id"),
|
||||
name=function_name,
|
||||
)
|
||||
await sync_to_async(conversation.save, thread_sensitive=True)(
|
||||
update_fields=["updated_at"]
|
||||
)
|
||||
|
||||
tool_event = {
|
||||
"tool_result": {
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
"name": function_name,
|
||||
"result": result,
|
||||
}
|
||||
}
|
||||
yield f"data: {json.dumps(tool_event)}\n\n"
|
||||
|
||||
continue
|
||||
|
||||
await sync_to_async(ChatMessage.objects.create, thread_sensitive=True)(
|
||||
conversation=conversation,
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
)
|
||||
await sync_to_async(conversation.save, thread_sensitive=True)(
|
||||
update_fields=["updated_at"]
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
streaming_content=self._async_to_sync_generator(event_stream()),
|
||||
content_type="text/event-stream",
|
||||
)
|
||||
response["Cache-Control"] = "no-cache"
|
||||
response["X-Accel-Buffering"] = "no"
|
||||
return response
|
||||
|
||||
|
||||
class ChatProviderCatalogViewSet(viewsets.ViewSet):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def list(self, request):
|
||||
return Response(get_provider_catalog())
|
||||
@@ -163,6 +163,41 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
except Collection.DoesNotExist:
|
||||
pass
|
||||
|
||||
if collection:
|
||||
itinerary_stops = []
|
||||
seen_stops = set()
|
||||
for location in collection.locations.select_related(
|
||||
"city", "country"
|
||||
).all():
|
||||
city_name = (getattr(location.city, "name", "") or "").strip()
|
||||
country_name = (getattr(location.country, "name", "") or "").strip()
|
||||
|
||||
if city_name or country_name:
|
||||
stop_label = (
|
||||
f"{city_name}, {country_name}"
|
||||
if city_name and country_name
|
||||
else city_name or country_name
|
||||
)
|
||||
stop_key = f"geo:{city_name.lower()}|{country_name.lower()}"
|
||||
else:
|
||||
fallback_name = (location.location or location.name or "").strip()
|
||||
if not fallback_name:
|
||||
continue
|
||||
stop_label = fallback_name
|
||||
stop_key = f"name:{fallback_name.lower()}"
|
||||
|
||||
if stop_key in seen_stops:
|
||||
continue
|
||||
|
||||
seen_stops.add(stop_key)
|
||||
itinerary_stops.append(stop_label)
|
||||
|
||||
if len(itinerary_stops) >= 8:
|
||||
break
|
||||
|
||||
if itinerary_stops:
|
||||
context_parts.append(f"Itinerary stops: {'; '.join(itinerary_stops)}")
|
||||
|
||||
system_prompt = get_system_prompt(request.user, collection)
|
||||
if context_parts:
|
||||
system_prompt += "\n\n## Trip Context\n" + "\n".join(context_parts)
|
||||
@@ -338,7 +373,7 @@ class ChatProviderCatalogViewSet(viewsets.ViewSet):
|
||||
@action(detail=True, methods=["get"])
|
||||
def models(self, request, pk=None):
|
||||
"""Fetch available models from a provider's API."""
|
||||
from chat.llm_client import get_llm_api_key
|
||||
from chat.llm_client import CHAT_PROVIDER_CONFIG, get_llm_api_key
|
||||
|
||||
provider = (pk or "").lower()
|
||||
|
||||
@@ -414,8 +449,38 @@ class ChatProviderCatalogViewSet(viewsets.ViewSet):
|
||||
pass
|
||||
return Response({"models": []})
|
||||
|
||||
if provider in ["opencode_zen"]:
|
||||
return Response({"models": ["openai/gpt-5-nano"]})
|
||||
if provider == "opencode_zen":
|
||||
import requests
|
||||
|
||||
config = CHAT_PROVIDER_CONFIG.get("opencode_zen", {})
|
||||
api_base = config.get("api_base", "https://opencode.ai/zen/v1")
|
||||
response = requests.get(
|
||||
f"{api_base}/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if response.ok:
|
||||
data = response.json()
|
||||
raw_models = (
|
||||
data.get("data", data) if isinstance(data, dict) else data
|
||||
)
|
||||
model_ids = []
|
||||
for model_entry in raw_models:
|
||||
if not isinstance(model_entry, dict):
|
||||
continue
|
||||
|
||||
model_id = model_entry.get("id") or model_entry.get("model_id")
|
||||
if model_id:
|
||||
model_ids.append(model_id)
|
||||
|
||||
return Response({"models": sorted(set(model_ids))})
|
||||
|
||||
logger.warning(
|
||||
"OpenCode Zen models fetch failed with status %s",
|
||||
response.status_code,
|
||||
)
|
||||
return Response({"models": []})
|
||||
|
||||
return Response({"models": []})
|
||||
except Exception as exc:
|
||||
|
||||
Reference in New Issue
Block a user