Source code for py_gql.utilities.collect_fields

# -*- coding: utf-8 -*-
import fnmatch
import re
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Mapping,
    Optional,
    Sequence,
    Set,
    TypeVar,
    Union,
)
from typing.re import Pattern

from .._utils import OrderedDict
from ..lang import ast
from ..schema import (
    GraphQLAbstractType,
    IncludeDirective,
    ObjectType,
    Schema,
    SkipDirective,
)
from ..utilities import directive_arguments


T = TypeVar("T")
GroupedFields = Dict[str, List[ast.Field]]

InclueCallable = Callable[
    [Union[ast.Field, ast.InlineFragment, ast.FragmentSpread]], bool
]


def collect_fields(
    schema: Schema,
    object_type: ObjectType,
    selections: Sequence[ast.Selection],
    fragments: Mapping[str, ast.FragmentDefinition],
    variables: Mapping[str, Any],
    _seen_fragments: Optional[Set[str]] = None,
) -> GroupedFields:
    _seen_fragments = _seen_fragments or set()
    grouped_fields = OrderedDict()  # type: GroupedFields

    for selection in selections:
        if isinstance(selection, ast.Field):
            if _skip_selection(selection, variables):
                continue

            key = selection.response_name

            if key not in grouped_fields:
                grouped_fields[key] = []

            grouped_fields[key].append(selection)

        elif isinstance(selection, ast.InlineFragment):
            if _skip_selection(
                selection, variables
            ) or not _fragment_type_applies(schema, object_type, selection):
                continue

            _merge(
                collect_fields(
                    schema,
                    object_type,
                    selection.selection_set.selections,
                    fragments,
                    variables,
                    _seen_fragments,
                ),
                into=grouped_fields,
            )

        elif isinstance(selection, ast.FragmentSpread):
            name = selection.name.value
            # This should usually be used after validation (given a document
            # and schema are required) and so we expect fragments to be present.
            fragment = fragments[name]

            if (
                _skip_selection(selection, variables)
                or name in _seen_fragments
                or not _fragment_type_applies(schema, object_type, fragment)
            ):
                continue

            _merge(
                collect_fields(
                    schema,
                    object_type,
                    fragment.selection_set.selections,
                    fragments,
                    variables,
                    _seen_fragments,
                ),
                into=grouped_fields,
            )
            _seen_fragments.add(name)

    return grouped_fields


def collect_fields_untyped(
    selections: Sequence[ast.Selection],
    fragments: Mapping[str, ast.FragmentDefinition],
    variables: Mapping[str, Any],
    _seen_fragments: Optional[Set[str]] = None,
) -> GroupedFields:
    _seen_fragments = _seen_fragments or set()
    grouped_fields = OrderedDict()  # type: GroupedFields

    for selection in selections:
        if isinstance(selection, ast.Field):
            if _skip_selection(selection, variables):
                continue

            key = selection.response_name

            if key not in grouped_fields:
                grouped_fields[key] = []

            grouped_fields[key].append(selection)

        elif isinstance(selection, ast.InlineFragment):
            if _skip_selection(selection, variables):
                continue

            _merge(
                collect_fields_untyped(
                    selection.selection_set.selections,
                    fragments,
                    variables,
                    _seen_fragments,
                ),
                into=grouped_fields,
            )

        elif isinstance(selection, ast.FragmentSpread):
            name = selection.name.value
            if _skip_selection(selection, variables) or name in _seen_fragments:
                continue

            try:
                fragment = fragments[name]
            except KeyError:
                # As we don't typecheck or validate, this could go through
                # invalid fragments.
                continue

            _merge(
                collect_fields_untyped(
                    fragment.selection_set.selections,
                    fragments,
                    variables,
                    _seen_fragments,
                ),
                into=grouped_fields,
            )

            _seen_fragments.add(name)

    return grouped_fields


def _merge(groups: Dict[str, List[T]], *, into: Dict[str, List[T]]) -> None:
    for key, collected in groups.items():
        if key not in into:
            into[key] = []

        into[key].extend(collected)


def _fragment_type_applies(
    schema: Schema,
    object_type: ObjectType,
    fragment: Union[ast.InlineFragment, ast.FragmentDefinition],
) -> bool:
    type_condition = fragment.type_condition
    if not type_condition:
        return True

    fragment_type = schema.get_type_from_literal(type_condition)
    return (fragment_type == object_type) or (
        isinstance(fragment_type, GraphQLAbstractType)
        and schema.is_possible_type(fragment_type, object_type)
    )


def _skip_selection(
    node: Union[ast.Field, ast.InlineFragment, ast.FragmentSpread],
    variables: Mapping[str, Any],
) -> bool:
    skip = directive_arguments(SkipDirective, node, variables=variables)
    include = directive_arguments(IncludeDirective, node, variables=variables)
    skipped = skip is not None and skip["if"]
    included = include is None or include["if"]
    return skipped or (not included)


[docs]def selected_fields( field: ast.Field, *, fragments: Mapping[str, ast.FragmentDefinition], variables: Mapping[str, Any], maxdepth: Optional[int] = 1, pattern: Optional[Union[str, Pattern]] = None, _path: Optional[List[str]] = None ) -> List[str]: """ Extract a list of field paths from an object field and provided fragments. If ``maxdepth`` is 0 or higher than 1, subfields will be traversed recursively and exposed as a ``/`` separated path. For example, considering the root field of the following document: .. code-block:: graphql query { field { foo { bar { baz } } } } Calling ``selected_fields(..., maxdepth=0)`` will yield ``['foo', 'foo/bar', 'foo/bar/baz']``. Args: field: Root field fragments: Document fragments variables: Operation variables maxdepth: Control how deep the traversal should go. If set to 0, then traversal will proceed as deep as possible. pattern: Filter string used to control which fields are returned. If this is passed as a string, it will be compiled into a regex through the :py:mod:`fnmatch` module. _path: Already traversed path, used for recursive calls. Returns: List of selected field paths. """ if field.selection_set is None: return [] _path = _path or [] fieldnames = [] collected = collect_fields_untyped( field.selection_set.selections, fragments, variables ) if isinstance(pattern, str): pattern = re.compile(fnmatch.translate(pattern)) for _, fields in collected.items(): child_field = fields[0] child_path = [*_path, child_field.name.value] joined = "/".join(child_path) if pattern is None or pattern.match(joined): fieldnames.append(joined) if (not maxdepth) or len(_path) < (maxdepth - 1): fieldnames.extend( selected_fields( child_field, fragments=fragments, variables=variables, maxdepth=maxdepth, pattern=pattern, _path=child_path, ) ) return fieldnames