Revove user sessions when OAuth tokens are revoked (#1354)

OAuth tokens are revoked when the user revoke access to the OAuth app from the provider settings.
Some OAuth providers also revoke OAuth tokens when the user changes password.
This commit is contained in:
Solomon
2025-11-28 12:48:56 +00:00
committed by Gabriel Gordon-Hall
parent 1c380c7085
commit e4e129a4e7
18 changed files with 841 additions and 59 deletions

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE oauth_handoffs\n SET\n status = 'authorized',\n error_code = NULL,\n user_id = $2,\n session_id = $3,\n app_code_hash = $4,\n authorized_at = NOW()\n WHERE id = $1\n ",
"query": "\n UPDATE oauth_handoffs\n SET\n status = 'authorized',\n error_code = NULL,\n user_id = $2,\n session_id = $3,\n app_code_hash = $4,\n encrypted_provider_tokens = $5,\n authorized_at = NOW()\n WHERE id = $1\n ",
"describe": {
"columns": [],
"parameters": {
@@ -8,10 +8,11 @@
"Uuid",
"Uuid",
"Uuid",
"Text",
"Text"
]
},
"nullable": []
},
"hash": "128bb938e490a07d9b567f483f1e8f1b004a267c32cfe14bc88c752f61fcc083"
"hash": "11eede7c3a324ffa6266ee5c3fe3fdb2bd3b9e894fcabeece1e8d2201d18dcc6"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO oauth_handoffs (\n provider,\n state,\n return_to,\n app_challenge,\n expires_at\n )\n VALUES ($1, $2, $3, $4, $5)\n RETURNING\n id AS \"id!\",\n provider AS \"provider!\",\n state AS \"state!\",\n return_to AS \"return_to!\",\n app_challenge AS \"app_challenge!\",\n app_code_hash AS \"app_code_hash?\",\n status AS \"status!\",\n error_code AS \"error_code?\",\n expires_at AS \"expires_at!\",\n authorized_at AS \"authorized_at?\",\n redeemed_at AS \"redeemed_at?\",\n user_id AS \"user_id?\",\n session_id AS \"session_id?\",\n created_at AS \"created_at!\",\n updated_at AS \"updated_at!\"\n ",
"query": "\n INSERT INTO oauth_handoffs (\n provider,\n state,\n return_to,\n app_challenge,\n expires_at\n )\n VALUES ($1, $2, $3, $4, $5)\n RETURNING\n id AS \"id!\",\n provider AS \"provider!\",\n state AS \"state!\",\n return_to AS \"return_to!\",\n app_challenge AS \"app_challenge!\",\n app_code_hash AS \"app_code_hash?\",\n status AS \"status!\",\n error_code AS \"error_code?\",\n expires_at AS \"expires_at!\",\n authorized_at AS \"authorized_at?\",\n redeemed_at AS \"redeemed_at?\",\n user_id AS \"user_id?\",\n session_id AS \"session_id?\",\n encrypted_provider_tokens AS \"encrypted_provider_tokens?\",\n created_at AS \"created_at!\",\n updated_at AS \"updated_at!\"\n ",
"describe": {
"columns": [
{
@@ -70,11 +70,16 @@
},
{
"ordinal": 13,
"name": "encrypted_provider_tokens?",
"type_info": "Text"
},
{
"ordinal": 14,
"name": "created_at!",
"type_info": "Timestamptz"
},
{
"ordinal": 14,
"ordinal": 15,
"name": "updated_at!",
"type_info": "Timestamptz"
}
@@ -102,9 +107,10 @@
true,
true,
true,
true,
false,
false
]
},
"hash": "4297d2fa8fd3d037243b8794a5ccfc33af057bcb6c9dc1ac601f82bb65130721"
"hash": "56d467122fa8b6599dc8821f65c2b191f378c9a76d3707d63d8cee1ef31fe4ba"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT\n id AS \"id!\",\n provider AS \"provider!\",\n state AS \"state!\",\n return_to AS \"return_to!\",\n app_challenge AS \"app_challenge!\",\n app_code_hash AS \"app_code_hash?\",\n status AS \"status!\",\n error_code AS \"error_code?\",\n expires_at AS \"expires_at!\",\n authorized_at AS \"authorized_at?\",\n redeemed_at AS \"redeemed_at?\",\n user_id AS \"user_id?\",\n session_id AS \"session_id?\",\n created_at AS \"created_at!\",\n updated_at AS \"updated_at!\"\n FROM oauth_handoffs\n WHERE id = $1\n ",
"query": "\n SELECT\n id AS \"id!\",\n provider AS \"provider!\",\n state AS \"state!\",\n return_to AS \"return_to!\",\n app_challenge AS \"app_challenge!\",\n app_code_hash AS \"app_code_hash?\",\n status AS \"status!\",\n error_code AS \"error_code?\",\n expires_at AS \"expires_at!\",\n authorized_at AS \"authorized_at?\",\n redeemed_at AS \"redeemed_at?\",\n user_id AS \"user_id?\",\n session_id AS \"session_id?\",\n encrypted_provider_tokens AS \"encrypted_provider_tokens?\",\n created_at AS \"created_at!\",\n updated_at AS \"updated_at!\"\n FROM oauth_handoffs\n WHERE id = $1\n ",
"describe": {
"columns": [
{
@@ -70,11 +70,16 @@
},
{
"ordinal": 13,
"name": "encrypted_provider_tokens?",
"type_info": "Text"
},
{
"ordinal": 14,
"name": "created_at!",
"type_info": "Timestamptz"
},
{
"ordinal": 14,
"ordinal": 15,
"name": "updated_at!",
"type_info": "Timestamptz"
}
@@ -98,9 +103,10 @@
true,
true,
true,
true,
false,
false
]
},
"hash": "b4ca0d7fada2acae624ec6a26fdf0354f3d4c1e0d24a6685bfdb8d594c882430"
"hash": "577b1dc54aeefe702c74a56776544a391429b561b76d36d59673e410d5d78576"
}

