74 lines
2.4 KiB
Python

import time
from typing import Optional
from authlib.oauth2 import OAuth2Request
from authlib.oauth2.rfc6749 import grants
from fittrackee import db
from fittrackee.users.models import User
from .models import OAuth2AuthorizationCode, OAuth2Client, OAuth2Token
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post']
def save_authorization_code(
self, code: str, request: OAuth2Request
) -> OAuth2AuthorizationCode:
code_challenge = request.data.get('code_challenge')
code_challenge_method = request.data.get('code_challenge_method')
auth_code = OAuth2AuthorizationCode(
code=code,
client_id=request.client.client_id,
redirect_uri=request.redirect_uri,
scope=request.scope,
user_id=request.user.id,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
)
db.session.add(auth_code)
db.session.commit()
return auth_code
def query_authorization_code(
self, code: str, client: OAuth2Client
) -> Optional[OAuth2AuthorizationCode]:
auth_code = OAuth2AuthorizationCode.query.filter_by(
code=code, client_id=client.client_id
).first()
if auth_code and not auth_code.is_expired():
return auth_code
return None
def delete_authorization_code(
self, authorization_code: OAuth2AuthorizationCode
) -> None:
db.session.delete(authorization_code)
db.session.commit()
def authenticate_user(
self, authorization_code: OAuth2AuthorizationCode
) -> User:
return User.query.get(authorization_code.user_id)
class RefreshTokenGrant(grants.RefreshTokenGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post']
INCLUDE_NEW_REFRESH_TOKEN = True
def authenticate_refresh_token(self, refresh_token: str) -> Optional[str]:
token = OAuth2Token.query.filter_by(
refresh_token=refresh_token
).first()
if token and token.is_refresh_token_active():
return token
return None
def authenticate_user(self, credential: OAuth2Token) -> User:
return User.query.get(credential.user_id)
def revoke_old_credential(self, credential: OAuth2Token) -> None:
credential.access_token_revoked_at = time.time()
db.session.commit()