1use anyhow::Result;
2use directories::ProjectDirs;
3use serde::{Deserialize, Serialize};
4use std::{path::PathBuf, time::Duration};
5
6#[derive(Debug, Serialize, Deserialize)]
7pub struct Config {
8 pub aws: AwsConfig,
9 pub app: AppConfig,
10}
11
12#[derive(Debug, Serialize, Deserialize)]
13pub struct AwsConfig {
14 pub region: Option<String>,
15 pub workgroup: Option<String>,
16 pub output_location: String,
17 pub catalog: Option<String>,
18 pub database: Option<String>,
19 pub profile: Option<String>,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
23pub struct AppConfig {
24 #[serde(with = "humantime_serde")]
25 pub query_reuse_time: Duration,
26 pub max_rows: usize,
27 #[serde(default = "default_history_size")]
29 pub history_size: i32,
30 #[serde(default)]
32 pub history_fields: Option<Vec<String>>,
33 #[serde(default)]
35 pub inspect_fields: Option<Vec<String>>,
36}
37
38fn default_history_size() -> i32 {
39 20
40}
41
42impl Default for Config {
43 fn default() -> Self {
44 Self {
45 aws: AwsConfig {
46 region: Some("eu-west-1".to_string()),
47 workgroup: Some("primary".to_string()),
48 output_location: "s3://athena-query-results/".to_string(),
49 catalog: Some("AwsDataCatalog".to_string()),
50 database: None,
51 profile: None,
52 },
53 app: AppConfig {
54 query_reuse_time: Duration::from_secs(3600), max_rows: 1000,
56 history_size: 20,
57 history_fields: None,
58 inspect_fields: None,
59 },
60 }
61 }
62}
63
64impl Config {
65 pub fn load() -> Result<Self> {
66 let config_path = get_config_path()?;
67
68 println!("Looking for config at: {}", config_path.display());
69
70 if !config_path.exists() {
71 println!("Config file not found, creating default");
72 let config = Config::default();
73 std::fs::create_dir_all(config_path.parent().unwrap())?;
74 std::fs::write(&config_path, toml::to_string_pretty(&config)?)?;
75 return Ok(config);
76 }
77
78 println!("Loading config from: {}", config_path.display());
79 let config = config::Config::builder()
80 .add_source(config::File::from(config_path))
81 .build()?;
82
83 let config: Config = config.try_deserialize()?;
84 println!("Loaded workgroup: {:?}", config.aws.workgroup);
85
86 Ok(config)
87 }
88}
89
90fn get_config_path() -> Result<PathBuf> {
91 if let Ok(home) = std::env::var("HOME") {
93 return Ok(PathBuf::from(home).join(".config/athena-cli/config.toml"));
94 }
95
96 let proj_dirs = ProjectDirs::from("com", "your-org", "athena-cli")
98 .ok_or_else(|| anyhow::anyhow!("Could not determine config directory"))?;
99
100 Ok(proj_dirs.config_dir().join("config.toml"))
101}