Source code for skyplane.api.transfer_job

import json
import math
import queue
import sys
import threading
import time
import uuid
from abc import ABC
from collections import defaultdict
from dataclasses import dataclass, field
from queue import Queue
from typing import TYPE_CHECKING, Callable, Generator, List, Optional, Tuple, TypeVar

import urllib3
from rich import print as rprint

from skyplane import exceptions
from skyplane.api.config import TransferConfig
from skyplane.chunk import Chunk, ChunkRequest
from skyplane.obj_store.azure_blob_interface import AzureBlobObject
from skyplane.obj_store.file_system_interface import FileSystemInterface
from skyplane.obj_store.gcs_interface import GCSObject
from skyplane.obj_store.object_store_interface import ObjectStoreInterface, ObjectStoreObject
from skyplane.obj_store.s3_interface import S3Object
from skyplane.utils import logger
from skyplane.utils.definitions import MB
from skyplane.utils.fn import do_parallel
from skyplane.utils.path import parse_path
from skyplane.utils.generator import batch_generator, prefetch_generator, tail_generator

if TYPE_CHECKING:
    from skyplane.api.dataplane import Dataplane

T = TypeVar("T")


[docs]class Chunker: """class that chunks the original files and makes the chunk requests""" def __init__( self, src_iface: ObjectStoreInterface or FileSystemInterface, dst_iface: ObjectStoreInterface or FileSystemInterface, transfer_config: TransferConfig, concurrent_multipart_chunk_threads: int = 64, ): """ :param src_iface: source object store interface :type src_iface: ObjectStoreInterface :param dst_iface: destination object store interface :type dst_iface: ObjectStoreInterface :param transfer_config: the configuration during the transfer :type transfer_config: TransferConfig :param concurrent_multipart_chunk_threads: the maximum number of concurrent threads that dispatch multipart chunk requests (default: 64) :type concurrent_multipart_chunk_threads: int """ self.src_iface = src_iface self.dst_iface = dst_iface self.transfer_config = transfer_config self.multipart_upload_requests = [] self.concurrent_multipart_chunk_threads = concurrent_multipart_chunk_threads def _run_multipart_chunk_thread( self, exit_event: threading.Event, in_queue: "Queue[Tuple[ObjectStoreObject, ObjectStoreObject]]", out_queue: "Queue[Chunk]", ): """Chunks large files into many small chunks.""" region = self.dst_iface.region_tag() bucket = self.dst_iface.bucket() while not exit_event.is_set(): try: input_data = in_queue.get(block=False, timeout=0.1) except queue.Empty: continue # get source and destination object and then compute number of chunks src_object, dest_object = input_data mime_type = self.src_iface.get_obj_mime_type(src_object.key) upload_id = self.dst_iface.initiate_multipart_upload(dest_object.key, mime_type=mime_type) chunk_size_bytes = int(self.transfer_config.multipart_chunk_size_mb * MB) num_chunks = math.ceil(src_object.size / chunk_size_bytes) if num_chunks > self.transfer_config.multipart_max_chunks: chunk_size_bytes = int(src_object.size / self.transfer_config.multipart_max_chunks) chunk_size_bytes = math.ceil(chunk_size_bytes / MB) * MB # round to next largest mb num_chunks = math.ceil(src_object.size / chunk_size_bytes) assert num_chunks * chunk_size_bytes >= src_object.size # create chunks offset = 0 part_num = 1 parts = [] for _ in range(num_chunks): file_size_bytes = min(chunk_size_bytes, src_object.size - offset) assert file_size_bytes > 0, f"file size <= 0 {file_size_bytes}" chunk = Chunk( src_key=src_object.key, dest_key=dest_object.key, chunk_id=uuid.uuid4().hex, file_offset_bytes=offset, chunk_length_bytes=file_size_bytes, part_number=part_num, upload_id=upload_id, ) offset += file_size_bytes parts.append(part_num) part_num += 1 out_queue.put(chunk) self.multipart_upload_requests.append(dict(upload_id=upload_id, key=dest_object.key, parts=parts, region=region, bucket=bucket))
[docs] def to_chunk_requests(self, gen_in: Generator[Chunk, None, None]) -> Generator[ChunkRequest, None, None]: """Converts a generator of chunks to a generator of chunk requests. :param gen_in: generator that generates chunk requests :type gen_in: Generator """ src_region = self.src_iface.region_tag() dest_region = self.dst_iface.region_tag() src_bucket = self.src_iface.bucket() dest_bucket = self.dst_iface.bucket() for chunk in gen_in: yield ChunkRequest( chunk=chunk, src_region=src_region, dst_region=dest_region, src_object_store_bucket=src_bucket, dst_object_store_bucket=dest_bucket, src_type="object_store", dst_type="object_store", )
[docs] @staticmethod def map_object_key_prefix(source_prefix: str, source_key: str, dest_prefix: str, recursive: bool = False): """ map_object_key_prefix computes the mapping of a source key in a bucket prefix to the destination. Users invoke a transfer via the CLI; aws s3 cp s3://bucket/source_prefix s3://bucket/dest_prefix. The CLI will query the object store for all objects in the source prefix and map them to the destination prefix using this function. :param source_prefix: source bucket folder prefix :type source_prefix: string :param source_key: source file key to map in the folder prefix :type source_key: string :param destination_prefix: destination bucket folder prefix :type destination_prefix: string :param recursive: whether to copy all the objects matching the pattern (default: False) :type recursive: bool """ join = lambda prefix, fname: prefix + fname if prefix.endswith("/") else prefix + "/" + fname src_fname = source_key.split("/")[-1] if "/" in source_key and not source_key.endswith("/") else source_key if not recursive: if source_key == source_prefix: if dest_prefix == "" or dest_prefix == "/": return src_fname elif dest_prefix[-1] == "/": return dest_prefix + src_fname else: return dest_prefix else: # todo: don't print output here rprint(f"\n:x: [bold red]In order to transfer objects using a prefix, you must use the --recursive or -r flag.[/bold red]") rprint(f"[yellow]If you meant to transfer a single object, pass the full source object key.[/yellow]") rprint(f"[bright_black]Try running: [bold]skyplane {' '.join(sys.argv[1:])} --recursive[/bold][/bright_black]") raise exceptions.MissingObjectException("Encountered a recursive transfer without the --recursive flag.") from None else: if source_prefix == "" or source_prefix == "/": if dest_prefix == "" or dest_prefix == "/": return source_key else: return join(dest_prefix, source_key) else: # catch special case: map_object_key_prefix("foo", "foobar/baz.txt", "", recursive=True) if not source_key.startswith(source_prefix + "/" if not source_prefix.endswith("/") else source_prefix): rprint(f"\n:x: [bold red]The source key {source_key} does not start with the source prefix {source_prefix}[/bold red]") raise exceptions.MissingObjectException(f"Source key {source_key} does not start with source prefix {source_prefix}") if dest_prefix == "" or dest_prefix == "/": return source_key[len(source_prefix) :] else: src_path_after_prefix = source_key[len(source_prefix) :] src_path_after_prefix = src_path_after_prefix[1:] if src_path_after_prefix.startswith("/") else src_path_after_prefix return join(dest_prefix, src_path_after_prefix)
[docs] def transfer_pair_generator( self, src_prefix: str, dst_prefix: str, recursive: bool, prefilter_fn: Optional[Callable[[ObjectStoreObject or FileSystemInterface], bool]] = None, ) -> Generator[Tuple[ObjectStoreObject, ObjectStoreObject], None, None]: """Query source region and return list of objects to transfer. :param src_prefix: source bucket folder prefix :type src_prefix: string :param dst_prefix: destination bucket folder prefix :type dst_prefix: string :param recursive: if true, will copy objects at folder prefix recursively :type recursive: bool :param prefilter_fn: filters out objects whose prefixes do not match the filter function (default: None) :type prefilter_fn: Callable[[ObjectStoreObject], bool] """ if not self.src_iface.bucket_exists(): raise exceptions.MissingBucketException(f"Source bucket {self.src_iface.path()} does not exist or is not readable.") if not self.dst_iface.bucket_exists(): raise exceptions.MissingBucketException(f"Destination bucket {self.dst_iface.path()} does not exist or is not readable.") # query all source region objects logger.fs.debug(f"Querying objects in {self.src_iface.path()}") n_objs = 0 for obj in self.src_iface.list_objects(src_prefix): if prefilter_fn is None or prefilter_fn(obj): try: dest_key = self.map_object_key_prefix(src_prefix, obj.key, dst_prefix, recursive=recursive) except exceptions.MissingObjectException as e: logger.fs.exception(e) raise e from None # make destination object dest_provider, dest_region = self.dst_iface.region_tag().split(":") if dest_provider == "aws": dest_obj = S3Object(dest_provider, self.dst_iface.bucket(), dest_key) elif dest_provider == "azure": dest_obj = AzureBlobObject(dest_provider, self.dst_iface.bucket(), dest_key) elif dest_provider == "gcp": dest_obj = GCSObject(dest_provider, self.dst_iface.bucket(), dest_key) else: raise ValueError(f"Invalid dest_region {dest_region}, unknown provider") n_objs += 1 yield obj, dest_obj if n_objs == 0: logger.error("Specified object does not exist.\n") raise exceptions.MissingObjectException(f"No objects were found in the specified prefix")
[docs] def chunk( self, transfer_pair_generator: Generator[Tuple[ObjectStoreObject, ObjectStoreObject], None, None] ) -> Generator[Chunk, None, None]: """Break transfer list into chunks. :param transfer_pair_generator: generator of pairs of objects to transfer :type transfer_pair_generator: Generator """ multipart_send_queue: Queue[Tuple[ObjectStoreObject, ObjectStoreObject]] = Queue() multipart_chunk_queue: Queue[Chunk] = Queue() multipart_exit_event = threading.Event() multipart_chunk_threads = [] # start chunking threads if self.transfer_config.multipart_enabled: for _ in range(self.concurrent_multipart_chunk_threads): t = threading.Thread( target=self._run_multipart_chunk_thread, args=(multipart_exit_event, multipart_send_queue, multipart_chunk_queue), daemon=False, ) t.start() multipart_chunk_threads.append(t) # begin chunking loop for src_obj, dst_obj in transfer_pair_generator: if self.transfer_config.multipart_enabled and src_obj.size > self.transfer_config.multipart_threshold_mb * MB: multipart_send_queue.put((src_obj, dst_obj)) else: yield Chunk( src_key=src_obj.key, dest_key=dst_obj.key, chunk_id=uuid.uuid4().hex, chunk_length_bytes=src_obj.size, ) if self.transfer_config.multipart_enabled: # drain multipart chunk queue and yield with updated chunk IDs while not multipart_chunk_queue.empty(): yield multipart_chunk_queue.get() if self.transfer_config.multipart_enabled: # send sentinel to all threads multipart_exit_event.set() for thread in multipart_chunk_threads: thread.join() # drain multipart chunk queue and yield with updated chunk IDs while not multipart_chunk_queue.empty(): yield multipart_chunk_queue.get()
[docs]@dataclass class TransferJob(ABC): """ transfer job with transfer configurations :param src_path: source full path :type src_path: str :param dst_path: destination full path :type dst_path: str :param recursive: if true, will transfer objects at folder prefix recursively (default: False) :type recursive: bool :param requester_pays: if set, will support requester pays buckets. (default: False) :type requester_pays: bool :param uuid: the uuid of one single transfer job :type uuid: str """ src_path: str dst_path: str recursive: bool = False requester_pays: bool = False uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())) @property def src_prefix(self) -> Optional[str]: """Return the source prefix""" if not hasattr(self, "_src_prefix"): self._src_prefix = parse_path(self.src_path)[2] return self._src_prefix @property def src_iface(self) -> ObjectStoreInterface or FileSystemInterface: """Return the source object store interface""" if not hasattr(self, "_src_iface"): provider_src, bucket_src, path_src = parse_path(self.src_path) if provider_src in ("local", "nfs"): self._src_iface = FileSystemInterface.create(f"{provider_src}:infer", path_src) else: self._src_iface = ObjectStoreInterface.create(f"{provider_src}:infer", bucket_src) if self.requester_pays: self._src_iface.set_requester_bool(True) return self._src_iface @property def dst_prefix(self) -> Optional[str]: """Return the destination prefix""" if not hasattr(self, "_dst_prefix"): self._dst_prefix = parse_path(self.dst_path)[2] return self._dst_prefix @property def dst_iface(self) -> ObjectStoreInterface: """Return the destination object store interface""" if not hasattr(self, "_dst_iface"): provider_dst, bucket_dst, _ = parse_path(self.dst_path) self._dst_iface = ObjectStoreInterface.create(f"{provider_dst}:infer", bucket_dst) return self._dst_iface
[docs] def dispatch(self, dataplane: "Dataplane", **kwargs) -> Generator[ChunkRequest, None, None]: """Dispatch transfer job to specified gateways.""" raise NotImplementedError("Dispatch not implemented")
[docs] def finalize(self): """Complete the multipart upload requests""" raise NotImplementedError("Finalize not implemented")
[docs] def verify(self): """Verifies the transfer completed, otherwise raises TransferFailedException.""" raise NotImplementedError("Verify not implemented")
def gen_transfer_pairs( self, chunker: Optional[Chunker] = None ) -> Generator[Tuple[ObjectStoreObject or FileSystemInterface, ObjectStoreObject or FileSystemInterface], None, None]: raise NotImplementedError("Generate transfer pairs not implemented")
[docs] def size_gb(self): """Return the size of the transfer in GB""" total_size = 0 for src_obj, _ in self.gen_transfer_pairs(): total_size += src_obj.size return total_size / 1e9
@classmethod def _pre_filter_fn(cls, obj: ObjectStoreObject) -> bool: """Optionally filter source objects before they are transferred. :meta private: :param obj: source object to be transferred :type obj: ObjectStoreObject """ return True
[docs]@dataclass class CopyJob(TransferJob): """copy job that copies the source objects to the destination :param transfer_list: transfer list for later verification :type transfer_list: list :param multipart_transfer_list: multipart transfer list for later verification :type multipart_transfer_list: list """ transfer_list: list = field(default_factory=list) multipart_transfer_list: list = field(default_factory=list) @property def http_pool(self): """http connection pool""" if not hasattr(self, "_http_pool"): self._http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=3)) return self._http_pool
[docs] def gen_transfer_pairs( self, chunker: Optional[Chunker] = None ) -> Generator[Tuple[ObjectStoreObject or FileSystemInterface, ObjectStoreObject or FileSystemInterface], None, None]: """Generate transfer pairs for the transfer job. :param chunker: chunker that makes the chunk requests :type chunker: Chunker """ if chunker is None: # used for external access to transfer pair list logger.fs.debug("Generating transfer pairs for external access, {} -> {}".format(self.src_iface, self.dst_iface)) chunker = Chunker(self.src_iface, self.dst_iface, TransferConfig()) yield from chunker.transfer_pair_generator(self.src_prefix, self.dst_prefix, self.recursive, self._pre_filter_fn)
[docs] def dispatch( self, dataplane: "Dataplane", transfer_config: TransferConfig, dispatch_batch_size: int = 100, # 6.4 GB worth of chunks ) -> Generator[ChunkRequest, None, None]: """Dispatch transfer job to specified gateways. :param dataplane: dataplane that starts the transfer job :type dataplane: Dataplane :param transfer_config: the configuration during the transfer :type transfer_config: TransferConfig :param dispatch_batch_size: maximum size of the buffer to temporarily store the generators (default: 1000) :type dispatch_batch_size: int """ chunker = Chunker(self.src_iface, self.dst_iface, transfer_config) transfer_pair_generator = self.gen_transfer_pairs(chunker) gen_transfer_list = tail_generator(transfer_pair_generator, self.transfer_list) chunks = chunker.chunk(gen_transfer_list) chunk_requests = chunker.to_chunk_requests(chunks) batches = batch_generator(prefetch_generator(chunk_requests, buffer_size=dispatch_batch_size * 32), batch_size=dispatch_batch_size) # dispatch chunk requests src_gateways = dataplane.source_gateways() bytes_dispatched = [0] * len(src_gateways) n_multiparts = 0 start = time.time() for batch in batches: end = time.time() logger.fs.debug(f"Queried {len(batch)} chunks in {end - start:.2f} seconds") start = time.time() min_idx = bytes_dispatched.index(min(bytes_dispatched)) server = src_gateways[min_idx] n_bytes = sum([cr.chunk.chunk_length_bytes for cr in batch]) bytes_dispatched[min_idx] += n_bytes start = time.time() reply = self.http_pool.request( "POST", f"{server.gateway_api_url}/api/v1/chunk_requests", body=json.dumps([c.as_dict() for c in batch]).encode("utf-8"), headers={"Content-Type": "application/json"}, ) end = time.time() if reply.status != 200: raise Exception(f"Failed to dispatch chunk requests {server.instance_name()}: {reply.data.decode('utf-8')}") logger.fs.debug( f"Dispatched {len(batch)} chunk requests to {server.instance_name()} ({n_bytes} bytes) in {end - start:.2f} seconds" ) yield from batch # copy new multipart transfers to the multipart transfer list updated_len = len(chunker.multipart_upload_requests) self.multipart_transfer_list.extend(chunker.multipart_upload_requests[n_multiparts:updated_len]) n_multiparts = updated_len
[docs] def finalize(self): """Complete the multipart upload requests""" groups = defaultdict(list) for req in self.multipart_transfer_list: if "region" not in req or "bucket" not in req: raise Exception(f"Invalid multipart upload request: {req}") groups[(req["region"], req["bucket"])].append(req) for key, group in groups.items(): region, bucket = key batch_len = max(1, len(group) // 128) batches = [group[i : i + batch_len] for i in range(0, len(group), batch_len)] obj_store_interface = ObjectStoreInterface.create(region, bucket) def complete_fn(batch): for req in batch: obj_store_interface.complete_multipart_upload(req["key"], req["upload_id"]) do_parallel(complete_fn, batches, n=-1)
[docs] def verify(self): """Verify the integrity of the transfered destination objects""" dst_keys = {dst_o.key: src_o for src_o, dst_o in self.transfer_list} for obj in self.dst_iface.list_objects(self.dst_prefix): # check metadata (src.size == dst.size) && (src.modified <= dst.modified) src_obj = dst_keys.get(obj.key) if src_obj and src_obj.size == obj.size and src_obj.last_modified <= obj.last_modified: del dst_keys[obj.key] if dst_keys: failed_keys = [obj.key for obj in dst_keys.values()] raise exceptions.TransferFailedException(f"{len(dst_keys)} objects failed verification {failed_keys}")
[docs]@dataclass class SyncJob(CopyJob): """sync job that copies the source objects that does not exist in the destination bucket to the destination"""
[docs] def gen_transfer_pairs( self, chunker: Optional[Chunker] = None ) -> Generator[Tuple[ObjectStoreObject or FileSystemInterface, ObjectStoreObject or FileSystemInterface], None, None]: """Generate transfer pairs for the transfer job. :param chunker: chunker that makes the chunk requests :type chunker: Chunker """ if chunker is None: # used for external access to transfer pair list chunker = Chunker(self.src_iface, self.dst_iface, TransferConfig()) transfer_pair_gen = chunker.transfer_pair_generator(self.src_prefix, self.dst_prefix, self.recursive, self._pre_filter_fn) # enrich destination objects with metadata for src_obj, dest_obj in self._enrich_dest_objs(transfer_pair_gen, self.dst_prefix): if self._post_filter_fn(src_obj, dest_obj): yield src_obj, dest_obj
def _enrich_dest_objs( self, transfer_pairs: Generator[Tuple[ObjectStoreObject, ObjectStoreObject], None, None], dest_prefix: str ) -> Generator[Tuple[ObjectStoreObject, ObjectStoreObject], None, None]: """ For skyplane sync, we enrich dest obj metadata with our existing dest obj metadata from the dest bucket following a query. :meta private: :param transfer_pairs: generator of transfer pairs :type transfer_pairs: Generator """ logger.fs.debug(f"Querying objects in {self.dst_iface.bucket()}") if not hasattr(self, "_found_dest_objs"): self._found_dest_objs = {obj.key: obj for obj in self.dst_iface.list_objects(dest_prefix)} for src_obj, dest_obj in transfer_pairs: if dest_obj.key in self._found_dest_objs: dest_obj.size = self._found_dest_objs[dest_obj.key].size dest_obj.last_modified = self._found_dest_objs[dest_obj.key].last_modified yield src_obj, dest_obj @classmethod def _post_filter_fn(cls, src_obj: ObjectStoreObject, dest_obj: ObjectStoreObject) -> bool: """Optionally filter destination objects after they are transferred. :param src_obj: source object to be transferred :type src_obj: ObjectStoreObject :param dest_obj: destination object transferred :type dest_obj: ObjectStoreObject """ return not dest_obj.exists or (src_obj.last_modified > dest_obj.last_modified or src_obj.size != dest_obj.size)