import contextlib
import copy
import hashlib
import inspect
import logging
import threading
import django.apps
from django.conf import settings
from django.db import connections, DEFAULT_DB_ALIAS, models, NotSupportedError, router, transaction
from django.db.models.expressions import Col
from django.db.models.fields.related import RelatedField
from django.db.models.sql import Query
from django.db.models.sql.datastructures import BaseTable
from django.db.utils import ProgrammingError
LOGGER = logging.getLogger('pgtrigger')
_unset = object()
# Postgres only allows identifiers to be 63 chars max. Since "pgtrigger_"
# is the prefix for trigger names, and since an additional "_" and
# 5 character hash is added, the user-defined name of the trigger can only
# be 47 chars.
# NOTE: We can do something more sophisticated later by allowing users
# to name their triggers and then hashing the names when actually creating
# the triggers.
MAX_NAME_LENGTH = 47
# Installation states for a triggers
INSTALLED = 'INSTALLED'
UNINSTALLED = 'UNINSTALLED'
OUTDATED = 'OUTDATED'
PRUNE = 'PRUNE'
# All registered triggers for each model
registry = {}
# All triggers currently being ignored
_ignore = threading.local()
def _quote(label):
"""Conditionally wraps a label in quotes"""
if label.startswith('"'):
return label
else:
return f'"{label}"'
def _get_database(model):
"""
Obtains the database used for a trigger / model pair. The database
for the connection is selected based on the write DB in the database
router config.
"""
return router.db_for_write(model) or DEFAULT_DB_ALIAS
def _postgres_databases(databases):
"""Given an iterable of databases, only return postgres ones"""
return [database for database in databases if connections[database].vendor == 'postgresql']
def _get_connection(model):
"""
Obtains the connection used for a trigger / model pair. The database
for the connection is selected based on the write DB in the database
router config.
"""
return connections[_get_database(model)]
def _get_model(table):
"""Obtains a django model based on its table name"""
for model in django.apps.apps.get_models(): # pragma: no branch
if _quote(model._meta.db_table) == _quote(table):
return model
def _is_concurrent_statement(sql):
"""
True if the sql statement is concurrent and cannot be ran in a transaction
"""
sql = sql.strip().lower() if sql else ''
return sql.startswith('create') and 'concurrently' in sql
def _inject_pgtrigger_ignore(execute, sql, params, many, context): # pragma: no cover
"""
A connection execution wrapper that sets a pgtrigger.ignore
variable in the executed SQL. This lets other triggers know when
they should ignore execution
"""
cursor = context['cursor']
# A named cursor automatically prepends
# "NO SCROLL CURSOR WITHOUT HOLD FOR" to the query, which
# causes invalid SQL to be generated. There is no way
# to override this behavior in psycopg2, so ignoring triggers
# cannot happen for named cursors. Django only names cursors
# for iterators and other statements that read the database,
# so it seems to be safe to ignore named cursors.
#
# Concurrent index creation is also incompatible with local variable
# setting. Ignore these cases for now.
if not cursor.name and not _is_concurrent_statement(sql):
sql = "SET LOCAL pgtrigger.ignore='{" + ",".join(_ignore.value) + "}';" + sql
return execute(sql, params, many, context)
[docs]def register(*triggers):
"""
Register the given triggers with wrapped Model class.
Args:
*triggers (`pgtrigger.Trigger`): Trigger classes to register.
Examples:
Register by decorating a model::
@pgtrigger.register(
pgtrigger.Protect(
name="append_only",
operation=(pgtrigger.Update | pgtrigger.Delete)
)
)
class MyModel(models.Model):
pass
Register by calling functionally::
pgtrigger.register(trigger_object)(MyModel)
"""
def _model_wrapper(model_class):
for trigger in triggers:
trigger.register(model_class)
return model_class
return _model_wrapper
class _Serializable:
def get_init_vals(self):
"""Returns class initialization args so that they are properly serialized for migrations"""
parameters = inspect.signature(self.__init__).parameters
for key, val in parameters.items():
if key != "self" and (
not hasattr(self, key) or val.kind == inspect.Parameter.VAR_KEYWORD
): # pragma: no cover
raise ValueError(
f"Could not automatically serialize Trigger {self.__class__} for migrations."
' Implement "get_init_vals()" on the trigger class. See the'
' FAQ in the django-pgtrigger docs for more information.'
)
args = tuple(
item
for key, val in parameters.items()
if val.kind == inspect.Parameter.VAR_POSITIONAL
for item in getattr(self, key)
)
kwargs = {
key: getattr(self, key)
for key, value in parameters.items()
if key != "self" and val.kind != inspect.Parameter.VAR_POSITIONAL
}
return args, kwargs
def deconstruct(self):
"""For supporting Django migrations"""
path = f"{self.__class__.__module__}.{self.__class__.__name__}"
path = path.replace("pgtrigger.core", "pgtrigger")
args, kwargs = self.get_init_vals()
return path, args, kwargs
def __eq__(self, other):
return self.get_init_vals() == other.get_init_vals()
class _Primitive(_Serializable):
"""Boilerplate for some of the primitive operations"""
def __init__(self, name):
assert name in self.values
self.name = name
def __str__(self):
return self.name
class Level(_Primitive):
values = ("ROW", "STATEMENT")
#: For specifying row-level triggers (the default)
Row = Level('ROW')
#: For specifying statement-level triggers
Statement = Level('STATEMENT')
[docs]class Referencing(_Serializable):
"""For specifying the REFERENCING clause of a statement-level trigger"""
def __init__(self, *, old=None, new=None):
if not old and not new:
raise ValueError(
'Must provide either "old" and/or "new" to the referencing'
' construct of a trigger'
)
self.old = old
self.new = new
def __str__(self):
ref = 'REFERENCING'
if self.old:
ref += f' OLD TABLE AS {self.old} '
if self.new:
ref += f' NEW TABLE AS {self.new} '
return ref
class When(_Primitive):
values = ("BEFORE", "AFTER", "INSTEAD OF")
#: For specifying ``BEFORE`` in the when clause of a trigger.
Before = When('BEFORE')
#: For specifying ``AFTER`` in the when clause of a trigger.
After = When('AFTER')
#: For specifying ``INSTEAD OF`` in the when clause of a trigger.
InsteadOf = When('INSTEAD OF')
class Operation(_Primitive):
values = ("UPDATE", "DELETE", "TRUNCATE", "INSERT")
def __or__(self, other):
assert isinstance(other, Operation)
return Operations(self, other)
class Operations(Operation):
"""For providing multiple operations ``OR``ed together.
Note that using the ``|`` operator is preferred syntax.
"""
def __init__(self, *operations):
for operation in operations:
assert isinstance(operation, Operation)
self.operations = operations
def __str__(self):
return ' OR '.join(str(operation) for operation in self.operations)
#: For specifying ``UPDATE`` as the trigger operation.
Update = Operation('UPDATE')
#: For specifying ``DELETE`` as the trigger operation.
Delete = Operation('DELETE')
#: For specifying ``TRUNCATE`` as the trigger operation.
Truncate = Operation('TRUNCATE')
#: For specifying ``INSERT`` as the trigger operation.
Insert = Operation('INSERT')
[docs]class UpdateOf(Operation):
"""For specifying ``UPDATE OF`` as the trigger operation."""
def __init__(self, *columns):
if not columns:
raise ValueError('Must provide at least one column')
self.columns = columns
def __str__(self):
columns = ', '.join(f'{_quote(col)}' for col in self.columns)
return f'UPDATE OF {columns}'
[docs]class Condition(_Serializable):
"""For specifying free-form SQL in the condition of a trigger."""
sql = None
def __init__(self, sql=None):
self.sql = sql or self.sql
if not self.sql:
raise ValueError('Must provide SQL to condition')
def resolve(self, model):
return self.sql
class _OldNewQuery(Query):
"""
A special Query object for referencing the ``OLD`` and ``NEW`` variables in a
trigger. Only used by the `pgtrigger.Q` object.
"""
def build_lookup(self, lookups, lhs, rhs):
# Django does not allow custom lookups on foreign keys, even though
# DISTINCT FROM is a comnpletely valid lookup. Trick django into
# being able to apply this lookup to related fields.
if lookups == ['df'] and isinstance(lhs.output_field, RelatedField):
lhs = copy.deepcopy(lhs)
lhs.output_field = models.IntegerField(null=lhs.output_field.null)
return super().build_lookup(lookups, lhs, rhs)
def build_filter(self, filter_expr, *args, **kwargs):
if isinstance(filter_expr, Q):
return super().build_filter(filter_expr, *args, **kwargs)
if filter_expr[0].startswith('old__'):
alias = 'OLD'
elif filter_expr[0].startswith('new__'):
alias = 'NEW'
else: # pragma: no cover
raise ValueError('Filter expression on trigger.Q object must reference old__ or new__')
filter_expr = (filter_expr[0][5:], filter_expr[1])
node, _ = super().build_filter(filter_expr, *args, **kwargs)
self.alias_map[alias] = BaseTable(alias, alias)
for child in node.children:
child.lhs = Col(
alias=alias,
target=child.lhs.target,
output_field=child.lhs.output_field,
)
return node, {alias}
[docs]class F(models.F):
"""
Similar to Django's ``F`` object, allows referencing the old and new
rows in a trigger condition.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.name.startswith('old__'):
self.row_alias = 'OLD'
elif self.name.startswith('new__'):
self.row_alias = 'NEW'
else:
raise ValueError('F() values must reference old__ or new__')
self.col_name = self.name[5:]
def deconstruct(self):
path, args, kwargs = super().deconstruct()
path = path.replace("pgtrigger.core", "pgtrigger")
return path, args, kwargs
@property
def resolved_name(self):
return f'{self.row_alias}.{_quote(self.col_name)}'
def resolve_expression(self, query=None, *args, **kwargs):
return Col(
alias=self.row_alias,
target=query.model._meta.get_field(self.col_name),
)
[docs]@models.fields.Field.register_lookup
class IsDistinctFrom(models.Lookup):
"""
A custom ``IS DISTINCT FROM`` field lookup for common trigger conditions.
For example, ``pgtrigger.Q(old__field__df=pgtrigger.F("new__field"))``.
"""
lookup_name = 'df'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s IS DISTINCT FROM %s' % (lhs, rhs), params
[docs]@models.fields.Field.register_lookup
class IsNotDistinctFrom(models.Lookup):
"""
A custom ``IS NOT DISTINCT FROM`` field lookup for common trigger conditions.
For example, ``pgtrigger.Q(old__field__ndf=pgtrigger.F("new__field"))``.
"""
lookup_name = 'ndf'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return '%s IS NOT DISTINCT FROM %s' % (lhs, rhs), params
[docs]class Q(models.Q, Condition):
"""
Similar to Django's ``Q`` object, allows referencing the old and new
rows in a trigger condition.
"""
def deconstruct(self):
path, args, kwargs = super().deconstruct()
path = path.replace("pgtrigger.core", "pgtrigger")
return path, args, kwargs
def resolve(self, model):
connection = _get_connection(model)
query = _OldNewQuery(model)
sql = (
connection.cursor()
.mogrify(
*self.resolve_expression(query).as_sql(
compiler=query.get_compiler('default'),
connection=connection,
)
)
.decode()
.replace('"OLD"', 'OLD')
.replace('"NEW"', 'NEW')
)
return sql
def _render_uninstall(table, trigger_pgid):
return f'DROP TRIGGER IF EXISTS {trigger_pgid} ON {_quote(table)};'
def _drop_trigger(table, trigger_pgid):
model = _get_model(table)
connection = _get_connection(model)
uninstall_sql = _render_uninstall(table, trigger_pgid)
with connection.cursor() as cursor:
cursor.execute(uninstall_sql)
# Allows Trigger methods to be used as context managers, mostly for
# testing purposes
@contextlib.contextmanager
def _cleanup_on_exit(cleanup):
yield
cleanup()
def _render_ignore_func():
"""
Triggers can be ignored dynamically by help of a special function that's installed.
The definition of this function is here.
Note: This function is global and shared by all triggers in the current
implementation. It isn't uninstalled when triggers are uninstalled.
"""
return '''
CREATE OR REPLACE FUNCTION _pgtrigger_should_ignore(
table_name NAME,
trigger_name NAME
)
RETURNS BOOLEAN AS $$
DECLARE
_pgtrigger_ignore TEXT[];
_result BOOLEAN;
BEGIN
BEGIN
SELECT INTO _pgtrigger_ignore
CURRENT_SETTING('pgtrigger.ignore');
EXCEPTION WHEN OTHERS THEN
END;
IF _pgtrigger_ignore IS NOT NULL THEN
SELECT CONCAT(table_name, ':', trigger_name) = ANY(_pgtrigger_ignore)
INTO _result;
RETURN _result;
ELSE
RETURN FALSE;
END IF;
END;
$$ LANGUAGE plpgsql;
'''
[docs]class Trigger(_Serializable):
"""
For specifying a free-form PL/pgSQL trigger function or for
creating derived trigger classes.
"""
name = None
level = Row
when = None
operation = None
condition = None
referencing = None
func = None
declare = None
def __init__(
self,
*,
name=None,
level=None,
when=None,
operation=None,
condition=None,
referencing=None,
func=None,
declare=None,
):
self.name = name or self.name
self.level = level or self.level
self.when = when or self.when
self.operation = operation or self.operation
self.condition = condition or self.condition
self.referencing = referencing or self.referencing
self.func = func or self.func
self.declare = declare or self.declare
if not self.level or not isinstance(self.level, Level):
raise ValueError(f'Invalid "level" attribute: {self.level}')
if not self.when or not isinstance(self.when, When):
raise ValueError(f'Invalid "when" attribute: {self.when}')
if not self.operation or not isinstance(self.operation, Operation):
raise ValueError(f'Invalid "operation" attribute: {self.operation}')
if self.level == Row and self.referencing:
raise ValueError('Row-level triggers cannot have a "referencing" attribute')
if not self.name:
raise ValueError('Trigger must have "name" attribute')
self.validate_name()
def __str__(self):
return self.name
def validate_name(self):
"""Verifies the name is under the maximum length"""
if len(self.name) > MAX_NAME_LENGTH:
raise ValueError(f'Trigger name "{self.name}" > {MAX_NAME_LENGTH} characters.')
def get_pgid(self, model):
"""The ID of the trigger and function object in postgres
All objects are prefixed with "pgtrigger_" in order to be
discovered/managed by django-pgtrigger
"""
model_hash = hashlib.sha1(self.get_uri(model).encode()).hexdigest()[:5]
pgid = f'pgtrigger_{self.name}_{model_hash}'
if len(pgid) > 63:
raise ValueError(f'Trigger identifier "{pgid}" is greater than 63 chars')
# NOTE - Postgres always stores names in lowercase. Ensure that all
# generated IDs are lowercase so that we can properly do installation
# and pruning tasks.
return pgid.lower()
def get_condition(self, model):
return self.condition
def get_declare(self, model):
"""
Gets the DECLARE part of the trigger function if any variables
are used.
Returns:
List[tuple]: A list of variable name / type tuples that will
be shown in the DECLARE. For example [('row_data', 'JSONB')]
"""
return self.declare or []
def get_func(self, model):
"""
Returns the trigger function that comes between the BEGIN and END
clause
"""
if not self.func:
raise ValueError('Must define func attribute or implement get_func')
return self.func
def get_uri(self, model):
"""The URI for the trigger in the registry"""
return f'{model._meta.app_label}.{model._meta.object_name}:{self.name}'
def register(self, *models):
"""Register model classes with the trigger"""
# Compute the unique trigger function names that are already
# registered in order to prevent an accidental collision
registered_function_names = {
trigger.get_pgid(model) for model, trigger in registry.values()
}
for model in models:
uri = self.get_uri(model)
if uri in registry:
raise ValueError(
f'Trigger with name "{self.name}" is already'
f' registered for model "{model}"'
)
if self.get_pgid(model) in registered_function_names:
raise ValueError(
f'Trigger with name "{self.name}" on model "{model}"'
' has a trigger function name that is already taken.'
' Use a different name for the trigger.'
)
registry[uri] = (model, self)
# If we support migration integration, patch the constraints of
# the model
if getattr(settings, 'PGTRIGGER_MIGRATIONS', True): # pragma: no branch
model._meta.constraints = list(model._meta.constraints) + [self]
model._meta.original_attrs["constraints"] = list(
model._meta.original_attrs.get("constraints", [])
) + [self]
return _cleanup_on_exit(lambda: self.unregister(*models))
def unregister(self, *models):
"""Unregister model classes with the trigger"""
for model in models:
del registry[self.get_uri(model)]
# If we support migration integration, patch the constraints of
# the model
if getattr(settings, 'PGTRIGGER_MIGRATIONS', True): # pragma: no branch
model._meta.constraints.remove(self)
model._meta.original_attrs["constraints"].remove(self)
return _cleanup_on_exit(lambda: self.register(*models))
def render_condition(self, model):
"""Renders the condition SQL in the trigger declaration"""
condition = self.get_condition(model)
resolved = condition.resolve(model).strip() if condition else ''
if resolved:
if not resolved.startswith('('):
resolved = f'({resolved})'
resolved = f'WHEN {resolved}'
return resolved
def render_declare(self, model):
"""Renders the DECLARE of the trigger function, if any"""
declare = self.get_declare(model)
if declare:
rendered_declare = 'DECLARE \n' + '\n'.join(
f'{var_name} {var_type};' for var_name, var_type in declare
)
else:
rendered_declare = ''
return rendered_declare
def render_ignore(self, model):
"""
Renders the clause that can dynamically ignore the trigger's execution
"""
return '''
IF (_pgtrigger_should_ignore(TG_TABLE_NAME, TG_NAME) IS TRUE) THEN
IF (TG_OP = 'DELETE') THEN
RETURN OLD;
ELSE
RETURN NEW;
END IF;
END IF;
'''
def render_func(self, model):
"""Renders the trigger function SQL statement"""
return f'''
CREATE OR REPLACE FUNCTION {self.get_pgid(model)}()
RETURNS TRIGGER AS $$
{self.render_declare(model)}
BEGIN
{self.render_ignore(model)}
{self.get_func(model)}
END;
$$ LANGUAGE plpgsql;
'''
def render_trigger(self, model):
"""Renders the trigger declaration SQL statement"""
table = model._meta.db_table
pgid = self.get_pgid(model)
return f'''
DO $$ BEGIN
CREATE TRIGGER {pgid}
{self.when} {self.operation} ON {_quote(table)}
{self.referencing or ''}
FOR EACH {self.level} {self.render_condition(model)}
EXECUTE PROCEDURE {pgid}();
EXCEPTION
-- Ignore issues if the trigger already exists
WHEN duplicate_object THEN null;
END $$;
'''
def render_comment(self, model):
"""Renders the trigger commment SQL statement
pgtrigger comments the hash of the trigger in order for us to
determine if the trigger definition has changed
"""
pgid = self.get_pgid(model)
hash = self.get_hash(model)
table = model._meta.db_table
return f"COMMENT ON TRIGGER {pgid} ON {_quote(table)} IS '{hash}'"
def get_installation_status(self, model):
"""Returns the installation status of a trigger.
The return type is (status, enabled), where status is one of:
1. ``INSTALLED``: If the trigger is installed
2. ``UNINSTALLED``: If the trigger is not installed
3. ``OUTDATED``: If the trigger is installed but
has been modified
"enabled" is True if the trigger is installed and enabled or false
if installed and disabled (or uninstalled).
"""
connection = _get_connection(model)
trigger_exists_sql = f'''
SELECT oid, obj_description(oid) AS hash, tgenabled AS enabled
FROM pg_trigger
WHERE tgname='{self.get_pgid(model)}'
AND tgrelid='{model._meta.db_table}'::regclass;
'''
try:
with connection.cursor() as cursor:
cursor.execute(trigger_exists_sql)
results = cursor.fetchall()
except ProgrammingError: # pragma: no cover
# When the table doesn't exist yet, possibly because migrations
# haven't been executed, a ProgrammingError will happen because
# of an invalid regclass cast. Return 'UNINSTALLED' for this
# case
return (UNINSTALLED, None)
if not results:
return (UNINSTALLED, None)
else:
hash = self.get_hash(model)
if hash != results[0][1]:
return (OUTDATED, results[0][2] == 'O')
else:
return (INSTALLED, results[0][2] == 'O')
def get_hash(self, model):
"""
Computes a hash for the trigger, which is used to
uniquely identify its contents. The hash is computed based
on the trigger function and declaration.
Note: If the trigger definition includes dynamic data, such
as the current time, the trigger hash will always change and
appear to be out of sync.
"""
rendered_func = self.render_func(model)
rendered_trigger = self.render_trigger(model)
return hashlib.sha1(f'{rendered_func} {rendered_trigger}'.encode()).hexdigest()
def render_install(self, model):
ignore_func = _render_ignore_func()
rendered_func = self.render_func(model)
rendered_trigger = self.render_trigger(model)
rendered_comment = self.render_comment(model)
return f"{ignore_func}; {rendered_func}; {rendered_trigger}; {rendered_comment};"
def install(self, model):
"""Installs the trigger for a model"""
connection = _get_connection(model)
install_sql = self.render_install(model)
with connection.cursor() as cursor:
cursor.execute(install_sql)
return _cleanup_on_exit(lambda: self.uninstall(model))
def render_uninstall(self, model):
return _render_uninstall(model._meta.db_table, self.get_pgid(model))
def uninstall(self, model):
"""Uninstalls the trigger for a model"""
connection = _get_connection(model)
uninstall_sql = self.render_uninstall(model)
with connection.cursor() as cursor:
cursor.execute(uninstall_sql)
return _cleanup_on_exit(lambda: self.install(model)) # pragma: no branch
def enable(self, model):
"""Enables the trigger for a model"""
connection = _get_connection(model)
with connection.cursor() as cursor:
cursor.execute(
f'ALTER TABLE {_quote(model._meta.db_table)}'
f' ENABLE TRIGGER {self.get_pgid(model)};'
)
return _cleanup_on_exit(lambda: self.disable(model)) # pragma: no branch
def disable(self, model):
"""Disables the trigger for a model"""
connection = _get_connection(model)
with connection.cursor() as cursor:
cursor.execute(
f'ALTER TABLE {_quote(model._meta.db_table)}'
f' DISABLE TRIGGER {self.get_pgid(model)};'
)
return _cleanup_on_exit(lambda: self.enable(model)) # pragma: no branch
@contextlib.contextmanager
def ignore(self, model):
"""Ignores the trigger in a single thread of execution"""
connection = transaction.get_connection()
with contextlib.ExitStack() as pre_execute_hook:
# Create the table name / trigger name URI to pass down to the
# trigger.
ignore_uri = f'{model._meta.db_table}:{self.get_pgid(model)}'
if not hasattr(_ignore, 'value'):
_ignore.value = set()
if not _ignore.value:
# If this is the first time we are ignoring trigger execution,
# register the pre_execute_hook
pre_execute_hook.enter_context(
connection.execute_wrapper(_inject_pgtrigger_ignore)
)
if ignore_uri not in _ignore.value:
try:
_ignore.value.add(ignore_uri)
yield
finally:
_ignore.value.remove(ignore_uri)
else: # The trigger is already being ignored
yield
if not _ignore.value and connection.in_atomic_block:
# We've finished all ignoring of triggers, but we are in a transaction
# and still have a reference to the local variable. Reset it
with connection.cursor() as cursor:
cursor.execute('RESET pgtrigger.ignore;')
# All following attributes are available so that the trigger works as
# a constraint in Django's migration system.
@property
def contains_expressions(self): # pragma: no cover
return False
def constraint_sql(self, model, schema_editor): # pragma: no cover
return ""
def create_sql(self, model, schema_editor):
if schema_editor.connection.vendor != 'postgresql': # pragma: no cover
raise NotSupportedError("Triggers are only supported on PostgreSQL databases.")
sql = self.render_install(model)
# Note: Django 3.2 introduced differences in how SQL is executed for
# migration commands. Be sure to escape the special "%" character
return sql.replace("%", "%%") if django.VERSION < (3, 2) else sql
def remove_sql(self, model, schema_editor):
sql = self.render_uninstall(model)
# Note: Django 3.2 introduced differences in how SQL is executed for
# migration commands. Be sure to escape the special "%" character
return sql.replace("%", "%%") if django.VERSION < (3, 2) else sql
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
"""
Triggers cannot be ran in software, so the validation check cannot happen
like it can for other constraints.
"""
pass
def clone(self):
_, args, kwargs = self.deconstruct()
return self.__class__(*args, **kwargs)
[docs]class Protect(Trigger):
"""A trigger that raises an exception."""
when = Before
def get_func(self, model):
return f'''
RAISE EXCEPTION
'pgtrigger: Cannot {str(self.operation).lower()} rows from % table',
TG_TABLE_NAME;
'''
[docs]class FSM(Trigger):
"""Enforces a finite state machine on a field.
Supply the trigger with the "field" that transitions and then
a list of tuples of valid transitions to the "transitions" argument.
.. note::
Only non-null ``CharField`` fields are currently supported.
"""
when = Before
operation = Update
field = None
transitions = None
def __init__(self, *, name=None, condition=None, field=None, transitions=None):
self.field = field or self.field
self.transitions = transitions or self.transitions
if not self.field: # pragma: no cover
raise ValueError('Must provide "field" for FSM')
if not self.transitions: # pragma: no cover
raise ValueError('Must provide "transitions" for FSM')
super().__init__(name=name, condition=condition)
def get_declare(self, model):
return [('_is_valid_transition', 'BOOLEAN')]
def get_func(self, model):
col = model._meta.get_field(self.field).column
transition_uris = '{' + ','.join([f'{old}:{new}' for old, new in self.transitions]) + '}'
return f'''
SELECT CONCAT(OLD.{_quote(col)}, ':', NEW.{_quote(col)}) = ANY('{transition_uris}'::text[])
INTO _is_valid_transition;
IF (_is_valid_transition IS FALSE AND OLD.{_quote(col)} IS DISTINCT FROM NEW.{_quote(col)}) THEN
RAISE EXCEPTION
'pgtrigger: Invalid transition of field "{self.field}" from "%" to "%" on table %',
OLD.{_quote(col)},
NEW.{_quote(col)},
TG_TABLE_NAME;
ELSE
RETURN NEW;
END IF;
''' # noqa
[docs]class SoftDelete(Trigger):
"""Sets a field to a value when a delete happens.
Supply the trigger with the "field" that will be set
upon deletion and the "value" to which it should be set.
The "value" defaults to ``False``.
.. note::
This trigger currently only supports nullable ``BooleanField``,
``CharField``, and ``IntField`` fields.
"""
when = Before
operation = Delete
field = None
value = False
def __init__(self, *, name=None, condition=None, field=None, value=_unset):
self.field = field or self.field
self.value = value if value is not _unset else self.value
if not self.field: # pragma: no cover
raise ValueError('Must provide "field" for soft delete')
super().__init__(name=name, condition=condition)
def get_func(self, model):
soft_field = model._meta.get_field(self.field).column
pk_col = model._meta.pk.column
def _render_value():
if self.value is None:
return 'NULL'
elif isinstance(self.value, str):
return f"'{self.value}'"
else:
return str(self.value)
return f'''
UPDATE {_quote(model._meta.db_table)}
SET {soft_field} = {_render_value()}
WHERE {_quote(pk_col)} = OLD.{_quote(pk_col)};
RETURN NULL;
'''
[docs]def get(*uris, database=None):
"""
Get registered trigger objects.
Args:
*uris (str): URIs of triggers to get. If none are provided,
all triggers are returned. URIs are in the format of
``{app_label}.{model_name}:{trigger_name}``.
database (str, default=None): Only get triggers from this
database.
Returns:
List[`pgtrigger.Trigger`]: Matching trigger objects.
"""
if database and uris:
raise ValueError('Cannot supply both trigger URIs and a database')
if not database:
databases = {_get_database(model) for model, _ in registry.values()}
else:
databases = [database] if isinstance(database, str) else database
if uris:
for uri in uris:
if uri and len(uri.split(':')) == 1:
raise ValueError(
'Trigger URI must be in the format of "app_label.model_name:trigger_name"'
)
elif uri and uri not in registry:
raise ValueError(f'URI "{uri}" not found in pgtrigger registry')
return [registry[uri] for uri in uris]
else:
return [
(model, trigger)
for model, trigger in registry.values()
if _get_database(model) in databases
]
[docs]def install(*uris, database=None):
"""
Install triggers.
Args:
*uris (str): URIs of triggers to install. If none are provided,
all triggers are installed and orphaned triggers are pruned.
database (str, default=None): Only install triggers from this
database.
"""
if uris:
model_triggers = get(*uris, database=database)
else:
model_triggers = [
(model, trigger)
for model, trigger in get(database=database)
if trigger.get_installation_status(model)[0] != INSTALLED
]
for model, trigger in model_triggers:
LOGGER.info(
f'pgtrigger: Installing {trigger} trigger'
f' for {model._meta.db_table} table'
f' on {_get_database(model)} database.'
)
trigger.install(model)
if not uris: # pragma: no branch
prune(database=database)
def get_prune_list(database=None):
"""Return triggers that will be pruned upon next full install
Args:
database (str, default=None): Only return results from this
database. Defaults to returning results from all databases
"""
installed = {
(_quote(model._meta.db_table), trigger.get_pgid(model)) for model, trigger in get()
}
if isinstance(database, str):
databases = [database]
else:
databases = database or settings.DATABASES
prune_list = []
for database in _postgres_databases(databases):
with connections[database].cursor() as cursor:
cursor.execute(
'SELECT tgrelid::regclass, tgname, tgenabled'
' FROM pg_trigger'
' WHERE tgname LIKE \'pgtrigger_%\''
)
triggers = set(cursor.fetchall())
prune_list += [
(trigger[0], trigger[1], trigger[2] == 'O', database)
for trigger in triggers
if (_quote(trigger[0]), trigger[1]) not in installed
]
return prune_list
[docs]def prune(database=None):
"""
Remove any pgtrigger triggers in the database that are not used by models.
I.e. if a model or trigger definition is deleted from a model, ensure
it is removed from the database
Args:
database (str, default=None): Only prune triggers from this
database.
"""
for trigger in get_prune_list(database=database):
LOGGER.info(
f'pgtrigger: Pruning trigger {trigger[1]}'
f' for table {trigger[0]} on {trigger[3]} database.'
)
_drop_trigger(trigger[0], trigger[1])
[docs]def enable(*uris, database=None):
"""
Enables registered triggers.
Args:
*uris (str): URIs of triggers to enable. If none are provided,
all triggers are enabled.
database (str, default=None): Only enable triggers from this
database.
"""
if uris:
model_triggers = get(*uris, database=database)
else:
model_triggers = [
(model, trigger)
for model, trigger in get(database=database)
if trigger.get_installation_status(model)[1] is False
]
for model, trigger in model_triggers:
LOGGER.info(
f'pgtrigger: Enabling {trigger} trigger'
f' for {model._meta.db_table} table'
f' on {_get_database(model)} database.'
)
trigger.enable(model)
[docs]def uninstall(*uris, database=None):
"""
Uninstalls triggers.
Args:
*uris (str): URIs of triggers to uninstall. If none are provided,
all triggers are uninstalled and orphaned triggers are pruned.
database (str, default=None): Only uninstall triggers from this
database.
"""
if uris:
model_triggers = get(*uris, database=database)
else:
model_triggers = [
(model, trigger)
for model, trigger in get(database=database)
if trigger.get_installation_status(model)[0] != UNINSTALLED
]
for model, trigger in model_triggers:
LOGGER.info(
f'pgtrigger: Uninstalling {trigger} trigger'
f' for {model._meta.db_table} table'
f' on {_get_database(model)} database.'
)
trigger.uninstall(model)
if not uris:
prune(database=database)
[docs]def disable(*uris, database=None):
"""
Disables triggers.
Args:
*uris (str): URIs of triggers to disable. If none are provided,
all triggers are disabled.
database (str, default=None): Only disable triggers from this
database.
"""
if uris:
model_triggers = get(*uris, database=database)
else:
model_triggers = [
(model, trigger)
for model, trigger in get(database=database)
if trigger.get_installation_status(model)[1]
]
for model, trigger in model_triggers:
LOGGER.info(
f'pgtrigger: Disabling {trigger} trigger for'
f' {model._meta.db_table} table'
f' on {_get_database(model)} database.'
)
trigger.disable(model)
[docs]@contextlib.contextmanager
def ignore(*uris):
"""
Dynamically ignore registered triggers matching URIs from executing in
an individual thread.
If no URIs are provided, ignore all pgtriggers from executing in an
individual thread.
Examples:
Ingore triggers in a context manager::
with pgtrigger.ignore("my_app.Model:trigger_name"):
# Do stuff while ignoring trigger
Ignore multiple triggers as a decorator::
@pgtrigger.ignore("my_app.Model:trigger_name", "my_app.Model:other_trigger")
def my_func():
# Do stuff while ignoring trigger
"""
with contextlib.ExitStack() as stack:
for model, trigger in get(*uris):
stack.enter_context(trigger.ignore(model))
yield