use rusqlite::OptionalExtension; use tracing::instrument; use crate::{database, types::like::Like}; pub fn init() -> Result<(), rusqlite::Error> { let sql = " CREATE TABLE IF NOT EXISTS likes ( user_id INTEGER NOT NULL, post_id INTEGER NOT NULL, FOREIGN KEY(user_id) REFERENCES users(user_id), FOREIGN KEY(post_id) REFERENCES posts(post_id), PRIMARY KEY (user_id, post_id) ); "; let conn = database::connect()?; conn.execute(sql, ())?; Ok(()) } #[instrument()] pub fn get_like_count(post_id: u64) -> Result, rusqlite::Error> { tracing::trace!("Retrieving like count"); let conn = database::connect()?; let mut stmt = conn.prepare("SELECT COUNT(post_id) FROM likes WHERE post_id = ?")?; let row = stmt .query_row([post_id], |row| { let row = row.get(0)?; Ok(row) }) .optional()?; Ok(row) } #[instrument()] pub fn get_liked(user_id: u64, post_id: u64) -> Result { tracing::trace!("Retrieving if liked"); let conn = database::connect()?; let mut stmt = conn.prepare("SELECT * FROM likes WHERE user_id = ? AND post_id = ?")?; let liked = stmt.query_row([user_id, post_id], |_| Ok(())).optional()?; Ok(liked.is_some()) } #[instrument()] pub fn add_liked(user_id: u64, post_id: u64) -> Result { tracing::trace!("Adding like"); let conn = database::connect()?; let mut stmt = conn.prepare("INSERT OR REPLACE INTO likes (user_id, post_id) VALUES (?,?)")?; let changes = stmt.execute([user_id, post_id])?; Ok(changes == 1) } #[instrument()] pub fn remove_liked(user_id: u64, post_id: u64) -> Result { tracing::trace!("Removing like"); let conn = database::connect()?; let mut stmt = conn.prepare("DELETE FROM likes WHERE user_id = ? AND post_id = ?;")?; let changes = stmt.execute((user_id, post_id))?; Ok(changes == 1) } #[instrument()] pub fn get_all_likes() -> Result, rusqlite::Error> { tracing::trace!("Retrieving comments page"); let conn = database::connect()?; let mut stmt = conn.prepare("SELECT * FROM likes")?; let row = stmt.query_map([], |row| { let like = Like { user_id: row.get(0)?, post_id: row.get(1)?, }; Ok(like) })?; Ok(row.into_iter().flatten().collect()) }