/*
* Copyright (C) 2024 Jonni Liljamo <jonni@liljamo.com>
*
* This file is licensed under GPL-3.0-or-later, see NOTICE and LICENSE for
* more information.
*/
use std::sync::Arc;
use emerwen_protocol::{
emerwen_protocol_server::{EmerwenProtocol, EmerwenProtocolServer},
SetTargetStateRequest, TargetsResponse,
};
use tokio::sync::{mpsc, RwLock};
use tonic::{transport::Server, Request, Response, Status};
use tonic_async_interceptor::async_interceptor;
use tracing::info;
use crate::{db::DatabaseHandle, worker::Worker};
mod auth;
use auth::AuthInterceptor;
#[derive(Clone)]
struct WorkerContext {
worker_id: u32,
}
pub enum ServerMessage {
ReloadWorkers,
}
pub async fn run(
addr: String,
db_handle: DatabaseHandle,
mut server_rx: mpsc::Receiver<ServerMessage>,
) -> Result<(), Box<dyn std::error::Error>> {
let workers = Arc::new(RwLock::new(db_handle.read_workers().await));
let workers_two = workers.clone();
let db_handle_two = db_handle.clone();
tokio::spawn(async move {
let workers = workers.clone();
while let Some(msg) = server_rx.recv().await {
match msg {
ServerMessage::ReloadWorkers => {
let new_workers = db_handle_two.read_workers().await;
let mut write_guard = workers.write().await;
*write_guard = new_workers;
}
}
}
});
Server::builder()
.layer(async_interceptor(AuthInterceptor::new(db_handle.clone())))
.add_service(EmerwenProtocolServer::new(
MasterServer::new(db_handle.clone(), workers_two).await,
))
.serve(addr.parse()?)
.await?;
Ok(())
}
struct MasterServer {
_db_handle: DatabaseHandle,
workers: Arc<RwLock<Vec<Worker>>>,
}
impl MasterServer {
async fn new(db_handle: DatabaseHandle, workers: Arc<RwLock<Vec<Worker>>>) -> Self {
Self {
_db_handle: db_handle,
workers,
}
}
}
#[tonic::async_trait]
impl EmerwenProtocol for MasterServer {
async fn get_targets(&self, request: Request<()>) -> Result<Response<TargetsResponse>, Status> {
let worker_id = match request.extensions().get::<WorkerContext>() {
Some(context) => context.worker_id,
None => return Err(Status::internal("no worker id in request")),
};
let targets = match self
.workers
.read()
.await
.iter()
.find(|worker| worker.id == worker_id)
{
Some(worker) => worker.targets.clone(),
None => return Err(Status::internal("no worker exists with request worker id")),
};
Ok(Response::new(TargetsResponse { targets }))
}
async fn set_target_state(
&self,
request: Request<SetTargetStateRequest>,
) -> Result<Response<()>, Status> {
// TODO: Implement. Follow the same pattern as above.
info!("{:?}", request.into_inner());
Ok(Response::new(()))
}
}