JWT: separate access tokens and refresh tokens (#1315)
This commit is contained in:
12
Cargo.lock
generated
12
Cargo.lock
generated
@@ -2474,9 +2474,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "jsonwebtoken"
|
||||
version = "10.1.0"
|
||||
version = "10.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d119c6924272d16f0ab9ce41f7aa0bfef9340c00b0bb7ca3dd3b263d4a9150b"
|
||||
checksum = "c76e1c7d7df3e34443b3621b459b066a7b79644f059fc8b2db7070c825fd417e"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"ed25519-dalek",
|
||||
@@ -3480,7 +3480,7 @@ dependencies = [
|
||||
"quinn-udp",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls",
|
||||
"socket2 0.5.10",
|
||||
"socket2 0.6.1",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -3517,9 +3517,9 @@ dependencies = [
|
||||
"cfg_aliases",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"socket2 0.5.10",
|
||||
"socket2 0.6.1",
|
||||
"tracing",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5600,7 +5600,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"git2",
|
||||
"json-patch",
|
||||
"jsonwebtoken 10.1.0",
|
||||
"jsonwebtoken 10.2.0",
|
||||
"open",
|
||||
"regex",
|
||||
"reqwest",
|
||||
|
||||
@@ -294,15 +294,6 @@ impl LocalDeployment {
|
||||
self.remote_client.clone()
|
||||
}
|
||||
|
||||
/// Convenience method to get the current JWT auth token.
|
||||
/// Returns None if the user is not authenticated.
|
||||
pub async fn auth_token(&self) -> Option<String> {
|
||||
self.auth_context
|
||||
.get_credentials()
|
||||
.await
|
||||
.map(|c| c.access_token)
|
||||
}
|
||||
|
||||
pub async fn get_login_status(&self) -> LoginStatus {
|
||||
if self.auth_context.get_credentials().await.is_none() {
|
||||
self.auth_context.clear_profile().await;
|
||||
|
||||
15
crates/remote/.sqlx/query-082aaf51a023c8ccb44002ce48287acd8ef90b0f4c8338447c6e5370ca93390b.json
generated
Normal file
15
crates/remote/.sqlx/query-082aaf51a023c8ccb44002ce48287acd8ef90b0f4c8338447c6e5370ca93390b.json
generated
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO revoked_refresh_tokens (token_id, user_id, revoked_reason)\n VALUES ($1, $2, 'token_rotation')\n ON CONFLICT (token_id) DO NOTHING\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "082aaf51a023c8ccb44002ce48287acd8ef90b0f4c8338447c6e5370ca93390b"
|
||||
}
|
||||
24
crates/remote/.sqlx/query-2f3898ec50ee1386f87786c605069aac78d5177feaabd719b60e54f94f5f535e.json
generated
Normal file
24
crates/remote/.sqlx/query-2f3898ec50ee1386f87786c605069aac78d5177feaabd719b60e54f94f5f535e.json
generated
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE auth_sessions\n SET refresh_token_id = $3,\n refresh_token_issued_at = NOW()\n WHERE id = $1\n AND refresh_token_id = $2\n RETURNING user_id\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "user_id",
|
||||
"type_info": "Uuid"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Uuid",
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "2f3898ec50ee1386f87786c605069aac78d5177feaabd719b60e54f94f5f535e"
|
||||
}
|
||||
22
crates/remote/.sqlx/query-389b412ed9b76973a5b1546a24167e0b752467405f024de73101b6c12e1e05f1.json
generated
Normal file
22
crates/remote/.sqlx/query-389b412ed9b76973a5b1546a24167e0b752467405f024de73101b6c12e1e05f1.json
generated
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT EXISTS(\n SELECT 1 FROM revoked_refresh_tokens WHERE token_id = $1\n ) as is_revoked\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "is_revoked",
|
||||
"type_info": "Bool"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
null
|
||||
]
|
||||
},
|
||||
"hash": "389b412ed9b76973a5b1546a24167e0b752467405f024de73101b6c12e1e05f1"
|
||||
}
|
||||
59
crates/remote/.sqlx/query-4d963a12190ee1db657446ef451c5364f8f91153f7f1bb4e5abfd3f3ddbe0461.json
generated
Normal file
59
crates/remote/.sqlx/query-4d963a12190ee1db657446ef451c5364f8f91153f7f1bb4e5abfd3f3ddbe0461.json
generated
Normal file
@@ -0,0 +1,59 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO auth_sessions (user_id, refresh_token_id)\n VALUES ($1, $2)\n RETURNING\n id AS \"id!\",\n user_id AS \"user_id!: Uuid\",\n created_at AS \"created_at!\",\n last_used_at AS \"last_used_at?\",\n revoked_at AS \"revoked_at?\",\n refresh_token_id AS \"refresh_token_id?\",\n refresh_token_issued_at AS \"refresh_token_issued_at?\"\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id!",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_id!: Uuid",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "created_at!",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "last_used_at?",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "revoked_at?",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "refresh_token_id?",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "refresh_token_issued_at?",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "4d963a12190ee1db657446ef451c5364f8f91153f7f1bb4e5abfd3f3ddbe0461"
|
||||
}
|
||||
14
crates/remote/.sqlx/query-68422b179dc361337c65a6bd1aa455a961708b97a673d84f7af64cd252cbfdf3.json
generated
Normal file
14
crates/remote/.sqlx/query-68422b179dc361337c65a6bd1aa455a961708b97a673d84f7af64cd252cbfdf3.json
generated
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE auth_sessions\n SET revoked_at = NOW()\n WHERE user_id = $1\n AND revoked_at IS NULL\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "68422b179dc361337c65a6bd1aa455a961708b97a673d84f7af64cd252cbfdf3"
|
||||
}
|
||||
14
crates/remote/.sqlx/query-8e32d5bf86d112e2f4a16f622bd95c8f728946f01e1a994a9c66b0fac6e3ae52.json
generated
Normal file
14
crates/remote/.sqlx/query-8e32d5bf86d112e2f4a16f622bd95c8f728946f01e1a994a9c66b0fac6e3ae52.json
generated
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO revoked_refresh_tokens (token_id, user_id, revoked_reason)\n SELECT refresh_token_id, user_id, 'reuse_of_revoked_token'\n FROM auth_sessions\n WHERE user_id = $1\n AND refresh_token_id IS NOT NULL\n ON CONFLICT (token_id) DO NOTHING\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "8e32d5bf86d112e2f4a16f622bd95c8f728946f01e1a994a9c66b0fac6e3ae52"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT\n id AS \"id!\",\n user_id AS \"user_id!: Uuid\",\n session_secret_hash AS \"session_secret_hash?\",\n created_at AS \"created_at!\",\n last_used_at AS \"last_used_at?\",\n revoked_at AS \"revoked_at?\"\n FROM auth_sessions\n WHERE id = $1\n ",
|
||||
"query": "\n SELECT\n id AS \"id!\",\n user_id AS \"user_id!: Uuid\",\n created_at AS \"created_at!\",\n last_used_at AS \"last_used_at?\",\n revoked_at AS \"revoked_at?\",\n refresh_token_id AS \"refresh_token_id?\",\n refresh_token_issued_at AS \"refresh_token_issued_at?\"\n FROM auth_sessions\n WHERE id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -15,23 +15,28 @@
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "session_secret_hash?",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "created_at!",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"ordinal": 3,
|
||||
"name": "last_used_at?",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"ordinal": 4,
|
||||
"name": "revoked_at?",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "refresh_token_id?",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "refresh_token_issued_at?",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
@@ -42,11 +47,12 @@
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "d12fbd108d36c817c94997744b50cafd08407c0e207e2cacd43c50d28e886b19"
|
||||
"hash": "9459cf92b30943acb79f0e0f2e9421be83ce9e50e39f6b1e435b92ff70907264"
|
||||
}
|
||||
@@ -1,15 +1,15 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE auth_sessions\n SET session_secret_hash = $2\n WHERE id = $1\n ",
|
||||
"query": "\n UPDATE auth_sessions\n SET refresh_token_id = $2,\n refresh_token_issued_at = NOW()\n WHERE id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text"
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "92d13927cde8ac62cb0cfd3c3410aa4d42717d6a3a219926ddc34ca1d2520306"
|
||||
"hash": "a1431ca78db627fef0eca6f573b34d65510e9333765126cbd80c943046dfaea8"
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO auth_sessions (user_id, session_secret_hash)\n VALUES ($1, $2)\n RETURNING\n id AS \"id!\",\n user_id AS \"user_id!: Uuid\",\n session_secret_hash AS \"session_secret_hash?\",\n created_at AS \"created_at!\",\n last_used_at AS \"last_used_at?\",\n revoked_at AS \"revoked_at?\"\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id!",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_id!: Uuid",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "session_secret_hash?",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "created_at!",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "last_used_at?",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "revoked_at?",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "f40c7ea0e0692e2ee7eead2027260104616026d32f312f8633236cc9438cd958"
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
ALTER TABLE auth_sessions ADD COLUMN IF NOT EXISTS refresh_token_id UUID;
|
||||
ALTER TABLE auth_sessions ADD COLUMN IF NOT EXISTS refresh_token_issued_at TIMESTAMPTZ;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_sessions_refresh_id
|
||||
ON auth_sessions (refresh_token_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS revoked_refresh_tokens (
|
||||
token_id UUID PRIMARY KEY,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
revoked_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
revoked_reason TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_revoked_tokens_user
|
||||
ON revoked_refresh_tokens (user_id);
|
||||
@@ -30,7 +30,6 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
const SESSION_SECRET_LENGTH: usize = 48;
|
||||
const STATE_LENGTH: usize = 48;
|
||||
const APP_CODE_LENGTH: usize = 48;
|
||||
const HANDOFF_TTL: i64 = 10; // minutes
|
||||
@@ -93,6 +92,7 @@ pub enum CallbackResult {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RedeemResponse {
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
pub struct OAuthHandoffService {
|
||||
@@ -321,7 +321,7 @@ impl OAuthHandoffService {
|
||||
.ok_or_else(|| HandoffError::Failed("missing_user".into()))?;
|
||||
|
||||
let session_repo = AuthSessionRepository::new(&self.pool);
|
||||
let mut session = session_repo.get(session_id).await?;
|
||||
let session = session_repo.get(session_id).await?;
|
||||
if session.revoked_at.is_some() {
|
||||
return Err(HandoffError::Denied);
|
||||
}
|
||||
@@ -331,13 +331,6 @@ impl OAuthHandoffService {
|
||||
return Err(HandoffError::Denied);
|
||||
}
|
||||
|
||||
let session_secret = generate_session_secret();
|
||||
let session_secret_hash = self.jwt.hash_session_secret(&session_secret)?;
|
||||
session_repo
|
||||
.update_secret(session.id, &session_secret_hash)
|
||||
.await?;
|
||||
session.session_secret_hash = Some(session_secret_hash.clone());
|
||||
|
||||
let user_repo = UserRepository::new(&self.pool);
|
||||
let user = user_repo.fetch_user(user_id).await?;
|
||||
let org_repo = OrganizationRepository::new(&self.pool);
|
||||
@@ -345,14 +338,20 @@ impl OAuthHandoffService {
|
||||
.ensure_personal_org_and_admin_membership(user.id, user.username.as_deref())
|
||||
.await?;
|
||||
|
||||
let token = self.jwt.encode(&session, &user, &session_secret)?;
|
||||
let tokens = self.jwt.generate_tokens(&session, &user)?;
|
||||
|
||||
session_repo
|
||||
.set_current_refresh_token(session.id, tokens.refresh_token_id)
|
||||
.await?;
|
||||
|
||||
session_repo.touch(session.id).await?;
|
||||
repo.mark_redeemed(record.id).await?;
|
||||
|
||||
configure_user_scope(user.id, user.username.as_deref(), Some(user.email.as_str()));
|
||||
|
||||
Ok(RedeemResponse {
|
||||
access_token: token,
|
||||
access_token: tokens.access_token,
|
||||
refresh_token: tokens.refresh_token,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -498,14 +497,6 @@ fn generate_app_code() -> String {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn generate_session_secret() -> String {
|
||||
rand::rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(SESSION_SECRET_LENGTH)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn ensure_email(provider: &str, profile: &ProviderUser) -> String {
|
||||
if let Some(email) = profile.email.clone() {
|
||||
return email;
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
|
||||
use chrono::Utc;
|
||||
use hmac::{Hmac, Mac};
|
||||
use chrono::{DateTime, Duration as ChronoDuration, Utc};
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
|
||||
use secrecy::{ExposeSecret, SecretString};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::Sha256;
|
||||
use subtle::ConstantTimeEq;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::db::{auth::AuthSession, users::User};
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
pub const ACCESS_TOKEN_TTL_SECONDS: i64 = 120;
|
||||
pub const REFRESH_TOKEN_TTL_DAYS: i64 = 365;
|
||||
const DEFAULT_JWT_LEEWAY_SECONDS: u64 = 60;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum JwtError {
|
||||
@@ -21,23 +19,49 @@ pub enum JwtError {
|
||||
InvalidToken,
|
||||
#[error("invalid jwt secret")]
|
||||
InvalidSecret,
|
||||
#[error("token expired")]
|
||||
TokenExpired,
|
||||
#[error("refresh token reused - possible theft detected")]
|
||||
TokenReuseDetected,
|
||||
#[error("session revoked")]
|
||||
SessionRevoked,
|
||||
#[error("token type mismatch")]
|
||||
InvalidTokenType,
|
||||
#[error(transparent)]
|
||||
Jwt(#[from] jsonwebtoken::errors::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JwtClaims {
|
||||
pub struct AccessTokenClaims {
|
||||
pub sub: Uuid,
|
||||
pub session_id: Uuid,
|
||||
pub nonce: String,
|
||||
pub iat: i64,
|
||||
pub exp: i64,
|
||||
pub aud: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RefreshTokenClaims {
|
||||
pub sub: Uuid,
|
||||
pub session_id: Uuid,
|
||||
pub jti: Uuid,
|
||||
pub iat: i64,
|
||||
pub exp: i64,
|
||||
pub aud: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct JwtIdentity {
|
||||
pub struct AccessTokenDetails {
|
||||
pub user_id: Uuid,
|
||||
pub session_id: Uuid,
|
||||
pub nonce: String,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RefreshTokenDetails {
|
||||
pub user_id: Uuid,
|
||||
pub session_id: Uuid,
|
||||
pub refresh_token_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -45,6 +69,13 @@ pub struct JwtService {
|
||||
secret: Arc<SecretString>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TokenPair {
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
pub refresh_token_id: Uuid,
|
||||
}
|
||||
|
||||
impl JwtService {
|
||||
pub fn new(secret: SecretString) -> Self {
|
||||
Self {
|
||||
@@ -52,71 +83,114 @@ impl JwtService {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode(
|
||||
pub fn generate_tokens(
|
||||
&self,
|
||||
session: &AuthSession,
|
||||
user: &User,
|
||||
session_secret: &str,
|
||||
) -> Result<String, JwtError> {
|
||||
let claims = JwtClaims {
|
||||
) -> Result<TokenPair, JwtError> {
|
||||
let now = Utc::now();
|
||||
let refresh_token_id = Uuid::new_v4();
|
||||
|
||||
// Access token, short-lived (~2 minutes)
|
||||
let access_exp = now + ChronoDuration::seconds(ACCESS_TOKEN_TTL_SECONDS);
|
||||
let access_claims = AccessTokenClaims {
|
||||
sub: user.id,
|
||||
session_id: session.id,
|
||||
nonce: session_secret.to_string(),
|
||||
iat: Utc::now().timestamp(),
|
||||
iat: now.timestamp(),
|
||||
exp: access_exp.timestamp(),
|
||||
aud: "access".to_string(),
|
||||
};
|
||||
|
||||
// Refresh token, long-lived (~1 year)
|
||||
let refresh_exp = now + ChronoDuration::days(REFRESH_TOKEN_TTL_DAYS);
|
||||
let refresh_claims = RefreshTokenClaims {
|
||||
sub: user.id,
|
||||
session_id: session.id,
|
||||
jti: refresh_token_id,
|
||||
iat: now.timestamp(),
|
||||
exp: refresh_exp.timestamp(),
|
||||
aud: "refresh".to_string(),
|
||||
};
|
||||
|
||||
let encoding_key = EncodingKey::from_base64_secret(self.secret.expose_secret())?;
|
||||
let token = encode(&Header::new(Algorithm::HS256), &claims, &encoding_key)?;
|
||||
|
||||
Ok(token)
|
||||
let access_token = encode(
|
||||
&Header::new(Algorithm::HS256),
|
||||
&access_claims,
|
||||
&encoding_key,
|
||||
)?;
|
||||
|
||||
let refresh_token = encode(
|
||||
&Header::new(Algorithm::HS256),
|
||||
&refresh_claims,
|
||||
&encoding_key,
|
||||
)?;
|
||||
|
||||
Ok(TokenPair {
|
||||
access_token,
|
||||
refresh_token,
|
||||
refresh_token_id,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn decode(&self, token: &str) -> Result<JwtIdentity, JwtError> {
|
||||
pub fn decode_access_token(&self, token: &str) -> Result<AccessTokenDetails, JwtError> {
|
||||
self.decode_access_token_with_leeway(token, DEFAULT_JWT_LEEWAY_SECONDS)
|
||||
}
|
||||
|
||||
pub fn decode_access_token_with_leeway(
|
||||
&self,
|
||||
token: &str,
|
||||
leeway_seconds: u64,
|
||||
) -> Result<AccessTokenDetails, JwtError> {
|
||||
if token.trim().is_empty() {
|
||||
return Err(JwtError::InvalidToken);
|
||||
}
|
||||
|
||||
let mut validation = Validation::new(Algorithm::HS256);
|
||||
validation.validate_exp = false;
|
||||
validation.validate_exp = true;
|
||||
validation.validate_nbf = false;
|
||||
validation.required_spec_claims = HashSet::from(["sub".to_string()]);
|
||||
validation.set_audience(&["access"]);
|
||||
validation.required_spec_claims =
|
||||
HashSet::from(["sub".to_string(), "exp".to_string(), "aud".to_string()]);
|
||||
validation.leeway = leeway_seconds;
|
||||
|
||||
let decoding_key = DecodingKey::from_base64_secret(self.secret.expose_secret())?;
|
||||
let data = decode::<JwtClaims>(token, &decoding_key, &validation)?;
|
||||
|
||||
let data = decode::<AccessTokenClaims>(token, &decoding_key, &validation)?;
|
||||
let claims = data.claims;
|
||||
Ok(JwtIdentity {
|
||||
let expires_at = DateTime::from_timestamp(claims.exp, 0).ok_or(JwtError::InvalidToken)?;
|
||||
|
||||
Ok(AccessTokenDetails {
|
||||
user_id: claims.sub,
|
||||
session_id: claims.session_id,
|
||||
nonce: claims.nonce,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
|
||||
fn secret_key_bytes(&self) -> Result<Vec<u8>, JwtError> {
|
||||
let raw = self.secret.expose_secret();
|
||||
BASE64_STANDARD
|
||||
.decode(raw.as_bytes())
|
||||
.map_err(|_| JwtError::InvalidSecret)
|
||||
}
|
||||
pub fn decode_refresh_token(&self, token: &str) -> Result<RefreshTokenDetails, JwtError> {
|
||||
if token.trim().is_empty() {
|
||||
return Err(JwtError::InvalidToken);
|
||||
}
|
||||
|
||||
pub fn hash_session_secret(&self, session_secret: &str) -> Result<String, JwtError> {
|
||||
let key = self.secret_key_bytes()?;
|
||||
let mut mac = HmacSha256::new_from_slice(&key).map_err(|_| JwtError::InvalidSecret)?;
|
||||
mac.update(session_secret.as_bytes());
|
||||
let digest = mac.finalize().into_bytes();
|
||||
Ok(BASE64_STANDARD.encode(digest))
|
||||
}
|
||||
let mut validation = Validation::new(Algorithm::HS256);
|
||||
validation.validate_exp = true;
|
||||
validation.validate_nbf = false;
|
||||
validation.set_audience(&["refresh"]);
|
||||
validation.required_spec_claims = HashSet::from([
|
||||
"sub".to_string(),
|
||||
"exp".to_string(),
|
||||
"aud".to_string(),
|
||||
"jti".to_string(),
|
||||
]);
|
||||
validation.leeway = DEFAULT_JWT_LEEWAY_SECONDS;
|
||||
|
||||
pub fn verify_session_secret(
|
||||
&self,
|
||||
stored_hash: Option<&str>,
|
||||
candidate_secret: &str,
|
||||
) -> Result<bool, JwtError> {
|
||||
let stored = match stored_hash {
|
||||
Some(value) => value,
|
||||
None => return Ok(false),
|
||||
};
|
||||
let candidate_hash = self.hash_session_secret(candidate_secret)?;
|
||||
Ok(stored.as_bytes().ct_eq(candidate_hash.as_bytes()).into())
|
||||
let decoding_key = DecodingKey::from_base64_secret(self.secret.expose_secret())?;
|
||||
let data = decode::<RefreshTokenClaims>(token, &decoding_key, &validation)?;
|
||||
let claims = data.claims;
|
||||
|
||||
Ok(RefreshTokenDetails {
|
||||
user_id: claims.sub,
|
||||
session_id: claims.session_id,
|
||||
refresh_token_id: claims.jti,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use axum::{
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use axum_extra::headers::{Authorization, HeaderMapExt, authorization::Bearer};
|
||||
use chrono::Utc;
|
||||
use chrono::{DateTime, Utc};
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -23,7 +23,7 @@ use crate::{
|
||||
pub struct RequestContext {
|
||||
pub user: User,
|
||||
pub session_id: Uuid,
|
||||
pub session_secret: String,
|
||||
pub access_token_expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
pub async fn require_session(
|
||||
@@ -37,10 +37,10 @@ pub async fn require_session(
|
||||
};
|
||||
|
||||
let jwt = state.jwt();
|
||||
let identity = match jwt.decode(&bearer) {
|
||||
Ok(identity) => identity,
|
||||
let identity = match jwt.decode_access_token(&bearer) {
|
||||
Ok(details) => details,
|
||||
Err(error) => {
|
||||
warn!(?error, "failed to decode session token");
|
||||
warn!(?error, "failed to decode access token");
|
||||
return StatusCode::UNAUTHORIZED.into_response();
|
||||
}
|
||||
};
|
||||
@@ -57,17 +57,14 @@ pub async fn require_session(
|
||||
warn!(?error, "failed to load session");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("failed to load session for unknown reason");
|
||||
return StatusCode::UNAUTHORIZED.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let secrets_match = jwt
|
||||
.verify_session_secret(session.session_secret_hash.as_deref(), &identity.nonce)
|
||||
.unwrap_or(false);
|
||||
|
||||
if session.revoked_at.is_some() || !secrets_match {
|
||||
warn!(
|
||||
"session `{}` rejected (revoked or rotated)",
|
||||
identity.session_id
|
||||
);
|
||||
if session.revoked_at.is_some() {
|
||||
warn!("session `{}` rejected (revoked)", identity.session_id);
|
||||
return StatusCode::UNAUTHORIZED.into_response();
|
||||
}
|
||||
|
||||
@@ -104,7 +101,7 @@ pub async fn require_session(
|
||||
req.extensions_mut().insert(RequestContext {
|
||||
user,
|
||||
session_id: session.id,
|
||||
session_secret: identity.nonce,
|
||||
access_token_expires_at: identity.expires_at,
|
||||
});
|
||||
|
||||
match session_repo.touch(session.id).await {
|
||||
|
||||
@@ -4,6 +4,6 @@ mod middleware;
|
||||
mod provider;
|
||||
|
||||
pub use handoff::{CallbackResult, HandoffError, OAuthHandoffService};
|
||||
pub use jwt::{JwtError, JwtIdentity, JwtService};
|
||||
pub use jwt::{JwtError, JwtService};
|
||||
pub use middleware::{RequestContext, require_session};
|
||||
pub use provider::{GitHubOAuthProvider, GoogleOAuthProvider, ProviderRegistry};
|
||||
|
||||
@@ -8,6 +8,14 @@ use uuid::Uuid;
|
||||
pub enum AuthSessionError {
|
||||
#[error("auth session not found")]
|
||||
NotFound,
|
||||
#[error("refresh token reused - possible theft detected")]
|
||||
TokenReuseDetected,
|
||||
#[error("token has been revoked")]
|
||||
TokenRevoked,
|
||||
#[error("token has expired")]
|
||||
TokenExpired,
|
||||
#[error("invalid token")]
|
||||
InvalidToken,
|
||||
#[error(transparent)]
|
||||
Database(#[from] sqlx::Error),
|
||||
}
|
||||
@@ -16,10 +24,11 @@ pub enum AuthSessionError {
|
||||
pub struct AuthSession {
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
pub session_secret_hash: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub last_used_at: Option<DateTime<Utc>>,
|
||||
pub revoked_at: Option<DateTime<Utc>>,
|
||||
pub refresh_token_id: Option<Uuid>,
|
||||
pub refresh_token_issued_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
pub const MAX_SESSION_INACTIVITY_DURATION: Duration = Duration::days(365);
|
||||
@@ -36,23 +45,24 @@ impl<'a> AuthSessionRepository<'a> {
|
||||
pub async fn create(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
session_secret_hash: Option<&str>,
|
||||
refresh_token_id: Option<Uuid>,
|
||||
) -> Result<AuthSession, AuthSessionError> {
|
||||
query_as!(
|
||||
AuthSession,
|
||||
r#"
|
||||
INSERT INTO auth_sessions (user_id, session_secret_hash)
|
||||
INSERT INTO auth_sessions (user_id, refresh_token_id)
|
||||
VALUES ($1, $2)
|
||||
RETURNING
|
||||
id AS "id!",
|
||||
user_id AS "user_id!: Uuid",
|
||||
session_secret_hash AS "session_secret_hash?",
|
||||
created_at AS "created_at!",
|
||||
last_used_at AS "last_used_at?",
|
||||
revoked_at AS "revoked_at?"
|
||||
id AS "id!",
|
||||
user_id AS "user_id!: Uuid",
|
||||
created_at AS "created_at!",
|
||||
last_used_at AS "last_used_at?",
|
||||
revoked_at AS "revoked_at?",
|
||||
refresh_token_id AS "refresh_token_id?",
|
||||
refresh_token_issued_at AS "refresh_token_issued_at?"
|
||||
"#,
|
||||
user_id,
|
||||
session_secret_hash
|
||||
refresh_token_id
|
||||
)
|
||||
.fetch_one(self.pool)
|
||||
.await
|
||||
@@ -64,12 +74,13 @@ impl<'a> AuthSessionRepository<'a> {
|
||||
AuthSession,
|
||||
r#"
|
||||
SELECT
|
||||
id AS "id!",
|
||||
user_id AS "user_id!: Uuid",
|
||||
session_secret_hash AS "session_secret_hash?",
|
||||
created_at AS "created_at!",
|
||||
last_used_at AS "last_used_at?",
|
||||
revoked_at AS "revoked_at?"
|
||||
id AS "id!",
|
||||
user_id AS "user_id!: Uuid",
|
||||
created_at AS "created_at!",
|
||||
last_used_at AS "last_used_at?",
|
||||
revoked_at AS "revoked_at?",
|
||||
refresh_token_id AS "refresh_token_id?",
|
||||
refresh_token_issued_at AS "refresh_token_issued_at?"
|
||||
FROM auth_sessions
|
||||
WHERE id = $1
|
||||
"#,
|
||||
@@ -98,6 +109,126 @@ impl<'a> AuthSessionRepository<'a> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn rotate_tokens(
|
||||
&self,
|
||||
session_id: Uuid,
|
||||
old_refresh_token_id: Uuid,
|
||||
new_refresh_token_id: Uuid,
|
||||
) -> Result<(), AuthSessionError> {
|
||||
let mut tx = self.pool.begin().await.map_err(AuthSessionError::from)?;
|
||||
|
||||
let updated = sqlx::query!(
|
||||
r#"
|
||||
UPDATE auth_sessions
|
||||
SET refresh_token_id = $3,
|
||||
refresh_token_issued_at = NOW()
|
||||
WHERE id = $1
|
||||
AND refresh_token_id = $2
|
||||
RETURNING user_id
|
||||
"#,
|
||||
session_id,
|
||||
old_refresh_token_id,
|
||||
new_refresh_token_id
|
||||
)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await
|
||||
.map_err(AuthSessionError::from)?;
|
||||
|
||||
let Some(row) = updated else {
|
||||
tx.rollback().await.map_err(AuthSessionError::from)?;
|
||||
return Err(AuthSessionError::TokenReuseDetected);
|
||||
};
|
||||
|
||||
// Revoke the old refresh token
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO revoked_refresh_tokens (token_id, user_id, revoked_reason)
|
||||
VALUES ($1, $2, 'token_rotation')
|
||||
ON CONFLICT (token_id) DO NOTHING
|
||||
"#,
|
||||
old_refresh_token_id,
|
||||
row.user_id
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(AuthSessionError::from)?;
|
||||
|
||||
tx.commit().await.map_err(AuthSessionError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set_current_refresh_token(
|
||||
&self,
|
||||
session_id: Uuid,
|
||||
refresh_token_id: Uuid,
|
||||
) -> Result<(), AuthSessionError> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE auth_sessions
|
||||
SET refresh_token_id = $2,
|
||||
refresh_token_issued_at = NOW()
|
||||
WHERE id = $1
|
||||
"#,
|
||||
session_id,
|
||||
refresh_token_id
|
||||
)
|
||||
.execute(self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn revoke_all_user_sessions(&self, user_id: Uuid) -> Result<i64, AuthSessionError> {
|
||||
let mut tx = self.pool.begin().await.map_err(AuthSessionError::from)?;
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO revoked_refresh_tokens (token_id, user_id, revoked_reason)
|
||||
SELECT refresh_token_id, user_id, 'reuse_of_revoked_token'
|
||||
FROM auth_sessions
|
||||
WHERE user_id = $1
|
||||
AND refresh_token_id IS NOT NULL
|
||||
ON CONFLICT (token_id) DO NOTHING
|
||||
"#,
|
||||
user_id
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(AuthSessionError::from)?;
|
||||
|
||||
let update_result = sqlx::query!(
|
||||
r#"
|
||||
UPDATE auth_sessions
|
||||
SET revoked_at = NOW()
|
||||
WHERE user_id = $1
|
||||
AND revoked_at IS NULL
|
||||
"#,
|
||||
user_id
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(AuthSessionError::from)?;
|
||||
|
||||
tx.commit().await.map_err(AuthSessionError::from)?;
|
||||
|
||||
Ok(update_result.rows_affected() as i64)
|
||||
}
|
||||
|
||||
pub async fn is_refresh_token_revoked(&self, token_id: Uuid) -> Result<bool, AuthSessionError> {
|
||||
let result = sqlx::query!(
|
||||
r#"
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM revoked_refresh_tokens WHERE token_id = $1
|
||||
) as is_revoked
|
||||
"#,
|
||||
token_id
|
||||
)
|
||||
.fetch_one(self.pool)
|
||||
.await
|
||||
.map_err(AuthSessionError::from)?;
|
||||
|
||||
Ok(result.is_revoked.unwrap_or(false))
|
||||
}
|
||||
|
||||
pub async fn revoke(&self, session_id: Uuid) -> Result<(), AuthSessionError> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
@@ -111,25 +242,6 @@ impl<'a> AuthSessionRepository<'a> {
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_secret(
|
||||
&self,
|
||||
session_id: Uuid,
|
||||
session_secret_hash: &str,
|
||||
) -> Result<(), AuthSessionError> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE auth_sessions
|
||||
SET session_secret_hash = $2
|
||||
WHERE id = $1
|
||||
"#,
|
||||
session_id,
|
||||
session_secret_hash
|
||||
)
|
||||
.execute(self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthSession {
|
||||
|
||||
@@ -22,6 +22,7 @@ pub(crate) mod organization_members;
|
||||
mod organizations;
|
||||
mod projects;
|
||||
pub mod tasks;
|
||||
mod tokens;
|
||||
|
||||
pub fn router(state: AppState) -> Router {
|
||||
let trace_layer = TraceLayer::new_for_http()
|
||||
@@ -47,7 +48,8 @@ pub fn router(state: AppState) -> Router {
|
||||
let v1_public = Router::<AppState>::new()
|
||||
.route("/health", get(health))
|
||||
.merge(oauth::public_router())
|
||||
.merge(organization_members::public_router());
|
||||
.merge(organization_members::public_router())
|
||||
.merge(tokens::public_router());
|
||||
|
||||
let v1_protected = Router::<AppState>::new()
|
||||
.merge(identity::router())
|
||||
|
||||
@@ -76,6 +76,7 @@ pub async fn web_redeem(
|
||||
StatusCode::OK,
|
||||
Json(HandoffRedeemResponse {
|
||||
access_token: result.access_token,
|
||||
refresh_token: result.refresh_token,
|
||||
}),
|
||||
)
|
||||
.into_response(),
|
||||
@@ -217,6 +218,10 @@ pub async fn logout(
|
||||
warn!(?error, session_id = %ctx.session_id, "failed to revoke auth session");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
Err(error) => {
|
||||
warn!(?error, session_id = %ctx.session_id, "failed to revoke auth session");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
140
crates/remote/src/routes/tokens.rs
Normal file
140
crates/remote/src/routes/tokens.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
use axum::{
|
||||
Json, Router,
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::post,
|
||||
};
|
||||
use tracing::warn;
|
||||
use utils::api::oauth::{TokenRefreshRequest, TokenRefreshResponse};
|
||||
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::JwtError,
|
||||
db::{
|
||||
auth::{AuthSessionError, AuthSessionRepository},
|
||||
identity_errors::IdentityError,
|
||||
users::UserRepository,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn public_router() -> Router<AppState> {
|
||||
Router::new().route("/tokens/refresh", post(refresh_token))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TokenRefreshError {
|
||||
#[error("invalid refresh token")]
|
||||
InvalidToken,
|
||||
#[error("session has been revoked")]
|
||||
SessionRevoked,
|
||||
#[error("refresh token expired")]
|
||||
TokenExpired,
|
||||
#[error("refresh token reused - possible token theft")]
|
||||
TokenReuseDetected,
|
||||
#[error(transparent)]
|
||||
Jwt(#[from] JwtError),
|
||||
#[error(transparent)]
|
||||
Database(#[from] sqlx::Error),
|
||||
#[error(transparent)]
|
||||
SessionError(#[from] AuthSessionError),
|
||||
#[error(transparent)]
|
||||
Identity(#[from] IdentityError),
|
||||
}
|
||||
|
||||
pub async fn refresh_token(
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<TokenRefreshRequest>,
|
||||
) -> Result<Response, TokenRefreshError> {
|
||||
let jwt_service = &state.jwt();
|
||||
let session_repo = AuthSessionRepository::new(state.pool());
|
||||
|
||||
let token_details = match jwt_service.decode_refresh_token(&payload.refresh_token) {
|
||||
Ok(details) => details,
|
||||
Err(JwtError::TokenExpired) => return Err(TokenRefreshError::TokenExpired),
|
||||
Err(_) => return Err(TokenRefreshError::InvalidToken),
|
||||
};
|
||||
|
||||
let session = session_repo.get(token_details.session_id).await?;
|
||||
|
||||
if session.revoked_at.is_some() {
|
||||
return Err(TokenRefreshError::SessionRevoked);
|
||||
}
|
||||
|
||||
if session.refresh_token_id != Some(token_details.refresh_token_id)
|
||||
|| session_repo
|
||||
.is_refresh_token_revoked(token_details.refresh_token_id)
|
||||
.await?
|
||||
{
|
||||
// Token was reused, revoke all user sessions as a security measure
|
||||
let revoked_count = session_repo
|
||||
.revoke_all_user_sessions(token_details.user_id)
|
||||
.await?;
|
||||
warn!(
|
||||
user_id = %token_details.user_id,
|
||||
session_id = %token_details.session_id,
|
||||
revoked_sessions = revoked_count,
|
||||
"Refresh token reuse detected. Revoked all user sessions."
|
||||
);
|
||||
return Err(TokenRefreshError::TokenReuseDetected);
|
||||
}
|
||||
|
||||
let user_repo = UserRepository::new(state.pool());
|
||||
let user = user_repo.fetch_user(token_details.user_id).await?;
|
||||
|
||||
let tokens = jwt_service.generate_tokens(&session, &user)?;
|
||||
|
||||
let old_token_id = token_details.refresh_token_id;
|
||||
let new_token_id = tokens.refresh_token_id;
|
||||
|
||||
match session_repo
|
||||
.rotate_tokens(session.id, old_token_id, new_token_id)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(AuthSessionError::TokenReuseDetected) => {
|
||||
let revoked_count = session_repo
|
||||
.revoke_all_user_sessions(token_details.user_id)
|
||||
.await?;
|
||||
warn!(
|
||||
user_id = %token_details.user_id,
|
||||
session_id = %token_details.session_id,
|
||||
revoked_sessions = revoked_count,
|
||||
"Detected concurrent refresh attempt; revoked all user sessions"
|
||||
);
|
||||
return Err(TokenRefreshError::TokenReuseDetected);
|
||||
}
|
||||
Err(error) => return Err(TokenRefreshError::SessionError(error)),
|
||||
}
|
||||
|
||||
Ok(Json(TokenRefreshResponse {
|
||||
access_token: tokens.access_token,
|
||||
refresh_token: tokens.refresh_token,
|
||||
})
|
||||
.into_response())
|
||||
}
|
||||
|
||||
impl IntoResponse for TokenRefreshError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_code) = match self {
|
||||
TokenRefreshError::InvalidToken => (StatusCode::UNAUTHORIZED, "invalid_token"),
|
||||
TokenRefreshError::TokenExpired => (StatusCode::UNAUTHORIZED, "token_expired"),
|
||||
TokenRefreshError::SessionRevoked => (StatusCode::UNAUTHORIZED, "session_revoked"),
|
||||
TokenRefreshError::TokenReuseDetected => {
|
||||
(StatusCode::UNAUTHORIZED, "token_reuse_detected")
|
||||
}
|
||||
TokenRefreshError::Jwt(_) => (StatusCode::UNAUTHORIZED, "invalid_token"),
|
||||
TokenRefreshError::Identity(_) => (StatusCode::UNAUTHORIZED, "identity_error"),
|
||||
TokenRefreshError::Database(_) | TokenRefreshError::SessionError(_) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "internal_error")
|
||||
}
|
||||
};
|
||||
|
||||
let body = serde_json::json!({
|
||||
"error": error_code,
|
||||
"message": self.to_string()
|
||||
});
|
||||
|
||||
(status, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,14 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use chrono::{DateTime, Duration as ChronoDuration, Utc};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use tokio::time::{self, MissedTickBehavior};
|
||||
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||
use tracing::{Span, instrument};
|
||||
use utils::ws::{WS_AUTH_REFRESH_INTERVAL, WS_BULK_SYNC_THRESHOLD};
|
||||
use utils::ws::{WS_AUTH_REFRESH_INTERVAL, WS_BULK_SYNC_THRESHOLD, WS_TOKEN_EXPIRY_GRACE};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
@@ -17,7 +18,7 @@ use super::{
|
||||
use crate::{
|
||||
AppState,
|
||||
activity::{ActivityBroker, ActivityEvent, ActivityStream},
|
||||
auth::{JwtError, JwtIdentity, JwtService, RequestContext},
|
||||
auth::{JwtError, JwtService, RequestContext},
|
||||
db::{
|
||||
activity::ActivityRepository,
|
||||
auth::{AuthSessionError, AuthSessionRepository},
|
||||
@@ -69,9 +70,9 @@ pub async fn handle(
|
||||
state.jwt(),
|
||||
pool.clone(),
|
||||
ctx.session_id,
|
||||
ctx.session_secret.clone(),
|
||||
ctx.user.id,
|
||||
project_id,
|
||||
ctx.access_token_expires_at,
|
||||
);
|
||||
let mut auth_check_interval = time::interval(WS_AUTH_REFRESH_INTERVAL);
|
||||
auth_check_interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
|
||||
@@ -203,10 +204,9 @@ pub async fn handle(
|
||||
Err(error) => {
|
||||
tracing::info!(?error, "closing websocket due to auth verification error");
|
||||
let message = match error {
|
||||
AuthVerifyError::Revoked | AuthVerifyError::SecretMismatch => {
|
||||
"authorization revoked"
|
||||
}
|
||||
AuthVerifyError::Revoked => "authorization revoked",
|
||||
AuthVerifyError::MembershipRevoked => "project access revoked",
|
||||
AuthVerifyError::Expired => "authorization expired",
|
||||
AuthVerifyError::UserMismatch { .. }
|
||||
| AuthVerifyError::Decode(_)
|
||||
| AuthVerifyError::Session(_) => "authorization error",
|
||||
@@ -269,10 +269,10 @@ struct WsAuthState {
|
||||
jwt: Arc<JwtService>,
|
||||
pool: PgPool,
|
||||
session_id: Uuid,
|
||||
session_secret: String,
|
||||
expected_user_id: Uuid,
|
||||
project_id: Uuid,
|
||||
pending_token: Option<String>,
|
||||
token_expires_at: DateTime<Utc>,
|
||||
new_access_token: Option<String>,
|
||||
}
|
||||
|
||||
impl WsAuthState {
|
||||
@@ -280,48 +280,64 @@ impl WsAuthState {
|
||||
jwt: Arc<JwtService>,
|
||||
pool: PgPool,
|
||||
session_id: Uuid,
|
||||
session_secret: String,
|
||||
expected_user_id: Uuid,
|
||||
project_id: Uuid,
|
||||
token_expires_at: DateTime<Utc>,
|
||||
) -> Self {
|
||||
Self {
|
||||
jwt,
|
||||
pool,
|
||||
session_id,
|
||||
session_secret,
|
||||
expected_user_id,
|
||||
project_id,
|
||||
pending_token: None,
|
||||
new_access_token: None,
|
||||
token_expires_at,
|
||||
}
|
||||
}
|
||||
|
||||
fn store_token(&mut self, token: String) {
|
||||
self.pending_token = Some(token);
|
||||
self.new_access_token = Some(token);
|
||||
}
|
||||
|
||||
async fn verify(&mut self) -> Result<(), AuthVerifyError> {
|
||||
if let Some(token) = self.pending_token.take() {
|
||||
let identity = self.jwt.decode(&token).map_err(AuthVerifyError::Decode)?;
|
||||
self.apply_identity(identity).await?;
|
||||
if let Some(token) = self.new_access_token.take() {
|
||||
let token_details = self
|
||||
.jwt
|
||||
.decode_access_token_with_leeway(&token, WS_TOKEN_EXPIRY_GRACE.as_secs())
|
||||
.map_err(AuthVerifyError::Decode)?;
|
||||
self.apply_identity(token_details.user_id, token_details.session_id)
|
||||
.await?;
|
||||
self.token_expires_at = token_details.expires_at;
|
||||
}
|
||||
|
||||
self.validate_token_expiry()?;
|
||||
self.validate_session().await?;
|
||||
self.validate_membership().await
|
||||
}
|
||||
|
||||
async fn apply_identity(&mut self, identity: JwtIdentity) -> Result<(), AuthVerifyError> {
|
||||
if identity.user_id != self.expected_user_id {
|
||||
async fn apply_identity(
|
||||
&mut self,
|
||||
user_id: Uuid,
|
||||
session_id: Uuid,
|
||||
) -> Result<(), AuthVerifyError> {
|
||||
if user_id != self.expected_user_id {
|
||||
return Err(AuthVerifyError::UserMismatch {
|
||||
expected: self.expected_user_id,
|
||||
received: identity.user_id,
|
||||
received: user_id,
|
||||
});
|
||||
}
|
||||
|
||||
self.session_id = identity.session_id;
|
||||
self.session_secret = identity.nonce;
|
||||
self.session_id = session_id;
|
||||
self.validate_session().await
|
||||
}
|
||||
|
||||
fn validate_token_expiry(&self) -> Result<(), AuthVerifyError> {
|
||||
if self.token_expires_at + ws_leeway_duration() > Utc::now() {
|
||||
return Ok(());
|
||||
}
|
||||
Err(AuthVerifyError::Expired)
|
||||
}
|
||||
|
||||
async fn validate_session(&self) -> Result<(), AuthVerifyError> {
|
||||
let repo = AuthSessionRepository::new(&self.pool);
|
||||
let session = repo
|
||||
@@ -333,14 +349,6 @@ impl WsAuthState {
|
||||
return Err(AuthVerifyError::Revoked);
|
||||
}
|
||||
|
||||
if !self
|
||||
.jwt
|
||||
.verify_session_secret(session.session_secret_hash.as_deref(), &self.session_secret)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return Err(AuthVerifyError::SecretMismatch);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -364,6 +372,10 @@ impl WsAuthState {
|
||||
}
|
||||
}
|
||||
|
||||
fn ws_leeway_duration() -> ChronoDuration {
|
||||
ChronoDuration::from_std(WS_TOKEN_EXPIRY_GRACE).unwrap()
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum AuthVerifyError {
|
||||
#[error(transparent)]
|
||||
@@ -374,10 +386,10 @@ enum AuthVerifyError {
|
||||
Session(#[from] AuthSessionError),
|
||||
#[error("session revoked")]
|
||||
Revoked,
|
||||
#[error("session rotated")]
|
||||
SecretMismatch,
|
||||
#[error("organization membership revoked")]
|
||||
MembershipRevoked,
|
||||
#[error("access token expired")]
|
||||
Expired,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
||||
@@ -152,6 +152,7 @@ impl IntoResponse for ApiError {
|
||||
StatusCode::from_u16(*status).unwrap_or(StatusCode::BAD_GATEWAY),
|
||||
"RemoteClientError",
|
||||
),
|
||||
RemoteClientError::Token(_) => (StatusCode::BAD_GATEWAY, "RemoteClientError"),
|
||||
RemoteClientError::Api(code) => match code {
|
||||
services::services::remote_client::HandoffErrorCode::NotFound => {
|
||||
(StatusCode::NOT_FOUND, "RemoteClientError")
|
||||
@@ -168,6 +169,9 @@ impl IntoResponse for ApiError {
|
||||
}
|
||||
_ => (StatusCode::BAD_REQUEST, "RemoteClientError"),
|
||||
},
|
||||
RemoteClientError::Storage(_) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "RemoteClientError")
|
||||
}
|
||||
RemoteClientError::Serde(_) | RemoteClientError::Url(_) => {
|
||||
(StatusCode::BAD_REQUEST, "RemoteClientError")
|
||||
}
|
||||
@@ -210,6 +214,12 @@ impl IntoResponse for ApiError {
|
||||
body.clone()
|
||||
}
|
||||
}
|
||||
RemoteClientError::Token(_) => {
|
||||
"Remote service returned an invalid access token. Please sign in again.".to_string()
|
||||
}
|
||||
RemoteClientError::Storage(_) => {
|
||||
"Failed to persist credentials locally. Please retry.".to_string()
|
||||
}
|
||||
RemoteClientError::Api(code) => match code {
|
||||
services::services::remote_client::HandoffErrorCode::NotFound => {
|
||||
"The requested resource was not found.".to_string()
|
||||
|
||||
@@ -13,6 +13,7 @@ use sha2::{Digest, Sha256};
|
||||
use utils::{
|
||||
api::oauth::{HandoffInitRequest, HandoffRedeemRequest, StatusResponse},
|
||||
assets::config_path,
|
||||
jwt::extract_expiration,
|
||||
response::ApiResponse,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
@@ -119,8 +120,12 @@ async fn handoff_complete(
|
||||
|
||||
let redeem = client.handoff_redeem(&redeem_request).await?;
|
||||
|
||||
let expires_at = extract_expiration(&redeem.access_token)
|
||||
.map_err(|err| ApiError::BadRequest(format!("Invalid access token: {err}")))?;
|
||||
let credentials = Credentials {
|
||||
access_token: redeem.access_token.clone(),
|
||||
access_token: Some(redeem.access_token.clone()),
|
||||
refresh_token: redeem.refresh_token.clone(),
|
||||
expires_at: Some(expires_at),
|
||||
};
|
||||
|
||||
deployment
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::{Mutex as TokioMutex, OwnedMutexGuard, RwLock};
|
||||
use utils::api::oauth::ProfileResponse;
|
||||
|
||||
use super::oauth_credentials::{Credentials, OAuthCredentials};
|
||||
@@ -9,6 +9,7 @@ use super::oauth_credentials::{Credentials, OAuthCredentials};
|
||||
pub struct AuthContext {
|
||||
oauth: Arc<OAuthCredentials>,
|
||||
profile: Arc<RwLock<Option<ProfileResponse>>>,
|
||||
refresh_lock: Arc<TokioMutex<()>>,
|
||||
}
|
||||
|
||||
impl AuthContext {
|
||||
@@ -16,7 +17,11 @@ impl AuthContext {
|
||||
oauth: Arc<OAuthCredentials>,
|
||||
profile: Arc<RwLock<Option<ProfileResponse>>>,
|
||||
) -> Self {
|
||||
Self { oauth, profile }
|
||||
Self {
|
||||
oauth,
|
||||
profile,
|
||||
refresh_lock: Arc::new(TokioMutex::new(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_credentials(&self) -> Option<Credentials> {
|
||||
@@ -42,4 +47,8 @@ impl AuthContext {
|
||||
pub async fn clear_profile(&self) {
|
||||
*self.profile.write().await = None
|
||||
}
|
||||
|
||||
pub async fn refresh_guard(&self) -> OwnedMutexGuard<()> {
|
||||
self.refresh_lock.clone().lock_owned().await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,40 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use chrono::{DateTime, Duration as ChronoDuration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// OAuth credentials containing the JWT access token.
|
||||
/// The access_token is a JWT from the remote OAuth service and should be treated as opaque.
|
||||
/// OAuth credentials containing the JWT tokens issued by the remote OAuth service.
|
||||
/// The `access_token` is short-lived; `refresh_token` allows minting a new pair.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Credentials {
|
||||
pub access_token: String,
|
||||
pub access_token: Option<String>,
|
||||
pub refresh_token: String,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl Credentials {
|
||||
pub fn expires_soon(&self, leeway: ChronoDuration) -> bool {
|
||||
match (self.access_token.as_ref(), self.expires_at.as_ref()) {
|
||||
(Some(_), Some(exp)) => Utc::now() + leeway >= *exp,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct StoredCredentials {
|
||||
refresh_token: String,
|
||||
}
|
||||
|
||||
impl From<StoredCredentials> for Credentials {
|
||||
fn from(value: StoredCredentials) -> Self {
|
||||
Self {
|
||||
access_token: None,
|
||||
refresh_token: value.refresh_token,
|
||||
expires_at: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Service for managing OAuth credentials (JWT tokens) in memory and persistent storage.
|
||||
@@ -26,13 +53,16 @@ impl OAuthCredentials {
|
||||
}
|
||||
|
||||
pub async fn load(&self) -> std::io::Result<()> {
|
||||
let creds = self.backend.load().await?;
|
||||
let creds = self.backend.load().await?.map(Credentials::from);
|
||||
*self.inner.write().await = creds;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn save(&self, creds: &Credentials) -> std::io::Result<()> {
|
||||
self.backend.save(creds).await?;
|
||||
let stored = StoredCredentials {
|
||||
refresh_token: creds.refresh_token.clone(),
|
||||
};
|
||||
self.backend.save(&stored).await?;
|
||||
*self.inner.write().await = Some(creds.clone());
|
||||
Ok(())
|
||||
}
|
||||
@@ -49,8 +79,8 @@ impl OAuthCredentials {
|
||||
}
|
||||
|
||||
trait StoreBackend {
|
||||
async fn load(&self) -> std::io::Result<Option<Credentials>>;
|
||||
async fn save(&self, creds: &Credentials) -> std::io::Result<()>;
|
||||
async fn load(&self) -> std::io::Result<Option<StoredCredentials>>;
|
||||
async fn save(&self, creds: &StoredCredentials) -> std::io::Result<()>;
|
||||
async fn clear(&self) -> std::io::Result<()>;
|
||||
}
|
||||
|
||||
@@ -86,7 +116,7 @@ impl Backend {
|
||||
}
|
||||
|
||||
impl StoreBackend for Backend {
|
||||
async fn load(&self) -> std::io::Result<Option<Credentials>> {
|
||||
async fn load(&self) -> std::io::Result<Option<StoredCredentials>> {
|
||||
match self {
|
||||
Backend::File(b) => b.load().await,
|
||||
#[cfg(target_os = "macos")]
|
||||
@@ -94,7 +124,7 @@ impl StoreBackend for Backend {
|
||||
}
|
||||
}
|
||||
|
||||
async fn save(&self, creds: &Credentials) -> std::io::Result<()> {
|
||||
async fn save(&self, creds: &StoredCredentials) -> std::io::Result<()> {
|
||||
match self {
|
||||
Backend::File(b) => b.save(creds).await,
|
||||
#[cfg(target_os = "macos")]
|
||||
@@ -116,13 +146,13 @@ struct FileBackend {
|
||||
}
|
||||
|
||||
impl FileBackend {
|
||||
async fn load(&self) -> std::io::Result<Option<Credentials>> {
|
||||
async fn load(&self) -> std::io::Result<Option<StoredCredentials>> {
|
||||
if !self.path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let bytes = std::fs::read(&self.path)?;
|
||||
match serde_json::from_slice::<Credentials>(&bytes) {
|
||||
match Self::parse_credentials(&bytes) {
|
||||
Ok(creds) => Ok(Some(creds)),
|
||||
Err(e) => {
|
||||
tracing::warn!(?e, "failed to parse credentials file, renaming to .bad");
|
||||
@@ -133,7 +163,11 @@ impl FileBackend {
|
||||
}
|
||||
}
|
||||
|
||||
async fn save(&self, creds: &Credentials) -> std::io::Result<()> {
|
||||
fn parse_credentials(bytes: &[u8]) -> Result<StoredCredentials, serde_json::Error> {
|
||||
serde_json::from_slice::<StoredCredentials>(bytes)
|
||||
}
|
||||
|
||||
async fn save(&self, creds: &StoredCredentials) -> std::io::Result<()> {
|
||||
let tmp = self.path.with_extension("tmp");
|
||||
|
||||
let file = {
|
||||
@@ -149,7 +183,7 @@ impl FileBackend {
|
||||
opts.open(&tmp)?
|
||||
};
|
||||
|
||||
serde_json::to_writer_pretty(&file, &creds)?;
|
||||
serde_json::to_writer_pretty(&file, creds)?;
|
||||
file.sync_all()?;
|
||||
drop(file);
|
||||
|
||||
@@ -172,14 +206,17 @@ impl KeychainBackend {
|
||||
const ACCOUNT_NAME: &'static str = "default";
|
||||
const ERR_SEC_ITEM_NOT_FOUND: i32 = -25300;
|
||||
|
||||
async fn load(&self) -> std::io::Result<Option<Credentials>> {
|
||||
async fn load(&self) -> std::io::Result<Option<StoredCredentials>> {
|
||||
use security_framework::passwords::get_generic_password;
|
||||
|
||||
match get_generic_password(Self::SERVICE_NAME, Self::ACCOUNT_NAME) {
|
||||
Ok(bytes) => match serde_json::from_slice::<Credentials>(&bytes) {
|
||||
Ok(bytes) => match serde_json::from_slice::<StoredCredentials>(&bytes) {
|
||||
Ok(creds) => Ok(Some(creds)),
|
||||
Err(e) => {
|
||||
tracing::warn!(?e, "failed to parse keychain credentials; ignoring");
|
||||
Err(error) => {
|
||||
tracing::warn!(
|
||||
?error,
|
||||
"failed to parse keychain credentials; ignoring entry and requiring re-login"
|
||||
);
|
||||
Ok(None)
|
||||
}
|
||||
},
|
||||
@@ -188,7 +225,7 @@ impl KeychainBackend {
|
||||
}
|
||||
}
|
||||
|
||||
async fn save(&self, creds: &Credentials) -> std::io::Result<()> {
|
||||
async fn save(&self, creds: &StoredCredentials) -> std::io::Result<()> {
|
||||
use security_framework::passwords::set_generic_password;
|
||||
|
||||
let bytes = serde_json::to_vec_pretty(creds).map_err(std::io::Error::other)?;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use backon::{ExponentialBuilder, Retryable};
|
||||
use chrono::Duration as ChronoDuration;
|
||||
use remote::{
|
||||
activity::ActivityResponse,
|
||||
routes::tasks::{
|
||||
@@ -16,23 +17,26 @@ use serde_json::Value;
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
use url::Url;
|
||||
use utils::api::{
|
||||
oauth::{
|
||||
HandoffInitRequest, HandoffInitResponse, HandoffRedeemRequest, HandoffRedeemResponse,
|
||||
ProfileResponse,
|
||||
use utils::{
|
||||
api::{
|
||||
oauth::{
|
||||
HandoffInitRequest, HandoffInitResponse, HandoffRedeemRequest, HandoffRedeemResponse,
|
||||
ProfileResponse, TokenRefreshRequest, TokenRefreshResponse,
|
||||
},
|
||||
organizations::{
|
||||
AcceptInvitationResponse, CreateInvitationRequest, CreateInvitationResponse,
|
||||
CreateOrganizationRequest, CreateOrganizationResponse, GetInvitationResponse,
|
||||
GetOrganizationResponse, ListInvitationsResponse, ListMembersResponse,
|
||||
ListOrganizationsResponse, Organization, RevokeInvitationRequest,
|
||||
UpdateMemberRoleRequest, UpdateMemberRoleResponse, UpdateOrganizationRequest,
|
||||
},
|
||||
projects::{ListProjectsResponse, RemoteProject},
|
||||
},
|
||||
organizations::{
|
||||
AcceptInvitationResponse, CreateInvitationRequest, CreateInvitationResponse,
|
||||
CreateOrganizationRequest, CreateOrganizationResponse, GetInvitationResponse,
|
||||
GetOrganizationResponse, ListInvitationsResponse, ListMembersResponse,
|
||||
ListOrganizationsResponse, Organization, RevokeInvitationRequest, UpdateMemberRoleRequest,
|
||||
UpdateMemberRoleResponse, UpdateOrganizationRequest,
|
||||
},
|
||||
projects::{ListProjectsResponse, RemoteProject},
|
||||
jwt::extract_expiration,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::auth::AuthContext;
|
||||
use super::{auth::AuthContext, oauth_credentials::Credentials};
|
||||
|
||||
#[derive(Debug, Clone, Error)]
|
||||
pub enum RemoteClientError {
|
||||
@@ -50,6 +54,10 @@ pub enum RemoteClientError {
|
||||
Serde(String),
|
||||
#[error("url error: {0}")]
|
||||
Url(String),
|
||||
#[error("credentials storage error: {0}")]
|
||||
Storage(String),
|
||||
#[error("invalid access token: {0}")]
|
||||
Token(String),
|
||||
}
|
||||
|
||||
impl RemoteClientError {
|
||||
@@ -124,6 +132,7 @@ impl Clone for RemoteClient {
|
||||
|
||||
impl RemoteClient {
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
const TOKEN_REFRESH_LEEWAY_SECS: i64 = 20;
|
||||
|
||||
pub fn new(base_url: &str, auth_context: AuthContext) -> Result<Self, RemoteClientError> {
|
||||
let base = Url::parse(base_url).map_err(|e| RemoteClientError::Url(e.to_string()))?;
|
||||
@@ -139,14 +148,84 @@ impl RemoteClient {
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the token if available.
|
||||
async fn require_token(&self) -> Result<String, RemoteClientError> {
|
||||
let creds = self
|
||||
.auth_context
|
||||
.get_credentials()
|
||||
/// Returns a valid access token, refreshing when it's about to expire.
|
||||
fn require_token(
|
||||
&self,
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<String, RemoteClientError>> + Send + '_>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let leeway = ChronoDuration::seconds(Self::TOKEN_REFRESH_LEEWAY_SECS);
|
||||
let creds = self
|
||||
.auth_context
|
||||
.get_credentials()
|
||||
.await
|
||||
.ok_or(RemoteClientError::Auth)?;
|
||||
|
||||
if let Some(token) = creds.access_token.as_ref()
|
||||
&& !creds.expires_soon(leeway)
|
||||
{
|
||||
return Ok(token.clone());
|
||||
}
|
||||
|
||||
let refreshed = {
|
||||
let _refresh_guard = self.auth_context.refresh_guard().await;
|
||||
let latest = self
|
||||
.auth_context
|
||||
.get_credentials()
|
||||
.await
|
||||
.ok_or(RemoteClientError::Auth)?;
|
||||
if let Some(token) = latest.access_token.as_ref()
|
||||
&& !latest.expires_soon(leeway)
|
||||
{
|
||||
return Ok(token.clone());
|
||||
}
|
||||
|
||||
self.refresh_credentials(&latest).await
|
||||
};
|
||||
|
||||
match refreshed {
|
||||
Ok(updated) => updated.access_token.ok_or(RemoteClientError::Auth),
|
||||
Err(RemoteClientError::Auth) => {
|
||||
let _ = self.auth_context.clear_credentials().await;
|
||||
Err(RemoteClientError::Auth)
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn refresh_credentials(
|
||||
&self,
|
||||
creds: &Credentials,
|
||||
) -> Result<Credentials, RemoteClientError> {
|
||||
let response = self.refresh_token_request(&creds.refresh_token).await?;
|
||||
let access_token = response.access_token;
|
||||
let refresh_token = response.refresh_token;
|
||||
let expires_at = extract_expiration(&access_token)
|
||||
.map_err(|err| RemoteClientError::Token(err.to_string()))?;
|
||||
let new_creds = Credentials {
|
||||
access_token: Some(access_token),
|
||||
refresh_token,
|
||||
expires_at: Some(expires_at),
|
||||
};
|
||||
self.auth_context
|
||||
.save_credentials(&new_creds)
|
||||
.await
|
||||
.ok_or(RemoteClientError::Auth)?;
|
||||
Ok(creds.access_token)
|
||||
.map_err(|e| RemoteClientError::Storage(e.to_string()))?;
|
||||
Ok(new_creds)
|
||||
}
|
||||
|
||||
async fn refresh_token_request(
|
||||
&self,
|
||||
refresh_token: &str,
|
||||
) -> Result<TokenRefreshResponse, RemoteClientError> {
|
||||
let request = TokenRefreshRequest {
|
||||
refresh_token: refresh_token.to_string(),
|
||||
};
|
||||
self.post_public("/v1/tokens/refresh", Some(&request))
|
||||
.await
|
||||
.map_err(|e| self.map_api_error(e))
|
||||
}
|
||||
|
||||
/// Returns the base URL for the client.
|
||||
@@ -154,6 +233,11 @@ impl RemoteClient {
|
||||
self.base.as_str()
|
||||
}
|
||||
|
||||
/// Returns a valid access token for use-cases like maintaining a websocket connection.
|
||||
pub async fn access_token(&self) -> Result<String, RemoteClientError> {
|
||||
self.require_token().await
|
||||
}
|
||||
|
||||
/// Initiates an authorization-code handoff for the given provider.
|
||||
pub async fn handoff_init(
|
||||
&self,
|
||||
@@ -187,7 +271,7 @@ impl RemoteClient {
|
||||
&self,
|
||||
method: reqwest::Method,
|
||||
path: &str,
|
||||
token: Option<&str>,
|
||||
requires_auth: bool,
|
||||
body: Option<&B>,
|
||||
) -> Result<reqwest::Response, RemoteClientError>
|
||||
where
|
||||
@@ -201,8 +285,9 @@ impl RemoteClient {
|
||||
(|| async {
|
||||
let mut req = self.http.request(method.clone(), url.clone());
|
||||
|
||||
if let Some(t) = token {
|
||||
req = req.bearer_auth(t);
|
||||
if requires_auth {
|
||||
let token = self.require_token().await?;
|
||||
req = req.bearer_auth(token);
|
||||
}
|
||||
|
||||
if let Some(b) = body {
|
||||
@@ -245,7 +330,7 @@ impl RemoteClient {
|
||||
T: for<'de> Deserialize<'de>,
|
||||
{
|
||||
let res = self
|
||||
.send(reqwest::Method::GET, path, None, None::<&()>)
|
||||
.send(reqwest::Method::GET, path, false, None::<&()>)
|
||||
.await?;
|
||||
res.json::<T>()
|
||||
.await
|
||||
@@ -257,7 +342,7 @@ impl RemoteClient {
|
||||
T: for<'de> Deserialize<'de>,
|
||||
B: Serialize,
|
||||
{
|
||||
let res = self.send(reqwest::Method::POST, path, None, body).await?;
|
||||
let res = self.send(reqwest::Method::POST, path, false, body).await?;
|
||||
res.json::<T>()
|
||||
.await
|
||||
.map_err(|e| RemoteClientError::Serde(e.to_string()))
|
||||
@@ -268,9 +353,8 @@ impl RemoteClient {
|
||||
where
|
||||
T: for<'de> Deserialize<'de>,
|
||||
{
|
||||
let token = self.require_token().await?;
|
||||
let res = self
|
||||
.send(reqwest::Method::GET, path, Some(&token), None::<&()>)
|
||||
.send(reqwest::Method::GET, path, true, None::<&()>)
|
||||
.await?;
|
||||
res.json::<T>()
|
||||
.await
|
||||
@@ -282,10 +366,7 @@ impl RemoteClient {
|
||||
T: for<'de> Deserialize<'de>,
|
||||
B: Serialize,
|
||||
{
|
||||
let token = self.require_token().await?;
|
||||
let res = self
|
||||
.send(reqwest::Method::POST, path, Some(&token), body)
|
||||
.await?;
|
||||
let res = self.send(reqwest::Method::POST, path, true, body).await?;
|
||||
res.json::<T>()
|
||||
.await
|
||||
.map_err(|e| RemoteClientError::Serde(e.to_string()))
|
||||
@@ -296,9 +377,8 @@ impl RemoteClient {
|
||||
T: for<'de> Deserialize<'de>,
|
||||
B: Serialize,
|
||||
{
|
||||
let token = self.require_token().await?;
|
||||
let res = self
|
||||
.send(reqwest::Method::PATCH, path, Some(&token), Some(body))
|
||||
.send(reqwest::Method::PATCH, path, true, Some(body))
|
||||
.await?;
|
||||
res.json::<T>()
|
||||
.await
|
||||
@@ -306,8 +386,7 @@ impl RemoteClient {
|
||||
}
|
||||
|
||||
async fn delete_authed(&self, path: &str) -> Result<(), RemoteClientError> {
|
||||
let token = self.require_token().await?;
|
||||
self.send(reqwest::Method::DELETE, path, Some(&token), None::<&()>)
|
||||
self.send(reqwest::Method::DELETE, path, true, None::<&()>)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -503,12 +582,11 @@ impl RemoteClient {
|
||||
task_id: Uuid,
|
||||
request: &DeleteSharedTaskRequest,
|
||||
) -> Result<SharedTaskResponse, RemoteClientError> {
|
||||
let token = self.require_token().await?;
|
||||
let res = self
|
||||
.send(
|
||||
reqwest::Method::DELETE,
|
||||
&format!("/v1/tasks/{task_id}"),
|
||||
Some(&token),
|
||||
true,
|
||||
Some(request),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -23,7 +23,7 @@ use db::{
|
||||
use processor::ActivityProcessor;
|
||||
pub use publisher::SharePublisher;
|
||||
use remote::{
|
||||
ServerMessage,
|
||||
ClientMessage, ServerMessage,
|
||||
db::{tasks::SharedTask as RemoteSharedTask, users::UserData as RemoteUserData},
|
||||
};
|
||||
use sqlx::{Executor, Sqlite, SqlitePool};
|
||||
@@ -35,7 +35,9 @@ use tokio::{
|
||||
};
|
||||
use tokio_tungstenite::tungstenite::Message as WsMessage;
|
||||
use url::Url;
|
||||
use utils::ws::{WsClient, WsConfig, WsError, WsHandler, WsResult, run_ws_client};
|
||||
use utils::ws::{
|
||||
WS_AUTH_REFRESH_INTERVAL, WsClient, WsConfig, WsError, WsHandler, WsResult, run_ws_client,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
@@ -254,13 +256,21 @@ impl RemoteSync {
|
||||
let processor = self.processor.clone();
|
||||
let config = self.config.clone();
|
||||
let auth_ctx = self.auth_ctx.clone();
|
||||
let remote_client = processor.remote_client();
|
||||
let db = self.db.clone();
|
||||
let (shutdown_tx, shutdown_rx) = oneshot::channel();
|
||||
|
||||
let join = tokio::spawn(async move {
|
||||
let result =
|
||||
project_watcher_task(db, processor, config, auth_ctx, project_id, shutdown_rx)
|
||||
.await;
|
||||
let result = project_watcher_task(
|
||||
db,
|
||||
processor,
|
||||
config,
|
||||
auth_ctx,
|
||||
remote_client,
|
||||
project_id,
|
||||
shutdown_rx,
|
||||
)
|
||||
.await;
|
||||
|
||||
let _ = events_tx.send(ProjectWatcherEvent { project_id, result });
|
||||
});
|
||||
@@ -327,22 +337,27 @@ impl WsHandler for SharedWsHandler {
|
||||
|
||||
async fn spawn_shared_remote(
|
||||
processor: ActivityProcessor,
|
||||
auth_ctx: &AuthContext,
|
||||
remote_client: RemoteClient,
|
||||
url: Url,
|
||||
close_tx: oneshot::Sender<()>,
|
||||
remote_project_id: Uuid,
|
||||
) -> Result<WsClient, ShareError> {
|
||||
let auth_source = auth_ctx.clone();
|
||||
let remote_client_clone = remote_client.clone();
|
||||
let ws_config = WsConfig {
|
||||
url,
|
||||
ping_interval: Some(std::time::Duration::from_secs(30)),
|
||||
header_factory: Some(Arc::new(move || {
|
||||
let auth_source = auth_source.clone();
|
||||
let remote_client_clone = remote_client_clone.clone();
|
||||
Box::pin(async move {
|
||||
if let Some(creds) = auth_source.get_credentials().await {
|
||||
build_ws_headers(&creds.access_token)
|
||||
} else {
|
||||
Err(WsError::MissingAuth)
|
||||
match remote_client_clone.access_token().await {
|
||||
Ok(token) => build_ws_headers(&token),
|
||||
Err(error) => {
|
||||
tracing::warn!(
|
||||
?error,
|
||||
"failed to obtain access token for websocket connection"
|
||||
);
|
||||
Err(WsError::MissingAuth)
|
||||
}
|
||||
}
|
||||
})
|
||||
})),
|
||||
@@ -356,6 +371,7 @@ async fn spawn_shared_remote(
|
||||
let client = run_ws_client(handler, ws_config)
|
||||
.await
|
||||
.map_err(ShareError::from)?;
|
||||
spawn_ws_auth_refresh_task(client.clone(), remote_client);
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
@@ -365,6 +381,7 @@ async fn project_watcher_task(
|
||||
processor: ActivityProcessor,
|
||||
config: ShareConfig,
|
||||
auth_ctx: AuthContext,
|
||||
remote_client: RemoteClient,
|
||||
remote_project_id: Uuid,
|
||||
mut shutdown_rx: oneshot::Receiver<()>,
|
||||
) -> Result<(), ShareError> {
|
||||
@@ -410,7 +427,7 @@ async fn project_watcher_task(
|
||||
let (close_tx, close_rx) = oneshot::channel();
|
||||
let ws_connection = match spawn_shared_remote(
|
||||
processor.clone(),
|
||||
&auth_ctx,
|
||||
remote_client.clone(),
|
||||
ws_url,
|
||||
close_tx,
|
||||
remote_project_id,
|
||||
@@ -479,6 +496,44 @@ fn build_ws_headers(access_token: &str) -> WsResult<Vec<(HeaderName, HeaderValue
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
fn spawn_ws_auth_refresh_task(client: WsClient, remote_client: RemoteClient) {
|
||||
tokio::spawn(async move {
|
||||
let mut close_rx = client.subscribe_close();
|
||||
loop {
|
||||
match remote_client.access_token().await {
|
||||
Ok(token) => {
|
||||
if let Err(err) = send_ws_auth_token(&client, token).await {
|
||||
tracing::warn!(
|
||||
?err,
|
||||
"failed to send websocket auth token; stopping auth refresh"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
?err,
|
||||
"failed to obtain access token for websocket auth refresh; stopping auth refresh"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
_ = close_rx.changed() => break,
|
||||
_ = sleep(WS_AUTH_REFRESH_INTERVAL) => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn send_ws_auth_token(client: &WsClient, token: String) -> Result<(), ShareError> {
|
||||
let payload = serde_json::to_string(&ClientMessage::AuthToken { token })?;
|
||||
client
|
||||
.send(WsMessage::Text(payload.into()))
|
||||
.map_err(ShareError::from)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RemoteSyncHandle {
|
||||
inner: Arc<RemoteSyncHandleInner>,
|
||||
|
||||
@@ -48,6 +48,10 @@ impl ActivityProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remote_client(&self) -> RemoteClient {
|
||||
self.remote_client.clone()
|
||||
}
|
||||
|
||||
pub async fn process_event(&self, event: ActivityEvent) -> Result<(), ShareError> {
|
||||
let mut tx = self.db.pool.begin().await?;
|
||||
match event.event_type.as_str() {
|
||||
|
||||
@@ -23,7 +23,7 @@ sentry = { version = "0.41.0", features = ["anyhow", "backtrace", "panic", "debu
|
||||
sentry-tracing = { version = "0.41.0", features = ["backtrace"] }
|
||||
futures-util = "0.3"
|
||||
json-patch = "2.0"
|
||||
jsonwebtoken = { version = "10.0.0", features = ["rust_crypto"] }
|
||||
jsonwebtoken = { version = "10.2.0", features = ["rust_crypto"] }
|
||||
tokio = { workspace = true }
|
||||
futures = "0.3.31"
|
||||
tokio-stream = { version = "0.1.17", features = ["sync"] }
|
||||
|
||||
@@ -29,6 +29,20 @@ pub struct HandoffRedeemRequest {
|
||||
#[ts(export)]
|
||||
pub struct HandoffRedeemResponse {
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, TS)]
|
||||
#[ts(export)]
|
||||
pub struct TokenRefreshRequest {
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, TS)]
|
||||
#[ts(export)]
|
||||
pub struct TokenRefreshResponse {
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, TS)]
|
||||
|
||||
26
crates/utils/src/jwt.rs
Normal file
26
crates/utils/src/jwt.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use jsonwebtoken::dangerous::insecure_decode;
|
||||
use serde::Deserialize;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TokenClaimsError {
|
||||
#[error("failed to decode JWT: {0}")]
|
||||
Decode(#[from] jsonwebtoken::errors::Error),
|
||||
#[error("missing `exp` claim in token")]
|
||||
MissingExpiration,
|
||||
#[error("invalid `exp` value `{0}`")]
|
||||
InvalidExpiration(i64),
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ExpClaim {
|
||||
exp: Option<i64>,
|
||||
}
|
||||
|
||||
/// Extract the expiration timestamp from a JWT without verifying its signature.
|
||||
pub fn extract_expiration(token: &str) -> Result<DateTime<Utc>, TokenClaimsError> {
|
||||
let data = insecure_decode::<ExpClaim>(token)?;
|
||||
let exp = data.claims.exp.ok_or(TokenClaimsError::MissingExpiration)?;
|
||||
DateTime::from_timestamp(exp, 0).ok_or(TokenClaimsError::InvalidExpiration(exp))
|
||||
}
|
||||
@@ -8,6 +8,7 @@ pub mod assets;
|
||||
pub mod browser;
|
||||
pub mod diff;
|
||||
pub mod git;
|
||||
pub mod jwt;
|
||||
pub mod log_msg;
|
||||
pub mod msg_store;
|
||||
pub mod path;
|
||||
|
||||
@@ -17,6 +17,7 @@ export type HandoffInitResponse = {
|
||||
|
||||
export type HandoffRedeemResponse = {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
};
|
||||
|
||||
export type AcceptInvitationResponse = {
|
||||
|
||||
Reference in New Issue
Block a user