Skip to content
Snippets Groups Projects
Commit 8e39ad2c authored by kaiyou's avatar kaiyou
Browse files

Support stacking intended redirect

parent a191388d
No related branches found
No related tags found
No related merge requests found
from trurt import models
from trurt import models, utils
from trurt.account import blueprint, forms
import flask_login
......@@ -12,14 +12,19 @@ def login():
user = models.User.login(form.username.data, form.password.data)
if user:
flask_login.login_user(user)
endpoint = flask.request.args.get("next", "account.login")
return flask.redirect(flask.url_for(endpoint, **flask.request.args))
return flask.redirect(utils.url_or_intent("account.status"))
else:
flask.flash("Wrong credentials")
return flask.render_template("account_login.html", form=form)
return flask.render_template("account_login.html",
action=utils.url_for(".login"), form=form)
@blueprint.route("/logout")
def logout():
flask_login.logout_user()
return flask.redirect("/")
@blueprint.route("/status")
def status():
return "ok"
......@@ -4,7 +4,7 @@
{% block subtitle %}{% endblock %}
{% block content %}
<form method="POST" action="{{ url_for("account.login") }}">
<form method="POST" action="{{ action }}">
{{ form.hidden_tag() }}
{{ form.username }}
{{ form.password }}
......
from trurt.sso import blueprint, forms
from trurt import models
from trurt import models, utils
import flask_login
import flask
@blueprint.route("/pick/<service_spn>/<return_endpoint>")
@blueprint.route("/pick")
@flask_login.login_required
def pick(service_spn, return_endpoint):
def pick():
service_spn = flask.request.args.get("service_spn") or flask.abort(404)
service = models.Service.query.filter_by(spn=service_spn).first_or_404()
profiles = models.Profile.query.filter_by(
service_id=service.id,
......@@ -16,4 +17,4 @@ def pick(service_spn, return_endpoint):
form = forms.SSOValidateForm()
return flask.render_template("sso_pick.html",
service=service, profiles=profiles, form=form,
return_endpoint=return_endpoint, args=flask.request.args)
action=utils.url_or_intent("account.status"))
......@@ -6,7 +6,7 @@ from saml2 import sigver
sigver.security_context = security_context
from trurt.sso import blueprint, forms
from trurt import models
from trurt import models, utils
from saml2 import server, saml, config, mdstore, assertion
import saml2, base64, flask, xmlsec, lxml.etree, flask_login
......@@ -95,10 +95,8 @@ class SecurityContext(sigver.SecurityContext):
@blueprint.route('/saml/<service_spn>/redirect')
def redirect(service_spn):
service = models.Service.query.filter_by(spn=service_spn).first_or_404()
return flask.redirect(flask.url_for(
"sso.pick", service_spn=service_spn,
return_endpoint="sso.reply",
**flask.request.args
return flask.redirect(utils.url_for(
"sso.pick", intent="sso.reply", service_spn=service_spn,
))
......
......@@ -5,7 +5,7 @@
{% block content %}
{% for profile in profiles %}
<form method="POST" action="{{ url_for(return_endpoint, **args) }}">
<form method="POST" action="{{ action }}">
{{ form.hidden_tag() }}
<input type="hidden" name="service_id" value="{{ service.id }}">
<input type="hidden" name="profile_id" value="{{ profile.id }}">
......
......@@ -14,10 +14,34 @@ login.login_view = "account.login"
@login.unauthorized_handler
def handle_needs_login():
return flask.redirect(
flask.url_for('account.login', next=flask.request.endpoint)
url_for('account.login', intent=flask.request.endpoint)
)
def url_for(endpoint, intent=None, *args, **kwargs):
""" Returns an url that preserves query string and supports passing on
an intent
"""
query_string = dict(flask.request.args)
query_string.update(kwargs)
if "intents" in query_string and intent is not None:
query_string["intents"] += ":" + intent
elif "intents" not in query_string and intent is not None:
query_string["intents"] = intent
return flask.url_for(endpoint, *args, **query_string)
def url_or_intent(endpoint):
""" Return the latest intent, or the endpoint url if none
"""
intents = flask.request.args.get("intents", "")
if intents:
intents = intents.split(":")
return url_for(intents[-1], intents=":".join(intents[:-1]))
else:
return flask.url_for(endpoint)
# Request rate limitation
limiter = flask_limiter.Limiter(key_func=lambda: current_user.id)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment