|
1 | 1 | use super::{rejection::*, FromRequestParts}; |
2 | 2 | use async_trait::async_trait; |
3 | | -use http::request::Parts; |
| 3 | +use http::{request::Parts, Uri}; |
4 | 4 | use serde::de::DeserializeOwned; |
5 | 5 |
|
6 | 6 | /// Extractor that deserializes query strings into some type. |
@@ -55,10 +55,38 @@ where |
55 | 55 | type Rejection = QueryRejection; |
56 | 56 |
|
57 | 57 | async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { |
58 | | - let query = parts.uri.query().unwrap_or_default(); |
59 | | - let value = |
| 58 | + Self::try_from_uri(&parts.uri) |
| 59 | + } |
| 60 | +} |
| 61 | + |
| 62 | +impl<T> Query<T> |
| 63 | +where |
| 64 | + T: DeserializeOwned, |
| 65 | +{ |
| 66 | + /// Attempts to construct a [`Query`] from a reference to a [`Uri`]. |
| 67 | + /// |
| 68 | + /// # Example |
| 69 | + /// ``` |
| 70 | + /// use axum::extract::Query; |
| 71 | + /// use http::Uri; |
| 72 | + /// use serde::Deserialize; |
| 73 | + /// |
| 74 | + /// #[derive(Deserialize)] |
| 75 | + /// struct ExampleParams { |
| 76 | + /// foo: String, |
| 77 | + /// bar: u32, |
| 78 | + /// } |
| 79 | + /// |
| 80 | + /// let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap(); |
| 81 | + /// let result: Query<ExampleParams> = Query::try_from_uri(&uri).unwrap(); |
| 82 | + /// assert_eq!(result.foo, String::from("hello")); |
| 83 | + /// assert_eq!(result.bar, 42); |
| 84 | + /// ``` |
| 85 | + pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> { |
| 86 | + let query = value.query().unwrap_or_default(); |
| 87 | + let params = |
60 | 88 | serde_urlencoded::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?; |
61 | | - Ok(Query(value)) |
| 89 | + Ok(Query(params)) |
62 | 90 | } |
63 | 91 | } |
64 | 92 |
|
@@ -137,4 +165,32 @@ mod tests { |
137 | 165 | let res = client.get("/?n=hi").send().await; |
138 | 166 | assert_eq!(res.status(), StatusCode::BAD_REQUEST); |
139 | 167 | } |
| 168 | + |
| 169 | + #[test] |
| 170 | + fn test_try_from_uri() { |
| 171 | + #[derive(Deserialize)] |
| 172 | + struct TestQueryParams { |
| 173 | + foo: String, |
| 174 | + bar: u32, |
| 175 | + } |
| 176 | + let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap(); |
| 177 | + let result: Query<TestQueryParams> = Query::try_from_uri(&uri).unwrap(); |
| 178 | + assert_eq!(result.foo, String::from("hello")); |
| 179 | + assert_eq!(result.bar, 42); |
| 180 | + } |
| 181 | + |
| 182 | + #[test] |
| 183 | + fn test_try_from_uri_with_invalid_query() { |
| 184 | + #[derive(Deserialize)] |
| 185 | + struct TestQueryParams { |
| 186 | + _foo: String, |
| 187 | + _bar: u32, |
| 188 | + } |
| 189 | + let uri: Uri = "http://example.com/path?foo=hello&bar=invalid" |
| 190 | + .parse() |
| 191 | + .unwrap(); |
| 192 | + let result: Result<Query<TestQueryParams>, _> = Query::try_from_uri(&uri); |
| 193 | + |
| 194 | + assert!(result.is_err()); |
| 195 | + } |
140 | 196 | } |
0 commit comments