74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
import time
|
|
from typing import Optional
|
|
|
|
from authlib.oauth2 import OAuth2Request
|
|
from authlib.oauth2.rfc6749 import grants
|
|
|
|
from fittrackee import db
|
|
from fittrackee.users.models import User
|
|
|
|
from .models import OAuth2AuthorizationCode, OAuth2Client, OAuth2Token
|
|
|
|
|
|
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
|
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post']
|
|
|
|
def save_authorization_code(
|
|
self, code: str, request: OAuth2Request
|
|
) -> OAuth2AuthorizationCode:
|
|
code_challenge = request.data.get('code_challenge')
|
|
code_challenge_method = request.data.get('code_challenge_method')
|
|
auth_code = OAuth2AuthorizationCode(
|
|
code=code,
|
|
client_id=request.client.client_id,
|
|
redirect_uri=request.redirect_uri,
|
|
scope=request.scope,
|
|
user_id=request.user.id,
|
|
code_challenge=code_challenge,
|
|
code_challenge_method=code_challenge_method,
|
|
)
|
|
db.session.add(auth_code)
|
|
db.session.commit()
|
|
return auth_code
|
|
|
|
def query_authorization_code(
|
|
self, code: str, client: OAuth2Client
|
|
) -> Optional[OAuth2AuthorizationCode]:
|
|
auth_code = OAuth2AuthorizationCode.query.filter_by(
|
|
code=code, client_id=client.client_id
|
|
).first()
|
|
if auth_code and not auth_code.is_expired():
|
|
return auth_code
|
|
return None
|
|
|
|
def delete_authorization_code(
|
|
self, authorization_code: OAuth2AuthorizationCode
|
|
) -> None:
|
|
db.session.delete(authorization_code)
|
|
db.session.commit()
|
|
|
|
def authenticate_user(
|
|
self, authorization_code: OAuth2AuthorizationCode
|
|
) -> User:
|
|
return User.query.get(authorization_code.user_id)
|
|
|
|
|
|
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
|
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_post']
|
|
INCLUDE_NEW_REFRESH_TOKEN = True
|
|
|
|
def authenticate_refresh_token(self, refresh_token: str) -> Optional[str]:
|
|
token = OAuth2Token.query.filter_by(
|
|
refresh_token=refresh_token
|
|
).first()
|
|
if token and token.is_refresh_token_active():
|
|
return token
|
|
return None
|
|
|
|
def authenticate_user(self, credential: OAuth2Token) -> User:
|
|
return User.query.get(credential.user_id)
|
|
|
|
def revoke_old_credential(self, credential: OAuth2Token) -> None:
|
|
credential.access_token_revoked_at = time.time()
|
|
db.session.commit()
|