Skip to content

Commit 8003a13

Browse files
committed
Handle edge-case in which we receive a partial WS header
me-no-dev#953
1 parent 8828f47 commit 8003a13

File tree

2 files changed

+481
-345
lines changed

2 files changed

+481
-345
lines changed

src/AsyncWebSocket.cpp

Lines changed: 162 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,8 @@ AsyncWebSocketClient::AsyncWebSocketClient(AsyncWebServerRequest *request, Async
483483
_clientId = _server->_getNextId();
484484
_status = WS_CONNECTED;
485485
_pstate = 0;
486+
_partialHeader = nullptr;
487+
_partialHeaderLen = 0;
486488
_lastMessageTime = millis();
487489
_keepAlivePeriod = 0;
488490
_client->setRxTimeout(0);
@@ -670,116 +672,197 @@ void AsyncWebSocketClient::_onData(void *pbuf, size_t plen)
670672
{
671673
if (!_pstate)
672674
{
673-
const uint8_t *fdata = data;
675+
ssize_t dataPayloadOffset = 0;
676+
const uint8_t *headerBuf = data;
677+
678+
// plen is backed up to initialPlen because, in case we receive a partial header, we would like to undo all of our
679+
// parsing and copy all of what we have of the header into a buffer for later use.
680+
// plen is modified during the parsing attempt, so if we don't back it up we won't know how much we need to copy.
681+
// partialHeaderLen is also backed up for the same reason.
682+
size_t initialPlen = plen;
683+
size_t partialHeaderLen = 0;
684+
685+
if (_partialHeaderLen > 0)
686+
{
687+
// We previously received a truncated header. Recover it by doing the following:
688+
// - Copy the new header chunk into the previous partial header, filling the buffer. It is allocated as a
689+
// buffer in a class field.
690+
// - Change *headerBuf to point to said buffer
691+
// - Update the length counters so that:
692+
// - The initialPlen and plen, which refer to the length of the remaining packet data, also accounts for the
693+
// previously received truncated header
694+
// - The dataPayloadOffset, which is the offset after the header at which the payload begins, so that it
695+
// refers to a point potentially before the beginning of the buffer. As we parse the header we increment it,
696+
// and we can pretty much guarantee it will go back to being positive unless there is a major bug.
697+
// - The class _partialHeaderLen is back to zero since we took ownership of the contained data.
698+
memcpy(_partialHeader + _partialHeaderLen, data,
699+
std::min(plen, (size_t)WS_MAX_HEADER_LEN - _partialHeaderLen));
700+
headerBuf = _partialHeader;
701+
initialPlen += _partialHeaderLen;
702+
plen += _partialHeaderLen;
703+
dataPayloadOffset -= _partialHeaderLen;
704+
partialHeaderLen = _partialHeaderLen;
705+
706+
_partialHeaderLen = 0;
707+
}
708+
709+
// The following series of gotos could have been a try-catch but we are likely being built with -fno-exceptions
710+
if (plen < 2)
711+
goto _exceptionHandleFailPartialHeader;
674712
_pinfo.index = 0;
675-
_pinfo.final = (fdata[0] & 0x80) != 0;
676-
_pinfo.opcode = fdata[0] & 0x0F;
677-
_pinfo.masked = (fdata[1] & 0x80) != 0;
678-
_pinfo.len = fdata[1] & 0x7F;
679-
data += 2;
713+
_pinfo.final = (headerBuf[0] & 0x80) != 0;
714+
_pinfo.opcode = headerBuf[0] & 0x0F;
715+
_pinfo.masked = (headerBuf[1] & 0x80) != 0;
716+
_pinfo.len = headerBuf[1] & 0x7F;
717+
dataPayloadOffset += 2;
680718
plen -= 2;
719+
681720
if (_pinfo.len == 126)
682721
{
683-
_pinfo.len = fdata[3] | (uint16_t)(fdata[2]) << 8;
684-
data += 2;
685-
plen -= 2;
686-
}
687-
else if (_pinfo.len == 127)
688-
{
689-
_pinfo.len = fdata[9] | (uint16_t)(fdata[8]) << 8 | (uint32_t)(fdata[7]) << 16 | (uint32_t)(fdata[6]) << 24 | (uint64_t)(fdata[5]) << 32 | (uint64_t)(fdata[4]) << 40 | (uint64_t)(fdata[3]) << 48 | (uint64_t)(fdata[2]) << 56;
690-
data += 8;
691-
plen -= 8;
692-
}
722+
if (plen < 2)
723+
goto _exceptionHandleFailPartialHeader;
693724

694-
if (_pinfo.masked)
695-
{
696-
memcpy(_pinfo.mask, data, 4);
697-
data += 4;
698-
plen -= 4;
725+
_pinfo.len = headerBuf[3] | (uint16_t)(headerBuf[2]) << 8;
726+
dataPayloadOffset += 2;
727+
plen -= 2;
699728
}
700729
}
701-
702-
const size_t datalen = std::min((size_t)(_pinfo.len - _pinfo.index), plen);
703-
const auto datalast = data[datalen];
730+
else if (_pinfo.len == 127)
731+
{
732+
if (plen < 8)
733+
goto _exceptionHandleFailPartialHeader;
734+
735+
_pinfo.len = headerBuf[9] | (uint16_t)(headerBuf[8]) << 8 | (uint32_t)(headerBuf[7]) << 16 |
736+
(uint32_t)(headerBuf[6]) << 24 | (uint64_t)(headerBuf[5]) << 32 | (uint64_t)(headerBuf[4]) << 40 |
737+
(uint64_t)(headerBuf[3]) << 48 | (uint64_t)(headerBuf[2]) << 56;
738+
dataPayloadOffset += 8;
739+
plen -= 8;
740+
}
704741

705742
if (_pinfo.masked)
706743
{
707-
for (size_t i = 0; i < datalen; i++)
708-
data[i] ^= _pinfo.mask[(_pinfo.index + i) % 4];
744+
if (plen < 4)
745+
goto _exceptionHandleFailPartialHeader;
746+
747+
memcpy(_pinfo.mask, headerBuf + dataPayloadOffset + partialHeaderLen, 4);
748+
dataPayloadOffset += 4;
749+
plen -= 4;
709750
}
710751

711-
if ((datalen + _pinfo.index) < _pinfo.len)
752+
// Yes I know the control flow here isn't 100% legible but we must support -fno-exceptions.
753+
// If we got to this point it means we did NOT receive a truncated header, therefore we can skip the exception
754+
// handling.
755+
// Control flow resumes after the following block.
756+
goto _headerParsingSuccessful;
757+
758+
// We DID receive a truncated header:
759+
// - We copy it to our buffer and set the _partialHeaderLen
760+
// - We return early
761+
// This will trigger the partial recovery at the next call of this method, once more data is received and we have
762+
// a full header.
763+
_exceptionHandleFailPartialHeader:
764+
{
765+
if (initialPlen <= WS_MAX_HEADER_LEN)
766+
{
767+
// If initialPlen > WS_MAX_HEADER_LEN there must be something wrong with this code. It should never happen but
768+
// but it's better safe than sorry.
769+
memcpy(_partialHeader, headerBuf, initialPlen * sizeof(uint8_t));
770+
_partialHeaderLen = initialPlen;
771+
}
772+
else
712773
{
713-
_pstate = 1;
774+
DEBUGF("[AsyncWebSocketClient::_onData] initialPlen (= %d) > WS_MAX_HEADER_LEN (= %d)\n", initialPlen,
775+
WS_MAX_HEADER_LEN);
776+
}
777+
return;
778+
}
779+
780+
_headerParsingSuccessful:
781+
782+
data += dataPayloadOffset;
783+
}
784+
785+
const size_t datalen = std::min((size_t)(_pinfo.len - _pinfo.index), plen);
786+
const auto datalast = data[datalen];
787+
788+
if (_pinfo.masked)
789+
{
790+
for (size_t i = 0; i < datalen; i++)
791+
data[i] ^= _pinfo.mask[(_pinfo.index + i) % 4];
792+
}
793+
794+
if ((datalen + _pinfo.index) < _pinfo.len)
795+
{
796+
_pstate = 1;
714797

715-
if (_pinfo.index == 0)
798+
if (_pinfo.index == 0)
799+
{
800+
if (_pinfo.opcode)
716801
{
717-
if (_pinfo.opcode)
718-
{
719-
_pinfo.message_opcode = _pinfo.opcode;
720-
_pinfo.num = 0;
721-
}
722-
else
723-
_pinfo.num += 1;
802+
_pinfo.message_opcode = _pinfo.opcode;
803+
_pinfo.num = 0;
724804
}
725-
_server->_handleEvent(this, WS_EVT_DATA, (void *)&_pinfo, (uint8_t *)data, datalen);
726-
727-
_pinfo.index += datalen;
805+
else
806+
_pinfo.num += 1;
728807
}
729-
else if ((datalen + _pinfo.index) == _pinfo.len)
808+
_server->_handleEvent(this, WS_EVT_DATA, (void *)&_pinfo, (uint8_t *)data, datalen);
809+
810+
_pinfo.index += datalen;
811+
}
812+
else if ((datalen + _pinfo.index) == _pinfo.len)
813+
{
814+
_pstate = 0;
815+
if (_pinfo.opcode == WS_DISCONNECT)
730816
{
731-
_pstate = 0;
732-
if (_pinfo.opcode == WS_DISCONNECT)
817+
if (datalen)
733818
{
734-
if (datalen)
735-
{
736-
uint16_t reasonCode = (uint16_t)(data[0] << 8) + data[1];
737-
char *reasonString = (char *)(data + 2);
738-
if (reasonCode > 1001)
739-
{
740-
_server->_handleEvent(this, WS_EVT_ERROR, (void *)&reasonCode, (uint8_t *)reasonString, strlen(reasonString));
741-
}
742-
}
743-
if (_status == WS_DISCONNECTING)
744-
{
745-
_status = WS_DISCONNECTED;
746-
_client->close(true);
747-
}
748-
else
819+
uint16_t reasonCode = (uint16_t)(data[0] << 8) + data[1];
820+
char *reasonString = (char *)(data + 2);
821+
if (reasonCode > 1001)
749822
{
750-
_status = WS_DISCONNECTING;
751-
_client->ackLater();
752-
_queueControl(new AsyncWebSocketControl(WS_DISCONNECT, data, datalen));
823+
_server->_handleEvent(this, WS_EVT_ERROR, (void *)&reasonCode, (uint8_t *)reasonString, strlen(reasonString));
753824
}
754825
}
755-
else if (_pinfo.opcode == WS_PING)
826+
if (_status == WS_DISCONNECTING)
756827
{
757-
_queueControl(new AsyncWebSocketControl(WS_PONG, data, datalen));
828+
_status = WS_DISCONNECTED;
829+
_client->close(true);
758830
}
759-
else if (_pinfo.opcode == WS_PONG)
831+
else
760832
{
761-
if (datalen != AWSC_PING_PAYLOAD_LEN || memcmp(AWSC_PING_PAYLOAD, data, AWSC_PING_PAYLOAD_LEN) != 0)
762-
_server->_handleEvent(this, WS_EVT_PONG, NULL, data, datalen);
763-
}
764-
else if (_pinfo.opcode < 8)
765-
{ // continuation or text/binary frame
766-
_server->_handleEvent(this, WS_EVT_DATA, (void *)&_pinfo, data, datalen);
833+
_status = WS_DISCONNECTING;
834+
_client->ackLater();
835+
_queueControl(new AsyncWebSocketControl(WS_DISCONNECT, data, datalen));
767836
}
768837
}
769-
else
838+
else if (_pinfo.opcode == WS_PING)
770839
{
771-
// os_printf("frame error: len: %u, index: %llu, total: %llu\n", datalen, _pinfo.index, _pinfo.len);
772-
// what should we do?
773-
break;
840+
_queueControl(new AsyncWebSocketControl(WS_PONG, data, datalen));
841+
}
842+
else if (_pinfo.opcode == WS_PONG)
843+
{
844+
if (datalen != AWSC_PING_PAYLOAD_LEN || memcmp(AWSC_PING_PAYLOAD, data, AWSC_PING_PAYLOAD_LEN) != 0)
845+
_server->_handleEvent(this, WS_EVT_PONG, NULL, data, datalen);
774846
}
847+
else if (_pinfo.opcode < 8)
848+
{ // continuation or text/binary frame
849+
_server->_handleEvent(this, WS_EVT_DATA, (void *)&_pinfo, data, datalen);
850+
}
851+
}
852+
else
853+
{
854+
// os_printf("frame error: len: %u, index: %llu, total: %llu\n", datalen, _pinfo.index, _pinfo.len);
855+
// what should we do?
856+
break;
857+
}
775858

776-
// restore byte as _handleEvent may have added a null terminator i.e., data[len] = 0;
777-
if (datalen > 0)
778-
data[datalen] = datalast;
859+
// restore byte as _handleEvent may have added a null terminator i.e., data[len] = 0;
860+
if (datalen > 0)
861+
data[datalen] = datalast;
779862

780-
data += datalen;
781-
plen -= datalen;
782-
}
863+
data += datalen;
864+
plen -= datalen;
865+
}
783866
}
784867

785868
size_t AsyncWebSocketClient::printf(const char *format, ...)

0 commit comments

Comments
 (0)