summaryrefslogtreecommitdiff
path: root/src/dns/packet/buffer.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/dns/packet/buffer.rs')
-rw-r--r--src/dns/packet/buffer.rs227
1 files changed, 227 insertions, 0 deletions
diff --git a/src/dns/packet/buffer.rs b/src/dns/packet/buffer.rs
new file mode 100644
index 0000000..058156e
--- /dev/null
+++ b/src/dns/packet/buffer.rs
@@ -0,0 +1,227 @@
+use crate::Result;
+
+pub struct PacketBuffer {
+ pub buf: Vec<u8>,
+ pub pos: usize,
+ pub size: usize,
+}
+
+impl PacketBuffer {
+ pub fn new(buf: Vec<u8>) -> Self {
+ Self {
+ size: buf.len(),
+ buf,
+ pos: 0,
+ }
+ }
+
+ pub fn pos(&self) -> usize {
+ self.pos
+ }
+
+ pub fn step(&mut self, steps: usize) -> Result<()> {
+ self.pos += steps;
+
+ Ok(())
+ }
+
+ pub fn seek(&mut self, pos: usize) -> Result<()> {
+ self.pos = pos;
+
+ Ok(())
+ }
+
+ pub fn read(&mut self) -> Result<u8> {
+ if self.pos >= self.size {
+ return Err("Tried to read past end of buffer".into());
+ }
+ let res = self.buf[self.pos];
+ self.pos += 1;
+ Ok(res)
+ }
+
+ pub fn get(&mut self, pos: usize) -> Result<u8> {
+ if pos >= self.size {
+ return Err("Tried to read past end of buffer".into());
+ }
+ Ok(self.buf[pos])
+ }
+
+ pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
+ if start + len >= self.size {
+ return Err("Tried to read past end of buffer".into());
+ }
+ Ok(&self.buf[start..start + len])
+ }
+
+ pub fn read_u16(&mut self) -> Result<u16> {
+ let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
+
+ Ok(res)
+ }
+
+ pub fn read_u32(&mut self) -> Result<u32> {
+ let res = ((self.read()? as u32) << 24)
+ | ((self.read()? as u32) << 16)
+ | ((self.read()? as u32) << 8)
+ | (self.read()? as u32);
+
+ Ok(res)
+ }
+
+ pub fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
+ let mut pos = self.pos();
+ let mut jumped = false;
+
+ let mut delim = "";
+ let max_jumps = 5;
+ let mut jumps_performed = 0;
+ loop {
+ // Dns Packets are untrusted data, so we need to be paranoid. Someone
+ // can craft a packet with a cycle in the jump instructions. This guards
+ // against such packets.
+ if jumps_performed > max_jumps {
+ return Err(format!("Limit of {max_jumps} jumps exceeded").into());
+ }
+
+ let len = self.get(pos)?;
+
+ if (len & 0xC0) == 0xC0 {
+ if !jumped {
+ self.seek(pos + 2)?;
+ }
+
+ let b2 = self.get(pos + 1)? as u16;
+ let offset = (((len as u16) ^ 0xC0) << 8) | b2;
+ pos = offset as usize;
+ jumped = true;
+ jumps_performed += 1;
+ continue;
+ }
+
+ pos += 1;
+
+ if len == 0 {
+ break;
+ }
+
+ outstr.push_str(delim);
+
+ let str_buffer = self.get_range(pos, len as usize)?;
+ outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
+
+ delim = ".";
+
+ pos += len as usize;
+ }
+
+ if !jumped {
+ self.seek(pos)?;
+ }
+
+ Ok(())
+ }
+
+ pub fn read_string(&mut self, outstr: &mut String) -> Result<()> {
+ let len = self.read()?;
+
+ self.read_string_n(outstr, len)?;
+
+ Ok(())
+ }
+
+ pub fn read_string_n(&mut self, outstr: &mut String, len: u8) -> Result<()> {
+ let mut pos = self.pos;
+
+ let str_buffer = self.get_range(pos, len as usize)?;
+
+ let mut i = 0;
+ for b in str_buffer {
+ let c = *b as char;
+ if c == '\0' {
+ break;
+ }
+ outstr.push(c);
+ i += 1;
+ }
+
+ pos += i;
+ self.seek(pos)?;
+
+ Ok(())
+ }
+
+ pub fn write(&mut self, val: u8) -> Result<()> {
+ if self.size < self.pos {
+ self.size = self.pos;
+ }
+
+ if self.buf.len() <= self.size {
+ self.buf.resize(self.size + 1, 0x00);
+ }
+
+ self.buf[self.pos] = val;
+ self.pos += 1;
+ Ok(())
+ }
+
+ pub fn write_u8(&mut self, val: u8) -> Result<()> {
+ self.write(val)?;
+
+ Ok(())
+ }
+
+ pub fn write_u16(&mut self, val: u16) -> Result<()> {
+ self.write((val >> 8) as u8)?;
+ self.write((val & 0xFF) as u8)?;
+
+ Ok(())
+ }
+
+ pub fn write_u32(&mut self, val: u32) -> Result<()> {
+ self.write(((val >> 24) & 0xFF) as u8)?;
+ self.write(((val >> 16) & 0xFF) as u8)?;
+ self.write(((val >> 8) & 0xFF) as u8)?;
+ self.write((val & 0xFF) as u8)?;
+
+ Ok(())
+ }
+
+ pub fn write_qname(&mut self, qname: &str) -> Result<()> {
+ for label in qname.split('.') {
+ let len = label.len();
+
+ self.write_u8(len as u8)?;
+ for b in label.as_bytes() {
+ self.write_u8(*b)?;
+ }
+ }
+
+ if !qname.is_empty() {
+ self.write_u8(0)?;
+ }
+
+ Ok(())
+ }
+
+ pub fn write_string(&mut self, text: &str) -> Result<()> {
+ for b in text.as_bytes() {
+ self.write_u8(*b)?;
+ }
+
+ Ok(())
+ }
+
+ pub fn set(&mut self, pos: usize, val: u8) -> Result<()> {
+ self.buf[pos] = val;
+
+ Ok(())
+ }
+
+ pub fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> {
+ self.set(pos, (val >> 8) as u8)?;
+ self.set(pos + 1, (val & 0xFF) as u8)?;
+
+ Ok(())
+ }
+}