2022-09-19 14:05:30 +02:00

321 lines
9.6 KiB
Python

import json
import time
from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import parse_qs
from flask import Flask
from flask.testing import FlaskClient
from urllib3.util import parse_url
from werkzeug.test import TestResponse
from fittrackee import db
from fittrackee.oauth2.client import create_oauth2_client
from fittrackee.oauth2.models import OAuth2Client, OAuth2Token
from fittrackee.users.models import User
from .custom_asserts import (
assert_errored_response,
assert_oauth_errored_response,
)
from .utils import (
TEST_OAUTH_CLIENT_METADATA,
random_email,
random_int,
random_string,
)
class RandomMixin:
@staticmethod
def random_string(
length: Optional[int] = None,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
) -> str:
return random_string(length, prefix, suffix)
@staticmethod
def random_domain() -> str:
return random_string(prefix='https://', suffix='com')
@staticmethod
def random_email() -> str:
return random_email()
@staticmethod
def random_int(min_val: int = 0, max_val: int = 999999) -> int:
return random_int(min_val, max_val)
class OAuth2Mixin(RandomMixin):
@staticmethod
def create_oauth2_client(
user: User,
metadata: Optional[Dict] = None,
scope: Optional[str] = None,
) -> OAuth2Client:
client_metadata = (
TEST_OAUTH_CLIENT_METADATA if metadata is None else metadata
)
if scope is not None:
client_metadata['scope'] = scope
oauth_client = create_oauth2_client(client_metadata, user)
db.session.add(oauth_client)
db.session.commit()
return oauth_client
def create_oauth2_token(
self,
oauth_client: OAuth2Client,
issued_at: Optional[int] = None,
access_token_revoked_at: Optional[int] = 0,
expires_in: Optional[int] = 1000,
) -> OAuth2Token:
issued_at = issued_at if issued_at else int(time.time())
token = OAuth2Token(
client_id=oauth_client.client_id,
access_token=self.random_string(),
refresh_token=self.random_string(),
issued_at=issued_at,
access_token_revoked_at=access_token_revoked_at,
expires_in=expires_in,
)
db.session.add(token)
db.session.commit()
return token
class ApiTestCaseMixin(OAuth2Mixin, RandomMixin):
@staticmethod
def get_test_client_and_auth_token(
app: Flask, user_email: str
) -> Tuple[FlaskClient, str]:
client = app.test_client()
resp_login = client.post(
'/api/auth/login',
data=json.dumps(
dict(
email=user_email,
password='12345678',
)
),
content_type='application/json',
)
auth_token = json.loads(resp_login.data.decode())['auth_token']
return client, auth_token
@staticmethod
def authorize_client(
client: FlaskClient,
oauth_client: OAuth2Client,
auth_token: str,
scope: Optional[str] = None,
code_challenge: Optional[Dict] = None,
) -> Union[List[str], str]:
if code_challenge is None:
code_challenge = {}
response = client.post(
'/api/oauth/authorize',
data={
'client_id': oauth_client.client_id,
'confirm': True,
'response_type': 'code',
'scope': 'read' if not scope else scope,
**code_challenge,
},
headers=dict(
Authorization=f'Bearer {auth_token}',
content_type='multipart/form-data',
),
)
data = json.loads(response.data.decode())
parsed_url = parse_url(data['redirect_url'])
code = parse_qs(parsed_url.query).get('code', '')
return code
def create_oauth2_client_and_issue_token(
self, app: Flask, user: User, scope: Optional[str] = None
) -> Tuple[FlaskClient, OAuth2Client, str, str]:
client, auth_token = self.get_test_client_and_auth_token(
app, user.email
)
oauth_client = self.create_oauth2_client(user, scope=scope)
code = self.authorize_client(
client, oauth_client, auth_token, scope=scope
)
response = client.post(
'/api/oauth/token',
data={
'client_id': oauth_client.client_id,
'client_secret': oauth_client.client_secret,
'grant_type': 'authorization_code',
'code': code,
},
headers=dict(content_type='multipart/form-data'),
)
data = json.loads(response.data.decode())
return client, oauth_client, data.get('access_token'), auth_token
@staticmethod
def assert_400(
response: TestResponse,
error_message: Optional[str] = 'invalid payload',
status: Optional[str] = 'error',
) -> Dict:
return assert_errored_response(
response, 400, error_message=error_message, status=status
)
@staticmethod
def assert_401(
response: TestResponse,
error_message: Optional[str] = 'provide a valid auth token',
) -> Dict:
return assert_errored_response(
response, 401, error_message=error_message
)
@staticmethod
def assert_403(
response: TestResponse,
error_message: Optional[str] = 'you do not have permissions',
) -> Dict:
return assert_errored_response(response, 403, error_message)
@staticmethod
def assert_404(response: TestResponse) -> Dict:
return assert_errored_response(response, 404, status='not found')
@staticmethod
def assert_404_with_entity(response: TestResponse, entity: str) -> Dict:
error_message = f'{entity} does not exist'
return assert_errored_response(
response, 404, error_message=error_message, status='not found'
)
@staticmethod
def assert_404_with_message(
response: TestResponse, error_message: str
) -> Dict:
return assert_errored_response(
response, 404, error_message=error_message, status='not found'
)
@staticmethod
def assert_413(
response: TestResponse,
error_message: Optional[str] = None,
match: Optional[str] = None,
) -> Dict:
return assert_errored_response(
response,
413,
error_message=error_message,
status='fail',
match=match,
)
@staticmethod
def assert_500(
response: TestResponse,
error_message: Optional[str] = (
'error, please try again or contact the administrator'
),
status: Optional[str] = 'error',
) -> Dict:
return assert_errored_response(
response, 500, error_message=error_message, status=status
)
@staticmethod
def assert_unsupported_grant_type(response: TestResponse) -> Dict:
return assert_oauth_errored_response(
response, 400, error='unsupported_grant_type'
)
@staticmethod
def assert_invalid_client(response: TestResponse) -> Dict:
return assert_oauth_errored_response(
response,
400,
error='invalid_client',
)
@staticmethod
def assert_invalid_grant(
response: TestResponse, error_description: Optional[str] = None
) -> Dict:
return assert_oauth_errored_response(
response,
400,
error='invalid_grant',
error_description=error_description,
)
@staticmethod
def assert_invalid_request(
response: TestResponse, error_description: Optional[str] = None
) -> Dict:
return assert_oauth_errored_response(
response,
400,
error='invalid_request',
error_description=error_description,
)
@staticmethod
def assert_invalid_token(response: TestResponse) -> Dict:
return assert_oauth_errored_response(
response,
401,
error='invalid_token',
error_description=(
'The access token provided is expired, revoked, malformed, '
'or invalid for other reasons.'
),
)
@staticmethod
def assert_insufficient_scope(response: TestResponse) -> Dict:
return assert_oauth_errored_response(
response,
403,
error='insufficient_scope',
error_description=(
'The request requires higher privileges than provided by '
'the access token.'
),
)
@staticmethod
def assert_not_insufficient_scope_error(response: TestResponse) -> None:
assert response.status_code != 403
def assert_response_scope(
self, response: TestResponse, can_access: bool
) -> None:
if can_access:
self.assert_not_insufficient_scope_error(response)
else:
self.assert_insufficient_scope(response)
class CallArgsMixin:
"""call args are returned differently between Python 3.7 and 3.7+"""
@staticmethod
def get_args(call_args: Tuple) -> Tuple:
if len(call_args) == 2:
args, _ = call_args
else:
_, args, _ = call_args
return args
@staticmethod
def get_kwargs(call_args: Tuple) -> Dict:
if len(call_args) == 2:
_, kwargs = call_args
else:
_, _, kwargs = call_args
return kwargs