Skip to content

entry_collections

EntryCollection

Backend-agnostic base class for querying collections of EntryResources.

Source code in optimade/server/entry_collections/entry_collections.py
class EntryCollection(ABC):
    """Backend-agnostic base class for querying collections of
    [`EntryResource`][optimade.models.entries.EntryResource]s."""

    def __init__(
        self,
        resource_cls: EntryResource,
        resource_mapper: BaseResourceMapper,
        transformer: Transformer,
    ):
        """Initialize the collection for the given parameters.

        Parameters:
            resource_cls (EntryResource): The `EntryResource` model
                that is stored by the collection.
            resource_mapper (BaseResourceMapper): A resource mapper
                object that handles aliases and format changes between
                deserialization and response.
            transformer (Transformer): The Lark `Transformer` used to
                interpret the filter.

        """
        self.parser = LarkParser()
        self.resource_cls = resource_cls
        self.resource_mapper = resource_mapper
        self.transformer = transformer

        self.provider_prefix = CONFIG.provider.prefix
        self.provider_fields = [
            field if isinstance(field, str) else field["name"]
            for field in CONFIG.provider_fields.get(resource_mapper.ENDPOINT, [])
        ]

        self._all_fields: Set[str] = None

    @abstractmethod
    def __len__(self) -> int:
        """Returns the total number of entries in the collection."""

    @abstractmethod
    def insert(self, data: List[EntryResource]) -> None:
        """Add the given entries to the underlying database.

        Arguments:
            data: The entry resource objects to add to the database.

        """

    @abstractmethod
    def count(self, **kwargs) -> int:
        """Returns the number of entries matching the query specified
        by the keyword arguments.

        Parameters:
            kwargs (dict): Query parameters as keyword arguments.

        """

    def find(
        self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
    ) -> Tuple[
        Union[List[EntryResource], EntryResource, None], int, bool, Set[str], Set[str]
    ]:
        """
        Fetches results and indicates if more data is available.

        Also gives the total number of data available in the absence of `page_limit`.
        See [`EntryListingQueryParams`][optimade.server.query_params.EntryListingQueryParams]
        for more information.

        Parameters:
            params: Entry listing URL query params.

        Returns:
            A tuple of various relevant values:
            (`results`, `data_returned`, `more_data_available`, `exclude_fields`, `include_fields`).

        """
        criteria = self.handle_query_params(params)
        single_entry = isinstance(params, SingleEntryQueryParams)
        response_fields = criteria.pop("fields")

        results, data_returned, more_data_available = self._run_db_query(
            criteria, single_entry
        )

        if single_entry:
            results = results[0] if results else None

            if data_returned > 1:
                raise NotFound(
                    detail=f"Instead of a single entry, {data_returned} entries were found",
                )

        exclude_fields = self.all_fields - response_fields
        include_fields = (
            response_fields - self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS
        )

        bad_optimade_fields = set()
        bad_provider_fields = set()
        for field in include_fields:
            if field not in self.resource_mapper.ALL_ATTRIBUTES:
                if field.startswith("_"):
                    if any(
                        field.startswith(f"_{prefix}_")
                        for prefix in self.resource_mapper.SUPPORTED_PREFIXES
                    ):
                        bad_provider_fields.add(field)
                else:
                    bad_optimade_fields.add(field)

        if bad_provider_fields:
            warnings.warn(
                message=f"Unrecognised field(s) for this provider requested in `response_fields`: {bad_provider_fields}.",
                category=UnknownProviderProperty,
            )

        if bad_optimade_fields:
            raise BadRequest(
                detail=f"Unrecognised OPTIMADE field(s) in requested `response_fields`: {bad_optimade_fields}."
            )

        if results:
            results = self.resource_mapper.deserialize(results)

        return (
            results,
            data_returned,
            more_data_available,
            exclude_fields,
            include_fields,
        )

    @abstractmethod
    def _run_db_query(
        self, criteria: Dict[str, Any], single_entry: bool = False
    ) -> Tuple[List[Dict[str, Any]], int, bool]:
        """Run the query on the backend and collect the results.

        Arguments:
            criteria: A dictionary representation of the query parameters.
            single_entry: Whether or not the caller is expecting a single entry response.

        Returns:
            The list of entries from the database (without any re-mapping), the total number of
            entries matching the query and a boolean for whether or not there is more data available.

        """

    @property
    def all_fields(self) -> Set[str]:
        """Get the set of all fields handled in this collection,
        from attribute fields in the schema, provider fields and top-level OPTIMADE fields.

        The set of all fields are lazily created and then cached.
        This means the set is created the first time the property is requested and then cached.

        Returns:
            All fields handled in this collection.

        """
        if not self._all_fields:
            # All OPTIMADE fields
            self._all_fields = (
                self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS.copy()
            )
            self._all_fields |= self.get_attribute_fields()
            # All provider-specific fields
            self._all_fields |= {
                f"_{self.provider_prefix}_{field_name}"
                for field_name in self.provider_fields
            }

        return self._all_fields

    def get_attribute_fields(self) -> Set[str]:
        """Get the set of attribute fields

        Return only the _first-level_ attribute fields from the schema of the resource class,
        resolving references along the way if needed.

        Note:
            It is not needed to take care of other special OpenAPI schema keys than `allOf`,
            since only `allOf` will be found in this context.
            Other special keys can be found in [the Swagger documentation](https://swagger.io/docs/specification/data-models/oneof-anyof-allof-not/).

        Returns:
            Property names.

        """

        schema = self.resource_cls.schema()
        attributes = schema["properties"]["attributes"]
        if "allOf" in attributes:
            allOf = attributes.pop("allOf")
            for dict_ in allOf:
                attributes.update(dict_)
        if "$ref" in attributes:
            path = attributes["$ref"].split("/")[1:]
            attributes = schema.copy()
            while path:
                next_key = path.pop(0)
                attributes = attributes[next_key]
        return set(attributes["properties"].keys())

    def handle_query_params(
        self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
    ) -> Dict[str, Any]:
        """Parse and interpret the backend-agnostic query parameter models into a dictionary
        that can be used by the specific backend.

        Note:
            Currently this method returns the pymongo interpretation of the parameters,
            which will need modification for modified for other backends.

        Parameters:
            params: The initialized query parameter model from the server.

        Raises:
            Forbidden: If too large of a page limit is provided.
            BadRequest: If an invalid request is made, e.g., with incorrect fields
                or response format.

        Returns:
            A dictionary representation of the query parameters.

        """
        cursor_kwargs = {}

        # filter
        if getattr(params, "filter", False):
            cursor_kwargs["filter"] = self.transformer.transform(
                self.parser.parse(params.filter)
            )
        else:
            cursor_kwargs["filter"] = {}

        # response_format
        if (
            getattr(params, "response_format", False)
            and params.response_format != "json"
        ):
            raise BadRequest(
                detail=f"Response format {params.response_format} is not supported, please use response_format='json'"
            )

        # page_limit
        if getattr(params, "page_limit", False):
            limit = params.page_limit
            if limit > CONFIG.page_limit_max:
                raise Forbidden(
                    detail=f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}",
                )
            cursor_kwargs["limit"] = limit
        else:
            cursor_kwargs["limit"] = CONFIG.page_limit

        # response_fields
        cursor_kwargs["projection"] = {
            f"{self.resource_mapper.get_backend_field(f)}": True
            for f in self.all_fields
        }

        if getattr(params, "response_fields", False):
            response_fields = set(params.response_fields.split(","))
            response_fields |= self.resource_mapper.get_required_fields()
        else:
            response_fields = self.all_fields.copy()

        cursor_kwargs["fields"] = response_fields

        # sort
        if getattr(params, "sort", False):
            cursor_kwargs["sort"] = self.parse_sort_params(params.sort)

        # page_offset and page_number
        if getattr(params, "page_offset", False):
            if getattr(params, "page_number", False):
                warnings.warn(
                    message="Only one of the query parameters 'page_number' and 'page_offset' should be set - 'page_number' will be ignored.",
                    category=QueryParamNotUsed,
                )

            cursor_kwargs["skip"] = params.page_offset
        elif getattr(params, "page_number", False):
            if isinstance(params.page_number, int):
                cursor_kwargs["skip"] = (params.page_number - 1) * cursor_kwargs[
                    "limit"
                ]

        return cursor_kwargs

    def parse_sort_params(self, sort_params: str) -> Tuple[Tuple[str, int]]:
        """Handles any sort parameters passed to the collection,
        resolving aliases and dealing with any invalid fields.

        Raises:
            BadRequest: if an invalid sort is requested.

        Returns:
            A tuple of tuples containing the aliased field name and
            sort direction encoded as 1 (ascending) or -1 (descending).

        """
        sort_spec = []
        for field in sort_params.split(","):
            sort_dir = 1
            if field.startswith("-"):
                field = field[1:]
                sort_dir = -1
            aliased_field = self.resource_mapper.get_backend_field(field)
            sort_spec.append((aliased_field, sort_dir))

        unknown_fields = [
            field
            for field, _ in sort_spec
            if self.resource_mapper.get_optimade_field(field) not in self.all_fields
        ]

        if unknown_fields:
            error_detail = "Unable to sort on unknown field{} '{}'".format(
                "s" if len(unknown_fields) > 1 else "",
                "', '".join(unknown_fields),
            )

            # If all unknown fields are "other" provider-specific, then only provide a warning
            if all(
                (
                    re.match(r"_[a-z_0-9]+_[a-z_0-9]*", field)
                    and not field.startswith(f"_{self.provider_prefix}_")
                )
                for field in unknown_fields
            ):
                warnings.warn(error_detail, FieldValueNotRecognized)

            # Otherwise, if all fields are unknown, or some fields are unknown and do not
            # have other provider prefixes, then return 400: Bad Request
            else:
                raise BadRequest(detail=error_detail)

        # If at least one valid field has been provided for sorting, then use that
        sort_spec = tuple(
            (field, sort_dir)
            for field, sort_dir in sort_spec
            if field not in unknown_fields
        )

        return sort_spec

