diff --git a/Cargo.lock b/Cargo.lock index 860b192..aae1628 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -722,6 +722,7 @@ version = "0.1.0" dependencies = [ "apps", "authorization_codes", + "base64 0.21.0", "database_pool", "erased-serde", "futures", @@ -733,6 +734,7 @@ dependencies = [ "openid", "refresh_tokens", "rocket", + "rocket_cors", "rocket_db_pools", "rocket_dyn_templates", "settings", @@ -2379,6 +2381,23 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "rocket_cors" +version = "0.6.0-alpha2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b12771b47f52e34d5d0e0e444aeba382863e73263cb9e18847e7d5b74aa2cbd0" +dependencies = [ + "http", + "log", + "regex", + "rocket", + "serde", + "serde_derive", + "unicase", + "unicase_serde", + "url", +] + [[package]] name = "rocket_db_pools" version = "0.1.0-rc.3" @@ -3362,6 +3381,25 @@ dependencies = [ "unic-common", ] +[[package]] +name = "unicase" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6" +dependencies = [ + "version_check", +] + +[[package]] +name = "unicase_serde" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ef53697679d874d69f3160af80bc28de12730a985d57bdf2b47456ccb8b11f1" +dependencies = [ + "serde", + "unicase", +] + [[package]] name = "unicode-bidi" version = "0.3.13" diff --git a/crates/ezidam/Cargo.toml b/crates/ezidam/Cargo.toml index d738b38..d849858 100644 --- a/crates/ezidam/Cargo.toml +++ b/crates/ezidam/Cargo.toml @@ -12,6 +12,8 @@ erased-serde = "0.3" url = { workspace = true } identicon-rs = "4.0" futures = "0.3" +base64 = "0.21.0" +rocket_cors = "0.6.0-alpha2" # local crates database_pool = { path = "../database_pool" } diff --git a/crates/ezidam/src/cors.rs b/crates/ezidam/src/cors.rs new file mode 100644 index 0000000..2ebb5d5 --- /dev/null +++ b/crates/ezidam/src/cors.rs @@ -0,0 +1,15 @@ +use rocket::{Build, Rocket}; +use rocket_cors::{Cors, CorsOptions}; + +fn cors() -> Cors { + CorsOptions { + allow_credentials: true, + ..Default::default() + } + .to_cors() + .expect("Failed to configure CORS") +} + +pub fn rocket(rocket_builder: Rocket) -> Rocket { + rocket_builder.attach(cors()) +} diff --git a/crates/ezidam/src/guards.rs b/crates/ezidam/src/guards.rs index f4cbe90..e81a988 100644 --- a/crates/ezidam/src/guards.rs +++ b/crates/ezidam/src/guards.rs @@ -1,3 +1,4 @@ +mod basic_auth; mod completed_setup; mod jwt; mod need_setup; @@ -7,3 +8,4 @@ pub use self::jwt::*; pub use completed_setup::CompletedSetup; pub use need_setup::NeedSetup; pub use refresh_token::RefreshToken; +pub use basic_auth::BasicAuth; diff --git a/crates/ezidam/src/guards/basic_auth.rs b/crates/ezidam/src/guards/basic_auth.rs new file mode 100644 index 0000000..a8a4d2b --- /dev/null +++ b/crates/ezidam/src/guards/basic_auth.rs @@ -0,0 +1,67 @@ +use base64::Engine; +use rocket::http::Status; +use rocket::request::{FromRequest, Outcome}; +use rocket::Request; + +#[derive(Debug)] +pub struct BasicAuth { + pub id: String, + pub password: String, +} + +#[derive(Debug)] +pub enum BasicAuthError { + BadCount, + Invalid, + Empty, + Base64Decode, + InvalidUtf8, + Format, +} + +impl BasicAuth { + fn from_base64(raw: &str) -> Result { + // Make sure format is `Basic base64_string` + if !raw.starts_with("Basic ") { + return Err(BasicAuthError::Invalid); + } + + // Extract base64 encoded string + let (_, base64) = raw.split_once(' ').ok_or(BasicAuthError::Empty)?; + + // Decode base64 to bytes + let decoded_bytes = base64::engine::general_purpose::URL_SAFE + .decode(base64) + .map_err(|_| BasicAuthError::Base64Decode)?; + + // Convert bytes to string slice + let decoded_str = + std::str::from_utf8(&decoded_bytes).map_err(|_| BasicAuthError::InvalidUtf8)?; + + // Extract id and password + let (id, password) = decoded_str.split_once(':').ok_or(BasicAuthError::Format)?; + + Ok(Self { + id: id.to_string(), + password: password.to_string(), + }) + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for BasicAuth { + type Error = BasicAuthError; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let keys: Vec<_> = request.headers().get("Authorization").collect(); + + match keys.len() { + 0 => Outcome::Forward(()), + 1 => match BasicAuth::from_base64(keys[0]) { + Ok(auth_header) => Outcome::Success(auth_header), + Err(e) => Outcome::Failure((Status::BadRequest, e)), + }, + _ => Outcome::Failure((Status::BadRequest, BasicAuthError::BadCount)), + } + } +} diff --git a/crates/ezidam/src/lib.rs b/crates/ezidam/src/lib.rs index 3ad6e20..8faffdb 100644 --- a/crates/ezidam/src/lib.rs +++ b/crates/ezidam/src/lib.rs @@ -1,6 +1,7 @@ use rocket::{Build, Rocket}; mod cache; +mod cors; mod database; mod error; mod file_from_bytes; @@ -48,6 +49,9 @@ pub fn rocket_setup(rocket_builder: Rocket) -> Rocket { // Routes let rocket_builder = routes::routes(rocket_builder); + // CORS + let rocket_builder = cors::rocket(rocket_builder); + // Errors let rocket_builder = catchers::register(rocket_builder);