Skip to content

Commit e7b8724

Browse files
committed
refactor: Introduce HttpHeaderRetriever, abstraction of auth retriever
HttpHeaderRetriever returns HeaderMap. So, C8YJwtRetriever now returns HeaderMap. The retrieved JWT token is included in the map. In the future, there will be C8YBasicAuthRetriever, which will also returns HeaderMap. Signed-off-by: Rina Fujino <rina.fujino.23@gmail.com>
1 parent 897c102 commit e7b8724

File tree

19 files changed

+246
-155
lines changed

19 files changed

+246
-155
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/common/download/src/download.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use log::warn;
1111
use nix::sys::statvfs;
1212
pub use partial_response::InvalidResponseError;
1313
use reqwest::header;
14+
use reqwest::header::HeaderMap;
1415
use reqwest::Client;
1516
use reqwest::Identity;
1617
use serde::Deserialize;
@@ -20,8 +21,6 @@ use std::fs::File;
2021
use std::io::Seek;
2122
use std::io::SeekFrom;
2223
use std::io::Write;
23-
#[cfg(target_os = "linux")]
24-
use std::os::unix::prelude::AsRawFd;
2524
use std::path::Path;
2625
use std::path::PathBuf;
2726
use std::time::Duration;
@@ -31,6 +30,8 @@ use tedge_utils::file::FileError;
3130
use nix::fcntl::fallocate;
3231
#[cfg(target_os = "linux")]
3332
use nix::fcntl::FallocateFlags;
33+
#[cfg(target_os = "linux")]
34+
use std::os::unix::prelude::AsRawFd;
3435

