API - init OAuth server and oauth clients creation
This commit is contained in:
parent
c13e9e0286
commit
c6cd7ff67c
@ -16,6 +16,7 @@ from flask_dramatiq import Dramatiq
|
||||
from flask_migrate import Migrate
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||
|
||||
from fittrackee.emails.email import EmailService
|
||||
from fittrackee.request import CustomRequest
|
||||
@ -64,6 +65,11 @@ def create_app(init_email: bool = True) -> Flask:
|
||||
migrate.init_app(app, db)
|
||||
dramatiq.init_app(app)
|
||||
|
||||
# set oauth2
|
||||
from fittrackee.oauth2.config import config_oauth
|
||||
|
||||
config_oauth(app)
|
||||
|
||||
# set up email if 'EMAIL_URL' is initialized
|
||||
if init_email:
|
||||
if app.config['EMAIL_URL']:
|
||||
@ -95,6 +101,7 @@ def create_app(init_email: bool = True) -> Flask:
|
||||
pass
|
||||
|
||||
from .application.app_config import config_blueprint # noqa
|
||||
from .oauth2.routes import oauth_blueprint # noqa
|
||||
from .users.auth import auth_blueprint # noqa
|
||||
from .users.users import users_blueprint # noqa
|
||||
from .workouts.records import records_blueprint # noqa
|
||||
@ -103,6 +110,7 @@ def create_app(init_email: bool = True) -> Flask:
|
||||
from .workouts.workouts import workouts_blueprint # noqa
|
||||
|
||||
app.register_blueprint(auth_blueprint, url_prefix='/api')
|
||||
app.register_blueprint(oauth_blueprint, url_prefix='/api')
|
||||
app.register_blueprint(config_blueprint, url_prefix='/api')
|
||||
app.register_blueprint(records_blueprint, url_prefix='/api')
|
||||
app.register_blueprint(sports_blueprint, url_prefix='/api')
|
||||
@ -153,4 +161,8 @@ def create_app(init_email: bool = True) -> Flask:
|
||||
else:
|
||||
return render_template('index.html')
|
||||
|
||||
# to get headers, especially 'X-Forwarded-Proto' for scheme needed by
|
||||
# Authlib, when the application is running behind a proxy server
|
||||
app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore
|
||||
|
||||
return app
|
||||
|
76
fittrackee/migrations/versions/24_84d840ce853b_add_oauth.py
Normal file
76
fittrackee/migrations/versions/24_84d840ce853b_add_oauth.py
Normal file
@ -0,0 +1,76 @@
|
||||
"""add OAuth 2.0
|
||||
|
||||
Revision ID: 84d840ce853b
|
||||
Revises: 5e3a3a31c432
|
||||
Create Date: 2022-05-27 10:54:02.284543
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '84d840ce853b'
|
||||
down_revision = '5e3a3a31c432'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('oauth2_client',
|
||||
sa.Column('client_id', sa.String(length=48), nullable=True),
|
||||
sa.Column('client_secret', sa.String(length=120), nullable=True),
|
||||
sa.Column('client_id_issued_at', sa.Integer(), nullable=False),
|
||||
sa.Column('client_secret_expires_at', sa.Integer(), nullable=False),
|
||||
sa.Column('client_metadata', sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_oauth2_client_client_id'), 'oauth2_client', ['client_id'], unique=False)
|
||||
op.create_table('oauth2_code',
|
||||
sa.Column('code', sa.String(length=120), nullable=False),
|
||||
sa.Column('client_id', sa.String(length=48), nullable=True),
|
||||
sa.Column('redirect_uri', sa.Text(), nullable=True),
|
||||
sa.Column('response_type', sa.Text(), nullable=True),
|
||||
sa.Column('scope', sa.Text(), nullable=True),
|
||||
sa.Column('nonce', sa.Text(), nullable=True),
|
||||
sa.Column('auth_time', sa.Integer(), nullable=False),
|
||||
sa.Column('code_challenge', sa.Text(), nullable=True),
|
||||
sa.Column('code_challenge_method', sa.String(length=48), nullable=True),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('code')
|
||||
)
|
||||
op.create_table('oauth2_token',
|
||||
sa.Column('client_id', sa.String(length=48), nullable=True),
|
||||
sa.Column('token_type', sa.String(length=40), nullable=True),
|
||||
sa.Column('access_token', sa.String(length=255), nullable=False),
|
||||
sa.Column('refresh_token', sa.String(length=255), nullable=True),
|
||||
sa.Column('scope', sa.Text(), nullable=True),
|
||||
sa.Column('issued_at', sa.Integer(), nullable=False),
|
||||
sa.Column('access_token_revoked_at', sa.Integer(), nullable=False),
|
||||
sa.Column('refresh_token_revoked_at', sa.Integer(), nullable=False),
|
||||
sa.Column('expires_in', sa.Integer(), nullable=False),
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('access_token')
|
||||
)
|
||||
op.create_index(op.f('ix_oauth2_token_refresh_token'), 'oauth2_token', ['refresh_token'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_oauth2_token_refresh_token'), table_name='oauth2_token')
|
||||
op.drop_table('oauth2_token')
|
||||
op.drop_table('oauth2_code')
|
||||
op.drop_index(op.f('ix_oauth2_client_client_id'), table_name='oauth2_client')
|
||||
op.drop_table('oauth2_client')
|
||||
# ### end Alembic commands ###
|
0
fittrackee/oauth2/__init__.py
Normal file
0
fittrackee/oauth2/__init__.py
Normal file
35
fittrackee/oauth2/client.py
Normal file
35
fittrackee/oauth2/client.py
Normal file
@ -0,0 +1,35 @@
|
||||
from time import time
|
||||
from typing import Dict
|
||||
|
||||
from werkzeug.security import gen_salt
|
||||
|
||||
from fittrackee.oauth2.models import OAuth2Client
|
||||
from fittrackee.users.models import User
|
||||
|
||||
|
||||
def create_oauth_client(metadata: Dict, user: User) -> OAuth2Client:
|
||||
"""
|
||||
Create oauth client for 3rd-party applications.
|
||||
|
||||
Only Authorization Code Grant with 'client_secret_post' as method
|
||||
is supported.
|
||||
"""
|
||||
client_metadata = {
|
||||
'client_name': metadata['client_name'],
|
||||
'client_uri': metadata['client_uri'],
|
||||
'redirect_uris': metadata['redirect_uris'],
|
||||
'scope': metadata['scope'],
|
||||
'grant_types': ['authorization_code'],
|
||||
'response_types': ['code'],
|
||||
'token_endpoint_auth_method': 'client_secret_post',
|
||||
}
|
||||
client_id = gen_salt(24)
|
||||
client_id_issued_at = int(time())
|
||||
client = OAuth2Client(
|
||||
client_id=client_id,
|
||||
client_id_issued_at=client_id_issued_at,
|
||||
user_id=user.id,
|
||||
)
|
||||
client.set_client_metadata(client_metadata)
|
||||
client.client_secret = gen_salt(48)
|
||||
return client
|
14
fittrackee/oauth2/config.py
Normal file
14
fittrackee/oauth2/config.py
Normal file
@ -0,0 +1,14 @@
|
||||
from authlib.oauth2.rfc7636 import CodeChallenge
|
||||
from flask import Flask
|
||||
|
||||
from .grants import AuthorizationCodeGrant
|
||||
from .server import authorization_server
|
||||
|
||||
|
||||
def config_oauth(app: Flask) -> None:
|
||||
authorization_server.init_app(app)
|
||||
|
||||
# supported grants
|
||||
authorization_server.register_grant(
|
||||
AuthorizationCodeGrant, [CodeChallenge(required=True)]
|
||||
)
|
65
fittrackee/oauth2/grants.py
Normal file
65
fittrackee/oauth2/grants.py
Normal file
@ -0,0 +1,65 @@
|
||||
from typing import Optional
|
||||
|
||||
from authlib.oauth2 import OAuth2Request
|
||||
from authlib.oauth2.rfc6749 import grants
|
||||
|
||||
from fittrackee import db
|
||||
from fittrackee.users.models import User
|
||||
|
||||
from .models import OAuth2AuthorizationCode, OAuth2Client, OAuth2Token
|
||||
|
||||
|
||||
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||||
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post']
|
||||
|
||||
def save_authorization_code(
|
||||
self, code: str, request: OAuth2Request
|
||||
) -> OAuth2AuthorizationCode:
|
||||
code_challenge = request.data.get('code_challenge')
|
||||
code_challenge_method = request.data.get('code_challenge_method')
|
||||
auth_code = OAuth2AuthorizationCode(
|
||||
code=code,
|
||||
client_id=request.client.client_id,
|
||||
redirect_uri=request.redirect_uri,
|
||||
scope=request.scope,
|
||||
user_id=request.user.id,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
)
|
||||
db.session.add(auth_code)
|
||||
db.session.commit()
|
||||
return auth_code
|
||||
|
||||
def query_authorization_code(
|
||||
self, code: str, client: OAuth2Client
|
||||
) -> Optional[OAuth2AuthorizationCode]:
|
||||
auth_code = OAuth2AuthorizationCode.query.filter_by(
|
||||
code=code, client_id=client.client_id
|
||||
).first()
|
||||
if auth_code and not auth_code.is_expired():
|
||||
return auth_code
|
||||
return None
|
||||
|
||||
def delete_authorization_code(
|
||||
self, authorization_code: OAuth2AuthorizationCode
|
||||
) -> None:
|
||||
db.session.delete(authorization_code)
|
||||
db.session.commit()
|
||||
|
||||
def authenticate_user(
|
||||
self, authorization_code: OAuth2AuthorizationCode
|
||||
) -> User:
|
||||
return User.query.get(authorization_code.user_id)
|
||||
|
||||
|
||||
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
||||
def authenticate_refresh_token(self, refresh_token: str) -> Optional[str]:
|
||||
token = OAuth2Token.query.filter_by(
|
||||
refresh_token=refresh_token
|
||||
).first()
|
||||
if token and token.is_refresh_token_active():
|
||||
return token
|
||||
return None
|
||||
|
||||
def authenticate_user(self, credential: OAuth2Token) -> User:
|
||||
return User.query.get(credential.user_id)
|
59
fittrackee/oauth2/models.py
Normal file
59
fittrackee/oauth2/models.py
Normal file
@ -0,0 +1,59 @@
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from authlib.integrations.sqla_oauth2 import (
|
||||
OAuth2AuthorizationCodeMixin,
|
||||
OAuth2ClientMixin,
|
||||
OAuth2TokenMixin,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import DeclarativeMeta
|
||||
|
||||
from fittrackee import db
|
||||
|
||||
BaseModel: DeclarativeMeta = db.Model
|
||||
|
||||
|
||||
class OAuth2Client(BaseModel, OAuth2ClientMixin):
|
||||
__tablename__ = 'oauth2_client'
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
user_id = db.Column(
|
||||
db.Integer, db.ForeignKey('users.id', ondelete='CASCADE')
|
||||
)
|
||||
user = db.relationship('User')
|
||||
|
||||
def serialize(self) -> Dict:
|
||||
return {
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'id': self.id,
|
||||
'name': self.client_name,
|
||||
'redirect_uris': self.redirect_uris,
|
||||
'website': self.client_uri,
|
||||
}
|
||||
|
||||
|
||||
class OAuth2AuthorizationCode(BaseModel, OAuth2AuthorizationCodeMixin):
|
||||
__tablename__ = 'oauth2_code'
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
user_id = db.Column(
|
||||
db.Integer, db.ForeignKey('users.id', ondelete='CASCADE')
|
||||
)
|
||||
user = db.relationship('User')
|
||||
|
||||
|
||||
class OAuth2Token(BaseModel, OAuth2TokenMixin):
|
||||
__tablename__ = 'oauth2_token'
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
user_id = db.Column(
|
||||
db.Integer, db.ForeignKey('users.id', ondelete='CASCADE')
|
||||
)
|
||||
user = db.relationship('User')
|
||||
|
||||
def is_refresh_token_active(self) -> bool:
|
||||
if self.revoked:
|
||||
return False
|
||||
expires_at = self.issued_at + self.expires_in * 2
|
||||
return expires_at >= time.time()
|
53
fittrackee/oauth2/routes.py
Normal file
53
fittrackee/oauth2/routes.py
Normal file
@ -0,0 +1,53 @@
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from flask import Blueprint, request
|
||||
|
||||
from fittrackee import db
|
||||
from fittrackee.responses import HttpResponse, InvalidPayloadErrorResponse
|
||||
from fittrackee.users.decorators import authenticate
|
||||
from fittrackee.users.models import User
|
||||
|
||||
from .client import create_oauth_client
|
||||
|
||||
oauth_blueprint = Blueprint('oauth', __name__)
|
||||
|
||||
EXPECTED_METADATA_KEYS = [
|
||||
'client_name',
|
||||
'client_uri',
|
||||
'redirect_uris',
|
||||
'scope',
|
||||
]
|
||||
|
||||
|
||||
@oauth_blueprint.route('/oauth/apps', methods=['POST'])
|
||||
@authenticate
|
||||
def create_client(auth_user: User) -> Union[HttpResponse, Tuple[Dict, int]]:
|
||||
client_metadata = request.get_json()
|
||||
if not client_metadata:
|
||||
return InvalidPayloadErrorResponse(
|
||||
message='OAuth client metadata missing'
|
||||
)
|
||||
|
||||
missing_keys = [
|
||||
key
|
||||
for key in EXPECTED_METADATA_KEYS
|
||||
if key not in client_metadata.keys()
|
||||
]
|
||||
if missing_keys:
|
||||
return InvalidPayloadErrorResponse(
|
||||
message=(
|
||||
'OAuth client metadata missing keys: '
|
||||
f'{", ".join(missing_keys)}'
|
||||
)
|
||||
)
|
||||
|
||||
new_client = create_oauth_client(client_metadata, auth_user)
|
||||
db.session.add(new_client)
|
||||
db.session.commit()
|
||||
return (
|
||||
{
|
||||
'status': 'created',
|
||||
'data': {'client': new_client.serialize()},
|
||||
},
|
||||
201,
|
||||
)
|
16
fittrackee/oauth2/server.py
Normal file
16
fittrackee/oauth2/server.py
Normal file
@ -0,0 +1,16 @@
|
||||
from authlib.integrations.flask_oauth2 import AuthorizationServer
|
||||
from authlib.integrations.sqla_oauth2 import (
|
||||
create_query_client_func,
|
||||
create_save_token_func,
|
||||
)
|
||||
|
||||
from fittrackee import db
|
||||
|
||||
from .models import OAuth2Client, OAuth2Token
|
||||
|
||||
query_client = create_query_client_func(db.session, OAuth2Client)
|
||||
save_token = create_save_token_func(db.session, OAuth2Token)
|
||||
authorization_server = AuthorizationServer(
|
||||
query_client=query_client,
|
||||
save_token=save_token,
|
||||
)
|
@ -10,6 +10,7 @@ os.environ['DATABASE_URL'] = os.environ['DATABASE_TEST_URL']
|
||||
TEMP_FOLDER = '/tmp/FitTrackee'
|
||||
os.environ['UPLOAD_FOLDER'] = TEMP_FOLDER
|
||||
os.environ['APP_LOG'] = TEMP_FOLDER + '/fittrackee.log'
|
||||
os.environ['AUTHLIB_INSECURE_TRANSPORT'] = '1'
|
||||
|
||||
pytest_plugins = [
|
||||
'fittrackee.tests.fixtures.fixtures_app',
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from random import randint
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from flask import Flask
|
||||
@ -18,10 +19,18 @@ class RandomMixin:
|
||||
) -> str:
|
||||
return random_string(length, prefix, suffix)
|
||||
|
||||
@staticmethod
|
||||
def random_domain() -> str:
|
||||
return random_string(prefix='https://', suffix='com')
|
||||
|
||||
@staticmethod
|
||||
def random_email() -> str:
|
||||
return random_email()
|
||||
|
||||
@staticmethod
|
||||
def random_int(min_val: int = 0, max_val: int = 999999) -> int:
|
||||
return randint(min_val, max_val)
|
||||
|
||||
|
||||
class ApiTestCaseMixin(RandomMixin):
|
||||
@staticmethod
|
||||
|
0
fittrackee/tests/oauth2/__init__.py
Normal file
0
fittrackee/tests/oauth2/__init__.py
Normal file
137
fittrackee/tests/oauth2/test_oauth2_client.py
Normal file
137
fittrackee/tests/oauth2/test_oauth2_client.py
Normal file
@ -0,0 +1,137 @@
|
||||
from time import time
|
||||
from typing import Dict
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from fittrackee.oauth2.client import create_oauth_client
|
||||
from fittrackee.oauth2.models import OAuth2Client
|
||||
from fittrackee.users.models import User
|
||||
|
||||
from ..utils import random_domain, random_string
|
||||
|
||||
TEST_METADATA = {
|
||||
'client_name': random_string(),
|
||||
'client_uri': random_string(),
|
||||
'redirect_uris': [random_domain()],
|
||||
'scope': 'read write',
|
||||
}
|
||||
|
||||
|
||||
class TestCreateOAuth2Client:
|
||||
def test_it_creates_oauth_client(self, app: Flask, user_1: User) -> None:
|
||||
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||
|
||||
assert isinstance(oauth_client, OAuth2Client)
|
||||
|
||||
def test_oauth_client_id_is_generated_with_gen_salt(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
client_id = random_string()
|
||||
with patch(
|
||||
'fittrackee.oauth2.client.gen_salt', return_value=client_id
|
||||
):
|
||||
|
||||
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||
|
||||
assert oauth_client.client_id == client_id
|
||||
|
||||
def test_oauth_client_client_id_issued_at_is_initialized(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
client_id_issued_at = int(time())
|
||||
with patch(
|
||||
'fittrackee.oauth2.client.time', return_value=client_id_issued_at
|
||||
):
|
||||
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||
|
||||
assert oauth_client.client_id_issued_at == client_id_issued_at
|
||||
|
||||
def test_oauth_client_has_expected_name(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
client_name = random_string()
|
||||
client_metadata: Dict = {**TEST_METADATA, 'client_name': client_name}
|
||||
|
||||
oauth_client = create_oauth_client(client_metadata, user_1)
|
||||
|
||||
assert oauth_client.client_name == client_name
|
||||
|
||||
def test_oauth_client_has_expected_client_uri(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
client_uri = random_domain()
|
||||
client_metadata: Dict = {**TEST_METADATA, 'client_uri': client_uri}
|
||||
|
||||
oauth_client = create_oauth_client(client_metadata, user_1)
|
||||
|
||||
assert oauth_client.client_uri == client_uri
|
||||
|
||||
def test_oauth_client_has_expected_grant_types(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||
|
||||
assert oauth_client.grant_types == ['authorization_code']
|
||||
|
||||
def test_oauth_client_has_expected_redirect_uris(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
redirect_uris = [random_domain()]
|
||||
client_metadata: Dict = {
|
||||
**TEST_METADATA,
|
||||
'redirect_uris': redirect_uris,
|
||||
}
|
||||
|
||||
oauth_client = create_oauth_client(client_metadata, user_1)
|
||||
|
||||
assert oauth_client.redirect_uris == redirect_uris
|
||||
|
||||
def test_oauth_client_has_expected_response_types(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
response_types = ['code']
|
||||
client_metadata: Dict = {
|
||||
**TEST_METADATA,
|
||||
'response_types': response_types,
|
||||
}
|
||||
|
||||
oauth_client = create_oauth_client(client_metadata, user_1)
|
||||
|
||||
assert oauth_client.response_types == ['code']
|
||||
|
||||
def test_oauth_client_has_expected_scope(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
scope = 'profile'
|
||||
client_metadata: Dict = {**TEST_METADATA, 'scope': scope}
|
||||
|
||||
oauth_client = create_oauth_client(client_metadata, user_1)
|
||||
|
||||
assert oauth_client.scope == scope
|
||||
|
||||
def test_oauth_client_has_expected_token_endpoint_auth_method(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||
|
||||
assert oauth_client.token_endpoint_auth_method == 'client_secret_post'
|
||||
|
||||
def test_when_auth_method_is_not_none_oauth_client_secret_is_generated(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
client_secret = random_string()
|
||||
with patch(
|
||||
'fittrackee.oauth2.client.gen_salt', return_value=client_secret
|
||||
):
|
||||
|
||||
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||
|
||||
assert oauth_client.client_secret == client_secret
|
||||
|
||||
def test_it_creates_oauth_client_for_given_user(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||
|
||||
assert oauth_client.user_id == user_1.id
|
36
fittrackee/tests/oauth2/test_oauth2_models.py
Normal file
36
fittrackee/tests/oauth2/test_oauth2_models.py
Normal file
@ -0,0 +1,36 @@
|
||||
from flask import Flask
|
||||
|
||||
from fittrackee.oauth2.models import OAuth2Client
|
||||
|
||||
from ..mixins import RandomMixin
|
||||
|
||||
|
||||
class TestOAuthClientSerialize(RandomMixin):
|
||||
def test_it_returns_oauth_client(self, app: Flask) -> None:
|
||||
oauth_client = OAuth2Client(
|
||||
id=self.random_int(),
|
||||
client_id=self.random_string(),
|
||||
client_id_issued_at=self.random_int(),
|
||||
)
|
||||
oauth_client.set_client_metadata(
|
||||
{
|
||||
'client_name': self.random_string(),
|
||||
'redirect_uris': [self.random_string()],
|
||||
'client_uri': self.random_domain(),
|
||||
}
|
||||
)
|
||||
|
||||
serialized_oauth_client = oauth_client.serialize()
|
||||
|
||||
assert serialized_oauth_client['client_id'] == oauth_client.client_id
|
||||
assert (
|
||||
serialized_oauth_client['client_secret']
|
||||
== oauth_client.client_secret
|
||||
)
|
||||
assert serialized_oauth_client['id'] == oauth_client.id
|
||||
assert serialized_oauth_client['name'] == oauth_client.client_name
|
||||
assert (
|
||||
serialized_oauth_client['redirect_uris']
|
||||
== oauth_client.redirect_uris
|
||||
)
|
||||
assert serialized_oauth_client['website'] == oauth_client.client_uri
|
163
fittrackee/tests/oauth2/test_oauth2_routes.py
Normal file
163
fittrackee/tests/oauth2/test_oauth2_routes.py
Normal file
@ -0,0 +1,163 @@
|
||||
import json
|
||||
from typing import List, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from fittrackee.oauth2.models import OAuth2Client
|
||||
from fittrackee.users.models import User
|
||||
|
||||
from ..mixins import ApiTestCaseMixin
|
||||
from ..utils import random_domain, random_string
|
||||
|
||||
TEST_METADATA = {
|
||||
'client_name': random_string(),
|
||||
'client_uri': random_domain(),
|
||||
'redirect_uris': [random_domain()],
|
||||
'scope': 'read write',
|
||||
}
|
||||
|
||||
|
||||
class TestOAuthClientCreation(ApiTestCaseMixin):
|
||||
route = '/api/oauth/apps'
|
||||
|
||||
def test_it_returns_error_when_no_user_authenticated(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
client, auth_token = self.get_test_client_and_auth_token(
|
||||
app, user_1.email
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
self.route,
|
||||
data=json.dumps(TEST_METADATA),
|
||||
content_type='application/json',
|
||||
)
|
||||
|
||||
self.assert_401(response)
|
||||
|
||||
def test_it_returns_error_when_no_metadata_provided(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
client, auth_token = self.get_test_client_and_auth_token(
|
||||
app, user_1.email
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
self.route,
|
||||
data=json.dumps(dict()),
|
||||
content_type='application/json',
|
||||
headers=dict(Authorization=f'Bearer {auth_token}'),
|
||||
)
|
||||
|
||||
self.assert_400(
|
||||
response, error_message='OAuth client metadata missing'
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'missing_key',
|
||||
[
|
||||
'client_name',
|
||||
'client_uri',
|
||||
'redirect_uris',
|
||||
'scope',
|
||||
],
|
||||
)
|
||||
def test_it_returns_error_when_metadata_key_is_missing(
|
||||
self, app: Flask, user_1: User, missing_key: str
|
||||
) -> None:
|
||||
metadata = TEST_METADATA.copy()
|
||||
del metadata[missing_key]
|
||||
client, auth_token = self.get_test_client_and_auth_token(
|
||||
app, user_1.email
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
self.route,
|
||||
data=json.dumps(metadata),
|
||||
content_type='application/json',
|
||||
headers=dict(Authorization=f'Bearer {auth_token}'),
|
||||
)
|
||||
|
||||
self.assert_400(
|
||||
response,
|
||||
error_message=f'OAuth client metadata missing keys: {missing_key}',
|
||||
)
|
||||
|
||||
def test_it_creates_oauth_client(self, app: Flask, user_1: User) -> None:
|
||||
client, auth_token = self.get_test_client_and_auth_token(
|
||||
app, user_1.email
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
self.route,
|
||||
data=json.dumps(TEST_METADATA),
|
||||
content_type='application/json',
|
||||
headers=dict(Authorization=f'Bearer {auth_token}'),
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
oauth_client = OAuth2Client.query.first()
|
||||
assert oauth_client is not None
|
||||
|
||||
def test_it_returns_serialized_oauth_client(
|
||||
self, app: Flask, user_1: User
|
||||
) -> None:
|
||||
client, auth_token = self.get_test_client_and_auth_token(
|
||||
app, user_1.email
|
||||
)
|
||||
client_id = self.random_string()
|
||||
client_secret = self.random_string()
|
||||
with patch(
|
||||
'fittrackee.oauth2.client.gen_salt',
|
||||
side_effect=[client_id, client_secret],
|
||||
):
|
||||
|
||||
response = client.post(
|
||||
self.route,
|
||||
data=json.dumps(TEST_METADATA),
|
||||
content_type='application/json',
|
||||
headers=dict(Authorization=f'Bearer {auth_token}'),
|
||||
)
|
||||
|
||||
data = json.loads(response.data.decode())
|
||||
assert data['data']['client']['client_id'] == client_id
|
||||
assert data['data']['client']['client_secret'] == client_secret
|
||||
assert data['data']['client']['name'] == TEST_METADATA['client_name']
|
||||
assert (
|
||||
data['data']['client']['redirect_uris']
|
||||
== TEST_METADATA['redirect_uris']
|
||||
)
|
||||
assert data['data']['client']['website'] == TEST_METADATA['client_uri']
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'input_key,expected_value',
|
||||
[
|
||||
('grant_types', ['authorization_code']),
|
||||
('response_types', ['code']),
|
||||
('token_endpoint_auth_method', 'client_secret_post'),
|
||||
],
|
||||
)
|
||||
def test_it_always_create_oauth_client_with_authorization_grant(
|
||||
self,
|
||||
app: Flask,
|
||||
user_1: User,
|
||||
input_key: str,
|
||||
expected_value: Union[List, str],
|
||||
) -> None:
|
||||
client, auth_token = self.get_test_client_and_auth_token(
|
||||
app, user_1.email
|
||||
)
|
||||
|
||||
client.post(
|
||||
self.route,
|
||||
data=json.dumps(
|
||||
{**TEST_METADATA, input_key: self.random_string()}
|
||||
),
|
||||
content_type='application/json',
|
||||
headers=dict(Authorization=f'Bearer {auth_token}'),
|
||||
)
|
||||
|
||||
oauth_client = OAuth2Client.query.first()
|
||||
assert getattr(oauth_client, input_key) == expected_value
|
@ -24,6 +24,10 @@ def random_string(
|
||||
)
|
||||
|
||||
|
||||
def random_domain() -> str:
|
||||
return random_string(prefix='https://', suffix='.com')
|
||||
|
||||
|
||||
def random_email() -> str:
|
||||
return random_string(suffix='@example.com')
|
||||
|
||||
|
@ -111,6 +111,9 @@ class User(BaseModel):
|
||||
new_password, current_app.config.get('BCRYPT_LOG_ROUNDS')
|
||||
).decode()
|
||||
|
||||
def get_user_id(self) -> int:
|
||||
return self.id
|
||||
|
||||
@hybrid_property
|
||||
def workouts_count(self) -> int:
|
||||
return Workout.query.filter(Workout.user_id == self.id).count()
|
||||
|
17
poetry.lock
generated
17
poetry.lock
generated
@ -64,6 +64,17 @@ docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"]
|
||||
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"]
|
||||
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"]
|
||||
|
||||
[[package]]
|
||||
name = "authlib"
|
||||
version = "1.0.1"
|
||||
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[package.dependencies]
|
||||
cryptography = ">=3.2"
|
||||
|
||||
[[package]]
|
||||
name = "babel"
|
||||
version = "2.10.1"
|
||||
@ -1466,7 +1477,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.7"
|
||||
content-hash = "e3bc153f918e57b66b117505b29466555da5ed642b0bbdc58bb5bcb9bd2ea688"
|
||||
content-hash = "84bb7759a500c94555fb8d8e5ca25bd98b287f6c53c80eb259bcc3c0d31b1031"
|
||||
|
||||
[metadata.files]
|
||||
alabaster = [
|
||||
@ -1493,6 +1504,10 @@ attrs = [
|
||||
{file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
|
||||
{file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
|
||||
]
|
||||
authlib = [
|
||||
{file = "Authlib-1.0.1-py2.py3-none-any.whl", hash = "sha256:1286e2d5ef5bfe5a11cc2d0a0d1031f0393f6ce4d61f5121cfe87fa0054e98bd"},
|
||||
{file = "Authlib-1.0.1.tar.gz", hash = "sha256:6e74a4846ac36dfc882b3cc2fbd3d9eb410a627f2f2dc11771276655345223b1"},
|
||||
]
|
||||
babel = [
|
||||
{file = "Babel-2.10.1-py3-none-any.whl", hash = "sha256:3f349e85ad3154559ac4930c3918247d319f21910d5ce4b25d439ed8693b98d2"},
|
||||
{file = "Babel-2.10.1.tar.gz", hash = "sha256:98aeaca086133efb3e1e2aad0396987490c8425929ddbcfe0550184fdc54cd13"},
|
||||
|
@ -41,6 +41,7 @@ staticmap = "^0.5.4"
|
||||
SQLAlchemy = "1.4.36"
|
||||
pyOpenSSL = "^22.0"
|
||||
ua-parser = "^0.10.0"
|
||||
Authlib = "^1.0.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
black = "^22.3"
|
||||
|
Loading…
Reference in New Issue
Block a user