Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ import org.apache.spark.SparkException
* The `rpcAddress` may be null, in which case the endpoint is registered via a client-only
* connection and can only be reached via the client that sent the endpoint reference.
*
* @param rpcAddress The socket address of the endpoint.
* @param rpcAddress The socket address of the endpoint. It's `null` when this address pointing to
* an endpoint in a client `NettyRpcEnv`.
* @param name Name of the endpoint.
*/
private[spark] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) {
private[spark] case class RpcEndpointAddress(rpcAddress: RpcAddress, name: String) {

require(name != null, "RpcEndpoint name must be provided.")

Expand Down
119 changes: 96 additions & 23 deletions core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.server._
import org.apache.spark.rpc._
import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance}
import org.apache.spark.util.{ThreadUtils, Utils}
import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance, SerializationStream}
import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, ThreadUtils, Utils}

private[netty] class NettyRpcEnv(
val conf: SparkConf,
Expand Down Expand Up @@ -189,7 +189,7 @@ private[netty] class NettyRpcEnv(
}
} else {
// Message to a remote RPC endpoint.
postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))
postToOutbox(message.receiver, OneWayOutboxMessage(message.serialize(this)))
}
}

Expand Down Expand Up @@ -224,7 +224,7 @@ private[netty] class NettyRpcEnv(
}(ThreadUtils.sameThread)
dispatcher.postLocalMessage(message, p)
} else {
val rpcMessage = RpcOutboxMessage(serialize(message),
val rpcMessage = RpcOutboxMessage(message.serialize(this),
onFailure,
(client, response) => onSuccess(deserialize[Any](client, response)))
postToOutbox(message.receiver, rpcMessage)
Expand Down Expand Up @@ -253,6 +253,13 @@ private[netty] class NettyRpcEnv(
javaSerializerInstance.serialize(content)
}

/**
* Returns [[SerializationStream]] that forwards the serialized bytes to `out`.
*/
private[netty] def serializeStream(out: OutputStream): SerializationStream = {
javaSerializerInstance.serializeStream(out)
}

private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {
NettyRpcEnv.currentClient.withValue(client) {
deserialize { () =>
Expand Down Expand Up @@ -480,16 +487,13 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
*/
private[netty] class NettyRpcEndpointRef(
@transient private val conf: SparkConf,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed _address and _name to save some bytes.

endpointAddress: RpcEndpointAddress,
@transient @volatile private var nettyEnv: NettyRpcEnv)
extends RpcEndpointRef(conf) with Serializable with Logging {
private val endpointAddress: RpcEndpointAddress,
@transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) {

@transient @volatile var client: TransportClient = _

private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null
private val _name = endpointAddress.name

override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
override def address: RpcAddress =
if (endpointAddress.rpcAddress != null) endpointAddress.rpcAddress else null

private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
Expand All @@ -501,34 +505,103 @@ private[netty] class NettyRpcEndpointRef(
out.defaultWriteObject()
}

override def name: String = _name
override def name: String = endpointAddress.name

override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout)
nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout)
}

override def send(message: Any): Unit = {
require(message != null, "Message is null")
nettyEnv.send(RequestMessage(nettyEnv.address, this, message))
nettyEnv.send(new RequestMessage(nettyEnv.address, this, message))
}

override def toString: String = s"NettyRpcEndpointRef(${_address})"

def toURI: URI = new URI(_address.toString)
override def toString: String = s"NettyRpcEndpointRef(${endpointAddress})"

final override def equals(that: Any): Boolean = that match {
case other: NettyRpcEndpointRef => _address == other._address
case other: NettyRpcEndpointRef => endpointAddress == other.endpointAddress
case _ => false
}

final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode()
final override def hashCode(): Int =
if (endpointAddress == null) 0 else endpointAddress.hashCode()
}

/**
* The message that is sent from the sender to the receiver.
*
* @param senderAddress the sender address. It's `null` if this message is from a client
* `NettyRpcEnv`.
* @param receiver the receiver of this message.
* @param content the message content.
*/
private[netty] case class RequestMessage(
senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any)
private[netty] class RequestMessage(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed case to make RequestMessage non-serializable to avoid using Java serialization occasionally.

val senderAddress: RpcAddress,
val receiver: NettyRpcEndpointRef,
val content: Any) {

/** Manually serialize [[RequestMessage]] to minimize the size. */
def serialize(nettyEnv: NettyRpcEnv): ByteBuffer = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't you just want to implement the standard Java serialization mechanism here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's different. If I just implement writeObject and call Java serialization APIs to write RequestMessage, it will write the full class name RequestMessage and a serialization id which are not needed by all RPC messages.

val bos = new ByteBufferOutputStream()
val out = new DataOutputStream(bos)
try {
writeRpcAddress(out, senderAddress)
writeRpcAddress(out, receiver.address)
out.writeUTF(receiver.name)
val s = nettyEnv.serializeStream(out)
try {
s.writeObject(content)
} finally {
s.close()
}
} finally {
out.close()
}
bos.toByteBuffer
}

private def writeRpcAddress(out: DataOutputStream, rpcAddress: RpcAddress): Unit = {
if (rpcAddress == null) {
out.writeBoolean(false)
} else {
out.writeBoolean(true)
out.writeUTF(rpcAddress.host)
out.writeInt(rpcAddress.port)
}
}

override def toString: String = s"RequestMessage($senderAddress, $receiver, $content)"
}

private[netty] object RequestMessage {

private def readRpcAddress(in: DataInputStream): RpcAddress = {
val hasRpcAddress = in.readBoolean()
if (hasRpcAddress) {
RpcAddress(in.readUTF(), in.readInt())
} else {
null
}
}

def apply(nettyEnv: NettyRpcEnv, client: TransportClient, bytes: ByteBuffer): RequestMessage = {
val bis = new ByteBufferInputStream(bytes)
val in = new DataInputStream(bis)
try {
val senderAddress = readRpcAddress(in)
val endpointAddress = RpcEndpointAddress(readRpcAddress(in), in.readUTF())
val ref = new NettyRpcEndpointRef(nettyEnv.conf, endpointAddress, nettyEnv)
ref.client = client
new RequestMessage(
senderAddress,
ref,
// The remaining bytes in `bytes` are the message content.
nettyEnv.deserialize(client, bytes))
} finally {
in.close()
}
}
}

/**
* A response that indicates some failure happens in the receiver side.
Expand Down Expand Up @@ -574,10 +647,10 @@ private[netty] class NettyRpcHandler(
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
val requestMessage = RequestMessage(nettyEnv, client, message)
if (requestMessage.senderAddress == null) {
// Create a new message with the socket address of the client as the sender.
RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
} else {
// The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for
// the listening address
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

package org.apache.spark.rpc.netty

import org.scalatest.mock.MockitoSugar

import org.apache.spark._
import org.apache.spark.network.client.TransportClient
import org.apache.spark.rpc._

class NettyRpcEnvSuite extends RpcEnvSuite {
class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar {

override def createRpcEnv(
conf: SparkConf,
Expand Down Expand Up @@ -53,4 +56,32 @@ class NettyRpcEnvSuite extends RpcEnvSuite {
}
}

test("RequestMessage serialization") {
def assertRequestMessageEquals(expected: RequestMessage, actual: RequestMessage): Unit = {
assert(expected.senderAddress === actual.senderAddress)
assert(expected.receiver === actual.receiver)
assert(expected.content === actual.content)
}

val nettyEnv = env.asInstanceOf[NettyRpcEnv]
val client = mock[TransportClient]
val senderAddress = RpcAddress("locahost", 12345)
val receiverAddress = RpcEndpointAddress("localhost", 54321, "test")
val receiver = new NettyRpcEndpointRef(nettyEnv.conf, receiverAddress, nettyEnv)

val msg = new RequestMessage(senderAddress, receiver, "foo")
assertRequestMessageEquals(
msg,
RequestMessage(nettyEnv, client, msg.serialize(nettyEnv)))

val msg2 = new RequestMessage(null, receiver, "foo")
assertRequestMessageEquals(
msg2,
RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv)))

val msg3 = new RequestMessage(senderAddress, receiver, null)
assertRequestMessageEquals(
msg3,
RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
val env = mock(classOf[NettyRpcEnv])
val sm = mock(classOf[StreamManager])
when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any()))
.thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null))
.thenReturn(new RequestMessage(RpcAddress("localhost", 12345), null, null))

test("receive") {
val dispatcher = mock(classOf[Dispatcher])
Expand Down