all_fields: Set[str] property readonly

Get the set of all fields handled in this collection, from attribute fields in the schema, provider fields and top-level OPTIMADE fields.

The set of all fields are lazily created and then cached. This means the set is created the first time the property is requested and then cached.

Returns:

Type Description
Set[str]

All fields handled in this collection.

__init__(self, resource_cls, resource_mapper, transformer) special

Initialize the collection for the given parameters.

Parameters:

Name Type Description Default
resource_cls EntryResource

The EntryResource model that is stored by the collection.

required
resource_mapper BaseResourceMapper

A resource mapper object that handles aliases and format changes between deserialization and response.

required
transformer Transformer

The Lark Transformer used to interpret the filter.

required
Source code in optimade/server/entry_collections/entry_collections.py
def __init__(
    self,
    resource_cls: EntryResource,
    resource_mapper: BaseResourceMapper,
    transformer: Transformer,
):
    """Initialize the collection for the given parameters.

    Parameters:
        resource_cls (EntryResource): The `EntryResource` model
            that is stored by the collection.
        resource_mapper (BaseResourceMapper): A resource mapper
            object that handles aliases and format changes between
            deserialization and response.
        transformer (Transformer): The Lark `Transformer` used to
            interpret the filter.

    """
    self.parser = LarkParser()
    self.resource_cls = resource_cls
    self.resource_mapper = resource_mapper
    self.transformer = transformer

    self.provider_prefix = CONFIG.provider.prefix
    self.provider_fields = [
        field if isinstance(field, str) else field["name"]
        for field in CONFIG.provider_fields.get(resource_mapper.ENDPOINT, [])
    ]

    self._all_fields: Set[str] = None

