diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c index c47211132c..86f820482b 100644 --- a/src/backend/libpq/auth-oauth.c +++ b/src/backend/libpq/auth-oauth.c @@ -24,7 +24,9 @@ #include "libpq/hba.h" #include "libpq/oauth.h" #include "libpq/sasl.h" +#include "miscadmin.h" #include "storage/fd.h" +#include "utils/memutils.h" /* GUC */ char *oauth_validator_command; @@ -34,6 +36,13 @@ static void *oauth_init(Port *port, const char *selected_mech, const char *shado static int oauth_exchange(void *opaq, const char *input, int inputlen, char **output, int *outputlen, char **logdetail); +/*---------------------------------------------------------------- + * OAuth Authentication + *---------------------------------------------------------------- + */ +static List *oauth_providers = NIL; +static OAuthProvider* oauth_provider = NULL; + /* Mechanism declaration */ const pg_be_sasl_mech pg_be_oauth_mech = { oauth_get_mechanisms, @@ -63,15 +72,90 @@ static char *sanitize_char(char c); static char *parse_kvpairs_for_auth(char **input); static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen); static bool validate(Port *port, const char *auth, char **logdetail); -static bool run_validator_command(Port *port, const char *token); +static const char* run_validator_command(Port *port, const char *token); static bool check_exit(FILE **fh, const char *command); static bool unset_cloexec(int fd); -static bool username_ok_for_shell(const char *username); #define KVSEP 0x01 #define AUTH_KEY "auth" #define BEARER_SCHEME "Bearer " +/*---------------------------------------------------------------- + * OAuth Token Validator + *---------------------------------------------------------------- + */ + +/* + * RegisterOAuthProvider registers a OAuth Token Validator to be + * used for oauth token validation. It validates the token and adds the valiator + * name and it's hooks to a list of loaded token validator. The right validator's + * hooks can then be called based on the validator name specified in + * pg_hba.conf. + * + * This function should be called in _PG_init() by any extension looking to + * add a custom authentication method. + */ +void +RegisterOAuthProvider( + const char *provider_name, + OAuthProviderCheck_hook_type OAuthProviderCheck_hook, + OAuthProviderError_hook_type OAuthProviderError_hook +) +{ + if (!process_shared_preload_libraries_in_progress) + { + ereport(ERROR, + (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("RegisterOAuthProvider can only be called by a shared_preload_library"))); + return; + } + + MemoryContext oldcxt; + if (oauth_provider == NULL) + { + oldcxt = MemoryContextSwitchTo(TopMemoryContext); + oauth_provider = palloc(sizeof(OAuthProvider)); + oauth_provider->name = pstrdup(provider_name); + oauth_provider->oauth_provider_hook = OAuthProviderCheck_hook; + oauth_provider->oauth_error_hook = OAuthProviderError_hook; + oauth_providers = lappend(oauth_providers, oauth_provider); + MemoryContextSwitchTo(oldcxt); + } + else + { + if (oauth_provider && oauth_provider->name) + { + ereport(ERROR, + (errmsg("OAuth provider \"%s\" is already loaded.", + oauth_provider->name))); + } + else + { + ereport(ERROR, + (errmsg("OAuth provider is already loaded."))); + } + } +} + +/* + * Returns the oauth provider (which includes it's + * callback functions) based on name specified. + */ +OAuthProvider *get_provider_by_name(const char *name) +{ + ListCell *lc; + foreach(lc, oauth_providers) + { + OAuthProvider *provider = (OAuthProvider *) lfirst(lc); + if (strcmp(provider->name, name) == 0) + { + return provider; + } + } + + return NULL; +} + static void oauth_get_mechanisms(Port *port, StringInfo buf) { @@ -494,17 +578,17 @@ validate(Port *port, const char *auth, char **logdetail) } /* Have the validator check the token. */ - if (!run_validator_command(port, token)) + if (run_validator_command(port, token) == NULL) return false; - + if (port->hba->oauth_skip_usermap) { /* - * If the validator is our authorization authority, we're done. - * Authentication may or may not have been performed depending on the - * validator implementation; all that matters is that the validator says - * the user can log in with the target role. - */ + * If the validator is our authorization authority, we're done. + * Authentication may or may not have been performed depending on the + * validator implementation; all that matters is that the validator says + * the user can log in with the target role. + */ return true; } @@ -524,193 +608,26 @@ validate(Port *port, const char *auth, char **logdetail) return (ret == STATUS_OK); } -static bool +static const char* run_validator_command(Port *port, const char *token) { - bool success = false; - int rc; - int pipefd[2]; - int rfd = -1; - int wfd = -1; - - StringInfoData command = { 0 }; - char *p; - FILE *fh = NULL; - - ssize_t written; - char *line = NULL; - size_t size = 0; - ssize_t len; - - Assert(oauth_validator_command); - - if (!oauth_validator_command[0]) - { - ereport(COMMERROR, - (errmsg("oauth_validator_command is not set"), - errhint("To allow OAuth authenticated connections, set " - "oauth_validator_command in postgresql.conf."))); - return false; - } - - /* - * Since popen() is unidirectional, open up a pipe for the other direction. - * Use CLOEXEC to ensure that our write end doesn't accidentally get copied - * into child processes, which would prevent us from closing it cleanly. - * - * XXX this is ugly. We should just read from the child process's stdout, - * but that's a lot more code. - * XXX by bypassing the popen API, we open the potential of process - * deadlock. Clearly document child process requirements (i.e. the child - * MUST read all data off of the pipe before writing anything). - * TODO: port to Windows using _pipe(). - */ - rc = pipe2(pipefd, O_CLOEXEC); - if (rc < 0) + if(oauth_provider->oauth_provider_hook == NULL) { - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("could not create child pipe: %m"))); return false; } - rfd = pipefd[0]; - wfd = pipefd[1]; - - /* Allow the read pipe be passed to the child. */ - if (!unset_cloexec(rfd)) + char *id = oauth_provider-> + oauth_provider_hook(port, token); + if(id == NULL) { - /* error message was already logged */ - goto cleanup; - } - - /* - * Construct the command, substituting any recognized %-specifiers: - * - * %f: the file descriptor of the input pipe - * %r: the role that the client wants to assume (port->user_name) - * %%: a literal '%' - */ - initStringInfo(&command); - - for (p = oauth_validator_command; *p; p++) - { - if (p[0] == '%') - { - switch (p[1]) - { - case 'f': - appendStringInfo(&command, "%d", rfd); - p++; - break; - case 'r': - /* - * TODO: decide how this string should be escaped. The role - * is controlled by the client, so if we don't escape it, - * command injections are inevitable. - * - * This is probably an indication that the role name needs - * to be communicated to the validator process in some other - * way. For this proof of concept, just be incredibly strict - * about the characters that are allowed in user names. - */ - if (!username_ok_for_shell(port->user_name)) - goto cleanup; - - appendStringInfoString(&command, port->user_name); - p++; - break; - case '%': - appendStringInfoChar(&command, '%'); - p++; - break; - default: - appendStringInfoChar(&command, p[0]); - } - } - else - appendStringInfoChar(&command, p[0]); - } - - /* Execute the command. */ - fh = OpenPipeStream(command.data, "re"); - /* TODO: handle failures */ - - /* We don't need the read end of the pipe anymore. */ - close(rfd); - rfd = -1; - - /* Give the command the token to validate. */ - written = write(wfd, token, strlen(token)); - if (written != strlen(token)) - { - /* TODO must loop for short writes, EINTR et al */ - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("could not write token to child pipe: %m"))); - goto cleanup; - } - - close(wfd); - wfd = -1; - - /* - * Read the command's response. - * - * TODO: getline() is probably too new to use, unfortunately. - * TODO: loop over all lines - */ - if ((len = getline(&line, &size, fh)) >= 0) - { - /* TODO: fail if the authn_id doesn't end with a newline */ - if (len > 0) - line[len - 1] = '\0'; - - set_authn_id(port, line); - } - else if (ferror(fh)) - { - ereport(COMMERROR, - (errcode_for_file_access(), - errmsg("could not read from command \"%s\": %m", - command.data))); - goto cleanup; - } - - /* Make sure the command exits cleanly. */ - if (!check_exit(&fh, command.data)) - { - /* error message already logged */ - goto cleanup; - } - - /* Done. */ - success = true; - -cleanup: - if (line) - free(line); - - /* - * In the successful case, the pipe fds are already closed. For the error - * case, always close out the pipe before waiting for the command, to - * prevent deadlock. - */ - if (rfd >= 0) - close(rfd); - if (wfd >= 0) - close(wfd); - - if (fh) - { - Assert(!success); - check_exit(&fh, command.data); + ereport(LOG, + (errmsg("OAuth bearer token validation failed" ))); + return NULL; } - if (command.data) - pfree(command.data); - - return success; + set_authn_id(port, id); + + return id; } static bool @@ -769,29 +686,3 @@ unset_cloexec(int fd) return true; } - -/* - * XXX This should go away eventually and be replaced with either a proper - * escape or a different strategy for communication with the validator command. - */ -static bool -username_ok_for_shell(const char *username) -{ - /* This set is borrowed from fe_utils' appendShellStringNoError(). */ - static const char * const allowed = "abcdefghijklmnopqrstuvwxyz" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "0123456789-_./:"; - size_t span; - - Assert(username && username[0]); /* should have already been checked */ - - span = strspn(username, allowed); - if (username[span] != '\0') - { - ereport(COMMERROR, - (errmsg("PostgreSQL user name contains unsafe characters and cannot be passed to the OAuth validator"))); - return false; - } - - return true; -} diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c index 333051ad3c..0bbcf231d2 100644 --- a/src/backend/libpq/auth.c +++ b/src/backend/libpq/auth.c @@ -296,8 +296,14 @@ auth_failed(Port *port, int status, const char *logdetail) errstr = gettext_noop("RADIUS authentication failed for user \"%s\""); break; case uaOAuth: - errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\""); - break; + { + OAuthProvider *provider = get_provider_by_name(port->hba->oauth_provider); + if(provider->oauth_error_hook) + errstr = provider->oauth_error_hook(port); + else + errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\""); + break; + } default: errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method"); break; diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c index 943e78ddff..94fb5d434d 100644 --- a/src/backend/libpq/hba.c +++ b/src/backend/libpq/hba.c @@ -1663,6 +1663,14 @@ parse_hba_line(TokenizedAuthLine *tok_line, int elevel) parsedline->clientcert = clientCertFull; } + /* + * Ensure that the token validation provider name is specified as provider for oauth method. + */ + if (parsedline->auth_method == uaOAuth) + { + MANDATORY_AUTH_ARG(parsedline->oauth_provider, "provider", "oauth"); + } + return parsedline; } @@ -2095,6 +2103,31 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline, else hbaline->oauth_skip_usermap = false; } + else if (strcmp(name, "provider") == 0) + { + REQUIRE_AUTH_OPTION(uaOAuth, "provider", "oauth"); + if (hbaline->auth_method != uaOAuth) + INVALID_AUTH_OPTION("provider", gettext_noop("oauth")); + /* + * Verify that the token validation mentioned is loaded via shared_preload_libraries. + */ + if (get_provider_by_name(val) == NULL) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("cannot use oauth provider %s",val), + errhint("Load provider token validation via shared_preload_libraries."), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("cannot use oauth provider %s", val); + + return false; + } + else + { + hbaline->oauth_provider = pstrdup(val); + } + } else { ereport(elevel, diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h index 485e48970e..938ac399dc 100644 --- a/src/include/libpq/auth.h +++ b/src/include/libpq/auth.h @@ -44,4 +44,29 @@ extern void set_authn_id(Port *port, const char *id); typedef void (*ClientAuthentication_hook_type) (Port *, int); extern PGDLLIMPORT ClientAuthentication_hook_type ClientAuthentication_hook; +/* Declarations for oAuth authentication providers */ +typedef const char* (*OAuthProviderCheck_hook_type) (Port *, const char*); + +/* Hook for plugins to report error messages in validation_failed() */ +typedef const char * (*OAuthProviderError_hook_type) (Port *); + +/* Hook for plugins to validate oauth provider options */ +typedef bool (*OAuthProviderValidateOptions_hook_type) + (char *, char *, HbaLine *, char **); + +typedef struct OAuthProvider +{ + const char *name; + OAuthProviderCheck_hook_type oauth_provider_hook; + OAuthProviderError_hook_type oauth_error_hook; +} OAuthProvider; + +extern void RegisterOAuthProvider + (const char *provider_name, + OAuthProviderCheck_hook_type OAuthProviderCheck_hook, + OAuthProviderError_hook_type OAuthProviderError_hook + ); + +extern OAuthProvider *get_provider_by_name(const char *name); + #endif /* AUTH_H */ diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h index c1b1313989..d65395cc22 100644 --- a/src/include/libpq/hba.h +++ b/src/include/libpq/hba.h @@ -123,6 +123,7 @@ typedef struct HbaLine char *radiusports_s; char *oauth_issuer; char *oauth_scope; + char *oauth_provider; bool oauth_skip_usermap; } HbaLine; diff --git a/src/interfaces/libpq/fe-auth-oauth.c b/src/interfaces/libpq/fe-auth-oauth.c index 91d2c69f16..61a0b80b7e 100644 --- a/src/interfaces/libpq/fe-auth-oauth.c +++ b/src/interfaces/libpq/fe-auth-oauth.c @@ -174,6 +174,16 @@ get_auth_token(PGconn *conn) if (!token_buf) goto cleanup; + if(conn->oauth_bearer_token) + { + appendPQExpBufferStr(token_buf, "Bearer "); + appendPQExpBufferStr(token_buf, conn->oauth_bearer_token); + if (PQExpBufferBroken(token_buf)) + goto cleanup; + token = strdup(token_buf->data); + goto cleanup; + } + err = i_set_str_parameter(&session, I_OPT_OPENID_CONFIG_ENDPOINT, conn->oauth_discovery_uri); if (err) { @@ -201,18 +211,22 @@ get_auth_token(PGconn *conn) libpq_gettext("issuer does not support device authorization")); goto cleanup; } + + //default device flow + int session_response_type = I_RESPONSE_TYPE_DEVICE_CODE; + auth_method = I_TOKEN_AUTH_METHOD_NONE; + if (conn->oauth_client_secret && *conn->oauth_client_secret) + { + auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC; + } - err = i_set_response_type(&session, I_RESPONSE_TYPE_DEVICE_CODE); + err = i_set_response_type(&session, session_response_type); if (err) { iddawc_error(conn, err, "failed to set device code response type"); goto cleanup; } - auth_method = I_TOKEN_AUTH_METHOD_NONE; - if (conn->oauth_client_secret && *conn->oauth_client_secret) - auth_method = I_TOKEN_AUTH_METHOD_SECRET_BASIC; - err = i_set_parameter_list(&session, I_OPT_CLIENT_ID, conn->oauth_client_id, I_OPT_CLIENT_SECRET, conn->oauth_client_secret, @@ -250,6 +264,18 @@ get_auth_token(PGconn *conn) goto cleanup; } + if (conn->oauth_client_secret && *conn->oauth_client_secret) + { + session_response_type = I_RESPONSE_TYPE_CLIENT_CREDENTIALS; + } + + err = i_set_response_type(&session, session_response_type); + if (err) + { + iddawc_error(conn, err, "failed to set session response type"); + goto cleanup; + } + /* * Poll the token endpoint until either the user logs in and authorizes the * use of a token, or a hard failure occurs. We perform one ping _before_ diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c index 2ff450ce05..5d804c8c0d 100644 --- a/src/interfaces/libpq/fe-connect.c +++ b/src/interfaces/libpq/fe-connect.c @@ -361,6 +361,10 @@ static const internalPQconninfoOption PQconninfoOptions[] = { "OAuth-Scope", "", 15, offsetof(struct pg_conn, oauth_scope)}, + {"oauth_bearer_token", NULL, NULL, NULL, + "OAuth-Bearer", "", 20, + offsetof(struct pg_conn, oauth_bearer_token)}, + /* Terminating entry --- MUST BE LAST */ {NULL, NULL, NULL, NULL, NULL, NULL, 0} @@ -4200,6 +4204,8 @@ freePGconn(PGconn *conn) free(conn->oauth_discovery_uri); if (conn->oauth_client_id) free(conn->oauth_client_id); + if(conn->oauth_bearer_token) + free(conn->oauth_bearer_token); if (conn->oauth_client_secret) free(conn->oauth_client_secret); if (conn->oauth_scope) diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h index 1b4de3dff0..91e71afe14 100644 --- a/src/interfaces/libpq/libpq-int.h +++ b/src/interfaces/libpq/libpq-int.h @@ -402,6 +402,7 @@ struct pg_conn char *oauth_client_id; /* client identifier */ char *oauth_client_secret; /* client secret */ char *oauth_scope; /* access token scope */ + char *oauth_bearer_token; /* oauth token */ bool oauth_want_retry; /* should we retry on failure? */ /* Optional file to write trace info to */