diff --git a/hiboo/sso/oidc.py b/hiboo/sso/oidc.py index c83208f6ff57563ba35f18ef40624d9397fa6b78..db3a93db0c6f9209d6bbcc6645fa4ded837e19f5 100644 --- a/hiboo/sso/oidc.py +++ b/hiboo/sso/oidc.py @@ -4,10 +4,10 @@ Supported grants are authorization code, OpenID implicit and hybrid. It relies heavily on authlib for the OAuth/OIDC implementation. """ -from werkzeug.security import gen_salt from authlib.integrations import flask_oauth2, sqla_oauth2 from authlib.oauth2 import rfc6749 as oauth2 from authlib.oidc import core as oidc +from authlib.common import security from hiboo.sso import forms, blueprint from hiboo import models, utils, profile @@ -56,9 +56,9 @@ class Config(object): """ if "client_id" not in service.config: service.config.update( - client_id=gen_salt(24), - client_secret=gen_salt(48), - jwt_key=gen_salt(24), + client_id=security.generate_token(24), + client_secret=security.generate_token(48), + jwt_key=security.generate_token(24), jwt_alg="HS256" ) @@ -72,7 +72,7 @@ class AuthorizationCodeMixin(object): def create_authorization_code(self, client, grant_user, request): obj = AuthorizationCodeMixin.AuthorizationCode( - code=gen_salt(48), nonce=request.data.get("nonce") or "", + code=security.generate_token(48), nonce=request.data.get("nonce") or "", client_id=client.client_id, redirect_uri=request.redirect_uri, scope=request.scope, user_id=grant_user.uuid, auth_time=int(time.time()) @@ -96,28 +96,23 @@ class AuthorizationCodeMixin(object): class OpenIDMixin(object): - """ Mixin for defining OpenID grants + """ Mixin for defining OpenID grants, mostly a proxy to client methods, + either used as a grant extension for code grant, or as a direct mixin. """ def exists_nonce(self, nonce, request): return bool(utils.redis.get("nonce:{}".format(nonce))) - def get_jwt_config(self, grant): - service = grant.client.service - return { - 'key': service.config["jwt_key"], 'alg': service.config["jwt_alg"], - 'iss': flask.url_for("sso.oidc_token", service_uuid=service.uuid, _external=True), - 'exp': 3600, - } + def get_client(self, grant=None): + # In the case of AuthorizationCode, the current object is not the grant + # but a grant extension, so the client is retrieved through the grant argument + return self.request.client if grant is None else grant.client + + def get_jwt_config(self, grant=None): + return self.get_client().get_jwt_config() def generate_user_info(self, user, scope): - return oidc.UserInfo( - sub=user.uuid, - name=user.username, - prefered_username=user.username, - login=user.username, - email=user.email - ) + return self.get_client().generate_user_info(user, scope) class Client(sqla_oauth2.OAuth2ClientMixin): @@ -140,42 +135,81 @@ class Client(sqla_oauth2.OAuth2ClientMixin): # Configuration is stored in a format compatible with authlib metadata # so it only needs to be passed to the authorization server object self.client_metadata = service.config - self.authorization = flask_oauth2.AuthorizationServer( - query_client=self.query_client, - save_token=self.save_token, - app=flask.current_app - ) - self.authorization.register_grant( - Client.AuthorizationCodeGrant, [Client.OpenIDCode(require_nonce=False)] - ) + self.authorization = flask_oauth2.AuthorizationServer(query_client=self.query_client, save_token=self.save_token) + self.authorization.generate_token = self.generate_token + # Register all grant types + self.authorization.register_grant(Client.AuthorizationCodeGrant, [Client.OpenIDCode(require_nonce=False)]) self.authorization.register_grant(Client.ImplicitGrant) self.authorization.register_grant(Client.HybridGrant) + @classmethod + def get_by_service(cls, service_uuid): + service = models.Service.query.get(service_uuid) + if service and service.protocol == "oidc": + return cls(service) + def query_client(self, client_id): return self if client_id == self.client_id else None + + def get_jwt_config(self): + service = self.service + return { + 'key': service.config["jwt_key"], 'alg': service.config["jwt_alg"], + 'iss': flask.url_for("sso.oidc_token", service_uuid=service.uuid, _external=True), + 'exp': 3600, + } + + def generate_user_info(self, user, scope): + """ User info generation function used by the oidc code mixin and the userinfo endpoint + """ + return oidc.UserInfo( + sub=user.uuid, name=user.username, prefered_username=user.username, + login=user.username, email=user.email + ) + + def generate_token(self, client, grant_type, user=None, scope=None, expires_in=None, include_refresh_token=False): + """ Specific token generation function to help keep track of the profile associated with a token + """ + return dict( + client_id=self.client_id, token_type="Bearer", access_token=security.generate_token(48), + issued_at=time.time(), expires_in=expires_in or 3600, profile_uuid=user.uuid, scope=scope or "" + ) def save_token(self, token, request): - # Tokens are not saved since Hiboo supports user authentication, note - # long term app authentication. - pass + """ Save the token to redis database + """ + utils.redis.hmset("token:{}".format(token["access_token"]), token) + + def validate_token(self, request): + """ Validate then returns the current request token + """ + auth = request.headers.get("Authorization", "").split(None, 1) + if auth and len(auth) == 2 and auth[0] == "Bearer": + token = utils.decode_dict(utils.redis.hgetall("token:{}".format(auth[1]))) + if (token and token["client_id"] == self.client_id and + time.time() < (float(token["issued_at"]) + float(token["expires_in"]))): + return token @blueprint.route("/oidc/authorize/<service_uuid>", methods=["GET", "POST"]) def oidc_authorize(service_uuid): # Get the profile from user input (implies redirects) - service = models.Service.query.get(service_uuid) or flask.abort(404) - service.protocol == "oidc" or flask.abort(404) - picked = profile.get_profile(service, intent=True) or flask.abort(403) + client = Client.get_by_service(service_uuid) or flask.abort(404) + picked = profile.get_profile(client.service, intent=True) or flask.abort(403) # Generate and return the response - client = Client(service) return client.authorization.create_authorization_response(grant_user=picked) @blueprint.route("/oidc/token/<service_uuid>", methods=["POST"]) def oidc_token(service_uuid): - # Get the profile from user input (implies redirects) - service = models.Service.query.get(service_uuid) or flask.abort(404) - service.protocol == "oidc" or flask.abort(404) - # Generate and return the response - client = Client(service) + client = Client.get_by_service(service_uuid) or flask.abort(404) return client.authorization.create_token_response() + + +@blueprint.route("/oidc/userinfo/<service_uuid>", methods=["GET", "POST"]) +def oidc_userinfo(service_uuid): + client = Client.get_by_service(service_uuid) or flask.abort(404) + token = client.validate_token(flask.request) + profile = models.Profile.query.get(token["profile_uuid"]) + return client.generate_user_info(profile, token["scope"]) + \ No newline at end of file diff --git a/hiboo/utils.py b/hiboo/utils.py index c6ea608057bb707081c3e0e203808bec4e72725c..8c122e208ecf97643e85aac740b7c3030eaea7f6 100644 --- a/hiboo/utils.py +++ b/hiboo/utils.py @@ -77,6 +77,20 @@ def display_help(identifier): return result +def encode_dict(source, valid_keys=None): + return { + key.encode("utf8"): value.encode("utf8") if type(value) is str else value + for key, value in source.items() if (valid_keys is None or key in valid_keys) + } + + +def decode_dict(source): + return { + key.decode("utf8"): value.decode("utf8") if type(value) is bytes else value + for key, value in source.items() + } + + class SerializableObj(object): def __init__(self, **kwargs): self.__dict__.update(**kwargs) @@ -84,16 +98,10 @@ class SerializableObj(object): @classmethod def unserialize(cls, kwargs): - return cls(**{ - key.decode("utf8"): value.decode("utf8") if type(value) is bytes else value - for key, value in kwargs.items() - }) if kwargs else None + return cls(**decode_dict(kwargs)) if kwargs else None def serialize(self): - return { - key.encode("utf8"): value.encode("utf8") if type(value) is str else value - for key, value in self.__dict__.items() if key in self.__keys__ - } + return encode_dict(self.__dict__, self.__keys__) # Request rate limitation