Source code for skyplane.api.tracker

import json
import time
from abc import ABC
from datetime import datetime
from threading import Thread
from functools import partial

import urllib3
from typing import TYPE_CHECKING, Dict, List, Optional, Set

from concurrent.futures import ThreadPoolExecutor, as_completed

from skyplane import exceptions
from skyplane.api.config import TransferConfig
from skyplane.chunk import ChunkState, Chunk
from skyplane.utils import logger, imports
from skyplane.utils.fn import do_parallel
from skyplane.api.usage import UsageClient
from skyplane.utils.definitions import GB
from skyplane.utils.retry import retry_backoff

from skyplane.cli.impl.common import print_stats_completed

if TYPE_CHECKING:
    from skyplane.api.transfer_job import TransferJob


[docs] class TransferHook(ABC): """Hook that shows transfer related stats"""
[docs] def on_dispatch_start(self): """Starting the dispatch job""" raise NotImplementedError()
[docs] def on_chunk_dispatched(self, chunks: List[Chunk]): """Dispatching data chunks to transfer""" raise NotImplementedError()
[docs] def on_dispatch_end(self): """Ending the dispatch job""" raise NotImplementedError()
[docs] def on_chunk_completed(self, chunks: List[Chunk], region_tag: Optional[str] = None): """Chunks are all transferred""" raise NotImplementedError()
[docs] def on_transfer_end(self, transfer_stats): """Ending the transfer job""" raise NotImplementedError()
[docs] def on_transfer_error(self, error): """Showing the tranfer error if it fails""" raise NotImplementedError()
[docs] class EmptyTransferHook(TransferHook): """Empty transfer hook that does nothing""" def __init__(self): return
[docs] def on_dispatch_start(self): return
[docs] def on_chunk_dispatched(self, chunks: List[Chunk]): return
[docs] def on_dispatch_end(self): return
[docs] def on_chunk_completed(self, chunks: List[Chunk], region_tag: Optional[str] = None): return
[docs] def on_transfer_end(self, transfer_stats): return
[docs] def on_transfer_error(self, error): return
[docs] class TransferProgressTracker(Thread): """Tracks transfer progress in one tranfer session""" def __init__(self, dataplane, jobs: List["TransferJob"], transfer_config: TransferConfig, hooks: TransferHook): """ :param dataplane: dataplane that starts the transfer :type dataplane: Dataplane :param jobs: list of transfer jobs launched in parallel :type jobs: List :param transfer_config: the configuration during the transfer :type transfer_config: TransferConfig :param hooks: the hook that shows transfer related stats :type hooks: TransferHook """ super().__init__() self.dataplane = dataplane self.jobs = {job.uuid: job for job in jobs} self.transfer_config = transfer_config if hooks is None: self.hooks = EmptyTransferHook() else: self.hooks = hooks # log job details logger.fs.debug(f"[TransferProgressTracker] Using dataplane {dataplane}") logger.fs.debug(f"[TransferProgressTracker] Initialized with {len(jobs)} jobs:") for job_uuid, job in self.jobs.items(): logger.fs.debug(f"[TransferProgressTracker] * {job_uuid}: {job}") logger.fs.debug(f"[TransferProgressTracker] Transfer config: {transfer_config}") # transfer state self.job_chunk_requests: Dict[str, Dict[str, Chunk]] = {} self.job_pending_chunk_ids: Dict[str, Dict[str, Set[str]]] = {} self.job_complete_chunk_ids: Dict[str, Dict[str, Set[str]]] = {} self.errors: Optional[Dict[str, List[str]]] = None # http_pool self.http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=9)) def __str__(self): return f"TransferProgressTracker({self.dataplane}, {self.jobs})"
[docs] def run(self): """Dispatch and start the transfer jobs""" src_cloud_provider = self.dataplane.topology.src_region_tag.split(":")[0] dst_cloud_provider = self.dataplane.topology.dest_region_tags[0].split(":")[0] args = { "cmd": ",".join([job.__class__.__name__ for job in self.jobs.values()]), "recursive": ",".join([str(job.recursive) for job in self.jobs.values()]), "multipart": self.transfer_config.multipart_enabled, # "instances_per_region": 1, # TODO: read this from config file # "src_instance_type": getattr(self.transfer_config, f"{src_cloud_provider}_instance_class"), # "dst_instance_type": #getattr(self.transfer_config, f"{dst_cloud_provider}_instance_class"), # "src_spot_instance": getattr(self.transfer_config, f"{src_cloud_provider}_use_spot_instances"), # "dst_spot_instance": getattr(self.transfer_config, f"{dst_cloud_provider}_use_spot_instances"), } # TODO: eventually jobs should be able to be concurrently dispatched and executed # however this will require being able to handle conflicting multipart uploads ids # initialize everything first for job_uuid, job in self.jobs.items(): self.job_chunk_requests[job_uuid] = {} self.job_pending_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags} self.job_complete_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags} session_start_timestamp_ms = int(time.time() * 1000) for job_uuid, job in self.jobs.items(): # pre-dispatch chunks to begin pre-buffering chunks try: chunk_stream = job.dispatch(self.dataplane, transfer_config=self.transfer_config) logger.fs.debug(f"[TransferProgressTracker] Dispatching job {job.uuid}") for chunk in chunk_stream: chunks_dispatched = [chunk] # TODO: check chunk ID self.job_chunk_requests[job_uuid][chunk.chunk_id] = chunk assert job_uuid in self.job_chunk_requests and chunk.chunk_id in self.job_chunk_requests[job_uuid] self.hooks.on_chunk_dispatched(chunks_dispatched) for region in self.dataplane.topology.dest_region_tags: self.job_pending_chunk_ids[job_uuid][region].add(chunk.chunk_id) logger.fs.debug( f"[TransferProgressTracker] Job {job.uuid} dispatched with {len(self.job_chunk_requests[job_uuid])} chunk requests" ) except Exception as e: UsageClient.log_exception( "dispatch job", e, args, self.dataplane.topology.src_region_tag, self.dataplane.topology.dest_region_tags[0], # TODO: support multiple destinations session_start_timestamp_ms, ) raise e self.hooks.on_dispatch_end() def monitor_single_dst_helper(dst_region): start_time = time.time() try: self.monitor_transfer(dst_region) except exceptions.SkyplaneGatewayException as err: reformat_err = Exception(err.pretty_print_str()[37:]) UsageClient.log_exception( "monitor transfer", reformat_err, args, self.dataplane.topology.src_region_tag, dst_region, session_start_timestamp_ms, ) raise err except Exception as e: UsageClient.log_exception( "monitor transfer", e, args, self.dataplane.topology.src_region_tag, dst_region, session_start_timestamp_ms ) raise e end_time = time.time() runtime_s = end_time - start_time # transfer successfully completed transfer_stats = { "dst_region": dst_region, "total_runtime_s": round(runtime_s, 4), } results = [] dest_regions = self.dataplane.topology.dest_region_tags with ThreadPoolExecutor(max_workers=len(dest_regions)) as executor: e2e_start_time = time.time() try: future_list = [executor.submit(monitor_single_dst_helper, dest) for dest in dest_regions] for future in as_completed(future_list): results.append(future.result()) except Exception as e: raise e e2e_end_time = time.time() transfer_stats = { "total_runtime_s": e2e_end_time - e2e_start_time, "throughput_gbits": self.query_bytes_dispatched() / (e2e_end_time - e2e_start_time) / GB * 8, } self.hooks.on_transfer_end() int(time.time()) try: for job in self.jobs.values(): logger.fs.debug(f"[TransferProgressTracker] Finalizing job {job.uuid}") job.finalize() except Exception as e: UsageClient.log_exception( "finalize job", e, args, self.dataplane.topology.src_region_tag, self.dataplane.topology.dest_region_tags[0], session_start_timestamp_ms, ) raise e int(time.time()) # verify transfer try: for job in self.jobs.values(): logger.fs.debug(f"[TransferProgressTracker] Verifying job {job.uuid}") job.verify() except Exception as e: UsageClient.log_exception( "verify job", e, args, self.dataplane.topology.src_region_tag, self.dataplane.topology.dest_region_tags[0], session_start_timestamp_ms, ) raise e # transfer successfully completed UsageClient.log_transfer( transfer_stats, args, self.dataplane.topology.src_region_tag, self.dataplane.topology.dest_region_tags, session_start_timestamp_ms, ) print_stats_completed(total_runtime_s=transfer_stats["total_runtime_s"], throughput_gbits=transfer_stats["throughput_gbits"])
[docs] @imports.inject("pandas") def monitor_transfer(pd, self, region_tag): """Monitor the tranfer by copying remote gateway logs and show transfer stats by hooks""" # todo implement transfer monitoring to update job_complete_chunk_ids and job_pending_chunk_ids while the transfer is in progress # regions that are sinks for specific region tag # TODO: should eventualy map bucket to list of instances sinks = [n for nodes in self.dataplane.topology.sink_instances(region_tag).values() for n in nodes] while any([len(self.job_pending_chunk_ids[job_uuid][region_tag]) > 0 for job_uuid in self.job_pending_chunk_ids]): # refresh shutdown status by running noop do_parallel(lambda i: i.run_command("echo 1"), self.dataplane.bound_nodes.values(), n=8) # check for errors and exit if there are any (while setting debug flags) errors = self.dataplane.check_error_logs() if any(errors.values()): logger.warning("Copying gateway logs...") self.dataplane.copy_gateway_logs() self.errors = errors raise exceptions.SkyplaneGatewayException("Transfer failed with errors", errors) log_df = pd.DataFrame(self._query_chunk_status()) if log_df.empty: logger.warning("No chunk status log entries yet") time.sleep(0.05) continue # TODO: have visualization for completition across all destinations is_complete_rec = ( lambda row: row["state"] == ChunkState.complete and row["instance"] in [s.gateway_id for s in sinks] # and row["region_tag"] in region_sinks ) sink_status_df = log_df[log_df.apply(is_complete_rec, axis=1)] completed_chunk_ids = list(set(sink_status_df.chunk_id.unique())) # update job_complete_chunk_ids and job_pending_chunk_ids # TODO: do chunk-tracking per-destination for job_uuid, job in self.jobs.items(): try: job_complete_chunk_ids = set( chunk_id for chunk_id in completed_chunk_ids if self._chunk_to_job_map[chunk_id] == job_uuid ) except Exception as e: raise e new_chunk_ids = ( self.job_complete_chunk_ids[job_uuid][region_tag] .union(job_complete_chunk_ids) .difference(self.job_complete_chunk_ids[job_uuid][region_tag]) ) completed_chunks = [] for id in new_chunk_ids: assert ( job_uuid in self.job_chunk_requests and id in self.job_chunk_requests[job_uuid] ), f"Missing chunk id {id} for job {job_uuid}: {self.job_chunk_requests}" for id in new_chunk_ids: completed_chunks.append(self.job_chunk_requests[job_uuid][id]) self.hooks.on_chunk_completed(completed_chunks, region_tag) self.job_complete_chunk_ids[job_uuid][region_tag] = self.job_complete_chunk_ids[job_uuid][region_tag].union( job_complete_chunk_ids ) self.job_pending_chunk_ids[job_uuid][region_tag] = self.job_pending_chunk_ids[job_uuid][region_tag].difference( job_complete_chunk_ids ) # sleep time.sleep(0.05)
@property # TODO: this is a very slow function, but we can't cache it since self.job_chunk_requests changes over time # do not call it more often than necessary def _chunk_to_job_map(self): return {chunk_id: job_uuid for job_uuid, cr_dict in self.job_chunk_requests.items() for chunk_id in cr_dict.keys()} def http_pool_request(self, instance): return self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/chunk_status_log") def _query_chunk_status(self): def get_chunk_status(args): node, instance = args # reply = self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/chunk_status_log") reply = retry_backoff(partial(self.http_pool_request, instance)) if reply.status != 200: raise Exception( f"Failed to get chunk status from gateway instance {instance.instance_name()}: {reply.data.decode('utf-8')}" ) logs = [] for log_entry in json.loads(reply.data.decode("utf-8"))["chunk_status_log"]: log_entry["region_tag"] = node.region_tag log_entry["instance"] = node.gateway_id log_entry["time"] = datetime.fromisoformat(log_entry["time"]) log_entry["state"] = ChunkState.from_str(log_entry["state"]) logs.append(log_entry) return logs rows = [] for result in do_parallel(get_chunk_status, self.dataplane.bound_nodes.items(), n=8, return_args=False): rows.extend(result) return rows @property def is_complete(self, region_tag: str): """Return if the transfer is complete""" return all([len(self.job_pending_chunk_ids[job_uuid][region_tag]) == 0 for job_uuid in self.jobs.keys()])
[docs] def query_bytes_remaining(self, region_tag: Optional[str] = None): """Query the total number of bytes remaining in all the transfer jobs""" if region_tag is None: assert len(list(self.job_pending_chunk_ids.keys())) == 1, "Must specify region_tag if there are multiple regions" region_tag = list(self.job_pending_chunk_ids.keys())[0] if len(self.job_chunk_requests) == 0: return None bytes_remaining_per_job = {} for job_uuid in self.job_pending_chunk_ids.keys(): bytes_remaining_per_job[job_uuid] = sum( [ cr.chunk_length_bytes for cr in self.job_chunk_requests[job_uuid].values() if cr.chunk_id in self.job_pending_chunk_ids[job_uuid][region_tag] ] ) logger.fs.debug(f"[TransferProgressTracker] Bytes remaining per job: {bytes_remaining_per_job}") print(f"[TransferProgressTracker] Bytes remaining per job: {bytes_remaining_per_job}") return sum(bytes_remaining_per_job.values())
[docs] def query_bytes_dispatched(self): """Query the total number of bytes dispatched to chunks ready for transfer""" if len(self.job_chunk_requests) == 0: return 0 bytes_total_per_job = {} for job_uuid in self.job_complete_chunk_ids.keys(): bytes_total_per_job[job_uuid] = sum([cr.chunk_length_bytes for cr in self.job_chunk_requests[job_uuid].values()]) return sum(bytes_total_per_job.values())