diff options
Diffstat (limited to 'src/server')
-rw-r--r-- | src/server/addr.c | 8 | ||||
-rw-r--r-- | src/server/addr.h | 4 | ||||
-rw-r--r-- | src/server/binding.c | 6 | ||||
-rw-r--r-- | src/server/resolver.c | 29 | ||||
-rw-r--r-- | src/server/resolver.h | 3 | ||||
-rw-r--r-- | src/server/server.c | 112 |
6 files changed, 101 insertions, 61 deletions
diff --git a/src/server/addr.c b/src/server/addr.c index 982da13..14f44c6 100644 --- a/src/server/addr.c +++ b/src/server/addr.c @@ -7,12 +7,12 @@ #include "addr.h" #include "../io/log.h" -void create_ip_addr(char* domain, IpAddr* addr) { +void create_ip_addr(uint8_t* domain, IpAddr* addr) { addr->type = V4; memcpy(&addr->data.v4.s_addr, domain, 4); } -void create_ip_addr6(char* domain, IpAddr* addr) { +void create_ip_addr6(uint8_t* domain, IpAddr* addr) { addr->type = V6; memcpy(&addr->data.v6.__in6_u.__u6_addr8, domain, 16); } @@ -76,7 +76,7 @@ void print_socket_addr(SocketAddr* addr, char* buffer) { (uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 16), (uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 8), (uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr), - addr->data.v4.sin_port + ntohs(addr->data.v4.sin_port) ); } else { for(int i = 0; i < 8; i++) { @@ -85,7 +85,7 @@ void print_socket_addr(SocketAddr* addr, char* buffer) { addr->data.v6.sin6_addr.__in6_u.__u6_addr8[i*2 + 1] ); } - APPEND(buffer, ":[%hu]", addr->data.v6.sin6_port); + APPEND(buffer, ":[%hu]", ntohs(addr->data.v6.sin6_port)); } } diff --git a/src/server/addr.h b/src/server/addr.h index 173c7fd..1210850 100644 --- a/src/server/addr.h +++ b/src/server/addr.h @@ -20,8 +20,8 @@ typedef struct { } data; } IpAddr; -void create_ip_addr(char* domain, IpAddr* addr); -void create_ip_addr6(char* domain, IpAddr* addr); +void create_ip_addr(uint8_t* domain, IpAddr* addr); +void create_ip_addr6(uint8_t* domain, IpAddr* addr); void ip_addr_any(IpAddr* addr); void ip_addr_any6(IpAddr* addr); diff --git a/src/server/binding.c b/src/server/binding.c index 47c62c6..157d1d4 100644 --- a/src/server/binding.c +++ b/src/server/binding.c @@ -144,10 +144,10 @@ bool read_connection(Connection* connection, Packet* packet) { } static bool write_udp(Connection* connection, uint8_t* buf, uint16_t len) { - //if (len > 512) { + if (len > 512) { buf[2] = buf[2] | 0x03; - // len = 512; - // } + len = 512; + } return write_udp_socket( &connection->sock.udp.udp, buf, diff --git a/src/server/resolver.c b/src/server/resolver.c index e05f365..a1fa82a 100644 --- a/src/server/resolver.c +++ b/src/server/resolver.c @@ -43,9 +43,9 @@ static bool lookup( return true; } -static bool search(Question* question, Packet* result, BindingType type) { +static bool search(Question* question, Packet* result, BindingType type, const RecordMap* map) { IpAddr addr; - char ip[4] = {1, 1, 1, 1}; + uint8_t ip[4] = {1, 1, 1, 1}; create_ip_addr(ip, &addr); uint16_t port = 53; @@ -53,6 +53,10 @@ static bool search(Question* question, Packet* result, BindingType type) { create_socket_addr(port, addr, &saddr); while(1) { + if (record_map_get(map, question, result)) { + return true; + } + if (!lookup(question, result, type, saddr)) { return false; } @@ -75,7 +79,7 @@ static bool search(Question* question, Packet* result, BindingType type) { } Packet recurse; - if (!search(&new_question, &recurse, type)) { + if (!search(&new_question, &recurse, type, map)) { return false; } @@ -104,7 +108,7 @@ static void push_questions(Question* from, uint8_t from_len, Question** to, uint memcpy(*to + to_len, from, from_len * sizeof(Question)); } -void handle_query(Packet* request, Packet* response, BindingType type) { +void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map) { memset(response, 0, sizeof(Packet)); response->header.id = request->header.id; response->header.opcode = request->header.opcode; @@ -121,7 +125,7 @@ void handle_query(Packet* request, Packet* response, BindingType type) { Packet result; memset(&result, 0, sizeof(Packet)); result.header.id = response->header.id; - if (!search(&request->questions[i], &result, type)) { + if (!search(&request->questions[i], &result, type, map)) { response->header.response = SERVFAIL; break; } @@ -158,9 +162,14 @@ void handle_query(Packet* request, Packet* response, BindingType type) { ); response->header.resource_entries += result.header.resource_entries; - free(result.questions); - free(result.answers); - free(result.authorities); - free(result.resources); + if (result.header.z == false) { + // Not from cache + free(result.questions); + free(result.answers); + free(result.authorities); + free(result.resources); + } else { + response->header.z = true; + } } -}
\ No newline at end of file +} diff --git a/src/server/resolver.h b/src/server/resolver.h index 79b4825..993b515 100644 --- a/src/server/resolver.h +++ b/src/server/resolver.h @@ -1,6 +1,7 @@ #pragma once #include "../packet/packet.h" +#include "../io/map.h" #include "binding.h" -void handle_query(Packet* request, Packet* response, BindingType type);
\ No newline at end of file +void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map); diff --git a/src/server/server.c b/src/server/server.c index c8975ee..27cc9d9 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -1,25 +1,67 @@ -#define _POSIX_SOURCE +#define _POSIX_C_SOURCE 200809L #include <errno.h> #include <stdio.h> #include <stdlib.h> #include <unistd.h> #include <sys/wait.h> #include <signal.h> +#include <pthread.h> #include "addr.h" #include "server.h" #include "resolver.h" #include "../io/log.h" +#include "../io/map.h" +#include "../io/config.h" -static pid_t udp, tcp; +static pthread_t udp, tcp; +static RecordMap map; void server_init(uint16_t port, Server* server) { INFO("Server port set to %hu", port); create_binding(UDP, port, &server->udp); create_binding(TCP, port, &server->tcp); + load_config("config", &map); } -static void server_listen(Binding* binding) { +struct DnsRequest { + Connection connection; + Packet request; +}; + +static void* server_respond(void* arg) { + struct DnsRequest req = *(struct DnsRequest*) arg; + + INFO("Recieved packet request ID %hu", req.request.header.id); + + Packet response; + handle_query(&req.request, &response, req.connection.type, &map); + + if (!write_connection(&req.connection, &response)) { + ERROR("Faled to respond to connection ID %hu: %s", + req.request.header.id, + strerror(errno) + ); + } + + free_packet(&req.request); + free_connection(&req.connection); + + if (response.header.z == false) { + free_packet(&response); + } else { + // Dont free from config + free(response.questions); + free(response.answers); + free(response.authorities); + free(response.resources); + } + + return NULL; +} + +static void* server_listen(void* arg) { + Binding* binding = (Binding*) arg; while(1) { Connection connection; @@ -34,67 +76,55 @@ static void server_listen(Binding* binding) { free_connection(&connection); continue; } + + struct DnsRequest req; + req.connection = connection; + req.request = request; - if(fork() != 0) { + pthread_t thread; + if(pthread_create(&thread, NULL, &server_respond, &req)) { + ERROR("Failed to create thread for dns request ID %hu: %s", + request.header.id, + strerror(errno) + ); free_packet(&request); free_connection(&connection); continue; } - - INFO("Recieved packet request ID %hu", request.header.id); - - Packet response; - handle_query(&request, &response, connection.type); - - if (!write_connection(&connection, &response)) { - ERROR("Failed to respond to connection ID %hu: %s", request.header.id, strerror(errno)); - } - - free_packet(&request); - free_packet(&response); - free_connection(&connection); - exit(EXIT_SUCCESS); } + + return NULL; } static void signal_handler() { printf("\n"); - kill(udp, SIGTERM); - kill(tcp, SIGTERM); + pthread_kill(udp, SIGTERM); + pthread_kill(tcp, SIGTERM); } void server_run(Server* server) { - if ((udp = fork()) == 0) { + if (!pthread_create(&udp, NULL, &server_listen, &server->udp)) { INFO("Listening for connections on UDP"); - server_listen(&server->udp); - exit(EXIT_SUCCESS); + } else { + ERROR("Failed to start UDP thread"); + exit(EXIT_FAILURE); } - if ((tcp = fork()) == 0) { + if (!pthread_create(&tcp, NULL, &server_listen, &server->tcp)) { INFO("Listening for connections on TCP"); - server_listen(&server->tcp); - exit(EXIT_SUCCESS); - } - - signal(SIGINT, signal_handler); - - int status; - waitpid(udp, &status, 0); - if (status == 0) { - INFO("UDP process closed successfully"); } else { - ERROR("UDP process failed with error code %d", status); + ERROR("Failed to start TCP thread"); + exit(EXIT_FAILURE); } + signal(SIGINT, signal_handler); - waitpid(tcp, &status, 0); - if (status == 0) { - INFO("TCP process closed successfully"); - } else { - ERROR("TCP process failed with error code %d", status); - } + pthread_join(udp, NULL); + pthread_join(tcp, NULL); } void server_free(Server* server) { free_binding(&server->udp); free_binding(&server->tcp); + record_map_free(&map); } + |