66 lines
2.1 KiB
Python
66 lines
2.1 KiB
Python
|
from typing import Optional
|
||
|
|
||
|
from authlib.oauth2 import OAuth2Request
|
||
|
from authlib.oauth2.rfc6749 import grants
|
||
|
|
||
|
from fittrackee import db
|
||
|
from fittrackee.users.models import User
|
||
|
|
||
|
from .models import OAuth2AuthorizationCode, OAuth2Client, OAuth2Token
|
||
|
|
||
|
|
||
|
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||
|
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post']
|
||
|
|
||
|
def save_authorization_code(
|
||
|
self, code: str, request: OAuth2Request
|
||
|
) -> OAuth2AuthorizationCode:
|
||
|
code_challenge = request.data.get('code_challenge')
|
||
|
code_challenge_method = request.data.get('code_challenge_method')
|
||
|
auth_code = OAuth2AuthorizationCode(
|
||
|
code=code,
|
||
|
client_id=request.client.client_id,
|
||
|
redirect_uri=request.redirect_uri,
|
||
|
scope=request.scope,
|
||
|
user_id=request.user.id,
|
||
|
code_challenge=code_challenge,
|
||
|
code_challenge_method=code_challenge_method,
|
||
|
)
|
||
|
db.session.add(auth_code)
|
||
|
db.session.commit()
|
||
|
return auth_code
|
||
|
|
||
|
def query_authorization_code(
|
||
|
self, code: str, client: OAuth2Client
|
||
|
) -> Optional[OAuth2AuthorizationCode]:
|
||
|
auth_code = OAuth2AuthorizationCode.query.filter_by(
|
||
|
code=code, client_id=client.client_id
|
||
|
).first()
|
||
|
if auth_code and not auth_code.is_expired():
|
||
|
return auth_code
|
||
|
return None
|
||
|
|
||
|
def delete_authorization_code(
|
||
|
self, authorization_code: OAuth2AuthorizationCode
|
||
|
) -> None:
|
||
|
db.session.delete(authorization_code)
|
||
|
db.session.commit()
|
||
|
|
||
|
def authenticate_user(
|
||
|
self, authorization_code: OAuth2AuthorizationCode
|
||
|
) -> User:
|
||
|
return User.query.get(authorization_code.user_id)
|
||
|
|
||
|
|
||
|
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
||
|
def authenticate_refresh_token(self, refresh_token: str) -> Optional[str]:
|
||
|
token = OAuth2Token.query.filter_by(
|
||
|
refresh_token=refresh_token
|
||
|
).first()
|
||
|
if token and token.is_refresh_token_active():
|
||
|
return token
|
||
|
return None
|
||
|
|
||
|
def authenticate_user(self, credential: OAuth2Token) -> User:
|
||
|
return User.query.get(credential.user_id)
|