summaryrefslogtreecommitdiff
path: root/src/api/chat.rs
diff options
context:
space:
mode:
authortylerm <tylerm@tylerm.dev>2023-08-22 04:16:31 +0000
committertylerm <tylerm@tylerm.dev>2023-08-22 04:16:31 +0000
commitedbbdf72c78536c48357a86181bbf6897fc52074 (patch)
tree91d91e9dfb77ae3b7d75f4348c01bba59d0f13dc /src/api/chat.rs
parentallow port env (diff)
parentfinish dms (diff)
downloadxssbook-edbbdf72c78536c48357a86181bbf6897fc52074.tar.gz
xssbook-edbbdf72c78536c48357a86181bbf6897fc52074.tar.bz2
xssbook-edbbdf72c78536c48357a86181bbf6897fc52074.zip
Merge pull request 'dms are cool' (#1) from dev into main
Reviewed-on: https://g.tylerm.dev/tylerm/xssbook/pulls/1
Diffstat (limited to 'src/api/chat.rs')
-rw-r--r--src/api/chat.rs512
1 files changed, 512 insertions, 0 deletions
diff --git a/src/api/chat.rs b/src/api/chat.rs
new file mode 100644
index 0000000..02bdfbd
--- /dev/null
+++ b/src/api/chat.rs
@@ -0,0 +1,512 @@
+use std::collections::HashMap;
+
+use axum::{response::Response, Router, routing::{post, patch, delete, get}, extract::{ws::Message, WebSocketUpgrade}};
+use serde::Deserialize;
+use tokio::sync::{Mutex, mpsc::{Sender, self}};
+use crate::{
+ public::docs::{EndpointDocumentation, EndpointMethod},
+ types::{
+ extract::{AuthorizedUser, Check, CheckResult, Database, Json, Log},
+ http::ResponseCode,
+ chat::{ChatRoom, ChatEvent}, user::User,
+ },
+};
+use std::collections::hash_map::Values;
+use lazy_static::lazy_static;
+
+lazy_static!(
+ static ref CONNECTIONS: Mutex<HashMap<u64, ConnectionPool>> = Mutex::new(HashMap::new());
+);
+
+struct ConnectionPool {
+ inner: HashMap<usize, Sender<ChatEvent>>,
+ index: usize
+}
+
+impl ConnectionPool {
+ fn new() -> Self {
+ Self {
+ inner: HashMap::new(),
+ index: 0
+ }
+ }
+
+ fn add(&mut self, send: Sender<ChatEvent>) -> usize {
+ let idx = self.index;
+ self.index += 1;
+ self.inner.insert(idx, send);
+ idx
+ }
+
+ fn del(&mut self, idx: &usize) {
+ self.inner.remove(idx);
+ }
+
+ fn values(&self) -> Values<'_, usize, Sender<ChatEvent>> {
+ self.inner.values()
+ }
+}
+
+async fn send_event(event: ChatEvent, room: &ChatRoom) {
+ for user in &room.users {
+ let lock = CONNECTIONS.lock().await;
+ let Some(connection) = lock.get(&user) else {
+ continue
+ };
+ for channel in connection.values() {
+ channel.send(event.clone()).await.ok();
+ }
+ }
+}
+
+pub const CHAT_LIST: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/list",
+ method: EndpointMethod::Post,
+ description: "Returns the rooms you are in",
+ body: None,
+ responses: &[
+ (201, "Returns rooms in a list"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to retrieve rooms"),
+ ],
+ cookie: Some("auth"),
+};
+
+async fn list (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ _: Log
+) -> Response {
+ let Ok(rooms) = ChatRoom::from_user_id(&db, user.user_id) else {
+ return ResponseCode::InternalServerError.text("Failed to retrieve rooms")
+ };
+
+ let Ok(json) = serde_json::to_string(&rooms) else {
+ return ResponseCode::InternalServerError.text("Failed to retrieve rooms")
+ };
+
+ ResponseCode::Success.json(&json)
+}
+
+pub const CHAT_CREATE: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/create",
+ method: EndpointMethod::Post,
+ description: "Creates a new room",
+ body: Some(
+ r#"
+ {
+ "name" : "Funny memes"
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully created room"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to create room"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct RoomCreateRequest {
+ name: String,
+}
+
+impl Check for RoomCreateRequest {
+ fn check(&self) -> CheckResult {
+ Self::assert_length(
+ &self.name,
+ 1,
+ 255,
+ "Room names must be between 1-255 characters long",
+ )?;
+ Ok(())
+ }
+}
+
+async fn create (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<RoomCreateRequest>,
+) -> Response {
+ let Ok(room) = ChatRoom::new(&db, vec![user.user_id], body.name) else {
+ return ResponseCode::InternalServerError.text("Failed to create room")
+ };
+
+ for user in &room.users {
+ send_event(ChatEvent::Add {
+ user_id: *user,
+ room: room.clone()
+ }, &room).await;
+ }
+
+ let Ok(json) = serde_json::to_string(&room) else {
+ return ResponseCode::InternalServerError.text("Failed to create room")
+ };
+
+ ResponseCode::Created.json(&json)
+}
+
+pub const CHAT_ADD: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/add",
+ method: EndpointMethod::Patch,
+ description: "Adds a user to a room",
+ body: Some(
+ r#"
+ {
+ "room_id": 69,
+ "email" : "joebide@house.gov"
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully added user"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to add user"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct AddUserRequest {
+ room_id: u64,
+ email: String,
+}
+
+impl Check for AddUserRequest {
+ fn check(&self) -> CheckResult {
+ Ok(())
+ }
+}
+
+async fn add (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<AddUserRequest>,
+) -> Response {
+
+ let Ok(to_add) = User::from_email(&db, &body.email) else {
+ return ResponseCode::BadRequest.text("User does not exist")
+ };
+
+ let Ok(mut room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else {
+ return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
+ };
+
+ if room.users.contains(&to_add.user_id) {
+ return ResponseCode::BadRequest.text("User is already in the room")
+ }
+
+ let Ok(success) = room.add_user(&db, to_add.user_id) else {
+ return ResponseCode::InternalServerError.text("Failed to add user")
+ };
+
+ if !success {
+ return ResponseCode::BadRequest.text("User is already in the room")
+ }
+
+ room.users.push(to_add.user_id);
+
+ send_event(ChatEvent::Add {
+ user_id: to_add.user_id,
+ room: room.clone()
+ }, &room).await;
+
+ ResponseCode::Success.text("Successfully added user")
+}
+
+pub const CHAT_LEAVE: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/leave",
+ method: EndpointMethod::Delete,
+ description: "Leaves a room",
+ body: Some(
+ r#"
+ {
+ "room_id": 69
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully left room"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to leave a room"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct LeaveRoomRequest {
+ room_id: u64,
+}
+
+impl Check for LeaveRoomRequest {
+ fn check(&self) -> CheckResult {
+ Ok(())
+ }
+}
+
+async fn leave (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<LeaveRoomRequest>,
+) -> Response {
+
+ let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else {
+ return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
+ };
+
+ let Ok(success) = room.remove_user(&db, user.user_id) else {
+ return ResponseCode::InternalServerError.text("Failed to leave room")
+ };
+
+ if !success {
+ return ResponseCode::BadRequest.text("You are currently not in this room (how did this happen?)")
+ }
+
+ send_event(ChatEvent::Leave {
+ user_id: user.user_id,
+ room_id: room.room_id
+ }, &room).await;
+
+ ResponseCode::Success.text("Successfully left room")
+}
+
+pub const CHAT_SEND: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/send",
+ method: EndpointMethod::Post,
+ description: "Send a message to a room",
+ body: Some(
+ r#"
+ {
+ "room_id": 420,
+ "content" : "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully sent message"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to send message"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct SendMessageRequest {
+ room_id: u64,
+ content: String
+}
+
+impl Check for SendMessageRequest {
+ fn check(&self) -> CheckResult {
+ Self::assert_length(
+ &self.content,
+ 1,
+ 500,
+ "Messages must be between 1-500 length"
+ )?;
+ Ok(())
+ }
+}
+
+async fn send (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<SendMessageRequest>,
+) -> Response {
+
+ let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else {
+ return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
+ };
+
+ let Ok(msg) = room.send_message(&db, user.user_id, body.content) else {
+ return ResponseCode::InternalServerError.text("Failed to send message")
+ };
+
+ send_event(ChatEvent::Message {
+ user_id: msg.user_id,
+ room_id: msg.room_id,
+ message_id: msg.message_id,
+ content: msg.content,
+ date: msg.date
+ }, &room).await;
+
+ ResponseCode::Created.text("Successfully sent message")
+}
+
+pub const CHAT_LOAD: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/load",
+ method: EndpointMethod::Post,
+ description: "Get a page of historic room messages starting before given message id",
+ body: Some(
+ r#"
+ {
+ "room_id": 69,
+ "newest_msg": 400,
+ "page": 3
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully sent message"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to send message"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct LoadMessagesRequest {
+ room_id: u64,
+ newest_msg: u64,
+ page: u64
+}
+
+impl Check for LoadMessagesRequest {
+ fn check(&self) -> CheckResult {
+ Ok(())
+ }
+}
+
+async fn load (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<LoadMessagesRequest>,
+) -> Response {
+
+ let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else {
+ return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
+ };
+
+ let Ok(msgs) = room.load_old_chat_messages(&db, body.newest_msg, body.page) else {
+ return ResponseCode::InternalServerError.text("Failed to load messages")
+ };
+
+ let Ok(json) = serde_json::to_string(&msgs) else {
+ return ResponseCode::InternalServerError.text("Failed to load messages")
+ };
+
+ ResponseCode::Created.json(&json)
+}
+
+pub const CHAT_TYPING: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/typing",
+ method: EndpointMethod::Post,
+ description: "Set if your typing in a given room",
+ body: Some(
+ r#"
+ {
+ "room_id": 69,
+ }
+ "#,
+ ),
+ responses: &[
+ (201, "Successfully sent typing indicator"),
+ (400, "Body does not match parameters"),
+ (401, "Unauthorized"),
+ (500, "Failed to send typing indicator"),
+ ],
+ cookie: Some("auth"),
+};
+
+#[derive(Deserialize)]
+struct TypingRequest {
+ room_id: u64,
+}
+
+impl Check for TypingRequest {
+ fn check(&self) -> CheckResult {
+ Ok(())
+ }
+}
+
+async fn typing (
+ AuthorizedUser(user): AuthorizedUser,
+ Database(db): Database,
+ Json(body): Json<TypingRequest>,
+) -> Response {
+
+ let Ok(room) = ChatRoom::from_user_and_room_id(&db, user.user_id, body.room_id) else {
+ return ResponseCode::BadRequest.text("Room doesnt exist or you are not in it")
+ };
+
+ send_event(ChatEvent::Typing {
+ user_id: user.user_id,
+ room_id: room.room_id,
+ }, &room).await;
+
+ ResponseCode::Success.text("Successfully sent typing indicator")
+}
+
+pub const CHAT_CONNECT: EndpointDocumentation = EndpointDocumentation {
+ uri: "/api/chat/connect",
+ method: EndpointMethod::Get,
+ description: "Start a websocket connection for chat events",
+ body: None,
+ responses: &[],
+ cookie: Some("auth"),
+};
+
+async fn connect (
+ AuthorizedUser(user): AuthorizedUser,
+ ws: WebSocketUpgrade
+) -> Response {
+ ws.on_upgrade(|mut ws| async move {
+ let user = user;
+ let (send, mut recv) = mpsc::channel::<ChatEvent>(20);
+ let id: usize;
+ {
+ let mut lock = CONNECTIONS.lock().await;
+ match lock.get_mut(&user.user_id) {
+ Some(pool) => {
+ id = pool.add(send);
+ },
+ None => {
+ let mut pool = ConnectionPool::new();
+ id = pool.add(send);
+ lock.insert(user.user_id, pool);
+ }
+ };
+ }
+ loop {
+ tokio::select! {
+ m = ws.recv() => {
+ let Some(Ok(_)) = m else {
+ break;
+ };
+ }
+ s = recv.recv() => {
+ let Some(msg) = s else {
+ break;
+ };
+ if let Ok(string) = serde_json::to_string(&msg) {
+ ws.send(Message::Text(string)).await.ok();
+ }
+ }
+ }
+ }
+
+ let mut lock = CONNECTIONS.lock().await;
+ if let Some(conn) = lock.get_mut(&user.user_id) {
+ conn.del(&id);
+ };
+ })
+}
+
+pub fn router() -> Router {
+ Router::new()
+ .route("/create", post(create))
+ .route("/list", post(list))
+ .route("/add", patch(add))
+ .route("/leave", delete(leave))
+ .route("/send", post(send))
+ .route("/load", post(load))
+ .route("/typing", post(typing))
+ .route("/connect", get(connect))
+}