fix(chat): use itinerary context for restaurant searches

This commit is contained in:
2026-03-10 17:12:29 +00:00
parent a023a9548c
commit 46d7704e4f
4 changed files with 231 additions and 30 deletions

View File

@@ -198,6 +198,38 @@ class ChatViewSet(viewsets.ModelViewSet):
"activities, or lodging."
)
@staticmethod
def _trip_context_search_location(destination, itinerary_stops):
destination_text = (destination or "").strip()
if destination_text:
return destination_text
for stop in itinerary_stops or []:
stop_text = (stop or "").strip()
if stop_text:
return stop_text
return ""
@staticmethod
def _infer_search_places_category(user_content, prior_user_messages):
message_parts = [(user_content or "").strip()]
message_parts.extend(
(content or "").strip() for content in prior_user_messages or []
)
normalized = " ".join(part for part in message_parts if part).lower()
if not normalized:
return None
dining_intent_pattern = (
r"\b(restaurant|restaurants|dining|dinner|lunch|breakfast|brunch|"
r"cafe|cafes|food|eat|eating|cuisine|meal|meals|bistro|bar|bars)\b"
)
if re.search(dining_intent_pattern, normalized):
return "food"
return None
@staticmethod
def _is_likely_location_reply(user_content):
if not isinstance(user_content, str):
@@ -274,6 +306,7 @@ class ChatViewSet(viewsets.ModelViewSet):
)
context_parts = []
itinerary_stops = []
if collection_name:
context_parts.append(f"Trip: {collection_name}")
if destination:
@@ -300,7 +333,6 @@ class ChatViewSet(viewsets.ModelViewSet):
"Collection UUID (use this exact collection_id for get_trip_details and add_to_itinerary): "
f"{collection.id}"
)
itinerary_stops = []
seen_stops = set()
for location in collection.locations.select_related(
"city", "country"
@@ -334,6 +366,15 @@ class ChatViewSet(viewsets.ModelViewSet):
if itinerary_stops:
context_parts.append(f"Itinerary stops: {'; '.join(itinerary_stops)}")
trip_context_location = self._trip_context_search_location(
destination, itinerary_stops
)
prior_user_messages = list(
conversation.messages.filter(role="user")
.order_by("-created_at")
.values_list("content", flat=True)[:3]
)
system_prompt = get_system_prompt(request.user, collection)
if context_parts:
system_prompt += "\n\n## Trip Context\n" + "\n".join(context_parts)
@@ -425,42 +466,81 @@ class ChatViewSet(viewsets.ModelViewSet):
if not isinstance(arguments, dict):
arguments = {}
result = await sync_to_async(
execute_tool, thread_sensitive=True
)(
function_name,
request.user,
**arguments,
)
prepared_arguments = dict(arguments)
tool_call_for_history = tool_call
if self._is_search_places_missing_location_required_error(
function_name,
result,
) and self._is_likely_location_reply(user_content):
retry_arguments = dict(arguments)
retry_arguments["location"] = user_content
retry_result = await sync_to_async(
execute_tool,
thread_sensitive=True,
)(
function_name,
request.user,
**retry_arguments,
)
if function_name == "search_places":
if not (prepared_arguments.get("category") or "").strip():
inferred_category = self._infer_search_places_category(
user_content,
prior_user_messages,
)
if inferred_category:
prepared_arguments["category"] = inferred_category
if not self._is_required_param_tool_error(retry_result):
result = retry_result
if prepared_arguments != arguments:
tool_call_for_history = {
**tool_call,
"function": {
**function_payload,
"name": function_name,
"arguments": json.dumps(retry_arguments),
"arguments": json.dumps(prepared_arguments),
},
}
result = await sync_to_async(
execute_tool, thread_sensitive=True
)(
function_name,
request.user,
**prepared_arguments,
)
if self._is_search_places_missing_location_required_error(
function_name,
result,
):
retry_locations = []
if trip_context_location:
retry_locations.append(trip_context_location)
if self._is_likely_location_reply(user_content):
retry_locations.append(user_content)
seen_retry_locations = set()
for retry_location in retry_locations:
normalized_retry_location = (
retry_location.strip().lower()
)
if (
not normalized_retry_location
or normalized_retry_location in seen_retry_locations
):
continue
seen_retry_locations.add(normalized_retry_location)
retry_arguments = dict(prepared_arguments)
retry_arguments["location"] = retry_location
retry_result = await sync_to_async(
execute_tool,
thread_sensitive=True,
)(
function_name,
request.user,
**retry_arguments,
)
if not self._is_required_param_tool_error(retry_result):
result = retry_result
tool_call_for_history = {
**tool_call,
"function": {
**function_payload,
"name": function_name,
"arguments": json.dumps(retry_arguments),
},
}
break
if self._is_required_param_tool_error(result):
assistant_message_kwargs = {
"conversation": conversation,