View File

@@ -1,14 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE oauth_handoffs\n SET\n status = 'redeemed',\n redeemed_at = NOW()\n WHERE id = $1\n AND status = 'authorized'\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": []
},
"hash": "94d0724ca8fdf2bf1c965d70ea3db976f1154439fd6299365b27d12f992e8862"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT\n id AS \"id!\",\n provider AS \"provider!\",\n state AS \"state!\",\n return_to AS \"return_to!\",\n app_challenge AS \"app_challenge!\",\n app_code_hash AS \"app_code_hash?\",\n status AS \"status!\",\n error_code AS \"error_code?\",\n expires_at AS \"expires_at!\",\n authorized_at AS \"authorized_at?\",\n redeemed_at AS \"redeemed_at?\",\n user_id AS \"user_id?\",\n session_id AS \"session_id?\",\n created_at AS \"created_at!\",\n updated_at AS \"updated_at!\"\n FROM oauth_handoffs\n WHERE state = $1\n ",
"query": "\n SELECT\n id AS \"id!\",\n provider AS \"provider!\",\n state AS \"state!\",\n return_to AS \"return_to!\",\n app_challenge AS \"app_challenge!\",\n app_code_hash AS \"app_code_hash?\",\n status AS \"status!\",\n error_code AS \"error_code?\",\n expires_at AS \"expires_at!\",\n authorized_at AS \"authorized_at?\",\n redeemed_at AS \"redeemed_at?\",\n user_id AS \"user_id?\",\n session_id AS \"session_id?\",\n encrypted_provider_tokens AS \"encrypted_provider_tokens?\",\n created_at AS \"created_at!\",\n updated_at AS \"updated_at!\"\n FROM oauth_handoffs\n WHERE state = $1\n ",
"describe": {
"columns": [
{
@@ -70,11 +70,16 @@
},
{
"ordinal": 13,
"name": "encrypted_provider_tokens?",
"type_info": "Text"
},
{
"ordinal": 14,
"name": "created_at!",
"type_info": "Timestamptz"
},
{
"ordinal": 14,
"ordinal": 15,
"name": "updated_at!",
"type_info": "Timestamptz"
}
@@ -98,9 +103,10 @@
true,
true,
true,
true,
false,
false
]
},
"hash": "3a32c3e1e517a81ebf65e5ec3c80b7b557639f8041ef9a890a94f38ea6f9c3cb"
"hash": "95427f2ba8293a8aa51366aad80129a3cfdcd1b3ec4dc8298d3aa7d0c5419191"
}

View File

@@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE oauth_handoffs\n SET\n status = 'redeemed',\n encrypted_provider_tokens = NULL,\n redeemed_at = NOW()\n WHERE id = $1\n AND status = 'authorized'\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": []
},
"hash": "ca680e4e2a221ccaf578639b96730fa0d0fd4451d956f9dfa46670f5980c29a8"
}

View File

@@ -8,6 +8,7 @@ publish = false
anyhow = { workspace = true }
axum = { workspace = true }
axum-extra = { version = "0.10.3", features = ["typed-header"] }
aes-gcm = "0.10"
chrono = { version = "0.4", features = ["serde"] }
futures = "0.3"
async-trait = "0.1"

View File

@@ -0,0 +1,2 @@
ALTER TABLE oauth_handoffs
ADD COLUMN IF NOT EXISTS encrypted_provider_tokens TEXT;

View File

@@ -7,7 +7,8 @@ use crate::{
AppState,
activity::ActivityBroker,
auth::{
GitHubOAuthProvider, GoogleOAuthProvider, JwtService, OAuthHandoffService, ProviderRegistry,
GitHubOAuthProvider, GoogleOAuthProvider, JwtService, OAuthHandoffService,
OAuthTokenValidator, ProviderRegistry,
},
config::RemoteServerConfig,
db,
@@ -70,6 +71,9 @@ impl Server {
auth_config.public_base_url().to_string(),
));
let oauth_token_validator =
Arc::new(OAuthTokenValidator::new(pool.clone(), registry.clone()));
let api_key = std::env::var("LOOPS_EMAIL_API_KEY")
.context("LOOPS_EMAIL_API_KEY environment variable is required")?;
let mailer = Arc::new(LoopsMailer::new(api_key));
@@ -86,6 +90,7 @@ impl Server {
config.clone(),
jwt,
handoff_service,
oauth_token_validator,
mailer,
server_public_base_url,
);

View File

