import json
import time
from abc import ABC
from datetime import datetime
from threading import Thread
from functools import partial
import urllib3
from typing import TYPE_CHECKING, Dict, List, Optional, Set
from concurrent.futures import ThreadPoolExecutor, as_completed
from skyplane import exceptions
from skyplane.api.config import TransferConfig
from skyplane.chunk import ChunkState, Chunk
from skyplane.utils import logger, imports
from skyplane.utils.fn import do_parallel
from skyplane.api.usage import UsageClient
from skyplane.utils.definitions import GB
from skyplane.utils.retry import retry_backoff
from skyplane.cli.impl.common import print_stats_completed
if TYPE_CHECKING:
from skyplane.api.transfer_job import TransferJob
[docs]
class TransferHook(ABC):
"""Hook that shows transfer related stats"""
[docs]
def on_dispatch_start(self):
"""Starting the dispatch job"""
raise NotImplementedError()
[docs]
def on_chunk_dispatched(self, chunks: List[Chunk]):
"""Dispatching data chunks to transfer"""
raise NotImplementedError()
[docs]
def on_dispatch_end(self):
"""Ending the dispatch job"""
raise NotImplementedError()
[docs]
def on_chunk_completed(self, chunks: List[Chunk], region_tag: Optional[str] = None):
"""Chunks are all transferred"""
raise NotImplementedError()
[docs]
def on_transfer_end(self, transfer_stats):
"""Ending the transfer job"""
raise NotImplementedError()
[docs]
def on_transfer_error(self, error):
"""Showing the tranfer error if it fails"""
raise NotImplementedError()
[docs]
class EmptyTransferHook(TransferHook):
"""Empty transfer hook that does nothing"""
def __init__(self):
return
[docs]
def on_dispatch_start(self):
return
[docs]
def on_chunk_dispatched(self, chunks: List[Chunk]):
return
[docs]
def on_dispatch_end(self):
return
[docs]
def on_chunk_completed(self, chunks: List[Chunk], region_tag: Optional[str] = None):
return
[docs]
def on_transfer_end(self, transfer_stats):
return
[docs]
def on_transfer_error(self, error):
return
[docs]
class TransferProgressTracker(Thread):
"""Tracks transfer progress in one tranfer session"""
def __init__(self, dataplane, jobs: List["TransferJob"], transfer_config: TransferConfig, hooks: TransferHook):
"""
:param dataplane: dataplane that starts the transfer
:type dataplane: Dataplane
:param jobs: list of transfer jobs launched in parallel
:type jobs: List
:param transfer_config: the configuration during the transfer
:type transfer_config: TransferConfig
:param hooks: the hook that shows transfer related stats
:type hooks: TransferHook
"""
super().__init__()
self.dataplane = dataplane
self.jobs = {job.uuid: job for job in jobs}
self.transfer_config = transfer_config
if hooks is None:
self.hooks = EmptyTransferHook()
else:
self.hooks = hooks
# log job details
logger.fs.debug(f"[TransferProgressTracker] Using dataplane {dataplane}")
logger.fs.debug(f"[TransferProgressTracker] Initialized with {len(jobs)} jobs:")
for job_uuid, job in self.jobs.items():
logger.fs.debug(f"[TransferProgressTracker] * {job_uuid}: {job}")
logger.fs.debug(f"[TransferProgressTracker] Transfer config: {transfer_config}")
# transfer state
self.job_chunk_requests: Dict[str, Dict[str, Chunk]] = {}
self.job_pending_chunk_ids: Dict[str, Dict[str, Set[str]]] = {}
self.job_complete_chunk_ids: Dict[str, Dict[str, Set[str]]] = {}
self.errors: Optional[Dict[str, List[str]]] = None
# http_pool
self.http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=9))
def __str__(self):
return f"TransferProgressTracker({self.dataplane}, {self.jobs})"
[docs]
def run(self):
"""Dispatch and start the transfer jobs"""
src_cloud_provider = self.dataplane.topology.src_region_tag.split(":")[0]
dst_cloud_provider = self.dataplane.topology.dest_region_tags[0].split(":")[0]
args = {
"cmd": ",".join([job.__class__.__name__ for job in self.jobs.values()]),
"recursive": ",".join([str(job.recursive) for job in self.jobs.values()]),
"multipart": self.transfer_config.multipart_enabled,
# "instances_per_region": 1, # TODO: read this from config file
# "src_instance_type": getattr(self.transfer_config, f"{src_cloud_provider}_instance_class"),
# "dst_instance_type": #getattr(self.transfer_config, f"{dst_cloud_provider}_instance_class"),
# "src_spot_instance": getattr(self.transfer_config, f"{src_cloud_provider}_use_spot_instances"),
# "dst_spot_instance": getattr(self.transfer_config, f"{dst_cloud_provider}_use_spot_instances"),
}
# TODO: eventually jobs should be able to be concurrently dispatched and executed
# however this will require being able to handle conflicting multipart uploads ids
# initialize everything first
for job_uuid, job in self.jobs.items():
self.job_chunk_requests[job_uuid] = {}
self.job_pending_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags}
self.job_complete_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags}
session_start_timestamp_ms = int(time.time() * 1000)
for job_uuid, job in self.jobs.items():
# pre-dispatch chunks to begin pre-buffering chunks
try:
chunk_stream = job.dispatch(self.dataplane, transfer_config=self.transfer_config)
logger.fs.debug(f"[TransferProgressTracker] Dispatching job {job.uuid}")
for chunk in chunk_stream:
chunks_dispatched = [chunk]
# TODO: check chunk ID
self.job_chunk_requests[job_uuid][chunk.chunk_id] = chunk
assert job_uuid in self.job_chunk_requests and chunk.chunk_id in self.job_chunk_requests[job_uuid]
self.hooks.on_chunk_dispatched(chunks_dispatched)
for region in self.dataplane.topology.dest_region_tags:
self.job_pending_chunk_ids[job_uuid][region].add(chunk.chunk_id)
logger.fs.debug(
f"[TransferProgressTracker] Job {job.uuid} dispatched with {len(self.job_chunk_requests[job_uuid])} chunk requests"
)
except Exception as e:
UsageClient.log_exception(
"dispatch job",
e,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0], # TODO: support multiple destinations
session_start_timestamp_ms,
)
raise e
self.hooks.on_dispatch_end()
def monitor_single_dst_helper(dst_region):
start_time = time.time()
try:
self.monitor_transfer(dst_region)
except exceptions.SkyplaneGatewayException as err:
reformat_err = Exception(err.pretty_print_str()[37:])
UsageClient.log_exception(
"monitor transfer",
reformat_err,
args,
self.dataplane.topology.src_region_tag,
dst_region,
session_start_timestamp_ms,
)
raise err
except Exception as e:
UsageClient.log_exception(
"monitor transfer", e, args, self.dataplane.topology.src_region_tag, dst_region, session_start_timestamp_ms
)
raise e
end_time = time.time()
runtime_s = end_time - start_time
# transfer successfully completed
transfer_stats = {
"dst_region": dst_region,
"total_runtime_s": round(runtime_s, 4),
}
results = []
dest_regions = self.dataplane.topology.dest_region_tags
with ThreadPoolExecutor(max_workers=len(dest_regions)) as executor:
e2e_start_time = time.time()
try:
future_list = [executor.submit(monitor_single_dst_helper, dest) for dest in dest_regions]
for future in as_completed(future_list):
results.append(future.result())
except Exception as e:
raise e
e2e_end_time = time.time()
transfer_stats = {
"total_runtime_s": e2e_end_time - e2e_start_time,
"throughput_gbits": self.query_bytes_dispatched() / (e2e_end_time - e2e_start_time) / GB * 8,
}
self.hooks.on_transfer_end()
int(time.time())
try:
for job in self.jobs.values():
logger.fs.debug(f"[TransferProgressTracker] Finalizing job {job.uuid}")
job.finalize()
except Exception as e:
UsageClient.log_exception(
"finalize job",
e,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0],
session_start_timestamp_ms,
)
raise e
int(time.time())
# verify transfer
try:
for job in self.jobs.values():
logger.fs.debug(f"[TransferProgressTracker] Verifying job {job.uuid}")
job.verify()
except Exception as e:
UsageClient.log_exception(
"verify job",
e,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0],
session_start_timestamp_ms,
)
raise e
# transfer successfully completed
UsageClient.log_transfer(
transfer_stats,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags,
session_start_timestamp_ms,
)
print_stats_completed(total_runtime_s=transfer_stats["total_runtime_s"], throughput_gbits=transfer_stats["throughput_gbits"])
[docs]
@imports.inject("pandas")
def monitor_transfer(pd, self, region_tag):
"""Monitor the tranfer by copying remote gateway logs and show transfer stats by hooks"""
# todo implement transfer monitoring to update job_complete_chunk_ids and job_pending_chunk_ids while the transfer is in progress
# regions that are sinks for specific region tag
# TODO: should eventualy map bucket to list of instances
sinks = [n for nodes in self.dataplane.topology.sink_instances(region_tag).values() for n in nodes]
while any([len(self.job_pending_chunk_ids[job_uuid][region_tag]) > 0 for job_uuid in self.job_pending_chunk_ids]):
# refresh shutdown status by running noop
do_parallel(lambda i: i.run_command("echo 1"), self.dataplane.bound_nodes.values(), n=8)
# check for errors and exit if there are any (while setting debug flags)
errors = self.dataplane.check_error_logs()
if any(errors.values()):
logger.warning("Copying gateway logs...")
self.dataplane.copy_gateway_logs()
self.errors = errors
raise exceptions.SkyplaneGatewayException("Transfer failed with errors", errors)
log_df = pd.DataFrame(self._query_chunk_status())
if log_df.empty:
logger.warning("No chunk status log entries yet")
time.sleep(0.05)
continue
# TODO: have visualization for completition across all destinations
is_complete_rec = (
lambda row: row["state"] == ChunkState.complete
and row["instance"] in [s.gateway_id for s in sinks]
# and row["region_tag"] in region_sinks
)
sink_status_df = log_df[log_df.apply(is_complete_rec, axis=1)]
completed_chunk_ids = list(set(sink_status_df.chunk_id.unique()))
# update job_complete_chunk_ids and job_pending_chunk_ids
# TODO: do chunk-tracking per-destination
for job_uuid, job in self.jobs.items():
try:
job_complete_chunk_ids = set(
chunk_id for chunk_id in completed_chunk_ids if self._chunk_to_job_map[chunk_id] == job_uuid
)
except Exception as e:
raise e
new_chunk_ids = (
self.job_complete_chunk_ids[job_uuid][region_tag]
.union(job_complete_chunk_ids)
.difference(self.job_complete_chunk_ids[job_uuid][region_tag])
)
completed_chunks = []
for id in new_chunk_ids:
assert (
job_uuid in self.job_chunk_requests and id in self.job_chunk_requests[job_uuid]
), f"Missing chunk id {id} for job {job_uuid}: {self.job_chunk_requests}"
for id in new_chunk_ids:
completed_chunks.append(self.job_chunk_requests[job_uuid][id])
self.hooks.on_chunk_completed(completed_chunks, region_tag)
self.job_complete_chunk_ids[job_uuid][region_tag] = self.job_complete_chunk_ids[job_uuid][region_tag].union(
job_complete_chunk_ids
)
self.job_pending_chunk_ids[job_uuid][region_tag] = self.job_pending_chunk_ids[job_uuid][region_tag].difference(
job_complete_chunk_ids
)
# sleep
time.sleep(0.05)
@property
# TODO: this is a very slow function, but we can't cache it since self.job_chunk_requests changes over time
# do not call it more often than necessary
def _chunk_to_job_map(self):
return {chunk_id: job_uuid for job_uuid, cr_dict in self.job_chunk_requests.items() for chunk_id in cr_dict.keys()}
def http_pool_request(self, instance):
return self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/chunk_status_log")
def _query_chunk_status(self):
def get_chunk_status(args):
node, instance = args
# reply = self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/chunk_status_log")
reply = retry_backoff(partial(self.http_pool_request, instance))
if reply.status != 200:
raise Exception(
f"Failed to get chunk status from gateway instance {instance.instance_name()}: {reply.data.decode('utf-8')}"
)
logs = []
for log_entry in json.loads(reply.data.decode("utf-8"))["chunk_status_log"]:
log_entry["region_tag"] = node.region_tag
log_entry["instance"] = node.gateway_id
log_entry["time"] = datetime.fromisoformat(log_entry["time"])
log_entry["state"] = ChunkState.from_str(log_entry["state"])
logs.append(log_entry)
return logs
rows = []
for result in do_parallel(get_chunk_status, self.dataplane.bound_nodes.items(), n=8, return_args=False):
rows.extend(result)
return rows
@property
def is_complete(self, region_tag: str):
"""Return if the transfer is complete"""
return all([len(self.job_pending_chunk_ids[job_uuid][region_tag]) == 0 for job_uuid in self.jobs.keys()])
[docs]
def query_bytes_remaining(self, region_tag: Optional[str] = None):
"""Query the total number of bytes remaining in all the transfer jobs"""
if region_tag is None:
assert len(list(self.job_pending_chunk_ids.keys())) == 1, "Must specify region_tag if there are multiple regions"
region_tag = list(self.job_pending_chunk_ids.keys())[0]
if len(self.job_chunk_requests) == 0:
return None
bytes_remaining_per_job = {}
for job_uuid in self.job_pending_chunk_ids.keys():
bytes_remaining_per_job[job_uuid] = sum(
[
cr.chunk_length_bytes
for cr in self.job_chunk_requests[job_uuid].values()
if cr.chunk_id in self.job_pending_chunk_ids[job_uuid][region_tag]
]
)
logger.fs.debug(f"[TransferProgressTracker] Bytes remaining per job: {bytes_remaining_per_job}")
print(f"[TransferProgressTracker] Bytes remaining per job: {bytes_remaining_per_job}")
return sum(bytes_remaining_per_job.values())
[docs]
def query_bytes_dispatched(self):
"""Query the total number of bytes dispatched to chunks ready for transfer"""
if len(self.job_chunk_requests) == 0:
return 0
bytes_total_per_job = {}
for job_uuid in self.job_complete_chunk_ids.keys():
bytes_total_per_job[job_uuid] = sum([cr.chunk_length_bytes for cr in self.job_chunk_requests[job_uuid].values()])
return sum(bytes_total_per_job.values())