use crate::prelude::*;
use std::io::{self, Read, Write};

mod serialize;
mod deserialize;
mod prim;

pub struct Program {
    version: u8,
    fun: Rc<Function>
}

const PROGRAM_HEADER: [u8; 5]  = [0x00, 0x4d, 0x41, 0x54, 0x0a];

impl Program {
    pub fn load(body: &str) -> Result<Option<Rc<Function>>> {
        let mut bytes = body.as_bytes();
        if bytes.len() < 6 {
            return Ok(None)
        }
        let header = &bytes[0..5];
        if header != &PROGRAM_HEADER {
            return Ok(None)
        }
        let mut s = ProgramDeserializer::from(&mut bytes);
        let program = <Self>::deserialize(&mut s)?;
        s.finish()?;
        Ok(Some(program.fun.clone()))
    }

    pub fn save<W: Write>(fun: Rc<Function>, w: &mut W) -> Result<()> {
        let mut s = ProgramSerializer::from(w);
        let p = Program {
            version: 0,
            fun
        };
        s.serialize(&p)?;
        s.finish()?;
        Ok(())
    }
}

pub trait Primitive : Sized {
    fn write<W: Write>(&self, w: &mut W) -> io::Result<()>;
    fn read<R: Read>(r: &mut R) -> io::Result<Self>;
}

pub trait Serialize : Sized {
    fn serialize<S: Serializer>(&self, s: &mut S) -> Result<()>;
}

pub trait Serializer : Sized {
    fn serialize<S: Serialize>(&mut self, val: &S) -> Result<()> {
        val.serialize(self)
    }
    fn write<P: Primitive>(&mut self, val: P) -> Result<()>;
}

pub trait Deserialize : Sized {
    fn deserialize<S: Deserializer>(s: &mut S) -> Result<Self>;
}

pub trait Deserializer : Sized {
    fn deserialize<D: Deserialize>(&mut self) -> Result<D> {
        D::deserialize(self)
    }
    fn read<P: Primitive>(&mut self) -> Result<P>;
}

macro_rules! error {
    ($($arg:tt)*) => {
        exception!(BINARY_EXCEPTION, $($arg)*)
    };
}

pub struct ProgramSerializer<'w, W: Write> {
    writer: &'w mut W,
    checksum: u64,
}

impl<'w, W: Write> ProgramSerializer<'w, W> {
    fn finish(self) -> Result<()> {
        let bytes = self.checksum.to_le_bytes();
        self.writer.write(&bytes).map_err(|e| error!("{e}"))?;
        Ok(())
    }
}

impl<'w, W: Write> Serializer for ProgramSerializer<'w, W> {
    fn write<P: Primitive>(&mut self, val: P) -> Result<()> {
        val.write(self).map_err(|e| error!("{e}"))
    }
}

impl<'w, W: Write> Write for ProgramSerializer<'w, W> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        for b in buf {
            self.checksum %= 0xf1e3beef;
            self.checksum += *b as u64;
        }
        self.writer.write(buf)
    }

    fn flush(&mut self) -> io::Result<()> {
        self.writer.flush()
    }
}

impl<'w, W: Write> From<&'w mut W> for ProgramSerializer<'w, W> {
    fn from(writer: &'w mut W) -> Self {
        Self { writer, checksum: 0xfe }
    }
}

pub struct ProgramDeserializer<'r, R: Read> {
    reader: &'r mut R,
    checksum: u64,
}

impl<'r, R: Read> ProgramDeserializer<'r, R> {
    fn finish(self) -> Result<()> {
        let mut bytes = [0u8; 8];
        self.reader.read_exact(&mut bytes).map_err(|e| error!("{e}"))?;
        let checksum = u64::from_le_bytes(bytes);
        if self.checksum != checksum {
            return Err(error!("checksum doesnt match"))
        }
        Ok(())
    }
}

impl<'r, R: Read> Deserializer for ProgramDeserializer<'r, R> {
    fn read<P: Primitive>(&mut self) -> Result<P> {
        P::read(self).map_err(|e| error!("{e}"))
    }
}

impl<'r, R: Read> Read for ProgramDeserializer<'r, R> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let c = self.reader.read(buf)?;
        for i in 0..c {
            let b = buf[i];
            self.checksum %= 0xf1e3beef;
            self.checksum += b as u64;
        }
        Ok(c)
    }
}

impl<'r, R: Read> From<&'r mut R> for ProgramDeserializer<'r, R> {
    fn from(reader: &'r mut R) -> Self {
        Self { reader, checksum: 0xfe }
    }
}