diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 0cd0d75df0f70..6d5b2b931cb08 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -11,7 +11,8 @@ Depends: R (>= 3.0), methods, Suggests: - testthat + testthat, + survival Description: R frontend for Spark License: Apache License (== 2.0) Collate: diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 636d39e1e9cae..dae7abafd757e 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -15,7 +15,8 @@ exportMethods("glm", "predict", "summary", "kmeans", - "fitted") + "fitted", + "survreg") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 6ad71fcb46712..25062b82705a0 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1175,3 +1175,7 @@ setGeneric("kmeans") #' @rdname fitted #' @export setGeneric("fitted") + +#' @rdname survreg +#' @export +setGeneric("survreg", function(formula, data) { standardGeneric("survreg") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 5c0d3dcf3af90..b889fd76ee6f2 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -61,6 +61,34 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram return(new("PipelineModel", model = model)) }) +#' Fit an accelerated failure time (AFT) survival regression model. +#' +#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg(). +#' +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param data DataFrame for training. +#' @return a fitted MLlib model +#' @rdname survreg +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' library(survival) +#' data(ovarian) +#' df <- createDataFrame(sqlContext, ovarian) +#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df) +#' summary(model) +#'} +setMethod("survreg", signature(formula = "formula", data = "DataFrame"), + function(formula, data) { + formula <- paste(deparse(formula), collapse = "") + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "fitAFTSurvivalRegression", formula, data@sdf) + return(new("PipelineModel", model = model)) + }) + #' Make predictions from a model #' #' Makes predictions from a model produced by glm(), similarly to R's predict(). @@ -135,6 +163,11 @@ setMethod("summary", signature(object = "PipelineModel"), colnames(coefficients) <- unlist(features) rownames(coefficients) <- 1:k return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster))) + } else if (modelName == "AFTSurvivalRegressionModel") { + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Value") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) } else { stop(paste("Unsupported model", modelName, sep = " ")) } diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index e120462964d1e..b1c45b3e3ab6c 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -141,3 +141,24 @@ test_that("kmeans", { cluster <- summary.model$cluster expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) }) + +test_that("SparkR::survreg vs survival::survreg", { + library(survival) + data(ovarian) + df <- suppressWarnings(createDataFrame(sqlContext, ovarian)) + + model <- SparkR::survreg(Surv(futime, fustat) ~ ecog_ps + rx, df) + stats <- summary(model) + coefs <- as.vector(stats$coefficients[, 1][1:3]) + scale <- exp(stats$coefficients[, 1][4]) + + rModel <- survival::survreg(Surv(futime, fustat) ~ ecog.ps + rx, ovarian) + rCoefs <- as.vector(coef(rModel)) + rScale <- rModel$scale + + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(abs(rScale - scale) < 1e-4) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "ecog_ps", "rx", "Log(scale)"))) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index d23e4fc9d1f57..ec48ed96bcb32 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.api.r +import org.apache.spark.SparkException import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.clustering.{KMeans, KMeansModel} import org.apache.spark.ml.feature.{RFormula, VectorAssembler} -import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.ml.regression._ import org.apache.spark.sql.DataFrame private[r] object SparkRWrappers { @@ -52,6 +53,43 @@ private[r] object SparkRWrappers { pipeline.fit(df) } + def fitAFTSurvivalRegression( + value: String, + df: DataFrame): PipelineModel = { + + def formulaRewrite(value: String): (String, String) = { + var rewrited: String = null + var censorCol: String = null + + val regex = "^Surv\\s*\\(([^,]+),([^,]+)\\)\\s*\\~\\s*(.+)".r + try { + val regex(label, censor, features) = value + // TODO: Support dot operator. + if (features.contains(".")) { + throw new UnsupportedOperationException( + "Terms of survreg formula can not support dot operator.") + } + rewrited = label.trim + "~" + features + censorCol = censor.trim + } catch { + case e: MatchError => + throw new SparkException(s"Could not parse formula: $value") + } + + (rewrited, censorCol) + } + + val (rewritedValue, censorCol) = formulaRewrite(value) + + val formula = new RFormula().setFormula(rewritedValue) + val estimator = new AFTSurvivalRegression() + .setCensorCol(censorCol) + .setFitIntercept(formula.hasIntercept) + + val pipeline = new Pipeline().setStages(Array(formula, estimator)) + pipeline.fit(df) + } + def fitKMeans( df: DataFrame, initMode: String, @@ -91,6 +129,12 @@ private[r] object SparkRWrappers { } case m: KMeansModel => m.clusterCenters.flatMap(_.toArray) + case m: AFTSurvivalRegressionModel => + if (m.getFitIntercept) { + Array(m.intercept) ++ m.coefficients.toArray ++ Array(math.log(m.scale)) + } else { + m.coefficients.toArray ++ Array(math.log(m.scale)) + } } } @@ -151,6 +195,14 @@ private[r] object SparkRWrappers { val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) attrs.attributes.get.map(_.name.get) + case m: AFTSurvivalRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.getFeaturesCol)) + if (m.getFitIntercept) { + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) ++ Array("Log(scale)") + } else { + attrs.attributes.get.map(_.name.get) ++ Array("Log(scale)") + } } } @@ -162,6 +214,8 @@ private[r] object SparkRWrappers { "LogisticRegressionModel" case m: KMeansModel => "KMeansModel" + case m: AFTSurvivalRegressionModel => + "AFTSurvivalRegressionModel" } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index ba5708ab8d9bb..fd4e1b692dddd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -232,8 +232,14 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val coefficients = Vectors.dense(parameters.slice(2, parameters.length)) val intercept = parameters(1) val scale = math.exp(parameters(0)) - val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) - copyValues(model.setParent(this)) + val model = copyValues( + new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + .setParent(this)) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + val summary = new AFTSurvivalRegressionSummary( + summaryModel.transform(dataset), predictionColName) + model.setSummary(summary) } @Since("1.6.0") @@ -281,6 +287,39 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") def setQuantilesCol(value: String): this.type = set(quantilesCol, value) + private var trainingSummary: Option[AFTSurvivalRegressionSummary] = None + + private[regression] def setSummary(summary: AFTSurvivalRegressionSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * If the prediction column is set returns the current model and prediction column, + * otherwise generates a new column and sets it as the prediction column on a new copy + * of the current model. + */ + private[regression] def findSummaryModelAndPredictionCol() + : (AFTSurvivalRegressionModel, String) = { + $(predictionCol) match { + case "" => + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) + case p => (this, p) + } + } + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: AFTSurvivalRegressionSummary = trainingSummary.getOrElse { + throw new SparkException( + "No training summary available for this AFTSurvivalRegressionModel", + new RuntimeException()) + } + @Since("1.6.0") def predictQuantiles(features: Vector): Vector = { // scale parameter for the Weibull distribution of lifetime @@ -375,6 +414,19 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] } } +/** + * :: Experimental :: + * AFT survival regression results evaluated on a dataset. + * + * @param predictions dataframe outputted by the model's `transform` method. + * @param predictionCol field in "predictions" which gives the prediction of each instance. + */ +@Experimental +@Since("2.0.0") +class AFTSurvivalRegressionSummary private[regression] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String) extends Serializable + /** * AFTAggregator computes the gradient and loss for a AFT loss function, * as used in AFT survival regression for samples in sparse or dense vector in a online fashion.