athena_cli/commands/
query.rs

1//! Query execution module for Athena CLI.
2//!
3//! This module provides functionality to:
4//! - Execute SQL queries against AWS Athena
5//! - Retrieve and display query results
6//! - Monitor query execution status and statistics
7//! - Handle result pagination and data formatting
8//!
9//! ## Usage Examples
10//!
11//! Simple query:
12//!
13//! ```bash
14//! athena-cli query "SELECT * FROM my_table"
15//! ```
16//!
17//! Query with database and workgroup specified:
18//!
19//! ```bash
20//! athena-cli -d my_database -w my_workgroup query "SELECT * FROM my_table"
21//! ```
22//!
23//! Query with custom result reuse time:
24//!
25//! ```bash
26//! athena-cli query --reuse-time 2h "SELECT * FROM my_table"
27//! ```
28//!
29//! Query with output location:
30//!
31//! ```bash
32//! athena-cli --output-location s3://my-bucket/results/ query "SELECT * FROM my_table"
33//! ```
34
35use crate::cli;
36use crate::context::Context;
37use crate::validation;
38use anyhow::Result;
39use aws_sdk_athena::types::{
40    QueryExecutionContext, QueryExecutionState, ResultConfiguration, ResultReuseByAgeConfiguration,
41    ResultReuseConfiguration,
42};
43use aws_sdk_athena::Client;
44use byte_unit::Byte;
45use colored::Colorize;
46use polars::prelude::*;
47use std::{thread, time::Duration};
48
49/// Executes an Athena SQL query and displays the results.
50///
51/// # Arguments
52///
53/// * `ctx` - The application context containing configuration and connection details
54/// * `args` - Command line arguments including the SQL query text and reuse time
55///
56/// # Returns
57///
58/// Returns a Result indicating success or failure of the query execution
59///
60/// # Features
61///
62/// * Configurable query result reuse (caching) duration
63/// * Displays query statistics including data scanned and cache status
64/// * Supports pagination for large result sets
65/// * Returns results as a Polars DataFrame for further processing
66///
67/// # Examples
68///
69/// Basic query example:
70///
71/// ```bash
72/// athena-cli query "SELECT * FROM my_database.my_table LIMIT 10"
73/// ```
74///
75/// Using result reuse/caching (30 minutes):
76///
77/// ```bash
78/// athena-cli query --reuse-time 30m "SELECT count(*) FROM my_table"
79/// ```
80///
81/// Query with specific database:
82///
83/// ```bash
84/// athena-cli -d my_database query "SELECT * FROM my_table WHERE id=123"
85/// ```
86///
87/// Query with custom workgroup and output location:
88///
89/// ```bash
90/// athena-cli -w my_workgroup --output-location s3://my-bucket/results/ query "SELECT * FROM my_table"
91/// ```
92pub async fn execute(ctx: &Context, args: &cli::QueryArgs) -> Result<()> {
93    println!("Executing query: {}", args.query);
94
95    // Validate SQL syntax before sending to Athena
96    if let Err(e) = validation::validate_query_syntax(&args.query) {
97        println!("{}", "SQL syntax validation failed".red().bold());
98        return Err(e);
99    }
100
101    let database = ctx
102        .database()
103        .ok_or_else(|| anyhow::anyhow!("Database name is required but was not provided"))?;
104
105    let client = ctx.create_athena_client();
106
107    let query_id = start_query(
108        &client,
109        &database,
110        &args.query,
111        &ctx.workgroup(),
112        args.reuse_time,
113        ctx.output_location()
114            .as_deref()
115            .unwrap_or("s3://aws-athena-query-results"),
116    )
117    .await?;
118
119    println!("Query execution ID: {}", query_id);
120
121    let df = get_query_results(&client, &query_id).await?;
122    println!("Results DataFrame:");
123    println!("{}", df);
124
125    Ok(())
126}
127
128/// Starts an Athena query execution with the specified parameters and returns the execution ID.
129///
130/// # Arguments
131///
132/// * `client` - The AWS Athena SDK client
133/// * `database` - The database to query against
134/// * `query` - The SQL query string to execute
135/// * `workgroup` - The Athena workgroup to use
136/// * `reuse_duration` - Duration for which query results should be reused/cached
137/// * `output_location` - S3 location where query results will be stored
138///
139/// # Returns
140///
141/// Returns a Result containing the query execution ID as a String
142///
143/// # Implementation Details
144///
145/// * Configures the query context with database and output location
146/// * Sets up result reuse configuration based on the provided duration
147/// * Returns the execution ID that can be used to track and retrieve results
148async fn start_query(
149    client: &Client,
150    database: &str,
151    query: &str,
152    workgroup: &str,
153    reuse_duration: Duration,
154    output_location: &str,
155) -> Result<String> {
156    let context = QueryExecutionContext::builder().database(database).build();
157
158    let config = ResultConfiguration::builder()
159        .output_location(output_location)
160        .build();
161
162    let result = client
163        .start_query_execution()
164        .result_reuse_configuration(
165            ResultReuseConfiguration::builder()
166                .result_reuse_by_age_configuration(
167                    ResultReuseByAgeConfiguration::builder()
168                        .enabled(true)
169                        .max_age_in_minutes(reuse_duration.as_secs() as i32 / 60)
170                        .build(),
171                )
172                .build(),
173        )
174        .query_string(query)
175        .query_execution_context(context)
176        .result_configuration(config)
177        .work_group(workgroup)
178        .send()
179        .await?;
180
181    Ok(result.query_execution_id().unwrap_or_default().to_string())
182}
183
184/// Retrieves query results and converts them to a Polars DataFrame.
185///
186/// # Arguments
187///
188/// * `client` - The AWS Athena SDK client
189/// * `query_execution_id` - The execution ID of the query whose results to retrieve
190///
191/// # Returns
192///
193/// Returns a Result containing a Polars DataFrame with the query results
194///
195/// # Behavior
196///
197/// * Polls the query execution until it succeeds, fails, or is cancelled
198/// * Displays query statistics including data scanned and cache status
199/// * Paginates through results if they span multiple pages (100 rows per page)
200/// * Converts query results to a Polars DataFrame for analysis and display
201///
202/// # Error Handling
203///
204/// * Returns an error if the query fails or is cancelled
205/// * Handles partial results and pagination automatically
206async fn get_query_results(client: &Client, query_execution_id: &str) -> Result<DataFrame> {
207    // Wait for query to complete
208    loop {
209        let status = client
210            .get_query_execution()
211            .query_execution_id(query_execution_id)
212            .send()
213            .await?;
214
215        if let Some(execution) = status.query_execution() {
216            match execution.status().unwrap().state().as_ref() {
217                Some(QueryExecutionState::Succeeded) => {
218                    // Print query info once before breaking
219                    if let Some(result_config) = execution.result_configuration() {
220                        if let Some(output_location) = result_config.output_location() {
221                            println!("Results S3 path: {}", output_location);
222                        }
223                    }
224
225                    if let Some(statistics) = execution.statistics() {
226                        let data_scanned = statistics.data_scanned_in_bytes().unwrap_or(0);
227                        let is_cached = data_scanned == 0;
228                        println!(
229                            "Query cache status: {}",
230                            if is_cached {
231                                String::from("Results retrieved from cache")
232                            } else {
233                                let formatted_size = Byte::from_i64(data_scanned)
234                                    .map(|b| {
235                                        b.get_appropriate_unit(byte_unit::UnitType::Decimal)
236                                            .to_string()
237                                    })
238                                    .unwrap_or_else(|| "-".to_string());
239                                format!("Fresh query execution (scanned {})", formatted_size)
240                            }
241                        );
242                    }
243                    break;
244                }
245                Some(QueryExecutionState::Failed) | Some(QueryExecutionState::Cancelled) => {
246                    let error_message = if let Some(status) = execution.status() {
247                        if let Some(reason) = status.state_change_reason() {
248                            format!("Query failed: {}", reason)
249                        } else {
250                            "Query failed or was cancelled without specific reason".to_string()
251                        }
252                    } else {
253                        "Query failed or was cancelled".to_string()
254                    };
255                    return Err(anyhow::anyhow!("{}", error_message.red().bold()));
256                }
257                _ => {
258                    thread::sleep(Duration::from_secs(1));
259                    continue;
260                }
261            }
262        }
263    }
264
265    let mut all_columns: Vec<Vec<String>> = Vec::new();
266    let mut column_names: Vec<String> = Vec::new();
267    let mut next_token: Option<String> = None;
268
269    // Get first page and column names
270    let mut results = client
271        .get_query_results()
272        .query_execution_id(query_execution_id)
273        .max_results(100)
274        .send()
275        .await?;
276
277    // Initialize column names from first result
278    if let Some(rs) = results.result_set() {
279        if let Some(first_row) = rs.rows().first() {
280            column_names = first_row
281                .data()
282                .iter()
283                .map(|d| d.var_char_value().unwrap_or_default().to_string())
284                .collect();
285            all_columns = vec![Vec::new(); column_names.len()];
286        }
287    }
288
289    // Process results page by page
290    let mut page_count = 1;
291    loop {
292        if let Some(rs) = results.result_set() {
293            let start_idx = if next_token.is_none() { 1 } else { 0 };
294            let rows_count = rs.rows().len() - start_idx;
295
296            println!("Processing page {}: {} rows", page_count, rows_count);
297
298            for row in rs.rows().iter().skip(start_idx) {
299                for (i, data) in row.data().iter().enumerate() {
300                    all_columns[i].push(data.var_char_value().unwrap_or_default().to_string());
301                }
302            }
303        }
304
305        next_token = results.next_token().map(|s| s.to_string());
306
307        if next_token.is_none() {
308            println!(
309                "Finished processing {} pages, total rows: {}",
310                page_count,
311                all_columns[0].len()
312            );
313            break;
314        }
315
316        page_count += 1;
317        results = client
318            .get_query_results()
319            .query_execution_id(query_execution_id)
320            .max_results(100)
321            .next_token(next_token.as_ref().unwrap())
322            .send()
323            .await?;
324    }
325
326    // Create DataFrame
327    let series = all_columns
328        .iter()
329        .zip(column_names.iter())
330        .map(|(col, name)| Series::new(name.into(), col))
331        .map(|s| s.into_column())
332        .collect();
333
334    // Convert Series to Columns and create DataFrame
335    Ok(DataFrame::new(series)?)
336}