diff --git a/biobricks/brick.py b/biobricks/brick.py index fe5d318..0db5424 100644 --- a/biobricks/brick.py +++ b/biobricks/brick.py @@ -5,7 +5,7 @@ from .config import bblib, token from .logger import logger import os, urllib.request as request, functools, shutil, yaml -from .downloader import download_out +from .dvc_fetcher import DVCFetcher from urllib.parse import urlparse import sys from .checks import check_url_available, check_token, check_safe_git_repo @@ -114,7 +114,12 @@ def path(self): def _relpath(self): "get the path to this brick relative to bblib" return self.urlpath() / self.commit - + + def get_dvc_lock(self): + """get the dvc.lock file for this brick""" + with open(self.path() / "dvc.lock") as f: + return yaml.safe_load(f) + def install(self): "install this brick" logger.info(f"running checks on brick") @@ -132,18 +137,7 @@ def install(self): cmd(f"git clone {self.remote} {self._relpath()}", cwd = bblib()) cmd(f"git checkout {self.commit}", cwd = self.path()) - outs = [] - with open(self.path() / "dvc.lock") as f: - dvc_lock = yaml.safe_load(f) - stages = [stage for stage in dvc_lock.get('stages', []).values()] - outs = [out for stage in stages for out in stage.get('outs', [])] - - brick_outs = [out for out in outs if out.get('path').startswith('brick')] - for out in brick_outs: - md5 = out.get('md5') - relpath = out.get('path') - dest_path = self.path() / relpath - download_out(md5, dest_path) + DVCFetcher().fetch_outs(self) logger.info(f"\033[94m{self.url()}\033[0m succesfully downloaded to BioBricks library.") return self diff --git a/biobricks/downloader.py b/biobricks/downloader.py index c39cb5a..6842ce4 100644 --- a/biobricks/downloader.py +++ b/biobricks/downloader.py @@ -7,50 +7,65 @@ from pathlib import Path from tqdm import tqdm # Import tqdm for the progress bar -def _download_outdir(url, dest_path: Path): - with requests.get(url, headers={'BBToken': token()}, stream=True) as r: - r.raise_for_status() - for o in r.json(): - download_out(o['md5'], dest_path / o['relpath']) - -def _download_outfile(url, path: Path, bytes=None): - - with requests.get(url, headers={'BBToken': token()}, stream=True) as r: - r.raise_for_status() - total_size = bytes if bytes else int(r.headers.get('content-length', 0)) - progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, desc=f"Downloading file") - with open(path, 'wb') as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - progress_bar.update(len(chunk)) - progress_bar.close() - if total_size != 0 and progress_bar.n != total_size: - logger.error("ERROR, something went wrong") - - -def download_out(md5, dest: Path, url_prefix="https://dvc.biobricks.ai/files/md5/", bytes=None): - - # make parent directories - dest.parent.mkdir(parents=True, exist_ok=True) - cache_path = bblib() / 'cache' / md5[:2] / md5[2:] - cache_path.parent.mkdir(parents=True, exist_ok=True) - - remote_url = url_prefix + md5[:2] + "/" + md5[2:] - - if md5.endswith('.dir'): - logger.info(f"downloading directory {remote_url} to {dest}") - return _download_outdir(remote_url, dest) - - if not cache_path.exists(): - logger.info(f"downloading file {remote_url} to {cache_path}") - _download_outfile(remote_url, cache_path, bytes) +from dataclasses import dataclass, field + +@dataclass +class Downloader: + remote_url_prefix: str = field(default = 'https://dvc.biobricks.ai/files/md5/') + + def _md5_to_remote_url( self, md5 ): + return self.remote_url_prefix + md5[:2] + "/" + md5[2:] + + def _download_outdir(self, url, dest_path: Path): + with requests.get(url, headers={'BBToken': token()}, stream=True) as r: + r.raise_for_status() + for o in r.json(): + self.download_out(o['md5'], dest_path / o['relpath']) + + def _download_outfile(self, url, path: Path, bytes=None): + + with requests.get(url, headers={'BBToken': token()}, stream=True) as r: + r.raise_for_status() + total_size = bytes if bytes else int(r.headers.get('content-length', 0)) + progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, desc=f"Downloading file") + with open(path, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + progress_bar.update(len(chunk)) + progress_bar.close() + if total_size != 0 and progress_bar.n != total_size: + logger.error("ERROR, something went wrong") - dest.unlink(missing_ok=True) # remove the symlink if it exists - if not biobricks.checks.can_symlink(): - logger.warning(f"you are not able to make symlinks cache-files will be copied to bricks. This is an inefficient use of disk space.") - shutil.copy(cache_path, dest) - else: - os.symlink(cache_path, dest) + + def download_out(self, md5, dest: Path, url_prefix="https://dvc.biobricks.ai/files/md5/", bytes=None): - + # make parent directories + dest.parent.mkdir(parents=True, exist_ok=True) + cache_path = bblib() / 'cache' / md5[:2] / md5[2:] + cache_path.parent.mkdir(parents=True, exist_ok=True) + + remote_url = self._md5_to_remote_url(md5) + + if md5.endswith('.dir'): + logger.info(f"downloading directory {remote_url} to {dest}") + return self._download_outdir(remote_url, dest) + + if not cache_path.exists(): + logger.info(f"downloading file {remote_url} to {cache_path}") + self._download_outfile(remote_url, cache_path, bytes) + + dest.unlink(missing_ok=True) # remove the symlink if it exists + if not biobricks.checks.can_symlink(): + logger.warning(f"you are not able to make symlinks cache-files will be copied to bricks. This is an inefficient use of disk space.") + shutil.copy(cache_path, dest) + else: + os.symlink(cache_path, dest) + + def download_by_prefix(self, outs, prefix, path): + brick_outs = [out for out in outs if out.get('path').startswith(prefix)] + for out in brick_outs: + md5 = out.get('md5') + relpath = out.get('path') + dest_path = path / relpath + self.download_out(md5, dest_path) diff --git a/biobricks/dvc_fetcher.py b/biobricks/dvc_fetcher.py new file mode 100644 index 0000000..6e0bb5d --- /dev/null +++ b/biobricks/dvc_fetcher.py @@ -0,0 +1,187 @@ +import biobricks.checks +import biobricks.config +from biobricks.logger import logger + + +from dataclasses import dataclass +import requests, threading, time, shutil, os +import signal +from tqdm import tqdm +from pathlib import Path + +def signal_handler(signum, frame, interrupt_event): + interrupt_event.set() + logger.info("Interrupt signal received. Attempting to terminate downloads gracefully...") + +class PositionManager: + def __init__(self): + self.available_positions = [] + self.lock = threading.Lock() + self.max_position = 0 + + def get_position(self): + with self.lock: + if self.available_positions: + return self.available_positions.pop(0) + else: + self.max_position += 1 + return self.max_position + + def release_position(self, position): + with self.lock: + self.available_positions.append(position) + self.available_positions.sort() + +class DownloadThread(threading.Thread): + + def __init__(self, url, total_progress_bar, path, headers, position_manager, semaphore, interrupt_event): + super(DownloadThread, self).__init__() + self.url = url + self.total_progress_bar = total_progress_bar + self.path = path + self.headers = headers + self.position_manager = position_manager + self.semaphore = semaphore + self.interrupt_event = interrupt_event + + def run(self): + position = self.position_manager.get_position() + self.path.parent.mkdir(parents=True, exist_ok=True) + try: + response = requests.get(self.url, stream=True, headers=self.headers) + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + block_size = 1024 + + with tqdm(total=total_size, unit='iB', unit_scale=True, disable=False, desc=str(self.path), position=position, leave=False) as progress: + with open(self.path, 'wb') as file: + for data in response.iter_content(chunk_size=block_size): + if self.interrupt_event.is_set(): # Check if the thread should stop + logger.info(f"Stopping download of {self.url}") + return # Exit the thread gracefully + if data: + file.write(data) + progress.update(len(data)) + self.total_progress_bar.update(len(data)) + finally: + self.semaphore.release() # Release the semaphore when the thread is done + self.position_manager.release_position(position) + +@dataclass +class DownloadManager: + skip_existing: bool = False + progress_bar : tqdm = None + active_threads : int = 0 + interrupt_event : threading.Event = threading.Event() + + def download_files(self, urls, paths, total_size, max_threads=4): + signal.signal(signal.SIGINT, lambda signum, frame: signal_handler(signum, frame, self.interrupt_event)) + + self.progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, position=0, desc="Overall Progress") + position_manager = PositionManager() + semaphore = threading.Semaphore(max_threads) + threads = [] + + for url, path in zip(urls, paths): + semaphore.acquire() # Block until a semaphore permit is available + if self.interrupt_event.is_set(): + logger.info("Download process interrupted. Waiting for ongoing downloads to complete...") + semaphore.release() + break + thread = DownloadThread(url, self.progress_bar, path, {'BBToken': biobricks.config.token()}, position_manager, semaphore, self.interrupt_event) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + self.progress_bar.close() + print(f"\n{len(paths)} files downloaded successfully!") + + +class DVCFetcher: + + def __init__(self, remote_url_prefix: str = 'https://dvc.biobricks.ai/files/md5/'): + self.cache_path = biobricks.config.bblib() / "cache" + self.remote_url_prefix: str = remote_url_prefix + + def _remote_url_to_cache_path(self, remote_url): + return self.cache_path / remote_url.split('/')[-2] / remote_url.split('/')[-1] + + def _md5_to_remote_url(self, md5): + return self.remote_url_prefix + md5[:2] + "/" + md5[2:] + + # TODO - this should have a better solution for error handling. What if the internet goes out? What if it's a completely wrong file? + def _expand_outdir(self, remote_url, path : Path) -> list[dict]: + """Returns a list of (md5,path) tuples for a given directory-out, skips on error.""" + try: + with requests.get(remote_url, headers={'BBToken': biobricks.config.token()}, stream=True) as r: + r.raise_for_status() + return [{'md5': o['md5'], 'path': path / o['relpath']} for o in r.json()] + except requests.exceptions.HTTPError as e: + logger.warning(f"Failed to fetch {remote_url}: {e}") # Log the error + return [] # Return an empty list to skip this directory-out + + def _find_all_dirouts(self, dir_outs = []) -> list[dict]: + urls, paths = [], [] + + while dir_outs: + current_dir_out = dir_outs.pop() + current_dir_path = Path(current_dir_out['path']) + expanded_outs = self._expand_outdir(self._md5_to_remote_url(current_dir_out['md5']), current_dir_path) + + # Separate file and directory outputs + file_outs = [out for out in expanded_outs if not out['md5'].endswith('.dir')] + more_dir_outs = [out for out in expanded_outs if out['md5'].endswith('.dir')] + + # Update lists with file outputs + urls.extend(self._md5_to_remote_url(out['md5']) for out in file_outs) + paths.extend(out['path'] for out in file_outs) + + # Add new directory outputs to be processed + dir_outs.extend(more_dir_outs) + + return urls, paths + + def _link_cache_to_brick(self, cache_path, brick_path): + "create a symlink from cache_path to brick_path, copy it if symlinks are not supported." + if not cache_path.exists(): + logger.warning(f"cache file {cache_path} does not exist") + return + + brick_path.unlink(missing_ok=True) + brick_path.parent.mkdir(parents=True, exist_ok=True) + + if not biobricks.checks.can_symlink(): + logger.warning(f"you are not able to make symlinks cache-files will be copied to bricks. This is an inefficient use of disk space.") + shutil.copy(cache_path, brick_path) + else: + os.symlink(cache_path, brick_path) + + def fetch_outs(self, brick, prefixes=['brick/', 'data/']) -> tuple[list[dict], int]: + dvc_lock = brick.get_dvc_lock() + stages = [stage for stage in dvc_lock.get('stages', []).values()] + all_outs = [out for stage in stages for out in stage.get('outs', [])] + + has_prefix = lambda x: any(x.get('path').startswith(prefix) for prefix in prefixes) + outs = [o for o in all_outs if has_prefix(o)] + total_size = sum(o.get('size') for o in outs) + + dir_outs = [o for o in outs if o.get('md5').endswith('.dir')] + dir_urls, dir_paths = self._find_all_dirouts(dir_outs) + + file_outs = [o for o in outs if not o.get('md5').endswith('.dir')] + urls = dir_urls + [self._md5_to_remote_url(o.get('md5')) for o in file_outs] + paths = dir_paths + [o.get('path') for o in file_outs] + + # download files + cache_paths = [self._remote_url_to_cache_path(url) for url in urls] + downloader = DownloadManager() + downloader.download_files(urls, cache_paths, total_size) + + # build a symlink between each cache_path and its corresponding path + brick_paths = [brick.path() / path for path in paths] + for cache_path, brick_path in zip(cache_paths, brick_paths): + self._link_cache_to_brick(cache_path, brick_path) + + return urls, paths, total_size diff --git a/tests/test_bricks.py b/tests/test_bricks.py index f9b7656..deeef41 100644 --- a/tests/test_bricks.py +++ b/tests/test_bricks.py @@ -2,20 +2,37 @@ from unittest.mock import patch from pathlib import Path from biobricks import Brick -from biobricks.config import write_config, init_bblib +from biobricks.config import write_config, init_bblib, token import tempfile import pandas as pd import sqlite3 import os class BrickTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.token = os.environ.get("BIOBRICKS_TEST_TOKEN",None) or token() + + # Create a temporary directory for the whole class + cls.class_temp_dir = tempfile.TemporaryDirectory() + cls.temp_biobricks_config_path = Path(cls.class_temp_dir.name) / "biobricks_config_temp.json" + + # Patch the biobricks.config.biobricks_config_path static method + cls.patcher = patch('biobricks.config.biobricks_config_path', return_value=cls.temp_biobricks_config_path) + cls.mock_biobricks_config_path = cls.patcher.start() + @classmethod + def tearDownClass(cls): + # Stop the patch after all tests in the class + cls.patcher.stop() + # Clean up the temporary directory for the class + cls.class_temp_dir.cleanup() + def setUp(self): self.tempdir = tempfile.TemporaryDirectory() bblib = Path(f"{self.tempdir.name}/biobricks") bblib.mkdir(exist_ok=True,parents=True) - token = os.environ.get("BIOBRICKS_TEST_TOKEN") - config = { "BBLIB": f"{bblib}", "TOKEN": token } + config = { "BBLIB": f"{bblib}", "TOKEN": BrickTests.token } write_config(config) init_bblib()