diff --git a/.gitmodules b/.gitmodules index 8ca8012ce..1a2ea72e6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -30,3 +30,6 @@ path = deps/memory url = https://github.com/foonathan/memory ignore = dirty +[submodule "deps/reliable"] + path = deps/reliable + url = https://github.com/mas-bandwidth/reliable diff --git a/deps/SCsub b/deps/SCsub index 948ebe78d..6315e9647 100644 --- a/deps/SCsub +++ b/deps/SCsub @@ -111,6 +111,36 @@ def build_memory(env): env.exposed_includes += env.memory["INCPATH"] +def build_reliable(env): + import os + + if env["dev_build"]: + env.Append(CPPDEFINES=["RELIABLE_DEBUG"]) + else: + env.Append(CPPDEFINES=["RELIABLE_RELEASE"]) + reliable_env = env.Clone() + + include_path = "reliable" + include_dir = reliable_env.Dir(include_path) + sources = [os.path.join(include_path, "reliable.c")] + env.reliable_sources = sources + library_name = "libreliable" + env["LIBSUFFIX"] + library = reliable_env.StaticLibrary(target=os.path.join(include_path, library_name), source=sources) + Default(library) + + env.reliable = {} + env.reliable["INCPATH"] = [include_dir] + + env.Append(CPPPATH=env.reliable["INCPATH"]) + if env.get("is_msvc", False): + env.Append(CXXFLAGS=["/external:I", include_dir, "/external:W0"]) + else: + env.Append(CXXFLAGS=["-isystem", include_dir]) + env.Append(LIBPATH=include_dir) + env.Prepend(LIBS=[library_name]) + + env.exposed_includes += env.reliable["INCPATH"] + def link_tbb(env): import sys if not env.get("is_msvc", False) and not env.get("use_mingw", False) and sys.platform != "darwin": @@ -123,4 +153,5 @@ build_colony(env) build_function2(env) build_std_function(env) build_memory(env) +build_reliable(env) link_tbb(env) \ No newline at end of file diff --git a/deps/reliable b/deps/reliable new file mode 160000 index 000000000..57b0c90fa --- /dev/null +++ b/deps/reliable @@ -0,0 +1 @@ +Subproject commit 57b0c90fa68598dbf41467ad28e9247cb08928c3 diff --git a/src/openvic-simulation/GameManager.cpp b/src/openvic-simulation/GameManager.cpp index 2962a7a62..08820f644 100644 --- a/src/openvic-simulation/GameManager.cpp +++ b/src/openvic-simulation/GameManager.cpp @@ -1,11 +1,13 @@ #include "GameManager.hpp" #include - #include #include #include "openvic-simulation/dataloader/Dataloader.hpp" +#include "openvic-simulation/multiplayer/ClientManager.hpp" +#include "openvic-simulation/multiplayer/HostManager.hpp" +#include "openvic-simulation/utility/Containers.hpp" #include "openvic-simulation/utility/Logger.hpp" using namespace OpenVic; @@ -252,6 +254,54 @@ bool GameManager::update_clock() { return instance_manager->update_clock(); } +void GameManager::create_client() { + client_manager = memory::make_unique(this); + chat_manager = memory::make_unique(client_manager.get()); +} + +void GameManager::create_host(memory::string session_name) { + host_manager = memory::make_unique(this); + if (!session_name.empty()) { + host_manager->get_host_session().set_game_name(session_name); + } +} + +void GameManager::threaded_poll_network() { + const int MAX_POLL_WAIT_MSEC = 100; + const int SLEEP_DURATION_USEC = 1000; + + if (host_manager) { + const uint64_t time = GameManager::get_elapsed_milliseconds(); + while (!(host_manager->poll() > 0) && (GameManager::get_elapsed_milliseconds() - time) < MAX_POLL_WAIT_MSEC) { + std::this_thread::sleep_for(std::chrono::microseconds { SLEEP_DURATION_USEC }); + } + } + + if (client_manager) { + const uint64_t time = GameManager::get_elapsed_milliseconds(); + while (!(client_manager->poll() > 0) && (GameManager::get_elapsed_milliseconds() - time) < MAX_POLL_WAIT_MSEC) { + std::this_thread::sleep_for(std::chrono::microseconds { SLEEP_DURATION_USEC }); + } + + if (host_manager) { + // TODO: create local ClientManager that doesn't send network data to HostManager + // In the case that client_manager sends something, host_manager may handle it + const uint64_t time = GameManager::get_elapsed_milliseconds(); + while (!(host_manager->poll() > 0) && (GameManager::get_elapsed_milliseconds() - time) < MAX_POLL_WAIT_MSEC / 4) { + std::this_thread::sleep_for(std::chrono::microseconds { SLEEP_DURATION_USEC / 4 }); + } + } + } +} + +void GameManager::delete_client() { + client_manager.reset(); +} + +void GameManager::delete_host() { + host_manager.reset(); +} + uint64_t GameManager::get_elapsed_microseconds() { return get_elapsed_usec_time_callback(); } diff --git a/src/openvic-simulation/GameManager.hpp b/src/openvic-simulation/GameManager.hpp index 64b2f4704..f19c1393a 100644 --- a/src/openvic-simulation/GameManager.hpp +++ b/src/openvic-simulation/GameManager.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include "openvic-simulation/DefinitionManager.hpp" @@ -9,11 +10,18 @@ #include "openvic-simulation/dataloader/Dataloader.hpp" #include "openvic-simulation/misc/GameRulesManager.hpp" #include "openvic-simulation/gen/commit_info.gen.hpp" +#include "openvic-simulation/multiplayer/ChatManager.hpp" +#include "openvic-simulation/multiplayer/ClientManager.hpp" +#include "openvic-simulation/multiplayer/HostManager.hpp" +#include "openvic-simulation/player/PlayerManager.hpp" #include "openvic-simulation/utility/ForwardableSpan.hpp" #include namespace OpenVic { + struct HostManager; + struct ClientManager; + struct GameManager { using elapsed_time_getter_func_t = fu2::function_base; @@ -31,6 +39,12 @@ namespace OpenVic { bool PROPERTY_CUSTOM_PREFIX(definitions_loaded, are); bool PROPERTY_CUSTOM_PREFIX(mod_descriptors_loaded, are); + memory::unique_ptr host_manager; + memory::unique_ptr client_manager; + memory::unique_ptr chat_manager; + + bool _get_mod_dependencies(Mod const* mod, memory::vector& load_list); + public: GameManager( InstanceManager::gamestate_updated_func_t new_gamestate_updated_callback, @@ -47,6 +61,30 @@ namespace OpenVic { return instance_manager ? &*instance_manager : nullptr; } + inline HostManager* get_host_manager() { + return host_manager.get(); + } + + inline HostManager const* get_host_manager() const { + return host_manager.get(); + } + + inline ClientManager* get_client_manager() { + return client_manager.get(); + } + + inline ClientManager const* get_client_manager() const { + return client_manager.get(); + } + + inline ChatManager* get_chat_manager() { + return chat_manager.get(); + } + + inline ChatManager const* get_chat_manager() const { + return chat_manager.get(); + } + inline bool set_base_path(Dataloader::path_span_t base_path) { OV_ERR_FAIL_COND_V_MSG(base_path.size() > 1, false, "Too many base paths were provided, only one should be set."); OV_ERR_FAIL_COND_V_MSG(!dataloader.set_roots(base_path, {}), false, "Failed to set Dataloader's base path"); @@ -73,6 +111,15 @@ namespace OpenVic { bool update_clock(); + void create_client(); + void create_host(memory::string session_name = ""); + + // DO NOT RUN ON MAIN THREAD + void threaded_poll_network(); + + void delete_client(); + void delete_host(); + static constexpr std::string_view get_commit_hash() { return SIM_COMMIT_HASH; } diff --git a/src/openvic-simulation/multiplayer/BaseMultiplayerManager.cpp b/src/openvic-simulation/multiplayer/BaseMultiplayerManager.cpp new file mode 100644 index 000000000..febd141c3 --- /dev/null +++ b/src/openvic-simulation/multiplayer/BaseMultiplayerManager.cpp @@ -0,0 +1,20 @@ +#include "BaseMultiplayerManager.hpp" + +#include "openvic-simulation/multiplayer/PacketType.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" + +using namespace OpenVic; + +BaseMultiplayerManager::BaseMultiplayerManager(GameManager* game_manager) : game_manager(game_manager) { + packet_cache.reserve_power(15); +} + +bool BaseMultiplayerManager::broadcast_packet(PacketType const& type, PacketType::argument_type argument) { + OV_ERR_FAIL_COND_V(!PacketType::is_valid_type(type), false); + return true; +} + +bool BaseMultiplayerManager::send_packet(client_id_type client_id, PacketType const& type, PacketType::argument_type argument) { + OV_ERR_FAIL_COND_V(!PacketType::is_valid_type(type), false); + return true; +} diff --git a/src/openvic-simulation/multiplayer/BaseMultiplayerManager.hpp b/src/openvic-simulation/multiplayer/BaseMultiplayerManager.hpp new file mode 100644 index 000000000..e73f639de --- /dev/null +++ b/src/openvic-simulation/multiplayer/BaseMultiplayerManager.hpp @@ -0,0 +1,75 @@ +#pragma once + +#include + +#include "openvic-simulation/multiplayer/Constants.hpp" +#include "openvic-simulation/multiplayer/HostSession.hpp" +#include "openvic-simulation/multiplayer/PacketType.hpp" +#include "openvic-simulation/multiplayer/lowlevel/Constants.hpp" +#include "openvic-simulation/types/RingBuffer.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/Getters.hpp" + +namespace OpenVic { + struct GameManager; + + struct BaseMultiplayerManager { + BaseMultiplayerManager(GameManager* game_manager = nullptr); + virtual ~BaseMultiplayerManager() = default; + + using client_id_type = OpenVic::client_id_type; + using sequence_type = reliable_udp_sequence_type; + + static constexpr client_id_type HOST_ID = MP_HOST_ID; + + virtual bool broadcast_packet(PacketType const& type, PacketType::argument_type argument); + virtual bool send_packet(client_id_type client_id, PacketType const& type, PacketType::argument_type argument); + virtual int64_t poll() = 0; + virtual void close() = 0; + + enum class ConnectionType : uint8_t { HOST, CLIENT }; + + virtual ConnectionType get_connection_type() const = 0; + + template + T* cast_as() { + if (get_connection_type() == T::type_tag) { + return static_cast(this); + } + return nullptr; + } + + template + T const* cast_as() const { + if (get_connection_type() == T::type_tag) { + return static_cast(this); + } + return nullptr; + } + + PacketSpan get_last_raw_packet() { + return last_raw_packet; + } + + protected: + bool PROPERTY_ACCESS(in_lobby, protected, false); + GameManager* PROPERTY_PTR_ACCESS(game_manager, protected); + HostSession PROPERTY_ACCESS(host_session, protected); + + RingBuffer packet_cache; + + struct PacketCacheIndex { + decltype(packet_cache)::const_iterator begin; + decltype(packet_cache)::const_iterator end; + constexpr bool is_valid() const { + return begin != end; + } + }; + + memory::vector last_raw_packet; + + friend bool PacketTypes::send_raw_packet_process_callback( // + BaseMultiplayerManager* multiplayer_manager, PacketSpan packet + ); + }; +} diff --git a/src/openvic-simulation/multiplayer/ChatManager.cpp b/src/openvic-simulation/multiplayer/ChatManager.cpp new file mode 100644 index 000000000..09a576a82 --- /dev/null +++ b/src/openvic-simulation/multiplayer/ChatManager.cpp @@ -0,0 +1,124 @@ +#include "ChatManager.hpp" + +#include +#include + +#include "openvic-simulation/GameManager.hpp" +#include "openvic-simulation/multiplayer/ClientManager.hpp" +#include "openvic-simulation/multiplayer/PacketType.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" + +using namespace OpenVic; + +ChatGroup::ChatGroup(index_type index, memory::vector&& clients) + : index { index }, clients { std::move(clients) } {} + +ChatManager::ChatManager(ClientManager* client_manager) : client_manager { client_manager } {} + +bool ChatManager::send_private_message(BaseMultiplayerManager::client_id_type to, memory::string&& message) { + MessageData data { MessageType::PRIVATE, std::move(message), to }; + + ChatMessageLog const& log = log_message(client_manager->get_player()->get_client_id(), std::move(data)); + bool was_sent = client_manager->broadcast_packet(PacketTypes::send_chat_message, &log); + return was_sent; +} + +bool ChatManager::send_private_message(BaseMultiplayerManager::client_id_type to, std::string_view message) { + return send_private_message(to, memory::string { message }); +} + +bool ChatManager::send_public_message(memory::string&& message) { + MessageData data { MessageType::PUBLIC, std::move(message) }; + + ChatMessageLog const& log = log_message(client_manager->get_player()->get_client_id(), std::move(data)); + bool was_sent = client_manager->broadcast_packet(PacketTypes::send_chat_message, &log); + return was_sent; +} + +bool ChatManager::send_public_message(std::string_view message) { + return send_public_message(memory::string { message }); +} + +bool ChatManager::send_group_message(ChatGroup const& group, memory::string&& message) { + MessageData data { MessageType::GROUP, std::move(message), group.get_index() }; + + ChatMessageLog const& log = log_message(client_manager->get_player()->get_client_id(), std::move(data)); + bool was_sent = client_manager->broadcast_packet(PacketTypes::send_chat_message, &log); + return was_sent; +} + +bool ChatManager::send_group_message(ChatGroup::index_type group_id, memory::string&& message) { + OV_ERR_FAIL_INDEX_V(group_id, groups.size(), false); + return send_group_message(groups[group_id], std::move(message)); +} + +bool ChatManager::send_group_message(ChatGroup const& group, std::string_view message) { + return send_group_message(group, memory::string { message }); +} + +bool ChatManager::send_group_message(ChatGroup::index_type group_id, std::string_view message) { + return send_group_message(group_id, memory::string { message }); +} + +ChatMessageLog const& ChatManager::log_message( // + BaseMultiplayerManager::client_id_type from, MessageData&& message, + std::chrono::time_point timestamp +) { + ChatMessageLog const& log = messages.emplace_back( + from, std::move(message), std::chrono::time_point_cast(timestamp).time_since_epoch().count() + ); + message_logged(log); + return log; +} + +ChatMessageLog const* ChatManager::_log_message(ChatMessageLog&& log) { + std::chrono::time_point cur = std::chrono::system_clock::now(); + int64_t current_epoch = std::chrono::time_point_cast(cur).time_since_epoch().count(); + OV_ERR_FAIL_COND_V_MSG(log.timestamp > current_epoch, nullptr, "Invalid chat message log, cannot log future timestamps."); + + ChatMessageLog const& moved_log = messages.emplace_back(std::move(log)); + message_logged(moved_log); + return &moved_log; +} + +memory::vector const& ChatManager::get_message_logs() const { + return messages; +} + +void ChatManager::create_group(memory::vector&& clients) { + client_manager->broadcast_packet(PacketTypes::add_chat_group, clients); +} + +void ChatManager::_create_group(memory::vector&& clients) { + groups.push_back({ groups.size(), std::move(clients) }); + ChatGroup const& last = groups.back(); + group_created(last); +} + +void ChatManager::set_group(ChatGroup::index_type group_id, memory::vector&& clients) { + OV_ERR_FAIL_INDEX(group_id, groups.size()); + client_manager->broadcast_packet(PacketTypes::modify_chat_group, PacketChatGroupModifyData { group_id, clients }); +} + +void ChatManager::_set_group(ChatGroup::index_type group_id, memory::vector&& clients) { + ChatGroup& group = groups[group_id]; + std::swap(group.clients, clients); + group_modified(group, clients); +} + +ChatGroup const& ChatManager::get_group(ChatGroup::index_type group_index) const { + return groups.at(group_index); +} + +void ChatManager::delete_group(ChatGroup::index_type group_id) { + OV_ERR_FAIL_INDEX(group_id, groups.size()); + client_manager->broadcast_packet( + PacketTypes::delete_chat_group, + PacketType::argument_type { std::in_place_index, group_id } + ); +} + +void ChatManager::_delete_group(ChatGroup::index_type group_id) { + groups.erase(groups.begin() + group_id); + group_deleted(group_id); +} diff --git a/src/openvic-simulation/multiplayer/ChatManager.hpp b/src/openvic-simulation/multiplayer/ChatManager.hpp new file mode 100644 index 000000000..501410276 --- /dev/null +++ b/src/openvic-simulation/multiplayer/ChatManager.hpp @@ -0,0 +1,176 @@ +#pragma once + +#include +#include +#include + +#include "openvic-simulation/multiplayer/Constants.hpp" +#include "openvic-simulation/multiplayer/PacketType.hpp" +#include "openvic-simulation/types/Signal.hpp" +#include "openvic-simulation/utility/Marshal.hpp" + +namespace OpenVic { + struct GameManager; + struct ChatManager; + struct ChatMessageLog; + struct BaseMultiplayerManager; + struct ClientManager; + + struct ChatGroup { + using index_type = size_t; + + ChatGroup(ChatGroup const&) = delete; + ChatGroup& operator=(ChatGroup const&) = delete; + ChatGroup(ChatGroup&&) = default; + ChatGroup& operator=(ChatGroup&& lhs) = default; + + std::span get_clients() const { + return clients; + } + + operator index_type() const { + return index; + } + + private: + friend struct ChatManager; + ChatGroup(index_type index, memory::vector&& clients); + + memory::vector clients; + index_type PROPERTY(index); + }; + + struct ChatManager { + enum class MessageType : uint8_t { NONE, PRIVATE, GROUP, PUBLIC }; + + struct MessageData { + MessageType type = MessageType::NONE; + memory::string message; + union { + client_id_type to_client = MP_HOST_ID; + ChatGroup::index_type to_group; + }; + + template + size_t encode(std::span span) const { + size_t offset = utility::encode(type, span, utility::endian_tag); + if (type != MessageType::PUBLIC) { + offset += // + utility::encode( + to_client, // + span.empty() ? span : span.subspan(offset) // + ); + } + + return utility::encode(message, span.empty() ? span : span.subspan(offset), utility::endian_tag) + + offset; + } + + template + static MessageData decode(std::span span, size_t& r_decode_count) { + MessageType type = utility::decode(span, r_decode_count); + size_t offset = r_decode_count; + + // Depending on type, either client_id_type or ChatGroup::index_type + decltype(to_client) to_index = MP_HOST_ID; + if (type != MessageType::PUBLIC) { + to_index = utility::decode(span.subspan(offset), r_decode_count); + offset += r_decode_count; + } + + MessageData data { + type, utility::decode(span.subspan(offset), r_decode_count), to_index // + }; + r_decode_count += offset; + return data; + } + }; + + ChatManager(ClientManager* client_manager = nullptr); + + bool send_private_message(client_id_type to, memory::string&& message); + bool send_private_message(client_id_type to, std::string_view message); + bool send_public_message(memory::string&& message); + bool send_public_message(std::string_view message); + bool send_group_message(ChatGroup const& group, memory::string&& message); + bool send_group_message(ChatGroup::index_type group_id, memory::string&& message); + bool send_group_message(ChatGroup const& group, std::string_view message); + bool send_group_message(ChatGroup::index_type group_id, std::string_view message); + + signal_property message_logged; + + ChatMessageLog const& log_message( + client_id_type from, MessageData&& message, + std::chrono::time_point timestamp = std::chrono::system_clock::now() + ); + memory::vector const& get_message_logs() const; + + signal_property group_created; + signal_property& /* old clients */> + group_modified; + signal_property group_deleted; + + void create_group(memory::vector&& clients); + void set_group(ChatGroup::index_type group, memory::vector&& clients); + ChatGroup const& get_group(ChatGroup::index_type group_index) const; + void delete_group(ChatGroup::index_type group_id); + + private: + ClientManager* PROPERTY_PTR(client_manager); + + memory::deque groups; + memory::vector messages; + + friend bool PacketTypes::send_chat_message_process_callback( // + BaseMultiplayerManager* multiplayer_manager, PacketSpan packet + ); + ChatMessageLog const* _log_message(ChatMessageLog&& log); + + friend bool PacketTypes::add_chat_group_process_callback( // + BaseMultiplayerManager* multiplayer_manager, PacketSpan packet + ); + void _create_group(memory::vector&& clients); + + friend bool PacketTypes::modify_chat_group_process_callback( // + BaseMultiplayerManager* multiplayer_manager, PacketSpan packet + ); + void _set_group(ChatGroup::index_type group, memory::vector&& clients); + + friend bool PacketTypes::delete_chat_group_process_callback( // + BaseMultiplayerManager* multiplayer_manager, PacketSpan packet + ); + void _delete_group(ChatGroup::index_type group_id); + }; + + struct ChatMessageLog { + client_id_type from_id = MP_HOST_ID; + ChatManager::MessageData data; + int64_t timestamp = 0; + + ChatMessageLog() = default; + ChatMessageLog(client_id_type from, ChatManager::MessageData data, int64_t timestamp) + : from_id { from }, data { data }, timestamp { timestamp } {} + + template + size_t encode(std::span span) const { + size_t offset = utility::encode(from_id, span, utility::endian_tag); + offset += utility::encode(data, span.empty() ? span : span.subspan(offset), utility::endian_tag); + offset += utility::encode(timestamp, span.empty() ? span : span.subspan(offset), utility::endian_tag); + return offset; + } + + template + static ChatMessageLog decode(std::span span, size_t& r_decode_count) { + size_t offset = 0; + + uint16_t from_id = utility::decode(span, r_decode_count); + offset += r_decode_count; + decltype(data) data = utility::decode(span.subspan(offset), r_decode_count); + offset += r_decode_count; + decltype(timestamp) timestamp = utility::decode(span.subspan(offset), r_decode_count); + r_decode_count = offset; + + return { std::move(from_id), std::move(data), std::move(timestamp) }; + } + }; +} diff --git a/src/openvic-simulation/multiplayer/ClientManager.cpp b/src/openvic-simulation/multiplayer/ClientManager.cpp new file mode 100644 index 000000000..7bc2d588e --- /dev/null +++ b/src/openvic-simulation/multiplayer/ClientManager.cpp @@ -0,0 +1,196 @@ +#include "ClientManager.hpp" + +#include +#include +#include +#include +#include + +#include "openvic-simulation/GameManager.hpp" +#include "openvic-simulation/multiplayer/PacketType.hpp" +#include "openvic-simulation/multiplayer/lowlevel/HostnameAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketBuilder.hpp" +#include "openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/Marshal.hpp" + +using namespace OpenVic; + +bool ClientManager::connect_to(HostnameAddress const& address, NetworkSocket::port_type port) { + NetworkError err = client.connect_to(address.resolved_address(), port); + OV_ERR_FAIL_COND_V(err != NetworkError::OK, false); + PacketBuilder builder; + builder.put_back(1); + return client.set_packet(builder) == NetworkError::OK; +} + +template +bool ClientManager::_send_packet_to_host(auto const& argument) { + PacketBuilder builder; + builder.put_back(type.packet_id); + if constexpr (requires { type.create_packet(this, argument, builder); }) { + if (!type.create_packet(this, argument, builder)) { + return false; + } + } else { + if (!type.create_packet(this, &argument, builder)) { + return false; + } + } + + if (!add_packet_to_cache(builder)) { + return false; + } + + client.set_packet(builder); + return true; +} + +bool ClientManager::broadcast_packet(PacketType const& type, PacketType::argument_type argument) { + if (!BaseMultiplayerManager::broadcast_packet(type, argument)) { + return false; + } + + OV_ERR_FAIL_COND_V(!type.can_client_send, false); + return _send_packet_to_host(BroadcastData { type.packet_id, argument }); +} + +bool ClientManager::send_packet(client_id_type client_id, PacketType const& type, PacketType::argument_type argument) { + if (!BaseMultiplayerManager::send_packet(client_id, type, argument)) { + return false; + } + + OV_ERR_FAIL_COND_V(!type.can_client_send, false); + return _send_packet_to_host(RetransmitData { client_id, type.packet_id, argument }); +} + +int64_t ClientManager::poll() { + int64_t result = client.available_packets(); + if (result < 1) { + return result; + } + + PacketSpan span = client.packet_span(); + if (player == nullptr) { + OV_ERR_FAIL_COND_V(client.get_current_sequence_value() != 0, -1); + client_id_type client_id = span.read(); + OV_ERR_FAIL_COND_V(client_id == INVALID_CLIENT_ID, -1); + player = host_session.add_player(client_id, ""); + return poll(); + } + + decltype(PacketType::packet_id) packet_id = span.read(); + OV_ERR_FAIL_COND_V(!PacketType::is_valid_type(packet_id), -1); + + PacketType const& packet_type = PacketType::get_type_by_id(packet_id); + if (!packet_type.process_packet(this, span)) { + // TODO: packet processing failed + return -1; + } + // TODO: packet was processed + + return result; +} + +void ClientManager::close() { + client.close(); +} + +bool ClientManager::connect_to_resource_server(std::optional port) { + if (is_running_as_host()) { + return true; + } + + NetworkError err = resource_client.connect_to(client.get_peer_ip(), port.value_or(client.get_peer_port() + 1)); + OV_ERR_FAIL_COND_V(err != NetworkError::OK, false); + return true; +} + +std::future ClientManager::poll_resource_server() { + if (is_running_as_host()) { + return {}; + } + + if (resource_client.poll() != NetworkError::OK || resource_client.get_status() != TcpPacketStream::Status::CONNECTED) { + return {}; + } + + PacketBuffer buffer = resource_client.packet_buffer(sizeof(size_t)); + size_t buffer_size = buffer.read(); + buffer.clear(); + buffer.resize(buffer_size); + resource_client.get_data(buffer); + std::string_view str = buffer.read(); + std::filesystem::path path = str; + + bool is_check = buffer.read(); + if (is_check) { + return std::async(std::launch::async, [&] { + std::error_code ec; + + PacketBuilder<> pb; + pb.put_back(utility::encode(str) + utility::encode(true)); + pb.put_back(str); + + if (!std::filesystem::is_directory(path.parent_path(), ec)) { + pb.put_back(false); + OV_ERR_FAIL_COND(!ec); + return; + } + + // TODO: get checksum of file + if (!std::filesystem::is_regular_file(path, ec)) { + pb.put_back(false); + OV_ERR_FAIL_COND(!ec); + return; + } + + pb.put_back(true); + }); + } + + size_t index = buffer.index(); + std::span data { &buffer.data()[index], buffer.size() - index }; + + return std::async(std::launch::async, [&] { + std::error_code ec; + + OV_ERR_FAIL_COND(!std::filesystem::create_directories(path.parent_path(), ec) || !ec); + + { + std::FILE* file = utility::fopen(path.c_str(), "wb"); + OV_ERR_FAIL_COND(file == nullptr); + + size_t write_count = std::fwrite(data.data(), sizeof(uint8_t), data.size(), file); + std::fclose(file); + OV_ERR_FAIL_COND(write_count != buffer_size); + } + }); +} + +bool ClientManager::is_running_as_host() const { + return game_manager->get_host_manager() != nullptr; +} + +bool ClientManager::add_packet_to_cache(std::span bytes) { + OV_ERR_FAIL_COND_V(bytes.size_bytes() > ReliableUdpClient::MAX_PACKET_SIZE, false); + decltype(packet_cache)::iterator begin = packet_cache.append(bytes.data(), bytes.size_bytes()); + decltype(packet_cache)::iterator end = packet_cache.end(); + OV_ERR_FAIL_COND_V(begin == end, false); + return sequence_to_index.try_emplace(client.get_next_sequence_value(), PacketCacheIndex { begin, end }).second; +} + +memory::vector ClientManager::get_packet_cache(sequence_type sequence_value) { + decltype(sequence_to_index)::iterator it = sequence_to_index.find(sequence_value); + OV_ERR_FAIL_COND_V(it == sequence_to_index.end(), {}); + return { it.value().begin, it.value().end }; +} + +void ClientManager::remove_from_cache(sequence_type sequence_value) { + decltype(sequence_to_index)::iterator it = sequence_to_index.find(sequence_value); + OV_ERR_FAIL_COND(it == sequence_to_index.end()); + sequence_to_index.unordered_erase(it); +} diff --git a/src/openvic-simulation/multiplayer/ClientManager.hpp b/src/openvic-simulation/multiplayer/ClientManager.hpp new file mode 100644 index 000000000..44add5965 --- /dev/null +++ b/src/openvic-simulation/multiplayer/ClientManager.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +#include "openvic-simulation/multiplayer/BaseMultiplayerManager.hpp" +#include "openvic-simulation/multiplayer/Constants.hpp" +#include "openvic-simulation/multiplayer/lowlevel/HostnameAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.hpp" +#include "openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct GameManager; + + struct ClientManager final : BaseMultiplayerManager { + using BaseMultiplayerManager::BaseMultiplayerManager; + + bool connect_to(HostnameAddress const& address, socket_port_type port); + + bool broadcast_packet(PacketType const& type, PacketType::argument_type argument) override; + bool send_packet(client_id_type client_id, PacketType const& type, PacketType::argument_type argument) override; + int64_t poll() override; + void close() override; + + bool connect_to_resource_server(std::optional port = std::nullopt); + std::future poll_resource_server(); + + bool is_running_as_host() const; + + static constexpr ConnectionType type_tag = ConnectionType::CLIENT; + inline constexpr ConnectionType get_connection_type() const override { + return type_tag; + } + + static constexpr client_id_type INVALID_CLIENT_ID = MP_INVALID_CLIENT_ID; + + private: + ReliableUdpClient PROPERTY_REF(client); + TcpPacketStream resource_client; + Player const* PROPERTY(player, &Player::INVALID_PLAYER); + + friend bool PacketTypes::update_host_session_process_callback(BaseMultiplayerManager* game_manager, PacketSpan packet); + + ordered_map sequence_to_index; + + bool add_packet_to_cache(std::span bytes); + memory::vector get_packet_cache(sequence_type sequence_value); + void remove_from_cache(sequence_type sequence_value); + + template + bool _send_packet_to_host(auto const& argument); + }; +} diff --git a/src/openvic-simulation/multiplayer/Constants.hpp b/src/openvic-simulation/multiplayer/Constants.hpp new file mode 100644 index 000000000..95dce976b --- /dev/null +++ b/src/openvic-simulation/multiplayer/Constants.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include +#include + +namespace OpenVic { + using client_id_type = uint64_t; + static constexpr client_id_type MP_HOST_ID = static_cast(~0); + static constexpr client_id_type MP_INVALID_CLIENT_ID = std::numeric_limits::max() - 1; +} diff --git a/src/openvic-simulation/multiplayer/HostManager.cpp b/src/openvic-simulation/multiplayer/HostManager.cpp new file mode 100644 index 000000000..3c799ec6c --- /dev/null +++ b/src/openvic-simulation/multiplayer/HostManager.cpp @@ -0,0 +1,366 @@ +#include "HostManager.hpp" + +#include +#include +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/BaseMultiplayerManager.hpp" +#include "openvic-simulation/multiplayer/ClientManager.hpp" +#include "openvic-simulation/multiplayer/HostSession.hpp" +#include "openvic-simulation/multiplayer/PacketType.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketBuilder.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp" +#include "openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.hpp" +#include "openvic-simulation/multiplayer/lowlevel/ReliableUdpServer.hpp" +#include "openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/TslHelper.hpp" + +using namespace OpenVic; + +HostManager::HostManager(GameManager* game_manager) : BaseMultiplayerManager(game_manager) { + host_session.session_changed.connect(&HostManager::_on_session_changed, this); +} + +HostManager::~HostManager() { + this->disconnect_all(); +} + +void HostManager::_on_session_changed() { + broadcast_packet(PacketTypes::update_host_session, &host_session); +} + +bool HostManager::listen(NetworkSocket::port_type port, HostnameAddress const& bind_address) { + return server.listen_to(port, bind_address.resolved_address()) == NetworkError::OK; +} + +bool HostManager::broadcast_packet(PacketType const& type, PacketType::argument_type argument) { + if (!BaseMultiplayerManager::broadcast_packet(type, argument)) { + return false; + } + + PacketBuilder builder; + builder.put_back(type.packet_id); + if (!type.create_packet(this, argument, builder)) { + return false; + } + + return broadcast_data(builder); +} + +bool HostManager::send_packet(client_id_type client_id, PacketType const& type, PacketType::argument_type argument) { + if (!BaseMultiplayerManager::send_packet(client_id, type, argument)) { + return false; + } + + decltype(clients)::iterator it = clients.find(client_id); + OV_ERR_FAIL_COND_V(it == clients.end(), false); + + PacketBuilder builder; + builder.put_back(type.packet_id); + + if (!type.create_packet(this, argument, builder)) { + return false; + } + + PacketCacheIndex index = cache_packet_for(client_id, builder); + if (!index.is_valid()) { + return false; + } + + if (it.value().client->set_packet(builder) != NetworkError::OK) { + return false; + } + + return true; +} + +int64_t HostManager::poll() { + if (server.poll() != NetworkError::OK && !server.is_connection_available()) { + return -1; + } + + { + PacketBuilder builder; + while (server.is_connection_available()) { + size_t client_id = clients.size(); + clients.try_emplace(client_id, ClientValue { server.take_next_client_as() }); + OV_ERR_CONTINUE(clients.back().second.client->packet_span().read() != 1); + builder.put_back(client_id); + clients.back().second.client->set_packet(builder); + builder.clear(); + } + } + + int64_t result = 0; + for (std::pair const& pair : clients) { + ReliableUdpClient& client = *pair.second.client; + + while (client.available_packets() > 0) { + PacketSpan span = client.packet_span(); + decltype(PacketType::packet_id) packet_id; + packet_id = span.read(); + OV_ERR_CONTINUE(!PacketType::is_valid_type(packet_id)); + + PacketType const& packet_type = PacketType::get_type_by_id(packet_id); + if (!packet_type.process_packet(this, span)) { + // TODO: packet processing failed + continue; + } + result += 1; + // TODO: packet was processed + } + } + return result; +} + +void HostManager::close() { + server.close(); +} + +size_t HostManager::send_resource(std::string_view path, bool recursive) { + if (resource_clients.empty()) { + resource_clients.reserve(clients.size()); + while (resource_server.is_connection_available()) { + resource_clients.push_back(resource_server.take_packet_stream()); + OV_ERR_CONTINUE(resource_clients.back()->get_status() != TcpPacketStream::Status::CONNECTED); + } + } + + std::error_code ec; + std::filesystem::path filepath = std::filesystem::current_path(ec); + OV_ERR_FAIL_COND_V(!ec, 0); + + filepath /= path; + while (std::filesystem::is_symlink(filepath, ec) && ec == std::error_code {}) { + filepath = std::filesystem::read_symlink(filepath, ec); + OV_ERR_FAIL_COND_V(!ec, 0); + } + OV_ERR_FAIL_COND_V(!ec, 0); + + const bool path_exists = std::filesystem::exists(filepath, ec); + OV_ERR_FAIL_COND_V(!ec, 0); + OV_ERR_FAIL_COND_V(!path_exists, 0); + + memory::vector file_bytes; + auto send_to_clients = [&](std::filesystem::directory_entry const& entry) -> size_t { + uintmax_t size = entry.file_size(ec); + OV_ERR_FAIL_COND_V(!ec, 0); + + file_bytes.resize(size); + { + std::FILE* file = utility::fopen(entry.path().c_str(), "rb"); + OV_ERR_FAIL_COND_V(file == nullptr, 0); + + size_t read_count = std::fread(file_bytes.data(), sizeof(uint8_t), size, file); + std::fclose(file); + OV_ERR_FAIL_COND_V(read_count != size, 0); + } + + for (memory::unique_ptr& client : resource_clients) { + OV_ERR_CONTINUE(client->set_data(file_bytes) != NetworkError::OK); + } + return 1; + }; + + if (std::filesystem::is_regular_file(filepath, ec)) { + std::filesystem::directory_entry const& entry { filepath, ec }; + OV_ERR_FAIL_COND_V(!ec, 0); + return send_to_clients(entry); + } else { + OV_ERR_FAIL_COND_V(!ec, 0); + } + + size_t file_count = 0; + for (std::filesystem::recursive_directory_iterator it { filepath, ec }; it != decltype(it)(); ++it) { + std::filesystem::directory_entry const& entry = *it; + OV_ERR_CONTINUE(!ec); + + if (entry.is_directory(ec)) { + if (!recursive) { + it.disable_recursion_pending(); + } + continue; + } else { + OV_ERR_CONTINUE(!ec); + } + + if (!entry.is_regular_file(ec)) { + Logger::warning("Only regular files can be sent as a resource, '", filepath, "' is not a regular file."); + continue; + } else { + OV_ERR_CONTINUE(!ec); + } + + const size_t sent_files = send_to_clients(entry); + OV_ERR_CONTINUE(sent_files == 0); + file_count += sent_files; + } + return file_count; +} + +std::future HostManager::check_for_resource(client_id_type client_id, memory::string const& check_path) { + OV_ERR_FAIL_INDEX_V(client_id, resource_clients.size(), {}); + + using namespace std::chrono_literals; + static constexpr std::chrono::duration SLEEP_DURATION = 1000us; + + TcpPacketStream& client = *resource_clients[client_id]; + + PacketBuilder<> pb; + pb.put_back(check_path); + pb.put_back(true); + client.set_data(pb); + + return std::async(std::launch::async, [&]() -> bool { + while (client.poll() == NetworkError::OK && client.get_status() != TcpPacketStream::Status::CONNECTED) { + std::this_thread::sleep_for(SLEEP_DURATION); + } + + if (client.get_status() != TcpPacketStream::Status::CONNECTED) { + return false; + } + + // TODO: support multiple simultaneous resource checks + + std::string_view str; + PacketBuffer buffer; + buffer.resize(sizeof(uint64_t)); + while (str != check_path) { + if (client.poll() == NetworkError::OK) { + std::this_thread::sleep_for(SLEEP_DURATION); + continue; + } else if (client.get_status() != TcpPacketStream::Status::CONNECTED) { + return false; + } + + client.get_data(buffer); + uint64_t buffer_size = buffer.read(); + if (buffer_size > buffer.size()) { + buffer.resize(buffer_size); + client.get_data(buffer); + } else { + client.get_data({ buffer.data(), buffer_size }); + } + str = buffer.read(); + } + + return buffer.read(); + }); +} + +bool HostManager::broadcast_data(std::span bytes) { + PacketCacheIndex index = add_packet_to_cache(bytes); + + if (!index.is_valid()) { + return false; + } + + bool has_succeeded = true; + for (std::pair pair : mutable_iterator(clients)) { + ReliableUdpClient& client = *pair.second.client; + if (client.set_packet(bytes) == NetworkError::OK) { + has_succeeded &= pair.second.add_cache_index(index).is_valid(); + continue; + } + has_succeeded = false; + } + return has_succeeded; +} + +bool HostManager::send_data(client_id_type client_id, std::span bytes) { + decltype(clients)::iterator it = clients.find(client_id); + OV_ERR_FAIL_COND_V(it == clients.end(), false); + + PacketCacheIndex index = cache_packet_for(client_id, bytes); + if (index.is_valid()) { + return false; + } + + if (it.value().client->set_packet(bytes) != NetworkError::OK) { + return false; + } + return true; +} + +HostSession& HostManager::get_host_session() { + return host_session; +} + +ReliableUdpClient* HostManager::get_client_by_id(client_id_type client_id) { + decltype(clients)::iterator it = clients.find(client_id); + OV_ERR_FAIL_COND_V(it == clients.end(), nullptr); + return it.value().client.get(); +} + +ReliableUdpClient const* HostManager::get_client_by_id(client_id_type client_id) const { + decltype(clients)::const_iterator it = clients.find(client_id); + OV_ERR_FAIL_COND_V(it == clients.end(), nullptr); + return it.value().client.get(); +} + +HostManager::ClientIterable HostManager::client_iterable() { + return ClientIterable { clients }; +} + +HostManager::ClientIterable HostManager::client_iterable() const { + return ClientIterable { clients }; +} + +HostManager::PacketCacheIndex HostManager::ClientValue::add_cache_index(PacketCacheIndex const& index) { + OV_ERR_FAIL_COND_V(!index.is_valid(), index); + + std::pair emplace_result = + sequence_to_index.try_emplace(client->get_next_sequence_value(), index); + + OV_ERR_FAIL_COND_V(!emplace_result.second, (PacketCacheIndex { index.end, index.end })); + return emplace_result.first.value(); +} + +void HostManager::ClientValue::remove_packet(sequence_type sequence_value) { + decltype(sequence_to_index)::iterator it = sequence_to_index.find(sequence_value); + OV_ERR_FAIL_COND(it == sequence_to_index.end()); + sequence_to_index.unordered_erase(it); +} + +HostManager::PacketCacheIndex HostManager::cache_packet_for(client_id_type client_id, std::span bytes) { + decltype(clients)::iterator it = clients.find(client_id); + OV_ERR_FAIL_COND_V(it == clients.end(), (PacketCacheIndex { packet_cache.end(), packet_cache.end() })); + + PacketCacheIndex index = add_packet_to_cache(bytes); + OV_ERR_FAIL_COND_V(!index.is_valid(), index); + + return it.value().add_cache_index(index); +} + +HostManager::PacketCacheIndex HostManager::add_packet_to_cache(std::span bytes) { + OV_ERR_FAIL_COND_V( + bytes.size_bytes() > ReliableUdpClient::MAX_PACKET_SIZE, (PacketCacheIndex { packet_cache.end(), packet_cache.end() }) + ); + + decltype(packet_cache)::iterator begin = packet_cache.append(bytes.data(), bytes.size_bytes()); + decltype(packet_cache)::iterator end = packet_cache.end(); + return { begin, end }; +} + +memory::vector HostManager::get_packet_cache_for(client_id_type client_id, sequence_type sequence_value) { + decltype(clients)::iterator client_it = clients.find(client_id); + OV_ERR_FAIL_COND_V(client_it == clients.end(), {}); + + decltype(client_it.value().sequence_to_index)::iterator it = client_it.value().sequence_to_index.find(sequence_value); + OV_ERR_FAIL_COND_V(it == client_it.value().sequence_to_index.end(), {}); + return { it.value().begin, it.value().end }; +} + +void HostManager::remove_from_cache_for(client_id_type client_id, sequence_type sequence_value) { + decltype(clients)::iterator it = clients.find(client_id); + OV_ERR_FAIL_COND(it == clients.end()); + + it.value().remove_packet(sequence_value); +} diff --git a/src/openvic-simulation/multiplayer/HostManager.hpp b/src/openvic-simulation/multiplayer/HostManager.hpp new file mode 100644 index 000000000..c1efcdf7c --- /dev/null +++ b/src/openvic-simulation/multiplayer/HostManager.hpp @@ -0,0 +1,210 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/BaseMultiplayerManager.hpp" +#include "openvic-simulation/multiplayer/lowlevel/HostnameAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/ReliableUdpServer.hpp" +#include "openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp" +#include "openvic-simulation/multiplayer/lowlevel/TcpServer.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/types/Signal.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct HostSession; + struct ReliableUdpClient; + + struct HostManager final : BaseMultiplayerManager, observer { + HostManager(GameManager* game_manager = nullptr); + ~HostManager(); + + bool listen(NetworkSocket::port_type port, HostnameAddress const& bind_address = IpAddress { "*" }); + + bool broadcast_packet(PacketType const& type, PacketType::argument_type argument) override; + bool send_packet(client_id_type client_id, PacketType const& type, PacketType::argument_type argument) override; + int64_t poll() override; + void close() override; + + size_t send_resource(std::string_view path, bool recursive = true); + + std::future check_for_resource(client_id_type client_id, memory::string const& check_path); + + static constexpr ConnectionType type_tag = ConnectionType::HOST; + inline constexpr ConnectionType get_connection_type() const override { + return type_tag; + } + + bool broadcast_data(std::span bytes); + bool send_data(client_id_type client_id, std::span bytes); + + HostSession& get_host_session(); + + ReliableUdpClient* get_client_by_id(client_id_type client_id); + ReliableUdpClient const* get_client_by_id(client_id_type client_id) const; + + template + struct ClientIterable; + + ClientIterable client_iterable(); + ClientIterable client_iterable() const; + + private: + ReliableUdpServer PROPERTY_REF(server); + TcpServer resource_server; + + struct ClientValue { + memory::unique_ptr client; + ordered_map sequence_to_index; + + PacketCacheIndex add_cache_index(PacketCacheIndex const& index); + void remove_packet(sequence_type sequence_value); + }; + ordered_map clients; + + HostManager::PacketCacheIndex add_packet_to_cache(std::span bytes); + + HostManager::PacketCacheIndex cache_packet_for(client_id_type client_id, std::span bytes); + memory::vector get_packet_cache_for(client_id_type client_id, sequence_type sequence_value); + void remove_from_cache_for(client_id_type client_id, sequence_type sequence_value); + + void _on_session_changed(); + + memory::vector> resource_clients; + + public: + template + struct ClientIterable { + class iterator final + : protected std::conditional_t { + friend struct ClientIterable; + using base_type = std::conditional_t; + + iterator(const base_type& other) noexcept : base_type(other) {} + + public: + using client_pointer = std::conditional_t< + IsConst, const decltype(decltype(clients)::value_type::second_type::client)::pointer, + decltype(decltype(clients)::value_type::second_type::client)::pointer>; + + using pair_type = std::pair; + + using difference_type = decltype(clients)::difference_type; + + iterator() = default; + + template + requires OtherConst + iterator(const ClientIterable::iterator& other) : base_type(other) {} + + iterator(const iterator& other) = default; + iterator(iterator&& other) = default; + iterator& operator=(const iterator& other) = default; + iterator& operator=(iterator&& other) = default; + + iterator operator++() { + ++(*this); + return *this; + } + + iterator operator--() { + --(*this); + return *this; + } + + iterator operator++(int) { + iterator tmp { *this }; + ++(*this); + return tmp; + } + + iterator operator--(int) { + iterator tmp { *this }; + --(*this); + return tmp; + } + + pair_type::first_type key() const { + return base_type::key(); + } + + pair_type::second_type value() const { + return base_type::value().client.get(); + } + + pair_type operator*() const { + return { key(), value() }; + } + + client_pointer operator->() const { + return value(); + } + + pair_type operator[](difference_type n) const { + return *(*this + n); + } + + iterator& operator+=(difference_type n) { + *this += n; + return *this; + } + + iterator& operator-=(difference_type n) { + *this -= n; + return *this; + } + + iterator operator+(difference_type n) const { + iterator tmp { *this }; + tmp += n; + return tmp; + } + + iterator operator-(difference_type n) const { + iterator tmp { *this }; + tmp -= n; + return tmp; + } + + friend iterator operator+(difference_type lhs, iterator const& rhs) { + return rhs + lhs; + } + + difference_type operator-(iterator const& rhs) const { + return *this - rhs; + } + + bool operator==(iterator const& rhs) const { + return *this == rhs; + } + + auto operator<=>(iterator const& rhs) const { + return *this <=> rhs; + } + }; + + iterator begin() const { + return iterator { source.begin() }; + } + + iterator end() const { + return iterator { source.end() }; + } + + size_t size() const { + return source.size(); + } + + private: + friend struct HostManager; + std::conditional_t source; + ClientIterable(decltype(source) source) : source { source } {} + }; + }; +} diff --git a/src/openvic-simulation/multiplayer/HostSession.cpp b/src/openvic-simulation/multiplayer/HostSession.cpp new file mode 100644 index 000000000..cd3736642 --- /dev/null +++ b/src/openvic-simulation/multiplayer/HostSession.cpp @@ -0,0 +1,58 @@ +#include "HostSession.hpp" + +#include + +#include "openvic-simulation/multiplayer/Constants.hpp" +#include "openvic-simulation/player/PlayerManager.hpp" +#include "openvic-simulation/types/Signal.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" + +using namespace OpenVic; + +HostSession::HostSession(memory::string game_name) : game_name { game_name } {} + +HostSession::~HostSession() { + this->disconnect_all(); +} + +HostSession::HostSession(HostSession&& other) { + std::swap(game_name, other.game_name); + std::swap(players, other.players); + other.session_changed(); +} + +HostSession& HostSession::operator=(HostSession&& other) { + std::swap(game_name, other.game_name); + std::swap(players, other.players); + session_changed(); + return *this; +} + +void HostSession::set_game_name(memory::string new_game_name) { + game_name = new_game_name; + session_changed(); +} + +Player* HostSession::add_player(client_id_type client_id, memory::string player_name) { + std::pair success = + players.insert(std::make_pair(client_id, Player { client_id, player_name })); + OV_ERR_FAIL_COND_V(!success.second, nullptr); + success.first.value().player_changed.connect([this]() { + session_changed(); + }); + session_changed(); + return &success.first.value(); +} + +bool HostSession::remove_player(client_id_type client_id) { + OV_ERR_FAIL_COND_V(players.unordered_erase(client_id) != 1, false); + session_changed(); + return true; +} + +Player* HostSession::get_player_by(client_id_type client_id) { + decltype(players)::iterator it = players.find(client_id); + OV_ERR_FAIL_COND_V(it == players.end(), nullptr); + return &it.value(); +} diff --git a/src/openvic-simulation/multiplayer/HostSession.hpp b/src/openvic-simulation/multiplayer/HostSession.hpp new file mode 100644 index 000000000..41d444dd2 --- /dev/null +++ b/src/openvic-simulation/multiplayer/HostSession.hpp @@ -0,0 +1,56 @@ +#pragma once + +#include +#include + +#include "openvic-simulation/multiplayer/Constants.hpp" +#include "openvic-simulation/player/PlayerManager.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/types/Signal.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/Getters.hpp" +#include "openvic-simulation/utility/Marshal.hpp" + +namespace OpenVic { + struct Player; + + struct HostSession : observer { + signal_property session_changed; + + HostSession(memory::string game_name = "HostSession"); + ~HostSession(); + + HostSession(HostSession const&) = delete; + HostSession& operator=(HostSession const&) = delete; + HostSession(HostSession&&); + HostSession& operator=(HostSession&& other); + + template + size_t encode(std::span span) const { + size_t offset = utility::encode(game_name, span); + offset += utility::encode(players, span.empty() ? span : span.subspan(offset)); + return offset; + } + + template + static HostSession decode(std::span span, size_t& r_decode_count) { + HostSession result { utility::decode(span, r_decode_count) }; + size_t offset = r_decode_count; + decltype(players) players = utility::decode(span.subspan(offset), r_decode_count); + r_decode_count += offset; + result.players.swap(players); + return result; + } + + void set_game_name(memory::string new_game_name); + + Player* add_player(client_id_type client_id, memory::string player_name); + bool remove_player(client_id_type client_id); + + Player* get_player_by(client_id_type client_id); + + private: + memory::string PROPERTY(game_name); + vector_ordered_map PROPERTY(players); + }; +} diff --git a/src/openvic-simulation/multiplayer/PacketType.cpp b/src/openvic-simulation/multiplayer/PacketType.cpp new file mode 100644 index 000000000..6c8c1bcf6 --- /dev/null +++ b/src/openvic-simulation/multiplayer/PacketType.cpp @@ -0,0 +1,244 @@ +#include "PacketType.hpp" + +#include + +#include "openvic-simulation/GameManager.hpp" +#include "openvic-simulation/misc/GameAction.hpp" +#include "openvic-simulation/multiplayer/BaseMultiplayerManager.hpp" +#include "openvic-simulation/multiplayer/ClientManager.hpp" +#include "openvic-simulation/multiplayer/HostManager.hpp" +#include "openvic-simulation/multiplayer/HostSession.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketBuilder.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" + +using namespace OpenVic; + +bool PacketTypes::retransmit_packet_send_callback( + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder +) { + RetransmitData const* const* data_ptr = std::get_if(&argument); + OV_ERR_FAIL_NULL_V(data_ptr, false); + + RetransmitData const& data = **data_ptr; + OV_ERR_FAIL_COND_V(!PacketType::is_valid_type(data.packet_id), false); + + PacketType const& packet_type = PacketType::get_type_by_id(data.packet_id); + OV_ERR_FAIL_COND_V(!packet_type.can_client_send, false); + + builder.put_back(data.client_id); + builder.put_back(data.packet_id); + return packet_type.create_packet(multiplayer_manager, data.argument, builder); +} + +bool PacketTypes::retransmit_packet_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet) { + uint64_t client_id = packet.read(); + + size_t start_index = packet.index(); + decltype(PacketType::packet_id) packet_id = packet.read(); + + OV_ERR_FAIL_COND_V(!PacketType::is_valid_type(packet_id), false); + + PacketType const& packet_type = PacketType::get_type_by_id(packet_id); + OV_ERR_FAIL_COND_V(!packet_type.can_client_send, false); + + HostManager* host_manager = multiplayer_manager->cast_as(); + OV_ERR_FAIL_NULL_V(host_manager, false); + + return host_manager->send_data(client_id, packet.subspan(start_index)); +} + +bool PacketTypes::broadcast_packet_send_callback( + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder +) { + BroadcastData const* const* data_ptr = std::get_if(&argument); + OV_ERR_FAIL_NULL_V(data_ptr, false); + + BroadcastData const& data = **data_ptr; + OV_ERR_FAIL_COND_V(!PacketType::is_valid_type(data.packet_id), false); + + PacketType const& packet_type = PacketType::get_type_by_id(data.packet_id); + OV_ERR_FAIL_COND_V(!packet_type.can_client_send, false); + + builder.put_back(data.packet_id); + return packet_type.create_packet(multiplayer_manager, data.argument, builder); +} + +bool PacketTypes::broadcast_packet_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet) { + size_t start_index = packet.index(); + + decltype(PacketType::packet_id) packet_id = packet.read(); + OV_ERR_FAIL_COND_V(!PacketType::is_valid_type(packet_id), false); + + PacketType const& packet_type = PacketType::get_type_by_id(packet_id); + OV_ERR_FAIL_COND_V(!packet_type.can_client_send, false); + + HostManager* host_manager = multiplayer_manager->cast_as(); + OV_ERR_FAIL_NULL_V(host_manager, false); + + return host_manager->broadcast_data(packet.subspan(start_index)); +} + +bool PacketTypes::send_raw_packet_send_callback( + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder +) { + std::span const* data = std::get_if(&argument); + OV_ERR_FAIL_NULL_V(data, false); + + builder.put_back>(*data); + return true; +} + +bool PacketTypes::send_raw_packet_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet) { + std::span span = packet.read>(); + multiplayer_manager->last_raw_packet.resize(span.size()); + std::uninitialized_copy_n(span.begin(), span.size(), multiplayer_manager->last_raw_packet.begin()); + return true; +} + +bool PacketTypes::update_host_session_send_callback( + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder +) { + HostSession const* const* session_ptr = std::get_if(&argument); + OV_ERR_FAIL_NULL_V(session_ptr, false); + + HostSession const& session = **session_ptr; + builder.put_back(session); + return true; +} + +bool PacketTypes::update_host_session_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet) { + ClientManager* client_manager = multiplayer_manager->cast_as(); + OV_ERR_FAIL_NULL_V(client_manager, false); + + client_manager->host_session = packet.readhost_session)>(); + return true; +} + +bool PacketTypes::execute_game_action_send_callback( + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder +) { + game_action_t const* const* action_ptr = std::get_if(&argument); + OV_ERR_FAIL_NULL_V(action_ptr, false); + + builder.put_back(**action_ptr); + return true; +} + +bool PacketTypes::execute_game_action_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet) { + game_action_t action = packet.read(); + + multiplayer_manager->get_game_manager()->get_instance_manager()->queue_game_action(action.first, std::move(action).second); + return true; +} + +bool PacketTypes::send_chat_message_send_callback( + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder +) { + ChatMessageLog const* const* message_ptr = std::get_if(&argument); + OV_ERR_FAIL_NULL_V(message_ptr, false); + + ChatMessageLog const& message = **message_ptr; + builder.put_back(message); + return true; +} + +bool PacketTypes::send_chat_message_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet) { + GameManager* game_manager = multiplayer_manager->get_game_manager(); + ClientManager* client_manager = game_manager->get_client_manager(); + + size_t index = packet.index(); + uint16_t from_id = packet.read(); + if (client_manager && client_manager->get_player()->get_client_id() == from_id) { + return true; + } + + packet.seek(index); + game_manager->get_chat_manager()->_log_message(packet.read()); + return true; +} + +bool PacketTypes::add_chat_group_send_callback( + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder +) { + std::span const* clients_ptr = std::get_if(&argument); + OV_ERR_FAIL_NULL_V(clients_ptr, false); + + builder.put_back(*clients_ptr); + return true; +} + +bool PacketTypes::add_chat_group_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet) { + multiplayer_manager->get_game_manager()->get_chat_manager()->_create_group(packet.read>()); + return true; +} + +bool PacketTypes::modify_chat_group_send_callback( + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder +) { + PacketChatGroupModifyData const* modify_data_ptr = std::get_if(&argument); + OV_ERR_FAIL_NULL_V(modify_data_ptr, false); + + builder.put_back(*modify_data_ptr); + return true; +} + +bool PacketTypes::modify_chat_group_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet) { + ChatGroup::index_type group_index = packet.read(); + + multiplayer_manager->get_game_manager()->get_chat_manager()->_set_group( + group_index, packet.read>() + ); + return true; +} + +bool PacketTypes::delete_chat_group_send_callback( + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder +) { + ChatGroup::index_type const* group_id_ptr = std::get_if(&argument); + OV_ERR_FAIL_NULL_V(group_id_ptr, false); + + builder.put_back(*group_id_ptr); + return true; +} + +bool PacketTypes::delete_chat_group_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet) { + multiplayer_manager->get_game_manager()->get_chat_manager()->_delete_group(packet.read()); + return true; +} + +bool PacketTypes::send_battleplan_send_callback( + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder +) { + // TODO: needs battleplan implementation + return true; +} + +bool PacketTypes::send_battleplan_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet) { + // TODO: need battleplan implementation + return true; +} + +bool PacketTypes::notify_player_left_send_callback( + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder +) { + // TODO: needs player left notify implementation + return true; +} + +bool PacketTypes::notify_player_left_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet) { + // TODO: needs player left notify implementation + return true; +} + +bool PacketTypes::set_ready_status_send_callback( + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder +) { + // TODO: needs player's set ready status implementation + return true; +} + +bool PacketTypes::set_ready_status_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet) { + // TODO: needs player's set ready status implementation + return true; +} diff --git a/src/openvic-simulation/multiplayer/PacketType.hpp b/src/openvic-simulation/multiplayer/PacketType.hpp new file mode 100644 index 000000000..806ee71b2 --- /dev/null +++ b/src/openvic-simulation/multiplayer/PacketType.hpp @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include +#include + +#include "openvic-simulation/misc/GameAction.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketBuilder.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp" + +namespace OpenVic { + struct BaseMultiplayerManager; + + using PacketChatGroupModifyData = std::pair>; + + // F_TYPE(packet_name, can_client_send, packet_data_type) + // F(packet_name, can_client_send) +#define PACKET_LIST(F, F_TYPE) \ + F_TYPE(retransmit_packet, true, RetransmitData const*) \ + F_TYPE(broadcast_packet, true, BroadcastData const*) \ + F_TYPE(send_raw_packet, true, std::span) \ + F_TYPE(update_host_session, false, HostSession const*) \ + F_TYPE(execute_game_action, false, game_action_t const*) \ + F_TYPE(send_chat_message, true, ChatMessageLog const*) \ + F_TYPE(add_chat_group, true, std::span) \ + F_TYPE(modify_chat_group, true, PacketChatGroupModifyData) \ + F_TYPE(delete_chat_group, true, size_t) \ + F(send_battleplan, true) \ + F_TYPE(notify_player_left, false, uint64_t) \ + F_TYPE(set_ready_status, true, bool) + + struct RetransmitData; + struct BroadcastData; + struct HostSession; + struct ChatMessageLog; + + struct PacketType { +#define DEFINE_TYPE(_, _2, DATA_TYPE) DATA_TYPE, +#define IGNORE(_, _2) std::monostate, + + using argument_type = std::variant; + +#undef DEFINE_TYPE +#undef IGNORE + + using packet_type_send_callback_t = std::add_pointer_t& builder + )>; + using packet_type_process_callback_t = + std::add_pointer_t; + + uint16_t packet_id; + bool can_client_send; + packet_type_send_callback_t create_packet; + packet_type_process_callback_t process_packet; + + static constexpr bool is_valid_type(decltype(packet_id) id); + static constexpr PacketType const& get_type_by_id(decltype(packet_id) id); + + constexpr operator decltype(packet_id)() const { + return packet_id; + } + }; + + namespace PacketTypes { +#define DEFINE_PACKET_TYPE(NAME, CAN_CLIENT_SEND, ...) \ + bool NAME##_send_callback( \ + BaseMultiplayerManager const* multiplayer_manager, PacketType::argument_type argument, PacketBuilder<>& builder \ + ); \ + bool NAME##_process_callback(BaseMultiplayerManager* multiplayer_manager, PacketSpan packet); \ + static constexpr PacketType NAME = { __COUNTER__, CAN_CLIENT_SEND, &NAME##_send_callback, &NAME##_process_callback }; + + PACKET_LIST(DEFINE_PACKET_TYPE, DEFINE_PACKET_TYPE) + +#undef DEFINE_PACKET_TYPE + +#define DEFINE_PACKET_TYPE_IN_ARRAY(NAME, _, ...) &NAME, + + static constexpr std::array _packet_types = std::to_array({ + PACKET_LIST(DEFINE_PACKET_TYPE_IN_ARRAY, DEFINE_PACKET_TYPE_IN_ARRAY) // + }); + +#undef DEFINE_PACKET_TYPE_IN_ARRAY + }; + + constexpr bool PacketType::is_valid_type(decltype(PacketType::packet_id) id) { + return id < PacketTypes::_packet_types.size(); + } + + constexpr PacketType const& PacketType::get_type_by_id(decltype(PacketType::packet_id) id) { + return *PacketTypes::_packet_types[id]; + } + +#undef PACKET_LIST + + struct RetransmitData { + uint64_t client_id; + decltype(PacketType::packet_id) packet_id; + PacketType::argument_type argument; + }; + + struct BroadcastData { + decltype(PacketType::packet_id) packet_id; + PacketType::argument_type argument; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/Constants.hpp b/src/openvic-simulation/multiplayer/lowlevel/Constants.hpp new file mode 100644 index 000000000..bd46735da --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/Constants.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace OpenVic { + using socket_port_type = uint16_t; + using reliable_udp_sequence_type = uint16_t; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.cpp b/src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.cpp new file mode 100644 index 000000000..27865b474 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.cpp @@ -0,0 +1,30 @@ + +#include "HostnameAddress.hpp" + +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" + +using namespace OpenVic; + +HostnameAddress::HostnameAddress() = default; + +HostnameAddress::HostnameAddress(IpAddress const& address) : _resolved_address(address) {} + +HostnameAddress::HostnameAddress(std::string_view name_or_address) : HostnameAddress() { + std::from_chars_result result = + _resolved_address.from_chars(name_or_address.data(), name_or_address.data() + name_or_address.size()); + if (result.ec != std::errc {}) { + _resolved_address = NetworkResolver::singleton().resolve_hostname(name_or_address); + } +} + +IpAddress const& HostnameAddress::resolved_address() const { + return _resolved_address; +} + +void HostnameAddress::set_resolved_address(IpAddress const& address) { + _resolved_address = address; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.hpp b/src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.hpp new file mode 100644 index 000000000..a36aff7b9 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/HostnameAddress.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" + +namespace OpenVic { + struct HostnameAddress { + HostnameAddress(); + HostnameAddress(IpAddress const& address); + HostnameAddress(std::string_view name_or_address); + + IpAddress const& resolved_address() const; + void set_resolved_address(IpAddress const& address); + + private: + IpAddress _resolved_address; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/IpAddress.cpp b/src/openvic-simulation/multiplayer/lowlevel/IpAddress.cpp new file mode 100644 index 000000000..92c4ac31b --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/IpAddress.cpp @@ -0,0 +1,18 @@ +#include "IpAddress.hpp" + +#include "openvic-simulation/utility/Containers.hpp" + +using namespace OpenVic; + +memory::string IpAddress::to_string(bool prefer_ipv4, to_chars_option option) const { + stack_string result = to_array(prefer_ipv4, option); + if (OV_unlikely(result.empty())) { + return {}; + } + + return result; +} + +IpAddress::operator memory::string() const { + return to_string(); +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/IpAddress.hpp b/src/openvic-simulation/multiplayer/lowlevel/IpAddress.hpp new file mode 100644 index 000000000..c4bfd4e3e --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/IpAddress.hpp @@ -0,0 +1,415 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvic-simulation/types/StackString.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/StringUtils.hpp" + +namespace OpenVic { + struct IpAddress { + constexpr IpAddress() : _field {} { + clear(); + } + + constexpr IpAddress(char const* cstring) : IpAddress(std::string_view { cstring }) {} + + constexpr IpAddress(std::string_view string) : _field {} { + clear(); + from_chars(string.data(), string.data() + string.size()); + } + + constexpr IpAddress(uint32_t a, uint32_t b, uint32_t c, uint32_t d, bool is_ipv6 = false) : _field {} { + clear(); + _valid = true; + if (is_ipv6) { + auto _set_buffer = [](uint8_t* buffer, uint32_t value) constexpr { + buffer[0] = (value >> 24) & 0xff; + buffer[1] = (value >> 16) & 0xff; + buffer[2] = (value >> 8) & 0xff; + buffer[3] = (value >> 0) & 0xff; + }; + _set_buffer(&_field._8bit[0], a); + _set_buffer(&_field._8bit[4], b); + _set_buffer(&_field._8bit[8], c); + _set_buffer(&_field._8bit[12], d); + return; + } + if (std::is_constant_evaluated()) { + _field._8bit[10] = 0xff; + _field._8bit[11] = 0xff; + } else { + _field._16bit[5] = 0xffff; + } + _field._8bit[12] = a; + _field._8bit[13] = b; + _field._8bit[14] = c; + _field._8bit[15] = d; + } + + constexpr void clear() { + if (std::is_constant_evaluated()) { + std::fill_n(_field._8bit.data(), sizeof(_field._8bit), 0); + } else { + std::memset(_field._8bit.data(), 0, sizeof(_field._8bit)); + } + _wildcard = false; + _valid = false; + } + + constexpr bool is_wildcard() const { + return _wildcard; + } + + constexpr bool is_valid() const { + return _valid; + } + + constexpr bool is_ipv4() const { + if (std::is_constant_evaluated()) { + std::array casted_fields = std::bit_cast>(_field._8bit); + return casted_fields[0] == 0 && casted_fields[1] == 0 && casted_fields[2] == 0 && casted_fields[3] == 0 && + casted_fields[4] == 0 && casted_fields[5] == 0xffff; + } else { + return (_field._32bit[0] == 0 && _field._32bit[1] == 0 && _field._16bit[4] == 0 && _field._16bit[5] == 0xffff); + } + } + + constexpr std::span get_ipv4() const { + if (std::is_constant_evaluated()) { + return std::span { _field._8bit }.subspan<12, 4>(); + } else { + OV_ERR_FAIL_COND_V_MSG( + !is_ipv4(), (std::span { _field._8bit }.subspan<12, 4>()), + "IPv4 requested, but current IP is IPv6." + ); + } + return std::span { _field._8bit }.subspan<12, 4>(); + } + + constexpr void set_ipv4(std::span ip) { + clear(); + _valid = true; + if (std::is_constant_evaluated()) { + _field._8bit[10] = 0xff; + _field._8bit[11] = 0xff; + std::copy_n(ip.data(), 4, &_field._8bit[12]); + } else { + _field._16bit[5] = 0xffff; + _field._32bit[3] = *std::bit_cast(ip.data()); + } + } + + constexpr void set_ipv4(std::initializer_list list) { + set_ipv4(std::span { list }); + } + + constexpr std::span get_ipv6() const { + return _field._8bit; + } + + constexpr void set_ipv6(std::span ip) { + clear(); + _valid = true; + for (size_t i = 0; i < _field._8bit.size(); i++) { + _field._8bit[i] = ip[i]; + } + } + + constexpr void set_ipv6(std::initializer_list list) { + set_ipv6(std::span { list }); + } + + constexpr bool operator==(IpAddress const& ip) const { + if (ip._valid != _valid) { + return false; + } + if (!_valid) { + return false; + } + if (std::is_constant_evaluated()) { + for (int i = 0; i < _field._8bit.size(); i++) { + if (_field._8bit[i] != ip._field._8bit[i]) { + return false; + } + } + } else { + for (int i = 0; i < _field._32bit.size(); i++) { + if (_field._32bit[i] != ip._field._32bit[i]) { + return false; + } + } + } + return true; + } + + enum class to_chars_option : uint8_t { // + SHORTEN_IPV6, + COMPRESS_IPV6, + EXPAND_IPV6, + }; + + inline constexpr std::to_chars_result to_chars( // + char* first, char* last, bool prefer_ipv4 = true, to_chars_option option = to_chars_option::SHORTEN_IPV6 + ) const { + if (first == nullptr || first >= last) { + return { last, std::errc::value_too_large }; + } + + if (_wildcard) { + *first = '*'; + ++first; + return { last, std::errc {} }; + } + + if (!_valid) { + return { last, std::errc {} }; + } + + std::to_chars_result result = { first, std::errc {} }; + if (is_ipv4() && prefer_ipv4) { + for (size_t i = 12; i < 16; i++) { + if (i > 12) { + if (last - result.ptr <= 1) { + return { last, std::errc::value_too_large }; + } + + *result.ptr = '.'; + ++result.ptr; + } + + result = StringUtils::to_chars(result.ptr, last, _field._8bit[i]); + if (result.ec != std::errc {}) { + return result; + } + } + return result; + } + + auto section_func = [this](size_t index) constexpr -> uint16_t { + return (_field._8bit[index * 2] << 8) + _field._8bit[index * 2 + 1]; + }; + + int32_t compress_pos = -1; + if (option == to_chars_option::COMPRESS_IPV6) { + int32_t last_compress_count = 0; + for (size_t i = 0; i < 8; i++) { + if (_field._8bit[i * 2] == 0 && _field._8bit[i * 2 + 1] == 0 && section_func(i + 1) == 0) { + int32_t compress_check = i; + do { + ++i; + } while (i < 8 && section_func(i) == 0); + + if (int32_t compress_count = i - compress_check; + compress_count >= 2 && last_compress_count < compress_count) { + compress_pos = compress_check; + last_compress_count = compress_count; + } + } + } + } + + for (size_t i = 0; i < 8; i++) { + if (compress_pos > -1 && compress_pos == i) { + if (last - result.ptr <= 2) { + return { last, std::errc::value_too_large }; + } + + result.ptr[0] = ':'; + result.ptr[1] = ':'; + result.ptr += 2; + do { + ++i; + } while (i < 8 && section_func(i) == 0); + } else if (i > 0) { + *result.ptr = ':'; + ++result.ptr; + } + + uint16_t section = section_func(i); + if (option == to_chars_option::EXPAND_IPV6) { + if (last - result.ptr < 4) { + return { last, std::errc::value_too_large }; + } + + if (section < 0xFFF) { + *result.ptr = '0'; + ++result.ptr; + } + if (section < 0xFF) { + *result.ptr = '0'; + ++result.ptr; + } + if (section < 0xF) { + *result.ptr = '0'; + ++result.ptr; + } + if (section == 0) { + *result.ptr = '0'; + ++result.ptr; + continue; + } + } + result = StringUtils::to_chars(result.ptr, last, section, 16); + if (result.ec != std::errc {}) { + return result; + } + } + return result; + } + + struct stack_string; + inline constexpr stack_string to_array( // + bool prefer_ipv4 = true, to_chars_option option = to_chars_option::SHORTEN_IPV6 + ) const; + + struct stack_string final : StackString<39> { + protected: + using StackString::StackString; + friend inline constexpr stack_string IpAddress::to_array(bool prefer_ipv4, to_chars_option option) const; + }; + + memory::string to_string(bool prefer_ipv4 = true, to_chars_option option = to_chars_option::SHORTEN_IPV6) const; + explicit operator memory::string() const; + + constexpr std::from_chars_result from_chars(char const* begin, char const* end) { + if (begin == nullptr || begin >= end) { + return { begin, std::errc::invalid_argument }; + } + + if (*begin == '*') { + _wildcard = true; + _valid = false; + return { begin + 1, std::errc {} }; + } + + size_t check_for_ipv4 = 0; + for (char const* check = begin; check < end; check++) { + if (*check == '.') { + check_for_ipv4++; + } + } + + fields_type fields { ._8bit {} }; + + std::from_chars_result result = { begin, std::errc {} }; + if (check_for_ipv4 == 3) { + if (std::is_constant_evaluated()) { + fields._8bit[10] = 0xff; + fields._8bit[11] = 0xff; + } else { + fields._16bit[5] = 0xffff; + } + for (size_t i = 0; i < 4; i++) { + if (*result.ptr == '.') { + if (i > 0 && end - result.ptr >= 1) { + ++result.ptr; + } else { + return { begin, std::errc::invalid_argument }; + } + } + + result = StringUtils::from_chars(result.ptr, end, fields._8bit[12 + i]); + if (result.ec != std::errc {}) { + return result; + } + } + _field._8bit.swap(fields._8bit); + _valid = true; + return result; + } else if (check_for_ipv4 > 0) { + return { begin, std::errc::invalid_argument }; + } + + int32_t compress_start = -1; + size_t colon_count = 0; + if (std::is_constant_evaluated()) { + fields._16bit = std::bit_cast>(fields._8bit); + } + for (size_t i = 0; i < fields._16bit.size(); i++) { + if (*result.ptr == ':') { + if (end - result.ptr <= 2) { + return { result.ptr, std::errc::invalid_argument }; + } + + ++result.ptr; + if (*result.ptr == ':') { + if (compress_start > -1 || i >= fields._16bit.size() - 2) { + return { result.ptr, std::errc::invalid_argument }; + } + + ++result.ptr; + compress_start = i; + } else { + ++colon_count; + } + } + + uint16_t big_endian_value; + result = StringUtils::from_chars(result.ptr, end, big_endian_value, 16); + fields._16bit[i] = (big_endian_value >> 8) | (big_endian_value << 8); + if (result.ec != std::errc {}) { + return result; + } + } + + if (compress_start > -1) { + constexpr size_t expected_ipv6_colons = 7; + size_t index = compress_start - 1; + // Pre-compression + std::swap_ranges(_field._16bit.data(), _field._16bit.data() + index, fields._16bit.data()); + size_t compress_end = expected_ipv6_colons - colon_count - 1; + // Compression + if (std::is_constant_evaluated()) { + std::fill_n(_field._16bit.data() + compress_start, compress_end - 1, 0); + } else { + std::memset(_field._16bit.data() + compress_start, 0, compress_end - 1); + } + // Post-compression + std::swap_ranges( + _field._16bit.data() + compress_end, fields._16bit.data() + fields._16bit.size(), + fields._16bit.data() + index + ); + } else { + if (std::is_constant_evaluated()) { + _field._16bit = std::bit_cast>(_field._8bit); + } + _field._16bit.swap(fields._16bit); + } + + if (std::is_constant_evaluated()) { + _field._8bit = std::bit_cast>(_field._16bit); + } + + _valid = true; + return result; + } + + private: + union fields_type { + std::array _8bit; + std::array _16bit; + std::array _32bit; + } _field; + + bool _wildcard = false; + bool _valid = false; + }; + + inline constexpr IpAddress::stack_string IpAddress::to_array(bool pref_ipv4, to_chars_option option) const { + stack_string str {}; + std::to_chars_result result = to_chars(str._array.data(), str._array.data() + str._array.size(), pref_ipv4, option); + str._string_size = result.ptr - str.data(); + return str; + } +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/NetworkError.hpp b/src/openvic-simulation/multiplayer/lowlevel/NetworkError.hpp new file mode 100644 index 000000000..03fdf4b59 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/NetworkError.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include + +namespace OpenVic { + enum class NetworkError : uint8_t { + OK, + FAILED, + UNAVAILABLE, + UNCONFIGURED, + UNAUTHORIZED, + OUT_OF_MEMORY, + LOCKED, + EMPTY_BUFFER, + BUG, + + ALREADY_OPEN, + NOT_OPEN, + INVALID_PARAMETER, + SOCKET_ERROR, + BROADCAST_CHANGE_FAILED, + UNSUPPORTED, + BUSY, + + WOULD_BLOCK, + IS_CONNECTED, + IN_PROGRESS, + ADDRESS_INVALID_OR_UNAVAILABLE, + BUFFER_TOO_SMALL, + OTHER, + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp b/src/openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp new file mode 100644 index 000000000..493285268 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp @@ -0,0 +1,19 @@ +#pragma once + +#ifdef _WIN32 +#include "openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.hpp" +#elif defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)) +#include "openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.hpp" +#else +#error "NetworkResolver.hpp only supports unix or windows systems" +#endif + +namespace OpenVic { + using NetworkResolver = +#ifdef _WIN32 + WindowsNetworkResolver +#else + UnixNetworkResolver +#endif + ; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.cpp b/src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.cpp new file mode 100644 index 000000000..2818af7e9 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.cpp @@ -0,0 +1,277 @@ +#include "NetworkResolverBase.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/StringUtils.hpp" + +using namespace OpenVic; + +static memory::string _get_cache_key(std::string_view p_hostname, NetworkResolverBase::Type p_type) { + memory::string key { p_hostname }; + key.insert(key.begin(), '0'); + key.reserve(p_hostname.size() + 1); + + std::to_chars_result result = StringUtils::to_chars(key.data(), key.data() + 1, static_cast(p_type)); + if (result.ec != std::errc {}) { + key = memory::string { p_hostname }; + } + + return key; +} + +NetworkResolverBase::ResolveHandler::Item::Item() : type { Type::NONE } { + clear(); +} + +void NetworkResolverBase::ResolveHandler::Item::clear() { + status.store(static_cast(ResolverStatus::NONE)); + response.clear(); + type = Type::NONE; + hostname.clear(); +} + +int32_t NetworkResolverBase::ResolveHandler::find_empty_id() const { + for (int i = 0; i < MAX_QUERIES; i++) { + if (queue[i].status.load() == static_cast(ResolverStatus::NONE)) { + return i; + } + } + return INVALID_ID; +} + +void NetworkResolverBase::ResolveHandler::resolve_queues() { + for (int i = 0; i < MAX_QUERIES; i++) { + if (queue[i].status.load() != static_cast(ResolverStatus::WAITING)) { + continue; + } + + std::unique_lock lock { mutex }; + memory::string hostname = queue[i].hostname; + Type type = queue[i].type; + + lock.unlock(); + + // We should not lock while resolving the hostname, + // only when modifying the queue. + memory::vector response = NetworkResolver::singleton()._resolve_hostname(hostname, type); + + lock.lock(); + // Could have been completed by another function, or deleted. + if (queue[i].status.load() != static_cast(ResolverStatus::WAITING)) { + continue; + } + // We might be overriding another result, but we don't care as long as the result is valid. + if (response.size()) { + memory::string key = _get_cache_key(hostname, type); + cache[key] = response; + } + queue[i].response = response; + queue[i].status.store(static_cast(response.empty() ? ResolverStatus::ERROR : ResolverStatus::DONE)); + } +} + +void NetworkResolverBase::ResolveHandler::_thread_runner(ResolveHandler& handler) { + while (!handler.should_abort.load(std::memory_order_acquire)) { + handler.semaphore.acquire(); + handler.resolve_queues(); + } +} + +NetworkResolverBase::ResolveHandler::ResolveHandler() : thread { _thread_runner, std::ref(*this) }, semaphore { 0 } { + should_abort.store(false, std::memory_order_release); +} + +NetworkResolverBase::ResolveHandler::~ResolveHandler() { + should_abort.store(true, std::memory_order_release); + semaphore.release(); + thread.join(); +} + +IpAddress NetworkResolverBase::resolve_hostname(std::string_view p_hostname, Type p_type) { + const memory::vector addresses = resolve_hostname_addresses(p_hostname, p_type); + return addresses.empty() ? IpAddress {} : addresses.front(); +} + +memory::vector NetworkResolverBase::resolve_hostname_addresses(std::string_view p_hostname, Type p_type) { + memory::string key = _get_cache_key(p_hostname, p_type); + { + std::unique_lock lock(_resolve_handler.mutex); + if (_resolve_handler.cache.contains(key)) { + return _resolve_handler.cache[key]; + } else { + // This should be run unlocked so the resolver thread can keep resolving + // other requests. + lock.unlock(); + memory::vector result = _resolve_hostname(p_hostname, p_type); + lock.lock(); + // We might be overriding another result, but we don't care as long as the result is valid. + if (result.size()) { + _resolve_handler.cache[key] = result; + } + return result; + } + } +} + +int32_t NetworkResolverBase::resolve_hostname_queue_item(std::string_view p_hostname, Type p_type) { + std::unique_lock lock(_resolve_handler.mutex); + + int32_t id = _resolve_handler.find_empty_id(); + + if (id == INVALID_ID) { + Logger::warning("Out of resolver queries"); + return id; + } + + memory::string key = _get_cache_key(p_hostname, p_type); + _resolve_handler.queue[id].hostname = p_hostname; + _resolve_handler.queue[id].type = p_type; + if (_resolve_handler.cache.contains(key)) { + _resolve_handler.queue[id].response = _resolve_handler.cache[key]; + _resolve_handler.queue[id].status.store(static_cast(ResolverStatus::DONE)); + } else { + _resolve_handler.queue[id].response = memory::vector(); + _resolve_handler.queue[id].status.store(static_cast(ResolverStatus::WAITING)); + if (_resolve_handler.thread.joinable()) { + _resolve_handler.semaphore.release(); + } else { + _resolve_handler.resolve_queues(); + } + } + + return id; +} + +NetworkResolverBase::ResolverStatus NetworkResolverBase::get_resolve_item_status(int32_t p_id) const { + OV_ERR_FAIL_INDEX_V_MSG( + p_id, MAX_QUERIES, ResolverStatus::NONE, + fmt::format( + "Too many concurrent DNS resolver queries ({}, but should be {} at most). Try performing less network requests at " + "once.", + p_id, MAX_QUERIES + ) + ); + + ResolverStatus result = static_cast(_resolve_handler.queue[p_id].status.load()); + if (result == ResolverStatus::NONE) { + Logger::error("Condition status == " _OV_STR(ResolverStatus::NONE)); + return ResolverStatus::NONE; + } + return result; +} + +memory::vector NetworkResolverBase::get_resolve_item_addresses(int32_t p_id) const { + OV_ERR_FAIL_INDEX_V_MSG( + p_id, MAX_QUERIES, memory::vector {}, + fmt::format( + "Too many concurrent DNS resolver queries ({}, but should be {} at most). Try performing less network requests at " + "once.", + p_id, MAX_QUERIES + ) + ); + + std::unique_lock lock(const_cast(_resolve_handler.mutex)); + + if (_resolve_handler.queue[p_id].status.load() != static_cast(ResolverStatus::DONE)) { + Logger::error(fmt::format("Resolve of '{}' didn't complete yet.", _resolve_handler.queue[p_id].hostname)); + return {}; + } + + return _resolve_handler.queue[p_id].response | ranges::views::filter([](IpAddress const& addr) -> bool { + return addr.is_valid(); + }) | + ranges::to>(); +} + +IpAddress NetworkResolverBase::get_resolve_item_address(int32_t p_id) const { + OV_ERR_FAIL_INDEX_V_MSG( + p_id, MAX_QUERIES, IpAddress {}, + fmt::format( + "Too many concurrent DNS resolver queries ({}, but should be {} at most). Try performing less network requests at " + "once.", + p_id, MAX_QUERIES + ) + ); + + std::unique_lock lock(const_cast(_resolve_handler.mutex)); + + if (_resolve_handler.queue[p_id].status.load() != static_cast(ResolverStatus::DONE)) { + Logger::error(fmt::format("Resolve of '{}' didn't complete yet.", _resolve_handler.queue[p_id].hostname)); + return {}; + } + + for (IpAddress const& address : _resolve_handler.queue[p_id].response) { + if (address.is_valid()) { + return address; + } + } + return IpAddress {}; +} + +void NetworkResolverBase::erase_resolve_item(int32_t p_id) { + OV_ERR_FAIL_INDEX_MSG( + p_id, MAX_QUERIES, + fmt::format( + "Too many concurrent DNS resolver queries ({}, but should be {} at most). Try performing less network requests at " + "once.", + p_id, MAX_QUERIES + ) + ); + + _resolve_handler.queue[p_id].status.store(static_cast(ResolverStatus::NONE)); +} +void NetworkResolverBase::clear_cache(std::string_view p_hostname) { + std::unique_lock lock(_resolve_handler.mutex); + + if (p_hostname.empty()) { + _resolve_handler.cache.clear(); + } else { + _resolve_handler.cache.erase(_get_cache_key(p_hostname, Type::NONE)); + _resolve_handler.cache.erase(_get_cache_key(p_hostname, Type::IPV4)); + _resolve_handler.cache.erase(_get_cache_key(p_hostname, Type::IPV6)); + _resolve_handler.cache.erase(_get_cache_key(p_hostname, Type::ANY)); + } +} + +memory::vector NetworkResolverBase::get_local_addresses() const { + OpenVic::string_map_t interfaces = get_local_interfaces(); + if (interfaces.empty()) { + return {}; + } + + memory::vector result; + size_t reserve_size = 0; + + for (std::pair const& pair : interfaces | ranges::views::reverse) { + reserve_size += pair.second.ip_addresses.size(); + } + if (reserve_size == 0) { + return {}; + } + + result.reserve(reserve_size); + for (std::pair const& pair : interfaces | ranges::views::reverse) { + for (IpAddress const& address : pair.second.ip_addresses | ranges::views::reverse) { + result.push_back(address); + } + } + return result; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.hpp b/src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.hpp new file mode 100644 index 000000000..bde116518 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.hpp @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct NetworkResolverBase { + static constexpr size_t MAX_QUERIES = 256; + static constexpr int32_t INVALID_ID = -1; + + enum class Provider : uint8_t { + UNIX, + WINDOWS, + }; + + enum class ResolverStatus : uint8_t { + NONE, + WAITING, + DONE, + ERROR, + }; + + enum class Type : uint8_t { + NONE, + IPV4, + IPV6, + ANY, + }; + + struct InterfaceInfo { + memory::string name; + memory::string name_friendly; + memory::string index; + memory::vector ip_addresses; + }; + + IpAddress resolve_hostname(std::string_view p_hostname, Type p_type = Type::ANY); + memory::vector resolve_hostname_addresses(std::string_view p_hostname, Type p_type = Type::ANY); + // async resolver hostname + int32_t resolve_hostname_queue_item(std::string_view p_hostname, Type p_type = Type::ANY); + ResolverStatus get_resolve_item_status(int32_t p_id) const; + memory::vector get_resolve_item_addresses(int32_t p_id) const; + IpAddress get_resolve_item_address(int32_t p_id) const; + + void erase_resolve_item(int32_t p_id); + + void clear_cache(std::string_view p_hostname = ""); + + memory::vector get_local_addresses() const; + virtual OpenVic::string_map_t get_local_interfaces() const = 0; + + virtual Provider provider() const = 0; + + protected: + virtual memory::vector _resolve_hostname(std::string_view p_hostname, Type p_type = Type::ANY) const = 0; + + struct ResolveHandler { + int32_t find_empty_id() const; + void resolve_queues(); + + struct Item { + Item(); + void clear(); + + memory::vector response; + memory::string hostname; + std::atomic_int32_t status; + Type type; + }; + + std::array queue; + string_map_t> cache; + std::mutex mutex; + std::thread thread; + std::counting_semaphore<> semaphore; + std::atomic_bool should_abort; + + static void _thread_runner(ResolveHandler& handler); + + ResolveHandler(); + ~ResolveHandler(); + } _resolve_handler; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp b/src/openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp new file mode 100644 index 000000000..5122a53e7 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp @@ -0,0 +1,19 @@ +#pragma once + +#ifdef _WIN32 +#include "openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.hpp" +#elif defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)) +#include "openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.hpp" +#else +#error "NetworkSocket.hpp only supports unix or windows systems" +#endif + +namespace OpenVic { + using NetworkSocket = +#ifdef _WIN32 + WindowsSocket +#else + UnixSocket +#endif + ; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp b/src/openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp new file mode 100644 index 000000000..2fad549d1 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp @@ -0,0 +1,76 @@ +#pragma once + +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/Constants.hpp" +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/types/EnumBitfield.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct NetworkSocketBase { + enum class PollType : uint8_t { + IN = 1 << 0, + OUT = 1 << 1, + IN_OUT = IN | OUT, + }; + + enum class Type : uint8_t { + NONE, + TCP, + UDP, + }; + + using port_type = socket_port_type; + + virtual ~NetworkSocketBase() = default; + + virtual NetworkError open(Type p_type, NetworkResolver::Type& ip_type) = 0; + virtual void close() = 0; + virtual NetworkError bind(IpAddress const& p_addr, port_type p_port) = 0; + virtual NetworkError listen(int p_max_pending) = 0; + virtual NetworkError connect_to_host(IpAddress const& p_addr, port_type p_port) = 0; + virtual NetworkError poll(PollType p_type, int timeout) const = 0; + virtual NetworkError receive(uint8_t* p_buffer, size_t p_len, int64_t& r_read) = 0; + virtual NetworkError receive_from( // + uint8_t* p_buffer, size_t p_len, int64_t& r_read, IpAddress& r_ip, port_type& r_port, bool p_peek = false + ) = 0; + virtual NetworkError send(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent) = 0; + virtual NetworkError send_to( // + const uint8_t* p_buffer, size_t p_len, int64_t& r_sent, IpAddress const& p_ip, port_type p_port + ) = 0; + + memory::unique_base_ptr accept(IpAddress& r_ip, port_type& r_port) { + return memory::unique_base_ptr { static_cast(_accept(r_ip, r_port)) }; + } + + template T> + memory::unique_ptr accept_as(IpAddress& r_ip, port_type& r_port) { + return memory::unique_ptr { static_cast(_accept(r_ip, r_port)) }; + } + + virtual bool is_open() const = 0; + virtual int available_bytes() const = 0; + virtual NetworkError get_socket_address(IpAddress* r_ip, uint16_t* r_port) const = 0; + + // Returns OK if the socket option has been set successfully + virtual NetworkError set_broadcasting_enabled(bool p_enabled) = 0; + virtual void set_blocking_enabled(bool p_enabled) = 0; + virtual void set_ipv6_only_enabled(bool p_enabled) = 0; + virtual void set_tcp_no_delay_enabled(bool p_enabled) = 0; + virtual void set_reuse_address_enabled(bool p_enabled) = 0; + virtual NetworkError change_multicast_group(IpAddress const& p_multi_address, std::string_view p_if_name, bool add) = 0; + + virtual NetworkResolver::Provider provider() const = 0; + + protected: + virtual NetworkSocketBase* _accept(IpAddress& r_ip, port_type& r_port) = 0; + }; + + template<> + struct enable_bitfield : std::true_type {}; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.cpp b/src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.cpp new file mode 100644 index 000000000..18dab291d --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.cpp @@ -0,0 +1,3 @@ +#include "PacketBuilder.hpp" + +template struct OpenVic::PacketBuilder<>; \ No newline at end of file diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.hpp b/src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.hpp new file mode 100644 index 000000000..55b185601 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketBuilder.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include +#include + +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/Marshal.hpp" + +namespace OpenVic { + template + struct PacketBuilder : memory::vector { + using base_type = memory::vector; + using base_type::base_type; + + template + requires requires(T const& value) { utility::encode(value, std::span {}); } + void put_back(T const& value) { + size_t value_size = utility::encode(value); + size_t last = size(); + resize(size() + value_size); + utility::encode(value, std::span { data() + last, value_size }); + } + }; + + extern template struct PacketBuilder<>; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketClient.hpp b/src/openvic-simulation/multiplayer/lowlevel/PacketClient.hpp new file mode 100644 index 000000000..7c0a36772 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketClient.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp" + +namespace OpenVic { + struct PacketClient { + virtual int64_t available_packets() const = 0; + virtual int64_t max_packet_size() const = 0; + + std::span get_packet() { + std::span buffer; + if (NetworkError error = _get_packet(buffer); error != NetworkError::OK) { + _last_error = error; + return {}; + } + return buffer; + } + + NetworkError set_packet(std::span buffer) { + NetworkError error = _set_packet(buffer); + if (error != NetworkError::OK) { + _last_error = error; + } + return error; + } + + NetworkError get_last_error() const { + return _last_error; + } + + PacketSpan packet_span() { + std::span span = get_packet(); + return { span.data(), span.size() }; + } + + protected: + virtual NetworkError _set_packet(std::span buffer) = 0; + virtual NetworkError _get_packet(std::span& r_set) = 0; + + mutable NetworkError _last_error = NetworkError::OK; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.cpp b/src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.cpp new file mode 100644 index 000000000..cd9d853de --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.cpp @@ -0,0 +1,6 @@ +#include "PacketReaderAdapter.hpp" + +#include "openvic-simulation/utility/Containers.hpp" + +template struct OpenVic::PacketReaderAdapter>; +template struct OpenVic::PacketReaderAdapter>; diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp b/src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp new file mode 100644 index 000000000..98e196666 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/Marshal.hpp" + +namespace OpenVic { + template + requires requires(Container& container) { + { container.data() } -> std::same_as; + { container.size() } -> std::same_as; + } + struct PacketReaderAdapter : Container { + using base_type = Container; + using base_type::base_type; + + static constexpr std::integral_constant::max()> npos {}; + + size_t index() const { + return _read_index; + } + + void seek(size_t index) { + if (index == npos) { + _read_index = npos; + } + OV_ERR_FAIL_INDEX(index, this->size()); + _read_index = index; + } + + template + requires requires(PacketReaderAdapter* self, size_t& size) { + { + utility::decode( + std::span { self->data() + self->_read_index, self->size() - self->_read_index }, size + ) + } -> std::same_as>; + } + T read() { + size_t decode_size = 0; + T result = utility::decode( + std::span { this->data() + _read_index, this->size() - _read_index }, decode_size + ); + OV_ERR_FAIL_COND_V(decode_size > this->size(), {}); + _read_index += decode_size; + return result; + } + + private: + size_t _read_index = 0; + }; + + template + using PacketBufferEndian = PacketReaderAdapter, Endian>; + template + using PacketSpanEndian = PacketReaderAdapter, Endian>; + + using PacketBuffer = PacketBufferEndian<>; + using PacketSpan = PacketSpanEndian<>; + + extern template struct PacketReaderAdapter>; + extern template struct PacketReaderAdapter>; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketServer.cpp b/src/openvic-simulation/multiplayer/lowlevel/PacketServer.cpp new file mode 100644 index 000000000..a542b0c38 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketServer.cpp @@ -0,0 +1,76 @@ +#include "PacketServer.hpp" + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketClient.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" + +using namespace OpenVic; + +PacketServer::PacketServer() : _socket { memory::make_shared() } {} + +PacketServer::~PacketServer() { + close(); +} + +NetworkSocket::port_type PacketServer::get_listening_port() const { + NetworkSocket::port_type local_port; + _socket->get_socket_address(nullptr, &local_port); + return local_port; +} + +bool PacketServer::is_connection_available() const { + OV_ERR_FAIL_COND_V(_socket == nullptr, false); + return _socket->is_open(); +} + +bool PacketServer::is_listening() const { + OV_ERR_FAIL_COND_V(_socket == nullptr, false); + return _socket->is_open(); +} + +NetworkError PacketServer::listen_to(NetworkSocket::port_type port, IpAddress const& bind_addr) { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(_socket->is_open(), NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V(!bind_addr.is_valid() && !bind_addr.is_wildcard(), NetworkError::INVALID_PARAMETER); + + NetworkError err; + NetworkResolver::Type ip_type = NetworkResolver::Type::ANY; + + if (bind_addr.is_valid()) { + ip_type = bind_addr.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + } + + err = _socket->open(NetworkSocket::Type::UDP, ip_type); + + if (err != NetworkError::OK) { + return err; + } + + _socket->set_blocking_enabled(false); + _socket->set_reuse_address_enabled(true); + err = _socket->bind(bind_addr, port); + + if (err != NetworkError::OK) { + close(); + return err; + } + return NetworkError::OK; +} + +void PacketServer::close() { + if (_socket != nullptr) { + _socket->close(); + } +} + +NetworkError PacketServer::poll() { + return NetworkError::OK; +} + +memory::unique_base_ptr PacketServer::take_next_client() { + return memory::unique_base_ptr { static_cast(_take_next_client()) }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketServer.hpp b/src/openvic-simulation/multiplayer/lowlevel/PacketServer.hpp new file mode 100644 index 000000000..614578161 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketServer.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketClient.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct PacketServer { + PacketServer(); + ~PacketServer(); + + NetworkSocket::port_type get_listening_port() const; + bool is_listening() const; + virtual bool is_connection_available() const; + virtual NetworkError listen_to(NetworkSocket::port_type port, IpAddress const& bind_addr = "*"); + virtual void close(); + virtual NetworkError poll(); + + memory::unique_base_ptr take_next_client(); + + template T> + memory::unique_ptr take_next_client_as() { + return memory::unique_ptr { static_cast(_take_next_client()) }; + } + + protected: + virtual PacketClient* _take_next_client() = 0; + + std::shared_ptr _socket; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketStream.cpp b/src/openvic-simulation/multiplayer/lowlevel/PacketStream.cpp new file mode 100644 index 000000000..29a7242c7 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketStream.cpp @@ -0,0 +1,110 @@ + +#include "PacketStream.hpp" + +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" + +using namespace OpenVic; + +NetworkError BufferPacketStream::set_data(std::span buffer) { + if (buffer.empty()) { + return NetworkError::OK; + } + + if (_position + buffer.size() > _data.size()) { + _data.resize(_position + buffer.size()); + } + + std::memcpy(&_data[_position], buffer.data(), buffer.size()); + + _position += buffer.size(); + return NetworkError::OK; +} + +NetworkError BufferPacketStream::get_data(std::span buffer_to_set) { + size_t received; + get_data_no_blocking(buffer_to_set, received); + if (received != buffer_to_set.size()) { + return NetworkError::INVALID_PARAMETER; + } + + return NetworkError::OK; +} + +NetworkError BufferPacketStream::set_data_no_blocking(std::span buffer, size_t& r_sent) { + r_sent = buffer.size(); + return set_data(buffer); +} + +NetworkError BufferPacketStream::get_data_no_blocking(std::span buffer_to_set, size_t& r_received) { + if (buffer_to_set.empty()) { + r_received = 0; + return NetworkError::OK; + } + + if (_position + buffer_to_set.size() > _data.size()) { + r_received = _data.size() - _position; + if (r_received <= 0) { + r_received = 0; + return NetworkError::OK; // you got 0 + } + } else { + r_received = buffer_to_set.size(); + } + + std::memcpy(buffer_to_set.data(), &_data[_position], r_received); + + _position += r_received; + // FIXME: return what? OK or ERR_* + // return OK for now so we don't maybe return garbage + return NetworkError::OK; +} + +int64_t BufferPacketStream::available_bytes() const { + return _data.size() - _position; +} + +void BufferPacketStream::seek(size_t p_pos) { + OV_ERR_FAIL_COND(p_pos > _data.size()); + _position = p_pos; +} + +size_t BufferPacketStream::size() const { + return _data.size(); +} + +size_t BufferPacketStream::position() const { + return _position; +} + +void BufferPacketStream::resize(size_t p_size) { + _data.resize(p_size); +} + +void BufferPacketStream::set_buffer_data(std::span p_data) { + if (p_data.size() > _data.size()) { + _data.resize(p_data.size()); + } + + std::uninitialized_copy(p_data.begin(), p_data.end(), _data.begin()); +} + +std::span BufferPacketStream::get_buffer_data() const { + return _data; +} + +void BufferPacketStream::clear() { + _data.clear(); + _position = 0; +} + +BufferPacketStream BufferPacketStream::duplicate() const { + BufferPacketStream result; + result._data.resize(_data.size()); + std::uninitialized_copy(_data.begin(), _data.end(), result._data.begin()); + return result; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/PacketStream.hpp b/src/openvic-simulation/multiplayer/lowlevel/PacketStream.hpp new file mode 100644 index 000000000..1d2af1d91 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/PacketStream.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketReaderAdapter.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct PacketStream { + virtual NetworkError set_data(std::span buffer) = 0; + virtual NetworkError get_data(std::span buffer_to_set) = 0; + virtual NetworkError set_data_no_blocking(std::span buffer, size_t& r_sent) = 0; + virtual NetworkError get_data_no_blocking(std::span buffer_to_set, size_t& r_received) = 0; + virtual int64_t available_bytes() const = 0; + + template + PacketBufferEndian packet_buffer(size_t bytes) { + PacketBufferEndian result; + result.resize(bytes); + return packet_buffer(result); + } + + template + PacketBufferEndian packet_buffer(PacketBufferEndian& buffer_store) { + get_data(buffer_store); + return buffer_store; + } + }; + + struct BufferPacketStream : PacketStream { + NetworkError set_data(std::span buffer) override; + NetworkError get_data(std::span r_buffer) override; + + NetworkError set_data_no_blocking(std::span buffer, size_t& r_sent) override; + NetworkError get_data_no_blocking(std::span buffer_to_set, size_t& r_received) override; + + virtual int64_t available_bytes() const override; + + void seek(size_t p_pos); + size_t size() const; + size_t position() const; + void resize(size_t p_size); + + void set_buffer_data(std::span p_data); + std::span get_buffer_data() const; + + void clear(); + + BufferPacketStream duplicate() const; + + private: + memory::vector _data; + size_t _position; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.cpp b/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.cpp new file mode 100644 index 000000000..7780c42f5 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.cpp @@ -0,0 +1,256 @@ + +#include "ReliableUdpClient.hpp" + +#include + +#include + +#include "openvic-simulation/GameManager.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/UdpClient.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/MemoryTracker.hpp" + +using namespace OpenVic; + +void ReliableUdpClient::_default_config_values(reliable_config_t& config) { + if (config.fragment_above == 0) { + config.fragment_above = 1024; + } + if (config.max_fragments == 0) { + config.max_fragments = 3; + } + if (config.fragment_size == 0) { + config.fragment_size = 1024; + } + if (config.ack_buffer_size == 0) { + config.ack_buffer_size = 256; + } + if (config.sent_packets_buffer_size == 0) { + config.sent_packets_buffer_size = 256; + } + if (config.received_packets_buffer_size == 0) { + config.received_packets_buffer_size = 256; + } + if (config.fragment_reassembly_buffer_size == 0) { + config.fragment_reassembly_buffer_size = 64; + } + if (config.rtt_smoothing_factor == 0) { + config.rtt_smoothing_factor = 0.0025f; + } + if (config.rtt_history_size == 0) { + config.rtt_history_size = 512; + } + if (config.packet_loss_smoothing_factor == 0) { + config.packet_loss_smoothing_factor = 0.1f; + } + if (config.bandwidth_smoothing_factor == 0) { + config.bandwidth_smoothing_factor = 0.1f; + } + if (config.packet_header_size == 0) { + config.packet_header_size = 48; + } + if (config.max_packet_size == 0) { + config.max_packet_size = max_packet_size() - config.packet_header_size; + } +} + +ReliableUdpClient::ReliableUdpClient(reliable_config_t config) { + _default_config_values(config); + + config.context = this; + config.transmit_packet_function = &_transmit_packet; + config.process_packet_function = &_process_packet; + + _endpoint = endpoint_pointer_type { + reliable_endpoint_create(&config, static_cast(GameManager::get_elapsed_milliseconds()) / 1000.0) + }; +#ifdef DEBUG_ENABLED + { + utility::MemoryTracker tracker {}; + tracker.on_node_allocation(nullptr, sizeof(reliable_endpoint_dummy), 0); + } +#endif + + OV_ERR_FAIL_COND(!_endpoint); +} + +thread_local NetworkError transmit_error; + +void ReliableUdpClient::_transmit_packet(void* context, uint64_t id, uint16_t sequence, uint8_t* packet_data, int packet_size) { + ReliableUdpClient* self = static_cast(context); + transmit_error = self->UdpClient::_set_packet({ packet_data, static_cast(packet_size) }); +} + +NetworkError ReliableUdpClient::_set_packet(std::span buffer) { + OV_ERR_FAIL_COND_V(!_endpoint, NetworkError::UNAVAILABLE); + + reliable_endpoint_send_packet(_endpoint.get(), buffer.data(), buffer.size_bytes()); + // Pretty sure floating point for packet tracking shouldn't cause issues + update(static_cast(GameManager::get_elapsed_milliseconds()) / 1000.0); + return transmit_error; +} + +struct PeerProcessData { + IpAddress& p_ip; + NetworkSocket::port_type& p_port; +}; + +thread_local PeerProcessData* process_peer_data = nullptr; + +int ReliableUdpClient::_process_packet(void* context, uint64_t id, uint16_t sequence, uint8_t* packet_data, int packet_size) { + ReliableUdpClient* self = static_cast(context); + std::span packet { packet_data, static_cast(packet_size) }; + + if (self->UdpClient::store_packet(process_peer_data->p_ip, process_peer_data->p_port, packet) == NetworkError::OK) { + return 1; + } + + process_peer_data = nullptr; + return 0; +} + +NetworkError ReliableUdpClient::store_packet(IpAddress p_ip, NetworkSocket::port_type p_port, std::span p_buf) { + OV_ERR_FAIL_COND_V(!_endpoint, NetworkError::UNAVAILABLE); + + PeerProcessData data = { p_ip, p_port }; + process_peer_data = &data; + reliable_endpoint_receive_packet(_endpoint.get(), p_buf.data(), p_buf.size_bytes()); + if (process_peer_data == nullptr) { + return NetworkError::OUT_OF_MEMORY; + } + process_peer_data = nullptr; + return NetworkError::OK; +} + +void ReliableUdpClient::update(double time) { + reliable_endpoint_update(_endpoint.get(), time); +} + +void ReliableUdpClient::reset() { + reliable_endpoint_reset(_endpoint.get()); +} + +ReliableUdpClient::sequence_type ReliableUdpClient::get_current_sequence_value() const { + return get_next_sequence_value() - 1; +} + +ReliableUdpClient::sequence_type ReliableUdpClient::get_next_sequence_value() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_next_packet_sequence(_endpoint.get()); +} + +std::span ReliableUdpClient::get_acknowledged_sequences() const { + OV_ERR_FAIL_COND_V(!_endpoint, {}); + + int32_t acks_count; + uint16_t* acks_array = reliable_endpoint_get_acks(_endpoint.get(), &acks_count); + OV_ERR_FAIL_COND_V(acks_count <= 0, {}); + + return { acks_array, static_cast(acks_count) }; +} + +void ReliableUdpClient::clear_acknowledged_sequences() { + OV_ERR_FAIL_COND(!_endpoint); + + reliable_endpoint_clear_acks(_endpoint.get()); +} + +float ReliableUdpClient::get_round_trip_smooth_average() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_rtt(_endpoint.get()); +} + +float ReliableUdpClient::get_round_trip_average() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_rtt_avg(_endpoint.get()); +} + +float ReliableUdpClient::get_round_trip_minimum() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_rtt_min(_endpoint.get()); +} + +float ReliableUdpClient::get_round_trip_maximum() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_rtt_max(_endpoint.get()); +} + +float ReliableUdpClient::get_jitter_average() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_jitter_avg_vs_min_rtt(_endpoint.get()); +} + +float ReliableUdpClient::get_jitter_maximum() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_jitter_max_vs_min_rtt(_endpoint.get()); +} + +float ReliableUdpClient::get_jitter_against_average_round_trip() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_jitter_stddev_vs_avg_rtt(_endpoint.get()); +} + +float ReliableUdpClient::get_packet_loss() const { + OV_ERR_FAIL_COND_V(!_endpoint, 0); + + return reliable_endpoint_packet_loss(_endpoint.get()); +} + +ReliableUdpClient::BandwidthStatistics ReliableUdpClient::get_bandwidth_statistics() const { + OV_ERR_FAIL_COND_V(!_endpoint, {}); + + BandwidthStatistics result; // NOLINT + reliable_endpoint_bandwidth( + _endpoint.get(), &result.sent_bandwidth_kbps, &result.received_bandwidth_kbps, &result.acked_bandwidth_kpbs + ); + return result; +} + +size_t ReliableUdpClient::get_packets_sent() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_SENT]; +} + +size_t ReliableUdpClient::get_packets_received() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_RECEIVED]; +} + +size_t ReliableUdpClient::get_packets_acknowledged() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_ACKED]; +} + +size_t ReliableUdpClient::get_packets_stale() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_STALE]; +} + +size_t ReliableUdpClient::get_packets_invalid() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_INVALID]; +} + +size_t ReliableUdpClient::get_packets_too_large_to_send() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_TOO_LARGE_TO_SEND]; +} + +size_t ReliableUdpClient::get_packets_too_large_to_receive() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_PACKETS_TOO_LARGE_TO_RECEIVE]; +} + +size_t ReliableUdpClient::get_fragments_sent() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_FRAGMENTS_SENT]; +} + +size_t ReliableUdpClient::get_fragments_received() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_FRAGMENTS_RECEIVED]; +} + +size_t ReliableUdpClient::get_fragments_invalid() const { + return reliable_endpoint_counters(_endpoint.get())[RELIABLE_ENDPOINT_COUNTER_NUM_FRAGMENTS_INVALID]; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.hpp b/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.hpp new file mode 100644 index 000000000..415bb5189 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.hpp @@ -0,0 +1,113 @@ +#pragma once + +#include +#include +#include + +#include + +#include "openvic-simulation/multiplayer/lowlevel/Constants.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/UdpClient.hpp" + +namespace OpenVic { + struct ReliableUdpClient : UdpClient { + ReliableUdpClient(reliable_config_t config = { .name = "client" }); + + using sequence_type = reliable_udp_sequence_type; + + NetworkError store_packet(IpAddress p_ip, NetworkSocket::port_type p_port, std::span p_buf) override; + + void update(double time); + void reset(); + + sequence_type get_current_sequence_value() const; + sequence_type get_next_sequence_value() const; + + std::span get_acknowledged_sequences() const; + void clear_acknowledged_sequences(); + + float get_round_trip_smooth_average() const; + float get_round_trip_average() const; + float get_round_trip_minimum() const; + float get_round_trip_maximum() const; + + float get_jitter_average() const; + float get_jitter_maximum() const; + float get_jitter_against_average_round_trip() const; + + float get_packet_loss() const; + + struct BandwidthStatistics { + float sent_bandwidth_kbps; + float received_bandwidth_kbps; + float acked_bandwidth_kpbs; + }; + BandwidthStatistics get_bandwidth_statistics() const; + + size_t get_packets_sent() const; + size_t get_packets_received() const; + size_t get_packets_acknowledged() const; + size_t get_packets_stale() const; + size_t get_packets_invalid() const; + size_t get_packets_too_large_to_send() const; + size_t get_packets_too_large_to_receive() const; + size_t get_fragments_sent() const; + size_t get_fragments_received() const; + size_t get_fragments_invalid() const; + + protected: + NetworkError _set_packet(std::span buffer) override; + + private: + static void _transmit_packet(void* context, uint64_t id, uint16_t sequence, uint8_t* packet_data, int packet_size); + static int _process_packet(void* context, uint64_t ud, uint16_t sequence, uint8_t* packet_data, int packet_size); + + void _default_config_values(reliable_config_t& config); + + // Ensures we track an accurate allocation/deallocation size for reliable_endpoint_t + struct reliable_endpoint_dummy { + void* allocator_context; + void* (*allocate_function)(void*, size_t); + void (*free_function)(void*, void*); + struct reliable_config_t config; + double time; + float rtt; + float rtt_min; + float rtt_max; + float rtt_avg; + float jitter_avg_vs_min_rtt; + float jitter_max_vs_min_rtt; + float jitter_stddev_vs_avg_rtt; + float packet_loss; + float sent_bandwidth_kbps; + float received_bandwidth_kbps; + float acked_bandwidth_kbps; + int num_acks; + uint16_t* acks; + uint16_t sequence; + float* rtt_history_buffer; + struct reliable_sequence_buffer_t* sent_packets; + struct reliable_sequence_buffer_t* received_packets; + struct reliable_sequence_buffer_t* fragment_reassembly; + uint64_t counters[RELIABLE_ENDPOINT_NUM_COUNTERS]; + }; + + struct endpoint_deleter_t { + using value_type = reliable_endpoint_t; + + void operator()(value_type* ptr) { + reliable_endpoint_destroy(ptr); +#ifdef DEBUG_ENABLED + { + utility::MemoryTracker tracker {}; + tracker.on_node_deallocation(nullptr, sizeof(reliable_endpoint_dummy), 0); + } +#endif + } + }; + using endpoint_pointer_type = std::unique_ptr; + + endpoint_pointer_type _endpoint; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpServer.hpp b/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpServer.hpp new file mode 100644 index 000000000..8878d0bb9 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/ReliableUdpServer.hpp @@ -0,0 +1,47 @@ +#pragma once + +#include +#include + +#include + +#include "openvic-simulation/multiplayer/lowlevel/ReliableUdpClient.hpp" +#include "openvic-simulation/multiplayer/lowlevel/UdpServer.hpp" +#include "openvic-simulation/utility/Utility.hpp" + +namespace OpenVic { + struct ReliableUdpServer : UdpServer { + void set_client_config_template(reliable_config_t const& config) { + _client_config_template = config; + } + + reliable_config_t const& get_client_config_template() const { + return _client_config_template; + } + + protected: + inline ReliableUdpClient* _take_next_client() override { + return static_cast(UdpServer::_take_next_client()); + } + + inline ReliableUdpClient* _create_client() override { + reliable_config_t config = _client_config_template; + config.id = _clients.size() + _pending_clients.size() + 1; + + size_t name_length = std::strlen(config.name); + if (OV_unlikely(name_length >= 225)) { + static constexpr const char server_connection[] = "server-connection-"; + static_assert(sizeof(server_connection) - 1 < sizeof(config.name)); + + name_length = sizeof(server_connection) - 1; + std::strncpy(config.name, server_connection, name_length); + } + + std::to_chars(config.name + name_length, config.name + sizeof(config.name), config.id); + + return new ReliableUdpClient(std::move(config)); + } + + reliable_config_t _client_config_template { .name = "server-connection-" }; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.cpp b/src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.cpp new file mode 100644 index 000000000..d5fdfe69c --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.cpp @@ -0,0 +1,10 @@ + +#include "StreamPacketClient.hpp" + +#include "openvic-simulation/multiplayer/lowlevel/PacketStream.hpp" +#include "openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp" + +using namespace OpenVic; + +template struct OpenVic::BasicStreamPacketClient; +template struct OpenVic::BasicStreamPacketClient; diff --git a/src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.hpp b/src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.hpp new file mode 100644 index 000000000..843e22797 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/StreamPacketClient.hpp @@ -0,0 +1,173 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketClient.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketStream.hpp" +#include "openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp" +#include "openvic-simulation/types/RingBuffer.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/Marshal.hpp" + +namespace OpenVic { + template T> + struct BasicStreamPacketClient : PacketClient { + BasicStreamPacketClient(uint8_t max_buffer_power = 16) { + _ring_buffer.reserve_power(max_buffer_power); + _input_buffer.resize(int64_t(1) << max_buffer_power); + _output_buffer.resize(int64_t(1) << max_buffer_power); + } + + int64_t available_packets() const override { + _poll_buffer(); + + uint32_t remaining = _ring_buffer.size() - _ring_buffer.space() - 1; + + int offset = 0; + int count = 0; + + while (remaining >= 4) { + uint8_t lbuf[4]; + std::copy_n(_ring_buffer.begin() + offset, 4, lbuf); + size_t decode_size = 0; + uint32_t length = utility::decode(lbuf, decode_size); + OV_ERR_BREAK(decode_size != 4); + remaining -= 4; + offset += 4; + if (length > remaining) { + break; + } + remaining -= length; + offset += length; + count++; + } + + return count; + } + + int64_t max_packet_size() const override { + return _output_buffer.size(); + } + + void set_packet_stream(std::shared_ptr stream) { + if (stream != _stream) { + _ring_buffer.clear(); // Reset the ring buffer. + } + + _stream = stream; + } + + std::shared_ptr get_packet_stream() const { + return _stream; + } + + void set_input_buffer_max_size(int p_max_size) { + OV_ERR_FAIL_COND_MSG(p_max_size < 0, "Max size of input buffer size cannot be smaller than 0."); + // WARNING: May lose packets. + OV_ERR_FAIL_COND_MSG( + _ring_buffer.size() - _ring_buffer.space() - 1, "Buffer in use, resizing would cause loss of data." + ); + _ring_buffer.reserve_power(std::bit_width(std::bit_ceil(p_max_size + 4)) - 1); + _input_buffer.resize(std::bit_ceil(p_max_size + 4)); + } + + int get_input_buffer_max_size() const { + return _input_buffer.size() - 4; + } + + void set_output_buffer_max_size(int p_max_size) { + _output_buffer.resize(std::bit_ceil(p_max_size + 4)); + } + + int get_output_buffer_max_size() const { + return _output_buffer.size() - 4; + } + + private: + NetworkError _set_packet(std::span buffer) override { + OV_ERR_FAIL_COND_V(_stream == nullptr, NetworkError::UNCONFIGURED); + NetworkError err = _poll_buffer(); // won't hurt to poll here too + + if (err != NetworkError::OK) { + return err; + } + + if (buffer.size() == 0) { + return NetworkError::OK; + } + + OV_ERR_FAIL_COND_V(buffer.size() + 4 > _output_buffer.size(), NetworkError::INVALID_PARAMETER); + + utility::encode(buffer.size(), _output_buffer); + uint8_t* dst = &_output_buffer[4]; + for (int i = 0; i < buffer.size(); i++) { + dst[i] = buffer[i]; + } + + return _stream->set_data({ &_output_buffer[0], buffer.size() + 4 }); + } + + NetworkError _get_packet(std::span& r_set) override { + OV_ERR_FAIL_COND_V(_stream == nullptr, NetworkError::UNCONFIGURED); + _poll_buffer(); + + int remaining = _ring_buffer.size() - _ring_buffer.space() - 1; + OV_ERR_FAIL_COND_V(remaining < 4, NetworkError::UNAVAILABLE); + uint8_t lbuf[4]; + std::copy_n(_ring_buffer.begin(), 4, lbuf); + remaining -= 4; + size_t decode_size = 0; + uint32_t length = utility::decode(lbuf, decode_size); + OV_ERR_FAIL_COND_V(remaining < (int)length, NetworkError::UNAVAILABLE); + + OV_ERR_FAIL_COND_V(_input_buffer.size() < length, NetworkError::UNAVAILABLE); + _ring_buffer.erase(_ring_buffer.begin(), decode_size); // get rid of the decode_size + _ring_buffer.read_buffer_to(_input_buffer.data(), length); // read packet + + r_set = std::span { _input_buffer.data(), length }; + return NetworkError::OK; + } + + NetworkError _poll_buffer() const { + OV_ERR_FAIL_COND_V(_stream == nullptr, NetworkError::UNCONFIGURED); + + size_t read = 0; + OV_ERR_FAIL_COND_V(_input_buffer.size() < _ring_buffer.space(), NetworkError::UNAVAILABLE); + NetworkError err = _stream->get_data_no_blocking({ _input_buffer.data(), _ring_buffer.space() }, read); + if (err != NetworkError::OK) { + return err; + } + if (read == 0) { + return NetworkError::OK; + } + + typename decltype(_ring_buffer)::iterator last = _ring_buffer.end(); + ptrdiff_t w = std::distance(last, _ring_buffer.append(&_input_buffer[0], read)); + OV_ERR_FAIL_COND_V(w != read, NetworkError::BUG); + + return NetworkError::OK; + } + + mutable std::shared_ptr _stream; + mutable RingBuffer _ring_buffer; + mutable memory::vector _input_buffer; + mutable memory::vector _output_buffer; + }; + + using StreamPacketClient = BasicStreamPacketClient; + extern template struct BasicStreamPacketClient; + + using TcpStreamPacketClient = BasicStreamPacketClient; + extern template struct BasicStreamPacketClient; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.cpp b/src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.cpp new file mode 100644 index 000000000..890f0cb27 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.cpp @@ -0,0 +1,316 @@ + +#include "TcpPacketStream.hpp" + +#include +#include +#include +#include + +#include "openvic-simulation/GameManager.hpp" +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" + +using namespace OpenVic; + +TcpPacketStream::TcpPacketStream() : _socket { std::make_shared() } {} + +TcpPacketStream::~TcpPacketStream() { + close(); +} + +void TcpPacketStream::accept_socket(std::shared_ptr p_sock, IpAddress p_host, NetworkSocket::port_type p_port) { + _socket = p_sock; + _socket->set_blocking_enabled(false); + + _timeout = GameManager::get_elapsed_milliseconds() + _timeout_seconds * 1000; + _status = Status::CONNECTED; + + _client_address = p_host; + _client_port = p_port; +} + +int64_t TcpPacketStream::available_bytes() const { + OV_ERR_FAIL_COND_V(_socket == nullptr, -1); + return _socket->available_bytes(); +} + +NetworkError TcpPacketStream::bind(NetworkSocket::port_type port, IpAddress const& address) { + OV_ERR_FAIL_COND_V(_socket != nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(_socket->is_open(), NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V_MSG( + port < 0 || port > 65535, NetworkError::INVALID_PARAMETER, + "The local port number must be between 0 and 65535 (inclusive)." + ); + + NetworkResolver::Type ip_type = address.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + if (address.is_wildcard()) { + ip_type = NetworkResolver::Type::ANY; + } + NetworkError err = _socket->open(NetworkSocket::Type::TCP, ip_type); + if (err != NetworkError::OK) { + return err; + } + _socket->set_blocking_enabled(false); + return _socket->bind(address, port); +} + +NetworkError TcpPacketStream::connect_to(IpAddress const& host, NetworkSocket::port_type port) { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(_status != Status::NONE, NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V(!host.is_valid(), NetworkError::INVALID_PARAMETER); + OV_ERR_FAIL_COND_V_MSG( + port < 1 || port > 65535, NetworkError::INVALID_PARAMETER, + "The remote port number must be between 1 and 65535 (inclusive)." + ); + + if (!_socket->is_open()) { + NetworkResolver::Type ip_type = host.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + NetworkError err = _socket->open(NetworkSocket::Type::TCP, ip_type); + if (err != NetworkError::OK) { + return err; + } + _socket->set_blocking_enabled(false); + } + + _timeout = GameManager::get_elapsed_milliseconds() + _timeout_seconds * 1000; + NetworkError err = _socket->connect_to_host(host, port); + + if (err == NetworkError::OK) { + _status = Status::CONNECTED; + } else if (err == NetworkError::BUSY) { + _status = Status::CONNECTING; + } else { + Logger::error("Connection to remote host failed!"); + close(); + return err; + } + + _client_address = host; + _client_port = port; + + return NetworkError::OK; +} + +NetworkError TcpPacketStream::poll() { + switch (_status) { + using enum Status; + case CONNECTED: { + NetworkError err; + err = _socket->poll(NetworkSocket::PollType::IN, 0); + if (err == NetworkError::OK) { + // FIN received + if (_socket->available_bytes() == 0) { + close(); + return NetworkError::OK; + } + } + // Also poll write + err = _socket->poll(NetworkSocket::PollType::IN_OUT, 0); + if (err != NetworkError::OK && err != NetworkError::BUSY) { + // Got an error + close(); + _status = ERROR; + return err; + } + return NetworkError::OK; + } + case CONNECTING: break; + default: return NetworkError::OK; + } + + NetworkError err = _socket->connect_to_host(_client_address, _client_port); + + if (err == NetworkError::OK) { + _status = Status::CONNECTED; + return NetworkError::OK; + } else if (err == NetworkError::BUSY) { + // Check for connect timeout + if (GameManager::get_elapsed_milliseconds() > _timeout) { + close(); + _status = Status::ERROR; + return err; + } + // Still trying to connect + return NetworkError::OK; + } + + close(); + _status = Status::ERROR; + return err; +} + +NetworkError TcpPacketStream::wait(NetworkSocket::PollType p_type, int p_timeout) { + OV_ERR_FAIL_COND_V(_socket == nullptr || !_socket->is_open(), NetworkError::UNAVAILABLE); + return _socket->poll(p_type, p_timeout); +} + +void TcpPacketStream::close() { + if (_socket != nullptr && _socket->is_open()) { + _socket->close(); + } + + _timeout = 0; + _status = Status::NONE; + _client_address = {}; + _client_port = 0; +} + +NetworkSocket::port_type TcpPacketStream::get_bound_port() const { + NetworkSocket::port_type local_port; + _socket->get_socket_address(nullptr, &local_port); + return local_port; +} + +bool TcpPacketStream::is_connected() const { + return _status == Status::CONNECTED; +} + +IpAddress TcpPacketStream::get_connected_address() const { + return _client_address; +} + +NetworkSocket::port_type TcpPacketStream::get_connected_port() const { + return _client_port; +} + +TcpPacketStream::Status TcpPacketStream::get_status() const { + return _status; +} + +void TcpPacketStream::set_timeout_seconds(uint64_t timeout) { + _timeout_seconds = timeout; +} + +uint64_t TcpPacketStream::get_timeout_seconds() const { + return _timeout_seconds; +} + +void TcpPacketStream::set_no_delay(bool p_enabled) { + OV_ERR_FAIL_COND(_socket == nullptr || !_socket->is_open()); + _socket->set_tcp_no_delay_enabled(p_enabled); +} + +NetworkError TcpPacketStream::set_data(std::span buffer) { + size_t _ = 0; + return _set_data(buffer, _); +} + +NetworkError TcpPacketStream::set_data_no_blocking(std::span buffer, size_t& r_sent) { + return _set_data(buffer, r_sent); +} + +NetworkError TcpPacketStream::get_data(std::span buffer_to_set) { + return _get_data(buffer_to_set); +} + +NetworkError TcpPacketStream::get_data_no_blocking(std::span buffer_to_set, size_t& r_received) { + NetworkError result = _get_data(buffer_to_set); + r_received = buffer_to_set.size(); + return result; +} + +template +NetworkError TcpPacketStream::_set_data(std::span buffer, size_t& r_sent) { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + + if (_status != Status::CONNECTED) { + return NetworkError::FAILED; + } + + NetworkError err; + int data_to_send = buffer.size_bytes(); + const uint8_t* offset = buffer.data(); + int64_t total_sent = 0; + + while (data_to_send) { + int64_t sent_amount = 0; + err = _socket->send(offset, data_to_send, sent_amount); + + if (err != NetworkError::OK) { + if (err != NetworkError::BUSY) { + close(); + return err; + } + + if constexpr (!IsBlocking) { + r_sent = total_sent; + return NetworkError::OK; + } else { + // Block and wait for the socket to accept more data + err = _socket->poll(NetworkSocket::PollType::OUT, -1); + if (err != NetworkError::OK) { + close(); + return err; + } + } + } else { + data_to_send -= sent_amount; + offset += sent_amount; + total_sent += sent_amount; + } + } + + r_sent = total_sent; + + return NetworkError::OK; +} + +template NetworkError TcpPacketStream::_set_data(std::span buffer, size_t& r_sent); +template NetworkError TcpPacketStream::_set_data(std::span buffer, size_t& r_sent); + +template +NetworkError TcpPacketStream::_get_data(std::span& r_buffer) { + if (_status != Status::CONNECTED) { + return NetworkError::FAILED; + } + + NetworkError err; + int to_read = r_buffer.size_bytes(); + int total_read = 0; + + while (to_read) { + int64_t read = 0; + err = _socket->receive(r_buffer.data() + total_read, to_read, read); + + if (err != NetworkError::OK) { + if (err != NetworkError::BUSY) { + close(); + return err; + } + + if constexpr (!IsBlocking) { + r_buffer = { r_buffer.data(), static_cast(total_read) }; + return NetworkError::OK; + } else { + err = _socket->poll(NetworkSocket::PollType::IN, -1); + + if (err != NetworkError::OK) { + close(); + return err; + } + } + } else if (read == 0) { + close(); + r_buffer = { r_buffer.data(), static_cast(total_read) }; + return NetworkError::EMPTY_BUFFER; + } else { + to_read -= read; + total_read += read; + + if constexpr (!IsBlocking) { + r_buffer = { r_buffer.data(), static_cast(total_read) }; + return NetworkError::OK; + } + } + } + + r_buffer = { r_buffer.data(), static_cast(total_read) }; + + return NetworkError::OK; +} + +template NetworkError TcpPacketStream::_get_data(std::span& r_buffer); +template NetworkError TcpPacketStream::_get_data(std::span& r_buffer); diff --git a/src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp b/src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp new file mode 100644 index 000000000..29050d254 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp @@ -0,0 +1,69 @@ +#pragma once + +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketStream.hpp" + +namespace OpenVic { + struct TcpPacketStream : PacketStream { + enum class Status : uint8_t { + NONE, + CONNECTING, + CONNECTED, + ERROR, + }; + + TcpPacketStream(); + ~TcpPacketStream(); + + void accept_socket(std::shared_ptr p_sock, IpAddress p_host, NetworkSocket::port_type p_port); + + int64_t available_bytes() const override; + NetworkError bind(NetworkSocket::port_type port, IpAddress const& address = "*"); + NetworkError connect_to(IpAddress const& host, NetworkSocket::port_type port); + NetworkError poll(); + NetworkError wait(NetworkSocket::PollType p_type, int p_timeout = 0); + void close(); + NetworkSocket::port_type get_bound_port() const; + bool is_connected() const; + + IpAddress get_connected_address() const; + NetworkSocket::port_type get_connected_port() const; + Status get_status() const; + + void set_timeout_seconds(uint64_t timeout); + uint64_t get_timeout_seconds() const; + + void set_no_delay(bool p_enabled); + + NetworkError set_data(std::span buffer) override; + NetworkError set_data_no_blocking(std::span buffer, size_t& r_sent) override; + NetworkError get_data(std::span buffer_to_set) override; + NetworkError get_data_no_blocking(std::span buffer_to_set, size_t& r_received) override; + + private: + uint64_t _timeout_seconds = 30; + + Status _status = Status::NONE; + uint64_t _timeout = 0; + NetworkSocket::port_type _client_port = 0; + IpAddress _client_address; + std::shared_ptr _socket; + + template + NetworkError _set_data(std::span buffer, size_t& r_sent); + + template + NetworkError _get_data(std::span& r_buffer); + }; + + extern template NetworkError TcpPacketStream::_set_data(std::span buffer, size_t& r_sent); + extern template NetworkError TcpPacketStream::_set_data(std::span buffer, size_t& r_sent); + extern template NetworkError TcpPacketStream::_get_data(std::span& r_buffer); + extern template NetworkError TcpPacketStream::_get_data(std::span& r_buffer); +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/TcpServer.cpp b/src/openvic-simulation/multiplayer/lowlevel/TcpServer.cpp new file mode 100644 index 000000000..3981e3d5b --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/TcpServer.cpp @@ -0,0 +1,79 @@ + +#include "TcpServer.hpp" + +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketServer.hpp" +#include "openvic-simulation/multiplayer/lowlevel/StreamPacketClient.hpp" +#include "openvic-simulation/multiplayer/lowlevel/TcpPacketStream.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +using namespace OpenVic; + +bool TcpServer::is_connection_available() const { + if (!PacketServer::is_connection_available()) { + return false; + } + return _socket->poll(NetworkSocket::PollType::IN, 0) == NetworkError::OK; +} + +NetworkError TcpServer::listen_to(NetworkSocket::port_type port, IpAddress const& bind_addr) { + OV_ERR_FAIL_COND_V(!_socket, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(_socket->is_open(), NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V(!bind_addr.is_valid() && !bind_addr.is_wildcard(), NetworkError::INVALID_PARAMETER); + + NetworkError err; + NetworkResolver::Type ip_type = NetworkResolver::Type::ANY; + + // If the bind address is valid use its type as the socket type + if (bind_addr.is_valid()) { + ip_type = bind_addr.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + } + + err = _socket->open(NetworkSocket::Type::TCP, ip_type); + + OV_ERR_FAIL_COND_V(err != NetworkError::OK, err); + + _socket->set_blocking_enabled(false); + _socket->set_reuse_address_enabled(true); + + err = _socket->bind(bind_addr, port); + + if (err != NetworkError::OK) { + _socket->close(); + return err; + } + + err = _socket->listen(MAX_PENDING_CONNECTIONS); + if (err != NetworkError::OK) { + _socket->close(); + } + return err; +} + +memory::unique_ptr TcpServer::take_packet_stream() { + if (!is_connection_available()) { + return nullptr; + } + + IpAddress ip; + NetworkSocket::port_type port = 0; + std::shared_ptr ns = _socket->accept_as(ip, port); + if (!ns->is_open()) { + return nullptr; + } + + memory::unique_ptr result = memory::make_unique(); + result->accept_socket(ns, ip, port); + return result; +} + +TcpStreamPacketClient* TcpServer::_take_next_client() { + memory::unique_ptr client = memory::make_unique(); + client->set_packet_stream(std::shared_ptr { take_packet_stream().release() }); + return client.release(); +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/TcpServer.hpp b/src/openvic-simulation/multiplayer/lowlevel/TcpServer.hpp new file mode 100644 index 000000000..889fa5433 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/TcpServer.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketServer.hpp" +#include "openvic-simulation/multiplayer/lowlevel/StreamPacketClient.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct TcpServer : PacketServer { + static constexpr size_t MAX_PENDING_CONNECTIONS = 8; + + bool is_connection_available() const override; + NetworkError listen_to(NetworkSocket::port_type port, IpAddress const& bind_addr = "*") override; + + memory::unique_ptr take_packet_stream(); + + protected: + TcpStreamPacketClient* _take_next_client() override; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/UdpClient.cpp b/src/openvic-simulation/multiplayer/lowlevel/UdpClient.cpp new file mode 100644 index 000000000..8e194f659 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/UdpClient.cpp @@ -0,0 +1,329 @@ +#include "UdpClient.hpp" + +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/UdpServer.hpp" + +using namespace OpenVic; + +UdpClient::UdpClient() : _socket(std::make_shared()), _ring_buffer(16) {} // NOLINT + +UdpClient::~UdpClient() { + close(); +} + +NetworkError UdpClient::bind(NetworkSocket::port_type port, IpAddress const& address, size_t min_buffer_size) { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(_socket->is_open(), NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V(!address.is_valid() && !address.is_wildcard(), NetworkError::INVALID_PARAMETER); + OV_ERR_FAIL_COND_V_MSG( + port > 65535, NetworkError::INVALID_PARAMETER, "The local port number must be between 0 and 65535 (inclusive)." + ); + + NetworkError err; + NetworkResolver::Type ip_type = NetworkResolver::Type::ANY; + + if (address.is_valid()) { + ip_type = address.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + } + + err = _socket->open(NetworkSocket::Type::UDP, ip_type); + + if (err != NetworkError::OK) { + return err; + } + + _socket->set_blocking_enabled(false); + _socket->set_broadcasting_enabled(_is_broadcasting); + err = _socket->bind(address, port); + + if (err != NetworkError::OK) { + _socket->close(); + return err; + } + if (min_buffer_size < _ring_buffer.capacity()) { + _ring_buffer.shrink_to_fit(); + } + _ring_buffer.reserve_power(std::bit_ceil(min_buffer_size)); + return NetworkError::OK; +} + +void UdpClient::close() { + if (_server) { + _server->remove_client(peer_ip, peer_port); + _server = nullptr; + _socket = std::make_shared(); + } else if (_socket) { + _socket->close(); + } + _ring_buffer.reserve_power(16); + _queue_count = 0; + _is_connected = false; +} + +NetworkError UdpClient::connect_to(IpAddress const& host, NetworkSocket::port_type port) { + OV_ERR_FAIL_COND_V(_server, NetworkError::LOCKED); + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!host.is_valid(), NetworkError::INVALID_PARAMETER); + + NetworkError err; + + if (!_socket->is_open()) { + NetworkResolver::Type ip_type = host.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + err = _socket->open(NetworkSocket::Type::UDP, ip_type); + OV_ERR_FAIL_COND_V(err != NetworkError::OK, err); + _socket->set_blocking_enabled(false); + } + + err = _socket->connect_to_host(host, port); + + // I see no reason why we should get ERR_BUSY (wouldblock/eagain) here. + // This is UDP, so connect is only used to tell the OS to which socket + // it should deliver packets when multiple are bound on the same address/port. + if (err != NetworkError::OK) { + close(); + OV_ERR_FAIL_V_MSG(err, "Unable to connect"); + } + + _is_connected = true; + + peer_ip = host; + peer_port = port; + + // Flush any packet we might still have in queue. + _ring_buffer.clear(); + return NetworkError::OK; +} + +NetworkSocket::port_type UdpClient::get_bound_port() const { + NetworkSocket::port_type local_port; + _socket->get_socket_address(nullptr, &local_port); + return local_port; +} + +IpAddress UdpClient::get_last_packet_ip() const { + return _packet_ip; +} + +NetworkSocket::port_type UdpClient::get_last_packet_port() const { + return _packet_port; +} + +bool UdpClient::is_bound() const { + return _socket != nullptr && _socket->is_open(); +} + +bool UdpClient::is_connected() const { + return _is_connected; +} + +NetworkError UdpClient::wait() { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + return _socket->poll(NetworkSocket::PollType::IN, -1); +} + +NetworkError UdpClient::set_destination(HostnameAddress const& addr, NetworkSocket::port_type port) { + OV_ERR_FAIL_COND_V_MSG( + _is_connected, NetworkError::UNCONFIGURED, "Destination address cannot be set for connected sockets" + ); + peer_ip = addr.resolved_address(); + peer_port = port; + return NetworkError::OK; +} + +NetworkError UdpClient::store_packet(IpAddress p_ip, NetworkSocket::port_type p_port, std::span p_buf) { + uint16_t buffer_size = p_buf.size(); + std::span ipv6 = p_ip.get_ipv6(); + + if (_ring_buffer.capacity() - _ring_buffer.size() < + buffer_size + ipv6.size_bytes() + sizeof(p_port) + sizeof(buffer_size)) { + return NetworkError::OUT_OF_MEMORY; + } + _ring_buffer.append_range(ipv6); + _ring_buffer.append((uint8_t*)&p_port, sizeof(p_port)); + _ring_buffer.append((uint8_t*)&buffer_size, sizeof(buffer_size)); + _ring_buffer.append_range(p_buf); + ++_queue_count; + return NetworkError::OK; +} + +NetworkError UdpClient::_set_packet(std::span buffer) { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(!peer_ip.is_valid(), NetworkError::UNCONFIGURED); + + NetworkError err; + int64_t sent = -1; + + if (!_socket->is_open()) { + NetworkResolver::Type ip_type = peer_ip.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + err = _socket->open(NetworkSocket::Type::UDP, ip_type); + OV_ERR_FAIL_COND_V(err != NetworkError::OK, err); + _socket->set_blocking_enabled(false); + _socket->set_broadcasting_enabled(_is_broadcasting); + } + + do { + if (_is_connected && !_server) { + err = _socket->send(buffer.data(), buffer.size(), sent); + } else { + err = _socket->send_to(buffer.data(), buffer.size(), sent, peer_ip, peer_port); + } + if (err != NetworkError::OK) { + if (err != NetworkError::BUSY) { + return err; + } else if (!_is_blocking) { + return NetworkError::BUSY; + } + // Keep trying to send full packet + continue; + } + return NetworkError::OK; + + } while (sent != buffer.size()); + + return NetworkError::OK; +} + +NetworkError UdpClient::_get_packet(std::span& r_buffer) { + NetworkError err = _poll(); + if (err != NetworkError::OK) { + return err; + } + if (_queue_count == 0) { + return NetworkError::UNAVAILABLE; + } + + uint16_t size = 0; + std::array ipv6 {}; + _ring_buffer.read_buffer_to(ipv6.data(), sizeof(ipv6)); + _packet_ip.set_ipv6(ipv6); + _ring_buffer.read_buffer_to((uint8_t*)&_packet_port, sizeof(_packet_port)); + _ring_buffer.read_buffer_to((uint8_t*)&size, sizeof(size)); + _ring_buffer.read_buffer_to(_packet_buffer.data(), size); + --_queue_count; + r_buffer = { _packet_buffer.data(), size }; + + return NetworkError::OK; +} + +NetworkError UdpClient::join_multicast_group(IpAddress addr, std::string_view interface_name) { + OV_ERR_FAIL_COND_V(_server, NetworkError::LOCKED); + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(!addr.is_valid(), NetworkError::INVALID_PARAMETER); + + if (!_socket->is_open()) { + NetworkResolver::Type ip_type = addr.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + NetworkError err = _socket->open(NetworkSocket::Type::UDP, ip_type); + OV_ERR_FAIL_COND_V(err != NetworkError::OK, err); + _socket->set_blocking_enabled(false); + _socket->set_broadcasting_enabled(_is_broadcasting); + } + return _socket->change_multicast_group(addr, interface_name, true); +} + +NetworkError UdpClient::leave_multicast_group(IpAddress addr, std::string_view interface_name) { + OV_ERR_FAIL_COND_V(_server, NetworkError::LOCKED); + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(!_socket->is_open(), NetworkError::UNCONFIGURED); + return _socket->change_multicast_group(addr, interface_name, false); +} + +bool UdpClient::is_broadcast_enabled() const { + return _is_broadcasting; +} + +void UdpClient::set_broadcast_enabled(bool enabled) { + OV_ERR_FAIL_COND(_server); + _is_broadcasting = enabled; + if (_socket && _socket->is_open()) { + _socket->set_broadcasting_enabled(enabled); + } +} + +bool UdpClient::is_blocking() const { + return _is_blocking; +} + +void UdpClient::set_blocking(bool blocking) { + _is_blocking = blocking; +} + +int64_t UdpClient::available_packets() const { + NetworkError err = const_cast(this)->_poll(); + if (err != NetworkError::OK) { + return -1; + } + + return _queue_count; +} + +int64_t UdpClient::max_packet_size() const { + return MAX_PACKET_SIZE; +} + +NetworkError UdpClient::_poll() { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + + if (!_socket->is_open()) { + return NetworkError::FAILED; + } + if (_server) { + return NetworkError::OK; // Handled by UDPServer. + } + + NetworkError err; + int64_t read; + IpAddress ip; + NetworkSocket::port_type port; + + while (true) { + if (_is_connected) { + err = _socket->receive(_receive_buffer.data(), sizeof(_receive_buffer), read); + ip = peer_ip; + port = peer_port; + } else { + err = _socket->receive_from(_receive_buffer.data(), sizeof(_receive_buffer), read, ip, port); + } + + if (err != NetworkError::OK) { + if (err == NetworkError::BUSY) { + break; + } + return err; + } + + err = store_packet(ip, port, { _receive_buffer.data(), static_cast(read) }); +#ifdef TOOLS_ENABLED + if (err != NetworkError::OK) { + Logger::warning("Buffer full, dropping packets!"); + } +#endif + } + + return NetworkError::OK; +} + +void UdpClient::_connect_shared_socket( // + std::shared_ptr p_sock, IpAddress p_ip, NetworkSocket::port_type p_port, UdpServer* p_server +) { + _server = p_server; + _is_connected = true; + _socket = std::move(p_sock); + peer_ip = p_ip; + peer_port = p_port; + _packet_ip = peer_ip; + _packet_port = peer_port; +} + +void UdpClient::_disconnect_shared_socket() { + _server = nullptr; + _socket = std::make_shared(); + close(); +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/UdpClient.hpp b/src/openvic-simulation/multiplayer/lowlevel/UdpClient.hpp new file mode 100644 index 000000000..cdf986caa --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/UdpClient.hpp @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/HostnameAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketClient.hpp" +#include "openvic-simulation/types/RingBuffer.hpp" + +namespace OpenVic { + struct UdpServer; + struct ReliableUdpServer; + + struct UdpClient : PacketClient { + static constexpr size_t PACKET_BUFFER_SIZE = 65536; + + static constexpr size_t MAX_PACKET_SIZE = 512 * 3; + + UdpClient(); + ~UdpClient(); + + UdpClient(UdpClient&&) = default; + UdpClient& operator=(UdpClient&&) = default; + + void close(); + NetworkSocket::port_type get_bound_port() const; + bool is_connected() const; + + NetworkError connect_to(IpAddress const& host, NetworkSocket::port_type port); + NetworkError bind( // + NetworkSocket::port_type port, IpAddress const& address, size_t min_buffer_size = PACKET_BUFFER_SIZE + ); + IpAddress get_last_packet_ip() const; + NetworkSocket::port_type get_last_packet_port() const; + bool is_bound() const; + NetworkError wait(); + NetworkError set_destination(HostnameAddress const& addr, NetworkSocket::port_type port); + virtual NetworkError store_packet(IpAddress p_ip, NetworkSocket::port_type p_port, std::span p_buf); + + NetworkError join_multicast_group(IpAddress addr, std::string_view interface_name); + NetworkError leave_multicast_group(IpAddress addr, std::string_view interface_name); + + bool is_broadcast_enabled() const; + void set_broadcast_enabled(bool enabled); + + bool is_blocking() const; + void set_blocking(bool blocking); + + int64_t available_packets() const override; + int64_t max_packet_size() const override; + + protected: + NetworkError _set_packet(std::span buffer) override; + NetworkError _get_packet(std::span& r_set) override; + + NetworkError _poll(); + + bool _is_connected = false; + bool _is_blocking = true; + bool _is_broadcasting = false; + NetworkSocket::port_type _packet_port = 0; + NetworkSocket::port_type PROPERTY_ACCESS(peer_port, protected, 0); + int32_t _queue_count = 0; + IpAddress _packet_ip; + IpAddress PROPERTY_ACCESS(peer_ip, protected); + + std::shared_ptr _socket; + UdpServer* _server = nullptr; + RingBuffer _ring_buffer; + + std::array _receive_buffer; + std::array _packet_buffer; + + friend struct UdpServer; + friend struct ReliableUdpServer; + void _connect_shared_socket( // + std::shared_ptr p_sock, IpAddress p_ip, NetworkSocket::port_type p_port, UdpServer* p_server + ); + void _disconnect_shared_socket(); + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/UdpServer.cpp b/src/openvic-simulation/multiplayer/lowlevel/UdpServer.cpp new file mode 100644 index 000000000..66b9242d2 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/UdpServer.cpp @@ -0,0 +1,114 @@ + +#include "UdpServer.hpp" + +#include + +#include + +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketServer.hpp" +#include "openvic-simulation/multiplayer/lowlevel/UdpClient.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +using namespace OpenVic; + +bool UdpServer::is_connection_available() const { + if (!PacketServer::is_connection_available()) { + return false; + } + return _pending_clients.size() > 0; +} + +void UdpServer::close() { + PacketServer::close(); + for (ClientRef& client : _clients) { + client.client->_disconnect_shared_socket(); + } + for (ClientRef& client : _pending_clients) { + client.client->_disconnect_shared_socket(); + } + _clients.clear(); + _pending_clients.clear(); +} + +UdpClient* UdpServer::_take_next_client() { + if (!is_connection_available()) { + return nullptr; + }; + + ClientRef& client = _pending_clients.front(); + _clients.emplace_back(client.client.release(), client.ip, client.port); + _pending_clients.pop_front(); + return _clients.back().client; +} + +UdpClient* UdpServer::_create_client() { + return new UdpClient(); +} + +void UdpServer::remove_client(IpAddress ip, NetworkSocket::port_type port) { + const ClientRef client { nullptr, ip, port }; + if (memory::vector>::iterator it = ranges::find(_clients, client); it != _clients.end()) { + _clients.erase(it); + } +} + +NetworkError UdpServer::poll() { + OV_ERR_FAIL_COND_V(_socket == nullptr, NetworkError::UNAVAILABLE); + OV_ERR_FAIL_COND_V(!_socket->is_open(), NetworkError::UNCONFIGURED); + + NetworkError err; + int64_t read; + IpAddress ip; + NetworkSocket::port_type port; + while (true) { + err = _socket->receive_from(_receive_buffer.data(), sizeof(_receive_buffer), read, ip, port); + if (err != NetworkError::OK) { + if (err == NetworkError::BUSY) { + break; + } + return err; + } + + ClientRef p; + p.ip = ip; + p.port = port; + UdpClient* client_ptr = nullptr; + if (decltype(_clients)::iterator it = ranges::find(_clients, p); it != _clients.end()) { + client_ptr = it->client; + } else if (decltype(_pending_clients)::iterator pend_it = ranges::find(_pending_clients, ClientRef { p }); + pend_it != _pending_clients.end()) { + client_ptr = pend_it->client.get(); + } + + if (client_ptr) { + client_ptr->store_packet(ip, port, std::span { _receive_buffer.data(), static_cast(read) }); + } else { + if (_pending_clients.size() >= _max_pending_clients) { + // Drop connection. + continue; + } + // It's a new client, add it to the pending list. + ClientRef ref { memory::unique_ptr(_create_client()), ip, port }; + ref.client->_connect_shared_socket(_socket, ip, port, this); + ref.client->store_packet(ip, port, std::span { _receive_buffer.data(), static_cast(read) }); + _pending_clients.push_back(std::move(ref)); + } + } + return NetworkError::OK; +} + +void UdpServer::set_max_pending_clients(uint32_t clients) { + OV_ERR_FAIL_COND_MSG( + clients < 0, "Max pending connections value must be a positive number (0 means refuse new connections)." + ); + _max_pending_clients = clients; + if (clients > _pending_clients.size()) { + _pending_clients.erase(_pending_clients.begin() + clients + 1, _pending_clients.end()); + } +} + +uint32_t UdpServer::get_max_pending_clients() const { + return _max_pending_clients; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/UdpServer.hpp b/src/openvic-simulation/multiplayer/lowlevel/UdpServer.hpp new file mode 100644 index 000000000..02eeda5c5 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/UdpServer.hpp @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocket.hpp" +#include "openvic-simulation/multiplayer/lowlevel/PacketServer.hpp" +#include "openvic-simulation/multiplayer/lowlevel/UdpClient.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/Deque.hpp" + +namespace OpenVic { + struct UdpServer : PacketServer { + bool is_connection_available() const override; + void close() override; + NetworkError poll() override; + + void remove_client(IpAddress ip, NetworkSocket::port_type port); + + void set_max_pending_clients(uint32_t clients); + uint32_t get_max_pending_clients() const; + + protected: + UdpClient* _take_next_client() override; + virtual UdpClient* _create_client(); + + template + struct ClientRef { + static constexpr std::bool_constant is_owner {}; + + std::conditional_t, UdpClient*> client = nullptr; + IpAddress ip; + NetworkSocket::port_type port = 0; + + ClientRef() = default; + ClientRef(decltype(client)&& client, IpAddress ip, NetworkSocket::port_type port = 0) + : client { std::move(client) }, ip { ip }, port { port } {} + ClientRef(ClientRef const& ref) : client { ref.client }, ip { ref.ip }, port { ref.port } {} + ClientRef(ClientRef const&) = default; + ClientRef& operator=(ClientRef const&) = default; + ClientRef(ClientRef&&) = default; + ClientRef& operator=(ClientRef&&) = default; + + template + bool operator==(ClientRef const& p_other) const { + return (ip == p_other.ip && port == p_other.port); + } + }; + + uint32_t _max_pending_clients = 16; + + memory::vector> _clients; + OpenVic::utility::deque> _pending_clients; + + std::array _receive_buffer; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.cpp b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.cpp new file mode 100644 index 000000000..06349a50a --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.cpp @@ -0,0 +1,154 @@ +#if defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)) + +#include "UnixNetworkResolver.hpp" + +#include +#include +#include + +#include + +#include + +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/types/StackString.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/StringUtils.hpp" + +#ifdef __FreeBSD__ +#include +#endif +#include + +#include +#include + +#ifdef __FreeBSD__ +#include +#endif + +#include // Order is important on OpenBSD, leave as last. + +using namespace OpenVic; + +static IpAddress _sockaddr2ip(struct sockaddr* p_addr) { + IpAddress ip; + + if (p_addr->sa_family == AF_INET) { + struct sockaddr_in* addr = (struct sockaddr_in*)p_addr; + ip.set_ipv4(std::bit_cast>(&addr->sin_addr)); + } else if (p_addr->sa_family == AF_INET6) { + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)p_addr; + ip.set_ipv6(addr6->sin6_addr.s6_addr); + } + + return ip; +} + +UnixNetworkResolver UnixNetworkResolver::_singleton {}; + +string_map_t UnixNetworkResolver::get_local_interfaces() const { + string_map_t result; + + struct ifaddrs* ifAddrStruct = nullptr; + struct ifaddrs* ifa = nullptr; + int family; + + getifaddrs(&ifAddrStruct); + + for (ifa = ifAddrStruct; ifa != nullptr; ifa = ifa->ifa_next) { + if (!ifa->ifa_addr) { + continue; + } + + family = ifa->ifa_addr->sa_family; + + if (family != AF_INET && family != AF_INET6) { + continue; + } + + string_map_t::iterator it = result.find(ifa->ifa_name); + if (it == result.end()) { + InterfaceInfo info; + info.name = ifa->ifa_name; + info.name_friendly = ifa->ifa_name; + + struct stack_string : StackString::max())> { + using StackString::_array; + using StackString::_string_size; + using StackString::StackString; + } str {}; + std::to_chars_result to_chars = + StringUtils::to_chars(str._array.data(), str._array.data() + str.array_length, if_nametoindex(ifa->ifa_name)); + str._string_size = to_chars.ptr - str.data(); + + info.index = str; + auto pair = result.insert_or_assign(ifa->ifa_name, info); + OV_ERR_CONTINUE(!pair.second); + it = pair.first; + } + + InterfaceInfo& info = it.value(); + info.ip_addresses.push_back(_sockaddr2ip(ifa->ifa_addr)); + } + + if (ifAddrStruct != nullptr) { + freeifaddrs(ifAddrStruct); + } + + return result; +} + +memory::vector UnixNetworkResolver::_resolve_hostname(std::string_view p_hostname, Type p_type) const { + struct addrinfo hints; // NOLINT + struct addrinfo* result = nullptr; + + std::memset(&hints, 0, sizeof(struct addrinfo)); + if (p_type == Type::IPV4) { + hints.ai_family = AF_INET; + } else if (p_type == Type::IPV6) { + hints.ai_family = AF_INET6; + hints.ai_flags = 0; + } else { + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_ADDRCONFIG; + } + hints.ai_flags &= ~AI_NUMERICHOST; + + int s = getaddrinfo(p_hostname.data(), nullptr, &hints, &result); + if (s != 0) { + Logger::info("getaddrinfo failed! Cannot resolve hostname."); + return {}; + } + + if (result == nullptr || result->ai_addr == nullptr) { + Logger::info("Invalid response from getaddrinfo."); + if (result) { + freeaddrinfo(result); + } + return {}; + } + + struct addrinfo* next = result; + + memory::vector result_addrs; + do { + if (next->ai_addr == nullptr) { + next = next->ai_next; + continue; + } + IpAddress ip = _sockaddr2ip(next->ai_addr); + if (ip.is_valid() && ranges::find(result_addrs, ip) == result_addrs.end()) { + result_addrs.push_back(ip); + } + next = next->ai_next; + } while (next); + + freeaddrinfo(result); + + return result_addrs; +} + +#endif diff --git a/src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.hpp b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.hpp new file mode 100644 index 000000000..c2923b81f --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixNetworkResolver.hpp @@ -0,0 +1,32 @@ +#pragma once + +#if !(defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__))) +#error "UnixNetworkResolver.hpp should only be included on unix systems" +#endif + +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct UnixNetworkResolver final : NetworkResolverBase { + static constexpr Provider provider_value = Provider::UNIX; + + static UnixNetworkResolver& singleton() { + return _singleton; + } + + OpenVic::string_map_t get_local_interfaces() const override; + + Provider provider() const override { + return provider_value; + } + + private: + friend NetworkResolverBase::ResolveHandler; + + memory::vector _resolve_hostname(std::string_view p_hostname, Type p_type = Type::ANY) const override; + + static UnixNetworkResolver _singleton; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.cpp b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.cpp new file mode 100644 index 000000000..a64de229b --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.cpp @@ -0,0 +1,564 @@ +#if defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__)) +#include "UnixSocket.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/StringUtils.hpp" + +#include +#include +#include +#include +#include + +// BSD calls this flag IPV6_JOIN_GROUP +#if !defined(IPV6_ADD_MEMBERSHIP) && defined(IPV6_JOIN_GROUP) +#define IPV6_ADD_MEMBERSHIP IPV6_JOIN_GROUP +#endif +#if !defined(IPV6_DROP_MEMBERSHIP) && defined(IPV6_LEAVE_GROUP) +#define IPV6_DROP_MEMBERSHIP IPV6_LEAVE_GROUP +#endif + +using namespace OpenVic; + +UnixSocket::UnixSocket() {} + +UnixSocket::~UnixSocket() { + close(); +} + +NetworkError UnixSocket::_get_socket_error() const { + switch (errno) { + case EISCONN: return NetworkError::IS_CONNECTED; + case EINPROGRESS: + case EALREADY: return NetworkError::IN_PROGRESS; +#if EAGAIN != EWOULDBLOCK + case EAGAIN: +#endif + case EWOULDBLOCK: return NetworkError::WOULD_BLOCK; + case EADDRINUSE: + case EINVAL: + case EADDRNOTAVAIL: return NetworkError::ADDRESS_INVALID_OR_UNAVAILABLE; + case EACCES: return NetworkError::UNAUTHORIZED; + case ENOBUFS: return NetworkError::BUFFER_TOO_SMALL; + default: // + Logger::info("Socket error: ", strerror(errno)); + return NetworkError::OTHER; + } +} + +void UnixSocket::_set_ip_port(struct sockaddr_storage* addr, IpAddress* r_ip, port_type* r_port) { + if (addr->ss_family == AF_INET) { + struct sockaddr_in* addr4 = (struct sockaddr_in*)addr; + if (r_ip) { + r_ip->set_ipv4(std::bit_cast>(&addr4->sin_addr.s_addr)); + } + if (r_port) { + *r_port = ntohs(addr4->sin_port); + } + } else if (addr->ss_family == AF_INET6) { + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)addr; + if (r_ip) { + r_ip->set_ipv6(addr6->sin6_addr.s6_addr); + } + if (r_port) { + *r_port = ntohs(addr6->sin6_port); + } + } +} + +bool UnixSocket::_can_use_ip(IpAddress const& ip, const bool for_bind) const { + if (for_bind && !(ip.is_valid() || ip.is_wildcard())) { + return false; + } else if (!for_bind && !ip.is_valid()) { + return false; + } + // Check if socket support this IP type. + NetworkResolver::Type type = ip.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + return !(_ip_type != NetworkResolver::Type::ANY && !ip.is_wildcard() && _ip_type != type); +} + +void UnixSocket::setup() {} + +void UnixSocket::cleanup() {} + +NetworkError UnixSocket::open(Type type, NetworkResolver::Type& ip_type) { + OV_ERR_FAIL_COND_V(is_open(), NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V( + ip_type > NetworkResolver::Type::ANY || ip_type < NetworkResolver::Type::NONE, NetworkError::INVALID_PARAMETER + ); + +#if defined(__OpenBSD__) + // OpenBSD does not support dual stacking, fallback to IPv4 only. + if (ip_type == NetworkResolver::Type::ANY) { + ip_type = NetworkResolver::Type::IPV4; + } +#endif + + int family = ip_type == NetworkResolver::Type::IPV4 ? AF_INET : AF_INET6; + int protocol = type == Type::TCP ? IPPROTO_TCP : IPPROTO_UDP; + int socket_type = type == Type::TCP ? SOCK_STREAM : SOCK_DGRAM; + _sock = socket(family, socket_type, protocol); + + if (_sock == -1 && ip_type == NetworkResolver::Type::ANY) { + // Careful here, changing the referenced parameter so the caller knows that we are using an IPv4 socket + // in place of a dual stack one, and further calls to _set_sock_addr will work as expected. + ip_type = NetworkResolver::Type::IPV4; + family = AF_INET; + _sock = socket(family, socket_type, protocol); + } + + OV_ERR_FAIL_COND_V(_sock == -1, NetworkError::SOCKET_ERROR); + _ip_type = ip_type; + + if (family == AF_INET6) { + // Select IPv4 over IPv6 mapping. + set_ipv6_only_enabled(ip_type != NetworkResolver::Type::ANY); + } + + if (protocol == IPPROTO_UDP) { + // Make sure to disable broadcasting for UDP sockets. + // Depending on the OS, this option might or might not be enabled by default. Let's normalize it. + set_broadcasting_enabled(false); + } + + _is_stream = type == Type::TCP; + + // Disable descriptor sharing with subprocesses. + _set_close_exec_enabled(true); + +#if defined(SO_NOSIGPIPE) + // Disable SIGPIPE (should only be relevant to stream sockets, but seems to affect UDP too on iOS). + int par = 1; + if (setsockopt(_sock, SOL_SOCKET, SO_NOSIGPIPE, &par, sizeof(int)) != 0) { + Logger::info("Unable to turn off SIGPIPE on socket."); + } +#endif + return NetworkError::OK; +} + +void UnixSocket::close() { + if (_sock != -1) { + ::close(_sock); + } + + _sock = -1; + _ip_type = NetworkResolver::Type::NONE; + _is_stream = false; +} + +NetworkError UnixSocket::bind(IpAddress const& addr, port_type p_port) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!_can_use_ip(addr, true), NetworkError::INVALID_PARAMETER); + + sockaddr_storage addr_store; // NOLINT + size_t addr_size = _set_addr_storage(&addr_store, addr, p_port, _ip_type); + + if (::bind(_sock, (struct sockaddr*)&addr_store, addr_size) != 0) { + NetworkError err = _get_socket_error(); + Logger::info("Failed to bind socket. Error: ", strerror(errno)); + close(); + return err; + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::listen(int max_pending) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + if (::listen(_sock, max_pending) != 0) { + NetworkError err = _get_socket_error(); + Logger::info("Failed to listen from socket. Error: ", strerror(errno)); + close(); + return err; + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::connect_to_host(IpAddress const& host, port_type port) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!_can_use_ip(host, false), NetworkError::INVALID_PARAMETER); + + struct sockaddr_storage addr; // NOLINT + size_t addr_size = _set_addr_storage(&addr, host, port, _ip_type); + + if (::connect(_sock, (struct sockaddr*)&addr, addr_size) != 0) { + NetworkError err = _get_socket_error(); + + switch (err) { + using enum NetworkError; + // We are already connected. + case IS_CONNECTED: return NetworkError::OK; + // Still waiting to connect, try again in a while. + case WOULD_BLOCK: + case IN_PROGRESS: return NetworkError::BUSY; + default: + Logger::info("Connection to remote host failed. Error: ", strerror(errno)); + close(); + return err; + } + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::poll(PollType type, int timeout) const { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + struct pollfd pfd { .fd = _sock, .events = POLLIN, .revents = 0 }; + + switch (type) { + using enum PollType; + case IN: pfd.events = POLLIN; break; + case OUT: pfd.events = POLLOUT; break; + case IN_OUT: pfd.events = POLLOUT | POLLIN; break; + } + + int ret = ::poll(&pfd, 1, timeout); + + if (ret < 0 || pfd.revents & POLLERR) { + NetworkError err = _get_socket_error(); + Logger::info("Error when polling socket. Error ", strerror(errno)); + return err; + } + + if (ret == 0) { + return NetworkError::BUSY; + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::receive(uint8_t* buffer, size_t p_len, int64_t& r_read) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + r_read = ::recv(_sock, buffer, p_len, 0); + + if (r_read < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::receive_from( // + uint8_t* buffer, size_t p_len, int64_t& r_read, IpAddress& r_ip, port_type& r_port, bool peek +) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + struct sockaddr_storage from; // NOLINT + socklen_t socket_length = sizeof(from); + std::memset(&from, 0, socket_length); + + r_read = ::recvfrom(_sock, buffer, p_len, peek ? MSG_PEEK : 0, (struct sockaddr*)&from, &socket_length); + + if (r_read < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + if (from.ss_family == AF_INET) { + struct sockaddr_in* sin_from = (struct sockaddr_in*)&from; + r_ip.set_ipv4(std::bit_cast>(&sin_from->sin_addr)); + r_port = ntohs(sin_from->sin_port); + } else if (from.ss_family == AF_INET6) { + struct sockaddr_in6* s6_from = (struct sockaddr_in6*)&from; + r_ip.set_ipv6(std::bit_cast>(&s6_from->sin6_addr)); + r_port = ntohs(s6_from->sin6_port); + } else { + // Unsupported socket family, should never happen. + OV_ERR_FAIL_V(NetworkError::UNSUPPORTED); + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::send(const uint8_t* buffer, size_t p_len, int64_t& r_sent) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + int flags = 0; +#ifdef MSG_NOSIGNAL + if (_is_stream) { + flags = MSG_NOSIGNAL; + } +#endif + r_sent = ::send(_sock, buffer, p_len, flags); + + if (r_sent < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + return NetworkError::OK; +} + +NetworkError UnixSocket::send_to(const uint8_t* buffer, size_t p_len, int64_t& r_sent, IpAddress const& ip, port_type port) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + struct sockaddr_storage addr; // NOLINT + size_t addr_size = _set_addr_storage(&addr, ip, port, _ip_type); + r_sent = ::sendto(_sock, buffer, p_len, 0, (struct sockaddr*)&addr, addr_size); + + if (r_sent < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + return NetworkError::OK; +} + +UnixSocket* UnixSocket::_accept(IpAddress& r_ip, port_type& r_port) { + OV_ERR_FAIL_COND_V(!is_open(), nullptr); + + struct sockaddr_storage their_addr; // NOLINT + socklen_t size = sizeof(their_addr); + int fd = ::accept(_sock, (struct sockaddr*)&their_addr, &size); + if (fd == -1) { + NetworkError err = _get_socket_error(); + Logger::info("Error when accepting socket connection. Error: ", strerror(errno)); + return nullptr; + } + + _set_ip_port(&their_addr, &r_ip, &r_port); + + UnixSocket* ns = new UnixSocket(); + ns->_sock = fd; + ns->_ip_type = _ip_type; + ns->_is_stream = _is_stream; + // Disable descriptor sharing with subprocesses. + ns->_set_close_exec_enabled(true); + ns->set_blocking_enabled(false); + return ns; +} + +bool UnixSocket::is_open() const { + return _sock != -1; +} + +int UnixSocket::available_bytes() const { + OV_ERR_FAIL_COND_V(!is_open(), -1); + + int len; + int ret = ioctl(_sock, FIONREAD, &len); + if (ret == -1) { + _get_socket_error(); + Logger::info("Error when checking available bytes on socket. Error: ", strerror(errno)); + return -1; + } + return len; +} + +NetworkError UnixSocket::get_socket_address(IpAddress* r_ip, port_type* r_port) const { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::NOT_OPEN); + + struct sockaddr_storage saddr; // NOLINT + socklen_t len = sizeof(saddr); + if (getsockname(_sock, (struct sockaddr*)&saddr, &len) != 0) { + NetworkError err = _get_socket_error(); + Logger::info("Error when reading local socket address. Error: ", strerror(errno)); + return err; + } + _set_ip_port(&saddr, r_ip, r_port); + return NetworkError::OK; +} + +NetworkError UnixSocket::set_broadcasting_enabled(bool enabled) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + // IPv6 has no broadcast support. + if (_ip_type == NetworkResolver::Type::IPV6) { + return NetworkError::UNSUPPORTED; + } + + int par = enabled ? 1 : 0; + if (setsockopt(_sock, SOL_SOCKET, SO_BROADCAST, &par, sizeof(int)) != 0) { + Logger::warning("Unable to change broadcast setting."); + return NetworkError::BROADCAST_CHANGE_FAILED; + } + return NetworkError::OK; +} + +void UnixSocket::set_blocking_enabled(bool enabled) { + OV_ERR_FAIL_COND(!is_open()); + + int ret = 0; + int opts = fcntl(_sock, F_GETFL); + if (enabled) { + ret = fcntl(_sock, F_SETFL, opts & ~O_NONBLOCK); + } else { + ret = fcntl(_sock, F_SETFL, opts | O_NONBLOCK); + } + + if (ret != 0) { + Logger::warning("Unable to change non-block mode."); + } +} + +void UnixSocket::set_ipv6_only_enabled(bool enabled) { + OV_ERR_FAIL_COND(!is_open()); + // This option is only available in IPv6 sockets. + OV_ERR_FAIL_COND(_ip_type == NetworkResolver::Type::IPV4); + + int par = enabled ? 1 : 0; + if (setsockopt(_sock, IPPROTO_IPV6, IPV6_V6ONLY, &par, sizeof(int)) != 0) { + Logger::warning("Unable to change IPv4 address mapping over IPv6 option."); + } +} + +void UnixSocket::set_tcp_no_delay_enabled(bool enabled) { + OV_ERR_FAIL_COND(!is_open()); + OV_ERR_FAIL_COND(!_is_stream); // Not TCP. + + int par = enabled ? 1 : 0; + if (setsockopt(_sock, IPPROTO_TCP, TCP_NODELAY, &par, sizeof(int)) < 0) { + Logger::warning("Unable to set TCP no delay option."); + } +} + +void UnixSocket::set_reuse_address_enabled(bool enabled) { + OV_ERR_FAIL_COND(!is_open()); + + int par = enabled ? 1 : 0; + if (setsockopt(_sock, SOL_SOCKET, SO_REUSEADDR, &par, sizeof(int)) < 0) { + Logger::warning("Unable to set socket REUSEADDR option."); + } +} + +NetworkError UnixSocket::change_multicast_group(IpAddress const& multi_address, std::string_view if_name, bool add) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!_can_use_ip(multi_address, false), NetworkError::INVALID_PARAMETER); + + // Need to force level and af_family to IP(v4) when using dual stacking and provided multicast group is IPv4. + NetworkResolver::Type type = + _ip_type == NetworkResolver::Type::ANY && multi_address.is_ipv4() ? NetworkResolver::Type::IPV4 : _ip_type; + // This needs to be the proper level for the multicast group, no matter if the socket is dual stacking. + int level = type == NetworkResolver::Type::IPV4 ? IPPROTO_IP : IPPROTO_IPV6; + int ret = -1; + + IpAddress if_ip; + uint32_t if_v6id = 0; + OpenVic::string_map_t if_info = NetworkResolver::singleton().get_local_interfaces(); + for (std::pair const& pair : if_info) { + NetworkResolver::InterfaceInfo const& c = pair.second; + if (c.name != if_name) { + continue; + } + + std::from_chars_result from_chars = + StringUtils::from_chars(c.index.data(), c.index.data() + c.index.size(), if_v6id); + OV_ERR_FAIL_COND_V(from_chars.ec != std::errc {}, NetworkError::SOCKET_ERROR); + if (type == NetworkResolver::Type::IPV6) { + break; // IPv6 uses index. + } + + for (IpAddress const& F : c.ip_addresses) { + if (!F.is_ipv4()) { + continue; // Wrong IP type. + } + if_ip = F; + break; + } + break; + } + + if (level == IPPROTO_IP) { + OV_ERR_FAIL_COND_V(!if_ip.is_valid(), NetworkError::INVALID_PARAMETER); + struct ip_mreq greq; // NOLINT + int sock_opt = add ? IP_ADD_MEMBERSHIP : IP_DROP_MEMBERSHIP; + std::memcpy(&greq.imr_multiaddr, multi_address.get_ipv4().data(), 4); + std::memcpy(&greq.imr_interface, if_ip.get_ipv4().data(), 4); + ret = setsockopt(_sock, level, sock_opt, (const char*)&greq, sizeof(greq)); + } else { + struct ipv6_mreq greq; // NOLINT + int sock_opt = add ? IPV6_ADD_MEMBERSHIP : IPV6_DROP_MEMBERSHIP; + std::memcpy(&greq.ipv6mr_multiaddr, multi_address.get_ipv6().data(), 16); + greq.ipv6mr_interface = if_v6id; + ret = setsockopt(_sock, level, sock_opt, (const char*)&greq, sizeof(greq)); + } + + if (ret == -1) { + return _get_socket_error(); + } + + return NetworkError::OK; +} + +size_t UnixSocket::_set_addr_storage( // + sockaddr_storage* addr, IpAddress const& ip, port_type port, NetworkResolver::Type ip_type +) { + std::memset(addr, 0, sizeof(struct sockaddr_storage)); + if (ip_type == NetworkResolver::Type::IPV6 || ip_type == NetworkResolver::Type::ANY) { // IPv6 socket. + + // IPv6 only socket with IPv4 address. + OV_ERR_FAIL_COND_V(!ip.is_wildcard() && ip_type == NetworkResolver::Type::IPV6 && ip.is_ipv4(), 0); + + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)addr; + addr6->sin6_family = AF_INET6; + addr6->sin6_port = htons(port); + if (ip.is_valid()) { + std::memcpy(&addr6->sin6_addr.s6_addr, ip.get_ipv6().data(), 16); + } else { + addr6->sin6_addr = in6addr_any; + } + return sizeof(sockaddr_in6); + } else { // IPv4 socket. + + // IPv4 socket with IPv6 address. + OV_ERR_FAIL_COND_V(!ip.is_wildcard() && !ip.is_ipv4(), 0); + + struct sockaddr_in* addr4 = (struct sockaddr_in*)addr; + addr4->sin_family = AF_INET; + addr4->sin_port = htons(port); // Short, network byte order. + + if (ip.is_valid()) { + std::memcpy(&addr4->sin_addr.s_addr, ip.get_ipv4().data(), 4); + } else { + addr4->sin_addr.s_addr = INADDR_ANY; + } + + return sizeof(sockaddr_in); + } +} + +void UnixSocket::_set_close_exec_enabled(bool enabled) { + // Enable close on exec to avoid sharing with subprocesses. Off by default on Windows. + int opts = fcntl(_sock, F_GETFD); + fcntl(_sock, F_SETFD, opts | FD_CLOEXEC); +} + +#endif diff --git a/src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.hpp b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.hpp new file mode 100644 index 000000000..3a3b33aad --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/unix/UnixSocket.hpp @@ -0,0 +1,70 @@ +#pragma once + +#if !(defined(__unix__) || defined(__unix) || (defined(__APPLE__) && defined(__MACH__))) +#error "UnixSocket.hpp should only be included on unix systems" +#endif + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp" + +struct sockaddr_storage; + +namespace OpenVic { + struct UnixSocket final : NetworkSocketBase { + static constexpr NetworkResolver::Provider provider_value = NetworkResolver::Provider::UNIX; + + UnixSocket(); + ~UnixSocket() override; + + NetworkError open(Type p_type, NetworkResolver::Type& ip_type) override; + void close() override; + NetworkError bind(IpAddress const& p_addr, port_type p_port) override; + NetworkError listen(int p_max_pending) override; + NetworkError connect_to_host(IpAddress const& p_addr, port_type p_port) override; + NetworkError poll(PollType p_type, int timeout) const override; + NetworkError receive(uint8_t* p_buffer, size_t p_len, int64_t& r_read) override; + NetworkError receive_from( // + uint8_t* p_buffer, size_t p_len, int64_t& r_read, IpAddress& r_ip, port_type& r_port, bool p_peek = false + ) override; + NetworkError send(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent) override; + NetworkError + send_to(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent, IpAddress const& p_ip, port_type p_port) override; + + bool is_open() const override; + int available_bytes() const override; + NetworkError get_socket_address(IpAddress* r_ip, port_type* r_port) const override; + + // Returns OK if the socket option has been set successfully + NetworkError set_broadcasting_enabled(bool p_enabled) override; + void set_blocking_enabled(bool p_enabled) override; + void set_ipv6_only_enabled(bool p_enabled) override; + void set_tcp_no_delay_enabled(bool p_enabled) override; + void set_reuse_address_enabled(bool p_enabled) override; + NetworkError change_multicast_group(IpAddress const& p_multi_address, std::string_view p_if_name, bool add) override; + + NetworkResolver::Provider provider() const override { + return provider_value; + } + + static void setup(); + static void cleanup(); + + protected: + UnixSocket* _accept(IpAddress& r_ip, port_type& r_port) override; + + private: + static size_t _set_addr_storage( // + sockaddr_storage* p_addr, IpAddress const& p_ip, port_type p_port, NetworkResolver::Type p_ip_type + ); + static void _set_ip_port(sockaddr_storage* addr, IpAddress* r_ip, port_type* r_port); + + NetworkError _get_socket_error() const; + bool _can_use_ip(IpAddress const& p_ip, const bool p_for_bind) const; + void _set_close_exec_enabled(bool p_enabled); + + int32_t _sock = -1; + NetworkResolver::Type _ip_type = NetworkResolver::Type::NONE; + bool _is_stream = false; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.cpp b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.cpp new file mode 100644 index 000000000..b24cf6354 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.cpp @@ -0,0 +1,183 @@ +#ifdef _WIN32 +#include "WindowsNetworkResolver.hpp" + +#include +#include + +#include +#include + +#include +#include + +#include + +#pragma comment(lib, "Ws2_32.lib") +#pragma comment(lib, "Iphlpapi.lib") + +#define WIN32_LEAN_AND_MEAN +#include +#include +// +#include +#include +#undef WIN32_LEAN_AND_MEAN +// Thank You Microsoft +#undef SOCKET_ERROR +#undef IN +#undef OUT + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/types/StackString.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/StringUtils.hpp" + +using namespace OpenVic; + +static IpAddress _sockaddr2ip(struct sockaddr* p_addr) { + IpAddress ip; + + if (p_addr->sa_family == AF_INET) { + struct sockaddr_in* addr = (struct sockaddr_in*)p_addr; + ip.set_ipv4(std::bit_cast>(&addr->sin_addr)); + } else if (p_addr->sa_family == AF_INET6) { + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)p_addr; + ip.set_ipv6(addr6->sin6_addr.s6_addr); + } + + return ip; +} + +WindowsNetworkResolver WindowsNetworkResolver::_singleton {}; + +WindowsNetworkResolver::WindowsNetworkResolver() { + WindowsSocket::setup(); +} + +WindowsNetworkResolver::~WindowsNetworkResolver() { + WindowsSocket::cleanup(); +} + +string_map_t WindowsNetworkResolver::get_local_interfaces() const { + ULONG buf_size = 1024; + IP_ADAPTER_ADDRESSES* addrs; + + while (true) { + addrs = (IP_ADAPTER_ADDRESSES*)malloc(buf_size); + int err = GetAdaptersAddresses( + AF_UNSPEC, GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER | GAA_FLAG_SKIP_FRIENDLY_NAME, + nullptr, addrs, &buf_size + ); + if (err == NO_ERROR) { + break; + } + free(addrs); + if (err == ERROR_BUFFER_OVERFLOW) { + continue; // Will go back and alloc the right size. + } + + OV_ERR_FAIL_V_MSG( + string_map_t {}, fmt::format("Call to GetAdaptersAddresses failed with error {}.", err) + ); + } + + string_map_t result; + IP_ADAPTER_ADDRESSES* adapter = addrs; + + while (adapter != nullptr) { + InterfaceInfo info; + info.name = adapter->AdapterName; + + size_t name_wlength = wcslen(adapter->FriendlyName); + int name_length = WideCharToMultiByte(CP_UTF8, 0, adapter->FriendlyName, name_wlength, 0, 0, nullptr, nullptr); + info.name_friendly.reserve(name_length * sizeof(char)); + WideCharToMultiByte( + CP_UTF8, 0, adapter->FriendlyName, name_wlength, info.name_friendly.data(), name_length, nullptr, nullptr + ); + + struct stack_string : StackString::max())> { + using StackString::_array; + using StackString::_string_size; + using StackString::StackString; + } str {}; + std::to_chars_result to_chars = + StringUtils::to_chars(str._array.data(), str._array.data() + str.array_length, adapter->IfIndex); + str._string_size = to_chars.ptr - str.data(); + + info.index = str; + + IP_ADAPTER_UNICAST_ADDRESS* address = adapter->FirstUnicastAddress; + while (address != nullptr) { + int family = address->Address.lpSockaddr->sa_family; + if (family != AF_INET && family != AF_INET6) { + continue; + } + info.ip_addresses.push_back(_sockaddr2ip(address->Address.lpSockaddr)); + address = address->Next; + } + adapter = adapter->Next; + // Only add interface if it has at least one IP. + if (info.ip_addresses.size() > 0) { + auto pair = result.insert_or_assign(info.name, info); + OV_ERR_CONTINUE(!pair.second); + } + } + + free(addrs); + + return result; +} + +memory::vector WindowsNetworkResolver::_resolve_hostname(std::string_view p_hostname, Type p_type) const { + struct addrinfo hints; // NOLINT + struct addrinfo* result = nullptr; + + std::memset(&hints, 0, sizeof(struct addrinfo)); + if (p_type == Type::IPV4) { + hints.ai_family = AF_INET; + } else if (p_type == Type::IPV6) { + hints.ai_family = AF_INET6; + hints.ai_flags = 0; + } else { + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_ADDRCONFIG; + } + hints.ai_flags &= ~AI_NUMERICHOST; + + int s = getaddrinfo(p_hostname.data(), nullptr, &hints, &result); + if (s != 0) { + Logger::info("getaddrinfo failed! Cannot resolve hostname."); + return {}; + } + + if (result == nullptr || result->ai_addr == nullptr) { + Logger::info("Invalid response from getaddrinfo."); + if (result) { + freeaddrinfo(result); + } + return {}; + } + + struct addrinfo* next = result; + + memory::vector result_addrs; + do { + if (next->ai_addr == nullptr) { + next = next->ai_next; + continue; + } + IpAddress ip = _sockaddr2ip(next->ai_addr); + if (ip.is_valid() && ranges::find(result_addrs, ip) == result_addrs.end()) { + result_addrs.push_back(ip); + } + next = next->ai_next; + } while (next); + + freeaddrinfo(result); + + return result_addrs; +} + +#endif diff --git a/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.hpp b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.hpp new file mode 100644 index 000000000..bf043bf2f --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsNetworkResolver.hpp @@ -0,0 +1,35 @@ +#pragma once + +#if !defined(_WIN32) +#error "WindowsNetworkResolver.hpp should only be included on windows systems" +#endif + +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolverBase.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/utility/Containers.hpp" + +namespace OpenVic { + struct WindowsNetworkResolver final : NetworkResolverBase { + static constexpr Provider provider_value = Provider::WINDOWS; + + WindowsNetworkResolver(); + ~WindowsNetworkResolver(); + + static WindowsNetworkResolver& singleton() { + return _singleton; + } + + OpenVic::string_map_t get_local_interfaces() const override; + + Provider provider() const override { + return provider_value; + } + + private: + friend NetworkResolverBase::ResolveHandler; + + memory::vector _resolve_hostname(std::string_view p_hostname, Type p_type = Type::ANY) const override; + + static WindowsNetworkResolver _singleton; + }; +} diff --git a/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.cpp b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.cpp new file mode 100644 index 000000000..766ff1ea2 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.cpp @@ -0,0 +1,583 @@ +#ifdef _WIN32 +#include "WindowsSocket.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#pragma comment(lib, "Ws2_32.lib") +#pragma comment(lib, "Mswsock.lib") + +#include +// +#include +#include +#include +#include +#include + +// Thank You Microsoft +#define WIN_SOCKET_ERROR (SOCKET)(~0) +#undef SOCKET_ERROR + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp" +#include "openvic-simulation/types/OrderedContainers.hpp" +#include "openvic-simulation/utility/Containers.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/StringUtils.hpp" + +// Workaround missing flag in MinGW +#if defined(__MINGW32__) && !defined(SIO_UDP_NETRESET) +#define SIO_UDP_NETRESET _WSAIOW(IOC_VENDOR, 15) +#endif + +using namespace OpenVic; + +WindowsSocket::WindowsSocket() {} + +WindowsSocket::~WindowsSocket() { + close(); +} + +NetworkError WindowsSocket::_get_socket_error() const { + int err = WSAGetLastError(); + switch (err) { + case WSAEISCONN: return NetworkError::IS_CONNECTED; + case WSAEINPROGRESS: + case WSAEALREADY: return NetworkError::IN_PROGRESS; + case WSAEWOULDBLOCK: return NetworkError::WOULD_BLOCK; + case WSAEADDRINUSE: + case WSAEADDRNOTAVAIL: return NetworkError::ADDRESS_INVALID_OR_UNAVAILABLE; + case WSAEACCES: return NetworkError::UNAUTHORIZED; + case WSAEMSGSIZE: + case WSAENOBUFS: return NetworkError::BUFFER_TOO_SMALL; + default: // + Logger::info("Socket error: ", strerror(errno)); + return NetworkError::OTHER; + } +} + +void WindowsSocket::_set_ip_port(sockaddr_storage* addr, IpAddress* r_ip, port_type* r_port) { + if (addr->ss_family == AF_INET) { + struct sockaddr_in* addr4 = (struct sockaddr_in*)addr; + if (r_ip) { + r_ip->set_ipv4(std::bit_cast>(&addr4->sin_addr.s_addr)); + } + if (r_port) { + *r_port = ntohs(addr4->sin_port); + } + } else if (addr->ss_family == AF_INET6) { + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)addr; + if (r_ip) { + r_ip->set_ipv6(addr6->sin6_addr.s6_addr); + } + if (r_port) { + *r_port = ntohs(addr6->sin6_port); + } + } +} + +bool WindowsSocket::_can_use_ip(IpAddress const& p_ip, const bool p_for_bind) const { + if (p_for_bind && !(p_ip.is_valid() || p_ip.is_wildcard())) { + return false; + } else if (!p_for_bind && !p_ip.is_valid()) { + return false; + } + // Check if socket support this IP type. + NetworkResolver::Type type = p_ip.is_ipv4() ? NetworkResolver::Type::IPV4 : NetworkResolver::Type::IPV6; + return !(_ip_type != NetworkResolver::Type::ANY && !p_ip.is_wildcard() && _ip_type != type); +} + +void WindowsSocket::setup() { + WSADATA data; + WSAStartup(MAKEWORD(2, 2), &data); +} + +void WindowsSocket::cleanup() { + WSACleanup(); +} + +NetworkError WindowsSocket::open(Type sock_type, NetworkResolver::Type& ip_type) { + OV_ERR_FAIL_COND_V(is_open(), NetworkError::ALREADY_OPEN); + OV_ERR_FAIL_COND_V( + ip_type > NetworkResolver::Type::ANY || ip_type < NetworkResolver::Type::NONE, NetworkError::INVALID_PARAMETER + ); + + int family = ip_type == NetworkResolver::Type::IPV4 ? AF_INET : AF_INET6; + int protocol = sock_type == Type::TCP ? IPPROTO_TCP : IPPROTO_UDP; + int type = sock_type == Type::TCP ? SOCK_STREAM : SOCK_DGRAM; + _sock = socket(family, type, protocol); + + if (_sock == INVALID_SOCKET && ip_type == NetworkResolver::Type::ANY) { + // Careful here, changing the referenced parameter so the caller knows that we are using an IPv4 socket + // in place of a dual stack one, and further calls to _set_sock_addr will work as expected. + ip_type = NetworkResolver::Type::IPV4; + family = AF_INET; + _sock = socket(family, type, protocol); + } + + OV_ERR_FAIL_COND_V(_sock == INVALID_SOCKET, NetworkError::SOCKET_ERROR); + _ip_type = ip_type; + + if (family == AF_INET6) { + // Select IPv4 over IPv6 mapping. + set_ipv6_only_enabled(ip_type != NetworkResolver::Type::ANY); + } + + if (protocol == IPPROTO_UDP) { + // Make sure to disable broadcasting for UDP sockets. + // Depending on the OS, this option might or might not be enabled by default. Let's normalize it. + set_broadcasting_enabled(false); + } + + _is_stream = sock_type == Type::TCP; + + if (!_is_stream) { + // Disable windows feature/bug reporting WSAECONNRESET/WSAENETRESET when + // recv/recvfrom and an ICMP reply was received from a previous send/sendto. + unsigned long disable = 0; + if (ioctlsocket(_sock, SIO_UDP_CONNRESET, &disable) == -1) { + Logger::info("Unable to turn off UDP WSAECONNRESET behavior on Windows."); + } + if (ioctlsocket(_sock, SIO_UDP_NETRESET, &disable) == -1) { + // This feature seems not to be supported on wine. + Logger::info("Unable to turn off UDP WSAENETRESET behavior on Windows."); + } + } + return NetworkError::OK; +} + +void WindowsSocket::close() { + if (_sock != INVALID_SOCKET) { + closesocket(_sock); + } + + _sock = INVALID_SOCKET; + _ip_type = NetworkResolver::Type::NONE; + _is_stream = false; +} + +NetworkError WindowsSocket::bind(IpAddress const& p_addr, port_type p_port) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!_can_use_ip(p_addr, true), NetworkError::INVALID_PARAMETER); + + sockaddr_storage addr; // NOLINT + size_t addr_size = _set_addr_storage(&addr, p_addr, p_port, _ip_type); + + if (::bind(_sock, (struct sockaddr*)&addr, addr_size) != 0) { + NetworkError err = _get_socket_error(); + Logger::info("Failed to bind socket. Error: ", strerror(errno)); + close(); + return err; + } + + return NetworkError::OK; +} + +NetworkError WindowsSocket::listen(int p_max_pending) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + if (::listen(_sock, p_max_pending) != 0) { + NetworkError err = _get_socket_error(); + Logger::info("Failed to listen from socket. Error: ", strerror(errno)); + close(); + return err; + } + + return NetworkError::OK; +} + +NetworkError WindowsSocket::connect_to_host(IpAddress const& p_addr, port_type p_port) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!_can_use_ip(p_addr, true), NetworkError::INVALID_PARAMETER); + + struct sockaddr_storage addr; // NOLINT + size_t addr_size = _set_addr_storage(&addr, p_addr, p_port, _ip_type); + + if (::WSAConnect(_sock, (struct sockaddr*)&addr, addr_size, nullptr, nullptr, nullptr, nullptr) != 0) { + NetworkError err = _get_socket_error(); + + switch (err) { + using enum NetworkError; + // We are already connected. + case IS_CONNECTED: return NetworkError::OK; + // Still waiting to connect, try again in a while. + case WOULD_BLOCK: + case IN_PROGRESS: return NetworkError::BUSY; + default: + Logger::info("Connection to remote host failed. Error: ", strerror(errno)); + close(); + return err; + } + } + + return NetworkError::OK; +} + +NetworkError WindowsSocket::poll(PollType p_type, int p_timeout) const { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + bool ready = false; + fd_set rd, wr, ex; + fd_set* rdp = nullptr; + fd_set* wrp = nullptr; + FD_ZERO(&rd); + FD_ZERO(&wr); + FD_ZERO(&ex); + FD_SET(_sock, &ex); + struct timeval timeout = { p_timeout / 1000, (p_timeout % 1000) * 1000 }; + // For blocking operation, pass nullptr timeout pointer to select. + struct timeval* tp = nullptr; + if (p_timeout >= 0) { + // If timeout is non-negative, we want to specify the timeout instead. + tp = &timeout; + } + +// Windows loves its idiotic macros +#pragma push_macro("IN") +#undef IN +#pragma push_macro("OUT") +#undef OUT + switch (p_type) { + using enum PollType; + case IN: + FD_SET(_sock, &rd); + rdp = &rd; + break; + case OUT: + FD_SET(_sock, &wr); + wrp = ≀ + break; + case IN_OUT: + FD_SET(_sock, &rd); + FD_SET(_sock, &wr); + rdp = &rd; + wrp = ≀ + } +#pragma pop_macro("IN") +#pragma pop_macro("OUT") + // WSAPoll is broken: https://daniel.haxx.se/blog/2012/10/10/wsapoll-is-broken/. + int ret = select(1, rdp, wrp, &ex, tp); + + if (ret == WIN_SOCKET_ERROR) { + return NetworkError::SOCKET_ERROR; + } + + if (ret == 0) { + return NetworkError::BUSY; + } + + if (FD_ISSET(_sock, &ex)) { + NetworkError err = _get_socket_error(); + Logger::info("Exception when polling socket. Error: ", strerror(errno)); + return err; + } + + if (rdp && FD_ISSET(_sock, rdp)) { + ready = true; + } + if (wrp && FD_ISSET(_sock, wrp)) { + ready = true; + } + + return ready ? NetworkError::OK : NetworkError::BUSY; +} + +NetworkError WindowsSocket::receive(uint8_t* p_buffer, size_t p_len, int64_t& r_read) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + r_read = ::recv(_sock, (char*)p_buffer, p_len, 0); + + if (r_read < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + return NetworkError::OK; +} + +NetworkError WindowsSocket::receive_from( // + uint8_t* p_buffer, size_t p_len, int64_t& r_read, IpAddress& r_ip, port_type& r_port, bool p_peek +) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + struct sockaddr_storage from; // NOLINT + socklen_t len = sizeof(from); + std::memset(&from, 0, len); + + r_read = ::recvfrom(_sock, (char*)p_buffer, p_len, p_peek ? MSG_PEEK : 0, (struct sockaddr*)&from, &len); + + if (r_read < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + if (from.ss_family == AF_INET) { + struct sockaddr_in* sin_from = (struct sockaddr_in*)&from; + r_ip.set_ipv4(std::bit_cast>(&sin_from->sin_addr)); + r_port = ntohs(sin_from->sin_port); + } else if (from.ss_family == AF_INET6) { + struct sockaddr_in6* s6_from = (struct sockaddr_in6*)&from; + r_ip.set_ipv6(std::bit_cast>(&s6_from->sin6_addr)); + r_port = ntohs(s6_from->sin6_port); + } else { + // Unsupported socket family, should never happen. + OV_ERR_FAIL_V(NetworkError::UNSUPPORTED); + } + + return NetworkError::OK; +} + +NetworkError WindowsSocket::send(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + int flags = 0; + r_sent = ::send(_sock, (const char*)p_buffer, p_len, flags); + + if (r_sent < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + return NetworkError::OK; +} + +NetworkError +WindowsSocket::send_to(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent, IpAddress const& p_ip, port_type p_port) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + + struct sockaddr_storage addr; // NOLINT + size_t addr_size = _set_addr_storage(&addr, p_ip, p_port, _ip_type); + r_sent = ::sendto(_sock, (const char*)p_buffer, p_len, 0, (struct sockaddr*)&addr, addr_size); + + if (r_sent < 0) { + NetworkError err = _get_socket_error(); + if (err == NetworkError::WOULD_BLOCK) { + return NetworkError::BUSY; + } + return err; + } + + return NetworkError::OK; +} + +WindowsSocket* WindowsSocket::_accept(IpAddress& r_ip, port_type& r_port) { + OV_ERR_FAIL_COND_V(!is_open(), nullptr); + + struct sockaddr_storage their_addr; // NOLINT + socklen_t size = sizeof(their_addr); + SOCKET fd = ::accept(_sock, (struct sockaddr*)&their_addr, &size); + if (fd == INVALID_SOCKET) { + NetworkError err = _get_socket_error(); + Logger::info("Error when accepting socket connection. Error: ", strerror(errno)); + return nullptr; + } + + _set_ip_port(&their_addr, &r_ip, &r_port); + + WindowsSocket* ns = new WindowsSocket(); + ns->_sock = fd; + ns->_ip_type = _ip_type; + ns->_is_stream = _is_stream; + ns->set_blocking_enabled(false); + return ns; +} + +bool WindowsSocket::is_open() const { + return _sock != INVALID_SOCKET; +} + +int WindowsSocket::available_bytes() const { + OV_ERR_FAIL_COND_V(!is_open(), -1); + + unsigned long len; + int ret = ioctlsocket(_sock, FIONREAD, &len); + if (ret == -1) { + _get_socket_error(); + Logger::info("Error when checking available bytes on socket. Error: ", strerror(errno)); + return -1; + } + return len; +} + +NetworkError WindowsSocket::get_socket_address(IpAddress* r_ip, port_type* r_port) const { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::NOT_OPEN); + + struct sockaddr_storage saddr; // NOLINT + socklen_t len = sizeof(saddr); + if (getsockname(_sock, (struct sockaddr*)&saddr, &len) != 0) { + NetworkError err = _get_socket_error(); + Logger::info("Error when reading local socket address. Error: ", strerror(errno)); + return err; + } + _set_ip_port(&saddr, r_ip, r_port); + return NetworkError::OK; +} + +NetworkError WindowsSocket::set_broadcasting_enabled(bool p_enabled) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + // IPv6 has no broadcast support. + if (_ip_type == NetworkResolver::Type::IPV6) { + return NetworkError::UNSUPPORTED; + } + + int par = p_enabled ? 1 : 0; + if (setsockopt(_sock, SOL_SOCKET, SO_BROADCAST, (const char*)&par, sizeof(int)) != 0) { + Logger::warning("Unable to change broadcast setting."); + return NetworkError::BROADCAST_CHANGE_FAILED; + } + return NetworkError::OK; +} + +void WindowsSocket::set_blocking_enabled(bool p_enabled) { + OV_ERR_FAIL_COND(!is_open()); + + int ret = 0; + unsigned long par = p_enabled ? 0 : 1; + ret = ioctlsocket(_sock, FIONBIO, &par); + if (ret != 0) { + Logger::warning("Unable to change non-block mode."); + } +} + +void WindowsSocket::set_ipv6_only_enabled(bool p_enabled) { + OV_ERR_FAIL_COND(!is_open()); + // This option is only available in IPv6 sockets. + OV_ERR_FAIL_COND(_ip_type == NetworkResolver::Type::IPV4); + + int par = p_enabled ? 1 : 0; + if (setsockopt(_sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&par, sizeof(int)) != 0) { + Logger::warning("Unable to change IPv4 address mapping over IPv6 option."); + } +} + +void WindowsSocket::set_tcp_no_delay_enabled(bool p_enabled) { + OV_ERR_FAIL_COND(!is_open()); + OV_ERR_FAIL_COND(!_is_stream); // Not TCP. + + int par = p_enabled ? 1 : 0; + if (setsockopt(_sock, IPPROTO_TCP, TCP_NODELAY, (const char*)&par, sizeof(int)) < 0) { + Logger::warning("Unable to set TCP no delay option."); + } +} + +void WindowsSocket::set_reuse_address_enabled(bool p_enabled) { + OV_ERR_FAIL_COND(!is_open()); + + // On Windows, enabling SO_REUSEADDR actually would also enable reuse port, very bad on TCP. Denying... + // Windows does not have this option, SO_REUSEADDR in this magical world means SO_REUSEPORT +} + +NetworkError WindowsSocket::change_multicast_group(IpAddress const& p_multi_address, std::string_view p_if_name, bool p_add) { + OV_ERR_FAIL_COND_V(!is_open(), NetworkError::UNCONFIGURED); + OV_ERR_FAIL_COND_V(!_can_use_ip(p_multi_address, false), NetworkError::INVALID_PARAMETER); + + // Need to force level and af_family to IP(v4) when using dual stacking and provided multicast group is IPv4. + NetworkResolver::Type type = + _ip_type == NetworkResolver::Type::ANY && p_multi_address.is_ipv4() ? NetworkResolver::Type::IPV4 : _ip_type; + // This needs to be the proper level for the multicast group, no matter if the socket is dual stacking. + int level = type == NetworkResolver::Type::IPV4 ? IPPROTO_IP : IPPROTO_IPV6; + int ret = -1; + + IpAddress if_ip; + uint32_t if_v6id = 0; + OpenVic::string_map_t if_info = NetworkResolver::singleton().get_local_interfaces(); + for (std::pair const& pair : if_info) { + NetworkResolver::InterfaceInfo const& c = pair.second; + if (c.name != p_if_name) { + continue; + } + + std::from_chars_result from_chars = + StringUtils::from_chars(c.index.data(), c.index.data() + c.index.size(), if_v6id); + OV_ERR_FAIL_COND_V(from_chars.ec != std::errc {}, NetworkError::SOCKET_ERROR); + if (type == NetworkResolver::Type::IPV6) { + break; // IPv6 uses index. + } + + for (IpAddress const& F : c.ip_addresses) { + if (!F.is_ipv4()) { + continue; // Wrong IP type. + } + if_ip = F; + break; + } + break; + } + + if (level == IPPROTO_IP) { + OV_ERR_FAIL_COND_V(!if_ip.is_valid(), NetworkError::INVALID_PARAMETER); + struct ip_mreq greq; // NOLINT + int sock_opt = p_add ? IP_ADD_MEMBERSHIP : IP_DROP_MEMBERSHIP; + std::memcpy(&greq.imr_multiaddr, p_multi_address.get_ipv4().data(), 4); + std::memcpy(&greq.imr_interface, if_ip.get_ipv4().data(), 4); + ret = setsockopt(_sock, level, sock_opt, (const char*)&greq, sizeof(greq)); + } else { + struct ipv6_mreq greq; // NOLINT + int sock_opt = p_add ? IPV6_ADD_MEMBERSHIP : IPV6_DROP_MEMBERSHIP; + std::memcpy(&greq.ipv6mr_multiaddr, p_multi_address.get_ipv6().data(), 16); + greq.ipv6mr_interface = if_v6id; + ret = setsockopt(_sock, level, sock_opt, (const char*)&greq, sizeof(greq)); + } + + if (ret == -1) { + return _get_socket_error(); + } + + return NetworkError::OK; +} + +size_t WindowsSocket::_set_addr_storage( + sockaddr_storage* p_addr, IpAddress const& p_ip, port_type p_port, NetworkResolver::Type p_ip_type +) { + std::memset(p_addr, 0, sizeof(struct sockaddr_storage)); + if (p_ip_type == NetworkResolver::Type::IPV6 || p_ip_type == NetworkResolver::Type::ANY) { // IPv6 socket. + + // IPv6 only socket with IPv4 address. + OV_ERR_FAIL_COND_V(!p_ip.is_wildcard() && p_ip_type == NetworkResolver::Type::IPV6 && p_ip.is_ipv4(), 0); + + struct sockaddr_in6* addr6 = (struct sockaddr_in6*)p_addr; + addr6->sin6_family = AF_INET6; + addr6->sin6_port = htons(p_port); + if (p_ip.is_valid()) { + std::memcpy(&addr6->sin6_addr.s6_addr, p_ip.get_ipv6().data(), 16); + } else { + addr6->sin6_addr = in6addr_any; + } + return sizeof(sockaddr_in6); + } else { // IPv4 socket. + + // IPv4 socket with IPv6 address. + OV_ERR_FAIL_COND_V(!p_ip.is_wildcard() && !p_ip.is_ipv4(), 0); + + struct sockaddr_in* addr4 = (struct sockaddr_in*)p_addr; + addr4->sin_family = AF_INET; + addr4->sin_port = htons(p_port); // Short, network byte order. + + if (p_ip.is_valid()) { + std::memcpy(&addr4->sin_addr.s_addr, p_ip.get_ipv4().data(), 4); + } else { + addr4->sin_addr.s_addr = INADDR_ANY; + } + + return sizeof(sockaddr_in); + } +} + +#endif diff --git a/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.hpp b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.hpp new file mode 100644 index 000000000..83285caf3 --- /dev/null +++ b/src/openvic-simulation/multiplayer/lowlevel/windows/WindowsSocket.hpp @@ -0,0 +1,75 @@ +#pragma once + +#if !defined(_WIN32) +#error "WindowsSocket.hpp should only be included on windows systems" +#endif + +#include +#include + +#include + +#include "openvic-simulation/multiplayer/lowlevel/IpAddress.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkError.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkResolver.hpp" +#include "openvic-simulation/multiplayer/lowlevel/NetworkSocketBase.hpp" + +struct sockaddr_storage; + +namespace OpenVic { + struct WindowsSocket final : NetworkSocketBase { + static constexpr NetworkResolver::Provider provider_value = NetworkResolver::Provider::WINDOWS; + + WindowsSocket(); + ~WindowsSocket() override; + + NetworkError open(Type p_type, NetworkResolver::Type& ip_type) override; + void close() override; + NetworkError bind(IpAddress const& p_addr, port_type p_port) override; + NetworkError listen(int p_max_pending) override; + NetworkError connect_to_host(IpAddress const& p_addr, port_type p_port) override; + NetworkError poll(PollType p_type, int timeout) const override; + NetworkError receive(uint8_t* p_buffer, size_t p_len, int64_t& r_read) override; + NetworkError receive_from( // + uint8_t* p_buffer, size_t p_len, int64_t& r_read, IpAddress& r_ip, port_type& r_port, bool p_peek = false + ) override; + NetworkError send(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent) override; + NetworkError + send_to(const uint8_t* p_buffer, size_t p_len, int64_t& r_sent, IpAddress const& p_ip, port_type p_port) override; + + bool is_open() const override; + int available_bytes() const override; + NetworkError get_socket_address(IpAddress* r_ip, port_type* r_port) const override; + + // Returns OK if the socket option has been set successfully + NetworkError set_broadcasting_enabled(bool p_enabled) override; + void set_blocking_enabled(bool p_enabled) override; + void set_ipv6_only_enabled(bool p_enabled) override; + void set_tcp_no_delay_enabled(bool p_enabled) override; + void set_reuse_address_enabled(bool p_enabled) override; + NetworkError change_multicast_group(IpAddress const& p_multi_address, std::string_view p_if_name, bool add) override; + + NetworkResolver::Provider provider() const override { + return provider_value; + } + + static void setup(); + static void cleanup(); + + protected: + WindowsSocket* _accept(IpAddress& r_ip, port_type& r_port) override; + + private: + static size_t _set_addr_storage( // + sockaddr_storage* p_addr, IpAddress const& p_ip, port_type p_port, NetworkResolver::Type p_ip_type + ); + static void _set_ip_port(sockaddr_storage* addr, IpAddress* r_ip, port_type* r_port); + + NetworkError _get_socket_error() const; + bool _can_use_ip(IpAddress const& p_ip, const bool p_for_bind) const; + + UINT_PTR _sock = (UINT_PTR)(~0); + NetworkResolver::Type _ip_type = NetworkResolver::Type::NONE; + bool _is_stream = false; + }; +} diff --git a/src/openvic-simulation/player/PlayerManager.cpp b/src/openvic-simulation/player/PlayerManager.cpp index 0f98cc8ff..cd277339a 100644 --- a/src/openvic-simulation/player/PlayerManager.cpp +++ b/src/openvic-simulation/player/PlayerManager.cpp @@ -1,7 +1,20 @@ #include "PlayerManager.hpp" +#include "openvic-simulation/multiplayer/Constants.hpp" +#include "openvic-simulation/utility/Containers.hpp" + using namespace OpenVic; -void PlayerManager::set_country(CountryInstance* instance) { +const Player Player::INVALID_PLAYER {}; + +Player::Player(client_id_type client_id, memory::string name) : client_id { client_id }, name { name } {} + +void Player::set_name(memory::string name) { + player_changed(); + this->name = name; +} + +void Player::set_country(CountryInstance* instance) { + player_changed(); country = instance; } diff --git a/src/openvic-simulation/player/PlayerManager.hpp b/src/openvic-simulation/player/PlayerManager.hpp index b873eb6d0..9dd1d6646 100644 --- a/src/openvic-simulation/player/PlayerManager.hpp +++ b/src/openvic-simulation/player/PlayerManager.hpp @@ -1,17 +1,35 @@ #pragma once +#include "openvic-simulation/multiplayer/Constants.hpp" +#include "openvic-simulation/types/Signal.hpp" #include "openvic-simulation/utility/Containers.hpp" #include "openvic-simulation/utility/Getters.hpp" namespace OpenVic { + struct GameManager; + struct ClientManager; + struct HostSession; struct CountryInstance; - struct PlayerManager { + struct Player { private: + friend struct GameManager; + friend struct ClientManager; + friend struct HostSession; + memory::string PROPERTY(name); CountryInstance* PROPERTY_PTR(country, nullptr); + client_id_type PROPERTY(client_id, MP_INVALID_CLIENT_ID); + + Player() = default; + Player(client_id_type client_id, memory::string name); public: + signal_property player_changed; + + static const Player INVALID_PLAYER; + + void set_name(memory::string name); void set_country(CountryInstance* instance); }; } diff --git a/src/openvic-simulation/types/RingBuffer.hpp b/src/openvic-simulation/types/RingBuffer.hpp new file mode 100644 index 000000000..a85ff293e --- /dev/null +++ b/src/openvic-simulation/types/RingBuffer.hpp @@ -0,0 +1,729 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/Utility.hpp" + +namespace OpenVic { + template> + struct RingBuffer { + using allocator_type = Allocator; + using allocator_traits = std::allocator_traits; + using value_type = T; + using size_type = typename allocator_traits::size_type; + using difference_type = typename allocator_traits::difference_type; + using pointer = typename allocator_traits::pointer; + using const_pointer = typename allocator_traits::const_pointer; + using reference = decltype(*pointer {}); + using const_reference = decltype(*const_pointer {}); + + template + struct _iterator { + using difference_type = typename allocator_traits::difference_type; + using size_type = typename allocator_traits::difference_type; + using value_type = typename allocator_traits::value_type; + using pointer = PointerType; + using reference = decltype(*pointer {}); + using iterator_category = std::random_access_iterator_tag; + + constexpr _iterator() noexcept = default; + _iterator(pointer data, const size_type ring_offset, const size_type ring_index, const uint8_t ring_capacity) + : _data(data), _offset(ring_offset), _index(ring_index), _capacity_power(ring_capacity) {} + + constexpr operator _iterator() const { + return _iterator(_data, _offset, _index, _capacity_power); + } + + constexpr reference operator*() const { + return _data[_ring_wrap(_offset + _index, _get_capacity(_capacity_power))]; + } + + constexpr pointer operator->() const noexcept { + return &**this; + } + + _iterator operator++(int) noexcept { + _iterator copy(*this); + operator++(); + return copy; + } + + _iterator& operator++() noexcept { + ++_index; + return *this; + } + + _iterator operator--(int) noexcept { + _iterator copy(*this); + operator--(); + return copy; + } + + _iterator& operator--() noexcept { + --_index; + return *this; + } + + _iterator& operator+=(difference_type n) noexcept { + _index += n; + return *this; + } + + constexpr _iterator operator+(difference_type n) const noexcept { + return _iterator(_data, _offset, _index + n, _capacity_power); + } + + _iterator& operator-=(difference_type n) noexcept { + _index -= n; + return *this; + } + + constexpr _iterator operator-(difference_type n) const noexcept { + assert(n <= _index); + return _iterator(_data, _offset, _index - n, _capacity_power); + } + + constexpr reference operator[](difference_type n) const { + return *(*this + n); + } + + constexpr friend difference_type operator-(_iterator const& lhs, _iterator const& rhs) noexcept { + return lhs._index > rhs._index ? lhs._index - rhs._index : -(rhs._index - lhs._index); + } + + constexpr friend _iterator operator+(difference_type lhs, _iterator const& rhs) noexcept { + return rhs + lhs; + } + + constexpr friend auto operator<=>(_iterator const& lhs, _iterator const& rhs) noexcept { + return std::tie(lhs._data, lhs._offset, lhs._index) <=> std::tie(rhs._data, rhs._offset, rhs._index); + } + + constexpr friend bool operator==(_iterator const& lhs, _iterator const& rhs) noexcept { + return &*lhs == &*rhs; + } + + private: + pointer _data {}; + + // Keeping both _offset and _index around is a little redundant, + // algorithmically, but it makes it much easier to express iterator-mutating + // operations. + + // Physical index of begin(). + size_type _offset {}; + // Logical index of this iterator. + size_type _index {}; + + uint8_t _capacity_power {}; + }; + + using iterator = _iterator; + using const_iterator = _iterator; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + explicit RingBuffer(uint8_t power, allocator_type const& allocator) : _allocator(allocator) { + reserve_power(power); + } + explicit RingBuffer(size_type power = 0) : RingBuffer(power, allocator_type {}) {} + + ~RingBuffer() { + clear(); + _deallocate(); + } + + RingBuffer(RingBuffer const& other) + : RingBuffer(other, allocator_traits::select_on_container_copy_construction(other._allocator)) {} + + RingBuffer(RingBuffer const& other, allocator_type const& allocator) : RingBuffer(other._capacity_power, allocator) { + clear(); + + for (const_reference value : other) { + push_back(value); + } + } + + RingBuffer(RingBuffer&& other) noexcept : RingBuffer(0, std::move(other._allocator)) { + _no_alloc_swap(other); + } + + RingBuffer(RingBuffer&& other, allocator_type const& allocator) : RingBuffer(0, allocator) { + if (other._allocator == allocator) { + _no_alloc_swap(other); + } else { + for (auto& element : other) { + emplace_back(std::move(element)); + } + } + } + + RingBuffer& operator=(RingBuffer const& other) { + clear(); + + if constexpr (typename allocator_traits::propagate_on_container_copy_assignment()) { + _allocator = other._allocator; + } + + for (auto const& value : other) { + push_back(value); + } + return *this; + } + + RingBuffer& operator=(RingBuffer&& other) noexcept( + allocator_traits::propagate_on_container_move_assignment::value || + std::is_nothrow_move_constructible::value + ) { + if (allocator_traits::propagate_on_container_move_assignment::value || _allocator == other._allocator) { + // We're either getting the other's allocator or they're already the same, + // so swap data in one go. + if constexpr (typename allocator_traits::propagate_on_container_move_assignment()) { + std::swap(_allocator, other._allocator); + } + _no_alloc_swap(other); + } else { + // Different allocators and can't swap them, so move elementwise. + clear(); + for (auto& element : other) { + emplace_back(std::move(element)); + } + } + + return *this; + } + + allocator_type get_allocator() const { + return _allocator; + } + reference front() { + return at(0); + } + reference back() { + return at(size() - 1); + } + const_reference back() const { + return at(size() - 1); + } + + const_reference operator[](const size_type index) const { + return _data[_ring_wrap(_offset + index, capacity())]; + } + reference operator[](const size_type index) { + return _data[_ring_wrap(_offset + index, capacity())]; + } + + const_reference at(const size_type index) const { + if (OV_unlikely(index >= size())) { + std::abort(); + } + return (*this)[index]; + } + reference at(const size_type index) { + if (OV_unlikely(index >= size())) { + std::abort(); + } + return (*this)[index]; + } + + iterator begin() noexcept { + return iterator(&_data[0], _offset, 0, _capacity_power); + } + iterator end() noexcept { + return iterator(&_data[0], _offset, size(), _capacity_power); + } + const_iterator begin() const noexcept { + return const_iterator(&_data[0], _offset, 0, _capacity_power); + } + const_iterator end() const noexcept { + return const_iterator(&_data[0], _offset, size(), _capacity_power); + } + + const_iterator cbegin() const noexcept { + return const_cast(*this).begin(); + } + const_iterator cend() const noexcept { + return const_cast(*this).end(); + } + + reverse_iterator rbegin() noexcept { + return reverse_iterator(end()); + } + reverse_iterator rend() noexcept { + return reverse_iterator(begin()); + } + const_reverse_iterator rbegin() const noexcept { + return const_reverse_iterator(end()); + } + const_reverse_iterator rend() const noexcept { + return const_reverse_iterator(begin()); + } + + const_reverse_iterator crbegin() const noexcept { + return const_cast(*this).rbegin(); + } + const_reverse_iterator crend() const noexcept { + return const_cast(*this).rend(); + } + + bool empty() const noexcept { + return size() == 0; + } + + size_type size() const noexcept { + return _size; + } + + size_type capacity() const noexcept { + return _get_capacity(_capacity_power); + } + + size_type max_size() const noexcept { + return std::min(allocator_traits::max_size(_allocator), std::numeric_limits::max() / sizeof(value_type)); + } + + size_type space() const noexcept { + return capacity() - size(); + } + + void reserve_power(uint8_t power) { + // Will cause overflow + OV_ERR_FAIL_COND(power > 64); + + if (OV_unlikely(_capacity_power != 0) && power <= _capacity_power) { + return; + } + + size_type new_capacity = _get_capacity(power); + pointer result = _allocate(new_capacity); + if (!empty()) { + pointer last = std::uninitialized_copy_n(_data + _offset, capacity() + 1 - _offset, result); + if (_offset > size()) { + std::uninitialized_copy_n(_data, _offset - size(), last); + } + static_assert(std::is_destructible_v, "value type is destructible"); + if constexpr (!std::is_trivially_destructible_v) { + for (pointer first = _data; first != _data + _size; first++) { + allocator_traits::destroy(_allocator, first); + } + } + } + _deallocate(); + _offset = 0; + _next = size(); + _data = result; + _capacity_power = power; + } + + void shrink_to_fit() { + if (_capacity_power > 0 && empty()) { + _deallocate(); + _data = _allocate(0); + _size = 0; + _offset = 0; + _next = 0; + _capacity_power = 0; + return; + } + + size_type new_capacity = std::bit_ceil(size()); + pointer result = _allocate(new_capacity); + pointer last = std::uninitialized_copy_n(_data + _offset, capacity() + 1 - _offset, result); + if (_offset > size()) { + std::uninitialized_copy_n(_data, _offset - size(), last); + } + static_assert(std::is_destructible_v, "value type is destructible"); + if constexpr (!std::is_trivially_destructible_v) { + for (pointer first = _data; first != _data + _size; first++) { + allocator_traits::destroy(_allocator, first); + } + } + _deallocate(); + _offset = 0; + _next = size(); + _data = result; + _capacity_power = std::bit_width(new_capacity) - 1; + } + + void push_front(const_reference value) { + emplace_front(value); + } + void push_front(value_type&& value) { + emplace_front(std::move(value)); + } + + template + reference emplace_front(Args&&... args) { + if (capacity() == 0) { + // A buffer of size zero is conceptually sound, so let's support it. + return (*this)[0]; + } + + allocator_traits::construct(_allocator, &_data[_decrement(_offset)], std::forward(args)...); + + // If required, make room for next time. + if (size() == capacity()) { + pop_back(); + } + _grow_front(); + return (*this)[0]; + } + + void push_back(const_reference value) { + emplace_back(value); + } + void push_back(value_type&& value) { + emplace_back(std::move(value)); + } + + template + reference emplace_back(Args&&... args) { + if (capacity() == 0) { + // A buffer of size zero is conceptually sound, so let's support it. + return (*this)[0]; + } + + allocator_traits::construct(_allocator, &_data[_next], std::forward(args)...); + + // If required, make room for next time. + if (size() == capacity()) { + pop_front(); + } + _grow_back(); + return (*this)[size() - 1]; + } + + template + iterator append(InputIt first, InputIt last) { + using distance_type = typename std::iterator_traits::difference_type; + + const size_type _capacity = capacity(); + const distance_type distance = std::distance(first, last); + if (OV_unlikely(distance <= 0)) { + return end(); + } + + // Limit the number of elements to append at _capacity + const size_type append_count = std::min(distance, _capacity); + + // If appending would exceed capacity, remove elements from front + const size_type excess = size() + append_count > _capacity ? size() + append_count - _capacity : 0; + if (excess > 0) { + // Destroy elements that will be overwritten + if constexpr (!std::is_trivially_destructible_v) { + size_type pos = _offset; + for (size_type i = 0; i < excess; ++i) { + allocator_traits::destroy(_allocator, _data + pos); + pos = _ring_wrap(pos + 1, _capacity); + } + } + _offset = _ring_wrap(_offset + excess, _capacity); + const size_type overflow_check = _size - excess; + if (OV_unlikely(overflow_check > _size)) { + _size = 0; + } else { + _size = overflow_check; + } + } + + // Determine physical position for appending + const size_type write_pos = _ring_wrap(_offset + _size, _capacity); + size_type space_to_end = _capacity - write_pos; + + iterator result = iterator(_data, _offset, _size, _capacity_power); + if (OV_likely(append_count <= space_to_end)) { + // Single copy to the back + std::uninitialized_copy_n(first, append_count, _data + write_pos); + } else { + // Split copy: part to the end, part to the beginning + ++space_to_end; + InputIt mid = first; + std::advance(mid, space_to_end); + std::uninitialized_copy(first, mid, _data + write_pos); + std::uninitialized_copy_n(mid, append_count - space_to_end, _data); + } + + _size += append_count; + _next = _ring_wrap(_offset + _size, _capacity); + return result; + } + template + iterator append(InputIt first, size_type count) { + return append(first, first + count); + } + + template + iterator append_range(Range&& range, size_type write_size = std::numeric_limits::max()) { + if (write_size < capacity()) { + auto end = ranges::begin(range); + ranges::advance(end, std::min(ranges::distance(range), write_size)); + return append(ranges::begin(range), end); + } + return append(ranges::begin(range), ranges::end(range)); + } + + value_type read() { + if (empty()) { + if constexpr (std::is_default_constructible_v) { + return {}; + } + std::abort(); + } + value_type result = (*this)[0]; + pop_front(); + return result; + } + + std::span read_buffer_to( // + pointer r_buffer, size_type read_size = std::numeric_limits::max() + ) { + read_size = std::min(read_size, capacity()); + iterator last = begin(); + ranges::advance(last, read_size); + ranges::move(std::make_move_iterator(begin()), std::make_move_iterator(last), r_buffer); + erase(begin(), read_size); + return std::span { r_buffer, read_size }; + } + + memory::vector read_buffer(size_type read_size) { + read_size = std::min(read_size, capacity()); + iterator last = begin(); + ranges::advance(last, read_size); + memory::vector result { std::make_move_iterator(begin()), std::make_move_iterator(last) }; + erase(begin(), read_size); + return result; + } + + + ranges::subrange read_range(size_type read_size) { + read_size = std::min(read_size, capacity()); + iterator last = begin(); + ranges::advance(last, read_size); + return { begin(), last }; + } + + void pop_front() noexcept { + if (empty()) { + return; + } + + if constexpr (!std::is_trivially_destructible_v) { + allocator_traits::destroy(_allocator, &_data[_offset]); + } + _shrink_front(); + } + void pop_back() noexcept { + if (empty()) { + return; + } + + _shrink_back(); + if constexpr (!std::is_trivially_destructible_v) { + allocator_traits::destroy(_allocator, &_data[_next]); + } + } + + void clear() noexcept { + if constexpr (!std::is_trivially_destructible_v) { + if (empty()) { + return; + } + for (pointer first = _data; first != _data + _size; first++) { + allocator_traits::destroy(_allocator, first); + } + } + _size = 0; + _offset = 0; + _next = 0; + } + + iterator erase(const_iterator from, const_iterator to) noexcept( + noexcept(pop_front()) && std::is_nothrow_move_assignable::value + ) { + if (OV_unlikely(from > end() || to > end())) { + return std::bit_cast(from); + } + + if (from == to) { + return iterator(_data, _offset, from - begin(), _capacity_power); + } + + const difference_type erase_count = to - from; + if (erase_count == 0) { + return iterator(_data, _offset, from - begin(), _capacity_power); + } + + const difference_type leading = from - begin(); + const difference_type trailing = end() - to; + const size_type _capacity = capacity(); + + iterator result = iterator(_data, _offset, from - begin(), _capacity_power); + + if (leading <= trailing) { + // Shift elements from the front towards the erasure point + const size_type dest_pos = _ring_wrap(_offset, _capacity); + const size_type src_pos = _ring_wrap(_offset + erase_count, _capacity); + size_type count = leading; + + // Move elements in one or two segments depending on wrap-around + if (dest_pos <= src_pos || src_pos + leading <= _capacity + 1) { + std::move(_data + src_pos, _data + src_pos + leading, _data + dest_pos); + } else { + size_type first_segment = _capacity + 1 - src_pos; + std::move(_data + src_pos, _data + src_pos + first_segment, _data + dest_pos); + std::move(_data, _data + (leading - first_segment), _data + dest_pos + first_segment); + } + + // Destroy elements at the front + if constexpr (!std::is_trivially_destructible_v) { + for (size_type i = 0; i < erase_count; ++i) { + allocator_traits::destroy(_allocator, _data + _offset); + _offset = _ring_wrap(_offset + 1, _capacity); + } + } else { + _offset = _ring_wrap(_offset + erase_count, _capacity); + } + _size -= erase_count; + } else if (trailing >= 0) { + // Shift elements from the back towards the erasure point + const size_type dest_pos = _ring_wrap(_offset + leading, _capacity); + const size_type src_pos = _ring_wrap(_offset + leading + erase_count, _capacity); + + // Move elements in one or two segments depending on wrap-around + if (dest_pos <= src_pos || src_pos + trailing <= _capacity + 1) { + std::move(_data + src_pos, _data + src_pos + trailing, _data + dest_pos); + } else { + const size_type first_segment = _capacity + 1 - src_pos; + std::move(_data + src_pos, _data + src_pos + first_segment, _data + dest_pos); + std::move(_data, _data + (trailing - first_segment), _data + dest_pos + first_segment); + } + + // Destroy elements at the back + if constexpr (!std::is_trivially_destructible_v) { + for (size_type i = 0; i < erase_count; ++i) { + _next = _ring_wrap(_next - 1, _capacity); + allocator_traits::destroy(_allocator, _data + _next); + } + } else { + _next = _next - erase_count; + } + _size += erase_count; + _offset -= erase_count + 1; + } + + return result; + } + iterator erase(const_iterator pos, size_type count) noexcept(noexcept(erase(pos, pos + count))) { + const_iterator last = pos; + std::advance(last, count); + return erase(pos, last); + } + iterator erase(const_iterator pos) noexcept(noexcept(erase(pos, 1))) { + return erase(pos, 1); + } + + void swap(RingBuffer& other) noexcept { + if constexpr (typename allocator_traits::propagate_on_container_swap()) { + std::swap(_allocator, other._allocator); + } + _no_alloc_swap(other); + } + + friend auto operator<=>(RingBuffer const& lhs, RingBuffer const& rhs) { + return std::lexicographical_compare_three_way(lhs.begin(), lhs.end(), rhs.begin(), rhs.end()); + } + + friend bool operator==(RingBuffer const& lhs, RingBuffer const& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + + return std::equal(lhs.begin(), lhs.end(), rhs.begin()); + } + + private: + void _no_alloc_swap(RingBuffer& other) { + std::swap(_data, other._data); + std::swap(_next, other._next); + std::swap(_offset, other._offset); + std::swap(_size, other._size); + std::swap(_capacity_power, other._capacity_power); + } + + size_type _increment(size_type& index, size_type amount = 1) { + return index = (index + amount - 1U < capacity() ? index + amount : 0); + } + + size_type _decrement(size_type& index, size_type amount = 1) { + return index = ((index - amount) + 1U > 0 ? index - amount : capacity()); + } + + void _grow_front() { + _decrement(_offset); + ++_size; + } + + void _grow_back() { + _increment(_next); + ++_size; + } + + void _shrink_front() { + _increment(_offset); + --_size; + } + + void _shrink_back() { + _decrement(_next); + --_size; + } + + pointer _allocate(const size_type new_capacity) { + return allocator_traits::allocate(_allocator, new_capacity + 1); + } + + void _deallocate() { + allocator_traits::deallocate(_allocator, _data, capacity() + 1); + } + + static constexpr size_type _ring_wrap(const size_type ring_index, const size_type ring_capacity) { + return (ring_index <= ring_capacity) ? ring_index : ring_index - ring_capacity - 1; + } + + static constexpr size_type _get_capacity(const uint8_t capacity_power) { + return (1U << static_cast(capacity_power)); + } + + // The start of the dynamically allocated backing array. + pointer _data = nullptr; + // The next position to write to for push_back(). + size_type _next = 0U; + + // Start of the ring buffer in data_. + size_type _offset = 0U; + // The number of elements in the ring buffer (distance between begin() and + // end()). + size_type _size = 0U; + // The power used to calculate the capacity of the ring buffer + uint8_t _capacity_power = 0U; + + // The allocator is used to allocate memory, and to construct and destroy + // elements. + [[no_unique_address]] allocator_type _allocator {}; + }; +} diff --git a/src/openvic-simulation/utility/ErrorMacros.hpp b/src/openvic-simulation/utility/ErrorMacros.hpp index c8ffb7a31..2cd2d7b19 100644 --- a/src/openvic-simulation/utility/ErrorMacros.hpp +++ b/src/openvic-simulation/utility/ErrorMacros.hpp @@ -3,10 +3,50 @@ #include "openvic-simulation/utility/Logger.hpp" // IWYU pragma: keep for macros #include "openvic-simulation/utility/Utility.hpp" +/** + * Try using `ERR_FAIL_COND_MSG`. + * Only use this macro if more complex error detection or recovery is required. + * + * Prints `m_msg`, and the current function returns. + */ +#define OV_ERR_FAIL_MSG(m_msg) \ + if (true) { \ + ::OpenVic::Logger::error("Method/function failed. ", m_msg); \ + return; \ + } else \ + ((void)0) + +/** + * Try using `ERR_FAIL_COND_V_MSG` or `ERR_FAIL_V_MSG`. + * Only use this macro if more complex error detection or recovery is required, and + * there is no sensible error message. + * + * The current function returns `m_retval`. + */ +#define OV_ERR_FAIL_V(m_retval) \ + if (true) { \ + ::OpenVic::Logger::error("Method/function failed. Returning: " _OV_STR(m_retval)); \ + return m_retval; \ + } else \ + ((void)0) + +/** + * Try using `ERR_FAIL_COND_V_MSG`. + * Only use this macro if more complex error detection or recovery is required. + * + * Prints `m_msg`, and the current function returns `m_retval`. + */ +#define OV_ERR_FAIL_V_MSG(m_retval, m_msg) \ + if (true) { \ + ::OpenVic::Logger::error("Method/function failed. Returning: " _OV_STR(m_retval), " ", m_msg); \ + return m_retval; \ + } else \ + ((void)0) + /** * Try using `ERR_FAIL_COND_MSG`. * Only use this macro if there is no sensible error message. - * If checking for null use ERR_FAIL_NULL_MSG instead. + * If checking for null use OV_ERR_FAIL_NULL_MSG instead. * If checking index bounds use ERR_FAIL_INDEX_MSG instead. * * Ensures `m_cond` is false. @@ -23,7 +63,7 @@ * Ensures `m_cond` is false. * If `m_cond` is true, prints `m_msg` and the current function returns. * - * If checking for null use ERR_FAIL_NULL_MSG instead. + * If checking for null use OV_ERR_FAIL_NULL_MSG instead. * If checking index bounds use ERR_FAIL_INDEX_MSG instead. */ #define OV_ERR_FAIL_COND_MSG(m_cond, m_msg) \ @@ -36,7 +76,7 @@ /** * Try using `ERR_FAIL_COND_V_MSG`. * Only use this macro if there is no sensible error message. - * If checking for null use ERR_FAIL_NULL_V_MSG instead. + * If checking for null use OV_ERR_FAIL_NULL_V_MSG instead. * If checking index bounds use ERR_FAIL_INDEX_V_MSG instead. * * Ensures `m_cond` is false. @@ -44,9 +84,7 @@ */ #define OV_ERR_FAIL_COND_V(m_cond, m_retval) \ if (OV_unlikely(m_cond)) { \ - ::OpenVic::Logger::error( \ - "Condition \"" _OV_STR(m_cond) "\" is true. Returning: " _OV_STR(m_retval) \ - ); \ + ::OpenVic::Logger::error("Condition \"" _OV_STR(m_cond) "\" is true. Returning: " _OV_STR(m_retval)); \ return m_retval; \ } else \ ((void)0) @@ -55,14 +93,136 @@ * Ensures `m_cond` is false. * If `m_cond` is true, prints `m_msg` and the current function returns `m_retval`. * - * If checking for null use ERR_FAIL_NULL_V_MSG instead. + * If checking for null use OV_ERR_FAIL_NULL_V_MSG instead. * If checking index bounds use ERR_FAIL_INDEX_V_MSG instead. */ #define OV_ERR_FAIL_COND_V_MSG(m_cond, m_retval, m_msg) \ if (OV_unlikely(m_cond)) { \ + ::OpenVic::Logger::error("Condition \"" _OV_STR(m_cond) "\" is true. Returning: " _OV_STR(m_retval) " ", m_msg); \ + return m_retval; \ + } else \ + ((void)0) + +/** + * Try using `ERR_FAIL_INDEX_MSG`. + * Only use this macro if there is no sensible error message. + * + * Ensures an integer index `m_index` is less than `m_size` and greater than or equal to 0. + * If not, the current function returns. + */ +#define OV_ERR_FAIL_INDEX(m_index, m_size) \ + if (OV_unlikely((m_index) < 0 || (m_index) >= (m_size))) { \ + ::OpenVic::Logger::error( \ + "Index " _OV_STR(m_index) " = ", m_size, " is out of bounds (" _OV_STR(m_size) " = ", m_size, ")." \ + ); \ + return; \ + } else \ + ((void)0) + +/** + * Ensures an integer index `m_index` is less than `m_size` and greater than or equal to 0. + * If not, prints `m_msg` and the current function returns. + */ +#define OV_ERR_FAIL_INDEX_MSG(m_index, m_size, m_msg) \ + if (OV_unlikely((m_index) < 0 || (m_index) >= (m_size))) { \ ::OpenVic::Logger::error( \ - "Condition \"" _OV_STR(m_cond) "\" is true. Returning: " _OV_STR(m_retval) " ", m_msg \ + "Index " _OV_STR(m_index) " = ", m_size, " is out of bounds (" _OV_STR(m_size) " = ", m_size, "). ", m_msg \ ); \ + return; \ + } else \ + ((void)0) + +/** + * Try using `ERR_FAIL_INDEX_V_MSG`. + * Only use this macro if there is no sensible error message. + * + * Ensures an integer index `m_index` is less than `m_size` and greater than or equal to 0. + * If not, the current function returns `m_retval`. + */ +#define OV_ERR_FAIL_INDEX_V(m_index, m_size, m_retval) \ + if (OV_unlikely((m_index) < 0 || (m_index) >= (m_size))) { \ + ::OpenVic::Logger::error( \ + "Index " _OV_STR(m_index) " = ", m_size, " is out of bounds (" _OV_STR(m_size) " = ", m_size, \ + "). Returning: " _OV_STR(m_retval) \ + ); \ + return m_retval; \ + } else \ + ((void)0) + +/** + * Ensures an integer index `m_index` is less than `m_size` and greater than or equal to 0. + * If not, prints `m_msg` and the current function returns `m_retval`. + */ +#define OV_ERR_FAIL_INDEX_V_MSG(m_index, m_size, m_retval, m_msg) \ + if (OV_unlikely((m_index) < 0 || (m_index) >= (m_size))) { \ + ::OpenVic::Logger::error( \ + "Index " _OV_STR(m_index) " = ", m_size, " is out of bounds (" _OV_STR(m_size) " = ", m_size, \ + "). Returning: " _OV_STR(m_retval) " ", m_msg \ + ); \ + return m_retval; \ + } else \ + ((void)0) + +#define OV_ERR_CONTINUE(m_cond) \ + if (OV_unlikely(m_cond)) { \ + ::OpenVic::Logger::error("Condition \"" _OV_STR(m_cond) "\" is true. Continuing."); \ + continue; \ + } else \ + ((void)0) + +#define OV_ERR_BREAK(m_cond) \ + if (OV_unlikely(m_cond)) { \ + ::OpenVic::Logger::error("Condition \"" _OV_STR(m_cond) "\" is true. Breaking."); \ + break; \ + } else \ + ((void)0) + +/** + * Try using `ERR_FAIL_NULL_MSG`. + * Only use this macro if there is no sensible error message. + * + * Ensures a pointer `m_param` is not null. + * If it is null, the current function returns. + */ +#define OV_ERR_FAIL_NULL(m_param) \ + if (OV_unlikely(m_param == nullptr)) { \ + ::OpenVic::Logger::error("Parameter \"" _OV_STR(m_param) "\" is null."); \ + return; \ + } else \ + ((void)0) + +/** + * Ensures a pointer `m_param` is not null. + * If it is null, prints `m_msg` and the current function returns. + */ +#define OV_ERR_FAIL_NULL_MSG(m_param, m_msg) \ + if (OV_unlikely(m_param == nullptr)) { \ + ::OpenVic::Logger::error("Parameter \"" _OV_STR(m_param) "\" is null.", m_msg); \ + return; \ + } else \ + ((void)0) + +/** + * Try using `ERR_FAIL_NULL_V_MSG`. + * Only use this macro if there is no sensible error message. + * + * Ensures a pointer `m_param` is not null. + * If it is null, the current function returns `m_retval`. + */ +#define OV_ERR_FAIL_NULL_V(m_param, m_retval) \ + if (OV_unlikely(m_param == nullptr)) { \ + ::OpenVic::Logger::error("Parameter \"" _OV_STR(m_param) "\" is null."); \ + return m_retval; \ + } else \ + ((void)0) + +/** + * Ensures a pointer `m_param` is not null. + * If it is null, prints `m_msg` and the current function returns `m_retval`. + */ +#define OV_ERR_FAIL_NULL_V_MSG(m_param, m_retval, m_msg) \ + if (OV_unlikely(m_param == nullptr)) { \ + ::OpenVic::Logger::error("Parameter \"" _OV_STR(m_param) "\" is null.", m_msg); \ return m_retval; \ } else \ ((void)0) diff --git a/src/openvic-simulation/utility/Marshal.hpp b/src/openvic-simulation/utility/Marshal.hpp new file mode 100644 index 000000000..e064da459 --- /dev/null +++ b/src/openvic-simulation/utility/Marshal.hpp @@ -0,0 +1,574 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvic-simulation/types/Colour.hpp" +#include "openvic-simulation/types/Date.hpp" +#include "openvic-simulation/types/Vector.hpp" +#include "openvic-simulation/types/fixed_point/FixedPoint.hpp" +#include "openvic-simulation/utility/Deque.hpp" +#include "openvic-simulation/utility/ErrorMacros.hpp" +#include "openvic-simulation/utility/Utility.hpp" + +namespace OpenVic::utility { + OV_ALWAYS_INLINE uint32_t halfbits_to_floatbits(uint16_t p_half) { + uint16_t h_exp, h_sig; + uint32_t f_sgn, f_exp, f_sig; + + h_exp = (p_half & 0x7c00u); + f_sgn = ((uint32_t)p_half & 0x8000u) << 16; + switch (h_exp) { + case 0x0000u: /* 0 or subnormal */ + h_sig = (p_half & 0x03ffu); + /* Signed zero */ + if (h_sig == 0) { + return f_sgn; + } + /* Subnormal */ + h_sig <<= 1; + while ((h_sig & 0x0400u) == 0) { + h_sig <<= 1; + h_exp++; + } + f_exp = ((uint32_t)(127 - 15 - h_exp)) << 23; + f_sig = ((uint32_t)(h_sig & 0x03ffu)) << 13; + return f_sgn + f_exp + f_sig; + case 0x7c00u: /* inf or NaN */ + /* All-ones exponent and a copy of the significand */ + return f_sgn + 0x7f800000u + (((uint32_t)(p_half & 0x03ffu)) << 13); + default: /* normalized */ + /* Just need to adjust the exponent and shift */ + return f_sgn + (((uint32_t)(p_half & 0x7fffu) + 0x1c000u) << 13); + } + } + + OV_ALWAYS_INLINE float halfptr_to_float(const uint16_t* p_half) { + return std::bit_cast(halfbits_to_floatbits(*p_half)); + } + + OV_ALWAYS_INLINE float half_to_float(const uint16_t p_half) { + return halfptr_to_float(&p_half); + } + + OV_ALWAYS_INLINE uint16_t make_half_float(float p_value) { + uint32_t x = std::bit_cast(p_value); + uint32_t sign = (unsigned short)(x >> 31); + uint32_t mantissa; + uint32_t exponent; + uint16_t hf; + + // get mantissa + mantissa = x & ((1 << 23) - 1); + // get exponent bits + exponent = x & (0xFF << 23); + if (exponent >= 0x47800000) { + // check if the original single precision float number is a NaN + if (mantissa && (exponent == (0xFF << 23))) { + // we have a single precision NaN + mantissa = (1 << 23) - 1; + } else { + // 16-bit half-float representation stores number as Inf + mantissa = 0; + } + hf = (((uint16_t)sign) << 15) | (uint16_t)((0x1F << 10)) | (uint16_t)(mantissa >> 13); + } + // check if exponent is <= -15 + else if (exponent <= 0x38000000) { + /* + // store a denorm half-float value or zero + exponent = (0x38000000 - exponent) >> 23; + mantissa >>= (14 + exponent); + + hf = (((uint16_t)sign) << 15) | (uint16_t)(mantissa); + */ + hf = 0; // denormals do not work for 3D, convert to zero + } else { + hf = (((uint16_t)sign) << 15) | (uint16_t)((exponent - 0x38000000) >> 13) | (uint16_t)(mantissa >> 13); + } + + return hf; + } + + struct half_float { + float value; + constexpr half_float() = default; + constexpr half_float(float value) : value(value) {} + + constexpr operator float&() { + return value; + } + + constexpr operator float const&() const { + return value; + } + + constexpr operator bool() const { + return value; + } + + constexpr half_float operator+() const { + return +value; + } + + constexpr half_float operator-() const { + return -value; + } + + constexpr half_float operator!() const { + return !value; + } + + constexpr half_float operator+(half_float const& rhs) const { + return value + rhs.value; + } + + constexpr half_float operator-(half_float const& rhs) const { + return value - rhs.value; + } + + constexpr half_float operator*(half_float const& rhs) const { + return value * rhs.value; + } + + constexpr half_float operator/(half_float const& rhs) const { + return value / rhs.value; + } + + constexpr half_float& operator+=(half_float const& rhs) { + value += rhs.value; + return *this; + } + + constexpr half_float& operator-=(half_float const& rhs) { + value -= rhs.value; + return *this; + } + + constexpr half_float& operator*=(half_float const& rhs) { + value *= rhs.value; + return *this; + } + + constexpr half_float& operator/=(half_float const& rhs) { + value /= rhs.value; + return *this; + } + + constexpr bool operator==(half_float const& rhs) const { + return value == rhs.value; + } + + constexpr auto operator<=>(half_float const& rhs) const { + return value <=> rhs.value; + } + }; + + static_assert(std::endian::native != std::endian::big || std::endian::native != std::endian::little, "HOW?!?!?!"); + + template + concept HasEncode = requires(T const& value, std::span span) { + { value.template encode(span) } -> std::same_as; + }; + + template + concept Encodable = + HasEncode || utility::specialization_of || std::integral || std::floating_point || + std::is_enum_v || std::convertible_to || std::same_as || + std::same_as || std::same_as || IsColour || std::same_as || + std::convertible_to> || utility::specialization_of || + utility::specialization_of || utility::specialization_of_span || + utility::specialization_of || utility::specialization_of || + utility::specialization_of || utility::specialization_of || + utility::specialization_of || (std::is_empty_v && std::is_default_constructible_v); + + template + static constexpr std::integral_constant endian_tag {}; + + template + requires Encodable || Encodable> + static inline size_t encode( // + T const& value, std::span span = {}, std::integral_constant endian = {} + ) { + if constexpr (std::is_const_v || std::is_volatile_v || std::is_reference_v) { + return encode, Endian>(value, span); + } else if constexpr (HasEncode) { + return value.template encode(span); + } else if constexpr (utility::specialization_of) { + const size_t index = value.index(); + const size_t index_size = encode(index); + const size_t size = // + std::visit( + [](auto& arg) -> size_t { + return encode(arg); + }, + value + ) + + index_size; + + if (span.empty()) { + return size; + } else if (span.size() < size) { + return 0; + } + + size_t offset = encode(index, span); + std::visit( + [&](auto& arg) { + encode(arg, span.subspan(offset)); + }, + value + ); + + return size; + } else if constexpr (std::same_as) { + return encode(static_cast(value), span); + } else if constexpr (std::integral) { + using unsigned_type = std::make_unsigned_t; + + if (!span.empty()) { + if constexpr (sizeof(T) == 1) { + *span.data() = *std::bit_cast(&value); + } else { + T copy; + if constexpr (std::endian::native != Endian) { + copy = std::bit_cast(utility::byteswap(std::bit_cast(value))); + } else { + copy = value; + } + decltype(span)::pointer ptr = span.data(); + for (size_t i = 0; i < sizeof(T); i++) { + *ptr = *std::bit_cast(©) & 0xFF; + ptr++; + *std::bit_cast(©) >>= 8; + } + } + return span.size() >= sizeof(T) ? sizeof(T) : 0; + } + + return sizeof(T); + } else if constexpr (std::floating_point) { + using int_type = std::conditional_t, uint32_t, uint64_t>; + + return encode(std::bit_cast(value), span); + } else if constexpr (std::is_enum_v) { + using underlying_type = std::underlying_type_t; + + return encode(std::bit_cast(value), span); + } else if constexpr (std::convertible_to) { + const std::string_view str = static_cast(value); + const size_t str_size = str.size(); + const size_t length_size = encode(str_size); + const size_t size = str_size + length_size; + + if (span.empty()) { + return size; + } else if (span.size() < size) { + return 0; + } + + uint8_t* ptr = span.data() + encode(str_size, span); + std::copy_n(reinterpret_cast(str.data()), str_size, ptr); + + return size; + } else if constexpr (std::same_as) { + return encode(make_half_float(value.value), span); + } else if constexpr (std::same_as) { + return encode(value.to_int(), span); + } else if constexpr (std::same_as) { + return encode(value.get_timespan(), span); + } else if constexpr (IsColour) { + return encode(static_cast(value), span); + } else if constexpr (std::same_as) { + return encode(value.get_raw_value(), span); + } else if constexpr (std::convertible_to>) { + const std::span inner_span = static_cast>(value); + const size_t span_size = inner_span.size(); + const size_t length_size = encode(span_size); + const size_t size = span_size + length_size; + + if (span.empty()) { + return size; + } else if (span.size() < size) { + return 0; + } + + uint8_t* ptr = span.data() + encode(span_size, span); + std::copy_n(inner_span.data(), span_size, ptr); + + return size; + } else if constexpr (utility::specialization_of || + utility::specialization_of || utility::specialization_of_span) { + const size_t collection_size = value.size(); + const size_t length_size = encode(collection_size); + const size_t size = + collection_size * encode(value.empty() ? typename T::value_type() : value[0]) + + length_size; + + if (span.empty()) { + return size; + } else if (span.size() < size) { + return 0; + } + + size_t offset = encode(collection_size, span); + for (auto const& element : value) { + offset += encode(element, span.subspan(offset), endian_tag); + } + + return size; + } else if constexpr (utility::specialization_of || utility::specialization_of) { + static size_t size = std::apply( + [](auto&&... args) { + return (encode(args) + ...); + }, + value + ); + + if (span.empty()) { + return size; + } + + size_t offset = 0; + std::apply( + [&](auto&&... args) { + (( // + offset += encode(args, span.subspan(offset)) + ), + ...); + }, + value + ); + + return size; + } else if constexpr (utility::specialization_of) { + if (span.empty()) { + return encode(value.x) + encode(value.y); + } + + size_t offset = encode(value.x, span); + return offset + encode(value.y, span.subspan(offset)); + } else if constexpr (utility::specialization_of) { + if (span.empty()) { + return encode(value.x) + encode(value.y) + + encode(value.z); + } + + size_t offset = encode(value.x, span); + offset += encode(value.y, span.subspan(offset)); + return offset + encode(value.z, span.subspan(offset)); + } else if constexpr (utility::specialization_of) { + if (span.empty()) { + return encode(value.x) + encode(value.y) + + encode(value.z) + encode(value.w); + } + + size_t offset = encode(value.x, span); + offset += encode(value.y, span.subspan(offset)); + offset += encode(value.z, span.subspan(offset)); + return offset + encode(value.w, span.subspan(offset)); + } else if constexpr (std::is_empty_v && std::is_default_constructible_v) { + return 0; + } + } + + template + using index_t = std::integral_constant; + + template + constexpr index_t index_v {}; + + template + using indexes_t = std::variant...>; + + template + constexpr indexes_t get_index(std::index_sequence, std::size_t I) { + constexpr indexes_t retvals[] = { index_v... }; + return retvals[I]; + } + + template + constexpr auto get_index(std::size_t I) { + return get_index(std::make_index_sequence {}, I); + } + + template + concept HasDecode = requires(std::span span, size_t& r_decode_count) { + { T::template decode(span, r_decode_count) } -> std::same_as; + }; + + template + concept Decodable = HasDecode || utility::specialization_of || std::integral || + std::floating_point || std::is_enum_v || std::convertible_to || + std::is_constructible_v || std::same_as || std::same_as || + std::same_as || IsColour || std::same_as || + std::is_constructible_v || utility::specialization_of || + utility::specialization_of || utility::specialization_of || + utility::specialization_of || utility::specialization_of_std_array_of || + utility::specialization_of || utility::specialization_of || + utility::specialization_of || (std::is_empty_v && std::is_default_constructible_v); + + template + requires Decodable || Decodable> + static inline T decode( // + std::span span, size_t& r_decode_count, std::integral_constant endian = {} + ) { + if constexpr (std::is_const_v || std::is_volatile_v) { + return decode, Endian>(span, r_decode_count); + } else if constexpr (HasDecode) { + return T::template decode(span, r_decode_count); + } else if constexpr (utility::specialization_of) { + auto index = get_index>(decode(span, r_decode_count)); + size_t offset = r_decode_count; + T value = std::visit( + [&](auto INDEX) -> T { + using assigned_type = std::variant_alternative_t; + return decode(span.subspan(offset), r_decode_count); + }, + index + ); + r_decode_count = offset + r_decode_count; + return value; + } else if constexpr (std::same_as) { + return static_cast(decode(span, r_decode_count)); + } else if constexpr (std::integral) { + using unsigned_type = std::make_unsigned_t; + + r_decode_count = sizeof(T); + if (OV_unlikely(span.size() < r_decode_count)) { + r_decode_count = 0; + OV_ERR_FAIL_V({}); + } + + if constexpr (sizeof(T) == 1) { + return std::bit_cast(span[0]); + } else { + decltype(span)::pointer ptr = span.data(); + T u = 0; + for (int i = 0; i < r_decode_count; i++) { + T b = *ptr; + b <<= (i * 8); + u |= b; + ptr++; + } + if constexpr (std::endian::native != Endian) { + u = utility::byteswap(u); + } + return u; + } + } else if constexpr (std::floating_point) { + using int_type = std::conditional_t, uint32_t, uint64_t>; + + return std::bit_cast(decode(span, r_decode_count)); + } else if constexpr (std::is_enum_v) { + using underlying_type = std::underlying_type_t; + + return std::bit_cast(decode(span, r_decode_count)); + } else if constexpr (std::convertible_to) { + uint32_t size = decode(span, r_decode_count); + OV_ERR_FAIL_INDEX_V(size, span.size(), {}); + std::string_view view { reinterpret_cast(span.data() + r_decode_count), size }; + r_decode_count += size; + return static_cast(view); + } else if constexpr (std::is_constructible_v) { + return T { decode(span, r_decode_count) }; + } else if constexpr (std::same_as) { + return { half_to_float(decode(span, r_decode_count)) }; + } else if constexpr (std::same_as) { + return T { decode(span, r_decode_count) }; + } else if constexpr (std::same_as) { + return T { decode(span, r_decode_count) }; + } else if constexpr (IsColour) { + return T::from_integer(decode(span, r_decode_count)); + } else if constexpr (std::same_as) { + return T::parse_raw(decode(span, r_decode_count)); + } else if constexpr (utility::specialization_of || + utility::specialization_of) { + uint32_t size = decode(span, r_decode_count); + OV_ERR_FAIL_INDEX_V(size, span.size(), T {}); + T value; + if constexpr (requires { value.reserve(size); }) { + value.reserve(size); + } + size_t offset = r_decode_count; + for (size_t index = 0; index < size; index++) { + value.emplace_back(decode(span.subspan(offset), r_decode_count)); + offset += r_decode_count; + } + r_decode_count = offset; + return value; + } else if constexpr (std::is_constructible_v) { + uint32_t size = decode(span, r_decode_count); + OV_ERR_FAIL_INDEX_V(size, span.size(), (T { (uint8_t*)nullptr, (uint8_t*)nullptr })); + T value // + { // + reinterpret_cast(span.data() + r_decode_count), // + reinterpret_cast(span.data() + r_decode_count + size) + }; + r_decode_count += size; + return value; + } else if constexpr (utility::specialization_of || utility::specialization_of) { + union { + T tuple; + bool dummy; + } value = { .dummy = false }; + size_t offset = 0; + std::apply( + [&](auto&... args) { + ((args = decode, Endian>(span.subspan(offset), r_decode_count), + offset += r_decode_count), + ...); + }, + value.tuple + ); + r_decode_count = offset; + + return std::move(value.tuple); + } else if constexpr (utility::specialization_of_std_array_of) { + T value; + OV_ERR_FAIL_INDEX_V(value.size(), span.size(), {}); + uint32_t size = decode(span, r_decode_count); + std::uninitialized_copy_n(span.subspan(r_decode_count).data(), size, value.data()); + r_decode_count += size; + return value; + } else if constexpr (utility::specialization_of) { + T value; + value.x = decode(span, r_decode_count); + size_t offset = r_decode_count; + value.y = decode(span.subspan(offset), r_decode_count); + r_decode_count += offset; + return value; + } else if constexpr (utility::specialization_of) { + T value; + value.x = decode(span, r_decode_count); + size_t offset = r_decode_count; + value.y = decode(span.subspan(offset), r_decode_count); + offset += r_decode_count; + value.z = decode(span.subspan(offset), r_decode_count); + r_decode_count += offset; + return value; + } else if constexpr (utility::specialization_of) { + T value; + value.x = decode(span, r_decode_count); + size_t offset = r_decode_count; + value.y = decode(span.subspan(offset), r_decode_count); + offset += r_decode_count; + value.z = decode(span.subspan(offset), r_decode_count); + offset += r_decode_count; + value.w = decode(span.subspan(offset), r_decode_count); + r_decode_count += offset; + return value; + } else if constexpr (std::is_empty_v && std::is_default_constructible_v) { + r_decode_count = 0; + return T(); + } + } +} diff --git a/src/openvic-simulation/utility/Utility.hpp b/src/openvic-simulation/utility/Utility.hpp index 29a6df79f..c99cdd688 100644 --- a/src/openvic-simulation/utility/Utility.hpp +++ b/src/openvic-simulation/utility/Utility.hpp @@ -1,10 +1,15 @@ #pragma once +#include +#include #include #include #include +#include +#include #include #include +#include #include #include @@ -123,13 +128,47 @@ namespace OpenVic::utility { template class Template> concept not_specialization_of = !specialization_of; - template class Template, typename... Args> + template class Template, typename... Args> void _derived_from_specialization_impl(Template const&); - template class Template> - concept is_derived_from_specialization_of = requires(T const& t) { - _derived_from_specialization_impl