Improve auth
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
use axum::{
|
||||
async_trait,
|
||||
body::Body,
|
||||
extract::FromRequestParts,
|
||||
http::{request::Parts, StatusCode},
|
||||
http::{request::Parts, StatusCode, Request},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -15,6 +19,7 @@ pub struct Claims {
|
||||
pub exp: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthUser {
|
||||
pub user_id: Uuid,
|
||||
pub email: String,
|
||||
@@ -29,32 +34,12 @@ where
|
||||
type Rejection = StatusCode;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let headers = &parts.headers;
|
||||
|
||||
let auth_header = headers
|
||||
.get("authorization")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let token = auth_header
|
||||
.strip_prefix("Bearer ")
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "your-secret-key".to_string());
|
||||
|
||||
let claims = decode::<Claims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(jwt_secret.as_ref()),
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?
|
||||
.claims;
|
||||
|
||||
Ok(AuthUser {
|
||||
user_id: claims.user_id,
|
||||
email: claims.email,
|
||||
is_admin: claims.is_admin,
|
||||
})
|
||||
// Get user from request extensions (set by auth middleware)
|
||||
parts
|
||||
.extensions
|
||||
.get::<AuthUser>()
|
||||
.cloned()
|
||||
.ok_or(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,3 +72,58 @@ pub fn hash_password(password: &str) -> Result<String, bcrypt::BcryptError> {
|
||||
pub fn verify_password(password: &str, hash: &str) -> Result<bool, bcrypt::BcryptError> {
|
||||
bcrypt::verify(password, hash)
|
||||
}
|
||||
|
||||
// Auth middleware that requires authentication for all routes
|
||||
pub async fn auth_middleware(
|
||||
mut request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let headers = request.headers();
|
||||
|
||||
let auth_header = headers
|
||||
.get("authorization")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let token = auth_header
|
||||
.strip_prefix("Bearer ")
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "your-secret-key".to_string());
|
||||
|
||||
let claims = decode::<Claims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(jwt_secret.as_ref()),
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?
|
||||
.claims;
|
||||
|
||||
// Get database pool from request extensions
|
||||
let pool = request
|
||||
.extensions()
|
||||
.get::<PgPool>()
|
||||
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Verify user exists in database
|
||||
let user_exists = sqlx::query!(
|
||||
"SELECT id FROM users WHERE id = $1",
|
||||
claims.user_id
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if user_exists.is_none() {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
// Add user info to request extensions for handlers to access
|
||||
request.extensions_mut().insert(AuthUser {
|
||||
user_id: claims.user_id,
|
||||
email: claims.email,
|
||||
is_admin: claims.is_admin,
|
||||
});
|
||||
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use axum::{
|
||||
extract::Extension,
|
||||
middleware,
|
||||
response::Json as ResponseJson,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
@@ -12,7 +13,7 @@ mod auth;
|
||||
mod models;
|
||||
mod routes;
|
||||
|
||||
use auth::hash_password;
|
||||
use auth::{auth_middleware, hash_password};
|
||||
use models::ApiResponse;
|
||||
use routes::{health, projects, tasks, users};
|
||||
|
||||
@@ -51,13 +52,24 @@ async fn main() -> anyhow::Result<()> {
|
||||
tracing::warn!("Failed to create admin account: {}", e);
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
// Public routes (no auth required)
|
||||
let public_routes = Router::new()
|
||||
.route("/", get(|| async { "Bloop API" }))
|
||||
.route("/health", get(health::health_check))
|
||||
.route("/echo", post(echo_handler))
|
||||
.merge(users::public_users_router());
|
||||
|
||||
// Protected routes (auth required)
|
||||
let protected_routes = Router::new()
|
||||
.merge(projects::projects_router())
|
||||
.merge(tasks::tasks_router())
|
||||
.merge(users::users_router())
|
||||
.merge(users::protected_users_router())
|
||||
.layer(Extension(pool.clone()))
|
||||
.layer(middleware::from_fn(auth_middleware));
|
||||
|
||||
let app = Router::new()
|
||||
.merge(public_routes)
|
||||
.merge(protected_routes)
|
||||
.layer(Extension(pool))
|
||||
.layer(CorsLayer::permissive());
|
||||
|
||||
|
||||
@@ -13,7 +13,10 @@ use chrono::Utc;
|
||||
use crate::models::{ApiResponse, project::{Project, CreateProject, UpdateProject}};
|
||||
use crate::auth::AuthUser;
|
||||
|
||||
pub async fn get_projects(Extension(pool): Extension<PgPool>) -> Result<ResponseJson<ApiResponse<Vec<Project>>>, StatusCode> {
|
||||
pub async fn get_projects(
|
||||
auth: AuthUser,
|
||||
Extension(pool): Extension<PgPool>
|
||||
) -> Result<ResponseJson<ApiResponse<Vec<Project>>>, StatusCode> {
|
||||
match sqlx::query_as!(
|
||||
Project,
|
||||
"SELECT id, name, owner_id, created_at, updated_at FROM projects ORDER BY created_at DESC"
|
||||
@@ -34,6 +37,7 @@ pub async fn get_projects(Extension(pool): Extension<PgPool>) -> Result<Response
|
||||
}
|
||||
|
||||
pub async fn get_project(
|
||||
auth: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Extension(pool): Extension<PgPool>
|
||||
) -> Result<ResponseJson<ApiResponse<Project>>, StatusCode> {
|
||||
|
||||
@@ -14,6 +14,7 @@ use crate::models::{ApiResponse, task::{Task, CreateTask, UpdateTask, TaskStatus
|
||||
use crate::auth::AuthUser;
|
||||
|
||||
pub async fn get_project_tasks(
|
||||
auth: AuthUser,
|
||||
Path(project_id): Path<Uuid>,
|
||||
Extension(pool): Extension<PgPool>
|
||||
) -> Result<ResponseJson<ApiResponse<Vec<Task>>>, StatusCode> {
|
||||
@@ -41,6 +42,7 @@ pub async fn get_project_tasks(
|
||||
}
|
||||
|
||||
pub async fn get_task(
|
||||
auth: AuthUser,
|
||||
Path((project_id, task_id)): Path<(Uuid, Uuid)>,
|
||||
Extension(pool): Extension<PgPool>
|
||||
) -> Result<ResponseJson<ApiResponse<Task>>, StatusCode> {
|
||||
|
||||
@@ -292,9 +292,29 @@ pub async fn get_current_user(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn users_router() -> Router {
|
||||
pub async fn check_auth_status(
|
||||
auth: AuthUser,
|
||||
) -> ResponseJson<ApiResponse<serde_json::Value>> {
|
||||
ResponseJson(ApiResponse {
|
||||
success: true,
|
||||
data: Some(serde_json::json!({
|
||||
"authenticated": true,
|
||||
"user_id": auth.user_id,
|
||||
"email": auth.email,
|
||||
"is_admin": auth.is_admin
|
||||
})),
|
||||
message: Some("User is authenticated".to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn public_users_router() -> Router {
|
||||
Router::new()
|
||||
.route("/auth/login", post(login))
|
||||
}
|
||||
|
||||
pub fn protected_users_router() -> Router {
|
||||
Router::new()
|
||||
.route("/auth/status", get(check_auth_status))
|
||||
.route("/auth/me", get(get_current_user))
|
||||
.route("/users", get(get_users).post(create_user))
|
||||
.route("/users/:id", get(get_user).put(update_user).delete(delete_user))
|
||||
|
||||
Reference in New Issue
Block a user