fix(chat): add saved AI defaults and harden suggestions

This commit is contained in:
2026-03-09 20:32:13 +00:00
parent 21954df3ee
commit bb54503235
38 changed files with 3949 additions and 105 deletions

View File

@@ -1,10 +1,12 @@
import asyncio
import json
import logging
import re
from asgiref.sync import sync_to_async
from adventures.models import Collection
from django.http import StreamingHttpResponse
from integrations.models import UserAISettings
from rest_framework import status, viewsets
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
@@ -53,19 +55,40 @@ class ChatViewSet(viewsets.ModelViewSet):
return Response(serializer.data, status=status.HTTP_201_CREATED)
def _build_llm_messages(self, conversation, user, system_prompt=None):
ordered_messages = list(conversation.messages.all().order_by("created_at"))
valid_tool_call_ids = {
message.tool_call_id
for message in ordered_messages
if message.role == "tool"
and message.tool_call_id
and not self._is_required_param_tool_error_message_content(message.content)
}
messages = [
{
"role": "system",
"content": system_prompt or get_system_prompt(user),
}
]
for message in conversation.messages.all().order_by("created_at"):
for message in ordered_messages:
if (
message.role == "tool"
and self._is_required_param_tool_error_message_content(message.content)
):
continue
payload = {
"role": message.role,
"content": message.content,
}
if message.role == "assistant" and message.tool_calls:
payload["tool_calls"] = message.tool_calls
filtered_tool_calls = [
tool_call
for tool_call in message.tool_calls
if (tool_call or {}).get("id") in valid_tool_call_ids
]
if filtered_tool_calls:
payload["tool_calls"] = filtered_tool_calls
if message.role == "tool":
payload["tool_call_id"] = message.tool_call_id
payload["name"] = message.name
@@ -109,6 +132,50 @@ class ChatViewSet(viewsets.ModelViewSet):
if function_data.get("arguments"):
current["function"]["arguments"] += function_data.get("arguments")
@staticmethod
def _is_required_param_tool_error(result):
if not isinstance(result, dict):
return False
error_text = result.get("error")
if not isinstance(error_text, str):
return False
normalized_error = error_text.strip().lower()
if normalized_error in {"location is required", "query is required"}:
return True
return bool(
re.fullmatch(
r"[a-z0-9_,\-\s]+ (is|are) required",
normalized_error,
)
)
@classmethod
def _is_required_param_tool_error_message_content(cls, content):
if not isinstance(content, str):
return False
try:
parsed = json.loads(content)
except json.JSONDecodeError:
return False
return cls._is_required_param_tool_error(parsed)
@staticmethod
def _build_required_param_error_event(tool_name, result):
tool_error = result.get("error") if isinstance(result, dict) else ""
return {
"error": (
"The assistant attempted to call "
f"'{tool_name}' without required arguments ({tool_error}). "
"Please try your message again with more specific details."
),
"error_category": "tool_validation_error",
}
@action(detail=True, methods=["post"])
def send_message(self, request, pk=None):
# Auto-learn preferences from user's travel history
@@ -128,8 +195,30 @@ class ChatViewSet(viewsets.ModelViewSet):
status=status.HTTP_400_BAD_REQUEST,
)
provider = (request.data.get("provider") or "openai").strip().lower()
model = (request.data.get("model") or "").strip() or None
requested_provider = (request.data.get("provider") or "").strip().lower()
requested_model = (request.data.get("model") or "").strip() or None
ai_settings = UserAISettings.objects.filter(user=request.user).first()
preferred_provider = (
(ai_settings.preferred_provider or "").strip().lower()
if ai_settings
else ""
)
preferred_model = (
(ai_settings.preferred_model or "").strip() if ai_settings else ""
)
provider = requested_provider
if not provider and preferred_provider:
if preferred_provider and is_chat_provider_available(preferred_provider):
provider = preferred_provider
if not provider:
provider = "openai"
model = requested_model
if model is None and preferred_model and provider == preferred_provider:
model = preferred_model
collection_id = request.data.get("collection_id")
collection_name = request.data.get("collection_name")
start_date = request.data.get("start_date")
@@ -266,29 +355,16 @@ class ChatViewSet(viewsets.ModelViewSet):
)
if encountered_error:
yield "data: [DONE]\n\n"
break
assistant_content = "".join(content_chunks)
if tool_calls_accumulator:
assistant_with_tools = {
"role": "assistant",
"content": assistant_content,
"tool_calls": tool_calls_accumulator,
}
current_messages.append(assistant_with_tools)
await sync_to_async(
ChatMessage.objects.create, thread_sensitive=True
)(
conversation=conversation,
role="assistant",
content=assistant_content,
tool_calls=tool_calls_accumulator,
)
await sync_to_async(conversation.save, thread_sensitive=True)(
update_fields=["updated_at"]
)
tool_iterations += 1
successful_tool_calls = []
successful_tool_messages = []
successful_tool_chat_entries = []
for tool_call in tool_calls_accumulator:
function_payload = tool_call.get("function") or {}
@@ -309,9 +385,58 @@ class ChatViewSet(viewsets.ModelViewSet):
request.user,
**arguments,
)
if self._is_required_param_tool_error(result):
assistant_message_kwargs = {
"conversation": conversation,
"role": "assistant",
"content": assistant_content,
}
if successful_tool_calls:
assistant_message_kwargs["tool_calls"] = (
successful_tool_calls
)
await sync_to_async(
ChatMessage.objects.create, thread_sensitive=True
)(**assistant_message_kwargs)
for tool_message in successful_tool_messages:
await sync_to_async(
ChatMessage.objects.create,
thread_sensitive=True,
)(**tool_message)
await sync_to_async(
conversation.save,
thread_sensitive=True,
)(update_fields=["updated_at"])
logger.info(
"Stopping chat tool loop due to required-arg tool validation error: %s (%s)",
function_name,
result.get("error"),
)
error_event = self._build_required_param_error_event(
function_name,
result,
)
yield f"data: {json.dumps(error_event)}\n\n"
yield "data: [DONE]\n\n"
return
result_content = serialize_tool_result(result)
current_messages.append(
successful_tool_calls.append(tool_call)
tool_message_payload = {
"conversation": conversation,
"role": "tool",
"content": result_content,
"tool_call_id": tool_call.get("id"),
"name": function_name,
}
successful_tool_messages.append(tool_message_payload)
successful_tool_chat_entries.append(
{
"role": "tool",
"tool_call_id": tool_call.get("id"),
@@ -320,19 +445,6 @@ class ChatViewSet(viewsets.ModelViewSet):
}
)
await sync_to_async(
ChatMessage.objects.create, thread_sensitive=True
)(
conversation=conversation,
role="tool",
content=result_content,
tool_call_id=tool_call.get("id"),
name=function_name,
)
await sync_to_async(conversation.save, thread_sensitive=True)(
update_fields=["updated_at"]
)
tool_event = {
"tool_result": {
"tool_call_id": tool_call.get("id"),
@@ -342,6 +454,32 @@ class ChatViewSet(viewsets.ModelViewSet):
}
yield f"data: {json.dumps(tool_event)}\n\n"
assistant_with_tools = {
"role": "assistant",
"content": assistant_content,
"tool_calls": successful_tool_calls,
}
current_messages.append(assistant_with_tools)
current_messages.extend(successful_tool_chat_entries)
await sync_to_async(
ChatMessage.objects.create, thread_sensitive=True
)(
conversation=conversation,
role="assistant",
content=assistant_content,
tool_calls=successful_tool_calls,
)
for tool_message in successful_tool_messages:
await sync_to_async(
ChatMessage.objects.create,
thread_sensitive=True,
)(**tool_message)
await sync_to_async(conversation.save, thread_sensitive=True)(
update_fields=["updated_at"]
)
continue
await sync_to_async(ChatMessage.objects.create, thread_sensitive=True)(
@@ -355,6 +493,18 @@ class ChatViewSet(viewsets.ModelViewSet):
yield "data: [DONE]\n\n"
break
if tool_iterations >= MAX_TOOL_ITERATIONS:
logger.warning(
"Stopping chat tool loop after max iterations (%s)",
MAX_TOOL_ITERATIONS,
)
payload = {
"error": "The assistant stopped after too many tool calls. Please try again with a more specific request.",
"error_category": "tool_loop_limit",
}
yield f"data: {json.dumps(payload)}\n\n"
yield "data: [DONE]\n\n"
response = StreamingHttpResponse(
streaming_content=self._async_to_sync_generator(event_stream()),
content_type="text/event-stream",