Source code for skyplane.replicate.replicator_client

import json
import pickle
import time
import uuid
import math
from datetime import datetime
from functools import partial
from typing import Dict, List, Optional, Tuple, Iterable
import nacl.secret
import nacl.utils

import pandas as pd
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeRemainingColumn, DownloadColumn, BarColumn, TransferSpeedColumn
import urllib3

from skyplane import GB, MB, gateway_docker_image, tmp_log_dir
from skyplane import exceptions
from skyplane.chunk import Chunk, ChunkRequest, ChunkState
from skyplane.compute.aws.aws_cloud_provider import AWSCloudProvider
from skyplane.compute.azure.azure_cloud_provider import AzureCloudProvider
from skyplane.compute.cloud_providers import CloudProvider
from skyplane.compute.gcp.gcp_cloud_provider import GCPCloudProvider
from skyplane.compute.server import Server, ServerState
from skyplane.obj_store.object_store_interface import ObjectStoreInterface
from skyplane.replicate.profiler import status_df_to_traceevent
from skyplane.replicate.replication_plan import ReplicationJob, ReplicationTopology, ReplicationTopologyGateway
from skyplane.utils import logger
from skyplane.utils.fn import PathLike, do_parallel
from skyplane.utils.timer import Timer


[docs]def refresh_instance_list(provider: CloudProvider, region_list: Iterable[str] = (), instance_filter=None, n=-1) -> Dict[str, List[Server]]: if instance_filter is None: instance_filter = {"tags": {"skyplane": "true"}} results = do_parallel( lambda region: provider.get_matching_instances(region=region, **instance_filter), region_list, spinner=True, n=n, desc="Querying clouds for active instances", ) return {r: ilist for r, ilist in results if ilist}
[docs]class ReplicatorClient: def __init__( self, topology: ReplicationTopology, gateway_docker_image: str = gateway_docker_image(), aws_instance_class: Optional[str] = "m5.4xlarge", # set to None to disable AWS azure_instance_class: Optional[str] = "Standard_D2_v5", # set to None to disable Azure gcp_instance_class: Optional[str] = "n2-standard-16", # set to None to disable GCP gcp_use_premium_network: bool = True, ): self.http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=3)) self.topology = topology self.gateway_docker_image = gateway_docker_image self.aws_instance_class = aws_instance_class self.azure_instance_class = azure_instance_class self.gcp_instance_class = gcp_instance_class self.gcp_use_premium_network = gcp_use_premium_network # provisioning self.aws = AWSCloudProvider() self.azure = AzureCloudProvider() self.gcp = GCPCloudProvider() self.bound_nodes: Dict[ReplicationTopologyGateway, Server] = {} self.temp_nodes: List[Server] = [] # saving nodes that are not yet bound so they can be deprovisioned later # logging self.transfer_dir = tmp_log_dir / "transfer_logs" / datetime.now().strftime("%Y%m%d_%H%M%S") self.transfer_dir.mkdir(exist_ok=True, parents=True) logger.open_log_file(self.transfer_dir / "client.log") # upload requests self.multipart_upload_requests = []
[docs] def provision_gateways( self, reuse_instances=False, log_dir: Optional[PathLike] = None, authorize_ssh_pub_key: Optional[PathLike] = None, use_bbr=True, use_compression=True, use_e2ee=True, use_socket_tls=False, ): regions_to_provision = [node.region for node in self.topology.gateway_nodes] aws_regions_to_provision = [r for r in regions_to_provision if r.startswith("aws:")] azure_regions_to_provision = [r for r in regions_to_provision if r.startswith("azure:")] gcp_regions_to_provision = [r for r in regions_to_provision if r.startswith("gcp:")] assert ( len(aws_regions_to_provision) == 0 or self.aws.auth.enabled() ), "AWS credentials not configured but job provisions AWS gateways" assert ( len(azure_regions_to_provision) == 0 or self.azure.auth.enabled() ), "Azure credentials not configured but job provisions Azure gateways" assert ( len(gcp_regions_to_provision) == 0 or self.gcp.auth.enabled() ), "GCP credentials not configured but job provisions GCP gateways" # reuse existing AWS instances if reuse_instances: if self.aws.auth.enabled(): aws_instance_filter = { "tags": {"skyplane": "true"}, "instance_type": self.aws_instance_class, "state": [ServerState.PENDING, ServerState.RUNNING], } current_aws_instances = refresh_instance_list( self.aws, set([r.split(":")[1] for r in aws_regions_to_provision]), aws_instance_filter ) for r, ilist in current_aws_instances.items(): for i in ilist: if f"aws:{r}" in aws_regions_to_provision: aws_regions_to_provision.remove(f"aws:{r}") else: current_aws_instances = {} if self.azure.auth.enabled(): azure_instance_filter = { "tags": {"skyplane": "true"}, "instance_type": self.azure_instance_class, "state": [ServerState.PENDING, ServerState.RUNNING], } current_azure_instances = refresh_instance_list( self.azure, set([r.split(":")[1] for r in azure_regions_to_provision]), azure_instance_filter ) for r, ilist in current_azure_instances.items(): for i in ilist: if f"azure:{r}" in azure_regions_to_provision: azure_regions_to_provision.remove(f"azure:{r}") else: current_azure_instances = {} if self.gcp.auth.enabled(): gcp_instance_filter = { "tags": {"skyplane": "true"}, "instance_type": self.gcp_instance_class, "state": [ServerState.PENDING, ServerState.RUNNING], } current_gcp_instances = refresh_instance_list( self.gcp, set([r.split(":")[1] for r in gcp_regions_to_provision]), gcp_instance_filter ) for r, ilist in current_gcp_instances.items(): for i in ilist: if f"gcp:{r}" in gcp_regions_to_provision: gcp_regions_to_provision.remove(f"gcp:{r}") else: current_gcp_instances = {} # init clouds jobs = [] if aws_regions_to_provision: jobs.append(partial(self.aws.create_iam, attach_policy_arn="arn:aws:iam::aws:policy/AmazonS3FullAccess")) for r in set(aws_regions_to_provision): def init_aws_vpc(r): self.aws.make_vpc(r) self.aws.authorize_client(r, "0.0.0.0/0") jobs.append(partial(init_aws_vpc, r.split(":")[1])) jobs.append(partial(self.aws.ensure_keyfile_exists, r.split(":")[1])) if azure_regions_to_provision: jobs.append(self.azure.create_ssh_key) jobs.append(self.azure.set_up_resource_group) if gcp_regions_to_provision: jobs.append(self.gcp.create_ssh_key) jobs.append(self.gcp.configure_skyplane_network) jobs.append(self.gcp.configure_skyplane_firewall) do_parallel(lambda fn: fn(), jobs, spinner=True, spinner_persist=True, desc="Initializing cloud keys") # provision instances def provision_gateway_instance(region: str) -> Server: provider, subregion = region.split(":") if provider == "aws": assert self.aws.auth.enabled() server = self.aws.provision_instance(subregion, self.aws_instance_class) elif provider == "azure": assert self.azure.auth.enabled() server = self.azure.provision_instance(subregion, self.azure_instance_class) elif provider == "gcp": assert self.gcp.auth.enabled() # todo specify network tier in ReplicationTopology server = self.gcp.provision_instance(subregion, self.gcp_instance_class, premium_network=self.gcp_use_premium_network) else: raise NotImplementedError(f"Unknown provider {provider}") server.enable_auto_shutdown() self.temp_nodes.append(server) return server results = do_parallel( provision_gateway_instance, list(aws_regions_to_provision + azure_regions_to_provision + gcp_regions_to_provision), spinner=True, spinner_persist=True, desc="Provisioning gateway instances", ) instances_by_region = { r: [instance for instance_region, instance in results if instance_region == r] for r in set(regions_to_provision) } # add existing instances if reuse_instances: for r, ilist in current_aws_instances.items(): if f"aws:{r}" not in instances_by_region: instances_by_region[f"aws:{r}"] = [] instances_by_region[f"aws:{r}"].extend(ilist) self.temp_nodes.extend(ilist) for r, ilist in current_azure_instances.items(): if f"azure:{r}" not in instances_by_region: instances_by_region[f"azure:{r}"] = [] instances_by_region[f"azure:{r}"].extend(ilist) self.temp_nodes.extend(ilist) for r, ilist in current_gcp_instances.items(): if f"gcp:{r}" not in instances_by_region: instances_by_region[f"gcp:{r}"] = [] instances_by_region[f"gcp:{r}"].extend(ilist) self.temp_nodes.extend(ilist) # bind instances to nodes for node in self.topology.gateway_nodes: instance = instances_by_region[node.region].pop() self.bound_nodes[node] = instance self.temp_nodes.remove(instance) # Firewall rules # todo add firewall rules for Azure public_ips = [self.bound_nodes[n].public_ip() for n in self.topology.gateway_nodes] authorize_ip_jobs = [] authorize_ip_jobs.extend( [partial(self.aws.add_ips_to_security_group, r.split(":")[1], public_ips) for r in set(aws_regions_to_provision)] ) if gcp_regions_to_provision: authorize_ip_jobs.append(partial(self.gcp.add_ips_to_firewall, public_ips)) do_parallel(lambda fn: fn(), authorize_ip_jobs, spinner=True, desc="Applying firewall rules") # generate E2EE key if use_e2ee: e2ee_key_bytes = nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE) else: e2ee_key_bytes = None # setup instances def setup(args: Tuple[Server, Dict[str, int], bool, bool]): server, outgoing_ports, am_source, am_sink = args if log_dir: server.init_log_files(log_dir) if authorize_ssh_pub_key: server.copy_public_key(authorize_ssh_pub_key) server.start_gateway( outgoing_ports, gateway_docker_image=self.gateway_docker_image, use_bbr=use_bbr, use_compression=use_compression, e2ee_key_bytes=e2ee_key_bytes if (am_source or am_sink) else None, use_socket_tls=use_socket_tls, ) args = [] sources = self.topology.source_instances() sinks = self.topology.sink_instances() for node, server in self.bound_nodes.items(): setup_args = { self.bound_nodes[n].public_ip(): v for n, v in self.topology.get_outgoing_paths(node).items() if isinstance(n, ReplicationTopologyGateway) } args.append((server, setup_args, node in sources, node in sinks)) do_parallel(setup, args, n=-1, spinner=True, spinner_persist=True, desc="Installing gateway package")
[docs] def deprovision_gateways(self): # This is a good place to tear down Security Groups and the instance since this is invoked by CLI too. def deprovision_gateway_instance(server: Server): if server.instance_state() == ServerState.RUNNING: server.terminate_instance() logger.fs.warning(f"Deprovisioned {server.uuid()}") # Clear IPs from security groups # todo remove firewall rules for Azure public_ips = [i.public_ip() for i in self.bound_nodes.values()] + [i.public_ip() for i in self.temp_nodes] aws_regions = [node.region for node in self.topology.gateway_nodes if node.region.startswith("aws:")] aws_jobs = [partial(self.aws.remove_ips_from_security_group, r.split(":")[1], public_ips) for r in set(aws_regions)] gcp_regions = [node.region for node in self.topology.gateway_nodes if node.region.startswith("gcp:")] gcp_jobs = [partial(self.gcp.remove_ips_from_firewall, public_ips)] if gcp_regions else [] do_parallel(lambda fn: fn(), aws_jobs + gcp_jobs, desc="Removing firewall rules") # Terminate instances instances = list(self.bound_nodes.values()) + self.temp_nodes logger.fs.warning(f"Deprovisioning {len(instances)} instances") if any(i.provider == "azure" for i in instances): logger.warning( f"NOTE: Azure is very slow to terminate instances. Consider using --reuse-instances and then deprovisioning the instances manually with `skyplane deprovision`." ) do_parallel(deprovision_gateway_instance, instances, n=-1, spinner=True, spinner_persist=True, desc="Deprovisioning instances") self.temp_nodes = [] logger.fs.info("Deprovisioned instances")
[docs] def run_replication_plan( self, job: ReplicationJob, multipart_enabled: bool, multipart_min_threshold_mb: int, multipart_min_size_mb: int, multipart_max_chunks: int, ) -> ReplicationJob: assert job.source_region.split(":")[0] in [ "aws", "azure", "gcp", ], f"Only AWS, Azure, and GCP are supported, but got {job.source_region}" assert job.dest_region.split(":")[0] in [ "aws", "azure", "gcp", ], f"Only AWS, Azure, and GCP are supported, but got {job.dest_region}" with Progress( SpinnerColumn(), TextColumn("Preparing replication plan{task.description}"), transient=True, ) as progress: prepare_task = progress.add_task("", total=None) # pre-fetch instance IPs for all gateways progress.update(prepare_task, description=": Fetching instance IPs") gateway_ips: Dict[Server, str] = {s: s.public_ip() for s in self.bound_nodes.values()} # make list of chunks progress.update(prepare_task, description=": Creating list of chunks for transfer") chunks = [] idx = 0 for (src_object, dest_object) in job.transfer_pairs: if job.random_chunk_size_mb: chunks.append( Chunk( src_key=src_object.key, dest_key=dest_object.key, chunk_id=idx, file_offset_bytes=0, chunk_length_bytes=job.random_chunk_size_mb * MB, ) ) idx += 1 elif multipart_enabled and src_object.size > multipart_min_threshold_mb * MB: # determine number of chunks via the following algorithm: chunk_size_bytes = int(multipart_min_size_mb * MB) num_chunks = math.ceil(src_object.size / chunk_size_bytes) if num_chunks > multipart_max_chunks: chunk_size_bytes = int(src_object.size / multipart_max_chunks) chunk_size_bytes = math.ceil(chunk_size_bytes / MB) * MB # round to next largest MB num_chunks = math.ceil(src_object.size / chunk_size_bytes) # TODO: potentially do this in a seperate thread, and/or after chunks sent obj_store_interface = ObjectStoreInterface.create(job.dest_region, job.dest_bucket) logger.fs.info(f"Initiate multipart upload on {dest_object}") upload_id = obj_store_interface.initiate_multipart_upload(dest_object.key) offset = 0 part_num = 1 parts = [] for chunk in range(num_chunks): # size is min(chunk_size, remaining data) file_size_bytes = min(chunk_size_bytes, src_object.size - offset) assert file_size_bytes > 0, f"File size <= 0 {file_size_bytes}" chunks.append( Chunk( src_key=src_object.key, dest_key=dest_object.key, chunk_id=idx, file_offset_bytes=offset, chunk_length_bytes=file_size_bytes, part_number=part_num, upload_id=upload_id, ) ) parts.append(part_num) idx += 1 part_num += 1 offset += chunk_size_bytes # add multipart upload request self.multipart_upload_requests.append( { "region": job.dest_region, "bucket": job.dest_bucket, "upload_id": upload_id, "key": dest_object.key, "parts": parts, } ) # transfer entire object else: chunk = Chunk( src_key=src_object.key, dest_key=dest_object.key, chunk_id=idx, file_offset_bytes=0, chunk_length_bytes=src_object.size, ) chunks.append(chunk) idx += 1 # partition chunks into roughly equal-sized batches (by bytes) def partition(items: List[Chunk], n_batches: int) -> List[List[Chunk]]: batches = [[] for _ in range(n_batches)] items.sort(key=lambda c: c.chunk_length_bytes, reverse=True) batch_sizes = [0 for _ in range(n_batches)] for item in items: min_batch = batch_sizes.index(min(batch_sizes)) batches[min_batch].append(item) batch_sizes[min_batch] += item.chunk_length_bytes return batches progress.update(prepare_task, description=": Partitioning chunks into batches") src_instances = [self.bound_nodes[n] for n in self.topology.source_instances()] chunk_batches = partition(chunks, len(src_instances)) assert (len(chunk_batches) == (len(src_instances) - 1)) or ( len(chunk_batches) == len(src_instances) ), f"{len(chunk_batches)} batches, expected {len(src_instances)}" for batch_idx, batch in enumerate(chunk_batches): logger.fs.info(f"Batch {batch_idx} size: {sum(c.chunk_length_bytes for c in batch)} with {len(batch)} chunks") # make list of ChunkRequests with Timer("Building chunk requests"): # make list of ChunkRequests progress.update(prepare_task, description=": Building list of chunk requests") chunk_requests_sharded: Dict[int, List[ChunkRequest]] = {} for batch_idx, batch in enumerate(chunk_batches): chunk_requests_sharded[batch_idx] = [] for chunk in batch: chunk_requests_sharded[batch_idx].append( ChunkRequest( chunk=chunk, src_region=job.source_region, dst_region=job.dest_region, src_type="object_store" if job.dest_bucket else "random", dst_type="object_store" if job.dest_bucket else "save_local", src_random_size_mb=job.random_chunk_size_mb, src_object_store_bucket=job.source_bucket, dst_object_store_bucket=job.dest_bucket, ) ) logger.fs.debug(f"Batch {batch_idx} size: {sum(c.chunk_length_bytes for c in batch)} with {len(batch)} chunks") # send chunk requests to start gateways in parallel progress.update(prepare_task, description=": Dispatching chunk requests to source gateways") def send_chunk_requests(args: Tuple[Server, List[ChunkRequest]]): hop_instance, chunk_requests = args while chunk_requests: batch, chunk_requests = chunk_requests[: 1024 * 16], chunk_requests[1024 * 16 :] reply = self.http_pool.request( "POST", f"{hop_instance.gateway_api_url}/api/v1/chunk_requests", body=json.dumps([c.as_dict() for c in batch]).encode("utf-8"), headers={"Content-Type": "application/json"}, ) if reply.status != 200: raise Exception( f"Failed to send chunk requests to gateway instance {hop_instance.instance_name()}: {reply.data.decode('utf-8')}" ) logger.fs.debug( f"Sent {len(batch)} chunk requests to {hop_instance.instance_name()}, {len(chunk_requests)} remaining" ) start_instances = list(zip(src_instances, chunk_requests_sharded.values())) do_parallel(send_chunk_requests, start_instances, n=-1) job.chunk_requests = [cr for crlist in chunk_requests_sharded.values() for cr in crlist] return job
[docs] def get_chunk_status_log_df(self) -> pd.DataFrame: def get_chunk_status(args): node, instance = args reply = self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/chunk_status_log") if reply.status != 200: raise Exception( f"Failed to get chunk status from gateway instance {instance.instance_name()}: {reply.data.decode('utf-8')}" ) logs = [] for log_entry in json.loads(reply.data.decode("utf-8"))["chunk_status_log"]: log_entry["region"] = node.region log_entry["instance"] = node.instance log_entry["time"] = datetime.fromisoformat(log_entry["time"]) log_entry["state"] = ChunkState.from_str(log_entry["state"]) logs.append(log_entry) return logs rows = [] for result in do_parallel(get_chunk_status, self.bound_nodes.items(), n=-1, return_args=False): rows.extend(result) return pd.DataFrame(rows)
[docs] def check_error_logs(self) -> Dict[str, List[str]]: def get_error_logs(args): _, instance = args reply = self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/errors") if reply.status != 200: raise Exception(f"Failed to get error logs from gateway instance {instance.instance_name()}: {reply.data.decode('utf-8')}") return json.loads(reply.data.decode("utf-8"))["errors"] errors: Dict[str, List[str]] = {} for (_, instance), result in do_parallel(get_error_logs, self.bound_nodes.items(), n=-1): errors[instance] = result return errors
[docs] def monitor_transfer( self, job: ReplicationJob, show_spinner=False, log_interval_s: Optional[float] = None, time_limit_seconds: Optional[float] = None, cleanup_gateway: bool = True, save_log: bool = True, write_profile: bool = False, write_socket_profile: bool = False, # slow but useful for debugging copy_gateway_logs: bool = False, multipart: bool = False, # multipart object uploads/downloads ) -> Optional[Dict]: assert job.chunk_requests is not None total_bytes = sum([cr.chunk.chunk_length_bytes for cr in job.chunk_requests]) last_log = None sources = self.topology.source_instances() source_regions = set(s.region for s in sources) sinks = self.topology.sink_instances() sink_regions = set(s.region for s in sinks) completed_chunk_ids = [] if save_log: (self.transfer_dir / "job.pkl").write_bytes(pickle.dumps(job)) try: with Progress( SpinnerColumn(), TextColumn("Transfer progress{task.description}"), BarColumn(), DownloadColumn(binary_units=True), TransferSpeedColumn(), TimeRemainingColumn(), disable=not show_spinner, ) as progress: copy_task = progress.add_task("", total=total_bytes) with Timer() as t: while True: # refresh shutdown status by running noop do_parallel(lambda i: i.run_command("echo 1"), self.bound_nodes.values(), n=-1) # check for errors and exit if there are any (while setting debug flags) errors = self.check_error_logs() if any(errors.values()): copy_gateway_logs = True write_profile = True write_socket_profile = True return { "errors": errors, "monitor_status": "error", } log_df = self.get_chunk_status_log_df() if log_df.empty: logger.warning("No chunk status log entries yet") time.sleep(0.5) continue is_complete_rec = ( lambda row: row["state"] == ChunkState.upload_complete and row["instance"] in [s.instance for s in sinks] and row["region"] in [s.region for s in sinks] ) sink_status_df = log_df[log_df.apply(is_complete_rec, axis=1)] completed_status = sink_status_df.groupby("chunk_id").apply( lambda x: set(x["region"].unique()) == set(sink_regions) ) completed_chunk_ids = completed_status[completed_status].index completed_bytes = sum( [cr.chunk.chunk_length_bytes for cr in job.chunk_requests if cr.chunk.chunk_id in completed_chunk_ids] ) # update progress bar total_runtime_s = (log_df.time.max() - log_df.time.min()).total_seconds() throughput_gbits = completed_bytes * 8 / GB / total_runtime_s if total_runtime_s > 0 else 0.0 # make log line progress.update( copy_task, description=f" ({len(completed_chunk_ids)} of {len(job.chunk_requests)} chunks)", completed=completed_bytes, ) if len(completed_chunk_ids) == len(job.chunk_requests): if multipart: # Complete multi-part uploads def complete_upload(req): obj_store_interface = ObjectStoreInterface.create(req["region"], req["bucket"]) succ = obj_store_interface.complete_multipart_upload(req["key"], req["upload_id"]) if not succ: raise ValueError(f"Failed to complete upload {req['upload_id']}") do_parallel( complete_upload, self.multipart_upload_requests, n=-1, desc="Completing multipart uploads", spinner=False, ) return dict( completed_chunk_ids=completed_chunk_ids, total_runtime_s=total_runtime_s, throughput_gbits=throughput_gbits, monitor_status="completed", ) elif time_limit_seconds is not None and t.elapsed > time_limit_seconds or t.elapsed > 600 and completed_bytes == 0: logger.error("Transfer timed out without progress, please check the debug log!") logger.fs.error("Transfer timed out! Please retry.") logger.error(f"Please share debug logs from: {self.transfer_dir}") return dict( completed_chunk_ids=completed_chunk_ids, total_runtime_s=total_runtime_s, throughput_gbits=throughput_gbits, monitor_status="timed_out", ) else: current_time = datetime.now() if log_interval_s and (not last_log or (current_time - last_log).seconds > float(log_interval_s)): last_log = current_time time.sleep(0.01 if show_spinner else 0.25) # always run cleanup, even if there's an exception finally: with Progress( SpinnerColumn(), TextColumn("Cleaning up after transfer{task.description}"), transient=True, ) as progress: cleanup_task = progress.add_task("", total=None) # get compression ratio information from destination gateways using "/api/v1/profile/compression" progress.update(cleanup_task, description=": Getting compression ratio information") total_sent_compressed, total_sent_uncompressed = 0, 0 for gateway in {v for v in self.bound_nodes.values() if v.region_tag in source_regions}: stats = self.http_pool.request("GET", f"{gateway.gateway_api_url}/api/v1/profile/compression") if stats.status == 200: stats = json.loads(stats.data.decode("utf-8")) total_sent_compressed += stats.get("compressed_bytes_sent", 0) total_sent_uncompressed += stats.get("uncompressed_bytes_sent", 0) compression_ratio = total_sent_compressed / total_sent_uncompressed if total_sent_uncompressed > 0 else 0 if compression_ratio > 0: logger.fs.info(f"Total compressed bytes sent: {total_sent_compressed / GB:.2f}GB") logger.fs.info(f"Total uncompressed bytes sent: {total_sent_uncompressed / GB:.2f}GB") logger.fs.info(f"Compression ratio: {compression_ratio}") progress.console.print(f"[bold yellow]Compression saved {(1. - compression_ratio)*100.:.2f}% of egress fees") if copy_gateway_logs: def copy_log(instance): instance.run_command("sudo docker logs -t skyplane_gateway 2> /tmp/gateway.stderr > /tmp/gateway.stdout") instance.download_file("/tmp/gateway.stdout", self.transfer_dir / f"gateway_{instance.uuid()}.stdout") instance.download_file("/tmp/gateway.stderr", self.transfer_dir / f"gateway_{instance.uuid()}.stderr") progress.update(cleanup_task, description=": Copying gateway logs") do_parallel(copy_log, self.bound_nodes.values(), n=-1) if write_profile: progress.update(cleanup_task, description=": Writing chunk profiles") chunk_status_df = self.get_chunk_status_log_df() (self.transfer_dir / "chunk_status_df.csv").write_text(chunk_status_df.to_csv(index=False)) traceevent = status_df_to_traceevent(chunk_status_df) profile_out = self.transfer_dir / f"traceevent_{uuid.uuid4()}.json" profile_out.parent.mkdir(parents=True, exist_ok=True) profile_out.write_text(json.dumps(traceevent)) if write_socket_profile: def write_socket_profile(instance): receiver_reply = self.http_pool.request("GET", f"{instance.gateway_api_url}/api/v1/profile/socket/receiver") text = receiver_reply.data.decode("utf-8") if receiver_reply.status != 200: logger.fs.error( f"Failed to get receiver socket profile from {instance.gateway_api_url}: {receiver_reply.status} {text}" ) (self.transfer_dir / f"receiver_socket_profile_{instance.uuid()}.json").write_text(text) progress.update(cleanup_task, description=": Writing socket profiles") do_parallel(write_socket_profile, self.bound_nodes.values(), n=-1) if cleanup_gateway: def fn(s: Server): try: self.http_pool.request("POST", f"{s.gateway_api_url}/api/v1/shutdown") except: return # ignore connection errors since server may be shutting down do_parallel(fn, self.bound_nodes.values(), n=-1) progress.update(cleanup_task, description=": Shutting down gateways")
[docs] @staticmethod def verify_transfer_prefix(job: ReplicationJob, dest_prefix: str): """Check that all objects to copy are present in the destination""" dst_interface = ObjectStoreInterface.create(job.dest_region, job.dest_bucket) # algorithm: check all expected keys are present in the destination # by iteratively removing found keys from list_objects from a # precomputed dictionary of keys to check. dst_keys = {dst_o.key: src_o for src_o, dst_o in job.transfer_pairs} for obj in dst_interface.list_objects(dest_prefix): # check metadata (src.size == dst.size) && (src.modified <= dst.modified) src_obj = dst_keys.get(obj.key) if src_obj and src_obj.size == obj.size and src_obj.last_modified <= obj.last_modified: del dst_keys[obj.key] if dst_keys: raise exceptions.TransferFailedException( f"{len(dst_keys)} objects failed verification", [obj.key for obj in dst_keys.values()], )