Source code for radiant_mlhub.models.dataset

"""Extensions of the `PySTAC <>`_ classes that provide convenience methods for interacting
with the `Radiant MLHub API <>`_."""

from __future__ import annotations

from pathlib import Path
import concurrent.futures
from datetime import datetime

from enum import Enum
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union, Tuple

from ..session import get_session

from .. import client
from . import Collection
from ..client import CatalogDownloader, CatalogDownloaderConfig
from ..if_exists import DownloadIfExistsOpts
TagOrTagList = Union[str, Iterable[str]]
TextOrTextList = Union[str, Iterable[str]]

JsonDict = Dict[str, Any]
GeoJSON = JsonDict

[docs]class Dataset: """Class that brings together multiple Radiant MLHub "collections" that are all considered part of a single "dataset". For instance, the ``bigearthnet_v1`` dataset is composed of both a source imagery collection (``bigearthnet_v1_source``) and a labels collection (``bigearthnet_v1_labels``). Attributes ---------- id : str The dataset ID. title : str or None The title of the dataset (or ``None`` if dataset has no title). registry_url : str or None The URL to the registry page for this dataset, or ``None`` if no registry page exists. doi : str or None The DOI identifier for this dataset, or ``None`` if there is no DOI for this dataset. citation: str or None The citation information for this dataset, or ``None`` if there is no citation information. """ def __init__( self, id: str, collections: List[Dict[str, Any]], title: Optional[str] = None, registry: Optional[str] = None, doi: Optional[str] = None, citation: Optional[str] = None, *, api_key: Optional[str] = None, profile: Optional[str] = None, # Absorbs additional keyword arguments to protect against changes to dataset object from API # **_: Any ): = id self.title = title self.collection_descriptions = collections self.registry_url = registry self.doi = doi self.citation = citation self.session_kwargs = {} if api_key: self.session_kwargs['api_key'] = api_key if profile: self.session_kwargs['profile'] = profile self._collections: Optional['_CollectionList'] = None @property def collections(self) -> _CollectionList: """List of collections associated with this dataset. The list that is returned has 2 additional attributes (``source_imagery`` and ``labels``) that represent the list of collections corresponding the each type. .. note:: This is a cached property, so updating ``self.collection_descriptions`` after calling ``self.collections`` the first time will have no effect on the results. See :func:`functools.cached_property` for details on clearing the cached value. Examples -------- >>> from radiant_mlhub import Dataset >>> dataset = Dataset.fetch('bigearthnet_v1') >>> len(dataset.collections) 2 >>> len(dataset.collections.source_imagery) 1 >>> len(dataset.collections.labels) 1 To loop through all collections >>> for collection in dataset.collections: ... # Do something here To loop through only the source imagery collections: >>> for collection in dataset.collections.source_imagery: ... # Do something here To loop through only the label collections: >>> for collection in dataset.collections.labels: ... # Do something here """ if self._collections is None: # Internal method to return a Collection along with it's CollectionType def _fetch_collection(_collection_description: Dict[str, Any]) -> _CollectionWithType: return _CollectionWithType( Collection.fetch(_collection_description['id'], **self.session_kwargs), [CollectionType(type_) for type_ in _collection_description['types']] ) # Fetch all collections and create Collection instances if len(self.collection_descriptions) == 1: # If there is only 1 collection, fetch it in the same thread only_description = self.collection_descriptions[0] collections = [_fetch_collection(only_description)] else: # If there are multiple collections, fetch them concurrently with concurrent.futures.ThreadPoolExecutor() as exc: collections = list(, self.collection_descriptions)) self._collections = _CollectionList(collections) return self._collections
[docs] @classmethod def list( cls, *, tags: Optional[TagOrTagList] = None, text: Optional[TextOrTextList] = None, api_key: Optional[str] = None, profile: Optional[str] = None ) -> List['Dataset']: """Returns a list of :class:`Dataset` instances for each datasets hosted by MLHub. See the :ref:`Authentication` documentation for details on how authentication is handled for this request. Parameters ---------- tags : A list of tags to filter datasets by. If not ``None``, only datasets containing all provided tags will be returned. text : A list of text phrases to filter datasets by. If not ``None``, only datasets containing all phrases will be returned. api_key : str An API key to use for this request. This will override an API key set in a profile on using an environment variable profile: str A profile to use when making this request. Yields ------ dataset : Dataset """ return [ cls(**d, api_key=api_key, profile=profile) for d in client.list_datasets(tags=tags, text=text, api_key=api_key, profile=profile) ]
[docs] @classmethod def fetch_by_doi(cls, dataset_doi: str, *, api_key: Optional[str] = None, profile: Optional[str] = None) -> "Dataset": """Creates a :class:`Dataset` instance by fetching the dataset with the given DOI from the Radiant MLHub API. Parameters ---------- dataset_doi : str The DOI of the dataset to fetch (e.g. ``10.6084/m9.figshare.12047478.v2``). api_key : str An API key to use for this request. This will override an API key set in a profile on using an environment variable profile: str A profile to use when making this request. Returns ------- dataset : Dataset """ return cls( **client.get_dataset_by_doi(dataset_doi, api_key=api_key, profile=profile), api_key=api_key, profile=profile, )
[docs] @classmethod def fetch_by_id(cls, dataset_id: str, *, api_key: Optional[str] = None, profile: Optional[str] = None) -> 'Dataset': """Creates a :class:`Dataset` instance by fetching the dataset with the given ID from the Radiant MLHub API. Parameters ---------- dataset_id : str The ID of the dataset to fetch (e.g. ``bigearthnet_v1``). api_key : str An API key to use for this request. This will override an API key set in a profile on using an environment variable profile: str A profile to use when making this request. Returns ------- dataset : Dataset """ return cls( **client.get_dataset_by_id( dataset_id, api_key=api_key, profile=profile ) )
[docs] @classmethod def fetch(cls, dataset_id_or_doi: str, *, api_key: Optional[str] = None, profile: Optional[str] = None) -> 'Dataset': """Creates a :class:`Dataset` instance by first trying to fetching the dataset based on ID, then falling back to fetching by DOI. Parameters ---------- dataset_id_or_doi : str The ID or DOI of the dataset to fetch (e.g. ``bigearthnet_v1``). api_key : str An API key to use for this request. This will override an API key set in a profile on using an environment variable profile: str A profile to use when making this request. Returns ------- dataset : Dataset """ return cls( **client.get_dataset(dataset_id_or_doi, api_key=api_key, profile=profile), api_key=api_key, profile=profile, )
def __str__(self) -> str: """Return the "informal" or nicely printable string representation of an object.""" return f'{}: {self.title}' @property def stac_catalog_size(self) -> Optional[int]: """Size of the dataset_id.tar.gz STAC archive (bytes)""" info = client.get_catalog_info(, **self.session_kwargs) return info.get('stac_catalog_size', None) @property def estimated_dataset_size(self) -> Optional[int]: """Size in bytes of entire dataset (bytes)""" info = client.get_catalog_info(, **self.session_kwargs) return info.get('estimated_dataset_size', None)
[docs] def download( self, output_dir: Union[Path, str] = Path.cwd(), *, catalog_only: bool = False, if_exists: DownloadIfExistsOpts = DownloadIfExistsOpts.resume, api_key: Optional[str] = None, profile: Optional[str] = None, bbox: Optional[List[float]] = None, intersects: Optional[GeoJSON] = None, datetime: Optional[Union[datetime, Tuple[datetime, datetime]]] = None, collection_filter: Optional[Dict[str, List[str]]] = None, ) -> None: """ Downloads dataset's STAC catalog and all linked assets. The download may be customized and controlled by providing bbox, intersects, datetime, and filter options. Parameters ---------- output_dir: str or pathlib.Path The directory into which the archives will be written. Defaults to current working directory. catalog_only: bool If True, the STAC catalog will be downloaded and unarchived, but no assets wll be downloaded. Defaults to False. if_exists : str, optional Allowed values: `skip`, `overwrite`, or `resume` (default). bbox: Optional[List[float]] List representing a bounding box of coordinates, for spatial intersection filter. Must be in CRS EPSG:4326. intersects: Optional[GeoJSON] GeoJSON object for spatial intersects filter. Must be a parsed GeoJSON dict with a `geometry` property. datetime: Optional[datetime, Tuple[datetime, datetime]] Single datetime or datetime range for temporal filter. collection_filter: Optional[Dict[str, list]] Mapping of collection_id and asset keys to include (exclusively). examples: * download will only include this collection: `dict(ref_landcovernet_sa_v1_source_sentinel_2=[])` * download will only include this collection and only these asset keys: `dict(ref_landcovernet_sa_v1_source_sentinel_2=["B02", "B03", "B04"])` api_key: Optional[str] An API key to use for this request. This will override an API key set in a profile on using an environment variable. profile: Optional[str] Authentication Profile to use when making this request. Raises ------- IOError If ``output_dir`` exists and is not a directory. If unrecoverable download errors occurred. ValueError If provided filters are incompatible, for example bbox and intersects. RuntimeError If filters result in zero assets to download. Error Reporting --------------- Any unrecoverable download errors will be logged to `{output_dir}/{dataset_id}/err_report.csv`. """ assert output_dir output_path = Path(output_dir) if output_path.exists() and not output_path.is_dir(): raise IOError('output_dir is not directory.') output_path.mkdir(exist_ok=True, parents=True) config = CatalogDownloaderConfig( catalog_only=catalog_only, api_key=api_key, bbox=bbox,, collection_filter=collection_filter, if_exists=if_exists, intersects=intersects, output_dir=output_path, profile=profile, session=get_session(api_key=api_key, profile=profile), temporal_query=datetime, ) dl = CatalogDownloader(config=config) dl()
[docs]class CollectionType(Enum): """Valid values for the type of a collection associated with a Radiant MLHub dataset.""" SOURCE = 'source_imagery' LABELS = 'labels'
class _CollectionWithType: def __init__(self, collection: Collection, types: List[CollectionType]): self.types = [CollectionType(type_) for type_ in types] self.collection = collection class _CollectionList: """Used internally by :class:`Dataset` to create a list of collections that can also be accessed by type using the ``source_imagery`` and ``labels`` attributes.""" _source_imagery: Optional[List[Collection]] _labels: Optional[List[Collection]] _collections: List[_CollectionWithType] def __init__(self, collections_with_type: List[_CollectionWithType]): self._collections = collections_with_type self._source_imagery = None self._labels = None def __iter__(self) -> Iterator[Collection]: for item in self._collections: yield item.collection def __len__(self) -> int: return len(self._collections) def __getitem__(self, item: int) -> Collection: return self._collections[item].collection def __repr__(self) -> str: return list(self.__iter__()).__repr__() @property def source_imagery(self) -> List[Collection]: if self._source_imagery is None: self._source_imagery = [ c.collection for c in self._collections if any(type_ is CollectionType.SOURCE for type_ in c.types) ] return self._source_imagery @property def labels(self) -> List[Collection]: if self._labels is None: self._labels = [ c.collection for c in self._collections if any(type_ is CollectionType.LABELS for type_ in c.types) ] return self._labels