#pragma once #include "../../core/net/data.hpp" #include "../../core/data/mem.hpp" #include #include namespace net { /// Abstract stream struct stream_ctx { stream_ctx(uint64_t id): stream_id(id) { } uint64_t stream_id; static bool IsUnidirId(uint64_t id); static bool IsServerId(uint64_t id); static bool IsClientId(uint64_t id); friend bool operator==(const stream_ctx &a, const stream_ctx &b) { return a.stream_id == b.stream_id; } }; /// Outgoing stream struct out_stream_ctx: stream_ctx { out_stream_ctx(uint64_t id, uint8_t queue_id, const data::out_buffer& buf): stream_ctx(id), buffer(buf), queue_id(queue_id) { } data::out_buffer buffer; uint8_t queue_id; }; /// Incoming stream struct in_stream_ctx: stream_ctx { in_stream_ctx(uint64_t id): stream_ctx(id) { } data::in_vector buffer; }; /// Informations about packets enum class PacketFlags { NONE = 0, DATAGRAM = 1 << 0, TINY = 1 << 1, }; /// Abstract QUIC context class Context { public: static constexpr auto MAX_SOCKET_COUNT = 2; Context(uint32_t max_connections, char const *cert, char const *key, picoquic_stream_data_cb_fn callback, void* callback_ctx, char const *tickets, char const *tokens); /// As client Context(char const* tickets, char const* tokens): Context(1, NULL, NULL, NULL, NULL, tickets, tokens) { } /// As server Context(uint32_t max_connections, char const *cert, char const *key, picoquic_stream_data_cb_fn callback, void* callback_ctx): Context(max_connections, cert, key, callback, callback_ctx, NULL, NULL) { } ~Context(); /// Act as pull break virtual bool isRunning() const = 0; constexpr picoquic_quic_t *getHandle() { assert(quic != nullptr); return quic; } protected: /// Start listening void openSockets(int port, int family); /// Read-write on sockets and notify callbacks void pull(uint64_t maxDelay, uint64_t maxIn, uint64_t maxOut); private: picoquic_quic_t *quic = nullptr; char const* ticket_store_filename = nullptr; // MAYBE: for 0-RTT char const* token_store_filename = nullptr; int nb_sockets = 0; SOCKET_TYPE sockets[MAX_SOCKET_COUNT]; int sockets_family[MAX_SOCKET_COUNT]; uint16_t socket_port = 0; }; /// Abstract QUIC connection class Connection { public: Connection(picoquic_cnx_t *cnx, bool is_client, uint8_t queues): cnx(cnx), is_client(is_client) { outgoing.resize(queues); } ~Connection(); /// Setup client connection void setup(picoquic_quic_t* ctx, sockaddr* addr, const char* sni, picoquic_stream_data_cb_fn cb_fn, void *cb_ctx); void setHandle(picoquic_cnx_t *c); constexpr bool contains(picoquic_cnx_t *c) const { return cnx == c; } constexpr bool operator==(const Connection &other) const { return other.contains(cnx); } /// Close connection void release(uint16_t reason); void setCallback(picoquic_stream_data_cb_fn cb_fn, void *ctx); uint16_t getErrorCode(bool is_app); std::string getAddress(); /// Send reliable data /// ptr is memcpy void sendCopy(const uint8_t* ptr, size_t size, uint8_t queue = 0, size_t queue_size = 0); /// Send reliable data /// take view ownership void send(const data::out_view &view, uint8_t queue = 0, size_t queue_size = 0); /// Send reliable data /// buffer must stay valid until handle is freed void send(const data::out_buffer& buffer, uint8_t queue = 0, size_t queue_size = 0); template void emit(const D &data) { emit((const uint8_t*)&data, sizeof(D)); } /// Send unreliable data void emit(const uint8_t *ptr, size_t size); /// Register incoming stream in_stream_ctx *receive(uint64_t streamId); /// Close outgoing stream void close(out_stream_ctx *str); /// Close incoming stream void close(in_stream_ctx *str); /// Reset stream void reset(uint64_t id); template static constexpr size_t GetSize(const C& list) { size_t res = 0; for (auto it = list.begin(); it != list.end(); ++it) { ++res; }; return res; } size_t queueSize(uint8_t id) const { return GetSize(outgoing.at(id).streams); } private: picoquic_cnx_t *cnx = nullptr; bool is_client; struct queue_ctx { data::out_buffer pending = data::out_buffer(); std::forward_list streams; }; std::vector outgoing; uint64_t outgoingCounter = 0; uint64_t nextOutgoingId() { return (outgoingCounter++) * 4 + PICOQUIC_STREAM_ID_UNIDIR + (is_client ? PICOQUIC_STREAM_ID_CLIENT_INITIATED : PICOQUIC_STREAM_ID_SERVER_INITIATED); } std::forward_list incoming; // MAYBE: add bidirectional streams }; }