Source code for py_gql.schema.schema_visitor

# -*- coding: utf-8 -*-

from typing import Dict, Optional, TypeVar

from .._utils import deprecated, map_and_filter
from .directives import SPECIFIED_DIRECTIVES
from .scalars import SPECIFIED_SCALAR_TYPES
from .schema import Schema
from .types import (
    Argument,
    Directive,
    EnumType,
    EnumValue,
    Field,
    InputField,
    InputObjectType,
    InterfaceType,
    NamedType,
    ObjectType,
    ScalarType,
    UnionType,
)


TType = TypeVar("TType", bound=type)


[docs]class SchemaVisitor(object): """ Base class encoding schema traversal and modifications. Subclass and override the ``on_*`` methods to implement custom behavior. Do not forget to call the superclass methods as it usually encodes how child elements such as field, enum values, etc. are processed. All methods *must* return the modified value; returning ``None`` will drop the respective values from their context, e.g. returning ``None`` from :meth:`on_field` will result in the field being dropped from the parent :class:`py_gql.schema.ObjectType`. Specified types (scalars, introspection) and directives are ignored. """
[docs] def on_schema(self, schema: Schema) -> Schema: """ Process the whole schema. """ updated_types = {} # type: Dict[str, Optional[NamedType]] updated_directives = {} # type: Dict[str, Optional[Directive]] for original_type in schema.types.values(): if ( original_type.name.startswith("__") or original_type in SPECIFIED_SCALAR_TYPES ): continue updated = None # type: Optional[NamedType] if isinstance(original_type, ObjectType): updated = self.on_object(original_type) elif isinstance(original_type, InterfaceType): updated = self.on_interface(original_type) elif isinstance(original_type, InputObjectType): updated = self.on_input_object(original_type) elif isinstance(original_type, ScalarType): updated = self.on_scalar(original_type) elif isinstance(original_type, UnionType): updated = self.on_union(original_type) elif isinstance(original_type, EnumType): updated = self.on_enum(original_type) else: raise TypeError(type(original_type)) if updated is not original_type: updated_types[original_type.name] = updated for original_directive in schema.directives.values(): if original_directive in SPECIFIED_DIRECTIVES: continue updated_directive = self.on_directive(original_directive) if updated_directive is not original_directive: updated_directives[original_directive.name] = updated_directive schema._replace_types_and_directives(updated_types, updated_directives) return schema
def on_scalar(self, scalar_type: ScalarType) -> Optional[ScalarType]: return scalar_type def on_object(self, object_type: ObjectType) -> Optional[ObjectType]: updated_fields = map_and_filter(self.on_field, object_type.fields) if updated_fields != object_type.fields: return ObjectType( object_type.name, updated_fields, interfaces=object_type.interfaces, default_resolver=object_type.default_resolver, description=object_type.description, nodes=object_type.nodes, ) return object_type def on_field(self, field: Field) -> Optional[Field]: updated_args = map_and_filter(self.on_argument, field.arguments) if updated_args != field.arguments: return Field( field.name, field.type, args=updated_args, description=field.description, deprecation_reason=field.deprecation_reason, resolver=field.resolver, subscription_resolver=field.subscription_resolver, node=field.node, python_name=field.python_name, ) return field def on_argument(self, argument: Argument) -> Optional[Argument]: return argument def on_interface( self, interface_type: InterfaceType ) -> Optional[InterfaceType]: updated_fields = map_and_filter(self.on_field, interface_type.fields) if updated_fields != interface_type.fields: return InterfaceType( interface_type.name, updated_fields, resolve_type=interface_type.resolve_type, description=interface_type.description, nodes=interface_type.nodes, ) return interface_type def on_union(self, union_type: UnionType) -> Optional[UnionType]: return union_type def on_enum(self, enum_type: EnumType) -> Optional[EnumType]: updated_values = map_and_filter(self.on_enum_value, enum_type.values) if updated_values != enum_type.values: return EnumType( enum_type.name, values=updated_values, description=enum_type.description, nodes=enum_type.nodes, ) return enum_type def on_enum_value(self, enum_value: EnumValue) -> Optional[EnumValue]: return enum_value def on_input_object( self, input_object_type: InputObjectType ) -> Optional[InputObjectType]: updated_fields = map_and_filter( self.on_input_field, input_object_type.fields ) if updated_fields != input_object_type.fields: return InputObjectType( input_object_type.name, updated_fields, description=input_object_type.description, nodes=input_object_type.nodes, ) return input_object_type def on_input_field(self, field: InputField) -> Optional[InputField]: return field def on_directive(self, directive: Directive) -> Optional[Directive]: updated_args = map_and_filter(self.on_argument, directive.arguments) if updated_args != directive.arguments: return Directive( directive.name, directive.locations, args=updated_args, description=directive.description, node=directive.node, ) return directive on_field_definition = deprecated( "This method has been deprecated, use on_field instead." )(on_field) on_input_field_definition = deprecated( "This method has been deprecated, use on_input_field instead." )(on_input_field) on_argument_definition = deprecated( "This method has been deprecated, use on_argument instead." )(on_argument)