diff --git a/Cargo.lock b/Cargo.lock index c41772e..e2fbccb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -139,6 +139,7 @@ dependencies = [ "tokio", "tracing-subscriber", "url", + "void", ] [[package]] @@ -2365,6 +2366,12 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5a972e5669d67ba988ce3dc826706fb0a8b01471c088cb0b6110b805cc36aed" +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + [[package]] name = "want" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index 7fe7e51..eb687c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ serde_yaml = "0.8" serde = "1.0" systemd = "0.8" thiserror = "1.0" +void = "1" [dependencies.matrix-sdk] git = "https://github.com/matrix-org/matrix-rust-sdk" diff --git a/src/bot.rs b/src/bot.rs index b2a6387..aa97d32 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -100,7 +100,7 @@ impl BadNewsBot { const KEY_MESSAGE: &str = "MESSAGE"; if let Some(unit) = record.get(KEY_UNIT) { - if !self.config.units.contains(unit) { + if !self.config.units.iter().map(|u| &u.name).any(|name| name == unit) { return; } diff --git a/src/config.rs b/src/config.rs index 9d99a37..0a06c3a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,8 +1,12 @@ use matrix_sdk::identifiers::RoomId; -use serde::Deserialize; -use std::collections::HashSet; +use serde::de::{self, MapAccess, Visitor}; +use serde::{Deserialize, Deserializer}; +use std::fmt; +use std::marker::PhantomData; use std::path::PathBuf; +use std::str::FromStr; use url::Url; +use void::Void; /// Holds the configuration for the bot. #[derive(Clone, Deserialize)] @@ -19,5 +23,71 @@ pub struct Config { /// invitations to this room. pub room_id: RoomId, /// Units to watch for logs - pub units: HashSet, + pub units: Vec, +} + +/// Holds a single unit's configuration. +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[serde(from = "SerializedUnit")] +pub struct Unit { + /// Can be serialized from a string only instead of a map. + pub name: String, + /// Regex to filter each line read from the unit's logs. + pub filter: Option, // FIXME: regex +} + +#[derive(Debug, Deserialize)] +#[serde(transparent)] +struct SerializedUnit(#[serde(deserialize_with = "unit_name_or_struct")] Unit); + +impl From for Unit { + fn from(s: SerializedUnit) -> Self { + s.0 + } +} + +impl FromStr for Unit { + type Err = Void; + + fn from_str(s: &str) -> Result { + Ok(Unit { + name: s.to_string(), + filter: None, + }) + } +} + +fn unit_name_or_struct<'de, T, D>(deserializer: D) -> Result +where + T: Deserialize<'de> + FromStr, + D: Deserializer<'de>, +{ + struct StringOrStruct(PhantomData T>); + + impl<'de, T> Visitor<'de> for StringOrStruct + where + T: Deserialize<'de> + FromStr, + { + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("string or map") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(FromStr::from_str(value).unwrap()) + } + + fn visit_map(self, map: M) -> Result + where + M: MapAccess<'de>, + { + Deserialize::deserialize(de::value::MapAccessDeserializer::new(map)) + } + } + + deserializer.deserialize_any(StringOrStruct(PhantomData)) }