Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion R/pkg/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ Depends:
R (>= 3.0),
methods,
Suggests:
testthat
testthat,
survival
Description: R frontend for Spark
License: Apache License (== 2.0)
Collate:
Expand Down
3 changes: 2 additions & 1 deletion R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ exportMethods("glm",
"predict",
"summary",
"kmeans",
"fitted")
"fitted",
"survreg")

# Job group lifecycle management methods
export("setJobGroup",
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1175,3 +1175,7 @@ setGeneric("kmeans")
#' @rdname fitted
#' @export
setGeneric("fitted")

#' @rdname survreg
#' @export
setGeneric("survreg", function(formula, data) { standardGeneric("survreg") })
33 changes: 33 additions & 0 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
return(new("PipelineModel", model = model))
})

#' Fit an accelerated failure time (AFT) survival regression model.
#'
#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document that . is not supported.

#' @param data DataFrame for training.
#' @return a fitted MLlib model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

provide a @seealso link to \url{https://cran.r-project.org/web/packages/survival/}

#' @rdname survreg
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlContext <- sparkRSQL.init(sc)
#' library(survival)
#' data(ovarian)
#' df <- createDataFrame(sqlContext, ovarian)
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df)
#' summary(model)
#'}
setMethod("survreg", signature(formula = "formula", data = "DataFrame"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only support "weibull" distribution in AFTSurvivalRegression currently, so we don't need arguments dist like R's survreg until we supporting more distributions.

function(formula, data) {
formula <- paste(deparse(formula), collapse = "")
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"fitAFTSurvivalRegression", formula, data@sdf)
return(new("PipelineModel", model = model))
})

#' Make predictions from a model
#'
#' Makes predictions from a model produced by glm(), similarly to R's predict().
Expand Down Expand Up @@ -135,6 +163,11 @@ setMethod("summary", signature(object = "PipelineModel"),
colnames(coefficients) <- unlist(features)
rownames(coefficients) <- 1:k
return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
} else if (modelName == "AFTSurvivalRegressionModel") {
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Value")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
} else {
stop(paste("Unsupported model", modelName, sep = " "))
}
Expand Down
21 changes: 21 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,24 @@ test_that("kmeans", {
cluster <- summary.model$cluster
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
})

test_that("SparkR::survreg vs survival::survreg", {
library(survival)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Would the test fail if we don't have survival installed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because we have added required library at Suggests: in DESCRIPTION.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, suggests means it will not be installed automatically, which I think is right - I don't think we should require survival or e1071 to use SparkR
http://r-pkgs.had.co.nz/description.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we should not require survival installed when using SparkR.

data(ovarian)
df <- suppressWarnings(createDataFrame(sqlContext, ovarian))

model <- SparkR::survreg(Surv(futime, fustat) ~ ecog_ps + rx, df)
stats <- summary(model)
coefs <- as.vector(stats$coefficients[, 1][1:3])
scale <- exp(stats$coefficients[, 1][4])

rModel <- survival::survreg(Surv(futime, fustat) ~ ecog.ps + rx, ovarian)
rCoefs <- as.vector(coef(rModel))
rScale <- rModel$scale

expect_true(all(abs(rCoefs - coefs) < 1e-4))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expect_equal(coefs, rCoefs, tolerance = 1e-4), which generates better error messages

expect_true(abs(rScale - scale) < 1e-4)
expect_true(all(
rownames(stats$coefficients) ==
c("(Intercept)", "ecog_ps", "rx", "Log(scale)")))
})
56 changes: 55 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.ml.api.r

import org.apache.spark.SparkException
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.ml.regression._
import org.apache.spark.sql.DataFrame

