/*
* Copyright (C) 2024 Jonni Liljamo <jonni@liljamo.com>
*
* This file is licensed under GPL-3.0-or-later, see NOTICE and LICENSE for
* more information.
*/
use std::{future::Future, pin::Pin};
use tonic::{Request, Status};
use tonic_async_interceptor::AsyncInterceptor;
use super::WorkerContext;
use crate::db::DatabaseHandle;
#[derive(Clone)]
pub struct AuthInterceptor {
db_handle: DatabaseHandle,
}
impl AuthInterceptor {
pub fn new(db_handle: DatabaseHandle) -> Self {
Self { db_handle }
}
fn authenticate(
&self,
mut request: Request<()>,
) -> impl Future<Output = Result<Request<()>, Status>> + Send + 'static {
let db_handle = self.db_handle.clone();
async move {
let token = match request.metadata().get("authorization") {
Some(value) => match value.to_str() {
Ok(bearer_token) => bearer_token.trim_start_matches("Bearer "),
Err(_) => {
return Err(Status::invalid_argument(
"couldn't read bearer auth to string",
))
}
},
None => return Err(Status::unauthenticated("request had no authorization")),
};
let workers = db_handle.read_workers().await;
let worker_id = match workers
.iter()
.filter(|w| w.auth_token == token)
.collect::<Vec<_>>()
.first()
{
Some(worker) => worker.id,
None => return Err(Status::unauthenticated("bad token")),
};
request.extensions_mut().insert(WorkerContext { worker_id });
Ok(request)
}
}
}
impl AsyncInterceptor for AuthInterceptor {
type Future = Pin<Box<dyn Future<Output = Result<Request<()>, Status>> + Send + 'static>>;
fn call(&mut self, request: Request<()>) -> Self::Future {
Box::pin(self.authenticate(request))
}
}