# -*- coding: utf-8 -*-
from typing import List, Mapping, Optional, Sequence, Tuple, TypeVar, Union
from .._utils import DefaultOrderedDict, OrderedDict, deduplicate, find_one
from ..exc import UnknownEnumValue, UnknownType, ValidationError
from ..lang import ast as _ast
from ..lang.visitor import DispatchingVisitor
from ..schema import (
Argument,
Directive,
EnumType,
EnumValue,
Field,
GraphQLCompositeType,
GraphQLType,
InputObjectType,
InputValue,
InterfaceType,
ObjectType,
ScalarType,
Schema,
is_input_type,
is_output_type,
unwrap_type,
)
from ..schema.introspection import (
SCHEMA_INTROSPECTION_FIELD,
TYPE_INTROSPECTION_FIELD,
TYPE_NAME_INTROSPECTION_FIELD,
)
T = TypeVar("T")
N = TypeVar("N", bound=_ast.Node)
MMap = Mapping[str, Mapping[str, T]]
LMap = Mapping[str, List[T]]
OptList = List[Optional[T]]
VariableUsages = MMap[
Tuple[
_ast.Variable,
Optional[Union[EnumType, ScalarType, InputObjectType]],
Optional[Union[_ast.Argument, _ast.Field]],
]
]
def _peek(
lst: Sequence[T], count: int = 1, default: Optional[T] = None
) -> Optional[T]:
return lst[-1 * count] if len(lst) >= count else default
def _get_field_def(schema, parent_type, field):
name = field.name.value
if parent_type is schema.query_type:
if name == SCHEMA_INTROSPECTION_FIELD.name:
return SCHEMA_INTROSPECTION_FIELD
if name == TYPE_INTROSPECTION_FIELD.name:
return TYPE_INTROSPECTION_FIELD
if (
isinstance(parent_type, GraphQLCompositeType)
and name == TYPE_NAME_INTROSPECTION_FIELD.name
):
return TYPE_NAME_INTROSPECTION_FIELD
if isinstance(parent_type, (ObjectType, InterfaceType)):
return parent_type.field_map.get(name, None)
return None
[docs]class ValidationVisitor(DispatchingVisitor):
"""
Visitor class used for validating GraphQL documents.
Subclass this to implement custom validators. Use :meth:`add_error` to
register errors and :class:`py_gql.lang.visitor.SkipNode` to prevent
validating child nodes when parent node is invalid.
Args:
schema: Schema to validate against (for known types, directives, etc.).
type_info: Type information collector provided by
:func:`~py_gql.validation.validate`.
Attributes:
schema (py_gql.schema.Schema): Schema to validate against
(for known types, directives, etc.).
type_info (TypeInfoVisitor): Type information collector provided by
:func:`~py_gql.validation.validate`.
errors (List[ValidationError]): Collected errors.
"""
def __init__(self, schema: Schema, type_info: "TypeInfoVisitor"):
super(ValidationVisitor, self).__init__()
self.schema = schema
self.type_info = type_info
self.errors = [] # type: List[ValidationError]
[docs] def add_error(
self, message: str, nodes: Optional[Sequence[_ast.Node]] = None
) -> None:
"""
Register an error
Args:
message (str): Error description
nodes (Optional[List[py_gql.lang.ast.Node]]): Nodes where the error
originated from.
"""
self.errors.append(ValidationError(message, nodes))
[docs] def enter(self, node: N) -> N:
super().enter(node)
return node
class VariablesCollector(ValidationVisitor):
"""
Custom validation visitor tracking tracks all variable definitions and usage.
"""
def __init__(self, schema, type_info):
super(VariablesCollector, self).__init__(schema, type_info)
self._op = None
self._op_variables = DefaultOrderedDict(
OrderedDict
) # type: VariableUsages
self._op_defined_variables = DefaultOrderedDict(
OrderedDict
) # type: MMap[_ast.VariableDefinition]
self._op_fragments = DefaultOrderedDict(list) # type: LMap[str]
self._fragment = None
self._fragment_variables = DefaultOrderedDict(
OrderedDict
) # type: VariableUsages
self._fragment_fragments = DefaultOrderedDict(list) # type: LMap[str]
self._in_var_def = False
def enter_operation_definition(self, node):
self._op = node.name.value if node.name else ""
def leave_operation_definition(self, _node):
self._op = None
def enter_fragment_definition(self, node):
self._fragment = node.name.value
def leave_fragment_definition(self, _node):
self._fragment = None
def enter_fragment_spread(self, node):
name = node.name.value
if self._op is not None:
self._op_fragments[self._op].append(name)
elif self._fragment is not None and name != self._fragment:
self._fragment_fragments[self._fragment].append(name)
def enter_variable_definition(self, node):
self._in_var_def = True
if self._op is not None:
name = node.variable.name.value
self._op_defined_variables[self._op][name] = node # type: ignore
def leave_variable_definition(self, _node):
self._in_var_def = False
def enter_variable(self, node):
var = node.name.value
input_type = self.type_info.input_type
input_value_def = self.type_info.input_value_def
if self._in_var_def:
pass
elif self._op is not None:
self._op_variables[self._op][var] = ( # type: ignore
node,
input_type,
input_value_def,
)
elif self._fragment is not None:
self._fragment_variables[self._fragment][var] = ( # type: ignore
node,
input_type,
input_value_def,
)
def _flatten_fragments(self):
for parent, children in self._fragment_fragments.items():
for child in deduplicate(children):
for op in self._op_fragments.keys():
if parent in self._op_fragments[op]:
self._op_fragments[op].append(child)
def leave_document(self, _):
self._flatten_fragments()
# This is a very basic re-implementation of the reference javascript
# implementation which is compatible with our version of AST visitors
# and it can most likley be improved.
class TypeInfoVisitor(DispatchingVisitor):
"""
AST visitor ecurisvely tracking the current types and field definitions.
All tracked types are considered with regards to the provided schema,
however unknown types and other unexpected errors will be downgraded to
null values in order to not crash the traversal. This leaves the consumer
responsible to handle such cases.
Warning:
When using this alongside other visitors (such as when using
:class:`py_gql.lang.visitor.ChainedVisitor`), this visitor **must**
to be the first one to visit the nodes in order for the information
provided donwstream to be accurate.
Args:
schema (py_gql.schema.Schema): Reference schema to extract types from
Attributes:
type: Current type if applicable.
parent_type: Current type if applicable.
input_type: Current input type if applicable (when visiting arguments).
parent_input_type: Current parent input type if applicable
(when visiting input objects).
field: Current field definition if applicable (when visiting object).
input_value_def: Current input value definition (e.g. arg def, input field)
if applicable.
"""
__slots__ = (
"_schema",
"_type_stack",
"_input_type_stack",
"_parent_type_stack",
"_field_stack",
"_input_value_def_stack",
"directive",
"argument",
"enum_value",
)
def __init__(self, schema):
self._schema = schema
self._type_stack = [] # type: OptList[GraphQLCompositeType]
self._parent_type_stack = [] # type: OptList[GraphQLCompositeType]
self._input_type_stack = [] # type: OptList[GraphQLType]
self._field_stack = [] # type: OptList[Field]
self._input_value_def_stack = [] # type: OptList[InputValue]
self.directive = None # type: Optional[Directive]
self.argument = None # type: Optional[Argument]
self.enum_value = None # type: Optional[EnumValue]
@property
def type(self) -> Optional[GraphQLCompositeType]:
return _peek(self._type_stack)
@property
def parent_type(self) -> Optional[GraphQLCompositeType]:
return _peek(self._parent_type_stack, 1)
@property
def input_type(self) -> Optional[GraphQLType]:
return _peek(self._input_type_stack, 1)
@property
def parent_input_type(self) -> Optional[InputObjectType]:
t = _peek(self._input_type_stack, 2)
return t if isinstance(t, InputObjectType) else None
@property
def field(self) -> Optional[Field]:
return _peek(self._field_stack)
@property
def input_value_def(self) -> Optional[InputValue]:
return _peek(self._input_value_def_stack)
def _get_field_def(self, node):
parent_type = self.parent_type
return (
_get_field_def(self._schema, parent_type, node)
if parent_type
else None
)
def _type_from_ast(self, type_node):
try:
return self._schema.get_type_from_literal(type_node)
except UnknownType:
return None
def _leave_input_value(self):
self._input_type_stack.pop()
self._input_value_def_stack.pop()
def enter_selection_set(self, node):
named_type = unwrap_type(self.type) if self.type else None
self._parent_type_stack.append(
named_type if isinstance(named_type, GraphQLCompositeType) else None
)
return node
def leave_selection_set(self, _node):
self._parent_type_stack.pop()
def enter_field(self, node):
field_def = self._get_field_def(node)
self._field_stack.append(field_def)
self._type_stack.append(
field_def.type
if field_def and is_output_type(field_def.type)
else None
)
return node
def leave_field(self, _node):
self._type_stack.pop()
self._field_stack.pop()
def enter_directive(self, node):
self.directive = self._schema.directives.get(node.name.value)
return node
def leave_directive(self, _node):
self.directive = None
def enter_operation_definition(self, node):
type_ = {
"query": self._schema.query_type,
"mutation": self._schema.mutation_type,
"subscription": self._schema.subscription_type,
}.get(node.operation, None)
self._type_stack.append(
type_ if isinstance(type_, ObjectType) else None
)
return node
def leave_operation_definition(self, _node):
self._type_stack.pop()
def enter_fragment_definition(self, node):
type_ = self._type_from_ast(node.type_condition)
self._type_stack.append(type_ if is_output_type(type_) else None)
return node
def leave_fragment_definition(self, _node):
self._type_stack.pop()
def enter_inline_fragment(self, node):
if node.type_condition:
type_ = self._type_from_ast(node.type_condition)
self._type_stack.append(type_ if is_output_type(type_) else None)
else:
self._type_stack.append(
self.type if self.type and is_output_type(self.type) else None
)
return node
def leave_inline_fragment(self, _node):
self._type_stack.pop()
def enter_variable_definition(self, node):
type_ = self._type_from_ast(node.type)
self._input_type_stack.append(type_ if is_input_type(type_) else None)
return node
def leave_variable_definition(self, _node):
self._input_type_stack.pop()
def enter_argument(self, node):
ctx = self.directive or self.field
if ctx:
name = node.name.value
self.argument = find_one(ctx.arguments, lambda a: a.name == name)
self._input_value_def_stack.append(self.argument)
self._input_type_stack.append(
self.argument.type
if self.argument and is_input_type(self.argument.type)
else None
)
else:
self.argument = None
self._input_type_stack.append(None)
self._input_value_def_stack.append(None)
return node
def leave_argument(self, _node):
self.argument = None
self._leave_input_value()
def enter_list_value(self, node):
item_type = unwrap_type(self.input_type) if self.input_type else None
self._input_type_stack.append(
item_type if item_type and is_input_type(item_type) else None
)
# List positions never have a default value.
self._input_value_def_stack.append(None)
return node
def leave_list_value(self, _node):
self._leave_input_value()
def enter_object_field(self, node):
object_type = unwrap_type(self.input_type) if self.input_type else None
if isinstance(object_type, InputObjectType):
name = node.name.value
field_def = find_one(object_type.fields, lambda f: f.name == name)
self._input_value_def_stack.append(field_def)
self._input_type_stack.append(
field_def.type
if field_def and is_input_type(field_def.type)
else None
)
else:
self._input_type_stack.append(None)
self._input_value_def_stack.append(None)
return node
def leave_object_field(self, _node):
self._leave_input_value()
def enter_enum_value(self, node):
enum = unwrap_type(self.input_type) if self.input_type else None
if isinstance(enum, EnumType):
try:
self.enum_value = enum.get_value(node.value)
except UnknownEnumValue:
self.enum_value = None
return node
def leave_enum_value(self, _node):
self.enum_value = None