diff --git a/crates/ezidam/src/guards.rs b/crates/ezidam/src/guards.rs index d59cbb2..f4cbe90 100644 --- a/crates/ezidam/src/guards.rs +++ b/crates/ezidam/src/guards.rs @@ -1,7 +1,9 @@ mod completed_setup; +mod jwt; mod need_setup; mod refresh_token; +pub use self::jwt::*; pub use completed_setup::CompletedSetup; pub use need_setup::NeedSetup; pub use refresh_token::RefreshToken; diff --git a/crates/ezidam/src/guards/jwt.rs b/crates/ezidam/src/guards/jwt.rs new file mode 100644 index 0000000..12fb4d8 --- /dev/null +++ b/crates/ezidam/src/guards/jwt.rs @@ -0,0 +1,126 @@ +use crate::database::Database; +use jwt::database::Key; +use jwt::{JwtClaims, PrivateKey}; +use rocket::http::Status; +use rocket::request::Outcome; +use rocket::tokio::task; +use rocket::Request; + +mod admin; +mod user; + +pub use admin::JwtAdmin; +use id::KeyID; +pub use user::JwtUser; + +#[derive(Debug)] +pub enum Error { + GetDatabase, + Keys(jwt::Error), + JwtParsing(jwt::Error), + NoSigningKey, + NonExistentKey(String), + RevokedKey(KeyID), + ImportKey(jwt::Error), + JwtValidation(jwt::Error), + BlockingTask(String), +} + +pub async fn get_jwt( + request: &Request<'_>, + get_admin: Option, +) -> Result, Outcome> { + // Get jwt + let jwt = match request + .cookies() + .get("access_token") + .map(|cookie| cookie.value()) + { + Some(jwt) => jwt, + None => { + return Err(Outcome::Forward(())); + } + }; + + // Get database + let db = match request.guard::<&Database>().await { + Outcome::Success(database) => database, + Outcome::Failure(e) => return Err(Outcome::Failure((e.0, Error::GetDatabase))), + Outcome::Forward(f) => return Err(Outcome::Forward(f)), + }; + + // Get keys + let keys = match Key::get_all(&**db, Some(false)).await { + Ok(keys) => keys, + Err(e) => { + return Err(Outcome::Failure(( + Status::InternalServerError, + Error::Keys(e), + ))) + } + }; + + let jwt = jwt.to_string(); + match task::spawn_blocking(move || -> Result, Error> { + // Parse jwt + let parsed_jwt = jwt::parse(&jwt).map_err(Error::JwtParsing)?; + + // Get key id + let jwk_id = parsed_jwt + .header() + .key_id + .as_deref() + .ok_or(Error::NoSigningKey)?; + + // Get key + let key = keys + .iter() + .find(|&key| key.key_id().as_ref() == jwk_id) + .ok_or_else(|| Error::NonExistentKey(jwk_id.into()))?; + + // If key has been revoked + if key.is_revoked() { + return Err(Error::RevokedKey(key.key_id().to_owned())); + } + + // Import private key + let private_key = + PrivateKey::from_der(key.private_der(), key.key_id()).map_err(Error::ImportKey)?; + + // Validate jwt and get claims + let jwt_claims = private_key + .validate_jwt_extract_claims(&parsed_jwt) + .map_err(Error::JwtValidation)?; + + // Is specific kind of user required? + match get_admin { + // Yes, need to get specific kind of user + Some(get_admin) => { + if jwt_claims.is_admin == get_admin { + Ok(Some(jwt_claims)) + } else { + Ok(None) + } + } + // No, any user is good + None => Ok(Some(jwt_claims)), + } + }) + .await + { + Ok(result) => match result { + Ok(claims) => { + // Return jwt claims + Ok(claims) + } + Err(e) => Err(Outcome::Failure((Status::InternalServerError, e))), + }, + Err(e) => { + // Failed to run blocking task + Err(Outcome::Failure(( + Status::InternalServerError, + Error::BlockingTask(e.to_string()), + ))) + } + } +} diff --git a/crates/ezidam/src/guards/jwt/admin.rs b/crates/ezidam/src/guards/jwt/admin.rs new file mode 100644 index 0000000..c82ae08 --- /dev/null +++ b/crates/ezidam/src/guards/jwt/admin.rs @@ -0,0 +1,20 @@ +use jwt::JwtClaims; +use rocket::request::{FromRequest, Outcome}; +use rocket::Request; + +pub struct JwtAdmin(pub JwtClaims); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for JwtAdmin { + type Error = super::Error; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + match super::get_jwt(request, Some(true)).await { + Ok(jwt_claims) => match jwt_claims { + Some(jwt_claims) => Outcome::Success(JwtAdmin(jwt_claims)), + None => Outcome::Forward(()), + }, + Err(e) => return e, + } + } +} diff --git a/crates/ezidam/src/guards/jwt/user.rs b/crates/ezidam/src/guards/jwt/user.rs new file mode 100644 index 0000000..e5b7281 --- /dev/null +++ b/crates/ezidam/src/guards/jwt/user.rs @@ -0,0 +1,20 @@ +use jwt::JwtClaims; +use rocket::request::{FromRequest, Outcome}; +use rocket::Request; + +pub struct JwtUser(pub JwtClaims); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for JwtUser { + type Error = super::Error; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + match super::get_jwt(request, None).await { + Ok(jwt_claims) => match jwt_claims { + Some(jwt_claims) => Outcome::Success(JwtUser(jwt_claims)), + None => Outcome::Forward(()), + }, + Err(e) => return e, + } + } +} diff --git a/crates/ezidam/src/guards/refresh_token.rs b/crates/ezidam/src/guards/refresh_token.rs index ecfadb9..b1153dd 100644 --- a/crates/ezidam/src/guards/refresh_token.rs +++ b/crates/ezidam/src/guards/refresh_token.rs @@ -5,7 +5,7 @@ pub struct RefreshToken(pub String); #[rocket::async_trait] impl<'r> FromRequest<'r> for RefreshToken { - type Error = (); + type Error = std::convert::Infallible; async fn from_request(request: &'r Request<'_>) -> Outcome { match request.cookies().get("refresh_token") { diff --git a/crates/ezidam/src/routes/oauth.rs b/crates/ezidam/src/routes/oauth.rs index 46361e3..d0398ca 100644 --- a/crates/ezidam/src/routes/oauth.rs +++ b/crates/ezidam/src/routes/oauth.rs @@ -2,8 +2,8 @@ use authorize::*; use redirect::*; use rocket::{routes, Route}; -mod authorize; -mod redirect; +pub mod authorize; +pub mod redirect; pub fn routes() -> Vec { routes![ diff --git a/crates/ezidam/src/routes/root.rs b/crates/ezidam/src/routes/root.rs index 95ceb6b..c79914e 100644 --- a/crates/ezidam/src/routes/root.rs +++ b/crates/ezidam/src/routes/root.rs @@ -5,7 +5,15 @@ use settings::Settings; use users::User; pub fn routes() -> Vec { - routes![logo, avatar, homepage, redirect_to_setup, logout] + routes![ + logo, + avatar, + homepage, + homepage_user, + homepage_redirect, + redirect_to_setup, + logout + ] } #[get("/logo")] @@ -44,6 +52,7 @@ mod test { #[get("/avatar/?")] async fn avatar( mut db: Connection, + _user: JwtUser, user_id: RocketUserID, size: Option, ) -> Result { @@ -87,12 +96,26 @@ async fn redirect_to_setup(_setup: NeedSetup) -> Redirect { } #[get("/", rank = 2)] -async fn homepage() -> Page { +async fn homepage(admin: JwtAdmin) -> Page { + println!("{:?}", admin.0); Page::Homepage(content::Homepage { - abc: "string".to_string(), + abc: "admin".to_string(), }) } +#[get("/", rank = 3)] +async fn homepage_user(user: JwtUser) -> Page { + println!("{:?}", user.0); + Page::Homepage(content::Homepage { + abc: "user".to_string(), + }) +} + +#[get("/", rank = 4)] +async fn homepage_redirect() -> Redirect { + Redirect::to(uri!(super::oauth::authorize::authorize_ezidam)) +} + #[post("/logout")] async fn logout( mut db: Connection, diff --git a/crates/jwt/src/error.rs b/crates/jwt/src/error.rs index 54a9746..cf34a18 100644 --- a/crates/jwt/src/error.rs +++ b/crates/jwt/src/error.rs @@ -23,4 +23,7 @@ pub enum Error { #[error("Failed to create JWT: `{0}`")] JwtCreation(#[from] jwt_compact::CreationError), + + #[error("Failed to validate JWT: `{0}`")] + JwtValidation(#[from] jwt_compact::ValidationError), } diff --git a/crates/jwt/src/key/private.rs b/crates/jwt/src/key/private.rs index b5bcf86..bd4ab08 100644 --- a/crates/jwt/src/key/private.rs +++ b/crates/jwt/src/key/private.rs @@ -1,7 +1,7 @@ use crate::{Error, JwtClaims}; use id::KeyID; use jwt_compact::alg::{Rsa, RsaPrivateKey, StrongKey}; -use jwt_compact::{AlgorithmExt, Claims, Header}; +use jwt_compact::{AlgorithmExt, Claims, Header, TimeOptions, Token, UntrustedToken}; use rsa::pkcs8::der::zeroize::Zeroizing; use rsa::pkcs8::{DecodePrivateKey, EncodePrivateKey}; @@ -40,6 +40,23 @@ impl PrivateKey { ) -> Result { Ok(Rsa::ps256().token(header, &claims, &self.key)?) } + + pub fn validate_jwt_extract_claims(&self, token: &UntrustedToken) -> Result { + // Verify signature + let token: Token = Rsa::ps256() + .validate_integrity(token, &self.key) + .map_err(Error::JwtValidation)?; + + // Validate additional conditions + let time_options = TimeOptions::default(); + token + .claims() + .validate_expiration(&time_options) + .map_err(Error::JwtValidation)?; + + // Return claims + Ok(token.claims().custom.clone()) + } } #[cfg(test)] diff --git a/crates/jwt/src/lib.rs b/crates/jwt/src/lib.rs index 4949aab..b5d628a 100644 --- a/crates/jwt/src/lib.rs +++ b/crates/jwt/src/lib.rs @@ -12,3 +12,4 @@ pub use claims::JwtClaims; pub use error::Error; pub use key::generate; pub use key::{PrivateKey, PublicKey}; +pub use token::parse;