DEVELOPMENT ENVIRONMENT

~liljamo/canwa

ref: 3a2ad3d4f4147f683ac78acab88597433a1430a2 canwa/src/main.rs -rw-r--r-- 3.5 KiB
3a2ad3d4Jonni Liljamo feat: optionally read tokens from files a month ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/*
 * Copyright (C) 2025 Jonni Liljamo <jonni@liljamo.com>
 *
 * This file is licensed under AGPL-3.0-or-later, see NOTICE and LICENSE for
 * more information.
 */

use std::sync::Arc;

use axum::{
    Json, Router,
    http::HeaderMap,
    response::{Html, IntoResponse},
    routing::{get, post},
};
use clap::Parser;
use reqwest::StatusCode;
use serde::Deserialize;
use tokio::net::TcpListener;
use tower_http::trace::TraceLayer;
use tracing::{error, info};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

mod config;
use config::Config;
mod state;
use state::State;

mod service;

const LICENSE_HTML: &str = include_str!("../static/license.html");

#[derive(Parser)]
#[command(version)]
struct Args {
    /// Config file location.
    #[arg(short, long, default_value = "./canwa.toml")]
    config: String,
    /// Don't run the main program, useful for config validation.
    #[arg(long)]
    dry_run: bool,
}

#[tokio::main]
async fn main() {
    tracing_subscriber::registry()
        .with(
            tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
                format!(
                    "{}=debug,tower_http=debug,axum::rejection=trace",
                    env!("CARGO_CRATE_NAME")
                )
                .into()
            }),
        )
        .with(tracing_subscriber::fmt::layer())
        .init();

    let args: Args = Args::parse();
    let config: Config = Config::from_path(&args.config).await.unwrap();
    let state: Arc<State> = Arc::new(State::from_config(&config).unwrap());

    if args.dry_run {
        std::process::exit(0);
    }

    let router = Router::new()
        .route("/", get(|| async { "canwa" }))
        .route(
            "/license",
            get(|| async { (StatusCode::OK, Html::from(LICENSE_HTML)) }),
        )
        .route(
            "/message",
            post({
                let shared_state = Arc::clone(&state);
                move |headers, body| message(shared_state, headers, body)
            }),
        )
        .layer(TraceLayer::new_for_http())
        .with_state(state);

    let addr = format!("{}:{}", config.interface, config.port);
    info!(msg="serving http", %addr);
    let listener = TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, router).await.unwrap();
}

#[derive(Deserialize)]
struct MessageForm {
    title: String,
    message: String,
}

async fn message(
    state: Arc<State>,
    headers: HeaderMap,
    Json(message): Json<MessageForm>,
) -> impl IntoResponse {
    let token = match headers.get("Authorization") {
        Some(token) => match token.to_str() {
            Ok(token) => token,
            Err(_) => {
                return (StatusCode::UNAUTHORIZED, "unauthorized");
            }
        },
        None => {
            return (StatusCode::UNAUTHORIZED, "unauthorized");
        }
    };

    let notifier = match state.notifiers.iter().find(|(_k, v)| v.token == token) {
        Some(n) => n,
        None => return (StatusCode::UNAUTHORIZED, "unauthorized"),
    };

    info!(msg = "message", notifier = notifier.0);

    for (_k, v) in state
        .services
        .iter()
        .filter(|(k, _v)| notifier.1.services.contains(k))
    {
        match v.send(&message.title, &message.message).await {
            Ok(_) => {}
            Err(err) => {
                error!(msg = "message sending failed", ?err);
                return (StatusCode::INTERNAL_SERVER_ERROR, "failed to send message");
            }
        }
    }

    (StatusCode::OK, "")
}