fix(chat): add saved AI defaults and harden suggestions
This commit is contained in:
@@ -98,7 +98,11 @@ def _parse_address(tags):
|
||||
|
||||
@agent_tool(
|
||||
name="search_places",
|
||||
description="Search for places of interest near a location. Returns tourist attractions, restaurants, hotels, etc.",
|
||||
description=(
|
||||
"Search for places of interest near a location. "
|
||||
"Required: provide a non-empty 'location' string (city, neighborhood, or address). "
|
||||
"Returns tourist attractions, restaurants, hotels, etc."
|
||||
),
|
||||
parameters={
|
||||
"location": {
|
||||
"type": "string",
|
||||
@@ -231,7 +235,11 @@ def list_trips(user):
|
||||
|
||||
@agent_tool(
|
||||
name="web_search",
|
||||
description="Search the web for current information about destinations, events, prices, weather, or any real-time travel information. Use this when you need up-to-date information that may not be in your training data.",
|
||||
description=(
|
||||
"Search the web for current travel information. "
|
||||
"Required: provide a non-empty 'query' string describing exactly what to look up. "
|
||||
"Use when you need up-to-date info that may not be in training data."
|
||||
),
|
||||
parameters={
|
||||
"query": {
|
||||
"type": "string",
|
||||
|
||||
@@ -165,6 +165,18 @@ def _normalize_provider_id(provider_id):
|
||||
return lowered
|
||||
|
||||
|
||||
def normalize_gateway_model(provider_id, model):
|
||||
normalized_provider = _normalize_provider_id(provider_id)
|
||||
normalized_model = str(model or "").strip()
|
||||
if not normalized_model:
|
||||
return None
|
||||
|
||||
if normalized_provider == "opencode_zen" and "/" not in normalized_model:
|
||||
return f"openai/{normalized_model}"
|
||||
|
||||
return normalized_model
|
||||
|
||||
|
||||
def _default_provider_label(provider_id):
|
||||
return provider_id.replace("_", " ").title()
|
||||
|
||||
@@ -405,6 +417,7 @@ async def stream_chat_completion(user, messages, provider, tools=None, model=Non
|
||||
)
|
||||
or provider_config["default_model"]
|
||||
)
|
||||
resolved_model = normalize_gateway_model(normalized_provider, resolved_model)
|
||||
|
||||
if tools and not litellm.supports_function_calling(model=resolved_model):
|
||||
logger.warning(
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from adventures.models import Collection
|
||||
from django.http import StreamingHttpResponse
|
||||
from integrations.models import UserAISettings
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
@@ -53,19 +55,40 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
||||
|
||||
def _build_llm_messages(self, conversation, user, system_prompt=None):
|
||||
ordered_messages = list(conversation.messages.all().order_by("created_at"))
|
||||
valid_tool_call_ids = {
|
||||
message.tool_call_id
|
||||
for message in ordered_messages
|
||||
if message.role == "tool"
|
||||
and message.tool_call_id
|
||||
and not self._is_required_param_tool_error_message_content(message.content)
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt or get_system_prompt(user),
|
||||
}
|
||||
]
|
||||
for message in conversation.messages.all().order_by("created_at"):
|
||||
for message in ordered_messages:
|
||||
if (
|
||||
message.role == "tool"
|
||||
and self._is_required_param_tool_error_message_content(message.content)
|
||||
):
|
||||
continue
|
||||
|
||||
payload = {
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
}
|
||||
if message.role == "assistant" and message.tool_calls:
|
||||
payload["tool_calls"] = message.tool_calls
|
||||
filtered_tool_calls = [
|
||||
tool_call
|
||||
for tool_call in message.tool_calls
|
||||
if (tool_call or {}).get("id") in valid_tool_call_ids
|
||||
]
|
||||
if filtered_tool_calls:
|
||||
payload["tool_calls"] = filtered_tool_calls
|
||||
if message.role == "tool":
|
||||
payload["tool_call_id"] = message.tool_call_id
|
||||
payload["name"] = message.name
|
||||
@@ -109,6 +132,50 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
if function_data.get("arguments"):
|
||||
current["function"]["arguments"] += function_data.get("arguments")
|
||||
|
||||
@staticmethod
|
||||
def _is_required_param_tool_error(result):
|
||||
if not isinstance(result, dict):
|
||||
return False
|
||||
|
||||
error_text = result.get("error")
|
||||
if not isinstance(error_text, str):
|
||||
return False
|
||||
|
||||
normalized_error = error_text.strip().lower()
|
||||
if normalized_error in {"location is required", "query is required"}:
|
||||
return True
|
||||
|
||||
return bool(
|
||||
re.fullmatch(
|
||||
r"[a-z0-9_,\-\s]+ (is|are) required",
|
||||
normalized_error,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _is_required_param_tool_error_message_content(cls, content):
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
|
||||
return cls._is_required_param_tool_error(parsed)
|
||||
|
||||
@staticmethod
|
||||
def _build_required_param_error_event(tool_name, result):
|
||||
tool_error = result.get("error") if isinstance(result, dict) else ""
|
||||
return {
|
||||
"error": (
|
||||
"The assistant attempted to call "
|
||||
f"'{tool_name}' without required arguments ({tool_error}). "
|
||||
"Please try your message again with more specific details."
|
||||
),
|
||||
"error_category": "tool_validation_error",
|
||||
}
|
||||
|
||||
@action(detail=True, methods=["post"])
|
||||
def send_message(self, request, pk=None):
|
||||
# Auto-learn preferences from user's travel history
|
||||
@@ -128,8 +195,30 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
provider = (request.data.get("provider") or "openai").strip().lower()
|
||||
model = (request.data.get("model") or "").strip() or None
|
||||
requested_provider = (request.data.get("provider") or "").strip().lower()
|
||||
requested_model = (request.data.get("model") or "").strip() or None
|
||||
ai_settings = UserAISettings.objects.filter(user=request.user).first()
|
||||
preferred_provider = (
|
||||
(ai_settings.preferred_provider or "").strip().lower()
|
||||
if ai_settings
|
||||
else ""
|
||||
)
|
||||
preferred_model = (
|
||||
(ai_settings.preferred_model or "").strip() if ai_settings else ""
|
||||
)
|
||||
|
||||
provider = requested_provider
|
||||
if not provider and preferred_provider:
|
||||
if preferred_provider and is_chat_provider_available(preferred_provider):
|
||||
provider = preferred_provider
|
||||
|
||||
if not provider:
|
||||
provider = "openai"
|
||||
|
||||
model = requested_model
|
||||
if model is None and preferred_model and provider == preferred_provider:
|
||||
model = preferred_model
|
||||
|
||||
collection_id = request.data.get("collection_id")
|
||||
collection_name = request.data.get("collection_name")
|
||||
start_date = request.data.get("start_date")
|
||||
@@ -266,29 +355,16 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
)
|
||||
|
||||
if encountered_error:
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
assistant_content = "".join(content_chunks)
|
||||
|
||||
if tool_calls_accumulator:
|
||||
assistant_with_tools = {
|
||||
"role": "assistant",
|
||||
"content": assistant_content,
|
||||
"tool_calls": tool_calls_accumulator,
|
||||
}
|
||||
current_messages.append(assistant_with_tools)
|
||||
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create, thread_sensitive=True
|
||||
)(
|
||||
conversation=conversation,
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
tool_calls=tool_calls_accumulator,
|
||||
)
|
||||
await sync_to_async(conversation.save, thread_sensitive=True)(
|
||||
update_fields=["updated_at"]
|
||||
)
|
||||
tool_iterations += 1
|
||||
successful_tool_calls = []
|
||||
successful_tool_messages = []
|
||||
successful_tool_chat_entries = []
|
||||
|
||||
for tool_call in tool_calls_accumulator:
|
||||
function_payload = tool_call.get("function") or {}
|
||||
@@ -309,9 +385,58 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
request.user,
|
||||
**arguments,
|
||||
)
|
||||
|
||||
if self._is_required_param_tool_error(result):
|
||||
assistant_message_kwargs = {
|
||||
"conversation": conversation,
|
||||
"role": "assistant",
|
||||
"content": assistant_content,
|
||||
}
|
||||
if successful_tool_calls:
|
||||
assistant_message_kwargs["tool_calls"] = (
|
||||
successful_tool_calls
|
||||
)
|
||||
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create, thread_sensitive=True
|
||||
)(**assistant_message_kwargs)
|
||||
|
||||
for tool_message in successful_tool_messages:
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create,
|
||||
thread_sensitive=True,
|
||||
)(**tool_message)
|
||||
|
||||
await sync_to_async(
|
||||
conversation.save,
|
||||
thread_sensitive=True,
|
||||
)(update_fields=["updated_at"])
|
||||
|
||||
logger.info(
|
||||
"Stopping chat tool loop due to required-arg tool validation error: %s (%s)",
|
||||
function_name,
|
||||
result.get("error"),
|
||||
)
|
||||
error_event = self._build_required_param_error_event(
|
||||
function_name,
|
||||
result,
|
||||
)
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
result_content = serialize_tool_result(result)
|
||||
|
||||
current_messages.append(
|
||||
successful_tool_calls.append(tool_call)
|
||||
tool_message_payload = {
|
||||
"conversation": conversation,
|
||||
"role": "tool",
|
||||
"content": result_content,
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
"name": function_name,
|
||||
}
|
||||
successful_tool_messages.append(tool_message_payload)
|
||||
successful_tool_chat_entries.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
@@ -320,19 +445,6 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
}
|
||||
)
|
||||
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create, thread_sensitive=True
|
||||
)(
|
||||
conversation=conversation,
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=tool_call.get("id"),
|
||||
name=function_name,
|
||||
)
|
||||
await sync_to_async(conversation.save, thread_sensitive=True)(
|
||||
update_fields=["updated_at"]
|
||||
)
|
||||
|
||||
tool_event = {
|
||||
"tool_result": {
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
@@ -342,6 +454,32 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
}
|
||||
yield f"data: {json.dumps(tool_event)}\n\n"
|
||||
|
||||
assistant_with_tools = {
|
||||
"role": "assistant",
|
||||
"content": assistant_content,
|
||||
"tool_calls": successful_tool_calls,
|
||||
}
|
||||
current_messages.append(assistant_with_tools)
|
||||
current_messages.extend(successful_tool_chat_entries)
|
||||
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create, thread_sensitive=True
|
||||
)(
|
||||
conversation=conversation,
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
tool_calls=successful_tool_calls,
|
||||
)
|
||||
for tool_message in successful_tool_messages:
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create,
|
||||
thread_sensitive=True,
|
||||
)(**tool_message)
|
||||
|
||||
await sync_to_async(conversation.save, thread_sensitive=True)(
|
||||
update_fields=["updated_at"]
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
await sync_to_async(ChatMessage.objects.create, thread_sensitive=True)(
|
||||
@@ -355,6 +493,18 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
if tool_iterations >= MAX_TOOL_ITERATIONS:
|
||||
logger.warning(
|
||||
"Stopping chat tool loop after max iterations (%s)",
|
||||
MAX_TOOL_ITERATIONS,
|
||||
)
|
||||
payload = {
|
||||
"error": "The assistant stopped after too many tool calls. Please try again with a more specific request.",
|
||||
"error_category": "tool_loop_limit",
|
||||
}
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
streaming_content=self._async_to_sync_generator(event_stream()),
|
||||
content_type="text/event-stream",
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
|
||||
import litellm
|
||||
from django.conf import settings
|
||||
from django.shortcuts import get_object_or_404
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
@@ -11,10 +13,17 @@ from rest_framework.views import APIView
|
||||
from adventures.models import Collection
|
||||
from chat.agent_tools import search_places
|
||||
from chat.llm_client import (
|
||||
CHAT_PROVIDER_CONFIG,
|
||||
_safe_error_payload,
|
||||
get_llm_api_key,
|
||||
get_system_prompt,
|
||||
is_chat_provider_available,
|
||||
normalize_gateway_model,
|
||||
)
|
||||
from integrations.models import UserAISettings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DaySuggestionsView(APIView):
|
||||
@@ -52,7 +61,7 @@ class DaySuggestionsView(APIView):
|
||||
|
||||
location = location_context or self._get_collection_location(collection)
|
||||
system_prompt = get_system_prompt(request.user, collection)
|
||||
provider = "openai"
|
||||
provider, model = self._resolve_provider_and_model(request)
|
||||
|
||||
if not is_chat_provider_available(provider):
|
||||
return Response(
|
||||
@@ -78,12 +87,22 @@ class DaySuggestionsView(APIView):
|
||||
user_prompt=prompt,
|
||||
user=request.user,
|
||||
provider=provider,
|
||||
model=model,
|
||||
)
|
||||
return Response({"suggestions": suggestions}, status=status.HTTP_200_OK)
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to generate day suggestions")
|
||||
payload = _safe_error_payload(exc)
|
||||
status_code = {
|
||||
"model_not_found": status.HTTP_400_BAD_REQUEST,
|
||||
"authentication_failed": status.HTTP_401_UNAUTHORIZED,
|
||||
"rate_limited": status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
"invalid_request": status.HTTP_400_BAD_REQUEST,
|
||||
"provider_unreachable": status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
}.get(payload.get("error_category"), status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return Response(
|
||||
{"error": "Failed to generate suggestions. Please try again."},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
payload,
|
||||
status=status_code,
|
||||
)
|
||||
|
||||
def _get_collection_location(self, collection):
|
||||
@@ -174,31 +193,98 @@ class DaySuggestionsView(APIView):
|
||||
category=tool_category_map.get(category, "tourism"),
|
||||
radius=8,
|
||||
)
|
||||
if not isinstance(result, dict):
|
||||
return ""
|
||||
if result.get("error"):
|
||||
return ""
|
||||
|
||||
raw_results = result.get("results")
|
||||
if not isinstance(raw_results, list):
|
||||
return ""
|
||||
|
||||
entries = []
|
||||
for place in result.get("results", [])[:5]:
|
||||
for place in raw_results[:5]:
|
||||
if not isinstance(place, dict):
|
||||
continue
|
||||
name = place.get("name")
|
||||
address = place.get("address") or ""
|
||||
if name:
|
||||
entries.append(f"{name} ({address})" if address else name)
|
||||
return "; ".join(entries)
|
||||
|
||||
def _get_suggestions_from_llm(self, system_prompt, user_prompt, user, provider):
|
||||
def _resolve_provider_and_model(self, request):
|
||||
request_provider = (request.data.get("provider") or "").strip().lower() or None
|
||||
request_model = (request.data.get("model") or "").strip() or None
|
||||
|
||||
user_settings = UserAISettings.objects.filter(user=request.user).first() # type: ignore[attr-defined]
|
||||
preferred_provider = (
|
||||
(user_settings.preferred_provider or "").strip().lower()
|
||||
if user_settings and user_settings.preferred_provider
|
||||
else None
|
||||
)
|
||||
preferred_model = (
|
||||
(user_settings.preferred_model or "").strip()
|
||||
if user_settings and user_settings.preferred_model
|
||||
else None
|
||||
)
|
||||
|
||||
settings_provider = (settings.VOYAGE_AI_PROVIDER or "").strip().lower() or None
|
||||
|
||||
provider = request_provider or preferred_provider or settings_provider
|
||||
if not provider or not is_chat_provider_available(provider):
|
||||
provider = (
|
||||
settings_provider
|
||||
if is_chat_provider_available(settings_provider)
|
||||
else None
|
||||
)
|
||||
if not provider or not is_chat_provider_available(provider):
|
||||
provider = "openai" if is_chat_provider_available("openai") else provider
|
||||
|
||||
provider_config = CHAT_PROVIDER_CONFIG.get(provider or "", {})
|
||||
default_model = (
|
||||
(settings.VOYAGE_AI_MODEL or "").strip()
|
||||
if provider == settings_provider and settings.VOYAGE_AI_MODEL
|
||||
else None
|
||||
) or provider_config.get("default_model")
|
||||
|
||||
model_from_user_defaults = (
|
||||
preferred_model
|
||||
if preferred_provider and provider == preferred_provider
|
||||
else None
|
||||
)
|
||||
model = request_model or model_from_user_defaults or default_model
|
||||
return provider, model
|
||||
|
||||
def _get_suggestions_from_llm(
|
||||
self, system_prompt, user_prompt, user, provider, model
|
||||
):
|
||||
api_key = get_llm_api_key(user, provider)
|
||||
if not api_key:
|
||||
raise ValueError("No API key available")
|
||||
|
||||
response = litellm.completion(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
provider_config = CHAT_PROVIDER_CONFIG.get(provider, {})
|
||||
resolved_model = normalize_gateway_model(
|
||||
provider,
|
||||
model or provider_config.get("default_model"),
|
||||
)
|
||||
if not resolved_model:
|
||||
raise ValueError("No model configured for provider")
|
||||
|
||||
completion_kwargs = {
|
||||
"model": resolved_model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
api_key=api_key,
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
"api_key": api_key,
|
||||
"max_tokens": 1000,
|
||||
}
|
||||
|
||||
if provider_config.get("api_base"):
|
||||
completion_kwargs["api_base"] = provider_config["api_base"]
|
||||
|
||||
response = litellm.completion(
|
||||
**completion_kwargs,
|
||||
)
|
||||
|
||||
content = (response.choices[0].message.content or "").strip()
|
||||
|
||||
Reference in New Issue
Block a user