From ff2ab9ece3f7575b5019e24e4ad8c69440df1c48 Mon Sep 17 00:00:00 2001 From: Spartan322 Date: Sun, 27 Apr 2025 19:32:01 -0400 Subject: [PATCH 1/3] Add multiplayer abstractions Add multi-platform socket abstractions Add Udp and Tcp client/server socket handlers Add RingBuffer Add PacketReaderAdapter Add PacketBuilder Add byte Marshal encode and decode Add RingBuffer tests Add IpAddress tests Add UdpServer tests Add ReliableUdpServer tess Add TcpServer tests Add Marshal tests Add mas-bandwidth/reliable@57b0c90fa68598dbf41467ad28e9247cb08928c3 --- .gitmodules | 3 + deps/SCsub | 31 + deps/reliable | 1 + .../multiplayer/lowlevel/HostnameAddress.cpp | 30 + .../multiplayer/lowlevel/HostnameAddress.hpp | 19 + .../multiplayer/lowlevel/IpAddress.cpp | 18 + .../multiplayer/lowlevel/IpAddress.hpp | 415 ++++++++++ .../multiplayer/lowlevel/NetworkError.hpp | 32 + .../multiplayer/lowlevel/NetworkResolver.hpp | 19 + .../lowlevel/NetworkResolverBase.cpp | 277 +++++++ .../lowlevel/NetworkResolverBase.hpp | 92 +++ .../multiplayer/lowlevel/NetworkSocket.hpp | 19 + .../lowlevel/NetworkSocketBase.hpp | 75 ++ .../multiplayer/lowlevel/PacketBuilder.cpp | 3 + .../multiplayer/lowlevel/PacketBuilder.hpp | 28 + .../multiplayer/lowlevel/PacketClient.hpp | 46 ++ .../lowlevel/PacketReaderAdapter.cpp | 6 + .../lowlevel/PacketReaderAdapter.hpp | 70 ++ .../multiplayer/lowlevel/PacketServer.cpp | 76 ++ .../multiplayer/lowlevel/PacketServer.hpp | 36 + .../multiplayer/lowlevel/PacketStream.cpp | 110 +++ .../multiplayer/lowlevel/PacketStream.hpp | 59 ++ .../lowlevel/ReliableUdpClient.cpp | 256 ++++++ .../lowlevel/ReliableUdpClient.hpp | 112 +++ .../lowlevel/ReliableUdpServer.hpp | 47 ++ .../lowlevel/StreamPacketClient.cpp | 10 + .../lowlevel/StreamPacketClient.hpp | 173 +++++ .../multiplayer/lowlevel/TcpPacketStream.cpp | 316 ++++++++ .../multiplayer/lowlevel/TcpPacketStream.hpp | 69 ++ .../multiplayer/lowlevel/TcpServer.cpp | 79 ++ .../multiplayer/lowlevel/TcpServer.hpp | 21 + .../multiplayer/lowlevel/UdpClient.cpp | 329 ++++++++ .../multiplayer/lowlevel/UdpClient.hpp | 87 +++ .../multiplayer/lowlevel/UdpServer.cpp | 114 +++ .../multiplayer/lowlevel/UdpServer.hpp | 60 ++ .../lowlevel/unix/UnixNetworkResolver.cpp | 154 ++++ .../lowlevel/unix/UnixNetworkResolver.hpp | 32 + .../multiplayer/lowlevel/unix/UnixSocket.cpp | 564 ++++++++++++++ .../multiplayer/lowlevel/unix/UnixSocket.hpp | 70 ++ .../windows/WindowsNetworkResolver.cpp | 183 +++++ .../windows/WindowsNetworkResolver.hpp | 35 + .../lowlevel/windows/WindowsSocket.cpp | 583 ++++++++++++++ .../lowlevel/windows/WindowsSocket.hpp | 75 ++ src/openvic-simulation/types/RingBuffer.hpp | 729 ++++++++++++++++++ .../utility/ErrorMacros.hpp | 176 ++++- src/openvic-simulation/utility/Marshal.hpp | 574 ++++++++++++++ src/openvic-simulation/utility/Utility.hpp | 127 ++- tests/src/multiplayer/lowlevel/IpAddress.cpp | 229 ++++++ .../lowlevel/ReliableUdpServer.cpp | 261 +++++++ tests/src/multiplayer/lowlevel/TcpServer.cpp | 249 ++++++ tests/src/multiplayer/lowlevel/UdpServer.cpp | 213 +++++ tests/src/types/Colour.cpp | 13 + tests/src/types/Date.cpp | 11 + tests/src/types/FixedPoint.cpp | 14 + tests/src/types/RingBuffer.cpp | 204 +++++ tests/src/types/Timespan.cpp | 13 + tests/src/types/Vector2.cpp | 14 + tests/src/types/Vector3.cpp | 14 + tests/src/types/Vector4.cpp | 14 + tests/src/utility/Marshal.cpp | 195 +++++ 60 files changed, 7862 insertions(+), 22 deletions(-) create mode 160000 deps/reliable create mode 100644 src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/IpAddress.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/IpAddress.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/NetworkError.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/PacketClient.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/PacketServer.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/PacketServer.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/PacketStream.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/PacketStream.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/ReliableUdpServer.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/TcpServer.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/TcpServer.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/UdpClient.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/UdpClient.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/UdpServer.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/UdpServer.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.hpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.cpp create mode 100644 src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.hpp create mode 100644 src/openvic-simulation/types/RingBuffer.hpp create mode 100644 src/openvic-simulation/utility/Marshal.hpp create mode 100644 tests/src/multiplayer/lowlevel/IpAddress.cpp create mode 100644 tests/src/multiplayer/lowlevel/ReliableUdpServer.cpp create mode 100644 tests/src/multiplayer/lowlevel/TcpServer.cpp create mode 100644 tests/src/multiplayer/lowlevel/UdpServer.cpp create mode 100644 tests/src/types/RingBuffer.cpp create mode 100644 tests/src/utility/Marshal.cpp diff --git a/.gitmodules b/.gitmodules index 8ca8012ce..1a2ea72e6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -30,3 +30,6 @@ path = deps/memory url = https://github.com/foonathan/memory ignore = dirty +[submodule "deps/reliable"] + path = deps/reliable + url = https://github.com/mas-bandwidth/reliable diff --git a/deps/SCsub b/deps/SCsub index 948ebe78d..6315e9647 100644 --- a/deps/SCsub +++ b/deps/SCsub @@ -111,6 +111,36 @@ def build_memory(env): env.exposed_includes += env.memory["INCPATH"] +def build_reliable(env): + import os + + if env["dev_build"]: + env.Append(CPPDEFINES=["RELIABLE_DEBUG"]) + else: + env.Append(CPPDEFINES=["RELIABLE_RELEASE"]) + reliable_env = env.Clone() + + include_path = "reliable" + include_dir = reliable_env.Dir(include_path) + sources = [os.path.join(include_path, "reliable.c")] + env.reliable_sources = sources + library_name = "libreliable" + env["LIBSUFFIX"] + library = reliable_env.StaticLibrary(target=os.path.join(include_path, library_name), source=sources) + Default(library) + + env.reliable = {} + env.reliable["INCPATH"] = [include_dir] + + env.Append(CPPPATH=env.reliable["INCPATH"]) + if env.get("is_msvc", False): + env.Append(CXXFLAGS=["/external:I", include_dir, "/external:W0"]) + else: + env.Append(CXXFLAGS=["-isystem", include_dir]) + env.Append(LIBPATH=include_dir) + env.Prepend(LIBS=[library_name]) + + env.exposed_includes += env.reliable["INCPATH"] + def link_tbb(env): import sys if not env.get("is_msvc", False) and not env.get("use_mingw", False) and sys.platform != "darwin": @@ -123,4 +153,5 @@ build_colony(env) build_function2(env) build_std_function(env) build_memory(env) +build_reliable(env) link_tbb(env) \ No newline at end of file diff --git a/deps/reliable b/deps/reliable new file mode 160000 index 000000000..57b0c90fa --- /dev/null +++ b/deps/reliable @@ -0,0 +1 @@ +Subproject commit 57b0c90fa68598dbf41467ad28e9247cb08928c3 diff --git a/src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.cpp b/src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.cpp new file mode 100644 index 000000000..27865b474 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.cpp @@ -0,0 +1,30 @@ + +#include "HostnameAddress.hpp" + +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" + +using namespace OpenVic; + +HostnameAddress::HostnameAddress() = default; + +HostnameAddress::HostnameAddress(IpAddress const& address) : _resolved_address(address) {} + +HostnameAddress::HostnameAddress(std::string_view name_or_address) : HostnameAddress() { + std::from_chars_result result = + _resolved_address.from_chars(name_or_address.data(), name_or_address.data() + name_or_address.size()); + if (result.ec != std::errc {}) { + _resolved_address = NetworkResolver::singleton().resolve_hostname(name_or_address); + } +} + +IpAddress const& HostnameAddress::resolved_address() const { + return _resolved_address; +} + +void HostnameAddress::set_resolved_address(IpAddress const& address) { + _resolved_address = address; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.hpp b/src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.hpp new file mode 100644 index 000000000..a36aff7b9 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" + +namespace OpenVic { + struct HostnameAddress { + HostnameAddress(); + HostnameAddress(IpAddress const& address); + HostnameAddress(std::string_view name_or_address); + + IpAddress const& resolved_address() const; + void set_resolved_address(IpAddress const& address); + + private: + IpAddress _resolved_address; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/IpAddress.cpp b/src/openvic-simulation/multiplayer/lowlevel/IpAddress.cpp new file mode 100644 index 000000000..92c4ac31b --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/IpAddress.cpp @@ -0,0 +1,18 @@ +#include "IpAddress.hpp" + +#include "openvic-simulation/utility/Containers.hpp" + +using namespace OpenVic; + +memory::string IpAddress::to_string(bool prefer_ipv4, to_chars_option option) const { + stack_string result = to_array(prefer_ipv4, option); + if (OV_unlikely(result.empty())) { + return {}; + } + + return result; +} + +IpAddress::operator memory::string() const { + return to_string(); +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/IpAddress.hpp b/src/openvic-simulation/multiplayer/lowlevel/IpAddress.hpp new file mode 100644 index 000000000..c4bfd4e3e --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/IpAddress.hpp @@ -0,0 +1,415 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvic-simulation/types/StackString.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/StringUtils.hpp" + +namespace OpenVic { + struct IpAddress { + constexpr IpAddress() : _field {} { + clear(); + } + + constexpr IpAddress(char const* cstring) : IpAddress(std::string_view { cstring }) {} + + constexpr IpAddress(std::string_view string) : _field {} { + clear(); + from_chars(string.data(), string.data() + string.size()); + } + + constexpr IpAddress(uint32_t a, uint32_t b, uint32_t c, uint32_t d, bool is_ipv6 = false) : _field {} { + clear(); + _valid = true; + if (is_ipv6) { + auto _set_buffer = [](uint8_t* buffer, uint32_t value) constexpr { + buffer[0] = (value >> 24) & 0xff; + buffer[1] = (value >> 16) & 0xff; + buffer[2] = (value >> 8) & 0xff; + buffer[3] = (value >> 0) & 0xff; + }; + _set_buffer(&_field._8bit[0], a); + _set_buffer(&_field._8bit[4], b); + _set_buffer(&_field._8bit[8], c); + _set_buffer(&_field._8bit[12], d); + return; + } + if (std::is_constant_evaluated()) { + _field._8bit[10] = 0xff; + _field._8bit[11] = 0xff; + } else { + _field._16bit[5] = 0xffff; + } + _field._8bit[12] = a; + _field._8bit[13] = b; + _field._8bit[14] = c; + _field._8bit[15] = d; + } + + constexpr void clear() { + if (std::is_constant_evaluated()) { + std::fill_n(_field._8bit.data(), sizeof(_field._8bit), 0); + } else { + std::memset(_field._8bit.data(), 0, sizeof(_field._8bit)); + } + _wildcard = false; + _valid = false; + } + + constexpr bool is_wildcard() const { + return _wildcard; + } + + constexpr bool is_valid() const { + return _valid; + } + + constexpr bool is_ipv4() const { + if (std::is_constant_evaluated()) { + std::array casted_fields = std::bit_cast>(_field._8bit); + return casted_fields[0] == 0 && casted_fields[1] == 0 && casted_fields[2] == 0 && casted_fields[3] == 0 && + casted_fields[4] == 0 && casted_fields[5] == 0xffff; + } else { + return (_field._32bit[0] == 0 && _field._32bit[1] == 0 && _field._16bit[4] == 0 && _field._16bit[5] == 0xffff); + } + } + + constexpr std::span get_ipv4() const { + if (std::is_constant_evaluated()) { + return std::span { _field._8bit }.subspan<12, 4>(); + } else { + OV_ERR_FAIL_COND_V_MSG( + !is_ipv4(), (std::span { _field._8bit }.subspan<12, 4>()), + "IPv4 requested, but current IP is IPv6." + ); + } + return std::span { _field._8bit }.subspan<12, 4>(); + } + + constexpr void set_ipv4(std::span ip) { + clear(); + _valid = true; + if (std::is_constant_evaluated()) { + _field._8bit[10] = 0xff; + _field._8bit[11] = 0xff; + std::copy_n(ip.data(), 4, &_field._8bit[12]); + } else { + _field._16bit[5] = 0xffff; + _field._32bit[3] = *std::bit_cast(ip.data()); + } + } + + constexpr void set_ipv4(std::initializer_list list) { + set_ipv4(std::span { list }); + } + + constexpr std::span get_ipv6() const { + return _field._8bit; + } + + constexpr void set_ipv6(std::span ip) { + clear(); + _valid = true; + for (size_t i = 0; i < _field._8bit.size(); i++) { + _field._8bit[i] = ip[i]; + } + } + + constexpr void set_ipv6(std::initializer_list list) { + set_ipv6(std::span { list }); + } + + constexpr bool operator==(IpAddress const& ip) const { + if (ip._valid != _valid) { + return false; + } + if (!_valid) { + return false; + } + if (std::is_constant_evaluated()) { + for (int i = 0; i < _field._8bit.size(); i++) { + if (_field._8bit[i] != ip._field._8bit[i]) { + return false; + } + } + } else { + for (int i = 0; i < _field._32bit.size(); i++) { + if (_field._32bit[i] != ip._field._32bit[i]) { + return false; + } + } + } + return true; + } + + enum class to_chars_option : uint8_t { // + SHORTEN_IPV6, + COMPRESS_IPV6, + EXPAND_IPV6, + }; + + inline constexpr std::to_chars_result to_chars( // + char* first, char* last, bool prefer_ipv4 = true, to_chars_option option = to_chars_option::SHORTEN_IPV6 + ) const { + if (first == nullptr || first >= last) { + return { last, std::errc::value_too_large }; + } + + if (_wildcard) { + *first = '*'; + ++first; + return { last, std::errc {} }; + } + + if (!_valid) { + return { last, std::errc {} }; + } + + std::to_chars_result result = { first, std::errc {} }; + if (is_ipv4() && prefer_ipv4) { + for (size_t i = 12; i < 16; i++) { + if (i > 12) { + if (last - result.ptr <= 1) { + return { last, std::errc::value_too_large }; + } + + *result.ptr = '.'; + ++result.ptr; + } + + result = StringUtils::to_chars(result.ptr, last, _field._8bit[i]); + if (result.ec != std::errc {}) { + return result; + } + } + return result; + } + + auto section_func = [this](size_t index) constexpr -> uint16_t { + return (_field._8bit[index * 2] << 8) + _field._8bit[index * 2 + 1]; + }; + + int32_t compress_pos = -1; + if (option == to_chars_option::COMPRESS_IPV6) { + int32_t last_compress_count = 0; + for (size_t i = 0; i < 8; i++) { + if (_field._8bit[i * 2] == 0 && _field._8bit[i * 2 + 1] == 0 && section_func(i + 1) == 0) { + int32_t compress_check = i; + do { + ++i; + } while (i < 8 && section_func(i) == 0); + + if (int32_t compress_count = i - compress_check; + compress_count >= 2 && last_compress_count < compress_count) { + compress_pos = compress_check; + last_compress_count = compress_count; + } + } + } + } + + for (size_t i = 0; i < 8; i++) { + if (compress_pos > -1 && compress_pos == i) { + if (last - result.ptr <= 2) { + return { last, std::errc::value_too_large }; + } + + result.ptr[0] = ':'; + result.ptr[1] = ':'; + result.ptr += 2; + do { + ++i; + } while (i < 8 && section_func(i) == 0); + } else if (i > 0) { + *result.ptr = ':'; + ++result.ptr; + } + + uint16_t section = section_func(i); + if (option == to_chars_option::EXPAND_IPV6) { + if (last - result.ptr < 4) { + return { last, std::errc::value_too_large }; + } + + if (section < 0xFFF) { + *result.ptr = '0'; + ++result.ptr; + } + if (section < 0xFF) { + *result.ptr = '0'; + ++result.ptr; + } + if (section < 0xF) { + *result.ptr = '0'; + ++result.ptr; + } + if (section == 0) { + *result.ptr = '0'; + ++result.ptr; + continue; + } + } + result = StringUtils::to_chars(result.ptr, last, section, 16); + if (result.ec != std::errc {}) { + return result; + } + } + return result; + } + + struct stack_string; + inline constexpr stack_string to_array( // + bool prefer_ipv4 = true, to_chars_option option = to_chars_option::SHORTEN_IPV6 + ) const; + + struct stack_string final : StackString<39> { + protected: + using StackString::StackString; + friend inline constexpr stack_string IpAddress::to_array(bool prefer_ipv4, to_chars_option option) const; + }; + + memory::string to_string(bool prefer_ipv4 = true, to_chars_option option = to_chars_option::SHORTEN_IPV6) const; + explicit operator memory::string() const; + + constexpr std::from_chars_result from_chars(char const* begin, char const* end) { + if (begin == nullptr || begin >= end) { + return { begin, std::errc::invalid_argument }; + } + + if (*begin == '*') { + _wildcard = true; + _valid = false; + return { begin + 1, std::errc {} }; + } + + size_t check_for_ipv4 = 0; + for (char const* check = begin; check < end; check++) { + if (*check == '.') { + check_for_ipv4++; + } + } + + fields_type fields { ._8bit {} }; + + std::from_chars_result result = { begin, std::errc {} }; + if (check_for_ipv4 == 3) { + if (std::is_constant_evaluated()) { + fields._8bit[10] = 0xff; + fields._8bit[11] = 0xff; + } else { + fields._16bit[5] = 0xffff; + } + for (size_t i = 0; i < 4; i++) { + if (*result.ptr == '.') { + if (i > 0 && end - result.ptr >= 1) { + ++result.ptr; + } else { + return { begin, std::errc::invalid_argument }; + } + } + + result = StringUtils::from_chars(result.ptr, end, fields._8bit[12 + i]); + if (result.ec != std::errc {}) { + return result; + } + } + _field._8bit.swap(fields._8bit); + _valid = true; + return result; + } else if (check_for_ipv4 > 0) { + return { begin, std::errc::invalid_argument }; + } + + int32_t compress_start = -1; + size_t colon_count = 0; + if (std::is_constant_evaluated()) { + fields._16bit = std::bit_cast>(fields._8bit); + } + for (size_t i = 0; i < fields._16bit.size(); i++) { + if (*result.ptr == ':') { + if (end - result.ptr <= 2) { + return { result.ptr, std::errc::invalid_argument }; + } + + ++result.ptr; + if (*result.ptr == ':') { + if (compress_start > -1 || i >= fields._16bit.size() - 2) { + return { result.ptr, std::errc::invalid_argument }; + } + + ++result.ptr; + compress_start = i; + } else { + ++colon_count; + } + } + + uint16_t big_endian_value; + result = StringUtils::from_chars(result.ptr, end, big_endian_value, 16); + fields._16bit[i] = (big_endian_value >> 8) | (big_endian_value << 8); + if (result.ec != std::errc {}) { + return result; + } + } + + if (compress_start > -1) { + constexpr size_t expected_ipv6_colons = 7; + size_t index = compress_start - 1; + // Pre-compression + std::swap_ranges(_field._16bit.data(), _field._16bit.data() + index, fields._16bit.data()); + size_t compress_end = expected_ipv6_colons - colon_count - 1; + // Compression + if (std::is_constant_evaluated()) { + std::fill_n(_field._16bit.data() + compress_start, compress_end - 1, 0); + } else { + std::memset(_field._16bit.data() + compress_start, 0, compress_end - 1); + } + // Post-compression + std::swap_ranges( + _field._16bit.data() + compress_end, fields._16bit.data() + fields._16bit.size(), + fields._16bit.data() + index + ); + } else { + if (std::is_constant_evaluated()) { + _field._16bit = std::bit_cast>(_field._8bit); + } + _field._16bit.swap(fields._16bit); + } + + if (std::is_constant_evaluated()) { + _field._8bit = std::bit_cast>(_field._16bit); + } + + _valid = true; + return result; + } + + private: + union fields_type { + std::array _8bit; + std::array _16bit; + std::array _32bit; + } _field; + + bool _wildcard = false; + bool _valid = false; + }; + + inline constexpr IpAddress::stack_string IpAddress::to_array(bool pref_ipv4, to_chars_option option) const { + stack_string str {}; + std::to_chars_result result = to_chars(str._array.data(), str._array.data() + str._array.size(), pref_ipv4, option); + str._string_size = result.ptr - str.data(); + return str; + } +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/NetworkError.hpp b/src/openvic-simulation/multiplayer/lowlevel/NetworkError.hpp new file mode 100644 index 000000000..03fdf4b59 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/NetworkError.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include + +namespace OpenVic { + enum class NetworkError : uint8_t { + OK, + FAILED, + UNAVAILABLE, + UNCONFIGURED, + UNAUTHORIZED, + OUT_OF_MEMORY, + LOCKED, + EMPTY_BUFFER, + BUG, + + ALREADY_OPEN, + NOT_OPEN, + INVALID_PARAMETER, + SOCKET_ERROR, + BROADCAST_CHANGE_FAILED, + UNSUPPORTED, + BUSY, + + WOULD_BLOCK, + IS_CONNECTED, + IN_PROGRESS, + ADDRESS_INVALID_OR_UNAVAILABLE, + BUFFER_TOO_SMALL, + OTHER, + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp b/src/openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp new file mode 100644 index 000000000..493285268 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp @@ -0,0 +1,19 @@ +#pragma once + +#ifdef _WIN32 +#include "openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.hpp" +#elif defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)) +#include "openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.hpp" +#else +#error "NetworkResolver.hpp only supports unix or windows systems" +#endif + +namespace OpenVic { + using NetworkResolver = +#ifdef _WIN32 + WindowsNetworkResolver +#else + UnixNetworkResolver +#endif + ; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.cpp b/src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.cpp new file mode 100644 index 000000000..2818af7e9 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.cpp @@ -0,0 +1,277 @@ +#include "NetworkResolverBase.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/StringUtils.hpp" + +using namespace OpenVic; + +static memory::string _get_cache_key(std::string_view p_hostname, NetworkResolverBase::Type p_type) { + memory::string key { p_hostname }; + key.insert(key.begin(), '0'); + key.reserve(p_hostname.size() + 1); + + std::to_chars_result result = StringUtils::to_chars(key.data(), key.data() + 1, static_cast(p_type)); + if (result.ec != std::errc {}) { + key = memory::string { p_hostname }; + } + + return key; +} + +NetworkResolverBase::ResolveHandler::Item::Item() : type { Type::NONE } { + clear(); +} + +void NetworkResolverBase::ResolveHandler::Item::clear() { + status.store(static_cast(ResolverStatus::NONE)); + response.clear(); + type = Type::NONE; + hostname.clear(); +} + +int32_t NetworkResolverBase::ResolveHandler::find_empty_id() const { + for (int i = 0; i < MAX_QUERIES; i++) { + if (queue[i].status.load() == static_cast(ResolverStatus::NONE)) { + return i; + } + } + return INVALID_ID; +} + +void NetworkResolverBase::ResolveHandler::resolve_queues() { + for (int i = 0; i < MAX_QUERIES; i++) { + if (queue[i].status.load() != static_cast(ResolverStatus::WAITING)) { + continue; + } + + std::unique_lock lock { mutex }; + memory::string hostname = queue[i].hostname; + Type type = queue[i].type; + + lock.unlock(); + + // We should not lock while resolving the hostname, + // only when modifying the queue. + memory::vector response = NetworkResolver::singleton()._resolve_hostname(hostname, type); + + lock.lock(); + // Could have been completed by another function, or deleted. + if (queue[i].status.load() != static_cast(ResolverStatus::WAITING)) { + continue; + } + // We might be overriding another result, but we don't care as long as the result is valid. + if (response.size()) { + memory::string key = _get_cache_key(hostname, type); + cache[key] = response; + } + queue[i].response = response; + queue[i].status.store(static_cast(response.empty() ? ResolverStatus::ERROR : ResolverStatus::DONE)); + } +} + +void NetworkResolverBase::ResolveHandler::_thread_runner(ResolveHandler& handler) { + while (!handler.should_abort.load(std::memory_order_acquire)) { + handler.semaphore.acquire(); + handler.resolve_queues(); + } +} + +NetworkResolverBase::ResolveHandler::ResolveHandler() : thread { _thread_runner, std::ref(*this) }, semaphore { 0 } { + should_abort.store(false, std::memory_order_release); +} + +NetworkResolverBase::ResolveHandler::~ResolveHandler() { + should_abort.store(true, std::memory_order_release); + semaphore.release(); + thread.join(); +} + +IpAddress NetworkResolverBase::resolve_hostname(std::string_view p_hostname, Type p_type) { + const memory::vector addresses = resolve_hostname_addresses(p_hostname, p_type); + return addresses.empty() ? IpAddress {} : addresses.front(); +} + +memory::vector NetworkResolverBase::resolve_hostname_addresses(std::string_view p_hostname, Type p_type) { + memory::string key = _get_cache_key(p_hostname, p_type); + { + std::unique_lock lock(_resolve_handler.mutex); + if (_resolve_handler.cache.contains(key)) { + return _resolve_handler.cache[key]; + } else { + // This should be run unlocked so the resolver thread can keep resolving + // other requests. + lock.unlock(); + memory::vector result = _resolve_hostname(p_hostname, p_type); + lock.lock(); + // We might be overriding another result, but we don't care as long as the result is valid. + if (result.size()) { + _resolve_handler.cache[key] = result; + } + return result; + } + } +} + +int32_t NetworkResolverBase::resolve_hostname_queue_item(std::string_view p_hostname, Type p_type) { + std::unique_lock lock(_resolve_handler.mutex); + + int32_t id = _resolve_handler.find_empty_id(); + + if (id == INVALID_ID) { + Logger::warning("Out of resolver queries"); + return id; + } + + memory::string key = _get_cache_key(p_hostname, p_type); + _resolve_handler.queue[id].hostname = p_hostname; + _resolve_handler.queue[id].type = p_type; + if (_resolve_handler.cache.contains(key)) { + _resolve_handler.queue[id].response = _resolve_handler.cache[key]; + _resolve_handler.queue[id].status.store(static_cast(ResolverStatus::DONE)); + } else { + _resolve_handler.queue[id].response = memory::vector(); + _resolve_handler.queue[id].status.store(static_cast(ResolverStatus::WAITING)); + if (_resolve_handler.thread.joinable()) { + _resolve_handler.semaphore.release(); + } else { + _resolve_handler.resolve_queues(); + } + } + + return id; +} + +NetworkResolverBase::ResolverStatus NetworkResolverBase::get_resolve_item_status(int32_t p_id) const { + OV_ERR_FAIL_INDEX_V_MSG( + p_id, MAX_QUERIES, ResolverStatus::NONE, + fmt::format( + "Too many concurrent DNS resolver queries ({}, but should be {} at most). Try performing less network requests at " + "once.", + p_id, MAX_QUERIES + ) + ); + + ResolverStatus result = static_cast(_resolve_handler.queue[p_id].status.load()); + if (result == ResolverStatus::NONE) { + Logger::error("Condition status == " _OV_STR(ResolverStatus::NONE)); + return ResolverStatus::NONE; + } + return result; +} + +memory::vector NetworkResolverBase::get_resolve_item_addresses(int32_t p_id) const { + OV_ERR_FAIL_INDEX_V_MSG( + p_id, MAX_QUERIES, memory::vector {}, + fmt::format( + "Too many concurrent DNS resolver queries ({}, but should be {} at most). Try performing less network requests at " + "once.", + p_id, MAX_QUERIES + ) + ); + + std::unique_lock lock(const_cast(_resolve_handler.mutex)); + + if (_resolve_handler.queue[p_id].status.load() != static_cast(ResolverStatus::DONE)) { + Logger::error(fmt::format("Resolve of '{}' didn't complete yet.", _resolve_handler.queue[p_id].hostname)); + return {}; + } + + return _resolve_handler.queue[p_id].response | ranges::views::filter([](IpAddress const& addr) -> bool { + return addr.is_valid(); + }) | + ranges::to>(); +} + +IpAddress NetworkResolverBase::get_resolve_item_address(int32_t p_id) const { + OV_ERR_FAIL_INDEX_V_MSG( + p_id, MAX_QUERIES, IpAddress {}, + fmt::format( + "Too many concurrent DNS resolver queries ({}, but should be {} at most). Try performing less network requests at " + "once.", + p_id, MAX_QUERIES + ) + ); + + std::unique_lock lock(const_cast(_resolve_handler.mutex)); + + if (_resolve_handler.queue[p_id].status.load() != static_cast(ResolverStatus::DONE)) { + Logger::error(fmt::format("Resolve of '{}' didn't complete yet.", _resolve_handler.queue[p_id].hostname)); + return {}; + } + + for (IpAddress const& address : _resolve_handler.queue[p_id].response) { + if (address.is_valid()) { + return address; + } + } + return IpAddress {}; +} + +void NetworkResolverBase::erase_resolve_item(int32_t p_id) { + OV_ERR_FAIL_INDEX_MSG( + p_id, MAX_QUERIES, + fmt::format( + "Too many concurrent DNS resolver queries ({}, but should be {} at most). Try performing less network requests at " + "once.", + p_id, MAX_QUERIES + ) + ); + + _resolve_handler.queue[p_id].status.store(static_cast(ResolverStatus::NONE)); +} +void NetworkResolverBase::clear_cache(std::string_view p_hostname) { + std::unique_lock lock(_resolve_handler.mutex); + + if (p_hostname.empty()) { + _resolve_handler.cache.clear(); + } else { + _resolve_handler.cache.erase(_get_cache_key(p_hostname, Type::NONE)); + _resolve_handler.cache.erase(_get_cache_key(p_hostname, Type::IPV4)); + _resolve_handler.cache.erase(_get_cache_key(p_hostname, Type::IPV6)); + _resolve_handler.cache.erase(_get_cache_key(p_hostname, Type::ANY)); + } +} + +memory::vector NetworkResolverBase::get_local_addresses() const { + OpenVic::string_map_t interfaces = get_local_interfaces(); + if (interfaces.empty()) { + return {}; + } + + memory::vector result; + size_t reserve_size = 0; + + for (std::pair const& pair : interfaces | ranges::views::reverse) { + reserve_size += pair.second.ip_addresses.size(); + } + if (reserve_size == 0) { + return {}; + } + + result.reserve(reserve_size); + for (std::pair const& pair : interfaces | ranges::views::reverse) { + for (IpAddress const& address : pair.second.ip_addresses | ranges::views::reverse) { + result.push_back(address); + } + } + return result; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.hpp b/src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.hpp new file mode 100644 index 000000000..bde116518 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.hpp @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct NetworkResolverBase { + static constexpr size_t MAX_QUERIES = 256; + static constexpr int32_t INVALID_ID = -1; + + enum class Provider : uint8_t { + UNIX, + WINDOWS, + }; + + enum class ResolverStatus : uint8_t { + NONE, + WAITING, + DONE, + ERROR, + }; + + enum class Type : uint8_t { + NONE, + IPV4, + IPV6, + ANY, + }; + + struct InterfaceInfo { + memory::string name; + memory::string name_friendly; + memory::string index; + memory::vector ip_addresses; + }; + + IpAddress resolve_hostname(std::string_view p_hostname, Type p_type = Type::ANY); + memory::vector resolve_hostname_addresses(std::string_view p_hostname, Type p_type = Type::ANY); + // async resolver hostname + int32_t resolve_hostname_queue_item(std::string_view p_hostname, Type p_type = Type::ANY); + ResolverStatus get_resolve_item_status(int32_t p_id) const; + memory::vector get_resolve_item_addresses(int32_t p_id) const; + IpAddress get_resolve_item_address(int32_t p_id) const; + + void erase_resolve_item(int32_t p_id); + + void clear_cache(std::string_view p_hostname = ""); + + memory::vector get_local_addresses() const; + virtual OpenVic::string_map_t get_local_interfaces() const = 0; + + virtual Provider provider() const = 0; + + protected: + virtual memory::vector _resolve_hostname(std::string_view p_hostname, Type p_type = Type::ANY) const = 0; + + struct ResolveHandler { + int32_t find_empty_id() const; + void resolve_queues(); + + struct Item { + Item(); + void clear(); + + memory::vector response; + memory::string hostname; + std::atomic_int32_t status; + Type type; + }; + + std::array queue; + string_map_t> cache; + std::mutex mutex; + std::thread thread; + std::counting_semaphore<> semaphore; + std::atomic_bool should_abort; + + static void _thread_runner(ResolveHandler& handler); + + ResolveHandler(); + ~ResolveHandler(); + } _resolve_handler; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp b/src/openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp new file mode 100644 index 000000000..5122a53e7 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp @@ -0,0 +1,19 @@ +#pragma once + +#ifdef _WIN32 +#include "openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.hpp" +#elif defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)) +#include "openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.hpp" +#else +#error "NetworkSocket.hpp only supports unix or windows systems" +#endif + +namespace OpenVic { + using NetworkSocket = +#ifdef _WIN32 + WindowsSocket +#else + UnixSocket +#endif + ; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp b/src/openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp new file mode 100644 index 000000000..6490745e7 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp @@ -0,0 +1,75 @@ +#pragma once + +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/types/EnumBitfield.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct NetworkSocketBase { + enum class PollType : uint8_t { + IN = 1 << 0, + OUT = 1 << 1, + IN_OUT = IN | OUT, + }; + + enum class Type : uint8_t { + NONE, + TCP, + UDP, + }; + + using port_type = uint16_t; + + virtual ~NetworkSocketBase() = default; + + virtual NetworkError open(Type p_type, NetworkResolver::Type& ip_type) = 0; + virtual void close() = 0; + virtual NetworkError bind(IpAddress const& p_addr, port_type p_port) = 0; + virtual NetworkError listen(int p_max_pending) = 0; + virtual NetworkError connect_to_host(IpAddress const& p_addr, port_type p_port) = 0; + virtual NetworkError poll(PollType p_type, int timeout) const = 0; + virtual NetworkError receive(uint8_t* p_buffer, size_t p_len, int64_t& r_read) = 0; + virtual NetworkError receive_from( // + uint8_t* p_buffer, size_t p_len, int64_t& r_read, IpAddress& r_ip, port_type& r_port, bool p_peek = false + ) = 0; + virtual NetworkError send(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent) = 0; + virtual NetworkError send_to( // + const uint8_t* p_buffer, size_t p_len, int64_t& r_sent, IpAddress const& p_ip, port_type p_port + ) = 0; + + memory::unique_base_ptr accept(IpAddress& r_ip, port_type& r_port) { + return memory::unique_base_ptr { static_cast(_accept(r_ip, r_port)) }; + } + + template T> + memory::unique_ptr accept_as(IpAddress& r_ip, port_type& r_port) { + return memory::unique_ptr { static_cast(_accept(r_ip, r_port)) }; + } + + virtual bool is_open() const = 0; + virtual int available_bytes() const = 0; + virtual NetworkError get_socket_address(IpAddress* r_ip, uint16_t* r_port) const = 0; + + // Returns OK if the socket option has been set successfully + virtual NetworkError set_broadcasting_enabled(bool p_enabled) = 0; + virtual void set_blocking_enabled(bool p_enabled) = 0; + virtual void set_ipv6_only_enabled(bool p_enabled) = 0; + virtual void set_tcp_no_delay_enabled(bool p_enabled) = 0; + virtual void set_reuse_address_enabled(bool p_enabled) = 0; + virtual NetworkError change_multicast_group(IpAddress const& p_multi_address, std::string_view p_if_name, bool add) = 0; + + virtual NetworkResolver::Provider provider() const = 0; + + protected: + virtual NetworkSocketBase* _accept(IpAddress& r_ip, port_type& r_port) = 0; + }; + + template<> + struct enable_bitfield : std::true_type {}; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.cpp b/src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.cpp new file mode 100644 index 000000000..18dab291d --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.cpp @@ -0,0 +1,3 @@ +#include "PacketBuilder.hpp" + +template struct OpenVic::PacketBuilder<>; \ No newline at end of file diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.hpp b/src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.hpp new file mode 100644 index 000000000..55b185601 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include +#include + +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/Marshal.hpp" + +namespace OpenVic { + template + struct PacketBuilder : memory::vector { + using base_type = memory::vector; + using base_type::base_type; + + template + requires requires(T const& value) { utility::encode(value, std::span {}); } + void put_back(T const& value) { + size_t value_size = utility::encode(value); + size_t last = size(); + resize(size() + value_size); + utility::encode(value, std::span { data() + last, value_size }); + } + }; + + extern template struct PacketBuilder<>; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketClient.hpp b/src/openvic-simulation/multiplayer/lowlevel/PacketClient.hpp new file mode 100644 index 000000000..7c0a36772 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketClient.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp" + +namespace OpenVic { + struct PacketClient { + virtual int64_t available_packets() const = 0; + virtual int64_t max_packet_size() const = 0; + + std::span get_packet() { + std::span buffer; + if (NetworkError error = _get_packet(buffer); error != NetworkError::OK) { + _last_error = error; + return {}; + } + return buffer; + } + + NetworkError set_packet(std::span buffer) { + NetworkError error = _set_packet(buffer); + if (error != NetworkError::OK) { + _last_error = error; + } + return error; + } + + NetworkError get_last_error() const { + return _last_error; + } + + PacketSpan packet_span() { + std::span span = get_packet(); + return { span.data(), span.size() }; + } + + protected: + virtual NetworkError _set_packet(std::span buffer) = 0; + virtual NetworkError _get_packet(std::span& r_set) = 0; + + mutable NetworkError _last_error = NetworkError::OK; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.cpp b/src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.cpp new file mode 100644 index 000000000..cd9d853de --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.cpp @@ -0,0 +1,6 @@ +#include "PacketReaderAdapter.hpp" + +#include "openvic-simulation/utility/Containers.hpp" + +template struct OpenVic::PacketReaderAdapter>; +template struct OpenVic::PacketReaderAdapter>; diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp b/src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp new file mode 100644 index 000000000..98e196666 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/Marshal.hpp" + +namespace OpenVic { + template + requires requires(Container& container) { + { container.data() } -> std::same_as; + { container.size() } -> std::same_as; + } + struct PacketReaderAdapter : Container { + using base_type = Container; + using base_type::base_type; + + static constexpr std::integral_constant::max()> npos {}; + + size_t index() const { + return _read_index; + } + + void seek(size_t index) { + if (index == npos) { + _read_index = npos; + } + OV_ERR_FAIL_INDEX(index, this->size()); + _read_index = index; + } + + template + requires requires(PacketReaderAdapter* self, size_t& size) { + { + utility::decode( + std::span { self->data() + self->_read_index, self->size() - self->_read_index }, size + ) + } -> std::same_as>; + } + T read() { + size_t decode_size = 0; + T result = utility::decode( + std::span { this->data() + _read_index, this->size() - _read_index }, decode_size + ); + OV_ERR_FAIL_COND_V(decode_size > this->size(), {}); + _read_index += decode_size; + return result; + } + + private: + size_t _read_index = 0; + }; + + template + using PacketBufferEndian = PacketReaderAdapter, Endian>; + template + using PacketSpanEndian = PacketReaderAdapter, Endian>; + + using PacketBuffer = PacketBufferEndian<>; + using PacketSpan = PacketSpanEndian<>; + + extern template struct PacketReaderAdapter>; + extern template struct PacketReaderAdapter>; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketServer.cpp b/src/openvic-simulation/multiplayer/lowlevel/PacketServer.cpp new file mode 100644 index 000000000..a542b0c38 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketServer.cpp @@ -0,0 +1,76 @@ +#include "PacketServer.hpp" + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketClient.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" + +using namespace OpenVic; + +PacketServer::PacketServer() : _socket { memory::make_shared() } {} + +PacketServer::~PacketServer() { + close(); +} + +NetworkSocket::port_type PacketServer::get_listening_port() const { + NetworkSocket::port_type local_port; + _socket->get_socket_address(nullptr, &local_port); + return local_port; +} + +bool PacketServer::is_connection_available() const { + OV_ERR_FAIL_COND_V(_socket == nullptr, false); + return _socket->is_open(); +} + +bool PacketServer::is_listening() const { + OV_ERR_FAIL_COND_V(_socket == nullptr, false); + return _socket->is_open(); +} + +NetworkError PacketServer::listen_to(NetworkSocket::port_type port, IpAddress const& bind_addr) { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(_socket->is_open(), NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V(!bind_addr.is_valid() && !bind_addr.is_wildcard(), NetworkError::INVALID_PARAMETER); + + NetworkError err; + NetworkResolver::Type ip_type = NetworkResolver::Type::ANY; + + if (bind_addr.is_valid()) { + ip_type = bind_addr.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + } + + err = _socket->open(NetworkSocket::Type::UDP, ip_type); + + if (err != NetworkError::OK) { + return err; + } + + _socket->set_blocking_enabled(false); + _socket->set_reuse_address_enabled(true); + err = _socket->bind(bind_addr, port); + + if (err != NetworkError::OK) { + close(); + return err; + } + return NetworkError::OK; +} + +void PacketServer::close() { + if (_socket != nullptr) { + _socket->close(); + } +} + +NetworkError PacketServer::poll() { + return NetworkError::OK; +} + +memory::unique_base_ptr PacketServer::take_next_client() { + return memory::unique_base_ptr { static_cast(_take_next_client()) }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketServer.hpp b/src/openvic-simulation/multiplayer/lowlevel/PacketServer.hpp new file mode 100644 index 000000000..614578161 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketServer.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketClient.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct PacketServer { + PacketServer(); + ~PacketServer(); + + NetworkSocket::port_type get_listening_port() const; + bool is_listening() const; + virtual bool is_connection_available() const; + virtual NetworkError listen_to(NetworkSocket::port_type port, IpAddress const& bind_addr = "*"); + virtual void close(); + virtual NetworkError poll(); + + memory::unique_base_ptr take_next_client(); + + template T> + memory::unique_ptr take_next_client_as() { + return memory::unique_ptr { static_cast(_take_next_client()) }; + } + + protected: + virtual PacketClient* _take_next_client() = 0; + + std::shared_ptr _socket; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketStream.cpp b/src/openvic-simulation/multiplayer/lowlevel/PacketStream.cpp new file mode 100644 index 000000000..29a7242c7 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketStream.cpp @@ -0,0 +1,110 @@ + +#include "PacketStream.hpp" + +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" + +using namespace OpenVic; + +NetworkError BufferPacketStream::set_data(std::span buffer) { + if (buffer.empty()) { + return NetworkError::OK; + } + + if (_position + buffer.size() > _data.size()) { + _data.resize(_position + buffer.size()); + } + + std::memcpy(&_data[_position], buffer.data(), buffer.size()); + + _position += buffer.size(); + return NetworkError::OK; +} + +NetworkError BufferPacketStream::get_data(std::span buffer_to_set) { + size_t received; + get_data_no_blocking(buffer_to_set, received); + if (received != buffer_to_set.size()) { + return NetworkError::INVALID_PARAMETER; + } + + return NetworkError::OK; +} + +NetworkError BufferPacketStream::set_data_no_blocking(std::span buffer, size_t& r_sent) { + r_sent = buffer.size(); + return set_data(buffer); +} + +NetworkError BufferPacketStream::get_data_no_blocking(std::span buffer_to_set, size_t& r_received) { + if (buffer_to_set.empty()) { + r_received = 0; + return NetworkError::OK; + } + + if (_position + buffer_to_set.size() > _data.size()) { + r_received = _data.size() - _position; + if (r_received <= 0) { + r_received = 0; + return NetworkError::OK; // you got 0 + } + } else { + r_received = buffer_to_set.size(); + } + + std::memcpy(buffer_to_set.data(), &_data[_position], r_received); + + _position += r_received; + // FIXME: return what? OK or ERR_* + // return OK for now so we don't maybe return garbage + return NetworkError::OK; +} + +int64_t BufferPacketStream::available_bytes() const { + return _data.size() - _position; +} + +void BufferPacketStream::seek(size_t p_pos) { + OV_ERR_FAIL_COND(p_pos > _data.size()); + _position = p_pos; +} + +size_t BufferPacketStream::size() const { + return _data.size(); +} + +size_t BufferPacketStream::position() const { + return _position; +} + +void BufferPacketStream::resize(size_t p_size) { + _data.resize(p_size); +} + +void BufferPacketStream::set_buffer_data(std::span p_data) { + if (p_data.size() > _data.size()) { + _data.resize(p_data.size()); + } + + std::uninitialized_copy(p_data.begin(), p_data.end(), _data.begin()); +} + +std::span BufferPacketStream::get_buffer_data() const { + return _data; +} + +void BufferPacketStream::clear() { + _data.clear(); + _position = 0; +} + +BufferPacketStream BufferPacketStream::duplicate() const { + BufferPacketStream result; + result._data.resize(_data.size()); + std::uninitialized_copy(_data.begin(), _data.end(), result._data.begin()); + return result; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketStream.hpp b/src/openvic-simulation/multiplayer/lowlevel/PacketStream.hpp new file mode 100644 index 000000000..1d2af1d91 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketStream.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct PacketStream { + virtual NetworkError set_data(std::span buffer) = 0; + virtual NetworkError get_data(std::span buffer_to_set) = 0; + virtual NetworkError set_data_no_blocking(std::span buffer, size_t& r_sent) = 0; + virtual NetworkError get_data_no_blocking(std::span buffer_to_set, size_t& r_received) = 0; + virtual int64_t available_bytes() const = 0; + + template + PacketBufferEndian packet_buffer(size_t bytes) { + PacketBufferEndian result; + result.resize(bytes); + return packet_buffer(result); + } + + template + PacketBufferEndian packet_buffer(PacketBufferEndian& buffer_store) { + get_data(buffer_store); + return buffer_store; + } + }; + + struct BufferPacketStream : PacketStream { + NetworkError set_data(std::span buffer) override; + NetworkError get_data(std::span r_buffer) override; + + NetworkError set_data_no_blocking(std::span buffer, size_t& r_sent) override; + NetworkError get_data_no_blocking(std::span buffer_to_set, size_t& r_received) override; + + virtual int64_t available_bytes() const override; + + void seek(size_t p_pos); + size_t size() const; + size_t position() const; + void resize(size_t p_size); + + void set_buffer_data(std::span p_data); + std::span get_buffer_data() const; + + void clear(); + + BufferPacketStream duplicate() const; + + private: + memory::vector _data; + size_t _position; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.cpp b/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.cpp new file mode 100644 index 000000000..7780c42f5 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.cpp @@ -0,0 +1,256 @@ + +#include "ReliableUdpClient.hpp" + +#include + +#include + +#include "openvic-simulation/GameManager.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/UdpClient.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/MemoryTracker.hpp" + +using namespace OpenVic; + +void ReliableUdpClient::_default_config_values(reliable_config_t& config) { + if (config.fragment_above == 0) { + config.fragment_above = 1024; + } + if (config.max_fragments == 0) { + config.max_fragments = 3; + } + if (config.fragment_size == 0) { + config.fragment_size = 1024; + } + if (config.ack_buffer_size == 0) { + config.ack_buffer_size = 256; + } + if (config.sent_packets_buffer_size == 0) { + config.sent_packets_buffer_size = 256; + } + if (config.received_packets_buffer_size == 0) { + config.received_packets_buffer_size = 256; + } + if (config.fragment_reassembly_buffer_size == 0) { + config.fragment_reassembly_buffer_size = 64; + } + if (config.rtt_smoothing_factor == 0) { + config.rtt_smoothing_factor = 0.0025f; + } + if (config.rtt_history_size == 0) { + config.rtt_history_size = 512; + } + if (config.packet_loss_smoothing_factor == 0) { + config.packet_loss_smoothing_factor = 0.1f; + } + if (config.bandwidth_smoothing_factor == 0) { + config.bandwidth_smoothing_factor = 0.1f; + } + if (config.packet_header_size == 0) { + config.packet_header_size = 48; + } + if (config.max_packet_size == 0) { + config.max_packet_size = max_packet_size() - config.packet_header_size; + } +} + +ReliableUdpClient::ReliableUdpClient(reliable_config_t config) { + _default_config_values(config); + + config.context = this; + config.transmit_packet_function = &_transmit_packet; + config.process_packet_function = &_process_packet; + + _endpoint = endpoint_pointer_type { + reliable_endpoint_create(&config, static_cast(GameManager::get_elapsed_milliseconds()) / 1000.0) + }; +#ifdef DEBUG_ENABLED + { + utility::MemoryTracker tracker {}; + tracker.on_node_allocation(nullptr, sizeof(reliable_endpoint_dummy), 0); + } +#endif + + OV_ERR_FAIL_COND(!_endpoint); +} + +thread_local NetworkError transmit_error; + +void ReliableUdpClient::_transmit_packet(void* context, uint64_t id, uint16_t sequence, uint8_t* packet_data, int packet_size) { + ReliableUdpClient* self = static_cast(context); + transmit_error = self->UdpClient::_set_packet({ packet_data, static_cast(packet_size) }); +} + +NetworkError ReliableUdpClient::_set_packet(std::span buffer) { + OV_ERR_FAIL_COND_V(!_endpoint, NetworkError::UNAVAILABLE); + + reliable_endpoint_send_packet(_endpoint.get(), buffer.data(), buffer.size_bytes()); + // Pretty sure floating point for packet tracking shouldn't cause issues + update(static_cast(GameManager::get_elapsed_milliseconds()) / 1000.0); + return transmit_error; +} + +struct PeerProcessData { + IpAddress& p_ip; + NetworkSocket::port_type& p_port; +}; + +thread_local PeerProcessData* process_peer_data = nullptr; + +int ReliableUdpClient::_process_packet(void* context, uint64_t id, uint16_t sequence, uint8_t* packet_data, int packet_size) { + ReliableUdpClient* self = static_cast(context); + std::span packet { packet_data, static_cast(packet_size) }; + + if (self->UdpClient::store_packet(process_peer_data->p_ip, process_peer_data->p_port, packet) == NetworkError::OK) { + return 1; + } + + process_peer_data = nullptr; + return 0; +} + +NetworkError ReliableUdpClient::store_packet(IpAddress p_ip, NetworkSocket::port_type p_port, std::span p_buf) { + OV_ERR_FAIL_COND_V(!_endpoint, NetworkError::UNAVAILABLE); + + PeerProcessData data = { p_ip, p_port }; + process_peer_data = &data; + reliable_endpoint_receive_packet(_endpoint.get(), p_buf.data(), p_buf.size_bytes()); + if (process_peer_data == nullptr) { + return NetworkError::OUT_OF_MEMORY; + } + process_peer_data = nullptr; + return NetworkError::OK; +} + +void ReliableUdpClient::update(double time) { + reliable_endpoint_update(_endpoint.get(), time); +} + +void ReliableUdpClient::reset() { + reliable_endpoint_reset(_endpoint.get()); +} + +ReliableUdpClient::sequence_type ReliableUdpClient::get_current_sequence_value() const { + return get_next_sequence_value() - 1; +} + +ReliableUdpClient::sequence_type ReliableUdpClient::get_next_sequence_value() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_next_packet_sequence(_endpoint.get()); +} + +std::span ReliableUdpClient::get_acknowledged_sequences() const { + OV_ERR_FAIL_COND_V(!_endpoint, {}); + + int32_t acks_count; + uint16_t* acks_array = reliable_endpoint_get_acks(_endpoint.get(), &acks_count); + OV_ERR_FAIL_COND_V(acks_count <= 0, {}); + + return { acks_array, static_cast(acks_count) }; +} + +void ReliableUdpClient::clear_acknowledged_sequences() { + OV_ERR_FAIL_COND(!_endpoint); + + reliable_endpoint_clear_acks(_endpoint.get()); +} + +float ReliableUdpClient::get_round_trip_smooth_average() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_rtt(_endpoint.get()); +} + +float ReliableUdpClient::get_round_trip_average() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_rtt_avg(_endpoint.get()); +} + +float ReliableUdpClient::get_round_trip_minimum() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_rtt_min(_endpoint.get()); +} + +float ReliableUdpClient::get_round_trip_maximum() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_rtt_max(_endpoint.get()); +} + +float ReliableUdpClient::get_jitter_average() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_jitter_avg_vs_min_rtt(_endpoint.get()); +} + +float ReliableUdpClient::get_jitter_maximum() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_jitter_max_vs_min_rtt(_endpoint.get()); +} + +float ReliableUdpClient::get_jitter_against_average_round_trip() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_jitter_stddev_vs_avg_rtt(_endpoint.get()); +} + +float ReliableUdpClient::get_packet_loss() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_packet_loss(_endpoint.get()); +} + +ReliableUdpClient::BandwidthStatistics ReliableUdpClient::get_bandwidth_statistics() const { + OV_ERR_FAIL_COND_V(!_endpoint, {}); + + BandwidthStatistics result; // NOLINT + reliable_endpoint_bandwidth( + _endpoint.get(), &result.sent_bandwidth_kbps, &result.received_bandwidth_kbps, &result.acked_bandwidth_kpbs + ); + return result; +} + +size_t ReliableUdpClient::get_packets_sent() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_SENT]; +} + +size_t ReliableUdpClient::get_packets_received() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_RECEIVED]; +} + +size_t ReliableUdpClient::get_packets_acknowledged() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_ACKED]; +} + +size_t ReliableUdpClient::get_packets_stale() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_STALE]; +} + +size_t ReliableUdpClient::get_packets_invalid() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_INVALID]; +} + +size_t ReliableUdpClient::get_packets_too_large_to_send() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_TOO_LARGE_TO_SEND]; +} + +size_t ReliableUdpClient::get_packets_too_large_to_receive() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_TOO_LARGE_TO_RECEIVE]; +} + +size_t ReliableUdpClient::get_fragments_sent() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_FRAGMENTS_SENT]; +} + +size_t ReliableUdpClient::get_fragments_received() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_FRAGMENTS_RECEIVED]; +} + +size_t ReliableUdpClient::get_fragments_invalid() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_FRAGMENTS_INVALID]; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.hpp b/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.hpp new file mode 100644 index 000000000..767631098 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.hpp @@ -0,0 +1,112 @@ +#pragma once + +#include +#include +#include + +#include + +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/UdpClient.hpp" + +namespace OpenVic { + struct ReliableUdpClient : UdpClient { + ReliableUdpClient(reliable_config_t config = { .name = "client" }); + + using sequence_type = uint16_t; + + NetworkError store_packet(IpAddress p_ip, NetworkSocket::port_type p_port, std::span p_buf) override; + + void update(double time); + void reset(); + + sequence_type get_current_sequence_value() const; + sequence_type get_next_sequence_value() const; + + std::span get_acknowledged_sequences() const; + void clear_acknowledged_sequences(); + + float get_round_trip_smooth_average() const; + float get_round_trip_average() const; + float get_round_trip_minimum() const; + float get_round_trip_maximum() const; + + float get_jitter_average() const; + float get_jitter_maximum() const; + float get_jitter_against_average_round_trip() const; + + float get_packet_loss() const; + + struct BandwidthStatistics { + float sent_bandwidth_kbps; + float received_bandwidth_kbps; + float acked_bandwidth_kpbs; + }; + BandwidthStatistics get_bandwidth_statistics() const; + + size_t get_packets_sent() const; + size_t get_packets_received() const; + size_t get_packets_acknowledged() const; + size_t get_packets_stale() const; + size_t get_packets_invalid() const; + size_t get_packets_too_large_to_send() const; + size_t get_packets_too_large_to_receive() const; + size_t get_fragments_sent() const; + size_t get_fragments_received() const; + size_t get_fragments_invalid() const; + + protected: + NetworkError _set_packet(std::span buffer) override; + + private: + static void _transmit_packet(void* context, uint64_t id, uint16_t sequence, uint8_t* packet_data, int packet_size); + static int _process_packet(void* context, uint64_t ud, uint16_t sequence, uint8_t* packet_data, int packet_size); + + void _default_config_values(reliable_config_t& config); + + // Ensures we track an accurate allocation/deallocation size for reliable_endpoint_t + struct reliable_endpoint_dummy { + void* allocator_context; + void* (*allocate_function)(void*, size_t); + void (*free_function)(void*, void*); + struct reliable_config_t config; + double time; + float rtt; + float rtt_min; + float rtt_max; + float rtt_avg; + float jitter_avg_vs_min_rtt; + float jitter_max_vs_min_rtt; + float jitter_stddev_vs_avg_rtt; + float packet_loss; + float sent_bandwidth_kbps; + float received_bandwidth_kbps; + float acked_bandwidth_kbps; + int num_acks; + uint16_t* acks; + uint16_t sequence; + float* rtt_history_buffer; + struct reliable_sequence_buffer_t* sent_packets; + struct reliable_sequence_buffer_t* received_packets; + struct reliable_sequence_buffer_t* fragment_reassembly; + uint64_t counters[RELIABLE_ENDPOINT_NUM_COUNTERS]; + }; + + struct endpoint_deleter_t { + using value_type = reliable_endpoint_t; + + void operator()(value_type* ptr) { + reliable_endpoint_destroy(ptr); +#ifdef DEBUG_ENABLED + { + utility::MemoryTracker tracker {}; + tracker.on_node_deallocation(nullptr, sizeof(reliable_endpoint_dummy), 0); + } +#endif + } + }; + using endpoint_pointer_type = std::unique_ptr; + + endpoint_pointer_type _endpoint; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpServer.hpp b/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpServer.hpp new file mode 100644 index 000000000..8878d0bb9 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpServer.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include +#include + +#include + +#include "openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.hpp" +#include "openvic-simulation/multiplayer/lowlevel/UdpServer.hpp" +#include "openvic-simulation/utility/Utility.hpp" + +namespace OpenVic { + struct ReliableUdpServer : UdpServer { + void set_client_config_template(reliable_config_t const& config) { + _client_config_template = config; + } + + reliable_config_t const& get_client_config_template() const { + return _client_config_template; + } + + protected: + inline ReliableUdpClient* _take_next_client() override { + return static_cast(UdpServer::_take_next_client()); + } + + inline ReliableUdpClient* _create_client() override { + reliable_config_t config = _client_config_template; + config.id = _clients.size() + _pending_clients.size() + 1; + + size_t name_length = std::strlen(config.name); + if (OV_unlikely(name_length >= 225)) { + static constexpr const char server_connection[] = "server-connection-"; + static_assert(sizeof(server_connection) - 1 < sizeof(config.name)); + + name_length = sizeof(server_connection) - 1; + std::strncpy(config.name, server_connection, name_length); + } + + std::to_chars(config.name + name_length, config.name + sizeof(config.name), config.id); + + return new ReliableUdpClient(std::move(config)); + } + + reliable_config_t _client_config_template { .name = "server-connection-" }; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.cpp b/src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.cpp new file mode 100644 index 000000000..d5fdfe69c --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.cpp @@ -0,0 +1,10 @@ + +#include "StreamPacketClient.hpp" + +#include "openvic-simulation/multiplayer/lowlevel/PacketStream.hpp" +#include "openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp" + +using namespace OpenVic; + +template struct OpenVic::BasicStreamPacketClient; +template struct OpenVic::BasicStreamPacketClient; diff --git a/src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.hpp b/src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.hpp new file mode 100644 index 000000000..843e22797 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.hpp @@ -0,0 +1,173 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketClient.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketStream.hpp" +#include "openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp" +#include "openvic-simulation/types/RingBuffer.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/Marshal.hpp" + +namespace OpenVic { + template T> + struct BasicStreamPacketClient : PacketClient { + BasicStreamPacketClient(uint8_t max_buffer_power = 16) { + _ring_buffer.reserve_power(max_buffer_power); + _input_buffer.resize(int64_t(1) << max_buffer_power); + _output_buffer.resize(int64_t(1) << max_buffer_power); + } + + int64_t available_packets() const override { + _poll_buffer(); + + uint32_t remaining = _ring_buffer.size() - _ring_buffer.space() - 1; + + int offset = 0; + int count = 0; + + while (remaining >= 4) { + uint8_t lbuf[4]; + std::copy_n(_ring_buffer.begin() + offset, 4, lbuf); + size_t decode_size = 0; + uint32_t length = utility::decode(lbuf, decode_size); + OV_ERR_BREAK(decode_size != 4); + remaining -= 4; + offset += 4; + if (length > remaining) { + break; + } + remaining -= length; + offset += length; + count++; + } + + return count; + } + + int64_t max_packet_size() const override { + return _output_buffer.size(); + } + + void set_packet_stream(std::shared_ptr stream) { + if (stream != _stream) { + _ring_buffer.clear(); // Reset the ring buffer. + } + + _stream = stream; + } + + std::shared_ptr get_packet_stream() const { + return _stream; + } + + void set_input_buffer_max_size(int p_max_size) { + OV_ERR_FAIL_COND_MSG(p_max_size < 0, "Max size of input buffer size cannot be smaller than 0."); + // WARNING: May lose packets. + OV_ERR_FAIL_COND_MSG( + _ring_buffer.size() - _ring_buffer.space() - 1, "Buffer in use, resizing would cause loss of data." + ); + _ring_buffer.reserve_power(std::bit_width(std::bit_ceil(p_max_size + 4)) - 1); + _input_buffer.resize(std::bit_ceil(p_max_size + 4)); + } + + int get_input_buffer_max_size() const { + return _input_buffer.size() - 4; + } + + void set_output_buffer_max_size(int p_max_size) { + _output_buffer.resize(std::bit_ceil(p_max_size + 4)); + } + + int get_output_buffer_max_size() const { + return _output_buffer.size() - 4; + } + + private: + NetworkError _set_packet(std::span buffer) override { + OV_ERR_FAIL_COND_V(_stream == nullptr, NetworkError::UNCONFIGURED); + NetworkError err = _poll_buffer(); // won't hurt to poll here too + + if (err != NetworkError::OK) { + return err; + } + + if (buffer.size() == 0) { + return NetworkError::OK; + } + + OV_ERR_FAIL_COND_V(buffer.size() + 4 > _output_buffer.size(), NetworkError::INVALID_PARAMETER); + + utility::encode(buffer.size(), _output_buffer); + uint8_t* dst = &_output_buffer[4]; + for (int i = 0; i < buffer.size(); i++) { + dst[i] = buffer[i]; + } + + return _stream->set_data({ &_output_buffer[0], buffer.size() + 4 }); + } + + NetworkError _get_packet(std::span& r_set) override { + OV_ERR_FAIL_COND_V(_stream == nullptr, NetworkError::UNCONFIGURED); + _poll_buffer(); + + int remaining = _ring_buffer.size() - _ring_buffer.space() - 1; + OV_ERR_FAIL_COND_V(remaining < 4, NetworkError::UNAVAILABLE); + uint8_t lbuf[4]; + std::copy_n(_ring_buffer.begin(), 4, lbuf); + remaining -= 4; + size_t decode_size = 0; + uint32_t length = utility::decode(lbuf, decode_size); + OV_ERR_FAIL_COND_V(remaining < (int)length, NetworkError::UNAVAILABLE); + + OV_ERR_FAIL_COND_V(_input_buffer.size() < length, NetworkError::UNAVAILABLE); + _ring_buffer.erase(_ring_buffer.begin(), decode_size); // get rid of the decode_size + _ring_buffer.read_buffer_to(_input_buffer.data(), length); // read packet + + r_set = std::span { _input_buffer.data(), length }; + return NetworkError::OK; + } + + NetworkError _poll_buffer() const { + OV_ERR_FAIL_COND_V(_stream == nullptr, NetworkError::UNCONFIGURED); + + size_t read = 0; + OV_ERR_FAIL_COND_V(_input_buffer.size() < _ring_buffer.space(), NetworkError::UNAVAILABLE); + NetworkError err = _stream->get_data_no_blocking({ _input_buffer.data(), _ring_buffer.space() }, read); + if (err != NetworkError::OK) { + return err; + } + if (read == 0) { + return NetworkError::OK; + } + + typename decltype(_ring_buffer)::iterator last = _ring_buffer.end(); + ptrdiff_t w = std::distance(last, _ring_buffer.append(&_input_buffer[0], read)); + OV_ERR_FAIL_COND_V(w != read, NetworkError::BUG); + + return NetworkError::OK; + } + + mutable std::shared_ptr _stream; + mutable RingBuffer _ring_buffer; + mutable memory::vector _input_buffer; + mutable memory::vector _output_buffer; + }; + + using StreamPacketClient = BasicStreamPacketClient; + extern template struct BasicStreamPacketClient; + + using TcpStreamPacketClient = BasicStreamPacketClient; + extern template struct BasicStreamPacketClient; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.cpp b/src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.cpp new file mode 100644 index 000000000..890f0cb27 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.cpp @@ -0,0 +1,316 @@ + +#include "TcpPacketStream.hpp" + +#include +#include +#include +#include + +#include "openvic-simulation/GameManager.hpp" +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" + +using namespace OpenVic; + +TcpPacketStream::TcpPacketStream() : _socket { std::make_shared() } {} + +TcpPacketStream::~TcpPacketStream() { + close(); +} + +void TcpPacketStream::accept_socket(std::shared_ptr p_sock, IpAddress p_host, NetworkSocket::port_type p_port) { + _socket = p_sock; + _socket->set_blocking_enabled(false); + + _timeout = GameManager::get_elapsed_milliseconds() + _timeout_seconds * 1000; + _status = Status::CONNECTED; + + _client_address = p_host; + _client_port = p_port; +} + +int64_t TcpPacketStream::available_bytes() const { + OV_ERR_FAIL_COND_V(_socket == nullptr, -1); + return _socket->available_bytes(); +} + +NetworkError TcpPacketStream::bind(NetworkSocket::port_type port, IpAddress const& address) { + OV_ERR_FAIL_COND_V(_socket != nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(_socket->is_open(), NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V_MSG( + port < 0 || port > 65535, NetworkError::INVALID_PARAMETER, + "The local port number must be between 0 and 65535 (inclusive)." + ); + + NetworkResolver::Type ip_type = address.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + if (address.is_wildcard()) { + ip_type = NetworkResolver::Type::ANY; + } + NetworkError err = _socket->open(NetworkSocket::Type::TCP, ip_type); + if (err != NetworkError::OK) { + return err; + } + _socket->set_blocking_enabled(false); + return _socket->bind(address, port); +} + +NetworkError TcpPacketStream::connect_to(IpAddress const& host, NetworkSocket::port_type port) { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(_status != Status::NONE, NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V(!host.is_valid(), NetworkError::INVALID_PARAMETER); + OV_ERR_FAIL_COND_V_MSG( + port < 1 || port > 65535, NetworkError::INVALID_PARAMETER, + "The remote port number must be between 1 and 65535 (inclusive)." + ); + + if (!_socket->is_open()) { + NetworkResolver::Type ip_type = host.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + NetworkError err = _socket->open(NetworkSocket::Type::TCP, ip_type); + if (err != NetworkError::OK) { + return err; + } + _socket->set_blocking_enabled(false); + } + + _timeout = GameManager::get_elapsed_milliseconds() + _timeout_seconds * 1000; + NetworkError err = _socket->connect_to_host(host, port); + + if (err == NetworkError::OK) { + _status = Status::CONNECTED; + } else if (err == NetworkError::BUSY) { + _status = Status::CONNECTING; + } else { + Logger::error("Connection to remote host failed!"); + close(); + return err; + } + + _client_address = host; + _client_port = port; + + return NetworkError::OK; +} + +NetworkError TcpPacketStream::poll() { + switch (_status) { + using enum Status; + case CONNECTED: { + NetworkError err; + err = _socket->poll(NetworkSocket::PollType::IN, 0); + if (err == NetworkError::OK) { + // FIN received + if (_socket->available_bytes() == 0) { + close(); + return NetworkError::OK; + } + } + // Also poll write + err = _socket->poll(NetworkSocket::PollType::IN_OUT, 0); + if (err != NetworkError::OK && err != NetworkError::BUSY) { + // Got an error + close(); + _status = ERROR; + return err; + } + return NetworkError::OK; + } + case CONNECTING: break; + default: return NetworkError::OK; + } + + NetworkError err = _socket->connect_to_host(_client_address, _client_port); + + if (err == NetworkError::OK) { + _status = Status::CONNECTED; + return NetworkError::OK; + } else if (err == NetworkError::BUSY) { + // Check for connect timeout + if (GameManager::get_elapsed_milliseconds() > _timeout) { + close(); + _status = Status::ERROR; + return err; + } + // Still trying to connect + return NetworkError::OK; + } + + close(); + _status = Status::ERROR; + return err; +} + +NetworkError TcpPacketStream::wait(NetworkSocket::PollType p_type, int p_timeout) { + OV_ERR_FAIL_COND_V(_socket == nullptr || !_socket->is_open(), NetworkError::UNAVAILABLE); + return _socket->poll(p_type, p_timeout); +} + +void TcpPacketStream::close() { + if (_socket != nullptr && _socket->is_open()) { + _socket->close(); + } + + _timeout = 0; + _status = Status::NONE; + _client_address = {}; + _client_port = 0; +} + +NetworkSocket::port_type TcpPacketStream::get_bound_port() const { + NetworkSocket::port_type local_port; + _socket->get_socket_address(nullptr, &local_port); + return local_port; +} + +bool TcpPacketStream::is_connected() const { + return _status == Status::CONNECTED; +} + +IpAddress TcpPacketStream::get_connected_address() const { + return _client_address; +} + +NetworkSocket::port_type TcpPacketStream::get_connected_port() const { + return _client_port; +} + +TcpPacketStream::Status TcpPacketStream::get_status() const { + return _status; +} + +void TcpPacketStream::set_timeout_seconds(uint64_t timeout) { + _timeout_seconds = timeout; +} + +uint64_t TcpPacketStream::get_timeout_seconds() const { + return _timeout_seconds; +} + +void TcpPacketStream::set_no_delay(bool p_enabled) { + OV_ERR_FAIL_COND(_socket == nullptr || !_socket->is_open()); + _socket->set_tcp_no_delay_enabled(p_enabled); +} + +NetworkError TcpPacketStream::set_data(std::span buffer) { + size_t _ = 0; + return _set_data(buffer, _); +} + +NetworkError TcpPacketStream::set_data_no_blocking(std::span buffer, size_t& r_sent) { + return _set_data(buffer, r_sent); +} + +NetworkError TcpPacketStream::get_data(std::span buffer_to_set) { + return _get_data(buffer_to_set); +} + +NetworkError TcpPacketStream::get_data_no_blocking(std::span buffer_to_set, size_t& r_received) { + NetworkError result = _get_data(buffer_to_set); + r_received = buffer_to_set.size(); + return result; +} + +template +NetworkError TcpPacketStream::_set_data(std::span buffer, size_t& r_sent) { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + + if (_status != Status::CONNECTED) { + return NetworkError::FAILED; + } + + NetworkError err; + int data_to_send = buffer.size_bytes(); + const uint8_t* offset = buffer.data(); + int64_t total_sent = 0; + + while (data_to_send) { + int64_t sent_amount = 0; + err = _socket->send(offset, data_to_send, sent_amount); + + if (err != NetworkError::OK) { + if (err != NetworkError::BUSY) { + close(); + return err; + } + + if constexpr (!IsBlocking) { + r_sent = total_sent; + return NetworkError::OK; + } else { + // Block and wait for the socket to accept more data + err = _socket->poll(NetworkSocket::PollType::OUT, -1); + if (err != NetworkError::OK) { + close(); + return err; + } + } + } else { + data_to_send -= sent_amount; + offset += sent_amount; + total_sent += sent_amount; + } + } + + r_sent = total_sent; + + return NetworkError::OK; +} + +template NetworkError TcpPacketStream::_set_data(std::span buffer, size_t& r_sent); +template NetworkError TcpPacketStream::_set_data(std::span buffer, size_t& r_sent); + +template +NetworkError TcpPacketStream::_get_data(std::span& r_buffer) { + if (_status != Status::CONNECTED) { + return NetworkError::FAILED; + } + + NetworkError err; + int to_read = r_buffer.size_bytes(); + int total_read = 0; + + while (to_read) { + int64_t read = 0; + err = _socket->receive(r_buffer.data() + total_read, to_read, read); + + if (err != NetworkError::OK) { + if (err != NetworkError::BUSY) { + close(); + return err; + } + + if constexpr (!IsBlocking) { + r_buffer = { r_buffer.data(), static_cast(total_read) }; + return NetworkError::OK; + } else { + err = _socket->poll(NetworkSocket::PollType::IN, -1); + + if (err != NetworkError::OK) { + close(); + return err; + } + } + } else if (read == 0) { + close(); + r_buffer = { r_buffer.data(), static_cast(total_read) }; + return NetworkError::EMPTY_BUFFER; + } else { + to_read -= read; + total_read += read; + + if constexpr (!IsBlocking) { + r_buffer = { r_buffer.data(), static_cast(total_read) }; + return NetworkError::OK; + } + } + } + + r_buffer = { r_buffer.data(), static_cast(total_read) }; + + return NetworkError::OK; +} + +template NetworkError TcpPacketStream::_get_data(std::span& r_buffer); +template NetworkError TcpPacketStream::_get_data(std::span& r_buffer); diff --git a/src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp b/src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp new file mode 100644 index 000000000..29050d254 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp @@ -0,0 +1,69 @@ +#pragma once + +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketStream.hpp" + +namespace OpenVic { + struct TcpPacketStream : PacketStream { + enum class Status : uint8_t { + NONE, + CONNECTING, + CONNECTED, + ERROR, + }; + + TcpPacketStream(); + ~TcpPacketStream(); + + void accept_socket(std::shared_ptr p_sock, IpAddress p_host, NetworkSocket::port_type p_port); + + int64_t available_bytes() const override; + NetworkError bind(NetworkSocket::port_type port, IpAddress const& address = "*"); + NetworkError connect_to(IpAddress const& host, NetworkSocket::port_type port); + NetworkError poll(); + NetworkError wait(NetworkSocket::PollType p_type, int p_timeout = 0); + void close(); + NetworkSocket::port_type get_bound_port() const; + bool is_connected() const; + + IpAddress get_connected_address() const; + NetworkSocket::port_type get_connected_port() const; + Status get_status() const; + + void set_timeout_seconds(uint64_t timeout); + uint64_t get_timeout_seconds() const; + + void set_no_delay(bool p_enabled); + + NetworkError set_data(std::span buffer) override; + NetworkError set_data_no_blocking(std::span buffer, size_t& r_sent) override; + NetworkError get_data(std::span buffer_to_set) override; + NetworkError get_data_no_blocking(std::span buffer_to_set, size_t& r_received) override; + + private: + uint64_t _timeout_seconds = 30; + + Status _status = Status::NONE; + uint64_t _timeout = 0; + NetworkSocket::port_type _client_port = 0; + IpAddress _client_address; + std::shared_ptr _socket; + + template + NetworkError _set_data(std::span buffer, size_t& r_sent); + + template + NetworkError _get_data(std::span& r_buffer); + }; + + extern template NetworkError TcpPacketStream::_set_data(std::span buffer, size_t& r_sent); + extern template NetworkError TcpPacketStream::_set_data(std::span buffer, size_t& r_sent); + extern template NetworkError TcpPacketStream::_get_data(std::span& r_buffer); + extern template NetworkError TcpPacketStream::_get_data(std::span& r_buffer); +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/TcpServer.cpp b/src/openvic-simulation/multiplayer/lowlevel/TcpServer.cpp new file mode 100644 index 000000000..3981e3d5b --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/TcpServer.cpp @@ -0,0 +1,79 @@ + +#include "TcpServer.hpp" + +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketServer.hpp" +#include "openvic-simulation/multiplayer/lowlevel/StreamPacketClient.hpp" +#include "openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +using namespace OpenVic; + +bool TcpServer::is_connection_available() const { + if (!PacketServer::is_connection_available()) { + return false; + } + return _socket->poll(NetworkSocket::PollType::IN, 0) == NetworkError::OK; +} + +NetworkError TcpServer::listen_to(NetworkSocket::port_type port, IpAddress const& bind_addr) { + OV_ERR_FAIL_COND_V(!_socket, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(_socket->is_open(), NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V(!bind_addr.is_valid() && !bind_addr.is_wildcard(), NetworkError::INVALID_PARAMETER); + + NetworkError err; + NetworkResolver::Type ip_type = NetworkResolver::Type::ANY; + + // If the bind address is valid use its type as the socket type + if (bind_addr.is_valid()) { + ip_type = bind_addr.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + } + + err = _socket->open(NetworkSocket::Type::TCP, ip_type); + + OV_ERR_FAIL_COND_V(err != NetworkError::OK, err); + + _socket->set_blocking_enabled(false); + _socket->set_reuse_address_enabled(true); + + err = _socket->bind(bind_addr, port); + + if (err != NetworkError::OK) { + _socket->close(); + return err; + } + + err = _socket->listen(MAX_PENDING_CONNECTIONS); + if (err != NetworkError::OK) { + _socket->close(); + } + return err; +} + +memory::unique_ptr TcpServer::take_packet_stream() { + if (!is_connection_available()) { + return nullptr; + } + + IpAddress ip; + NetworkSocket::port_type port = 0; + std::shared_ptr ns = _socket->accept_as(ip, port); + if (!ns->is_open()) { + return nullptr; + } + + memory::unique_ptr result = memory::make_unique(); + result->accept_socket(ns, ip, port); + return result; +} + +TcpStreamPacketClient* TcpServer::_take_next_client() { + memory::unique_ptr client = memory::make_unique(); + client->set_packet_stream(std::shared_ptr { take_packet_stream().release() }); + return client.release(); +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/TcpServer.hpp b/src/openvic-simulation/multiplayer/lowlevel/TcpServer.hpp new file mode 100644 index 000000000..889fa5433 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/TcpServer.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketServer.hpp" +#include "openvic-simulation/multiplayer/lowlevel/StreamPacketClient.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct TcpServer : PacketServer { + static constexpr size_t MAX_PENDING_CONNECTIONS = 8; + + bool is_connection_available() const override; + NetworkError listen_to(NetworkSocket::port_type port, IpAddress const& bind_addr = "*") override; + + memory::unique_ptr take_packet_stream(); + + protected: + TcpStreamPacketClient* _take_next_client() override; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/UdpClient.cpp b/src/openvic-simulation/multiplayer/lowlevel/UdpClient.cpp new file mode 100644 index 000000000..8e194f659 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/UdpClient.cpp @@ -0,0 +1,329 @@ +#include "UdpClient.hpp" + +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/UdpServer.hpp" + +using namespace OpenVic; + +UdpClient::UdpClient() : _socket(std::make_shared()), _ring_buffer(16) {} // NOLINT + +UdpClient::~UdpClient() { + close(); +} + +NetworkError UdpClient::bind(NetworkSocket::port_type port, IpAddress const& address, size_t min_buffer_size) { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(_socket->is_open(), NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V(!address.is_valid() && !address.is_wildcard(), NetworkError::INVALID_PARAMETER); + OV_ERR_FAIL_COND_V_MSG( + port > 65535, NetworkError::INVALID_PARAMETER, "The local port number must be between 0 and 65535 (inclusive)." + ); + + NetworkError err; + NetworkResolver::Type ip_type = NetworkResolver::Type::ANY; + + if (address.is_valid()) { + ip_type = address.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + } + + err = _socket->open(NetworkSocket::Type::UDP, ip_type); + + if (err != NetworkError::OK) { + return err; + } + + _socket->set_blocking_enabled(false); + _socket->set_broadcasting_enabled(_is_broadcasting); + err = _socket->bind(address, port); + + if (err != NetworkError::OK) { + _socket->close(); + return err; + } + if (min_buffer_size < _ring_buffer.capacity()) { + _ring_buffer.shrink_to_fit(); + } + _ring_buffer.reserve_power(std::bit_ceil(min_buffer_size)); + return NetworkError::OK; +} + +void UdpClient::close() { + if (_server) { + _server->remove_client(peer_ip, peer_port); + _server = nullptr; + _socket = std::make_shared(); + } else if (_socket) { + _socket->close(); + } + _ring_buffer.reserve_power(16); + _queue_count = 0; + _is_connected = false; +} + +NetworkError UdpClient::connect_to(IpAddress const& host, NetworkSocket::port_type port) { + OV_ERR_FAIL_COND_V(_server, NetworkError::LOCKED); + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!host.is_valid(), NetworkError::INVALID_PARAMETER); + + NetworkError err; + + if (!_socket->is_open()) { + NetworkResolver::Type ip_type = host.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + err = _socket->open(NetworkSocket::Type::UDP, ip_type); + OV_ERR_FAIL_COND_V(err != NetworkError::OK, err); + _socket->set_blocking_enabled(false); + } + + err = _socket->connect_to_host(host, port); + + // I see no reason why we should get ERR_BUSY (wouldblock/eagain) here. + // This is UDP, so connect is only used to tell the OS to which socket + // it should deliver packets when multiple are bound on the same address/port. + if (err != NetworkError::OK) { + close(); + OV_ERR_FAIL_V_MSG(err, "Unable to connect"); + } + + _is_connected = true; + + peer_ip = host; + peer_port = port; + + // Flush any packet we might still have in queue. + _ring_buffer.clear(); + return NetworkError::OK; +} + +NetworkSocket::port_type UdpClient::get_bound_port() const { + NetworkSocket::port_type local_port; + _socket->get_socket_address(nullptr, &local_port); + return local_port; +} + +IpAddress UdpClient::get_last_packet_ip() const { + return _packet_ip; +} + +NetworkSocket::port_type UdpClient::get_last_packet_port() const { + return _packet_port; +} + +bool UdpClient::is_bound() const { + return _socket != nullptr && _socket->is_open(); +} + +bool UdpClient::is_connected() const { + return _is_connected; +} + +NetworkError UdpClient::wait() { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + return _socket->poll(NetworkSocket::PollType::IN, -1); +} + +NetworkError UdpClient::set_destination(HostnameAddress const& addr, NetworkSocket::port_type port) { + OV_ERR_FAIL_COND_V_MSG( + _is_connected, NetworkError::UNCONFIGURED, "Destination address cannot be set for connected sockets" + ); + peer_ip = addr.resolved_address(); + peer_port = port; + return NetworkError::OK; +} + +NetworkError UdpClient::store_packet(IpAddress p_ip, NetworkSocket::port_type p_port, std::span p_buf) { + uint16_t buffer_size = p_buf.size(); + std::span ipv6 = p_ip.get_ipv6(); + + if (_ring_buffer.capacity() - _ring_buffer.size() < + buffer_size + ipv6.size_bytes() + sizeof(p_port) + sizeof(buffer_size)) { + return NetworkError::OUT_OF_MEMORY; + } + _ring_buffer.append_range(ipv6); + _ring_buffer.append((uint8_t*)&p_port, sizeof(p_port)); + _ring_buffer.append((uint8_t*)&buffer_size, sizeof(buffer_size)); + _ring_buffer.append_range(p_buf); + ++_queue_count; + return NetworkError::OK; +} + +NetworkError UdpClient::_set_packet(std::span buffer) { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(!peer_ip.is_valid(), NetworkError::UNCONFIGURED); + + NetworkError err; + int64_t sent = -1; + + if (!_socket->is_open()) { + NetworkResolver::Type ip_type = peer_ip.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + err = _socket->open(NetworkSocket::Type::UDP, ip_type); + OV_ERR_FAIL_COND_V(err != NetworkError::OK, err); + _socket->set_blocking_enabled(false); + _socket->set_broadcasting_enabled(_is_broadcasting); + } + + do { + if (_is_connected && !_server) { + err = _socket->send(buffer.data(), buffer.size(), sent); + } else { + err = _socket->send_to(buffer.data(), buffer.size(), sent, peer_ip, peer_port); + } + if (err != NetworkError::OK) { + if (err != NetworkError::BUSY) { + return err; + } else if (!_is_blocking) { + return NetworkError::BUSY; + } + // Keep trying to send full packet + continue; + } + return NetworkError::OK; + + } while (sent != buffer.size()); + + return NetworkError::OK; +} + +NetworkError UdpClient::_get_packet(std::span& r_buffer) { + NetworkError err = _poll(); + if (err != NetworkError::OK) { + return err; + } + if (_queue_count == 0) { + return NetworkError::UNAVAILABLE; + } + + uint16_t size = 0; + std::array ipv6 {}; + _ring_buffer.read_buffer_to(ipv6.data(), sizeof(ipv6)); + _packet_ip.set_ipv6(ipv6); + _ring_buffer.read_buffer_to((uint8_t*)&_packet_port, sizeof(_packet_port)); + _ring_buffer.read_buffer_to((uint8_t*)&size, sizeof(size)); + _ring_buffer.read_buffer_to(_packet_buffer.data(), size); + --_queue_count; + r_buffer = { _packet_buffer.data(), size }; + + return NetworkError::OK; +} + +NetworkError UdpClient::join_multicast_group(IpAddress addr, std::string_view interface_name) { + OV_ERR_FAIL_COND_V(_server, NetworkError::LOCKED); + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(!addr.is_valid(), NetworkError::INVALID_PARAMETER); + + if (!_socket->is_open()) { + NetworkResolver::Type ip_type = addr.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + NetworkError err = _socket->open(NetworkSocket::Type::UDP, ip_type); + OV_ERR_FAIL_COND_V(err != NetworkError::OK, err); + _socket->set_blocking_enabled(false); + _socket->set_broadcasting_enabled(_is_broadcasting); + } + return _socket->change_multicast_group(addr, interface_name, true); +} + +NetworkError UdpClient::leave_multicast_group(IpAddress addr, std::string_view interface_name) { + OV_ERR_FAIL_COND_V(_server, NetworkError::LOCKED); + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(!_socket->is_open(), NetworkError::UNCONFIGURED); + return _socket->change_multicast_group(addr, interface_name, false); +} + +bool UdpClient::is_broadcast_enabled() const { + return _is_broadcasting; +} + +void UdpClient::set_broadcast_enabled(bool enabled) { + OV_ERR_FAIL_COND(_server); + _is_broadcasting = enabled; + if (_socket && _socket->is_open()) { + _socket->set_broadcasting_enabled(enabled); + } +} + +bool UdpClient::is_blocking() const { + return _is_blocking; +} + +void UdpClient::set_blocking(bool blocking) { + _is_blocking = blocking; +} + +int64_t UdpClient::available_packets() const { + NetworkError err = const_cast(this)->_poll(); + if (err != NetworkError::OK) { + return -1; + } + + return _queue_count; +} + +int64_t UdpClient::max_packet_size() const { + return MAX_PACKET_SIZE; +} + +NetworkError UdpClient::_poll() { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + + if (!_socket->is_open()) { + return NetworkError::FAILED; + } + if (_server) { + return NetworkError::OK; // Handled by UDPServer. + } + + NetworkError err; + int64_t read; + IpAddress ip; + NetworkSocket::port_type port; + + while (true) { + if (_is_connected) { + err = _socket->receive(_receive_buffer.data(), sizeof(_receive_buffer), read); + ip = peer_ip; + port = peer_port; + } else { + err = _socket->receive_from(_receive_buffer.data(), sizeof(_receive_buffer), read, ip, port); + } + + if (err != NetworkError::OK) { + if (err == NetworkError::BUSY) { + break; + } + return err; + } + + err = store_packet(ip, port, { _receive_buffer.data(), static_cast(read) }); +#ifdef TOOLS_ENABLED + if (err != NetworkError::OK) { + Logger::warning("Buffer full, dropping packets!"); + } +#endif + } + + return NetworkError::OK; +} + +void UdpClient::_connect_shared_socket( // + std::shared_ptr p_sock, IpAddress p_ip, NetworkSocket::port_type p_port, UdpServer* p_server +) { + _server = p_server; + _is_connected = true; + _socket = std::move(p_sock); + peer_ip = p_ip; + peer_port = p_port; + _packet_ip = peer_ip; + _packet_port = peer_port; +} + +void UdpClient::_disconnect_shared_socket() { + _server = nullptr; + _socket = std::make_shared(); + close(); +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/UdpClient.hpp b/src/openvic-simulation/multiplayer/lowlevel/UdpClient.hpp new file mode 100644 index 000000000..cdf986caa --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/UdpClient.hpp @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/HostnameAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketClient.hpp" +#include "openvic-simulation/types/RingBuffer.hpp" + +namespace OpenVic { + struct UdpServer; + struct ReliableUdpServer; + + struct UdpClient : PacketClient { + static constexpr size_t PACKET_BUFFER_SIZE = 65536; + + static constexpr size_t MAX_PACKET_SIZE = 512 * 3; + + UdpClient(); + ~UdpClient(); + + UdpClient(UdpClient&&) = default; + UdpClient& operator=(UdpClient&&) = default; + + void close(); + NetworkSocket::port_type get_bound_port() const; + bool is_connected() const; + + NetworkError connect_to(IpAddress const& host, NetworkSocket::port_type port); + NetworkError bind( // + NetworkSocket::port_type port, IpAddress const& address, size_t min_buffer_size = PACKET_BUFFER_SIZE + ); + IpAddress get_last_packet_ip() const; + NetworkSocket::port_type get_last_packet_port() const; + bool is_bound() const; + NetworkError wait(); + NetworkError set_destination(HostnameAddress const& addr, NetworkSocket::port_type port); + virtual NetworkError store_packet(IpAddress p_ip, NetworkSocket::port_type p_port, std::span p_buf); + + NetworkError join_multicast_group(IpAddress addr, std::string_view interface_name); + NetworkError leave_multicast_group(IpAddress addr, std::string_view interface_name); + + bool is_broadcast_enabled() const; + void set_broadcast_enabled(bool enabled); + + bool is_blocking() const; + void set_blocking(bool blocking); + + int64_t available_packets() const override; + int64_t max_packet_size() const override; + + protected: + NetworkError _set_packet(std::span buffer) override; + NetworkError _get_packet(std::span& r_set) override; + + NetworkError _poll(); + + bool _is_connected = false; + bool _is_blocking = true; + bool _is_broadcasting = false; + NetworkSocket::port_type _packet_port = 0; + NetworkSocket::port_type PROPERTY_ACCESS(peer_port, protected, 0); + int32_t _queue_count = 0; + IpAddress _packet_ip; + IpAddress PROPERTY_ACCESS(peer_ip, protected); + + std::shared_ptr _socket; + UdpServer* _server = nullptr; + RingBuffer _ring_buffer; + + std::array _receive_buffer; + std::array _packet_buffer; + + friend struct UdpServer; + friend struct ReliableUdpServer; + void _connect_shared_socket( // + std::shared_ptr p_sock, IpAddress p_ip, NetworkSocket::port_type p_port, UdpServer* p_server + ); + void _disconnect_shared_socket(); + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/UdpServer.cpp b/src/openvic-simulation/multiplayer/lowlevel/UdpServer.cpp new file mode 100644 index 000000000..66b9242d2 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/UdpServer.cpp @@ -0,0 +1,114 @@ + +#include "UdpServer.hpp" + +#include + +#include + +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketServer.hpp" +#include "openvic-simulation/multiplayer/lowlevel/UdpClient.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +using namespace OpenVic; + +bool UdpServer::is_connection_available() const { + if (!PacketServer::is_connection_available()) { + return false; + } + return _pending_clients.size() > 0; +} + +void UdpServer::close() { + PacketServer::close(); + for (ClientRef& client : _clients) { + client.client->_disconnect_shared_socket(); + } + for (ClientRef& client : _pending_clients) { + client.client->_disconnect_shared_socket(); + } + _clients.clear(); + _pending_clients.clear(); +} + +UdpClient* UdpServer::_take_next_client() { + if (!is_connection_available()) { + return nullptr; + }; + + ClientRef& client = _pending_clients.front(); + _clients.emplace_back(client.client.release(), client.ip, client.port); + _pending_clients.pop_front(); + return _clients.back().client; +} + +UdpClient* UdpServer::_create_client() { + return new UdpClient(); +} + +void UdpServer::remove_client(IpAddress ip, NetworkSocket::port_type port) { + const ClientRef client { nullptr, ip, port }; + if (memory::vector>::iterator it = ranges::find(_clients, client); it != _clients.end()) { + _clients.erase(it); + } +} + +NetworkError UdpServer::poll() { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(!_socket->is_open(), NetworkError::UNCONFIGURED); + + NetworkError err; + int64_t read; + IpAddress ip; + NetworkSocket::port_type port; + while (true) { + err = _socket->receive_from(_receive_buffer.data(), sizeof(_receive_buffer), read, ip, port); + if (err != NetworkError::OK) { + if (err == NetworkError::BUSY) { + break; + } + return err; + } + + ClientRef p; + p.ip = ip; + p.port = port; + UdpClient* client_ptr = nullptr; + if (decltype(_clients)::iterator it = ranges::find(_clients, p); it != _clients.end()) { + client_ptr = it->client; + } else if (decltype(_pending_clients)::iterator pend_it = ranges::find(_pending_clients, ClientRef { p }); + pend_it != _pending_clients.end()) { + client_ptr = pend_it->client.get(); + } + + if (client_ptr) { + client_ptr->store_packet(ip, port, std::span { _receive_buffer.data(), static_cast(read) }); + } else { + if (_pending_clients.size() >= _max_pending_clients) { + // Drop connection. + continue; + } + // It's a new client, add it to the pending list. + ClientRef ref { memory::unique_ptr(_create_client()), ip, port }; + ref.client->_connect_shared_socket(_socket, ip, port, this); + ref.client->store_packet(ip, port, std::span { _receive_buffer.data(), static_cast(read) }); + _pending_clients.push_back(std::move(ref)); + } + } + return NetworkError::OK; +} + +void UdpServer::set_max_pending_clients(uint32_t clients) { + OV_ERR_FAIL_COND_MSG( + clients < 0, "Max pending connections value must be a positive number (0 means refuse new connections)." + ); + _max_pending_clients = clients; + if (clients > _pending_clients.size()) { + _pending_clients.erase(_pending_clients.begin() + clients + 1, _pending_clients.end()); + } +} + +uint32_t UdpServer::get_max_pending_clients() const { + return _max_pending_clients; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/UdpServer.hpp b/src/openvic-simulation/multiplayer/lowlevel/UdpServer.hpp new file mode 100644 index 000000000..02eeda5c5 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/UdpServer.hpp @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketServer.hpp" +#include "openvic-simulation/multiplayer/lowlevel/UdpClient.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/Deque.hpp" + +namespace OpenVic { + struct UdpServer : PacketServer { + bool is_connection_available() const override; + void close() override; + NetworkError poll() override; + + void remove_client(IpAddress ip, NetworkSocket::port_type port); + + void set_max_pending_clients(uint32_t clients); + uint32_t get_max_pending_clients() const; + + protected: + UdpClient* _take_next_client() override; + virtual UdpClient* _create_client(); + + template + struct ClientRef { + static constexpr std::bool_constant is_owner {}; + + std::conditional_t, UdpClient*> client = nullptr; + IpAddress ip; + NetworkSocket::port_type port = 0; + + ClientRef() = default; + ClientRef(decltype(client)&& client, IpAddress ip, NetworkSocket::port_type port = 0) + : client { std::move(client) }, ip { ip }, port { port } {} + ClientRef(ClientRef const& ref) : client { ref.client }, ip { ref.ip }, port { ref.port } {} + ClientRef(ClientRef const&) = default; + ClientRef& operator=(ClientRef const&) = default; + ClientRef(ClientRef&&) = default; + ClientRef& operator=(ClientRef&&) = default; + + template + bool operator==(ClientRef const& p_other) const { + return (ip == p_other.ip && port == p_other.port); + } + }; + + uint32_t _max_pending_clients = 16; + + memory::vector> _clients; + OpenVic::utility::deque> _pending_clients; + + std::array _receive_buffer; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.cpp b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.cpp new file mode 100644 index 000000000..06349a50a --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.cpp @@ -0,0 +1,154 @@ +#if defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)) + +#include "UnixNetworkResolver.hpp" + +#include +#include +#include + +#include + +#include + +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/types/StackString.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/StringUtils.hpp" + +#ifdef __FreeBSD__ +#include +#endif +#include + +#include +#include + +#ifdef __FreeBSD__ +#include +#endif + +#include // Order is important on OpenBSD, leave as last. + +using namespace OpenVic; + +static IpAddress _sockaddr2ip(struct sockaddr* p_addr) { + IpAddress ip; + + if (p_addr->sa_family == AF_INET) { + struct sockaddr_in* addr = (struct sockaddr_in*)p_addr; + ip.set_ipv4(std::bit_cast>(&addr->sin_addr)); + } else if (p_addr->sa_family == AF_INET6) { + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)p_addr; + ip.set_ipv6(addr6->sin6_addr.s6_addr); + } + + return ip; +} + +UnixNetworkResolver UnixNetworkResolver::_singleton {}; + +string_map_t UnixNetworkResolver::get_local_interfaces() const { + string_map_t result; + + struct ifaddrs* ifAddrStruct = nullptr; + struct ifaddrs* ifa = nullptr; + int family; + + getifaddrs(&ifAddrStruct); + + for (ifa = ifAddrStruct; ifa != nullptr; ifa = ifa->ifa_next) { + if (!ifa->ifa_addr) { + continue; + } + + family = ifa->ifa_addr->sa_family; + + if (family != AF_INET && family != AF_INET6) { + continue; + } + + string_map_t::iterator it = result.find(ifa->ifa_name); + if (it == result.end()) { + InterfaceInfo info; + info.name = ifa->ifa_name; + info.name_friendly = ifa->ifa_name; + + struct stack_string : StackString::max())> { + using StackString::_array; + using StackString::_string_size; + using StackString::StackString; + } str {}; + std::to_chars_result to_chars = + StringUtils::to_chars(str._array.data(), str._array.data() + str.array_length, if_nametoindex(ifa->ifa_name)); + str._string_size = to_chars.ptr - str.data(); + + info.index = str; + auto pair = result.insert_or_assign(ifa->ifa_name, info); + OV_ERR_CONTINUE(!pair.second); + it = pair.first; + } + + InterfaceInfo& info = it.value(); + info.ip_addresses.push_back(_sockaddr2ip(ifa->ifa_addr)); + } + + if (ifAddrStruct != nullptr) { + freeifaddrs(ifAddrStruct); + } + + return result; +} + +memory::vector UnixNetworkResolver::_resolve_hostname(std::string_view p_hostname, Type p_type) const { + struct addrinfo hints; // NOLINT + struct addrinfo* result = nullptr; + + std::memset(&hints, 0, sizeof(struct addrinfo)); + if (p_type == Type::IPV4) { + hints.ai_family = AF_INET; + } else if (p_type == Type::IPV6) { + hints.ai_family = AF_INET6; + hints.ai_flags = 0; + } else { + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_ADDRCONFIG; + } + hints.ai_flags &= ~AI_NUMERICHOST; + + int s = getaddrinfo(p_hostname.data(), nullptr, &hints, &result); + if (s != 0) { + Logger::info("getaddrinfo failed! Cannot resolve hostname."); + return {}; + } + + if (result == nullptr || result->ai_addr == nullptr) { + Logger::info("Invalid response from getaddrinfo."); + if (result) { + freeaddrinfo(result); + } + return {}; + } + + struct addrinfo* next = result; + + memory::vector result_addrs; + do { + if (next->ai_addr == nullptr) { + next = next->ai_next; + continue; + } + IpAddress ip = _sockaddr2ip(next->ai_addr); + if (ip.is_valid() && ranges::find(result_addrs, ip) == result_addrs.end()) { + result_addrs.push_back(ip); + } + next = next->ai_next; + } while (next); + + freeaddrinfo(result); + + return result_addrs; +} + +#endif diff --git a/src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.hpp b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.hpp new file mode 100644 index 000000000..c2923b81f --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.hpp @@ -0,0 +1,32 @@ +#pragma once + +#if !(defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__))) +#error "UnixNetworkResolver.hpp should only be included on unix systems" +#endif + +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct UnixNetworkResolver final : NetworkResolverBase { + static constexpr Provider provider_value = Provider::UNIX; + + static UnixNetworkResolver& singleton() { + return _singleton; + } + + OpenVic::string_map_t get_local_interfaces() const override; + + Provider provider() const override { + return provider_value; + } + + private: + friend NetworkResolverBase::ResolveHandler; + + memory::vector _resolve_hostname(std::string_view p_hostname, Type p_type = Type::ANY) const override; + + static UnixNetworkResolver _singleton; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.cpp b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.cpp new file mode 100644 index 000000000..a64de229b --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.cpp @@ -0,0 +1,564 @@ +#if defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)) +#include "UnixSocket.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/StringUtils.hpp" + +#include +#include +#include +#include +#include + +// BSD calls this flag IPV6_JOIN_GROUP +#if !defined(IPV6_ADD_MEMBERSHIP) && defined(IPV6_JOIN_GROUP) +#define IPV6_ADD_MEMBERSHIP IPV6_JOIN_GROUP +#endif +#if !defined(IPV6_DROP_MEMBERSHIP) && defined(IPV6_LEAVE_GROUP) +#define IPV6_DROP_MEMBERSHIP IPV6_LEAVE_GROUP +#endif + +using namespace OpenVic; + +UnixSocket::UnixSocket() {} + +UnixSocket::~UnixSocket() { + close(); +} + +NetworkError UnixSocket::_get_socket_error() const { + switch (errno) { + case EISCONN: return NetworkError::IS_CONNECTED; + case EINPROGRESS: + case EALREADY: return NetworkError::IN_PROGRESS; +#if EAGAIN != EWOULDBLOCK + case EAGAIN: +#endif + case EWOULDBLOCK: return NetworkError::WOULD_BLOCK; + case EADDRINUSE: + case EINVAL: + case EADDRNOTAVAIL: return NetworkError::ADDRESS_INVALID_OR_UNAVAILABLE; + case EACCES: return NetworkError::UNAUTHORIZED; + case ENOBUFS: return NetworkError::BUFFER_TOO_SMALL; + default: // + Logger::info("Socket error: ", strerror(errno)); + return NetworkError::OTHER; + } +} + +void UnixSocket::_set_ip_port(struct sockaddr_storage* addr, IpAddress* r_ip, port_type* r_port) { + if (addr->ss_family == AF_INET) { + struct sockaddr_in* addr4 = (struct sockaddr_in*)addr; + if (r_ip) { + r_ip->set_ipv4(std::bit_cast>(&addr4->sin_addr.s_addr)); + } + if (r_port) { + *r_port = ntohs(addr4->sin_port); + } + } else if (addr->ss_family == AF_INET6) { + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)addr; + if (r_ip) { + r_ip->set_ipv6(addr6->sin6_addr.s6_addr); + } + if (r_port) { + *r_port = ntohs(addr6->sin6_port); + } + } +} + +bool UnixSocket::_can_use_ip(IpAddress const& ip, const bool for_bind) const { + if (for_bind && !(ip.is_valid() || ip.is_wildcard())) { + return false; + } else if (!for_bind && !ip.is_valid()) { + return false; + } + // Check if socket support this IP type. + NetworkResolver::Type type = ip.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + return !(_ip_type != NetworkResolver::Type::ANY && !ip.is_wildcard() && _ip_type != type); +} + +void UnixSocket::setup() {} + +void UnixSocket::cleanup() {} + +NetworkError UnixSocket::open(Type type, NetworkResolver::Type& ip_type) { + OV_ERR_FAIL_COND_V(is_open(), NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V( + ip_type > NetworkResolver::Type::ANY || ip_type < NetworkResolver::Type::NONE, NetworkError::INVALID_PARAMETER + ); + +#if defined(__OpenBSD__) + // OpenBSD does not support dual stacking, fallback to IPv4 only. + if (ip_type == NetworkResolver::Type::ANY) { + ip_type = NetworkResolver::Type::IPV4; + } +#endif + + int family = ip_type == NetworkResolver::Type::IPV4 ? AF_INET : AF_INET6; + int protocol = type == Type::TCP ? IPPROTO_TCP : IPPROTO_UDP; + int socket_type = type == Type::TCP ? SOCK_STREAM : SOCK_DGRAM; + _sock = socket(family, socket_type, protocol); + + if (_sock == -1 && ip_type == NetworkResolver::Type::ANY) { + // Careful here, changing the referenced parameter so the caller knows that we are using an IPv4 socket + // in place of a dual stack one, and further calls to _set_sock_addr will work as expected. + ip_type = NetworkResolver::Type::IPV4; + family = AF_INET; + _sock = socket(family, socket_type, protocol); + } + + OV_ERR_FAIL_COND_V(_sock == -1, NetworkError::SOCKET_ERROR); + _ip_type = ip_type; + + if (family == AF_INET6) { + // Select IPv4 over IPv6 mapping. + set_ipv6_only_enabled(ip_type != NetworkResolver::Type::ANY); + } + + if (protocol == IPPROTO_UDP) { + // Make sure to disable broadcasting for UDP sockets. + // Depending on the OS, this option might or might not be enabled by default. Let's normalize it. + set_broadcasting_enabled(false); + } + + _is_stream = type == Type::TCP; + + // Disable descriptor sharing with subprocesses. + _set_close_exec_enabled(true); + +#if defined(SO_NOSIGPIPE) + // Disable SIGPIPE (should only be relevant to stream sockets, but seems to affect UDP too on iOS). + int par = 1; + if (setsockopt(_sock, SOL_SOCKET, SO_NOSIGPIPE, &par, sizeof(int)) != 0) { + Logger::info("Unable to turn off SIGPIPE on socket."); + } +#endif + return NetworkError::OK; +} + +void UnixSocket::close() { + if (_sock != -1) { + ::close(_sock); + } + + _sock = -1; + _ip_type = NetworkResolver::Type::NONE; + _is_stream = false; +} + +NetworkError UnixSocket::bind(IpAddress const& addr, port_type p_port) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!_can_use_ip(addr, true), NetworkError::INVALID_PARAMETER); + + sockaddr_storage addr_store; // NOLINT + size_t addr_size = _set_addr_storage(&addr_store, addr, p_port, _ip_type); + + if (::bind(_sock, (struct sockaddr*)&addr_store, addr_size) != 0) { + NetworkError err = _get_socket_error(); + Logger::info("Failed to bind socket. Error: ", strerror(errno)); + close(); + return err; + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::listen(int max_pending) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + if (::listen(_sock, max_pending) != 0) { + NetworkError err = _get_socket_error(); + Logger::info("Failed to listen from socket. Error: ", strerror(errno)); + close(); + return err; + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::connect_to_host(IpAddress const& host, port_type port) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!_can_use_ip(host, false), NetworkError::INVALID_PARAMETER); + + struct sockaddr_storage addr; // NOLINT + size_t addr_size = _set_addr_storage(&addr, host, port, _ip_type); + + if (::connect(_sock, (struct sockaddr*)&addr, addr_size) != 0) { + NetworkError err = _get_socket_error(); + + switch (err) { + using enum NetworkError; + // We are already connected. + case IS_CONNECTED: return NetworkError::OK; + // Still waiting to connect, try again in a while. + case WOULD_BLOCK: + case IN_PROGRESS: return NetworkError::BUSY; + default: + Logger::info("Connection to remote host failed. Error: ", strerror(errno)); + close(); + return err; + } + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::poll(PollType type, int timeout) const { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + struct pollfd pfd { .fd = _sock, .events = POLLIN, .revents = 0 }; + + switch (type) { + using enum PollType; + case IN: pfd.events = POLLIN; break; + case OUT: pfd.events = POLLOUT; break; + case IN_OUT: pfd.events = POLLOUT | POLLIN; break; + } + + int ret = ::poll(&pfd, 1, timeout); + + if (ret < 0 || pfd.revents & POLLERR) { + NetworkError err = _get_socket_error(); + Logger::info("Error when polling socket. Error ", strerror(errno)); + return err; + } + + if (ret == 0) { + return NetworkError::BUSY; + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::receive(uint8_t* buffer, size_t p_len, int64_t& r_read) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + r_read = ::recv(_sock, buffer, p_len, 0); + + if (r_read < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::receive_from( // + uint8_t* buffer, size_t p_len, int64_t& r_read, IpAddress& r_ip, port_type& r_port, bool peek +) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + struct sockaddr_storage from; // NOLINT + socklen_t socket_length = sizeof(from); + std::memset(&from, 0, socket_length); + + r_read = ::recvfrom(_sock, buffer, p_len, peek ? MSG_PEEK : 0, (struct sockaddr*)&from, &socket_length); + + if (r_read < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + if (from.ss_family == AF_INET) { + struct sockaddr_in* sin_from = (struct sockaddr_in*)&from; + r_ip.set_ipv4(std::bit_cast>(&sin_from->sin_addr)); + r_port = ntohs(sin_from->sin_port); + } else if (from.ss_family == AF_INET6) { + struct sockaddr_in6* s6_from = (struct sockaddr_in6*)&from; + r_ip.set_ipv6(std::bit_cast>(&s6_from->sin6_addr)); + r_port = ntohs(s6_from->sin6_port); + } else { + // Unsupported socket family, should never happen. + OV_ERR_FAIL_V(NetworkError::UNSUPPORTED); + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::send(const uint8_t* buffer, size_t p_len, int64_t& r_sent) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + int flags = 0; +#ifdef MSG_NOSIGNAL + if (_is_stream) { + flags = MSG_NOSIGNAL; + } +#endif + r_sent = ::send(_sock, buffer, p_len, flags); + + if (r_sent < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::send_to(const uint8_t* buffer, size_t p_len, int64_t& r_sent, IpAddress const& ip, port_type port) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + struct sockaddr_storage addr; // NOLINT + size_t addr_size = _set_addr_storage(&addr, ip, port, _ip_type); + r_sent = ::sendto(_sock, buffer, p_len, 0, (struct sockaddr*)&addr, addr_size); + + if (r_sent < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + return NetworkError::OK; +} + +UnixSocket* UnixSocket::_accept(IpAddress& r_ip, port_type& r_port) { + OV_ERR_FAIL_COND_V(!is_open(), nullptr); + + struct sockaddr_storage their_addr; // NOLINT + socklen_t size = sizeof(their_addr); + int fd = ::accept(_sock, (struct sockaddr*)&their_addr, &size); + if (fd == -1) { + NetworkError err = _get_socket_error(); + Logger::info("Error when accepting socket connection. Error: ", strerror(errno)); + return nullptr; + } + + _set_ip_port(&their_addr, &r_ip, &r_port); + + UnixSocket* ns = new UnixSocket(); + ns->_sock = fd; + ns->_ip_type = _ip_type; + ns->_is_stream = _is_stream; + // Disable descriptor sharing with subprocesses. + ns->_set_close_exec_enabled(true); + ns->set_blocking_enabled(false); + return ns; +} + +bool UnixSocket::is_open() const { + return _sock != -1; +} + +int UnixSocket::available_bytes() const { + OV_ERR_FAIL_COND_V(!is_open(), -1); + + int len; + int ret = ioctl(_sock, FIONREAD, &len); + if (ret == -1) { + _get_socket_error(); + Logger::info("Error when checking available bytes on socket. Error: ", strerror(errno)); + return -1; + } + return len; +} + +NetworkError UnixSocket::get_socket_address(IpAddress* r_ip, port_type* r_port) const { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::NOT_OPEN); + + struct sockaddr_storage saddr; // NOLINT + socklen_t len = sizeof(saddr); + if (getsockname(_sock, (struct sockaddr*)&saddr, &len) != 0) { + NetworkError err = _get_socket_error(); + Logger::info("Error when reading local socket address. Error: ", strerror(errno)); + return err; + } + _set_ip_port(&saddr, r_ip, r_port); + return NetworkError::OK; +} + +NetworkError UnixSocket::set_broadcasting_enabled(bool enabled) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + // IPv6 has no broadcast support. + if (_ip_type == NetworkResolver::Type::IPV6) { + return NetworkError::UNSUPPORTED; + } + + int par = enabled ? 1 : 0; + if (setsockopt(_sock, SOL_SOCKET, SO_BROADCAST, &par, sizeof(int)) != 0) { + Logger::warning("Unable to change broadcast setting."); + return NetworkError::BROADCAST_CHANGE_FAILED; + } + return NetworkError::OK; +} + +void UnixSocket::set_blocking_enabled(bool enabled) { + OV_ERR_FAIL_COND(!is_open()); + + int ret = 0; + int opts = fcntl(_sock, F_GETFL); + if (enabled) { + ret = fcntl(_sock, F_SETFL, opts & ~O_NONBLOCK); + } else { + ret = fcntl(_sock, F_SETFL, opts | O_NONBLOCK); + } + + if (ret != 0) { + Logger::warning("Unable to change non-block mode."); + } +} + +void UnixSocket::set_ipv6_only_enabled(bool enabled) { + OV_ERR_FAIL_COND(!is_open()); + // This option is only available in IPv6 sockets. + OV_ERR_FAIL_COND(_ip_type == NetworkResolver::Type::IPV4); + + int par = enabled ? 1 : 0; + if (setsockopt(_sock, IPPROTO_IPV6, IPV6_V6ONLY, &par, sizeof(int)) != 0) { + Logger::warning("Unable to change IPv4 address mapping over IPv6 option."); + } +} + +void UnixSocket::set_tcp_no_delay_enabled(bool enabled) { + OV_ERR_FAIL_COND(!is_open()); + OV_ERR_FAIL_COND(!_is_stream); // Not TCP. + + int par = enabled ? 1 : 0; + if (setsockopt(_sock, IPPROTO_TCP, TCP_NODELAY, &par, sizeof(int)) < 0) { + Logger::warning("Unable to set TCP no delay option."); + } +} + +void UnixSocket::set_reuse_address_enabled(bool enabled) { + OV_ERR_FAIL_COND(!is_open()); + + int par = enabled ? 1 : 0; + if (setsockopt(_sock, SOL_SOCKET, SO_REUSEADDR, &par, sizeof(int)) < 0) { + Logger::warning("Unable to set socket REUSEADDR option."); + } +} + +NetworkError UnixSocket::change_multicast_group(IpAddress const& multi_address, std::string_view if_name, bool add) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!_can_use_ip(multi_address, false), NetworkError::INVALID_PARAMETER); + + // Need to force level and af_family to IP(v4) when using dual stacking and provided multicast group is IPv4. + NetworkResolver::Type type = + _ip_type == NetworkResolver::Type::ANY && multi_address.is_ipv4() ? NetworkResolver::Type::IPV4 : _ip_type; + // This needs to be the proper level for the multicast group, no matter if the socket is dual stacking. + int level = type == NetworkResolver::Type::IPV4 ? IPPROTO_IP : IPPROTO_IPV6; + int ret = -1; + + IpAddress if_ip; + uint32_t if_v6id = 0; + OpenVic::string_map_t if_info = NetworkResolver::singleton().get_local_interfaces(); + for (std::pair const& pair : if_info) { + NetworkResolver::InterfaceInfo const& c = pair.second; + if (c.name != if_name) { + continue; + } + + std::from_chars_result from_chars = + StringUtils::from_chars(c.index.data(), c.index.data() + c.index.size(), if_v6id); + OV_ERR_FAIL_COND_V(from_chars.ec != std::errc {}, NetworkError::SOCKET_ERROR); + if (type == NetworkResolver::Type::IPV6) { + break; // IPv6 uses index. + } + + for (IpAddress const& F : c.ip_addresses) { + if (!F.is_ipv4()) { + continue; // Wrong IP type. + } + if_ip = F; + break; + } + break; + } + + if (level == IPPROTO_IP) { + OV_ERR_FAIL_COND_V(!if_ip.is_valid(), NetworkError::INVALID_PARAMETER); + struct ip_mreq greq; // NOLINT + int sock_opt = add ? IP_ADD_MEMBERSHIP : IP_DROP_MEMBERSHIP; + std::memcpy(&greq.imr_multiaddr, multi_address.get_ipv4().data(), 4); + std::memcpy(&greq.imr_interface, if_ip.get_ipv4().data(), 4); + ret = setsockopt(_sock, level, sock_opt, (const char*)&greq, sizeof(greq)); + } else { + struct ipv6_mreq greq; // NOLINT + int sock_opt = add ? IPV6_ADD_MEMBERSHIP : IPV6_DROP_MEMBERSHIP; + std::memcpy(&greq.ipv6mr_multiaddr, multi_address.get_ipv6().data(), 16); + greq.ipv6mr_interface = if_v6id; + ret = setsockopt(_sock, level, sock_opt, (const char*)&greq, sizeof(greq)); + } + + if (ret == -1) { + return _get_socket_error(); + } + + return NetworkError::OK; +} + +size_t UnixSocket::_set_addr_storage( // + sockaddr_storage* addr, IpAddress const& ip, port_type port, NetworkResolver::Type ip_type +) { + std::memset(addr, 0, sizeof(struct sockaddr_storage)); + if (ip_type == NetworkResolver::Type::IPV6 || ip_type == NetworkResolver::Type::ANY) { // IPv6 socket. + + // IPv6 only socket with IPv4 address. + OV_ERR_FAIL_COND_V(!ip.is_wildcard() && ip_type == NetworkResolver::Type::IPV6 && ip.is_ipv4(), 0); + + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)addr; + addr6->sin6_family = AF_INET6; + addr6->sin6_port = htons(port); + if (ip.is_valid()) { + std::memcpy(&addr6->sin6_addr.s6_addr, ip.get_ipv6().data(), 16); + } else { + addr6->sin6_addr = in6addr_any; + } + return sizeof(sockaddr_in6); + } else { // IPv4 socket. + + // IPv4 socket with IPv6 address. + OV_ERR_FAIL_COND_V(!ip.is_wildcard() && !ip.is_ipv4(), 0); + + struct sockaddr_in* addr4 = (struct sockaddr_in*)addr; + addr4->sin_family = AF_INET; + addr4->sin_port = htons(port); // Short, network byte order. + + if (ip.is_valid()) { + std::memcpy(&addr4->sin_addr.s_addr, ip.get_ipv4().data(), 4); + } else { + addr4->sin_addr.s_addr = INADDR_ANY; + } + + return sizeof(sockaddr_in); + } +} + +void UnixSocket::_set_close_exec_enabled(bool enabled) { + // Enable close on exec to avoid sharing with subprocesses. Off by default on Windows. + int opts = fcntl(_sock, F_GETFD); + fcntl(_sock, F_SETFD, opts | FD_CLOEXEC); +} + +#endif diff --git a/src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.hpp b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.hpp new file mode 100644 index 000000000..3a3b33aad --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.hpp @@ -0,0 +1,70 @@ +#pragma once + +#if !(defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__))) +#error "UnixSocket.hpp should only be included on unix systems" +#endif + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp" + +struct sockaddr_storage; + +namespace OpenVic { + struct UnixSocket final : NetworkSocketBase { + static constexpr NetworkResolver::Provider provider_value = NetworkResolver::Provider::UNIX; + + UnixSocket(); + ~UnixSocket() override; + + NetworkError open(Type p_type, NetworkResolver::Type& ip_type) override; + void close() override; + NetworkError bind(IpAddress const& p_addr, port_type p_port) override; + NetworkError listen(int p_max_pending) override; + NetworkError connect_to_host(IpAddress const& p_addr, port_type p_port) override; + NetworkError poll(PollType p_type, int timeout) const override; + NetworkError receive(uint8_t* p_buffer, size_t p_len, int64_t& r_read) override; + NetworkError receive_from( // + uint8_t* p_buffer, size_t p_len, int64_t& r_read, IpAddress& r_ip, port_type& r_port, bool p_peek = false + ) override; + NetworkError send(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent) override; + NetworkError + send_to(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent, IpAddress const& p_ip, port_type p_port) override; + + bool is_open() const override; + int available_bytes() const override; + NetworkError get_socket_address(IpAddress* r_ip, port_type* r_port) const override; + + // Returns OK if the socket option has been set successfully + NetworkError set_broadcasting_enabled(bool p_enabled) override; + void set_blocking_enabled(bool p_enabled) override; + void set_ipv6_only_enabled(bool p_enabled) override; + void set_tcp_no_delay_enabled(bool p_enabled) override; + void set_reuse_address_enabled(bool p_enabled) override; + NetworkError change_multicast_group(IpAddress const& p_multi_address, std::string_view p_if_name, bool add) override; + + NetworkResolver::Provider provider() const override { + return provider_value; + } + + static void setup(); + static void cleanup(); + + protected: + UnixSocket* _accept(IpAddress& r_ip, port_type& r_port) override; + + private: + static size_t _set_addr_storage( // + sockaddr_storage* p_addr, IpAddress const& p_ip, port_type p_port, NetworkResolver::Type p_ip_type + ); + static void _set_ip_port(sockaddr_storage* addr, IpAddress* r_ip, port_type* r_port); + + NetworkError _get_socket_error() const; + bool _can_use_ip(IpAddress const& p_ip, const bool p_for_bind) const; + void _set_close_exec_enabled(bool p_enabled); + + int32_t _sock = -1; + NetworkResolver::Type _ip_type = NetworkResolver::Type::NONE; + bool _is_stream = false; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.cpp b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.cpp new file mode 100644 index 000000000..b24cf6354 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.cpp @@ -0,0 +1,183 @@ +#ifdef _WIN32 +#include "WindowsNetworkResolver.hpp" + +#include +#include + +#include +#include + +#include +#include + +#include + +#pragma comment(lib, "Ws2_32.lib") +#pragma comment(lib, "Iphlpapi.lib") + +#define WIN32_LEAN_AND_MEAN +#include +#include +// +#include +#include +#undef WIN32_LEAN_AND_MEAN +// Thank You Microsoft +#undef SOCKET_ERROR +#undef IN +#undef OUT + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/types/StackString.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/StringUtils.hpp" + +using namespace OpenVic; + +static IpAddress _sockaddr2ip(struct sockaddr* p_addr) { + IpAddress ip; + + if (p_addr->sa_family == AF_INET) { + struct sockaddr_in* addr = (struct sockaddr_in*)p_addr; + ip.set_ipv4(std::bit_cast>(&addr->sin_addr)); + } else if (p_addr->sa_family == AF_INET6) { + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)p_addr; + ip.set_ipv6(addr6->sin6_addr.s6_addr); + } + + return ip; +} + +WindowsNetworkResolver WindowsNetworkResolver::_singleton {}; + +WindowsNetworkResolver::WindowsNetworkResolver() { + WindowsSocket::setup(); +} + +WindowsNetworkResolver::~WindowsNetworkResolver() { + WindowsSocket::cleanup(); +} + +string_map_t WindowsNetworkResolver::get_local_interfaces() const { + ULONG buf_size = 1024; + IP_ADAPTER_ADDRESSES* addrs; + + while (true) { + addrs = (IP_ADAPTER_ADDRESSES*)malloc(buf_size); + int err = GetAdaptersAddresses( + AF_UNSPEC, GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER | GAA_FLAG_SKIP_FRIENDLY_NAME, + nullptr, addrs, &buf_size + ); + if (err == NO_ERROR) { + break; + } + free(addrs); + if (err == ERROR_BUFFER_OVERFLOW) { + continue; // Will go back and alloc the right size. + } + + OV_ERR_FAIL_V_MSG( + string_map_t {}, fmt::format("Call to GetAdaptersAddresses failed with error {}.", err) + ); + } + + string_map_t result; + IP_ADAPTER_ADDRESSES* adapter = addrs; + + while (adapter != nullptr) { + InterfaceInfo info; + info.name = adapter->AdapterName; + + size_t name_wlength = wcslen(adapter->FriendlyName); + int name_length = WideCharToMultiByte(CP_UTF8, 0, adapter->FriendlyName, name_wlength, 0, 0, nullptr, nullptr); + info.name_friendly.reserve(name_length * sizeof(char)); + WideCharToMultiByte( + CP_UTF8, 0, adapter->FriendlyName, name_wlength, info.name_friendly.data(), name_length, nullptr, nullptr + ); + + struct stack_string : StackString::max())> { + using StackString::_array; + using StackString::_string_size; + using StackString::StackString; + } str {}; + std::to_chars_result to_chars = + StringUtils::to_chars(str._array.data(), str._array.data() + str.array_length, adapter->IfIndex); + str._string_size = to_chars.ptr - str.data(); + + info.index = str; + + IP_ADAPTER_UNICAST_ADDRESS* address = adapter->FirstUnicastAddress; + while (address != nullptr) { + int family = address->Address.lpSockaddr->sa_family; + if (family != AF_INET && family != AF_INET6) { + continue; + } + info.ip_addresses.push_back(_sockaddr2ip(address->Address.lpSockaddr)); + address = address->Next; + } + adapter = adapter->Next; + // Only add interface if it has at least one IP. + if (info.ip_addresses.size() > 0) { + auto pair = result.insert_or_assign(info.name, info); + OV_ERR_CONTINUE(!pair.second); + } + } + + free(addrs); + + return result; +} + +memory::vector WindowsNetworkResolver::_resolve_hostname(std::string_view p_hostname, Type p_type) const { + struct addrinfo hints; // NOLINT + struct addrinfo* result = nullptr; + + std::memset(&hints, 0, sizeof(struct addrinfo)); + if (p_type == Type::IPV4) { + hints.ai_family = AF_INET; + } else if (p_type == Type::IPV6) { + hints.ai_family = AF_INET6; + hints.ai_flags = 0; + } else { + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_ADDRCONFIG; + } + hints.ai_flags &= ~AI_NUMERICHOST; + + int s = getaddrinfo(p_hostname.data(), nullptr, &hints, &result); + if (s != 0) { + Logger::info("getaddrinfo failed! Cannot resolve hostname."); + return {}; + } + + if (result == nullptr || result->ai_addr == nullptr) { + Logger::info("Invalid response from getaddrinfo."); + if (result) { + freeaddrinfo(result); + } + return {}; + } + + struct addrinfo* next = result; + + memory::vector result_addrs; + do { + if (next->ai_addr == nullptr) { + next = next->ai_next; + continue; + } + IpAddress ip = _sockaddr2ip(next->ai_addr); + if (ip.is_valid() && ranges::find(result_addrs, ip) == result_addrs.end()) { + result_addrs.push_back(ip); + } + next = next->ai_next; + } while (next); + + freeaddrinfo(result); + + return result_addrs; +} + +#endif diff --git a/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.hpp b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.hpp new file mode 100644 index 000000000..bf043bf2f --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.hpp @@ -0,0 +1,35 @@ +#pragma once + +#if !defined(_WIN32) +#error "WindowsNetworkResolver.hpp should only be included on windows systems" +#endif + +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct WindowsNetworkResolver final : NetworkResolverBase { + static constexpr Provider provider_value = Provider::WINDOWS; + + WindowsNetworkResolver(); + ~WindowsNetworkResolver(); + + static WindowsNetworkResolver& singleton() { + return _singleton; + } + + OpenVic::string_map_t get_local_interfaces() const override; + + Provider provider() const override { + return provider_value; + } + + private: + friend NetworkResolverBase::ResolveHandler; + + memory::vector _resolve_hostname(std::string_view p_hostname, Type p_type = Type::ANY) const override; + + static WindowsNetworkResolver _singleton; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.cpp b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.cpp new file mode 100644 index 000000000..766ff1ea2 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.cpp @@ -0,0 +1,583 @@ +#ifdef _WIN32 +#include "WindowsSocket.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#pragma comment(lib, "Ws2_32.lib") +#pragma comment(lib, "Mswsock.lib") + +#include +// +#include +#include +#include +#include +#include + +// Thank You Microsoft +#define WIN_SOCKET_ERROR (SOCKET)(~0) +#undef SOCKET_ERROR + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/StringUtils.hpp" + +// Workaround missing flag in MinGW +#if defined(__MINGW32__) && !defined(SIO_UDP_NETRESET) +#define SIO_UDP_NETRESET _WSAIOW(IOC_VENDOR, 15) +#endif + +using namespace OpenVic; + +WindowsSocket::WindowsSocket() {} + +WindowsSocket::~WindowsSocket() { + close(); +} + +NetworkError WindowsSocket::_get_socket_error() const { + int err = WSAGetLastError(); + switch (err) { + case WSAEISCONN: return NetworkError::IS_CONNECTED; + case WSAEINPROGRESS: + case WSAEALREADY: return NetworkError::IN_PROGRESS; + case WSAEWOULDBLOCK: return NetworkError::WOULD_BLOCK; + case WSAEADDRINUSE: + case WSAEADDRNOTAVAIL: return NetworkError::ADDRESS_INVALID_OR_UNAVAILABLE; + case WSAEACCES: return NetworkError::UNAUTHORIZED; + case WSAEMSGSIZE: + case WSAENOBUFS: return NetworkError::BUFFER_TOO_SMALL; + default: // + Logger::info("Socket error: ", strerror(errno)); + return NetworkError::OTHER; + } +} + +void WindowsSocket::_set_ip_port(sockaddr_storage* addr, IpAddress* r_ip, port_type* r_port) { + if (addr->ss_family == AF_INET) { + struct sockaddr_in* addr4 = (struct sockaddr_in*)addr; + if (r_ip) { + r_ip->set_ipv4(std::bit_cast>(&addr4->sin_addr.s_addr)); + } + if (r_port) { + *r_port = ntohs(addr4->sin_port); + } + } else if (addr->ss_family == AF_INET6) { + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)addr; + if (r_ip) { + r_ip->set_ipv6(addr6->sin6_addr.s6_addr); + } + if (r_port) { + *r_port = ntohs(addr6->sin6_port); + } + } +} + +bool WindowsSocket::_can_use_ip(IpAddress const& p_ip, const bool p_for_bind) const { + if (p_for_bind && !(p_ip.is_valid() || p_ip.is_wildcard())) { + return false; + } else if (!p_for_bind && !p_ip.is_valid()) { + return false; + } + // Check if socket support this IP type. + NetworkResolver::Type type = p_ip.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + return !(_ip_type != NetworkResolver::Type::ANY && !p_ip.is_wildcard() && _ip_type != type); +} + +void WindowsSocket::setup() { + WSADATA data; + WSAStartup(MAKEWORD(2, 2), &data); +} + +void WindowsSocket::cleanup() { + WSACleanup(); +} + +NetworkError WindowsSocket::open(Type sock_type, NetworkResolver::Type& ip_type) { + OV_ERR_FAIL_COND_V(is_open(), NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V( + ip_type > NetworkResolver::Type::ANY || ip_type < NetworkResolver::Type::NONE, NetworkError::INVALID_PARAMETER + ); + + int family = ip_type == NetworkResolver::Type::IPV4 ? AF_INET : AF_INET6; + int protocol = sock_type == Type::TCP ? IPPROTO_TCP : IPPROTO_UDP; + int type = sock_type == Type::TCP ? SOCK_STREAM : SOCK_DGRAM; + _sock = socket(family, type, protocol); + + if (_sock == INVALID_SOCKET && ip_type == NetworkResolver::Type::ANY) { + // Careful here, changing the referenced parameter so the caller knows that we are using an IPv4 socket + // in place of a dual stack one, and further calls to _set_sock_addr will work as expected. + ip_type = NetworkResolver::Type::IPV4; + family = AF_INET; + _sock = socket(family, type, protocol); + } + + OV_ERR_FAIL_COND_V(_sock == INVALID_SOCKET, NetworkError::SOCKET_ERROR); + _ip_type = ip_type; + + if (family == AF_INET6) { + // Select IPv4 over IPv6 mapping. + set_ipv6_only_enabled(ip_type != NetworkResolver::Type::ANY); + } + + if (protocol == IPPROTO_UDP) { + // Make sure to disable broadcasting for UDP sockets. + // Depending on the OS, this option might or might not be enabled by default. Let's normalize it. + set_broadcasting_enabled(false); + } + + _is_stream = sock_type == Type::TCP; + + if (!_is_stream) { + // Disable windows feature/bug reporting WSAECONNRESET/WSAENETRESET when + // recv/recvfrom and an ICMP reply was received from a previous send/sendto. + unsigned long disable = 0; + if (ioctlsocket(_sock, SIO_UDP_CONNRESET, &disable) == -1) { + Logger::info("Unable to turn off UDP WSAECONNRESET behavior on Windows."); + } + if (ioctlsocket(_sock, SIO_UDP_NETRESET, &disable) == -1) { + // This feature seems not to be supported on wine. + Logger::info("Unable to turn off UDP WSAENETRESET behavior on Windows."); + } + } + return NetworkError::OK; +} + +void WindowsSocket::close() { + if (_sock != INVALID_SOCKET) { + closesocket(_sock); + } + + _sock = INVALID_SOCKET; + _ip_type = NetworkResolver::Type::NONE; + _is_stream = false; +} + +NetworkError WindowsSocket::bind(IpAddress const& p_addr, port_type p_port) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!_can_use_ip(p_addr, true), NetworkError::INVALID_PARAMETER); + + sockaddr_storage addr; // NOLINT + size_t addr_size = _set_addr_storage(&addr, p_addr, p_port, _ip_type); + + if (::bind(_sock, (struct sockaddr*)&addr, addr_size) != 0) { + NetworkError err = _get_socket_error(); + Logger::info("Failed to bind socket. Error: ", strerror(errno)); + close(); + return err; + } + + return NetworkError::OK; +} + +NetworkError WindowsSocket::listen(int p_max_pending) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + if (::listen(_sock, p_max_pending) != 0) { + NetworkError err = _get_socket_error(); + Logger::info("Failed to listen from socket. Error: ", strerror(errno)); + close(); + return err; + } + + return NetworkError::OK; +} + +NetworkError WindowsSocket::connect_to_host(IpAddress const& p_addr, port_type p_port) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!_can_use_ip(p_addr, true), NetworkError::INVALID_PARAMETER); + + struct sockaddr_storage addr; // NOLINT + size_t addr_size = _set_addr_storage(&addr, p_addr, p_port, _ip_type); + + if (::WSAConnect(_sock, (struct sockaddr*)&addr, addr_size, nullptr, nullptr, nullptr, nullptr) != 0) { + NetworkError err = _get_socket_error(); + + switch (err) { + using enum NetworkError; + // We are already connected. + case IS_CONNECTED: return NetworkError::OK; + // Still waiting to connect, try again in a while. + case WOULD_BLOCK: + case IN_PROGRESS: return NetworkError::BUSY; + default: + Logger::info("Connection to remote host failed. Error: ", strerror(errno)); + close(); + return err; + } + } + + return NetworkError::OK; +} + +NetworkError WindowsSocket::poll(PollType p_type, int p_timeout) const { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + bool ready = false; + fd_set rd, wr, ex; + fd_set* rdp = nullptr; + fd_set* wrp = nullptr; + FD_ZERO(&rd); + FD_ZERO(&wr); + FD_ZERO(&ex); + FD_SET(_sock, &ex); + struct timeval timeout = { p_timeout / 1000, (p_timeout % 1000) * 1000 }; + // For blocking operation, pass nullptr timeout pointer to select. + struct timeval* tp = nullptr; + if (p_timeout >= 0) { + // If timeout is non-negative, we want to specify the timeout instead. + tp = &timeout; + } + +// Windows loves its idiotic macros +#pragma push_macro("IN") +#undef IN +#pragma push_macro("OUT") +#undef OUT + switch (p_type) { + using enum PollType; + case IN: + FD_SET(_sock, &rd); + rdp = &rd; + break; + case OUT: + FD_SET(_sock, &wr); + wrp = ≀ + break; + case IN_OUT: + FD_SET(_sock, &rd); + FD_SET(_sock, &wr); + rdp = &rd; + wrp = ≀ + } +#pragma pop_macro("IN") +#pragma pop_macro("OUT") + // WSAPoll is broken: https://daniel.haxx.se/blog/2012/10/10/wsapoll-is-broken/. + int ret = select(1, rdp, wrp, &ex, tp); + + if (ret == WIN_SOCKET_ERROR) { + return NetworkError::SOCKET_ERROR; + } + + if (ret == 0) { + return NetworkError::BUSY; + } + + if (FD_ISSET(_sock, &ex)) { + NetworkError err = _get_socket_error(); + Logger::info("Exception when polling socket. Error: ", strerror(errno)); + return err; + } + + if (rdp && FD_ISSET(_sock, rdp)) { + ready = true; + } + if (wrp && FD_ISSET(_sock, wrp)) { + ready = true; + } + + return ready ? NetworkError::OK : NetworkError::BUSY; +} + +NetworkError WindowsSocket::receive(uint8_t* p_buffer, size_t p_len, int64_t& r_read) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + r_read = ::recv(_sock, (char*)p_buffer, p_len, 0); + + if (r_read < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + return NetworkError::OK; +} + +NetworkError WindowsSocket::receive_from( // + uint8_t* p_buffer, size_t p_len, int64_t& r_read, IpAddress& r_ip, port_type& r_port, bool p_peek +) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + struct sockaddr_storage from; // NOLINT + socklen_t len = sizeof(from); + std::memset(&from, 0, len); + + r_read = ::recvfrom(_sock, (char*)p_buffer, p_len, p_peek ? MSG_PEEK : 0, (struct sockaddr*)&from, &len); + + if (r_read < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + if (from.ss_family == AF_INET) { + struct sockaddr_in* sin_from = (struct sockaddr_in*)&from; + r_ip.set_ipv4(std::bit_cast>(&sin_from->sin_addr)); + r_port = ntohs(sin_from->sin_port); + } else if (from.ss_family == AF_INET6) { + struct sockaddr_in6* s6_from = (struct sockaddr_in6*)&from; + r_ip.set_ipv6(std::bit_cast>(&s6_from->sin6_addr)); + r_port = ntohs(s6_from->sin6_port); + } else { + // Unsupported socket family, should never happen. + OV_ERR_FAIL_V(NetworkError::UNSUPPORTED); + } + + return NetworkError::OK; +} + +NetworkError WindowsSocket::send(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + int flags = 0; + r_sent = ::send(_sock, (const char*)p_buffer, p_len, flags); + + if (r_sent < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + return NetworkError::OK; +} + +NetworkError +WindowsSocket::send_to(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent, IpAddress const& p_ip, port_type p_port) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + struct sockaddr_storage addr; // NOLINT + size_t addr_size = _set_addr_storage(&addr, p_ip, p_port, _ip_type); + r_sent = ::sendto(_sock, (const char*)p_buffer, p_len, 0, (struct sockaddr*)&addr, addr_size); + + if (r_sent < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + return NetworkError::OK; +} + +WindowsSocket* WindowsSocket::_accept(IpAddress& r_ip, port_type& r_port) { + OV_ERR_FAIL_COND_V(!is_open(), nullptr); + + struct sockaddr_storage their_addr; // NOLINT + socklen_t size = sizeof(their_addr); + SOCKET fd = ::accept(_sock, (struct sockaddr*)&their_addr, &size); + if (fd == INVALID_SOCKET) { + NetworkError err = _get_socket_error(); + Logger::info("Error when accepting socket connection. Error: ", strerror(errno)); + return nullptr; + } + + _set_ip_port(&their_addr, &r_ip, &r_port); + + WindowsSocket* ns = new WindowsSocket(); + ns->_sock = fd; + ns->_ip_type = _ip_type; + ns->_is_stream = _is_stream; + ns->set_blocking_enabled(false); + return ns; +} + +bool WindowsSocket::is_open() const { + return _sock != INVALID_SOCKET; +} + +int WindowsSocket::available_bytes() const { + OV_ERR_FAIL_COND_V(!is_open(), -1); + + unsigned long len; + int ret = ioctlsocket(_sock, FIONREAD, &len); + if (ret == -1) { + _get_socket_error(); + Logger::info("Error when checking available bytes on socket. Error: ", strerror(errno)); + return -1; + } + return len; +} + +NetworkError WindowsSocket::get_socket_address(IpAddress* r_ip, port_type* r_port) const { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::NOT_OPEN); + + struct sockaddr_storage saddr; // NOLINT + socklen_t len = sizeof(saddr); + if (getsockname(_sock, (struct sockaddr*)&saddr, &len) != 0) { + NetworkError err = _get_socket_error(); + Logger::info("Error when reading local socket address. Error: ", strerror(errno)); + return err; + } + _set_ip_port(&saddr, r_ip, r_port); + return NetworkError::OK; +} + +NetworkError WindowsSocket::set_broadcasting_enabled(bool p_enabled) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + // IPv6 has no broadcast support. + if (_ip_type == NetworkResolver::Type::IPV6) { + return NetworkError::UNSUPPORTED; + } + + int par = p_enabled ? 1 : 0; + if (setsockopt(_sock, SOL_SOCKET, SO_BROADCAST, (const char*)&par, sizeof(int)) != 0) { + Logger::warning("Unable to change broadcast setting."); + return NetworkError::BROADCAST_CHANGE_FAILED; + } + return NetworkError::OK; +} + +void WindowsSocket::set_blocking_enabled(bool p_enabled) { + OV_ERR_FAIL_COND(!is_open()); + + int ret = 0; + unsigned long par = p_enabled ? 0 : 1; + ret = ioctlsocket(_sock, FIONBIO, &par); + if (ret != 0) { + Logger::warning("Unable to change non-block mode."); + } +} + +void WindowsSocket::set_ipv6_only_enabled(bool p_enabled) { + OV_ERR_FAIL_COND(!is_open()); + // This option is only available in IPv6 sockets. + OV_ERR_FAIL_COND(_ip_type == NetworkResolver::Type::IPV4); + + int par = p_enabled ? 1 : 0; + if (setsockopt(_sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&par, sizeof(int)) != 0) { + Logger::warning("Unable to change IPv4 address mapping over IPv6 option."); + } +} + +void WindowsSocket::set_tcp_no_delay_enabled(bool p_enabled) { + OV_ERR_FAIL_COND(!is_open()); + OV_ERR_FAIL_COND(!_is_stream); // Not TCP. + + int par = p_enabled ? 1 : 0; + if (setsockopt(_sock, IPPROTO_TCP, TCP_NODELAY, (const char*)&par, sizeof(int)) < 0) { + Logger::warning("Unable to set TCP no delay option."); + } +} + +void WindowsSocket::set_reuse_address_enabled(bool p_enabled) { + OV_ERR_FAIL_COND(!is_open()); + + // On Windows, enabling SO_REUSEADDR actually would also enable reuse port, very bad on TCP. Denying... + // Windows does not have this option, SO_REUSEADDR in this magical world means SO_REUSEPORT +} + +NetworkError WindowsSocket::change_multicast_group(IpAddress const& p_multi_address, std::string_view p_if_name, bool p_add) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!_can_use_ip(p_multi_address, false), NetworkError::INVALID_PARAMETER); + + // Need to force level and af_family to IP(v4) when using dual stacking and provided multicast group is IPv4. + NetworkResolver::Type type = + _ip_type == NetworkResolver::Type::ANY && p_multi_address.is_ipv4() ? NetworkResolver::Type::IPV4 : _ip_type; + // This needs to be the proper level for the multicast group, no matter if the socket is dual stacking. + int level = type == NetworkResolver::Type::IPV4 ? IPPROTO_IP : IPPROTO_IPV6; + int ret = -1; + + IpAddress if_ip; + uint32_t if_v6id = 0; + OpenVic::string_map_t if_info = NetworkResolver::singleton().get_local_interfaces(); + for (std::pair const& pair : if_info) { + NetworkResolver::InterfaceInfo const& c = pair.second; + if (c.name != p_if_name) { + continue; + } + + std::from_chars_result from_chars = + StringUtils::from_chars(c.index.data(), c.index.data() + c.index.size(), if_v6id); + OV_ERR_FAIL_COND_V(from_chars.ec != std::errc {}, NetworkError::SOCKET_ERROR); + if (type == NetworkResolver::Type::IPV6) { + break; // IPv6 uses index. + } + + for (IpAddress const& F : c.ip_addresses) { + if (!F.is_ipv4()) { + continue; // Wrong IP type. + } + if_ip = F; + break; + } + break; + } + + if (level == IPPROTO_IP) { + OV_ERR_FAIL_COND_V(!if_ip.is_valid(), NetworkError::INVALID_PARAMETER); + struct ip_mreq greq; // NOLINT + int sock_opt = p_add ? IP_ADD_MEMBERSHIP : IP_DROP_MEMBERSHIP; + std::memcpy(&greq.imr_multiaddr, p_multi_address.get_ipv4().data(), 4); + std::memcpy(&greq.imr_interface, if_ip.get_ipv4().data(), 4); + ret = setsockopt(_sock, level, sock_opt, (const char*)&greq, sizeof(greq)); + } else { + struct ipv6_mreq greq; // NOLINT + int sock_opt = p_add ? IPV6_ADD_MEMBERSHIP : IPV6_DROP_MEMBERSHIP; + std::memcpy(&greq.ipv6mr_multiaddr, p_multi_address.get_ipv6().data(), 16); + greq.ipv6mr_interface = if_v6id; + ret = setsockopt(_sock, level, sock_opt, (const char*)&greq, sizeof(greq)); + } + + if (ret == -1) { + return _get_socket_error(); + } + + return NetworkError::OK; +} + +size_t WindowsSocket::_set_addr_storage( + sockaddr_storage* p_addr, IpAddress const& p_ip, port_type p_port, NetworkResolver::Type p_ip_type +) { + std::memset(p_addr, 0, sizeof(struct sockaddr_storage)); + if (p_ip_type == NetworkResolver::Type::IPV6 || p_ip_type == NetworkResolver::Type::ANY) { // IPv6 socket. + + // IPv6 only socket with IPv4 address. + OV_ERR_FAIL_COND_V(!p_ip.is_wildcard() && p_ip_type == NetworkResolver::Type::IPV6 && p_ip.is_ipv4(), 0); + + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)p_addr; + addr6->sin6_family = AF_INET6; + addr6->sin6_port = htons(p_port); + if (p_ip.is_valid()) { + std::memcpy(&addr6->sin6_addr.s6_addr, p_ip.get_ipv6().data(), 16); + } else { + addr6->sin6_addr = in6addr_any; + } + return sizeof(sockaddr_in6); + } else { // IPv4 socket. + + // IPv4 socket with IPv6 address. + OV_ERR_FAIL_COND_V(!p_ip.is_wildcard() && !p_ip.is_ipv4(), 0); + + struct sockaddr_in* addr4 = (struct sockaddr_in*)p_addr; + addr4->sin_family = AF_INET; + addr4->sin_port = htons(p_port); // Short, network byte order. + + if (p_ip.is_valid()) { + std::memcpy(&addr4->sin_addr.s_addr, p_ip.get_ipv4().data(), 4); + } else { + addr4->sin_addr.s_addr = INADDR_ANY; + } + + return sizeof(sockaddr_in); + } +} + +#endif diff --git a/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.hpp b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.hpp new file mode 100644 index 000000000..83285caf3 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.hpp @@ -0,0 +1,75 @@ +#pragma once + +#if !defined(_WIN32) +#error "WindowsSocket.hpp should only be included on windows systems" +#endif + +#include +#include + +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp" + +struct sockaddr_storage; + +namespace OpenVic { + struct WindowsSocket final : NetworkSocketBase { + static constexpr NetworkResolver::Provider provider_value = NetworkResolver::Provider::WINDOWS; + + WindowsSocket(); + ~WindowsSocket() override; + + NetworkError open(Type p_type, NetworkResolver::Type& ip_type) override; + void close() override; + NetworkError bind(IpAddress const& p_addr, port_type p_port) override; + NetworkError listen(int p_max_pending) override; + NetworkError connect_to_host(IpAddress const& p_addr, port_type p_port) override; + NetworkError poll(PollType p_type, int timeout) const override; + NetworkError receive(uint8_t* p_buffer, size_t p_len, int64_t& r_read) override; + NetworkError receive_from( // + uint8_t* p_buffer, size_t p_len, int64_t& r_read, IpAddress& r_ip, port_type& r_port, bool p_peek = false + ) override; + NetworkError send(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent) override; + NetworkError + send_to(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent, IpAddress const& p_ip, port_type p_port) override; + + bool is_open() const override; + int available_bytes() const override; + NetworkError get_socket_address(IpAddress* r_ip, port_type* r_port) const override; + + // Returns OK if the socket option has been set successfully + NetworkError set_broadcasting_enabled(bool p_enabled) override; + void set_blocking_enabled(bool p_enabled) override; + void set_ipv6_only_enabled(bool p_enabled) override; + void set_tcp_no_delay_enabled(bool p_enabled) override; + void set_reuse_address_enabled(bool p_enabled) override; + NetworkError change_multicast_group(IpAddress const& p_multi_address, std::string_view p_if_name, bool add) override; + + NetworkResolver::Provider provider() const override { + return provider_value; + } + + static void setup(); + static void cleanup(); + + protected: + WindowsSocket* _accept(IpAddress& r_ip, port_type& r_port) override; + + private: + static size_t _set_addr_storage( // + sockaddr_storage* p_addr, IpAddress const& p_ip, port_type p_port, NetworkResolver::Type p_ip_type + ); + static void _set_ip_port(sockaddr_storage* addr, IpAddress* r_ip, port_type* r_port); + + NetworkError _get_socket_error() const; + bool _can_use_ip(IpAddress const& p_ip, const bool p_for_bind) const; + + UINT_PTR _sock = (UINT_PTR)(~0); + NetworkResolver::Type _ip_type = NetworkResolver::Type::NONE; + bool _is_stream = false; + }; +} diff --git a/src/openvic-simulation/types/RingBuffer.hpp b/src/openvic-simulation/types/RingBuffer.hpp new file mode 100644 index 000000000..a85ff293e --- /dev/null +++ b/src/openvic-simulation/types/RingBuffer.hpp @@ -0,0 +1,729 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/Utility.hpp" + +namespace OpenVic { + template> + struct RingBuffer { + using allocator_type = Allocator; + using allocator_traits = std::allocator_traits; + using value_type = T; + using size_type = typename allocator_traits::size_type; + using difference_type = typename allocator_traits::difference_type; + using pointer = typename allocator_traits::pointer; + using const_pointer = typename allocator_traits::const_pointer; + using reference = decltype(*pointer {}); + using const_reference = decltype(*const_pointer {}); + + template + struct _iterator { + using difference_type = typename allocator_traits::difference_type; + using size_type = typename allocator_traits::difference_type; + using value_type = typename allocator_traits::value_type; + using pointer = PointerType; + using reference = decltype(*pointer {}); + using iterator_category = std::random_access_iterator_tag; + + constexpr _iterator() noexcept = default; + _iterator(pointer data, const size_type ring_offset, const size_type ring_index, const uint8_t ring_capacity) + : _data(data), _offset(ring_offset), _index(ring_index), _capacity_power(ring_capacity) {} + + constexpr operator _iterator() const { + return _iterator(_data, _offset, _index, _capacity_power); + } + + constexpr reference operator*() const { + return _data[_ring_wrap(_offset + _index, _get_capacity(_capacity_power))]; + } + + constexpr pointer operator->() const noexcept { + return &**this; + } + + _iterator operator++(int) noexcept { + _iterator copy(*this); + operator++(); + return copy; + } + + _iterator& operator++() noexcept { + ++_index; + return *this; + } + + _iterator operator--(int) noexcept { + _iterator copy(*this); + operator--(); + return copy; + } + + _iterator& operator--() noexcept { + --_index; + return *this; + } + + _iterator& operator+=(difference_type n) noexcept { + _index += n; + return *this; + } + + constexpr _iterator operator+(difference_type n) const noexcept { + return _iterator(_data, _offset, _index + n, _capacity_power); + } + + _iterator& operator-=(difference_type n) noexcept { + _index -= n; + return *this; + } + + constexpr _iterator operator-(difference_type n) const noexcept { + assert(n <= _index); + return _iterator(_data, _offset, _index - n, _capacity_power); + } + + constexpr reference operator[](difference_type n) const { + return *(*this + n); + } + + constexpr friend difference_type operator-(_iterator const& lhs, _iterator const& rhs) noexcept { + return lhs._index > rhs._index ? lhs._index - rhs._index : -(rhs._index - lhs._index); + } + + constexpr friend _iterator operator+(difference_type lhs, _iterator const& rhs) noexcept { + return rhs + lhs; + } + + constexpr friend auto operator<=>(_iterator const& lhs, _iterator const& rhs) noexcept { + return std::tie(lhs._data, lhs._offset, lhs._index) <=> std::tie(rhs._data, rhs._offset, rhs._index); + } + + constexpr friend bool operator==(_iterator const& lhs, _iterator const& rhs) noexcept { + return &*lhs == &*rhs; + } + + private: + pointer _data {}; + + // Keeping both _offset and _index around is a little redundant, + // algorithmically, but it makes it much easier to express iterator-mutating + // operations. + + // Physical index of begin(). + size_type _offset {}; + // Logical index of this iterator. + size_type _index {}; + + uint8_t _capacity_power {}; + }; + + using iterator = _iterator; + using const_iterator = _iterator; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + explicit RingBuffer(uint8_t power, allocator_type const& allocator) : _allocator(allocator) { + reserve_power(power); + } + explicit RingBuffer(size_type power = 0) : RingBuffer(power, allocator_type {}) {} + + ~RingBuffer() { + clear(); + _deallocate(); + } + + RingBuffer(RingBuffer const& other) + : RingBuffer(other, allocator_traits::select_on_container_copy_construction(other._allocator)) {} + + RingBuffer(RingBuffer const& other, allocator_type const& allocator) : RingBuffer(other._capacity_power, allocator) { + clear(); + + for (const_reference value : other) { + push_back(value); + } + } + + RingBuffer(RingBuffer&& other) noexcept : RingBuffer(0, std::move(other._allocator)) { + _no_alloc_swap(other); + } + + RingBuffer(RingBuffer&& other, allocator_type const& allocator) : RingBuffer(0, allocator) { + if (other._allocator == allocator) { + _no_alloc_swap(other); + } else { + for (auto& element : other) { + emplace_back(std::move(element)); + } + } + } + + RingBuffer& operator=(RingBuffer const& other) { + clear(); + + if constexpr (typename allocator_traits::propagate_on_container_copy_assignment()) { + _allocator = other._allocator; + } + + for (auto const& value : other) { + push_back(value); + } + return *this; + } + + RingBuffer& operator=(RingBuffer&& other) noexcept( + allocator_traits::propagate_on_container_move_assignment::value || + std::is_nothrow_move_constructible::value + ) { + if (allocator_traits::propagate_on_container_move_assignment::value || _allocator == other._allocator) { + // We're either getting the other's allocator or they're already the same, + // so swap data in one go. + if constexpr (typename allocator_traits::propagate_on_container_move_assignment()) { + std::swap(_allocator, other._allocator); + } + _no_alloc_swap(other); + } else { + // Different allocators and can't swap them, so move elementwise. + clear(); + for (auto& element : other) { + emplace_back(std::move(element)); + } + } + + return *this; + } + + allocator_type get_allocator() const { + return _allocator; + } + reference front() { + return at(0); + } + reference back() { + return at(size() - 1); + } + const_reference back() const { + return at(size() - 1); + } + + const_reference operator[](const size_type index) const { + return _data[_ring_wrap(_offset + index, capacity())]; + } + reference operator[](const size_type index) { + return _data[_ring_wrap(_offset + index, capacity())]; + } + + const_reference at(const size_type index) const { + if (OV_unlikely(index >= size())) { + std::abort(); + } + return (*this)[index]; + } + reference at(const size_type index) { + if (OV_unlikely(index >= size())) { + std::abort(); + } + return (*this)[index]; + } + + iterator begin() noexcept { + return iterator(&_data[0], _offset, 0, _capacity_power); + } + iterator end() noexcept { + return iterator(&_data[0], _offset, size(), _capacity_power); + } + const_iterator begin() const noexcept { + return const_iterator(&_data[0], _offset, 0, _capacity_power); + } + const_iterator end() const noexcept { + return const_iterator(&_data[0], _offset, size(), _capacity_power); + } + + const_iterator cbegin() const noexcept { + return const_cast(*this).begin(); + } + const_iterator cend() const noexcept { + return const_cast(*this).end(); + } + + reverse_iterator rbegin() noexcept { + return reverse_iterator(end()); + } + reverse_iterator rend() noexcept { + return reverse_iterator(begin()); + } + const_reverse_iterator rbegin() const noexcept { + return const_reverse_iterator(end()); + } + const_reverse_iterator rend() const noexcept { + return const_reverse_iterator(begin()); + } + + const_reverse_iterator crbegin() const noexcept { + return const_cast(*this).rbegin(); + } + const_reverse_iterator crend() const noexcept { + return const_cast(*this).rend(); + } + + bool empty() const noexcept { + return size() == 0; + } + + size_type size() const noexcept { + return _size; + } + + size_type capacity() const noexcept { + return _get_capacity(_capacity_power); + } + + size_type max_size() const noexcept { + return std::min(allocator_traits::max_size(_allocator), std::numeric_limits::max() / sizeof(value_type)); + } + + size_type space() const noexcept { + return capacity() - size(); + } + + void reserve_power(uint8_t power) { + // Will cause overflow + OV_ERR_FAIL_COND(power > 64); + + if (OV_unlikely(_capacity_power != 0) && power <= _capacity_power) { + return; + } + + size_type new_capacity = _get_capacity(power); + pointer result = _allocate(new_capacity); + if (!empty()) { + pointer last = std::uninitialized_copy_n(_data + _offset, capacity() + 1 - _offset, result); + if (_offset > size()) { + std::uninitialized_copy_n(_data, _offset - size(), last); + } + static_assert(std::is_destructible_v, "value type is destructible"); + if constexpr (!std::is_trivially_destructible_v) { + for (pointer first = _data; first != _data + _size; first++) { + allocator_traits::destroy(_allocator, first); + } + } + } + _deallocate(); + _offset = 0; + _next = size(); + _data = result; + _capacity_power = power; + } + + void shrink_to_fit() { + if (_capacity_power > 0 && empty()) { + _deallocate(); + _data = _allocate(0); + _size = 0; + _offset = 0; + _next = 0; + _capacity_power = 0; + return; + } + + size_type new_capacity = std::bit_ceil(size()); + pointer result = _allocate(new_capacity); + pointer last = std::uninitialized_copy_n(_data + _offset, capacity() + 1 - _offset, result); + if (_offset > size()) { + std::uninitialized_copy_n(_data, _offset - size(), last); + } + static_assert(std::is_destructible_v, "value type is destructible"); + if constexpr (!std::is_trivially_destructible_v) { + for (pointer first = _data; first != _data + _size; first++) { + allocator_traits::destroy(_allocator, first); + } + } + _deallocate(); + _offset = 0; + _next = size(); + _data = result; + _capacity_power = std::bit_width(new_capacity) - 1; + } + + void push_front(const_reference value) { + emplace_front(value); + } + void push_front(value_type&& value) { + emplace_front(std::move(value)); + } + + template + reference emplace_front(Args&&... args) { + if (capacity() == 0) { + // A buffer of size zero is conceptually sound, so let's support it. + return (*this)[0]; + } + + allocator_traits::construct(_allocator, &_data[_decrement(_offset)], std::forward(args)...); + + // If required, make room for next time. + if (size() == capacity()) { + pop_back(); + } + _grow_front(); + return (*this)[0]; + } + + void push_back(const_reference value) { + emplace_back(value); + } + void push_back(value_type&& value) { + emplace_back(std::move(value)); + } + + template + reference emplace_back(Args&&... args) { + if (capacity() == 0) { + // A buffer of size zero is conceptually sound, so let's support it. + return (*this)[0]; + } + + allocator_traits::construct(_allocator, &_data[_next], std::forward(args)...); + + // If required, make room for next time. + if (size() == capacity()) { + pop_front(); + } + _grow_back(); + return (*this)[size() - 1]; + } + + template + iterator append(InputIt first, InputIt last) { + using distance_type = typename std::iterator_traits::difference_type; + + const size_type _capacity = capacity(); + const distance_type distance = std::distance(first, last); + if (OV_unlikely(distance <= 0)) { + return end(); + } + + // Limit the number of elements to append at _capacity + const size_type append_count = std::min(distance, _capacity); + + // If appending would exceed capacity, remove elements from front + const size_type excess = size() + append_count > _capacity ? size() + append_count - _capacity : 0; + if (excess > 0) { + // Destroy elements that will be overwritten + if constexpr (!std::is_trivially_destructible_v) { + size_type pos = _offset; + for (size_type i = 0; i < excess; ++i) { + allocator_traits::destroy(_allocator, _data + pos); + pos = _ring_wrap(pos + 1, _capacity); + } + } + _offset = _ring_wrap(_offset + excess, _capacity); + const size_type overflow_check = _size - excess; + if (OV_unlikely(overflow_check > _size)) { + _size = 0; + } else { + _size = overflow_check; + } + } + + // Determine physical position for appending + const size_type write_pos = _ring_wrap(_offset + _size, _capacity); + size_type space_to_end = _capacity - write_pos; + + iterator result = iterator(_data, _offset, _size, _capacity_power); + if (OV_likely(append_count <= space_to_end)) { + // Single copy to the back + std::uninitialized_copy_n(first, append_count, _data + write_pos); + } else { + // Split copy: part to the end, part to the beginning + ++space_to_end; + InputIt mid = first; + std::advance(mid, space_to_end); + std::uninitialized_copy(first, mid, _data + write_pos); + std::uninitialized_copy_n(mid, append_count - space_to_end, _data); + } + + _size += append_count; + _next = _ring_wrap(_offset + _size, _capacity); + return result; + } + template + iterator append(InputIt first, size_type count) { + return append(first, first + count); + } + + template + iterator append_range(Range&& range, size_type write_size = std::numeric_limits::max()) { + if (write_size < capacity()) { + auto end = ranges::begin(range); + ranges::advance(end, std::min(ranges::distance(range), write_size)); + return append(ranges::begin(range), end); + } + return append(ranges::begin(range), ranges::end(range)); + } + + value_type read() { + if (empty()) { + if constexpr (std::is_default_constructible_v) { + return {}; + } + std::abort(); + } + value_type result = (*this)[0]; + pop_front(); + return result; + } + + std::span read_buffer_to( // + pointer r_buffer, size_type read_size = std::numeric_limits::max() + ) { + read_size = std::min(read_size, capacity()); + iterator last = begin(); + ranges::advance(last, read_size); + ranges::move(std::make_move_iterator(begin()), std::make_move_iterator(last), r_buffer); + erase(begin(), read_size); + return std::span { r_buffer, read_size }; + } + + memory::vector read_buffer(size_type read_size) { + read_size = std::min(read_size, capacity()); + iterator last = begin(); + ranges::advance(last, read_size); + memory::vector result { std::make_move_iterator(begin()), std::make_move_iterator(last) }; + erase(begin(), read_size); + return result; + } + + + ranges::subrange read_range(size_type read_size) { + read_size = std::min(read_size, capacity()); + iterator last = begin(); + ranges::advance(last, read_size); + return { begin(), last }; + } + + void pop_front() noexcept { + if (empty()) { + return; + } + + if constexpr (!std::is_trivially_destructible_v) { + allocator_traits::destroy(_allocator, &_data[_offset]); + } + _shrink_front(); + } + void pop_back() noexcept { + if (empty()) { + return; + } + + _shrink_back(); + if constexpr (!std::is_trivially_destructible_v) { + allocator_traits::destroy(_allocator, &_data[_next]); + } + } + + void clear() noexcept { + if constexpr (!std::is_trivially_destructible_v) { + if (empty()) { + return; + } + for (pointer first = _data; first != _data + _size; first++) { + allocator_traits::destroy(_allocator, first); + } + } + _size = 0; + _offset = 0; + _next = 0; + } + + iterator erase(const_iterator from, const_iterator to) noexcept( + noexcept(pop_front()) && std::is_nothrow_move_assignable::value + ) { + if (OV_unlikely(from > end() || to > end())) { + return std::bit_cast(from); + } + + if (from == to) { + return iterator(_data, _offset, from - begin(), _capacity_power); + } + + const difference_type erase_count = to - from; + if (erase_count == 0) { + return iterator(_data, _offset, from - begin(), _capacity_power); + } + + const difference_type leading = from - begin(); + const difference_type trailing = end() - to; + const size_type _capacity = capacity(); + + iterator result = iterator(_data, _offset, from - begin(), _capacity_power); + + if (leading <= trailing) { + // Shift elements from the front towards the erasure point + const size_type dest_pos = _ring_wrap(_offset, _capacity); + const size_type src_pos = _ring_wrap(_offset + erase_count, _capacity); + size_type count = leading; + + // Move elements in one or two segments depending on wrap-around + if (dest_pos <= src_pos || src_pos + leading <= _capacity + 1) { + std::move(_data + src_pos, _data + src_pos + leading, _data + dest_pos); + } else { + size_type first_segment = _capacity + 1 - src_pos; + std::move(_data + src_pos, _data + src_pos + first_segment, _data + dest_pos); + std::move(_data, _data + (leading - first_segment), _data + dest_pos + first_segment); + } + + // Destroy elements at the front + if constexpr (!std::is_trivially_destructible_v) { + for (size_type i = 0; i < erase_count; ++i) { + allocator_traits::destroy(_allocator, _data + _offset); + _offset = _ring_wrap(_offset + 1, _capacity); + } + } else { + _offset = _ring_wrap(_offset + erase_count, _capacity); + } + _size -= erase_count; + } else if (trailing >= 0) { + // Shift elements from the back towards the erasure point + const size_type dest_pos = _ring_wrap(_offset + leading, _capacity); + const size_type src_pos = _ring_wrap(_offset + leading + erase_count, _capacity); + + // Move elements in one or two segments depending on wrap-around + if (dest_pos <= src_pos || src_pos + trailing <= _capacity + 1) { + std::move(_data + src_pos, _data + src_pos + trailing, _data + dest_pos); + } else { + const size_type first_segment = _capacity + 1 - src_pos; + std::move(_data + src_pos, _data + src_pos + first_segment, _data + dest_pos); + std::move(_data, _data + (trailing - first_segment), _data + dest_pos + first_segment); + } + + // Destroy elements at the back + if constexpr (!std::is_trivially_destructible_v) { + for (size_type i = 0; i < erase_count; ++i) { + _next = _ring_wrap(_next - 1, _capacity); + allocator_traits::destroy(_allocator, _data + _next); + } + } else { + _next = _next - erase_count; + } + _size += erase_count; + _offset -= erase_count + 1; + } + + return result; + } + iterator erase(const_iterator pos, size_type count) noexcept(noexcept(erase(pos, pos + count))) { + const_iterator last = pos; + std::advance(last, count); + return erase(pos, last); + } + iterator erase(const_iterator pos) noexcept(noexcept(erase(pos, 1))) { + return erase(pos, 1); + } + + void swap(RingBuffer& other) noexcept { + if constexpr (typename allocator_traits::propagate_on_container_swap()) { + std::swap(_allocator, other._allocator); + } + _no_alloc_swap(other); + } + + friend auto operator<=>(RingBuffer const& lhs, RingBuffer const& rhs) { + return std::lexicographical_compare_three_way(lhs.begin(), lhs.end(), rhs.begin(), rhs.end()); + } + + friend bool operator==(RingBuffer const& lhs, RingBuffer const& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + + return std::equal(lhs.begin(), lhs.end(), rhs.begin()); + } + + private: + void _no_alloc_swap(RingBuffer& other) { + std::swap(_data, other._data); + std::swap(_next, other._next); + std::swap(_offset, other._offset); + std::swap(_size, other._size); + std::swap(_capacity_power, other._capacity_power); + } + + size_type _increment(size_type& index, size_type amount = 1) { + return index = (index + amount - 1U < capacity() ? index + amount : 0); + } + + size_type _decrement(size_type& index, size_type amount = 1) { + return index = ((index - amount) + 1U > 0 ? index - amount : capacity()); + } + + void _grow_front() { + _decrement(_offset); + ++_size; + } + + void _grow_back() { + _increment(_next); + ++_size; + } + + void _shrink_front() { + _increment(_offset); + --_size; + } + + void _shrink_back() { + _decrement(_next); + --_size; + } + + pointer _allocate(const size_type new_capacity) { + return allocator_traits::allocate(_allocator, new_capacity + 1); + } + + void _deallocate() { + allocator_traits::deallocate(_allocator, _data, capacity() + 1); + } + + static constexpr size_type _ring_wrap(const size_type ring_index, const size_type ring_capacity) { + return (ring_index <= ring_capacity) ? ring_index : ring_index - ring_capacity - 1; + } + + static constexpr size_type _get_capacity(const uint8_t capacity_power) { + return (1U << static_cast(capacity_power)); + } + + // The start of the dynamically allocated backing array. + pointer _data = nullptr; + // The next position to write to for push_back(). + size_type _next = 0U; + + // Start of the ring buffer in data_. + size_type _offset = 0U; + // The number of elements in the ring buffer (distance between begin() and + // end()). + size_type _size = 0U; + // The power used to calculate the capacity of the ring buffer + uint8_t _capacity_power = 0U; + + // The allocator is used to allocate memory, and to construct and destroy + // elements. + [[no_unique_address]] allocator_type _allocator {}; + }; +} diff --git a/src/openvic-simulation/utility/ErrorMacros.hpp b/src/openvic-simulation/utility/ErrorMacros.hpp index c8ffb7a31..2cd2d7b19 100644 --- a/src/openvic-simulation/utility/ErrorMacros.hpp +++ b/src/openvic-simulation/utility/ErrorMacros.hpp @@ -3,10 +3,50 @@ #include "openvic-simulation/utility/Logger.hpp" // IWYU pragma: keep for macros #include "openvic-simulation/utility/Utility.hpp" +/** + * Try using `ERR_FAIL_COND_MSG`. + * Only use this macro if more complex error detection or recovery is required. + * + * Prints `m_msg`, and the current function returns. + */ +#define OV_ERR_FAIL_MSG(m_msg) \ + if (true) { \ + ::OpenVic::Logger::error("Method/function failed. ", m_msg); \ + return; \ + } else \ + ((void)0) + +/** + * Try using `ERR_FAIL_COND_V_MSG` or `ERR_FAIL_V_MSG`. + * Only use this macro if more complex error detection or recovery is required, and + * there is no sensible error message. + * + * The current function returns `m_retval`. + */ +#define OV_ERR_FAIL_V(m_retval) \ + if (true) { \ + ::OpenVic::Logger::error("Method/function failed. Returning: " _OV_STR(m_retval)); \ + return m_retval; \ + } else \ + ((void)0) + +/** + * Try using `ERR_FAIL_COND_V_MSG`. + * Only use this macro if more complex error detection or recovery is required. + * + * Prints `m_msg`, and the current function returns `m_retval`. + */ +#define OV_ERR_FAIL_V_MSG(m_retval, m_msg) \ + if (true) { \ + ::OpenVic::Logger::error("Method/function failed. Returning: " _OV_STR(m_retval), " ", m_msg); \ + return m_retval; \ + } else \ + ((void)0) + /** * Try using `ERR_FAIL_COND_MSG`. * Only use this macro if there is no sensible error message. - * If checking for null use ERR_FAIL_NULL_MSG instead. + * If checking for null use OV_ERR_FAIL_NULL_MSG instead. * If checking index bounds use ERR_FAIL_INDEX_MSG instead. * * Ensures `m_cond` is false. @@ -23,7 +63,7 @@ * Ensures `m_cond` is false. * If `m_cond` is true, prints `m_msg` and the current function returns. * - * If checking for null use ERR_FAIL_NULL_MSG instead. + * If checking for null use OV_ERR_FAIL_NULL_MSG instead. * If checking index bounds use ERR_FAIL_INDEX_MSG instead. */ #define OV_ERR_FAIL_COND_MSG(m_cond, m_msg) \ @@ -36,7 +76,7 @@ /** * Try using `ERR_FAIL_COND_V_MSG`. * Only use this macro if there is no sensible error message. - * If checking for null use ERR_FAIL_NULL_V_MSG instead. + * If checking for null use OV_ERR_FAIL_NULL_V_MSG instead. * If checking index bounds use ERR_FAIL_INDEX_V_MSG instead. * * Ensures `m_cond` is false. @@ -44,9 +84,7 @@ */ #define OV_ERR_FAIL_COND_V(m_cond, m_retval) \ if (OV_unlikely(m_cond)) { \ - ::OpenVic::Logger::error( \ - "Condition \"" _OV_STR(m_cond) "\" is true. Returning: " _OV_STR(m_retval) \ - ); \ + ::OpenVic::Logger::error("Condition \"" _OV_STR(m_cond) "\" is true. Returning: " _OV_STR(m_retval)); \ return m_retval; \ } else \ ((void)0) @@ -55,14 +93,136 @@ * Ensures `m_cond` is false. * If `m_cond` is true, prints `m_msg` and the current function returns `m_retval`. * - * If checking for null use ERR_FAIL_NULL_V_MSG instead. + * If checking for null use OV_ERR_FAIL_NULL_V_MSG instead. * If checking index bounds use ERR_FAIL_INDEX_V_MSG instead. */ #define OV_ERR_FAIL_COND_V_MSG(m_cond, m_retval, m_msg) \ if (OV_unlikely(m_cond)) { \ + ::OpenVic::Logger::error("Condition \"" _OV_STR(m_cond) "\" is true. Returning: " _OV_STR(m_retval) " ", m_msg); \ + return m_retval; \ + } else \ + ((void)0) + +/** + * Try using `ERR_FAIL_INDEX_MSG`. + * Only use this macro if there is no sensible error message. + * + * Ensures an integer index `m_index` is less than `m_size` and greater than or equal to 0. + * If not, the current function returns. + */ +#define OV_ERR_FAIL_INDEX(m_index, m_size) \ + if (OV_unlikely((m_index) < 0 || (m_index) >= (m_size))) { \ + ::OpenVic::Logger::error( \ + "Index " _OV_STR(m_index) " = ", m_size, " is out of bounds (" _OV_STR(m_size) " = ", m_size, ")." \ + ); \ + return; \ + } else \ + ((void)0) + +/** + * Ensures an integer index `m_index` is less than `m_size` and greater than or equal to 0. + * If not, prints `m_msg` and the current function returns. + */ +#define OV_ERR_FAIL_INDEX_MSG(m_index, m_size, m_msg) \ + if (OV_unlikely((m_index) < 0 || (m_index) >= (m_size))) { \ ::OpenVic::Logger::error( \ - "Condition \"" _OV_STR(m_cond) "\" is true. Returning: " _OV_STR(m_retval) " ", m_msg \ + "Index " _OV_STR(m_index) " = ", m_size, " is out of bounds (" _OV_STR(m_size) " = ", m_size, "). ", m_msg \ ); \ + return; \ + } else \ + ((void)0) + +/** + * Try using `ERR_FAIL_INDEX_V_MSG`. + * Only use this macro if there is no sensible error message. + * + * Ensures an integer index `m_index` is less than `m_size` and greater than or equal to 0. + * If not, the current function returns `m_retval`. + */ +#define OV_ERR_FAIL_INDEX_V(m_index, m_size, m_retval) \ + if (OV_unlikely((m_index) < 0 || (m_index) >= (m_size))) { \ + ::OpenVic::Logger::error( \ + "Index " _OV_STR(m_index) " = ", m_size, " is out of bounds (" _OV_STR(m_size) " = ", m_size, \ + "). Returning: " _OV_STR(m_retval) \ + ); \ + return m_retval; \ + } else \ + ((void)0) + +/** + * Ensures an integer index `m_index` is less than `m_size` and greater than or equal to 0. + * If not, prints `m_msg` and the current function returns `m_retval`. + */ +#define OV_ERR_FAIL_INDEX_V_MSG(m_index, m_size, m_retval, m_msg) \ + if (OV_unlikely((m_index) < 0 || (m_index) >= (m_size))) { \ + ::OpenVic::Logger::error( \ + "Index " _OV_STR(m_index) " = ", m_size, " is out of bounds (" _OV_STR(m_size) " = ", m_size, \ + "). Returning: " _OV_STR(m_retval) " ", m_msg \ + ); \ + return m_retval; \ + } else \ + ((void)0) + +#define OV_ERR_CONTINUE(m_cond) \ + if (OV_unlikely(m_cond)) { \ + ::OpenVic::Logger::error("Condition \"" _OV_STR(m_cond) "\" is true. Continuing."); \ + continue; \ + } else \ + ((void)0) + +#define OV_ERR_BREAK(m_cond) \ + if (OV_unlikely(m_cond)) { \ + ::OpenVic::Logger::error("Condition \"" _OV_STR(m_cond) "\" is true. Breaking."); \ + break; \ + } else \ + ((void)0) + +/** + * Try using `ERR_FAIL_NULL_MSG`. + * Only use this macro if there is no sensible error message. + * + * Ensures a pointer `m_param` is not null. + * If it is null, the current function returns. + */ +#define OV_ERR_FAIL_NULL(m_param) \ + if (OV_unlikely(m_param == nullptr)) { \ + ::OpenVic::Logger::error("Parameter \"" _OV_STR(m_param) "\" is null."); \ + return; \ + } else \ + ((void)0) + +/** + * Ensures a pointer `m_param` is not null. + * If it is null, prints `m_msg` and the current function returns. + */ +#define OV_ERR_FAIL_NULL_MSG(m_param, m_msg) \ + if (OV_unlikely(m_param == nullptr)) { \ + ::OpenVic::Logger::error("Parameter \"" _OV_STR(m_param) "\" is null.", m_msg); \ + return; \ + } else \ + ((void)0) + +/** + * Try using `ERR_FAIL_NULL_V_MSG`. + * Only use this macro if there is no sensible error message. + * + * Ensures a pointer `m_param` is not null. + * If it is null, the current function returns `m_retval`. + */ +#define OV_ERR_FAIL_NULL_V(m_param, m_retval) \ + if (OV_unlikely(m_param == nullptr)) { \ + ::OpenVic::Logger::error("Parameter \"" _OV_STR(m_param) "\" is null."); \ + return m_retval; \ + } else \ + ((void)0) + +/** + * Ensures a pointer `m_param` is not null. + * If it is null, prints `m_msg` and the current function returns `m_retval`. + */ +#define OV_ERR_FAIL_NULL_V_MSG(m_param, m_retval, m_msg) \ + if (OV_unlikely(m_param == nullptr)) { \ + ::OpenVic::Logger::error("Parameter \"" _OV_STR(m_param) "\" is null.", m_msg); \ return m_retval; \ } else \ ((void)0) diff --git a/src/openvic-simulation/utility/Marshal.hpp b/src/openvic-simulation/utility/Marshal.hpp new file mode 100644 index 000000000..e064da459 --- /dev/null +++ b/src/openvic-simulation/utility/Marshal.hpp @@ -0,0 +1,574 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvic-simulation/types/Colour.hpp" +#include "openvic-simulation/types/Date.hpp" +#include "openvic-simulation/types/Vector.hpp" +#include "openvic-simulation/types/fixed_point/FixedPoint.hpp" +#include "openvic-simulation/utility/Deque.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/Utility.hpp" + +namespace OpenVic::utility { + OV_ALWAYS_INLINE uint32_t halfbits_to_floatbits(uint16_t p_half) { + uint16_t h_exp, h_sig; + uint32_t f_sgn, f_exp, f_sig; + + h_exp = (p_half & 0x7c00u); + f_sgn = ((uint32_t)p_half & 0x8000u) << 16; + switch (h_exp) { + case 0x0000u: /* 0 or subnormal */ + h_sig = (p_half & 0x03ffu); + /* Signed zero */ + if (h_sig == 0) { + return f_sgn; + } + /* Subnormal */ + h_sig <<= 1; + while ((h_sig & 0x0400u) == 0) { + h_sig <<= 1; + h_exp++; + } + f_exp = ((uint32_t)(127 - 15 - h_exp)) << 23; + f_sig = ((uint32_t)(h_sig & 0x03ffu)) << 13; + return f_sgn + f_exp + f_sig; + case 0x7c00u: /* inf or NaN */ + /* All-ones exponent and a copy of the significand */ + return f_sgn + 0x7f800000u + (((uint32_t)(p_half & 0x03ffu)) << 13); + default: /* normalized */ + /* Just need to adjust the exponent and shift */ + return f_sgn + (((uint32_t)(p_half & 0x7fffu) + 0x1c000u) << 13); + } + } + + OV_ALWAYS_INLINE float halfptr_to_float(const uint16_t* p_half) { + return std::bit_cast(halfbits_to_floatbits(*p_half)); + } + + OV_ALWAYS_INLINE float half_to_float(const uint16_t p_half) { + return halfptr_to_float(&p_half); + } + + OV_ALWAYS_INLINE uint16_t make_half_float(float p_value) { + uint32_t x = std::bit_cast(p_value); + uint32_t sign = (unsigned short)(x >> 31); + uint32_t mantissa; + uint32_t exponent; + uint16_t hf; + + // get mantissa + mantissa = x & ((1 << 23) - 1); + // get exponent bits + exponent = x & (0xFF << 23); + if (exponent >= 0x47800000) { + // check if the original single precision float number is a NaN + if (mantissa && (exponent == (0xFF << 23))) { + // we have a single precision NaN + mantissa = (1 << 23) - 1; + } else { + // 16-bit half-float representation stores number as Inf + mantissa = 0; + } + hf = (((uint16_t)sign) << 15) | (uint16_t)((0x1F << 10)) | (uint16_t)(mantissa >> 13); + } + // check if exponent is <= -15 + else if (exponent <= 0x38000000) { + /* + // store a denorm half-float value or zero + exponent = (0x38000000 - exponent) >> 23; + mantissa >>= (14 + exponent); + + hf = (((uint16_t)sign) << 15) | (uint16_t)(mantissa); + */ + hf = 0; // denormals do not work for 3D, convert to zero + } else { + hf = (((uint16_t)sign) << 15) | (uint16_t)((exponent - 0x38000000) >> 13) | (uint16_t)(mantissa >> 13); + } + + return hf; + } + + struct half_float { + float value; + constexpr half_float() = default; + constexpr half_float(float value) : value(value) {} + + constexpr operator float&() { + return value; + } + + constexpr operator float const&() const { + return value; + } + + constexpr operator bool() const { + return value; + } + + constexpr half_float operator+() const { + return +value; + } + + constexpr half_float operator-() const { + return -value; + } + + constexpr half_float operator!() const { + return !value; + } + + constexpr half_float operator+(half_float const& rhs) const { + return value + rhs.value; + } + + constexpr half_float operator-(half_float const& rhs) const { + return value - rhs.value; + } + + constexpr half_float operator*(half_float const& rhs) const { + return value * rhs.value; + } + + constexpr half_float operator/(half_float const& rhs) const { + return value / rhs.value; + } + + constexpr half_float& operator+=(half_float const& rhs) { + value += rhs.value; + return *this; + } + + constexpr half_float& operator-=(half_float const& rhs) { + value -= rhs.value; + return *this; + } + + constexpr half_float& operator*=(half_float const& rhs) { + value *= rhs.value; + return *this; + } + + constexpr half_float& operator/=(half_float const& rhs) { + value /= rhs.value; + return *this; + } + + constexpr bool operator==(half_float const& rhs) const { + return value == rhs.value; + } + + constexpr auto operator<=>(half_float const& rhs) const { + return value <=> rhs.value; + } + }; + + static_assert(std::endian::native != std::endian::big || std::endian::native != std::endian::little, "HOW?!?!?!"); + + template + concept HasEncode = requires(T const& value, std::span span) { + { value.template encode(span) } -> std::same_as; + }; + + template + concept Encodable = + HasEncode || utility::specialization_of || std::integral || std::floating_point || + std::is_enum_v || std::convertible_to || std::same_as || + std::same_as || std::same_as || IsColour || std::same_as || + std::convertible_to> || utility::specialization_of || + utility::specialization_of || utility::specialization_of_span || + utility::specialization_of || utility::specialization_of || + utility::specialization_of || utility::specialization_of || + utility::specialization_of || (std::is_empty_v && std::is_default_constructible_v); + + template + static constexpr std::integral_constant endian_tag {}; + + template + requires Encodable || Encodable> + static inline size_t encode( // + T const& value, std::span span = {}, std::integral_constant endian = {} + ) { + if constexpr (std::is_const_v || std::is_volatile_v || std::is_reference_v) { + return encode, Endian>(value, span); + } else if constexpr (HasEncode) { + return value.template encode(span); + } else if constexpr (utility::specialization_of) { + const size_t index = value.index(); + const size_t index_size = encode(index); + const size_t size = // + std::visit( + [](auto& arg) -> size_t { + return encode(arg); + }, + value + ) + + index_size; + + if (span.empty()) { + return size; + } else if (span.size() < size) { + return 0; + } + + size_t offset = encode(index, span); + std::visit( + [&](auto& arg) { + encode(arg, span.subspan(offset)); + }, + value + ); + + return size; + } else if constexpr (std::same_as) { + return encode(static_cast(value), span); + } else if constexpr (std::integral) { + using unsigned_type = std::make_unsigned_t; + + if (!span.empty()) { + if constexpr (sizeof(T) == 1) { + *span.data() = *std::bit_cast(&value); + } else { + T copy; + if constexpr (std::endian::native != Endian) { + copy = std::bit_cast(utility::byteswap(std::bit_cast(value))); + } else { + copy = value; + } + decltype(span)::pointer ptr = span.data(); + for (size_t i = 0; i < sizeof(T); i++) { + *ptr = *std::bit_cast(©) & 0xFF; + ptr++; + *std::bit_cast(©) >>= 8; + } + } + return span.size() >= sizeof(T) ? sizeof(T) : 0; + } + + return sizeof(T); + } else if constexpr (std::floating_point) { + using int_type = std::conditional_t, uint32_t, uint64_t>; + + return encode(std::bit_cast(value), span); + } else if constexpr (std::is_enum_v) { + using underlying_type = std::underlying_type_t; + + return encode(std::bit_cast(value), span); + } else if constexpr (std::convertible_to) { + const std::string_view str = static_cast(value); + const size_t str_size = str.size(); + const size_t length_size = encode(str_size); + const size_t size = str_size + length_size; + + if (span.empty()) { + return size; + } else if (span.size() < size) { + return 0; + } + + uint8_t* ptr = span.data() + encode(str_size, span); + std::copy_n(reinterpret_cast(str.data()), str_size, ptr); + + return size; + } else if constexpr (std::same_as) { + return encode(make_half_float(value.value), span); + } else if constexpr (std::same_as) { + return encode(value.to_int(), span); + } else if constexpr (std::same_as) { + return encode(value.get_timespan(), span); + } else if constexpr (IsColour) { + return encode(static_cast(value), span); + } else if constexpr (std::same_as) { + return encode(value.get_raw_value(), span); + } else if constexpr (std::convertible_to>) { + const std::span inner_span = static_cast>(value); + const size_t span_size = inner_span.size(); + const size_t length_size = encode(span_size); + const size_t size = span_size + length_size; + + if (span.empty()) { + return size; + } else if (span.size() < size) { + return 0; + } + + uint8_t* ptr = span.data() + encode(span_size, span); + std::copy_n(inner_span.data(), span_size, ptr); + + return size; + } else if constexpr (utility::specialization_of || + utility::specialization_of || utility::specialization_of_span) { + const size_t collection_size = value.size(); + const size_t length_size = encode(collection_size); + const size_t size = + collection_size * encode(value.empty() ? typename T::value_type() : value[0]) + + length_size; + + if (span.empty()) { + return size; + } else if (span.size() < size) { + return 0; + } + + size_t offset = encode(collection_size, span); + for (auto const& element : value) { + offset += encode(element, span.subspan(offset), endian_tag); + } + + return size; + } else if constexpr (utility::specialization_of || utility::specialization_of) { + static size_t size = std::apply( + [](auto&&... args) { + return (encode(args) + ...); + }, + value + ); + + if (span.empty()) { + return size; + } + + size_t offset = 0; + std::apply( + [&](auto&&... args) { + (( // + offset += encode(args, span.subspan(offset)) + ), + ...); + }, + value + ); + + return size; + } else if constexpr (utility::specialization_of) { + if (span.empty()) { + return encode(value.x) + encode(value.y); + } + + size_t offset = encode(value.x, span); + return offset + encode(value.y, span.subspan(offset)); + } else if constexpr (utility::specialization_of) { + if (span.empty()) { + return encode(value.x) + encode(value.y) + + encode(value.z); + } + + size_t offset = encode(value.x, span); + offset += encode(value.y, span.subspan(offset)); + return offset + encode(value.z, span.subspan(offset)); + } else if constexpr (utility::specialization_of) { + if (span.empty()) { + return encode(value.x) + encode(value.y) + + encode(value.z) + encode(value.w); + } + + size_t offset = encode(value.x, span); + offset += encode(value.y, span.subspan(offset)); + offset += encode(value.z, span.subspan(offset)); + return offset + encode(value.w, span.subspan(offset)); + } else if constexpr (std::is_empty_v && std::is_default_constructible_v) { + return 0; + } + } + + template + using index_t = std::integral_constant; + + template + constexpr index_t index_v {}; + + template + using indexes_t = std::variant...>; + + template + constexpr indexes_t get_index(std::index_sequence, std::size_t I) { + constexpr indexes_t retvals[] = { index_v... }; + return retvals[I]; + } + + template + constexpr auto get_index(std::size_t I) { + return get_index(std::make_index_sequence {}, I); + } + + template + concept HasDecode = requires(std::span span, size_t& r_decode_count) { + { T::template decode(span, r_decode_count) } -> std::same_as; + }; + + template + concept Decodable = HasDecode || utility::specialization_of || std::integral || + std::floating_point || std::is_enum_v || std::convertible_to || + std::is_constructible_v || std::same_as || std::same_as || + std::same_as || IsColour || std::same_as || + std::is_constructible_v || utility::specialization_of || + utility::specialization_of || utility::specialization_of || + utility::specialization_of || utility::specialization_of_std_array_of || + utility::specialization_of || utility::specialization_of || + utility::specialization_of || (std::is_empty_v && std::is_default_constructible_v); + + template + requires Decodable || Decodable> + static inline T decode( // + std::span span, size_t& r_decode_count, std::integral_constant endian = {} + ) { + if constexpr (std::is_const_v || std::is_volatile_v) { + return decode, Endian>(span, r_decode_count); + } else if constexpr (HasDecode) { + return T::template decode(span, r_decode_count); + } else if constexpr (utility::specialization_of) { + auto index = get_index>(decode(span, r_decode_count)); + size_t offset = r_decode_count; + T value = std::visit( + [&](auto INDEX) -> T { + using assigned_type = std::variant_alternative_t; + return decode(span.subspan(offset), r_decode_count); + }, + index + ); + r_decode_count = offset + r_decode_count; + return value; + } else if constexpr (std::same_as) { + return static_cast(decode(span, r_decode_count)); + } else if constexpr (std::integral) { + using unsigned_type = std::make_unsigned_t; + + r_decode_count = sizeof(T); + if (OV_unlikely(span.size() < r_decode_count)) { + r_decode_count = 0; + OV_ERR_FAIL_V({}); + } + + if constexpr (sizeof(T) == 1) { + return std::bit_cast(span[0]); + } else { + decltype(span)::pointer ptr = span.data(); + T u = 0; + for (int i = 0; i < r_decode_count; i++) { + T b = *ptr; + b <<= (i * 8); + u |= b; + ptr++; + } + if constexpr (std::endian::native != Endian) { + u = utility::byteswap(u); + } + return u; + } + } else if constexpr (std::floating_point) { + using int_type = std::conditional_t, uint32_t, uint64_t>; + + return std::bit_cast(decode(span, r_decode_count)); + } else if constexpr (std::is_enum_v) { + using underlying_type = std::underlying_type_t; + + return std::bit_cast(decode(span, r_decode_count)); + } else if constexpr (std::convertible_to) { + uint32_t size = decode(span, r_decode_count); + OV_ERR_FAIL_INDEX_V(size, span.size(), {}); + std::string_view view { reinterpret_cast(span.data() + r_decode_count), size }; + r_decode_count += size; + return static_cast(view); + } else if constexpr (std::is_constructible_v) { + return T { decode(span, r_decode_count) }; + } else if constexpr (std::same_as) { + return { half_to_float(decode(span, r_decode_count)) }; + } else if constexpr (std::same_as) { + return T { decode(span, r_decode_count) }; + } else if constexpr (std::same_as) { + return T { decode(span, r_decode_count) }; + } else if constexpr (IsColour) { + return T::from_integer(decode(span, r_decode_count)); + } else if constexpr (std::same_as) { + return T::parse_raw(decode(span, r_decode_count)); + } else if constexpr (utility::specialization_of || + utility::specialization_of) { + uint32_t size = decode(span, r_decode_count); + OV_ERR_FAIL_INDEX_V(size, span.size(), T {}); + T value; + if constexpr (requires { value.reserve(size); }) { + value.reserve(size); + } + size_t offset = r_decode_count; + for (size_t index = 0; index < size; index++) { + value.emplace_back(decode(span.subspan(offset), r_decode_count)); + offset += r_decode_count; + } + r_decode_count = offset; + return value; + } else if constexpr (std::is_constructible_v) { + uint32_t size = decode(span, r_decode_count); + OV_ERR_FAIL_INDEX_V(size, span.size(), (T { (uint8_t*)nullptr, (uint8_t*)nullptr })); + T value // + { // + reinterpret_cast(span.data() + r_decode_count), // + reinterpret_cast(span.data() + r_decode_count + size) + }; + r_decode_count += size; + return value; + } else if constexpr (utility::specialization_of || utility::specialization_of) { + union { + T tuple; + bool dummy; + } value = { .dummy = false }; + size_t offset = 0; + std::apply( + [&](auto&... args) { + ((args = decode, Endian>(span.subspan(offset), r_decode_count), + offset += r_decode_count), + ...); + }, + value.tuple + ); + r_decode_count = offset; + + return std::move(value.tuple); + } else if constexpr (utility::specialization_of_std_array_of) { + T value; + OV_ERR_FAIL_INDEX_V(value.size(), span.size(), {}); + uint32_t size = decode(span, r_decode_count); + std::uninitialized_copy_n(span.subspan(r_decode_count).data(), size, value.data()); + r_decode_count += size; + return value; + } else if constexpr (utility::specialization_of) { + T value; + value.x = decode(span, r_decode_count); + size_t offset = r_decode_count; + value.y = decode(span.subspan(offset), r_decode_count); + r_decode_count += offset; + return value; + } else if constexpr (utility::specialization_of) { + T value; + value.x = decode(span, r_decode_count); + size_t offset = r_decode_count; + value.y = decode(span.subspan(offset), r_decode_count); + offset += r_decode_count; + value.z = decode(span.subspan(offset), r_decode_count); + r_decode_count += offset; + return value; + } else if constexpr (utility::specialization_of) { + T value; + value.x = decode(span, r_decode_count); + size_t offset = r_decode_count; + value.y = decode(span.subspan(offset), r_decode_count); + offset += r_decode_count; + value.z = decode(span.subspan(offset), r_decode_count); + offset += r_decode_count; + value.w = decode(span.subspan(offset), r_decode_count); + r_decode_count += offset; + return value; + } else if constexpr (std::is_empty_v && std::is_default_constructible_v) { + r_decode_count = 0; + return T(); + } + } +} diff --git a/src/openvic-simulation/utility/Utility.hpp b/src/openvic-simulation/utility/Utility.hpp index 29a6df79f..31a51cc42 100644 --- a/src/openvic-simulation/utility/Utility.hpp +++ b/src/openvic-simulation/utility/Utility.hpp @@ -1,10 +1,14 @@ #pragma once +#include +#include #include #include #include +#include #include #include +#include #include #include @@ -123,13 +127,47 @@ namespace OpenVic::utility { template class Template> concept not_specialization_of = !specialization_of; - template class Template, typename... Args> + template class Template, typename... Args> void _derived_from_specialization_impl(Template const&); - template class Template> - concept is_derived_from_specialization_of = requires(T const& t) { - _derived_from_specialization_impl