fix(chat): use itinerary context for restaurant searches
This commit is contained in:
@@ -198,6 +198,38 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
"activities, or lodging."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _trip_context_search_location(destination, itinerary_stops):
|
||||
destination_text = (destination or "").strip()
|
||||
if destination_text:
|
||||
return destination_text
|
||||
|
||||
for stop in itinerary_stops or []:
|
||||
stop_text = (stop or "").strip()
|
||||
if stop_text:
|
||||
return stop_text
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _infer_search_places_category(user_content, prior_user_messages):
|
||||
message_parts = [(user_content or "").strip()]
|
||||
message_parts.extend(
|
||||
(content or "").strip() for content in prior_user_messages or []
|
||||
)
|
||||
normalized = " ".join(part for part in message_parts if part).lower()
|
||||
if not normalized:
|
||||
return None
|
||||
|
||||
dining_intent_pattern = (
|
||||
r"\b(restaurant|restaurants|dining|dinner|lunch|breakfast|brunch|"
|
||||
r"cafe|cafes|food|eat|eating|cuisine|meal|meals|bistro|bar|bars)\b"
|
||||
)
|
||||
if re.search(dining_intent_pattern, normalized):
|
||||
return "food"
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_likely_location_reply(user_content):
|
||||
if not isinstance(user_content, str):
|
||||
@@ -274,6 +306,7 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
)
|
||||
|
||||
context_parts = []
|
||||
itinerary_stops = []
|
||||
if collection_name:
|
||||
context_parts.append(f"Trip: {collection_name}")
|
||||
if destination:
|
||||
@@ -300,7 +333,6 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
"Collection UUID (use this exact collection_id for get_trip_details and add_to_itinerary): "
|
||||
f"{collection.id}"
|
||||
)
|
||||
itinerary_stops = []
|
||||
seen_stops = set()
|
||||
for location in collection.locations.select_related(
|
||||
"city", "country"
|
||||
@@ -334,6 +366,15 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
if itinerary_stops:
|
||||
context_parts.append(f"Itinerary stops: {'; '.join(itinerary_stops)}")
|
||||
|
||||
trip_context_location = self._trip_context_search_location(
|
||||
destination, itinerary_stops
|
||||
)
|
||||
prior_user_messages = list(
|
||||
conversation.messages.filter(role="user")
|
||||
.order_by("-created_at")
|
||||
.values_list("content", flat=True)[:3]
|
||||
)
|
||||
|
||||
system_prompt = get_system_prompt(request.user, collection)
|
||||
if context_parts:
|
||||
system_prompt += "\n\n## Trip Context\n" + "\n".join(context_parts)
|
||||
@@ -425,42 +466,81 @@ class ChatViewSet(viewsets.ModelViewSet):
|
||||
if not isinstance(arguments, dict):
|
||||
arguments = {}
|
||||
|
||||
result = await sync_to_async(
|
||||
execute_tool, thread_sensitive=True
|
||||
)(
|
||||
function_name,
|
||||
request.user,
|
||||
**arguments,
|
||||
)
|
||||
|
||||
prepared_arguments = dict(arguments)
|
||||
tool_call_for_history = tool_call
|
||||
|
||||
if self._is_search_places_missing_location_required_error(
|
||||
function_name,
|
||||
result,
|
||||
) and self._is_likely_location_reply(user_content):
|
||||
retry_arguments = dict(arguments)
|
||||
retry_arguments["location"] = user_content
|
||||
retry_result = await sync_to_async(
|
||||
execute_tool,
|
||||
thread_sensitive=True,
|
||||
)(
|
||||
function_name,
|
||||
request.user,
|
||||
**retry_arguments,
|
||||
)
|
||||
if function_name == "search_places":
|
||||
if not (prepared_arguments.get("category") or "").strip():
|
||||
inferred_category = self._infer_search_places_category(
|
||||
user_content,
|
||||
prior_user_messages,
|
||||
)
|
||||
if inferred_category:
|
||||
prepared_arguments["category"] = inferred_category
|
||||
|
||||
if not self._is_required_param_tool_error(retry_result):
|
||||
result = retry_result
|
||||
if prepared_arguments != arguments:
|
||||
tool_call_for_history = {
|
||||
**tool_call,
|
||||
"function": {
|
||||
**function_payload,
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(retry_arguments),
|
||||
"arguments": json.dumps(prepared_arguments),
|
||||
},
|
||||
}
|
||||
|
||||
result = await sync_to_async(
|
||||
execute_tool, thread_sensitive=True
|
||||
)(
|
||||
function_name,
|
||||
request.user,
|
||||
**prepared_arguments,
|
||||
)
|
||||
|
||||
if self._is_search_places_missing_location_required_error(
|
||||
function_name,
|
||||
result,
|
||||
):
|
||||
retry_locations = []
|
||||
if trip_context_location:
|
||||
retry_locations.append(trip_context_location)
|
||||
if self._is_likely_location_reply(user_content):
|
||||
retry_locations.append(user_content)
|
||||
|
||||
seen_retry_locations = set()
|
||||
for retry_location in retry_locations:
|
||||
normalized_retry_location = (
|
||||
retry_location.strip().lower()
|
||||
)
|
||||
if (
|
||||
not normalized_retry_location
|
||||
or normalized_retry_location in seen_retry_locations
|
||||
):
|
||||
continue
|
||||
seen_retry_locations.add(normalized_retry_location)
|
||||
|
||||
retry_arguments = dict(prepared_arguments)
|
||||
retry_arguments["location"] = retry_location
|
||||
retry_result = await sync_to_async(
|
||||
execute_tool,
|
||||
thread_sensitive=True,
|
||||
)(
|
||||
function_name,
|
||||
request.user,
|
||||
**retry_arguments,
|
||||
)
|
||||
|
||||
if not self._is_required_param_tool_error(retry_result):
|
||||
result = retry_result
|
||||
tool_call_for_history = {
|
||||
**tool_call,
|
||||
"function": {
|
||||
**function_payload,
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(retry_arguments),
|
||||
},
|
||||
}
|
||||
break
|
||||
|
||||
if self._is_required_param_tool_error(result):
|
||||
assistant_message_kwargs = {
|
||||
"conversation": conversation,
|
||||
|
||||
Reference in New Issue
Block a user