diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index 888bf7f23..619ae5b4d 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -83,6 +83,7 @@ pub struct DbError { file: Option>, line: Option, routine: Option>, + statement: Option>, } impl DbError { @@ -192,6 +193,7 @@ impl DbError { file, line, routine, + statement: None, }) } @@ -242,6 +244,43 @@ impl DbError { self.position.as_ref() } + /// Format the position with an arrow and at most one context line + /// before and after the error. + pub fn format_position(&self) -> Option { + let (sql, pos) = match self.position()? { + ErrorPosition::Original(idx) => (self.statement.as_deref()?, *idx), + ErrorPosition::Internal { position, query } => (query.as_str(), *position), + }; + // This should not fail as long as postgres gives us a valid byte index. + let (before, after) = sql.split_at_checked(pos.saturating_sub(1) as usize)?; + + // Don't use `.lines()` because it removes the last line if it is empty. + // `.split('\n')` always returns at least one item. + let before: Vec<&str> = before.trim_start().split('\n').collect(); + let after: Vec<&str> = after.trim_end().split('\n').collect(); + + // `before.len().saturating_sub(2)..` is always in range, so unwrap would also work. + let mut out = before + .get(before.len().saturating_sub(2)..) + .unwrap_or_default() + .join("\n"); + + // `after` always has at least one item, so unwrap would also work. + out.push_str(after.first().copied().unwrap_or_default()); + + // `before` always has at least one item, so unwrap would also work. + // Count chars because we care about the printed width with monospace font. + // This is not perfect, but good enough. + let indent = before.last().copied().unwrap_or_default().chars().count(); + out = format!("{out}\n{:width$}^", "", width = indent); + + if let Some(after_str) = after.get(1).copied() { + out = format!("{out}\n{after_str}") + } + + Some(out) + } + /// An indication of the context in which the error occurred. /// /// Presently this includes a call stack traceback of active procedural @@ -316,6 +355,9 @@ impl fmt::Display for DbError { if let Some(hint) = &self.hint { write!(fmt, "\nHINT: {}", hint)?; } + if let Some(sql) = self.format_position() { + write!(fmt, "\n{}", sql)?; + } Ok(()) } } @@ -647,6 +689,13 @@ impl Error { Error::new(Kind::RowCount { expected, got }) } + pub(crate) fn with_statement(mut self, sql: &str) -> Error { + if let Kind::Db(x) = &mut self.0.kind { + x.statement = Some(sql.to_owned().into_boxed_str()); + } + self + } + #[cfg(feature = "runtime")] pub(crate) fn connect(e: io::Error) -> Error { Error::new(Kind::Connect(e)) diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index 8df64b086..a3c475100 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -66,7 +66,11 @@ pub async fn prepare( let buf = encode(client, &name, query, types)?; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - match responses.next().await? { + match responses + .next() + .await + .map_err(|e| e.with_statement(query))? + { Message::ParseComplete => {} _ => return Err(Error::unexpected_message()), }