Skip to content

Commit 767af86

Browse files
committed
Address comments
1 parent 168a81e commit 767af86

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

R/pkg/R/SQLContext.R

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ getDefaultSqlSource <- function() {
147147
l[["spark.sql.sources.default"]]
148148
}
149149

150-
writeToTempFileInArrow <- function(rdf, numPartitions) {
150+
writeToFileInArrow <- function(fileName, rdf, numPartitions) {
151151
requireNamespace1 <- requireNamespace
152152

153153
# For some reasons, Arrow R API requires to load 'defer_parent' which is from 'withr' package.
@@ -186,7 +186,6 @@ writeToTempFileInArrow <- function(rdf, numPartitions) {
186186
list(rdf)
187187
}
188188

189-
fileName <- tempfile(pattern = "spark-arrow", fileext = ".tmp")
190189
stream_writer <- NULL
191190
tryCatch({
192191
for (rdf_slice in rdf_slices) {
@@ -209,7 +208,6 @@ writeToTempFileInArrow <- function(rdf, numPartitions) {
209208
}
210209
})
211210

212-
fileName
213211
} else {
214212
stop("'arrow' package should be installed.")
215213
}
@@ -258,8 +256,9 @@ createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0,
258256
numPartitions = NULL) {
259257
sparkSession <- getSparkSession()
260258
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true"
261-
shouldUseArrow <- FALSE
259+
useArrow <- FALSE
262260
firstRow <- NULL
261+
263262
if (is.data.frame(data)) {
264263
# get the names of columns, they will be put into RDD
265264
if (is.null(schema)) {
@@ -278,16 +277,18 @@ createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0,
278277

279278
args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE)
280279
if (arrowEnabled) {
281-
shouldUseArrow <- tryCatch({
280+
useArrow <- tryCatch({
282281
stopifnot(length(data) > 0)
283282
dataHead <- head(data, 1)
284283
checkTypeRequirementForArrow(data, schema)
285-
fileName <- writeToTempFileInArrow(data, numPartitions)
286-
tryCatch(
284+
fileName <- tempfile(pattern = "sparwriteToFileInArrowk-arrow", fileext = ".tmp")
285+
tryCatch({
286+
writeToFileInArrow(fileName, data, numPartitions)
287287
jrddInArrow <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
288288
"readArrowStreamFromFile",
289289
sparkSession,
290-
fileName),
290+
fileName)
291+
},
291292
finally = {
292293
file.remove(fileName)
293294
})
@@ -304,7 +305,7 @@ createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0,
304305
})
305306
}
306307

307-
if (!shouldUseArrow) {
308+
if (!useArrow) {
308309
# Convert data into a list of rows. Each row is a list.
309310
# drop factors and wrap lists
310311
data <- setNames(as.list(data), NULL)
@@ -320,7 +321,7 @@ createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0,
320321
}
321322
}
322323

323-
if (shouldUseArrow) {
324+
if (useArrow) {
324325
rdd <- jrddInArrow
325326
} else if (is.list(data)) {
326327
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
@@ -369,7 +370,7 @@ createDataFrame <- function(data, schema = NULL, samplingRatio = 1.0,
369370

370371
stopifnot(class(schema) == "structType")
371372

372-
if (shouldUseArrow) {
373+
if (useArrow) {
373374
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
374375
"toDataFrame", rdd, schema$jobj, sparkSession)
375376
} else {

sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ private[sql] object SQLUtils extends Logging {
250250
}
251251

252252
/**
253-
* R callable function to read a file in Arrow stream format and create a `DataFrame`
254-
* from an RDD.
253+
* R callable function to create a `DataFrame` from a `JavaRDD` of serialized
254+
* ArrowRecordBatches.
255255
*/
256256
def toDataFrame(
257257
arrowBatchRDD: JavaRDD[Array[Byte]],

0 commit comments

Comments
 (0)