use std::collections::HashMap; use axum::{response::Response, Router, routing::{post, patch, delete, get}, extract::{ws::Message, WebSocketUpgrade}}; use serde::Deserialize; use tokio::sync::{Mutex, mpsc::{Sender, self}}; use crate::{ public::docs::{EndpointDocumentation, EndpointMethod}, types::{ extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log}, http::ResponseCode, chat::{ChatRoom, ChatEvent}, user::User, }, }; use std::collections::hash_map::Values; use lazy_static::lazy_static; lazy_static!( static ref CONNECTIONS: Mutex> = Mutex::new(HashMap::new()); ); struct ConnectionPool { inner: HashMap>, index: usize } impl ConnectionPool { fn new() -> Self { Self { inner: HashMap::new(), index: 0 } } fn add(&mut self, send: Sender) -> usize { let idx = self.index; self.index += 1; self.inner.insert(idx, send); idx } fn del(&mut self, idx: &usize) { self.inner.remove(idx); } fn values(&self) -> Values<'_, usize, Sender> { self.inner.values() } } async fn send_event(event: ChatEvent, room: &ChatRoom) { for user in &room.users { let lock = CONNECTIONS.lock().await; let Some(connection) = lock.get(&user) else { continue }; for channel in connection.values() { channel.send(event.clone()).await.ok(); } } } pub const CHAT_LIST: EndpointDocumentation = EndpointDocumentation { uri: "/api/chat/list", method: EndpointMethod::Post, description: "Returns the rooms you are in", body: None, responses: &[ (201, "Returns rooms in a list"), (400, "Body does not match parameters"), (401, "Unauthorized"), (500, "Failed to retrieve rooms"), ], cookie: Some("auth"), }; async fn list ( AuthorizedUser(user): AuthorizedUser, Database(db): Database, _: Log ) -> Response { let Ok(rooms) = ChatRoom::from_user_id(&db, user.user_id) else { return ResponseCode::InternalServerError.text("Failed to retrieve rooms") }; let Ok(json) = serde_json::to_string(&rooms) else { return ResponseCode::InternalServerError.text("Failed to retrieve rooms") }; ResponseCode::Success.json(&json) } pub const CHAT_CREATE: EndpointDocumentation = EndpointDocumentation { uri: "/api/chat/create", method: EndpointMethod::Post, description: "Creates a new room", body: Some( r#" { "name" : "Funny memes" } "#, ), responses: &[ (201, "Successfully created room"), (400, "Body does not match parameters"), (401, "Unauthorized"), (500, "Failed to create room"), ], cookie: Some("auth"), }; #[derive(Deserialize)] struct RoomCreateRequest { name: String, } impl Check for RoomCreateRequest { fn check(&self) -> CheckResult { Self::assert_length( &self.name, 1, 255, "Room names must be between 1-255 characters long", )?; Ok(()) } } async fn create ( AuthorizedUser(user): AuthorizedUser, Database(db): Database, Json(body): Json, ) -> Response { let Ok(room) = ChatRoom::new(&db, vec![user.user_id], body.name) else { return ResponseCode::InternalServerError.text("Failed to create room") }; for user in &room.users { send_event(ChatEvent::Add { user_id: *user, room: room.clone() }, &room).await; } let Ok(json) = serde_json::to_string(&room) else { return ResponseCode::InternalServerError.text("Failed to create room") }; ResponseCode::Created.json(&json) } pub const CHAT_ADD: EndpointDocumentation = EndpointDocumentation { uri: "/api/chat/add", method: EndpointMethod::Patch, description: "Adds a user to a room", body: Some( r#" { "room_id": 69, "email" : "joebide@house.gov" } "#, ), responses: &[ (201, "Successfully added user"), (400, "Body does not match parameters"), (401, "Unauthorized"), (500, "Failed to add user"), ], cookie: Some("auth"), }; #[derive(Deserialize)] struct AddUserRequest { room_id: u64, email: String, } impl Check for AddUserRequest { fn check(&self) -> CheckResult { Ok(()) } } async fn add ( AuthorizedUser(user): AuthorizedUser, Database(db): Database, Json(body): Json, ) -> Response { let Ok(to_add) = User::from_email(&db, &body.email) else { return ResponseCode::BadRequest.text("User does not exist") }; let Ok(mut room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else { return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") }; if room.users.contains(&to_add.user_id) { return ResponseCode::BadRequest.text("User is already in the room") } let Ok(success) = room.add_user(&db, to_add.user_id) else { return ResponseCode::InternalServerError.text("Failed to add user") }; if !success { return ResponseCode::BadRequest.text("User is already in the room") } room.users.push(to_add.user_id); send_event(ChatEvent::Add { user_id: to_add.user_id, room: room.clone() }, &room).await; ResponseCode::Success.text("Successfully added user") } pub const CHAT_LEAVE: EndpointDocumentation = EndpointDocumentation { uri: "/api/chat/leave", method: EndpointMethod::Delete, description: "Leaves a room", body: Some( r#" { "room_id": 69 } "#, ), responses: &[ (201, "Successfully left room"), (400, "Body does not match parameters"), (401, "Unauthorized"), (500, "Failed to leave a room"), ], cookie: Some("auth"), }; #[derive(Deserialize)] struct LeaveRoomRequest { room_id: u64, } impl Check for LeaveRoomRequest { fn check(&self) -> CheckResult { Ok(()) } } async fn leave ( AuthorizedUser(user): AuthorizedUser, Database(db): Database, Json(body): Json, ) -> Response { let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else { return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") }; let Ok(success) = room.remove_user(&db, user.user_id) else { return ResponseCode::InternalServerError.text("Failed to leave room") }; if !success { return ResponseCode::BadRequest.text("You are currently not in this room (how did this happen?)") } send_event(ChatEvent::Leave { user_id: user.user_id, room_id: room.room_id }, &room).await; ResponseCode::Success.text("Successfully left room") } pub const CHAT_SEND: EndpointDocumentation = EndpointDocumentation { uri: "/api/chat/send", method: EndpointMethod::Post, description: "Send a message to a room", body: Some( r#" { "room_id": 420, "content" : "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" } "#, ), responses: &[ (201, "Successfully sent message"), (400, "Body does not match parameters"), (401, "Unauthorized"), (500, "Failed to send message"), ], cookie: Some("auth"), }; #[derive(Deserialize)] struct SendMessageRequest { room_id: u64, content: String } impl Check for SendMessageRequest { fn check(&self) -> CheckResult { Self::assert_length( &self.content, 1, 500, "Messages must be between 1-500 length" )?; Ok(()) } } async fn send ( AuthorizedUser(user): AuthorizedUser, Database(db): Database, Json(body): Json, ) -> Response { let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else { return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") }; let Ok(msg) = room.send_message(&db, user.user_id, body.content) else { return ResponseCode::InternalServerError.text("Failed to send message") }; send_event(ChatEvent::Message { user_id: msg.user_id, room_id: msg.room_id, message_id: msg.message_id, content: msg.content, date: msg.date }, &room).await; ResponseCode::Created.text("Successfully sent message") } pub const CHAT_LOAD: EndpointDocumentation = EndpointDocumentation { uri: "/api/chat/load", method: EndpointMethod::Post, description: "Get a page of historic room messages starting before given message id", body: Some( r#" { "room_id": 69, "newest_msg": 400, "page": 3 } "#, ), responses: &[ (201, "Successfully sent message"), (400, "Body does not match parameters"), (401, "Unauthorized"), (500, "Failed to send message"), ], cookie: Some("auth"), }; #[derive(Deserialize)] struct LoadMessagesRequest { room_id: u64, newest_msg: u64, page: u64 } impl Check for LoadMessagesRequest { fn check(&self) -> CheckResult { Ok(()) } } async fn load ( AuthorizedUser(user): AuthorizedUser, Database(db): Database, Json(body): Json, ) -> Response { let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else { return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") }; let Ok(msgs) = room.load_old_chat_messages(&db, body.newest_msg, body.page) else { return ResponseCode::InternalServerError.text("Failed to load messages") }; let Ok(json) = serde_json::to_string(&msgs) else { return ResponseCode::InternalServerError.text("Failed to load messages") }; ResponseCode::Created.json(&json) } pub const CHAT_TYPING: EndpointDocumentation = EndpointDocumentation { uri: "/api/chat/typing", method: EndpointMethod::Post, description: "Set if your typing in a given room", body: Some( r#" { "room_id": 69, } "#, ), responses: &[ (201, "Successfully sent typing indicator"), (400, "Body does not match parameters"), (401, "Unauthorized"), (500, "Failed to send typing indicator"), ], cookie: Some("auth"), }; #[derive(Deserialize)] struct TypingRequest { room_id: u64, } impl Check for TypingRequest { fn check(&self) -> CheckResult { Ok(()) } } async fn typing ( AuthorizedUser(user): AuthorizedUser, Database(db): Database, Json(body): Json, ) -> Response { let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else { return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") }; send_event(ChatEvent::Typing { user_id: user.user_id, room_id: room.room_id, }, &room).await; ResponseCode::Success.text("Successfully sent typing indicator") } pub const CHAT_CONNECT: EndpointDocumentation = EndpointDocumentation { uri: "/api/chat/connect", method: EndpointMethod::Get, description: "Start a websocket connection for chat events", body: None, responses: &[], cookie: Some("auth"), }; async fn connect ( AuthorizedUser(user): AuthorizedUser, ws: WebSocketUpgrade ) -> Response { ws.on_upgrade(|mut ws| async move { let user = user; let (send, mut recv) = mpsc::channel::(20); let id: usize; { let mut lock = CONNECTIONS.lock().await; match lock.get_mut(&user.user_id) { Some(pool) => { id = pool.add(send); }, None => { let mut pool = ConnectionPool::new(); id = pool.add(send); lock.insert(user.user_id, pool); } }; } loop { tokio::select! { m = ws.recv() => { let Some(Ok(_)) = m else { break; }; } s = recv.recv() => { let Some(msg) = s else { break; }; if let Ok(string) = serde_json::to_string(&msg) { ws.send(Message::Text(string)).await.ok(); } } } } let mut lock = CONNECTIONS.lock().await; if let Some(conn) = lock.get_mut(&user.user_id) { conn.del(&id); }; }) } pub fn router() -> Router { Router::new() .route("/create", post(create)) .route("/list", post(list)) .route("/add", patch(add)) .route("/leave", delete(leave)) .route("/send", post(send)) .route("/load", post(load)) .route("/typing", post(typing)) .route("/connect", get(connect)) }