3536
fn default_backoff() -> ExponentialBackoff {
3637
// Default retry is an exponential retry with a limit of 15 minutes total.
@@ -49,8 +50,8 @@ fn default_backoff() -> ExponentialBackoff {
4950
#[serde(deny_unknown_fields)]
5051
pub struct DownloadInfo {
5152
pub url: String,
52-
#[serde(skip_serializing_if = "Option::is_none")]
53-
pub auth: Option<String>,
53+
#[serde(skip)]
54+
pub headers: HeaderMap,
5455
}
5556

5657
impl From<&str> for DownloadInfo {
@@ -64,14 +65,14 @@ impl DownloadInfo {
6465
pub fn new(url: &str) -> Self {
6566
Self {
6667
url: url.into(),
67-
auth: None,
68+
headers: HeaderMap::new(),
6869
}
6970
}
7071

7172
/// Creates new [`DownloadInfo`] from a URL with authentication.
72-
pub fn with_auth(self, auth: &str) -> Self {
73+
pub fn with_headers(self, header_map: HeaderMap) -> Self {
7374
Self {
74-
auth: Some(auth.into()),
75+
headers: header_map,
7576
..self
7677
}
7778
}
@@ -369,8 +370,8 @@ impl Downloader {
369370

370371
let operation = || async {
371372
let mut request = self.client.get(url.url());
372-
if let Some(header_value) = &url.auth {
373-
request = request.header("Authorization", header_value)
373+
for (key, value) in &url.headers {
374+
request = request.header(key, value)
374375
}
375376

376377
if range_start != 0 {
@@ -467,6 +468,7 @@ fn try_pre_allocate_space(file: &File, path: &Path, file_len: u64) -> Result<(),
467468
#[allow(deprecated)]
468469
mod tests {
469470
use super::*;
471+
use hyper::header::AUTHORIZATION;
470472
use std::io::Write;
471473
use tempfile::tempdir;
472474
use tempfile::NamedTempFile;
@@ -908,10 +910,12 @@ mod tests {
908910
}
909911
};
910912

911-
// applying token if `with_token` = true
913+
// applying http auth header
912914
let url = {
913915
if with_token {
914-
url.with_auth("Bearer token")
916+
let mut headers = HeaderMap::new();
917+
headers.append(AUTHORIZATION, "Bearer token".parse().unwrap());
918+
url.with_headers(headers)
915919
} else {
916920
url
917921
}

crates/core/c8y_api/src/http_proxy.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use mqtt_channel::PubChannel;
55
use mqtt_channel::StreamExt;
66
use mqtt_channel::Topic;
77
use mqtt_channel::TopicFilter;
8+
use reqwest::header::HeaderMap;
89
use reqwest::Url;
910
use std::collections::HashMap;
1011
use std::time::Duration;
@@ -27,7 +28,7 @@ pub struct C8yEndPoint {
2728
c8y_host: String,
2829
c8y_mqtt_host: String,
2930
pub device_id: String,
30-
pub token: Option<String>,
31+
pub headers: HeaderMap,
3132
devices_internal_id: HashMap<String, String>,
3233
}
3334

@@ -37,7 +38,7 @@ impl C8yEndPoint {
3738
c8y_host: c8y_host.into(),
3839
c8y_mqtt_host: c8y_mqtt_host.into(),
3940
device_id: device_id.into(),
40-
token: None,
41+
headers: HeaderMap::new(),
4142
devices_internal_id: HashMap::new(),
4243
}
4344
}

crates/extensions/c8y_auth_proxy/src/actor.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use axum::async_trait;
2-
use c8y_http_proxy::credentials::AuthResult;
3-
use c8y_http_proxy::credentials::AuthRetriever;
2+
use c8y_http_proxy::credentials::HttpHeaderResult;
3+
use c8y_http_proxy::credentials::HttpHeaderRetriever;
44
use camino::Utf8PathBuf;
55
use futures::channel::mpsc;
66
use futures::StreamExt;
@@ -38,14 +38,14 @@ impl C8yAuthProxyBuilder {
3838
pub fn try_from_config(
3939
config: &TEdgeConfig,
4040
c8y_profile: Option<&str>,
41-
auth: &mut impl Service<(), AuthResult>,
41+
header_retriever: &mut impl Service<(), HttpHeaderResult>,
4242
) -> anyhow::Result<Self> {
4343
let reqwest_client = config.cloud_root_certs().client();
4444
let c8y = config.c8y.try_get(c8y_profile)?;
4545
let app_data = AppData {
4646
is_https: true,
4747
host: c8y.http.or_config_not_set()?.to_string(),
48-
token_manager: TokenManager::new(AuthRetriever::new(auth)).shared(),
48+
token_manager: TokenManager::new(HttpHeaderRetriever::new(header_retriever)).shared(),
4949
client: reqwest_client,
5050
};
5151
let bind = &c8y.proxy.bind;

crates/extensions/c8y_auth_proxy/src/server.rs

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use futures::Sink;
2323
use futures::SinkExt;
2424
use futures::Stream;
2525
use futures::StreamExt;
26+
use hyper::header::AUTHORIZATION;
2627
use hyper::header::HOST;
2728
use hyper::HeaderMap;
2829
use reqwest::Method;
@@ -235,7 +236,7 @@ async fn connect_to_websocket(
235236
for (name, value) in headers {
236237
req = req.header(name.as_str(), value);
237238
}
238-
req = req.header("Authorization", auth_value);
239+
req = req.header(AUTHORIZATION, auth_value);
239240
let req = req
240241
.uri(uri)
241242
.header(HOST, host.without_scheme.as_ref())
@@ -404,10 +405,10 @@ async fn respond_to(
404405
None => "",
405406
};
406407
let auth: fn(reqwest::RequestBuilder, &str) -> reqwest::RequestBuilder =
407-
if headers.contains_key("Authorization") {
408+
if headers.contains_key(AUTHORIZATION) {
408409
|req, _auth_value| req
409410
} else {
410-
|req, auth_value| req.header("Authorization", auth_value)
411+
|req, auth_value| req.header(AUTHORIZATION, auth_value)
411412
};
412413
headers.remove(HOST);
413414

@@ -496,12 +497,13 @@ mod tests {
496497
use axum::body::Bytes;
497498
use axum::headers::authorization::Bearer;
498499
use axum::headers::Authorization;
500+
use axum::http::header::AUTHORIZATION;
499501
use axum::http::Request;
500502
use axum::middleware::Next;
501503
use axum::TypedHeader;
502-
use c8y_http_proxy::credentials::AuthRequest;
503-
use c8y_http_proxy::credentials::AuthResult;
504-
use c8y_http_proxy::credentials::AuthRetriever;
504+
use c8y_http_proxy::credentials::HttpHeaderRequest;
505+
use c8y_http_proxy::credentials::HttpHeaderResult;
506+
use c8y_http_proxy::credentials::HttpHeaderRetriever;
505507
use camino::Utf8PathBuf;
506508
use futures::channel::mpsc;
507509
use futures::future::ready;
@@ -1113,7 +1115,7 @@ mod tests {
11131115
let state = AppData {
11141116
is_https: false,
11151117
host: target_host.into(),
1116-
token_manager: TokenManager::new(AuthRetriever::new(&mut retriever)).shared(),
1118+
token_manager: TokenManager::new(HttpHeaderRetriever::new(&mut retriever)).shared(),
11171119
client: reqwest::Client::new(),
11181120
};
11191121
let trust_store = ca_dir
@@ -1147,16 +1149,22 @@ mod tests {
11471149

11481150
#[async_trait]
11491151
impl Server for IterJwtRetriever {
1150-
type Request = AuthRequest;
1151-
type Response = AuthResult;
1152+
type Request = HttpHeaderRequest;
1153+
type Response = HttpHeaderResult;
11521154

11531155
fn name(&self) -> &str {
11541156
"IterJwtRetriever"
11551157
}
11561158

11571159
async fn handle(&mut self, _request: Self::Request) -> Self::Response {
1158-
let auth_value = format!("Bearer {}", self.tokens.next().unwrap());
1159-
Ok(auth_value)
1160+
let mut header_map = HeaderMap::new();
1161+
header_map.insert(
1162+
AUTHORIZATION,
1163+
format!("Bearer {}", self.tokens.next().unwrap())
1164+
.parse()
1165+
.unwrap(),
1166+
);
1167+
Ok(header_map)
11601168
}
11611169
}
11621170

crates/extensions/c8y_auth_proxy/src/tokens.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use anyhow::Context;
2+
use hyper::header::AUTHORIZATION;
13
use std::sync::Arc;
24

3-
use c8y_http_proxy::credentials::AuthRetriever;
5+
use c8y_http_proxy::credentials::HttpHeaderRetriever;
46
use tokio::sync::Mutex;
57

68
#[derive(Clone)]
@@ -16,12 +18,12 @@ impl SharedTokenManager {
1618
}
1719

1820
pub struct TokenManager {
19-
recv: AuthRetriever,
21+
recv: HttpHeaderRetriever,
2022
cached: Option<Arc<str>>,
2123
}
2224

2325
impl TokenManager {
24-
pub fn new(recv: AuthRetriever) -> Self {
26+
pub fn new(recv: HttpHeaderRetriever) -> Self {
2527
Self { recv, cached: None }
2628
}
2729

@@ -41,7 +43,11 @@ impl TokenManager {
4143
}
4244

4345
async fn refresh(&mut self) -> Result<Arc<str>, anyhow::Error> {
44-
self.cached = Some(self.recv.await_response(()).await??.into());
46+
let header_map = self.recv.await_response(()).await??;
47+
let auth_header_value = header_map
48+
.get(AUTHORIZATION)
49+
.context("Authorization is missing from header")?;
50+
self.cached = Some(auth_header_value.to_str()?.into());
4551
Ok(self.cached.as_ref().unwrap().clone())
4652
}
4753
}

crates/extensions/c8y_firmware_manager/src/actor.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use c8y_api::smartrest::message::collect_smartrest_messages;
1212
use c8y_api::smartrest::message::get_smartrest_template_id;
1313
use c8y_api::smartrest::smartrest_deserializer::SmartRestFirmwareRequest;
1414
use c8y_api::smartrest::smartrest_deserializer::SmartRestRequestGeneric;
15-
use c8y_http_proxy::credentials::AuthRetriever;
15+
use c8y_http_proxy::credentials::HttpHeaderRetriever;
1616
use log::error;
1717
use log::info;
1818
use log::warn;
@@ -84,7 +84,7 @@ impl FirmwareManagerActor {
8484
config: FirmwareManagerConfig,
8585
input_receiver: LoggingReceiver<FirmwareInput>,
8686
mqtt_publisher: DynSender<MqttMessage>,
87-
auth_retriever: AuthRetriever,
87+
header_retriever: HttpHeaderRetriever,
8888
download_sender: ClientMessageBox<IdDownloadRequest, IdDownloadResult>,
8989
progress_sender: DynSender<OperationOutcome>,
9090
) -> Self {
@@ -93,7 +93,7 @@ impl FirmwareManagerActor {
9393
worker: FirmwareManagerWorker::new(
9494
config,
9595
mqtt_publisher,
96-
auth_retriever,
96+
header_retriever,
9797
download_sender,
9898
progress_sender,
9999
),

crates/extensions/c8y_firmware_manager/src/lib.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ mod tests;
1010

1111
use actor::FirmwareInput;
1212
use actor::FirmwareManagerActor;
13-
use c8y_http_proxy::credentials::AuthResult;
14-
use c8y_http_proxy::credentials::AuthRetriever;
13+
use c8y_http_proxy::credentials::HttpHeaderResult;
14+
use c8y_http_proxy::credentials::HttpHeaderRetriever;
1515
pub use config::*;
1616
use tedge_actors::futures::channel::mpsc;
1717
use tedge_actors::Builder;
@@ -39,7 +39,7 @@ pub struct FirmwareManagerBuilder {
3939
config: FirmwareManagerConfig,
4040
input_receiver: LoggingReceiver<FirmwareInput>,
4141
mqtt_publisher: DynSender<MqttMessage>,
42-
jwt_retriever: AuthRetriever,
42+
header_retriever: HttpHeaderRetriever,
4343
download_sender: ClientMessageBox<IdDownloadRequest, IdDownloadResult>,
4444
progress_sender: DynSender<OperationOutcome>,
4545
signal_sender: mpsc::Sender<RuntimeRequest>,
@@ -49,7 +49,7 @@ impl FirmwareManagerBuilder {
4949
pub fn try_new(
5050
config: FirmwareManagerConfig,
5151
mqtt_actor: &mut (impl MessageSource<MqttMessage, TopicFilter> + MessageSink<MqttMessage>),
52-
jwt_actor: &mut impl Service<(), AuthResult>,
52+
header_actor: &mut impl Service<(), HttpHeaderResult>,
5353
downloader_actor: &mut impl Service<IdDownloadRequest, IdDownloadResult>,
5454
) -> Result<FirmwareManagerBuilder, FileError> {
5555
Self::init(&config.data_dir)?;
@@ -65,14 +65,14 @@ impl FirmwareManagerBuilder {
6565

6666
mqtt_actor.connect_sink(Self::subscriptions(&config.c8y_prefix), &mqtt_sender);
6767
let mqtt_publisher = mqtt_actor.get_sender();
68-
let jwt_retriever = AuthRetriever::new(jwt_actor);
68+
let header_retriever = HttpHeaderRetriever::new(header_actor);
6969
let download_sender = ClientMessageBox::new(downloader_actor);
7070
let progress_sender = input_sender.into();
7171
Ok(Self {
7272
config,
7373
input_receiver,
7474
mqtt_publisher,
75-
jwt_retriever,
75+
header_retriever,
7676
download_sender,
7777
progress_sender,
7878
signal_sender,
@@ -110,7 +110,7 @@ impl Builder<FirmwareManagerActor> for FirmwareManagerBuilder {
110110
self.config,
111111
self.input_receiver,
112112
self.mqtt_publisher,
113-
self.jwt_retriever,
113+
self.header_retriever,
114114
self.download_sender,
115115
self.progress_sender,
116116
))

0 commit comments

Comments
 (0)