fix: stabilize post-MVP travel-agent and itinerary workflows
This commit is contained in:
@@ -27,6 +27,7 @@ EMAIL_BACKEND='console'
|
||||
|
||||
# GOOGLE_MAPS_API_KEY='key'
|
||||
# OSRM_BASE_URL='https://router.project-osrm.org' # replace with self-host URL if needed (e.g. http://osrm:5000)
|
||||
# DJANGO_MCP_ENDPOINT='api/mcp' # optional custom MCP HTTP endpoint path
|
||||
|
||||
# ACCOUNT_EMAIL_VERIFICATION='none' # 'none', 'optional', 'mandatory' # You can change this as needed for your environment
|
||||
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
from django.utils import timezone
|
||||
|
||||
from adventures.models import Collection, CollectionItineraryItem
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = (
|
||||
"Export Collection and CollectionItineraryItem data to a JSON backup "
|
||||
"file before upgrades/migrations."
|
||||
)
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
help="Optional output file path (default: ./collections_backup_<timestamp>.json)",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
backup_timestamp = timezone.now()
|
||||
timestamp = backup_timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
output_path = Path(
|
||||
options.get("output") or f"collections_backup_{timestamp}.json"
|
||||
)
|
||||
|
||||
if output_path.parent and not output_path.parent.exists():
|
||||
raise CommandError(f"Output directory does not exist: {output_path.parent}")
|
||||
|
||||
collections = list(
|
||||
Collection.objects.values(
|
||||
"id",
|
||||
"user_id",
|
||||
"name",
|
||||
"description",
|
||||
"is_public",
|
||||
"is_archived",
|
||||
"start_date",
|
||||
"end_date",
|
||||
"link",
|
||||
"primary_image_id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
)
|
||||
)
|
||||
|
||||
shared_with_map = {
|
||||
str(collection.id): list(
|
||||
collection.shared_with.values_list("id", flat=True)
|
||||
)
|
||||
for collection in Collection.objects.prefetch_related("shared_with")
|
||||
}
|
||||
for collection in collections:
|
||||
collection["shared_with_ids"] = shared_with_map.get(
|
||||
str(collection["id"]), []
|
||||
)
|
||||
|
||||
itinerary_items = list(
|
||||
CollectionItineraryItem.objects.select_related("content_type").values(
|
||||
"id",
|
||||
"collection_id",
|
||||
"content_type_id",
|
||||
"content_type__app_label",
|
||||
"content_type__model",
|
||||
"object_id",
|
||||
"date",
|
||||
"is_global",
|
||||
"order",
|
||||
"created_at",
|
||||
)
|
||||
)
|
||||
|
||||
backup_payload = {
|
||||
"backup_type": "collections_snapshot",
|
||||
"timestamp": backup_timestamp.isoformat(),
|
||||
"counts": {
|
||||
"collections": len(collections),
|
||||
"collection_itinerary_items": len(itinerary_items),
|
||||
},
|
||||
"collections": collections,
|
||||
"collection_itinerary_items": itinerary_items,
|
||||
}
|
||||
|
||||
try:
|
||||
with output_path.open("w", encoding="utf-8") as backup_file:
|
||||
json.dump(backup_payload, backup_file, indent=2, cls=DjangoJSONEncoder)
|
||||
except OSError as exc:
|
||||
raise CommandError(f"Failed to write backup file: {exc}") from exc
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise CommandError(f"Failed to serialize backup data: {exc}") from exc
|
||||
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
"Exported collections backup to "
|
||||
f"{output_path} "
|
||||
f"at {backup_timestamp.isoformat()} "
|
||||
f"(collections: {len(collections)}, "
|
||||
f"itinerary_items: {len(itinerary_items)})."
|
||||
)
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,16 +1,34 @@
|
||||
import json
|
||||
import tempfile
|
||||
import base64
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core.cache import cache
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from django.core.management import call_command
|
||||
from django.core.management.base import CommandError
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
from rest_framework.test import APIClient, APITestCase
|
||||
|
||||
from adventures.models import (
|
||||
Collection,
|
||||
CollectionItineraryItem,
|
||||
ContentImage,
|
||||
Lodging,
|
||||
Note,
|
||||
Transportation,
|
||||
)
|
||||
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class WeatherEndpointTests(APITestCase):
|
||||
class WeatherViewTests(APITestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(
|
||||
username="weather-user",
|
||||
@@ -35,11 +53,38 @@ class WeatherEndpointTests(APITestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertIn("maximum", response.json().get("error", "").lower())
|
||||
|
||||
@patch("adventures.views.weather_view.requests.get")
|
||||
def test_daily_temperatures_future_date_returns_unavailable_without_external_call(
|
||||
self, mock_requests_get
|
||||
@patch("adventures.views.weather_view.WeatherViewSet._fetch_daily_temperature")
|
||||
def test_daily_temperatures_future_date_reaches_fetch_path(
|
||||
self, mock_fetch_temperature
|
||||
):
|
||||
future_date = (timezone.now().date() + timedelta(days=10)).isoformat()
|
||||
mock_fetch_temperature.return_value = {
|
||||
"date": future_date,
|
||||
"available": True,
|
||||
"temperature_c": 22.5,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
"/api/weather/daily-temperatures/",
|
||||
{"days": [{"date": future_date, "latitude": 12.34, "longitude": 56.78}]},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.json()["results"][0]["date"], future_date)
|
||||
self.assertTrue(response.json()["results"][0]["available"])
|
||||
self.assertEqual(response.json()["results"][0]["temperature_c"], 22.5)
|
||||
mock_fetch_temperature.assert_called_once_with(future_date, 12.34, 56.78)
|
||||
|
||||
@patch("adventures.views.weather_view.requests.get")
|
||||
def test_daily_temperatures_far_future_returns_unavailable_when_upstream_has_no_data(
|
||||
self, mock_requests_get
|
||||
):
|
||||
future_date = (timezone.now().date() + timedelta(days=3650)).isoformat()
|
||||
mocked_response = Mock()
|
||||
mocked_response.raise_for_status.return_value = None
|
||||
mocked_response.json.return_value = {"daily": {}}
|
||||
mock_requests_get.return_value = mocked_response
|
||||
|
||||
response = self.client.post(
|
||||
"/api/weather/daily-temperatures/",
|
||||
@@ -52,7 +97,7 @@ class WeatherEndpointTests(APITestCase):
|
||||
response.json()["results"][0],
|
||||
{"date": future_date, "available": False, "temperature_c": None},
|
||||
)
|
||||
mock_requests_get.assert_not_called()
|
||||
self.assertEqual(mock_requests_get.call_count, 2)
|
||||
|
||||
@patch("adventures.views.weather_view.requests.get")
|
||||
def test_daily_temperatures_accepts_zero_lat_lon(self, mock_requests_get):
|
||||
@@ -106,3 +151,166 @@ class MCPAuthTests(APITestCase):
|
||||
unauthenticated_client = APIClient()
|
||||
response = unauthenticated_client.post("/api/mcp", {}, format="json")
|
||||
self.assertIn(response.status_code, [401, 403])
|
||||
|
||||
|
||||
class CollectionViewSetTests(APITestCase):
|
||||
def setUp(self):
|
||||
self.owner = User.objects.create_user(
|
||||
username="collection-owner",
|
||||
email="owner@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.shared_user = User.objects.create_user(
|
||||
username="collection-shared",
|
||||
email="shared@example.com",
|
||||
password="password123",
|
||||
)
|
||||
|
||||
def _create_test_image_file(self, name="test.png"):
|
||||
# 1x1 PNG
|
||||
png_bytes = base64.b64decode(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO7Y9x8AAAAASUVORK5CYII="
|
||||
)
|
||||
return SimpleUploadedFile(name, png_bytes, content_type="image/png")
|
||||
|
||||
def _create_collection_with_non_location_images(self):
|
||||
collection = Collection.objects.create(
|
||||
user=self.owner,
|
||||
name="Image fallback collection",
|
||||
)
|
||||
|
||||
lodging = Lodging.objects.create(
|
||||
user=self.owner,
|
||||
collection=collection,
|
||||
name="Fallback lodge",
|
||||
)
|
||||
transportation = Transportation.objects.create(
|
||||
user=self.owner,
|
||||
collection=collection,
|
||||
type="car",
|
||||
name="Fallback ride",
|
||||
)
|
||||
|
||||
lodging_content_type = ContentType.objects.get_for_model(Lodging)
|
||||
transportation_content_type = ContentType.objects.get_for_model(Transportation)
|
||||
|
||||
ContentImage.objects.create(
|
||||
user=self.owner,
|
||||
content_type=lodging_content_type,
|
||||
object_id=lodging.id,
|
||||
image=self._create_test_image_file("lodging.png"),
|
||||
is_primary=True,
|
||||
)
|
||||
ContentImage.objects.create(
|
||||
user=self.owner,
|
||||
content_type=transportation_content_type,
|
||||
object_id=transportation.id,
|
||||
image=self._create_test_image_file("transport.png"),
|
||||
is_primary=True,
|
||||
)
|
||||
|
||||
return collection
|
||||
|
||||
def test_list_includes_lodging_transportation_images_when_no_location_images(self):
|
||||
collection = self._create_collection_with_non_location_images()
|
||||
|
||||
self.client.force_authenticate(user=self.owner)
|
||||
response = self.client.get("/api/collections/")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertGreater(len(response.data.get("results", [])), 0)
|
||||
|
||||
collection_payload = next(
|
||||
item
|
||||
for item in response.data["results"]
|
||||
if item["id"] == str(collection.id)
|
||||
)
|
||||
self.assertIn("location_images", collection_payload)
|
||||
self.assertGreater(len(collection_payload["location_images"]), 0)
|
||||
self.assertTrue(
|
||||
any(
|
||||
image.get("is_primary")
|
||||
for image in collection_payload["location_images"]
|
||||
)
|
||||
)
|
||||
|
||||
def test_shared_endpoint_includes_non_location_primary_images(self):
|
||||
collection = self._create_collection_with_non_location_images()
|
||||
collection.shared_with.add(self.shared_user)
|
||||
|
||||
self.client.force_authenticate(user=self.shared_user)
|
||||
response = self.client.get("/api/collections/shared/")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertGreater(len(response.data), 0)
|
||||
|
||||
collection_payload = next(
|
||||
item for item in response.data if item["id"] == str(collection.id)
|
||||
)
|
||||
self.assertEqual(str(collection.id), collection_payload["id"])
|
||||
self.assertIn("location_images", collection_payload)
|
||||
self.assertGreater(len(collection_payload["location_images"]), 0)
|
||||
first_image = collection_payload["location_images"][0]
|
||||
self.assertSetEqual(
|
||||
set(first_image.keys()),
|
||||
{"id", "image", "is_primary", "user", "immich_id"},
|
||||
)
|
||||
|
||||
|
||||
class ExportCollectionsBackupCommandTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(
|
||||
username="backup-user",
|
||||
email="backup@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.collaborator = User.objects.create_user(
|
||||
username="collab-user",
|
||||
email="collab@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.collection = Collection.objects.create(
|
||||
user=self.user,
|
||||
name="My Trip",
|
||||
description="Backup test collection",
|
||||
)
|
||||
self.collection.shared_with.add(self.collaborator)
|
||||
|
||||
note = Note.objects.create(user=self.user, name="Test item")
|
||||
note_content_type = ContentType.objects.get_for_model(Note)
|
||||
CollectionItineraryItem.objects.create(
|
||||
collection=self.collection,
|
||||
content_type=note_content_type,
|
||||
object_id=note.id,
|
||||
date=timezone.now().date(),
|
||||
is_global=False,
|
||||
order=1,
|
||||
)
|
||||
|
||||
def test_export_collections_backup_writes_expected_payload(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
output_file = Path(temp_dir) / "collections_snapshot.json"
|
||||
|
||||
call_command("export_collections_backup", output=str(output_file))
|
||||
|
||||
self.assertTrue(output_file.exists())
|
||||
payload = json.loads(output_file.read_text(encoding="utf-8"))
|
||||
|
||||
self.assertEqual(payload["backup_type"], "collections_snapshot")
|
||||
self.assertIn("timestamp", payload)
|
||||
self.assertEqual(payload["counts"]["collections"], 1)
|
||||
self.assertEqual(payload["counts"]["collection_itinerary_items"], 1)
|
||||
self.assertEqual(len(payload["collections"]), 1)
|
||||
self.assertEqual(len(payload["collection_itinerary_items"]), 1)
|
||||
self.assertEqual(
|
||||
payload["collections"][0]["shared_with_ids"],
|
||||
[self.collaborator.id],
|
||||
)
|
||||
|
||||
def test_export_collections_backup_raises_for_missing_output_directory(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
missing_directory = Path(temp_dir) / "missing"
|
||||
output_file = missing_directory / "collections_snapshot.json"
|
||||
|
||||
with self.assertRaises(CommandError):
|
||||
call_command("export_collections_backup", output=str(output_file))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -60,12 +60,6 @@ class WeatherViewSet(viewsets.ViewSet):
|
||||
)
|
||||
continue
|
||||
|
||||
if parsed_date > date_cls.today():
|
||||
results.append(
|
||||
{"date": date, "available": False, "temperature_c": None}
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
lat = float(latitude)
|
||||
lon = float(longitude)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from django.db import IntegrityError
|
||||
|
||||
from .models import (
|
||||
EncryptionConfigurationError,
|
||||
ImmichIntegration,
|
||||
@@ -41,12 +43,28 @@ class UserAPIKeySerializer(serializers.ModelSerializer):
|
||||
def create(self, validated_data):
|
||||
api_key = validated_data.pop("api_key")
|
||||
user = self.context["request"].user
|
||||
instance = UserAPIKey(user=user, **validated_data)
|
||||
|
||||
provider = validated_data.get("provider")
|
||||
|
||||
try:
|
||||
instance, _ = UserAPIKey.objects.get_or_create(
|
||||
user=user,
|
||||
provider=provider,
|
||||
defaults={"encrypted_api_key": ""},
|
||||
)
|
||||
instance.set_api_key(api_key)
|
||||
except EncryptionConfigurationError as exc:
|
||||
raise serializers.ValidationError({"api_key": str(exc)}) from exc
|
||||
instance.save()
|
||||
except IntegrityError:
|
||||
# Defensive retry: in highly concurrent requests a competing create can
|
||||
# still race. Fall back to updating the existing row instead of 500.
|
||||
instance = UserAPIKey.objects.get(user=user, provider=provider)
|
||||
try:
|
||||
instance.set_api_key(api_key)
|
||||
except EncryptionConfigurationError as exc:
|
||||
raise serializers.ValidationError({"api_key": str(exc)}) from exc
|
||||
|
||||
instance.save(update_fields=["encrypted_api_key", "updated_at"])
|
||||
return instance
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
|
||||
@@ -51,3 +51,65 @@ class UserAPIKeyConfigurationTests(APITestCase):
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertIn("not configured", response.json().get("error", "").lower())
|
||||
mock_requests_get.assert_not_called()
|
||||
|
||||
|
||||
class UserAPIKeyCreateBehaviorTests(APITestCase):
|
||||
@override_settings(
|
||||
FIELD_ENCRYPTION_KEY="YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWE="
|
||||
)
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(
|
||||
username="api-key-create-user",
|
||||
email="apikey-create@example.com",
|
||||
password="password123",
|
||||
)
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
@override_settings(
|
||||
FIELD_ENCRYPTION_KEY="YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWE="
|
||||
)
|
||||
def test_duplicate_provider_post_updates_existing_key(self):
|
||||
first_response = self.client.post(
|
||||
"/api/integrations/api-keys/",
|
||||
{"provider": "google_maps", "api_key": "first-secret"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(first_response.status_code, 201)
|
||||
|
||||
second_response = self.client.post(
|
||||
"/api/integrations/api-keys/",
|
||||
{"provider": "google_maps", "api_key": "second-secret"},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(second_response.status_code, 201)
|
||||
|
||||
from integrations.models import UserAPIKey
|
||||
|
||||
records = UserAPIKey.objects.filter(user=self.user, provider="google_maps")
|
||||
self.assertEqual(records.count(), 1)
|
||||
self.assertEqual(records.first().get_api_key(), "second-secret")
|
||||
|
||||
@override_settings(
|
||||
FIELD_ENCRYPTION_KEY="YWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWE="
|
||||
)
|
||||
def test_provider_is_normalized_and_still_upserts(self):
|
||||
self.client.post(
|
||||
"/api/integrations/api-keys/",
|
||||
{"provider": "Google_Maps", "api_key": "first-secret"},
|
||||
format="json",
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/integrations/api-keys/",
|
||||
{"provider": " google_maps ", "api_key": "rotated-secret"},
|
||||
format="json",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 201)
|
||||
|
||||
from integrations.models import UserAPIKey
|
||||
|
||||
records = UserAPIKey.objects.filter(user=self.user, provider="google_maps")
|
||||
self.assertEqual(records.count(), 1)
|
||||
self.assertEqual(records.first().get_api_key(), "rotated-secret")
|
||||
|
||||
38
backend/server/main/tests.py
Normal file
38
backend/server/main/tests.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from django.contrib.auth import get_user_model
|
||||
from rest_framework.authtoken.models import Token
|
||||
from rest_framework.test import APIClient, APITestCase
|
||||
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class MCPTokenEndpointTests(APITestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(
|
||||
username="mcp-token-user",
|
||||
email="mcp-token@example.com",
|
||||
password="password123",
|
||||
)
|
||||
|
||||
def test_requires_authentication(self):
|
||||
unauthenticated_client = APIClient()
|
||||
response = unauthenticated_client.get("/auth/mcp-token/")
|
||||
self.assertIn(response.status_code, [401, 403])
|
||||
|
||||
def test_returns_token_for_authenticated_user(self):
|
||||
self.client.force_authenticate(user=self.user)
|
||||
response = self.client.get("/auth/mcp-token/")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn("token", response.json())
|
||||
self.assertTrue(Token.objects.filter(user=self.user).exists())
|
||||
|
||||
def test_reuses_existing_token(self):
|
||||
existing_token = Token.objects.create(user=self.user)
|
||||
|
||||
self.client.force_authenticate(user=self.user)
|
||||
response = self.client.get("/auth/mcp-token/")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.json().get("token"), existing_token.key)
|
||||
self.assertEqual(Token.objects.filter(user=self.user).count(), 1)
|
||||
@@ -10,7 +10,12 @@ from users.views import (
|
||||
EnabledSocialProvidersView,
|
||||
DisablePasswordAuthenticationView,
|
||||
)
|
||||
from .views import get_csrf_token, get_public_url, serve_protected_media
|
||||
from .views import (
|
||||
get_csrf_token,
|
||||
get_mcp_api_token,
|
||||
get_public_url,
|
||||
serve_protected_media,
|
||||
)
|
||||
from drf_yasg.views import get_schema_view
|
||||
from drf_yasg import openapi
|
||||
from mcp_server.views import MCPServerStreamableHttpView
|
||||
@@ -48,6 +53,7 @@ urlpatterns = [
|
||||
),
|
||||
name="mcp_server_streamable_http_endpoint",
|
||||
),
|
||||
path("auth/mcp-token/", get_mcp_api_token, name="get_mcp_api_token"),
|
||||
path("auth/", include("allauth.headless.urls")),
|
||||
# Serve protected media files
|
||||
re_path(
|
||||
|
||||
@@ -5,21 +5,42 @@ from django.conf import settings
|
||||
from django.http import HttpResponse, HttpResponseForbidden
|
||||
from django.views.static import serve
|
||||
from adventures.utils.file_permissions import checkFilePermission
|
||||
from rest_framework.authentication import SessionAuthentication, TokenAuthentication
|
||||
from rest_framework.authtoken.models import Token
|
||||
from rest_framework.decorators import (
|
||||
api_view,
|
||||
authentication_classes,
|
||||
permission_classes,
|
||||
)
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
|
||||
def get_csrf_token(request):
|
||||
csrf_token = get_token(request)
|
||||
return JsonResponse({'csrfToken': csrf_token})
|
||||
return JsonResponse({"csrfToken": csrf_token})
|
||||
|
||||
|
||||
def get_public_url(request):
|
||||
return JsonResponse({'PUBLIC_URL': getenv('PUBLIC_URL')})
|
||||
return JsonResponse({"PUBLIC_URL": getenv("PUBLIC_URL")})
|
||||
|
||||
|
||||
@api_view(["GET"])
|
||||
@authentication_classes([SessionAuthentication, TokenAuthentication])
|
||||
@permission_classes([IsAuthenticated])
|
||||
def get_mcp_api_token(request):
|
||||
token, _ = Token.objects.get_or_create(user=request.user)
|
||||
return Response({"token": token.key})
|
||||
|
||||
|
||||
protected_paths = ["images/", "attachments/"]
|
||||
|
||||
protected_paths = ['images/', 'attachments/']
|
||||
|
||||
def serve_protected_media(request, path):
|
||||
if any([path.startswith(protected_path) for protected_path in protected_paths]):
|
||||
image_id = path.split('/')[1]
|
||||
image_id = path.split("/")[1]
|
||||
user = request.user
|
||||
media_type = path.split('/')[0] + '/'
|
||||
media_type = path.split("/")[0] + "/"
|
||||
if checkFilePermission(image_id, user, media_type):
|
||||
if settings.DEBUG:
|
||||
# In debug mode, serve the file directly
|
||||
@@ -27,8 +48,8 @@ def serve_protected_media(request, path):
|
||||
else:
|
||||
# In production, use X-Accel-Redirect to serve the file using Nginx
|
||||
response = HttpResponse()
|
||||
response['Content-Type'] = ''
|
||||
response['X-Accel-Redirect'] = '/protectedMedia/' + path
|
||||
response["Content-Type"] = ""
|
||||
response["X-Accel-Redirect"] = "/protectedMedia/" + path
|
||||
return response
|
||||
else:
|
||||
return HttpResponseForbidden()
|
||||
@@ -37,6 +58,6 @@ def serve_protected_media(request, path):
|
||||
return serve(request, path, document_root=settings.MEDIA_ROOT)
|
||||
else:
|
||||
response = HttpResponse()
|
||||
response['Content-Type'] = ''
|
||||
response['X-Accel-Redirect'] = '/protectedMedia/' + path
|
||||
return response
|
||||
response["Content-Type"] = ""
|
||||
response["X-Accel-Redirect"] = "/protectedMedia/" + path
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user