Skip to content

Commit

Permalink
feat(api): add Token.user_override (and associated refactoring)
Browse files Browse the repository at this point in the history
Related: #886
  • Loading branch information
peterthomassen committed Dec 7, 2024
1 parent 66c0c1c commit 607e2f0
Show file tree
Hide file tree
Showing 9 changed files with 362 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Generated by Django 5.1.3 on 2024-11-25 19:53

import django.db.models.deletion
import django.db.models.functions.comparison
import pgtrigger.compiler
import pgtrigger.migrations
from django.conf import settings
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("desecapi", "0041_remove_token_token_auto_policy_and_more"),
]

operations = [
migrations.RunSQL(
sql="ALTER TABLE desecapi_tokendomainpolicy DROP CONSTRAINT desecapi_tokendomainpolicy_token_id_token_user_id_fkey RESTRICT;",
reverse_sql="ALTER TABLE desecapi_tokendomainpolicy ADD FOREIGN KEY ( token_id, token_user_id ) REFERENCES desecapi_token ( id, user_id );",
),
migrations.RemoveConstraint(
model_name="token",
name="unique_id_user",
),
migrations.AlterField(
model_name="token",
name="user",
field=models.UUIDField(),
),
migrations.RenameField(
model_name="token",
old_name="user",
new_name="owner",
),
migrations.AlterField(
model_name="token",
name="owner",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="+",
to=settings.AUTH_USER_MODEL,
),
),
migrations.AddField(
model_name="token",
name="user_override",
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="+",
to=settings.AUTH_USER_MODEL,
),
),
migrations.AddConstraint(
model_name="token",
constraint=models.CheckConstraint(
condition=models.Q(
("user_override__isnull", True),
models.Q(
("mfa__isnull", True),
("perm_manage_tokens", False),
models.Q(("user_override", models.F("owner")), _negated=True),
),
_connector="OR",
),
name="user_override_conditions",
),
),
migrations.AddField(
model_name="token",
name="user_id",
field=models.GeneratedField(
db_index=True,
db_persist=True,
expression=django.db.models.functions.comparison.Coalesce(
"user_override", "owner"
),
output_field=models.UUIDField(),
),
),
migrations.AddConstraint(
model_name="token",
constraint=models.UniqueConstraint(
fields=("id", "user_id"), name="unique_id_user"
),
),
migrations.RunSQL(
sql="ALTER TABLE desecapi_tokendomainpolicy ADD FOREIGN KEY ( token_id, token_user_id ) REFERENCES desecapi_token ( id, user_id ) DEFERRABLE INITIALLY DEFERRED;",
reverse_sql="ALTER TABLE desecapi_tokendomainpolicy DROP CONSTRAINT desecapi_tokendomainpolicy_token_id_token_user_id_fkey RESTRICT;",
),
migrations.AddField(
model_name="token",
name="user",
field=models.ForeignObject(
from_fields=["user_id"],
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="token_set",
to=settings.AUTH_USER_MODEL,
to_fields=["id"],
),
),
pgtrigger.migrations.AddTrigger(
model_name="token",
trigger=pgtrigger.compiler.Trigger(
name="token_policy_user_id",
sql=pgtrigger.compiler.UpsertTriggerSql(
condition='WHEN (OLD."user_id" IS DISTINCT FROM (NEW."user_id"))',
func="\n IF\n OLD.user_override_id IS NOT NULL\n THEN\n RAISE EXCEPTION 'Cannot alter Token.user_override_id once set. (token.id=%)', NEW.id;\n END IF;\n UPDATE desecapi_tokendomainpolicy SET token_user_id = NEW.user_id WHERE token_id = NEW.id;\n RETURN NULL;\n ",
hash="a35df14f79206d73314376dd33ce217359b8f3dc",
operation='UPDATE OF "user_id"',
pgid="pgtrigger_token_policy_user_id_c0a5d",
table="desecapi_token",
when="AFTER",
),
),
),
pgtrigger.migrations.RemoveTrigger(
model_name="tokendomainpolicy",
name="token_user",
),
pgtrigger.migrations.AddTrigger(
model_name="tokendomainpolicy",
trigger=pgtrigger.compiler.Trigger(
name="token_user_insert",
sql=pgtrigger.compiler.UpsertTriggerSql(
func="NEW.token_user_id = (SELECT user_id FROM desecapi_token WHERE id = NEW.token_id); RETURN NEW;",
hash="37cd6136f62cfc0209565c771a1fc1b1e789ed4b",
operation="INSERT",
pgid="pgtrigger_token_user_insert_619b2",
table="desecapi_tokendomainpolicy",
when="BEFORE",
),
),
),
pgtrigger.migrations.AddTrigger(
model_name="tokendomainpolicy",
trigger=pgtrigger.compiler.Trigger(
name="token_user_update",
sql=pgtrigger.compiler.UpsertTriggerSql(
func="\n IF\n NEW.token_user_id != (SELECT user_id FROM desecapi_token WHERE id = NEW.token_id)\n THEN\n RAISE EXCEPTION 'Invalid token_user_id: %', NEW.token_user_id;\n END IF;\n RETURN NEW;\n ",
hash="f385209f06beb1f0d38376b6223721efd02baf73",
operation='UPDATE OF "token_user_id"',
pgid="pgtrigger_token_user_update_7ff54",
table="desecapi_tokendomainpolicy",
when="BEFORE",
),
),
),
]
67 changes: 63 additions & 4 deletions api/desecapi/models/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from django.core.exceptions import ValidationError
from django.db import models, transaction
from django.db.models import F, Q
from django.db.models.fields.related import ForeignObject
from django.db.models.functions import Coalesce
from django.utils import timezone
from django_prometheus.models import ExportModelOperationsMixin
from netfields import CidrAddressField, NetManager
Expand All @@ -37,7 +39,24 @@ def _allowed_subnets_default():

