@@ -19,7 +19,7 @@ package org.apache.spark.deploy.rest
1919
2020import java .io .{DataOutputStream , File }
2121import java .net .InetSocketAddress
22- import javax .servlet .http .{HttpServlet , HttpServletResponse , HttpServletRequest }
22+ import javax .servlet .http .{HttpServlet , HttpServletRequest , HttpServletResponse }
2323
2424import scala .io .Source
2525
@@ -48,6 +48,15 @@ private[spark] class StandaloneRestServer(master: Master, host: String, requeste
4848 import StandaloneRestServer ._
4949
5050 private var _server : Option [Server ] = None
51+ private val basePrefix = s " / $PROTOCOL_VERSION/submissions "
52+
53+ // A mapping from servlets to the URL prefixes they are responsible for
54+ private val servletToPrefix = Map [StandaloneRestServlet , String ](
55+ new SubmitRequestServlet (master) -> s " $basePrefix/create/* " ,
56+ new KillRequestServlet (master) -> s " $basePrefix/kill/* " ,
57+ new StatusRequestServlet (master) -> s " $basePrefix/status/* " ,
58+ new ErrorServlet -> " /"
59+ )
5160
5261 /** Start the server and return the bound port. */
5362 def start (): Int = {
@@ -58,28 +67,19 @@ private[spark] class StandaloneRestServer(master: Master, host: String, requeste
5867 }
5968
6069 /**
61- * Set up the mapping from contexts to the appropriate servlets:
62- * (1) submit requests should be directed to /create
63- * (2) kill requests should be directed to /kill
64- * (3) status requests should be directed to /status
70+ * Map the servlets to their corresponding contexts and attach them to a server.
6571 * Return a 2-tuple of the started server and the bound port.
6672 */
6773 private def doStart (startPort : Int ): (Server , Int ) = {
6874 val server = new Server (new InetSocketAddress (host, requestedPort))
6975 val threadPool = new QueuedThreadPool
7076 threadPool.setDaemon(true )
7177 server.setThreadPool(threadPool)
72- val pathPrefix = s " / $PROTOCOL_VERSION/submissions "
7378 val mainHandler = new ServletContextHandler
7479 mainHandler.setContextPath(" /" )
75- mainHandler.addServlet(
76- new ServletHolder (new SubmitRequestServlet (master)), s " $pathPrefix/create " )
77- mainHandler.addServlet(
78- new ServletHolder (new KillRequestServlet (master)), s " $pathPrefix/kill/* " )
79- mainHandler.addServlet(
80- new ServletHolder (new StatusRequestServlet (master)), s " $pathPrefix/status/* " )
81- mainHandler.addServlet(
82- new ServletHolder (new ErrorServlet ), " /" )
80+ servletToPrefix.foreach { case (servlet, prefix) =>
81+ mainHandler.addServlet(new ServletHolder (servlet), prefix)
82+ }
8383 server.setHandler(mainHandler)
8484 server.start()
8585 val boundPort = server.getConnectors()(0 ).getLocalPort
@@ -93,6 +93,7 @@ private[spark] class StandaloneRestServer(master: Master, host: String, requeste
9393
9494private object StandaloneRestServer {
9595 val PROTOCOL_VERSION = StandaloneRestClient .PROTOCOL_VERSION
96+ val SC_UNKNOWN_PROTOCOL_VERSION = 468
9697}
9798
9899/**
@@ -257,7 +258,6 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
257258 responseServlet : HttpServletResponse ): Unit = {
258259 val requestMessageJson = Source .fromInputStream(requestServlet.getInputStream).mkString
259260 val requestMessage = SubmitRestProtocolMessage .fromJson(requestMessageJson)
260- .asInstanceOf [SubmitRestProtocolRequest ]
261261 val responseMessage = handleSubmit(requestMessage, responseServlet)
262262 responseServlet.setContentType(" application/json" )
263263 responseServlet.setCharacterEncoding(" utf-8" )
@@ -268,8 +268,13 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
268268 out.close()
269269 }
270270
271+ /**
272+ * Handle a submit request by first validating the request message, then submitting the
273+ * application using the parameters specified in the message. If the message is not of
274+ * the expected type, return error to the client.
275+ */
271276 private def handleSubmit (
272- requestMessage : SubmitRestProtocolRequest ,
277+ requestMessage : SubmitRestProtocolMessage ,
273278 responseServlet : HttpServletResponse ): SubmitRestProtocolResponse = {
274279 // The response should have already been validated on the client.
275280 // In case this is not true, validate it ourselves to avoid potential NPEs.
@@ -293,8 +298,7 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
293298 submitResponse
294299 case unexpected =>
295300 responseServlet.setStatus(HttpServletResponse .SC_BAD_REQUEST )
296- handleError(
297- s " Received message of unexpected type ${Utils .getFormattedClassName(unexpected)}. " )
301+ handleError(s " Received message of unexpected type ${unexpected.messageType}. " )
298302 }
299303 }
300304
@@ -366,23 +370,36 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
366370 */
367371private [spark] class ErrorServlet extends StandaloneRestServlet {
368372 private val expectedVersion = StandaloneRestServer .PROTOCOL_VERSION
373+
374+ /** Service a faulty request by returning an appropriate error message to the client. */
369375 protected override def service (
370376 request : HttpServletRequest ,
371377 response : HttpServletResponse ): Unit = {
378+ response.setStatus(HttpServletResponse .SC_BAD_REQUEST )
372379 val path = request.getPathInfo
373- val parts = path.stripPrefix(" /" ).split(" /" )
374- if (parts.nonEmpty) {
375- val version = parts.head
376- if (version != expectedVersion) {
377- response.setStatus(HttpServletResponse .SC_BAD_REQUEST )
378- val error = handleError(s " Incompatible protocol version $version" )
379- sendResponse(error, response)
380- return
380+ val parts = path.stripPrefix(" /" ).split(" /" ).toSeq
381+ var msg =
382+ parts match {
383+ case Nil =>
384+ // http://host:port/
385+ " Missing protocol version."
386+ case `expectedVersion` :: Nil =>
387+ // http://host:port/correct-version
388+ " Missing the /submissions prefix."
389+ case `expectedVersion` :: " submissions" :: Nil =>
390+ // http://host:port/correct-version/submissions
391+ " Missing an action: please specify one of /create, /kill, or /status."
392+ case unknownVersion :: _ =>
393+ // http://host:port/unknown-version/*
394+ // Use a special response code in case the client wants to retry with a different version
395+ response.setStatus(StandaloneRestServer .SC_UNKNOWN_PROTOCOL_VERSION )
396+ s " Unknown protocol version ' $unknownVersion'. "
397+ case _ =>
398+ // never reached
399+ s " Malformed path $path. "
381400 }
382- }
383- response.setStatus(HttpServletResponse .SC_BAD_REQUEST )
384- val error = handleError(
385- s " Unexpected path $path: Please submit requests through / $expectedVersion/submissions/ " )
401+ msg += s " Please submit requests through http://[host]:[port]/ $expectedVersion/submissions/... "
402+ val error = handleError(msg)
386403 sendResponse(error, response)
387404 }
388405}
0 commit comments