summaryrefslogtreecommitdiff
path: root/src/server
diff options
context:
space:
mode:
Diffstat (limited to 'src/server')
-rw-r--r--src/server/addr.c8
-rw-r--r--src/server/addr.h4
-rw-r--r--src/server/binding.c6
-rw-r--r--src/server/resolver.c29
-rw-r--r--src/server/resolver.h3
-rw-r--r--src/server/server.c112
6 files changed, 101 insertions, 61 deletions
diff --git a/src/server/addr.c b/src/server/addr.c
index 982da13..14f44c6 100644
--- a/src/server/addr.c
+++ b/src/server/addr.c
@@ -7,12 +7,12 @@
#include "addr.h"
#include "../io/log.h"
-void create_ip_addr(char* domain, IpAddr* addr) {
+void create_ip_addr(uint8_t* domain, IpAddr* addr) {
addr->type = V4;
memcpy(&addr->data.v4.s_addr, domain, 4);
}
-void create_ip_addr6(char* domain, IpAddr* addr) {
+void create_ip_addr6(uint8_t* domain, IpAddr* addr) {
addr->type = V6;
memcpy(&addr->data.v6.__in6_u.__u6_addr8, domain, 16);
}
@@ -76,7 +76,7 @@ void print_socket_addr(SocketAddr* addr, char* buffer) {
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 16),
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr >> 8),
(uint8_t) ((uint32_t)addr->data.v4.sin_addr.s_addr),
- addr->data.v4.sin_port
+ ntohs(addr->data.v4.sin_port)
);
} else {
for(int i = 0; i < 8; i++) {
@@ -85,7 +85,7 @@ void print_socket_addr(SocketAddr* addr, char* buffer) {
addr->data.v6.sin6_addr.__in6_u.__u6_addr8[i*2 + 1]
);
}
- APPEND(buffer, ":[%hu]", addr->data.v6.sin6_port);
+ APPEND(buffer, ":[%hu]", ntohs(addr->data.v6.sin6_port));
}
}
diff --git a/src/server/addr.h b/src/server/addr.h
index 173c7fd..1210850 100644
--- a/src/server/addr.h
+++ b/src/server/addr.h
@@ -20,8 +20,8 @@ typedef struct {
} data;
} IpAddr;
-void create_ip_addr(char* domain, IpAddr* addr);
-void create_ip_addr6(char* domain, IpAddr* addr);
+void create_ip_addr(uint8_t* domain, IpAddr* addr);
+void create_ip_addr6(uint8_t* domain, IpAddr* addr);
void ip_addr_any(IpAddr* addr);
void ip_addr_any6(IpAddr* addr);
diff --git a/src/server/binding.c b/src/server/binding.c
index 47c62c6..157d1d4 100644
--- a/src/server/binding.c
+++ b/src/server/binding.c
@@ -144,10 +144,10 @@ bool read_connection(Connection* connection, Packet* packet) {
}
static bool write_udp(Connection* connection, uint8_t* buf, uint16_t len) {
- //if (len > 512) {
+ if (len > 512) {
buf[2] = buf[2] | 0x03;
- // len = 512;
- // }
+ len = 512;
+ }
return write_udp_socket(
&connection->sock.udp.udp,
buf,
diff --git a/src/server/resolver.c b/src/server/resolver.c
index e05f365..a1fa82a 100644
--- a/src/server/resolver.c
+++ b/src/server/resolver.c
@@ -43,9 +43,9 @@ static bool lookup(
return true;
}
-static bool search(Question* question, Packet* result, BindingType type) {
+static bool search(Question* question, Packet* result, BindingType type, const RecordMap* map) {
IpAddr addr;
- char ip[4] = {1, 1, 1, 1};
+ uint8_t ip[4] = {1, 1, 1, 1};
create_ip_addr(ip, &addr);
uint16_t port = 53;
@@ -53,6 +53,10 @@ static bool search(Question* question, Packet* result, BindingType type) {
create_socket_addr(port, addr, &saddr);
while(1) {
+ if (record_map_get(map, question, result)) {
+ return true;
+ }
+
if (!lookup(question, result, type, saddr)) {
return false;
}
@@ -75,7 +79,7 @@ static bool search(Question* question, Packet* result, BindingType type) {
}
Packet recurse;
- if (!search(&new_question, &recurse, type)) {
+ if (!search(&new_question, &recurse, type, map)) {
return false;
}
@@ -104,7 +108,7 @@ static void push_questions(Question* from, uint8_t from_len, Question** to, uint
memcpy(*to + to_len, from, from_len * sizeof(Question));
}
-void handle_query(Packet* request, Packet* response, BindingType type) {
+void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map) {
memset(response, 0, sizeof(Packet));
response->header.id = request->header.id;
response->header.opcode = request->header.opcode;
@@ -121,7 +125,7 @@ void handle_query(Packet* request, Packet* response, BindingType type) {
Packet result;
memset(&result, 0, sizeof(Packet));
result.header.id = response->header.id;
- if (!search(&request->questions[i], &result, type)) {
+ if (!search(&request->questions[i], &result, type, map)) {
response->header.response = SERVFAIL;
break;
}
@@ -158,9 +162,14 @@ void handle_query(Packet* request, Packet* response, BindingType type) {
);
response->header.resource_entries += result.header.resource_entries;
- free(result.questions);
- free(result.answers);
- free(result.authorities);
- free(result.resources);
+ if (result.header.z == false) {
+ // Not from cache
+ free(result.questions);
+ free(result.answers);
+ free(result.authorities);
+ free(result.resources);
+ } else {
+ response->header.z = true;
+ }
}
-} \ No newline at end of file
+}
diff --git a/src/server/resolver.h b/src/server/resolver.h
index 79b4825..993b515 100644
--- a/src/server/resolver.h
+++ b/src/server/resolver.h
@@ -1,6 +1,7 @@
#pragma once
#include "../packet/packet.h"
+#include "../io/map.h"
#include "binding.h"
-void handle_query(Packet* request, Packet* response, BindingType type); \ No newline at end of file
+void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map);
diff --git a/src/server/server.c b/src/server/server.c
index c8975ee..27cc9d9 100644
--- a/src/server/server.c
+++ b/src/server/server.c
@@ -1,25 +1,67 @@
-#define _POSIX_SOURCE
+#define _POSIX_C_SOURCE 200809L
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/wait.h>
#include <signal.h>
+#include <pthread.h>
#include "addr.h"
#include "server.h"
#include "resolver.h"
#include "../io/log.h"
+#include "../io/map.h"
+#include "../io/config.h"
-static pid_t udp, tcp;
+static pthread_t udp, tcp;
+static RecordMap map;
void server_init(uint16_t port, Server* server) {
INFO("Server port set to %hu", port);
create_binding(UDP, port, &server->udp);
create_binding(TCP, port, &server->tcp);
+ load_config("config", &map);
}
-static void server_listen(Binding* binding) {
+struct DnsRequest {
+ Connection connection;
+ Packet request;
+};
+
+static void* server_respond(void* arg) {
+ struct DnsRequest req = *(struct DnsRequest*) arg;
+
+ INFO("Recieved packet request ID %hu", req.request.header.id);
+
+ Packet response;
+ handle_query(&req.request, &response, req.connection.type, &map);
+
+ if (!write_connection(&req.connection, &response)) {
+ ERROR("Faled to respond to connection ID %hu: %s",
+ req.request.header.id,
+ strerror(errno)
+ );
+ }
+
+ free_packet(&req.request);
+ free_connection(&req.connection);
+
+ if (response.header.z == false) {
+ free_packet(&response);
+ } else {
+ // Dont free from config
+ free(response.questions);
+ free(response.answers);
+ free(response.authorities);
+ free(response.resources);
+ }
+
+ return NULL;
+}
+
+static void* server_listen(void* arg) {
+ Binding* binding = (Binding*) arg;
while(1) {
Connection connection;
@@ -34,67 +76,55 @@ static void server_listen(Binding* binding) {
free_connection(&connection);
continue;
}
+
+ struct DnsRequest req;
+ req.connection = connection;
+ req.request = request;
- if(fork() != 0) {
+ pthread_t thread;
+ if(pthread_create(&thread, NULL, &server_respond, &req)) {
+ ERROR("Failed to create thread for dns request ID %hu: %s",
+ request.header.id,
+ strerror(errno)
+ );
free_packet(&request);
free_connection(&connection);
continue;
}
-
- INFO("Recieved packet request ID %hu", request.header.id);
-
- Packet response;
- handle_query(&request, &response, connection.type);
-
- if (!write_connection(&connection, &response)) {
- ERROR("Failed to respond to connection ID %hu: %s", request.header.id, strerror(errno));
- }
-
- free_packet(&request);
- free_packet(&response);
- free_connection(&connection);
- exit(EXIT_SUCCESS);
}
+
+ return NULL;
}
static void signal_handler() {
printf("\n");
- kill(udp, SIGTERM);
- kill(tcp, SIGTERM);
+ pthread_kill(udp, SIGTERM);
+ pthread_kill(tcp, SIGTERM);
}
void server_run(Server* server) {
- if ((udp = fork()) == 0) {
+ if (!pthread_create(&udp, NULL, &server_listen, &server->udp)) {
INFO("Listening for connections on UDP");
- server_listen(&server->udp);
- exit(EXIT_SUCCESS);
+ } else {
+ ERROR("Failed to start UDP thread");
+ exit(EXIT_FAILURE);
}
- if ((tcp = fork()) == 0) {
+ if (!pthread_create(&tcp, NULL, &server_listen, &server->tcp)) {
INFO("Listening for connections on TCP");
- server_listen(&server->tcp);
- exit(EXIT_SUCCESS);
- }
-
- signal(SIGINT, signal_handler);
-
- int status;
- waitpid(udp, &status, 0);
- if (status == 0) {
- INFO("UDP process closed successfully");
} else {
- ERROR("UDP process failed with error code %d", status);
+ ERROR("Failed to start TCP thread");
+ exit(EXIT_FAILURE);
}
+ signal(SIGINT, signal_handler);
- waitpid(tcp, &status, 0);
- if (status == 0) {
- INFO("TCP process closed successfully");
- } else {
- ERROR("TCP process failed with error code %d", status);
- }
+ pthread_join(udp, NULL);
+ pthread_join(tcp, NULL);
}
void server_free(Server* server) {
free_binding(&server->udp);
free_binding(&server->tcp);
+ record_map_free(&map);
}
+