#include "tcp.h" #include #include #include #include #include #include #include #define report_error(socket, error) \ do { \ LDEBUGF("error on socket %i: %i", socket, error); \ write(socket, &zerozu, sizeof(zerozu)); \ enum commd_error __ecpy = htonl(error); \ write(socket, &__ecpy, sizeof(__ecpy)); \ return; \ } while (0) #define return_error(err) \ do { \ *e = err; \ return 0; \ } while (0) #define process_error(pfd_p, state_p, error) \ do { \ pfd_p->events &= ~POLLIN; \ pfd_p->events |= POLLOUT; \ state_p->err = htonl(error); \ state_p->progress = PROG_ERRR; \ return; \ } while (0) struct hash_set channel_pools; struct hash_set channels; struct hash_set users; static struct tcp_user varusr = {0}; static struct tcp_channel varchnl = {0}; static u64 lalloc_usr = 0; static int udpctlfd = 0; static const u64 zerozu = 0; static bool should_exit = false; void print_state(int _) { (void)_; #ifdef DBG fputs("printing server state.\n hash_set users {\n", stderr); struct hash_set_iter iter; for (hseti_begin(&users, &iter); !hseti_end(&iter); hseti_next(&iter)) { struct tcp_user *curr = hseti_get(&iter); fprintf(stderr, "\ttcp_user {.id=%zu, .permissions=%u, .channel=%zu}\n", curr->id, curr->permissions, curr->joined_channel); } fputs("}\nhash_set channels {\n", stderr); for (hseti_begin(&channels, &iter); !hseti_end(&iter); hseti_next(&iter)) { struct tcp_channel *curr = hseti_get(&iter); struct sockaddr_in addr; socklen_t addrlen = sizeof(addr); getsockname(curr->fd, (struct sockaddr *)&addr, &addrlen); fprintf(stderr, "\ttcp_channel {.id=%zu, .fd=%u, .channel=%zu} [port=%hu]\n", curr->id, curr->fd, curr->owner, addr.sin_port); } fputs("}\n", stderr); #endif } void exit_tcp(int _) { (void)_; should_exit = true; LINFOF("EXITING SERVER, setting `should_exit (%p)` to %i", (void*)&should_exit, (int)should_exit); } static int tcp_user_cmp(const struct tcp_user *a, const struct tcp_user *b) { return (a->id - b->id) ? 1 : 0; } static int tcp_channel_cmp(const struct tcp_channel *a, const struct tcp_channel *b) { return (a->id - b->id) ? 1 : 0; } static size_t tcp_user_hash(const struct tcp_user *a) { return a->id; } static size_t tcp_channel_hash(const struct tcp_channel *a) { return a->id; } static int set_sock_timeout(int fd, int ms) { struct timeval timeout; timeout.tv_sec = ms / 1000; timeout.tv_usec = (ms % 1000) * 1000; return setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)); } static void init_statics(void) { channels = hset_new(sizeof(struct tcp_channel), (hset_equal_fn)&tcp_channel_cmp, (hset_hash_fn)&tcp_channel_hash); users = hset_new(sizeof(struct tcp_user), (hset_equal_fn)&tcp_user_cmp, (hset_hash_fn)&tcp_user_hash); udpctlfd = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); } static int setup_socket(unsigned short port) { int sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); int flag = 1; if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof flag)) goto error; if (fcntl(sock, F_SETFL, O_NONBLOCK)) goto error; struct sockaddr_in localaddr = { .sin_family = AF_INET, .sin_port = htons(port), .sin_addr = {htonl(INADDR_ANY)}}; if (bind(sock, (struct sockaddr *)&localaddr, sizeof(localaddr))) goto error; if (listen(sock, LISTEN_AMOUNT)) goto error; return sock; error: LFAIL("TCP thread failed to initialize"); } static void init_admin(u64 aid) { struct tcp_user u = {.id = aid, .pubkey = NULL, .permissions = PERM_ADMIN, .joined_channel = 0}; hset_insert_copy(&users, &u); } static bool has_4bytes_0xff(u64 id) { return (unsigned int)(id >> 32) == 0xFFFFFFFF || (unsigned int)(id & 0xFFFFFFFF) == 0xFFFFFFFF; } static u64 get_uniq_id(struct hash_set *set) { /* while (map has lalloc_usr inex || lalloc_usr has 4 bytes of ones) next lalloc_usr; */ varusr.id = lalloc_usr; while (hset_at(set, &varusr) != NULL || has_4bytes_0xff(varusr.id)) ++varusr.id; lalloc_usr = varusr.id; return lalloc_usr; } static unsigned short get_channel_port(u64 id) { varchnl.id = id; struct tcp_channel *ch = hset_at(&channels, &varchnl); if (ch == NULL) return 0; struct sockaddr_in a; socklen_t len = sizeof a; getsockname(ch->fd, (struct sockaddr *)&a, &len); return a.sin_port; } static bool user_has_permission(u64 uid, unsigned int perm) { varusr.id = uid; struct tcp_user *u = hset_at(&users, &varusr); if (u == NULL) return false; unsigned int uperm = u->permissions; /* bitwise implication must yield all ones (0xFFFFFFFF). */ /* Invert it for easier check */ return (perm & ~uperm) == 0; } /* TODO: remove */ static u64 send_channels(int sockfd, enum commd_error *e) { struct hash_set_iter iter; u64 array_length = hton64(hset_count(&channels)); if (write(sockfd, &array_length, sizeof(array_length)) != sizeof(array_length)) return_error(ERR_SERV); for (hseti_begin(&channels, &iter); !hseti_end(&iter); hseti_next(&iter)) { struct tcp_channel *c = hseti_get(&iter); u64 chid = hton64(c->id); if (write(sockfd, &chid, sizeof(chid)) != sizeof(chid)) return_error(ERR_SERV); } /* the leading zero is written by the caller */ /* or is it? */ return_error(ERR_SUCCESS); /* actually returns success but... */ } static u64* get_channel_array(void) { /* consider later: store allocated buffer with all channels somewhere * but I would need to avoid race conditions (channel added while I was writing) */ hset_iter iter; size_t i = 0; u64 len = hset_count(&channels); u64* ret = array_new(u64, len + 2); ret[i++] = hton64(len); /* ret[0]; i=1 */ for (hseti_begin(&channels, &iter); !hseti_end(&iter); hseti_next(&iter)) { struct tcp_channel *c = hseti_get(&iter); ret[i++] = c->id; } ret[i] = 0LU; return ret; } static inline u64 commd_register_process(struct commd_register *cmd, enum commd_error *e) { /* fprintf(stderr, "%s: auid=%zu; perm=%zu\n", "commd_register_process", cmd->auid, cmd->perm); */ if (!user_has_permission(cmd->auid, PERM_JOIN_USER | cmd->perm)) return_error(ERR_ACCESS); struct tcp_user new_user = { .id = get_uniq_id(&users), .joined_channel = 0, .permissions = (unsigned int)cmd->perm}; hset_insert_copy(&users, &new_user); return new_user.id; } static inline u64 commd_unregister_process(struct commd_unregister *cmd, enum commd_error *e) { LDEBUGF("delete user %zu (admin %zu)", cmd->uid, cmd->auid); if (cmd->auid != cmd->uid && !user_has_permission(cmd->auid, PERM_UNREGISTER_USER)) return_error(ERR_ACCESS); varusr.id = cmd->uid; hset_remove(&users, &varusr); return cmd->uid; } static inline u64 commd_create_process(struct commd_create *cmd, enum commd_error *e) { if (!user_has_permission(cmd->uid, PERM_ADD_CHANNEL)) return_error(ERR_ACCESS); u64 chid; int sock = -1; { pthread_mutex_t sock_mx = PTHREAD_MUTEX_INITIALIZER; pthread_cond_t sock_cond = PTHREAD_COND_INITIALIZER; struct thread_loop_arg arg = { .owner = cmd->uid, .sock_dest = &sock, .sock_mx = &sock_mx, .sock_ready_cond = &sock_cond}; pthread_mutex_lock(&sock_mx); chid = spawn_channel(&arg); while (sock == -1) pthread_cond_wait(&sock_cond, &sock_mx); pthread_mutex_unlock(&sock_mx); } struct tcp_channel new_channel = {.id = chid, .owner = cmd->uid, .name = NULL, .fd = sock}; hset_insert_copy(&channels, &new_channel); return chid; } static inline u64 commd_delete_process(struct commd_delete *cmd, enum commd_error *e) { LDEBUGF("received command%p", (void *)cmd); varchnl.id = cmd->chid; struct tcp_channel *c = hset_at(&channels, &varchnl); if (c == NULL) return_error(ERR_PARAM); if (cmd->uid != c->owner && !user_has_permission(cmd->uid, PERM_UNADD_CHANNEL)) return_error(ERR_ACCESS); hset_remove(&channels, &varchnl); return varchnl.id; } static inline u64 commd_join_process(struct commd_join *cmd, enum commd_error *e) { if (cmd->uid != cmd->juid && !user_has_permission(cmd->uid, PERM_JOIN_USER)) return_error(ERR_ACCESS); struct kv_system_packet packet = { .magic_bytes = SYS_PACKET_MAGIC_BYTES, .operation_id = htonl(SYS_JOIN), .user_id = cmd->juid}; if (!sendto_channel(cmd->chid, &packet, TCP_MAX_WAIT_MS, TCP_MAX_RETRIES)) return_error(ERR_SERV); return (u64)get_channel_port(cmd->chid); } static inline u64 commd_leave_process(struct commd_leave *cmd, enum commd_error *e) { if (cmd->uid != cmd->luid && !user_has_permission(cmd->uid, PERM_KICK_USER)) return_error(ERR_ACCESS); struct kv_system_packet packet = { .magic_bytes = SYS_PACKET_MAGIC_BYTES, .operation_id = htonl(SYS_LEAVE), .user_id = cmd->luid}; if (!sendto_channel(cmd->chid, &packet, TCP_MAX_WAIT_MS, TCP_MAX_RETRIES)) return_error(ERR_SERV); return 1; } /* switches on command type and operates accordingly */ static u64 process_cmd(enum commd_type type, struct commd *cmd, enum commd_error *NONNULL e) { switch (type) { case CMD_REGISTER: return commd_register_process((struct commd_register *)cmd, e); case CMD_UNREGISTER: return commd_unregister_process((struct commd_unregister *)cmd, e); case CMD_CREATE: return commd_create_process((struct commd_create *)cmd, e); case CMD_DELETE: return commd_delete_process((struct commd_delete *)cmd, e); case CMD_JOIN: return commd_join_process((struct commd_join *)cmd, e); case CMD_LEAVE: return commd_leave_process((struct commd_leave *)cmd, e); case CMD_GET_PORT: return (u64)get_channel_port(((struct commd_get_port *)cmd)->cihd); case CMD_GET_CHANNELS: UNREACHABLE; /* this should be unreachable */ case CMD_LAST: UNREACHABLE; /* this as well */ default: return_error(ERR_TYPE); } } /* коммент */ static void process_connection(int sockfd) { LDEBUG("Processing connection"); /* TODO: protection against blocking reads */ /* this will become irrelevant after implementing poll */ enum commd_type type; if (read(sockfd, &type, sizeof(type)) != sizeof(type)) report_error(sockfd, ERR_INVAL); type = ntohl(type); struct commd cmd; memset(&cmd, 0, sizeof(cmd)); /* TODO: consider to remove */ i64 commd_size = commd_size_lookup[type]; if (read(sockfd, &cmd, commd_size) != commd_size) report_error(sockfd, ERR_INVAL); enum commd_error e = ERR_SUCCESS; u64 cmd_status = process_cmd(type, &cmd, &e); if (e == ERR_DO_IT_YOURSELF) cmd_status = send_channels(sockfd, &e); cmd_status = hton64(cmd_status); if (e != ERR_SUCCESS) report_error(sockfd, e); write(sockfd, &cmd_status, sizeof(cmd_status)); } u64 spawn_channel(struct thread_loop_arg *arg) { pthread_t thread; pthread_create(&thread, NULL, thread_loop, arg); return thread; } bool sendto_channel(u64 chid, struct kv_system_packet *packet, int wait_ack_ms, int repeat) { bool success = wait_ack_ms == 0; varchnl.id = chid; struct tcp_channel *ch = hset_at(&channels, &varchnl); if (ch == NULL) return false; set_sock_timeout(udpctlfd, wait_ack_ms); struct sockaddr_in chaddr = {0}; socklen_t len = sizeof(chaddr); getsockname(ch->fd, (struct sockaddr *)&chaddr, &len); do { sendto(udpctlfd, packet, KV_PACKET_SIZE, 0, (struct sockaddr *)&chaddr, len); if (wait_ack_ms == 0) continue; struct kv_system_packet resp; recvfrom(udpctlfd, &resp, KV_PACKET_SIZE, 0, (struct sockaddr *)&chaddr, &len); if (errno == EWOULDBLOCK || errno == EAGAIN) continue; if (resp.operation_id == SYS_ACK) success = true; } while (--repeat >= 0); return success; } void tcp_loop(void) { init_statics(); init_admin(ADMIN_UID); int sock = setup_socket(TCP_PORT); if (listen(sock, LISTEN_AMOUNT) != 0) LFAIL("listen on TCP socket failed"); LDEBUGF("listening on port %hu", TCP_PORT); struct sockaddr_in accept_addr; socklen_t addrlen = sizeof(accept_addr); int currfd; while (!should_exit) { currfd = accept(sock, (struct sockaddr *)&accept_addr, &addrlen); if (currfd < 0) continue; LDEBUGF("accepted connection on port %hu", accept_addr.sin_port); process_connection(currfd); /* this is synchronous and UNACCEPTABLE! */ shutdown(currfd, SHUT_RDWR); close(currfd); } close(sock); close(udpctlfd); hset_free(&users); hset_free(&channels); } /* Accepts address of pointer to array cause pointers might have to be modified */ static void accept_new(struct pollfd **_sockets, struct connection_state **_states) { struct pollfd * sockets = *_sockets; /* This is a hack */ struct connection_state * states = *_states; /* And a dirty one */ LDEBUG("Accepting new connection"); struct pollfd newsock = {accept(sockets[0].fd, NULL, NULL), POLLIN, 0}; if (newsock.fd < 0) { LWARNP("Accept"); return; } if(fcntl(newsock.fd, F_SETFL, O_NONBLOCK)) { LWARNF("Fcntl on fd %i failed: %s (errno %i)", newsock.fd, strerror(errno), errno); return; } struct connection_state newstate; newstate.progress = PROG_NONE, newstate.bytes_ctr = 0, newstate.response = 0, newstate.cmd_size = 0, newstate.err = ERR_SUCCESS, newstate.type = -1/*, newstate.response_array = NULL*/; array_push(sockets, newsock); array_push(states, newstate); *_sockets = sockets; /* Yea, this is */ *_states = states; /* DiSgUsTiNg */ LDEBUGF("Accepted new connection: fd %i. Array size: %zu", newsock.fd, array_size(sockets)); } static void reopen_master(int *masterfd) { LWARN("Reopening master socket"); shutdown(*masterfd, SHUT_RDWR); close(*masterfd); *masterfd = setup_socket(TCP_PORT); } static bool try_read(int *fdp, u8* restrict where, u64* bytes_read, u64 how_much) { int fd = *fdp; i64 bytes = read(fd, where + *bytes_read, how_much - *bytes_read); if (bytes <= 0) { if (bytes < 0) LWARNP("Read on client socket"); *fdp = ~*fdp; /* Loop closes client socket */ return false; } *bytes_read += bytes; if (*bytes_read == how_much) { *bytes_read = 0; return true; } return false; /* false? lick my ballse */ } static bool try_write(int *fdp, const u8* from, u64* bytes_written, u64 how_much) { int fd = *fdp; i64 bytes = write(fd, from + *bytes_written, how_much - *bytes_written); if (bytes <= 0) { if (bytes < 0) { LWARNP("Write on client socket"); *fdp = ~*fdp; /* Loop closes client socket */ } return false; } *bytes_written += bytes; if (*bytes_written == how_much) { *bytes_written = 0; return true; } return false; /* false? lick my ballse */ } /* Accepts NOT arrays, but individual element pointers * Note: how tf did this functinon turn into an asynchronous state machine... * The calling function MUST free the array if, * upon exit, fd->fd < 0 && state->progress == PROG_ARRY */ static void process_socket(struct pollfd* fd, struct connection_state* state) { LDEBUGF("FD %i: progress %i, type %i", fd->fd, state->progress, state->type); if (fd->revents & (POLLNVAL | POLLERR | POLLHUP)) goto pollerr; switch (state->progress) { case PROG_NONE: if (!try_read(&fd->fd, (u8*)&state->type, &state->bytes_ctr, sizeof (enum commd_type))) return; /* continue reading on next iteration */ state->type = ntohl(state->type); if (state->type < 0 || state->type >= CMD_LAST) process_error(fd, state, ERR_TYPE); state->cmd_size = commd_size_lookup[state->type]; ++state->progress; return; case PROG_TYPE: if (!try_read(&fd->fd, (u8*)&state->cmd, &state->bytes_ctr, state->cmd_size)) return; /* continue reading on next iteration */ struct commd_conv *cmd = (struct commd_conv *)&state->cmd; /* Network byte order conversion */ switch (state->type) { case CMD_LEAVE: case CMD_JOIN: cmd->_3 = ntoh64(cmd->_3); FALLTHROUGH; case CMD_UNREGISTER: case CMD_REGISTER: case CMD_DELETE: cmd->_2 = ntoh64(cmd->_2); FALLTHROUGH; case CMD_CREATE: case CMD_GET_PORT: cmd->_1 = ntoh64(cmd->_1); FALLTHROUGH; case CMD_GET_CHANNELS: break; case CMD_LAST: UNREACHABLE; /* because it would have terminated on case PROG_NONE */ } fd->events &= ~POLLIN; /* remove */ fd->events |= POLLOUT; /* add */ state->progress += 1 + (state->type == CMD_GET_CHANNELS); /* same as: state->progress = state->type == CMD_GET_CHANNELS ? PROG_ARRAY : PROG_COMD; */ if (state->progress == PROG_COMD) /* The actual command processing is done in this call */ state->response = hton64(process_cmd(state->type, &state->cmd, &state->err)); else if (state->progress == PROG_ARRY) state->response_array = get_channel_array(); if (state->err != ERR_SUCCESS) { state->err = htonl(state->err); state->progress = PROG_ERRR; return; } return; case PROG_COMD: if (!try_write(&fd->fd, (u8*)&state->response, &state->bytes_ctr, sizeof (u64))) return; /* continue writing on next iteration */ goto socket_done; case PROG_ARRY: if (!try_write( &fd->fd, (u8*)state->response_array, &state->bytes_ctr, array_size(state->response_array) * sizeof (u64) )) return; /* continue writing on next iteration */ /* array is freed in calling function */ goto socket_done; case PROG_ERRR: if (!try_write(&fd->fd, (u8*)&state->err, &state->bytes_ctr, sizeof (enum commd_error))) return; /* continue writing on next iteration */ goto socket_done; } pollerr: LWARNF("Poll error on socket %i; closing connection (revents = %i)", fd->fd, fd->revents); socket_done: shutdown(fd->fd, SHUT_RDWR); close(fd->fd); fd->fd = ~fd->fd; return; } void new_tcp_loop(void) { LINFO("Starting TCP loop"); init_statics(); init_admin(ADMIN_UID); int master = setup_socket(TCP_PORT); struct pollfd *sockets = array_new(struct pollfd, 1); struct connection_state *states = array_new(struct connection_state, 1); sockets[0].fd = master; sockets[0].events = POLLIN; int numready; for(;;) { numready = poll(sockets, array_size(sockets), POLL_RESTART); if (should_exit == true) break; if (numready < 0) switch (errno) { case 0: break; case EINTR: LINFO("Poll call was interrupted"); continue; case ENOMEM: sleep(1); continue; case EFAULT: case EINVAL: LFAILV("poll"); } #if POLL_RESTART > 0 if (numready == 0) continue; #endif /* assert(array_size(sockets) == array_size(states)); */ LDEBUGF("poll: %i sockets ready", numready); /* master socket */ if (sockets[0].revents != 0) { if (sockets[0].revents & POLLIN) accept_new(&sockets, &states); if (sockets[0].revents & (POLLNVAL | POLLHUP | POLLERR)) reopen_master(&sockets[0].fd); } LDEBUG("Array:"); for (size_t i = 0; i < array_size(sockets); ++i) { LDEBUGF("%i", sockets[i].fd); } /* other sockets */ /* Please tell me I am not the only person who fucking hates reverse loops with * array deletinos... */ for (size_t i = array_size(sockets) - 1; i > 0; --i) { if (sockets[i].revents != 0) { process_socket(sockets + i, states + i); if (sockets[i].fd < 0) { array_pop_at(sockets, i); array_pop_at(states, i); } } } } for (size_t i = 0; i < array_size(sockets); ++i) { shutdown(sockets[i].fd, SHUT_RDWR); close(sockets[i].fd); } array_free(sockets); array_free(states); hset_free(&users); hset_free(&channels); } /* vim: set ts=8 noet: */