diff --git a/smdebug/core/access_layer/s3.py b/smdebug/core/access_layer/s3.py index 3e46fc407..4c45b6f8a 100644 --- a/smdebug/core/access_layer/s3.py +++ b/smdebug/core/access_layer/s3.py @@ -1,9 +1,12 @@ # Standard Library +import io import os import re +import tempfile # Third Party import boto3 +from boto3.s3.transfer import TransferConfig # First Party from smdebug.core.access_layer.base import TSAccessBase @@ -28,6 +31,10 @@ def __init__( self.s3 = boto3.resource("s3", region_name=get_region()) self.s3_client = boto3.client("s3", region_name=get_region()) + # Set the desired multipart threshold value (5GB) + MB = 1024 ** 2 + self.transfer_config = TransferConfig(multipart_threshold=5 * MB) + # check if the bucket exists buckets = [bucket["Name"] for bucket in self.s3_client.list_buckets()["Buckets"]] if self.bucket_name not in buckets: @@ -39,17 +46,12 @@ def _init_data(self): else: self.data = "" - def _init_data(self): - if self.binary: - self.data = bytearray() - else: - self.data = "" - def open(self, bucket_name, mode): raise NotImplementedError def write(self, _data): start = len(self.data) + self.data += _data length = len(_data) return [start, length] @@ -57,8 +59,25 @@ def write(self, _data): def close(self): if self.flushed: return - key = self.s3.Object(self.bucket_name, self.key_name) - key.put(Body=self.data) + if self.binary: + self.logger.debug( + f"Sagemaker-Debugger: Writing binary data to s3://{os.path.join(self.bucket_name, self.key_name)}" + ) + self.s3_client.upload_fileobj( + io.BytesIO(self.data), self.bucket_name, self.key_name, Config=self.transfer_config + ) + else: + f = tempfile.NamedTemporaryFile(mode="w+") + self.logger.debug( + f"Sagemaker-Debugger: Writing string data to s3://{os.path.join(self.bucket_name, self.key_name)}" + ) + + f.write(self.data) + f.flush() + self.s3_client.upload_file( + f.name, self.bucket_name, self.key_name, Config=self.transfer_config + ) + self.logger.debug( f"Sagemaker-Debugger: Wrote {len(self.data)} bytes to file " f"s3://{os.path.join(self.bucket_name, self.key_name)}"