diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java
index 2fc97f65aca4..04888f568be3 100644
--- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java
@@ -30,7 +30,6 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
private final String host;
private final int port;
private final String key;
- private final SocketFileDescriptorGetter socketFileDescriptorGetter;
private volatile Socket currSocket = new Socket();
private Runnable callback;
@@ -40,14 +39,11 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
* @param host Proxy server host.
* @param port Proxy server port.
* @param key Proxy server key.
- * @param sockFdGetter Method to get file descriptor from Java socket.
*/
- public ConnectProxyServerProcessor(String host, int port, String key,
- SocketFileDescriptorGetter sockFdGetter) {
+ public ConnectProxyServerProcessor(String host, int port, String key) {
this.host = host;
this.port = port;
this.key = "server:" + key;
- socketFileDescriptorGetter = sockFdGetter;
}
/**
@@ -70,8 +66,8 @@ public void setStartTimeCallback(Runnable callback) {
try {
SocketAddress address = new InetSocketAddress(host, port);
currSocket.connect(address, 6000);
- InputStream in = currSocket.getInputStream();
- OutputStream out = currSocket.getOutputStream();
+ final InputStream in = currSocket.getInputStream();
+ final OutputStream out = currSocket.getOutputStream();
out.write(Utils.toBytes(RPC.RPC_MAGIC));
out.write(Utils.toBytes(key.length()));
out.write(Utils.toBytes(key));
@@ -91,11 +87,10 @@ public void setStartTimeCallback(Runnable callback) {
if (callback != null) {
callback.run();
}
- final int sockFd = socketFileDescriptorGetter.get(currSocket);
- if (sockFd != -1) {
- new NativeServerLoop(sockFd).run();
- System.err.println("Finish serving " + address);
- }
+
+ SocketChannel sockChannel = new SocketChannel(currSocket);
+ new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
+ System.err.println("Finish serving " + address);
} catch (Throwable e) {
e.printStackTrace();
throw new RuntimeException(e);
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java
index 47881eb350c3..c449bb18a565 100644
--- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java
@@ -37,7 +37,6 @@
*/
public class ConnectTrackerServerProcessor implements ServerProcessor {
private ServerSocket server;
- private final SocketFileDescriptorGetter socketFileDescriptorGetter;
private final String trackerHost;
private final int trackerPort;
// device key
@@ -62,10 +61,11 @@ public class ConnectTrackerServerProcessor implements ServerProcessor {
* @param trackerHost Tracker host.
* @param trackerPort Tracker port.
* @param key Device key.
- * @param sockFdGetter Method to get file descriptor from Java socket.
+ * @param watchdog watch for timeout, etc.
+ * @throws java.io.IOException when socket fails to open.
*/
public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String key,
- SocketFileDescriptorGetter sockFdGetter, RPCWatchdog watchdog) throws IOException {
+ RPCWatchdog watchdog) throws IOException {
while (true) {
try {
this.server = new ServerSocket(serverPort);
@@ -81,7 +81,6 @@ public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String
}
}
System.err.println("using port: " + serverPort);
- this.socketFileDescriptorGetter = sockFdGetter;
this.trackerHost = trackerHost;
this.trackerPort = trackerPort;
this.key = key;
@@ -163,11 +162,9 @@ public String getMatchKey() {
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
// received timeout in seconds
watchdog.startTimeout(timeout * 1000);
- final int sockFd = socketFileDescriptorGetter.get(socket);
- if (sockFd != -1) {
- new NativeServerLoop(sockFd).run();
- System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
- }
+ SocketChannel sockChannel = new SocketChannel(socket);
+ new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
+ System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
Utils.closeQuietly(socket);
} catch (ConnectException e) {
// if the tracker connection failed, wait a bit before retrying
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java
index 255dabb438d5..697ce45fa04f 100644
--- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java
@@ -28,14 +28,17 @@
* Call native ServerLoop on socket file descriptor.
*/
public class NativeServerLoop implements Runnable {
- private final int sockFd;
+ private final Function fsend;
+ private final Function frecv;
/**
* Constructor for NativeServerLoop.
- * @param nativeSockFd native socket file descriptor.
+ * @param fsend socket.send function.
+ * @param frecv socket.recv function.
*/
- public NativeServerLoop(final int nativeSockFd) {
- sockFd = nativeSockFd;
+ public NativeServerLoop(final Function fsend, final Function frecv) {
+ this.fsend = fsend;
+ this.frecv = frecv;
}
@Override public void run() {
@@ -43,7 +46,7 @@ public NativeServerLoop(final int nativeSockFd) {
try {
tempDir = serverEnv();
System.err.println("starting server loop...");
- RPC.getApi("_ServerLoop").pushArg(sockFd).invoke();
+ RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
System.err.println("done server loop...");
} catch (IOException e) {
e.printStackTrace();
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java
index 8ebf188b0667..278ef9fe8eef 100644
--- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java
@@ -200,6 +200,7 @@ public void upload(byte[] data, String target) {
* Upload file to remote runtime temp folder.
* @param data The file in local to upload.
* @param target The path in remote.
+ * @throws java.io.IOException for network failure.
*/
public void upload(File data, String target) throws IOException {
byte[] blob = getBytesFromFile(data);
@@ -209,6 +210,7 @@ public void upload(File data, String target) throws IOException {
/**
* Upload file to remote runtime temp folder.
* @param data The file in local to upload.
+ * @throws java.io.IOException for network failure.
*/
public void upload(File data) throws IOException {
upload(data, data.getName());
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java
index c81faa0ca999..a9ea2d89a62c 100644
--- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java
@@ -17,31 +17,12 @@
package ml.dmlc.tvm.rpc;
-import sun.misc.SharedSecrets;
-
-import java.io.FileDescriptor;
-import java.io.FileInputStream;
import java.io.IOException;
-import java.io.InputStream;
-import java.net.Socket;
/**
* RPC Server.
*/
public class Server {
- private static SocketFileDescriptorGetter defaultSocketFdGetter
- = new SocketFileDescriptorGetter() {
- @Override public int get(Socket socket) {
- try {
- InputStream is = socket.getInputStream();
- FileDescriptor fd = ((FileInputStream) is).getFD();
- return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd);
- } catch (IOException e) {
- e.printStackTrace();
- return -1;
- }
- }
- };
private final WorkerThread worker;
private static class WorkerThread extends Thread {
@@ -72,35 +53,10 @@ public void terminate() {
/**
* Start a standalone server.
* @param serverPort Port.
- * @param socketFdGetter Method to get system file descriptor of the server socket.
- * @throws IOException if failed to bind localhost:port.
- */
- public Server(int serverPort, SocketFileDescriptorGetter socketFdGetter) throws IOException {
- worker = new WorkerThread(new StandaloneServerProcessor(serverPort, socketFdGetter));
- }
-
- /**
- * Start a standalone server.
- * Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess
- * to get file descriptor for the socket.
- * @param serverPort Port.
* @throws IOException if failed to bind localhost:port.
*/
public Server(int serverPort) throws IOException {
- this(serverPort, defaultSocketFdGetter);
- }
-
- /**
- * Start a server connected to proxy.
- * @param proxyHost The proxy server host.
- * @param proxyPort The proxy server port.
- * @param key The key to identify the server.
- * @param socketFdGetter Method to get system file descriptor of the server socket.
- */
- public Server(String proxyHost, int proxyPort, String key,
- SocketFileDescriptorGetter socketFdGetter) {
- worker = new WorkerThread(
- new ConnectProxyServerProcessor(proxyHost, proxyPort, key, socketFdGetter));
+ worker = new WorkerThread(new StandaloneServerProcessor(serverPort));
}
/**
@@ -112,7 +68,8 @@ public Server(String proxyHost, int proxyPort, String key,
* @param key The key to identify the server.
*/
public Server(String proxyHost, int proxyPort, String key) {
- this(proxyHost, proxyPort, key, defaultSocketFdGetter);
+ worker = new WorkerThread(
+ new ConnectProxyServerProcessor(proxyHost, proxyPort, key));
}
/**
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketChannel.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketChannel.java
new file mode 100644
index 000000000000..e72581b2358f
--- /dev/null
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketChannel.java
@@ -0,0 +1,49 @@
+package ml.dmlc.tvm.rpc;
+
+import ml.dmlc.tvm.Function;
+import ml.dmlc.tvm.TVMValue;
+import ml.dmlc.tvm.TVMValueBytes;
+
+import java.io.IOException;
+import java.net.Socket;
+
+public class SocketChannel {
+ private final Socket socket;
+
+ SocketChannel(Socket sock) {
+ socket = sock;
+ }
+
+ private Function fsend = Function.convertFunc(new Function.Callback() {
+ @Override public Object invoke(TVMValue... args) {
+ byte[] data = args[0].asBytes();
+ try {
+ socket.getOutputStream().write(data);
+ } catch (IOException e) {
+ e.printStackTrace();
+ return -1;
+ }
+ return data.length;
+ }
+ });
+
+ private Function frecv = Function.convertFunc(new Function.Callback() {
+ @Override public Object invoke(TVMValue... args) {
+ long size = args[0].asLong();
+ try {
+ return new TVMValueBytes(Utils.recvAll(socket.getInputStream(), (int) size));
+ } catch (IOException e) {
+ e.printStackTrace();
+ return -1;
+ }
+ }
+ });
+
+ public Function getFsend() {
+ return fsend;
+ }
+
+ public Function getFrecv() {
+ return frecv;
+ }
+}
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketFileDescriptorGetter.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketFileDescriptorGetter.java
deleted file mode 100644
index 4c35f720009d..000000000000
--- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketFileDescriptorGetter.java
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package ml.dmlc.tvm.rpc;
-
-import java.net.Socket;
-
-/**
- * Interface for defining different socket fd getter.
- */
-public interface SocketFileDescriptorGetter {
- /**
- * Get native socket file descriptor.
- * @param socket Java socket.
- * @return native socket fd.
- */
- public int get(Socket socket);
-}
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java
index 06e3303d1523..2d2303d3fe8a 100644
--- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java
@@ -28,12 +28,9 @@
*/
public class StandaloneServerProcessor implements ServerProcessor {
private final ServerSocket server;
- private final SocketFileDescriptorGetter socketFileDescriptorGetter;
- public StandaloneServerProcessor(int serverPort,
- SocketFileDescriptorGetter sockFdGetter) throws IOException {
+ public StandaloneServerProcessor(int serverPort) throws IOException {
this.server = new ServerSocket(serverPort);
- this.socketFileDescriptorGetter = sockFdGetter;
}
@Override public void terminate() {
@@ -46,9 +43,9 @@ public StandaloneServerProcessor(int serverPort,
@Override public void run() {
try {
- Socket socket = server.accept();
- InputStream in = socket.getInputStream();
- OutputStream out = socket.getOutputStream();
+ final Socket socket = server.accept();
+ final InputStream in = socket.getInputStream();
+ final OutputStream out = socket.getOutputStream();
int magic = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt();
if (magic != RPC.RPC_MAGIC) {
Utils.closeQuietly(socket);
@@ -66,12 +63,10 @@ public StandaloneServerProcessor(int serverPort,
out.write(Utils.toBytes(serverKey));
}
+ SocketChannel sockChannel = new SocketChannel(socket);
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
- final int sockFd = socketFileDescriptorGetter.get(socket);
- if (sockFd != -1) {
- new NativeServerLoop(sockFd).run();
- System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
- }
+ new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
+ System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
Utils.closeQuietly(socket);
} catch (Throwable e) {
e.printStackTrace();
diff --git a/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java b/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java
index d719eb6f61e7..a29402867381 100644
--- a/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java
+++ b/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java
@@ -17,7 +17,10 @@
package ml.dmlc.tvm.contrib;
-import ml.dmlc.tvm.*;
+import ml.dmlc.tvm.Module;
+import ml.dmlc.tvm.NDArray;
+import ml.dmlc.tvm.TVMContext;
+import ml.dmlc.tvm.TestUtils;
import ml.dmlc.tvm.rpc.Client;
import ml.dmlc.tvm.rpc.RPCSession;
import ml.dmlc.tvm.rpc.Server;
diff --git a/jvm/pom.xml b/jvm/pom.xml
index 99cfe0d7b5ec..150c3a00a894 100644
--- a/jvm/pom.xml
+++ b/jvm/pom.xml
@@ -164,8 +164,8 @@
maven-compiler-plugin
3.3
- 1.6
- 1.6
+ 1.7
+ 1.7
UTF-8
diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py
index f31625cd34ed..b9b29a7fe4a1 100644
--- a/python/tvm/rpc/tracker.py
+++ b/python/tvm/rpc/tracker.py
@@ -230,7 +230,7 @@ def call_handler(self, args):
port, matchkey = args[2]
self.pending_matchkeys.add(matchkey)
# got custom address (from rpc server)
- if args[3] is not None:
+ if len(args) >= 4 and args[3] is not None:
value = (self, args[3], port, matchkey)
else:
value = (self, self._addr[0], port, matchkey)
diff --git a/src/common/socket.h b/src/common/socket.h
index 2a2d9166a134..39bcff863c10 100644
--- a/src/common/socket.h
+++ b/src/common/socket.h
@@ -27,8 +27,10 @@
#define TVM_COMMON_SOCKET_H_
#if defined(_WIN32)
+#define NOMINMAX
#include
#include
+#undef NOMINMAX
using ssize_t = int;
#ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib")
diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc
index 7a142f3373db..3f4782693d8a 100644
--- a/src/runtime/rpc/rpc_event_impl.cc
+++ b/src/runtime/rpc/rpc_event_impl.cc
@@ -29,32 +29,14 @@
namespace tvm {
namespace runtime {
-class CallbackChannel final : public RPCChannel {
- public:
- explicit CallbackChannel(PackedFunc fsend)
- : fsend_(fsend) {}
-
- size_t Send(const void* data, size_t size) final {
- TVMByteArray bytes;
- bytes.data = static_cast(data);
- bytes.size = size;
- uint64_t ret = fsend_(bytes);
- return static_cast(ret);
- }
-
- size_t Recv(void* data, size_t size) final {
- LOG(FATAL) << "Do not allow explicit receive for";
- return 0;
- }
-
- private:
- PackedFunc fsend_;
-};
-
PackedFunc CreateEventDrivenServer(PackedFunc fsend,
std::string name,
std::string remote_key) {
- std::unique_ptr ch(new CallbackChannel(fsend));
+ static PackedFunc frecv([](TVMArgs args, TVMRetValue* rv) {
+ LOG(FATAL) << "Do not allow explicit receive";
+ return 0;
+ });
+ std::unique_ptr ch(new CallbackChannel(fsend, frecv));
std::shared_ptr sess =
RPCSession::Create(std::move(ch), name, remote_key);
return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc
index f235ec8e8f0c..39db150bd3a0 100644
--- a/src/runtime/rpc/rpc_session.cc
+++ b/src/runtime/rpc/rpc_session.cc
@@ -36,6 +36,7 @@
#include
#include "rpc_session.h"
#include "../../common/ring_buffer.h"
+#include "../../common/socket.h"
namespace tvm {
namespace runtime {
@@ -1260,5 +1261,26 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf,
return PackedFunc(ftimer);
}
+size_t CallbackChannel::Send(const void* data, size_t size) {
+ TVMByteArray bytes;
+ bytes.data = static_cast(data);
+ bytes.size = size;
+ int64_t n = fsend_(bytes);
+ if (n == -1) {
+ common::Socket::Error("CallbackChannel::Send");
+ }
+ return static_cast(n);
+}
+
+size_t CallbackChannel::Recv(void* data, size_t size) {
+ TVMRetValue ret = frecv_(size);
+ if (ret.type_code() != kBytes) {
+ common::Socket::Error("CallbackChannel::Recv");
+ }
+ std::string* bytes = ret.ptr();
+ memcpy(static_cast(data), bytes->c_str(), bytes->length());
+ return bytes->length();
+}
+
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h
index bc0bc8fe5918..d982f68bcb6e 100644
--- a/src/runtime/rpc/rpc_session.h
+++ b/src/runtime/rpc/rpc_session.h
@@ -87,7 +87,7 @@ class RPCChannel {
*/
virtual size_t Send(const void* data, size_t size) = 0;
/*!
-e * \brief Recv data from channel.
+ * \brief Recv data from channel.
*
* \param data The data pointer.
* \param size The size fo the data.
@@ -253,6 +253,37 @@ class RPCSession {
std::string remote_key_;
};
+/*!
+ * \brief RPC channel which callback
+ * frontend (Python/Java/etc.)'s send & recv function
+ */
+class CallbackChannel final : public RPCChannel {
+ public:
+ explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv)
+ : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {}
+
+ ~CallbackChannel() {}
+ /*!
+ * \brief Send data over to the channel.
+ * \param data The data pointer.
+ * \param size The size fo the data.
+ * \return The actual bytes sent.
+ */
+ size_t Send(const void* data, size_t size) final;
+ /*!
+ * \brief Recv data from channel.
+ *
+ * \param data The data pointer.
+ * \param size The size fo the data.
+ * \return The actual bytes received.
+ */
+ size_t Recv(void* data, size_t size) final;
+
+ private:
+ PackedFunc fsend_;
+ PackedFunc frecv_;
+};
+
/*!
* \brief Wrap a timer function to measure the time cost of a given packed function.
* \param f The function argument.
diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc
index 16528bcc68a1..65d37531159f 100644
--- a/src/runtime/rpc/rpc_socket_impl.cc
+++ b/src/runtime/rpc/rpc_socket_impl.cc
@@ -36,7 +36,7 @@ class SockChannel final : public RPCChannel {
: sock_(sock) {}
~SockChannel() {
if (!sock_.BadSocket()) {
- sock_.Close();
+ sock_.Close();
}
}
size_t Send(const void* data, size_t size) final {
@@ -109,12 +109,25 @@ void RPCServerLoop(int sockfd) {
"SockServerLoop", "")->ServerLoop();
}
+void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) {
+ RPCSession::Create(std::unique_ptr(
+ new CallbackChannel(fsend, frecv)),
+ "SockServerLoop", "")->ServerLoop();
+}
+
TVM_REGISTER_GLOBAL("rpc._Connect")
.set_body_typed(RPCClientConnect);
TVM_REGISTER_GLOBAL("rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
- RPCServerLoop(args[0]);
+ if (args.size() == 1) {
+ RPCServerLoop(args[0]);
+ } else {
+ CHECK_EQ(args.size(), 2);
+ RPCServerLoop(
+ args[0].operator tvm::runtime::PackedFunc(),
+ args[1].operator tvm::runtime::PackedFunc());
+ }
});
} // namespace runtime
} // namespace tvm