diff --git a/src/main.rs b/src/main.rs index 5a7113a..24e2e9b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -246,13 +246,13 @@ fn config_options<'a, 'b>() -> App<'a, 'b> { .index(1) .required(true), ).arg( - Arg::with_name("algorithm") - .help("the algorithm to use for signing the JWT") + Arg::with_name("algorithms") + .help("a comma-separated list of algorithms to be used for signature validation. All algorithms need to be of the same family (HMAC, RSA, EC).") + .require_delimiter(true) .takes_value(true) - .long("alg") + .long("algs") .short("A") .possible_values(&SupportedAlgorithms::variants()) - .default_value("HS256"), ).arg( Arg::with_name("iso_dates") .help("display unix timestamps as ISO 8601 dates") @@ -264,7 +264,7 @@ fn config_options<'a, 'b>() -> App<'a, 'b> { .takes_value(true) .long("secret") .short("S") - .default_value(""), + .requires("algorithms") ).arg( Arg::with_name("json") .help("render decoded JWT as JSON") @@ -465,13 +465,6 @@ fn decode_token( JWTResult>, OutputFormat, ) { - let algorithm = translate_algorithm(SupportedAlgorithms::from_string( - matches.value_of("algorithm").unwrap(), - )); - let secret = match matches.value_of("secret").map(|s| (s, !s.is_empty())) { - Some((secret, true)) => Some(decoding_key_from_secret(&algorithm, &secret)), - _ => None, - }; let jwt = matches .value_of("jwt") .map(|value| { @@ -491,13 +484,7 @@ fn decode_token( .trim() .to_owned(); - let secret_validator = Validation { - leeway: 1000, - algorithms: vec![algorithm], - validate_exp: !matches.is_present("ignore_exp"), - ..Default::default() - }; - + // decode token without signature verification let token_data = dangerous_insecure_decode::(&jwt).map(|mut token| { if matches.is_present("iso_dates") { token.claims.convert_timestamps(); @@ -506,6 +493,31 @@ fn decode_token( token }); + // get vector of allowed algorithms from command line argument + let algorithms: Vec = match matches.values_of("algorithms") { + Some(algorithms) => algorithms + .map(|x| translate_algorithm(SupportedAlgorithms::from_string(x))) + .collect(), + None => vec![], + }; + + let secret_validator = Validation { + leeway: 1000, + algorithms: algorithms, + validate_exp: !matches.is_present("ignore_exp"), + ..Default::default() + }; + + // get the shared secret/public key to be used for signature validation + let secret = match matches.value_of("secret").map(|s| (s, !s.is_empty())) { + Some((secret, true)) => Some(decoding_key_from_secret( + &token_data.as_ref().unwrap().header.alg, // decode key according to algorithm used in the JWT + &secret, + )), + _ => None, + }; + + // return validated token, non-validated token data and output format ( match secret { Some(secret_key) => decode::(&jwt, &secret_key.unwrap(), &secret_validator), diff --git a/tests/jwt-cli.rs b/tests/jwt-cli.rs index 79f9eee..71db8f6 100644 --- a/tests/jwt-cli.rs +++ b/tests/jwt-cli.rs @@ -223,7 +223,15 @@ mod tests { let encode_matches = encode_matcher.subcommand_matches("encode").unwrap(); let encoded_token = encode_token(&encode_matches).unwrap(); let decode_matcher = config_options() - .get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token]) + .get_matches_from_safe(vec![ + "jwt", + "decode", + "-S", + "1234567890", + "-A", + "HS256", + &encoded_token, + ]) .unwrap(); let decode_matches = decode_matcher.subcommand_matches("decode").unwrap(); let (decoded_token, _, _) = decode_token(&decode_matches); @@ -257,7 +265,15 @@ mod tests { let encode_matches = encode_matcher.subcommand_matches("encode").unwrap(); let encoded_token = encode_token(&encode_matches).unwrap(); let decode_matcher = config_options() - .get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token]) + .get_matches_from_safe(vec![ + "jwt", + "decode", + "-S", + "1234567890", + "-A", + "HS256", + &encoded_token, + ]) .unwrap(); let decode_matches = decode_matcher.subcommand_matches("decode").unwrap(); let (decoded_token, _, _) = decode_token(&decode_matches); @@ -279,7 +295,15 @@ mod tests { let encode_matches = encode_matcher.subcommand_matches("encode").unwrap(); let encoded_token = encode_token(&encode_matches).unwrap(); let decode_matcher = config_options() - .get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token]) + .get_matches_from_safe(vec![ + "jwt", + "decode", + "-S", + "1234567890", + "-A", + "HS256", + &encoded_token, + ]) .unwrap(); let decode_matches = decode_matcher.subcommand_matches("decode").unwrap(); let (decoded_token, token_data, _) = decode_token(&decode_matches); @@ -299,7 +323,15 @@ mod tests { let encode_matches = encode_matcher.subcommand_matches("encode").unwrap(); let encoded_token = encode_token(&encode_matches).unwrap(); let decode_matcher = config_options() - .get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token]) + .get_matches_from_safe(vec![ + "jwt", + "decode", + "-S", + "1234567890", + "-A", + "HS256", + &encoded_token, + ]) .unwrap(); let decode_matches = decode_matcher.subcommand_matches("decode").unwrap(); let (decoded_token, _, _) = decode_token(&decode_matches); @@ -328,7 +360,15 @@ mod tests { let encode_matches = encode_matcher.subcommand_matches("encode").unwrap(); let encoded_token = encode_token(&encode_matches).unwrap(); let decode_matcher = config_options() - .get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token]) + .get_matches_from_safe(vec![ + "jwt", + "decode", + "-S", + "1234567890", + "-A", + "HS256", + &encoded_token, + ]) .unwrap(); let decode_matches = decode_matcher.subcommand_matches("decode").unwrap(); let (decoded_token, _, _) = decode_token(&decode_matches); @@ -356,7 +396,15 @@ mod tests { let encode_matches = encode_matcher.subcommand_matches("encode").unwrap(); let encoded_token = encode_token(&encode_matches).unwrap(); let decode_matcher = config_options() - .get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token]) + .get_matches_from_safe(vec![ + "jwt", + "decode", + "-S", + "1234567890", + "-A", + "HS256", + &encoded_token, + ]) .unwrap(); let decode_matches = decode_matcher.subcommand_matches("decode").unwrap(); let (decoded_token, _, _) = decode_token(&decode_matches); @@ -378,7 +426,15 @@ mod tests { let encode_matches = encode_matcher.subcommand_matches("encode").unwrap(); let encoded_token = encode_token(&encode_matches).unwrap(); let decode_matcher = config_options() - .get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token]) + .get_matches_from_safe(vec![ + "jwt", + "decode", + "-S", + "1234567890", + "-A", + "HS512", + &encoded_token, + ]) .unwrap(); let decode_matches = decode_matcher.subcommand_matches("decode").unwrap(); let (decoded_token, _, _) = decode_token(&decode_matches); @@ -399,6 +455,8 @@ mod tests { "decode", "-S", "1234567890", + "-A", + "HS256", "--ignore-exp", &encoded_token, ]) @@ -424,7 +482,15 @@ mod tests { let encode_matches = encode_matcher.subcommand_matches("encode").unwrap(); let encoded_token = encode_token(&encode_matches).unwrap(); let decode_matcher = config_options() - .get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token]) + .get_matches_from_safe(vec![ + "jwt", + "decode", + "-S", + "1234567890", + "-A", + "HS256", + &encoded_token, + ]) .unwrap(); let decode_matches = decode_matcher.subcommand_matches("decode").unwrap(); let (decoded_token, _, _) = decode_token(&decode_matches); @@ -460,7 +526,15 @@ mod tests { let encode_matches = encode_matcher.subcommand_matches("encode").unwrap(); let encoded_token = encode_token(&encode_matches).unwrap(); let decode_matcher = config_options() - .get_matches_from_safe(vec!["jwt", "decode", "-S", "1234567890", &encoded_token]) + .get_matches_from_safe(vec![ + "jwt", + "decode", + "-S", + "1234567890", + "-A", + "HS256", + &encoded_token, + ]) .unwrap(); let decode_matches = decode_matcher.subcommand_matches("decode").unwrap(); let (decoded_token, _, _) = decode_token(&decode_matches); @@ -597,6 +671,108 @@ mod tests { assert!(result.is_ok()); } + #[test] + fn encodes_and_decodes_a_token_with_multiple_algorithms() { + let body: String = "{\"field\":\"value\"}".to_string(); + let encode_matcher = config_options() + .get_matches_from_safe(vec![ + "jwt", + "encode", + "-A", + "HS256", + "--exp", + "-S", + "1234567890", + &body, + ]) + .unwrap(); + let encode_matches = encode_matcher.subcommand_matches("encode").unwrap(); + let encoded_token = encode_token(&encode_matches).unwrap(); + let decode_matcher = config_options() + .get_matches_from_safe(vec![ + "jwt", + "decode", + "-S", + "1234567890", + "-A", + "HS256,HS384,HS512", + &encoded_token, + ]) + .unwrap(); + let decode_matches = decode_matcher.subcommand_matches("decode").unwrap(); + let (result, _, _) = decode_token(&decode_matches); + + assert!(result.is_ok()); + } + + #[test] + fn encodes_and_decodes_a_token_with_invalid_algorithms_family() { + let body: String = "{\"field\":\"value\"}".to_string(); + let encode_matcher = config_options() + .get_matches_from_safe(vec![ + "jwt", + "encode", + "-A", + "HS256", + "--exp", + "-S", + "1234567890", + &body, + ]) + .unwrap(); + let encode_matches = encode_matcher.subcommand_matches("encode").unwrap(); + let encoded_token = encode_token(&encode_matches).unwrap(); + let decode_matcher = config_options() + .get_matches_from_safe(vec![ + "jwt", + "decode", + "-S", + "1234567890", + "-A", + "RS256,RS384,RS512", // invalid algorithm family + &encoded_token, + ]) + .unwrap(); + let decode_matches = decode_matcher.subcommand_matches("decode").unwrap(); + let (result, _, _) = decode_token(&decode_matches); + + assert!(result.is_err()); + } + + #[test] + fn encodes_and_decodes_a_token_with_mixed_algorithms_family() { + let body: String = "{\"field\":\"value\"}".to_string(); + let encode_matcher = config_options() + .get_matches_from_safe(vec![ + "jwt", + "encode", + "-A", + "HS256", + "--exp", + "-S", + "1234567890", + &body, + ]) + .unwrap(); + let encode_matches = encode_matcher.subcommand_matches("encode").unwrap(); + let encoded_token = encode_token(&encode_matches).unwrap(); + let decode_matcher = config_options() + .get_matches_from_safe(vec![ + "jwt", + "decode", + "-S", + "1234567890", + "-A", + "HS256,RS512", // algorithms from incompatible algorithm families + &encoded_token, + ]) + .unwrap(); + let decode_matches = decode_matcher.subcommand_matches("decode").unwrap(); + let (result, _, _) = decode_token(&decode_matches); + + assert!(result.is_err()); + } + #[test] fn encodes_and_decodes_an_rsa_token_using_key_from_file() { let body: String = "{\"field\":\"value\"}".to_string(); @@ -705,6 +881,8 @@ mod tests { "decode", "-S", "1234567890", + "-A", + "HS256", "--iso8601", &encoded_token, ])