private[r] object SparkRWrappers {
Expand Down Expand Up @@ -52,6 +53,43 @@ private[r] object SparkRWrappers {
pipeline.fit(df)
}

def fitAFTSurvivalRegression(
value: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change value to formula

df: DataFrame): PipelineModel = {

def formulaRewrite(value: String): (String, String) = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrite Surv(futime, fustat) ~ ecog.ps + rx to a tuple of string (futime ~ ecog.ps + rx, fustat).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we move it out as a private method?

var rewrited: String = null
var censorCol: String = null

val regex = "^Surv\\s*\\(([^,]+),([^,]+)\\)\\s*\\~\\s*(.+)".r
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use triple quotes to avoid escaping \. This should work: """Surv\((\S+), (\S+)\) ~ (\S+)""".r. I think the formula passed in from R wouldn't have extra spaces.

try {
val regex(label, censor, features) = value
// TODO: Support dot operator.
if (features.contains(".")) {
throw new UnsupportedOperationException(
"Terms of survreg formula can not support dot operator.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will not support dot operator currently, because dot will be expanded to all column names including the censorCol which is unexpected.
Although, if terms of survreg formula only include ., we can replace it by all feature column names at here. But if the formula contains operators like .:. or .:x, we can not handle it at here. So I choose to not support any kind of dot in term currently.

}
rewrited = label.trim + "~" + features
censorCol = censor.trim
} catch {
case e: MatchError =>
throw new SparkException(s"Could not parse formula: $value")
}

(rewrited, censorCol)
}

val (rewritedValue, censorCol) = formulaRewrite(value)

val formula = new RFormula().setFormula(rewritedValue)
val estimator = new AFTSurvivalRegression()
.setCensorCol(censorCol)
.setFitIntercept(formula.hasIntercept)

val pipeline = new Pipeline().setStages(Array(formula, estimator))
pipeline.fit(df)
}

def fitKMeans(
df: DataFrame,
initMode: String,
Expand Down Expand Up @@ -91,6 +129,12 @@ private[r] object SparkRWrappers {
}
case m: KMeansModel =>
m.clusterCenters.flatMap(_.toArray)
case m: AFTSurvivalRegressionModel =>
if (m.getFitIntercept) {
Array(m.intercept) ++ m.coefficients.toArray ++ Array(math.log(m.scale))
} else {
m.coefficients.toArray ++ Array(math.log(m.scale))
}
}
}

Expand Down Expand Up @@ -151,6 +195,14 @@ private[r] object SparkRWrappers {
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
attrs.attributes.get.map(_.name.get)
case m: AFTSurvivalRegressionModel =>
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.getFeaturesCol))
if (m.getFitIntercept) {
Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) ++ Array("Log(scale)")
} else {
attrs.attributes.get.map(_.name.get) ++ Array("Log(scale)")
}
}
}

Expand All @@ -162,6 +214,8 @@ private[r] object SparkRWrappers {
"LogisticRegressionModel"
case m: KMeansModel =>
"KMeansModel"
case m: AFTSurvivalRegressionModel =>
"AFTSurvivalRegressionModel"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,14 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
val coefficients = Vectors.dense(parameters.slice(2, parameters.length))
val intercept = parameters(1)
val scale = math.exp(parameters(0))
val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
copyValues(model.setParent(this))
val model = copyValues(
new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
.setParent(this))
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
val summary = new AFTSurvivalRegressionSummary(
summaryModel.transform(dataset), predictionColName)
model.setSummary(summary)
}

@Since("1.6.0")
Expand Down Expand Up @@ -281,6 +287,39 @@ class AFTSurvivalRegressionModel private[ml] (
@Since("1.6.0")
def setQuantilesCol(value: String): this.type = set(quantilesCol, value)

private var trainingSummary: Option[AFTSurvivalRegressionSummary] = None

private[regression] def setSummary(summary: AFTSurvivalRegressionSummary): this.type = {
this.trainingSummary = Some(summary)
this
}

/**
* If the prediction column is set returns the current model and prediction column,
* otherwise generates a new column and sets it as the prediction column on a new copy
* of the current model.
*/
private[regression] def findSummaryModelAndPredictionCol()
: (AFTSurvivalRegressionModel, String) = {
$(predictionCol) match {
case "" =>
val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
}

/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@Since("2.0.0")
def summary: AFTSurvivalRegressionSummary = trainingSummary.getOrElse {
throw new SparkException(
"No training summary available for this AFTSurvivalRegressionModel",
new RuntimeException())
}

@Since("1.6.0")
def predictQuantiles(features: Vector): Vector = {
// scale parameter for the Weibull distribution of lifetime
Expand Down Expand Up @@ -375,6 +414,19 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
}
}

/**
* :: Experimental ::
* AFT survival regression results evaluated on a dataset.
*
* @param predictions dataframe outputted by the model's `transform` method.
* @param predictionCol field in "predictions" which gives the prediction of each instance.
*/
@Experimental
@Since("2.0.0")
class AFTSurvivalRegressionSummary private[regression] (
@Since("2.0.0") @transient val predictions: DataFrame,
@Since("2.0.0") val predictionCol: String) extends Serializable

/**
* AFTAggregator computes the gradient and loss for a AFT loss function,
* as used in AFT survival regression for samples in sparse or dense vector in a online fashion.
Expand Down