1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one
3+ * or more contributor license agreements. See the NOTICE file
4+ * distributed with this work for additional information
5+ * regarding copyright ownership. The ASF licenses this file
6+ * to you under the Apache License, Version 2.0 (the
7+ * "License"); you may not use this file except in compliance
8+ * with the License. You may obtain a copy of the License at
9+ *
10+ * http://www.apache.org/licenses/LICENSE-2.0
11+ *
12+ * Unless required by applicable law or agreed to in writing,
13+ * software distributed under the License is distributed on an
14+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+ * KIND, either express or implied. See the License for the
16+ * specific language governing permissions and limitations
17+ * under the License.
18+ */
19+ #ifndef TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_
20+ #define TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_
21+
22+ #include < dmlc/io.h>
23+
24+ #include " ./protocol.h"
25+
26+ namespace tvm {
27+ namespace runtime {
28+
29+ class DiscoStreamMessageQueue : private dmlc ::Stream,
30+ private DiscoProtocol<DiscoStreamMessageQueue> {
31+ public:
32+ explicit DiscoStreamMessageQueue (Stream* stream) : stream_(stream) {}
33+
34+ ~DiscoStreamMessageQueue () = default ;
35+
36+ void Send (const TVMArgs& args) {
37+ RPCReference::ReturnPackedSeq (args.values , args.type_codes , args.num_args , this );
38+ CommitSendAndNotifyEnqueue ();
39+ }
40+
41+ TVMArgs Recv () {
42+ bool is_implicit_shutdown = DequeueNextPacket ();
43+ TVMValue* values = nullptr ;
44+ int * type_codes = nullptr ;
45+ int num_args = 0 ;
46+
47+ if (is_implicit_shutdown) {
48+ num_args = 2 ;
49+ values = ArenaAlloc<TVMValue>(num_args);
50+ type_codes = ArenaAlloc<int >(num_args);
51+ TVMArgsSetter setter (values, type_codes);
52+ setter (0 , static_cast <int >(DiscoAction::kShutDown ));
53+ setter (1 , 0 );
54+ } else {
55+ RPCReference::RecvPackedSeq (&values, &type_codes, &num_args, this );
56+ }
57+ return TVMArgs (values, type_codes, num_args);
58+ }
59+
60+ protected:
61+ void CommitSendAndNotifyEnqueue () {
62+ stream_->Write (write_buffer_.data (), write_buffer_.size ());
63+ write_buffer_.clear ();
64+ }
65+
66+ /* \brief Read next packet and reset unpacker
67+ *
68+ * Read the next packet into `read_buffer_`, releasing all arena
69+ * allocations performed by the unpacker and resetting the unpacker
70+ * to its initial state.
71+ *
72+ * \return A boolean value. If true, this packet should be treated
73+ * equivalently to a `DiscoAction::kShutdown` event. If false,
74+ * this packet should be unpacked.
75+ */
76+ bool DequeueNextPacket () {
77+ uint64_t packet_nbytes = 0 ;
78+ int read_size = stream_->Read (&packet_nbytes, sizeof (packet_nbytes));
79+ if (read_size == 0 ) {
80+ // Special case, connection dropped between packets. Treat as a
81+ // request to shutdown.
82+ return true ;
83+ }
84+
85+ ICHECK_EQ (read_size, sizeof (packet_nbytes))
86+ << " Stream closed without proper shutdown. Please make sure to explicitly call "
87+ " `Session::Shutdown`" ;
88+ read_buffer_.resize (packet_nbytes);
89+ read_size = stream_->Read (read_buffer_.data (), packet_nbytes);
90+ ICHECK_EQ (read_size, packet_nbytes)
91+ << " Stream closed without proper shutdown. Please make sure to explicitly call "
92+ " `Session::Shutdown`" ;
93+ read_offset_ = 0 ;
94+ this ->RecycleAll ();
95+ RPCCode code = RPCCode::kReturn ;
96+ this ->Read (&code);
97+ return false ;
98+ }
99+
100+ size_t Read (void * data, size_t size) final {
101+ std::memcpy (data, read_buffer_.data () + read_offset_, size);
102+ read_offset_ += size;
103+ ICHECK_LE (read_offset_, read_buffer_.size ());
104+ return size;
105+ }
106+
107+ size_t Write (const void * data, size_t size) final {
108+ size_t cur_size = write_buffer_.size ();
109+ write_buffer_.resize (cur_size + size);
110+ std::memcpy (write_buffer_.data () + cur_size, data, size);
111+ return size;
112+ }
113+
114+ using dmlc::Stream::Read;
115+ using dmlc::Stream::ReadArray;
116+ using dmlc::Stream::Write;
117+ using dmlc::Stream::WriteArray;
118+ friend struct RPCReference ;
119+ friend struct DiscoProtocol <DiscoStreamMessageQueue>;
120+
121+ // The read/write buffer will only be accessed by the producer thread.
122+ std::string write_buffer_;
123+ std::string read_buffer_;
124+ size_t read_offset_ = 0 ;
125+ dmlc::Stream* stream_;
126+ };
127+
128+ } // namespace runtime
129+ } // namespace tvm
130+
131+ #endif // TVM_RUNTIME_DISCO_MESSAGE_QUEUE_H_
0 commit comments