id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
key = models.CharField("Key", max_length=128, db_index=True, unique=True)
user = models.ForeignKey("User", on_delete=models.CASCADE)
owner = models.ForeignKey("User", on_delete=models.CASCADE, related_name="+")
user_id = models.GeneratedField(
expression=Coalesce("user_override", "owner"),
output_field=models.UUIDField(),
db_index=True,
db_persist=True,
)
user = ForeignObject(
"User",
on_delete=models.CASCADE,
from_fields=["user_id"],
to_fields=["id"],
null=True,
related_name="token_set",
)
user_override = models.ForeignKey(
"User", on_delete=models.CASCADE, null=True, related_name="+"
)
name = models.CharField("Name", blank=True, max_length=64)
last_used = models.DateTimeField(null=True, blank=True)
mfa = models.BooleanField(default=None, null=True)
Expand All @@ -59,7 +78,15 @@ def _allowed_subnets_default():

class Meta:
constraints = [
models.UniqueConstraint(fields=["id", "user"], name="unique_id_user")
models.UniqueConstraint(fields=["id", "user_id"], name="unique_id_user"),
models.CheckConstraint(
condition=Q(user_override__isnull=True)
| (
Q(perm_manage_tokens=False, mfa__isnull=True)
& ~Q(user_override=F("owner"))
),
name="user_override_conditions",
),
]
triggers = [
# Ensure that a default policy is defined when auto_policy=true
Expand All @@ -81,6 +108,25 @@ class Meta:
"""
),
),
# Update TokenDomainPolicy.token_user when .user_override changes (will fail FK if token
# has domain-specific policies). Also protect against further .user_override changes.
pgtrigger.Trigger(
name="token_policy_user_id",
operation=pgtrigger.UpdateOf("user_id"),
condition=pgtrigger.AnyChange("user_id"),
when=pgtrigger.After,
func=pgtrigger.Func(
"""
IF
OLD.user_override_id IS NOT NULL
THEN
RAISE EXCEPTION 'Cannot alter Token.user_override_id once set. (token.id=%)', NEW.id;
END IF;
UPDATE {meta.many_to_many[0].remote_field.through._meta.db_table} SET token_user_id = NEW.user_id WHERE token_id = NEW.id;
RETURN NULL;
"""
),
),
]

@property
Expand Down Expand Up @@ -206,11 +252,24 @@ class Meta:
triggers = [
# Ensure that token_user is consistent with token (to fulfill compound FK constraint, see migration)
pgtrigger.Trigger(
name="token_user",
operation=pgtrigger.Update | pgtrigger.Insert,
name="token_user_insert",
operation=pgtrigger.Insert,
when=pgtrigger.Before,
func="NEW.token_user_id = (SELECT user_id FROM desecapi_token WHERE id = NEW.token_id); RETURN NEW;",
),
pgtrigger.Trigger(
name="token_user_update",
operation=pgtrigger.UpdateOf("token_user_id"),
when=pgtrigger.Before,
func="""
IF
NEW.token_user_id != (SELECT user_id FROM desecapi_token WHERE id = NEW.token_id)
THEN
RAISE EXCEPTION 'Invalid token_user_id: %', NEW.token_user_id;
END IF;
RETURN NEW;
""",
),
# Ensure that if there is *any* domain policy for a given token, there is always one with domain=None.
pgtrigger.Trigger(
name="default_policy_primacy",
Expand Down
12 changes: 6 additions & 6 deletions api/desecapi/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,12 +756,12 @@ def assertResponse(self, response, code=None, body=None):
if body:
self.assertJSONEqual(response.content, body)

def assertToken(self, plain, user=None):
user = user or self.owner
def assertToken(self, plain, owner=None):
owner = owner or self.owner
self.assertTrue(
any(
check_password(plain, hashed, preferred="pbkdf2_sha256_iter1")
for hashed in Token.objects.filter(user=user).values_list(
for hashed in Token.objects.filter(owner=owner).values_list(
"key", flat=True
)
)
Expand Down Expand Up @@ -1001,8 +1001,8 @@ def has_local_suffix(cls, domain_name: str):
)

@classmethod
def create_token(cls, user, **kwargs):
return Token.objects.create(user=user, **kwargs)
def create_token(cls, owner, **kwargs):
return Token.objects.create(owner=owner, **kwargs)

@classmethod
def create_user(cls, needs_captcha=False, **kwargs):
Expand Down Expand Up @@ -1351,7 +1351,7 @@ def setUpTestDataWithPdns(cls):
cls.create_rr_set(cls.other_domain, ["40.1.1.1"], type="A", ttl=456)

cls.token = cls.create_token(
user=cls.owner, perm_create_domain=True, perm_delete_domain=True
owner=cls.owner, perm_create_domain=True, perm_delete_domain=True
)

def setUp(self):
Expand Down
19 changes: 8 additions & 11 deletions api/desecapi/tests/test_token_domain_policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from django.db import connection, transaction
from django.db.utils import IntegrityError
from django.db.utils import IntegrityError, ProgrammingError
from rest_framework import status
from rest_framework.test import APIClient

Expand Down Expand Up @@ -220,7 +220,7 @@ def test_policy_lifecycle(self):
self.assertEqual(response.data, [])

# Other token gives 404
other_token = self.create_token(user=self.create_user())
other_token = self.create_token(owner=self.create_user())
response = self.client.list_policies(other_token, using=self.token_manage)
self.assertStatus(response, status.HTTP_404_NOT_FOUND)

Expand Down Expand Up @@ -557,13 +557,9 @@ def test_token_user_consistency(self):
)
policy.save()

self.token.user = self.other_domains[0].owner
with self.assertRaises(IntegrityError):
with transaction.atomic(): # https://stackoverflow.com/a/23326971/6867099
self.token.save()

policy.delete()
self.token.save()
policy.token_user = self.other_domains[0].owner
with self.assertRaises(ProgrammingError):
policy.save()

def test_domain_owner_equals_token_user(self):
models.TokenDomainPolicy(
Expand All @@ -579,10 +575,11 @@ def test_domain_owner_equals_token_user(self):
type=None,
).save()

self.token.user = self.other_domain.owner
self.token.owner = self.other_domain.owner
with self.assertRaises(IntegrityError):
with transaction.atomic(): # https://stackoverflow.com/a/23326971/6867099
self.token.save()
connection.check_constraints() # simulate transaction commit

def test_domain_deletion_policy_cleanup(self):
domains = [None] + self.my_domains[:2]
Expand Down Expand Up @@ -630,7 +627,7 @@ def test_user_deletion(self):
# Only the default policy should be left, so get can simply get() it
policy_pk = self.token.tokendomainpolicy_set.get().pk

self.token.user.delete()
self.token.owner.delete()
self.assertFalse(models.TokenDomainPolicy.objects.filter(pk=policy_pk).exists())


Expand Down
Loading

0 comments on commit 607e2f0

Please sign in to comment.