|
| 1 | +require "uri" |
| 2 | +require "jwt" |
| 3 | +require "jwt/jwks" |
| 4 | +require "placeos-models/authority" |
| 5 | +require "placeos-models/user" |
| 6 | +require "placeos-models/user_jwt" |
| 7 | + |
| 8 | +module PlaceOS::Api |
| 9 | + # Helper to authenticate using an MS token |
| 10 | + # * check the token is valid |
| 11 | + module Utils::MSTokenExchange |
| 12 | + extend self |
| 13 | + |
| 14 | + enum TokenVersion |
| 15 | + V1 |
| 16 | + V2 |
| 17 | + end |
| 18 | + |
| 19 | + record PeekInfo, |
| 20 | + aud_raw : String, |
| 21 | + aud_host : String, |
| 22 | + email : String?, |
| 23 | + tid : String?, |
| 24 | + iss : String?, |
| 25 | + iss_host : String?, |
| 26 | + version : TokenVersion, |
| 27 | + kid : String? do |
| 28 | + # Basic heuristic to detect Microsoft Entra / Azure AD issuers |
| 29 | + def is_ms_token? : Bool |
| 30 | + iss_val = iss_host |
| 31 | + return false unless iss_val |
| 32 | + iss_val = iss_val.downcase |
| 33 | + iss_val.ends_with?("microsoftonline.com") || |
| 34 | + iss_val.ends_with?("sts.windows.net") || |
| 35 | + iss_val.ends_with?("login.windows.net") || |
| 36 | + iss_val.ends_with?("login.chinacloudapi.cn") || # China cloud |
| 37 | + iss_val.ends_with?("login.microsoftonline.de") || # Germany |
| 38 | + iss_val.ends_with?("login.partner.microsoftonline.cn") || # 21V |
| 39 | + iss_val.ends_with?("login-us.microsoftonline.com") # GCC/DoD |
| 40 | + end |
| 41 | + |
| 42 | + def token_endpoint : URI? |
| 43 | + case version |
| 44 | + in .v1? |
| 45 | + URI.parse("https://login.microsoftonline.com/#{tid}/oauth2/token") |
| 46 | + in .v2? |
| 47 | + URI.parse("https://login.microsoftonline.com/#{tid}/oauth2/v2.0/token") |
| 48 | + end |
| 49 | + end |
| 50 | + end |
| 51 | + |
| 52 | + # ---------- Peek (safe decode, no signature validation) ---------- |
| 53 | + |
| 54 | + def peek_token_info(token : String) : PeekInfo |
| 55 | + payload, header = JWT.decode(token, verify: false, validate: false) |
| 56 | + |
| 57 | + aud_raw = payload["aud"]?.try(&.as_s) || raise "missing aud" |
| 58 | + iss = payload["iss"]?.try(&.as_s) || raise "missing iss" |
| 59 | + email = payload["upn"]?.try(&.as_s) |
| 60 | + tid = payload["tid"]?.try(&.as_s) |
| 61 | + kid = header["kid"]?.try(&.as_s) |
| 62 | + |
| 63 | + version = detect_token_version(payload, iss) |
| 64 | + aud_host = extract_aud_host(aud_raw) |
| 65 | + iss_host = extract_issuer_host(iss) |
| 66 | + |
| 67 | + PeekInfo.new( |
| 68 | + aud_raw: aud_raw, |
| 69 | + aud_host: aud_host, |
| 70 | + email: email, |
| 71 | + tid: tid, |
| 72 | + iss: iss, |
| 73 | + iss_host: iss_host, |
| 74 | + version: version, |
| 75 | + kid: kid |
| 76 | + ) |
| 77 | + end |
| 78 | + |
| 79 | + # obtain MS Graph API token - this is a simple way to validate its authenticity |
| 80 | + def obtain_place_user(token : String, token_info : PeekInfo? = nil) : Model::User? |
| 81 | + info = token_info || peek_token_info(token) |
| 82 | + tenant = info.tid |
| 83 | + email = info.email |
| 84 | + return unless tenant && email |
| 85 | + oauth = Model::OAuthAuthentication.find_by?(client_id: info.aud_host) |
| 86 | + return unless oauth |
| 87 | + |
| 88 | + # ensure Tenant ID matches our authentication source |
| 89 | + return unless oauth.token_url.includes?(tenant) |
| 90 | + |
| 91 | + # validate the MS token |
| 92 | + payload = validate_token_with_jwks(token, token_info: info) |
| 93 | + |
| 94 | + # find the place user or create a new one |
| 95 | + user = Model::User.find_by?(authority_id: oauth.authority_id, email: email.downcase) || create_place_user(oauth, payload) |
| 96 | + |
| 97 | + # ensure there is a valid MS Graph API access token in place |
| 98 | + # as we maybe attempting to perform graph actions on behalf of the user |
| 99 | + ensure_valid_token(oauth, user, token, info) |
| 100 | + |
| 101 | + # return the user |
| 102 | + user |
| 103 | + end |
| 104 | + |
| 105 | + def create_place_user(oauth : Model::OAuthAuthentication, payload : JSON::Any) : Model::User |
| 106 | + Model::User.create!( |
| 107 | + name: payload["name"].as_s, |
| 108 | + last_name: payload["family_name"].as_s, |
| 109 | + first_name: payload["given_name"].as_s, |
| 110 | + email: Model::Email.new(payload["upn"].as_s), |
| 111 | + authority_id: oauth.authority_id |
| 112 | + ) |
| 113 | + end |
| 114 | + |
| 115 | + def ensure_valid_token(oauth : Model::OAuthAuthentication, user : Model::User, token : String, token_info : PeekInfo) |
| 116 | + # return if there is an existing token and valid |
| 117 | + existing = Api::Users.get_user_token(user, oauth.authority.as(Model::Authority)) rescue nil |
| 118 | + return if existing |
| 119 | + |
| 120 | + # if not existing or refresh failed, get a token using this token and on behalf of |
| 121 | + # https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-on-behalf-of-flow#example |
| 122 | + form = URI::Params.build do |form| |
| 123 | + form.add "grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer" |
| 124 | + form.add "client_id", oauth.client_id |
| 125 | + form.add "client_secret", oauth.client_secret |
| 126 | + form.add "assertion", token |
| 127 | + form.add "scope", oauth.scope |
| 128 | + form.add "requested_token_use", "on_behalf_of" |
| 129 | + end |
| 130 | + |
| 131 | + uri = token_info.token_endpoint |
| 132 | + |
| 133 | + client = HTTP::Client.new(uri, tls: true) |
| 134 | + client.basic_auth(oauth.client_id, oauth.client_secret) |
| 135 | + response = HTTP::Client.post( |
| 136 | + uri.request_target, |
| 137 | + headers: HTTP::Headers{ |
| 138 | + "Accept" => "application/json", |
| 139 | + }, |
| 140 | + form: form |
| 141 | + ) |
| 142 | + |
| 143 | + if !response.success? |
| 144 | + Log.warn { "failed with #{response.status_code} to obtain token on behalf of #{user.name} (#{user.id})\nbody: #{response.body}" } |
| 145 | + return |
| 146 | + end |
| 147 | + |
| 148 | + # update the user model with the graph API access token |
| 149 | + token = OAuth2::AccessToken.from_json(response.body) |
| 150 | + user.access_token = token.access_token |
| 151 | + user.refresh_token = token.refresh_token if token.refresh_token |
| 152 | + user.expires_at = Time.utc.to_unix + token.expires_in.not_nil! |
| 153 | + user.save! |
| 154 | + end |
| 155 | + |
| 156 | + def detect_token_version(payload : JSON::Any, iss : String) : TokenVersion |
| 157 | + ver = payload["ver"]?.try &.as_s? |
| 158 | + return TokenVersion::V2 if ver == "2.0" || iss.includes?("/v2.0") |
| 159 | + TokenVersion::V1 |
| 160 | + end |
| 161 | + |
| 162 | + # ---------- Audience Parsing ---------- |
| 163 | + |
| 164 | + def extract_aud_host(aud_raw : String) : String |
| 165 | + begin |
| 166 | + uri = URI.parse(aud_raw) |
| 167 | + uri.host || aud_raw |
| 168 | + rescue |
| 169 | + aud_raw |
| 170 | + end |
| 171 | + end |
| 172 | + |
| 173 | + # ---------- Issuer Parsing ---------- |
| 174 | + |
| 175 | + def extract_issuer_host(iss_raw : String) : String? |
| 176 | + begin |
| 177 | + uri = URI.parse(iss_raw) |
| 178 | + uri.host |
| 179 | + rescue |
| 180 | + nil |
| 181 | + end |
| 182 | + end |
| 183 | + |
| 184 | + # ---------- Validation (JWKS) ---------- |
| 185 | + |
| 186 | + class_getter jwks : JWT::JWKS { JWT::JWKS.new } |
| 187 | + |
| 188 | + def validate_token_with_jwks( |
| 189 | + token : String, |
| 190 | + token_info : PeekInfo? = nil, |
| 191 | + ) : JSON::Any |
| 192 | + info = token_info || peek_token_info(token) |
| 193 | + jwks = MSTokenExchange.jwks |
| 194 | + payload = jwks.validate( |
| 195 | + token, |
| 196 | + validate_claims: true |
| 197 | + ) || raise "token validation failed" |
| 198 | + |
| 199 | + payload |
| 200 | + end |
| 201 | + end |
| 202 | +end |
0 commit comments