Skip to content
Merged
7 changes: 6 additions & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3369,6 +3369,7 @@ def write_bigquery(
self,
project_id: str,
dataset: str,
max_retry_cnt: int = 10,
ray_remote_args: Dict[str, Any] = None,
) -> None:
"""Write the dataset to a BigQuery dataset table.
Expand Down Expand Up @@ -3397,6 +3398,10 @@ def write_bigquery(
dataset: The name of the dataset in the format of ``dataset_id.table_id``.
The dataset is created if it doesn't already exist. The table_id is
overwritten if it exists.
max_retry_cnt: The maximum number of retries that an individual block write
is retried due to BigQuery rate limiting errors. This isn't
related to Ray fault tolerance retries. The default number of retries
is 10.
ray_remote_args: Kwargs passed to ray.remote in the write tasks.
""" # noqa: E501
if ray_remote_args is None:
Expand All @@ -3412,7 +3417,7 @@ def write_bigquery(
else:
ray_remote_args["max_retries"] = 0

datasink = _BigQueryDatasink(project_id, dataset)
datasink = _BigQueryDatasink(project_id, dataset, max_retry_cnt=max_retry_cnt)
self.write_datasink(datasink, ray_remote_args=ray_remote_args)

@Deprecated
Expand Down
36 changes: 26 additions & 10 deletions python/ray/data/datasource/bigquery_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,24 @@

logger = logging.getLogger(__name__)

MAX_RETRY_CNT = 10
DEFAULT_MAX_RETRY_CNT = 10
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11


class _BigQueryDatasink(Datasink):
def __init__(self, project_id: str, dataset: str) -> None:
def __init__(
self,
project_id: str,
dataset: str,
max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT,
) -> None:
_check_import(self, module="google.cloud", package="bigquery")
_check_import(self, module="google.cloud", package="bigquery_storage")
_check_import(self, module="google.api_core", package="exceptions")

self.project_id = project_id
self.dataset = dataset
self.max_retry_cnt = max_retry_cnt

def on_write_start(self) -> None:
from google.api_core import exceptions
Expand Down Expand Up @@ -71,25 +77,35 @@ def _write_single_block(block: Block, project_id: str, dataset: str) -> None:
pq.write_table(block, fp, compression="SNAPPY")

retry_cnt = 0
while retry_cnt < MAX_RETRY_CNT:
while retry_cnt <= self.max_retry_cnt:
with open(fp, "rb") as source_file:
job = client.load_table_from_file(
source_file, dataset, job_config=job_config
)
retry_cnt += 1
try:
logger.info(job.result())
break
except exceptions.Forbidden as e:
logger.info("Rate limit exceeded... Sleeping to try again")
logger.debug(e)
retry_cnt += 1
if retry_cnt > self.max_retry_cnt:
break
logger.info(
"A block write encountered a rate limit exceeded error"
+ f" {retry_cnt} time(s). Sleeping to try again."
)
logging.debug(e)
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)

# Raise exception if retry_cnt hits MAX_RETRY_CNT
if retry_cnt >= MAX_RETRY_CNT:
# Raise exception if retry_cnt exceeds max_retry_cnt
if retry_cnt > self.max_retry_cnt:
logger.info(
f"Maximum ({self.max_retry_cnt}) retry count exceeded. Ray"
+ " will attempt to retry the block write via fault tolerance."
)
raise RuntimeError(
f"Write failed due to {MAX_RETRY_CNT} repeated"
+ " API rate limit exceeded responses"
f"Write failed due to {retry_cnt}"
+ " repeated API rate limit exceeded responses. Consider"
+ " specifiying the max_retry_cnt kwarg with a higher value."
)

_write_single_block = cached_remote_fn(_write_single_block)
Expand Down