#include #include #include #include "resolver.h" #include "addr.h" #include "binding.h" #include "../io/log.h" static bool lookup( Question* question, Packet* response, BindingType type, SocketAddr addr ) { INIT_LOG_BUFFER(log) LOGONLY(print_socket_addr(&addr, log);) TRACE("Attempting lookup on fallback dns %s", log); Connection request; if (!create_request(type, &addr, &request)) { return false; } Packet req; memset(&req, 0, sizeof(Packet)); req.header.id = response->header.id; req.header.opcode = response->header.opcode; req.header.questions = 1; req.header.recursion_desired = true; req.questions = malloc(sizeof(Question)); req.questions[0] = *question; if (!request_packet(&request, &req, response)) { free_request(&request); free(req.questions); ERROR("Failed to request fallback dns: %s", strerror(errno)); return false; } free_request(&request); free(req.questions); return true; } static bool search(Question* question, Packet* result, BindingType type, const RecordMap* map) { IpAddr addr; uint8_t ip[4] = {9,9,9,9}; create_ip_addr(ip, &addr); uint16_t port = 53; SocketAddr saddr; create_socket_addr(port, addr, &saddr); while(1) { if (record_map_get(map, question, result)) { TRACE("Found answer in user defined config"); return true; } if (!lookup(question, result, type, saddr)) { return false; } if (result->header.answers > 0 && result->header.rescode == NOERROR) { return true; } if (result->header.rescode == NXDOMAIN) { return true; } if (get_resolved_ns(result, question->domain, &addr)) { continue; } Question new_question; if (!get_unresoled_ns(result, question->domain, &new_question)) { return true; } Packet recurse; if (!search(&new_question, &recurse, type, map)) { return false; } free_question(&new_question); IpAddr random; if (!get_random_a(&recurse, &random)) { free_packet(&recurse); return true; } else { free_packet(&recurse); addr = random; } } } static void push_records(Record* from, uint8_t from_len, Record** to, uint8_t to_len) { if(from_len < 1) return; *to = realloc(*to, sizeof(Record) * (from_len + to_len)); memcpy(*to + to_len, from, from_len * sizeof(Record)); } static void push_questions(Question* from, uint8_t from_len, Question** to, uint8_t to_len) { if(from_len < 1) return; *to = realloc(*to, sizeof(Question) * (from_len + to_len)); memcpy(*to + to_len, from, from_len * sizeof(Question)); } void handle_query(const Packet* request, Packet* response, BindingType type, const RecordMap* map) { memset(response, 0, sizeof(Packet)); response->header.id = request->header.id; response->header.opcode = request->header.opcode; response->header.recursion_desired = true; response->header.recursion_available = true; response->header.response = true; if (request->header.questions < 1) { response->header.response = FORMERR; return; } for (uint16_t i = 0; i < request->header.questions; i++) { Packet result; memset(&result, 0, sizeof(Packet)); result.header.id = response->header.id; if (!search(&request->questions[i], &result, type, map)) { response->header.response = SERVFAIL; break; } push_questions( result.questions, result.header.questions, &response->questions, response->header.questions ); response->header.questions += result.header.questions; push_records( result.answers, result.header.answers, &response->answers, response->header.answers ); response->header.answers += result.header.answers; push_records( result.authorities, result.header.authoritative_entries, &response->authorities, response->header.authoritative_entries ); response->header.authoritative_entries += result.header.authoritative_entries; push_records( result.resources, result.header.resource_entries, &response->resources, response->header.resource_entries ); response->header.resource_entries += result.header.resource_entries; if (result.header.z == false) { // Not from cache free(result.questions); free(result.answers); free(result.authorities); free(result.resources); } else { response->header.z = true; } } }