feat(ai): implement agent-redesign plan with enhanced AI travel features
Phase 1 - Configuration Infrastructure (WS1): - Add instance-level AI env vars (VOYAGE_AI_PROVIDER, VOYAGE_AI_MODEL, VOYAGE_AI_API_KEY) - Implement fallback chain: user key → instance key → error - Add UserAISettings model for per-user provider/model preferences - Enhance provider catalog with instance_configured and user_configured flags - Optimize provider catalog to avoid N+1 queries Phase 1 - User Preference Learning (WS2): - Add Travel Preferences tab to Settings page - Improve preference formatting in system prompt with emoji headers - Add multi-user preference aggregation for shared collections Phase 2 - Day-Level Suggestions Modal (WS3): - Create ItinerarySuggestionModal with 3-step flow (category → filters → results) - Add AI suggestions button to itinerary Add dropdown - Support restaurant, activity, event, and lodging categories - Backend endpoint POST /api/chat/suggestions/day/ with context-aware prompts Phase 3 - Collection-Level Chat Improvements (WS4): - Inject collection context (destination, dates) into chat system prompt - Add quick action buttons for common queries - Add 'Add to itinerary' button on search_places results - Update chat UI with travel-themed branding and improved tool result cards Phase 3 - Web Search Capability (WS5): - Add web_search agent tool using DuckDuckGo - Support location_context parameter for biased results - Handle rate limiting gracefully Phase 4 - Extensibility Architecture (WS6): - Implement decorator-based @agent_tool registry - Convert existing tools to use decorators - Add GET /api/chat/capabilities/ endpoint for tool discovery - Refactor execute_tool() to use registry pattern
This commit is contained in:
328
backend/server/chat/views/__init__.py
Normal file
328
backend/server/chat/views/__init__.py
Normal file
@@ -0,0 +1,328 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from adventures.models import Collection
|
||||
from django.http import StreamingHttpResponse
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from ..agent_tools import AGENT_TOOLS, execute_tool, serialize_tool_result
|
||||
from ..llm_client import (
|
||||
get_provider_catalog,
|
||||
get_system_prompt,
|
||||
is_chat_provider_available,
|
||||
stream_chat_completion,
|
||||
)
|
||||
from ..models import ChatConversation, ChatMessage
|
||||
from ..serializers import ChatConversationSerializer
|
||||
|
||||
|
||||
class ChatViewSet(viewsets.ModelViewSet):
|
||||
serializer_class = ChatConversationSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get_queryset(self):
|
||||
return ChatConversation.objects.filter(user=self.request.user).prefetch_related(
|
||||
"messages"
|
||||
)
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
conversations = self.get_queryset().only("id", "title", "updated_at")
|
||||
data = [
|
||||
{
|
||||
"id": str(conversation.id),
|
||||
"title": conversation.title,
|
||||
"updated_at": conversation.updated_at,
|
||||
}
|
||||
for conversation in conversations
|
||||
]
|
||||
return Response(data)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
conversation = ChatConversation.objects.create(
|
||||
user=request.user,
|
||||
title=(request.data.get("title") or "").strip(),
|
||||
)
|
||||
serializer = self.get_serializer(conversation)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
||||
|
||||
def _build_llm_messages(self, conversation, user, system_prompt=None):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt or get_system_prompt(user),
|
||||
}
|
||||
]
|
||||
for message in conversation.messages.all().order_by("created_at"):
|
||||
payload = {
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
}
|
||||
if message.role == "assistant" and message.tool_calls:
|
||||
payload["tool_calls"] = message.tool_calls
|
||||
if message.role == "tool":
|
||||
payload["tool_call_id"] = message.tool_call_id
|
||||
payload["name"] = message.name
|
||||
messages.append(payload)
|
||||
return messages
|
||||
|
||||
def _async_to_sync_generator(self, async_gen):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
yield loop.run_until_complete(async_gen.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
loop.close()
|
||||
|
||||
@staticmethod
|
||||
def _merge_tool_call_delta(accumulator, tool_calls_delta):
|
||||
for idx, tool_call in enumerate(tool_calls_delta or []):
|
||||
idx = tool_call.get("index", idx)
|
||||
while len(accumulator) <= idx:
|
||||
accumulator.append(
|
||||
{
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
)
|
||||
|
||||
current = accumulator[idx]
|
||||
if tool_call.get("id"):
|
||||
current["id"] = tool_call.get("id")
|
||||
if tool_call.get("type"):
|
||||
current["type"] = tool_call.get("type")
|
||||
|
||||
function_data = tool_call.get("function") or {}
|
||||
if function_data.get("name"):
|
||||
current["function"]["name"] = function_data.get("name")
|
||||
if function_data.get("arguments"):
|
||||
current["function"]["arguments"] += function_data.get("arguments")
|
||||
|
||||
@action(detail=True, methods=["post"])
|
||||
def send_message(self, request, pk=None):
|
||||
conversation = self.get_object()
|
||||
user_content = (request.data.get("message") or "").strip()
|
||||
if not user_content:
|
||||
return Response(
|
||||
{"error": "message is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
provider = (request.data.get("provider") or "openai").strip().lower()
|
||||
model = (request.data.get("model") or "").strip() or None
|
||||
collection_id = request.data.get("collection_id")
|
||||
collection_name = request.data.get("collection_name")
|
||||
start_date = request.data.get("start_date")
|
||||
end_date = request.data.get("end_date")
|
||||
destination = request.data.get("destination")
|
||||
if not is_chat_provider_available(provider):
|
||||
return Response(
|
||||
{"error": f"Provider is not available for chat: {provider}."},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
context_parts = []
|
||||
if collection_name:
|
||||
context_parts.append(f"Trip: {collection_name}")
|
||||
if destination:
|
||||
context_parts.append(f"Destination: {destination}")
|
||||
if start_date and end_date:
|
||||
context_parts.append(f"Dates: {start_date} to {end_date}")
|
||||
|
||||
collection = None
|
||||
if collection_id:
|
||||
try:
|
||||
requested_collection = Collection.objects.get(id=collection_id)
|
||||
if (
|
||||
requested_collection.user == request.user
|
||||
or requested_collection.shared_with.filter(
|
||||
id=request.user.id
|
||||
).exists()
|
||||
):
|
||||
collection = requested_collection
|
||||
except Collection.DoesNotExist:
|
||||
pass
|
||||
|
||||
system_prompt = get_system_prompt(request.user, collection)
|
||||
if context_parts:
|
||||
system_prompt += "\n\n## Trip Context\n" + "\n".join(context_parts)
|
||||
|
||||
ChatMessage.objects.create(
|
||||
conversation=conversation,
|
||||
role="user",
|
||||
content=user_content,
|
||||
)
|
||||
conversation.save(update_fields=["updated_at"])
|
||||
|
||||
if not conversation.title:
|
||||
conversation.title = user_content[:120]
|
||||
conversation.save(update_fields=["title", "updated_at"])
|
||||
|
||||
llm_messages = self._build_llm_messages(
|
||||
conversation,
|
||||
request.user,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
async def event_stream():
|
||||
current_messages = list(llm_messages)
|
||||
encountered_error = False
|
||||
tool_iterations = 0
|
||||
|
||||
while tool_iterations < MAX_TOOL_ITERATIONS:
|
||||
content_chunks = []
|
||||
tool_calls_accumulator = []
|
||||
|
||||
async for chunk in stream_chat_completion(
|
||||
request.user,
|
||||
current_messages,
|
||||
provider,
|
||||
tools=AGENT_TOOLS,
|
||||
model=model,
|
||||
):
|
||||
if not chunk.startswith("data: "):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
payload = chunk[len("data: ") :].strip()
|
||||
if payload == "[DONE]":
|
||||
continue
|
||||
|
||||
yield chunk
|
||||
|
||||
try:
|
||||
data = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if data.get("error"):
|
||||
encountered_error = True
|
||||
break
|
||||
|
||||
if data.get("content"):
|
||||
content_chunks.append(data["content"])
|
||||
|
||||
if data.get("tool_calls"):
|
||||
self._merge_tool_call_delta(
|
||||
tool_calls_accumulator,
|
||||
data["tool_calls"],
|
||||
)
|
||||
|
||||
if encountered_error:
|
||||
break
|
||||
|
||||
assistant_content = "".join(content_chunks)
|
||||
|
||||
if tool_calls_accumulator:
|
||||
assistant_with_tools = {
|
||||
"role": "assistant",
|
||||
"content": assistant_content,
|
||||
"tool_calls": tool_calls_accumulator,
|
||||
}
|
||||
current_messages.append(assistant_with_tools)
|
||||
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create, thread_sensitive=True
|
||||
)(
|
||||
conversation=conversation,
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
tool_calls=tool_calls_accumulator,
|
||||
)
|
||||
await sync_to_async(conversation.save, thread_sensitive=True)(
|
||||
update_fields=["updated_at"]
|
||||
)
|
||||
|
||||
for tool_call in tool_calls_accumulator:
|
||||
function_payload = tool_call.get("function") or {}
|
||||
function_name = function_payload.get("name") or ""
|
||||
raw_arguments = function_payload.get("arguments") or "{}"
|
||||
|
||||
try:
|
||||
arguments = json.loads(raw_arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if not isinstance(arguments, dict):
|
||||
arguments = {}
|
||||
|
||||
result = await sync_to_async(
|
||||
execute_tool, thread_sensitive=True
|
||||
)(
|
||||
function_name,
|
||||
request.user,
|
||||
**arguments,
|
||||
)
|
||||
result_content = serialize_tool_result(result)
|
||||
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
"name": function_name,
|
||||
"content": result_content,
|
||||
}
|
||||
)
|
||||
|
||||
await sync_to_async(
|
||||
ChatMessage.objects.create, thread_sensitive=True
|
||||
)(
|
||||
conversation=conversation,
|
||||
role="tool",
|
||||
content=result_content,
|
||||
tool_call_id=tool_call.get("id"),
|
||||
name=function_name,
|
||||
)
|
||||
await sync_to_async(conversation.save, thread_sensitive=True)(
|
||||
update_fields=["updated_at"]
|
||||
)
|
||||
|
||||
tool_event = {
|
||||
"tool_result": {
|
||||
"tool_call_id": tool_call.get("id"),
|
||||
"name": function_name,
|
||||
"result": result,
|
||||
}
|
||||
}
|
||||
yield f"data: {json.dumps(tool_event)}\n\n"
|
||||
|
||||
continue
|
||||
|
||||
await sync_to_async(ChatMessage.objects.create, thread_sensitive=True)(
|
||||
conversation=conversation,
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
)
|
||||
await sync_to_async(conversation.save, thread_sensitive=True)(
|
||||
update_fields=["updated_at"]
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
streaming_content=self._async_to_sync_generator(event_stream()),
|
||||
content_type="text/event-stream",
|
||||
)
|
||||
response["Cache-Control"] = "no-cache"
|
||||
response["X-Accel-Buffering"] = "no"
|
||||
return response
|
||||
|
||||
|
||||
class ChatProviderCatalogViewSet(viewsets.ViewSet):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def list(self, request):
|
||||
return Response(get_provider_catalog(user=request.user))
|
||||
|
||||
|
||||
from .capabilities import CapabilitiesView
|
||||
from .day_suggestions import DaySuggestionsView
|
||||
24
backend/server/chat/views/capabilities.py
Normal file
24
backend/server/chat/views/capabilities.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from chat.agent_tools import get_tool_schemas
|
||||
|
||||
|
||||
class CapabilitiesView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get(self, request):
|
||||
"""Return available AI capabilities/tools."""
|
||||
tools = get_tool_schemas()
|
||||
return Response(
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"name": tool["function"]["name"],
|
||||
"description": tool["function"]["description"],
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
}
|
||||
)
|
||||
215
backend/server/chat/views/day_suggestions.py
Normal file
215
backend/server/chat/views/day_suggestions.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
import litellm
|
||||
from django.shortcuts import get_object_or_404
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from adventures.models import Collection
|
||||
from chat.agent_tools import search_places
|
||||
from chat.llm_client import (
|
||||
get_llm_api_key,
|
||||
get_system_prompt,
|
||||
is_chat_provider_available,
|
||||
)
|
||||
|
||||
|
||||
class DaySuggestionsView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def post(self, request):
|
||||
collection_id = request.data.get("collection_id")
|
||||
date = request.data.get("date")
|
||||
category = request.data.get("category")
|
||||
filters = request.data.get("filters", {}) or {}
|
||||
location_context = request.data.get("location_context", "")
|
||||
|
||||
if not all([collection_id, date, category]):
|
||||
return Response(
|
||||
{"error": "collection_id, date, and category are required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
valid_categories = ["restaurant", "activity", "event", "lodging"]
|
||||
if category not in valid_categories:
|
||||
return Response(
|
||||
{"error": f"category must be one of: {', '.join(valid_categories)}"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
collection = get_object_or_404(Collection, id=collection_id)
|
||||
if (
|
||||
collection.user != request.user
|
||||
and not collection.shared_with.filter(id=request.user.id).exists()
|
||||
):
|
||||
return Response(
|
||||
{"error": "You don't have access to this collection"},
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
location = location_context or self._get_collection_location(collection)
|
||||
system_prompt = get_system_prompt(request.user, collection)
|
||||
provider = "openai"
|
||||
|
||||
if not is_chat_provider_available(provider):
|
||||
return Response(
|
||||
{
|
||||
"error": "AI suggestions are not available. Please configure an API key."
|
||||
},
|
||||
status=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
try:
|
||||
places_context = self._get_places_context(request.user, category, location)
|
||||
prompt = self._build_prompt(
|
||||
category=category,
|
||||
filters=filters,
|
||||
location=location,
|
||||
date=date,
|
||||
collection=collection,
|
||||
places_context=places_context,
|
||||
)
|
||||
|
||||
suggestions = self._get_suggestions_from_llm(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=prompt,
|
||||
user=request.user,
|
||||
provider=provider,
|
||||
)
|
||||
return Response({"suggestions": suggestions}, status=status.HTTP_200_OK)
|
||||
except Exception:
|
||||
return Response(
|
||||
{"error": "Failed to generate suggestions. Please try again."},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
def _get_collection_location(self, collection):
|
||||
for loc in collection.locations.select_related("city", "country").all():
|
||||
if loc.city:
|
||||
city_name = getattr(loc.city, "name", str(loc.city))
|
||||
country_name = getattr(loc.country, "name", "") if loc.country else ""
|
||||
return ", ".join([x for x in [city_name, country_name] if x])
|
||||
if loc.location:
|
||||
return loc.location
|
||||
if loc.name:
|
||||
return loc.name
|
||||
return "Unknown location"
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
category,
|
||||
filters,
|
||||
location,
|
||||
date,
|
||||
collection,
|
||||
places_context="",
|
||||
):
|
||||
category_prompts = {
|
||||
"restaurant": f"Find restaurant recommendations for {location}",
|
||||
"activity": f"Find activity recommendations for {location}",
|
||||
"event": f"Find event recommendations for {location} around {date}",
|
||||
"lodging": f"Find lodging recommendations for {location}",
|
||||
}
|
||||
|
||||
prompt = category_prompts.get(
|
||||
category, f"Find {category} recommendations for {location}"
|
||||
)
|
||||
|
||||
if filters:
|
||||
filter_parts = []
|
||||
if filters.get("cuisine_type"):
|
||||
filter_parts.append(f"cuisine type: {filters['cuisine_type']}")
|
||||
if filters.get("price_range"):
|
||||
filter_parts.append(f"price range: {filters['price_range']}")
|
||||
if filters.get("dietary"):
|
||||
filter_parts.append(f"dietary restrictions: {filters['dietary']}")
|
||||
if filters.get("activity_type"):
|
||||
filter_parts.append(f"type: {filters['activity_type']}")
|
||||
if filters.get("duration"):
|
||||
filter_parts.append(f"duration: {filters['duration']}")
|
||||
if filters.get("event_type"):
|
||||
filter_parts.append(f"event type: {filters['event_type']}")
|
||||
if filters.get("lodging_type"):
|
||||
filter_parts.append(f"lodging type: {filters['lodging_type']}")
|
||||
amenities = filters.get("amenities")
|
||||
if isinstance(amenities, list) and amenities:
|
||||
filter_parts.append(
|
||||
f"amenities: {', '.join(str(x) for x in amenities)}"
|
||||
)
|
||||
|
||||
if filter_parts:
|
||||
prompt += f" with these preferences: {', '.join(filter_parts)}"
|
||||
|
||||
prompt += f". The trip date is {date}."
|
||||
|
||||
if collection.start_date or collection.end_date:
|
||||
prompt += (
|
||||
" Collection trip window: "
|
||||
f"{collection.start_date or 'unknown'} to {collection.end_date or 'unknown'}."
|
||||
)
|
||||
|
||||
if places_context:
|
||||
prompt += f" Nearby place context: {places_context}."
|
||||
|
||||
prompt += (
|
||||
" Return 3-5 specific suggestions as a JSON array."
|
||||
" Each suggestion should have: name, description, why_fits, category, location, rating, price_level."
|
||||
" Return ONLY valid JSON, no markdown, no surrounding text."
|
||||
)
|
||||
return prompt
|
||||
|
||||
def _get_places_context(self, user, category, location):
|
||||
tool_category_map = {
|
||||
"restaurant": "food",
|
||||
"activity": "tourism",
|
||||
"event": "tourism",
|
||||
"lodging": "lodging",
|
||||
}
|
||||
result = search_places(
|
||||
user,
|
||||
location=location,
|
||||
category=tool_category_map.get(category, "tourism"),
|
||||
radius=8,
|
||||
)
|
||||
if result.get("error"):
|
||||
return ""
|
||||
|
||||
entries = []
|
||||
for place in result.get("results", [])[:5]:
|
||||
name = place.get("name")
|
||||
address = place.get("address") or ""
|
||||
if name:
|
||||
entries.append(f"{name} ({address})" if address else name)
|
||||
return "; ".join(entries)
|
||||
|
||||
def _get_suggestions_from_llm(self, system_prompt, user_prompt, user, provider):
|
||||
api_key = get_llm_api_key(user, provider)
|
||||
if not api_key:
|
||||
raise ValueError("No API key available")
|
||||
|
||||
response = litellm.completion(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
api_key=api_key,
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
content = (response.choices[0].message.content or "").strip()
|
||||
try:
|
||||
json_match = re.search(r"\[.*\]", content, re.DOTALL)
|
||||
parsed = (
|
||||
json.loads(json_match.group())
|
||||
if json_match
|
||||
else json.loads(content or "[]")
|
||||
)
|
||||
suggestions = parsed if isinstance(parsed, list) else [parsed]
|
||||
return suggestions[:5]
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
Reference in New Issue
Block a user