Source code for py_gql.schema.schema

# -*- coding: utf-8 -*-

import copy
from collections import defaultdict
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Sequence,
    Union,
)

from ..exc import SchemaError, UnknownType
from ..lang import ast as _ast
from .directives import SPECIFIED_DIRECTIVES
from .introspection import INTROPSPECTION_TYPES
from .resolver_map import ResolverMap
from .scalars import SPECIFIED_SCALAR_TYPES
from .types import (
    Directive,
    GraphQLAbstractType,
    GraphQLType,
    InputObjectType,
    InterfaceType,
    ListType,
    NamedType,
    NonNullType,
    ObjectType,
    UnionType,
    unwrap_type,
)
from .validation import validate_schema


_SPECIFIED_DIRECTIVE_NAMES = [t.name for t in SPECIFIED_DIRECTIVES]
_PROTECTED_TYPES = SPECIFIED_SCALAR_TYPES + INTROPSPECTION_TYPES


Resolver = Callable[..., Any]


[docs]class Schema(ResolverMap): """ A GraphQL schema definition. A GraphQL schema definition. This is the main container for a GraphQL schema and its related types. Args: query_type: The root query type for the schema mutation_type: The root mutation type for the schema subscription_type: The root subscription type for the schema directives: List of possible directives to use. The default, specified directives (``@include``, ``@skip``) will **always** be included. types: List of additional supported types. This only necessary for types that cannot be inferred by traversing the root types. nodes: AST node for the schema if applicable, i.e. when creating the schema from a GraphQL (SDL) document. Attributes: query_type (Optional[ObjectType]): The root query type for the schema (required). mutation_type (Optional[ObjectType]): The root mutation type for the schema (optional). subscription_type (Optional[ObjectType]): The root subscription type for the schema (optional). nodes (List[Union[_ast.SchemaDefinition, _ast.SchemaExtension]]): AST node for the schema if applicable, i.e. when creating the schema from a GraphQL (SDL) document. types (Dict[str, GraphQLType]): Mapping ``type name -> Type instance`` of all types used in the schema, excluding directives. directives (Dict[str, py_gql.schema.Directive]): Mapping ``directive name -> Directive instance`` of all directives used in the schema. implementations (Dict[str, ObjectType]): Mapping of ``interface name -> [implementing object types]``. """ __slots__ = ( "query_type", "mutation_type", "subscription_type", "nodes", "_possible_types", "_is_valid", "_literal_types_cache", "types", "directives", "implementations", "resolvers", "subscriptions", "default_resolver", "default_resolvers", ) def __init__( self, query_type: Optional[ObjectType] = None, mutation_type: Optional[ObjectType] = None, subscription_type: Optional[ObjectType] = None, directives: Optional[List[Directive]] = None, types: Optional[List[NamedType]] = None, nodes: Optional[ List[Union[_ast.SchemaDefinition, _ast.SchemaExtension]] ] = None, ): super().__init__() self.query_type = query_type self.mutation_type = mutation_type self.subscription_type = subscription_type self.nodes = ( nodes or [] ) # type: List[Union[_ast.SchemaDefinition, _ast.SchemaExtension]] self.directives = _build_directive_map(directives or []) self.types = _build_type_map( [*(types or []), query_type, mutation_type, subscription_type], self.directives.values(), _type_map=_default_type_map(), ) # type: Dict[str, NamedType] self._invalidate_and_rebuild_caches() def _invalidate_and_rebuild_caches(self): self._possible_types = ( {} ) # type: Dict[GraphQLAbstractType, Sequence[ObjectType]] self._is_valid = None # type: Optional[bool] self._literal_types_cache = {} # type: Dict[_ast.Type, GraphQLType] self.implementations = defaultdict( list ) # type: Dict[str, List[ObjectType]] for type_ in self.types.values(): if isinstance(type_, ObjectType): for i in type_.interfaces: self.implementations[i.name].append(type_) def _replace_types_and_directives( self, types: Optional[Dict[str, Optional[NamedType]]] = None, directives: Optional[Dict[str, Optional[Directive]]] = None, ) -> None: busted_cache = False for type_name, new_type in (types or {}).items(): try: original_type = self.types[type_name] except KeyError: pass else: if original_type in _PROTECTED_TYPES: raise SchemaError( "Cannot replace specified type %s" % original_type ) busted_cache = new_type != original_type if new_type is None: del self.types[type_name] else: if type(original_type) != type(new_type): raise SchemaError( "Cannot replace type %r with a different kind of type %r." % (original_type, new_type) ) self.types[type_name] = new_type for directive_name, new_directive in (directives or {}).items(): try: original_directive = self.directives[directive_name] except KeyError: pass else: if original_directive in SPECIFIED_DIRECTIVES: raise SchemaError( "Cannot replace specified directive %s" % original_directive ) if new_directive is None: del self.directives[directive_name] else: self.directives[directive_name] = new_directive # We can safely ignore the potential type error given that if the type # has been replaced we have checked it matches its old kind above. self.query_type = ( self.types.get(self.query_type.name) # type: ignore if self.query_type else None ) self.mutation_type = ( self.types.get(self.mutation_type.name) # type: ignore if self.mutation_type else None ) self.subscription_type = ( self.types.get(self.subscription_type.name) # type: ignore if self.subscription_type else None ) if busted_cache: # Circular import from .fix_type_references import fix_type_references fix_type_references(self) self._invalidate_and_rebuild_caches()
[docs] def validate(self): """ Check that the schema is valid. Raises: :class:`~py_gql.exc.SchemaError` if the schema is invalid. """ if self._is_valid is None: validate_schema(self) self._is_valid = True
[docs] def get_type(self, name: str) -> NamedType: """ Get a type by name. Args: name: Requested type name Returns: py_gql.schema.NamedType: Type instance Raises: UnknownType: if ``default`` is not set and the type is not found. """ try: return self.types[name] except KeyError: raise UnknownType(name)
[docs] def has_type(self, name: str) -> bool: """ Check if the schema contains a type with the given name. """ return name in self.types
[docs] def get_type_from_literal(self, ast_node: _ast.Type) -> GraphQLType: """ Return a :class:`py_gql.schema.Type` instance for an AST type node. For example, if provided the parsed AST node for ``[User]``, a :class:`py_gql.schema.ListType` instance will be returned, containing the type called ``User`` found in the schema. If a type called ``User`` is not found in the schema, then :class:`~py_gql.exc.UnknownType` will be raised. Raises: :class:`~py_gql.exc.UnknownType`: if any named type is not found """ if ast_node in self._literal_types_cache: return self._literal_types_cache[ast_node] if isinstance(ast_node, _ast.ListType): t1 = ListType(self.get_type_from_literal(ast_node.type)) self._literal_types_cache[ast_node] = t1 return t1 elif isinstance(ast_node, _ast.NonNullType): t2 = NonNullType(self.get_type_from_literal(ast_node.type)) self._literal_types_cache[ast_node] = t2 return t2 elif isinstance(ast_node, _ast.NamedType): t3 = self.get_type(ast_node.name.value) self._literal_types_cache[ast_node] = t3 return t3 raise TypeError("Invalid type node %r" % ast_node)
[docs] def get_possible_types( self, abstract_type: GraphQLAbstractType ) -> Sequence[ObjectType]: """ Get the possible implementations of an abstract type. Args: abstract_type: Abstract type to check. Raises: TypeError: when the input type is not an abstract type. Returns: List of possible types. """ if abstract_type in self._possible_types: return self._possible_types[abstract_type] if isinstance(abstract_type, UnionType): self._possible_types[abstract_type] = abstract_type.types or [] return self._possible_types[abstract_type] elif isinstance(abstract_type, InterfaceType): self._possible_types[abstract_type] = self.implementations.get( abstract_type.name, [] ) return self._possible_types[abstract_type] raise TypeError("Not an abstract type: %s" % abstract_type)
[docs] def is_possible_type( self, abstract_type: GraphQLAbstractType, type_: GraphQLType ) -> bool: """ Check that ``type_`` is a possible realization of ``abstract_type``. Returns: ``True`` if ``type_`` is valid for ``abstract_type`` """ if not isinstance(type_, ObjectType): return False return type_ in self.get_possible_types(abstract_type)
[docs] def is_subtype(self, type_, super_type): """ Check if a type is either equal or a subset of a super type (covariant). Args: type_ (py_gql.schema.Type): Target type. super_type (py_gql.schema.Type): Super type. Returns: bool: """ if type_ == super_type: return True if ( isinstance(type_, (ListType, NonNullType)) and isinstance(super_type, (ListType, NonNullType)) and type(type_) == type(super_type) ): return self.is_subtype(type_.type, super_type.type) if isinstance(type_, NonNullType): return self.is_subtype(type_.type, super_type) if isinstance(type_, ListType): return False return ( isinstance(super_type, GraphQLAbstractType) and isinstance(type_, ObjectType) and self.is_possible_type(super_type, type_) )
[docs] def types_overlap(self, rhs: GraphQLType, lhs: GraphQLType) -> bool: """ Determine if two composite types "overlap". Two composite types overlap when the Sets of possible concrete types for each intersect. This is often used to determine if a fragment of a given type could possibly be visited in a context of another type. This function is commutative. """ if rhs == lhs: return True if isinstance(rhs, GraphQLAbstractType) and isinstance( lhs, GraphQLAbstractType ): rhs_types = self.get_possible_types(rhs) lhs_types = self.get_possible_types(lhs) return any((t in lhs_types for t in rhs_types)) return ( isinstance(rhs, GraphQLAbstractType) and self.is_possible_type(rhs, lhs) ) or ( isinstance(lhs, GraphQLAbstractType) and self.is_possible_type(lhs, rhs) )
[docs] def to_string( self, indent: Union[str, int] = 4, include_descriptions: bool = True, include_introspection: bool = False, include_custom_schema_directives: bool = False, ) -> str: """ Format the schema as an SDL string. Refer to :class:`py_gql.sdl.ASTSchemaPrinter` for details. """ from ..sdl import ASTSchemaPrinter return ASTSchemaPrinter( indent=indent, include_descriptions=include_descriptions, include_introspection=include_introspection, include_custom_schema_directives=include_custom_schema_directives, )(self)
[docs] def register_default_resolver( self, typename: str, resolver: Resolver, *, allow_override: bool = False ) -> None: super().register_default_resolver( typename, resolver, allow_override=allow_override ) try: object_type = self.types[typename] except KeyError: raise UnknownType(typename) if not isinstance(object_type, ObjectType): raise SchemaError( 'Cannot assign default resolver to %s "%s".' % (object_type.__class__.__name__, typename) ) if object_type.default_resolver and not allow_override: raise ValueError( 'Type "%s" already has a default resolver.' % (typename,) ) object_type.default_resolver = resolver # Invalidate validation self._is_valid = None
[docs] def register_resolver( self, typename: str, fieldname: str, resolver: Resolver, *, allow_override: bool = False ) -> None: super().register_resolver( typename, fieldname, resolver, allow_override=allow_override ) try: object_type = self.types[typename] except KeyError: raise UnknownType(typename) if not isinstance(object_type, ObjectType): raise SchemaError( 'Cannot assign resolver to %s "%s".' % (object_type.__class__.__name__, typename) ) if fieldname == "*": return try: field = object_type.field_map[fieldname] except KeyError: raise SchemaError( 'Cannot assign resolver to unknown field "%s.%s".' % (typename, fieldname) ) if ( field.resolver is not None and not allow_override and field.resolver is not resolver ): raise ValueError( 'Field "%s" of type "%s" already has a resolver.' % (fieldname, typename) ) field.resolver = resolver # Invalidate validation self._is_valid = None
[docs] def register_subscription( self, typename: str, fieldname: str, resolver: Resolver, *, allow_override: bool = False ) -> None: super().register_subscription( typename, fieldname, resolver, allow_override=allow_override ) try: object_type = self.types[typename] except KeyError: raise UnknownType(typename) if not isinstance(object_type, ObjectType): raise SchemaError( 'Cannot assign subscription to %s "%s".' % (object_type.__class__.__name__, typename) ) try: field = object_type.field_map[fieldname] except KeyError: raise SchemaError( 'Cannot assign subscription to unknown field "%s.%s".' % (typename, fieldname) ) if ( field.subscription_resolver is not None and not allow_override and field.subscription_resolver is not resolver ): raise ValueError( 'Field "%s" of type "%s" already has a subscription.' % (fieldname, typename) ) field.subscription_resolver = resolver # Invalidate validation self._is_valid = None
def clone(self) -> "Schema": cloned = Schema( query_type=self.query_type, mutation_type=self.mutation_type, subscription_type=self.subscription_type, nodes=self.nodes, ) cloned._replace_types_and_directives( types={ t.name: copy.copy(t) for t in self.types.values() if ( t not in SPECIFIED_SCALAR_TYPES and t not in INTROPSPECTION_TYPES ) }, directives={ d.name: copy.copy(d) for d in self.directives.values() if d not in SPECIFIED_DIRECTIVES }, ) cloned.merge_resolvers(self) return cloned
def _build_directive_map(maybe_directives: List[Any]) -> Dict[str, Directive]: directives = { d.name: d for d in SPECIFIED_DIRECTIVES } # Dict[str, Directive] for value in maybe_directives: if not isinstance(value, Directive): raise SchemaError( 'Expected directive but got "%r" of type "%s"' % (value, type(value)) ) name = value.name if name in _SPECIFIED_DIRECTIVE_NAMES: if value is not directives[name]: raise SchemaError( 'Cannot override specified directive "%s"' % name ) continue if name in directives: if value is not directives[name]: raise SchemaError('Duplicate directive "%s"' % name) continue directives[name] = value return directives def _default_type_map() -> Dict[str, NamedType]: types = {} # type: Dict[str, NamedType] types.update({t.name: t for t in SPECIFIED_SCALAR_TYPES}) types.update({t.name: t for t in INTROPSPECTION_TYPES}) return types def _build_type_map( types: Iterable[Optional[GraphQLType]], directives: Optional[Iterable[Directive]] = None, _type_map: Optional[Dict[str, NamedType]] = None, ) -> Dict[str, NamedType]: type_map = ( _type_map if _type_map is not None else {} ) # type: Dict[str, NamedType] for type_ in types: if type_ is None: continue child_types = [] # type: List[GraphQLType] inner_type = unwrap_type(type_) if not isinstance(inner_type, NamedType): raise SchemaError( 'Expected NamedType but got "%s" of type %s' % (inner_type, type(inner_type)) ) name = inner_type.name if name in type_map: if inner_type is not type_map[name]: raise SchemaError('Duplicate type "%s"' % name) continue type_map[name] = inner_type if isinstance(inner_type, UnionType): child_types.extend(inner_type.types) if isinstance(inner_type, ObjectType): child_types.extend(inner_type.interfaces) if isinstance(inner_type, (ObjectType, InterfaceType)): for field in inner_type.fields: child_types.append(field.type) child_types.extend([arg.type for arg in field.arguments or []]) if isinstance(inner_type, InputObjectType): for input_field in inner_type.fields: child_types.append(input_field.type) type_map.update(_build_type_map(child_types, _type_map=type_map)) if directives: directive_types = [] # type: List[GraphQLType] for directive in directives: directive_types.extend( [arg.type for arg in directive.arguments or []] ) type_map.update(_build_type_map(directive_types, _type_map=type_map)) return type_map