Source code for py_gql.sdl.schema_directives

# -*- coding: utf-8 -*-
"""
Schema Directives

This is largely based on the way Apollo and graphql-tools implement it,
borrowing the same idea of using visitors and treating the schema as graph.
"""

from typing import (
    Dict,
    Iterator,
    List,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    TypeVar,
    Union,
)

from .._utils import flatten
from ..exc import SDLError
from ..lang import ast as _ast
from ..schema import (
    SPECIFIED_DIRECTIVES,
    SPECIFIED_SCALAR_TYPES,
    Argument,
    Directive,
    EnumType,
    EnumValue,
    Field,
    InputField,
    InputObjectType,
    InterfaceType,
    ObjectType,
    ScalarType,
    Schema,
    SchemaVisitor,
    UnionType,
)
from ..utilities import coerce_argument_values


__all__ = ("SchemaDirective", "apply_schema_directives")


SPECIFIED_DIRECTIVE_NAMES = [x.name for x in SPECIFIED_DIRECTIVES]


T = TypeVar("T")
TType = TypeVar("TType", bound=type)

_HasDirectives = Union[
    Argument,
    EnumType,
    EnumValue,
    Field,
    InputField,
    InputObjectType,
    InterfaceType,
    ObjectType,
    ScalarType,
    Schema,
    UnionType,
]

TSchemaDirective = Type["SchemaDirective"]


def _find_directives(definition: _HasDirectives) -> List[_ast.Directive]:
    if isinstance(definition, (Field, Argument, InputField, EnumValue)):
        return definition.node.directives if definition.node is not None else []
    else:
        return list(flatten(n.directives for n in definition.nodes if n))


[docs]class SchemaDirective(SchemaVisitor): """ @directive implementation for use alongside :func:`py_gql.schema.build_schema`. You need to subclass this in order to define your own custom directives. All valid directive locations have a corresponding `on_X` method to implement from :class:`~py_gql.schema.SchemaVisitor`. The definition attributes defines how the definition will be found at runtime. A `Directive` object defines the directive inline, while a string delegates to the schema at build time by name, in which case the directive must be part of the schema it's applied to. """ definition = NotImplemented # type: Union[Directive, str] def __init__(self, args=None): self.args = args or {}
[docs]def apply_schema_directives( schema: Schema, schema_directives: Sequence[TSchemaDirective] ) -> Schema: """ Apply :class:`~py_gql.schema.SchemaDirective` implementers to a given schema. This assumes the provided schema was built from a GraphQL document and contains references to the parse node which contains the actual directive information. Each directive will be instantiated with the arguments extracted from the parse nodes (which is why we need to provide a class here and not an instance of :class:`~py_gql.schema.SchemaDirective`). Warning: Specified types (scalars, introspection) cannot be modified through schema directives. Args: schema: Schema to modify schema_directives: List of schema directives (`~py_gql.schema.SchemaDirective`). Each directive must implement the `definition` attribute. Returns: Modified schema. """ return _SchemaDirectivesApplicationVisitor( schema_directives, schema.directives ).on_schema(schema)
class _SchemaDirectivesApplicationVisitor(SchemaVisitor): def __init__( self, schema_directives: Sequence[TSchemaDirective], directives: Dict[str, Directive], ): self._defs = {} # type: Dict[str, Tuple[Directive, TSchemaDirective]] for sd in schema_directives: if not isinstance(sd, type) or not issubclass(sd, SchemaDirective): raise TypeError( 'Expected SchemaDirective subclass but got "%r"' % sd ) if isinstance(sd.definition, str): try: self._defs[sd.definition] = directives[sd.definition], sd except KeyError: raise SDLError( "Unknown schema directive %s.\n" "The definition attribute must either be an explicit " "Directive instance or a string. When using a string, a " "directive with that name must be present in the schema." % sd.definition ) else: self._defs[sd.definition.name] = sd.definition, sd def _collect_schema_directives( self, definition: _HasDirectives, loc: str ) -> Iterator[SchemaDirective]: applied = set() # type: Set[str] for node in _find_directives(definition): name = node.name.value if name in SPECIFIED_DIRECTIVE_NAMES: continue try: directive_def, schema_directive_cls = self._defs[name] except KeyError: raise SDLError('Unknown directive "@%s"' % name, [node]) if loc not in directive_def.locations: raise SDLError( 'Directive "@%s" not applicable to "%s"' % (name, loc), [node], ) if name in applied: raise SDLError('Directive "@%s" already applied' % name, [node]) args = coerce_argument_values(directive_def, node) applied.add(name) yield schema_directive_cls(args) def on_schema(self, schema: Schema) -> Schema: # Make sure the schema has all the definitions. schema.directives.update({n: d for n, (d, _) in self._defs.items()}) for sd in self._collect_schema_directives(schema, "SCHEMA"): schema = sd.on_schema(schema) return super().on_schema(schema) def on_scalar(self, scalar: ScalarType) -> Optional[ScalarType]: if scalar in SPECIFIED_SCALAR_TYPES: return scalar for sd in self._collect_schema_directives(scalar, "SCALAR"): scalar = sd.on_scalar(scalar) # type: ignore if scalar is None: return None return super().on_scalar(scalar) def on_object(self, object_type: ObjectType) -> Optional[ObjectType]: for sd in self._collect_schema_directives(object_type, "OBJECT"): object_type = sd.on_object(object_type) # type: ignore if object_type is None: return object_type return super().on_object(object_type) def on_field(self, field: Field) -> Optional[Field]: for sd in self._collect_schema_directives(field, "FIELD_DEFINITION"): field = sd.on_field(field) # type: ignore if field is None: return None return super().on_field(field) def on_argument(self, arg: Argument) -> Optional[Argument]: for sd in self._collect_schema_directives(arg, "ARGUMENT_DEFINITION"): arg = sd.on_argument(arg) # type: ignore if arg is None: return None return super().on_argument(arg) def on_interface(self, interface: InterfaceType) -> Optional[InterfaceType]: for sd in self._collect_schema_directives(interface, "INTERFACE"): interface = sd.on_interface(interface) # type: ignore if interface is None: return None return super().on_interface(interface) def on_union(self, union: UnionType) -> Optional[UnionType]: for sd in self._collect_schema_directives(union, "UNION"): union = sd.on_union(union) # type: ignore if union is None: return None return super().on_union(union) def on_enum(self, enum: EnumType) -> Optional[EnumType]: for sd in self._collect_schema_directives(enum, "ENUM"): enum = sd.on_enum(enum) # type: ignore if enum is None: return None return super().on_enum(enum) def on_enum_value(self, enum_value: EnumValue) -> Optional[EnumValue]: for sd in self._collect_schema_directives(enum_value, "ENUM_VALUE"): enum_value = sd.on_enum_value(enum_value) # type: ignore if enum_value is None: return None return super().on_enum_value(enum_value) def on_input_object( self, input_object: InputObjectType ) -> Optional[InputObjectType]: for sd in self._collect_schema_directives(input_object, "INPUT_OBJECT"): input_object = sd.on_input_object(input_object) # type: ignore if input_object is None: return None return super().on_input_object(input_object) def on_input_field(self, field: InputField) -> Optional[InputField]: for sd in self._collect_schema_directives( field, "INPUT_FIELD_DEFINITION" ): field = sd.on_input_field(field) # type: ignore if field is None: return None return super().on_input_field(field)