diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c index 3a625847f3..f213a40b65 100644 --- a/src/backend/libpq/auth-oauth.c +++ b/src/backend/libpq/auth-oauth.c @@ -24,15 +24,23 @@ #include "libpq/hba.h" #include "libpq/oauth.h" #include "libpq/sasl.h" +#include "miscadmin.h" #include "storage/fd.h" /* GUC */ char *oauth_validator_command; +static OAuthProvider* oauth_provider = NULL; + +/*---------------------------------------------------------------- + * OAuth Authentication + *---------------------------------------------------------------- + */ +static List *oauth_providers = NIL; static void oauth_get_mechanisms(Port *port, StringInfo buf); static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass); static int oauth_exchange(void *opaq, const char *input, int inputlen, - char **output, int *outputlen, char **logdetail); + char **output, int *outputlen, const char **logdetail); /* Mechanism declaration */ const pg_be_sasl_mech pg_be_oauth_mech = { @@ -43,7 +51,6 @@ const pg_be_sasl_mech pg_be_oauth_mech = { PG_MAX_AUTH_TOKEN_LENGTH, }; - typedef enum { OAUTH_STATE_INIT = 0, @@ -62,7 +69,7 @@ struct oauth_ctx static char *sanitize_char(char c); static char *parse_kvpairs_for_auth(char **input); static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen); -static bool validate(Port *port, const char *auth, char **logdetail); +static bool validate(Port *port, const char *auth, const char **logdetail); static bool run_validator_command(Port *port, const char *token); static bool check_exit(FILE **fh, const char *command); static bool unset_cloexec(int fd); @@ -72,6 +79,86 @@ static bool username_ok_for_shell(const char *username); #define AUTH_KEY "auth" #define BEARER_SCHEME "Bearer " +#include "utils/memutils.h" + +/*---------------------------------------------------------------- + * OAuth Token Validator + *---------------------------------------------------------------- + */ + +/* + * RegistorOAuthProvider registers a OAuth Token Validator to be + * used for oauth token validation. It validates the token and adds the valiator + * name and it's hooks to a list of loaded token validator. The right validator's + * hooks can then be called based on the validator name specified in + * pg_hba.conf. + * + * This function should be called in _PG_init() by any extension looking to + * add a custom authentication method. + */ +void +RegistorOAuthProvider( + const char *provider_name, + OAuthProviderCheck_hook_type OAuthProviderCheck_hook, + OAuthProviderError_hook_type OAuthProviderError_hook, + OAuthProviderOptions_hook_type OAuthProviderOptions_hook +) +{ + if (!process_shared_preload_libraries_in_progress) + { + ereport(ERROR, + (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("RegistorOAuthProvider can only be called by a shared_preload_library"))); + return; + } + + MemoryContext oldcxt; + if (oauth_provider == NULL) + { + oldcxt = MemoryContextSwitchTo(TopMemoryContext); + oauth_provider = palloc(sizeof(OAuthProvider)); + oauth_provider->name = pstrdup(provider_name); + oauth_provider->oauth_provider_hook = OAuthProviderCheck_hook; + oauth_provider->oauth_error_hook = OAuthProviderError_hook; + oauth_provider->oauth_options_hook = OAuthProviderOptions_hook; + oauth_providers = lappend(oauth_providers, oauth_provider); + MemoryContextSwitchTo(oldcxt); + } + else + { + if (oauth_provider && oauth_provider->name) + { + ereport(ERROR, + (errmsg("OAuth provider \"%s\" is already loaded.", + oauth_provider->name))); + } + else + { + ereport(ERROR, + (errmsg("OAuth provider is already loaded."))); + } + } +} + +/* + * Returns the oauth provider (which includes it's + * callback functions) based on name specified. + */ +OAuthProvider *get_provider_by_name(const char *name) +{ + ListCell *lc; + foreach(lc, oauth_providers) + { + OAuthProvider *provider = (OAuthProvider *) lfirst(lc); + if (strcmp(provider->name, name) == 0) + { + return provider; + } + } + + return NULL; +} + static void oauth_get_mechanisms(Port *port, StringInfo buf) { @@ -102,9 +189,32 @@ oauth_init(Port *port, const char *selected_mech, const char *shadow_pass) return ctx; } +static void process_oauth_flow_type(pg_oauth_flow_type flow_type, struct oauth_ctx *ctx, char **output, int *outputlen) +{ + StringInfoData buf; + initStringInfo(&buf); + + OAuthProviderOptions *oauth_options = oauth_provider->oauth_options_hook(flow_type); + ctx->scope = oauth_options->scope; + ctx->issuer = oauth_options->issuer_url; + appendStringInfo(&buf, + "{ " + "\"status\": \"invalid_token\", " + "\"openid-configuration\": \"%s/.well-known/openid-configuration\"," + "\"scope\": \"%s\"" + "}", + oauth_options->issuer_url, + oauth_options->scope); + + *output = buf.data; + *outputlen = buf.len; + + pfree(oauth_options); +} + static int oauth_exchange(void *opaq, const char *input, int inputlen, - char **output, int *outputlen, char **logdetail) + char **output, int *outputlen, const char **logdetail) { char *p; char cbind_flag; @@ -247,11 +357,17 @@ oauth_exchange(void *opaq, const char *input, int inputlen, (errcode(ERRCODE_PROTOCOL_VIOLATION), errmsg("malformed OAUTHBEARER message"), errdetail("Message contains additional data after the final terminator."))); - - if (!validate(ctx->port, auth, logdetail)) + + /* if not Bearer, process flow_type*/ + if (strncasecmp(auth, BEARER_SCHEME, strlen(BEARER_SCHEME))) + { + process_oauth_flow_type(atoi(auth), ctx, output, outputlen); + ctx->state = OAUTH_STATE_ERROR; + return PG_SASL_EXCHANGE_CONTINUE; + } + else if(!validate(ctx->port, auth, logdetail)) { generate_error_response(ctx, output, outputlen); - ctx->state = OAUTH_STATE_ERROR; return PG_SASL_EXCHANGE_CONTINUE; } @@ -415,7 +531,7 @@ generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen) } static bool -validate(Port *port, const char *auth, char **logdetail) +validate(Port *port, const char *auth, const char **logdetail) { static const char * const b64_set = "abcdefghijklmnopqrstuvwxyz" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -508,7 +624,7 @@ validate(Port *port, const char *auth, char **logdetail) return true; } - /* Make sure the validator authenticated the user. */ + /* Make sure the validator authenticated the user. */ if (!MyClientConnectionInfo.authn_id) { /* TODO: use logdetail; reduce message duplication */ @@ -518,199 +634,22 @@ validate(Port *port, const char *auth, char **logdetail) return false; } - /* Finally, check the user map. */ - ret = check_usermap(port->hba->usermap, port->user_name, - MyClientConnectionInfo.authn_id, false); + /* Finally, check the user map. */ + ret = check_usermap(port->hba->usermap, port->user_name, + MyClientConnectionInfo.authn_id, false); return (ret == STATUS_OK); } static bool run_validator_command(Port *port, const char *token) { - bool success = false; - int rc; - int pipefd[2]; - int rfd = -1; - int wfd = -1; - - StringInfoData command = { 0 }; - char *p; - FILE *fh = NULL; - - ssize_t written; - char *line = NULL; - size_t size = 0; - ssize_t len; - - Assert(oauth_validator_command); - - if (!oauth_validator_command[0]) - { - ereport(COMMERROR, - (errmsg("oauth_validator_command is not set"), - errhint("To allow OAuth authenticated connections, set " - "oauth_validator_command in postgresql.conf."))); - return false; - } - - /* - * Since popen() is unidirectional, open up a pipe for the other direction. - * Use CLOEXEC to ensure that our write end doesn't accidentally get copied - * into child processes, which would prevent us from closing it cleanly. - * - * XXX this is ugly. We should just read from the child process's stdout, - * but that's a lot more code. - * XXX by bypassing the popen API, we open the potential of process - * deadlock. Clearly document child process requirements (i.e. the child - * MUST read all data off of the pipe before writing anything). - * TODO: port to Windows using _pipe(). - */ - rc = pipe2(pipefd, O_CLOEXEC); - if (rc < 0) - { - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("could not create child pipe: %m"))); - return false; - } - - rfd = pipefd[0]; - wfd = pipefd[1]; - - /* Allow the read pipe be passed to the child. */ - if (!unset_cloexec(rfd)) - { - /* error message was already logged */ - goto cleanup; - } - - /* - * Construct the command, substituting any recognized %-specifiers: - * - * %f: the file descriptor of the input pipe - * %r: the role that the client wants to assume (port->user_name) - * %%: a literal '%' - */ - initStringInfo(&command); - - for (p = oauth_validator_command; *p; p++) - { - if (p[0] == '%') - { - switch (p[1]) - { - case 'f': - appendStringInfo(&command, "%d", rfd); - p++; - break; - case 'r': - /* - * TODO: decide how this string should be escaped. The role - * is controlled by the client, so if we don't escape it, - * command injections are inevitable. - * - * This is probably an indication that the role name needs - * to be communicated to the validator process in some other - * way. For this proof of concept, just be incredibly strict - * about the characters that are allowed in user names. - */ - if (!username_ok_for_shell(port->user_name)) - goto cleanup; - - appendStringInfoString(&command, port->user_name); - p++; - break; - case '%': - appendStringInfoChar(&command, '%'); - p++; - break; - default: - appendStringInfoChar(&command, p[0]); - } - } - else - appendStringInfoChar(&command, p[0]); - } - - /* Execute the command. */ - fh = OpenPipeStream(command.data, "re"); - /* TODO: handle failures */ - - /* We don't need the read end of the pipe anymore. */ - close(rfd); - rfd = -1; - - /* Give the command the token to validate. */ - written = write(wfd, token, strlen(token)); - if (written != strlen(token)) - { - /* TODO must loop for short writes, EINTR et al */ - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("could not write token to child pipe: %m"))); - goto cleanup; - } - - close(wfd); - wfd = -1; - - /* - * Read the command's response. - * - * TODO: getline() is probably too new to use, unfortunately. - * TODO: loop over all lines - */ - if ((len = getline(&line, &size, fh)) >= 0) - { - /* TODO: fail if the authn_id doesn't end with a newline */ - if (len > 0) - line[len - 1] = '\0'; - - set_authn_id(port, line); - } - else if (ferror(fh)) - { - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("could not read from command \"%s\": %m", - command.data))); - goto cleanup; - } - - /* Make sure the command exits cleanly. */ - if (!check_exit(&fh, command.data)) + int result = oauth_provider->oauth_provider_hook(port, token); + if(result == STATUS_OK) { - /* error message already logged */ - goto cleanup; - } - - /* Done. */ - success = true; - -cleanup: - if (line) - free(line); - - /* - * In the successful case, the pipe fds are already closed. For the error - * case, always close out the pipe before waiting for the command, to - * prevent deadlock. - */ - if (rfd >= 0) - close(rfd); - if (wfd >= 0) - close(wfd); - - if (fh) - { - Assert(!success); - check_exit(&fh, command.data); + set_authn_id(port, port->user_name); + return true; } - - if (command.data) - pfree(command.data); - - return success; + return false; } static bool @@ -780,7 +719,7 @@ username_ok_for_shell(const char *username) /* This set is borrowed from fe_utils' appendShellStringNoError(). */ static const char * const allowed = "abcdefghijklmnopqrstuvwxyz" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "0123456789-_./:"; + "0123456789-_./@:"; size_t span; Assert(username && username[0]); /* should have already been checked */ diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h index b62457d57c..7b7b6ff9aa 100644 --- a/src/include/libpq/auth.h +++ b/src/include/libpq/auth.h @@ -28,6 +28,41 @@ extern void set_authn_id(Port *port, const char *id); /* Hook for plugins to get control in ClientAuthentication() */ typedef void (*ClientAuthentication_hook_type) (Port *, int); extern PGDLLIMPORT ClientAuthentication_hook_type ClientAuthentication_hook; +/* Declarations for oAuth authentication providers */ +typedef int (*OAuthProviderCheck_hook_type) (Port *, const char*); + +/* Hook for plugins to report error messages in validation_failed() */ +typedef const char * (*OAuthProviderError_hook_type) (Port *); + +/* Hook for plugins to validate oauth provider options */ +typedef bool (*OAuthProviderValidateOptions_hook_type) + (char *, char *, HbaLine *, char **); + +typedef struct OAuthProviderOptions +{ + char *issuer_url; + char *scope; +} OAuthProviderOptions; + +/* Hook for plugins to get oauth params */ +typedef OAuthProviderOptions *(*OAuthProviderOptions_hook_type) (pg_oauth_flow_type); + +typedef struct OAuthProvider +{ + const char *name; + OAuthProviderCheck_hook_type oauth_provider_hook; + OAuthProviderError_hook_type oauth_error_hook; + OAuthProviderOptions_hook_type oauth_options_hook; +} OAuthProvider; + +extern void RegistorOAuthProvider + (const char *provider_name, + OAuthProviderCheck_hook_type OAuthProviderCheck_hook, + OAuthProviderError_hook_type OAuthProviderError_hook, + OAuthProviderOptions_hook_type OAuthProviderParams_hook + ); + +extern OAuthProvider *get_provider_by_name(const char *name); #define PG_MAX_AUTH_TOKEN_LENGTH 65535 #endif /* AUTH_H */ diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h index 6d452ec6d9..f7bbb9dcf4 100644 --- a/src/include/libpq/libpq-be.h +++ b/src/include/libpq/libpq-be.h @@ -68,6 +68,17 @@ typedef enum CAC_state CAC_TOOMANY } CAC_state; +/* OAuth flow types */ +typedef enum pg_oauth_flow_type +{ + OAUTH_DEVICE_CODE, + OAUTH_CLIENT_CREDENTIALS, + OAUTH_AUTH, + OAUTH_AUTH_PKCE, + OAUTH_REFRESH_TOKEN, + OAUTH_NONE +} pg_oauth_flow_type; + /* * GSSAPI specific state information diff --git a/src/interfaces/libpq/fe-auth-oauth.c b/src/interfaces/libpq/fe-auth-oauth.c index 91d2c69f16..1ba2e033c4 100644 --- a/src/interfaces/libpq/fe-auth-oauth.c +++ b/src/interfaces/libpq/fe-auth-oauth.c @@ -142,6 +142,43 @@ iddawc_request_error(PGconn *conn, struct _i_session *i, int err, const char *ms appendPQExpBuffer(&conn->errorMessage, "(%s)\n", error_code); } +static pg_oauth_flow_type oauth_get_flow_type(const char *oauthflow) +{ + pg_oauth_flow_type flow_type; + + if(!oauthflow) + { + return OAUTH_NONE; + } + + /* client_secret, device_code, auth_code_pkce, refresh_token */ + if(strcmp(oauthflow, "device_code") == 0) + { + flow_type = OAUTH_DEVICE_CODE; + } + else if(strcmp(oauthflow, "client_secret") == 0) + { + flow_type = OAUTH_CLIENT_CREDENTIALS; + } + else if(strcmp(oauthflow, "auth_code_pkce") == 0) + { + flow_type = OAUTH_AUTH_PKCE; + } + else if(strcmp(oauthflow, "refresh_token") == 0) + { + flow_type = OAUTH_REFRESH_TOKEN; + } + else if(strcmp(oauthflow, "auth_code")) + { + flow_type = OAUTH_AUTH_CODE; + } + else + { + flow_type = OAUTH_NONE; + } + return flow_type; +} + static char * get_auth_token(PGconn *conn) { @@ -150,29 +187,44 @@ get_auth_token(PGconn *conn) int err; int auth_method; bool user_prompted = false; - const char *verification_uri; - const char *user_code; - const char *access_token; - const char *token_type; - char *token = NULL; - + char *verification_uri; + char *user_code; + char *access_token; + char *refresh_token; + char *token_type; + pg_oauth_flow_type flow_type; + char *token = NULL; + uint session_response_type; + PGOAuthMsgObj oauthMsgObj; + + MemSet(&oauthMsgObj, 0x00, sizeof(PGOAuthMsgObj)); + if (!conn->oauth_discovery_uri) return strdup(""); /* ask the server for one */ - i_init_session(&session); - if (!conn->oauth_client_id) { /* We can't talk to a server without a client identifier. */ appendPQExpBufferStr(&conn->errorMessage, libpq_gettext("no oauth_client_id is set for the connection")); - goto cleanup; + return NULL; } - token_buf = createPQExpBuffer(); + i_init_session(&session); + token_buf = createPQExpBuffer(); if (!token_buf) goto cleanup; + + if(conn->oauth_bearer_token) + { + appendPQExpBufferStr(token_buf, "Bearer "); + appendPQExpBufferStr(token_buf, conn->oauth_bearer_token); + if (PQExpBufferBroken(token_buf)) + goto cleanup; + token = strdup(token_buf->data); + goto cleanup; + } err = i_set_str_parameter(&session, I_OPT_OPENID_CONFIG_ENDPOINT, conn->oauth_discovery_uri); if (err) @@ -181,6 +233,8 @@ get_auth_token(PGconn *conn) goto cleanup; } + flow_type = oauth_get_flow_type(conn->oauth_flow_type); + err = i_get_openid_config(&session); if (err) { @@ -201,18 +255,64 @@ get_auth_token(PGconn *conn) libpq_gettext("issuer does not support device authorization")); goto cleanup; } + auth_method = I_TOKEN_AUTH_METHOD_NONE; + + /* for refresh token flow, do not run auth request*/ + if(flow_type == OAUTH_REFRESH_TOKEN && conn->oauth_refresh_token) + { + err = i_set_parameter_list(&session, + I_OPT_CLIENT_ID, conn->oauth_client_id, + I_OPT_REFRESH_TOKEN, conn->oauth_refresh_token, + I_OPT_RESPONSE_TYPE, I_RESPONSE_TYPE_REFRESH_TOKEN, + I_OPT_TOKEN_METHOD, auth_method, + I_OPT_CLIENT_SECRET, conn->oauth_client_secret, + I_OPT_SCOPE, conn->oauth_scope, + I_OPT_NONE + ); + + if (err) + { + iddawc_error(conn, err, "failed to set refresh token flow parameters"); + goto cleanup; + } - err = i_set_response_type(&session, I_RESPONSE_TYPE_DEVICE_CODE); + err = i_run_token_request(&session); + if (err) + { + iddawc_request_error(conn, &session, err, + "failed to obtain token authorization with refresh token flow"); + goto cleanup; + } + + access_token = i_get_str_parameter(&session, I_OPT_ACCESS_TOKEN); + token_type = i_get_str_parameter(&session, I_OPT_TOKEN_TYPE); + + if (!access_token || !token_type || strcasecmp(token_type, "Bearer")) + { + appendPQExpBufferStr(&conn->errorMessage, + libpq_gettext("issuer did not provide a bearer token")); + goto cleanup; + } + + appendPQExpBufferStr(token_buf, "Bearer "); + appendPQExpBufferStr(token_buf, access_token); + + if (PQExpBufferBroken(token_buf)) + goto cleanup; + + token = strdup(token_buf->data); + return token; + } + + //default device flow + session_response_type = I_RESPONSE_TYPE_DEVICE_CODE; + err = i_set_response_type(&session, session_response_type); if (err) { iddawc_error(conn, err, "failed to set device code response type"); goto cleanup; } - auth_method = I_TOKEN_AUTH_METHOD_NONE; - if (conn->oauth_client_secret && *conn->oauth_client_secret) - auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC; - err = i_set_parameter_list(&session, I_OPT_CLIENT_ID, conn->oauth_client_id, I_OPT_CLIENT_SECRET, conn->oauth_client_secret, @@ -225,7 +325,7 @@ get_auth_token(PGconn *conn) iddawc_error(conn, err, "failed to set client identifier"); goto cleanup; } - + err = i_run_device_auth_request(&session); if (err) { @@ -278,14 +378,15 @@ get_auth_token(PGconn *conn) if (!user_prompted) { + oauthMsgObj.verification_uri = verification_uri; + oauthMsgObj.user_code = user_code; + conn->oauthNoticeHooks.noticeRecArg = (void*) &oauthMsgObj; + /* * Now that we know the token endpoint isn't broken, give the user * the login instructions. - */ - pqInternalNotice(&conn->noticeHooks, - "Visit %s and enter the code: %s", - verification_uri, user_code); - + */ + pqInternalOAuthNotice(&conn->oauthNoticeHooks, ""); user_prompted = true; } @@ -300,7 +401,7 @@ get_auth_token(PGconn *conn) * A slow_down error requires us to permanently increase our retry * interval by five seconds. RFC 8628, Sec. 3.5. */ - if (!strcmp(error_code, "slow_down")) + //if (!strcmp(error_code, "slow_down")) { interval += 5; i_set_int_parameter(&session, I_OPT_DEVICE_AUTH_INTERVAL, interval); @@ -323,6 +424,14 @@ get_auth_token(PGconn *conn) access_token = i_get_str_parameter(&session, I_OPT_ACCESS_TOKEN); token_type = i_get_str_parameter(&session, I_OPT_TOKEN_TYPE); + refresh_token = i_get_str_parameter(&session, I_OPT_REFRESH_TOKEN); + + if(refresh_token) + { + MemSet(&oauthMsgObj, 0x00, sizeof(PGOAuthMsgObj)); + oauthMsgObj.refresh_token = refresh_token; + pqInternalOAuthNotice(&conn->oauthNoticeHooks, ""); + } if (!access_token || !token_type || strcasecmp(token_type, "Bearer")) { @@ -358,6 +467,8 @@ client_initial_response(PGconn *conn) PQExpBuffer discovery_buf = NULL; char *token = NULL; char *response = NULL; + pg_oauth_flow_type flow_type; + char oauth_flow_str[3]; token_buf = createPQExpBuffer(); if (!token_buf) @@ -385,8 +496,26 @@ client_initial_response(PGconn *conn) token = get_auth_token(conn); if (!token) goto cleanup; - + + if(strcmp(token, "") == 0) + { + flow_type = oauth_get_flow_type(conn->oauth_flow_type); + if(flow_type == OAUTH_NONE) + { + appendPQExpBufferStr(&conn->errorMessage, + libpq_gettext("value passed in oauth_flow_type is not valid."\ + "supported flows: client_secret, device_code, auth_code_pkce, refresh_token\n")); + goto cleanup; + } + else + { + sprintf(oauth_flow_str, "%d", flow_type); + token = strdup(oauth_flow_str); + } + } appendPQExpBuffer(token_buf, resp_format, token); +// elog(INFO, "fe-flowtype: %s", token); + if (PQExpBufferBroken(token_buf)) goto cleanup; @@ -406,6 +535,9 @@ cleanup: #define ERROR_STATUS_FIELD "status" #define ERROR_SCOPE_FIELD "scope" #define ERROR_OPENID_CONFIGURATION_FIELD "openid-configuration" +#define ERROR_ISSUER_URL_FIELD "issuer" +#define ERROR_AUTH_ENDPOINT_FIELD "authorization_endpoint" +#define ERROR_TOKEN_ENDPOINT_FIELD "token_endpoint" struct json_ctx { @@ -420,6 +552,9 @@ struct json_ctx char *status; char *scope; char *discovery_uri; + char *issuer_url; + char *auth_endpoint; + char *token_endpoint; }; #define oauth_json_has_error(ctx) \ @@ -491,6 +626,21 @@ oauth_json_object_field_start(void *state, char *name, bool isnull) ctx->target_field_name = ERROR_OPENID_CONFIGURATION_FIELD; ctx->target_field = &ctx->discovery_uri; } + else if(!strcmp(name, ERROR_ISSUER_URL_FIELD)) + { + ctx->target_field_name = ERROR_ISSUER_URL_FIELD; + ctx->target_field = &ctx->issuer_url; + } + else if(!strcmp(name, ERROR_AUTH_ENDPOINT_FIELD)) + { + ctx->target_field_name = ERROR_AUTH_ENDPOINT_FIELD; + ctx->target_field = &ctx->auth_endpoint; + } + else if(!strcmp(name, ERROR_TOKEN_ENDPOINT_FIELD)) + { + ctx->target_field_name = ERROR_TOKEN_ENDPOINT_FIELD; + ctx->target_field = &ctx->token_endpoint; + } } free(name); @@ -627,6 +777,15 @@ handle_oauth_sasl_error(PGconn *conn, char *msg, int msglen) conn->oauth_scope = ctx.scope; } + + if(ctx.issuer_url) + { + if(conn->oauth_issuer) + free(conn->oauth_issuer); + + conn->oauth_issuer = ctx.issuer_url; + } + /* TODO: missing error scope should clear any existing connection scope */ if (!ctx.status) diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c index 64f27fee18..e6e8dc48e2 100644 --- a/src/interfaces/libpq/fe-connect.c +++ b/src/interfaces/libpq/fe-connect.c @@ -358,6 +358,18 @@ static const internalPQconninfoOption PQconninfoOptions[] = { "OAuth-Scope", "", 15, offsetof(struct pg_conn, oauth_scope)}, + {"oauth_bearer_token", NULL, NULL, NULL, + "OAuth-Bearer", "", 20, + offsetof(struct pg_conn, oauth_bearer_token)}, + + {"oauth_flow_type", NULL, NULL, NULL, + "OAuth-Flow-Type", "", 20, + offsetof(struct pg_conn, oauth_flow_type)}, + + {"oauth_refresh_token", NULL, NULL, NULL, + "OAuth-Refresh-Token", "", 40, + offsetof(struct pg_conn, oauth_refresh_token)}, + /* Terminating entry --- MUST BE LAST */ {NULL, NULL, NULL, NULL, NULL, NULL, 0} @@ -427,6 +439,7 @@ static PQconninfoOption *conninfo_find(PQconninfoOption *connOptions, const char *keyword); static void defaultNoticeReceiver(void *arg, const PGresult *res); static void defaultNoticeProcessor(void *arg, const char *message); +static void OAuthMsgObjReceiver(void *arg, const PGresult *res); static int parseServiceInfo(PQconninfoOption *options, PQExpBuffer errorMessage); static int parseServiceFile(const char *serviceFile, @@ -3926,6 +3939,7 @@ makeEmptyPGconn(void) /* install default notice hooks */ conn->noticeHooks.noticeRec = defaultNoticeReceiver; conn->noticeHooks.noticeProc = defaultNoticeProcessor; + conn->oauthNoticeHooks.noticeRec = OAuthMsgObjReceiver; conn->status = CONNECTION_BAD; conn->asyncStatus = PGASYNC_IDLE; @@ -4073,6 +4087,12 @@ freePGconn(PGconn *conn) free(conn->oauth_client_secret); if (conn->oauth_scope) free(conn->oauth_scope); + if(conn->oauth_bearer_token) + free(conn->oauth_bearer_token); + if(conn->oauth_flow_type) + free(conn->oauth_flow_type); + if(conn->oauth_refresh_token) + free(conn->oauth_refresh_token); termPQExpBuffer(&conn->errorMessage); termPQExpBuffer(&conn->workBuffer); @@ -6991,6 +7011,32 @@ defaultNoticeProcessor(void *arg, const char *message) fprintf(stderr, "%s", message); } +static void +OAuthMsgObjReceiver(void *arg, const PGresult *res) +{ + PGOAuthMsgObj *oauthMsg = (PGOAuthMsgObj *) arg; + + if(oauthMsg->message) + { + fprintf(stderr, "%s\n", oauthMsg->message); + } + + if(oauthMsg->verification_uri) + { + fprintf(stderr, "Visit: %s\n", oauthMsg->verification_uri); + } + + if(oauthMsg->user_code) + { + fprintf(stderr, "Enter: %s\n", oauthMsg->user_code); + } + + if(oauthMsg->refresh_token) + { + fprintf(stderr, "Refresh Token: %s\n", oauthMsg->refresh_token); + } +} + /* * returns a pointer to the next token or NULL if the current * token doesn't match diff --git a/src/interfaces/libpq/fe-exec.c b/src/interfaces/libpq/fe-exec.c index da229d632a..4789c1a1fe 100644 --- a/src/interfaces/libpq/fe-exec.c +++ b/src/interfaces/libpq/fe-exec.c @@ -976,6 +976,58 @@ pqInternalNotice(const PGNoticeHooks *hooks, const char *fmt,...) PQclear(res); } +/* + * pqInternalOAuthNotice - it is similar to pqInternalNotice + * except that OAuthNoticeHooks are invoked. + */ +void +pqInternalOAuthNotice(const PGOAuthNoticeHooks *hooks, const char *fmt,...) +{ + char msgBuf[1024]; + va_list args; + PGresult *res; + + if (hooks->noticeRec == NULL) + return; /* nobody home to receive notice? */ + + /* Format the message */ + va_start(args, fmt); + vsnprintf(msgBuf, sizeof(msgBuf), libpq_gettext(fmt), args); + va_end(args); + msgBuf[sizeof(msgBuf) - 1] = '\0'; /* make real sure it's terminated */ + + /* Make a PGresult to pass to the notice receiver */ + res = PQmakeEmptyPGresult(NULL, PGRES_NONFATAL_ERROR); + if (!res) + return; + res->oauthNoticeHooks = *hooks; + res->oauthNoticeHooks.noticeRecArg = hooks->noticeRecArg; + + /* + * Set up fields of notice. + */ + pqSaveMessageField(res, PG_DIAG_MESSAGE_PRIMARY, msgBuf); + pqSaveMessageField(res, PG_DIAG_SEVERITY, libpq_gettext("NOTICE")); + pqSaveMessageField(res, PG_DIAG_SEVERITY_NONLOCALIZED, "NOTICE"); + /* XXX should provide a SQLSTATE too? */ + + /* + * Result text is always just the primary message + newline. If we can't + * allocate it, substitute "out of memory", as in pqSetResultError. + */ + res->errMsg = (char *) pqResultAlloc(res, strlen(msgBuf) + 2, false); + if (res->errMsg) + sprintf(res->errMsg, "%s\n", msgBuf); + else + res->errMsg = libpq_gettext("out of memory\n"); + + /* + * Pass to receiver, then free it. + */ + res->oauthNoticeHooks.noticeRec(res->oauthNoticeHooks.noticeRecArg, res); + PQclear(res); +} + /* * pqAddTuple * add a row pointer to the PGresult structure, growing it if necessary diff --git a/src/interfaces/libpq/libpq-fe.h b/src/interfaces/libpq/libpq-fe.h index b7df3224c0..ee5b2e2b59 100644 --- a/src/interfaces/libpq/libpq-fe.h +++ b/src/interfaces/libpq/libpq-fe.h @@ -197,6 +197,9 @@ typedef struct pgNotify typedef void (*PQnoticeReceiver) (void *arg, const PGresult *res); typedef void (*PQnoticeProcessor) (void *arg, const char *message); +/* Function types for notice-handling callbacks */ +typedef void (*PQOAuthNoticeReceiver) (void *arg, const PGresult *res); + /* Print options for PQprint() */ typedef char pqbool; diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h index ae76ae0e8f..3155d81e00 100644 --- a/src/interfaces/libpq/libpq-int.h +++ b/src/interfaces/libpq/libpq-int.h @@ -157,6 +157,24 @@ typedef struct void *noticeProcArg; } PGNoticeHooks; +typedef struct +{ + char *verification_uri; /* URI the user should go to with the user_code in order to sign in */ + char *user_code; /* used to identify the session on a secondary device */ + char *refresh_token; + char *message; /* string with instructions for the user. */ + char *response_error; /*JSON error response (400 Bad Request) */ + uint expires_in; /* number of seconds before the device_code expire */ + uint interval; /* number of seconds the client should wait between polling requests */ +} PGOAuthMsgObj; + +/* Fields needed for oauth callback handling */ +typedef struct +{ + PQOAuthNoticeReceiver noticeRec; /* OAuth notice message receiver */ + void *noticeRecArg; +} PGOAuthNoticeHooks; + typedef struct PGEvent { PGEventProc proc; /* the function to call on events */ @@ -186,6 +204,7 @@ struct pg_result * on the PGresult don't have to reference the PGconn. */ PGNoticeHooks noticeHooks; + PGOAuthNoticeHooks oauthNoticeHooks; PGEvent *events; int nEvents; int client_encoding; /* encoding id */ @@ -343,6 +362,17 @@ typedef struct pg_conn_host * found in password file. */ } pg_conn_host; +typedef enum pg_oauth_flow_type +{ + OAUTH_DEVICE_CODE, + OAUTH_CLIENT_CREDENTIALS, + OAUTH_AUTH, + OAUTH_AUTH_PKCE, + OAUTH_REFRESH_TOKEN, + OAUTH_AUTH_CODE, + OAUTH_NONE +} pg_oauth_flow_type; + /* * PGconn stores all the state data associated with a single connection * to a backend. @@ -403,6 +433,9 @@ struct pg_conn char *oauth_client_id; /* client identifier */ char *oauth_client_secret; /* client secret */ char *oauth_scope; /* access token scope */ + char *oauth_bearer_token; /* oauth token */ + char *oauth_flow_type; /* oauth flow type */ + char *oauth_refresh_token; /* refresh token */ bool oauth_want_retry; /* should we retry on failure? */ /* Optional file to write trace info to */ @@ -412,6 +445,9 @@ struct pg_conn /* Callback procedures for notice message processing */ PGNoticeHooks noticeHooks; + /* Callback procedures for notifying messages during oauth flows*/ + PGOAuthNoticeHooks oauthNoticeHooks; + /* Event procs registered via PQregisterEventProc */ PGEvent *events; /* expandable array of event data */ int nEvents; /* number of active events */ @@ -677,6 +713,7 @@ extern void pqClearAsyncResult(PGconn *conn); extern void pqSaveErrorResult(PGconn *conn); extern PGresult *pqPrepareAsyncResult(PGconn *conn); extern void pqInternalNotice(const PGNoticeHooks *hooks, const char *fmt,...) pg_attribute_printf(2, 3); +extern void pqInternalOAuthNotice(const PGOAuthNoticeHooks *hooks, const char *fmt,...); extern void pqSaveMessageField(PGresult *res, char code, const char *value); extern void pqSaveParameterStatus(PGconn *conn, const char *name,