From 376ffb60c12346750bb69aa52ee04e3cedf3794a Mon Sep 17 00:00:00 2001 From: Ethan Zhang Date: Wed, 7 Aug 2024 00:57:27 +0800 Subject: [PATCH] Add support for allowing request_refresh_jwt hooks with json response --- dash/dash-renderer/src/actions/api.js | 36 ++++-- .../renderer/test_request_hooks.py | 108 +++++++++++++++++- 2 files changed, 130 insertions(+), 14 deletions(-) diff --git a/dash/dash-renderer/src/actions/api.js b/dash/dash-renderer/src/actions/api.js index aada02d5d4..ff2101d6ac 100644 --- a/dash/dash-renderer/src/actions/api.js +++ b/dash/dash-renderer/src/actions/api.js @@ -1,8 +1,8 @@ -import {mergeDeepRight, once} from 'ramda'; -import {getCSRFHeader, handleAsyncError, addHttpHeaders} from '../actions'; -import {urlBase} from './utils'; -import {MAX_AUTH_RETRIES} from './constants'; -import {JWT_EXPIRED_MESSAGE, STATUS} from '../constants/constants'; +import { mergeDeepRight, once } from 'ramda'; +import { getCSRFHeader, handleAsyncError, addHttpHeaders } from '../actions'; +import { urlBase } from './utils'; +import { MAX_AUTH_RETRIES } from './constants'; +import { JWT_EXPIRED_MESSAGE, STATUS } from '../constants/constants'; /* eslint-disable-next-line no-console */ const logWarningOnce = once(console.warn); @@ -28,11 +28,11 @@ function POST(path, fetchConfig, body = {}) { ); } -const request = {GET, POST}; +const request = { GET, POST }; export default function apiThunk(endpoint, method, store, id, body) { return async (dispatch, getState) => { - let {config, hooks} = getState(); + let { config, hooks } = getState(); let newHeaders = null; const url = `${urlBase(config)}${endpoint}`; @@ -48,7 +48,7 @@ export default function apiThunk(endpoint, method, store, id, body) { dispatch({ type: store, - payload: {id, status: 'loading'} + payload: { id, status: 'loading' } }); try { @@ -71,8 +71,24 @@ export default function apiThunk(endpoint, method, store, id, body) { res.status === STATUS.BAD_REQUEST ) { if (hooks.request_refresh_jwt) { - const body = await res.text(); - if (body.includes(JWT_EXPIRED_MESSAGE)) { + let body; + try { + body = await res.text(); + } catch (e) { + body = await res.json(); + } + let jwtExpired = false; + if (typeof body == 'string') { + jwtExpired = body.includes(JWT_EXPIRED_MESSAGE); + } else if (typeof body == 'object') { + for (const key in body) { + if (body[key] && typeof body[key] == 'string' && body[key].includes(JWT_EXPIRED_MESSAGE)) { + jwtExpired = true; + break; + } + } + } + if (jwtExpired) { const newJwt = await hooks.request_refresh_jwt( config.fetch.headers.Authorization.substr( 'Bearer '.length diff --git a/tests/integration/renderer/test_request_hooks.py b/tests/integration/renderer/test_request_hooks.py index 7f707cf823..06115c8811 100644 --- a/tests/integration/renderer/test_request_hooks.py +++ b/tests/integration/renderer/test_request_hooks.py @@ -1,13 +1,13 @@ -import json import functools +import json + import flask import pytest - from flaky import flaky +from werkzeug.exceptions import HTTPException -from dash import Dash, Output, Input, html, dcc +from dash import Dash, Input, Output, dcc, html from dash.types import RendererHooks -from werkzeug.exceptions import HTTPException def test_rdrh001_request_hooks(dash_duo): @@ -327,3 +327,103 @@ def test_rdrh004_layout_hooks(dash_duo): dash_duo.wait_for_text_to_equal("#layout", "layout_post generated this text") assert dash_duo.get_logs() == [] + + +@flaky(max_runs=3) +@pytest.mark.parametrize("expiry_code", [401, 400]) +def test_rdrh003_refresh_jwt_json(expiry_code, dash_duo): + app = Dash(__name__) + + app.index_string = """ + + + {%metas%} + {%title%} + {%favicon%} + {%css%} + + +
Testing custom DashRenderer
+ {%app_entry%} + +
With request hooks
+ + """ + + app.layout = html.Div( + [ + dcc.Input(id="input", value="initial value"), + html.Div(html.Div([html.Div(id="output-1"), html.Div(id="output-token")])), + ] + ) + + @app.callback(Output("output-1", "children"), [Input("input", "value")]) + def update_output(value): + return value + + required_jwt_len = 0 + + # test with an auth layer that requires a JWT with a certain length + def protect_route(func): + @functools.wraps(func) + def wrap(*args, **kwargs): + try: + if flask.request.method == "OPTIONS": + return func(*args, **kwargs) + token = flask.request.headers.environ.get("HTTP_AUTHORIZATION") + if required_jwt_len and ( + not token or len(token) != required_jwt_len + len("Bearer ") + ): + response = flask.jsonify({'error': 'JWT Expired ' + str(token)}) + flask.abort(response, expiry_code) + except HTTPException: + return flask.jsonify({'error': "JWT Expired " + str(token)}), expiry_code + return func(*args, **kwargs) + + return wrap + + # wrap all API calls with auth. + for name, method in ( + (x, app.server.view_functions[x]) + for x in app.routes + if x in app.server.view_functions + ): + app.server.view_functions[name] = protect_route(method) + + dash_duo.start_server(app) + + _in = dash_duo.find_element("#input") + dash_duo.clear_input(_in) + + required_jwt_len = 1 + + _in.send_keys("fired request") + + dash_duo.wait_for_text_to_equal("#output-1", "fired request") + dash_duo.wait_for_text_to_equal("#output-token", ".") + + required_jwt_len = 2 + + dash_duo.clear_input(_in) + _in.send_keys("fired request again") + + dash_duo.wait_for_text_to_equal("#output-1", "fired request again") + dash_duo.wait_for_text_to_equal("#output-token", "..") + + assert len(dash_duo.get_logs()) == 2