diff --git a/fittrackee/oauth2/grants.py b/fittrackee/oauth2/grants.py index e99f8e38..85340c77 100644 --- a/fittrackee/oauth2/grants.py +++ b/fittrackee/oauth2/grants.py @@ -1,3 +1,4 @@ +import time from typing import Optional from authlib.oauth2 import OAuth2Request @@ -53,6 +54,9 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): class RefreshTokenGrant(grants.RefreshTokenGrant): + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post'] + INCLUDE_NEW_REFRESH_TOKEN = True + def authenticate_refresh_token(self, refresh_token: str) -> Optional[str]: token = OAuth2Token.query.filter_by( refresh_token=refresh_token @@ -63,3 +67,7 @@ class RefreshTokenGrant(grants.RefreshTokenGrant): def authenticate_user(self, credential: OAuth2Token) -> User: return User.query.get(credential.user_id) + + def revoke_old_credential(self, credential: OAuth2Token) -> None: + credential.access_token_revoked_at = time.time() + db.session.commit() diff --git a/fittrackee/tests/oauth2/test_oauth2_routes.py b/fittrackee/tests/oauth2/test_oauth2_routes.py index 74e1f997..4113a773 100644 --- a/fittrackee/tests/oauth2/test_oauth2_routes.py +++ b/fittrackee/tests/oauth2/test_oauth2_routes.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Union +from typing import Dict, List, Tuple, Union from unittest.mock import patch import pytest @@ -343,17 +343,30 @@ class TestOAuthClientAuthorization(ApiTestCaseMixin): ) -class TestOAuthIssueToken(ApiTestCaseMixin): - route = '/api/oauth/token' +class OAuthIssueTokenTestCase(ApiTestCaseMixin): + def create_authorized_oauth_client( + self, app: Flask, user_1: User + ) -> Tuple[OAuth2Client, Union[List[str], str]]: + client, auth_token = self.get_test_client_and_auth_token( + app, user_1.email + ) + oauth_client = self.create_oauth_client(user_1) + code = self.authorize_client(client, oauth_client, auth_token) + return oauth_client, code @staticmethod - def assert_token_is_returned(response: TestResponse) -> None: + def assert_token_is_returned(response: TestResponse) -> Dict: assert response.status_code == 200 data = json.loads(response.data.decode()) assert data.get('access_token') is not None assert data.get('expires_in') == 60 # test config assert data.get('refresh_token') is not None assert data.get('token_type') == 'Bearer' + return data + + +class TestOAuthIssueAccessToken(OAuthIssueTokenTestCase): + route = '/api/oauth/token' def test_it_returns_error_when_form_is_empty(self, app: Flask) -> None: client = app.test_client() @@ -369,11 +382,8 @@ class TestOAuthIssueToken(ApiTestCaseMixin): def test_it_returns_error_when_client_id_is_invalid( self, app: Flask, user_1: User ) -> None: - client, auth_token = self.get_test_client_and_auth_token( - app, user_1.email - ) - oauth_client = self.create_oauth_client(user_1) - code = self.authorize_client(client, oauth_client, auth_token) + oauth_client, code = self.create_authorized_oauth_client(app, user_1) + client = app.test_client() response = client.post( self.route, @@ -391,11 +401,8 @@ class TestOAuthIssueToken(ApiTestCaseMixin): def test_it_returns_error_when_client_secret_is_invalid( self, app: Flask, user_1: User ) -> None: - client, auth_token = self.get_test_client_and_auth_token( - app, user_1.email - ) - oauth_client = self.create_oauth_client(user_1) - code = self.authorize_client(client, oauth_client, auth_token) + oauth_client, code = self.create_authorized_oauth_client(app, user_1) + client = app.test_client() response = client.post( self.route, @@ -413,10 +420,8 @@ class TestOAuthIssueToken(ApiTestCaseMixin): def test_it_returns_error_when_client_not_authorized( self, app: Flask, user_1: User ) -> None: - client, auth_token = self.get_test_client_and_auth_token( - app, user_1.email - ) oauth_client = self.create_oauth_client(user_1) + client = app.test_client() response = client.post( self.route, @@ -434,11 +439,8 @@ class TestOAuthIssueToken(ApiTestCaseMixin): def test_it_returns_error_when_code_is_invalid( self, app: Flask, user_1: User ) -> None: - client, auth_token = self.get_test_client_and_auth_token( - app, user_1.email - ) - oauth_client = self.create_oauth_client(user_1) - self.authorize_client(client, oauth_client, auth_token) + oauth_client, code = self.create_authorized_oauth_client(app, user_1) + client = app.test_client() response = client.post( self.route, @@ -454,11 +456,8 @@ class TestOAuthIssueToken(ApiTestCaseMixin): self.assert_invalid_request(response) def test_it_returns_access_token(self, app: Flask, user_1: User) -> None: - client, auth_token = self.get_test_client_and_auth_token( - app, user_1.email - ) - oauth_client = self.create_oauth_client(user_1) - code = self.authorize_client(client, oauth_client, auth_token) + oauth_client, code = self.create_authorized_oauth_client(app, user_1) + client = app.test_client() response = client.post( self.route, @@ -474,6 +473,70 @@ class TestOAuthIssueToken(ApiTestCaseMixin): self.assert_token_is_returned(response) +class TestOAuthIssueRefreshToken(OAuthIssueTokenTestCase): + route = '/api/oauth/token' + + def generate_token( + self, app: Flask, user_1: User + ) -> Tuple[OAuth2Client, Dict]: + oauth_client, code = self.create_authorized_oauth_client(app, user_1) + response = app.test_client().post( + self.route, + data={ + 'client_id': oauth_client.client_id, + 'client_secret': oauth_client.client_secret, + 'grant_type': 'authorization_code', + 'code': code, + }, + headers=dict(content_type='multipart/form-data'), + ) + return oauth_client, json.loads(response.data.decode()) + + def test_it_returns_new_token_when_grant_is_refresh_token( + self, app: Flask, user_1: User + ) -> None: + oauth_client, token = self.generate_token(app, user_1) + client = app.test_client() + + response = client.post( + self.route, + data={ + 'client_id': oauth_client.client_id, + 'client_secret': oauth_client.client_secret, + 'grant_type': 'refresh_token', + 'refresh_token': token['refresh_token'], + }, + headers=dict(content_type='multipart/form-data'), + ) + + data = self.assert_token_is_returned(response) + assert data['access_token'] != token['access_token'] + new_token = OAuth2Token.query.filter_by( + access_token=token['access_token'] + ).first() + assert new_token is not None + + def test_it_revokes_old_token(self, app: Flask, user_1: User) -> None: + oauth_client, token = self.generate_token(app, user_1) + client = app.test_client() + + client.post( + self.route, + data={ + 'client_id': oauth_client.client_id, + 'client_secret': oauth_client.client_secret, + 'grant_type': 'refresh_token', + 'refresh_token': token['refresh_token'], + }, + headers=dict(content_type='multipart/form-data'), + ) + + revoked_token = OAuth2Token.query.filter_by( + access_token=token['access_token'] + ).first() + assert revoked_token.is_revoked() + + class TestOAuthTokenRevocation(ApiTestCaseMixin): route = '/api/oauth/revoke'