Source code for radiant_mlhub.client.catalog_downloader

import csv
import os
import threading
import tarfile
import sqlite3
import json
from glob import iglob

from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from io import TextIOWrapper
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union, Any
from urllib.parse import urlparse
from dateutil.parser import parse as date_parser

from pydantic import BaseModel
from shapely.geometry import box, shape
from tqdm import tqdm

from ..if_exists import DownloadIfExistsOpts
from ..session import Session
from .resumable_downloader import ResumableDownloader

log = getLogger(__name__)

JsonDict = Dict[str, Any]
GeoJSON = JsonDict

COMMON_ASSET_NAMES = [
    'documentation',
    'readme',
    'test_split',
    'train_split',
    'validation_split',
]
"""Common assets will be put into `_common` and only downloaded once."""


[docs]class CatalogDownloaderConfig(BaseModel): """ Configuration model & validator for CatalogDownloader. """
[docs] class Config: arbitrary_types_allowed = True
api_key: Optional[str] = None bbox: Optional[Union[Tuple[float], List[float]]] = None catalog_only: bool = False collection_filter: Optional[Dict[str, List[str]]] = None dataset_id: str if_exists: DownloadIfExistsOpts = DownloadIfExistsOpts.resume intersects: Optional[GeoJSON] = None output_dir: Path profile: Optional[str] = None session: Session temporal_query: Optional[Union[datetime, Tuple[datetime, datetime]]] = None
[docs]class AssetRecord(BaseModel): """ A stac_assets db record. """
[docs] class Config: arbitrary_types_allowed = True
rowid: Optional[int] = None asset_key: Optional[str] = None asset_save_path: Optional[str] = None asset_url: Optional[str] = None bbox_json: Optional[str] = None collection_id: Optional[str] = None common_asset: bool = False single_datetime: Optional[datetime] = None start_datetime: Optional[datetime] = None end_datetime: Optional[datetime] = None filtered: bool = False geometry_json: Optional[str] = None item_id: Optional[str] = None
[docs]class CatalogDownloader(): config: CatalogDownloaderConfig err_report: TextIOWrapper err_report_path: Path catalog_file: Path work_dir: Path db_conn: sqlite3.Connection db_cur: sqlite3.Cursor def __init__(self, config: CatalogDownloaderConfig): if config.bbox is not None and config.intersects is not None: raise ValueError('Provider either bbox or intersects option (not both') if config.intersects: if 'geometry' not in config.intersects: raise ValueError('intersects must be geojson with a geometry property') self.config = config self.work_dir = (config.output_dir / config.dataset_id) self.work_dir.mkdir(exist_ok=True, parents=True) self.err_report_path = self.work_dir / 'err_report.csv' def _fetch_unfiltered_count(self) -> int: self.db_cur.execute( """ SELECT COUNT(DISTINCT asset_save_path) FROM assets WHERE filtered = 0 """ ) (total_count, ) = self.db_cur.fetchone() return int(total_count) def _mark_asset_filtered(self, row_id: int) -> None: self.db_cur.execute( """ UPDATE assets SET filtered = 1 WHERE rowid = ? """, [row_id] ) self.db_conn.commit() def _fetch_catalog_step(self) -> None: """ Fetch the stac catalog archive, save to disk. Sets path to stac catalog .tar.gz. """ c = self.config out_file = c.output_dir / f'{c.dataset_id}.tar.gz' dl = ResumableDownloader( session=c.session, url=f'/catalog/{c.dataset_id}', out_file=out_file, if_exists=c.if_exists, desc=f'{c.dataset_id}: fetch stac catalog', disable_progress_bar=False, ) dl.run() assert out_file.exists() self.catalog_file = out_file def _unarchive_catalog_step(self) -> None: """ Unarchive the stac catalog archive .tar.gz. In `skip` or `resume` mode, will not overwrite existing files. """ c = self.config msg = f'unarchive {self.catalog_file.name}' log.info(msg) with tarfile.open(self.catalog_file, 'r:gz') as archive: if self.config.if_exists == DownloadIfExistsOpts.overwrite: archive.extractall(path=c.output_dir) else: members = archive.getmembers() for tar_info in tqdm(members, desc=msg): if (c.output_dir / tar_info.name).exists(): continue else: archive.extract(tar_info, path=c.output_dir) assert (self.work_dir / 'catalog.json').exists() def _create_asset_list_step(self) -> None: """ Scan the stac catalog and extract asset list into tabular format. Creates table in sqlite db. """ msg = 'create stac asset list' log.info(msg) def _asset_save_path(rec: AssetRecord) -> Path: """ Transform asset into a local save path. This filesystem layout is the same as the mlhub's collection archive .tar.gz files. """ c = self.config ext = Path(str(urlparse(rec.asset_url).path)).suffix base_path = c.output_dir / c.dataset_id / rec.collection_id # type: ignore asset_filename = f'{rec.asset_key}{ext}' if rec.item_id is None: # this is a collection level asset return base_path / asset_filename if rec.common_asset: # common assets: save to _common dir (at the collection level) instead of in every item subdir. return base_path / '_common' / asset_filename return base_path / rec.item_id / asset_filename def _insert_asset_rec(rec: AssetRecord) -> None: self.db_cur.execute( """ INSERT INTO assets ( collection_id, item_id, asset_key, asset_url, asset_save_path, filtered, common_asset, bbox_json, geometry_json, single_datetime, start_datetime, end_datetime ) VALUES ( :collection_id, :item_id, :asset_key, :asset_url, :asset_save_path, :filtered, :common_asset, :bbox_json, :geometry_json, :single_datetime, :start_datetime, :end_datetime ); """, rec.dict() ) self.db_conn.commit() def _handle_item(stac_item: JsonDict) -> None: item_id = stac_item['id'] assets = stac_item['assets'] props = stac_item['properties'] common_meta = props.get('common_metadata', dict()) bbox = stac_item.get('bbox', None) geometry = stac_item.get('geometry', None) if geometry and not bbox: raise RuntimeError(f'item {item_id} has no bbox, but has geometry') for k, v in assets.items(): rec = AssetRecord( collection_id=stac_item['collection'], item_id=item_id, asset_key=k, common_asset=k in COMMON_ASSET_NAMES, asset_url=v['href'], bbox_json=json.dumps(bbox) if bbox else None, geometry_json=json.dumps(geometry) if geometry else None, single_datetime=props.get('datetime', None), start_datetime=common_meta.get('start_datetime', None), end_datetime=common_meta.get('end_datetime', None), ) asset_save_path = _asset_save_path(rec).relative_to(self.work_dir) rec.asset_save_path = str(asset_save_path) _insert_asset_rec(rec) def _handle_collection(stac_collection: JsonDict) -> None: collection_id = stac_collection['id'] assets = stac_collection.get('assets', None) if assets is None: return for k, v in assets.items(): rec = AssetRecord( collection_id=collection_id, asset_key=k, asset_url=v['href'], ) asset_save_path = _asset_save_path(rec).relative_to(self.work_dir) rec.asset_save_path = str(asset_save_path) _insert_asset_rec(rec) json_srcs = iglob(str(self.work_dir / '**/*.json'), recursive=True) for json_src in json_srcs: p = Path(json_src) if p.name == 'catalog.json': continue with open(json_src) as json_fh: stac_item = json.load(json_fh) stac_type = stac_item.get('type', None) if p.name == 'collection.json' or stac_type == 'Collection': _handle_collection(stac_item) else: _handle_item(stac_item) log.info(f'{self._fetch_unfiltered_count()} unique assets in stac catalog.') def _filter_collections_step(self) -> None: """ Iterate through the filters and mark entries in the assets table as `filtered`. Filter is an allow-list. Only matching collection_ids and optionally, asset keys, will be included. """ f = self.config.collection_filter if f is None: return desc = 'filter by collection ids and asset keys' log.info(desc) total_asset_ct = self._fetch_unfiltered_count() self.db_cur.execute( """ SELECT rowid, collection_id, asset_key FROM assets WHERE filtered = 0 AND item_id IS NOT NULL """ ) progress = tqdm(total=total_asset_ct, desc=desc) progress_value = 0 row_ids_to_filter = set() while True: rows = self.db_cur.fetchmany() if not rows: progress.update(total_asset_ct) break progress_value += len(rows) progress.update(progress_value) for row_tuple in rows: (row_id, collection_id, asset_key) = row_tuple filtered = True if collection_id in f: # collection_id is a key in the filter (allow list) filter_asset_keys = f[collection_id] if not filter_asset_keys: # no asset keys, so include because of collection id filtered = False else: # check each asset key if asset_key in filter_asset_keys: # include asset because it's key appears in filter (allow list) filtered = False if filtered: row_ids_to_filter.add(row_id) for row_id in row_ids_to_filter: self._mark_asset_filtered(row_id) total_asset_ct = self._fetch_unfiltered_count() if total_asset_ct == 0: raise RuntimeError( f'after filtering collections_ids and asset keys, zero assets to download. filter: {filter}' ) log.info(f'{total_asset_ct} assets after collection filter.') def _filter_bbox_step(self) -> None: """ Filter items by bounding box intersection. Marks items in the assets table as `filtered` if they do not intersect. """ desc = 'filter by bounding box' if self.config.bbox is None: return bbox_polygon_query = box(*self.config.bbox) log.info(desc) total_asset_ct = self._fetch_unfiltered_count() self.db_cur.execute( """ SELECT rowid, item_id, bbox_json FROM assets WHERE filtered = 0 AND item_id IS NOT NULL ORDER BY item_id """ ) progress = tqdm(total=total_asset_ct, desc=desc) progress_value = 0 row_ids_to_filter = set() while True: rows = self.db_cur.fetchmany() if not rows: progress.update(total_asset_ct) break progress_value += len(rows) progress.update(progress_value) # cache the bboxs, which belong to items, not to the asset. the # results are ordered by item_id, and since we're within # db_cur.fetchmany(), this is a cache with bounded size. item_bbox_cache: Dict[str, bool] = dict() for row_tuple in rows: (row_id, item_id, bbox_json) = row_tuple if not bbox_json: log.warning(f'item missing bbox: {item_id}') continue hit = item_bbox_cache.get(item_id, None) if hit is None: bbox = json.loads(bbox_json) item_bbox_polygon = box(*bbox) hit = bbox_polygon_query.intersects(item_bbox_polygon) item_bbox_cache[item_id] = hit if not hit: row_ids_to_filter.add(row_id) for row_id in row_ids_to_filter: self._mark_asset_filtered(row_id) total_asset_ct = self._fetch_unfiltered_count() if total_asset_ct == 0: raise RuntimeError( f'after filtering by bounding box, zero assets to download. filter: {filter}' ) log.info(f'{total_asset_ct} assets after bounding box filter.') def _filter_intersects_step(self) -> None: """ Filter items by geojson vs. bounding box intersection. Marks items in the assets table as `filtered` if they do not intersect. """ f = self.config.intersects if f is None: return desc = 'filter by intersects' log.info(desc) intersects_shape_query = shape(f['geometry']) total_asset_ct = self._fetch_unfiltered_count() self.db_cur.execute( """ SELECT rowid, item_id, bbox_json FROM assets WHERE filtered = 0 AND item_id IS NOT NULL ORDER BY item_id """ ) progress = tqdm(total=total_asset_ct, desc=desc) progress_value = 0 row_ids_to_filter = set() while True: rows = self.db_cur.fetchmany() if not rows: progress.update(total_asset_ct) break progress_value += len(rows) progress.update(progress_value) # cache the spatial join test, which belong to items, not to the # asset. the results are ordered by item_id, and since we're within # db_cur.fetchmany(), here we maintain a cache with bounded size. item_intersects_cache: Dict[str, bool] = dict() for row_tuple in rows: (row_id, item_id, bbox_json) = row_tuple if not bbox_json: log.warning(f'item missing bbox: {item_id}') continue hit = item_intersects_cache.get(item_id, None) if hit is None: bbox = json.loads(bbox_json) item_bbox_polygon = box(*bbox) hit = intersects_shape_query.intersects(item_bbox_polygon) item_intersects_cache[item_id] = hit if not hit: row_ids_to_filter.add(row_id) for row_id in row_ids_to_filter: self._mark_asset_filtered(row_id) total_asset_ct = self._fetch_unfiltered_count() if total_asset_ct == 0: raise RuntimeError( f'after filtering by intersects, zero assets to download. filter: {filter}' ) log.info(f'{total_asset_ct} assets after intersects filter.') def _filter_temporal_step(self) -> None: """ Filter items by temporal query. Marks items in the assets table as `filtered` if they do not fall in the temporal range or single day. """ def one_to_one_check(d1: datetime, d2: datetime) -> bool: """ Compare day for each. """ return d1.day == d2.day def one_to_range_check(d1: datetime, d2: Tuple[datetime, datetime]) -> bool: """ Compare single datetime with date range. """ (d2_start, d2_end) = d2 return d1 >= d2_start and d1 <= d2_end def range_to_range_check(d1: Tuple[datetime, datetime], d2: Tuple[datetime, datetime]) -> bool: """ Compare two date ranges. """ (d1_start, d1_end) = d1 (d2_start, d2_end) = d2 if d1_start >= d2_start and d1_start <= d2_end: return True if d1_end >= d2_start and d1_start <= d2_end: return True return False q = self.config.temporal_query if q is None: return desc = 'filter by temporal query' log.info(desc) total_asset_ct = self._fetch_unfiltered_count() self.db_cur.execute( """ SELECT rowid, item_id, single_datetime, start_datetime, end_datetime FROM assets WHERE filtered = 0 and item_id IS NOT NULL """ ) progress = tqdm(total=total_asset_ct, desc=desc) progress_value = 0 row_ids_to_filter = set() while True: rows = self.db_cur.fetchmany() if not rows: progress.update(total_asset_ct) break progress_value += len(rows) progress.update(progress_value) for row_tuple in rows: (row_id, item_id, single_datetime, start_datetime, end_datetime) = row_tuple filtered = False # inspect the stac item for datetime properties if single_datetime: # item has single date property if isinstance(q, tuple): filtered = not one_to_range_check( date_parser(single_datetime), q ) else: filtered = not one_to_one_check( date_parser(single_datetime), q ) else: # item has date range properties start = date_parser(start_datetime) end = date_parser(end_datetime) if not start or not end: # cannot process date range, just skip forward and log a warning log.warn(f'cannot compare to missing date range for: {item_id}') next if isinstance(q, tuple): filtered = not range_to_range_check((start, end), q) else: filtered = not one_to_range_check(q, (start, end)) if filtered: row_ids_to_filter.add(row_id) for row_id in row_ids_to_filter: self._mark_asset_filtered(row_id) total_asset_ct = self._fetch_unfiltered_count() if total_asset_ct == 0: raise RuntimeError( f'after filtering by temporal query, zero assets to download. filter: {filter}' ) log.info(f'{total_asset_ct} assets after temporal filter.') def _asset_download_step(self) -> None: """ Download all assets assets table, which are not marked as filtered. Manage thread pool, build error_log. """ def _download_asset_worker( asset_url: str, out_file: Path, if_exists: DownloadIfExistsOpts, ) -> None: """ Download asset worker function (will be called in multithreaded context). Returns Path (out_file) on success, or raises Exception on error. Warning: if the asset url is scheme s3://, it will be transformed to https://{bucket_name}.s3.amazonaws.com . """ log.debug(f'(thread id: {threading.get_ident()}) {asset_url} -> {out_file}') if not out_file.parent.exists(): out_file.parent.mkdir(exist_ok=True, parents=True) if 's3://' in asset_url: # workaround for some datasets, e.g. spacenet # * use https instead of s3 (avoid adding boto3 dependency) # * use https://bucket.s3.amazonaws.com because the region is unknown u = urlparse(asset_url) bucket_name = u.netloc cleaned_url = asset_url.replace( f's3://{bucket_name}', f'https://{bucket_name}.s3.amazonaws.com' ) else: cleaned_url = asset_url dl = ResumableDownloader( url=cleaned_url, out_file=out_file, if_exists=if_exists, desc=f'fetch {asset_url}' ) dl.run() assert out_file.exists(), f'failed to create {out_file}' # create set of unique asset save paths to process in threads, to avoid # having subthreads attempt to write to same file (ex: with _common assets). self.db_cur.execute(""" SELECT asset_save_path, asset_url, collection_id, item_id, asset_key FROM assets WHERE filtered = 0 """) asset_list = list() uniq_asset_save_path = set() while True: rows = self.db_cur.fetchmany() if not rows: break for row_tuple in rows: (asset_save_path, asset_url, collection_id, item_id, asset_key) = row_tuple if asset_save_path not in uniq_asset_save_path: asset_rec = AssetRecord( asset_save_path=asset_save_path, asset_url=asset_url, collection_id=collection_id, # TODO yank? item_id=item_id, # TODO: yank? asset_key=asset_key, ) asset_list.append(asset_rec) uniq_asset_save_path.add(asset_save_path) if 'PYTEST_CURRENT_TEST' in os.environ: # vcr.py does not work multithreading `requests`, so bail out here # and consider it a 'dry run'. return self._finalize_db() with ThreadPoolExecutor() as executor: future_to_asset_record = { executor.submit( _download_asset_worker, **dict( asset_url=r.asset_url, out_file=self.work_dir / r.asset_save_path, # type: ignore if_exists=self.config.if_exists, )): r for r in asset_list } for future in tqdm( as_completed(future_to_asset_record), desc='download assets', total=len(asset_list) ): asset_rec = future_to_asset_record[future] try: future.result() except Exception as e: # write a line to err_report in the format: # (Error code, Dataset ID, Collection ID, Item ID, Asset Key, Asset URL). r = asset_rec self.err_writer.writerow([ str(e), self.config.dataset_id, r.collection_id, r.item_id, r.asset_key, r.asset_url ]) log.exception(e) def _init_db(self) -> None: db_path = self.work_dir / 'mlhub_stac_assets.db' if db_path.exists(): db_path.unlink() self.db_conn = sqlite3.connect( db_path, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES ) self.db_cur = self.db_conn.cursor() self.db_cur.arraysize = 1024 self.db_cur.execute(""" CREATE TABLE assets ( collection_id TEXT, item_id TEXT, asset_key TEXT, asset_url TEXT, asset_save_path TEXT, filtered BOOLEAN, common_asset BOOLEAN, bbox_json TEXT, geometry_json TEXT, single_datetime TEXT, start_datetime TEXT, end_datetime TEXT ) """) def _finalize_db(self) -> None: if not self.config.catalog_only: self.db_cur.close() self.db_conn.close() def __call__(self) -> None: """ Create and run functions for each processing step. """ c = self.config self.err_report = open(self.err_report_path, 'w') self.err_writer = csv.writer(self.err_report, quoting=csv.QUOTE_MINIMAL) steps: List[Callable[[], None]] = [] steps.append(self._fetch_catalog_step) steps.append(self._unarchive_catalog_step) if not c.catalog_only: self._init_db() steps.append(self._create_asset_list_step) # conditional step for collection/item key filter if c.collection_filter: steps.append(self._filter_collections_step) # conditional step for temporal filter if c.temporal_query: steps.append(self._filter_temporal_step) # conditional step for bounding box spatial filter if c.bbox: steps.append(self._filter_bbox_step) # conditional step for geojson spatial filter if c.intersects: steps.append(self._filter_intersects_step) # create final step for asset downloading steps.append(self._asset_download_step) # call each step for step in steps: step() # inspect the error report self.err_report.flush() self.err_report.close() if os.path.getsize(self.err_report_path) > 0: msg = f'asset download error(s) were logged to {self.err_report.name}' log.error(msg) raise IOError(msg) if c.catalog_only: log.info(f'catalog saved to {self.work_dir}') else: log.info(f'assets saved to {self.work_dir}')