2424import requests
2525
2626from google .api_core import rest_streaming
27+ from google .api import http_pb2
28+ from google .api import httpbody_pb2
2729from google .protobuf import duration_pb2
2830from google .protobuf import timestamp_pb2
31+ from google .protobuf .json_format import MessageToJson
2932
3033
3134__protobuf__ = proto .module (package = __name__ )
@@ -98,7 +101,10 @@ def _parse_responses(self, responses: List[proto.Message]) -> bytes:
98101 # json.dumps returns a string surrounded with quotes that need to be stripped
99102 # in order to be an actual JSON.
100103 json_responses = [
101- self ._response_message_cls .to_json (r ).strip ('"' ) for r in responses
104+ self ._response_message_cls .to_json (r ).strip ('"' )
105+ if issubclass (self ._response_message_cls , proto .Message )
106+ else MessageToJson (r ).strip ('"' )
107+ for r in responses
102108 ]
103109 logging .info (f"Sending JSON stream: { json_responses } " )
104110 ret_val = "[{}]" .format ("," .join (json_responses ))
@@ -114,103 +120,220 @@ def iter_content(self, *args, **kwargs):
114120 )
115121
116122
117- @pytest .mark .parametrize ("random_split" , [False ])
118- def test_next_simple (random_split ):
119- responses = [EchoResponse (content = "hello world" ), EchoResponse (content = "yes" )]
123+ @pytest .mark .parametrize (
124+ "random_split,resp_message_is_proto_plus" ,
125+ [(False , True ), (False , False )],
126+ )
127+ def test_next_simple (random_split , resp_message_is_proto_plus ):
128+ if resp_message_is_proto_plus :
129+ response_type = EchoResponse
130+ responses = [EchoResponse (content = "hello world" ), EchoResponse (content = "yes" )]
131+ else :
132+ response_type = httpbody_pb2 .HttpBody
133+ responses = [
134+ httpbody_pb2 .HttpBody (content_type = "hello world" ),
135+ httpbody_pb2 .HttpBody (content_type = "yes" ),
136+ ]
137+
120138 resp = ResponseMock (
121- responses = responses , random_split = random_split , response_cls = EchoResponse
139+ responses = responses , random_split = random_split , response_cls = response_type
122140 )
123- itr = rest_streaming .ResponseIterator (resp , EchoResponse )
141+ itr = rest_streaming .ResponseIterator (resp , response_type )
124142 assert list (itr ) == responses
125143
126144
127- @pytest .mark .parametrize ("random_split" , [True , False ])
128- def test_next_nested (random_split ):
129- responses = [
130- Song (title = "some song" , composer = Composer (given_name = "some name" )),
131- Song (title = "another song" , date_added = datetime .datetime (2021 , 12 , 17 )),
132- ]
145+ @pytest .mark .parametrize (
146+ "random_split,resp_message_is_proto_plus" ,
147+ [
148+ (True , True ),
149+ (False , True ),
150+ (True , False ),
151+ (False , False ),
152+ ],
153+ )
154+ def test_next_nested (random_split , resp_message_is_proto_plus ):
155+ if resp_message_is_proto_plus :
156+ response_type = Song
157+ responses = [
158+ Song (title = "some song" , composer = Composer (given_name = "some name" )),
159+ Song (title = "another song" , date_added = datetime .datetime (2021 , 12 , 17 )),
160+ ]
161+ else :
162+ # Although `http_pb2.HttpRule`` is used in the response, any response message
163+ # can be used which meets this criteria for the test of having a nested field.
164+ response_type = http_pb2 .HttpRule
165+ responses = [
166+ http_pb2 .HttpRule (
167+ selector = "some selector" ,
168+ custom = http_pb2 .CustomHttpPattern (kind = "some kind" ),
169+ ),
170+ http_pb2 .HttpRule (
171+ selector = "another selector" ,
172+ custom = http_pb2 .CustomHttpPattern (path = "some path" ),
173+ ),
174+ ]
133175 resp = ResponseMock (
134- responses = responses , random_split = random_split , response_cls = Song
176+ responses = responses , random_split = random_split , response_cls = response_type
135177 )
136- itr = rest_streaming .ResponseIterator (resp , Song )
178+ itr = rest_streaming .ResponseIterator (resp , response_type )
137179 assert list (itr ) == responses
138180
139181
140- @pytest .mark .parametrize ("random_split" , [True , False ])
141- def test_next_stress (random_split ):
182+ @pytest .mark .parametrize (
183+ "random_split,resp_message_is_proto_plus" ,
184+ [
185+ (True , True ),
186+ (False , True ),
187+ (True , False ),
188+ (False , False ),
189+ ],
190+ )
191+ def test_next_stress (random_split , resp_message_is_proto_plus ):
142192 n = 50
143- responses = [
144- Song (title = "title_%d" % i , composer = Composer (given_name = "name_%d" % i ))
145- for i in range (n )
146- ]
193+ if resp_message_is_proto_plus :
194+ response_type = Song
195+ responses = [
196+ Song (title = "title_%d" % i , composer = Composer (given_name = "name_%d" % i ))
197+ for i in range (n )
198+ ]
199+ else :
200+ response_type = http_pb2 .HttpRule
201+ responses = [
202+ http_pb2 .HttpRule (
203+ selector = "selector_%d" % i ,
204+ custom = http_pb2 .CustomHttpPattern (path = "path_%d" % i ),
205+ )
206+ for i in range (n )
207+ ]
147208 resp = ResponseMock (
148- responses = responses , random_split = random_split , response_cls = Song
209+ responses = responses , random_split = random_split , response_cls = response_type
149210 )
150- itr = rest_streaming .ResponseIterator (resp , Song )
211+ itr = rest_streaming .ResponseIterator (resp , response_type )
151212 assert list (itr ) == responses
152213
153214
154- @pytest .mark .parametrize ("random_split" , [True , False ])
155- def test_next_escaped_characters_in_string (random_split ):
156- composer_with_relateds = Composer ()
157- relateds = ["Artist A" , "Artist B" ]
158- composer_with_relateds .relateds = relateds
159-
160- responses = [
161- Song (title = 'ti"tle\n foo\t bar{}' , composer = Composer (given_name = "name\n \n \n " )),
162- Song (
163- title = '{"this is weird": "totally"}' , composer = Composer (given_name = "\\ {}\\ " )
164- ),
165- Song (title = '\\ {"key": ["value",]}\\ ' , composer = composer_with_relateds ),
166- ]
215+ @pytest .mark .parametrize (
216+ "random_split,resp_message_is_proto_plus" ,
217+ [
218+ (True , True ),
219+ (False , True ),
220+ (True , False ),
221+ (False , False ),
222+ ],
223+ )
224+ def test_next_escaped_characters_in_string (random_split , resp_message_is_proto_plus ):
225+ if resp_message_is_proto_plus :
226+ response_type = Song
227+ composer_with_relateds = Composer ()
228+ relateds = ["Artist A" , "Artist B" ]
229+ composer_with_relateds .relateds = relateds
230+
231+ responses = [
232+ Song (
233+ title = 'ti"tle\n foo\t bar{}' , composer = Composer (given_name = "name\n \n \n " )
234+ ),
235+ Song (
236+ title = '{"this is weird": "totally"}' ,
237+ composer = Composer (given_name = "\\ {}\\ " ),
238+ ),
239+ Song (title = '\\ {"key": ["value",]}\\ ' , composer = composer_with_relateds ),
240+ ]
241+ else :
242+ response_type = http_pb2 .Http
243+ responses = [
244+ http_pb2 .Http (
245+ rules = [
246+ http_pb2 .HttpRule (
247+ selector = 'ti"tle\n foo\t bar{}' ,
248+ custom = http_pb2 .CustomHttpPattern (kind = "name\n \n \n " ),
249+ )
250+ ]
251+ ),
252+ http_pb2 .Http (
253+ rules = [
254+ http_pb2 .HttpRule (
255+ selector = '{"this is weird": "totally"}' ,
256+ custom = http_pb2 .CustomHttpPattern (kind = "\\ {}\\ " ),
257+ )
258+ ]
259+ ),
260+ http_pb2 .Http (
261+ rules = [
262+ http_pb2 .HttpRule (
263+ selector = '\\ {"key": ["value",]}\\ ' ,
264+ custom = http_pb2 .CustomHttpPattern (kind = "\\ {}\\ " ),
265+ )
266+ ]
267+ ),
268+ ]
167269 resp = ResponseMock (
168- responses = responses , random_split = random_split , response_cls = Song
270+ responses = responses , random_split = random_split , response_cls = response_type
169271 )
170- itr = rest_streaming .ResponseIterator (resp , Song )
272+ itr = rest_streaming .ResponseIterator (resp , response_type )
171273 assert list (itr ) == responses
172274
173275
174- def test_next_not_array ():
276+ @pytest .mark .parametrize ("response_type" , [EchoResponse , httpbody_pb2 .HttpBody ])
277+ def test_next_not_array (response_type ):
175278 with patch .object (
176279 ResponseMock , "iter_content" , return_value = iter ('{"hello": 0}' )
177280 ) as mock_method :
178-
179- resp = ResponseMock (responses = [], response_cls = EchoResponse )
180- itr = rest_streaming .ResponseIterator (resp , EchoResponse )
281+ resp = ResponseMock (responses = [], response_cls = response_type )
282+ itr = rest_streaming .ResponseIterator (resp , response_type )
181283 with pytest .raises (ValueError ):
182284 next (itr )
183285 mock_method .assert_called_once ()
184286
185287
186- def test_cancel ():
288+ @pytest .mark .parametrize ("response_type" , [EchoResponse , httpbody_pb2 .HttpBody ])
289+ def test_cancel (response_type ):
187290 with patch .object (ResponseMock , "close" , return_value = None ) as mock_method :
188- resp = ResponseMock (responses = [], response_cls = EchoResponse )
189- itr = rest_streaming .ResponseIterator (resp , EchoResponse )
291+ resp = ResponseMock (responses = [], response_cls = response_type )
292+ itr = rest_streaming .ResponseIterator (resp , response_type )
190293 itr .cancel ()
191294 mock_method .assert_called_once ()
192295
193296
194- def test_check_buffer ():
297+ @pytest .mark .parametrize (
298+ "response_type,return_value" ,
299+ [
300+ (EchoResponse , bytes ('[{"content": "hello"}, {' , "utf-8" )),
301+ (httpbody_pb2 .HttpBody , bytes ('[{"content_type": "hello"}, {' , "utf-8" )),
302+ ],
303+ )
304+ def test_check_buffer (response_type , return_value ):
195305 with patch .object (
196306 ResponseMock ,
197307 "_parse_responses" ,
198- return_value = bytes ( '[{"content": "hello"}, {' , "utf-8" ) ,
308+ return_value = return_value ,
199309 ):
200- resp = ResponseMock (responses = [], response_cls = EchoResponse )
201- itr = rest_streaming .ResponseIterator (resp , EchoResponse )
310+ resp = ResponseMock (responses = [], response_cls = response_type )
311+ itr = rest_streaming .ResponseIterator (resp , response_type )
202312 with pytest .raises (ValueError ):
203313 next (itr )
204314 next (itr )
205315
206316
207- def test_next_html ():
317+ @pytest .mark .parametrize ("response_type" , [EchoResponse , httpbody_pb2 .HttpBody ])
318+ def test_next_html (response_type ):
208319 with patch .object (
209320 ResponseMock , "iter_content" , return_value = iter ("<!DOCTYPE html><html></html>" )
210321 ) as mock_method :
211-
212- resp = ResponseMock (responses = [], response_cls = EchoResponse )
213- itr = rest_streaming .ResponseIterator (resp , EchoResponse )
322+ resp = ResponseMock (responses = [], response_cls = response_type )
323+ itr = rest_streaming .ResponseIterator (resp , response_type )
214324 with pytest .raises (ValueError ):
215325 next (itr )
216326 mock_method .assert_called_once ()
327+
328+
329+ def test_invalid_response_class ():
330+ class SomeClass :
331+ pass
332+
333+ resp = ResponseMock (responses = [], response_cls = SomeClass )
334+ response_iterator = rest_streaming .ResponseIterator (resp , SomeClass )
335+ with pytest .raises (
336+ ValueError ,
337+ match = "Response message class must be a subclass of proto.Message or google.protobuf.message.Message" ,
338+ ):
339+ response_iterator ._grab ()
0 commit comments