JWT: separate access tokens and refresh tokens (#1315)

This commit is contained in:
Solomon
2025-11-19 18:07:12 +00:00
committed by GitHub
parent f3d963c285
commit 84454b54a1
33 changed files with 983 additions and 303 deletions

12
Cargo.lock generated
View File

@@ -2474,9 +2474,9 @@ dependencies = [
[[package]]
name = "jsonwebtoken"
version = "10.1.0"
version = "10.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d119c6924272d16f0ab9ce41f7aa0bfef9340c00b0bb7ca3dd3b263d4a9150b"
checksum = "c76e1c7d7df3e34443b3621b459b066a7b79644f059fc8b2db7070c825fd417e"
dependencies = [
"base64",
"ed25519-dalek",
@@ -3480,7 +3480,7 @@ dependencies = [
"quinn-udp",
"rustc-hash 2.1.1",
"rustls",
"socket2 0.5.10",
"socket2 0.6.1",
"thiserror 2.0.17",
"tokio",
"tracing",
@@ -3517,9 +3517,9 @@ dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2 0.5.10",
"socket2 0.6.1",
"tracing",
"windows-sys 0.52.0",
"windows-sys 0.60.2",
]
[[package]]
@@ -5600,7 +5600,7 @@ dependencies = [
"futures-util",
"git2",
"json-patch",
"jsonwebtoken 10.1.0",
"jsonwebtoken 10.2.0",
"open",
"regex",
"reqwest",

View File

@@ -294,15 +294,6 @@ impl LocalDeployment {
self.remote_client.clone()
}
/// Convenience method to get the current JWT auth token.
/// Returns None if the user is not authenticated.
pub async fn auth_token(&self) -> Option<String> {
self.auth_context
.get_credentials()
.await
.map(|c| c.access_token)
}
pub async fn get_login_status(&self) -> LoginStatus {
if self.auth_context.get_credentials().await.is_none() {
self.auth_context.clear_profile().await;

View File

@@ -0,0 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO revoked_refresh_tokens (token_id, user_id, revoked_reason)\n VALUES ($1, $2, 'token_rotation')\n ON CONFLICT (token_id) DO NOTHING\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Uuid"
]
},
"nullable": []
},
"hash": "082aaf51a023c8ccb44002ce48287acd8ef90b0f4c8338447c6e5370ca93390b"
}

View File

@@ -0,0 +1,24 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE auth_sessions\n SET refresh_token_id = $3,\n refresh_token_issued_at = NOW()\n WHERE id = $1\n AND refresh_token_id = $2\n RETURNING user_id\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "user_id",
"type_info": "Uuid"
}
],
"parameters": {
"Left": [
"Uuid",
"Uuid",
"Uuid"
]
},
"nullable": [
false
]
},
"hash": "2f3898ec50ee1386f87786c605069aac78d5177feaabd719b60e54f94f5f535e"
}

View File

@@ -0,0 +1,22 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT EXISTS(\n SELECT 1 FROM revoked_refresh_tokens WHERE token_id = $1\n ) as is_revoked\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "is_revoked",
"type_info": "Bool"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
null
]
},
"hash": "389b412ed9b76973a5b1546a24167e0b752467405f024de73101b6c12e1e05f1"
}

View File

@@ -0,0 +1,59 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO auth_sessions (user_id, refresh_token_id)\n VALUES ($1, $2)\n RETURNING\n id AS \"id!\",\n user_id AS \"user_id!: Uuid\",\n created_at AS \"created_at!\",\n last_used_at AS \"last_used_at?\",\n revoked_at AS \"revoked_at?\",\n refresh_token_id AS \"refresh_token_id?\",\n refresh_token_issued_at AS \"refresh_token_issued_at?\"\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id!",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "user_id!: Uuid",
"type_info": "Uuid"
},
{
"ordinal": 2,
"name": "created_at!",
"type_info": "Timestamptz"
},
{
"ordinal": 3,
"name": "last_used_at?",
"type_info": "Timestamptz"
},
{
"ordinal": 4,
"name": "revoked_at?",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"name": "refresh_token_id?",
"type_info": "Uuid"
},
{
"ordinal": 6,
"name": "refresh_token_issued_at?",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Uuid",
"Uuid"
]
},
"nullable": [
false,
false,
false,
true,
true,
true,
true
]
},
"hash": "4d963a12190ee1db657446ef451c5364f8f91153f7f1bb4e5abfd3f3ddbe0461"
}

View File

@@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE auth_sessions\n SET revoked_at = NOW()\n WHERE user_id = $1\n AND revoked_at IS NULL\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": []
},
"hash": "68422b179dc361337c65a6bd1aa455a961708b97a673d84f7af64cd252cbfdf3"
}

View File

@@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO revoked_refresh_tokens (token_id, user_id, revoked_reason)\n SELECT refresh_token_id, user_id, 'reuse_of_revoked_token'\n FROM auth_sessions\n WHERE user_id = $1\n AND refresh_token_id IS NOT NULL\n ON CONFLICT (token_id) DO NOTHING\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": []
},
"hash": "8e32d5bf86d112e2f4a16f622bd95c8f728946f01e1a994a9c66b0fac6e3ae52"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT\n id AS \"id!\",\n user_id AS \"user_id!: Uuid\",\n session_secret_hash AS \"session_secret_hash?\",\n created_at AS \"created_at!\",\n last_used_at AS \"last_used_at?\",\n revoked_at AS \"revoked_at?\"\n FROM auth_sessions\n WHERE id = $1\n ",
"query": "\n SELECT\n id AS \"id!\",\n user_id AS \"user_id!: Uuid\",\n created_at AS \"created_at!\",\n last_used_at AS \"last_used_at?\",\n revoked_at AS \"revoked_at?\",\n refresh_token_id AS \"refresh_token_id?\",\n refresh_token_issued_at AS \"refresh_token_issued_at?\"\n FROM auth_sessions\n WHERE id = $1\n ",
"describe": {
"columns": [
{
@@ -15,23 +15,28 @@
},
{
"ordinal": 2,
"name": "session_secret_hash?",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "created_at!",
"type_info": "Timestamptz"
},
{
"ordinal": 4,
"ordinal": 3,
"name": "last_used_at?",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"ordinal": 4,
"name": "revoked_at?",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"name": "refresh_token_id?",
"type_info": "Uuid"
},
{
"ordinal": 6,
"name": "refresh_token_issued_at?",
"type_info": "Timestamptz"
}
],
"parameters": {
@@ -42,11 +47,12 @@
"nullable": [
false,
false,
true,
false,
true,
true,
true,
true
]
},
"hash": "d12fbd108d36c817c94997744b50cafd08407c0e207e2cacd43c50d28e886b19"
"hash": "9459cf92b30943acb79f0e0f2e9421be83ce9e50e39f6b1e435b92ff70907264"
}

View File

@@ -1,15 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE auth_sessions\n SET session_secret_hash = $2\n WHERE id = $1\n ",
"query": "\n UPDATE auth_sessions\n SET refresh_token_id = $2,\n refresh_token_issued_at = NOW()\n WHERE id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Text"
"Uuid"
]
},
"nullable": []
},
"hash": "92d13927cde8ac62cb0cfd3c3410aa4d42717d6a3a219926ddc34ca1d2520306"
"hash": "a1431ca78db627fef0eca6f573b34d65510e9333765126cbd80c943046dfaea8"
}

View File

@@ -1,53 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO auth_sessions (user_id, session_secret_hash)\n VALUES ($1, $2)\n RETURNING\n id AS \"id!\",\n user_id AS \"user_id!: Uuid\",\n session_secret_hash AS \"session_secret_hash?\",\n created_at AS \"created_at!\",\n last_used_at AS \"last_used_at?\",\n revoked_at AS \"revoked_at?\"\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id!",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "user_id!: Uuid",
"type_info": "Uuid"
},
{
"ordinal": 2,
"name": "session_secret_hash?",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "created_at!",
"type_info": "Timestamptz"
},
{
"ordinal": 4,
"name": "last_used_at?",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"name": "revoked_at?",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Uuid",
"Text"
]
},
"nullable": [
false,
false,
true,
false,
true,
true
]
},
"hash": "f40c7ea0e0692e2ee7eead2027260104616026d32f312f8633236cc9438cd958"
}