@@ -4,6 +4,7 @@ use anyhow::Error as AnyhowError;
use chrono::{DateTime, Duration, Utc};
use rand::{Rng, distr::Alphanumeric};
use reqwest::StatusCode;
use secrecy::ExposeSecret;
use sha2::{Digest, Sha256};
use sqlx::PgPool;
use thiserror::Error;
@@ -118,6 +119,10 @@ impl OAuthHandoffService {
}
}
pub fn providers(&self) -> Arc<ProviderRegistry> {
Arc::clone(&self.providers)
}
pub async fn initiate(
&self,
provider: &str,
@@ -264,14 +269,36 @@ impl OAuthHandoffService {
let user_profile = self.fetch_user_with_retries(&provider, &grant).await?;
let user = self.upsert_identity(&provider, &user_profile).await?;
let provider_token_details = crate::auth::ProviderTokenDetails {
provider: provider.name().to_string(),
access_token: grant.access_token.expose_secret().to_string(),
refresh_token: grant
.refresh_token
.as_ref()
.map(|t| t.expose_secret().to_string()),
expires_at: grant.expires_in.map(|d| (Utc::now() + d).timestamp()),
};
let session_repo = AuthSessionRepository::new(&self.pool);
let session_record = session_repo.create(user.id, None).await?;
let app_code = generate_app_code();
let app_code_hash = hash_sha256_hex(&app_code);
repo.mark_authorized(record.id, user.id, session_record.id, &app_code_hash)
.await?;
let encrypted_tokens = self
.jwt
.encrypt_provider_tokens(&provider_token_details)
.map_err(|e| HandoffError::Failed(format!("Failed to encrypt provider token: {e}")))?;
repo.mark_authorized(
record.id,
user.id,
session_record.id,
&app_code_hash,
Some(encrypted_tokens),
)
.await?;
configure_user_scope(user.id, user.username.as_deref(), Some(user.email.as_str()));
@@ -319,6 +346,9 @@ impl OAuthHandoffService {
let user_id = record
.user_id
.ok_or_else(|| HandoffError::Failed("missing_user".into()))?;
let encrypted_provider_tokens = record
.encrypted_provider_tokens
.ok_or_else(|| HandoffError::Failed("missing_encrypted_provider_tokens".into()))?;
let session_repo = AuthSessionRepository::new(&self.pool);
let session = session_repo.get(session_id).await?;
@@ -338,7 +368,11 @@ impl OAuthHandoffService {
.ensure_personal_org_and_admin_membership(user.id, user.username.as_deref())
.await?;
let tokens = self.jwt.generate_tokens(&session, &user)?;
let provider_token = self
.jwt
.decrypt_provider_tokens(&encrypted_provider_tokens)?;
let tokens = self.jwt.generate_tokens(&session, &user, provider_token)?;
session_repo
.set_current_refresh_token(session.id, tokens.refresh_token_id)

View File

@@ -1,13 +1,25 @@
use std::{collections::HashSet, sync::Arc};
use aes_gcm::{
Aes256Gcm, Key, Nonce,
aead::{Aead, AeadCore, KeyInit, OsRng},
};
use base64::{
Engine as _,
engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD},
};
use chrono::{DateTime, Duration as ChronoDuration, Utc};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use thiserror::Error;
use uuid::Uuid;
use crate::db::{auth::AuthSession, users::User};
use crate::{
auth::provider::ProviderTokenDetails,
db::{auth::AuthSession, users::User},
};
pub const ACCESS_TOKEN_TTL_SECONDS: i64 = 120;
pub const REFRESH_TOKEN_TTL_DAYS: i64 = 365;
@@ -27,6 +39,10 @@ pub enum JwtError {
SessionRevoked,
#[error("token type mismatch")]
InvalidTokenType,
#[error("encryption error")]
EncryptionError,
#[error("serialization error")]
SerializationError,
#[error(transparent)]
Jwt(#[from] jsonwebtoken::errors::Error),
}
@@ -48,6 +64,7 @@ pub struct RefreshTokenClaims {
pub iat: i64,
pub exp: i64,
pub aud: String,
pub provider_tokens_blob: String, // Encrypted JSON blob containing provider tokens
}
#[derive(Debug, Clone)]
@@ -62,18 +79,20 @@ pub struct RefreshTokenDetails {
pub user_id: Uuid,
pub session_id: Uuid,
pub refresh_token_id: Uuid,
pub provider_token_details: ProviderTokenDetails,
}
#[derive(Clone)]
pub struct JwtService {
secret: Arc<SecretString>,
pub secret: Arc<SecretString>,
}
#[derive(Debug, Clone)]
pub struct TokenPair {
pub struct Tokens {
pub access_token: String,
pub refresh_token: String,
pub refresh_token_id: Uuid,
pub encrypted_provider_tokens: String,
}
impl JwtService {
@@ -87,7 +106,8 @@ impl JwtService {
&self,
session: &AuthSession,
user: &User,
) -> Result<TokenPair, JwtError> {
provider_token: ProviderTokenDetails,
) -> Result<Tokens, JwtError> {
let now = Utc::now();
let refresh_token_id = Uuid::new_v4();
@@ -101,6 +121,8 @@ impl JwtService {
aud: "access".to_string(),
};
let encrypted_provider_tokens = self.encrypt_provider_tokens(&provider_token)?;
// Refresh token, long-lived (~1 year)
let refresh_exp = now + ChronoDuration::days(REFRESH_TOKEN_TTL_DAYS);
let refresh_claims = RefreshTokenClaims {
@@ -110,6 +132,7 @@ impl JwtService {
iat: now.timestamp(),
exp: refresh_exp.timestamp(),
aud: "refresh".to_string(),
provider_tokens_blob: encrypted_provider_tokens.clone(),
};
let encoding_key = EncodingKey::from_base64_secret(self.secret.expose_secret())?;
@@ -126,10 +149,11 @@ impl JwtService {
&encoding_key,
)?;
Ok(TokenPair {
Ok(Tokens {
access_token,
refresh_token,
refresh_token_id,
encrypted_provider_tokens,
})
}
@@ -186,11 +210,80 @@ impl JwtService {
let decoding_key = DecodingKey::from_base64_secret(self.secret.expose_secret())?;
let data = decode::<RefreshTokenClaims>(token, &decoding_key, &validation)?;
let claims = data.claims;
let provider_token_details = self.decrypt_provider_tokens(&claims.provider_tokens_blob)?;
Ok(RefreshTokenDetails {
user_id: claims.sub,
session_id: claims.session_id,
refresh_token_id: claims.jti,
provider_token_details,
})
}
pub fn decrypt_provider_tokens(
&self,
provider_tokens_blob: &str,
) -> Result<ProviderTokenDetails, JwtError> {
let decrypted = self.decrypt_data(provider_tokens_blob)?;
let decrypted_str = String::from_utf8_lossy(&decrypted);
serde_json::from_str(&decrypted_str).map_err(|_| JwtError::InvalidToken)
}
pub fn encrypt_provider_tokens(
&self,
provider_tokens: &ProviderTokenDetails,
) -> Result<String, JwtError> {
let json =
serde_json::to_string(provider_tokens).map_err(|_| JwtError::SerializationError)?;
self.encrypt_data(json.as_bytes())
}
fn encrypt_data(&self, data: &[u8]) -> Result<String, JwtError> {
let key_bytes = self.derive_key()?;
let key = Key::<Aes256Gcm>::from(key_bytes);
let cipher = Aes256Gcm::new(&key);
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = cipher
.encrypt(&nonce, data)
.map_err(|_| JwtError::EncryptionError)?;
let mut combined = nonce.to_vec();
combined.extend_from_slice(&ciphertext);
Ok(URL_SAFE_NO_PAD.encode(combined))
}
fn decrypt_data(&self, encrypted: &str) -> Result<Vec<u8>, JwtError> {
let decoded = URL_SAFE_NO_PAD
.decode(encrypted)
.map_err(|_| JwtError::InvalidToken)?;
const NONCE_SIZE: usize = 12; // 96 bits for AES-256-GCM
if decoded.len() < NONCE_SIZE {
return Err(JwtError::InvalidToken);
}
let key_bytes = self.derive_key()?;
let key = Key::<Aes256Gcm>::from(key_bytes);
let cipher = Aes256Gcm::new(&key);
let nonce_bytes: [u8; NONCE_SIZE] = decoded[..NONCE_SIZE]
.try_into()
.map_err(|_| JwtError::InvalidToken)?;
let nonce = Nonce::from(nonce_bytes);
let ciphertext = &decoded[NONCE_SIZE..];
cipher
.decrypt(&nonce, ciphertext)
.map_err(|_| JwtError::EncryptionError)
}
fn derive_key(&self) -> Result<[u8; 32], JwtError> {
let secret_bytes = STANDARD
.decode(self.secret.expose_secret())
.map_err(|_| JwtError::InvalidSecret)?;
let mut hasher = Sha256::new();
hasher.update(&secret_bytes);
Ok(hasher.finalize().into())
}
}

View File

@@ -1,9 +1,13 @@
mod handoff;
mod jwt;
mod middleware;
mod oauth_token_validator;
mod provider;
pub use handoff::{CallbackResult, HandoffError, OAuthHandoffService};
pub use jwt::{JwtError, JwtService};
pub use middleware::{RequestContext, require_session};
pub use provider::{GitHubOAuthProvider, GoogleOAuthProvider, ProviderRegistry};
pub use oauth_token_validator::{OAuthTokenValidationError, OAuthTokenValidator};
pub use provider::{
GitHubOAuthProvider, GoogleOAuthProvider, ProviderRegistry, ProviderTokenDetails,
};

View File

@@ -0,0 +1,148 @@
use std::sync::Arc;
use sqlx::PgPool;
use tracing::{info, warn};
use crate::{
auth::{
ProviderTokenDetails,
provider::{ProviderRegistry, TokenValidationError, VALIDATE_TOKEN_MAX_RETRIES},
},
db::{
auth::AuthSessionRepository,
oauth_accounts::{OAuthAccountError, OAuthAccountRepository},
},
};
#[derive(Debug, thiserror::Error)]
pub enum OAuthTokenValidationError {
#[error("failed to fetch OAuth accounts for user")]
FetchAccountsFailed(OAuthAccountError),
#[error("provider account no longer linked to user")]
ProviderAccountNotLinked,
#[error("OAuth provider token validation failed")]
ProviderTokenValidationFailed,
#[error("temporary failure validating provider token: {0}")]
ValidationUnavailable(String),
}
pub struct OAuthTokenValidator {
pool: PgPool,
provider_registry: Arc<ProviderRegistry>,
}
impl OAuthTokenValidator {
pub fn new(pool: PgPool, provider_registry: Arc<ProviderRegistry>) -> Self {
Self {
pool,
provider_registry,
}
}
// Check if the OAuth provider token is still valid, refresh if possible
// Revoke all sessions if provider has revoked the OAuth token
pub async fn validate(
&self,
provider_token_details: ProviderTokenDetails,
user_id: uuid::Uuid,
session_id: uuid::Uuid,
) -> Result<ProviderTokenDetails, OAuthTokenValidationError> {
match self
.verify_inner(provider_token_details, user_id, session_id)
.await
{
Ok(updated_token_details) => Ok(updated_token_details),
Err(err) => {
match &err {
OAuthTokenValidationError::ProviderAccountNotLinked
| OAuthTokenValidationError::ProviderTokenValidationFailed
| OAuthTokenValidationError::FetchAccountsFailed(_) => {
let session_repo = AuthSessionRepository::new(&self.pool);
if let Err(e) = session_repo.revoke_all_user_sessions(user_id).await {
warn!(
user_id = %user_id,
error = %e,
"Failed to revoke all user sessions after OAuth token validation failure"
);
}
}
OAuthTokenValidationError::ValidationUnavailable(_) => (),
};
Err(err)
}
}
}
async fn verify_inner(
&self,
mut provider_token_details: ProviderTokenDetails,
user_id: uuid::Uuid,
session_id: uuid::Uuid,
) -> Result<ProviderTokenDetails, OAuthTokenValidationError> {
let oauth_account_repo = OAuthAccountRepository::new(&self.pool);
let accounts = match oauth_account_repo.list_by_user(user_id).await {
Ok(accounts) => accounts,
Err(err) => {
warn!(
user_id = %user_id,
error = %err,
"Failed to fetch OAuth accounts for user"
);
return Err(OAuthTokenValidationError::FetchAccountsFailed(err));
}
};
let account_exists = accounts
.iter()
.any(|a| a.provider == provider_token_details.provider);
if !account_exists {
warn!(
user_id = %user_id,
provider = %provider_token_details.provider,
"Provider account no longer linked to user, revoking sessions"
);
return Err(OAuthTokenValidationError::ProviderAccountNotLinked);
}
let Some(provider) = self.provider_registry.get(&provider_token_details.provider) else {
warn!(
user_id = %user_id,
provider = %provider_token_details.provider,
"OAuth provider not found in registry, revoking all sessions"
);
return Err(OAuthTokenValidationError::ProviderTokenValidationFailed);
};
match provider
.validate_token(&provider_token_details, VALIDATE_TOKEN_MAX_RETRIES)
.await
{
Ok(Some(updated_token_details)) => {
provider_token_details = updated_token_details;
}
Ok(None) => {}
Err(TokenValidationError::InvalidOrRevoked) => {
info!(
user_id = %user_id,
provider = %provider_token_details.provider,
session_id = %session_id,
"OAuth provider reported token as invalid or revoked"
);
return Err(OAuthTokenValidationError::ProviderTokenValidationFailed);
}
Err(TokenValidationError::Temporary(reason)) => {
warn!(
user_id = %user_id,
provider = %provider_token_details.provider,
session_id = %session_id,
error = %reason,
"OAuth provider validation temporarily unavailable"
);
return Err(OAuthTokenValidationError::ValidationUnavailable(reason));
}
}
Ok(provider_token_details)
}
}

View File

@@ -5,11 +5,17 @@ use async_trait::async_trait;
use chrono::Duration;
use reqwest::Client;
use secrecy::{ExposeSecret, SecretString};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::info;
use url::Url;
const USER_AGENT: &str = "VibeKanbanRemote/1.0";
const TOKEN_EXPIRATION_LEEWAY_SECONDS: i64 = 20;
pub const VALIDATE_TOKEN_MAX_RETRIES: u32 = 3;
const RETRY_INTERVAL_SECONDS: u64 = 2;
#[derive(Debug, Clone)]
pub struct AuthorizationGrant {
pub access_token: SecretString,
@@ -29,6 +35,28 @@ pub struct ProviderUser {
pub avatar_url: Option<String>,
}
#[derive(Debug, Error)]
pub enum TokenValidationError {
#[error("provider token invalid or revoked")]
InvalidOrRevoked,
#[error("provider validation temporarily unavailable: {0}")]
Temporary(String),
}
impl TokenValidationError {
fn temporary(message: impl Into<String>) -> Self {
Self::Temporary(message.into())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderTokenDetails {
pub provider: String,
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at: Option<i64>,
}
#[async_trait]
pub trait AuthorizationProvider: Send + Sync {
fn name(&self) -> &'static str;
@@ -36,6 +64,11 @@ pub trait AuthorizationProvider: Send + Sync {
fn authorize_url(&self, state: &str, redirect_uri: &str) -> Result<Url>;
async fn exchange_code(&self, code: &str, redirect_uri: &str) -> Result<AuthorizationGrant>;
async fn fetch_user(&self, access_token: &SecretString) -> Result<ProviderUser>;
async fn validate_token(
&self,
token_details: &ProviderTokenDetails,
max_retries: u32,
) -> Result<Option<ProviderTokenDetails>, TokenValidationError>;
}
#[derive(Default)]
@@ -232,6 +265,99 @@ impl AuthorizationProvider for GitHubOAuthProvider {
avatar_url: user.avatar_url,
})
}
async fn validate_token(
&self,
token_details: &ProviderTokenDetails,
max_retries: u32,
) -> Result<Option<ProviderTokenDetails>, TokenValidationError> {
let mut attempt = 0;
let access_token = SecretString::new(token_details.access_token.clone().into_boxed_str());
loop {
attempt += 1;
let response = match self
.client
.get("https://api.github.com/rate_limit")
.header(
"Authorization",
format!("Bearer {}", access_token.expose_secret()),
)
.header("Accept", "application/vnd.github+json")
.send()
.await
{
Ok(resp) => resp,
Err(err) => {
if attempt >= max_retries {
return Err(TokenValidationError::temporary(format!(
"request failed: {err}"
)));
}
tokio::time::sleep(tokio::time::Duration::from_secs(RETRY_INTERVAL_SECONDS))
.await;
continue;
}
};
match response.status() {
reqwest::StatusCode::OK => {
// GitHub tokens don't expire
return Ok(None);
}
reqwest::StatusCode::UNAUTHORIZED => {
return Err(TokenValidationError::InvalidOrRevoked);
}
reqwest::StatusCode::FORBIDDEN => {
// Check if rate limited
let rate_limit_remaining = response
.headers()
.get("x-ratelimit-remaining")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<i32>().ok())
.unwrap_or(1);
if rate_limit_remaining == 0 {
if attempt <= max_retries {
// Get reset time and wait
if let Some(reset_str) = response
.headers()
.get("x-ratelimit-reset")
.and_then(|v| v.to_str().ok())
&& let Ok(reset_time) = reset_str.parse::<i64>()
{
let now = chrono::Utc::now().timestamp();
let wait_seconds = (reset_time - now).clamp(0, 5);
tokio::time::sleep(tokio::time::Duration::from_secs(
wait_seconds as u64,
))
.await;
continue;
}
}
return Err(TokenValidationError::temporary("rate limited by GitHub"));
} else {
return Err(TokenValidationError::temporary(
"access forbidden during validation",
));
}
}
status => {
if status.is_server_error() && attempt <= max_retries {
tokio::time::sleep(tokio::time::Duration::from_secs(
RETRY_INTERVAL_SECONDS,
))
.await;
continue;
}
return Err(TokenValidationError::temporary(format!(
"unexpected validation status: {status}"
)));
}
}
}
}
}
pub struct GoogleOAuthProvider {
@@ -249,6 +375,92 @@ impl GoogleOAuthProvider {
client_secret,
})
}
async fn try_refresh_access_token(
&self,
refresh_token: &str,
) -> Result<ProviderTokenDetails, TokenValidationError> {
let response = match self
.client
.post("https://oauth2.googleapis.com/token")
.form(&[
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.expose_secret()),
("refresh_token", refresh_token),
("grant_type", "refresh_token"),
])
.send()
.await
{
Ok(resp) => resp,
Err(err) => {
return Err(TokenValidationError::temporary(format!(
"refresh request failed: {err}"
)));
}
};
match response.status() {
reqwest::StatusCode::OK => {
#[derive(Debug, Deserialize)]
struct RefreshResponse {
access_token: String,
expires_in: i64,
#[serde(default)]
refresh_token: Option<String>,
}
let refresh_data: RefreshResponse = response
.json()
.await
.map_err(|err| TokenValidationError::temporary(format!("{err}")))?;
let expires_at = chrono::Utc::now().timestamp() + refresh_data.expires_in;
let new_refresh_token = refresh_data
.refresh_token
.unwrap_or_else(|| refresh_token.to_string());
Ok(ProviderTokenDetails {
provider: self.name().to_string(),
access_token: refresh_data.access_token,
refresh_token: Some(new_refresh_token),
expires_at: Some(expires_at),
})
}
reqwest::StatusCode::BAD_REQUEST => Err(TokenValidationError::InvalidOrRevoked),
status if status.is_server_error() => Err(TokenValidationError::temporary(format!(
"token refresh server error: {status}"
))),
status => Err(TokenValidationError::temporary(format!(
"unexpected token refresh status: {status}"
))),
}
}
async fn refresh_token(
&self,
refresh_token: &str,
max_retries: u32,
) -> Result<ProviderTokenDetails, TokenValidationError> {
let mut attempt = 0;
loop {
attempt += 1;
match self.try_refresh_access_token(refresh_token).await {
Ok(new_token_details) => return Ok(new_token_details),
Err(TokenValidationError::InvalidOrRevoked) => {
return Err(TokenValidationError::InvalidOrRevoked);
}
Err(TokenValidationError::Temporary(err)) => {
if attempt >= max_retries {
return Err(TokenValidationError::Temporary(err));
}
tokio::time::sleep(tokio::time::Duration::from_secs(RETRY_INTERVAL_SECONDS))
.await;
}
}
}
}
}
#[derive(Debug, Deserialize)]
@@ -386,4 +598,95 @@ impl AuthorizationProvider for GoogleOAuthProvider {
avatar_url: profile.picture,
})
}
async fn validate_token(
&self,
token_details: &ProviderTokenDetails,
max_retries: u32,
) -> Result<Option<ProviderTokenDetails>, TokenValidationError> {
let mut attempt = 0;
let access_token = SecretString::new(token_details.access_token.clone().into_boxed_str());
loop {
attempt += 1;
if let Some(expires_at) = token_details.expires_at
&& let now = chrono::Utc::now().timestamp()
&& now >= expires_at - TOKEN_EXPIRATION_LEEWAY_SECONDS
{
let Some(refresh_token) = &token_details.refresh_token else {
return Err(TokenValidationError::InvalidOrRevoked);
};
info!("Token expired, attempting refresh for Google OAuth");
return self
.refresh_token(refresh_token, max_retries)
.await
.map(Some);
}
let response = match self
.client
.get("https://www.googleapis.com/oauth2/v2/tokeninfo")
.query(&[("access_token", access_token.expose_secret())])
.send()
.await
{
Ok(resp) => resp,
Err(err) => {
if attempt >= max_retries {
return Err(TokenValidationError::temporary(format!(
"tokeninfo request failed: {err}"
)));
}
tokio::time::sleep(tokio::time::Duration::from_secs(RETRY_INTERVAL_SECONDS))
.await;
continue;
}
};
match response.status() {
reqwest::StatusCode::OK => {
return Ok(None);
}
reqwest::StatusCode::BAD_REQUEST => {
let Some(refresh_token) = &token_details.refresh_token else {
return Err(TokenValidationError::InvalidOrRevoked);
};
info!("Token expired during validation, attempting refresh");
return self
.refresh_token(refresh_token, max_retries)
.await
.map(Some);
}
reqwest::StatusCode::TOO_MANY_REQUESTS => {
if attempt >= max_retries {
return Err(TokenValidationError::temporary(
"rate limited by Google".to_string(),
));
}
tokio::time::sleep(tokio::time::Duration::from_secs(RETRY_INTERVAL_SECONDS))
.await;
}
status if status.is_server_error() => {
if attempt >= max_retries {
return Err(TokenValidationError::temporary(format!(
"google tokeninfo server error: {status}"
)));
}
tokio::time::sleep(tokio::time::Duration::from_secs(RETRY_INTERVAL_SECONDS))
.await;
}
status => {
if attempt >= max_retries {
return Err(TokenValidationError::temporary(format!(
"unexpected tokeninfo status: {status}"
)));
}
tokio::time::sleep(tokio::time::Duration::from_secs(RETRY_INTERVAL_SECONDS))
.await;
}
}
}
}
}

View File

@@ -68,6 +68,7 @@ pub struct OAuthHandoff {
pub redeemed_at: Option<DateTime<Utc>>,
pub user_id: Option<Uuid>,
pub session_id: Option<Uuid>,
pub encrypted_provider_tokens: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
@@ -112,21 +113,22 @@ impl<'a> OAuthHandoffRepository<'a> {
)
VALUES ($1, $2, $3, $4, $5)
RETURNING
id AS "id!",
provider AS "provider!",
state AS "state!",
return_to AS "return_to!",
app_challenge AS "app_challenge!",
app_code_hash AS "app_code_hash?",
status AS "status!",
error_code AS "error_code?",
expires_at AS "expires_at!",
authorized_at AS "authorized_at?",
redeemed_at AS "redeemed_at?",
user_id AS "user_id?",
session_id AS "session_id?",
created_at AS "created_at!",
updated_at AS "updated_at!"
id AS "id!",
provider AS "provider!",
state AS "state!",
return_to AS "return_to!",
app_challenge AS "app_challenge!",
app_code_hash AS "app_code_hash?",
status AS "status!",
error_code AS "error_code?",
expires_at AS "expires_at!",
authorized_at AS "authorized_at?",
redeemed_at AS "redeemed_at?",
user_id AS "user_id?",
session_id AS "session_id?",
encrypted_provider_tokens AS "encrypted_provider_tokens?",
created_at AS "created_at!",
updated_at AS "updated_at!"
"#,
data.provider,
data.state,
@@ -156,7 +158,8 @@ impl<'a> OAuthHandoffRepository<'a> {
authorized_at AS "authorized_at?",
redeemed_at AS "redeemed_at?",
user_id AS "user_id?",
session_id AS "session_id?",
session_id AS "session_id?",
encrypted_provider_tokens AS "encrypted_provider_tokens?",
created_at AS "created_at!",
updated_at AS "updated_at!"
FROM oauth_handoffs
@@ -186,7 +189,8 @@ impl<'a> OAuthHandoffRepository<'a> {
authorized_at AS "authorized_at?",
redeemed_at AS "redeemed_at?",
user_id AS "user_id?",
session_id AS "session_id?",
session_id AS "session_id?",
encrypted_provider_tokens AS "encrypted_provider_tokens?",
created_at AS "created_at!",
updated_at AS "updated_at!"
FROM oauth_handoffs
@@ -228,6 +232,7 @@ impl<'a> OAuthHandoffRepository<'a> {
user_id: Uuid,
session_id: Uuid,
app_code_hash: &str,
encrypted_provider_tokens: Option<String>,
) -> Result<(), OAuthHandoffError> {
sqlx::query!(
r#"
@@ -238,13 +243,15 @@ impl<'a> OAuthHandoffRepository<'a> {
user_id = $2,
session_id = $3,
app_code_hash = $4,
encrypted_provider_tokens = $5,
authorized_at = NOW()
WHERE id = $1
"#,
id,
user_id,
session_id,
app_code_hash
app_code_hash,
encrypted_provider_tokens
)
.execute(self.pool)
.await?;
@@ -257,6 +264,7 @@ impl<'a> OAuthHandoffRepository<'a> {
UPDATE oauth_handoffs
SET
status = 'redeemed',
encrypted_provider_tokens = NULL,
redeemed_at = NOW()
WHERE id = $1
AND status = 'authorized'

View File

@@ -10,10 +10,11 @@ use utils::api::oauth::{TokenRefreshRequest, TokenRefreshResponse};
use crate::{
AppState,
auth::JwtError,
auth::{JwtError, OAuthTokenValidationError},
db::{
auth::{AuthSessionError, AuthSessionRepository},
identity_errors::IdentityError,
oauth_accounts::OAuthAccountError,
users::UserRepository,
},
};
@@ -32,6 +33,10 @@ pub enum TokenRefreshError {
TokenExpired,
#[error("refresh token reused - possible token theft")]
TokenReuseDetected,
#[error("provider token has been revoked")]
ProviderTokenRevoked,
#[error("temporary failure validating provider token")]
ProviderValidationUnavailable(String),
#[error(transparent)]
Jwt(#[from] JwtError),
#[error(transparent)]
@@ -42,6 +47,23 @@ pub enum TokenRefreshError {
Identity(#[from] IdentityError),
}
impl From<OAuthTokenValidationError> for TokenRefreshError {
fn from(err: OAuthTokenValidationError) -> Self {
match err {
OAuthTokenValidationError::ProviderAccountNotLinked
| OAuthTokenValidationError::ProviderTokenValidationFailed => {
TokenRefreshError::ProviderTokenRevoked
}
OAuthTokenValidationError::FetchAccountsFailed(inner) => match inner {
OAuthAccountError::Database(db_err) => TokenRefreshError::Database(db_err),
},
OAuthTokenValidationError::ValidationUnavailable(reason) => {
TokenRefreshError::ProviderValidationUnavailable(reason)
}
}
}
}
pub async fn refresh_token(
State(state): State<AppState>,
Json(payload): Json<TokenRefreshRequest>,
@@ -79,10 +101,20 @@ pub async fn refresh_token(
return Err(TokenRefreshError::TokenReuseDetected);
}
// Check if provider has revoked the OAuth token
let provider_token_details = state
.oauth_token_validator()
.validate(
token_details.provider_token_details.clone(),
token_details.user_id,
token_details.session_id,
)
.await?;
let user_repo = UserRepository::new(state.pool());
let user = user_repo.fetch_user(token_details.user_id).await?;
let tokens = jwt_service.generate_tokens(&session, &user)?;
let tokens = jwt_service.generate_tokens(&session, &user, provider_token_details)?;
let old_token_id = token_details.refresh_token_id;
let new_token_id = tokens.refresh_token_id;
@@ -123,9 +155,27 @@ impl IntoResponse for TokenRefreshError {
TokenRefreshError::TokenReuseDetected => {
(StatusCode::UNAUTHORIZED, "token_reuse_detected")
}
TokenRefreshError::ProviderTokenRevoked => {
(StatusCode::UNAUTHORIZED, "provider_token_revoked")
}
TokenRefreshError::ProviderValidationUnavailable(ref reason) => {
warn!(
reason = reason.as_str(),
"Provider validation temporarily unavailable during refresh"
);
(
StatusCode::SERVICE_UNAVAILABLE,
"provider_validation_unavailable",
)
}
TokenRefreshError::Jwt(_) => (StatusCode::UNAUTHORIZED, "invalid_token"),
TokenRefreshError::Identity(_) => (StatusCode::UNAUTHORIZED, "identity_error"),
TokenRefreshError::Database(_) | TokenRefreshError::SessionError(_) => {
TokenRefreshError::Database(ref err) => {
tracing::error!(error = %err, "Database error during token refresh");
(StatusCode::INTERNAL_SERVER_ERROR, "internal_error")
}
TokenRefreshError::SessionError(ref err) => {
tracing::error!(error = %err, "Session error during token refresh");
(StatusCode::INTERNAL_SERVER_ERROR, "internal_error")
}
};

View File

@@ -4,7 +4,7 @@ use sqlx::PgPool;
use crate::{
activity::ActivityBroker,
auth::{JwtService, OAuthHandoffService},
auth::{JwtService, OAuthHandoffService, OAuthTokenValidator, ProviderRegistry},
config::RemoteServerConfig,
mail::Mailer,
};
@@ -17,16 +17,19 @@ pub struct AppState {
pub jwt: Arc<JwtService>,
pub mailer: Arc<dyn Mailer>,
pub server_public_base_url: String,
handoff: Arc<OAuthHandoffService>,
pub handoff: Arc<OAuthHandoffService>,
pub oauth_token_validator: Arc<OAuthTokenValidator>,
}
impl AppState {
#[allow(clippy::too_many_arguments)]
pub fn new(
pool: PgPool,
broker: ActivityBroker,
config: RemoteServerConfig,
jwt: Arc<JwtService>,
handoff: Arc<OAuthHandoffService>,
oauth_token_validator: Arc<OAuthTokenValidator>,
mailer: Arc<dyn Mailer>,
server_public_base_url: String,
) -> Self {
@@ -38,6 +41,7 @@ impl AppState {
mailer,
server_public_base_url,
handoff,
oauth_token_validator,
}
}
@@ -60,4 +64,12 @@ impl AppState {
pub fn handoff(&self) -> Arc<OAuthHandoffService> {
Arc::clone(&self.handoff)
}
pub fn providers(&self) -> Arc<ProviderRegistry> {
self.handoff.providers()
}
pub fn oauth_token_validator(&self) -> Arc<OAuthTokenValidator> {
Arc::clone(&self.oauth_token_validator)
}
}