From 1dad24badcbcc8968c3ada76dc5cedcb0a7164cd Mon Sep 17 00:00:00 2001 From: Timshel Date: Wed, 25 Jun 2025 12:30:48 +0200 Subject: [PATCH] Create a separate sso_client --- src/main.rs | 1 + src/sso.rs | 303 +++++++++------------------------------------- src/sso_client.rs | 264 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 323 insertions(+), 245 deletions(-) create mode 100644 src/sso_client.rs diff --git a/src/main.rs b/src/main.rs index f61e339b..8fbd3453 100644 --- a/src/main.rs +++ b/src/main.rs @@ -57,6 +57,7 @@ mod http_client; mod mail; mod ratelimit; mod sso; +mod sso_client; mod util; use crate::api::core::two_factor::duo_oidc::purge_duo_contexts; diff --git a/src/sso.rs b/src/sso.rs index bf998fe7..4f7ed86a 100644 --- a/src/sso.rs +++ b/src/sso.rs @@ -1,21 +1,11 @@ use chrono::Utc; use derive_more::{AsRef, Deref, Display, From}; use regex::Regex; -use std::borrow::Cow; use std::time::Duration; use url::Url; use mini_moka::sync::Cache; use once_cell::sync::Lazy; -use openidconnect::core::{ - CoreClient, CoreIdTokenVerifier, CoreProviderMetadata, CoreResponseType, CoreUserInfoClaims, -}; -use openidconnect::reqwest; -use openidconnect::{ - AccessToken, AuthDisplay, AuthPrompt, AuthenticationFlow, AuthorizationCode, AuthorizationRequest, ClientId, - ClientSecret, CsrfToken, EndpointNotSet, EndpointSet, Nonce, OAuth2TokenResponse, PkceCodeChallenge, - PkceCodeVerifier, RefreshToken, ResponseType, Scope, -}; use crate::{ api::ApiResult, @@ -25,6 +15,7 @@ use crate::{ models::{Device, SsoNonce, User}, DbConn, }, + sso_client::Client, CONFIG, }; @@ -33,30 +24,10 @@ pub static FAKE_IDENTIFIER: &str = "Vaultwarden"; static AC_CACHE: Lazy> = Lazy::new(|| Cache::builder().max_capacity(1000).time_to_live(Duration::from_secs(10 * 60)).build()); -static CLIENT_CACHE_KEY: Lazy = Lazy::new(|| "sso-client".to_string()); -static CLIENT_CACHE: Lazy> = Lazy::new(|| { - Cache::builder().max_capacity(1).time_to_live(Duration::from_secs(CONFIG.sso_client_cache_expiration())).build() -}); - static SSO_JWT_ISSUER: Lazy = Lazy::new(|| format!("{}|sso", CONFIG.domain_origin())); pub static NONCE_EXPIRATION: Lazy = Lazy::new(|| chrono::TimeDelta::try_minutes(10).unwrap()); -trait AuthorizationRequestExt<'a> { - fn add_extra_params>, V: Into>>(self, params: Vec<(N, V)>) -> Self; -} - -impl<'a, AD: AuthDisplay, P: AuthPrompt, RT: ResponseType> AuthorizationRequestExt<'a> - for AuthorizationRequest<'a, AD, P, RT> -{ - fn add_extra_params>, V: Into>>(mut self, params: Vec<(N, V)>) -> Self { - for (key, value) in params { - self = self.add_extra_param(key, value); - } - self - } -} - #[derive( Clone, Debug, @@ -180,91 +151,6 @@ fn decode_token_claims(token_name: &str, token: &str) -> ApiResult, -} - -impl Client { - // Call the OpenId discovery endpoint to retrieve configuration - async fn _get_client() -> ApiResult { - let client_id = ClientId::new(CONFIG.sso_client_id()); - let client_secret = ClientSecret::new(CONFIG.sso_client_secret()); - - let issuer_url = CONFIG.sso_issuer_url()?; - - let http_client = match reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none()).build() { - Err(err) => err!(format!("Failed to build http client: {err}")), - Ok(client) => client, - }; - - let provider_metadata = match CoreProviderMetadata::discover_async(issuer_url, &http_client).await { - Err(err) => err!(format!("Failed to discover OpenID provider: {err}")), - Ok(metadata) => metadata, - }; - - let base_client = CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret)); - - let token_uri = match base_client.token_uri() { - Some(uri) => uri.clone(), - None => err!("Failed to discover token_url, cannot proceed"), - }; - - let user_info_url = match base_client.user_info_url() { - Some(url) => url.clone(), - None => err!("Failed to discover user_info url, cannot proceed"), - }; - - let core_client = base_client - .set_redirect_uri(CONFIG.sso_redirect_url()?) - .set_token_uri(token_uri) - .set_user_info_url(user_info_url); - - Ok(Client { - http_client, - core_client, - }) - } - - // Simple cache to prevent recalling the discovery endpoint each time - async fn cached() -> ApiResult { - if CONFIG.sso_client_cache_expiration() > 0 { - match CLIENT_CACHE.get(&*CLIENT_CACHE_KEY) { - Some(client) => Ok(client), - None => Self::_get_client().await.inspect(|client| { - debug!("Inserting new client in cache"); - CLIENT_CACHE.insert(CLIENT_CACHE_KEY.clone(), client.clone()); - }), - } - } else { - Self::_get_client().await - } - } - - async fn user_info(&self, access_token: AccessToken) -> ApiResult { - match self.core_client.user_info(access_token, None).request_async(&self.http_client).await { - Err(err) => err!(format!("Request to user_info endpoint failed: {err}")), - Ok(user_info) => Ok(user_info), - } - } - - fn vw_id_token_verifier(&self) -> CoreIdTokenVerifier<'_> { - let mut verifier = self.core_client.id_token_verifier(); - if let Some(regex_str) = CONFIG.sso_audience_trusted() { - match Regex::new(®ex_str) { - Ok(regex) => { - verifier = verifier.set_other_audience_verifier_fn(move |aud| regex.is_match(aud)); - } - Err(err) => { - error!("Failed to parse SSO_AUDIENCE_TRUSTED={regex_str} regex: {err}"); - } - } - } - verifier - } -} - pub fn deocde_state(base64_state: String) -> ApiResult { let state = match data_encoding::BASE64.decode(base64_state.as_bytes()) { Ok(vec) => match String::from_utf8(vec) { @@ -278,7 +164,6 @@ pub fn deocde_state(base64_state: String) -> ApiResult { } // The `nonce` allow to protect against replay attacks -// The `state` is encoded using base64 to ensure no issue with providers (It contains the Organization identifier). // redirect_uri from: https://github.com/bitwarden/server/blob/main/src/Identity/IdentityServer/ApiClient.cs pub async fn authorize_url( state: OIDCState, @@ -286,9 +171,6 @@ pub async fn authorize_url( raw_redirect_uri: &str, mut conn: DbConn, ) -> ApiResult { - let scopes = CONFIG.sso_scopes_vec().into_iter().map(Scope::new); - let base64_state = data_encoding::BASE64.encode(state.to_string().as_bytes()); - let redirect_uri = match client_id { "web" | "browser" => format!("{}/sso-connector.html", CONFIG.domain()), "desktop" | "mobile" => "bitwarden://sso-callback".to_string(), @@ -302,30 +184,8 @@ pub async fn authorize_url( _ => err!(format!("Unsupported client {client_id}")), }; - let client = Client::cached().await?; - let mut auth_req = client - .core_client - .authorize_url( - AuthenticationFlow::::AuthorizationCode, - || CsrfToken::new(base64_state), - Nonce::new_random, - ) - .add_scopes(scopes) - .add_extra_params(CONFIG.sso_authorize_extra_params_vec()?); - - let verifier = if CONFIG.sso_pkce() { - let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); - auth_req = auth_req.set_pkce_challenge(pkce_challenge); - Some(pkce_verifier.into_secret()) - } else { - None - }; - - let (auth_url, _, nonce) = auth_req.url(); - - let sso_nonce = SsoNonce::new(state, nonce.secret().clone(), verifier, redirect_uri); - sso_nonce.save(&mut conn).await?; - + let (auth_url, nonce) = Client::authorize_url(state, redirect_uri).await?; + nonce.save(&mut conn).await?; Ok(auth_url) } @@ -406,6 +266,8 @@ async fn decode_code_claims(code: &str, conn: &mut DbConn) -> ApiResult<(OIDCCod // The `nonce` will ensure that the user is authorized only once. // We return only the `UserInformation` to force calling `redeem` to obtain the `refresh_token`. pub async fn exchange_code(wrapped_code: &str, conn: &mut DbConn) -> ApiResult { + use openidconnect::OAuth2TokenResponse; + let (code, state) = decode_code_claims(wrapped_code, conn).await?; if let Some(authenticated_user) = AC_CACHE.get(&state) { @@ -418,87 +280,53 @@ pub async fn exchange_code(wrapped_code: &str, conn: &mut DbConn) -> ApiResult err!(format!("Invalid state cannot retrieve nonce")), Some(nonce) => nonce, }; - let mut exchange = client.core_client.exchange_code(oidc_code); + let client = Client::cached().await?; + let (token_response, id_claims) = client.exchange_code(code, nonce).await?; - if CONFIG.sso_pkce() { - match nonce.verifier { - None => err!(format!("Missing verifier in the DB nonce table")), - Some(secret) => exchange = exchange.set_pkce_verifier(PkceCodeVerifier::new(secret)), - } + let user_info = client.user_info(token_response.access_token().to_owned()).await?; + + let email = match id_claims.email().or(user_info.email()) { + None => err!("Neither id token nor userinfo contained an email"), + Some(e) => e.to_string().to_lowercase(), + }; + + let email_verified = id_claims.email_verified().or(user_info.email_verified()); + + let user_name = id_claims.preferred_username().map(|un| un.to_string()); + + let refresh_token = token_response.refresh_token().map(|t| t.secret()); + if refresh_token.is_none() && CONFIG.sso_scopes_vec().contains(&"offline_access".to_string()) { + error!("Scope offline_access is present but response contain no refresh_token"); } - match exchange.request_async(&client.http_client).await { - Ok(token_response) => { - let user_info = client.user_info(token_response.access_token().to_owned()).await?; - let oidc_nonce = Nonce::new(nonce.nonce.clone()); + let identifier = OIDCIdentifier::new(id_claims.issuer(), id_claims.subject()); - let id_token = match token_response.extra_fields().id_token() { - None => err!("Token response did not contain an id_token"), - Some(token) => token, - }; + let authenticated_user = AuthenticatedUser { + refresh_token: refresh_token.cloned(), + access_token: token_response.access_token().secret().clone(), + expires_in: token_response.expires_in(), + identifier: identifier.clone(), + email: email.clone(), + email_verified, + user_name: user_name.clone(), + }; - if CONFIG.sso_debug_tokens() { - debug!("Id token: {}", id_token.to_string()); - debug!("Access token: {}", token_response.access_token().secret()); - debug!("Refresh token: {:?}", token_response.refresh_token().map(|t| t.secret())); - debug!("Expiration time: {:?}", token_response.expires_in()); - } + debug!("Authentified user {authenticated_user:?}"); - let id_claims = match id_token.claims(&client.vw_id_token_verifier(), &oidc_nonce) { - Ok(claims) => claims, - Err(err) => { - if CONFIG.sso_client_cache_expiration() > 0 { - CLIENT_CACHE.invalidate(&*CLIENT_CACHE_KEY); - } - err!(format!("Could not read id_token claims, {err}")); - } - }; + AC_CACHE.insert(state.clone(), authenticated_user); - let email = match id_claims.email().or(user_info.email()) { - None => err!("Neither id token nor userinfo contained an email"), - Some(e) => e.to_string().to_lowercase(), - }; - let email_verified = id_claims.email_verified().or(user_info.email_verified()); - - let user_name = user_info.preferred_username().map(|un| un.to_string()); - - let refresh_token = token_response.refresh_token().map(|t| t.secret()); - if refresh_token.is_none() && CONFIG.sso_scopes_vec().contains(&"offline_access".to_string()) { - error!("Scope offline_access is present but response contain no refresh_token"); - } - - let identifier = OIDCIdentifier::new(id_claims.issuer(), id_claims.subject()); - - let authenticated_user = AuthenticatedUser { - refresh_token: refresh_token.cloned(), - access_token: token_response.access_token().secret().clone(), - expires_in: token_response.expires_in(), - identifier: identifier.clone(), - email: email.clone(), - email_verified, - user_name: user_name.clone(), - }; - - AC_CACHE.insert(state.clone(), authenticated_user.clone()); - - Ok(UserInformation { - state, - identifier, - email, - email_verified, - user_name, - }) - } - Err(err) => err!(format!("Failed to contact token endpoint: {err}")), - } + Ok(UserInformation { + state, + identifier, + email, + email_verified, + user_name, + }) } // User has passed 2FA flow we can delete `nonce` and clear the cache. @@ -594,27 +422,17 @@ pub async fn exchange_refresh_token( let exp = refresh_claims.exp; match refresh_claims.token { Some(TokenWrapper::Refresh(refresh_token)) => { - let rt = RefreshToken::new(refresh_token); - - let client = Client::cached().await?; - - let token_response = - match client.core_client.exchange_refresh_token(&rt).request_async(&client.http_client).await { - Err(err) => err!(format!("Request to exchange_refresh_token endpoint failed: {:?}", err)), - Ok(token_response) => token_response, - }; - // Use new refresh_token if returned - let rolled_refresh_token = - token_response.refresh_token().map(|token| token.secret()).unwrap_or(rt.secret()); + let (new_refresh_token, access_token, expires_in) = + Client::exchange_refresh_token(refresh_token.clone()).await?; create_auth_tokens( device, user, client_id, - Some(rolled_refresh_token.clone()), - token_response.access_token().secret().clone(), - token_response.expires_in(), + new_refresh_token.or(Some(refresh_token)), + access_token, + expires_in, ) } Some(TokenWrapper::Access(access_token)) => { @@ -625,24 +443,19 @@ pub async fn exchange_refresh_token( err_silent!("Access token is close to expiration but we have no refresh token") } - let client = Client::cached().await?; - match client.user_info(AccessToken::new(access_token.clone())).await { - Err(err) => { - err_silent!(format!("Failed to retrieve user info, token has probably been invalidated: {err}")) - } - Ok(_) => { - let access_claims = auth::LoginJwtClaims::new( - device, - user, - now.timestamp(), - exp, - AuthMethod::Sso.scope_vec(), - client_id, - now, - ); - _create_auth_tokens(device, None, access_claims, access_token) - } - } + Client::check_validaty(access_token.clone()).await?; + + let access_claims = auth::LoginJwtClaims::new( + device, + user, + now.timestamp(), + exp, + AuthMethod::Sso.scope_vec(), + client_id, + now, + ); + + _create_auth_tokens(device, None, access_claims, access_token) } None => err!("No token present while in SSO"), } diff --git a/src/sso_client.rs b/src/sso_client.rs new file mode 100644 index 00000000..f3aa667c --- /dev/null +++ b/src/sso_client.rs @@ -0,0 +1,264 @@ +use regex::Regex; +use std::borrow::Cow; +use std::time::Duration; +use url::Url; + +use mini_moka::sync::Cache; +use once_cell::sync::Lazy; +use openidconnect::core::*; +use openidconnect::reqwest; +use openidconnect::*; + +use crate::{ + api::{ApiResult, EmptyResult}, + db::models::SsoNonce, + sso::{OIDCCode, OIDCState}, + CONFIG, +}; + +static CLIENT_CACHE_KEY: Lazy = Lazy::new(|| "sso-client".to_string()); +static CLIENT_CACHE: Lazy> = Lazy::new(|| { + Cache::builder().max_capacity(1).time_to_live(Duration::from_secs(CONFIG.sso_client_cache_expiration())).build() +}); + +/// OpenID Connect Core client. +pub type CustomClient = openidconnect::Client< + EmptyAdditionalClaims, + CoreAuthDisplay, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJsonWebKey, + CoreAuthPrompt, + StandardErrorResponse, + CoreTokenResponse, + CoreTokenIntrospectionResponse, + CoreRevocableToken, + CoreRevocationErrorResponse, + EndpointSet, + EndpointNotSet, + EndpointNotSet, + EndpointNotSet, + EndpointSet, + EndpointSet, +>; + +#[derive(Clone)] +pub struct Client { + pub http_client: reqwest::Client, + pub core_client: CustomClient, +} + +impl Client { + // Call the OpenId discovery endpoint to retrieve configuration + async fn _get_client() -> ApiResult { + let client_id = ClientId::new(CONFIG.sso_client_id()); + let client_secret = ClientSecret::new(CONFIG.sso_client_secret()); + + let issuer_url = CONFIG.sso_issuer_url()?; + + let http_client = match reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none()).build() { + Err(err) => err!(format!("Failed to build http client: {err}")), + Ok(client) => client, + }; + + let provider_metadata = match CoreProviderMetadata::discover_async(issuer_url, &http_client).await { + Err(err) => err!(format!("Failed to discover OpenID provider: {err}")), + Ok(metadata) => metadata, + }; + + let base_client = CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret)); + + let token_uri = match base_client.token_uri() { + Some(uri) => uri.clone(), + None => err!("Failed to discover token_url, cannot proceed"), + }; + + let user_info_url = match base_client.user_info_url() { + Some(url) => url.clone(), + None => err!("Failed to discover user_info url, cannot proceed"), + }; + + let core_client = base_client + .set_redirect_uri(CONFIG.sso_redirect_url()?) + .set_token_uri(token_uri) + .set_user_info_url(user_info_url); + + Ok(Client { + http_client, + core_client, + }) + } + + // Simple cache to prevent recalling the discovery endpoint each time + pub async fn cached() -> ApiResult { + if CONFIG.sso_client_cache_expiration() > 0 { + match CLIENT_CACHE.get(&*CLIENT_CACHE_KEY) { + Some(client) => Ok(client), + None => Self::_get_client().await.inspect(|client| { + debug!("Inserting new client in cache"); + CLIENT_CACHE.insert(CLIENT_CACHE_KEY.clone(), client.clone()); + }), + } + } else { + Self::_get_client().await + } + } + + pub fn invalidate() { + if CONFIG.sso_client_cache_expiration() > 0 { + CLIENT_CACHE.invalidate(&*CLIENT_CACHE_KEY); + } + } + + // The `state` is encoded using base64 to ensure no issue with providers (It contains the Organization identifier). + pub async fn authorize_url(state: OIDCState, redirect_uri: String) -> ApiResult<(Url, SsoNonce)> { + let scopes = CONFIG.sso_scopes_vec().into_iter().map(Scope::new); + let base64_state = data_encoding::BASE64.encode(state.to_string().as_bytes()); + + let client = Self::cached().await?; + let mut auth_req = client + .core_client + .authorize_url( + AuthenticationFlow::::AuthorizationCode, + || CsrfToken::new(base64_state), + Nonce::new_random, + ) + .add_scopes(scopes) + .add_extra_params(CONFIG.sso_authorize_extra_params_vec()?); + + let verifier = if CONFIG.sso_pkce() { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + auth_req = auth_req.set_pkce_challenge(pkce_challenge); + Some(pkce_verifier.into_secret()) + } else { + None + }; + + let (auth_url, _, nonce) = auth_req.url(); + Ok((auth_url, SsoNonce::new(state, nonce.secret().clone(), verifier, redirect_uri))) + } + + pub async fn exchange_code( + &self, + code: OIDCCode, + nonce: SsoNonce, + ) -> ApiResult<( + StandardTokenResponse< + IdTokenFields< + EmptyAdditionalClaims, + EmptyExtraTokenFields, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, + >, + CoreTokenType, + >, + IdTokenClaims, + )> { + let oidc_code = AuthorizationCode::new(code.to_string()); + + let mut exchange = self.core_client.exchange_code(oidc_code); + + if CONFIG.sso_pkce() { + match nonce.verifier { + None => err!(format!("Missing verifier in the DB nonce table")), + Some(secret) => exchange = exchange.set_pkce_verifier(PkceCodeVerifier::new(secret.clone())), + } + } + + match exchange.request_async(&self.http_client).await { + Err(err) => err!(format!("Failed to contact token endpoint: {:?}", err)), + Ok(token_response) => { + let oidc_nonce = Nonce::new(nonce.nonce); + + let id_token = match token_response.extra_fields().id_token() { + None => err!("Token response did not contain an id_token"), + Some(token) => token, + }; + + if CONFIG.sso_debug_tokens() { + debug!("Id token: {}", id_token.to_string()); + debug!("Access token: {}", token_response.access_token().secret()); + debug!("Refresh token: {:?}", token_response.refresh_token().map(|t| t.secret())); + debug!("Expiration time: {:?}", token_response.expires_in()); + } + + let id_claims = match id_token.claims(&self.vw_id_token_verifier(), &oidc_nonce) { + Ok(claims) => claims.clone(), + Err(err) => { + Self::invalidate(); + err!(format!("Could not read id_token claims, {err}")); + } + }; + + Ok((token_response, id_claims)) + } + } + } + + pub async fn user_info(&self, access_token: AccessToken) -> ApiResult { + match self.core_client.user_info(access_token, None).request_async(&self.http_client).await { + Err(err) => err!(format!("Request to user_info endpoint failed: {err}")), + Ok(user_info) => Ok(user_info), + } + } + + pub async fn check_validaty(access_token: String) -> EmptyResult { + let client = Client::cached().await?; + match client.user_info(AccessToken::new(access_token)).await { + Err(err) => { + err_silent!(format!("Failed to retrieve user info, token has probably been invalidated: {err}")) + } + Ok(_) => Ok(()), + } + } + + pub fn vw_id_token_verifier(&self) -> CoreIdTokenVerifier<'_> { + let mut verifier = self.core_client.id_token_verifier(); + if let Some(regex_str) = CONFIG.sso_audience_trusted() { + match Regex::new(®ex_str) { + Ok(regex) => { + verifier = verifier.set_other_audience_verifier_fn(move |aud| regex.is_match(aud)); + } + Err(err) => { + error!("Failed to parse SSO_AUDIENCE_TRUSTED={regex_str} regex: {err}"); + } + } + } + verifier + } + + pub async fn exchange_refresh_token( + refresh_token: String, + ) -> ApiResult<(Option, String, Option)> { + let rt = RefreshToken::new(refresh_token); + + let client = Client::cached().await?; + let token_response = + match client.core_client.exchange_refresh_token(&rt).request_async(&client.http_client).await { + Err(err) => err!(format!("Request to exchange_refresh_token endpoint failed: {:?}", err)), + Ok(token_response) => token_response, + }; + + Ok(( + token_response.refresh_token().map(|token| token.secret().clone()), + token_response.access_token().secret().clone(), + token_response.expires_in(), + )) + } +} + +trait AuthorizationRequestExt<'a> { + fn add_extra_params>, V: Into>>(self, params: Vec<(N, V)>) -> Self; +} + +impl<'a, AD: AuthDisplay, P: AuthPrompt, RT: ResponseType> AuthorizationRequestExt<'a> + for AuthorizationRequest<'a, AD, P, RT> +{ + fn add_extra_params>, V: Into>>(mut self, params: Vec<(N, V)>) -> Self { + for (key, value) in params { + self = self.add_extra_param(key, value); + } + self + } +}