__len__(self) special

Returns the total number of entries in the collection.

Source code in optimade/server/entry_collections/entry_collections.py
@abstractmethod
def __len__(self) -> int:
    """Returns the total number of entries in the collection."""

count(self, **kwargs)

Returns the number of entries matching the query specified by the keyword arguments.

Parameters:

Name Type Description Default
kwargs dict

Query parameters as keyword arguments.

{}
Source code in optimade/server/entry_collections/entry_collections.py
@abstractmethod
def count(self, **kwargs) -> int:
    """Returns the number of entries matching the query specified
    by the keyword arguments.

    Parameters:
        kwargs (dict): Query parameters as keyword arguments.

    """

find(self, params)

Fetches results and indicates if more data is available.

Also gives the total number of data available in the absence of page_limit. See EntryListingQueryParams for more information.

Parameters:

Name Type Description Default
params Union[optimade.server.query_params.EntryListingQueryParams, optimade.server.query_params.SingleEntryQueryParams]

Entry listing URL query params.

required

Returns:

Type Description
A tuple of various relevant values

(results, data_returned, more_data_available, exclude_fields, include_fields).

Source code in optimade/server/entry_collections/entry_collections.py
def find(
    self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
) -> Tuple[
    Union[List[EntryResource], EntryResource, None], int, bool, Set[str], Set[str]
]:
    """
    Fetches results and indicates if more data is available.

    Also gives the total number of data available in the absence of `page_limit`.
    See [`EntryListingQueryParams`][optimade.server.query_params.EntryListingQueryParams]
    for more information.

    Parameters:
        params: Entry listing URL query params.

    Returns:
        A tuple of various relevant values:
        (`results`, `data_returned`, `more_data_available`, `exclude_fields`, `include_fields`).

    """
    criteria = self.handle_query_params(params)
    single_entry = isinstance(params, SingleEntryQueryParams)
    response_fields = criteria.pop("fields")

    results, data_returned, more_data_available = self._run_db_query(
        criteria, single_entry
    )

    if single_entry:
        results = results[0] if results else None

        if data_returned > 1:
            raise NotFound(
                detail=f"Instead of a single entry, {data_returned} entries were found",
            )

    exclude_fields = self.all_fields - response_fields
    include_fields = (
        response_fields - self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS
    )

    bad_optimade_fields = set()
    bad_provider_fields = set()
    for field in include_fields:
        if field not in self.resource_mapper.ALL_ATTRIBUTES:
            if field.startswith("_"):
                if any(
                    field.startswith(f"_{prefix}_")
                    for prefix in self.resource_mapper.SUPPORTED_PREFIXES
                ):
                    bad_provider_fields.add(field)
            else:
                bad_optimade_fields.add(field)

    if bad_provider_fields:
        warnings.warn(
            message=f"Unrecognised field(s) for this provider requested in `response_fields`: {bad_provider_fields}.",
            category=UnknownProviderProperty,
        )

    if bad_optimade_fields:
        raise BadRequest(
            detail=f"Unrecognised OPTIMADE field(s) in requested `response_fields`: {bad_optimade_fields}."
        )

    if results:
        results = self.resource_mapper.deserialize(results)

    return (
        results,
        data_returned,
        more_data_available,
        exclude_fields,
        include_fields,
    )

