1
0
Fork 0
Univerxel/src/core/net/Context.hpp

165 lines
4.8 KiB
C++

#pragma once
#include "../../core/net/data.hpp"
#include "../../core/data/mem.hpp"
#include <picoquic/picoquic.h>
#include <forward_list>
namespace net {
/// Abstract stream
struct stream_ctx {
stream_ctx(uint64_t id): stream_id(id) { }
uint64_t stream_id;
static bool IsUnidirId(uint64_t id);
static bool IsServerId(uint64_t id);
static bool IsClientId(uint64_t id);
friend bool operator==(const stream_ctx &a, const stream_ctx &b) { return a.stream_id == b.stream_id; }
};
/// Outgoing stream
struct out_stream_ctx: stream_ctx {
out_stream_ctx(uint64_t id, uint8_t queue_id, const data::out_buffer& buf):
stream_ctx(id), buffer(buf), queue_id(queue_id) { }
data::out_buffer buffer;
uint8_t queue_id;
};
/// Incoming stream
struct in_stream_ctx: stream_ctx {
in_stream_ctx(uint64_t id): stream_ctx(id) { }
data::in_vector buffer;
};
/// Informations about packets
enum class PacketFlags {
NONE = 0,
DATAGRAM = 1 << 0,
TINY = 1 << 1,
};
/// Abstract QUIC context
class Context {
public:
static constexpr auto MAX_SOCKET_COUNT = 2;
Context(uint32_t max_connections, char const *cert, char const *key,
picoquic_stream_data_cb_fn callback, void* callback_ctx,
char const *tickets, char const *tokens);
/// As client
Context(char const* tickets, char const* tokens):
Context(1, NULL, NULL, NULL, NULL, tickets, tokens) { }
/// As server
Context(uint32_t max_connections, char const *cert, char const *key,
picoquic_stream_data_cb_fn callback, void* callback_ctx):
Context(max_connections, cert, key, callback, callback_ctx, NULL, NULL) { }
~Context();
/// Act as pull break
virtual bool isRunning() const = 0;
constexpr picoquic_quic_t *getHandle() {
assert(quic != nullptr);
return quic;
}
protected:
/// Start listening
void openSockets(int port, int family);
/// Read-write on sockets and notify callbacks
void pull(uint64_t maxDelay, uint64_t maxIn, uint64_t maxOut);
private:
picoquic_quic_t *quic = nullptr;
char const* ticket_store_filename = nullptr; // MAYBE: for 0-RTT
char const* token_store_filename = nullptr;
int nb_sockets = 0;
SOCKET_TYPE sockets[MAX_SOCKET_COUNT];
int sockets_family[MAX_SOCKET_COUNT];
uint16_t socket_port = 0;
};
/// Abstract QUIC connection
class Connection {
public:
Connection(picoquic_cnx_t *cnx, bool is_client, uint8_t queues):
cnx(cnx), is_client(is_client) {
outgoing.resize(queues);
}
~Connection();
/// Setup client connection
void setup(picoquic_quic_t* ctx, sockaddr* addr, const char* sni, picoquic_stream_data_cb_fn cb_fn, void *cb_ctx);
void setHandle(picoquic_cnx_t *c);
constexpr bool contains(picoquic_cnx_t *c) const { return cnx == c; }
constexpr bool operator==(const Connection &other) const { return other.contains(cnx); }
/// Close connection
void release(uint16_t reason);
void setCallback(picoquic_stream_data_cb_fn cb_fn, void *ctx);
uint16_t getErrorCode(bool is_app);
std::string getAddress();
/// Send reliable data
/// ptr is memcpy
void sendCopy(const uint8_t* ptr, size_t size, uint8_t queue = 0, size_t queue_size = 0);
/// Send reliable data
/// take view ownership
void send(const data::out_view &view, uint8_t queue = 0, size_t queue_size = 0);
/// Send reliable data
/// buffer must stay valid until handle is freed
void send(const data::out_buffer& buffer, uint8_t queue = 0, size_t queue_size = 0);
template <typename D>
void emit(const D &data) { emit((const uint8_t*)&data, sizeof(D)); }
/// Send unreliable data
void emit(const uint8_t *ptr, size_t size);
/// Register incoming stream
in_stream_ctx *receive(uint64_t streamId);
/// Close outgoing stream
void close(out_stream_ctx *str);
/// Close incoming stream
void close(in_stream_ctx *str);
/// Reset stream
void reset(uint64_t id);
template<typename C>
static constexpr size_t GetSize(const C& list) {
size_t res = 0;
for (auto it = list.begin(); it != list.end(); ++it) {
++res;
};
return res;
}
size_t queueSize(uint8_t id) const { return GetSize(outgoing.at(id).streams); }
private:
picoquic_cnx_t *cnx = nullptr;
bool is_client;
struct queue_ctx {
data::out_buffer pending = data::out_buffer();
std::forward_list<out_stream_ctx> streams;
};
std::vector<queue_ctx> outgoing;
uint64_t outgoingCounter = 0;
uint64_t nextOutgoingId() {
return (outgoingCounter++) * 4 + PICOQUIC_STREAM_ID_UNIDIR + (is_client ?
PICOQUIC_STREAM_ID_CLIENT_INITIATED : PICOQUIC_STREAM_ID_SERVER_INITIATED);
}
std::forward_list<in_stream_ctx> incoming;
// MAYBE: add bidirectional streams
};
}