From c6cd7ff67c795b3608e234fe948044d42ff2b325 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 27 May 2022 13:28:26 +0200 Subject: [PATCH] API - init OAuth server and oauth clients creation --- fittrackee/__init__.py | 12 ++ .../versions/24_84d840ce853b_add_oauth.py | 76 ++++++++ fittrackee/oauth2/__init__.py | 0 fittrackee/oauth2/client.py | 35 ++++ fittrackee/oauth2/config.py | 14 ++ fittrackee/oauth2/grants.py | 65 +++++++ fittrackee/oauth2/models.py | 59 +++++++ fittrackee/oauth2/routes.py | 53 ++++++ fittrackee/oauth2/server.py | 16 ++ fittrackee/tests/conftest.py | 1 + fittrackee/tests/mixins.py | 9 + fittrackee/tests/oauth2/__init__.py | 0 fittrackee/tests/oauth2/test_oauth2_client.py | 137 +++++++++++++++ fittrackee/tests/oauth2/test_oauth2_models.py | 36 ++++ fittrackee/tests/oauth2/test_oauth2_routes.py | 163 ++++++++++++++++++ fittrackee/tests/utils.py | 4 + fittrackee/users/models.py | 3 + poetry.lock | 17 +- pyproject.toml | 1 + 19 files changed, 700 insertions(+), 1 deletion(-) create mode 100644 fittrackee/migrations/versions/24_84d840ce853b_add_oauth.py create mode 100644 fittrackee/oauth2/__init__.py create mode 100644 fittrackee/oauth2/client.py create mode 100644 fittrackee/oauth2/config.py create mode 100644 fittrackee/oauth2/grants.py create mode 100644 fittrackee/oauth2/models.py create mode 100644 fittrackee/oauth2/routes.py create mode 100644 fittrackee/oauth2/server.py create mode 100644 fittrackee/tests/oauth2/__init__.py create mode 100644 fittrackee/tests/oauth2/test_oauth2_client.py create mode 100644 fittrackee/tests/oauth2/test_oauth2_models.py create mode 100644 fittrackee/tests/oauth2/test_oauth2_routes.py diff --git a/fittrackee/__init__.py b/fittrackee/__init__.py index 9fcc0772..4d0dafcc 100644 --- a/fittrackee/__init__.py +++ b/fittrackee/__init__.py @@ -16,6 +16,7 @@ from flask_dramatiq import Dramatiq from flask_migrate import Migrate from flask_sqlalchemy import SQLAlchemy from sqlalchemy.exc import ProgrammingError +from werkzeug.middleware.proxy_fix import ProxyFix from fittrackee.emails.email import EmailService from fittrackee.request import CustomRequest @@ -64,6 +65,11 @@ def create_app(init_email: bool = True) -> Flask: migrate.init_app(app, db) dramatiq.init_app(app) + # set oauth2 + from fittrackee.oauth2.config import config_oauth + + config_oauth(app) + # set up email if 'EMAIL_URL' is initialized if init_email: if app.config['EMAIL_URL']: @@ -95,6 +101,7 @@ def create_app(init_email: bool = True) -> Flask: pass from .application.app_config import config_blueprint # noqa + from .oauth2.routes import oauth_blueprint # noqa from .users.auth import auth_blueprint # noqa from .users.users import users_blueprint # noqa from .workouts.records import records_blueprint # noqa @@ -103,6 +110,7 @@ def create_app(init_email: bool = True) -> Flask: from .workouts.workouts import workouts_blueprint # noqa app.register_blueprint(auth_blueprint, url_prefix='/api') + app.register_blueprint(oauth_blueprint, url_prefix='/api') app.register_blueprint(config_blueprint, url_prefix='/api') app.register_blueprint(records_blueprint, url_prefix='/api') app.register_blueprint(sports_blueprint, url_prefix='/api') @@ -153,4 +161,8 @@ def create_app(init_email: bool = True) -> Flask: else: return render_template('index.html') + # to get headers, especially 'X-Forwarded-Proto' for scheme needed by + # Authlib, when the application is running behind a proxy server + app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore + return app diff --git a/fittrackee/migrations/versions/24_84d840ce853b_add_oauth.py b/fittrackee/migrations/versions/24_84d840ce853b_add_oauth.py new file mode 100644 index 00000000..7bcdac55 --- /dev/null +++ b/fittrackee/migrations/versions/24_84d840ce853b_add_oauth.py @@ -0,0 +1,76 @@ +"""add OAuth 2.0 + +Revision ID: 84d840ce853b +Revises: 5e3a3a31c432 +Create Date: 2022-05-27 10:54:02.284543 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '84d840ce853b' +down_revision = '5e3a3a31c432' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('oauth2_client', + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('client_secret', sa.String(length=120), nullable=True), + sa.Column('client_id_issued_at', sa.Integer(), nullable=False), + sa.Column('client_secret_expires_at', sa.Integer(), nullable=False), + sa.Column('client_metadata', sa.Text(), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_oauth2_client_client_id'), 'oauth2_client', ['client_id'], unique=False) + op.create_table('oauth2_code', + sa.Column('code', sa.String(length=120), nullable=False), + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('redirect_uri', sa.Text(), nullable=True), + sa.Column('response_type', sa.Text(), nullable=True), + sa.Column('scope', sa.Text(), nullable=True), + sa.Column('nonce', sa.Text(), nullable=True), + sa.Column('auth_time', sa.Integer(), nullable=False), + sa.Column('code_challenge', sa.Text(), nullable=True), + sa.Column('code_challenge_method', sa.String(length=48), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('code') + ) + op.create_table('oauth2_token', + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('token_type', sa.String(length=40), nullable=True), + sa.Column('access_token', sa.String(length=255), nullable=False), + sa.Column('refresh_token', sa.String(length=255), nullable=True), + sa.Column('scope', sa.Text(), nullable=True), + sa.Column('issued_at', sa.Integer(), nullable=False), + sa.Column('access_token_revoked_at', sa.Integer(), nullable=False), + sa.Column('refresh_token_revoked_at', sa.Integer(), nullable=False), + sa.Column('expires_in', sa.Integer(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('access_token') + ) + op.create_index(op.f('ix_oauth2_token_refresh_token'), 'oauth2_token', ['refresh_token'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_oauth2_token_refresh_token'), table_name='oauth2_token') + op.drop_table('oauth2_token') + op.drop_table('oauth2_code') + op.drop_index(op.f('ix_oauth2_client_client_id'), table_name='oauth2_client') + op.drop_table('oauth2_client') + # ### end Alembic commands ### diff --git a/fittrackee/oauth2/__init__.py b/fittrackee/oauth2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fittrackee/oauth2/client.py b/fittrackee/oauth2/client.py new file mode 100644 index 00000000..e07e6639 --- /dev/null +++ b/fittrackee/oauth2/client.py @@ -0,0 +1,35 @@ +from time import time +from typing import Dict + +from werkzeug.security import gen_salt + +from fittrackee.oauth2.models import OAuth2Client +from fittrackee.users.models import User + + +def create_oauth_client(metadata: Dict, user: User) -> OAuth2Client: + """ + Create oauth client for 3rd-party applications. + + Only Authorization Code Grant with 'client_secret_post' as method + is supported. + """ + client_metadata = { + 'client_name': metadata['client_name'], + 'client_uri': metadata['client_uri'], + 'redirect_uris': metadata['redirect_uris'], + 'scope': metadata['scope'], + 'grant_types': ['authorization_code'], + 'response_types': ['code'], + 'token_endpoint_auth_method': 'client_secret_post', + } + client_id = gen_salt(24) + client_id_issued_at = int(time()) + client = OAuth2Client( + client_id=client_id, + client_id_issued_at=client_id_issued_at, + user_id=user.id, + ) + client.set_client_metadata(client_metadata) + client.client_secret = gen_salt(48) + return client diff --git a/fittrackee/oauth2/config.py b/fittrackee/oauth2/config.py new file mode 100644 index 00000000..8f7bb4ea --- /dev/null +++ b/fittrackee/oauth2/config.py @@ -0,0 +1,14 @@ +from authlib.oauth2.rfc7636 import CodeChallenge +from flask import Flask + +from .grants import AuthorizationCodeGrant +from .server import authorization_server + + +def config_oauth(app: Flask) -> None: + authorization_server.init_app(app) + + # supported grants + authorization_server.register_grant( + AuthorizationCodeGrant, [CodeChallenge(required=True)] + ) diff --git a/fittrackee/oauth2/grants.py b/fittrackee/oauth2/grants.py new file mode 100644 index 00000000..e99f8e38 --- /dev/null +++ b/fittrackee/oauth2/grants.py @@ -0,0 +1,65 @@ +from typing import Optional + +from authlib.oauth2 import OAuth2Request +from authlib.oauth2.rfc6749 import grants + +from fittrackee import db +from fittrackee.users.models import User + +from .models import OAuth2AuthorizationCode, OAuth2Client, OAuth2Token + + +class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post'] + + def save_authorization_code( + self, code: str, request: OAuth2Request + ) -> OAuth2AuthorizationCode: + code_challenge = request.data.get('code_challenge') + code_challenge_method = request.data.get('code_challenge_method') + auth_code = OAuth2AuthorizationCode( + code=code, + client_id=request.client.client_id, + redirect_uri=request.redirect_uri, + scope=request.scope, + user_id=request.user.id, + code_challenge=code_challenge, + code_challenge_method=code_challenge_method, + ) + db.session.add(auth_code) + db.session.commit() + return auth_code + + def query_authorization_code( + self, code: str, client: OAuth2Client + ) -> Optional[OAuth2AuthorizationCode]: + auth_code = OAuth2AuthorizationCode.query.filter_by( + code=code, client_id=client.client_id + ).first() + if auth_code and not auth_code.is_expired(): + return auth_code + return None + + def delete_authorization_code( + self, authorization_code: OAuth2AuthorizationCode + ) -> None: + db.session.delete(authorization_code) + db.session.commit() + + def authenticate_user( + self, authorization_code: OAuth2AuthorizationCode + ) -> User: + return User.query.get(authorization_code.user_id) + + +class RefreshTokenGrant(grants.RefreshTokenGrant): + def authenticate_refresh_token(self, refresh_token: str) -> Optional[str]: + token = OAuth2Token.query.filter_by( + refresh_token=refresh_token + ).first() + if token and token.is_refresh_token_active(): + return token + return None + + def authenticate_user(self, credential: OAuth2Token) -> User: + return User.query.get(credential.user_id) diff --git a/fittrackee/oauth2/models.py b/fittrackee/oauth2/models.py new file mode 100644 index 00000000..aa3641d7 --- /dev/null +++ b/fittrackee/oauth2/models.py @@ -0,0 +1,59 @@ +import time +from typing import Dict + +from authlib.integrations.sqla_oauth2 import ( + OAuth2AuthorizationCodeMixin, + OAuth2ClientMixin, + OAuth2TokenMixin, +) +from sqlalchemy.ext.declarative import DeclarativeMeta + +from fittrackee import db + +BaseModel: DeclarativeMeta = db.Model + + +class OAuth2Client(BaseModel, OAuth2ClientMixin): + __tablename__ = 'oauth2_client' + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column( + db.Integer, db.ForeignKey('users.id', ondelete='CASCADE') + ) + user = db.relationship('User') + + def serialize(self) -> Dict: + return { + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'id': self.id, + 'name': self.client_name, + 'redirect_uris': self.redirect_uris, + 'website': self.client_uri, + } + + +class OAuth2AuthorizationCode(BaseModel, OAuth2AuthorizationCodeMixin): + __tablename__ = 'oauth2_code' + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column( + db.Integer, db.ForeignKey('users.id', ondelete='CASCADE') + ) + user = db.relationship('User') + + +class OAuth2Token(BaseModel, OAuth2TokenMixin): + __tablename__ = 'oauth2_token' + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column( + db.Integer, db.ForeignKey('users.id', ondelete='CASCADE') + ) + user = db.relationship('User') + + def is_refresh_token_active(self) -> bool: + if self.revoked: + return False + expires_at = self.issued_at + self.expires_in * 2 + return expires_at >= time.time() diff --git a/fittrackee/oauth2/routes.py b/fittrackee/oauth2/routes.py new file mode 100644 index 00000000..f913e097 --- /dev/null +++ b/fittrackee/oauth2/routes.py @@ -0,0 +1,53 @@ +from typing import Dict, Tuple, Union + +from flask import Blueprint, request + +from fittrackee import db +from fittrackee.responses import HttpResponse, InvalidPayloadErrorResponse +from fittrackee.users.decorators import authenticate +from fittrackee.users.models import User + +from .client import create_oauth_client + +oauth_blueprint = Blueprint('oauth', __name__) + +EXPECTED_METADATA_KEYS = [ + 'client_name', + 'client_uri', + 'redirect_uris', + 'scope', +] + + +@oauth_blueprint.route('/oauth/apps', methods=['POST']) +@authenticate +def create_client(auth_user: User) -> Union[HttpResponse, Tuple[Dict, int]]: + client_metadata = request.get_json() + if not client_metadata: + return InvalidPayloadErrorResponse( + message='OAuth client metadata missing' + ) + + missing_keys = [ + key + for key in EXPECTED_METADATA_KEYS + if key not in client_metadata.keys() + ] + if missing_keys: + return InvalidPayloadErrorResponse( + message=( + 'OAuth client metadata missing keys: ' + f'{", ".join(missing_keys)}' + ) + ) + + new_client = create_oauth_client(client_metadata, auth_user) + db.session.add(new_client) + db.session.commit() + return ( + { + 'status': 'created', + 'data': {'client': new_client.serialize()}, + }, + 201, + ) diff --git a/fittrackee/oauth2/server.py b/fittrackee/oauth2/server.py new file mode 100644 index 00000000..2ec8ff6e --- /dev/null +++ b/fittrackee/oauth2/server.py @@ -0,0 +1,16 @@ +from authlib.integrations.flask_oauth2 import AuthorizationServer +from authlib.integrations.sqla_oauth2 import ( + create_query_client_func, + create_save_token_func, +) + +from fittrackee import db + +from .models import OAuth2Client, OAuth2Token + +query_client = create_query_client_func(db.session, OAuth2Client) +save_token = create_save_token_func(db.session, OAuth2Token) +authorization_server = AuthorizationServer( + query_client=query_client, + save_token=save_token, +) diff --git a/fittrackee/tests/conftest.py b/fittrackee/tests/conftest.py index 2b643709..e8e959da 100644 --- a/fittrackee/tests/conftest.py +++ b/fittrackee/tests/conftest.py @@ -10,6 +10,7 @@ os.environ['DATABASE_URL'] = os.environ['DATABASE_TEST_URL'] TEMP_FOLDER = '/tmp/FitTrackee' os.environ['UPLOAD_FOLDER'] = TEMP_FOLDER os.environ['APP_LOG'] = TEMP_FOLDER + '/fittrackee.log' +os.environ['AUTHLIB_INSECURE_TRANSPORT'] = '1' pytest_plugins = [ 'fittrackee.tests.fixtures.fixtures_app', diff --git a/fittrackee/tests/mixins.py b/fittrackee/tests/mixins.py index 22eb89e6..63c7c8a1 100644 --- a/fittrackee/tests/mixins.py +++ b/fittrackee/tests/mixins.py @@ -1,4 +1,5 @@ import json +from random import randint from typing import Any, Dict, Optional, Tuple from flask import Flask @@ -18,10 +19,18 @@ class RandomMixin: ) -> str: return random_string(length, prefix, suffix) + @staticmethod + def random_domain() -> str: + return random_string(prefix='https://', suffix='com') + @staticmethod def random_email() -> str: return random_email() + @staticmethod + def random_int(min_val: int = 0, max_val: int = 999999) -> int: + return randint(min_val, max_val) + class ApiTestCaseMixin(RandomMixin): @staticmethod diff --git a/fittrackee/tests/oauth2/__init__.py b/fittrackee/tests/oauth2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fittrackee/tests/oauth2/test_oauth2_client.py b/fittrackee/tests/oauth2/test_oauth2_client.py new file mode 100644 index 00000000..b8638f16 --- /dev/null +++ b/fittrackee/tests/oauth2/test_oauth2_client.py @@ -0,0 +1,137 @@ +from time import time +from typing import Dict +from unittest.mock import patch + +from flask import Flask + +from fittrackee.oauth2.client import create_oauth_client +from fittrackee.oauth2.models import OAuth2Client +from fittrackee.users.models import User + +from ..utils import random_domain, random_string + +TEST_METADATA = { + 'client_name': random_string(), + 'client_uri': random_string(), + 'redirect_uris': [random_domain()], + 'scope': 'read write', +} + + +class TestCreateOAuth2Client: + def test_it_creates_oauth_client(self, app: Flask, user_1: User) -> None: + oauth_client = create_oauth_client(TEST_METADATA, user_1) + + assert isinstance(oauth_client, OAuth2Client) + + def test_oauth_client_id_is_generated_with_gen_salt( + self, app: Flask, user_1: User + ) -> None: + client_id = random_string() + with patch( + 'fittrackee.oauth2.client.gen_salt', return_value=client_id + ): + + oauth_client = create_oauth_client(TEST_METADATA, user_1) + + assert oauth_client.client_id == client_id + + def test_oauth_client_client_id_issued_at_is_initialized( + self, app: Flask, user_1: User + ) -> None: + client_id_issued_at = int(time()) + with patch( + 'fittrackee.oauth2.client.time', return_value=client_id_issued_at + ): + oauth_client = create_oauth_client(TEST_METADATA, user_1) + + assert oauth_client.client_id_issued_at == client_id_issued_at + + def test_oauth_client_has_expected_name( + self, app: Flask, user_1: User + ) -> None: + client_name = random_string() + client_metadata: Dict = {**TEST_METADATA, 'client_name': client_name} + + oauth_client = create_oauth_client(client_metadata, user_1) + + assert oauth_client.client_name == client_name + + def test_oauth_client_has_expected_client_uri( + self, app: Flask, user_1: User + ) -> None: + client_uri = random_domain() + client_metadata: Dict = {**TEST_METADATA, 'client_uri': client_uri} + + oauth_client = create_oauth_client(client_metadata, user_1) + + assert oauth_client.client_uri == client_uri + + def test_oauth_client_has_expected_grant_types( + self, app: Flask, user_1: User + ) -> None: + oauth_client = create_oauth_client(TEST_METADATA, user_1) + + assert oauth_client.grant_types == ['authorization_code'] + + def test_oauth_client_has_expected_redirect_uris( + self, app: Flask, user_1: User + ) -> None: + redirect_uris = [random_domain()] + client_metadata: Dict = { + **TEST_METADATA, + 'redirect_uris': redirect_uris, + } + + oauth_client = create_oauth_client(client_metadata, user_1) + + assert oauth_client.redirect_uris == redirect_uris + + def test_oauth_client_has_expected_response_types( + self, app: Flask, user_1: User + ) -> None: + response_types = ['code'] + client_metadata: Dict = { + **TEST_METADATA, + 'response_types': response_types, + } + + oauth_client = create_oauth_client(client_metadata, user_1) + + assert oauth_client.response_types == ['code'] + + def test_oauth_client_has_expected_scope( + self, app: Flask, user_1: User + ) -> None: + scope = 'profile' + client_metadata: Dict = {**TEST_METADATA, 'scope': scope} + + oauth_client = create_oauth_client(client_metadata, user_1) + + assert oauth_client.scope == scope + + def test_oauth_client_has_expected_token_endpoint_auth_method( + self, app: Flask, user_1: User + ) -> None: + oauth_client = create_oauth_client(TEST_METADATA, user_1) + + assert oauth_client.token_endpoint_auth_method == 'client_secret_post' + + def test_when_auth_method_is_not_none_oauth_client_secret_is_generated( + self, app: Flask, user_1: User + ) -> None: + client_secret = random_string() + with patch( + 'fittrackee.oauth2.client.gen_salt', return_value=client_secret + ): + + oauth_client = create_oauth_client(TEST_METADATA, user_1) + + assert oauth_client.client_secret == client_secret + + def test_it_creates_oauth_client_for_given_user( + self, app: Flask, user_1: User + ) -> None: + oauth_client = create_oauth_client(TEST_METADATA, user_1) + + assert oauth_client.user_id == user_1.id diff --git a/fittrackee/tests/oauth2/test_oauth2_models.py b/fittrackee/tests/oauth2/test_oauth2_models.py new file mode 100644 index 00000000..000ceb77 --- /dev/null +++ b/fittrackee/tests/oauth2/test_oauth2_models.py @@ -0,0 +1,36 @@ +from flask import Flask + +from fittrackee.oauth2.models import OAuth2Client + +from ..mixins import RandomMixin + + +class TestOAuthClientSerialize(RandomMixin): + def test_it_returns_oauth_client(self, app: Flask) -> None: + oauth_client = OAuth2Client( + id=self.random_int(), + client_id=self.random_string(), + client_id_issued_at=self.random_int(), + ) + oauth_client.set_client_metadata( + { + 'client_name': self.random_string(), + 'redirect_uris': [self.random_string()], + 'client_uri': self.random_domain(), + } + ) + + serialized_oauth_client = oauth_client.serialize() + + assert serialized_oauth_client['client_id'] == oauth_client.client_id + assert ( + serialized_oauth_client['client_secret'] + == oauth_client.client_secret + ) + assert serialized_oauth_client['id'] == oauth_client.id + assert serialized_oauth_client['name'] == oauth_client.client_name + assert ( + serialized_oauth_client['redirect_uris'] + == oauth_client.redirect_uris + ) + assert serialized_oauth_client['website'] == oauth_client.client_uri diff --git a/fittrackee/tests/oauth2/test_oauth2_routes.py b/fittrackee/tests/oauth2/test_oauth2_routes.py new file mode 100644 index 00000000..169d1074 --- /dev/null +++ b/fittrackee/tests/oauth2/test_oauth2_routes.py @@ -0,0 +1,163 @@ +import json +from typing import List, Union +from unittest.mock import patch + +import pytest +from flask import Flask + +from fittrackee.oauth2.models import OAuth2Client +from fittrackee.users.models import User + +from ..mixins import ApiTestCaseMixin +from ..utils import random_domain, random_string + +TEST_METADATA = { + 'client_name': random_string(), + 'client_uri': random_domain(), + 'redirect_uris': [random_domain()], + 'scope': 'read write', +} + + +class TestOAuthClientCreation(ApiTestCaseMixin): + route = '/api/oauth/apps' + + def test_it_returns_error_when_no_user_authenticated( + self, app: Flask, user_1: User + ) -> None: + client, auth_token = self.get_test_client_and_auth_token( + app, user_1.email + ) + + response = client.post( + self.route, + data=json.dumps(TEST_METADATA), + content_type='application/json', + ) + + self.assert_401(response) + + def test_it_returns_error_when_no_metadata_provided( + self, app: Flask, user_1: User + ) -> None: + client, auth_token = self.get_test_client_and_auth_token( + app, user_1.email + ) + + response = client.post( + self.route, + data=json.dumps(dict()), + content_type='application/json', + headers=dict(Authorization=f'Bearer {auth_token}'), + ) + + self.assert_400( + response, error_message='OAuth client metadata missing' + ) + + @pytest.mark.parametrize( + 'missing_key', + [ + 'client_name', + 'client_uri', + 'redirect_uris', + 'scope', + ], + ) + def test_it_returns_error_when_metadata_key_is_missing( + self, app: Flask, user_1: User, missing_key: str + ) -> None: + metadata = TEST_METADATA.copy() + del metadata[missing_key] + client, auth_token = self.get_test_client_and_auth_token( + app, user_1.email + ) + + response = client.post( + self.route, + data=json.dumps(metadata), + content_type='application/json', + headers=dict(Authorization=f'Bearer {auth_token}'), + ) + + self.assert_400( + response, + error_message=f'OAuth client metadata missing keys: {missing_key}', + ) + + def test_it_creates_oauth_client(self, app: Flask, user_1: User) -> None: + client, auth_token = self.get_test_client_and_auth_token( + app, user_1.email + ) + + response = client.post( + self.route, + data=json.dumps(TEST_METADATA), + content_type='application/json', + headers=dict(Authorization=f'Bearer {auth_token}'), + ) + + assert response.status_code == 201 + oauth_client = OAuth2Client.query.first() + assert oauth_client is not None + + def test_it_returns_serialized_oauth_client( + self, app: Flask, user_1: User + ) -> None: + client, auth_token = self.get_test_client_and_auth_token( + app, user_1.email + ) + client_id = self.random_string() + client_secret = self.random_string() + with patch( + 'fittrackee.oauth2.client.gen_salt', + side_effect=[client_id, client_secret], + ): + + response = client.post( + self.route, + data=json.dumps(TEST_METADATA), + content_type='application/json', + headers=dict(Authorization=f'Bearer {auth_token}'), + ) + + data = json.loads(response.data.decode()) + assert data['data']['client']['client_id'] == client_id + assert data['data']['client']['client_secret'] == client_secret + assert data['data']['client']['name'] == TEST_METADATA['client_name'] + assert ( + data['data']['client']['redirect_uris'] + == TEST_METADATA['redirect_uris'] + ) + assert data['data']['client']['website'] == TEST_METADATA['client_uri'] + + @pytest.mark.parametrize( + 'input_key,expected_value', + [ + ('grant_types', ['authorization_code']), + ('response_types', ['code']), + ('token_endpoint_auth_method', 'client_secret_post'), + ], + ) + def test_it_always_create_oauth_client_with_authorization_grant( + self, + app: Flask, + user_1: User, + input_key: str, + expected_value: Union[List, str], + ) -> None: + client, auth_token = self.get_test_client_and_auth_token( + app, user_1.email + ) + + client.post( + self.route, + data=json.dumps( + {**TEST_METADATA, input_key: self.random_string()} + ), + content_type='application/json', + headers=dict(Authorization=f'Bearer {auth_token}'), + ) + + oauth_client = OAuth2Client.query.first() + assert getattr(oauth_client, input_key) == expected_value diff --git a/fittrackee/tests/utils.py b/fittrackee/tests/utils.py index a93db3cf..a2bfb460 100644 --- a/fittrackee/tests/utils.py +++ b/fittrackee/tests/utils.py @@ -24,6 +24,10 @@ def random_string( ) +def random_domain() -> str: + return random_string(prefix='https://', suffix='.com') + + def random_email() -> str: return random_string(suffix='@example.com') diff --git a/fittrackee/users/models.py b/fittrackee/users/models.py index 19aadbe1..5ad33f8b 100644 --- a/fittrackee/users/models.py +++ b/fittrackee/users/models.py @@ -111,6 +111,9 @@ class User(BaseModel): new_password, current_app.config.get('BCRYPT_LOG_ROUNDS') ).decode() + def get_user_id(self) -> int: + return self.id + @hybrid_property def workouts_count(self) -> int: return Workout.query.filter(Workout.user_id == self.id).count() diff --git a/poetry.lock b/poetry.lock index ff8f4c4b..d420fcde 100644 --- a/poetry.lock +++ b/poetry.lock @@ -64,6 +64,17 @@ docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"] +[[package]] +name = "authlib" +version = "1.0.1" +description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients." +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +cryptography = ">=3.2" + [[package]] name = "babel" version = "2.10.1" @@ -1466,7 +1477,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "e3bc153f918e57b66b117505b29466555da5ed642b0bbdc58bb5bcb9bd2ea688" +content-hash = "84bb7759a500c94555fb8d8e5ca25bd98b287f6c53c80eb259bcc3c0d31b1031" [metadata.files] alabaster = [ @@ -1493,6 +1504,10 @@ attrs = [ {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"}, {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"}, ] +authlib = [ + {file = "Authlib-1.0.1-py2.py3-none-any.whl", hash = "sha256:1286e2d5ef5bfe5a11cc2d0a0d1031f0393f6ce4d61f5121cfe87fa0054e98bd"}, + {file = "Authlib-1.0.1.tar.gz", hash = "sha256:6e74a4846ac36dfc882b3cc2fbd3d9eb410a627f2f2dc11771276655345223b1"}, +] babel = [ {file = "Babel-2.10.1-py3-none-any.whl", hash = "sha256:3f349e85ad3154559ac4930c3918247d319f21910d5ce4b25d439ed8693b98d2"}, {file = "Babel-2.10.1.tar.gz", hash = "sha256:98aeaca086133efb3e1e2aad0396987490c8425929ddbcfe0550184fdc54cd13"}, diff --git a/pyproject.toml b/pyproject.toml index ef2b778a..023bac9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ staticmap = "^0.5.4" SQLAlchemy = "1.4.36" pyOpenSSL = "^22.0" ua-parser = "^0.10.0" +Authlib = "^1.0.1" [tool.poetry.dev-dependencies] black = "^22.3"