Skip to content

Commit 5408d3a

Browse files
yzhliutqchen
authored andcommitted
[rpc] use callback func to do send & recv (#4147)
* [rpc] use callback func to do send & recv. don't get fd from sock as it is deprecated in java * fix java build * fix min/max macro define in windows * keep the old rpc setup for py * add doc for CallbackChannel
1 parent a740423 commit 5408d3a

File tree

16 files changed

+165
-146
lines changed

16 files changed

+165
-146
lines changed

jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectProxyServerProcessor.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
3030
private final String host;
3131
private final int port;
3232
private final String key;
33-
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
3433

3534
private volatile Socket currSocket = new Socket();
3635
private Runnable callback;
@@ -40,14 +39,11 @@ public class ConnectProxyServerProcessor implements ServerProcessor {
4039
* @param host Proxy server host.
4140
* @param port Proxy server port.
4241
* @param key Proxy server key.
43-
* @param sockFdGetter Method to get file descriptor from Java socket.
4442
*/
45-
public ConnectProxyServerProcessor(String host, int port, String key,
46-
SocketFileDescriptorGetter sockFdGetter) {
43+
public ConnectProxyServerProcessor(String host, int port, String key) {
4744
this.host = host;
4845
this.port = port;
4946
this.key = "server:" + key;
50-
socketFileDescriptorGetter = sockFdGetter;
5147
}
5248

5349
/**
@@ -70,8 +66,8 @@ public void setStartTimeCallback(Runnable callback) {
7066
try {
7167
SocketAddress address = new InetSocketAddress(host, port);
7268
currSocket.connect(address, 6000);
73-
InputStream in = currSocket.getInputStream();
74-
OutputStream out = currSocket.getOutputStream();
69+
final InputStream in = currSocket.getInputStream();
70+
final OutputStream out = currSocket.getOutputStream();
7571
out.write(Utils.toBytes(RPC.RPC_MAGIC));
7672
out.write(Utils.toBytes(key.length()));
7773
out.write(Utils.toBytes(key));
@@ -91,11 +87,10 @@ public void setStartTimeCallback(Runnable callback) {
9187
if (callback != null) {
9288
callback.run();
9389
}
94-
final int sockFd = socketFileDescriptorGetter.get(currSocket);
95-
if (sockFd != -1) {
96-
new NativeServerLoop(sockFd).run();
97-
System.err.println("Finish serving " + address);
98-
}
90+
91+
SocketChannel sockChannel = new SocketChannel(currSocket);
92+
new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
93+
System.err.println("Finish serving " + address);
9994
} catch (Throwable e) {
10095
e.printStackTrace();
10196
throw new RuntimeException(e);

jvm/core/src/main/java/ml/dmlc/tvm/rpc/ConnectTrackerServerProcessor.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
*/
3838
public class ConnectTrackerServerProcessor implements ServerProcessor {
3939
private ServerSocket server;
40-
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
4140
private final String trackerHost;
4241
private final int trackerPort;
4342
// device key
@@ -62,10 +61,11 @@ public class ConnectTrackerServerProcessor implements ServerProcessor {
6261
* @param trackerHost Tracker host.
6362
* @param trackerPort Tracker port.
6463
* @param key Device key.
65-
* @param sockFdGetter Method to get file descriptor from Java socket.
64+
* @param watchdog watch for timeout, etc.
65+
* @throws java.io.IOException when socket fails to open.
6666
*/
6767
public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String key,
68-
SocketFileDescriptorGetter sockFdGetter, RPCWatchdog watchdog) throws IOException {
68+
RPCWatchdog watchdog) throws IOException {
6969
while (true) {
7070
try {
7171
this.server = new ServerSocket(serverPort);
@@ -81,7 +81,6 @@ public ConnectTrackerServerProcessor(String trackerHost, int trackerPort, String
8181
}
8282
}
8383
System.err.println("using port: " + serverPort);
84-
this.socketFileDescriptorGetter = sockFdGetter;
8584
this.trackerHost = trackerHost;
8685
this.trackerPort = trackerPort;
8786
this.key = key;
@@ -163,11 +162,9 @@ public String getMatchKey() {
163162
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
164163
// received timeout in seconds
165164
watchdog.startTimeout(timeout * 1000);
166-
final int sockFd = socketFileDescriptorGetter.get(socket);
167-
if (sockFd != -1) {
168-
new NativeServerLoop(sockFd).run();
169-
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
170-
}
165+
SocketChannel sockChannel = new SocketChannel(socket);
166+
new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
167+
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
171168
Utils.closeQuietly(socket);
172169
} catch (ConnectException e) {
173170
// if the tracker connection failed, wait a bit before retrying

jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,25 @@
2828
* Call native ServerLoop on socket file descriptor.
2929
*/
3030
public class NativeServerLoop implements Runnable {
31-
private final int sockFd;
31+
private final Function fsend;
32+
private final Function frecv;
3233

3334
/**
3435
* Constructor for NativeServerLoop.
35-
* @param nativeSockFd native socket file descriptor.
36+
* @param fsend socket.send function.
37+
* @param frecv socket.recv function.
3638
*/
37-
public NativeServerLoop(final int nativeSockFd) {
38-
sockFd = nativeSockFd;
39+
public NativeServerLoop(final Function fsend, final Function frecv) {
40+
this.fsend = fsend;
41+
this.frecv = frecv;
3942
}
4043

4144
@Override public void run() {
4245
File tempDir = null;
4346
try {
4447
tempDir = serverEnv();
4548
System.err.println("starting server loop...");
46-
RPC.getApi("_ServerLoop").pushArg(sockFd).invoke();
49+
RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke();
4750
System.err.println("done server loop...");
4851
} catch (IOException e) {
4952
e.printStackTrace();

jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ public void upload(byte[] data, String target) {
200200
* Upload file to remote runtime temp folder.
201201
* @param data The file in local to upload.
202202
* @param target The path in remote.
203+
* @throws java.io.IOException for network failure.
203204
*/
204205
public void upload(File data, String target) throws IOException {
205206
byte[] blob = getBytesFromFile(data);
@@ -209,6 +210,7 @@ public void upload(File data, String target) throws IOException {
209210
/**
210211
* Upload file to remote runtime temp folder.
211212
* @param data The file in local to upload.
213+
* @throws java.io.IOException for network failure.
212214
*/
213215
public void upload(File data) throws IOException {
214216
upload(data, data.getName());

jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,12 @@
1717

1818
package ml.dmlc.tvm.rpc;
1919

20-
import sun.misc.SharedSecrets;
21-
22-
import java.io.FileDescriptor;
23-
import java.io.FileInputStream;
2420
import java.io.IOException;
25-
import java.io.InputStream;
26-
import java.net.Socket;
2721

2822
/**
2923
* RPC Server.
3024
*/
3125
public class Server {
32-
private static SocketFileDescriptorGetter defaultSocketFdGetter
33-
= new SocketFileDescriptorGetter() {
34-
@Override public int get(Socket socket) {
35-
try {
36-
InputStream is = socket.getInputStream();
37-
FileDescriptor fd = ((FileInputStream) is).getFD();
38-
return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd);
39-
} catch (IOException e) {
40-
e.printStackTrace();
41-
return -1;
42-
}
43-
}
44-
};
4526
private final WorkerThread worker;
4627

4728
private static class WorkerThread extends Thread {
@@ -72,35 +53,10 @@ public void terminate() {
7253
/**
7354
* Start a standalone server.
7455
* @param serverPort Port.
75-
* @param socketFdGetter Method to get system file descriptor of the server socket.
76-
* @throws IOException if failed to bind localhost:port.
77-
*/
78-
public Server(int serverPort, SocketFileDescriptorGetter socketFdGetter) throws IOException {
79-
worker = new WorkerThread(new StandaloneServerProcessor(serverPort, socketFdGetter));
80-
}
81-
82-
/**
83-
* Start a standalone server.
84-
* Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess
85-
* to get file descriptor for the socket.
86-
* @param serverPort Port.
8756
* @throws IOException if failed to bind localhost:port.
8857
*/
8958
public Server(int serverPort) throws IOException {
90-
this(serverPort, defaultSocketFdGetter);
91-
}
92-
93-
/**
94-
* Start a server connected to proxy.
95-
* @param proxyHost The proxy server host.
96-
* @param proxyPort The proxy server port.
97-
* @param key The key to identify the server.
98-
* @param socketFdGetter Method to get system file descriptor of the server socket.
99-
*/
100-
public Server(String proxyHost, int proxyPort, String key,
101-
SocketFileDescriptorGetter socketFdGetter) {
102-
worker = new WorkerThread(
103-
new ConnectProxyServerProcessor(proxyHost, proxyPort, key, socketFdGetter));
59+
worker = new WorkerThread(new StandaloneServerProcessor(serverPort));
10460
}
10561

10662
/**
@@ -112,7 +68,8 @@ public Server(String proxyHost, int proxyPort, String key,
11268
* @param key The key to identify the server.
11369
*/
11470
public Server(String proxyHost, int proxyPort, String key) {
115-
this(proxyHost, proxyPort, key, defaultSocketFdGetter);
71+
worker = new WorkerThread(
72+
new ConnectProxyServerProcessor(proxyHost, proxyPort, key));
11673
}
11774

11875
/**
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package ml.dmlc.tvm.rpc;
2+
3+
import ml.dmlc.tvm.Function;
4+
import ml.dmlc.tvm.TVMValue;
5+
import ml.dmlc.tvm.TVMValueBytes;
6+
7+
import java.io.IOException;
8+
import java.net.Socket;
9+
10+
public class SocketChannel {
11+
private final Socket socket;
12+
13+
SocketChannel(Socket sock) {
14+
socket = sock;
15+
}
16+
17+
private Function fsend = Function.convertFunc(new Function.Callback() {
18+
@Override public Object invoke(TVMValue... args) {
19+
byte[] data = args[0].asBytes();
20+
try {
21+
socket.getOutputStream().write(data);
22+
} catch (IOException e) {
23+
e.printStackTrace();
24+
return -1;
25+
}
26+
return data.length;
27+
}
28+
});
29+
30+
private Function frecv = Function.convertFunc(new Function.Callback() {
31+
@Override public Object invoke(TVMValue... args) {
32+
long size = args[0].asLong();
33+
try {
34+
return new TVMValueBytes(Utils.recvAll(socket.getInputStream(), (int) size));
35+
} catch (IOException e) {
36+
e.printStackTrace();
37+
return -1;
38+
}
39+
}
40+
});
41+
42+
public Function getFsend() {
43+
return fsend;
44+
}
45+
46+
public Function getFrecv() {
47+
return frecv;
48+
}
49+
}

jvm/core/src/main/java/ml/dmlc/tvm/rpc/SocketFileDescriptorGetter.java

Lines changed: 0 additions & 32 deletions
This file was deleted.

jvm/core/src/main/java/ml/dmlc/tvm/rpc/StandaloneServerProcessor.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,9 @@
2828
*/
2929
public class StandaloneServerProcessor implements ServerProcessor {
3030
private final ServerSocket server;
31-
private final SocketFileDescriptorGetter socketFileDescriptorGetter;
3231

33-
public StandaloneServerProcessor(int serverPort,
34-
SocketFileDescriptorGetter sockFdGetter) throws IOException {
32+
public StandaloneServerProcessor(int serverPort) throws IOException {
3533
this.server = new ServerSocket(serverPort);
36-
this.socketFileDescriptorGetter = sockFdGetter;
3734
}
3835

3936
@Override public void terminate() {
@@ -46,9 +43,9 @@ public StandaloneServerProcessor(int serverPort,
4643

4744
@Override public void run() {
4845
try {
49-
Socket socket = server.accept();
50-
InputStream in = socket.getInputStream();
51-
OutputStream out = socket.getOutputStream();
46+
final Socket socket = server.accept();
47+
final InputStream in = socket.getInputStream();
48+
final OutputStream out = socket.getOutputStream();
5249
int magic = Utils.wrapBytes(Utils.recvAll(in, 4)).getInt();
5350
if (magic != RPC.RPC_MAGIC) {
5451
Utils.closeQuietly(socket);
@@ -66,12 +63,10 @@ public StandaloneServerProcessor(int serverPort,
6663
out.write(Utils.toBytes(serverKey));
6764
}
6865

66+
SocketChannel sockChannel = new SocketChannel(socket);
6967
System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
70-
final int sockFd = socketFileDescriptorGetter.get(socket);
71-
if (sockFd != -1) {
72-
new NativeServerLoop(sockFd).run();
73-
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
74-
}
68+
new NativeServerLoop(sockChannel.getFsend(), sockChannel.getFrecv()).run();
69+
System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
7570
Utils.closeQuietly(socket);
7671
} catch (Throwable e) {
7772
e.printStackTrace();

jvm/core/src/test/java/ml/dmlc/tvm/contrib/GraphRuntimeTest.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
package ml.dmlc.tvm.contrib;
1919

20-
import ml.dmlc.tvm.*;
20+
import ml.dmlc.tvm.Module;
21+
import ml.dmlc.tvm.NDArray;
22+
import ml.dmlc.tvm.TVMContext;
23+
import ml.dmlc.tvm.TestUtils;
2124
import ml.dmlc.tvm.rpc.Client;
2225
import ml.dmlc.tvm.rpc.RPCSession;
2326
import ml.dmlc.tvm.rpc.Server;

jvm/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@
164164
<artifactId>maven-compiler-plugin</artifactId>
165165
<version>3.3</version>
166166
<configuration>
167-
<source>1.6</source>
168-
<target>1.6</target>
167+
<source>1.7</source>
168+
<target>1.7</target>
169169
<encoding>UTF-8</encoding>
170170
</configuration>
171171
</plugin>

0 commit comments

Comments
 (0)