get_attribute_fields(self)

Get the set of attribute fields

Return only the first-level attribute fields from the schema of the resource class, resolving references along the way if needed.

Note

It is not needed to take care of other special OpenAPI schema keys than allOf, since only allOf will be found in this context. Other special keys can be found in the Swagger documentation.

Returns:

Type Description
Set[str]

Property names.

Source code in optimade/server/entry_collections/entry_collections.py
def get_attribute_fields(self) -> Set[str]:
    """Get the set of attribute fields

    Return only the _first-level_ attribute fields from the schema of the resource class,
    resolving references along the way if needed.

    Note:
        It is not needed to take care of other special OpenAPI schema keys than `allOf`,
        since only `allOf` will be found in this context.
        Other special keys can be found in [the Swagger documentation](https://swagger.io/docs/specification/data-models/oneof-anyof-allof-not/).

    Returns:
        Property names.

    """

    schema = self.resource_cls.schema()
    attributes = schema["properties"]["attributes"]
    if "allOf" in attributes:
        allOf = attributes.pop("allOf")
        for dict_ in allOf:
            attributes.update(dict_)
    if "$ref" in attributes:
        path = attributes["$ref"].split("/")[1:]
        attributes = schema.copy()
        while path:
            next_key = path.pop(0)
            attributes = attributes[next_key]
    return set(attributes["properties"].keys())

