import csv
import json
import logging
import os
import sqlite3
import tarfile
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from glob import iglob
from io import TextIOWrapper
from logging import getLogger
from pathlib import Path
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, TypedDict,
Union)
from urllib.parse import urlparse
from dateutil.parser import parse as date_parser
from pydantic import BaseModel
from shapely.geometry import box, shape
from tqdm import tqdm
from ..if_exists import DownloadIfExistsOpts
from ..session import Session
from . import datetime_utils
from .resumable_downloader import ResumableDownloader
log = getLogger(__name__)
JsonDict = Dict[str, Any]
GeoJSON = JsonDict
COMMON_ASSET_NAMES = [
'documentation',
'readme',
'test_split',
'train_split',
'validation_split',
]
"""Common assets will be put into `_common` and only downloaded once."""
[docs]class CatalogDownloaderConfig(BaseModel):
"""
Configuration model & validator for CatalogDownloader.
"""
[docs] class Config:
arbitrary_types_allowed = True
api_key: Optional[str] = None
bbox: Optional[Union[Tuple[float], List[float]]] = None
catalog_only: bool = False
collection_filter: Optional[Dict[str, List[str]]] = None
dataset_id: str
if_exists: DownloadIfExistsOpts = DownloadIfExistsOpts.resume
intersects: Optional[GeoJSON] = None
output_dir: Path
asset_output_dir: Optional[Path] = None
profile: Optional[str] = None
mlhub_api_session: Session
"""Requests session for mlhub api calls."""
temporal_query: Optional[Union[datetime, Tuple[datetime, datetime]]] = None
[docs]class AssetRecord(TypedDict):
"""
A stac_assets db record.
"""
rowid: Optional[int]
asset_key: Optional[str]
asset_save_path: Optional[str]
asset_url: str
bbox_json: Optional[str]
collection_id: Optional[str]
common_asset: bool
single_datetime: Optional[datetime]
start_datetime: Optional[datetime]
end_datetime: Optional[datetime]
filtered: bool
geometry_json: Optional[str]
item_id: Optional[str]
[docs]class CatalogDownloader():
config: CatalogDownloaderConfig
err_writer: Any
err_report: TextIOWrapper
err_report_path: Path
catalog_file: Path
work_dir: Path
asset_dir: Path
db_conn: sqlite3.Connection
db_cur: sqlite3.Cursor
def __init__(self, config: CatalogDownloaderConfig):
if config.bbox is not None and config.intersects is not None:
raise ValueError('Provider either bbox or intersects option (not both')
if config.intersects:
if 'geometry' not in config.intersects:
raise ValueError('intersects must be geojson with a geometry property')
self.config = config
self.work_dir = (config.output_dir / config.dataset_id)
if config.asset_output_dir:
self.asset_dir = (config.asset_output_dir / config.dataset_id)
self.asset_dir.mkdir(exist_ok=True, parents=True)
else:
self.asset_dir = self.work_dir
self.work_dir.mkdir(exist_ok=True, parents=True)
self.err_report_path = self.asset_dir / 'err_report.csv'
logging.basicConfig(level=logging.INFO)
def _fetch_unfiltered_count(self) -> int:
self.db_cur.execute(
"""
SELECT COUNT(DISTINCT asset_save_path)
FROM assets WHERE filtered = 0
"""
)
(total_count, ) = self.db_cur.fetchone()
return int(total_count)
def _mark_assets_filtered(self, row_ids: Set[int]) -> None:
in_clause = ','.join([str(row_id) for row_id in row_ids])
self.db_cur.execute(
f"""
UPDATE assets
SET filtered = 1
WHERE rowid in ( { in_clause } )
""",
)
def _fetch_catalog_step(self) -> None:
"""
Fetch the stac catalog archive, save to disk.
Sets path to stac catalog .tar.gz.
"""
c = self.config
out_file = c.output_dir / f'{c.dataset_id}.tar.gz'
dl = ResumableDownloader(
session=c.mlhub_api_session,
url=f'/catalog/{c.dataset_id}',
out_file=out_file,
if_exists=c.if_exists,
desc=f'{c.dataset_id}: fetch stac catalog',
disable_progress_bar=False,
)
dl.run()
assert out_file.exists()
self.catalog_file = out_file
def _unarchive_catalog_step(self) -> None:
"""
Unarchive the stac catalog archive .tar.gz.
In `skip` or `resume` mode, will not overwrite existing files.
"""
c = self.config
msg = f'unarchive {self.catalog_file.name}'
log.info('%s ...', msg)
with tarfile.open(self.catalog_file, 'r:gz') as archive:
if self.config.if_exists == DownloadIfExistsOpts.overwrite:
archive.extractall(path=c.output_dir)
else:
members = archive.getmembers()
for tar_info in tqdm(members, desc=msg):
if (c.output_dir / tar_info.name).exists():
continue
else:
archive.extract(tar_info, path=c.output_dir)
assert (self.work_dir / 'catalog.json').exists()
def _create_asset_list_step(self) -> None:
"""
Scan the stac catalog and extract asset list into tabular format.
Creates table in sqlite db.
"""
msg = 'create stac asset list (please wait) ...'
log.info(msg)
def _asset_save_path(rec: AssetRecord) -> Path:
"""
Transform asset into a local save path. This filesystem layout
is the same as the mlhub's collection archive .tar.gz files.
"""
url = rec['asset_url']
# optimization to prevent calling urlparse
if '.tif' in url:
ext = '.tif'
elif '.tiff' in url:
ext = '.tiff'
elif '.json' in url:
ext = '.json'
elif '.pdf' in url:
ext = '.pdf'
elif '.png' in url:
ext = '.png'
elif '.jpg' in url:
ext = '.jpg'
elif '.jpeg' in url:
ext = '.jpeg'
elif '.csv' in url:
ext = '.csv'
else:
# parse the url and extract the path -> file suffix (slow)
ext = Path(str(urlparse(rec['asset_url']).path)).suffix
assert '.' in ext, 'File extension is not formatted correctly'
base_path = self.asset_dir / rec['collection_id'] # type: ignore
asset_filename = f"{rec['asset_key']}{ext}"
if rec['item_id'] is None:
# this is a collection level asset
return base_path / asset_filename
if rec['common_asset']:
# common assets: save to _common dir (at the collection level) instead of in every item subdir.
return base_path / '_common' / asset_filename
return base_path / rec['item_id'] / asset_filename
def _insert_asset_rec(rec: AssetRecord) -> None:
self.db_cur.execute(
"""
INSERT INTO assets (
collection_id,
item_id,
asset_key,
asset_url,
asset_save_path,
filtered,
common_asset,
bbox_json,
geometry_json,
single_datetime,
start_datetime,
end_datetime
) VALUES (
:collection_id,
:item_id,
:asset_key,
:asset_url,
:asset_save_path,
:filtered,
:common_asset,
:bbox_json,
:geometry_json,
:single_datetime,
:start_datetime,
:end_datetime
);
""",
rec
)
def _handle_item(stac_item: JsonDict) -> None:
item_id = stac_item['id']
assets = stac_item['assets']
props = stac_item['properties']
bbox = stac_item.get('bbox', None)
geometry = stac_item.get('geometry', None)
if geometry and not bbox:
raise RuntimeError(f'item {item_id} has no bbox, but has geometry')
n = 0
for k, v in assets.items():
rec = AssetRecord(
asset_key=k,
asset_save_path=None,
asset_url=v['href'],
bbox_json=json.dumps(bbox) if bbox else None,
collection_id=stac_item['collection'],
common_asset=k in COMMON_ASSET_NAMES,
end_datetime=props.get('end_datetime', None),
filtered=False,
geometry_json=json.dumps(geometry) if geometry else None,
item_id=item_id,
rowid=None,
single_datetime=props.get('datetime', None),
start_datetime=props.get('start_datetime', None),
)
asset_save_path = _asset_save_path(rec).relative_to(self.asset_dir)
rec['asset_save_path'] = str(asset_save_path)
_insert_asset_rec(rec)
n += 1
if n % 1000 == 0:
self.db_conn.commit()
def _handle_collection(stac_collection: JsonDict) -> None:
collection_id = stac_collection['id']
# early out if there is a collection_filter but collection_id is not
# a member of the collection_filter.
if self.config.collection_filter and collection_id not in self.config.collection_filter:
log.warning('skipping collection %s', collection_id)
return
assets = stac_collection.get('assets', None)
if assets is None:
return
n = 0
for k, v in assets.items():
rec = AssetRecord(
asset_key=k,
asset_save_path=None,
asset_url=v['href'],
bbox_json=None,
collection_id=collection_id,
common_asset=False,
end_datetime=None,
filtered=False,
geometry_json=None,
item_id=None,
rowid=None,
single_datetime=None,
start_datetime=None,
)
asset_save_path = _asset_save_path(rec).relative_to(self.asset_dir)
rec['asset_save_path'] = str(asset_save_path)
_insert_asset_rec(rec)
n += 1
if n % 1000 == 0:
self.db_conn.commit()
json_srcs = iglob(str(self.work_dir / '**/*.json'), recursive=True)
for json_src in json_srcs:
p = Path(json_src)
if p.name == 'catalog.json':
continue
with open(json_src, encoding='utf-8') as json_fh:
stac_item = json.load(json_fh)
stac_type = stac_item.get('type', None)
if p.name == 'collection.json' or stac_type == 'Collection':
_handle_collection(stac_item)
else:
_handle_item(stac_item)
log.info('%s unique assets in stac catalog.', self._fetch_unfiltered_count())
def _filter_collections_step(self) -> None:
"""
Iterate through the filters and mark entries in the assets table as `filtered`.
Filter is an allow-list. Only matching collection_ids and optionally, asset keys,
will be included.
"""
f = self.config.collection_filter
if f is None:
return
desc = 'filter by collection ids and asset keys'
log.info(desc)
total_asset_ct = self._fetch_unfiltered_count()
self.db_cur.execute(
"""
SELECT rowid, collection_id, asset_key
FROM assets
WHERE filtered = 0 AND item_id IS NOT NULL
"""
)
progress = tqdm(total=total_asset_ct, desc=desc)
progress_value = 0
row_ids_to_filter = set()
while True:
rows = self.db_cur.fetchmany()
if not rows:
progress.update(total_asset_ct)
break
progress_value += len(rows)
progress.update(progress_value)
for row_tuple in rows:
(row_id, collection_id, asset_key) = row_tuple
filtered = True
if collection_id in f:
# collection_id is a key in the filter (allow list)
filter_asset_keys = f[collection_id]
if not filter_asset_keys:
# no asset keys, so include because of collection id
filtered = False
else:
# check each asset key
if asset_key in filter_asset_keys:
# include asset because it's key appears in filter (allow list)
filtered = False
if filtered:
row_ids_to_filter.add(row_id)
self._mark_assets_filtered(row_ids_to_filter)
total_asset_ct = self._fetch_unfiltered_count()
if total_asset_ct == 0:
raise RuntimeError(
f'after filtering collections_ids and asset keys, zero assets to download. filter: {filter}'
)
log.info('%s assets after collection filter.', total_asset_ct)
def _filter_bbox_step(self) -> None:
"""
Filter items by bounding box intersection. Marks items in the assets
table as `filtered` if they do not intersect.
"""
desc = 'filter by bounding box'
if self.config.bbox is None:
return
bbox_polygon_query = box(*self.config.bbox)
log.info(desc)
total_asset_ct = self._fetch_unfiltered_count()
self.db_cur.execute(
"""
SELECT rowid, item_id, bbox_json
FROM assets
WHERE filtered = 0 AND item_id IS NOT NULL
ORDER BY item_id
"""
)
progress = tqdm(total=total_asset_ct, desc=desc)
progress_value = 0
row_ids_to_filter = set()
while True:
rows = self.db_cur.fetchmany()
if not rows:
progress.update(total_asset_ct)
break
progress_value += len(rows)
progress.update(progress_value)
# cache the bboxs, which belong to items, not to the asset. the
# results are ordered by item_id, and since we're within
# db_cur.fetchmany(), this is a cache with bounded size.
item_bbox_cache: Dict[str, bool] = dict()
for row_tuple in rows:
(row_id, item_id, bbox_json) = row_tuple
if not bbox_json:
log.warning('item missing bbox: %s', item_id)
continue
hit = item_bbox_cache.get(item_id, None)
if hit is None:
bbox = json.loads(bbox_json)
item_bbox_polygon = box(*bbox)
hit = bbox_polygon_query.intersects(item_bbox_polygon)
item_bbox_cache[item_id] = hit
if not hit:
row_ids_to_filter.add(row_id)
self._mark_assets_filtered(row_ids_to_filter)
total_asset_ct = self._fetch_unfiltered_count()
if total_asset_ct == 0:
raise RuntimeError(
f'after filtering by bounding box, zero assets to download. filter: {filter}'
)
log.info('%s assets after bounding box filter.', total_asset_ct)
def _filter_intersects_step(self) -> None:
"""
Filter items by geojson vs. bounding box intersection.
Marks items in the assets table as `filtered` if they do not intersect.
"""
f = self.config.intersects
if f is None:
return
desc = 'filter by intersects'
log.info(desc)
intersects_shape_query = shape(f['geometry'])
total_asset_ct = self._fetch_unfiltered_count()
self.db_cur.execute(
"""
SELECT rowid, item_id, bbox_json
FROM assets
WHERE filtered = 0 AND item_id IS NOT NULL
ORDER BY item_id
"""
)
progress = tqdm(total=total_asset_ct, desc=desc)
progress_value = 0
row_ids_to_filter = set()
while True:
rows = self.db_cur.fetchmany()
if not rows:
progress.update(total_asset_ct)
break
progress_value += len(rows)
progress.update(progress_value)
# cache the spatial join test, which belong to items, not to the
# asset. the results are ordered by item_id, and since we're within
# db_cur.fetchmany(), here we maintain a cache with bounded size.
item_intersects_cache: Dict[str, bool] = dict()
for row_tuple in rows:
(row_id, item_id, bbox_json) = row_tuple
if not bbox_json:
log.warning('item missing bbox: %s', item_id)
continue
hit = item_intersects_cache.get(item_id, None)
if hit is None:
bbox = json.loads(bbox_json)
item_bbox_polygon = box(*bbox)
hit = intersects_shape_query.intersects(item_bbox_polygon)
item_intersects_cache[item_id] = hit
if not hit:
row_ids_to_filter.add(row_id)
self._mark_assets_filtered(row_ids_to_filter)
total_asset_ct = self._fetch_unfiltered_count()
if total_asset_ct == 0:
raise RuntimeError(
f'after filtering by intersects, zero assets to download. filter: {filter}'
)
log.info('%s assets after intersects filter.', total_asset_ct)
def _filter_temporal_step(self) -> None:
"""
Filter items by temporal query. Marks items in the assets table as
`filtered` if they do not fall in the temporal range or single day.
"""
q = self.config.temporal_query
if q is None:
return
desc = 'filter by temporal query'
log.info(desc)
total_asset_ct = self._fetch_unfiltered_count()
self.db_cur.execute(
"""
SELECT rowid, item_id, single_datetime, start_datetime, end_datetime
FROM assets
WHERE filtered = 0 and item_id IS NOT NULL
"""
)
progress = tqdm(total=total_asset_ct, desc=desc)
progress_value = 0
row_ids_to_filter = set()
while True:
rows = self.db_cur.fetchmany()
if not rows:
progress.update(total_asset_ct)
break
progress_value += len(rows)
progress.update(progress_value)
for row_tuple in rows:
(row_id, item_id, single_datetime, start_datetime, end_datetime) = row_tuple
filtered = False
# inspect the stac item for datetime properties
if single_datetime:
# item has single date property
if isinstance(q, tuple):
filtered = not datetime_utils.one_to_range_check(
date_parser(single_datetime),
q
)
else:
filtered = not datetime_utils.one_to_one_check(
date_parser(single_datetime),
q
)
else:
# item has date range properties
start = date_parser(start_datetime)
end = date_parser(end_datetime)
if not start or not end:
# cannot process date range, just skip forward and log a warning
log.warning('cannot compare to missing date range for: %s', item_id)
continue
if isinstance(q, tuple):
filtered = not datetime_utils.range_to_range_check((start, end), q)
else:
filtered = not datetime_utils.one_to_range_check(q, (start, end))
if filtered:
row_ids_to_filter.add(row_id)
self._mark_assets_filtered(row_ids_to_filter)
total_asset_ct = self._fetch_unfiltered_count()
if total_asset_ct == 0:
raise RuntimeError(
f'after filtering by temporal query, zero assets to download. filter: {filter}'
)
log.info('%s assets after temporal filter.', total_asset_ct)
def _asset_download_step(self) -> None:
"""
Download all assets assets table, which are not marked as filtered.
Manage thread pool, build error_log.
"""
def _download_asset_worker(
asset_url: str,
out_file: Path,
if_exists: DownloadIfExistsOpts,
) -> None:
"""
Download asset worker function (will be called in multithreaded context).
Returns Path (out_file) on success, or raises Exception on error.
Warning: if the asset url is scheme s3://, it will be transformed to
https://{bucket_name}.s3.amazonaws.com .
"""
log.debug(
'(thread id: %s) %s -> %s',
threading.get_ident(),
asset_url,
out_file,
)
if not out_file.parent.exists():
out_file.parent.mkdir(exist_ok=True, parents=True)
if 's3://' in asset_url:
# workaround for some datasets, e.g. spacenet
# * use https instead of s3 (avoid adding boto3 dependency)
# * use https://bucket.s3.amazonaws.com because the region is unknown
u = urlparse(asset_url)
bucket_name = u.netloc
cleaned_url = asset_url.replace(
f's3://{bucket_name}',
f'https://{bucket_name}.s3.amazonaws.com'
)
else:
cleaned_url = asset_url
dl = ResumableDownloader(
url=cleaned_url,
out_file=out_file,
if_exists=if_exists,
desc=f'fetch {asset_url}'
)
dl.run()
assert out_file.exists(), f'failed to create {out_file}'
# create set of unique asset save paths to process in threads, to avoid
# having subthreads attempt to write to same file (ex: with _common assets).
self.db_cur.execute("""
SELECT asset_save_path, asset_url, collection_id, item_id, asset_key
FROM assets
WHERE filtered = 0
""")
asset_list = list()
uniq_asset_save_path = set()
while True:
rows = self.db_cur.fetchmany()
if not rows:
break
for row_tuple in rows:
(asset_save_path, asset_url, collection_id, item_id, asset_key) = row_tuple
if asset_save_path not in uniq_asset_save_path:
asset_rec = AssetRecord(
asset_key=asset_key,
asset_save_path=asset_save_path,
asset_url=asset_url,
bbox_json=None,
collection_id=collection_id,
common_asset=False,
end_datetime=None,
filtered=False,
geometry_json=None,
item_id=item_id,
rowid=None,
single_datetime=None,
start_datetime=None,
)
asset_list.append(asset_rec)
uniq_asset_save_path.add(asset_save_path)
self._finalize_db()
if 'PYTEST_CURRENT_TEST' in os.environ and 'MLHUB_CI' in os.environ:
# vcr.py does not work multithreading `requests`, so bail out here
# and consider it a 'dry run'.
return
with ThreadPoolExecutor() as executor:
future_to_asset_record = {
executor.submit(
_download_asset_worker, **dict(
asset_url=asset['asset_url'],
out_file=self.asset_dir / asset['asset_save_path'], # type:ignore
if_exists=self.config.if_exists,
)): asset for asset in asset_list
}
for future in tqdm(
as_completed(future_to_asset_record),
desc='download assets',
total=len(asset_list)
):
asset_rec = future_to_asset_record[future]
try:
future.result()
except Exception as e:
# write a line to err_report in the format:
# (Error code, Dataset ID, Collection ID, Item ID, Asset Key, Asset URL).
self.err_writer.writerow([
str(e),
self.config.dataset_id,
asset_rec.get('collection_id'),
asset_rec.get('item_id'),
asset_rec.get('asset_key'),
asset_rec.get('asset_url'),
])
# write log message with exception info, but don't break out of
# thread pool executor.
err_msg = str(e)
log.exception(err_msg)
def _init_db(self) -> None:
db_path = self.asset_dir / 'mlhub_stac_assets.db'
if db_path.exists():
db_path.unlink()
self.db_conn = sqlite3.connect(
db_path,
detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
)
self.db_cur = self.db_conn.cursor()
self.db_cur.arraysize = 1024
self.db_cur.execute("""
CREATE TABLE assets (
collection_id TEXT,
item_id TEXT,
asset_key TEXT,
asset_url TEXT,
asset_save_path TEXT,
filtered BOOLEAN,
common_asset BOOLEAN,
bbox_json TEXT,
geometry_json TEXT,
single_datetime TEXT,
start_datetime TEXT,
end_datetime TEXT
)
""")
def _finalize_db(self) -> None:
if not self.config.catalog_only:
self.db_conn.commit()
self.db_cur.close()
self.db_conn.close()
def __call__(self) -> None:
"""
Create and run functions for each processing step.
"""
c = self.config
self.err_report = open(self.err_report_path, 'w', encoding='utf-8')
self.err_writer = csv.writer(self.err_report, quoting=csv.QUOTE_MINIMAL)
steps: List[Callable[[], None]] = []
steps.append(self._fetch_catalog_step)
steps.append(self._unarchive_catalog_step)
if not c.catalog_only:
self._init_db()
steps.append(self._create_asset_list_step)
# conditional step for collection/item key filter
if c.collection_filter:
steps.append(self._filter_collections_step)
# conditional step for temporal filter
if c.temporal_query:
steps.append(self._filter_temporal_step)
# conditional step for bounding box spatial filter
if c.bbox:
steps.append(self._filter_bbox_step)
# conditional step for geojson spatial filter
if c.intersects:
steps.append(self._filter_intersects_step)
# create final step for asset downloading
steps.append(self._asset_download_step)
# call each step
for step in steps:
step()
# inspect the error report
self.err_report.flush()
self.err_report.close()
if os.path.getsize(self.err_report_path) > 0:
msg = f'asset download error(s) were logged to {self.err_report.name}'
log.error(msg)
raise IOError(msg)
if c.catalog_only:
log.info('catalog saved to %s', self.work_dir)
else:
log.info('assets saved to %s', self.asset_dir)