Source code for radiant_mlhub.models.dataset

"""Extensions of the `PySTAC <https://pystac.readthedocs.io/en/latest/>`_ classes that provide convenience methods for interacting
with the `Radiant MLHub API <https://docs.mlhub.earth/#radiant-mlhub-api>`_."""

from __future__ import annotations

import concurrent.futures
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union

from .. import client
from . import Collection

TagOrTagList = Union[str, Iterable[str]]
TextOrTextList = Union[str, Iterable[str]]


[docs]class Dataset: """Class that brings together multiple Radiant MLHub "collections" that are all considered part of a single "dataset". For instance, the ``bigearthnet_v1`` dataset is composed of both a source imagery collection (``bigearthnet_v1_source``) and a labels collection (``bigearthnet_v1_labels``). Attributes ---------- id : str The dataset ID. title : str or None The title of the dataset (or ``None`` if dataset has no title). registry_url : str or None The URL to the registry page for this dataset, or ``None`` if no registry page exists. doi : str or None The DOI identifier for this dataset, or ``None`` if there is no DOI for this dataset. citation: str or None The citation information for this dataset, or ``None`` if there is no citation information. """ def __init__( self, id: str, collections: List[Dict[str, Any]], title: Optional[str] = None, registry: Optional[str] = None, doi: Optional[str] = None, citation: Optional[str] = None, *, api_key: Optional[str] = None, profile: Optional[str] = None, # Absorbs additional keyword arguments to protect against changes to dataset object from API # https://github.com/radiantearth/radiant-mlhub/issues/41 **_: Any ): self.id = id self.title = title self.collection_descriptions = collections self.registry_url = registry self.doi = doi self.citation = citation self.session_kwargs = {} if api_key: self.session_kwargs['api_key'] = api_key if profile: self.session_kwargs['profile'] = profile self._collections: Optional['_CollectionList'] = None @property def collections(self) -> _CollectionList: """List of collections associated with this dataset. The list that is returned has 2 additional attributes (``source_imagery`` and ``labels``) that represent the list of collections corresponding the each type. .. note:: This is a cached property, so updating ``self.collection_descriptions`` after calling ``self.collections`` the first time will have no effect on the results. See :func:`functools.cached_property` for details on clearing the cached value. Examples -------- >>> from radiant_mlhub import Dataset >>> dataset = Dataset.fetch('bigearthnet_v1') >>> len(dataset.collections) 2 >>> len(dataset.collections.source_imagery) 1 >>> len(dataset.collections.labels) 1 To loop through all collections >>> for collection in dataset.collections: ... # Do something here To loop through only the source imagery collections: >>> for collection in dataset.collections.source_imagery: ... # Do something here To loop through only the label collections: >>> for collection in dataset.collections.labels: ... # Do something here """ if self._collections is None: # Internal method to return a Collection along with it's CollectionType def _fetch_collection(_collection_description: Dict[str, Any]) -> _CollectionWithType: return _CollectionWithType( Collection.fetch(_collection_description['id'], **self.session_kwargs), [CollectionType(type_) for type_ in _collection_description['types']] ) # Fetch all collections and create Collection instances if len(self.collection_descriptions) == 1: # If there is only 1 collection, fetch it in the same thread only_description = self.collection_descriptions[0] collections = [_fetch_collection(only_description)] else: # If there are multiple collections, fetch them concurrently with concurrent.futures.ThreadPoolExecutor() as exc: collections = list(exc.map(_fetch_collection, self.collection_descriptions)) self._collections = _CollectionList(collections) return self._collections
[docs] @classmethod def list( cls, *, tags: Optional[TagOrTagList] = None, text: Optional[TextOrTextList] = None, api_key: Optional[str] = None, profile: Optional[str] = None ) -> List['Dataset']: """Returns a list of :class:`Dataset` instances for each datasets hosted by MLHub. See the :ref:`Authentication` documentation for details on how authentication is handled for this request. Parameters ---------- tags : A list of tags to filter datasets by. If not ``None``, only datasets containing all provided tags will be returned. text : A list of text phrases to filter datasets by. If not ``None``, only datasets containing all phrases will be returned. api_key : str An API key to use for this request. This will override an API key set in a profile on using an environment variable profile: str A profile to use when making this request. Yields ------ dataset : Dataset """ return [ cls(**d, api_key=api_key, profile=profile) for d in client.list_datasets(tags=tags, text=text, api_key=api_key, profile=profile) ]
[docs] @classmethod def fetch_by_doi(cls, dataset_doi: str, *, api_key: Optional[str] = None, profile: Optional[str] = None) -> "Dataset": """Creates a :class:`Dataset` instance by fetching the dataset with the given DOI from the Radiant MLHub API. Parameters ---------- dataset_doi : str The DOI of the dataset to fetch (e.g. ``10.6084/m9.figshare.12047478.v2``). api_key : str An API key to use for this request. This will override an API key set in a profile on using an environment variable profile: str A profile to use when making this request. Returns ------- dataset : Dataset """ return cls( **client.get_dataset_by_doi(dataset_doi, api_key=api_key, profile=profile), api_key=api_key, profile=profile, )
[docs] @classmethod def fetch_by_id(cls, dataset_id: str, *, api_key: Optional[str] = None, profile: Optional[str] = None) -> 'Dataset': """Creates a :class:`Dataset` instance by fetching the dataset with the given ID from the Radiant MLHub API. Parameters ---------- dataset_id : str The ID of the dataset to fetch (e.g. ``bigearthnet_v1``). api_key : str An API key to use for this request. This will override an API key set in a profile on using an environment variable profile: str A profile to use when making this request. Returns ------- dataset : Dataset """ return cls( **client.get_dataset_by_id( dataset_id, api_key=api_key, profile=profile ) )
[docs] @classmethod def fetch(cls, dataset_id_or_doi: str, *, api_key: Optional[str] = None, profile: Optional[str] = None) -> 'Dataset': """Creates a :class:`Dataset` instance by first trying to fetching the dataset based on ID, then falling back to fetching by DOI. Parameters ---------- dataset_id_or_doi : str The ID or DOI of the dataset to fetch (e.g. ``bigearthnet_v1``). api_key : str An API key to use for this request. This will override an API key set in a profile on using an environment variable profile: str A profile to use when making this request. Returns ------- dataset : Dataset """ return cls( **client.get_dataset(dataset_id_or_doi, api_key=api_key, profile=profile), api_key=api_key, profile=profile, )
[docs] def download( self, output_dir: Union[Path, str], *, if_exists: str = 'resume', api_key: Optional[str] = None, profile: Optional[str] = None ) -> List[Path]: """Downloads archives for all collections associated with this dataset to given directory. Each archive will be named using the collection ID (e.g. some_collection.tar.gz). If ``output_dir`` does not exist, it will be created. .. note:: Some collections may be very large and take a significant amount of time to download, depending on your connection speed. Parameters ---------- output_dir : str or pathlib.Path The directory into which the archives will be written. if_exists : str, optional How to handle an existing archive at the same location. If ``"skip"``, the download will be skipped. If ``"overwrite"``, the existing file will be overwritten and the entire file will be re-downloaded. If ``"resume"`` (the default), the existing file size will be compared to the size of the download (using the ``Content-Length`` header). If the existing file is smaller, then only the remaining portion will be downloaded. Otherwise, the download will be skipped. api_key : str An API key to use for this request. This will override an API key set in a profile on using an environment variable profile: str A profile to use when making this request. Returns ------- output_paths : List[pathlib.Path] List of paths to the downloaded archives Raises ------- IOError If ``output_dir`` exists and is not a directory. FileExistsError If one of the archive files already exists in the ``output_dir`` and both ``exist_okay`` and ``overwrite`` are ``False``. """ return [ collection.download(output_dir, if_exists=if_exists, api_key=api_key, profile=profile) for collection in self.collections ]
@property def total_archive_size(self) -> Optional[int]: """Gets the total size (in bytes) of the archives for all collections associated with this dataset. If no archives exist, returns ``None``.""" # Since self.collections is cached on the Dataset instance, and collection.archive_size is # cached on each Collection, we don't bother to cache this property. archive_sizes = [ collection.archive_size for collection in self.collections if collection.archive_size is not None ] return None if not archive_sizes else sum(archive_sizes)
[docs]class CollectionType(Enum): """Valid values for the type of a collection associated with a Radiant MLHub dataset.""" SOURCE = 'source_imagery' LABELS = 'labels'
class _CollectionWithType: def __init__(self, collection: Collection, types: List[CollectionType]): self.types = [CollectionType(type_) for type_ in types] self.collection = collection class _CollectionList: """Used internally by :class:`Dataset` to create a list of collections that can also be accessed by type using the ``source_imagery`` and ``labels`` attributes.""" _source_imagery: Optional[List[Collection]] _labels: Optional[List[Collection]] _collections: List[_CollectionWithType] def __init__(self, collections_with_type: List[_CollectionWithType]): self._collections = collections_with_type self._source_imagery = None self._labels = None def __iter__(self) -> Iterator[Collection]: for item in self._collections: yield item.collection def __len__(self) -> int: return len(self._collections) def __getitem__(self, item: int) -> Collection: return self._collections[item].collection def __repr__(self) -> str: return list(self.__iter__()).__repr__() @property def source_imagery(self) -> List[Collection]: if self._source_imagery is None: self._source_imagery = [ c.collection for c in self._collections if any(type_ is CollectionType.SOURCE for type_ in c.types) ] return self._source_imagery @property def labels(self) -> List[Collection]: if self._labels is None: self._labels = [ c.collection for c in self._collections if any(type_ is CollectionType.LABELS for type_ in c.types) ] return self._labels