handle_query_params(self, params)

Parse and interpret the backend-agnostic query parameter models into a dictionary that can be used by the specific backend.

Note

Currently this method returns the pymongo interpretation of the parameters, which will need modification for modified for other backends.

Parameters:

Name Type Description Default
params Union[optimade.server.query_params.EntryListingQueryParams, optimade.server.query_params.SingleEntryQueryParams]

The initialized query parameter model from the server.

required

Exceptions:

Type Description
Forbidden

If too large of a page limit is provided.

BadRequest

If an invalid request is made, e.g., with incorrect fields or response format.

Returns:

Type Description
Dict[str, Any]

A dictionary representation of the query parameters.

Source code in optimade/server/entry_collections/entry_collections.py
def handle_query_params(
    self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
) -> Dict[str, Any]:
    """Parse and interpret the backend-agnostic query parameter models into a dictionary
    that can be used by the specific backend.

    Note:
        Currently this method returns the pymongo interpretation of the parameters,
        which will need modification for modified for other backends.

    Parameters:
        params: The initialized query parameter model from the server.

    Raises:
        Forbidden: If too large of a page limit is provided.
        BadRequest: If an invalid request is made, e.g., with incorrect fields
            or response format.

    Returns:
        A dictionary representation of the query parameters.

    """
    cursor_kwargs = {}

    # filter
    if getattr(params, "filter", False):
        cursor_kwargs["filter"] = self.transformer.transform(
            self.parser.parse(params.filter)
        )
    else:
        cursor_kwargs["filter"] = {}

    # response_format
    if (
        getattr(params, "response_format", False)
        and params.response_format != "json"
    ):
        raise BadRequest(
            detail=f"Response format {params.response_format} is not supported, please use response_format='json'"
        )

    # page_limit
    if getattr(params, "page_limit", False):
        limit = params.page_limit
        if limit > CONFIG.page_limit_max:
            raise Forbidden(
                detail=f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}",
            )
        cursor_kwargs["limit"] = limit
    else:
        cursor_kwargs["limit"] = CONFIG.page_limit

    # response_fields
    cursor_kwargs["projection"] = {
        f"{self.resource_mapper.get_backend_field(f)}": True
        for f in self.all_fields
    }

    if getattr(params, "response_fields", False):
        response_fields = set(params.response_fields.split(","))
        response_fields |= self.resource_mapper.get_required_fields()
    else:
        response_fields = self.all_fields.copy()

    cursor_kwargs["fields"] = response_fields

    # sort
    if getattr(params, "sort", False):
        cursor_kwargs["sort"] = self.parse_sort_params(params.sort)

    # page_offset and page_number
    if getattr(params, "page_offset", False):
        if getattr(params, "page_number", False):
            warnings.warn(
                message="Only one of the query parameters 'page_number' and 'page_offset' should be set - 'page_number' will be ignored.",
                category=QueryParamNotUsed,
            )

        cursor_kwargs["skip"] = params.page_offset
    elif getattr(params, "page_number", False):
        if isinstance(params.page_number, int):
            cursor_kwargs["skip"] = (params.page_number - 1) * cursor_kwargs[
                "limit"
            ]

    return cursor_kwargs

insert(self, data)

Add the given entries to the underlying database.

Parameters:

Name Type Description Default
data List[optimade.models.entries.EntryResource]

The entry resource objects to add to the database.

