# -*- coding: utf-8 -*-
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)
from .._string_utils import stringify_path
from .._utils import OrderedDict, apply_middlewares, is_iterable
from ..exc import (
CoercionError,
ResolverError,
ScalarSerializationError,
UnknownEnumValue,
)
from ..lang import ast as _ast
from ..schema import (
EnumType,
Field,
GraphQLAbstractType,
GraphQLCompositeType,
GraphQLType,
ListType,
NonNullType,
ObjectType,
ScalarType,
Schema,
)
from .default_resolver import default_resolver
from .instrumentation import Instrumentation
from .runtime import BlockingRuntime, Runtime
from .wrappers import (
GroupedFields,
ResolutionContext,
ResolveInfo,
ResponsePath,
)
Resolver = Callable[..., Any]
T = TypeVar("T")
G = TypeVar("G")
E = TypeVar("E", bound=Exception)
[docs]class Executor(ResolutionContext):
"""
Core executor class.
This is the core executor class implementing all of the operations necessary
to fulfill a GraphQL query or mutation as defined [in the spec](
https://spec.graphql.org/June2018/#sec-Execution).
"""
__slots__ = ResolutionContext.__slots__ + (
"instrumentation",
"runtime",
"_default_resolver",
)
def __init__(
self,
schema: Schema,
document: _ast.Document,
variables: Dict[str, Any],
context_value: Any,
*,
middlewares: Optional[Sequence[Callable[..., Any]]] = None,
instrumentation: Optional[Instrumentation] = None,
disable_introspection: bool = False,
runtime: Optional[Runtime] = None
):
super().__init__(
schema,
document,
variables,
context_value,
disable_introspection=disable_introspection,
middlewares=middlewares,
)
self.instrumentation = instrumentation or Instrumentation()
self.runtime = runtime or BlockingRuntime()
self._default_resolver = schema.default_resolver or default_resolver
[docs] def field_resolver(
self, parent_type: ObjectType, field_definition: Field
) -> Resolver:
base = (
field_definition.resolver
or parent_type.default_resolver
or self._default_resolver
)
try:
return self._resolver_cache[base]
except KeyError:
wrapped = (
self.runtime.wrap_callable(base)
if base is not self._default_resolver
else base
)
if self._middlewares:
wrapped = apply_middlewares(wrapped, self._middlewares)
self._resolver_cache[base] = wrapped
return wrapped
[docs] def resolve_type(
self, value: Any, info: ResolveInfo, abstract_type: GraphQLAbstractType,
) -> Optional[ObjectType]:
maybe_type = None # type: Optional[Union[ObjectType, str]]
if abstract_type.resolve_type is not None:
maybe_type = abstract_type.resolve_type(
value, self.context_value, info
)
else:
# Default type resolution
maybe_type = (
value.get("__typename__", None)
if isinstance(value, dict)
else getattr(value, "__typename__", None)
)
if maybe_type is None:
maybe_type = type(value).__name__
if isinstance(maybe_type, str):
return self.schema.get_type(maybe_type) # type: ignore
else:
return maybe_type
[docs] def resolve_field(
self,
parent_type: ObjectType,
parent_value: Any,
field_definition: Field,
nodes: List[_ast.Field],
path: ResponsePath,
) -> Any:
resolver = self.field_resolver(parent_type, field_definition)
node = nodes[0]
info = ResolveInfo(
field_definition, path, parent_type, nodes, self.runtime, self
)
self.instrumentation.on_field_start(
parent_value, self.context_value, info
)
def fail(err):
self.add_error(err, path, node)
self.instrumentation.on_field_end(
parent_value, self.context_value, info
)
return None
def complete(res):
self.instrumentation.on_field_end(
parent_value, self.context_value, info
)
return self.complete_value(
field_definition.type, nodes, path, info, res
)
try:
coerced_args = self.argument_values(field_definition, node)
except CoercionError as err:
return fail(err)
try:
return self.runtime.unwrap_value(
self.runtime.map_value(
self.runtime.unwrap_value(
resolver(
parent_value,
self.context_value,
info,
**coerced_args,
)
),
complete,
else_=(ResolverError, fail),
)
)
except ResolverError as err:
return fail(err)
def _iterate_fields(
self, parent_type: ObjectType, fields: GroupedFields
) -> Iterator[Tuple[str, Field, List[_ast.Field]]]:
for key, nodes in fields.items():
field_def = self.field_definition(parent_type, nodes[0].name.value)
if field_def is None:
continue
yield key, field_def, nodes
[docs] def execute_fields(
self,
parent_type: ObjectType,
root: Any,
path: ResponsePath,
fields: GroupedFields,
) -> Any:
keys = []
pending = []
for key, field_def, nodes in self._iterate_fields(parent_type, fields):
resolved = self.resolve_field(
parent_type, root, field_def, nodes, path + [key]
)
keys.append(key)
pending.append(resolved)
def _collect(done):
return OrderedDict(zip(keys, done))
return self.runtime.map_value(
self.runtime.gather_values(pending), _collect
)
[docs] def execute_fields_serially(
self,
parent_type: ObjectType,
root: Any,
path: ResponsePath,
fields: GroupedFields,
) -> Any:
resolved_fields = OrderedDict() # type: Dict[str, Any]
args = list(self._iterate_fields(parent_type, fields))
def _next():
try:
k, f, n = args.pop(0)
except IndexError:
return resolved_fields
else:
def cb(value):
resolved_fields[k] = value
return _next()
return self.runtime.map_value(
self.resolve_field(parent_type, root, f, n, path + [k]), cb
)
return _next()
[docs] def complete_list_value(
self,
inner_type: GraphQLType,
nodes: List[_ast.Field],
path: ResponsePath,
info: ResolveInfo,
resolved_value: Any,
) -> Any:
return self.runtime.gather_values(
self.complete_value(inner_type, nodes, path + [index], info, entry)
for index, entry in enumerate(resolved_value)
)
[docs] def complete_non_nullable_value(
self,
inner_type: GraphQLType,
nodes: List[_ast.Field],
path: ResponsePath,
info: ResolveInfo,
resolved_value: Any,
) -> Any:
return self.runtime.map_value(
self.complete_value(inner_type, nodes, path, info, resolved_value),
lambda r: self._handle_non_nullable_value(nodes, path, r),
)
[docs] def complete_value( # noqa: C901
self,
field_type: GraphQLType,
nodes: List[_ast.Field],
path: ResponsePath,
info: ResolveInfo,
resolved_value: Any,
) -> Any:
if isinstance(field_type, NonNullType):
return self.complete_non_nullable_value(
field_type.type, nodes, path, info, resolved_value
)
if resolved_value is None:
return None
if isinstance(field_type, ListType):
if not is_iterable(resolved_value, False):
raise RuntimeError(
'Field "%s" is a list type and resolved value should be '
"iterable" % stringify_path(path)
)
return self.complete_list_value(
field_type.type, nodes, path, info, resolved_value
)
if isinstance(field_type, ScalarType):
try:
return field_type.serialize(resolved_value)
except ScalarSerializationError as err:
raise RuntimeError(
'Field "%s" cannot be serialized as "%s": %s'
% (stringify_path(path), field_type, err)
) from err
if isinstance(field_type, EnumType):
try:
return field_type.get_name(resolved_value)
except UnknownEnumValue as err:
raise RuntimeError(
'Field "%s" cannot be serialized as "%s": %s'
% (stringify_path(path), field_type, err)
) from err
if isinstance(field_type, GraphQLCompositeType):
if isinstance(field_type, GraphQLAbstractType):
runtime_type = self.resolve_type(
resolved_value, info, field_type
)
if not isinstance(runtime_type, ObjectType):
raise RuntimeError(
'Abstract type "%s" must resolve to an ObjectType at '
'runtime for field "%s". Received "%s"'
% (field_type, stringify_path(path), runtime_type)
)
# Backup check in case of badly implemented `resolve_type`
if not self.schema.is_possible_type(field_type, runtime_type):
raise RuntimeError(
'Runtime ObjectType "%s" is not a possible type for '
'field "%s" of type "%s".'
% (runtime_type, stringify_path(path), field_type)
)
else:
runtime_type = cast(ObjectType, field_type)
return self.execute_fields(
runtime_type,
resolved_value,
path,
self.collect_fields(
runtime_type,
[
selection
for field in nodes
if field.selection_set
for selection in field.selection_set.selections
],
),
)
raise TypeError(
"Invalid field type %s at %s" % (field_type, stringify_path(path))
)
def _handle_non_nullable_value(
self, nodes: List[_ast.Field], path: ResponsePath, resolved_value: Any
) -> Any:
if resolved_value is None:
# REVIEW: Shouldn't this be a RuntimeError? As in the developer
# should never return a null non nullable field, raising explicitely
# if the query lead to this behavior could be valid outcome.
self.add_error(
ResolverError(
'Field "%s" is not nullable' % stringify_path(path),
nodes=nodes,
path=path,
)
)
return resolved_value