1
// Copyright (C) Moondance Labs Ltd.
2
// This file is part of Tanssi.
3

            
4
// Tanssi is free software: you can redistribute it and/or modify
5
// it under the terms of the GNU General Public License as published by
6
// the Free Software Foundation, either version 3 of the License, or
7
// (at your option) any later version.
8

            
9
// Tanssi is distributed in the hope that it will be useful,
10
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
// GNU General Public License for more details.
13

            
14
// You should have received a copy of the GNU General Public License
15
// along with Tanssi.  If not, see <http://www.gnu.org/licenses/>.
16

            
17
use {
18
    futures::{
19
        future::BoxFuture,
20
        stream::{FuturesUnordered, StreamExt},
21
        FutureExt,
22
    },
23
    jsonrpsee::{
24
        core::{
25
            client::{Client as JsonRpcClient, ClientT as _, Subscription},
26
            params::ArrayParams,
27
            ClientError as JsonRpseeError, JsonValue,
28
        },
29
        ws_client::WsClientBuilder,
30
    },
31
    sc_rpc_api::chain::ChainApiClient,
32
    schnellru::{ByLength, LruMap},
33
    std::sync::Arc,
34
    tokio::sync::{mpsc, oneshot},
35
};
36

            
37
const LOG_TARGET: &str = "reconnecting-websocket-client-orchestrator";
38

            
39
type RpcRequestFuture = BoxFuture<'static, Result<(), JsonRpcRequest>>;
40

            
41
/// A Json Rpc/Rpsee request with a oneshot sender to send the request's response.
42
pub struct JsonRpcRequest {
43
    pub method: String,
44
    pub params: ArrayParams,
45
    pub response_sender: oneshot::Sender<Result<JsonValue, JsonRpseeError>>,
46
}
47

            
48
pub enum WsClientRequest {
49
    JsonRpcRequest(JsonRpcRequest),
50
    RegisterBestHeadListener(mpsc::Sender<dp_core::Header>),
51
    RegisterImportListener(mpsc::Sender<dp_core::Header>),
52
    RegisterFinalizationListener(mpsc::Sender<dp_core::Header>),
53
}
54

            
55
enum ConnectionStatus {
56
    Connected,
57
    Disconnected {
58
        failed_request: Option<JsonRpcRequest>,
59
    },
60
}
61

            
62
/// Worker that manage a WebSocket connection and handle disconnects by changing endpoint and
63
/// retrying pending requests.
64
///
65
/// Is first created with [`ReconnectingWsClientWorker::new`], which returns both a
66
/// [`ReconnectingWsClientWorker`] and an [`mpsc::Sender`] to send the requests.
67
/// [`ReconnectingWsClientWorker::run`] must the be called and the returned future queued in
68
/// a tokio executor.
69
pub struct ReconnectingWsClientWorker {
70
    urls: Vec<String>,
71
    active_client: Arc<JsonRpcClient>,
72
    active_index: usize,
73

            
74
    request_receiver: mpsc::Receiver<WsClientRequest>,
75

            
76
    imported_header_listeners: Vec<mpsc::Sender<dp_core::Header>>,
77
    finalized_header_listeners: Vec<mpsc::Sender<dp_core::Header>>,
78
    best_header_listeners: Vec<mpsc::Sender<dp_core::Header>>,
79
}
80

            
81
struct OrchestratorSubscription {
82
    import_subscription: Subscription<dp_core::Header>,
83
    finalized_subscription: Subscription<dp_core::Header>,
84
    best_subscription: Subscription<dp_core::Header>,
85
}
86

            
87
/// Connects to a ws server by cycle throught all provided urls from the starting position until
88
/// each one one was tried. Stops once a connection was succesfully made.
89
async fn connect_next_available_rpc_server(
90
    urls: &[String],
91
    starting_position: usize,
92
) -> Result<(usize, Arc<JsonRpcClient>), ()> {
93
    tracing::debug!(target: LOG_TARGET, starting_position, "Connecting to RPC server.");
94

            
95
    for (counter, url) in urls
96
        .iter()
97
        .cycle()
98
        .skip(starting_position)
99
        .take(urls.len())
100
        .enumerate()
101
    {
102
        let index = (starting_position + counter) % urls.len();
103

            
104
        tracing::info!(
105
            target: LOG_TARGET,
106
            index,
107
            url,
108
            "Trying to connect to next external orchestrator node.",
109
        );
110

            
111
        match WsClientBuilder::default().build(&url).await {
112
            Ok(ws_client) => return Ok((index, Arc::new(ws_client))),
113
            Err(err) => tracing::debug!(target: LOG_TARGET, url, ?err, "Unable to connect."),
114
        };
115
    }
116
    Err(())
117
}
118

            
119
impl ReconnectingWsClientWorker {
120
    /// Create a new worker that will connect to the provided URLs.
121
    pub async fn new(urls: Vec<String>) -> Result<(Self, mpsc::Sender<WsClientRequest>), ()> {
122
        if urls.is_empty() {
123
            return Err(());
124
        }
125

            
126
        let (active_index, active_client) = connect_next_available_rpc_server(&urls, 0).await?;
127
        let (request_sender, request_receiver) = mpsc::channel(100);
128

            
129
        Ok((
130
            Self {
131
                urls,
132
                active_client,
133
                active_index,
134
                request_receiver,
135
                best_header_listeners: vec![],
136
                imported_header_listeners: vec![],
137
                finalized_header_listeners: vec![],
138
            },
139
            request_sender,
140
        ))
141
    }
142

            
143
    /// Change RPC server for future requests.
144
    async fn connect_to_new_rpc_server(&mut self) -> Result<(), ()> {
145
        let (active_index, active_client) =
146
            connect_next_available_rpc_server(&self.urls, self.active_index + 1).await?;
147
        self.active_index = active_index;
148
        self.active_client = active_client;
149
        Ok(())
150
    }
151

            
152
    /// Send the request to the current client. If this connection becomes dead, the returned future
153
    /// will return the request so it can be sent to another client.
154
    fn send_request(
155
        &self,
156
        JsonRpcRequest {
157
            method,
158
            params,
159
            response_sender,
160
        }: JsonRpcRequest,
161
    ) -> RpcRequestFuture {
162
        let client = self.active_client.clone();
163
        async move {
164
            let response = client.request(&method, params.clone()).await;
165

            
166
            // We should only return the original request in case
167
            // the websocket connection is dead and requires a restart.
168
            // Other errors should be forwarded to the request caller.
169
            if let Err(JsonRpseeError::RestartNeeded(_)) = response {
170
                return Err(JsonRpcRequest {
171
                    method,
172
                    params,
173
                    response_sender,
174
                });
175
            }
176

            
177
            if let Err(err) = response_sender.send(response) {
178
                tracing::debug!(
179
                    target: LOG_TARGET,
180
                    ?err,
181
                    "Recipient no longer interested in request result"
182
                );
183
            }
184

            
185
            Ok(())
186
        }
187
        .boxed()
188
    }
189

            
190
    async fn get_subscriptions(&self) -> Result<OrchestratorSubscription, JsonRpseeError> {
191
        let import_subscription = <JsonRpcClient as ChainApiClient<
192
            dp_core::BlockNumber,
193
            dp_core::Hash,
194
            dp_core::Header,
195
            dp_core::SignedBlock,
196
        >>::subscribe_all_heads(&self.active_client)
197
        .await
198
        .map_err(|e| {
199
            tracing::error!(
200
                target: LOG_TARGET,
201
                ?e,
202
                "Unable to open `chain_subscribeAllHeads` subscription."
203
            );
204
            e
205
        })?;
206

            
207
        let best_subscription = <JsonRpcClient as ChainApiClient<
208
            dp_core::BlockNumber,
209
            dp_core::Hash,
210
            dp_core::Header,
211
            dp_core::SignedBlock,
212
        >>::subscribe_new_heads(&self.active_client)
213
        .await
214
        .map_err(|e| {
215
            tracing::error!(
216
                target: LOG_TARGET,
217
                ?e,
218
                "Unable to open `chain_subscribeNewHeads` subscription."
219
            );
220
            e
221
        })?;
222

            
223
        let finalized_subscription = <JsonRpcClient as ChainApiClient<
224
            dp_core::BlockNumber,
225
            dp_core::Hash,
226
            dp_core::Header,
227
            dp_core::SignedBlock,
228
        >>::subscribe_finalized_heads(&self.active_client)
229
        .await
230
        .map_err(|e| {
231
            tracing::error!(
232
                target: LOG_TARGET,
233
                ?e,
234
                "Unable to open `chain_subscribeFinalizedHeads` subscription."
235
            );
236
            e
237
        })?;
238

            
239
        Ok(OrchestratorSubscription {
240
            import_subscription,
241
            best_subscription,
242
            finalized_subscription,
243
        })
244
    }
245

            
246
    /// Handle a reconnection by fnding a new RPC server and sending all pending requests.
247
    async fn handle_reconnect(
248
        &mut self,
249
        pending_requests: &mut FuturesUnordered<RpcRequestFuture>,
250
        first_failed_request: Option<JsonRpcRequest>,
251
    ) -> Result<(), String> {
252
        let mut requests_to_retry = Vec::new();
253
        if let Some(req) = first_failed_request {
254
            requests_to_retry.push(req)
255
        }
256

            
257
        // All pending requests will return an error since the websocket connection is dead.
258
        // Draining the pending requests should be fast.
259
        while !pending_requests.is_empty() {
260
            if let Some(Err(req)) = pending_requests.next().await {
261
                requests_to_retry.push(req);
262
            }
263
        }
264

            
265
        // Connect to new RPC server if possible.
266
        if self.connect_to_new_rpc_server().await.is_err() {
267
            return Err("Unable to find valid external RPC server, shutting down.".to_string());
268
        }
269

            
270
        // Retry requests.
271
        for req in requests_to_retry.into_iter() {
272
            pending_requests.push(self.send_request(req));
273
        }
274

            
275
        // Get subscriptions from new endpoint.
276
        self.get_subscriptions().await.map_err(|e| {
277
			format!("Not able to create streams from newly connected RPC server, shutting down. err: {:?}", e)
278
		})?;
279

            
280
        Ok(())
281
    }
282

            
283
    pub async fn run(mut self) {
284
        let mut pending_requests = FuturesUnordered::new();
285
        let mut connection_status = ConnectionStatus::Connected;
286

            
287
        let Ok(mut subscriptions) = self.get_subscriptions().await else {
288
            tracing::error!(target: LOG_TARGET, "Unable to fetch subscriptions on initial connection.");
289
            return;
290
        };
291

            
292
        let mut imported_blocks_cache = LruMap::new(ByLength::new(40));
293
        let mut last_seen_finalized_num: dp_core::BlockNumber = 0;
294

            
295
        loop {
296
            // Handle reconnection.
297
            if let ConnectionStatus::Disconnected { failed_request } = connection_status {
298
                if let Err(message) = self
299
                    .handle_reconnect(&mut pending_requests, failed_request)
300
                    .await
301
                {
302
                    tracing::error!(
303
                        target: LOG_TARGET,
304
                        message,
305
                        "Unable to reconnect, stopping worker."
306
                    );
307
                    return;
308
                }
309

            
310
                connection_status = ConnectionStatus::Connected;
311
            }
312

            
313
            tokio::select! {
314
                // New request received.
315
                req = self.request_receiver.recv() => match req {
316
                    Some(WsClientRequest::JsonRpcRequest(req)) => {
317
                        pending_requests.push(self.send_request(req));
318
                    },
319
                    Some(WsClientRequest::RegisterBestHeadListener(tx)) => {
320
                        self.best_header_listeners.push(tx);
321
                    },
322
                    Some(WsClientRequest::RegisterImportListener(tx)) => {
323
                        self.imported_header_listeners.push(tx);
324
                    },
325
                    Some(WsClientRequest::RegisterFinalizationListener(tx)) => {
326
                        self.finalized_header_listeners.push(tx);
327
                    },
328
                    None => {
329
                        tracing::error!(target: LOG_TARGET, "RPC client receiver closed. Stopping RPC Worker.");
330
                        return;
331
                    }
332
                },
333
                // We poll pending request futures. If one completes with an `Err`, it means the
334
                // ws client was disconnected and we need to reconnect to a new ws client.
335
                pending = pending_requests.next(), if !pending_requests.is_empty() => {
336
                    if let Some(Err(req)) = pending {
337
                        connection_status = ConnectionStatus::Disconnected { failed_request: Some(req) };
338
                    }
339
                },
340
                import_event = subscriptions.import_subscription.next() => {
341
                    match import_event {
342
                        Some(Ok(header)) => {
343
                            let hash = header.hash();
344
                            if imported_blocks_cache.peek(&hash).is_some() {
345
                                tracing::debug!(
346
                                    target: LOG_TARGET,
347
                                    number = header.number,
348
                                    ?hash,
349
                                    "Duplicate imported block header. This might happen after switching to a new RPC node. Skipping distribution."
350
                                );
351
                                continue;
352
                            }
353
                            imported_blocks_cache.insert(hash, ());
354
                            distribute(header, &mut self.imported_header_listeners);
355
                        },
356
                        None => {
357
                            tracing::error!(target: LOG_TARGET, "Subscription closed.");
358
                            connection_status = ConnectionStatus::Disconnected { failed_request: None};
359
                        },
360
                        Some(Err(error)) => {
361
                            tracing::error!(target: LOG_TARGET, ?error, "Error in RPC subscription.");
362
                            connection_status = ConnectionStatus::Disconnected { failed_request: None};
363
                        },
364
                    }
365
                },
366
                best_header_event = subscriptions.best_subscription.next() => {
367
                    match best_header_event {
368
                        Some(Ok(header)) => distribute(header, &mut self.best_header_listeners),
369
                        None => {
370
                            tracing::error!(target: LOG_TARGET, "Subscription closed.");
371
                            connection_status = ConnectionStatus::Disconnected { failed_request: None};
372
                        },
373
                        Some(Err(error)) => {
374
                            tracing::error!(target: LOG_TARGET, ?error, "Error in RPC subscription.");
375
                            connection_status = ConnectionStatus::Disconnected { failed_request: None};
376
                        },
377
                    }
378
                }
379
                finalized_event = subscriptions.finalized_subscription.next() => {
380
                    match finalized_event {
381
                        Some(Ok(header)) if header.number > last_seen_finalized_num => {
382
                            last_seen_finalized_num = header.number;
383
                            distribute(header, &mut self.finalized_header_listeners);
384
                        },
385
                        Some(Ok(header)) => {
386
                            tracing::debug!(
387
                                target: LOG_TARGET,
388
                                number = header.number,
389
                                last_seen_finalized_num,
390
                                "Duplicate finalized block header. This might happen after switching to a new RPC node. Skipping distribution."
391
                            );
392
                        },
393
                        None => {
394
                            tracing::error!(target: LOG_TARGET, "Subscription closed.");
395
                            connection_status = ConnectionStatus::Disconnected { failed_request: None};
396
                        },
397
                        Some(Err(error)) => {
398
                            tracing::error!(target: LOG_TARGET, ?error, "Error in RPC subscription.");
399
                            connection_status = ConnectionStatus::Disconnected { failed_request: None};
400
                        },
401
                    }
402
                }
403
            }
404
        }
405
    }
406
}
407

            
408
/// Send `value` through all channels contained in `senders`.
409
/// If no one is listening to the sender, it is removed from the vector.
410
pub fn distribute<T: Clone + Send>(value: T, senders: &mut Vec<mpsc::Sender<T>>) {
411
    senders.retain_mut(|e| {
412
        match e.try_send(value.clone()) {
413
            // Receiver has been dropped, remove Sender from list.
414
            Err(mpsc::error::TrySendError::Closed(_)) => false,
415
            // Channel is full. This should not happen.
416
            // TODO: Improve error handling here
417
            // https://github.com/paritytech/cumulus/issues/1482
418
            Err(error) => {
419
                tracing::error!(target: LOG_TARGET, ?error, "Event distribution channel has reached its limit. This can lead to missed notifications.");
420
                true
421
            },
422
            _ => true,
423
        }
424
    });
425
}