diff --git a/designsafe/apps/auth/models.py b/designsafe/apps/auth/models.py index 6db226f3f..cdecf426e 100644 --- a/designsafe/apps/auth/models.py +++ b/designsafe/apps/auth/models.py @@ -3,9 +3,11 @@ import logging import time +import requests from django.db import models from django.conf import settings from tapipy.tapis import Tapis +from tapipy.errors import BaseTapyException logger = logging.getLogger(__name__) @@ -74,13 +76,24 @@ def client(self): :return: Tapis client using refresh token. :rtype: :class:Tapis """ - return Tapis( - base_url=getattr(settings, "TAPIS_TENANT_BASEURL"), - client_id=getattr(settings, "TAPIS_CLIENT_ID"), - client_key=getattr(settings, "TAPIS_CLIENT_KEY"), - access_token=self.access_token, - refresh_token=self.refresh_token, - ) + try: + return Tapis( + base_url=getattr(settings, "TAPIS_TENANT_BASEURL"), + client_id=getattr(settings, "TAPIS_CLIENT_ID"), + client_key=getattr(settings, "TAPIS_CLIENT_KEY"), + access_token=self.access_token, + refresh_token=self.refresh_token, + ) + except BaseTapyException: + # If client cannot be instantiated, we might need to refresh tokens using an API call. + self.refresh_tokens_api() + return Tapis( + base_url=getattr(settings, "TAPIS_TENANT_BASEURL"), + client_id=getattr(settings, "TAPIS_CLIENT_ID"), + client_key=getattr(settings, "TAPIS_CLIENT_KEY"), + access_token=self.access_token, + refresh_token=self.refresh_token, + ) def update(self, **kwargs): """Bulk update model attributes""" @@ -98,6 +111,32 @@ def refresh_tokens(self): expires_in=self.client.access_token.expires_in().total_seconds(), ) + def refresh_tokens_api(self): + """ + Refresh tokens using a direct call to the Tapis API. + Used when a Tapipy client cannot be instantiated. + """ + auth = requests.auth.HTTPBasicAuth(username=settings.TAPIS_CLIENT_ID, + password=settings.TAPIS_CLIENT_KEY) + payload = {"grant_type": "refresh_token", "refresh_token": self.refresh_token} + request = requests.post(f"{settings.TAPIS_TENANT_BASEURL}/v3/oauth2/tokens", + json=payload, + auth=auth, + timeout=30) + request.raise_for_status() + + request_json = request.json() + access_token = request_json["result"]["access_token"]["access_token"] + expires_in = request_json["result"]["access_token"]["expires_in"] + refresh_token = request_json["result"]["refresh_token"]["refresh_token"] + + self.update( + created=int(time.time()), + access_token=access_token, + refresh_token=refresh_token, + expires_in=expires_in, + ) + def __str__(self): access_token_masked = self.access_token[-5:] refresh_token_masked = self.refresh_token[-5:]