Files
Buffteks-Website/venv/lib/python3.12/site-packages/streamlit/web/server/oidc_mixin.py
2025-05-08 21:10:14 -05:00

109 lines
4.1 KiB
Python

# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tornado.web
from authlib.integrations.base_client import ( # type: ignore[import-untyped]
BaseApp,
BaseOAuth,
OAuth2Mixin,
OAuthError,
OpenIDMixin,
)
from authlib.integrations.requests_client import ( # type: ignore[import-untyped]
OAuth2Session,
)
from streamlit.web.server.authlib_tornado_integration import TornadoIntegration
class TornadoOAuth2App(OAuth2Mixin, OpenIDMixin, BaseApp): # type: ignore[misc]
client_cls = OAuth2Session
def load_server_metadata(self):
"""We enforce S256 code challenge method if it is supported by the server."""
result = super().load_server_metadata()
if "S256" in result.get("code_challenge_methods_supported", []):
self.client_kwargs["code_challenge_method"] = "S256"
return result
def authorize_redirect(
self, request_handler: tornado.web.RequestHandler, redirect_uri=None, **kwargs
):
"""Create a HTTP Redirect for Authorization Endpoint.
:param request_handler: HTTP request instance from Tornado.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: A HTTP redirect response.
"""
auth_context = self.create_authorization_url(redirect_uri, **kwargs)
self._save_authorize_data(redirect_uri=redirect_uri, **auth_context)
request_handler.redirect(auth_context["url"], status=302)
def authorize_access_token(
self, request_handler: tornado.web.RequestHandler, **kwargs
):
"""
:param request_handler: HTTP request instance from Tornado.
:return: A token dict.
"""
error = request_handler.get_argument("error", None)
if error:
description = request_handler.get_argument("error_description", None)
raise OAuthError(error=error, description=description)
params = {
"code": request_handler.get_argument("code"),
"state": request_handler.get_argument("state"),
}
assert self.framework.cache is not None
session = None
claims_options = kwargs.pop("claims_options", None)
state_data = self.framework.get_state_data(session, params.get("state"))
self.framework.clear_state_data(session, params.get("state"))
params = self._format_state_params(state_data, params)
token = self.fetch_access_token(**params, **kwargs)
if "id_token" in token and "nonce" in state_data:
userinfo = self.parse_id_token(
token, nonce=state_data["nonce"], claims_options=claims_options
)
token = {**token, "userinfo": userinfo}
return token
def _save_authorize_data(self, **kwargs):
"""Authlib underlying uses the concept of "session" to store state data.
In Tornado, we don't have a session, so we use the framework's cache option.
"""
state = kwargs.pop("state", None)
if state:
assert self.framework.cache is not None
session = None
self.framework.set_state_data(session, state, kwargs)
else:
raise RuntimeError("Missing state value")
class TornadoOAuth(BaseOAuth): # type: ignore[misc]
oauth2_client_cls = TornadoOAuth2App
framework_integration_cls = TornadoIntegration
def __init__(self, config=None, cache=None, fetch_token=None, update_token=None):
super().__init__(
cache=cache, fetch_token=fetch_token, update_token=update_token
)
self.config = config