# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license """DNS Versioned Zones.""" import collections try: import threading as _threading except ImportError: # pragma: no cover import dummy_threading as _threading # type: ignore import dns.exception import dns.immutable import dns.name import dns.rdataclass import dns.rdatatype import dns.rdtypes.ANY.SOA import dns.zone class UseTransaction(dns.exception.DNSException): """To alter a versioned zone, use a transaction.""" # Backwards compatibility Node = dns.zone.VersionedNode ImmutableNode = dns.zone.ImmutableVersionedNode Version = dns.zone.Version WritableVersion = dns.zone.WritableVersion ImmutableVersion = dns.zone.ImmutableVersion Transaction = dns.zone.Transaction class Zone(dns.zone.Zone): __slots__ = ['_versions', '_versions_lock', '_write_txn', '_write_waiters', '_write_event', '_pruning_policy', '_readers'] node_factory = Node def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True, pruning_policy=None): """Initialize a versioned zone object. *origin* is the origin of the zone. It may be a ``dns.name.Name``, a ``str``, or ``None``. If ``None``, then the zone's origin will be set by the first ``$ORIGIN`` line in a zone file. *rdclass*, an ``int``, the zone's rdata class; the default is class IN. *relativize*, a ``bool``, determine's whether domain names are relativized to the zone's origin. The default is ``True``. *pruning policy*, a function taking a `Version` and returning a `bool`, or `None`. Should the version be pruned? If `None`, the default policy, which retains one version is used. """ super().__init__(origin, rdclass, relativize) self._versions = collections.deque() self._version_lock = _threading.Lock() if pruning_policy is None: self._pruning_policy = self._default_pruning_policy else: self._pruning_policy = pruning_policy self._write_txn = None self._write_event = None self._write_waiters = collections.deque() self._readers = set() self._commit_version_unlocked(None, WritableVersion(self, replacement=True), origin) def reader(self, id=None, serial=None): # pylint: disable=arguments-differ if id is not None and serial is not None: raise ValueError('cannot specify both id and serial') with self._version_lock: if id is not None: version = None for v in reversed(self._versions): if v.id == id: version = v break if version is None: raise KeyError('version not found') elif serial is not None: if self.relativize: oname = dns.name.empty else: oname = self.origin version = None for v in reversed(self._versions): n = v.nodes.get(oname) if n: rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA) if rds and rds[0].serial == serial: version = v break if version is None: raise KeyError('serial not found') else: version = self._versions[-1] txn = Transaction(self, False, version) self._readers.add(txn) return txn def writer(self, replacement=False): event = None while True: with self._version_lock: # Checking event == self._write_event ensures that either # no one was waiting before we got lucky and found no write # txn, or we were the one who was waiting and got woken up. # This prevents "taking cuts" when creating a write txn. if self._write_txn is None and event == self._write_event: # Creating the transaction defers version setup # (i.e. copying the nodes dictionary) until we # give up the lock, so that we hold the lock as # short a time as possible. This is why we call # _setup_version() below. self._write_txn = Transaction(self, replacement, make_immutable=True) # give up our exclusive right to make a Transaction self._write_event = None break # Someone else is writing already, so we will have to # wait, but we want to do the actual wait outside the # lock. event = _threading.Event() self._write_waiters.append(event) # wait (note we gave up the lock!) # # We only wake one sleeper at a time, so it's important # that no event waiter can exit this method (e.g. via # cancellation) without returning a transaction or waking # someone else up. # # This is not a problem with Threading module threads as # they cannot be canceled, but could be an issue with trio # or curio tasks when we do the async version of writer(). # I.e. we'd need to do something like: # # try: # event.wait() # except trio.Cancelled: # with self._version_lock: # self._maybe_wakeup_one_waiter_unlocked() # raise # event.wait() # Do the deferred version setup. self._write_txn._setup_version() return self._write_txn def _maybe_wakeup_one_waiter_unlocked(self): if len(self._write_waiters) > 0: self._write_event = self._write_waiters.popleft() self._write_event.set() # pylint: disable=unused-argument def _default_pruning_policy(self, zone, version): return True # pylint: enable=unused-argument def _prune_versions_unlocked(self): assert len(self._versions) > 0 # Don't ever prune a version greater than or equal to one that # a reader has open. This pins versions in memory while the # reader is open, and importantly lets the reader open a txn on # a successor version (e.g. if generating an IXFR). # # Note our definition of least_kept also ensures we do not try to # delete the greatest version. if len(self._readers) > 0: least_kept = min(txn.version.id for txn in self._readers) else: least_kept = self._versions[-1].id while self._versions[0].id < least_kept and \ self._pruning_policy(self, self._versions[0]): self._versions.popleft() def set_max_versions(self, max_versions): """Set a pruning policy that retains up to the specified number of versions """ if max_versions is not None and max_versions < 1: raise ValueError('max versions must be at least 1') if max_versions is None: def policy(*_): return False else: def policy(zone, _): return len(zone._versions) > max_versions self.set_pruning_policy(policy) def set_pruning_policy(self, policy): """Set the pruning policy for the zone. The *policy* function takes a `Version` and returns `True` if the version should be pruned, and `False` otherwise. `None` may also be specified for policy, in which case the default policy is used. Pruning checking proceeds from the least version and the first time the function returns `False`, the checking stops. I.e. the retained versions are always a consecutive sequence. """ if policy is None: policy = self._default_pruning_policy with self._version_lock: self._pruning_policy = policy self._prune_versions_unlocked() def _end_read(self, txn): with self._version_lock: self._readers.remove(txn) self._prune_versions_unlocked() def _end_write_unlocked(self, txn): assert self._write_txn == txn self._write_txn = None self._maybe_wakeup_one_waiter_unlocked() def _end_write(self, txn): with self._version_lock: self._end_write_unlocked(txn) def _commit_version_unlocked(self, txn, version, origin): self._versions.append(version) self._prune_versions_unlocked() self.nodes = version.nodes if self.origin is None: self.origin = origin # txn can be None in __init__ when we make the empty version. if txn is not None: self._end_write_unlocked(txn) def _commit_version(self, txn, version, origin): with self._version_lock: self._commit_version_unlocked(txn, version, origin) def _get_next_version_id(self): if len(self._versions) > 0: id = self._versions[-1].id + 1 else: id = 1 return id def find_node(self, name, create=False): if create: raise UseTransaction return super().find_node(name) def delete_node(self, name): raise UseTransaction def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, create=False): if create: raise UseTransaction rdataset = super().find_rdataset(name, rdtype, covers) return dns.rdataset.ImmutableRdataset(rdataset) def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, create=False): if create: raise UseTransaction rdataset = super().get_rdataset(name, rdtype, covers) return dns.rdataset.ImmutableRdataset(rdataset) def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE): raise UseTransaction def replace_rdataset(self, name, replacement): raise UseTransaction