1150 lines
44 KiB
Python
1150 lines
44 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import re
|
|
from datetime import timedelta
|
|
|
|
from asgiref.sync import sync_to_async
|
|
from adventures.models import Collection
|
|
from django.http import StreamingHttpResponse
|
|
from django.utils import timezone
|
|
from integrations.models import UserAISettings
|
|
from rest_framework import status, viewsets
|
|
from rest_framework.decorators import action
|
|
from rest_framework.permissions import IsAuthenticated
|
|
from rest_framework.response import Response
|
|
|
|
from ..agent_tools import AGENT_TOOLS, execute_tool, serialize_tool_result
|
|
from ..llm_client import (
|
|
get_provider_catalog,
|
|
get_system_prompt,
|
|
is_chat_provider_available,
|
|
stream_chat_completion,
|
|
)
|
|
from ..models import ChatConversation, ChatMessage
|
|
from ..serializers import ChatConversationSerializer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChatViewSet(viewsets.ModelViewSet):
|
|
serializer_class = ChatConversationSerializer
|
|
permission_classes = [IsAuthenticated]
|
|
|
|
def get_queryset(self):
|
|
return ChatConversation.objects.filter(user=self.request.user).prefetch_related(
|
|
"messages"
|
|
)
|
|
|
|
def list(self, request, *args, **kwargs):
|
|
conversations = self.get_queryset().only("id", "title", "updated_at")
|
|
data = [
|
|
{
|
|
"id": str(conversation.id),
|
|
"title": conversation.title,
|
|
"updated_at": conversation.updated_at,
|
|
}
|
|
for conversation in conversations
|
|
]
|
|
return Response(data)
|
|
|
|
def create(self, request, *args, **kwargs):
|
|
conversation = ChatConversation.objects.create(
|
|
user=request.user,
|
|
title=(request.data.get("title") or "").strip(),
|
|
)
|
|
serializer = self.get_serializer(conversation)
|
|
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
|
|
|
def _build_llm_messages(self, conversation, user, system_prompt=None):
|
|
ordered_messages = list(conversation.messages.all().order_by("created_at"))
|
|
valid_tool_call_ids = {
|
|
message.tool_call_id
|
|
for message in ordered_messages
|
|
if message.role == "tool"
|
|
and message.tool_call_id
|
|
and not self._is_required_param_tool_error_message_content(message.content)
|
|
and not self._is_execution_failure_tool_error_message_content(
|
|
message.content
|
|
)
|
|
}
|
|
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt or get_system_prompt(user),
|
|
}
|
|
]
|
|
for message in ordered_messages:
|
|
if (
|
|
message.role == "tool"
|
|
and self._is_required_param_tool_error_message_content(message.content)
|
|
):
|
|
continue
|
|
if (
|
|
message.role == "tool"
|
|
and self._is_execution_failure_tool_error_message_content(
|
|
message.content
|
|
)
|
|
):
|
|
continue
|
|
|
|
payload = {
|
|
"role": message.role,
|
|
"content": message.content,
|
|
}
|
|
if message.role == "assistant" and message.tool_calls:
|
|
filtered_tool_calls = [
|
|
tool_call
|
|
for tool_call in message.tool_calls
|
|
if (tool_call or {}).get("id") in valid_tool_call_ids
|
|
]
|
|
if filtered_tool_calls:
|
|
payload["tool_calls"] = filtered_tool_calls
|
|
if message.role == "tool":
|
|
payload["tool_call_id"] = message.tool_call_id
|
|
payload["name"] = message.name
|
|
messages.append(payload)
|
|
return messages
|
|
|
|
def _async_to_sync_generator(self, async_gen):
|
|
loop = asyncio.new_event_loop()
|
|
try:
|
|
while True:
|
|
try:
|
|
yield loop.run_until_complete(async_gen.__anext__())
|
|
except StopAsyncIteration:
|
|
break
|
|
finally:
|
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
|
loop.close()
|
|
|
|
@staticmethod
|
|
def _merge_tool_call_delta(accumulator, tool_calls_delta):
|
|
for idx, tool_call in enumerate(tool_calls_delta or []):
|
|
idx = tool_call.get("index", idx)
|
|
while len(accumulator) <= idx:
|
|
accumulator.append(
|
|
{
|
|
"id": None,
|
|
"type": "function",
|
|
"function": {"name": "", "arguments": ""},
|
|
}
|
|
)
|
|
|
|
current = accumulator[idx]
|
|
if tool_call.get("id"):
|
|
current["id"] = tool_call.get("id")
|
|
if tool_call.get("type"):
|
|
current["type"] = tool_call.get("type")
|
|
|
|
function_data = tool_call.get("function") or {}
|
|
if function_data.get("name"):
|
|
current["function"]["name"] = function_data.get("name")
|
|
if function_data.get("arguments"):
|
|
current["function"]["arguments"] += function_data.get("arguments")
|
|
|
|
@staticmethod
|
|
def _is_required_param_tool_error(result):
|
|
if not isinstance(result, dict):
|
|
return False
|
|
|
|
error_text = result.get("error")
|
|
if not isinstance(error_text, str):
|
|
return False
|
|
|
|
normalized_error = error_text.strip().lower()
|
|
if normalized_error in {"location is required", "query is required"}:
|
|
return True
|
|
|
|
return bool(
|
|
re.fullmatch(
|
|
r"[a-z0-9_,\-\s]+ (is|are) required",
|
|
normalized_error,
|
|
)
|
|
)
|
|
|
|
@classmethod
|
|
def _is_execution_failure_tool_error(cls, result):
|
|
if not isinstance(result, dict):
|
|
return False
|
|
|
|
error_text = result.get("error")
|
|
if not isinstance(error_text, str) or not error_text.strip():
|
|
return False
|
|
|
|
return not cls._is_required_param_tool_error(result)
|
|
|
|
@staticmethod
|
|
def _is_retryable_execution_failure(result):
|
|
if not isinstance(result, dict):
|
|
return False
|
|
|
|
return result.get("retryable", True) is not False
|
|
|
|
@classmethod
|
|
def _is_required_param_tool_error_message_content(cls, content):
|
|
if not isinstance(content, str):
|
|
return False
|
|
|
|
try:
|
|
parsed = json.loads(content)
|
|
except json.JSONDecodeError:
|
|
return False
|
|
|
|
return cls._is_required_param_tool_error(parsed)
|
|
|
|
@classmethod
|
|
def _is_execution_failure_tool_error_message_content(cls, content):
|
|
if not isinstance(content, str):
|
|
return False
|
|
|
|
try:
|
|
parsed = json.loads(content)
|
|
except json.JSONDecodeError:
|
|
return False
|
|
|
|
return cls._is_execution_failure_tool_error(parsed)
|
|
|
|
@staticmethod
|
|
def _build_required_param_error_event(tool_name, result):
|
|
tool_error = result.get("error") if isinstance(result, dict) else ""
|
|
return {
|
|
"error": (
|
|
"The assistant attempted to call "
|
|
f"'{tool_name}' without required arguments ({tool_error}). "
|
|
"Please try your message again with more specific details."
|
|
),
|
|
"error_category": "tool_validation_error",
|
|
}
|
|
|
|
@classmethod
|
|
def _is_search_places_missing_location_required_error(cls, tool_name, result):
|
|
if tool_name != "search_places" or not cls._is_required_param_tool_error(
|
|
result
|
|
):
|
|
return False
|
|
|
|
error_text = (result or {}).get("error") if isinstance(result, dict) else ""
|
|
if not isinstance(error_text, str):
|
|
return False
|
|
|
|
normalized_error = error_text.strip().lower()
|
|
return "location" in normalized_error
|
|
|
|
@staticmethod
|
|
def _is_search_places_geocode_error(tool_name, result):
|
|
if tool_name != "search_places" or not isinstance(result, dict):
|
|
return False
|
|
|
|
error_text = result.get("error")
|
|
if not isinstance(error_text, str):
|
|
return False
|
|
|
|
return error_text.strip().lower().startswith("could not geocode location")
|
|
|
|
@classmethod
|
|
def _is_search_places_location_retry_candidate_error(cls, tool_name, result):
|
|
return cls._is_search_places_missing_location_required_error(
|
|
tool_name,
|
|
result,
|
|
) or cls._is_search_places_geocode_error(tool_name, result)
|
|
|
|
@classmethod
|
|
def _is_get_weather_missing_latlong_error(cls, tool_name, result):
|
|
"""True when get_weather was called without latitude/longitude."""
|
|
if tool_name != "get_weather" or not cls._is_required_param_tool_error(result):
|
|
return False
|
|
|
|
error_text = (result or {}).get("error") if isinstance(result, dict) else ""
|
|
if not isinstance(error_text, str):
|
|
return False
|
|
|
|
normalized_error = error_text.strip().lower()
|
|
return "latitude" in normalized_error or "longitude" in normalized_error
|
|
|
|
@staticmethod
|
|
def _extract_collection_coordinates(collection):
|
|
"""Return (lat, lon) from the first geocoded location in the collection, or None."""
|
|
if collection is None:
|
|
return None
|
|
for location in collection.locations.all():
|
|
lat = getattr(location, "latitude", None)
|
|
lon = getattr(location, "longitude", None)
|
|
if lat is not None and lon is not None:
|
|
try:
|
|
return float(lat), float(lon)
|
|
except (TypeError, ValueError):
|
|
continue
|
|
return None
|
|
|
|
@staticmethod
|
|
def _derive_weather_dates_from_collection(collection, max_days=7):
|
|
"""Derive a bounded weather date list from collection dates, or fallback to today."""
|
|
today = timezone.localdate()
|
|
if collection is None:
|
|
return [today.isoformat()]
|
|
|
|
start_date = getattr(collection, "start_date", None)
|
|
end_date = getattr(collection, "end_date", None)
|
|
|
|
if start_date and end_date:
|
|
range_start = min(start_date, end_date)
|
|
range_end = max(start_date, end_date)
|
|
day_count = min((range_end - range_start).days + 1, max_days)
|
|
return [
|
|
(range_start + timedelta(days=offset)).isoformat()
|
|
for offset in range(day_count)
|
|
]
|
|
|
|
if start_date:
|
|
return [start_date.isoformat()]
|
|
|
|
if end_date:
|
|
return [end_date.isoformat()]
|
|
|
|
return [today.isoformat()]
|
|
|
|
@staticmethod
|
|
def _build_search_places_location_clarification_message():
|
|
return (
|
|
"Could you share the specific location you'd like me to search near "
|
|
"(city, neighborhood, or address)? I can also focus on food, "
|
|
"activities, or lodging."
|
|
)
|
|
|
|
@staticmethod
|
|
def _build_tool_execution_error_event(tool_name, result):
|
|
tool_error = (
|
|
(result or {}).get("error")
|
|
if isinstance(result, dict)
|
|
else "Tool execution failed"
|
|
)
|
|
return {
|
|
"error": (
|
|
f"The assistant could not complete '{tool_name}' ({tool_error}). "
|
|
"Please try again in a moment or adjust your request."
|
|
),
|
|
"error_category": "tool_execution_error",
|
|
}
|
|
|
|
@staticmethod
|
|
def _normalize_trip_context_destination(destination):
|
|
destination_text = (destination or "").strip()
|
|
if not destination_text:
|
|
return ""
|
|
|
|
if ";" not in destination_text:
|
|
if re.fullmatch(r"\+\d+\s+more", destination_text, re.IGNORECASE):
|
|
return ""
|
|
return destination_text
|
|
|
|
for segment in destination_text.split(";"):
|
|
segment_text = segment.strip()
|
|
if not segment_text:
|
|
continue
|
|
|
|
if re.fullmatch(r"\+\d+\s+more", segment_text, re.IGNORECASE):
|
|
continue
|
|
|
|
return segment_text
|
|
|
|
return ""
|
|
|
|
@classmethod
|
|
def _trip_context_search_location(cls, destination, itinerary_stops):
|
|
destination_text = cls._normalize_trip_context_destination(destination)
|
|
if destination_text:
|
|
return destination_text
|
|
|
|
for stop in itinerary_stops or []:
|
|
stop_text = (stop or "").strip()
|
|
if stop_text:
|
|
return stop_text
|
|
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _infer_search_places_category(user_content, prior_user_messages):
|
|
message_parts = [(user_content or "").strip()]
|
|
message_parts.extend(
|
|
(content or "").strip() for content in prior_user_messages or []
|
|
)
|
|
normalized = " ".join(part for part in message_parts if part).lower()
|
|
if not normalized:
|
|
return None
|
|
|
|
dining_intent_pattern = (
|
|
r"\b(restaurant|restaurants|dining|dinner|lunch|breakfast|brunch|"
|
|
r"cafe|cafes|food|eat|eating|cuisine|meal|meals|bistro|bar|bars)\b"
|
|
)
|
|
if re.search(dining_intent_pattern, normalized):
|
|
return "food"
|
|
|
|
return None
|
|
|
|
# Verbs that indicate a command/request rather than a location reply.
|
|
_COMMAND_VERBS = frozenset(
|
|
[
|
|
"find",
|
|
"search",
|
|
"show",
|
|
"get",
|
|
"look",
|
|
"give",
|
|
"tell",
|
|
"help",
|
|
"recommend",
|
|
"suggest",
|
|
"list",
|
|
"fetch",
|
|
"what",
|
|
"where",
|
|
"which",
|
|
"who",
|
|
"how",
|
|
"can",
|
|
"could",
|
|
"would",
|
|
"please",
|
|
]
|
|
)
|
|
|
|
@classmethod
|
|
def _is_likely_location_reply(cls, user_content):
|
|
if not isinstance(user_content, str):
|
|
return False
|
|
|
|
normalized = user_content.strip()
|
|
if not normalized:
|
|
return False
|
|
|
|
if normalized.endswith("?"):
|
|
return False
|
|
|
|
if len(normalized) > 80:
|
|
return False
|
|
|
|
parts = normalized.split()
|
|
if len(parts) > 6:
|
|
return False
|
|
|
|
# Exclude messages that start with a command/request verb — those are
|
|
# original requests, not replies to a "where?" clarification prompt.
|
|
first_word = parts[0].lower().rstrip(".,!;:")
|
|
if first_word in cls._COMMAND_VERBS:
|
|
return False
|
|
|
|
return bool(re.search(r"[a-z0-9]", normalized, re.IGNORECASE))
|
|
|
|
@action(detail=True, methods=["post"])
|
|
def send_message(self, request, pk=None):
|
|
# Auto-learn preferences from user's travel history
|
|
from integrations.utils.auto_profile import update_auto_preference_profile
|
|
|
|
try:
|
|
update_auto_preference_profile(request.user)
|
|
except Exception as exc:
|
|
logger.warning("Auto-profile update failed: %s", exc)
|
|
# Continue anyway - not critical
|
|
|
|
conversation = self.get_object()
|
|
user_content = (request.data.get("message") or "").strip()
|
|
if not user_content:
|
|
return Response(
|
|
{"error": "message is required"},
|
|
status=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
requested_provider = (request.data.get("provider") or "").strip().lower()
|
|
requested_model = (request.data.get("model") or "").strip() or None
|
|
ai_settings = UserAISettings.objects.filter(user=request.user).first()
|
|
preferred_provider = (
|
|
(ai_settings.preferred_provider or "").strip().lower()
|
|
if ai_settings
|
|
else ""
|
|
)
|
|
preferred_model = (
|
|
(ai_settings.preferred_model or "").strip() if ai_settings else ""
|
|
)
|
|
|
|
provider = requested_provider
|
|
if not provider and preferred_provider:
|
|
if preferred_provider and is_chat_provider_available(preferred_provider):
|
|
provider = preferred_provider
|
|
|
|
if not provider:
|
|
provider = "openai"
|
|
|
|
model = requested_model
|
|
if model is None and preferred_model and provider == preferred_provider:
|
|
model = preferred_model
|
|
|
|
collection_id = request.data.get("collection_id")
|
|
collection_name = request.data.get("collection_name")
|
|
start_date = request.data.get("start_date")
|
|
end_date = request.data.get("end_date")
|
|
destination = request.data.get("destination")
|
|
if not is_chat_provider_available(provider):
|
|
return Response(
|
|
{"error": f"Provider is not available for chat: {provider}."},
|
|
status=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
|
|
context_parts = []
|
|
itinerary_stops = []
|
|
if collection_name:
|
|
context_parts.append(f"Trip: {collection_name}")
|
|
if start_date and end_date:
|
|
context_parts.append(f"Dates: {start_date} to {end_date}")
|
|
|
|
collection = None
|
|
if collection_id:
|
|
try:
|
|
requested_collection = Collection.objects.get(id=collection_id)
|
|
if (
|
|
requested_collection.user == request.user
|
|
or requested_collection.shared_with.filter(
|
|
id=request.user.id
|
|
).exists()
|
|
):
|
|
collection = requested_collection
|
|
except Collection.DoesNotExist:
|
|
pass
|
|
|
|
if collection:
|
|
context_parts.append(
|
|
"Collection UUID (use this exact collection_id for get_trip_details and add_to_itinerary): "
|
|
f"{collection.id}"
|
|
)
|
|
seen_stops = set()
|
|
for location in collection.locations.select_related(
|
|
"city", "country"
|
|
).all():
|
|
city_name = (getattr(location.city, "name", "") or "").strip()
|
|
country_name = (getattr(location.country, "name", "") or "").strip()
|
|
|
|
if city_name or country_name:
|
|
stop_label = (
|
|
f"{city_name}, {country_name}"
|
|
if city_name and country_name
|
|
else city_name or country_name
|
|
)
|
|
stop_key = f"geo:{city_name.lower()}|{country_name.lower()}"
|
|
else:
|
|
fallback_name = (location.location or location.name or "").strip()
|
|
if not fallback_name:
|
|
continue
|
|
# When city/country FKs are not set, try to extract a geocodable
|
|
# city name from a comma-separated address string.
|
|
# e.g. "Little Turnstile 6, London" → "London"
|
|
# e.g. "Kingsway 58, London" → "London"
|
|
if "," in fallback_name:
|
|
parts = [p.strip() for p in fallback_name.split(",")]
|
|
# Last non-empty, non-purely-numeric segment is typically the city
|
|
city_hint = next(
|
|
(p for p in reversed(parts) if p and not p.isdigit()),
|
|
None,
|
|
)
|
|
stop_label = city_hint if city_hint else fallback_name
|
|
else:
|
|
stop_label = fallback_name
|
|
stop_key = f"name:{stop_label.lower()}"
|
|
|
|
if stop_key in seen_stops:
|
|
continue
|
|
|
|
seen_stops.add(stop_key)
|
|
itinerary_stops.append(stop_label)
|
|
|
|
if len(itinerary_stops) >= 8:
|
|
break
|
|
|
|
if itinerary_stops:
|
|
context_parts.append(f"Itinerary stops: {'; '.join(itinerary_stops)}")
|
|
|
|
trip_context_location = self._trip_context_search_location(
|
|
destination, itinerary_stops
|
|
)
|
|
if trip_context_location:
|
|
context_parts.append(f"Destination: {trip_context_location}")
|
|
prior_user_messages = list(
|
|
conversation.messages.filter(role="user")
|
|
.order_by("-created_at")
|
|
.values_list("content", flat=True)[:3]
|
|
)
|
|
|
|
system_prompt = get_system_prompt(request.user, collection)
|
|
if context_parts:
|
|
system_prompt += "\n\n## Trip Context\n" + "\n".join(context_parts)
|
|
|
|
ChatMessage.objects.create(
|
|
conversation=conversation,
|
|
role="user",
|
|
content=user_content,
|
|
)
|
|
conversation.save(update_fields=["updated_at"])
|
|
|
|
if not conversation.title:
|
|
conversation.title = user_content[:120]
|
|
conversation.save(update_fields=["title", "updated_at"])
|
|
|
|
llm_messages = self._build_llm_messages(
|
|
conversation,
|
|
request.user,
|
|
system_prompt=system_prompt,
|
|
)
|
|
|
|
MAX_TOOL_ITERATIONS = 10
|
|
MAX_ALL_FAILURE_ROUNDS = 3
|
|
|
|
async def event_stream():
|
|
current_messages = list(llm_messages)
|
|
encountered_error = False
|
|
tool_iterations = 0
|
|
all_failure_rounds = 0
|
|
|
|
while tool_iterations < MAX_TOOL_ITERATIONS:
|
|
content_chunks = []
|
|
tool_calls_accumulator = []
|
|
|
|
async for chunk in stream_chat_completion(
|
|
request.user,
|
|
current_messages,
|
|
provider,
|
|
tools=AGENT_TOOLS,
|
|
model=model,
|
|
):
|
|
if not chunk.startswith("data: "):
|
|
yield chunk
|
|
continue
|
|
|
|
payload = chunk[len("data: ") :].strip()
|
|
if payload == "[DONE]":
|
|
continue
|
|
|
|
yield chunk
|
|
|
|
try:
|
|
data = json.loads(payload)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
if data.get("error"):
|
|
encountered_error = True
|
|
break
|
|
|
|
if data.get("content"):
|
|
content_chunks.append(data["content"])
|
|
|
|
if data.get("tool_calls"):
|
|
self._merge_tool_call_delta(
|
|
tool_calls_accumulator,
|
|
data["tool_calls"],
|
|
)
|
|
|
|
if encountered_error:
|
|
yield "data: [DONE]\n\n"
|
|
break
|
|
|
|
assistant_content = "".join(content_chunks)
|
|
|
|
if tool_calls_accumulator:
|
|
successful_tool_calls = []
|
|
successful_tool_messages = []
|
|
successful_tool_chat_entries = []
|
|
first_execution_failure = None
|
|
encountered_permanent_failure = False
|
|
|
|
for tool_call in tool_calls_accumulator:
|
|
function_payload = tool_call.get("function") or {}
|
|
function_name = function_payload.get("name") or ""
|
|
raw_arguments = function_payload.get("arguments") or "{}"
|
|
|
|
try:
|
|
arguments = json.loads(raw_arguments)
|
|
except json.JSONDecodeError:
|
|
arguments = {}
|
|
if not isinstance(arguments, dict):
|
|
arguments = {}
|
|
|
|
prepared_arguments = dict(arguments)
|
|
tool_call_for_history = tool_call
|
|
|
|
if function_name == "search_places":
|
|
if not (prepared_arguments.get("category") or "").strip():
|
|
inferred_category = self._infer_search_places_category(
|
|
user_content,
|
|
prior_user_messages,
|
|
)
|
|
if inferred_category:
|
|
prepared_arguments["category"] = inferred_category
|
|
|
|
if prepared_arguments != arguments:
|
|
tool_call_for_history = {
|
|
**tool_call,
|
|
"function": {
|
|
**function_payload,
|
|
"name": function_name,
|
|
"arguments": json.dumps(prepared_arguments),
|
|
},
|
|
}
|
|
|
|
result = await sync_to_async(
|
|
execute_tool, thread_sensitive=True
|
|
)(
|
|
function_name,
|
|
request.user,
|
|
**prepared_arguments,
|
|
)
|
|
|
|
attempted_location_retry = False
|
|
if self._is_search_places_location_retry_candidate_error(
|
|
function_name,
|
|
result,
|
|
):
|
|
retry_locations = []
|
|
if trip_context_location:
|
|
retry_locations.append(trip_context_location)
|
|
if self._is_likely_location_reply(user_content):
|
|
retry_locations.append(user_content)
|
|
|
|
seen_retry_locations = set()
|
|
for retry_location in retry_locations:
|
|
normalized_retry_location = (
|
|
retry_location.strip().lower()
|
|
)
|
|
if (
|
|
not normalized_retry_location
|
|
or normalized_retry_location in seen_retry_locations
|
|
):
|
|
continue
|
|
seen_retry_locations.add(normalized_retry_location)
|
|
attempted_location_retry = True
|
|
|
|
retry_arguments = dict(prepared_arguments)
|
|
retry_arguments["location"] = retry_location
|
|
retry_result = await sync_to_async(
|
|
execute_tool,
|
|
thread_sensitive=True,
|
|
)(
|
|
function_name,
|
|
request.user,
|
|
**retry_arguments,
|
|
)
|
|
|
|
if not self._is_required_param_tool_error(
|
|
retry_result
|
|
) and not self._is_execution_failure_tool_error(
|
|
retry_result
|
|
):
|
|
result = retry_result
|
|
tool_call_for_history = {
|
|
**tool_call,
|
|
"function": {
|
|
**function_payload,
|
|
"name": function_name,
|
|
"arguments": json.dumps(retry_arguments),
|
|
},
|
|
}
|
|
break
|
|
|
|
# If we attempted context retries but all failed, convert
|
|
# to an execution failure so we never ask the user for a
|
|
# location they already implied via itinerary context.
|
|
if (
|
|
attempted_location_retry
|
|
and self._is_required_param_tool_error(result)
|
|
):
|
|
result = {
|
|
"error": "Could not search places at the provided itinerary locations"
|
|
}
|
|
|
|
attempted_weather_coord_retry = False
|
|
if self._is_get_weather_missing_latlong_error(
|
|
function_name, result
|
|
):
|
|
coords = await sync_to_async(
|
|
self._extract_collection_coordinates,
|
|
thread_sensitive=True,
|
|
)(collection)
|
|
if coords is not None:
|
|
retry_lat, retry_lon = coords
|
|
retry_arguments = dict(prepared_arguments)
|
|
retry_arguments["latitude"] = retry_lat
|
|
retry_arguments["longitude"] = retry_lon
|
|
if not retry_arguments.get("dates"):
|
|
retry_arguments["dates"] = (
|
|
self._derive_weather_dates_from_collection(
|
|
collection
|
|
)
|
|
)
|
|
attempted_weather_coord_retry = True
|
|
retry_result = await sync_to_async(
|
|
execute_tool,
|
|
thread_sensitive=True,
|
|
)(
|
|
function_name,
|
|
request.user,
|
|
**retry_arguments,
|
|
)
|
|
if not self._is_required_param_tool_error(
|
|
retry_result
|
|
) and not self._is_execution_failure_tool_error(
|
|
retry_result
|
|
):
|
|
result = retry_result
|
|
tool_call_for_history = {
|
|
**tool_call,
|
|
"function": {
|
|
**function_payload,
|
|
"name": function_name,
|
|
"arguments": json.dumps(retry_arguments),
|
|
},
|
|
}
|
|
|
|
# If retry was attempted but still failed, convert to an
|
|
# execution failure — never ask the user for coordinates
|
|
# they implied via collection context.
|
|
if (
|
|
attempted_weather_coord_retry
|
|
and self._is_required_param_tool_error(result)
|
|
and self._is_get_weather_missing_latlong_error(
|
|
function_name,
|
|
result,
|
|
)
|
|
):
|
|
result = {
|
|
"error": "Could not fetch weather for the collection locations"
|
|
}
|
|
|
|
if self._is_required_param_tool_error(result):
|
|
assistant_message_kwargs = {
|
|
"conversation": conversation,
|
|
"role": "assistant",
|
|
"content": assistant_content,
|
|
}
|
|
if successful_tool_calls:
|
|
assistant_message_kwargs["tool_calls"] = (
|
|
successful_tool_calls
|
|
)
|
|
|
|
await sync_to_async(
|
|
ChatMessage.objects.create, thread_sensitive=True
|
|
)(**assistant_message_kwargs)
|
|
|
|
for tool_message in successful_tool_messages:
|
|
await sync_to_async(
|
|
ChatMessage.objects.create,
|
|
thread_sensitive=True,
|
|
)(**tool_message)
|
|
|
|
if (
|
|
not attempted_location_retry
|
|
and self._is_search_places_missing_location_required_error(
|
|
function_name,
|
|
result,
|
|
)
|
|
):
|
|
clarification_content = self._build_search_places_location_clarification_message()
|
|
await sync_to_async(
|
|
ChatMessage.objects.create,
|
|
thread_sensitive=True,
|
|
)(
|
|
conversation=conversation,
|
|
role="assistant",
|
|
content=clarification_content,
|
|
)
|
|
|
|
await sync_to_async(
|
|
conversation.save,
|
|
thread_sensitive=True,
|
|
)(update_fields=["updated_at"])
|
|
|
|
yield (
|
|
"data: "
|
|
f"{json.dumps({'content': clarification_content})}"
|
|
"\n\n"
|
|
)
|
|
yield "data: [DONE]\n\n"
|
|
return
|
|
|
|
await sync_to_async(
|
|
conversation.save,
|
|
thread_sensitive=True,
|
|
)(update_fields=["updated_at"])
|
|
|
|
logger.info(
|
|
"Stopping chat tool loop due to required-arg tool validation error: %s (%s)",
|
|
function_name,
|
|
result.get("error"),
|
|
)
|
|
error_event = self._build_required_param_error_event(
|
|
function_name,
|
|
result,
|
|
)
|
|
yield f"data: {json.dumps(error_event)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
return
|
|
|
|
if self._is_execution_failure_tool_error(result):
|
|
if first_execution_failure is None:
|
|
first_execution_failure = (function_name, result)
|
|
if not self._is_retryable_execution_failure(result):
|
|
encountered_permanent_failure = True
|
|
continue
|
|
|
|
result_content = serialize_tool_result(result)
|
|
|
|
successful_tool_calls.append(tool_call_for_history)
|
|
tool_message_payload = {
|
|
"conversation": conversation,
|
|
"role": "tool",
|
|
"content": result_content,
|
|
"tool_call_id": tool_call_for_history.get("id"),
|
|
"name": function_name,
|
|
}
|
|
successful_tool_messages.append(tool_message_payload)
|
|
successful_tool_chat_entries.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tool_call_for_history.get("id"),
|
|
"name": function_name,
|
|
"content": result_content,
|
|
}
|
|
)
|
|
|
|
tool_event = {
|
|
"tool_result": {
|
|
"tool_call_id": tool_call_for_history.get("id"),
|
|
"name": function_name,
|
|
"result": result,
|
|
}
|
|
}
|
|
yield f"data: {json.dumps(tool_event)}\n\n"
|
|
|
|
if not successful_tool_calls and first_execution_failure:
|
|
if encountered_permanent_failure:
|
|
all_failure_rounds = MAX_ALL_FAILURE_ROUNDS
|
|
else:
|
|
all_failure_rounds += 1
|
|
|
|
if all_failure_rounds >= MAX_ALL_FAILURE_ROUNDS:
|
|
failed_tool_name, failed_tool_result = (
|
|
first_execution_failure
|
|
)
|
|
error_event = self._build_tool_execution_error_event(
|
|
failed_tool_name,
|
|
failed_tool_result,
|
|
)
|
|
await sync_to_async(
|
|
ChatMessage.objects.create,
|
|
thread_sensitive=True,
|
|
)(
|
|
conversation=conversation,
|
|
role="assistant",
|
|
content=error_event["error"],
|
|
)
|
|
await sync_to_async(
|
|
conversation.save,
|
|
thread_sensitive=True,
|
|
)(update_fields=["updated_at"])
|
|
yield f"data: {json.dumps(error_event)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
return
|
|
|
|
continue
|
|
|
|
all_failure_rounds = 0
|
|
tool_iterations += 1
|
|
|
|
assistant_with_tools = {
|
|
"role": "assistant",
|
|
"content": assistant_content,
|
|
"tool_calls": successful_tool_calls,
|
|
}
|
|
current_messages.append(assistant_with_tools)
|
|
current_messages.extend(successful_tool_chat_entries)
|
|
|
|
await sync_to_async(
|
|
ChatMessage.objects.create, thread_sensitive=True
|
|
)(
|
|
conversation=conversation,
|
|
role="assistant",
|
|
content=assistant_content,
|
|
tool_calls=successful_tool_calls,
|
|
)
|
|
for tool_message in successful_tool_messages:
|
|
await sync_to_async(
|
|
ChatMessage.objects.create,
|
|
thread_sensitive=True,
|
|
)(**tool_message)
|
|
|
|
await sync_to_async(conversation.save, thread_sensitive=True)(
|
|
update_fields=["updated_at"]
|
|
)
|
|
|
|
continue
|
|
|
|
await sync_to_async(ChatMessage.objects.create, thread_sensitive=True)(
|
|
conversation=conversation,
|
|
role="assistant",
|
|
content=assistant_content,
|
|
)
|
|
await sync_to_async(conversation.save, thread_sensitive=True)(
|
|
update_fields=["updated_at"]
|
|
)
|
|
yield "data: [DONE]\n\n"
|
|
break
|
|
|
|
if tool_iterations >= MAX_TOOL_ITERATIONS:
|
|
logger.warning(
|
|
"Stopping chat tool loop after max iterations (%s)",
|
|
MAX_TOOL_ITERATIONS,
|
|
)
|
|
payload = {
|
|
"error": "The assistant stopped after too many tool calls. Please try again with a more specific request.",
|
|
"error_category": "tool_loop_limit",
|
|
}
|
|
yield f"data: {json.dumps(payload)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
response = StreamingHttpResponse(
|
|
streaming_content=self._async_to_sync_generator(event_stream()),
|
|
content_type="text/event-stream",
|
|
)
|
|
response["Cache-Control"] = "no-cache"
|
|
response["X-Accel-Buffering"] = "no"
|
|
return response
|
|
|
|
|
|
class ChatProviderCatalogViewSet(viewsets.ViewSet):
|
|
permission_classes = [IsAuthenticated]
|
|
|
|
def list(self, request):
|
|
return Response(get_provider_catalog(user=request.user))
|
|
|
|
@action(detail=True, methods=["get"])
|
|
def models(self, request, pk=None):
|
|
"""Fetch available models from a provider's API."""
|
|
from chat.llm_client import CHAT_PROVIDER_CONFIG, get_llm_api_key
|
|
|
|
provider = (pk or "").lower()
|
|
|
|
api_key = get_llm_api_key(request.user, provider)
|
|
if not api_key:
|
|
return Response(
|
|
{"error": "No API key configured for this provider"},
|
|
status=status.HTTP_403_FORBIDDEN,
|
|
)
|
|
|
|
try:
|
|
if provider == "openai":
|
|
import openai
|
|
|
|
client = openai.OpenAI(api_key=api_key)
|
|
models = client.models.list()
|
|
chat_models = [
|
|
model.id
|
|
for model in models
|
|
if any(prefix in model.id for prefix in ["gpt-", "o1-", "chatgpt"])
|
|
]
|
|
return Response({"models": sorted(set(chat_models), reverse=True)})
|
|
|
|
if provider in ["anthropic", "claude"]:
|
|
return Response(
|
|
{
|
|
"models": [
|
|
"claude-sonnet-4-20250514",
|
|
"claude-opus-4-20250514",
|
|
"claude-3-5-sonnet-20241022",
|
|
"claude-3-5-haiku-20241022",
|
|
"claude-3-haiku-20240307",
|
|
]
|
|
}
|
|
)
|
|
|
|
if provider in ["gemini", "google"]:
|
|
return Response(
|
|
{
|
|
"models": [
|
|
"gemini-2.0-flash",
|
|
"gemini-1.5-pro",
|
|
"gemini-1.5-flash",
|
|
"gemini-1.5-flash-8b",
|
|
]
|
|
}
|
|
)
|
|
|
|
if provider in ["groq"]:
|
|
return Response(
|
|
{
|
|
"models": [
|
|
"llama-3.3-70b-versatile",
|
|
"llama-3.1-70b-versatile",
|
|
"llama-3.1-8b-instant",
|
|
"mixtral-8x7b-32768",
|
|
]
|
|
}
|
|
)
|
|
|
|
if provider in ["ollama"]:
|
|
import requests
|
|
|
|
try:
|
|
response = requests.get(
|
|
"http://localhost:11434/api/tags", timeout=5
|
|
)
|
|
if response.ok:
|
|
data = response.json()
|
|
models = [item["name"] for item in data.get("models", [])]
|
|
return Response({"models": models})
|
|
except Exception:
|
|
pass
|
|
return Response({"models": []})
|
|
|
|
if provider == "opencode_zen":
|
|
import requests
|
|
|
|
config = CHAT_PROVIDER_CONFIG.get("opencode_zen", {})
|
|
api_base = config.get("api_base", "https://opencode.ai/zen/v1")
|
|
response = requests.get(
|
|
f"{api_base}/models",
|
|
headers={"Authorization": f"Bearer {api_key}"},
|
|
timeout=10,
|
|
)
|
|
|
|
if response.ok:
|
|
data = response.json()
|
|
raw_models = (
|
|
data.get("data", data) if isinstance(data, dict) else data
|
|
)
|
|
model_ids = []
|
|
for model_entry in raw_models:
|
|
if not isinstance(model_entry, dict):
|
|
continue
|
|
|
|
model_id = model_entry.get("id") or model_entry.get("model_id")
|
|
if model_id:
|
|
model_ids.append(model_id)
|
|
|
|
return Response({"models": sorted(set(model_ids))})
|
|
|
|
logger.warning(
|
|
"OpenCode Zen models fetch failed with status %s",
|
|
response.status_code,
|
|
)
|
|
return Response({"models": []})
|
|
|
|
return Response({"models": []})
|
|
except Exception as exc:
|
|
logger.error("Failed to fetch models for %s: %s", provider, exc)
|
|
return Response(
|
|
{"error": f"Failed to fetch models: {str(exc)}"},
|
|
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
)
|
|
|
|
|
|
from .capabilities import CapabilitiesView
|
|
from .day_suggestions import DaySuggestionsView
|