Source code for skyplane.api.tracker

import functools
import json
import time
from abc import ABC
from datetime import datetime
from threading import Thread

import urllib3
from typing import TYPE_CHECKING, Dict, List, Optional, Set

from skyplane import exceptions
from skyplane.api.config import TransferConfig
from skyplane.chunk import ChunkRequest, ChunkState, Chunk
from skyplane.utils import logger, imports
from skyplane.utils.fn import do_parallel
from skyplane.api.usage import UsageClient
from skyplane.utils.definitions import GB

if TYPE_CHECKING:
    from skyplane.api.transfer_job import TransferJob


[docs]class TransferHook(ABC): """Hook that shows transfer related stats"""
[docs] def on_dispatch_start(self): """Starting the dispatch job""" raise NotImplementedError()
[docs] def on_chunk_dispatched(self, chunks: List[Chunk]): """Dispatching data chunks to transfer""" raise NotImplementedError()
[docs] def on_dispatch_end(self): """Ending the dispatch job""" raise NotImplementedError()
[docs] def on_chunk_completed(self, chunks: List[Chunk]): """Chunks are all transferred""" raise NotImplementedError()
[docs] def on_transfer_end(self, transfer_stats): """Ending the transfer job""" raise NotImplementedError()
[docs] def on_transfer_error(self, error): """Showing the tranfer error if it fails""" raise NotImplementedError()
[docs]class EmptyTransferHook(TransferHook): """Empty transfer hook that does nothing""" def __init__(self): return
[docs] def on_dispatch_start(self): return
[docs] def on_chunk_dispatched(self, chunks: List[Chunk]): return
[docs] def on_dispatch_end(self): return
[docs] def on_chunk_completed(self, chunks: List[Chunk]): return
[docs] def on_transfer_end(self, transfer_stats): return
[docs] def on_transfer_error(self, error): return
[docs]class TransferProgressTracker(Thread): """Tracks transfer progress in one tranfer session""" def __init__(self, dataplane, jobs: List["TransferJob"], transfer_config: TransferConfig, hooks: TransferHook): """ :param dataplane: dataplane that starts the transfer :type dataplane: Dataplane :param jobs: list of transfer jobs launched in parallel :type jobs: List :param transfer_config: the configuration during the transfer :type transfer_config: TransferConfig :param hooks: the hook that shows transfer related stats :type hooks: TransferHook """ super().__init__() self.dataplane = dataplane self.jobs = {job.uuid: job for job in jobs} self.transfer_config = transfer_config if hooks is None: self.hooks = EmptyTransferHook() else: self.hooks = hooks # log job details logger.fs.debug(f"[TransferProgressTracker] Using dataplane {dataplane}") logger.fs.debug(f"[TransferProgressTracker] Initialized with {len(jobs)} jobs:") for job_uuid, job in self.jobs.items(): logger.fs.debug(f"[TransferProgressTracker] * {job_uuid}: {job}") logger.fs.debug(f"[TransferProgressTracker] Transfer config: {transfer_config}") # transfer state self.job_chunk_requests: Dict[str, Dict[str, ChunkRequest]] = {} self.job_pending_chunk_ids: Dict[str, Set[str]] = {} self.job_complete_chunk_ids: Dict[str, Set[str]] = {} self.errors: Optional[Dict[str, List[str]]] = None # http_pool self.http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=3)) def __str__(self): return f"TransferProgressTracker({self.dataplane}, {self.jobs})"
[docs] def run(self): """Dispatch and start the transfer jobs""" src_cloud_provider = self.dataplane.src_region_tag.split(":")[0] dst_cloud_provider = self.dataplane.dst_region_tag.split(":")[0] args = { "cmd": ",".join([job.__class__.__name__ for job in self.jobs.values()]), "recursive": ",".join([str(job.recursive) for job in self.jobs.values()]), "multipart": self.transfer_config.multipart_enabled, "instances_per_region": self.dataplane.max_instances, "src_instance_type": getattr(self.transfer_config, f"{src_cloud_provider}_instance_class"), "dst_instance_type": getattr(self.transfer_config, f"{dst_cloud_provider}_instance_class"), "src_spot_instance": getattr(self.transfer_config, f"{src_cloud_provider}_use_spot_instances"), "dst_spot_instance": getattr(self.transfer_config, f"{dst_cloud_provider}_use_spot_instances"), } session_start_timestamp_ms = int(time.time() * 1000) try: # pre-dispatch chunks to begin pre-buffering chunks cr_streams = { job_uuid: job.dispatch(self.dataplane, transfer_config=self.transfer_config) for job_uuid, job in self.jobs.items() } for job_uuid, job in self.jobs.items(): logger.fs.debug(f"[TransferProgressTracker] Dispatching job {job.uuid}") self.job_chunk_requests[job_uuid] = {} self.job_pending_chunk_ids[job_uuid] = set() self.job_complete_chunk_ids[job_uuid] = set() for cr in cr_streams[job_uuid]: chunks_dispatched = [cr.chunk] self.job_chunk_requests[job_uuid][cr.chunk.chunk_id] = cr self.job_pending_chunk_ids[job_uuid].add(cr.chunk.chunk_id) self.hooks.on_chunk_dispatched(chunks_dispatched) logger.fs.debug( f"[TransferProgressTracker] Job {job.uuid} dispatched with {len(self.job_chunk_requests[job_uuid])} chunk requests" ) except Exception as e: UsageClient.log_exception( "dispatch job", e, args, self.dataplane.src_region_tag, self.dataplane.dst_region_tag, session_start_timestamp_ms ) raise e self.hooks.on_dispatch_end() # Record only the transfer time start_time = int(time.time()) try: self.monitor_transfer() except exceptions.SkyplaneGatewayException as err: reformat_err = Exception(err.pretty_print_str()[37:]) UsageClient.log_exception( "monitor transfer", reformat_err, args, self.dataplane.src_region_tag, self.dataplane.dst_region_tag, session_start_timestamp_ms, ) raise err except Exception as e: UsageClient.log_exception( "monitor transfer", e, args, self.dataplane.src_region_tag, self.dataplane.dst_region_tag, session_start_timestamp_ms ) raise e end_time = int(time.time()) try: for job in self.jobs.values(): logger.fs.debug(f"[TransferProgressTracker] Finalizing job {job.uuid}") job.finalize() except Exception as e: UsageClient.log_exception( "finalize job", e, args, self.dataplane.src_region_tag, self.dataplane.dst_region_tag, session_start_timestamp_ms ) raise e try: for job in self.jobs.values(): logger.fs.debug(f"[TransferProgressTracker] Verifying job {job.uuid}") job.verify() except Exception as e: UsageClient.log_exception( "verify job", e, args, self.dataplane.src_region_tag, self.dataplane.dst_region_tag, session_start_timestamp_ms ) raise e # transfer successfully completed transfer_stats = { "total_runtime_s": end_time - start_time, "throughput_gbits": self.query_bytes_dispatched() / (end_time - start_time) / GB * 8, } self.hooks.on_transfer_end(transfer_stats) UsageClient.log_transfer( transfer_stats, args, self.dataplane.src_region_tag, self.dataplane.dst_region_tag, session_start_timestamp_ms )
[docs] @imports.inject("pandas") def monitor_transfer(pd, self): """Monitor the tranfer by copying remote gateway logs and show transfer stats by hooks""" # todo implement transfer monitoring to update job_complete_chunk_ids and job_pending_chunk_ids while the transfer is in progress sinks = self.dataplane.topology.sink_instances() sink_regions = set([sink.region for sink in sinks]) while any([len(self.job_pending_chunk_ids[job_uuid]) > 0 for job_uuid in self.job_pending_chunk_ids]): # refresh shutdown status by running noop do_parallel(lambda i: i.run_command("echo 1"), self.dataplane.bound_nodes.values(), n=-1) # check for errors and exit if there are any (while setting debug flags) errors = self.dataplane.check_error_logs() if any(errors.values()): logger.warning("Copying gateway logs...") self.dataplane.copy_logs() self.errors = errors raise exceptions.SkyplaneGatewayException("Transfer failed with errors", errors) log_df = pd.DataFrame(self._query_chunk_status()) if log_df.empty: logger.warning("No chunk status log entries yet") time.sleep(0.05) continue is_complete_rec = ( lambda row: row["state"] == ChunkState.upload_complete and row["instance"] in [s.instance for s in sinks] and row["region"] in [s.region for s in sinks] ) sink_status_df = log_df[log_df.apply(is_complete_rec, axis=1)] completed_status = sink_status_df.groupby("chunk_id").apply(lambda x: set(x["region"].unique()) == set(sink_regions)) completed_chunk_ids = completed_status[completed_status].index # update job_complete_chunk_ids and job_pending_chunk_ids for job_uuid, job in self.jobs.items(): job_complete_chunk_ids = set(chunk_id for chunk_id in completed_chunk_ids if self._chunk_to_job_map[chunk_id] == job_uuid) new_chunk_ids = ( self.job_complete_chunk_ids[job_uuid].union(job_complete_chunk_ids).difference(self.job_complete_chunk_ids[job_uuid]) ) completed_chunks = [] for id in new_chunk_ids: completed_chunks.append(self.job_chunk_requests[job_uuid][id].chunk) self.hooks.on_chunk_completed(completed_chunks) self.job_complete_chunk_ids[job_uuid] = self.job_complete_chunk_ids[job_uuid].union(job_complete_chunk_ids) self.job_pending_chunk_ids[job_uuid] = self.job_pending_chunk_ids[job_uuid].difference(job_complete_chunk_ids) # sleep time.sleep(0.05)
@property @functools.lru_cache(maxsize=1) def _chunk_to_job_map(self): return {chunk_id: job_uuid for job_uuid, cr_dict in self.job_chunk_requests.items() for chunk_id in cr_dict.keys()} def _query_chunk_status(self): def get_chunk_status(args): node, instance = args reply = self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/chunk_status_log") if reply.status != 200: raise Exception( f"Failed to get chunk status from gateway instance {instance.instance_name()}: {reply.data.decode('utf-8')}" ) logs = [] for log_entry in json.loads(reply.data.decode("utf-8"))["chunk_status_log"]: log_entry["region"] = node.region log_entry["instance"] = node.instance log_entry["time"] = datetime.fromisoformat(log_entry["time"]) log_entry["state"] = ChunkState.from_str(log_entry["state"]) logs.append(log_entry) return logs rows = [] for result in do_parallel(get_chunk_status, self.dataplane.bound_nodes.items(), n=-1, return_args=False): rows.extend(result) return rows @property def is_complete(self): """Return if the transfer is complete""" return all([len(self.job_pending_chunk_ids[job_uuid]) == 0 for job_uuid in self.jobs.keys()])
[docs] def query_bytes_remaining(self): """Query the total number of bytes remaining in all the transfer jobs""" if len(self.job_chunk_requests) == 0: return None bytes_remaining_per_job = {} for job_uuid in self.job_pending_chunk_ids.keys(): bytes_remaining_per_job[job_uuid] = sum( [ cr.chunk.chunk_length_bytes for cr in self.job_chunk_requests[job_uuid].values() if cr.chunk.chunk_id in self.job_pending_chunk_ids[job_uuid] ] ) logger.fs.debug(f"[TransferProgressTracker] Bytes remaining per job: {bytes_remaining_per_job}") return sum(bytes_remaining_per_job.values())
[docs] def query_bytes_dispatched(self): """Query the total number of bytes dispatched to chunks ready for transfer""" if len(self.job_chunk_requests) == 0: return 0 bytes_total_per_job = {} for job_uuid in self.job_complete_chunk_ids.keys(): bytes_total_per_job[job_uuid] = sum( [ cr.chunk.chunk_length_bytes for cr in self.job_chunk_requests[job_uuid].values() if cr.chunk.chunk_id in self.job_complete_chunk_ids[job_uuid] ] ) return sum(bytes_total_per_job.values())