diff options
author | tylerm <tylerm@tylerm.dev> | 2023-08-22 04:16:31 +0000 |
---|---|---|
committer | tylerm <tylerm@tylerm.dev> | 2023-08-22 04:16:31 +0000 |
commit | edbbdf72c78536c48357a86181bbf6897fc52074 (patch) | |
tree | 91d91e9dfb77ae3b7d75f4348c01bba59d0f13dc /src/api | |
parent | allow port env (diff) | |
parent | finish dms (diff) | |
download | xssbook-edbbdf72c78536c48357a86181bbf6897fc52074.tar.gz xssbook-edbbdf72c78536c48357a86181bbf6897fc52074.tar.bz2 xssbook-edbbdf72c78536c48357a86181bbf6897fc52074.zip |
Merge pull request 'dms are cool' (#1) from dev into main
Reviewed-on: https://g.tylerm.dev/tylerm/xssbook/pulls/1
Diffstat (limited to 'src/api')
-rw-r--r-- | src/api/chat.rs | 512 | ||||
-rw-r--r-- | src/api/mod.rs | 33 |
2 files changed, 518 insertions, 27 deletions
diff --git a/src/api/chat.rs b/src/api/chat.rs new file mode 100644 index 0000000..02bdfbd --- /dev/null +++ b/src/api/chat.rs @@ -0,0 +1,512 @@ +use std::collections::HashMap; + +use axum::{response::Response, Router, routing::{post, patch, delete, get}, extract::{ws::Message, WebSocketUpgrade}}; +use serde::Deserialize; +use tokio::sync::{Mutex, mpsc::{Sender, self}}; +use crate::{ + public::docs::{EndpointDocumentation, EndpointMethod}, + types::{ + extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log}, + http::ResponseCode, + chat::{ChatRoom, ChatEvent}, user::User, + }, +}; +use std::collections::hash_map::Values; +use lazy_static::lazy_static; + +lazy_static!( + static ref CONNECTIONS: Mutex<HashMap<u64, ConnectionPool>> = Mutex::new(HashMap::new()); +); + +struct ConnectionPool { + inner: HashMap<usize, Sender<ChatEvent>>, + index: usize +} + +impl ConnectionPool { + fn new() -> Self { + Self { + inner: HashMap::new(), + index: 0 + } + } + + fn add(&mut self, send: Sender<ChatEvent>) -> usize { + let idx = self.index; + self.index += 1; + self.inner.insert(idx, send); + idx + } + + fn del(&mut self, idx: &usize) { + self.inner.remove(idx); + } + + fn values(&self) -> Values<'_, usize, Sender<ChatEvent>> { + self.inner.values() + } +} + +async fn send_event(event: ChatEvent, room: &ChatRoom) { + for user in &room.users { + let lock = CONNECTIONS.lock().await; + let Some(connection) = lock.get(&user) else { + continue + }; + for channel in connection.values() { + channel.send(event.clone()).await.ok(); + } + } +} + +pub const CHAT_LIST: EndpointDocumentation = EndpointDocumentation { + uri: "/api/chat/list", + method: EndpointMethod::Post, + description: "Returns the rooms you are in", + body: None, + responses: &[ + (201, "Returns rooms in a list"), + (400, "Body does not match parameters"), + (401, "Unauthorized"), + (500, "Failed to retrieve rooms"), + ], + cookie: Some("auth"), +}; + +async fn list ( + AuthorizedUser(user): AuthorizedUser, + Database(db): Database, + _: Log +) -> Response { + let Ok(rooms) = ChatRoom::from_user_id(&db, user.user_id) else { + return ResponseCode::InternalServerError.text("Failed to retrieve rooms") + }; + + let Ok(json) = serde_json::to_string(&rooms) else { + return ResponseCode::InternalServerError.text("Failed to retrieve rooms") + }; + + ResponseCode::Success.json(&json) +} + +pub const CHAT_CREATE: EndpointDocumentation = EndpointDocumentation { + uri: "/api/chat/create", + method: EndpointMethod::Post, + description: "Creates a new room", + body: Some( + r#" + { + "name" : "Funny memes" + } + "#, + ), + responses: &[ + (201, "Successfully created room"), + (400, "Body does not match parameters"), + (401, "Unauthorized"), + (500, "Failed to create room"), + ], + cookie: Some("auth"), +}; + +#[derive(Deserialize)] +struct RoomCreateRequest { + name: String, +} + +impl Check for RoomCreateRequest { + fn check(&self) -> CheckResult { + Self::assert_length( + &self.name, + 1, + 255, + "Room names must be between 1-255 characters long", + )?; + Ok(()) + } +} + +async fn create ( + AuthorizedUser(user): AuthorizedUser, + Database(db): Database, + Json(body): Json<RoomCreateRequest>, +) -> Response { + let Ok(room) = ChatRoom::new(&db, vec![user.user_id], body.name) else { + return ResponseCode::InternalServerError.text("Failed to create room") + }; + + for user in &room.users { + send_event(ChatEvent::Add { + user_id: *user, + room: room.clone() + }, &room).await; + } + + let Ok(json) = serde_json::to_string(&room) else { + return ResponseCode::InternalServerError.text("Failed to create room") + }; + + ResponseCode::Created.json(&json) +} + +pub const CHAT_ADD: EndpointDocumentation = EndpointDocumentation { + uri: "/api/chat/add", + method: EndpointMethod::Patch, + description: "Adds a user to a room", + body: Some( + r#" + { + "room_id": 69, + "email" : "joebide@house.gov" + } + "#, + ), + responses: &[ + (201, "Successfully added user"), + (400, "Body does not match parameters"), + (401, "Unauthorized"), + (500, "Failed to add user"), + ], + cookie: Some("auth"), +}; + +#[derive(Deserialize)] +struct AddUserRequest { + room_id: u64, + email: String, +} + +impl Check for AddUserRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + +async fn add ( + AuthorizedUser(user): AuthorizedUser, + Database(db): Database, + Json(body): Json<AddUserRequest>, +) -> Response { + + let Ok(to_add) = User::from_email(&db, &body.email) else { + return ResponseCode::BadRequest.text("User does not exist") + }; + + let Ok(mut room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else { + return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") + }; + + if room.users.contains(&to_add.user_id) { + return ResponseCode::BadRequest.text("User is already in the room") + } + + let Ok(success) = room.add_user(&db, to_add.user_id) else { + return ResponseCode::InternalServerError.text("Failed to add user") + }; + + if !success { + return ResponseCode::BadRequest.text("User is already in the room") + } + + room.users.push(to_add.user_id); + + send_event(ChatEvent::Add { + user_id: to_add.user_id, + room: room.clone() + }, &room).await; + + ResponseCode::Success.text("Successfully added user") +} + +pub const CHAT_LEAVE: EndpointDocumentation = EndpointDocumentation { + uri: "/api/chat/leave", + method: EndpointMethod::Delete, + description: "Leaves a room", + body: Some( + r#" + { + "room_id": 69 + } + "#, + ), + responses: &[ + (201, "Successfully left room"), + (400, "Body does not match parameters"), + (401, "Unauthorized"), + (500, "Failed to leave a room"), + ], + cookie: Some("auth"), +}; + +#[derive(Deserialize)] +struct LeaveRoomRequest { + room_id: u64, +} + +impl Check for LeaveRoomRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + +async fn leave ( + AuthorizedUser(user): AuthorizedUser, + Database(db): Database, + Json(body): Json<LeaveRoomRequest>, +) -> Response { + + let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else { + return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") + }; + + let Ok(success) = room.remove_user(&db, user.user_id) else { + return ResponseCode::InternalServerError.text("Failed to leave room") + }; + + if !success { + return ResponseCode::BadRequest.text("You are currently not in this room (how did this happen?)") + } + + send_event(ChatEvent::Leave { + user_id: user.user_id, + room_id: room.room_id + }, &room).await; + + ResponseCode::Success.text("Successfully left room") +} + +pub const CHAT_SEND: EndpointDocumentation = EndpointDocumentation { + uri: "/api/chat/send", + method: EndpointMethod::Post, + description: "Send a message to a room", + body: Some( + r#" + { + "room_id": 420, + "content" : "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + } + "#, + ), + responses: &[ + (201, "Successfully sent message"), + (400, "Body does not match parameters"), + (401, "Unauthorized"), + (500, "Failed to send message"), + ], + cookie: Some("auth"), +}; + +#[derive(Deserialize)] +struct SendMessageRequest { + room_id: u64, + content: String +} + +impl Check for SendMessageRequest { + fn check(&self) -> CheckResult { + Self::assert_length( + &self.content, + 1, + 500, + "Messages must be between 1-500 length" + )?; + Ok(()) + } +} + +async fn send ( + AuthorizedUser(user): AuthorizedUser, + Database(db): Database, + Json(body): Json<SendMessageRequest>, +) -> Response { + + let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else { + return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") + }; + + let Ok(msg) = room.send_message(&db, user.user_id, body.content) else { + return ResponseCode::InternalServerError.text("Failed to send message") + }; + + send_event(ChatEvent::Message { + user_id: msg.user_id, + room_id: msg.room_id, + message_id: msg.message_id, + content: msg.content, + date: msg.date + }, &room).await; + + ResponseCode::Created.text("Successfully sent message") +} + +pub const CHAT_LOAD: EndpointDocumentation = EndpointDocumentation { + uri: "/api/chat/load", + method: EndpointMethod::Post, + description: "Get a page of historic room messages starting before given message id", + body: Some( + r#" + { + "room_id": 69, + "newest_msg": 400, + "page": 3 + } + "#, + ), + responses: &[ + (201, "Successfully sent message"), + (400, "Body does not match parameters"), + (401, "Unauthorized"), + (500, "Failed to send message"), + ], + cookie: Some("auth"), +}; + +#[derive(Deserialize)] +struct LoadMessagesRequest { + room_id: u64, + newest_msg: u64, + page: u64 +} + +impl Check for LoadMessagesRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + +async fn load ( + AuthorizedUser(user): AuthorizedUser, + Database(db): Database, + Json(body): Json<LoadMessagesRequest>, +) -> Response { + + let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else { + return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") + }; + + let Ok(msgs) = room.load_old_chat_messages(&db, body.newest_msg, body.page) else { + return ResponseCode::InternalServerError.text("Failed to load messages") + }; + + let Ok(json) = serde_json::to_string(&msgs) else { + return ResponseCode::InternalServerError.text("Failed to load messages") + }; + + ResponseCode::Created.json(&json) +} + +pub const CHAT_TYPING: EndpointDocumentation = EndpointDocumentation { + uri: "/api/chat/typing", + method: EndpointMethod::Post, + description: "Set if your typing in a given room", + body: Some( + r#" + { + "room_id": 69, + } + "#, + ), + responses: &[ + (201, "Successfully sent typing indicator"), + (400, "Body does not match parameters"), + (401, "Unauthorized"), + (500, "Failed to send typing indicator"), + ], + cookie: Some("auth"), +}; + +#[derive(Deserialize)] +struct TypingRequest { + room_id: u64, +} + +impl Check for TypingRequest { + fn check(&self) -> CheckResult { + Ok(()) + } +} + +async fn typing ( + AuthorizedUser(user): AuthorizedUser, + Database(db): Database, + Json(body): Json<TypingRequest>, +) -> Response { + + let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else { + return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it") + }; + + send_event(ChatEvent::Typing { + user_id: user.user_id, + room_id: room.room_id, + }, &room).await; + + ResponseCode::Success.text("Successfully sent typing indicator") +} + +pub const CHAT_CONNECT: EndpointDocumentation = EndpointDocumentation { + uri: "/api/chat/connect", + method: EndpointMethod::Get, + description: "Start a websocket connection for chat events", + body: None, + responses: &[], + cookie: Some("auth"), +}; + +async fn connect ( + AuthorizedUser(user): AuthorizedUser, + ws: WebSocketUpgrade +) -> Response { + ws.on_upgrade(|mut ws| async move { + let user = user; + let (send, mut recv) = mpsc::channel::<ChatEvent>(20); + let id: usize; + { + let mut lock = CONNECTIONS.lock().await; + match lock.get_mut(&user.user_id) { + Some(pool) => { + id = pool.add(send); + }, + None => { + let mut pool = ConnectionPool::new(); + id = pool.add(send); + lock.insert(user.user_id, pool); + } + }; + } + loop { + tokio::select! { + m = ws.recv() => { + let Some(Ok(_)) = m else { + break; + }; + } + s = recv.recv() => { + let Some(msg) = s else { + break; + }; + if let Ok(string) = serde_json::to_string(&msg) { + ws.send(Message::Text(string)).await.ok(); + } + } + } + } + + let mut lock = CONNECTIONS.lock().await; + if let Some(conn) = lock.get_mut(&user.user_id) { + conn.del(&id); + }; + }) +} + +pub fn router() -> Router { + Router::new() + .route("/create", post(create)) + .route("/list", post(list)) + .route("/add", patch(add)) + .route("/leave", delete(leave)) + .route("/send", post(send)) + .route("/load", post(load)) + .route("/typing", post(typing)) + .route("/connect", get(connect)) +} diff --git a/src/api/mod.rs b/src/api/mod.rs index 8b631c8..0c01ea0 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,33 +1,21 @@ use crate::types::extract::{RouterURI, self}; -use axum::{ - error_handling::HandleErrorLayer, - BoxError, Extension, Router, middleware, -}; -use tower::ServiceBuilder; -use tower_governor::{ - errors::display_error, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, - GovernorLayer, -}; +pub mod chat; pub mod admin; pub mod auth; pub mod posts; pub mod users; pub use auth::RegistrationRequet; +use axum::{Extension, Router, middleware}; pub fn router() -> Router { - let governor_conf = Box::new( - GovernorConfigBuilder::default() - .burst_size(15) - .per_second(1) - .key_extractor(SmartIpKeyExtractor) - .finish() - .expect("Failed to create rate limiter"), - ); - Router::new() .nest( + "/chat", + chat::router().layer(Extension(RouterURI("/api/chat"))), + ) + .nest( "/admin", admin::router().layer(Extension(RouterURI("/api/admin"))), ) @@ -43,14 +31,5 @@ pub fn router() -> Router { "/posts", posts::router().layer(Extension(RouterURI("/api/posts"))), ) - .layer( - ServiceBuilder::new() - .layer(HandleErrorLayer::new(|e: BoxError| async move { - display_error(e) - })) - .layer(GovernorLayer { - config: Box::leak(governor_conf), - }), - ) .layer(middleware::from_fn(extract::connect)) } |