Source code for aside.boilerplate.observable

"""Contains implementation of observable-related classes."""

import re
import uuid
from bisect import bisect
from collections.abc import Mapping, MutableSet
from enum import Enum, auto
from functools import partial, reduce, wraps
from operator import getitem
from typing import (
    Callable,
    Dict,
    Iterable,
    Iterator,
    List,
    Optional,
    Pattern,
    Set,
    Type,
    Union,
)

import attr
from typing_extensions import Protocol, runtime_checkable

from .attributes import attrib, attrs

__all__ = [
    "observable",
    "Observable",
    "ObservableCollection",
    "Event",
    "EventType",
    "Observer",
    "uuid_attrib",
]


[docs]def uuid_str() -> str: """Generate a random UUID and convert it into a string.""" return str(uuid.uuid4())
uuid_attrib = partial(attrib, factory=uuid_str, on_setattr=attr.setters.frozen)
[docs]class EventType(Enum): """Possible event types.""" REASSIGN = auto() CHANGE = auto() ADD = auto() DISCARD = auto()
[docs]@attrs class Event: """Describes the cause of the observable event.""" attr_name: str obj: "Observable" event_type: EventType
[docs] def get_nested_object(self, level: Optional[int] = None): """Get nested object of event. Args: level: Behaves as a list index, negative indices are valid. """ return reduce(getitem, self.split_attr_path[:level], self.obj)
@property def split_attr_path(self): """Get list of consecutive attributes leading to source of event.""" return self.attr_name.split("/")
# ToDo: replace `Optional[Any]` with `None` after # [this PR](https://github.com/bloomberg/attrs-strict/pull/66) # gets merged. # FIXME: attrs_strict won't let partial function pass if we specify Callable prototype ObserverCallable = Callable # [[Event], Optional[Any]]
[docs]@attrs class Observer: """Contains information about a subscribed observer.""" callback: ObserverCallable regexp: Pattern order: int event_types: Set[EventType]
[docs]@runtime_checkable class Observable(Protocol): """Common protocol/interface for all observable classes.""" uuid: str __observers__: List[Observer] _owner_handler: Optional[Callable] = None @property def __owner_handler__(self) -> Optional[Callable]: """Access parent's event forwarding function. Needed for event propagation.""" return self._owner_handler @__owner_handler__.setter def __owner_handler__(self, value: Optional[Callable]) -> None: if value is not None and self._owner_handler is not None: raise AttributeError("Observable object is owned by somebody else!") self._owner_handler = value def __forward_event__(self, event: Event, child_path): """Emit event received from child adjusting event's fields.""" self.__emit_event__( Event( attr_name=f"{child_path}/{event.attr_name}", obj=self, event_type=event.event_type, ) ) def __emit_event__(self, event: Event) -> None: """Call all matching subscribers' callbacks.""" for observer in self.__observers__: if ( observer.regexp.fullmatch(event.attr_name) and event.event_type in observer.event_types ): observer.callback(event) if self.__owner_handler__ is not None: self.__owner_handler__(event)
[docs] def subscribe( self, callback, regexp: str = ".*", order: int = 0, event_types: Optional[Iterable[EventType]] = None, ) -> None: """Subscribe to changes of Observable object.""" if event_types is None: event_types = { EventType.REASSIGN, EventType.CHANGE, EventType.ADD, EventType.DISCARD, } subscriber = Observer( callback=callback, regexp=re.compile(regexp), order=order, event_types=set(event_types), ) self.__observers__.insert( bisect([el.order for el in self.__observers__], subscriber.order), subscriber, )
[docs] def drop_observers(self): """Drop list of tracked observers, as if no subscription were ever made.""" self.__observers__ = []
[docs]def wrap_init(old_init: Callable) -> Callable: @wraps(old_init) def new_init(self, *args, **kwargs) -> None: self.__observers__ = [] old_init(self, *args, **kwargs) for name in self.__observable_attrs__: getattr(self, name).__owner_handler__ = partial( self.__forward_event__, child_path=name ) return new_init
[docs]def wrap_setattr(old_setattr: Callable) -> Callable: @wraps(old_setattr) def observable_setattr(self, name, value) -> None: if name not in self.__attrs_names__: old_setattr(self, name, value) return old_value = getattr(self, name) is_observable = name in self.__observable_attrs__ if is_observable: old_value.__owner_handler__ = None old_setattr(self, name, value) new_value = getattr(self, name) if old_value == new_value: event_type = EventType.REASSIGN else: event_type = EventType.CHANGE self.__emit_event__(Event(attr_name=name, obj=self, event_type=event_type)) if is_observable: new_value.__owner_handler__ = partial( self.__forward_event__, child_path=name ) return observable_setattr
[docs]def observable(cls: type) -> Union[Type[Observable], type]: """Make the attributes of an attrs-style `Observable` class also observable.""" assert Observable in cls.__bases__ def observable_getitem(self, name): if name in self.__attrs_names__: return getattr(self, name) raise AttributeError cls.__getitem__ = observable_getitem cls.__observable_attrs__ = frozenset( attr.name for attr in cls.__attrs_attrs__ if Observable in attr.type.__bases__ ) cls.__attrs_names__ = frozenset(attr.name for attr in cls.__attrs_attrs__) cls.__init__ = wrap_init(cls.__init__) cls.__setattr__ = wrap_setattr(cls.__setattr__) return cls
# pylint: disable=too-many-ancestors
[docs]class ObservableCollection(MutableSet, Mapping, Observable): """Observable collection of Observable objects."""
[docs] def __init__(self, init_vals: Optional[Iterable[Observable]] = None): """Construct collection and populate it from init_vals, if any.""" self._storage: Dict[str, Observable] = {} self.__observers__: List[Observer] = [] if init_vals is not None: for elem in init_vals: self.add(elem)
[docs] def add(self, value: Observable) -> None: """Add value to collection, overwriting old value with same UUID, if any.""" was_already_in = value.uuid in self._storage if was_already_in: self._storage[value.uuid].__owner_handler__ = None if self._storage[value.uuid] == value: event_type = EventType.REASSIGN else: self.discard(value) event_type = EventType.ADD else: event_type = EventType.ADD self._storage[value.uuid] = value value.__owner_handler__ = partial(self.__forward_event__, child_path=value.uuid) self.__emit_event__( Event(attr_name=value.uuid, obj=self, event_type=event_type) )
[docs] def discard(self, value: Observable) -> None: """Discard object with same UUID from collection.""" if value.uuid in self._storage: self._storage[value.uuid].__owner_handler__ = None self._storage.pop(value.uuid) self.__emit_event__( Event(attr_name=value.uuid, obj=self, event_type=EventType.DISCARD) )
def __contains__(self, x: Observable) -> bool: """Check if object with same UUID is in collection.""" return x.uuid in self._storage def __getitem__(self, k: str): """Get object by it's UUID.""" return self._storage[k] def __len__(self) -> int: """Return number of objects on collection.""" return len(self._storage) def __iter__(self) -> Iterator[str]: """Return iterator for objects in collection, with no particular order.""" return iter(self._storage.keys()) def __repr__(self): """Return string representation of collection and its contents. Can be `eval`ed to get equal collection. """ return ( f"ObservableCollection([" f"{', '.join(repr(self[elem]) for elem in self.__iter__())}" f"])" )