Skip to content
Snippets Groups Projects
Commit 699c1e28 authored by kaiyou's avatar kaiyou
Browse files

Store authorization codes in redis

parent 9e221411
No related branches found
No related tags found
No related merge requests found
......@@ -16,6 +16,7 @@ def create_app_from_config(config):
utils.login.init_app(app)
utils.login.user_loader(models.User.get)
utils.migrate.init_app(app, models.db)
utils.redis.init_app(app)
# Initialize debugging tools
if app.config.get("DEBUG"):
......
......@@ -7,6 +7,7 @@ from hiboo.sso import forms, blueprint
from hiboo import models, utils, profile
import flask
import time
class Config(object):
......@@ -17,7 +18,7 @@ class Config(object):
def derive_form(cls, form):
""" Add required fields to a form.
"""
return type('NewForm', (forms.OIDCForm, form), {})
return tmodelsype('NewForm', (forms.OIDCForm, form), {})
@classmethod
def populate_service(cls, form, service):
......@@ -54,22 +55,18 @@ class Config(object):
)
class Client(models.db.Model, sqla_oauth2.OAuth2ClientMixin):
class Client(sqla_oauth2.OAuth2ClientMixin):
""" OIDC client that only supports authorization code, implicit and
hybrid flows.
"""
scope = "openid"
expire_after = 3600
def __init__(self, service):
self.service = service
super(Client, self).__init__(
client_id=service.config["client_id"],
client_secret=service.config["client_secret"],
client_metadata=service.config
)
# The authorization server is specific to a client
self.client_id = service.config["client_id"]
self.client_secret = service.config["client_secret"]
self.client_metadata = service.config
self.authorization = flask_oauth2.AuthorizationServer(
query_client=self.query_client,
save_token=self.save_token,
......@@ -85,7 +82,6 @@ class Client(models.db.Model, sqla_oauth2.OAuth2ClientMixin):
return self if client_id == self.client_id else None
def save_token(self, token, request):
# TODO: atm we do not save any token
pass
def get_jwt_config(self):
......@@ -93,7 +89,7 @@ class Client(models.db.Model, sqla_oauth2.OAuth2ClientMixin):
'key': self.service.config["jwt_key"],
'alg': self.service.config["jwt_alg"],
'iss': flask.url_for("sso.oidc_token", service_uuid=self.service.uuid, _external=True),
'exp': self.expire_after,
'exp': 3600,
}
@classmethod
......@@ -110,43 +106,35 @@ class Client(models.db.Model, sqla_oauth2.OAuth2ClientMixin):
@classmethod
def exists_nonce(cls, nonce, request):
return bool(AuthorizationCode.query.filter_by(
nonce=nonce, client_id=request.client_id).first()
)
return bool(utils.redis.get("nonce:{}".format(nonce)))
class AuthorizationCode(models.db.Model, sqla_oauth2.OAuth2AuthorizationCodeMixin):
class AuthorizationCode(utils.SerializableObj, sqla_oauth2.OAuth2AuthorizationCodeMixin):
""" Authorization code object for storage
"""
__tablename__ = "oidc_authorization_code"
user_id = models.db.Column(models.db.Text())
@classmethod
def create(cls, client, grant_user, request):
code = gen_salt(48) # TODO
authorization_code = AuthorizationCode(
code=code, nonce=request.data.get('nonce'),
client_id=client.client_id,
redirect_uri=request.redirect_uri,
scope=request.scope,
user_id=grant_user.uuid
obj = cls(
code=gen_salt(48), nonce=request.data.get("nonce") or "",
client_id=client.client_id, redirect_uri=request.redirect_uri,
scope=request.scope, user_id=grant_user.uuid,
auth_time=int(time.time())
)
models.db.session.add(authorization_code)
models.db.session.commit()
return code
utils.redis.hmset("code:{}".format(obj.code), obj.serialize())
if obj.nonce:
utils.redis.set("nonce:{}".format(obj.nonce), obj.code)
return obj.code
@classmethod
def get(cls, code, client):
return AuthorizationCode.query.filter_by(
client_id=client.client_id,
code=code
).first()
obj = cls.unserialize(utils.redis.hgetall("code:{}".format(code)))
if obj and obj.client_id == client.client_id:
return obj
@classmethod
def delete(cls, authorization_code):
models.db.session.delete(authorization_code)
models.db.session.commit()
utils.redis.delete("code:{}".format(authorization_code))
class AuthorizationCodeGrant(oauth2.grants.AuthorizationCodeGrant):
......
......@@ -3,6 +3,7 @@ import flask_login
import flask_migrate
import flask_babel
import flask_limiter
import flask_redis
from werkzeug.contrib import fixers
from werkzeug import routing
......@@ -75,6 +76,25 @@ def display_help(identifier):
return result
class SerializableObj(object):
def __init__(self, **kwargs):
self.__dict__.update(**kwargs)
self.__keys__ = list(kwargs.keys())
@classmethod
def unserialize(cls, kwargs):
return cls(**{
key.decode("utf8"): value.decode("utf8") if type(value) is bytes else value
for key, value in kwargs.items()
}) if kwargs else None
def serialize(self):
return {
key.encode("utf8"): value.encode("utf8") if type(value) is str else value
for key, value in self.__dict__.items() if key in self.__keys__
}
# Request rate limitation
limiter = flask_limiter.Limiter(key_func=lambda: current_user.id)
......@@ -90,3 +110,7 @@ def get_locale():
# Data migrate
migrate = flask_migrate.Migrate()
# Redis storage
redis = flask_redis.FlaskRedis()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment