Skip to content

Commit c8b5273

Browse files
BoronBGPWojciechNagorskiRob-Hague
authored
Handle timeout correctly on Socks5 Proxy (#1342)
* Add timeouts when reading from sockets in Socks5Connector * Add a Socks5 timeout test for a connection reply --------- Co-authored-by: Wojciech Nagórski <[email protected]> Co-authored-by: Rob Hague <[email protected]>
1 parent ce45129 commit c8b5273

File tree

2 files changed

+125
-9
lines changed

2 files changed

+125
-9
lines changed

src/Renci.SshNet/Connection/Socks5Connector.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke
4343
};
4444
SocketAbstraction.Send(socket, greeting);
4545

46-
var socksVersion = SocketReadByte(socket);
46+
var socksVersion = SocketReadByte(socket, connectionInfo.Timeout);
4747
if (socksVersion != 0x05)
4848
{
4949
throw new ProxyException(string.Format("SOCKS Version '{0}' is not supported.", socksVersion));
5050
}
5151

52-
var authenticationMethod = SocketReadByte(socket);
52+
var authenticationMethod = SocketReadByte(socket, connectionInfo.Timeout);
5353
switch (authenticationMethod)
5454
{
5555
case 0x00:
@@ -86,13 +86,13 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke
8686
SocketAbstraction.Send(socket, connectionRequest);
8787

8888
// Read Server SOCKS5 version
89-
if (SocketReadByte(socket) != 5)
89+
if (SocketReadByte(socket, connectionInfo.Timeout) != 5)
9090
{
9191
throw new ProxyException("SOCKS5: Version 5 is expected.");
9292
}
9393

9494
// Read response code
95-
var status = SocketReadByte(socket);
95+
var status = SocketReadByte(socket, connectionInfo.Timeout);
9696

9797
switch (status)
9898
{
@@ -119,21 +119,21 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke
119119
}
120120

121121
// Read reserved byte
122-
if (SocketReadByte(socket) != 0)
122+
if (SocketReadByte(socket, connectionInfo.Timeout) != 0)
123123
{
124124
throw new ProxyException("SOCKS5: 0 byte is expected.");
125125
}
126126

127-
var addressType = SocketReadByte(socket);
127+
var addressType = SocketReadByte(socket, connectionInfo.Timeout);
128128
switch (addressType)
129129
{
130130
case 0x01:
131131
var ipv4 = new byte[4];
132-
_ = SocketRead(socket, ipv4, 0, 4);
132+
_ = SocketRead(socket, ipv4, 0, 4, connectionInfo.Timeout);
133133
break;
134134
case 0x04:
135135
var ipv6 = new byte[16];
136-
_ =SocketRead(socket, ipv6, 0, 16);
136+
_ =SocketRead(socket, ipv6, 0, 16, connectionInfo.Timeout);
137137
break;
138138
default:
139139
throw new ProxyException(string.Format("Address type '{0}' is not supported.", addressType));
@@ -142,7 +142,7 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke
142142
var port = new byte[2];
143143

144144
// Read 2 bytes to be ignored
145-
_ = SocketRead(socket, port, 0, 2);
145+
_ = SocketRead(socket, port, 0, 2, connectionInfo.Timeout);
146146
}
147147

148148
/// <summary>
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using System.Net;
5+
using System.Net.Sockets;
6+
7+
using Microsoft.VisualStudio.TestTools.UnitTesting;
8+
9+
using Moq;
10+
11+
using Renci.SshNet.Common;
12+
using Renci.SshNet.Tests.Common;
13+
14+
namespace Renci.SshNet.Tests.Classes.Connection
15+
{
16+
[TestClass]
17+
public class Socks5ConnectorTest_Connect_TimeoutConnectionReply : Socks5ConnectorTestBase
18+
{
19+
private ConnectionInfo _connectionInfo;
20+
private Exception _actualException;
21+
private AsyncSocketListener _proxyServer;
22+
private Socket _clientSocket;
23+
private List<byte> _bytesReceivedByProxy;
24+
private Stopwatch _stopWatch;
25+
26+
protected override void SetupData()
27+
{
28+
base.SetupData();
29+
30+
var random = new Random();
31+
32+
_connectionInfo = CreateConnectionInfo("proxyUser", "proxyPwd");
33+
_connectionInfo.Timeout = TimeSpan.FromMilliseconds(random.Next(50, 200));
34+
_stopWatch = new Stopwatch();
35+
_bytesReceivedByProxy = new List<byte>();
36+
37+
_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
38+
39+
_proxyServer = new AsyncSocketListener(new IPEndPoint(IPAddress.Loopback, _connectionInfo.ProxyPort));
40+
_proxyServer.BytesReceived += (bytesReceived, socket) => {
41+
_bytesReceivedByProxy.AddRange(bytesReceived);
42+
43+
if (_bytesReceivedByProxy.Count == 4) {
44+
_ = socket.Send(new byte[]
45+
{
46+
// SOCKS version
47+
0x05,
48+
// Require no authentication
49+
0x00
50+
});
51+
}
52+
};
53+
_proxyServer.Start();
54+
}
55+
56+
protected override void SetupMocks()
57+
{
58+
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
59+
.Returns(_clientSocket);
60+
}
61+
62+
protected override void TearDown()
63+
{
64+
base.TearDown();
65+
66+
_proxyServer?.Dispose();
67+
_clientSocket?.Dispose();
68+
}
69+
70+
protected override void Act()
71+
{
72+
_stopWatch.Start();
73+
74+
try
75+
{
76+
_ = Connector.Connect(_connectionInfo);
77+
Assert.Fail();
78+
}
79+
catch (SocketException ex) {
80+
_actualException = ex;
81+
}
82+
catch (SshOperationTimeoutException ex) {
83+
_actualException = ex;
84+
}
85+
finally
86+
{
87+
_stopWatch.Stop();
88+
}
89+
}
90+
91+
[TestMethod]
92+
public void ConnectShouldHaveThrownSshOperationTimeoutException() {
93+
Assert.IsNull(_actualException.InnerException);
94+
Assert.IsInstanceOfType<SshOperationTimeoutException>(_actualException);
95+
}
96+
97+
[TestMethod]
98+
public void ConnectShouldHaveRespectedTimeout()
99+
{
100+
var errorText = string.Format("Elapsed: {0}, Timeout: {1}",
101+
_stopWatch.ElapsedMilliseconds,
102+
_connectionInfo.Timeout.TotalMilliseconds);
103+
104+
// Compare elapsed time with configured timeout, allowing for a margin of error
105+
Assert.IsTrue(_stopWatch.ElapsedMilliseconds >= _connectionInfo.Timeout.TotalMilliseconds - 10, errorText);
106+
Assert.IsTrue(_stopWatch.ElapsedMilliseconds < _connectionInfo.Timeout.TotalMilliseconds + 100, errorText);
107+
}
108+
109+
[TestMethod]
110+
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
111+
{
112+
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
113+
Times.Once());
114+
}
115+
}
116+
}

0 commit comments

Comments
 (0)