Skip to content

Commit e095bf9

Browse files
authored
chore: add direct connect handshake (#483)
Momento direct connect uses an http+json -> protosocket negotiation handshake. The http connection retrieves a list of socket addresses that can be used with the connection-oriented protosocket rpc protocol. The ProtosocketCacheClient queries the address list on initial connection, and updates the local list in the background. When a new connection is required, for example due to a disconnect, the local list is used to choose a new socket address for that "slot." Direct connections are lighter than grpc connections, but currently have some compromises. For example, they do not currently support graceful GOAWAY mechanism. Also, the cached local address list may be out of date with reality for a few seconds, causing some possible re-connect attempts to fail until list refresh. This connection scheme will be improved over time, and is expected to outperform grpc for both cost & latency now, and match grpc practically for resilience in the future.
1 parent 8a4f21b commit e095bf9

File tree

10 files changed

+1004
-74
lines changed

10 files changed

+1004
-74
lines changed

Cargo.lock

Lines changed: 581 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@ base64 = "0.22"
3636
derive_more = { version = "2.0.1", features = ["full"] }
3737
futures = "0"
3838
h2 = { version = "0.4" }
39+
http = { version = "1" }
3940
hyper = { version = "1.6" }
4041
log = "0.4"
4142
protosocket = "0.11.0"
4243
protosocket-prost = "0.11.0"
4344
protosocket-rpc = "0.11.0"
4445
rand = "0.9"
46+
reqwest = { version = "0.12", default-features = false, features = ["http2", "json", "rustls-tls-native-roots"] }
4547
serde = { version = "1.0", features = ["derive"] }
4648
serde_json = "1.0"
4749
thiserror = "2.0"

src/credential_provider.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ struct V1Token {
1111
pub endpoint: String,
1212
}
1313

14-
#[derive(PartialEq, Eq, Clone)]
14+
#[derive(PartialEq, Eq, Clone, Debug)]
1515
pub(crate) enum EndpointSecurity {
1616
Insecure,
1717
Unverified,
@@ -25,6 +25,7 @@ pub struct CredentialProvider {
2525
pub(crate) auth_token: String,
2626
pub(crate) control_endpoint: String,
2727
pub(crate) cache_endpoint: String,
28+
pub(crate) cache_http_endpoint: String,
2829
pub(crate) token_endpoint: String,
2930
pub(crate) endpoint_security: EndpointSecurity,
3031
}
@@ -46,6 +47,7 @@ impl Debug for CredentialProvider {
4647
.field("cache_endpoint", &self.cache_endpoint)
4748
.field("control_endpoint", &self.control_endpoint)
4849
.field("token_endpoint", &self.token_endpoint)
50+
.field("endpoint_security", &self.endpoint_security)
4951
.finish()
5052
}
5153
}
@@ -83,6 +85,11 @@ impl CredentialProvider {
8385
decode_auth_token(token_to_process)
8486
}
8587

88+
/// Returns the hostname that can be used with momento HTTP apis
89+
pub fn cache_http_endpoint(&self) -> &str {
90+
&self.cache_http_endpoint
91+
}
92+
8693
/// Returns a Credential Provider from the provided API key
8794
///
8895
/// # Arguments
@@ -175,6 +182,7 @@ fn process_v1_token(auth_token_bytes: Vec<u8>) -> MomentoResult<CredentialProvid
175182
Ok(CredentialProvider {
176183
auth_token: json.api_key,
177184
cache_endpoint: https_endpoint(get_cache_endpoint(&json.endpoint)),
185+
cache_http_endpoint: https_endpoint(get_cache_http_endpoint(&json.endpoint)),
178186
control_endpoint: https_endpoint(get_control_endpoint(&json.endpoint)),
179187
token_endpoint: https_endpoint(get_token_endpoint(&json.endpoint)),
180188
endpoint_security: EndpointSecurity::Tls,
@@ -185,6 +193,10 @@ fn get_cache_endpoint(endpoint: &str) -> String {
185193
format!("cache.{endpoint}")
186194
}
187195

196+
fn get_cache_http_endpoint(endpoint: &str) -> String {
197+
format!("api.cache.{endpoint}")
198+
}
199+
188200
fn get_control_endpoint(endpoint: &str) -> String {
189201
format!("control.{endpoint}")
190202
}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
use std::{
2+
collections::HashMap,
3+
net::SocketAddr,
4+
sync::{Arc, Mutex},
5+
};
6+
7+
use crate::CredentialProvider;
8+
9+
// Todo: should make the connections try to balance better than random
10+
#[derive(serde::Deserialize, serde::Serialize, Debug, Default)]
11+
pub(crate) struct Addresses {
12+
#[serde(flatten)]
13+
azs: HashMap<AzId, Vec<Address>>,
14+
}
15+
16+
impl Addresses {
17+
/// Get a list of socket addresses, optionally filtered by availability zone ID.
18+
/// If an az_id is provided, only addresses in that availability zone will be returned,
19+
pub fn for_az(&self, az_id: Option<&str>) -> Vec<SocketAddr> {
20+
if let Some(az_id) = az_id {
21+
if let Some(addresses) = self.azs.get(&AzId(az_id.to_string())) {
22+
if !addresses.is_empty() {
23+
return addresses.iter().map(|a| a.socket_address).collect();
24+
}
25+
}
26+
}
27+
self.azs
28+
.values()
29+
.flat_map(|addresses| addresses.iter().map(|a| a.socket_address))
30+
.collect()
31+
}
32+
}
33+
34+
#[derive(serde::Deserialize, serde::Serialize, Debug, Eq, PartialEq, Hash)]
35+
pub(crate) struct AzId(String);
36+
37+
#[derive(serde::Deserialize, serde::Serialize, Debug)]
38+
pub(crate) struct Address {
39+
socket_address: SocketAddr,
40+
}
41+
42+
#[derive(Debug)]
43+
pub(crate) struct AddressProvider {
44+
addresses: Mutex<Arc<Addresses>>,
45+
client: reqwest::Client,
46+
credential_provider: CredentialProvider,
47+
}
48+
49+
#[derive(Debug, thiserror::Error)]
50+
pub(crate) enum RefreshError {
51+
Reqwest(#[from] reqwest::Error),
52+
Json(#[from] serde_json::Error),
53+
Uri(#[from] http::uri::InvalidUri),
54+
BadStatus((reqwest::StatusCode, String)),
55+
}
56+
impl std::fmt::Display for RefreshError {
57+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58+
match self {
59+
RefreshError::Reqwest(e) => write!(f, "Reqwest error: {e}"),
60+
RefreshError::Json(e) => write!(f, "JSON error: {e}"),
61+
RefreshError::Uri(e) => write!(f, "URI error: {e}"),
62+
RefreshError::BadStatus((status, text)) => write!(f, "Bad status: {status}, {text}"),
63+
}
64+
}
65+
}
66+
67+
impl AddressProvider {
68+
/// Looks for an address list from the provided endpoint.
69+
#[allow(clippy::expect_used)]
70+
pub fn new(credential_provider: CredentialProvider) -> Self {
71+
let client = reqwest::Client::builder()
72+
.tls_built_in_native_certs(true)
73+
.tls_built_in_root_certs(true)
74+
.build()
75+
.expect("must be able to build client");
76+
Self {
77+
addresses: Default::default(),
78+
client,
79+
credential_provider,
80+
}
81+
}
82+
83+
#[allow(clippy::expect_used)]
84+
pub fn get_addresses(&self) -> impl std::ops::Deref<Target = Addresses> {
85+
self.addresses
86+
.lock()
87+
.expect("local mutex must not be poisoned")
88+
.clone()
89+
}
90+
91+
#[allow(clippy::expect_used)]
92+
pub async fn try_refresh_addresses(&self) -> Result<(), RefreshError> {
93+
let request = self
94+
.client
95+
.get(format!(
96+
"{}/endpoints",
97+
self.credential_provider
98+
.cache_http_endpoint()
99+
.trim_end_matches('/')
100+
))
101+
.header("authorization", &self.credential_provider.auth_token)
102+
.build()?;
103+
let response = self.client.execute(request).await?;
104+
105+
if !response.status().is_success() {
106+
let status = response.status();
107+
let text = response.text().await?;
108+
return Err(RefreshError::BadStatus((status, text)));
109+
}
110+
111+
let response = response.text().await?;
112+
let addresses = match serde_json::from_str(&response) {
113+
Ok(addresses) => addresses,
114+
Err(e) => {
115+
log::warn!("error parsing address list JSON: {response}");
116+
return Err(RefreshError::Json(e));
117+
}
118+
};
119+
log::debug!("refreshed address list: {addresses:?}");
120+
let addresses = Arc::new(addresses);
121+
*self
122+
.addresses
123+
.lock()
124+
.expect("local mutex must not be poisoned") = addresses;
125+
Ok(())
126+
}
127+
}

src/protosocket/cache/cache_client.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use protosocket_rpc::client::{ConnectionPool, RpcClient};
33

44
use crate::cache::{GetRequest, SetRequest};
55
use crate::protosocket::cache::cache_client_builder::NeedsDefaultTtl;
6-
use crate::protosocket::cache::utils::ProtosocketConnectionManager;
6+
use crate::protosocket::cache::connection_manager::ProtosocketConnectionManager;
77
use crate::protosocket::cache::{Configuration, MomentoProtosocketRequest};
88
use crate::{utils, IntoBytes, MomentoError, MomentoResult, ProtosocketCacheClientBuilder};
99
use std::convert::TryInto;
@@ -103,7 +103,7 @@ impl ProtosocketCacheClient {
103103
/// # }
104104
/// ```
105105
pub fn builder() -> ProtosocketCacheClientBuilder<NeedsDefaultTtl> {
106-
ProtosocketCacheClientBuilder(NeedsDefaultTtl(()))
106+
super::cache_client_builder::initial()
107107
}
108108

109109
/// Gets an item from a Momento Cache
@@ -211,7 +211,7 @@ impl ProtosocketCacheClient {
211211
&self,
212212
) -> MomentoResult<RpcClient<CacheCommand, CacheResponse>> {
213213
let pooled_client = self.client_pool.get_connection().await.map_err(|e| {
214-
MomentoError::unknown_error("protosocket_connection", Some(e.to_string()))
214+
MomentoError::unknown_error("protosocket_connection", Some(format!("{e:?}")))
215215
})?;
216216
Ok(pooled_client.clone())
217217
}

src/protosocket/cache/cache_client_builder.rs

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::protosocket::cache::utils::ProtosocketConnectionManager;
1+
use crate::protosocket::cache::connection_manager::ProtosocketConnectionManager;
22
use crate::protosocket::cache::Configuration;
33
use crate::{CredentialProvider, MomentoResult, ProtosocketCacheClient};
44
use momento_protos::protosocket::cache::CacheCommand;
@@ -11,11 +11,19 @@ pub type Serializer = ProstSerializer<CacheResponse, CacheCommand>;
1111

1212
/// The initial state of the ProtosocketCacheClientBuilder.
1313
#[derive(PartialEq, Eq, Clone, Debug)]
14-
pub struct ProtosocketCacheClientBuilder<State>(pub State);
14+
pub struct ProtosocketCacheClientBuilder<State> {
15+
state: State,
16+
}
17+
18+
pub(crate) fn initial() -> ProtosocketCacheClientBuilder<NeedsDefaultTtl> {
19+
ProtosocketCacheClientBuilder {
20+
state: NeedsDefaultTtl,
21+
}
22+
}
1523

1624
/// The state of the ProtosocketCacheClientBuilder when it is waiting for a default TTL.
1725
#[derive(PartialEq, Eq, Clone, Debug)]
18-
pub struct NeedsDefaultTtl(pub ());
26+
pub struct NeedsDefaultTtl;
1927

2028
/// The state of the ProtosocketCacheClientBuilder when it is waiting for a configuration.
2129
#[derive(PartialEq, Eq, Clone, Debug)]
@@ -53,7 +61,9 @@ impl ProtosocketCacheClientBuilder<NeedsDefaultTtl> {
5361
self,
5462
default_ttl: Duration,
5563
) -> ProtosocketCacheClientBuilder<NeedsConfiguration> {
56-
ProtosocketCacheClientBuilder(NeedsConfiguration { default_ttl })
64+
ProtosocketCacheClientBuilder {
65+
state: NeedsConfiguration { default_ttl },
66+
}
5767
}
5868
}
5969

@@ -63,10 +73,12 @@ impl ProtosocketCacheClientBuilder<NeedsConfiguration> {
6373
self,
6474
configuration: impl Into<Configuration>,
6575
) -> ProtosocketCacheClientBuilder<NeedsCredentialProvider> {
66-
ProtosocketCacheClientBuilder(NeedsCredentialProvider {
67-
default_ttl: self.0.default_ttl,
68-
configuration: configuration.into(),
69-
})
76+
ProtosocketCacheClientBuilder {
77+
state: NeedsCredentialProvider {
78+
default_ttl: self.state.default_ttl,
79+
configuration: configuration.into(),
80+
},
81+
}
7082
}
7183
}
7284

@@ -76,11 +88,13 @@ impl ProtosocketCacheClientBuilder<NeedsCredentialProvider> {
7688
self,
7789
credential_provider: CredentialProvider,
7890
) -> ProtosocketCacheClientBuilder<NeedsRuntime> {
79-
ProtosocketCacheClientBuilder(NeedsRuntime {
80-
default_ttl: self.0.default_ttl,
81-
configuration: self.0.configuration,
82-
credential_provider,
83-
})
91+
ProtosocketCacheClientBuilder {
92+
state: NeedsRuntime {
93+
default_ttl: self.state.default_ttl,
94+
configuration: self.state.configuration,
95+
credential_provider,
96+
},
97+
}
8498
}
8599
}
86100

@@ -90,28 +104,38 @@ impl ProtosocketCacheClientBuilder<NeedsRuntime> {
90104
self,
91105
runtime: tokio::runtime::Handle,
92106
) -> ProtosocketCacheClientBuilder<ReadyToBuild> {
93-
ProtosocketCacheClientBuilder(ReadyToBuild {
94-
default_ttl: self.0.default_ttl,
95-
runtime,
96-
credential_provider: self.0.credential_provider,
97-
configuration: self.0.configuration,
98-
})
107+
ProtosocketCacheClientBuilder {
108+
state: ReadyToBuild {
109+
default_ttl: self.state.default_ttl,
110+
runtime,
111+
credential_provider: self.state.credential_provider,
112+
configuration: self.state.configuration,
113+
},
114+
}
99115
}
100116
}
101117

102118
impl ProtosocketCacheClientBuilder<ReadyToBuild> {
103119
/// Constructs a new CacheClientBuilder in the ReadyToBuild state.
104120
pub async fn build(self) -> MomentoResult<ProtosocketCacheClient> {
105-
let client_connector =
106-
ProtosocketConnectionManager::new(self.0.credential_provider, self.0.runtime)?;
121+
let ReadyToBuild {
122+
default_ttl,
123+
credential_provider,
124+
runtime,
125+
configuration,
126+
} = self.state;
127+
let client_connector = ProtosocketConnectionManager::new(
128+
credential_provider,
129+
runtime,
130+
configuration.az_id.clone(),
131+
)?;
107132

108-
let client_pool =
109-
ConnectionPool::new(client_connector, self.0.configuration.connection_count());
133+
let client_pool = ConnectionPool::new(client_connector, configuration.connection_count());
110134

111135
Ok(ProtosocketCacheClient::new(
112136
client_pool,
113-
self.0.default_ttl,
114-
self.0.configuration,
137+
default_ttl,
138+
configuration,
115139
))
116140
}
117141
}

0 commit comments

Comments
 (0)