diff --git a/crates/youki/src/logger.rs b/crates/youki/src/logger.rs index 7647c0ee..1f7753e3 100644 --- a/crates/youki/src/logger.rs +++ b/crates/youki/src/logger.rs @@ -8,6 +8,8 @@ use std::io::Write; use std::path::PathBuf; use std::str::FromStr; +const LOG_LEVEL_ENV_NAME: &str = "YOUKI_LOG_LEVEL"; + /// If in debug mode, default level is debug to get maximum logging #[cfg(debug_assertions)] const DEFAULT_LOG_LEVEL: &str = "debug"; @@ -27,13 +29,7 @@ pub fn init( log_file: Option, log_format: Option, ) -> Result<()> { - let filter: Cow = if log_debug_flag { - "debug".into() - } else if let Ok(level) = std::env::var("YOUKI_LOG_LEVEL") { - level.into() - } else { - DEFAULT_LOG_LEVEL.into() - }; + let log_level = detect_log_level(log_debug_flag); let formatter = match log_format.as_deref() { None | Some(LOG_FORMAT_TEXT) => text_write, Some(LOG_FORMAT_JSON) => json_write, @@ -51,7 +47,7 @@ pub fn init( env_logger::Target::Stderr }; env_logger::Builder::new() - .filter_level(LevelFilter::from_str(filter.as_ref()).context("failed to parse log level")?) + .filter_level(log_level.context("failed to parse log level")?) .format(formatter) .target(target) .init(); @@ -59,6 +55,17 @@ pub fn init( Ok(()) } +fn detect_log_level(is_debug: bool) -> Result { + let filter: Cow = if is_debug { + "debug".into() + } else if let Ok(level) = std::env::var(LOG_LEVEL_ENV_NAME) { + level.into() + } else { + DEFAULT_LOG_LEVEL.into() + }; + Ok(LevelFilter::from_str(filter.as_ref())?) +} + fn json_write(f: &mut F, record: &log::Record) -> std::io::Result<()> where F: Write, @@ -93,3 +100,56 @@ where Ok(()) } + +#[cfg(test)] +mod tests { + use serial_test::serial; + + use super::*; + use std::env; + struct LogLevelGuard { + original_level: Option, + } + + impl LogLevelGuard { + fn new(level: &str) -> Result { + let original_level = env::var(LOG_LEVEL_ENV_NAME).ok(); + env::set_var(LOG_LEVEL_ENV_NAME, level); + Ok(Self { original_level }) + } + } + impl Drop for LogLevelGuard { + fn drop(self: &mut LogLevelGuard) { + if let Some(level) = self.original_level.as_ref() { + env::set_var(LOG_LEVEL_ENV_NAME, level); + } else { + env::remove_var(LOG_LEVEL_ENV_NAME); + } + } + } + + #[test] + fn test_detect_log_level_is_debug() { + let _guard = LogLevelGuard::new("error").unwrap(); + assert_eq!(detect_log_level(true).unwrap(), LevelFilter::Debug) + } + + #[test] + #[serial] + fn test_detect_log_level_default() { + let _guard = LogLevelGuard::new("error").unwrap(); + env::remove_var(LOG_LEVEL_ENV_NAME); + if cfg!(debug_assertions) { + assert_eq!(detect_log_level(false).unwrap(), LevelFilter::Debug) + } else { + assert_eq!(detect_log_level(false).unwrap(), LevelFilter::Warn) + } + } + + #[test] + #[serial] + fn test_detect_log_level_from_env() { + let _guard = LogLevelGuard::new("error").unwrap(); + assert_eq!(detect_log_level(false).unwrap(), LevelFilter::Error) + } +}