required
Source code in optimade/server/entry_collections/entry_collections.py
@abstractmethod
def insert(self, data: List[EntryResource]) -> None:
    """Add the given entries to the underlying database.

    Arguments:
        data: The entry resource objects to add to the database.

    """

parse_sort_params(self, sort_params)

Handles any sort parameters passed to the collection, resolving aliases and dealing with any invalid fields.

Exceptions:

Type Description
BadRequest

if an invalid sort is requested.

Returns:

Type Description
Tuple[Tuple[str, int]]

A tuple of tuples containing the aliased field name and sort direction encoded as 1 (ascending) or -1 (descending).

Source code in optimade/server/entry_collections/entry_collections.py
def parse_sort_params(self, sort_params: str) -> Tuple[Tuple[str, int]]:
    """Handles any sort parameters passed to the collection,
    resolving aliases and dealing with any invalid fields.

    Raises:
        BadRequest: if an invalid sort is requested.

    Returns:
        A tuple of tuples containing the aliased field name and
        sort direction encoded as 1 (ascending) or -1 (descending).

    """
    sort_spec = []
    for field in sort_params.split(","):
        sort_dir = 1
        if field.startswith("-"):
            field = field[1:]
            sort_dir = -1
        aliased_field = self.resource_mapper.get_backend_field(field)
        sort_spec.append((aliased_field, sort_dir))

    unknown_fields = [
        field
        for field, _ in sort_spec
        if self.resource_mapper.get_optimade_field(field) not in self.all_fields
    ]

    if unknown_fields:
        error_detail = "Unable to sort on unknown field{} '{}'".format(
            "s" if len(unknown_fields) > 1 else "",
            "', '".join(unknown_fields),
        )

        # If all unknown fields are "other" provider-specific, then only provide a warning
        if all(
            (
                re.match(r"_[a-z_0-9]+_[a-z_0-9]*", field)
                and not field.startswith(f"_{self.provider_prefix}_")
            )
            for field in unknown_fields
        ):
            warnings.warn(error_detail, FieldValueNotRecognized)

        # Otherwise, if all fields are unknown, or some fields are unknown and do not
        # have other provider prefixes, then return 400: Bad Request
        else:
            raise BadRequest(detail=error_detail)

    # If at least one valid field has been provided for sorting, then use that
    sort_spec = tuple(
        (field, sort_dir)
        for field, sort_dir in sort_spec
        if field not in unknown_fields
    )

    return sort_spec

create_collection(name, resource_cls, resource_mapper)

Create an EntryCollection of the configured type, depending on the value of CONFIG.database_backend.

Parameters:

Name Type Description Default
name str

The collection name.

required
resource_cls EntryResource

The type of entry resource to be stored within the collection.

required
resource_mapper BaseResourceMapper

The associated resource mapper for that entry resource type.

required

Returns:

Type Description
EntryCollection

The created EntryCollection.

Source code in optimade/server/entry_collections/entry_collections.py
def create_collection(
    name: str, resource_cls: EntryResource, resource_mapper: BaseResourceMapper
) -> "EntryCollection":
    """Create an `EntryCollection` of the configured type, depending on the value of
    `CONFIG.database_backend`.

    Arguments:
        name: The collection name.
        resource_cls: The type of entry resource to be stored within the collection.
        resource_mapper: The associated resource mapper for that entry resource type.

    Returns:
        The created `EntryCollection`.

    """
    if CONFIG.database_backend in (
        SupportedBackend.MONGODB,
        SupportedBackend.MONGOMOCK,
    ):
        from optimade.server.entry_collections.mongo import MongoCollection

        return MongoCollection(
            name=name,
            resource_cls=resource_cls,
            resource_mapper=resource_mapper,
        )

    if CONFIG.database_backend is SupportedBackend.ELASTIC:
        from optimade.server.entry_collections.elasticsearch import ElasticCollection

        return ElasticCollection(
            name=name,
            resource_cls=resource_cls,
            resource_mapper=resource_mapper,
        )

    raise NotImplementedError(
        f"The database backend {CONFIG.database_backend!r} is not implemented"
    )