diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java index 2fc97f65aca4..04888f568be3 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java @@ -30,7 +30,6 @@ public class ConnectProxyServerProcessor implements ServerProcessor { private final String host; private final int port; private final String key; - private final SocketFileDescriptorGetter socketFileDescriptorGetter; private volatile Socket currSocket = new Socket(); private Runnable callback; @@ -40,14 +39,11 @@ public class ConnectProxyServerProcessor implements ServerProcessor { * @param host Proxy server host. * @param port Proxy server port. * @param key Proxy server key. - * @param sockFdGetter Method to get file descriptor from Java socket. */ - public ConnectProxyServerProcessor(String host, int port, String key, - SocketFileDescriptorGetter sockFdGetter) { + public ConnectProxyServerProcessor(String host, int port, String key) { this.host = host; this.port = port; this.key = "server:" + key; - socketFileDescriptorGetter = sockFdGetter; } /** @@ -70,8 +66,8 @@ public void setStartTimeCallback(Runnable callback) { try { SocketAddress address = new InetSocketAddress(host, port); currSocket.connect(address, 6000); - InputStream in = currSocket.getInputStream(); - OutputStream out = currSocket.getOutputStream(); + final InputStream in = currSocket.getInputStream(); + final OutputStream out = currSocket.getOutputStream(); out.write(Utils.toBytes(RPC.RPC_MAGIC)); out.write(Utils.toBytes(key.length())); out.write(Utils.toBytes(key)); @@ -91,11 +87,10 @@ public void setStartTimeCallback(Runnable callback) { if (callback != null) { callback.run(); } - final int sockFd = socketFileDescriptorGetter.get(currSocket); - if (sockFd != -1) { - new NativeServerLoop(sockFd).run(); - System.err.println("Finish serving " + address); - } + + SocketChannel sockChannel = new SocketChannel(currSocket); + new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run(); + System.err.println("Finish serving " + address); } catch (Throwable e) { e.printStackTrace(); throw new RuntimeException(e); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java index 47881eb350c3..c449bb18a565 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java @@ -37,7 +37,6 @@ */ public class ConnectTrackerServerProcessor implements ServerProcessor { private ServerSocket server; - private final SocketFileDescriptorGetter socketFileDescriptorGetter; private final String trackerHost; private final int trackerPort; // device key @@ -62,10 +61,11 @@ public class ConnectTrackerServerProcessor implements ServerProcessor { * @param trackerHost Tracker host. * @param trackerPort Tracker port. * @param key Device key. - * @param sockFdGetter Method to get file descriptor from Java socket. + * @param watchdog watch for timeout, etc. + * @throws java.io.IOException when socket fails to open. */ public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String key, - SocketFileDescriptorGetter sockFdGetter, RPCWatchdog watchdog) throws IOException { + RPCWatchdog watchdog) throws IOException { while (true) { try { this.server = new ServerSocket(serverPort); @@ -81,7 +81,6 @@ public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String } } System.err.println("using port: " + serverPort); - this.socketFileDescriptorGetter = sockFdGetter; this.trackerHost = trackerHost; this.trackerPort = trackerPort; this.key = key; @@ -163,11 +162,9 @@ public String getMatchKey() { System.err.println("Connection from " + socket.getRemoteSocketAddress().toString()); // received timeout in seconds watchdog.startTimeout(timeout * 1000); - final int sockFd = socketFileDescriptorGetter.get(socket); - if (sockFd != -1) { - new NativeServerLoop(sockFd).run(); - System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString()); - } + SocketChannel sockChannel = new SocketChannel(socket); + new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run(); + System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString()); Utils.closeQuietly(socket); } catch (ConnectException e) { // if the tracker connection failed, wait a bit before retrying diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java index 255dabb438d5..697ce45fa04f 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java @@ -28,14 +28,17 @@ * Call native ServerLoop on socket file descriptor. */ public class NativeServerLoop implements Runnable { - private final int sockFd; + private final Function fsend; + private final Function frecv; /** * Constructor for NativeServerLoop. - * @param nativeSockFd native socket file descriptor. + * @param fsend socket.send function. + * @param frecv socket.recv function. */ - public NativeServerLoop(final int nativeSockFd) { - sockFd = nativeSockFd; + public NativeServerLoop(final Function fsend, final Function frecv) { + this.fsend = fsend; + this.frecv = frecv; } @Override public void run() { @@ -43,7 +46,7 @@ public NativeServerLoop(final int nativeSockFd) { try { tempDir = serverEnv(); System.err.println("starting server loop..."); - RPC.getApi("_ServerLoop").pushArg(sockFd).invoke(); + RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke(); System.err.println("done server loop..."); } catch (IOException e) { e.printStackTrace(); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java index 8ebf188b0667..278ef9fe8eef 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java @@ -200,6 +200,7 @@ public void upload(byte[] data, String target) { * Upload file to remote runtime temp folder. * @param data The file in local to upload. * @param target The path in remote. + * @throws java.io.IOException for network failure. */ public void upload(File data, String target) throws IOException { byte[] blob = getBytesFromFile(data); @@ -209,6 +210,7 @@ public void upload(File data, String target) throws IOException { /** * Upload file to remote runtime temp folder. * @param data The file in local to upload. + * @throws java.io.IOException for network failure. */ public void upload(File data) throws IOException { upload(data, data.getName()); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java index c81faa0ca999..a9ea2d89a62c 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java @@ -17,31 +17,12 @@ package ml.dmlc.tvm.rpc; -import sun.misc.SharedSecrets; - -import java.io.FileDescriptor; -import java.io.FileInputStream; import java.io.IOException; -import java.io.InputStream; -import java.net.Socket; /** * RPC Server. */ public class Server { - private static SocketFileDescriptorGetter defaultSocketFdGetter - = new SocketFileDescriptorGetter() { - @Override public int get(Socket socket) { - try { - InputStream is = socket.getInputStream(); - FileDescriptor fd = ((FileInputStream) is).getFD(); - return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd); - } catch (IOException e) { - e.printStackTrace(); - return -1; - } - } - }; private final WorkerThread worker; private static class WorkerThread extends Thread { @@ -72,35 +53,10 @@ public void terminate() { /** * Start a standalone server. * @param serverPort Port. - * @param socketFdGetter Method to get system file descriptor of the server socket. - * @throws IOException if failed to bind localhost:port. - */ - public Server(int serverPort, SocketFileDescriptorGetter socketFdGetter) throws IOException { - worker = new WorkerThread(new StandaloneServerProcessor(serverPort, socketFdGetter)); - } - - /** - * Start a standalone server. - * Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess - * to get file descriptor for the socket. - * @param serverPort Port. * @throws IOException if failed to bind localhost:port. */ public Server(int serverPort) throws IOException { - this(serverPort, defaultSocketFdGetter); - } - - /** - * Start a server connected to proxy. - * @param proxyHost The proxy server host. - * @param proxyPort The proxy server port. - * @param key The key to identify the server. - * @param socketFdGetter Method to get system file descriptor of the server socket. - */ - public Server(String proxyHost, int proxyPort, String key, - SocketFileDescriptorGetter socketFdGetter) { - worker = new WorkerThread( - new ConnectProxyServerProcessor(proxyHost, proxyPort, key, socketFdGetter)); + worker = new WorkerThread(new StandaloneServerProcessor(serverPort)); } /** @@ -112,7 +68,8 @@ public Server(String proxyHost, int proxyPort, String key, * @param key The key to identify the server. */ public Server(String proxyHost, int proxyPort, String key) { - this(proxyHost, proxyPort, key, defaultSocketFdGetter); + worker = new WorkerThread( + new ConnectProxyServerProcessor(proxyHost, proxyPort, key)); } /** diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketChannel.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketChannel.java new file mode 100644 index 000000000000..e72581b2358f --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketChannel.java @@ -0,0 +1,49 @@ +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.Function; +import ml.dmlc.tvm.TVMValue; +import ml.dmlc.tvm.TVMValueBytes; + +import java.io.IOException; +import java.net.Socket; + +public class SocketChannel { + private final Socket socket; + + SocketChannel(Socket sock) { + socket = sock; + } + + private Function fsend = Function.convertFunc(new Function.Callback() { + @Override public Object invoke(TVMValue... args) { + byte[] data = args[0].asBytes(); + try { + socket.getOutputStream().write(data); + } catch (IOException e) { + e.printStackTrace(); + return -1; + } + return data.length; + } + }); + + private Function frecv = Function.convertFunc(new Function.Callback() { + @Override public Object invoke(TVMValue... args) { + long size = args[0].asLong(); + try { + return new TVMValueBytes(Utils.recvAll(socket.getInputStream(), (int) size)); + } catch (IOException e) { + e.printStackTrace(); + return -1; + } + } + }); + + public Function getFsend() { + return fsend; + } + + public Function getFrecv() { + return frecv; + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketFileDescriptorGetter.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketFileDescriptorGetter.java deleted file mode 100644 index 4c35f720009d..000000000000 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketFileDescriptorGetter.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ml.dmlc.tvm.rpc; - -import java.net.Socket; - -/** - * Interface for defining different socket fd getter. - */ -public interface SocketFileDescriptorGetter { - /** - * Get native socket file descriptor. - * @param socket Java socket. - * @return native socket fd. - */ - public int get(Socket socket); -} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java index 06e3303d1523..2d2303d3fe8a 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java @@ -28,12 +28,9 @@ */ public class StandaloneServerProcessor implements ServerProcessor { private final ServerSocket server; - private final SocketFileDescriptorGetter socketFileDescriptorGetter; - public StandaloneServerProcessor(int serverPort, - SocketFileDescriptorGetter sockFdGetter) throws IOException { + public StandaloneServerProcessor(int serverPort) throws IOException { this.server = new ServerSocket(serverPort); - this.socketFileDescriptorGetter = sockFdGetter; } @Override public void terminate() { @@ -46,9 +43,9 @@ public StandaloneServerProcessor(int serverPort, @Override public void run() { try { - Socket socket = server.accept(); - InputStream in = socket.getInputStream(); - OutputStream out = socket.getOutputStream(); + final Socket socket = server.accept(); + final InputStream in = socket.getInputStream(); + final OutputStream out = socket.getOutputStream(); int magic = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt(); if (magic != RPC.RPC_MAGIC) { Utils.closeQuietly(socket); @@ -66,12 +63,10 @@ public StandaloneServerProcessor(int serverPort, out.write(Utils.toBytes(serverKey)); } + SocketChannel sockChannel = new SocketChannel(socket); System.err.println("Connection from " + socket.getRemoteSocketAddress().toString()); - final int sockFd = socketFileDescriptorGetter.get(socket); - if (sockFd != -1) { - new NativeServerLoop(sockFd).run(); - System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString()); - } + new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run(); + System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString()); Utils.closeQuietly(socket); } catch (Throwable e) { e.printStackTrace(); diff --git a/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java b/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java index d719eb6f61e7..a29402867381 100644 --- a/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java +++ b/jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java @@ -17,7 +17,10 @@ package ml.dmlc.tvm.contrib; -import ml.dmlc.tvm.*; +import ml.dmlc.tvm.Module; +import ml.dmlc.tvm.NDArray; +import ml.dmlc.tvm.TVMContext; +import ml.dmlc.tvm.TestUtils; import ml.dmlc.tvm.rpc.Client; import ml.dmlc.tvm.rpc.RPCSession; import ml.dmlc.tvm.rpc.Server; diff --git a/jvm/pom.xml b/jvm/pom.xml index 99cfe0d7b5ec..150c3a00a894 100644 --- a/jvm/pom.xml +++ b/jvm/pom.xml @@ -164,8 +164,8 @@ maven-compiler-plugin 3.3 - 1.6 - 1.6 + 1.7 + 1.7 UTF-8 diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index f31625cd34ed..b9b29a7fe4a1 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -230,7 +230,7 @@ def call_handler(self, args): port, matchkey = args[2] self.pending_matchkeys.add(matchkey) # got custom address (from rpc server) - if args[3] is not None: + if len(args) >= 4 and args[3] is not None: value = (self, args[3], port, matchkey) else: value = (self, self._addr[0], port, matchkey) diff --git a/src/common/socket.h b/src/common/socket.h index 2a2d9166a134..39bcff863c10 100644 --- a/src/common/socket.h +++ b/src/common/socket.h @@ -27,8 +27,10 @@ #define TVM_COMMON_SOCKET_H_ #if defined(_WIN32) +#define NOMINMAX #include #include +#undef NOMINMAX using ssize_t = int; #ifdef _MSC_VER #pragma comment(lib, "Ws2_32.lib") diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index 7a142f3373db..3f4782693d8a 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -29,32 +29,14 @@ namespace tvm { namespace runtime { -class CallbackChannel final : public RPCChannel { - public: - explicit CallbackChannel(PackedFunc fsend) - : fsend_(fsend) {} - - size_t Send(const void* data, size_t size) final { - TVMByteArray bytes; - bytes.data = static_cast(data); - bytes.size = size; - uint64_t ret = fsend_(bytes); - return static_cast(ret); - } - - size_t Recv(void* data, size_t size) final { - LOG(FATAL) << "Do not allow explicit receive for"; - return 0; - } - - private: - PackedFunc fsend_; -}; - PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name, std::string remote_key) { - std::unique_ptr ch(new CallbackChannel(fsend)); + static PackedFunc frecv([](TVMArgs args, TVMRetValue* rv) { + LOG(FATAL) << "Do not allow explicit receive"; + return 0; + }); + std::unique_ptr ch(new CallbackChannel(fsend, frecv)); std::shared_ptr sess = RPCSession::Create(std::move(ch), name, remote_key); return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index f235ec8e8f0c..39db150bd3a0 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -36,6 +36,7 @@ #include #include "rpc_session.h" #include "../../common/ring_buffer.h" +#include "../../common/socket.h" namespace tvm { namespace runtime { @@ -1260,5 +1261,26 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, return PackedFunc(ftimer); } +size_t CallbackChannel::Send(const void* data, size_t size) { + TVMByteArray bytes; + bytes.data = static_cast(data); + bytes.size = size; + int64_t n = fsend_(bytes); + if (n == -1) { + common::Socket::Error("CallbackChannel::Send"); + } + return static_cast(n); +} + +size_t CallbackChannel::Recv(void* data, size_t size) { + TVMRetValue ret = frecv_(size); + if (ret.type_code() != kBytes) { + common::Socket::Error("CallbackChannel::Recv"); + } + std::string* bytes = ret.ptr(); + memcpy(static_cast(data), bytes->c_str(), bytes->length()); + return bytes->length(); +} + } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index bc0bc8fe5918..d982f68bcb6e 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -87,7 +87,7 @@ class RPCChannel { */ virtual size_t Send(const void* data, size_t size) = 0; /*! -e * \brief Recv data from channel. + * \brief Recv data from channel. * * \param data The data pointer. * \param size The size fo the data. @@ -253,6 +253,37 @@ class RPCSession { std::string remote_key_; }; +/*! + * \brief RPC channel which callback + * frontend (Python/Java/etc.)'s send & recv function + */ +class CallbackChannel final : public RPCChannel { + public: + explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv) + : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {} + + ~CallbackChannel() {} + /*! + * \brief Send data over to the channel. + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes sent. + */ + size_t Send(const void* data, size_t size) final; + /*! + * \brief Recv data from channel. + * + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes received. + */ + size_t Recv(void* data, size_t size) final; + + private: + PackedFunc fsend_; + PackedFunc frecv_; +}; + /*! * \brief Wrap a timer function to measure the time cost of a given packed function. * \param f The function argument. diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 16528bcc68a1..65d37531159f 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -36,7 +36,7 @@ class SockChannel final : public RPCChannel { : sock_(sock) {} ~SockChannel() { if (!sock_.BadSocket()) { - sock_.Close(); + sock_.Close(); } } size_t Send(const void* data, size_t size) final { @@ -109,12 +109,25 @@ void RPCServerLoop(int sockfd) { "SockServerLoop", "")->ServerLoop(); } +void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) { + RPCSession::Create(std::unique_ptr( + new CallbackChannel(fsend, frecv)), + "SockServerLoop", "")->ServerLoop(); +} + TVM_REGISTER_GLOBAL("rpc._Connect") .set_body_typed(RPCClientConnect); TVM_REGISTER_GLOBAL("rpc._ServerLoop") .set_body([](TVMArgs args, TVMRetValue* rv) { - RPCServerLoop(args[0]); + if (args.size() == 1) { + RPCServerLoop(args[0]); + } else { + CHECK_EQ(args.size(), 2); + RPCServerLoop( + args[0].operator tvm::runtime::PackedFunc(), + args[1].operator tvm::runtime::PackedFunc()); + } }); } // namespace runtime } // namespace tvm