From 6b497bd72ffa9710be5d9fa12bb662973b77f205 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 19 Jun 2022 18:47:42 +0200 Subject: [PATCH] API - add missing tests on Authorization Code Grant (code challenge) --- fittrackee/tests/mixins.py | 9 +- fittrackee/tests/oauth2/test_oauth2_routes.py | 133 +++++++++++++++++- 2 files changed, 138 insertions(+), 4 deletions(-) diff --git a/fittrackee/tests/mixins.py b/fittrackee/tests/mixins.py index 47827348..fffe3928 100644 --- a/fittrackee/tests/mixins.py +++ b/fittrackee/tests/mixins.py @@ -106,7 +106,10 @@ class ApiTestCaseMixin(OAuth2Mixin, RandomMixin): oauth_client: OAuth2Client, auth_token: str, scope: Optional[str] = None, + code_challenge: Optional[Dict] = None, ) -> Union[List[str], str]: + if code_challenge is None: + code_challenge = {} response = client.post( '/api/oauth/authorize', data={ @@ -114,6 +117,7 @@ class ApiTestCaseMixin(OAuth2Mixin, RandomMixin): 'confirm': True, 'response_type': 'code', 'scope': 'read' if not scope else scope, + **code_challenge, }, headers=dict( Authorization=f'Bearer {auth_token}', @@ -234,11 +238,14 @@ class ApiTestCaseMixin(OAuth2Mixin, RandomMixin): ) @staticmethod - def assert_invalid_request(response: TestResponse) -> Dict: + def assert_invalid_request( + response: TestResponse, error_description: Optional[str] = None + ) -> Dict: return assert_oauth_errored_response( response, 400, error='invalid_request', + error_description=error_description, ) @staticmethod diff --git a/fittrackee/tests/oauth2/test_oauth2_routes.py b/fittrackee/tests/oauth2/test_oauth2_routes.py index b50cf53c..19964d73 100644 --- a/fittrackee/tests/oauth2/test_oauth2_routes.py +++ b/fittrackee/tests/oauth2/test_oauth2_routes.py @@ -1,8 +1,10 @@ import json -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from unittest.mock import patch import pytest +from authlib.common.security import generate_token +from authlib.oauth2.rfc7636 import create_s256_code_challenge from flask import Flask from werkzeug.test import TestResponse @@ -367,15 +369,82 @@ class TestOAuthClientAuthorization(ApiTestCaseMixin): ) +class TestOAuthClientAuthorizationWithCodeChallenge(ApiTestCaseMixin): + route = '/api/oauth/authorize' + + def test_it_returns_error_when_code_challenge_method_is_invalid( + self, app: Flask, user_1: User + ) -> None: + client, auth_token = self.get_test_client_and_auth_token( + app, user_1.email + ) + oauth_client = self.create_oauth_client(user_1) + code_verifier = generate_token(48) + code_challenge = create_s256_code_challenge(code_verifier) + + response = client.post( + self.route, + data={ + 'client_id': oauth_client.client_id, + 'response_type': 'code', + 'confirm': 'true', + 'code_challenge': code_challenge, + 'code_challenge_method': self.random_string(), + }, + headers=dict( + Authorization=f'Bearer {auth_token}', + content_type='multipart/form-data', + ), + ) + + self.assert_400(response, 'Unsupported "code_challenge_method"') + + def test_it_creates_authorization_code( + self, app: Flask, user_1: User + ) -> None: + client, auth_token = self.get_test_client_and_auth_token( + app, user_1.email + ) + oauth_client = self.create_oauth_client(user_1) + code_verifier = generate_token(48) + code_challenge = create_s256_code_challenge(code_verifier) + + response = client.post( + self.route, + data={ + 'client_id': oauth_client.client_id, + 'response_type': 'code', + 'confirm': 'true', + 'code_challenge': code_challenge, + 'code_challenge_method': 'S256', + }, + headers=dict( + Authorization=f'Bearer {auth_token}', + content_type='multipart/form-data', + ), + ) + + code = OAuth2AuthorizationCode.query.filter_by( + client_id=oauth_client.client_id + ).first() + assert response.status_code == 200 + data = json.loads(response.data.decode()) + assert data['redirect_url'] == ( + f'{oauth_client.get_default_redirect_uri()}?code={code.code}' + ) + + class OAuthIssueTokenTestCase(ApiTestCaseMixin): def create_authorized_oauth_client( - self, app: Flask, user_1: User + self, app: Flask, user_1: User, code_challenge: Optional[Dict] = None ) -> Tuple[OAuth2Client, Union[List[str], str]]: client, auth_token = self.get_test_client_and_auth_token( app, user_1.email ) oauth_client = self.create_oauth_client(user_1) - code = self.authorize_client(client, oauth_client, auth_token) + code = self.authorize_client( + client, oauth_client, auth_token, code_challenge=code_challenge + ) return oauth_client, code @staticmethod @@ -497,6 +566,64 @@ class TestOAuthIssueAccessToken(OAuthIssueTokenTestCase): self.assert_token_is_returned(response) +class TestOAuthIssueAccessTokenWithCodeChallenge(OAuthIssueTokenTestCase): + route = '/api/oauth/token' + + def test_it_raises_error_when_code_verifier_is_invalid( + self, app: Flask, user_1: User + ) -> None: + code_verifier = generate_token(48) + code_challenge = create_s256_code_challenge(code_verifier) + oauth_client, code = self.create_authorized_oauth_client( + app, + user_1, + { + 'code_challenge': code_challenge, + 'code_challenge_method': 'S256', + }, + ) + client = app.test_client() + + response = client.post( + self.route, + data={ + 'client_id': oauth_client.client_id, + 'client_secret': oauth_client.client_secret, + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': self.random_string(), + }, + headers=dict(content_type='multipart/form-data'), + ) + self.assert_invalid_request(response, 'Invalid "code_verifier"') + + def test_it_returns_access_token(self, app: Flask, user_1: User) -> None: + code_verifier = generate_token(48) + oauth_client, code = self.create_authorized_oauth_client( + app, + user_1, + code_challenge={ + 'code_challenge': create_s256_code_challenge(code_verifier), + 'code_challenge_method': 'S256', + }, + ) + client = app.test_client() + + response = client.post( + self.route, + data={ + 'client_id': oauth_client.client_id, + 'client_secret': oauth_client.client_secret, + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': code_verifier, + }, + headers=dict(content_type='multipart/form-data'), + ) + + self.assert_token_is_returned(response) + + class TestOAuthIssueRefreshToken(OAuthIssueTokenTestCase): route = '/api/oauth/token'