API - init OAuth server and oauth clients creation
This commit is contained in:
parent
c13e9e0286
commit
c6cd7ff67c
@ -16,6 +16,7 @@ from flask_dramatiq import Dramatiq
|
|||||||
from flask_migrate import Migrate
|
from flask_migrate import Migrate
|
||||||
from flask_sqlalchemy import SQLAlchemy
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
from sqlalchemy.exc import ProgrammingError
|
from sqlalchemy.exc import ProgrammingError
|
||||||
|
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||||
|
|
||||||
from fittrackee.emails.email import EmailService
|
from fittrackee.emails.email import EmailService
|
||||||
from fittrackee.request import CustomRequest
|
from fittrackee.request import CustomRequest
|
||||||
@ -64,6 +65,11 @@ def create_app(init_email: bool = True) -> Flask:
|
|||||||
migrate.init_app(app, db)
|
migrate.init_app(app, db)
|
||||||
dramatiq.init_app(app)
|
dramatiq.init_app(app)
|
||||||
|
|
||||||
|
# set oauth2
|
||||||
|
from fittrackee.oauth2.config import config_oauth
|
||||||
|
|
||||||
|
config_oauth(app)
|
||||||
|
|
||||||
# set up email if 'EMAIL_URL' is initialized
|
# set up email if 'EMAIL_URL' is initialized
|
||||||
if init_email:
|
if init_email:
|
||||||
if app.config['EMAIL_URL']:
|
if app.config['EMAIL_URL']:
|
||||||
@ -95,6 +101,7 @@ def create_app(init_email: bool = True) -> Flask:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
from .application.app_config import config_blueprint # noqa
|
from .application.app_config import config_blueprint # noqa
|
||||||
|
from .oauth2.routes import oauth_blueprint # noqa
|
||||||
from .users.auth import auth_blueprint # noqa
|
from .users.auth import auth_blueprint # noqa
|
||||||
from .users.users import users_blueprint # noqa
|
from .users.users import users_blueprint # noqa
|
||||||
from .workouts.records import records_blueprint # noqa
|
from .workouts.records import records_blueprint # noqa
|
||||||
@ -103,6 +110,7 @@ def create_app(init_email: bool = True) -> Flask:
|
|||||||
from .workouts.workouts import workouts_blueprint # noqa
|
from .workouts.workouts import workouts_blueprint # noqa
|
||||||
|
|
||||||
app.register_blueprint(auth_blueprint, url_prefix='/api')
|
app.register_blueprint(auth_blueprint, url_prefix='/api')
|
||||||
|
app.register_blueprint(oauth_blueprint, url_prefix='/api')
|
||||||
app.register_blueprint(config_blueprint, url_prefix='/api')
|
app.register_blueprint(config_blueprint, url_prefix='/api')
|
||||||
app.register_blueprint(records_blueprint, url_prefix='/api')
|
app.register_blueprint(records_blueprint, url_prefix='/api')
|
||||||
app.register_blueprint(sports_blueprint, url_prefix='/api')
|
app.register_blueprint(sports_blueprint, url_prefix='/api')
|
||||||
@ -153,4 +161,8 @@ def create_app(init_email: bool = True) -> Flask:
|
|||||||
else:
|
else:
|
||||||
return render_template('index.html')
|
return render_template('index.html')
|
||||||
|
|
||||||
|
# to get headers, especially 'X-Forwarded-Proto' for scheme needed by
|
||||||
|
# Authlib, when the application is running behind a proxy server
|
||||||
|
app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
76
fittrackee/migrations/versions/24_84d840ce853b_add_oauth.py
Normal file
76
fittrackee/migrations/versions/24_84d840ce853b_add_oauth.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
"""add OAuth 2.0
|
||||||
|
|
||||||
|
Revision ID: 84d840ce853b
|
||||||
|
Revises: 5e3a3a31c432
|
||||||
|
Create Date: 2022-05-27 10:54:02.284543
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '84d840ce853b'
|
||||||
|
down_revision = '5e3a3a31c432'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('oauth2_client',
|
||||||
|
sa.Column('client_id', sa.String(length=48), nullable=True),
|
||||||
|
sa.Column('client_secret', sa.String(length=120), nullable=True),
|
||||||
|
sa.Column('client_id_issued_at', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('client_secret_expires_at', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('client_metadata', sa.Text(), nullable=True),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_oauth2_client_client_id'), 'oauth2_client', ['client_id'], unique=False)
|
||||||
|
op.create_table('oauth2_code',
|
||||||
|
sa.Column('code', sa.String(length=120), nullable=False),
|
||||||
|
sa.Column('client_id', sa.String(length=48), nullable=True),
|
||||||
|
sa.Column('redirect_uri', sa.Text(), nullable=True),
|
||||||
|
sa.Column('response_type', sa.Text(), nullable=True),
|
||||||
|
sa.Column('scope', sa.Text(), nullable=True),
|
||||||
|
sa.Column('nonce', sa.Text(), nullable=True),
|
||||||
|
sa.Column('auth_time', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('code_challenge', sa.Text(), nullable=True),
|
||||||
|
sa.Column('code_challenge_method', sa.String(length=48), nullable=True),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('code')
|
||||||
|
)
|
||||||
|
op.create_table('oauth2_token',
|
||||||
|
sa.Column('client_id', sa.String(length=48), nullable=True),
|
||||||
|
sa.Column('token_type', sa.String(length=40), nullable=True),
|
||||||
|
sa.Column('access_token', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('refresh_token', sa.String(length=255), nullable=True),
|
||||||
|
sa.Column('scope', sa.Text(), nullable=True),
|
||||||
|
sa.Column('issued_at', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('access_token_revoked_at', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('refresh_token_revoked_at', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('expires_in', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('access_token')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_oauth2_token_refresh_token'), 'oauth2_token', ['refresh_token'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f('ix_oauth2_token_refresh_token'), table_name='oauth2_token')
|
||||||
|
op.drop_table('oauth2_token')
|
||||||
|
op.drop_table('oauth2_code')
|
||||||
|
op.drop_index(op.f('ix_oauth2_client_client_id'), table_name='oauth2_client')
|
||||||
|
op.drop_table('oauth2_client')
|
||||||
|
# ### end Alembic commands ###
|
0
fittrackee/oauth2/__init__.py
Normal file
0
fittrackee/oauth2/__init__.py
Normal file
35
fittrackee/oauth2/client.py
Normal file
35
fittrackee/oauth2/client.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from time import time
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from werkzeug.security import gen_salt
|
||||||
|
|
||||||
|
from fittrackee.oauth2.models import OAuth2Client
|
||||||
|
from fittrackee.users.models import User
|
||||||
|
|
||||||
|
|
||||||
|
def create_oauth_client(metadata: Dict, user: User) -> OAuth2Client:
|
||||||
|
"""
|
||||||
|
Create oauth client for 3rd-party applications.
|
||||||
|
|
||||||
|
Only Authorization Code Grant with 'client_secret_post' as method
|
||||||
|
is supported.
|
||||||
|
"""
|
||||||
|
client_metadata = {
|
||||||
|
'client_name': metadata['client_name'],
|
||||||
|
'client_uri': metadata['client_uri'],
|
||||||
|
'redirect_uris': metadata['redirect_uris'],
|
||||||
|
'scope': metadata['scope'],
|
||||||
|
'grant_types': ['authorization_code'],
|
||||||
|
'response_types': ['code'],
|
||||||
|
'token_endpoint_auth_method': 'client_secret_post',
|
||||||
|
}
|
||||||
|
client_id = gen_salt(24)
|
||||||
|
client_id_issued_at = int(time())
|
||||||
|
client = OAuth2Client(
|
||||||
|
client_id=client_id,
|
||||||
|
client_id_issued_at=client_id_issued_at,
|
||||||
|
user_id=user.id,
|
||||||
|
)
|
||||||
|
client.set_client_metadata(client_metadata)
|
||||||
|
client.client_secret = gen_salt(48)
|
||||||
|
return client
|
14
fittrackee/oauth2/config.py
Normal file
14
fittrackee/oauth2/config.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from authlib.oauth2.rfc7636 import CodeChallenge
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
from .grants import AuthorizationCodeGrant
|
||||||
|
from .server import authorization_server
|
||||||
|
|
||||||
|
|
||||||
|
def config_oauth(app: Flask) -> None:
|
||||||
|
authorization_server.init_app(app)
|
||||||
|
|
||||||
|
# supported grants
|
||||||
|
authorization_server.register_grant(
|
||||||
|
AuthorizationCodeGrant, [CodeChallenge(required=True)]
|
||||||
|
)
|
65
fittrackee/oauth2/grants.py
Normal file
65
fittrackee/oauth2/grants.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from authlib.oauth2 import OAuth2Request
|
||||||
|
from authlib.oauth2.rfc6749 import grants
|
||||||
|
|
||||||
|
from fittrackee import db
|
||||||
|
from fittrackee.users.models import User
|
||||||
|
|
||||||
|
from .models import OAuth2AuthorizationCode, OAuth2Client, OAuth2Token
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||||||
|
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post']
|
||||||
|
|
||||||
|
def save_authorization_code(
|
||||||
|
self, code: str, request: OAuth2Request
|
||||||
|
) -> OAuth2AuthorizationCode:
|
||||||
|
code_challenge = request.data.get('code_challenge')
|
||||||
|
code_challenge_method = request.data.get('code_challenge_method')
|
||||||
|
auth_code = OAuth2AuthorizationCode(
|
||||||
|
code=code,
|
||||||
|
client_id=request.client.client_id,
|
||||||
|
redirect_uri=request.redirect_uri,
|
||||||
|
scope=request.scope,
|
||||||
|
user_id=request.user.id,
|
||||||
|
code_challenge=code_challenge,
|
||||||
|
code_challenge_method=code_challenge_method,
|
||||||
|
)
|
||||||
|
db.session.add(auth_code)
|
||||||
|
db.session.commit()
|
||||||
|
return auth_code
|
||||||
|
|
||||||
|
def query_authorization_code(
|
||||||
|
self, code: str, client: OAuth2Client
|
||||||
|
) -> Optional[OAuth2AuthorizationCode]:
|
||||||
|
auth_code = OAuth2AuthorizationCode.query.filter_by(
|
||||||
|
code=code, client_id=client.client_id
|
||||||
|
).first()
|
||||||
|
if auth_code and not auth_code.is_expired():
|
||||||
|
return auth_code
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete_authorization_code(
|
||||||
|
self, authorization_code: OAuth2AuthorizationCode
|
||||||
|
) -> None:
|
||||||
|
db.session.delete(authorization_code)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def authenticate_user(
|
||||||
|
self, authorization_code: OAuth2AuthorizationCode
|
||||||
|
) -> User:
|
||||||
|
return User.query.get(authorization_code.user_id)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
||||||
|
def authenticate_refresh_token(self, refresh_token: str) -> Optional[str]:
|
||||||
|
token = OAuth2Token.query.filter_by(
|
||||||
|
refresh_token=refresh_token
|
||||||
|
).first()
|
||||||
|
if token and token.is_refresh_token_active():
|
||||||
|
return token
|
||||||
|
return None
|
||||||
|
|
||||||
|
def authenticate_user(self, credential: OAuth2Token) -> User:
|
||||||
|
return User.query.get(credential.user_id)
|
59
fittrackee/oauth2/models.py
Normal file
59
fittrackee/oauth2/models.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import time
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from authlib.integrations.sqla_oauth2 import (
|
||||||
|
OAuth2AuthorizationCodeMixin,
|
||||||
|
OAuth2ClientMixin,
|
||||||
|
OAuth2TokenMixin,
|
||||||
|
)
|
||||||
|
from sqlalchemy.ext.declarative import DeclarativeMeta
|
||||||
|
|
||||||
|
from fittrackee import db
|
||||||
|
|
||||||
|
BaseModel: DeclarativeMeta = db.Model
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2Client(BaseModel, OAuth2ClientMixin):
|
||||||
|
__tablename__ = 'oauth2_client'
|
||||||
|
|
||||||
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
|
user_id = db.Column(
|
||||||
|
db.Integer, db.ForeignKey('users.id', ondelete='CASCADE')
|
||||||
|
)
|
||||||
|
user = db.relationship('User')
|
||||||
|
|
||||||
|
def serialize(self) -> Dict:
|
||||||
|
return {
|
||||||
|
'client_id': self.client_id,
|
||||||
|
'client_secret': self.client_secret,
|
||||||
|
'id': self.id,
|
||||||
|
'name': self.client_name,
|
||||||
|
'redirect_uris': self.redirect_uris,
|
||||||
|
'website': self.client_uri,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2AuthorizationCode(BaseModel, OAuth2AuthorizationCodeMixin):
|
||||||
|
__tablename__ = 'oauth2_code'
|
||||||
|
|
||||||
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
|
user_id = db.Column(
|
||||||
|
db.Integer, db.ForeignKey('users.id', ondelete='CASCADE')
|
||||||
|
)
|
||||||
|
user = db.relationship('User')
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2Token(BaseModel, OAuth2TokenMixin):
|
||||||
|
__tablename__ = 'oauth2_token'
|
||||||
|
|
||||||
|
id = db.Column(db.Integer, primary_key=True)
|
||||||
|
user_id = db.Column(
|
||||||
|
db.Integer, db.ForeignKey('users.id', ondelete='CASCADE')
|
||||||
|
)
|
||||||
|
user = db.relationship('User')
|
||||||
|
|
||||||
|
def is_refresh_token_active(self) -> bool:
|
||||||
|
if self.revoked:
|
||||||
|
return False
|
||||||
|
expires_at = self.issued_at + self.expires_in * 2
|
||||||
|
return expires_at >= time.time()
|
53
fittrackee/oauth2/routes.py
Normal file
53
fittrackee/oauth2/routes.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
|
from flask import Blueprint, request
|
||||||
|
|
||||||
|
from fittrackee import db
|
||||||
|
from fittrackee.responses import HttpResponse, InvalidPayloadErrorResponse
|
||||||
|
from fittrackee.users.decorators import authenticate
|
||||||
|
from fittrackee.users.models import User
|
||||||
|
|
||||||
|
from .client import create_oauth_client
|
||||||
|
|
||||||
|
oauth_blueprint = Blueprint('oauth', __name__)
|
||||||
|
|
||||||
|
EXPECTED_METADATA_KEYS = [
|
||||||
|
'client_name',
|
||||||
|
'client_uri',
|
||||||
|
'redirect_uris',
|
||||||
|
'scope',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@oauth_blueprint.route('/oauth/apps', methods=['POST'])
|
||||||
|
@authenticate
|
||||||
|
def create_client(auth_user: User) -> Union[HttpResponse, Tuple[Dict, int]]:
|
||||||
|
client_metadata = request.get_json()
|
||||||
|
if not client_metadata:
|
||||||
|
return InvalidPayloadErrorResponse(
|
||||||
|
message='OAuth client metadata missing'
|
||||||
|
)
|
||||||
|
|
||||||
|
missing_keys = [
|
||||||
|
key
|
||||||
|
for key in EXPECTED_METADATA_KEYS
|
||||||
|
if key not in client_metadata.keys()
|
||||||
|
]
|
||||||
|
if missing_keys:
|
||||||
|
return InvalidPayloadErrorResponse(
|
||||||
|
message=(
|
||||||
|
'OAuth client metadata missing keys: '
|
||||||
|
f'{", ".join(missing_keys)}'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
new_client = create_oauth_client(client_metadata, auth_user)
|
||||||
|
db.session.add(new_client)
|
||||||
|
db.session.commit()
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
'status': 'created',
|
||||||
|
'data': {'client': new_client.serialize()},
|
||||||
|
},
|
||||||
|
201,
|
||||||
|
)
|
16
fittrackee/oauth2/server.py
Normal file
16
fittrackee/oauth2/server.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from authlib.integrations.flask_oauth2 import AuthorizationServer
|
||||||
|
from authlib.integrations.sqla_oauth2 import (
|
||||||
|
create_query_client_func,
|
||||||
|
create_save_token_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
from fittrackee import db
|
||||||
|
|
||||||
|
from .models import OAuth2Client, OAuth2Token
|
||||||
|
|
||||||
|
query_client = create_query_client_func(db.session, OAuth2Client)
|
||||||
|
save_token = create_save_token_func(db.session, OAuth2Token)
|
||||||
|
authorization_server = AuthorizationServer(
|
||||||
|
query_client=query_client,
|
||||||
|
save_token=save_token,
|
||||||
|
)
|
@ -10,6 +10,7 @@ os.environ['DATABASE_URL'] = os.environ['DATABASE_TEST_URL']
|
|||||||
TEMP_FOLDER = '/tmp/FitTrackee'
|
TEMP_FOLDER = '/tmp/FitTrackee'
|
||||||
os.environ['UPLOAD_FOLDER'] = TEMP_FOLDER
|
os.environ['UPLOAD_FOLDER'] = TEMP_FOLDER
|
||||||
os.environ['APP_LOG'] = TEMP_FOLDER + '/fittrackee.log'
|
os.environ['APP_LOG'] = TEMP_FOLDER + '/fittrackee.log'
|
||||||
|
os.environ['AUTHLIB_INSECURE_TRANSPORT'] = '1'
|
||||||
|
|
||||||
pytest_plugins = [
|
pytest_plugins = [
|
||||||
'fittrackee.tests.fixtures.fixtures_app',
|
'fittrackee.tests.fixtures.fixtures_app',
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
from random import randint
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
@ -18,10 +19,18 @@ class RandomMixin:
|
|||||||
) -> str:
|
) -> str:
|
||||||
return random_string(length, prefix, suffix)
|
return random_string(length, prefix, suffix)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def random_domain() -> str:
|
||||||
|
return random_string(prefix='https://', suffix='com')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def random_email() -> str:
|
def random_email() -> str:
|
||||||
return random_email()
|
return random_email()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def random_int(min_val: int = 0, max_val: int = 999999) -> int:
|
||||||
|
return randint(min_val, max_val)
|
||||||
|
|
||||||
|
|
||||||
class ApiTestCaseMixin(RandomMixin):
|
class ApiTestCaseMixin(RandomMixin):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
0
fittrackee/tests/oauth2/__init__.py
Normal file
0
fittrackee/tests/oauth2/__init__.py
Normal file
137
fittrackee/tests/oauth2/test_oauth2_client.py
Normal file
137
fittrackee/tests/oauth2/test_oauth2_client.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from time import time
|
||||||
|
from typing import Dict
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
from fittrackee.oauth2.client import create_oauth_client
|
||||||
|
from fittrackee.oauth2.models import OAuth2Client
|
||||||
|
from fittrackee.users.models import User
|
||||||
|
|
||||||
|
from ..utils import random_domain, random_string
|
||||||
|
|
||||||
|
TEST_METADATA = {
|
||||||
|
'client_name': random_string(),
|
||||||
|
'client_uri': random_string(),
|
||||||
|
'redirect_uris': [random_domain()],
|
||||||
|
'scope': 'read write',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateOAuth2Client:
|
||||||
|
def test_it_creates_oauth_client(self, app: Flask, user_1: User) -> None:
|
||||||
|
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||||
|
|
||||||
|
assert isinstance(oauth_client, OAuth2Client)
|
||||||
|
|
||||||
|
def test_oauth_client_id_is_generated_with_gen_salt(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
client_id = random_string()
|
||||||
|
with patch(
|
||||||
|
'fittrackee.oauth2.client.gen_salt', return_value=client_id
|
||||||
|
):
|
||||||
|
|
||||||
|
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||||
|
|
||||||
|
assert oauth_client.client_id == client_id
|
||||||
|
|
||||||
|
def test_oauth_client_client_id_issued_at_is_initialized(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
client_id_issued_at = int(time())
|
||||||
|
with patch(
|
||||||
|
'fittrackee.oauth2.client.time', return_value=client_id_issued_at
|
||||||
|
):
|
||||||
|
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||||
|
|
||||||
|
assert oauth_client.client_id_issued_at == client_id_issued_at
|
||||||
|
|
||||||
|
def test_oauth_client_has_expected_name(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
client_name = random_string()
|
||||||
|
client_metadata: Dict = {**TEST_METADATA, 'client_name': client_name}
|
||||||
|
|
||||||
|
oauth_client = create_oauth_client(client_metadata, user_1)
|
||||||
|
|
||||||
|
assert oauth_client.client_name == client_name
|
||||||
|
|
||||||
|
def test_oauth_client_has_expected_client_uri(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
client_uri = random_domain()
|
||||||
|
client_metadata: Dict = {**TEST_METADATA, 'client_uri': client_uri}
|
||||||
|
|
||||||
|
oauth_client = create_oauth_client(client_metadata, user_1)
|
||||||
|
|
||||||
|
assert oauth_client.client_uri == client_uri
|
||||||
|
|
||||||
|
def test_oauth_client_has_expected_grant_types(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||||
|
|
||||||
|
assert oauth_client.grant_types == ['authorization_code']
|
||||||
|
|
||||||
|
def test_oauth_client_has_expected_redirect_uris(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
redirect_uris = [random_domain()]
|
||||||
|
client_metadata: Dict = {
|
||||||
|
**TEST_METADATA,
|
||||||
|
'redirect_uris': redirect_uris,
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth_client = create_oauth_client(client_metadata, user_1)
|
||||||
|
|
||||||
|
assert oauth_client.redirect_uris == redirect_uris
|
||||||
|
|
||||||
|
def test_oauth_client_has_expected_response_types(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
response_types = ['code']
|
||||||
|
client_metadata: Dict = {
|
||||||
|
**TEST_METADATA,
|
||||||
|
'response_types': response_types,
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth_client = create_oauth_client(client_metadata, user_1)
|
||||||
|
|
||||||
|
assert oauth_client.response_types == ['code']
|
||||||
|
|
||||||
|
def test_oauth_client_has_expected_scope(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
scope = 'profile'
|
||||||
|
client_metadata: Dict = {**TEST_METADATA, 'scope': scope}
|
||||||
|
|
||||||
|
oauth_client = create_oauth_client(client_metadata, user_1)
|
||||||
|
|
||||||
|
assert oauth_client.scope == scope
|
||||||
|
|
||||||
|
def test_oauth_client_has_expected_token_endpoint_auth_method(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||||
|
|
||||||
|
assert oauth_client.token_endpoint_auth_method == 'client_secret_post'
|
||||||
|
|
||||||
|
def test_when_auth_method_is_not_none_oauth_client_secret_is_generated(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
client_secret = random_string()
|
||||||
|
with patch(
|
||||||
|
'fittrackee.oauth2.client.gen_salt', return_value=client_secret
|
||||||
|
):
|
||||||
|
|
||||||
|
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||||
|
|
||||||
|
assert oauth_client.client_secret == client_secret
|
||||||
|
|
||||||
|
def test_it_creates_oauth_client_for_given_user(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
oauth_client = create_oauth_client(TEST_METADATA, user_1)
|
||||||
|
|
||||||
|
assert oauth_client.user_id == user_1.id
|
36
fittrackee/tests/oauth2/test_oauth2_models.py
Normal file
36
fittrackee/tests/oauth2/test_oauth2_models.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
from fittrackee.oauth2.models import OAuth2Client
|
||||||
|
|
||||||
|
from ..mixins import RandomMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthClientSerialize(RandomMixin):
|
||||||
|
def test_it_returns_oauth_client(self, app: Flask) -> None:
|
||||||
|
oauth_client = OAuth2Client(
|
||||||
|
id=self.random_int(),
|
||||||
|
client_id=self.random_string(),
|
||||||
|
client_id_issued_at=self.random_int(),
|
||||||
|
)
|
||||||
|
oauth_client.set_client_metadata(
|
||||||
|
{
|
||||||
|
'client_name': self.random_string(),
|
||||||
|
'redirect_uris': [self.random_string()],
|
||||||
|
'client_uri': self.random_domain(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized_oauth_client = oauth_client.serialize()
|
||||||
|
|
||||||
|
assert serialized_oauth_client['client_id'] == oauth_client.client_id
|
||||||
|
assert (
|
||||||
|
serialized_oauth_client['client_secret']
|
||||||
|
== oauth_client.client_secret
|
||||||
|
)
|
||||||
|
assert serialized_oauth_client['id'] == oauth_client.id
|
||||||
|
assert serialized_oauth_client['name'] == oauth_client.client_name
|
||||||
|
assert (
|
||||||
|
serialized_oauth_client['redirect_uris']
|
||||||
|
== oauth_client.redirect_uris
|
||||||
|
)
|
||||||
|
assert serialized_oauth_client['website'] == oauth_client.client_uri
|
163
fittrackee/tests/oauth2/test_oauth2_routes.py
Normal file
163
fittrackee/tests/oauth2/test_oauth2_routes.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
import json
|
||||||
|
from typing import List, Union
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
from fittrackee.oauth2.models import OAuth2Client
|
||||||
|
from fittrackee.users.models import User
|
||||||
|
|
||||||
|
from ..mixins import ApiTestCaseMixin
|
||||||
|
from ..utils import random_domain, random_string
|
||||||
|
|
||||||
|
TEST_METADATA = {
|
||||||
|
'client_name': random_string(),
|
||||||
|
'client_uri': random_domain(),
|
||||||
|
'redirect_uris': [random_domain()],
|
||||||
|
'scope': 'read write',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthClientCreation(ApiTestCaseMixin):
|
||||||
|
route = '/api/oauth/apps'
|
||||||
|
|
||||||
|
def test_it_returns_error_when_no_user_authenticated(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
client, auth_token = self.get_test_client_and_auth_token(
|
||||||
|
app, user_1.email
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
self.route,
|
||||||
|
data=json.dumps(TEST_METADATA),
|
||||||
|
content_type='application/json',
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_401(response)
|
||||||
|
|
||||||
|
def test_it_returns_error_when_no_metadata_provided(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
client, auth_token = self.get_test_client_and_auth_token(
|
||||||
|
app, user_1.email
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
self.route,
|
||||||
|
data=json.dumps(dict()),
|
||||||
|
content_type='application/json',
|
||||||
|
headers=dict(Authorization=f'Bearer {auth_token}'),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_400(
|
||||||
|
response, error_message='OAuth client metadata missing'
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'missing_key',
|
||||||
|
[
|
||||||
|
'client_name',
|
||||||
|
'client_uri',
|
||||||
|
'redirect_uris',
|
||||||
|
'scope',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_it_returns_error_when_metadata_key_is_missing(
|
||||||
|
self, app: Flask, user_1: User, missing_key: str
|
||||||
|
) -> None:
|
||||||
|
metadata = TEST_METADATA.copy()
|
||||||
|
del metadata[missing_key]
|
||||||
|
client, auth_token = self.get_test_client_and_auth_token(
|
||||||
|
app, user_1.email
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
self.route,
|
||||||
|
data=json.dumps(metadata),
|
||||||
|
content_type='application/json',
|
||||||
|
headers=dict(Authorization=f'Bearer {auth_token}'),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assert_400(
|
||||||
|
response,
|
||||||
|
error_message=f'OAuth client metadata missing keys: {missing_key}',
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_it_creates_oauth_client(self, app: Flask, user_1: User) -> None:
|
||||||
|
client, auth_token = self.get_test_client_and_auth_token(
|
||||||
|
app, user_1.email
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
self.route,
|
||||||
|
data=json.dumps(TEST_METADATA),
|
||||||
|
content_type='application/json',
|
||||||
|
headers=dict(Authorization=f'Bearer {auth_token}'),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
oauth_client = OAuth2Client.query.first()
|
||||||
|
assert oauth_client is not None
|
||||||
|
|
||||||
|
def test_it_returns_serialized_oauth_client(
|
||||||
|
self, app: Flask, user_1: User
|
||||||
|
) -> None:
|
||||||
|
client, auth_token = self.get_test_client_and_auth_token(
|
||||||
|
app, user_1.email
|
||||||
|
)
|
||||||
|
client_id = self.random_string()
|
||||||
|
client_secret = self.random_string()
|
||||||
|
with patch(
|
||||||
|
'fittrackee.oauth2.client.gen_salt',
|
||||||
|
side_effect=[client_id, client_secret],
|
||||||
|
):
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
self.route,
|
||||||
|
data=json.dumps(TEST_METADATA),
|
||||||
|
content_type='application/json',
|
||||||
|
headers=dict(Authorization=f'Bearer {auth_token}'),
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(response.data.decode())
|
||||||
|
assert data['data']['client']['client_id'] == client_id
|
||||||
|
assert data['data']['client']['client_secret'] == client_secret
|
||||||
|
assert data['data']['client']['name'] == TEST_METADATA['client_name']
|
||||||
|
assert (
|
||||||
|
data['data']['client']['redirect_uris']
|
||||||
|
== TEST_METADATA['redirect_uris']
|
||||||
|
)
|
||||||
|
assert data['data']['client']['website'] == TEST_METADATA['client_uri']
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'input_key,expected_value',
|
||||||
|
[
|
||||||
|
('grant_types', ['authorization_code']),
|
||||||
|
('response_types', ['code']),
|
||||||
|
('token_endpoint_auth_method', 'client_secret_post'),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_it_always_create_oauth_client_with_authorization_grant(
|
||||||
|
self,
|
||||||
|
app: Flask,
|
||||||
|
user_1: User,
|
||||||
|
input_key: str,
|
||||||
|
expected_value: Union[List, str],
|
||||||
|
) -> None:
|
||||||
|
client, auth_token = self.get_test_client_and_auth_token(
|
||||||
|
app, user_1.email
|
||||||
|
)
|
||||||
|
|
||||||
|
client.post(
|
||||||
|
self.route,
|
||||||
|
data=json.dumps(
|
||||||
|
{**TEST_METADATA, input_key: self.random_string()}
|
||||||
|
),
|
||||||
|
content_type='application/json',
|
||||||
|
headers=dict(Authorization=f'Bearer {auth_token}'),
|
||||||
|
)
|
||||||
|
|
||||||
|
oauth_client = OAuth2Client.query.first()
|
||||||
|
assert getattr(oauth_client, input_key) == expected_value
|
@ -24,6 +24,10 @@ def random_string(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def random_domain() -> str:
|
||||||
|
return random_string(prefix='https://', suffix='.com')
|
||||||
|
|
||||||
|
|
||||||
def random_email() -> str:
|
def random_email() -> str:
|
||||||
return random_string(suffix='@example.com')
|
return random_string(suffix='@example.com')
|
||||||
|
|
||||||
|
@ -111,6 +111,9 @@ class User(BaseModel):
|
|||||||
new_password, current_app.config.get('BCRYPT_LOG_ROUNDS')
|
new_password, current_app.config.get('BCRYPT_LOG_ROUNDS')
|
||||||
).decode()
|
).decode()
|
||||||
|
|
||||||
|
def get_user_id(self) -> int:
|
||||||
|
return self.id
|
||||||
|
|
||||||
@hybrid_property
|
@hybrid_property
|
||||||
def workouts_count(self) -> int:
|
def workouts_count(self) -> int:
|
||||||
return Workout.query.filter(Workout.user_id == self.id).count()
|
return Workout.query.filter(Workout.user_id == self.id).count()
|
||||||
|
17
poetry.lock
generated
17
poetry.lock
generated
@ -64,6 +64,17 @@ docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"]
|
|||||||
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"]
|
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"]
|
||||||
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"]
|
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "authlib"
|
||||||
|
version = "1.0.1"
|
||||||
|
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
cryptography = ">=3.2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "babel"
|
name = "babel"
|
||||||
version = "2.10.1"
|
version = "2.10.1"
|
||||||
@ -1466,7 +1477,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.7"
|
python-versions = "^3.7"
|
||||||
content-hash = "e3bc153f918e57b66b117505b29466555da5ed642b0bbdc58bb5bcb9bd2ea688"
|
content-hash = "84bb7759a500c94555fb8d8e5ca25bd98b287f6c53c80eb259bcc3c0d31b1031"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
alabaster = [
|
alabaster = [
|
||||||
@ -1493,6 +1504,10 @@ attrs = [
|
|||||||
{file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
|
{file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
|
||||||
{file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
|
{file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
|
||||||
]
|
]
|
||||||
|
authlib = [
|
||||||
|
{file = "Authlib-1.0.1-py2.py3-none-any.whl", hash = "sha256:1286e2d5ef5bfe5a11cc2d0a0d1031f0393f6ce4d61f5121cfe87fa0054e98bd"},
|
||||||
|
{file = "Authlib-1.0.1.tar.gz", hash = "sha256:6e74a4846ac36dfc882b3cc2fbd3d9eb410a627f2f2dc11771276655345223b1"},
|
||||||
|
]
|
||||||
babel = [
|
babel = [
|
||||||
{file = "Babel-2.10.1-py3-none-any.whl", hash = "sha256:3f349e85ad3154559ac4930c3918247d319f21910d5ce4b25d439ed8693b98d2"},
|
{file = "Babel-2.10.1-py3-none-any.whl", hash = "sha256:3f349e85ad3154559ac4930c3918247d319f21910d5ce4b25d439ed8693b98d2"},
|
||||||
{file = "Babel-2.10.1.tar.gz", hash = "sha256:98aeaca086133efb3e1e2aad0396987490c8425929ddbcfe0550184fdc54cd13"},
|
{file = "Babel-2.10.1.tar.gz", hash = "sha256:98aeaca086133efb3e1e2aad0396987490c8425929ddbcfe0550184fdc54cd13"},
|
||||||
|
@ -41,6 +41,7 @@ staticmap = "^0.5.4"
|
|||||||
SQLAlchemy = "1.4.36"
|
SQLAlchemy = "1.4.36"
|
||||||
pyOpenSSL = "^22.0"
|
pyOpenSSL = "^22.0"
|
||||||
ua-parser = "^0.10.0"
|
ua-parser = "^0.10.0"
|
||||||
|
Authlib = "^1.0.1"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
black = "^22.3"
|
black = "^22.3"
|
||||||
|
Loading…
Reference in New Issue
Block a user