View File

@@ -0,0 +1,15 @@
ALTER TABLE auth_sessions ADD COLUMN IF NOT EXISTS refresh_token_id UUID;
ALTER TABLE auth_sessions ADD COLUMN IF NOT EXISTS refresh_token_issued_at TIMESTAMPTZ;
CREATE INDEX IF NOT EXISTS idx_auth_sessions_refresh_id
ON auth_sessions (refresh_token_id);
CREATE TABLE IF NOT EXISTS revoked_refresh_tokens (
token_id UUID PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
revoked_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_reason TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_revoked_tokens_user
ON revoked_refresh_tokens (user_id);

View File

@@ -30,7 +30,6 @@ use crate::{
},
};
const SESSION_SECRET_LENGTH: usize = 48;
const STATE_LENGTH: usize = 48;
const APP_CODE_LENGTH: usize = 48;
const HANDOFF_TTL: i64 = 10; // minutes
@@ -93,6 +92,7 @@ pub enum CallbackResult {
#[derive(Debug, Clone)]
pub struct RedeemResponse {
pub access_token: String,
pub refresh_token: String,
}
pub struct OAuthHandoffService {
@@ -321,7 +321,7 @@ impl OAuthHandoffService {
.ok_or_else(|| HandoffError::Failed("missing_user".into()))?;
let session_repo = AuthSessionRepository::new(&self.pool);
let mut session = session_repo.get(session_id).await?;
let session = session_repo.get(session_id).await?;
if session.revoked_at.is_some() {
return Err(HandoffError::Denied);
}
@@ -331,13 +331,6 @@ impl OAuthHandoffService {
return Err(HandoffError::Denied);
}
let session_secret = generate_session_secret();
let session_secret_hash = self.jwt.hash_session_secret(&session_secret)?;
session_repo
.update_secret(session.id, &session_secret_hash)
.await?;
session.session_secret_hash = Some(session_secret_hash.clone());
let user_repo = UserRepository::new(&self.pool);
let user = user_repo.fetch_user(user_id).await?;
let org_repo = OrganizationRepository::new(&self.pool);
@@ -345,14 +338,20 @@ impl OAuthHandoffService {
.ensure_personal_org_and_admin_membership(user.id, user.username.as_deref())
.await?;
let token = self.jwt.encode(&session, &user, &session_secret)?;
let tokens = self.jwt.generate_tokens(&session, &user)?;
session_repo
.set_current_refresh_token(session.id, tokens.refresh_token_id)
.await?;
session_repo.touch(session.id).await?;
repo.mark_redeemed(record.id).await?;
configure_user_scope(user.id, user.username.as_deref(), Some(user.email.as_str()));
Ok(RedeemResponse {
access_token: token,
access_token: tokens.access_token,
refresh_token: tokens.refresh_token,
})
}
@@ -498,14 +497,6 @@ fn generate_app_code() -> String {
.collect()
}
fn generate_session_secret() -> String {
rand::rng()
.sample_iter(&Alphanumeric)
.take(SESSION_SECRET_LENGTH)
.map(char::from)
.collect()
}
fn ensure_email(provider: &str, profile: &ProviderUser) -> String {
if let Some(email) = profile.email.clone() {
return email;

View File

@@ -1,19 +1,17 @@
use std::{collections::HashSet, sync::Arc};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
use chrono::Utc;
use hmac::{Hmac, Mac};
use chrono::{DateTime, Duration as ChronoDuration, Utc};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use subtle::ConstantTimeEq;
use thiserror::Error;
use uuid::Uuid;
use crate::db::{auth::AuthSession, users::User};
type HmacSha256 = Hmac<Sha256>;
pub const ACCESS_TOKEN_TTL_SECONDS: i64 = 120;
pub const REFRESH_TOKEN_TTL_DAYS: i64 = 365;
const DEFAULT_JWT_LEEWAY_SECONDS: u64 = 60;
#[derive(Debug, Error)]
pub enum JwtError {
@@ -21,23 +19,49 @@ pub enum JwtError {
InvalidToken,
#[error("invalid jwt secret")]
InvalidSecret,
#[error("token expired")]
TokenExpired,
#[error("refresh token reused - possible theft detected")]
TokenReuseDetected,
#[error("session revoked")]
SessionRevoked,
#[error("token type mismatch")]
InvalidTokenType,
#[error(transparent)]
Jwt(#[from] jsonwebtoken::errors::Error),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtClaims {
pub struct AccessTokenClaims {
pub sub: Uuid,
pub session_id: Uuid,
pub nonce: String,
pub iat: i64,
pub exp: i64,
pub aud: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefreshTokenClaims {
pub sub: Uuid,
pub session_id: Uuid,
pub jti: Uuid,
pub iat: i64,
pub exp: i64,
pub aud: String,
}
#[derive(Debug, Clone)]
pub struct JwtIdentity {
pub struct AccessTokenDetails {
pub user_id: Uuid,
pub session_id: Uuid,
pub nonce: String,
pub expires_at: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct RefreshTokenDetails {
pub user_id: Uuid,
pub session_id: Uuid,
pub refresh_token_id: Uuid,
}
#[derive(Clone)]
@@ -45,6 +69,13 @@ pub struct JwtService {
secret: Arc<SecretString>,
}
#[derive(Debug, Clone)]
pub struct TokenPair {
pub access_token: String,
pub refresh_token: String,
pub refresh_token_id: Uuid,
}
impl JwtService {
pub fn new(secret: SecretString) -> Self {
Self {
@@ -52,71 +83,114 @@ impl JwtService {
}
}
pub fn encode(
pub fn generate_tokens(
&self,
session: &AuthSession,
user: &User,
session_secret: &str,
) -> Result<String, JwtError> {
let claims = JwtClaims {
) -> Result<TokenPair, JwtError> {
let now = Utc::now();
let refresh_token_id = Uuid::new_v4();
// Access token, short-lived (~2 minutes)
let access_exp = now + ChronoDuration::seconds(ACCESS_TOKEN_TTL_SECONDS);
let access_claims = AccessTokenClaims {
sub: user.id,
session_id: session.id,
nonce: session_secret.to_string(),
iat: Utc::now().timestamp(),
iat: now.timestamp(),
exp: access_exp.timestamp(),
aud: "access".to_string(),
};
// Refresh token, long-lived (~1 year)
let refresh_exp = now + ChronoDuration::days(REFRESH_TOKEN_TTL_DAYS);
let refresh_claims = RefreshTokenClaims {
sub: user.id,
session_id: session.id,
jti: refresh_token_id,
iat: now.timestamp(),
exp: refresh_exp.timestamp(),
aud: "refresh".to_string(),
};
let encoding_key = EncodingKey::from_base64_secret(self.secret.expose_secret())?;
let token = encode(&Header::new(Algorithm::HS256), &claims, &encoding_key)?;
Ok(token)
let access_token = encode(
&Header::new(Algorithm::HS256),
&access_claims,
&encoding_key,
)?;
let refresh_token = encode(
&Header::new(Algorithm::HS256),
&refresh_claims,
&encoding_key,
)?;
Ok(TokenPair {
access_token,
refresh_token,
refresh_token_id,
})
}
pub fn decode(&self, token: &str) -> Result<JwtIdentity, JwtError> {
pub fn decode_access_token(&self, token: &str) -> Result<AccessTokenDetails, JwtError> {
self.decode_access_token_with_leeway(token, DEFAULT_JWT_LEEWAY_SECONDS)
}
pub fn decode_access_token_with_leeway(
&self,
token: &str,
leeway_seconds: u64,
) -> Result<AccessTokenDetails, JwtError> {
if token.trim().is_empty() {
return Err(JwtError::InvalidToken);
}
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = false;
validation.validate_exp = true;
validation.validate_nbf = false;
validation.required_spec_claims = HashSet::from(["sub".to_string()]);
validation.set_audience(&["access"]);
validation.required_spec_claims =
HashSet::from(["sub".to_string(), "exp".to_string(), "aud".to_string()]);
validation.leeway = leeway_seconds;
let decoding_key = DecodingKey::from_base64_secret(self.secret.expose_secret())?;
let data = decode::<JwtClaims>(token, &decoding_key, &validation)?;
let data = decode::<AccessTokenClaims>(token, &decoding_key, &validation)?;
let claims = data.claims;
Ok(JwtIdentity {
let expires_at = DateTime::from_timestamp(claims.exp, 0).ok_or(JwtError::InvalidToken)?;
Ok(AccessTokenDetails {
user_id: claims.sub,
session_id: claims.session_id,
nonce: claims.nonce,
expires_at,
})
}
fn secret_key_bytes(&self) -> Result<Vec<u8>, JwtError> {
let raw = self.secret.expose_secret();
BASE64_STANDARD
.decode(raw.as_bytes())
.map_err(|_| JwtError::InvalidSecret)
}
pub fn decode_refresh_token(&self, token: &str) -> Result<RefreshTokenDetails, JwtError> {
if token.trim().is_empty() {
return Err(JwtError::InvalidToken);
}
pub fn hash_session_secret(&self, session_secret: &str) -> Result<String, JwtError> {
let key = self.secret_key_bytes()?;
let mut mac = HmacSha256::new_from_slice(&key).map_err(|_| JwtError::InvalidSecret)?;
mac.update(session_secret.as_bytes());
let digest = mac.finalize().into_bytes();
Ok(BASE64_STANDARD.encode(digest))
}
let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = true;
validation.validate_nbf = false;
validation.set_audience(&["refresh"]);
validation.required_spec_claims = HashSet::from([
"sub".to_string(),
"exp".to_string(),
"aud".to_string(),
"jti".to_string(),
]);
validation.leeway = DEFAULT_JWT_LEEWAY_SECONDS;
pub fn verify_session_secret(
&self,
stored_hash: Option<&str>,
candidate_secret: &str,
) -> Result<bool, JwtError> {
let stored = match stored_hash {
Some(value) => value,
None => return Ok(false),
};
let candidate_hash = self.hash_session_secret(candidate_secret)?;
Ok(stored.as_bytes().ct_eq(candidate_hash.as_bytes()).into())
let decoding_key = DecodingKey::from_base64_secret(self.secret.expose_secret())?;
let data = decode::<RefreshTokenClaims>(token, &decoding_key, &validation)?;
let claims = data.claims;
Ok(RefreshTokenDetails {
user_id: claims.sub,
session_id: claims.session_id,
refresh_token_id: claims.jti,
})
}
}

View File

@@ -6,7 +6,7 @@ use axum::{
response::{IntoResponse, Response},
};
use axum_extra::headers::{Authorization, HeaderMapExt, authorization::Bearer};
use chrono::Utc;
use chrono::{DateTime, Utc};
use tracing::warn;
use uuid::Uuid;
@@ -23,7 +23,7 @@ use crate::{
pub struct RequestContext {
pub user: User,
pub session_id: Uuid,
pub session_secret: String,
pub access_token_expires_at: DateTime<Utc>,
}
pub async fn require_session(
@@ -37,10 +37,10 @@ pub async fn require_session(
};
let jwt = state.jwt();
let identity = match jwt.decode(&bearer) {
Ok(identity) => identity,
let identity = match jwt.decode_access_token(&bearer) {
Ok(details) => details,
Err(error) => {
warn!(?error, "failed to decode session token");
warn!(?error, "failed to decode access token");
return StatusCode::UNAUTHORIZED.into_response();
}
};
@@ -57,17 +57,14 @@ pub async fn require_session(
warn!(?error, "failed to load session");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
Err(_) => {
warn!("failed to load session for unknown reason");
return StatusCode::UNAUTHORIZED.into_response();
}
};
let secrets_match = jwt
.verify_session_secret(session.session_secret_hash.as_deref(), &identity.nonce)
.unwrap_or(false);
if session.revoked_at.is_some() || !secrets_match {
warn!(
"session `{}` rejected (revoked or rotated)",
identity.session_id
);
if session.revoked_at.is_some() {
warn!("session `{}` rejected (revoked)", identity.session_id);
return StatusCode::UNAUTHORIZED.into_response();
}
@@ -104,7 +101,7 @@ pub async fn require_session(
req.extensions_mut().insert(RequestContext {
user,
session_id: session.id,
session_secret: identity.nonce,
access_token_expires_at: identity.expires_at,
});
match session_repo.touch(session.id).await {

View File

@@ -4,6 +4,6 @@ mod middleware;
mod provider;
pub use handoff::{CallbackResult, HandoffError, OAuthHandoffService};
pub use jwt::{JwtError, JwtIdentity, JwtService};
pub use jwt::{JwtError, JwtService};
pub use middleware::{RequestContext, require_session};
pub use provider::{GitHubOAuthProvider, GoogleOAuthProvider, ProviderRegistry};

View File

@@ -8,6 +8,14 @@ use uuid::Uuid;
pub enum AuthSessionError {
#[error("auth session not found")]
NotFound,
#[error("refresh token reused - possible theft detected")]
TokenReuseDetected,
#[error("token has been revoked")]
TokenRevoked,
#[error("token has expired")]
TokenExpired,
#[error("invalid token")]
InvalidToken,
#[error(transparent)]
Database(#[from] sqlx::Error),
}
@@ -16,10 +24,11 @@ pub enum AuthSessionError {
pub struct AuthSession {
pub id: Uuid,
pub user_id: Uuid,
pub session_secret_hash: Option<String>,
pub created_at: DateTime<Utc>,
pub last_used_at: Option<DateTime<Utc>>,
pub revoked_at: Option<DateTime<Utc>>,
pub refresh_token_id: Option<Uuid>,
pub refresh_token_issued_at: Option<DateTime<Utc>>,
}
pub const MAX_SESSION_INACTIVITY_DURATION: Duration = Duration::days(365);
@@ -36,23 +45,24 @@ impl<'a> AuthSessionRepository<'a> {
pub async fn create(
&self,
user_id: Uuid,
session_secret_hash: Option<&str>,
refresh_token_id: Option<Uuid>,
) -> Result<AuthSession, AuthSessionError> {
query_as!(
AuthSession,
r#"
INSERT INTO auth_sessions (user_id, session_secret_hash)
INSERT INTO auth_sessions (user_id, refresh_token_id)
VALUES ($1, $2)
RETURNING
id AS "id!",
user_id AS "user_id!: Uuid",
session_secret_hash AS "session_secret_hash?",
created_at AS "created_at!",
last_used_at AS "last_used_at?",
revoked_at AS "revoked_at?"
id AS "id!",
user_id AS "user_id!: Uuid",
created_at AS "created_at!",
last_used_at AS "last_used_at?",
revoked_at AS "revoked_at?",
refresh_token_id AS "refresh_token_id?",
refresh_token_issued_at AS "refresh_token_issued_at?"
"#,
user_id,
session_secret_hash
refresh_token_id
)
.fetch_one(self.pool)
.await
@@ -64,12 +74,13 @@ impl<'a> AuthSessionRepository<'a> {
AuthSession,
r#"
SELECT
id AS "id!",
user_id AS "user_id!: Uuid",
session_secret_hash AS "session_secret_hash?",
created_at AS "created_at!",
last_used_at AS "last_used_at?",
revoked_at AS "revoked_at?"
id AS "id!",
user_id AS "user_id!: Uuid",
created_at AS "created_at!",
last_used_at AS "last_used_at?",
revoked_at AS "revoked_at?",
refresh_token_id AS "refresh_token_id?",
refresh_token_issued_at AS "refresh_token_issued_at?"
FROM auth_sessions
WHERE id = $1
"#,
@@ -98,6 +109,126 @@ impl<'a> AuthSessionRepository<'a> {
Ok(())
}
pub async fn rotate_tokens(
&self,
session_id: Uuid,
old_refresh_token_id: Uuid,
new_refresh_token_id: Uuid,
) -> Result<(), AuthSessionError> {
let mut tx = self.pool.begin().await.map_err(AuthSessionError::from)?;
let updated = sqlx::query!(
r#"
UPDATE auth_sessions
SET refresh_token_id = $3,
refresh_token_issued_at = NOW()
WHERE id = $1
AND refresh_token_id = $2
RETURNING user_id
"#,
session_id,
old_refresh_token_id,
new_refresh_token_id
)
.fetch_optional(&mut *tx)
.await
.map_err(AuthSessionError::from)?;
let Some(row) = updated else {
tx.rollback().await.map_err(AuthSessionError::from)?;
return Err(AuthSessionError::TokenReuseDetected);
};
// Revoke the old refresh token
sqlx::query!(
r#"
INSERT INTO revoked_refresh_tokens (token_id, user_id, revoked_reason)
VALUES ($1, $2, 'token_rotation')
ON CONFLICT (token_id) DO NOTHING
"#,
old_refresh_token_id,
row.user_id
)
.execute(&mut *tx)
.await
.map_err(AuthSessionError::from)?;
tx.commit().await.map_err(AuthSessionError::from)?;
Ok(())
}
pub async fn set_current_refresh_token(
&self,
session_id: Uuid,
refresh_token_id: Uuid,
) -> Result<(), AuthSessionError> {
sqlx::query!(
r#"
UPDATE auth_sessions
SET refresh_token_id = $2,
refresh_token_issued_at = NOW()
WHERE id = $1
"#,
session_id,
refresh_token_id
)
.execute(self.pool)
.await?;
Ok(())
}
pub async fn revoke_all_user_sessions(&self, user_id: Uuid) -> Result<i64, AuthSessionError> {
let mut tx = self.pool.begin().await.map_err(AuthSessionError::from)?;
sqlx::query!(
r#"
INSERT INTO revoked_refresh_tokens (token_id, user_id, revoked_reason)
SELECT refresh_token_id, user_id, 'reuse_of_revoked_token'
FROM auth_sessions
WHERE user_id = $1
AND refresh_token_id IS NOT NULL
ON CONFLICT (token_id) DO NOTHING
"#,
user_id
)
.execute(&mut *tx)
.await
.map_err(AuthSessionError::from)?;
let update_result = sqlx::query!(
r#"
UPDATE auth_sessions
SET revoked_at = NOW()
WHERE user_id = $1
AND revoked_at IS NULL
"#,
user_id
)
.execute(&mut *tx)
.await
.map_err(AuthSessionError::from)?;
tx.commit().await.map_err(AuthSessionError::from)?;
Ok(update_result.rows_affected() as i64)
}
pub async fn is_refresh_token_revoked(&self, token_id: Uuid) -> Result<bool, AuthSessionError> {
let result = sqlx::query!(
r#"
SELECT EXISTS(
SELECT 1 FROM revoked_refresh_tokens WHERE token_id = $1
) as is_revoked
"#,
token_id
)
.fetch_one(self.pool)
.await
.map_err(AuthSessionError::from)?;
Ok(result.is_revoked.unwrap_or(false))
}
pub async fn revoke(&self, session_id: Uuid) -> Result<(), AuthSessionError> {
sqlx::query!(
r#"
@@ -111,25 +242,6 @@ impl<'a> AuthSessionRepository<'a> {
.await?;
Ok(())
}
pub async fn update_secret(
&self,
session_id: Uuid,
session_secret_hash: &str,
) -> Result<(), AuthSessionError> {
sqlx::query!(
r#"
UPDATE auth_sessions
SET session_secret_hash = $2
WHERE id = $1
"#,
session_id,
session_secret_hash
)
.execute(self.pool)
.await?;
Ok(())
}
}
impl AuthSession {

View File

@@ -22,6 +22,7 @@ pub(crate) mod organization_members;
mod organizations;
mod projects;
pub mod tasks;
mod tokens;
pub fn router(state: AppState) -> Router {
let trace_layer = TraceLayer::new_for_http()
@@ -47,7 +48,8 @@ pub fn router(state: AppState) -> Router {
let v1_public = Router::<AppState>::new()
.route("/health", get(health))
.merge(oauth::public_router())
.merge(organization_members::public_router());
.merge(organization_members::public_router())
.merge(tokens::public_router());
let v1_protected = Router::<AppState>::new()
.merge(identity::router())

View File

@@ -76,6 +76,7 @@ pub async fn web_redeem(
StatusCode::OK,
Json(HandoffRedeemResponse {
access_token: result.access_token,
refresh_token: result.refresh_token,
}),
)
.into_response(),
@@ -217,6 +218,10 @@ pub async fn logout(
warn!(?error, session_id = %ctx.session_id, "failed to revoke auth session");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
Err(error) => {
warn!(?error, session_id = %ctx.session_id, "failed to revoke auth session");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}

View File

@@ -0,0 +1,140 @@
use axum::{
Json, Router,
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
routing::post,
};
use tracing::warn;
use utils::api::oauth::{TokenRefreshRequest, TokenRefreshResponse};
use crate::{
AppState,
auth::JwtError,
db::{
auth::{AuthSessionError, AuthSessionRepository},
identity_errors::IdentityError,
users::UserRepository,
},
};
pub fn public_router() -> Router<AppState> {
Router::new().route("/tokens/refresh", post(refresh_token))
}
#[derive(Debug, thiserror::Error)]
pub enum TokenRefreshError {
#[error("invalid refresh token")]
InvalidToken,
#[error("session has been revoked")]
SessionRevoked,
#[error("refresh token expired")]
TokenExpired,
#[error("refresh token reused - possible token theft")]
TokenReuseDetected,
#[error(transparent)]
Jwt(#[from] JwtError),
#[error(transparent)]
Database(#[from] sqlx::Error),
#[error(transparent)]
SessionError(#[from] AuthSessionError),
#[error(transparent)]
Identity(#[from] IdentityError),
}
pub async fn refresh_token(
State(state): State<AppState>,
Json(payload): Json<TokenRefreshRequest>,
) -> Result<Response, TokenRefreshError> {
let jwt_service = &state.jwt();
let session_repo = AuthSessionRepository::new(state.pool());
let token_details = match jwt_service.decode_refresh_token(&payload.refresh_token) {
Ok(details) => details,
Err(JwtError::TokenExpired) => return Err(TokenRefreshError::TokenExpired),
Err(_) => return Err(TokenRefreshError::InvalidToken),
};
let session = session_repo.get(token_details.session_id).await?;
if session.revoked_at.is_some() {
return Err(TokenRefreshError::SessionRevoked);
}
if session.refresh_token_id != Some(token_details.refresh_token_id)
|| session_repo
.is_refresh_token_revoked(token_details.refresh_token_id)
.await?
{
// Token was reused, revoke all user sessions as a security measure
let revoked_count = session_repo
.revoke_all_user_sessions(token_details.user_id)
.await?;
warn!(
user_id = %token_details.user_id,
session_id = %token_details.session_id,
revoked_sessions = revoked_count,
"Refresh token reuse detected. Revoked all user sessions."
);
return Err(TokenRefreshError::TokenReuseDetected);
}
let user_repo = UserRepository::new(state.pool());
let user = user_repo.fetch_user(token_details.user_id).await?;
let tokens = jwt_service.generate_tokens(&session, &user)?;
let old_token_id = token_details.refresh_token_id;
let new_token_id = tokens.refresh_token_id;
match session_repo
.rotate_tokens(session.id, old_token_id, new_token_id)
.await
{
Ok(_) => {}
Err(AuthSessionError::TokenReuseDetected) => {
let revoked_count = session_repo
.revoke_all_user_sessions(token_details.user_id)
.await?;
warn!(
user_id = %token_details.user_id,
session_id = %token_details.session_id,
revoked_sessions = revoked_count,
"Detected concurrent refresh attempt; revoked all user sessions"
);
return Err(TokenRefreshError::TokenReuseDetected);
}
Err(error) => return Err(TokenRefreshError::SessionError(error)),
}
Ok(Json(TokenRefreshResponse {
access_token: tokens.access_token,
refresh_token: tokens.refresh_token,
})
.into_response())
}
impl IntoResponse for TokenRefreshError {
fn into_response(self) -> Response {
let (status, error_code) = match self {
TokenRefreshError::InvalidToken => (StatusCode::UNAUTHORIZED, "invalid_token"),
TokenRefreshError::TokenExpired => (StatusCode::UNAUTHORIZED, "token_expired"),
TokenRefreshError::SessionRevoked => (StatusCode::UNAUTHORIZED, "session_revoked"),
TokenRefreshError::TokenReuseDetected => {
(StatusCode::UNAUTHORIZED, "token_reuse_detected")
}
TokenRefreshError::Jwt(_) => (StatusCode::UNAUTHORIZED, "invalid_token"),
TokenRefreshError::Identity(_) => (StatusCode::UNAUTHORIZED, "identity_error"),
TokenRefreshError::Database(_) | TokenRefreshError::SessionError(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, "internal_error")
}
};
let body = serde_json::json!({
"error": error_code,
"message": self.to_string()
});
(status, Json(body)).into_response()
}
}

View File

@@ -1,13 +1,14 @@
use std::sync::Arc;
use axum::extract::ws::{Message, WebSocket};
use chrono::{DateTime, Duration as ChronoDuration, Utc};
use futures::{SinkExt, StreamExt};
use sqlx::PgPool;
use thiserror::Error;
use tokio::time::{self, MissedTickBehavior};
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tracing::{Span, instrument};
use utils::ws::{WS_AUTH_REFRESH_INTERVAL, WS_BULK_SYNC_THRESHOLD};
use utils::ws::{WS_AUTH_REFRESH_INTERVAL, WS_BULK_SYNC_THRESHOLD, WS_TOKEN_EXPIRY_GRACE};
use uuid::Uuid;
use super::{
@@ -17,7 +18,7 @@ use super::{
use crate::{
AppState,
activity::{ActivityBroker, ActivityEvent, ActivityStream},
auth::{JwtError, JwtIdentity, JwtService, RequestContext},
auth::{JwtError, JwtService, RequestContext},
db::{
activity::ActivityRepository,
auth::{AuthSessionError, AuthSessionRepository},
@@ -69,9 +70,9 @@ pub async fn handle(
state.jwt(),
pool.clone(),
ctx.session_id,
ctx.session_secret.clone(),
ctx.user.id,
project_id,
ctx.access_token_expires_at,
);
let mut auth_check_interval = time::interval(WS_AUTH_REFRESH_INTERVAL);
auth_check_interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
@@ -203,10 +204,9 @@ pub async fn handle(
Err(error) => {
tracing::info!(?error, "closing websocket due to auth verification error");
let message = match error {
AuthVerifyError::Revoked | AuthVerifyError::SecretMismatch => {
"authorization revoked"
}
AuthVerifyError::Revoked => "authorization revoked",
AuthVerifyError::MembershipRevoked => "project access revoked",
AuthVerifyError::Expired => "authorization expired",
AuthVerifyError::UserMismatch { .. }
| AuthVerifyError::Decode(_)
| AuthVerifyError::Session(_) => "authorization error",
@@ -269,10 +269,10 @@ struct WsAuthState {
jwt: Arc<JwtService>,
pool: PgPool,
session_id: Uuid,
session_secret: String,
expected_user_id: Uuid,
project_id: Uuid,
pending_token: Option<String>,
token_expires_at: DateTime<Utc>,
new_access_token: Option<String>,
}
impl WsAuthState {
@@ -280,48 +280,64 @@ impl WsAuthState {
jwt: Arc<JwtService>,
pool: PgPool,
session_id: Uuid,
session_secret: String,
expected_user_id: Uuid,
project_id: Uuid,
token_expires_at: DateTime<Utc>,
) -> Self {
Self {
jwt,
pool,
session_id,
session_secret,
expected_user_id,
project_id,
pending_token: None,
new_access_token: None,
token_expires_at,
}
}
fn store_token(&mut self, token: String) {
self.pending_token = Some(token);
self.new_access_token = Some(token);
}
async fn verify(&mut self) -> Result<(), AuthVerifyError> {
if let Some(token) = self.pending_token.take() {
let identity = self.jwt.decode(&token).map_err(AuthVerifyError::Decode)?;
self.apply_identity(identity).await?;
if let Some(token) = self.new_access_token.take() {
let token_details = self
.jwt
.decode_access_token_with_leeway(&token, WS_TOKEN_EXPIRY_GRACE.as_secs())
.map_err(AuthVerifyError::Decode)?;
self.apply_identity(token_details.user_id, token_details.session_id)
.await?;
self.token_expires_at = token_details.expires_at;
}
self.validate_token_expiry()?;
self.validate_session().await?;
self.validate_membership().await
}
async fn apply_identity(&mut self, identity: JwtIdentity) -> Result<(), AuthVerifyError> {
if identity.user_id != self.expected_user_id {
async fn apply_identity(
&mut self,
user_id: Uuid,
session_id: Uuid,
) -> Result<(), AuthVerifyError> {
if user_id != self.expected_user_id {
return Err(AuthVerifyError::UserMismatch {
expected: self.expected_user_id,
received: identity.user_id,
received: user_id,
});
}
self.session_id = identity.session_id;
self.session_secret = identity.nonce;
self.session_id = session_id;
self.validate_session().await
}
fn validate_token_expiry(&self) -> Result<(), AuthVerifyError> {
if self.token_expires_at + ws_leeway_duration() > Utc::now() {
return Ok(());
}
Err(AuthVerifyError::Expired)
}
async fn validate_session(&self) -> Result<(), AuthVerifyError> {
let repo = AuthSessionRepository::new(&self.pool);
let session = repo
@@ -333,14 +349,6 @@ impl WsAuthState {
return Err(AuthVerifyError::Revoked);
}
if !self
.jwt
.verify_session_secret(session.session_secret_hash.as_deref(), &self.session_secret)
.unwrap_or(false)
{
return Err(AuthVerifyError::SecretMismatch);
}
Ok(())
}
@@ -364,6 +372,10 @@ impl WsAuthState {
}
}
fn ws_leeway_duration() -> ChronoDuration {
ChronoDuration::from_std(WS_TOKEN_EXPIRY_GRACE).unwrap()
}
#[derive(Debug, Error)]
enum AuthVerifyError {
#[error(transparent)]
@@ -374,10 +386,10 @@ enum AuthVerifyError {
Session(#[from] AuthSessionError),
#[error("session revoked")]
Revoked,
#[error("session rotated")]
SecretMismatch,
#[error("organization membership revoked")]
MembershipRevoked,
#[error("access token expired")]
Expired,
}
#[allow(clippy::too_many_arguments)]

View File

@@ -152,6 +152,7 @@ impl IntoResponse for ApiError {
StatusCode::from_u16(*status).unwrap_or(StatusCode::BAD_GATEWAY),
"RemoteClientError",
),
RemoteClientError::Token(_) => (StatusCode::BAD_GATEWAY, "RemoteClientError"),
RemoteClientError::Api(code) => match code {
services::services::remote_client::HandoffErrorCode::NotFound => {
(StatusCode::NOT_FOUND, "RemoteClientError")
@@ -168,6 +169,9 @@ impl IntoResponse for ApiError {
}
_ => (StatusCode::BAD_REQUEST, "RemoteClientError"),
},
RemoteClientError::Storage(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, "RemoteClientError")
}
RemoteClientError::Serde(_) | RemoteClientError::Url(_) => {
(StatusCode::BAD_REQUEST, "RemoteClientError")
}
@@ -210,6 +214,12 @@ impl IntoResponse for ApiError {
body.clone()
}
}
RemoteClientError::Token(_) => {
"Remote service returned an invalid access token. Please sign in again.".to_string()
}
RemoteClientError::Storage(_) => {
"Failed to persist credentials locally. Please retry.".to_string()
}
RemoteClientError::Api(code) => match code {
services::services::remote_client::HandoffErrorCode::NotFound => {
"The requested resource was not found.".to_string()

View File

@@ -13,6 +13,7 @@ use sha2::{Digest, Sha256};
use utils::{
api::oauth::{HandoffInitRequest, HandoffRedeemRequest, StatusResponse},
assets::config_path,
jwt::extract_expiration,
response::ApiResponse,
};
use uuid::Uuid;
@@ -119,8 +120,12 @@ async fn handoff_complete(
let redeem = client.handoff_redeem(&redeem_request).await?;
let expires_at = extract_expiration(&redeem.access_token)
.map_err(|err| ApiError::BadRequest(format!("Invalid access token: {err}")))?;
let credentials = Credentials {
access_token: redeem.access_token.clone(),
access_token: Some(redeem.access_token.clone()),
refresh_token: redeem.refresh_token.clone(),
expires_at: Some(expires_at),
};
deployment

View File

@@ -1,6 +1,6 @@
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::sync::{Mutex as TokioMutex, OwnedMutexGuard, RwLock};
use utils::api::oauth::ProfileResponse;
use super::oauth_credentials::{Credentials, OAuthCredentials};
@@ -9,6 +9,7 @@ use super::oauth_credentials::{Credentials, OAuthCredentials};
pub struct AuthContext {
oauth: Arc<OAuthCredentials>,
profile: Arc<RwLock<Option<ProfileResponse>>>,
refresh_lock: Arc<TokioMutex<()>>,
}
impl AuthContext {
@@ -16,7 +17,11 @@ impl AuthContext {
oauth: Arc<OAuthCredentials>,
profile: Arc<RwLock<Option<ProfileResponse>>>,
) -> Self {
Self { oauth, profile }
Self {
oauth,
profile,
refresh_lock: Arc::new(TokioMutex::new(())),
}
}
pub async fn get_credentials(&self) -> Option<Credentials> {
@@ -42,4 +47,8 @@ impl AuthContext {
pub async fn clear_profile(&self) {
*self.profile.write().await = None
}
pub async fn refresh_guard(&self) -> OwnedMutexGuard<()> {
self.refresh_lock.clone().lock_owned().await
}
}

View File

@@ -1,13 +1,40 @@
use std::path::PathBuf;
use chrono::{DateTime, Duration as ChronoDuration, Utc};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
/// OAuth credentials containing the JWT access token.
/// The access_token is a JWT from the remote OAuth service and should be treated as opaque.
/// OAuth credentials containing the JWT tokens issued by the remote OAuth service.
/// The `access_token` is short-lived; `refresh_token` allows minting a new pair.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Credentials {
pub access_token: String,
pub access_token: Option<String>,
pub refresh_token: String,
pub expires_at: Option<DateTime<Utc>>,
}
impl Credentials {
pub fn expires_soon(&self, leeway: ChronoDuration) -> bool {
match (self.access_token.as_ref(), self.expires_at.as_ref()) {
(Some(_), Some(exp)) => Utc::now() + leeway >= *exp,
_ => true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct StoredCredentials {
refresh_token: String,
}
impl From<StoredCredentials> for Credentials {
fn from(value: StoredCredentials) -> Self {
Self {
access_token: None,
refresh_token: value.refresh_token,
expires_at: None,
}
}
}
/// Service for managing OAuth credentials (JWT tokens) in memory and persistent storage.
@@ -26,13 +53,16 @@ impl OAuthCredentials {
}
pub async fn load(&self) -> std::io::Result<()> {
let creds = self.backend.load().await?;
let creds = self.backend.load().await?.map(Credentials::from);
*self.inner.write().await = creds;
Ok(())
}
pub async fn save(&self, creds: &Credentials) -> std::io::Result<()> {
self.backend.save(creds).await?;
let stored = StoredCredentials {
refresh_token: creds.refresh_token.clone(),
};
self.backend.save(&stored).await?;
*self.inner.write().await = Some(creds.clone());
Ok(())
}
@@ -49,8 +79,8 @@ impl OAuthCredentials {
}
trait StoreBackend {
async fn load(&self) -> std::io::Result<Option<Credentials>>;
async fn save(&self, creds: &Credentials) -> std::io::Result<()>;
async fn load(&self) -> std::io::Result<Option<StoredCredentials>>;
async fn save(&self, creds: &StoredCredentials) -> std::io::Result<()>;
async fn clear(&self) -> std::io::Result<()>;
}
@@ -86,7 +116,7 @@ impl Backend {
}
impl StoreBackend for Backend {
async fn load(&self) -> std::io::Result<Option<Credentials>> {
async fn load(&self) -> std::io::Result<Option<StoredCredentials>> {
match self {
Backend::File(b) => b.load().await,
#[cfg(target_os = "macos")]
@@ -94,7 +124,7 @@ impl StoreBackend for Backend {
}
}
async fn save(&self, creds: &Credentials) -> std::io::Result<()> {
async fn save(&self, creds: &StoredCredentials) -> std::io::Result<()> {
match self {
Backend::File(b) => b.save(creds).await,
#[cfg(target_os = "macos")]
@@ -116,13 +146,13 @@ struct FileBackend {
}
impl FileBackend {
async fn load(&self) -> std::io::Result<Option<Credentials>> {
async fn load(&self) -> std::io::Result<Option<StoredCredentials>> {
if !self.path.exists() {
return Ok(None);
}
let bytes = std::fs::read(&self.path)?;
match serde_json::from_slice::<Credentials>(&bytes) {
match Self::parse_credentials(&bytes) {
Ok(creds) => Ok(Some(creds)),
Err(e) => {
tracing::warn!(?e, "failed to parse credentials file, renaming to .bad");
@@ -133,7 +163,11 @@ impl FileBackend {
}
}
async fn save(&self, creds: &Credentials) -> std::io::Result<()> {
fn parse_credentials(bytes: &[u8]) -> Result<StoredCredentials, serde_json::Error> {
serde_json::from_slice::<StoredCredentials>(bytes)
}
async fn save(&self, creds: &StoredCredentials) -> std::io::Result<()> {
let tmp = self.path.with_extension("tmp");
let file = {
@@ -149,7 +183,7 @@ impl FileBackend {
opts.open(&tmp)?
};
serde_json::to_writer_pretty(&file, &creds)?;
serde_json::to_writer_pretty(&file, creds)?;
file.sync_all()?;
drop(file);
@@ -172,14 +206,17 @@ impl KeychainBackend {
const ACCOUNT_NAME: &'static str = "default";
const ERR_SEC_ITEM_NOT_FOUND: i32 = -25300;
async fn load(&self) -> std::io::Result<Option<Credentials>> {
async fn load(&self) -> std::io::Result<Option<StoredCredentials>> {
use security_framework::passwords::get_generic_password;
match get_generic_password(Self::SERVICE_NAME, Self::ACCOUNT_NAME) {
Ok(bytes) => match serde_json::from_slice::<Credentials>(&bytes) {
Ok(bytes) => match serde_json::from_slice::<StoredCredentials>(&bytes) {
Ok(creds) => Ok(Some(creds)),
Err(e) => {
tracing::warn!(?e, "failed to parse keychain credentials; ignoring");
Err(error) => {
tracing::warn!(
?error,
"failed to parse keychain credentials; ignoring entry and requiring re-login"
);
Ok(None)
}
},
@@ -188,7 +225,7 @@ impl KeychainBackend {
}
}
async fn save(&self, creds: &Credentials) -> std::io::Result<()> {
async fn save(&self, creds: &StoredCredentials) -> std::io::Result<()> {
use security_framework::passwords::set_generic_password;
let bytes = serde_json::to_vec_pretty(creds).map_err(std::io::Error::other)?;

View File

@@ -3,6 +3,7 @@
use std::time::Duration;
use backon::{ExponentialBuilder, Retryable};
use chrono::Duration as ChronoDuration;
use remote::{
activity::ActivityResponse,
routes::tasks::{
@@ -16,23 +17,26 @@ use serde_json::Value;
use thiserror::Error;
use tracing::warn;
use url::Url;
use utils::api::{
oauth::{
HandoffInitRequest, HandoffInitResponse, HandoffRedeemRequest, HandoffRedeemResponse,
ProfileResponse,
use utils::{
api::{
oauth::{
HandoffInitRequest, HandoffInitResponse, HandoffRedeemRequest, HandoffRedeemResponse,
ProfileResponse, TokenRefreshRequest, TokenRefreshResponse,
},
organizations::{
AcceptInvitationResponse, CreateInvitationRequest, CreateInvitationResponse,
CreateOrganizationRequest, CreateOrganizationResponse, GetInvitationResponse,
GetOrganizationResponse, ListInvitationsResponse, ListMembersResponse,
ListOrganizationsResponse, Organization, RevokeInvitationRequest,
UpdateMemberRoleRequest, UpdateMemberRoleResponse, UpdateOrganizationRequest,
},
projects::{ListProjectsResponse, RemoteProject},
},
organizations::{
AcceptInvitationResponse, CreateInvitationRequest, CreateInvitationResponse,
CreateOrganizationRequest, CreateOrganizationResponse, GetInvitationResponse,
GetOrganizationResponse, ListInvitationsResponse, ListMembersResponse,
ListOrganizationsResponse, Organization, RevokeInvitationRequest, UpdateMemberRoleRequest,
UpdateMemberRoleResponse, UpdateOrganizationRequest,
},
projects::{ListProjectsResponse, RemoteProject},
jwt::extract_expiration,
};
use uuid::Uuid;
use super::auth::AuthContext;
use super::{auth::AuthContext, oauth_credentials::Credentials};
#[derive(Debug, Clone, Error)]
pub enum RemoteClientError {
@@ -50,6 +54,10 @@ pub enum RemoteClientError {
Serde(String),
#[error("url error: {0}")]
Url(String),
#[error("credentials storage error: {0}")]
Storage(String),
#[error("invalid access token: {0}")]
Token(String),
}
impl RemoteClientError {
@@ -124,6 +132,7 @@ impl Clone for RemoteClient {
impl RemoteClient {
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
const TOKEN_REFRESH_LEEWAY_SECS: i64 = 20;
pub fn new(base_url: &str, auth_context: AuthContext) -> Result<Self, RemoteClientError> {
let base = Url::parse(base_url).map_err(|e| RemoteClientError::Url(e.to_string()))?;
@@ -139,14 +148,84 @@ impl RemoteClient {
})
}
/// Returns the token if available.
async fn require_token(&self) -> Result<String, RemoteClientError> {
let creds = self
.auth_context
.get_credentials()
/// Returns a valid access token, refreshing when it's about to expire.
fn require_token(
&self,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<String, RemoteClientError>> + Send + '_>,
> {
Box::pin(async move {
let leeway = ChronoDuration::seconds(Self::TOKEN_REFRESH_LEEWAY_SECS);
let creds = self
.auth_context
.get_credentials()
.await
.ok_or(RemoteClientError::Auth)?;
if let Some(token) = creds.access_token.as_ref()
&& !creds.expires_soon(leeway)
{
return Ok(token.clone());
}
let refreshed = {
let _refresh_guard = self.auth_context.refresh_guard().await;
let latest = self
.auth_context
.get_credentials()
.await
.ok_or(RemoteClientError::Auth)?;
if let Some(token) = latest.access_token.as_ref()
&& !latest.expires_soon(leeway)
{
return Ok(token.clone());
}
self.refresh_credentials(&latest).await
};
match refreshed {
Ok(updated) => updated.access_token.ok_or(RemoteClientError::Auth),
Err(RemoteClientError::Auth) => {
let _ = self.auth_context.clear_credentials().await;
Err(RemoteClientError::Auth)
}
Err(err) => Err(err),
}
})
}
async fn refresh_credentials(
&self,
creds: &Credentials,
) -> Result<Credentials, RemoteClientError> {
let response = self.refresh_token_request(&creds.refresh_token).await?;
let access_token = response.access_token;
let refresh_token = response.refresh_token;
let expires_at = extract_expiration(&access_token)
.map_err(|err| RemoteClientError::Token(err.to_string()))?;
let new_creds = Credentials {
access_token: Some(access_token),
refresh_token,
expires_at: Some(expires_at),
};
self.auth_context
.save_credentials(&new_creds)
.await
.ok_or(RemoteClientError::Auth)?;
Ok(creds.access_token)
.map_err(|e| RemoteClientError::Storage(e.to_string()))?;
Ok(new_creds)
}
async fn refresh_token_request(
&self,
refresh_token: &str,
) -> Result<TokenRefreshResponse, RemoteClientError> {
let request = TokenRefreshRequest {
refresh_token: refresh_token.to_string(),
};
self.post_public("/v1/tokens/refresh", Some(&request))
.await
.map_err(|e| self.map_api_error(e))
}
/// Returns the base URL for the client.
@@ -154,6 +233,11 @@ impl RemoteClient {
self.base.as_str()
}
/// Returns a valid access token for use-cases like maintaining a websocket connection.
pub async fn access_token(&self) -> Result<String, RemoteClientError> {
self.require_token().await
}
/// Initiates an authorization-code handoff for the given provider.
pub async fn handoff_init(
&self,
@@ -187,7 +271,7 @@ impl RemoteClient {
&self,
method: reqwest::Method,
path: &str,
token: Option<&str>,
requires_auth: bool,
body: Option<&B>,
) -> Result<reqwest::Response, RemoteClientError>
where
@@ -201,8 +285,9 @@ impl RemoteClient {
(|| async {
let mut req = self.http.request(method.clone(), url.clone());
if let Some(t) = token {
req = req.bearer_auth(t);
if requires_auth {
let token = self.require_token().await?;
req = req.bearer_auth(token);
}
if let Some(b) = body {
@@ -245,7 +330,7 @@ impl RemoteClient {
T: for<'de> Deserialize<'de>,
{
let res = self
.send(reqwest::Method::GET, path, None, None::<&()>)
.send(reqwest::Method::GET, path, false, None::<&()>)
.await?;
res.json::<T>()
.await
@@ -257,7 +342,7 @@ impl RemoteClient {
T: for<'de> Deserialize<'de>,
B: Serialize,
{
let res = self.send(reqwest::Method::POST, path, None, body).await?;
let res = self.send(reqwest::Method::POST, path, false, body).await?;
res.json::<T>()
.await
.map_err(|e| RemoteClientError::Serde(e.to_string()))
@@ -268,9 +353,8 @@ impl RemoteClient {
where
T: for<'de> Deserialize<'de>,
{
let token = self.require_token().await?;
let res = self
.send(reqwest::Method::GET, path, Some(&token), None::<&()>)
.send(reqwest::Method::GET, path, true, None::<&()>)
.await?;
res.json::<T>()
.await
@@ -282,10 +366,7 @@ impl RemoteClient {
T: for<'de> Deserialize<'de>,
B: Serialize,
{
let token = self.require_token().await?;
let res = self
.send(reqwest::Method::POST, path, Some(&token), body)
.await?;
let res = self.send(reqwest::Method::POST, path, true, body).await?;
res.json::<T>()
.await
.map_err(|e| RemoteClientError::Serde(e.to_string()))
@@ -296,9 +377,8 @@ impl RemoteClient {
T: for<'de> Deserialize<'de>,
B: Serialize,
{
let token = self.require_token().await?;
let res = self
.send(reqwest::Method::PATCH, path, Some(&token), Some(body))
.send(reqwest::Method::PATCH, path, true, Some(body))
.await?;
res.json::<T>()
.await
@@ -306,8 +386,7 @@ impl RemoteClient {
}
async fn delete_authed(&self, path: &str) -> Result<(), RemoteClientError> {
let token = self.require_token().await?;
self.send(reqwest::Method::DELETE, path, Some(&token), None::<&()>)
self.send(reqwest::Method::DELETE, path, true, None::<&()>)
.await?;
Ok(())
}
@@ -503,12 +582,11 @@ impl RemoteClient {
task_id: Uuid,
request: &DeleteSharedTaskRequest,
) -> Result<SharedTaskResponse, RemoteClientError> {
let token = self.require_token().await?;
let res = self
.send(
reqwest::Method::DELETE,
&format!("/v1/tasks/{task_id}"),
Some(&token),
true,
Some(request),
)
.await?;

View File

@@ -23,7 +23,7 @@ use db::{
use processor::ActivityProcessor;
pub use publisher::SharePublisher;
use remote::{
ServerMessage,
ClientMessage, ServerMessage,
db::{tasks::SharedTask as RemoteSharedTask, users::UserData as RemoteUserData},
};
use sqlx::{Executor, Sqlite, SqlitePool};
@@ -35,7 +35,9 @@ use tokio::{
};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use url::Url;
use utils::ws::{WsClient, WsConfig, WsError, WsHandler, WsResult, run_ws_client};
use utils::ws::{
WS_AUTH_REFRESH_INTERVAL, WsClient, WsConfig, WsError, WsHandler, WsResult, run_ws_client,
};
use uuid::Uuid;
use crate::{
@@ -254,13 +256,21 @@ impl RemoteSync {
let processor = self.processor.clone();
let config = self.config.clone();
let auth_ctx = self.auth_ctx.clone();
let remote_client = processor.remote_client();
let db = self.db.clone();
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let join = tokio::spawn(async move {
let result =
project_watcher_task(db, processor, config, auth_ctx, project_id, shutdown_rx)
.await;
let result = project_watcher_task(
db,
processor,
config,
auth_ctx,
remote_client,
project_id,
shutdown_rx,
)
.await;
let _ = events_tx.send(ProjectWatcherEvent { project_id, result });
});
@@ -327,22 +337,27 @@ impl WsHandler for SharedWsHandler {
async fn spawn_shared_remote(
processor: ActivityProcessor,
auth_ctx: &AuthContext,
remote_client: RemoteClient,
url: Url,
close_tx: oneshot::Sender<()>,
remote_project_id: Uuid,
) -> Result<WsClient, ShareError> {
let auth_source = auth_ctx.clone();
let remote_client_clone = remote_client.clone();
let ws_config = WsConfig {
url,
ping_interval: Some(std::time::Duration::from_secs(30)),
header_factory: Some(Arc::new(move || {
let auth_source = auth_source.clone();
let remote_client_clone = remote_client_clone.clone();
Box::pin(async move {
if let Some(creds) = auth_source.get_credentials().await {
build_ws_headers(&creds.access_token)
} else {
Err(WsError::MissingAuth)
match remote_client_clone.access_token().await {
Ok(token) => build_ws_headers(&token),
Err(error) => {
tracing::warn!(
?error,
"failed to obtain access token for websocket connection"
);
Err(WsError::MissingAuth)
}
}
})
})),
@@ -356,6 +371,7 @@ async fn spawn_shared_remote(
let client = run_ws_client(handler, ws_config)
.await
.map_err(ShareError::from)?;
spawn_ws_auth_refresh_task(client.clone(), remote_client);
Ok(client)
}
@@ -365,6 +381,7 @@ async fn project_watcher_task(
processor: ActivityProcessor,
config: ShareConfig,
auth_ctx: AuthContext,
remote_client: RemoteClient,
remote_project_id: Uuid,
mut shutdown_rx: oneshot::Receiver<()>,
) -> Result<(), ShareError> {
@@ -410,7 +427,7 @@ async fn project_watcher_task(
let (close_tx, close_rx) = oneshot::channel();
let ws_connection = match spawn_shared_remote(
processor.clone(),
&auth_ctx,
remote_client.clone(),
ws_url,
close_tx,
remote_project_id,
@@ -479,6 +496,44 @@ fn build_ws_headers(access_token: &str) -> WsResult<Vec<(HeaderName, HeaderValue
Ok(headers)
}
fn spawn_ws_auth_refresh_task(client: WsClient, remote_client: RemoteClient) {
tokio::spawn(async move {
let mut close_rx = client.subscribe_close();
loop {
match remote_client.access_token().await {
Ok(token) => {
if let Err(err) = send_ws_auth_token(&client, token).await {
tracing::warn!(
?err,
"failed to send websocket auth token; stopping auth refresh"
);
break;
}
}
Err(err) => {
tracing::warn!(
?err,
"failed to obtain access token for websocket auth refresh; stopping auth refresh"
);
break;
}
}
tokio::select! {
_ = close_rx.changed() => break,
_ = sleep(WS_AUTH_REFRESH_INTERVAL) => {}
}
}
});
}
async fn send_ws_auth_token(client: &WsClient, token: String) -> Result<(), ShareError> {
let payload = serde_json::to_string(&ClientMessage::AuthToken { token })?;
client
.send(WsMessage::Text(payload.into()))
.map_err(ShareError::from)
}
#[derive(Clone)]
pub struct RemoteSyncHandle {
inner: Arc<RemoteSyncHandleInner>,

View File

@@ -48,6 +48,10 @@ impl ActivityProcessor {
}
}
pub fn remote_client(&self) -> RemoteClient {
self.remote_client.clone()
}
pub async fn process_event(&self, event: ActivityEvent) -> Result<(), ShareError> {
let mut tx = self.db.pool.begin().await?;
match event.event_type.as_str() {

View File

@@ -23,7 +23,7 @@ sentry = { version = "0.41.0", features = ["anyhow", "backtrace", "panic", "debu
sentry-tracing = { version = "0.41.0", features = ["backtrace"] }
futures-util = "0.3"
json-patch = "2.0"
jsonwebtoken = { version = "10.0.0", features = ["rust_crypto"] }
jsonwebtoken = { version = "10.2.0", features = ["rust_crypto"] }
tokio = { workspace = true }
futures = "0.3.31"
tokio-stream = { version = "0.1.17", features = ["sync"] }

View File

@@ -29,6 +29,20 @@ pub struct HandoffRedeemRequest {
#[ts(export)]
pub struct HandoffRedeemResponse {
pub access_token: String,
pub refresh_token: String,
}
#[derive(Debug, Serialize, Deserialize, Clone, TS)]
#[ts(export)]
pub struct TokenRefreshRequest {
pub refresh_token: String,
}
#[derive(Debug, Serialize, Deserialize, Clone, TS)]
#[ts(export)]
pub struct TokenRefreshResponse {
pub access_token: String,
pub refresh_token: String,
}
#[derive(Debug, Serialize, Deserialize, Clone, TS)]

26
crates/utils/src/jwt.rs Normal file
View File

@@ -0,0 +1,26 @@
use chrono::{DateTime, Utc};
use jsonwebtoken::dangerous::insecure_decode;
use serde::Deserialize;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum TokenClaimsError {
#[error("failed to decode JWT: {0}")]
Decode(#[from] jsonwebtoken::errors::Error),
#[error("missing `exp` claim in token")]
MissingExpiration,
#[error("invalid `exp` value `{0}`")]
InvalidExpiration(i64),
}
#[derive(Debug, Deserialize)]
struct ExpClaim {
exp: Option<i64>,
}
/// Extract the expiration timestamp from a JWT without verifying its signature.
pub fn extract_expiration(token: &str) -> Result<DateTime<Utc>, TokenClaimsError> {
let data = insecure_decode::<ExpClaim>(token)?;
let exp = data.claims.exp.ok_or(TokenClaimsError::MissingExpiration)?;
DateTime::from_timestamp(exp, 0).ok_or(TokenClaimsError::InvalidExpiration(exp))
}

View File

@@ -8,6 +8,7 @@ pub mod assets;
pub mod browser;
pub mod diff;
pub mod git;
pub mod jwt;
pub mod log_msg;
pub mod msg_store;
pub mod path;

View File

@@ -17,6 +17,7 @@ export type HandoffInitResponse = {
export type HandoffRedeemResponse = {
access_token: string;
refresh_token: string;
};
export type AcceptInvitationResponse = {