Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,34 @@ def _populate1(

self._upstream = Diagram.trace(self & dict(key))

# If strict_provenance is on, push the active-make context so the
# runtime gates in expression.cursor / table.insert can check this
# make()'s reads and writes. The context is popped in the finally
# block below.
strict_token = None
if self.connection._config.get("strict_provenance", False):
from .provenance import push_strict_make_context
from .user_tables import Part

allowed_tables = set(self._upstream._cascade_restrictions.keys()) | {self.full_table_name}
# Add Part tables of self to the allowed set. Use class __dict__
# (not dir/getattr) to avoid triggering descriptors like the
# _JobsDescriptor that lazy-declares the ~~ job table.
for cls in type(self).__mro__:
for attr_name, attr in cls.__dict__.items():
if attr_name.startswith("_"):
continue
if isinstance(attr, type) and issubclass(attr, Part):
# Instantiate to get full_table_name resolved against
# this schema. The Part class is already attached via
# @schema decoration of the master.
try:
part_ftn = attr().full_table_name
allowed_tables.add(part_ftn)
except Exception:
pass
strict_token = push_strict_make_context(self, frozenset(allowed_tables), dict(key))

try:
if not is_generator:
make(dict(key), **(make_kwargs or {}))
Expand Down Expand Up @@ -719,6 +747,11 @@ def _populate1(
# access raises a clear error rather than silently using a
# stale trace from the previous make() call.
self._upstream = None
# Pop the strict-make context, if any.
if strict_token is not None:
from .provenance import pop_strict_make_context

pop_strict_make_context(strict_token)

def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int]:
"""
Expand Down
6 changes: 6 additions & 0 deletions src/datajoint/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,12 @@ def cursor(self, as_dict=False):
cursor
Database query cursor.
"""
# Strict-provenance read gate. No-op outside make() or when the
# config flag is off. See src/datajoint/provenance.py.
from .provenance import assert_read_allowed

assert_read_allowed(self)

sql = self.make_sql()
logger.debug(sql)
return self.connection.query(sql, as_dict=as_dict)
Expand Down
206 changes: 206 additions & 0 deletions src/datajoint/provenance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""
Runtime gates for ``dj.config["strict_provenance"]``.

When the flag is enabled, this module's context (set by ``AutoPopulate._populate_one``)
tracks which tables and primary key the currently-executing ``make()`` is
allowed to read and write. The read gate in :func:`assert_read_allowed`
fires inside ``QueryExpression.cursor``. The write gate has two parts: the
target check in :func:`assert_write_allowed` fires inside ``Table.insert``
(before rows are materialized), and the per-row key-consistency check in
:func:`assert_row_key_allowed` fires inside ``Table._insert_rows`` as each row
is materialized — so the gate never consumes the caller's ``rows`` iterable.

The contract is documented in
``datajoint-docs/src/reference/specs/provenance.md`` §3.

Implementation note: the active-make context is stored in a
``contextvars.ContextVar`` so it propagates correctly across threads
that share the parent's context (e.g. the populate-in-subprocess path
which uses ``multiprocessing`` workers, each of which inherits its
parent's contextvar binding at fork time).
"""

from __future__ import annotations

from contextvars import ContextVar
from typing import TYPE_CHECKING, Optional, Tuple

from .errors import DataJointError

if TYPE_CHECKING:
from .table import Table


# Active context: (the target table, the set of allowed full table names, the current key dict)
_active_strict_make: ContextVar[Optional[Tuple["Table", frozenset[str], dict]]] = ContextVar(
"_dj_active_strict_make", default=None
)


def push_strict_make_context(target: "Table", allowed_tables: frozenset[str], key: dict):
"""
Push a strict-make context for the duration of one ``make()`` invocation.

Returns a token that the caller must pass to :func:`pop_strict_make_context`
in a ``finally`` block.
"""
return _active_strict_make.set((target, allowed_tables, key))


def pop_strict_make_context(token) -> None:
"""Pop the strict-make context using a token from :func:`push_strict_make_context`."""
_active_strict_make.reset(token)


def get_active_context():
"""Return the currently-active strict-make context, or None."""
return _active_strict_make.get()


def _base_tables(query_expression) -> set[str]:
"""
Return the set of base-table SQL names that a QueryExpression reads from.

For a single-table expression (FreeTable / Table / restricted variants),
returns ``{full_table_name}``. For compound expressions (joins,
projections of joins), traverses ``support`` recursively.
"""
# FreeTable / Table: has full_table_name directly
ftn = getattr(query_expression, "full_table_name", None)
if isinstance(ftn, str):
return {ftn}

bases: set[str] = set()
support = getattr(query_expression, "_support", None) or []
for s in support:
if isinstance(s, str):
# Direct table name in the support list
bases.add(s)
else:
# Subquery — recurse
bases.update(_base_tables(s))
return bases


def assert_read_allowed(query_expression) -> None:
"""
Verify a fetch is allowed under the active strict-make context.

Called from ``QueryExpression.cursor`` before SQL is issued. No-op when
no strict-make context is active (i.e. outside ``make()`` or when
``strict_provenance`` is False).

Allowed reads:

- Any table in the active context's ``allowed_tables`` set. The set is
built from ``self.upstream`` (the ancestor graph) plus the target
table and its Parts.

Anything else raises ``DataJointError``.

Known limitation (will sharpen in a follow-up): the check does not
distinguish reads that came *through* ``self.upstream`` from reads of
the same ancestor via a direct expression. Both are allowed if the
table is in the allowed set. The intent is to catch reads from
*undeclared* dependencies; tightening the "must come through
``self.upstream``" path requires propagating an attribution marker
through QueryExpression composition and is deferred.
"""
ctx = _active_strict_make.get()
if ctx is None:
return # strict mode off, or outside make()

_target, allowed_tables, _key = ctx
bases = _base_tables(query_expression)
if not bases:
return # nothing to check (e.g. dj.U expressions)

disallowed = bases - allowed_tables
if disallowed:
raise DataJointError(
f"strict_provenance=True: read from undeclared table(s) "
f"{sorted(disallowed)} is not permitted inside make(). "
f"Use self.upstream[T] for declared ancestors, or declare a "
f"foreign-key dependency on the table you want to read."
)


def assert_write_allowed(target_table) -> None:
"""
Verify the *target* of an insert is allowed under the active strict-make context.

Called from ``Table.insert`` after the existing ``_allow_insert`` check and
before any rows are materialized. No-op when no strict-make context is active.

Allowed targets:

- The current ``make()`` target (``self``) or one of its Part tables.

Per-row key consistency is checked separately by :func:`assert_row_key_allowed`
as rows are materialized, so this gate never consumes the caller's ``rows``
iterable — a one-shot generator must survive to reach ``insert``.

Raises ``DataJointError`` if the target is not permitted.
"""
ctx = _active_strict_make.get()
if ctx is None:
return

make_target, _allowed_tables, _key = ctx

# Target must be `make_target` (self) or one of its Parts.
target_name = getattr(target_table, "full_table_name", None)
target_set = {make_target.full_table_name}
# Collect Part tables of make_target via class __dict__ (not dir/getattr,
# which would trigger descriptors like the _JobsDescriptor).
from .user_tables import Part # local import to avoid circular dep

for cls in type(make_target).__mro__:
for attr_name, attr in cls.__dict__.items():
if attr_name.startswith("_"):
continue
if isinstance(attr, type) and issubclass(attr, Part):
try:
part_ftn = attr().full_table_name
target_set.add(part_ftn)
except Exception:
pass

if target_name not in target_set:
raise DataJointError(
f"strict_provenance=True: insert into {target_name!r} is not permitted "
f"inside make() for {make_target.full_table_name!r}. Only the target "
f"table and its Part tables may be written."
)


def assert_row_key_allowed(row) -> None:
"""
Verify a single insert row's key columns match the active ``make()`` key.

Called per row from ``Table._insert_rows`` as rows are materialized, so the
check sees a concrete row without the write gate having to consume the
caller's ``rows`` iterable. No-op when no strict-make context is active or
when ``row`` is not a dict (numpy records / bare sequences carry no field
names to check by — same as the previous behavior).

Raises ``DataJointError`` on a mismatch.
"""
ctx = _active_strict_make.get()
if ctx is None:
return
if not isinstance(row, dict):
return
_make_target, _allowed_tables, key = ctx
_check_row_key(row, key)


def _check_row_key(row: dict, current_key: dict) -> None:
"""Raise if any row attribute overlapping with the current key has a different value."""
for k, v in current_key.items():
if k in row and row[k] != v:
raise DataJointError(
f"strict_provenance=True: inserted row's {k!r}={row[k]!r} does not "
f"match the current make() key's {k!r}={v!r}. Inserts must be "
f"consistent with the key being populated."
)
11 changes: 11 additions & 0 deletions src/datajoint/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"database.database_prefix": "DJ_DATABASE_PREFIX",
"database.create_tables": "DJ_CREATE_TABLES",
"loglevel": "DJ_LOG_LEVEL",
"strict_provenance": "DJ_STRICT_PROVENANCE",
"display.diagram_direction": "DJ_DIAGRAM_DIRECTION",
}

Expand Down Expand Up @@ -361,6 +362,16 @@ class Config(BaseSettings):
"*New in 2.2.3.*",
)

strict_provenance: bool = Field(
default=False,
validation_alias="DJ_STRICT_PROVENANCE",
description="If True, enforces the upstream-only convention inside make(): "
"reads must go through self.upstream[Ancestor], writes must target self "
"or self's Part tables with primary keys consistent with the current key. "
"Off by default; opt-in for deployments that need runtime provenance "
"guarantees backing downstream lineage / CDC tooling. *New in 2.3.*",
)

# Cache path for query results
query_cache: Path | None = None

Expand Down
27 changes: 25 additions & 2 deletions src/datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,10 +797,23 @@ def insert(
" To override, set keyword argument allow_direct_insert=True."
)

# Strict-provenance write gate (target check only). No-op outside make()
# or when the config flag is off. Deliberately does NOT touch `rows` —
# the per-row key-consistency check happens in `_insert_rows` as rows are
# materialized, so a one-shot iterable (generator) is not consumed here.
# See src/datajoint/provenance.py.
from .provenance import assert_write_allowed

assert_write_allowed(self)

if inspect.isclass(rows) and issubclass(rows, QueryExpression):
rows = rows() # instantiate if a class
if isinstance(rows, QueryExpression):
# insert from select - chunk_size not applicable
# insert from select - chunk_size not applicable.
# Note: this INSERT ... SELECT runs entirely server-side, so under
# strict_provenance the per-row key-consistency check does not apply
# (row values are never materialized client-side). The target check
# in assert_write_allowed above still governs which table is written.
if chunk_size is not None:
raise DataJointError("chunk_size is not supported for QueryExpression inserts")
if not ignore_extra_fields:
Expand Down Expand Up @@ -855,7 +868,17 @@ def _insert_rows(self, rows, replace, skip_duplicates, ignore_extra_fields):
"""
# collects the field list from first row (passed by reference)
field_list = []
rows = list(self.__make_row_to_insert(row, field_list, ignore_extra_fields) for row in rows)
# Strict-provenance per-row key check runs here, as each row is
# materialized — no-op outside make()/when the flag is off. Placing it in
# this single materialization point (reached by both the chunked and
# single-batch paths) avoids consuming the caller's `rows` iterable early.
from .provenance import assert_row_key_allowed

def _make_row(row):
assert_row_key_allowed(row)
return self.__make_row_to_insert(row, field_list, ignore_extra_fields)

rows = list(_make_row(row) for row in rows)
if rows:
try:
# Handle empty field_list (all-defaults insert)
Expand Down
Loading
Loading