@@ -629,6 +629,154 @@ async def runner():
629
629
630
630
class Test_UV_TCP (_TestTCP , tb .UVTestCase ):
631
631
632
+ def test_create_server_buffered_1 (self ):
633
+ SIZE = 123123
634
+
635
+ class Proto (asyncio .BaseProtocol ):
636
+ def connection_made (self , tr ):
637
+ self .tr = tr
638
+ self .recvd = b''
639
+ self .data = bytearray (50 )
640
+ self .buf = memoryview (self .data )
641
+
642
+ def get_buffer (self ):
643
+ return self .buf
644
+
645
+ def buffer_updated (self , nbytes ):
646
+ self .recvd += self .buf [:nbytes ]
647
+ if self .recvd == b'a' * SIZE :
648
+ self .tr .write (b'hello' )
649
+
650
+ def eof_received (self ):
651
+ pass
652
+
653
+ async def test ():
654
+ port = tb .find_free_port ()
655
+ srv = await self .loop .create_server (Proto , '127.0.0.1' , port )
656
+
657
+ s = socket .socket (socket .AF_INET )
658
+ with s :
659
+ s .setblocking (False )
660
+ await self .loop .sock_connect (s , ('127.0.0.1' , port ))
661
+ await self .loop .sock_sendall (s , b'a' * SIZE )
662
+ d = await self .loop .sock_recv (s , 100 )
663
+ self .assertEqual (d , b'hello' )
664
+
665
+ srv .close ()
666
+ await srv .wait_closed ()
667
+
668
+ self .loop .run_until_complete (test ())
669
+
670
+ def test_create_server_buffered_2 (self ):
671
+ class ProtoExc (asyncio .BaseProtocol ):
672
+ def __init__ (self ):
673
+ self ._lost_exc = None
674
+
675
+ def get_buffer (self ):
676
+ 1 / 0
677
+
678
+ def buffer_updated (self , nbytes ):
679
+ pass
680
+
681
+ def connection_lost (self , exc ):
682
+ self ._lost_exc = exc
683
+
684
+ def eof_received (self ):
685
+ pass
686
+
687
+ class ProtoZeroBuf1 (asyncio .BaseProtocol ):
688
+ def __init__ (self ):
689
+ self ._lost_exc = None
690
+
691
+ def get_buffer (self ):
692
+ return bytearray (0 )
693
+
694
+ def buffer_updated (self , nbytes ):
695
+ pass
696
+
697
+ def connection_lost (self , exc ):
698
+ self ._lost_exc = exc
699
+
700
+ def eof_received (self ):
701
+ pass
702
+
703
+ class ProtoZeroBuf2 (asyncio .BaseProtocol ):
704
+ def __init__ (self ):
705
+ self ._lost_exc = None
706
+
707
+ def get_buffer (self ):
708
+ return memoryview (bytearray (0 ))
709
+
710
+ def buffer_updated (self , nbytes ):
711
+ pass
712
+
713
+ def connection_lost (self , exc ):
714
+ self ._lost_exc = exc
715
+
716
+ def eof_received (self ):
717
+ pass
718
+
719
+ class ProtoUpdatedError (asyncio .BaseProtocol ):
720
+ def __init__ (self ):
721
+ self ._lost_exc = None
722
+
723
+ def get_buffer (self ):
724
+ return memoryview (bytearray (100 ))
725
+
726
+ def buffer_updated (self , nbytes ):
727
+ raise RuntimeError ('oups' )
728
+
729
+ def connection_lost (self , exc ):
730
+ self ._lost_exc = exc
731
+
732
+ def eof_received (self ):
733
+ pass
734
+
735
+ async def test (proto_factory , exc_type , exc_re ):
736
+ port = tb .find_free_port ()
737
+ proto = proto_factory ()
738
+ srv = await self .loop .create_server (
739
+ lambda : proto , '127.0.0.1' , port )
740
+
741
+ try :
742
+ s = socket .socket (socket .AF_INET )
743
+ with s :
744
+ s .setblocking (False )
745
+ await self .loop .sock_connect (s , ('127.0.0.1' , port ))
746
+ await self .loop .sock_sendall (s , b'a' )
747
+ d = await self .loop .sock_recv (s , 100 )
748
+ if not d :
749
+ raise ConnectionResetError
750
+ except ConnectionResetError :
751
+ pass
752
+ else :
753
+ self .fail ("server didn't abort the connection" )
754
+ return
755
+ finally :
756
+ srv .close ()
757
+ await srv .wait_closed ()
758
+
759
+ if proto ._lost_exc is None :
760
+ self .fail ("connection_lost() was not called" )
761
+ return
762
+
763
+ with self .assertRaisesRegex (exc_type , exc_re ):
764
+ raise proto ._lost_exc
765
+
766
+ self .loop .set_exception_handler (lambda loop , ctx : None )
767
+
768
+ self .loop .run_until_complete (
769
+ test (ProtoExc , RuntimeError , 'unhandled error .* get_buffer' ))
770
+
771
+ self .loop .run_until_complete (
772
+ test (ProtoZeroBuf1 , RuntimeError , 'unhandled error .* get_buffer' ))
773
+
774
+ self .loop .run_until_complete (
775
+ test (ProtoZeroBuf2 , RuntimeError , 'unhandled error .* get_buffer' ))
776
+
777
+ self .loop .run_until_complete (
778
+ test (ProtoUpdatedError , RuntimeError , r'^oups$' ))
779
+
632
780
def test_transport_get_extra_info (self ):
633
781
# This tests is only for uvloop. asyncio should pass it
634
782
# too in Python 3.6.
0 commit comments