diff options
Diffstat (limited to 'src/api/chat.rs')
-rw-r--r-- | src/api/chat.rs | 193 |
1 files changed, 187 insertions, 6 deletions
diff --git a/src/api/chat.rs b/src/api/chat.rs index 0f3b92e..1e56c3e 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -1,13 +1,63 @@ -use axum::{response::Response, Router, routing::{post, patch, delete}}; +use std::collections::HashMap; + +use axum::{response::Response, Router, routing::{post, patch, delete, get}, extract::{ws::Message, WebSocketUpgrade}}; use serde::Deserialize; +use tokio::sync::{Mutex, mpsc::{Sender, self}}; use crate::{ public::docs::{EndpointDocumentation, EndpointMethod}, types::{ extract::{AuthorizedUser, Check, CheckResult, Database, Json}, http::ResponseCode, - chat::ChatRoom, user::User, + chat::{ChatRoom, ChatEvent}, user::User, }, }; +use std::collections::hash_map::Values; +use lazy_static::lazy_static; + +lazy_static!( + static ref CONNECTIONS: Mutex<HashMap<u64, ConnectionPool>> = Mutex::new(HashMap::new()); +); + +struct ConnectionPool { + inner: HashMap<usize, Sender<ChatEvent>>, + index: usize +} + +impl ConnectionPool { + fn new() -> Self { + Self { + inner: HashMap::new(), + index: 0 + } + } + + fn add(&mut self, send: Sender<ChatEvent>) -> usize { + let idx = self.index; + self.index += 1; + self.inner.insert(idx, send); + idx + } + + fn del(&mut self, idx: &usize) { + self.inner.remove(idx); + } + + fn values(&self) -> Values<'_, usize, Sender<ChatEvent>> { + self.inner.values() + } +} + +async fn send_event(event: ChatEvent, room: &ChatRoom) { + for user in &room.users { + let lock = CONNECTIONS.lock().await; + let Some(connection) = lock.get(&user) else { + continue + }; + for channel in connection.values() { + channel.send(event.clone()).await.ok(); + } + } +} pub const CHAT_LIST: EndpointDocumentation = EndpointDocumentation { uri: "/api/chat/list", @@ -80,11 +130,18 @@ async fn create ( Database(db): Database, Json(body): Json<RoomCreateRequest>, ) -> Response { - let Ok(post) = ChatRoom::new(&db, vec![user.user_id], body.name) else { + let Ok(room) = ChatRoom::new(&db, vec![user.user_id], body.name) else { return ResponseCode::InternalServerError.text("Failed to create room") }; - let Ok(json) = serde_json::to_string(&post) else { + for user in &room.users { + send_event(ChatEvent::Add { + user_id: *user, + room_id: room.room_id + }, &room).await; + } + + let Ok(json) = serde_json::to_string(&room) else { return ResponseCode::InternalServerError.text("Failed to create room") }; @@ -145,6 +202,11 @@ async fn add ( if !success { return ResponseCode::BadRequest.text("User is already in the room") } + + send_event(ChatEvent::Add { + user_id: to_add.user_id, + room_id: room.room_id + }, &room).await; ResponseCode::Success.text("Successfully added user") } @@ -197,6 +259,11 @@ async fn leave ( if !success { return ResponseCode::BadRequest.text("You are currently not in this room (how did this happen?)") } + + send_event(ChatEvent::Leave { + user_id: user.user_id, + room_id: room.room_id + }, &room).await; ResponseCode::Success.text("Successfully left room") } @@ -250,9 +317,17 @@ async fn send ( return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") }; - let Ok(_msg) = room.send_message(&db, user.user_id, body.content) else { + let Ok(msg) = room.send_message(&db, user.user_id, body.content) else { return ResponseCode::InternalServerError.text("Failed to send message") }; + + send_event(ChatEvent::Message { + user_id: msg.user_id, + room_id: msg.room_id, + message_id: msg.message_id, + content: msg.content, + date: msg.date + }, &room).await; ResponseCode::Created.text("Successfully sent message") } @@ -302,7 +377,7 @@ async fn load ( return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") }; - let Ok(msgs) = room.load_old_chat_messagegs(&db, body.newest_msg, body.page) else { + let Ok(msgs) = room.load_old_chat_messages(&db, body.newest_msg, body.page) else { return ResponseCode::InternalServerError.text("Failed to load messages") }; @@ -313,6 +388,110 @@ async fn load ( ResponseCode::Created.json(&json) } +pub const CHAT_TYPING: EndpointDocumentation = EndpointDocumentation { + uri: "/api/chat/typing", + method: EndpointMethod::Post, + description: "Set if your typing in a given room", + body: Some( + r#" + { + "room_id": 69, + } + "#, + ), + responses: &[ + (201, "Successfully sent typing indicator"), + (400, "Body does not match parameters"), + (401, "Unauthorized"), + (500, "Failed to send typing indicator"), + ], + cookie: Some("auth"), +}; + +#[derive(Deserialize)] +struct TypingRequest { + room_id: u64, +} + +impl Check for TypingRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + +async fn typing ( + AuthorizedUser(user): AuthorizedUser, + Database(db): Database, + Json(body): Json<TypingRequest>, +) -> Response { + + let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else { + return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") + }; + + send_event(ChatEvent::Typing { + user_id: user.user_id, + room_id: room.room_id, + }, &room).await; + + ResponseCode::Success.text("Successfully sent typing indicator") +} + +pub const CHAT_CONNECT: EndpointDocumentation = EndpointDocumentation { + uri: "/api/chat/connect", + method: EndpointMethod::Get, + description: "Start a websocket connection for chat events", + body: None, + responses: &[], + cookie: Some("auth"), +}; + +async fn connect ( + AuthorizedUser(user): AuthorizedUser, + ws: WebSocketUpgrade +) -> Response { + ws.on_upgrade(|mut ws| async move { + let user = user; + let (send, mut recv) = mpsc::channel::<ChatEvent>(20); + let id: usize; + { + let mut lock = CONNECTIONS.lock().await; + match lock.get_mut(&user.user_id) { + Some(pool) => { + id = pool.add(send); + }, + None => { + let mut pool = ConnectionPool::new(); + id = pool.add(send); + lock.insert(user.user_id, pool); + } + }; + } + loop { + tokio::select! { + m = ws.recv() => { + let Some(Ok(_)) = m else { + break; + }; + } + s = recv.recv() => { + let Some(msg) = s else { + break; + }; + if let Ok(string) = serde_json::to_string(&msg) { + ws.send(Message::Text(string)).await.ok(); + } + } + } + } + + let mut lock = CONNECTIONS.lock().await; + if let Some(conn) = lock.get_mut(&user.user_id) { + conn.del(&id); + }; + }) +} + pub fn router() -> Router { Router::new() .route("/create", post(create)) @@ -321,4 +500,6 @@ pub fn router() -> Router { .route("/leave", delete(leave)) .route("/send", post(send)) .route("/load", post(load)) + .route("/typing", post(typing)) + .route("/connect", get(connect)) } |