From 044d8b08e99ddcac3e9680ee50c53672009bf95b Mon Sep 17 00:00:00 2001 From: Jacob Champion Date: Tue, 4 May 2021 16:21:11 -0700 Subject: [PATCH v23 5/6] backend: add OAUTHBEARER SASL mechanism DO NOT USE THIS PROOF OF CONCEPT IN PRODUCTION. Implement OAUTHBEARER (RFC 7628) on the server side. This adds a new auth method, oauth, to pg_hba. Because OAuth implementations vary so wildly, and bearer token validation is heavily dependent on the issuing party, authn/z is done by communicating with an external validator module using callbacks. The module is responsible for: 1. Validate the bearer token. The correct way to do this depends on the issuer, but it generally involves either cryptographic operations to prove that the token was issued by a trusted party, or the presentation of the bearer token to some other party so that _it_ can perform validation. The command MUST maintain confidentiality of the bearer token, since in most cases it can be used just like a password. (There are ways to cryptographically bind tokens to client certificates, but they are way beyond the scope of this commit message.) If the token cannot be validated, the authorized member of the ValidatorModuleResult struct is used to indicate failure. Further authentication/authorization is pointless if the bearer token wasn't issued by someone you trust. 3. Authenticate the user, authorize the user, or both: a. To authenticate the user, use the bearer token to retrieve some trusted identifier string for the end user. The exact process for this is, again, issuer-dependent. The module wull return the authenticated identity in the authn_id member. b. To optionally authorize the user, in combination with the HBA option trust_validator_authz=1 (see below). The hard part is in determining whether the given token truly authorizes the client to use the given role, which must unfortunately be left as an exercise to the reader. This obviously requires some care, as a poorly implemented token validator may silently open the entire database to anyone with a bearer token. But it may be a more portable approach, since OAuth is designed as an authorization framework, not an authentication framework. For example, the user's bearer token could carry an "allow_superuser_access" claim, which would authorize pseudonymous database access as any role. It's then up to the OAuth system administrators to ensure that allow_superuser_access is doled out only to the proper users. c. It's possible that the user can be successfully authenticated but isn't authorized to connect. In this case, the validator module may return the authenticated ID and then fail with false authorized member. (This can make it easier to see what's going on in the Postgres logs.) The oauth method supports the following HBA options (but note that two of them are not optional, since we have no way of choosing sensible defaults): issuer: Required. The URL of the OAuth issuing party, which the client must contact to receive a bearer token. Some real-world examples as of time of writing: - https://accounts.google.com - https://login.microsoft.com/[tenant-id]/v2.0 scope: Required. The OAuth scope(s) required for the server to authenticate and/or authorize the user. This is heavily deployment-specific, but a simple example is "openid email". map: Optional. Specify a standard PostgreSQL user map; this works the same as with other auth methods such as peer. If a map is not specified, the user ID returned by the token validator must exactly match the role that's being requested (but see trust_validator_authz, below). trust_validator_authz: Optional. When set to 1, this allows the token validator to take full control of the authorization process. Standard user mapping is skipped: if the validator command succeeds, the client is allowed to connect under its desired role and no further checks are done. Several TODOs: - port to platforms other than "modern Linux/BSD" - implement more helpful handling of HBA misconfigurations - use logdetail during auth failures - allow passing the configured issuer to the oauth_validator_command, to deal with multi-issuer setups - ...and more. Co-authored-by: Daniel Gustafsson --- .cirrus.tasks.yml | 15 +- src/backend/libpq/Makefile | 1 + src/backend/libpq/auth-oauth.c | 666 ++++++++++++++++++ src/backend/libpq/auth-sasl.c | 10 +- src/backend/libpq/auth-scram.c | 4 +- src/backend/libpq/auth.c | 26 +- src/backend/libpq/hba.c | 31 +- src/backend/libpq/meson.build | 1 + src/backend/utils/misc/guc_tables.c | 12 + src/common/Makefile | 2 +- src/include/libpq/auth.h | 17 + src/include/libpq/hba.h | 6 +- src/include/libpq/oauth.h | 49 ++ src/include/libpq/sasl.h | 11 + src/test/modules/Makefile | 1 + src/test/modules/meson.build | 1 + src/test/modules/oauth_validator/.gitignore | 4 + src/test/modules/oauth_validator/Makefile | 22 + .../oauth_validator/expected/validator.out | 6 + src/test/modules/oauth_validator/meson.build | 37 + .../modules/oauth_validator/sql/validator.sql | 1 + .../modules/oauth_validator/t/001_server.pl | 79 +++ .../modules/oauth_validator/t/oauth_server.py | 114 +++ src/test/modules/oauth_validator/validator.c | 82 +++ src/test/perl/PostgreSQL/Test/Cluster.pm | 14 +- src/test/perl/PostgreSQL/Test/OAuthServer.pm | 65 ++ src/tools/pgindent/typedefs.list | 3 + 27 files changed, 1241 insertions(+), 39 deletions(-) create mode 100644 src/backend/libpq/auth-oauth.c create mode 100644 src/include/libpq/oauth.h create mode 100644 src/test/modules/oauth_validator/.gitignore create mode 100644 src/test/modules/oauth_validator/Makefile create mode 100644 src/test/modules/oauth_validator/expected/validator.out create mode 100644 src/test/modules/oauth_validator/meson.build create mode 100644 src/test/modules/oauth_validator/sql/validator.sql create mode 100644 src/test/modules/oauth_validator/t/001_server.pl create mode 100755 src/test/modules/oauth_validator/t/oauth_server.py create mode 100644 src/test/modules/oauth_validator/validator.c create mode 100644 src/test/perl/PostgreSQL/Test/OAuthServer.pm diff --git a/.cirrus.tasks.yml b/.cirrus.tasks.yml index 33646faead..95f131baa9 100644 --- a/.cirrus.tasks.yml +++ b/.cirrus.tasks.yml @@ -163,7 +163,7 @@ task: chown root:postgres /tmp/cores sysctl kern.corefile='/tmp/cores/%N.%P.core' setup_additional_packages_script: | - #pkg install -y ... + pkg install -y curl # NB: Intentionally build without -Dllvm. The freebsd image size is already # large enough to make VM startup slow, and even without llvm freebsd @@ -175,6 +175,7 @@ task: -Dcassert=true -Dinjection_points=true \ -Duuid=bsd -Dtcl_version=tcl86 -Ddtrace=auto \ -DPG_TEST_EXTRA="$PG_TEST_EXTRA" \ + -Doauth=curl \ -Dextra_lib_dirs=/usr/local/lib -Dextra_include_dirs=/usr/local/include/ \ build EOF @@ -223,6 +224,7 @@ LINUX_CONFIGURE_FEATURES: &LINUX_CONFIGURE_FEATURES >- --with-libxslt --with-llvm --with-lz4 + --with-oauth=curl --with-pam --with-perl --with-python @@ -235,6 +237,7 @@ LINUX_CONFIGURE_FEATURES: &LINUX_CONFIGURE_FEATURES >- LINUX_MESON_FEATURES: &LINUX_MESON_FEATURES >- -Dllvm=enabled + -Doauth=curl -Duuid=e2fs @@ -310,8 +313,10 @@ task: EOF setup_additional_packages_script: | - #apt-get update - #DEBIAN_FRONTEND=noninteractive apt-get -y install ... + apt-get update + DEBIAN_FRONTEND=noninteractive apt-get -y install \ + libcurl4-openssl-dev \ + libcurl4-openssl-dev:i386 \ matrix: - name: Linux - Debian Bullseye - Autoconf @@ -676,8 +681,8 @@ task: folder: $CCACHE_DIR setup_additional_packages_script: | - #apt-get update - #DEBIAN_FRONTEND=noninteractive apt-get -y install ... + apt-get update + DEBIAN_FRONTEND=noninteractive apt-get -y install libcurl4-openssl-dev ### # Test that code can be built with gcc/clang without warnings diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile index 6d385fd6a4..98eb2a8242 100644 --- a/src/backend/libpq/Makefile +++ b/src/backend/libpq/Makefile @@ -15,6 +15,7 @@ include $(top_builddir)/src/Makefile.global # be-fsstubs is here for historical reasons, probably belongs elsewhere OBJS = \ + auth-oauth.o \ auth-sasl.o \ auth-scram.o \ auth.o \ diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c new file mode 100644 index 0000000000..024f304e4d --- /dev/null +++ b/src/backend/libpq/auth-oauth.c @@ -0,0 +1,666 @@ +/*------------------------------------------------------------------------- + * + * auth-oauth.c + * Server-side implementation of the SASL OAUTHBEARER mechanism. + * + * See the following RFC for more details: + * - RFC 7628: https://tools.ietf.org/html/rfc7628 + * + * Portions Copyright (c) 1996-2024, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * src/backend/libpq/auth-oauth.c + * + *------------------------------------------------------------------------- + */ +#include "postgres.h" + +#include +#include + +#include "common/oauth-common.h" +#include "fmgr.h" +#include "lib/stringinfo.h" +#include "libpq/auth.h" +#include "libpq/hba.h" +#include "libpq/oauth.h" +#include "libpq/sasl.h" +#include "storage/fd.h" +#include "storage/ipc.h" +#include "utils/json.h" + +/* GUC */ +char *OAuthValidatorLibrary = ""; + +static void oauth_get_mechanisms(Port *port, StringInfo buf); +static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass); +static int oauth_exchange(void *opaq, const char *input, int inputlen, + char **output, int *outputlen, const char **logdetail); + +static void load_validator_library(void); +static void shutdown_validator_library(int code, Datum arg); + +static ValidatorModuleState *validator_module_state; +static const OAuthValidatorCallbacks *ValidatorCallbacks; + +/* Mechanism declaration */ +const pg_be_sasl_mech pg_be_oauth_mech = { + oauth_get_mechanisms, + oauth_init, + oauth_exchange, + + PG_MAX_AUTH_TOKEN_LENGTH, +}; + + +typedef enum +{ + OAUTH_STATE_INIT = 0, + OAUTH_STATE_ERROR, + OAUTH_STATE_FINISHED, +} oauth_state; + +struct oauth_ctx +{ + oauth_state state; + Port *port; + const char *issuer; + const char *scope; +}; + +static char *sanitize_char(char c); +static char *parse_kvpairs_for_auth(char **input); +static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen); +static bool validate(Port *port, const char *auth); + +#define KVSEP 0x01 +#define AUTH_KEY "auth" +#define BEARER_SCHEME "Bearer " + +static void +oauth_get_mechanisms(Port *port, StringInfo buf) +{ + /* Only OAUTHBEARER is supported. */ + appendStringInfoString(buf, OAUTHBEARER_NAME); + appendStringInfoChar(buf, '\0'); +} + +static void * +oauth_init(Port *port, const char *selected_mech, const char *shadow_pass) +{ + struct oauth_ctx *ctx; + + if (strcmp(selected_mech, OAUTHBEARER_NAME)) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("client selected an invalid SASL authentication mechanism"))); + + ctx = palloc0(sizeof(*ctx)); + + ctx->state = OAUTH_STATE_INIT; + ctx->port = port; + + Assert(port->hba); + ctx->issuer = port->hba->oauth_issuer; + ctx->scope = port->hba->oauth_scope; + + load_validator_library(); + + return ctx; +} + +static int +oauth_exchange(void *opaq, const char *input, int inputlen, + char **output, int *outputlen, const char **logdetail) +{ + char *p; + char cbind_flag; + char *auth; + + struct oauth_ctx *ctx = opaq; + + *output = NULL; + *outputlen = -1; + + /* + * If the client didn't include an "Initial Client Response" in the + * SASLInitialResponse message, send an empty challenge, to which the + * client will respond with the same data that usually comes in the + * Initial Client Response. + */ + if (input == NULL) + { + Assert(ctx->state == OAUTH_STATE_INIT); + + *output = pstrdup(""); + *outputlen = 0; + return PG_SASL_EXCHANGE_CONTINUE; + } + + /* + * Check that the input length agrees with the string length of the input. + */ + if (inputlen == 0) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("The message is empty."))); + if (inputlen != strlen(input)) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Message length does not match input length."))); + + switch (ctx->state) + { + case OAUTH_STATE_INIT: + /* Handle this case below. */ + break; + + case OAUTH_STATE_ERROR: + + /* + * Only one response is valid for the client during authentication + * failure: a single kvsep. + */ + if (inputlen != 1 || *input != KVSEP) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Client did not send a kvsep response."))); + + /* The (failed) handshake is now complete. */ + ctx->state = OAUTH_STATE_FINISHED; + return PG_SASL_EXCHANGE_FAILURE; + + default: + elog(ERROR, "invalid OAUTHBEARER exchange state"); + return PG_SASL_EXCHANGE_FAILURE; + } + + /* Handle the client's initial message. */ + p = pstrdup(input); + + /* + * OAUTHBEARER does not currently define a channel binding (so there is no + * OAUTHBEARER-PLUS, and we do not accept a 'p' specifier). We accept a + * 'y' specifier purely for the remote chance that a future specification + * could define one; then future clients can still interoperate with this + * server implementation. 'n' is the expected case. + */ + cbind_flag = *p; + switch (cbind_flag) + { + case 'p': + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("The server does not support channel binding for OAuth, but the client message includes channel binding data."))); + break; + + case 'y': /* fall through */ + case 'n': + p++; + if (*p != ',') + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Comma expected, but found character \"%s\".", + sanitize_char(*p)))); + p++; + break; + + default: + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Unexpected channel-binding flag %s.", + sanitize_char(cbind_flag)))); + } + + /* + * Forbid optional authzid (authorization identity). We don't support it. + */ + if (*p == 'a') + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("client uses authorization identity, but it is not supported"))); + if (*p != ',') + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Unexpected attribute %s in client-first-message.", + sanitize_char(*p)))); + p++; + + /* All remaining fields are separated by the RFC's kvsep (\x01). */ + if (*p != KVSEP) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Key-value separator expected, but found character %s.", + sanitize_char(*p)))); + p++; + + auth = parse_kvpairs_for_auth(&p); + if (!auth) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Message does not contain an auth value."))); + + /* We should be at the end of our message. */ + if (*p) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Message contains additional data after the final terminator."))); + + if (!validate(ctx->port, auth)) + { + generate_error_response(ctx, output, outputlen); + + ctx->state = OAUTH_STATE_ERROR; + return PG_SASL_EXCHANGE_CONTINUE; + } + + ctx->state = OAUTH_STATE_FINISHED; + return PG_SASL_EXCHANGE_SUCCESS; +} + +/* + * Convert an arbitrary byte to printable form. For error messages. + * + * If it's a printable ASCII character, print it as a single character. + * otherwise, print it in hex. + * + * The returned pointer points to a static buffer. + */ +static char * +sanitize_char(char c) +{ + static char buf[5]; + + if (c >= 0x21 && c <= 0x7E) + snprintf(buf, sizeof(buf), "'%c'", c); + else + snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c); + return buf; +} + +/* + * Performs syntactic validation of a key and value from the initial client + * response. (Semantic validation of interesting values must be performed + * later.) + */ +static void +validate_kvpair(const char *key, const char *val) +{ + /*----- + * From Sec 3.1: + * key = 1*(ALPHA) + */ + static const char *key_allowed_set = + "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + + size_t span; + + if (!key[0]) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Message contains an empty key name."))); + + span = strspn(key, key_allowed_set); + if (key[span] != '\0') + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Message contains an invalid key name."))); + + /*----- + * From Sec 3.1: + * value = *(VCHAR / SP / HTAB / CR / LF ) + * + * The VCHAR (visible character) class is large; a loop is more + * straightforward than strspn(). + */ + for (; *val; ++val) + { + if (0x21 <= *val && *val <= 0x7E) + continue; /* VCHAR */ + + switch (*val) + { + case ' ': + case '\t': + case '\r': + case '\n': + continue; /* SP, HTAB, CR, LF */ + + default: + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Message contains an invalid value."))); + } + } +} + +/* + * Consumes all kvpairs in an OAUTHBEARER exchange message. If the "auth" key is + * found, its value is returned. + */ +static char * +parse_kvpairs_for_auth(char **input) +{ + char *pos = *input; + char *auth = NULL; + + /*---- + * The relevant ABNF, from Sec. 3.1: + * + * kvsep = %x01 + * key = 1*(ALPHA) + * value = *(VCHAR / SP / HTAB / CR / LF ) + * kvpair = key "=" value kvsep + * ;;gs2-header = See RFC 5801 + * client-resp = (gs2-header kvsep *kvpair kvsep) / kvsep + * + * By the time we reach this code, the gs2-header and initial kvsep have + * already been validated. We start at the beginning of the first kvpair. + */ + + while (*pos) + { + char *end; + char *sep; + char *key; + char *value; + + /* + * Find the end of this kvpair. Note that input is null-terminated by + * the SASL code, so the strchr() is bounded. + */ + end = strchr(pos, KVSEP); + if (!end) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Message contains an unterminated key/value pair."))); + *end = '\0'; + + if (pos == end) + { + /* Empty kvpair, signifying the end of the list. */ + *input = pos + 1; + return auth; + } + + /* + * Find the end of the key name. + */ + sep = strchr(pos, '='); + if (!sep) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Message contains a key without a value."))); + *sep = '\0'; + + /* Both key and value are now safely terminated. */ + key = pos; + value = sep + 1; + validate_kvpair(key, value); + + if (!strcmp(key, AUTH_KEY)) + { + if (auth) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Message contains multiple auth values."))); + + auth = value; + } + else + { + /* + * The RFC also defines the host and port keys, but they are not + * required for OAUTHBEARER and we do not use them. Also, per Sec. + * 3.1, any key/value pairs we don't recognize must be ignored. + */ + } + + /* Move to the next pair. */ + pos = end + 1; + } + + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAUTHBEARER message"), + errdetail("Message did not contain a final terminator."))); + + return NULL; /* unreachable */ +} + +static void +generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen) +{ + StringInfoData buf; + StringInfoData issuer; + + /* + * The admin needs to set an issuer and scope for OAuth to work. There's + * not really a way to hide this from the user, either, because we can't + * choose a "default" issuer, so be honest in the failure message. + * + * TODO: see if there's a better place to fail, earlier than this. + */ + if (!ctx->issuer || !ctx->scope) + ereport(FATAL, + (errcode(ERRCODE_INTERNAL_ERROR), + errmsg("OAuth is not properly configured for this user"), + errdetail_log("The issuer and scope parameters must be set in pg_hba.conf."))); + + /*------ + * Build the .well-known URI based on our issuer. + * TODO: RFC 8414 defines a competing well-known URI, so we'll probably + * have to make this configurable too. + */ + initStringInfo(&issuer); + appendStringInfoString(&issuer, ctx->issuer); + appendStringInfoString(&issuer, "/.well-known/openid-configuration"); + + initStringInfo(&buf); + + /* + * TODO: note that escaping here should be belt-and-suspenders, since + * escapable characters aren't valid in either the issuer URI or the scope + * list, but the HBA doesn't enforce that yet. + */ + appendStringInfoString(&buf, "{ \"status\": \"invalid_token\", "); + + appendStringInfoString(&buf, "\"openid-configuration\": "); + escape_json(&buf, issuer.data); + pfree(issuer.data); + + appendStringInfoString(&buf, ", \"scope\": "); + escape_json(&buf, ctx->scope); + + appendStringInfoString(&buf, " }"); + + *output = buf.data; + *outputlen = buf.len; +} + +/*----- + * Validates the provided Authorization header and returns the token from + * within it. NULL is returned on validation failure. + * + * Only Bearer tokens are accepted. The ABNF is defined in RFC 6750, Sec. + * 2.1: + * + * b64token = 1*( ALPHA / DIGIT / + * "-" / "." / "_" / "~" / "+" / "/" ) *"=" + * credentials = "Bearer" 1*SP b64token + * + * The "credentials" construction is what we receive in our auth value. + * + * Since that spec is subordinate to HTTP (i.e. the HTTP Authorization + * header format; RFC 7235 Sec. 2), the "Bearer" scheme string must be + * compared case-insensitively. (This is not mentioned in RFC 6750, but + * it's pointed out in RFC 7628 Sec. 4.) + * + * Invalid formats are technically a protocol violation, but we shouldn't + * reflect any information about the sensitive Bearer token back to the + * client; log at COMMERROR instead. + * + * TODO: handle the Authorization spec, RFC 7235 Sec. 2.1. + */ +static const char * +validate_token_format(const char *header) +{ + size_t span; + const char *token; + static const char *const b64token_allowed_set = + "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "0123456789-._~+/"; + + /* If the token is empty or simply too short to be correct */ + if (!header || strlen(header) <= 7) + { + ereport(COMMERROR, + (errmsg("malformed OAuth bearer token 1"))); + return NULL; + } + + if (pg_strncasecmp(header, BEARER_SCHEME, strlen(BEARER_SCHEME))) + return NULL; + + /* Pull the bearer token out of the auth value. */ + token = header + strlen(BEARER_SCHEME); + + /* Swallow any additional spaces. */ + while (*token == ' ') + token++; + + /* Tokens must not be empty. */ + if (!*token) + { + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAuth bearer token 2"), + errdetail("Bearer token is empty."))); + return NULL; + } + + /* + * Make sure the token contains only allowed characters. Tokens may end + * with any number of '=' characters. + */ + span = strspn(token, b64token_allowed_set); + while (token[span] == '=') + span++; + + if (token[span] != '\0') + { + /* + * This error message could be more helpful by printing the + * problematic character(s), but that'd be a bit like printing a piece + * of someone's password into the logs. + */ + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed OAuth bearer token 3"), + errdetail("Bearer token is not in the correct format."))); + return NULL; + } + + return token; +} + +static bool +validate(Port *port, const char *auth) +{ + int map_status; + ValidatorModuleResult *ret; + const char *token; + + /* Ensure that we have a correct token to validate */ + if (!(token = validate_token_format(auth))) + return false; + + /* Call the validation function from the validator module */ + ret = ValidatorCallbacks->validate_cb(validator_module_state, + token, port->user_name); + + if (!ret->authorized) + return false; + + if (ret->authn_id) + set_authn_id(port, ret->authn_id); + + if (port->hba->oauth_skip_usermap) + { + /* + * If the validator is our authorization authority, we're done. + * Authentication may or may not have been performed depending on the + * validator implementation; all that matters is that the validator + * says the user can log in with the target role. + */ + return true; + } + + /* Make sure the validator authenticated the user. */ + if (ret->authn_id == NULL || ret->authn_id[0] == '\0') + { + /* TODO: use logdetail; reduce message duplication */ + ereport(LOG, + (errmsg("OAuth bearer authentication failed for user \"%s\": validator provided no identity", + port->user_name))); + return false; + } + + /* Finally, check the user map. */ + map_status = check_usermap(port->hba->usermap, port->user_name, + MyClientConnectionInfo.authn_id, false); + return (map_status == STATUS_OK); +} + +static void +load_validator_library(void) +{ + OAuthValidatorModuleInit validator_init; + + if (OAuthValidatorLibrary[0] == '\0') + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("oauth_validator_library is not set"))); + + validator_init = (OAuthValidatorModuleInit) + load_external_function(OAuthValidatorLibrary, + "_PG_oauth_validator_module_init", false, NULL); + + if (validator_init == NULL) + ereport(ERROR, + (errmsg("%s module \"%s\" have to define the symbol %s", + "OAuth validator", OAuthValidatorLibrary, "_PG_oauth_validator_module_init"))); + + ValidatorCallbacks = (*validator_init) (); + + validator_module_state = (ValidatorModuleState *) palloc0(sizeof(ValidatorModuleState)); + if (ValidatorCallbacks->startup_cb != NULL) + ValidatorCallbacks->startup_cb(validator_module_state); + + before_shmem_exit(shutdown_validator_library, 0); +} + +static void +shutdown_validator_library(int code, Datum arg) +{ + if (ValidatorCallbacks->shutdown_cb != NULL) + ValidatorCallbacks->shutdown_cb(validator_module_state); +} diff --git a/src/backend/libpq/auth-sasl.c b/src/backend/libpq/auth-sasl.c index 08b24d90b4..4039e7fa3e 100644 --- a/src/backend/libpq/auth-sasl.c +++ b/src/backend/libpq/auth-sasl.c @@ -20,14 +20,6 @@ #include "libpq/pqformat.h" #include "libpq/sasl.h" -/* - * Maximum accepted size of SASL messages. - * - * The messages that the server or libpq generate are much smaller than this, - * but have some headroom. - */ -#define PG_MAX_SASL_MESSAGE_LENGTH 1024 - /* * Perform a SASL exchange with a libpq client, using a specific mechanism * implementation. @@ -103,7 +95,7 @@ CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass, /* Get the actual SASL message */ initStringInfo(&buf); - if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH)) + if (pq_getmessage(&buf, mech->max_message_length)) { /* EOF - pq_getmessage already logged error */ pfree(buf.data); diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c index 4161959914..486a34e719 100644 --- a/src/backend/libpq/auth-scram.c +++ b/src/backend/libpq/auth-scram.c @@ -113,7 +113,9 @@ static int scram_exchange(void *opaq, const char *input, int inputlen, const pg_be_sasl_mech pg_be_scram_mech = { scram_get_mechanisms, scram_init, - scram_exchange + scram_exchange, + + PG_MAX_SASL_MESSAGE_LENGTH }; /* diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c index 2b607c5270..0a5a8640fc 100644 --- a/src/backend/libpq/auth.c +++ b/src/backend/libpq/auth.c @@ -29,6 +29,7 @@ #include "libpq/auth.h" #include "libpq/crypt.h" #include "libpq/libpq.h" +#include "libpq/oauth.h" #include "libpq/pqformat.h" #include "libpq/sasl.h" #include "libpq/scram.h" @@ -45,7 +46,6 @@ */ static void auth_failed(Port *port, int status, const char *logdetail); static char *recv_password_packet(Port *port); -static void set_authn_id(Port *port, const char *id); /*---------------------------------------------------------------- @@ -201,22 +201,6 @@ static int CheckRADIUSAuth(Port *port); static int PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd); -/* - * Maximum accepted size of GSS and SSPI authentication tokens. - * We also use this as a limit on ordinary password packet lengths. - * - * Kerberos tickets are usually quite small, but the TGTs issued by Windows - * domain controllers include an authorization field known as the Privilege - * Attribute Certificate (PAC), which contains the user's Windows permissions - * (group memberships etc.). The PAC is copied into all tickets obtained on - * the basis of this TGT (even those issued by Unix realms which the Windows - * realm trusts), and can be several kB in size. The maximum token size - * accepted by Windows systems is determined by the MaxAuthToken Windows - * registry setting. Microsoft recommends that it is not set higher than - * 65535 bytes, so that seems like a reasonable limit for us as well. - */ -#define PG_MAX_AUTH_TOKEN_LENGTH 65535 - /*---------------------------------------------------------------- * Global authentication functions *---------------------------------------------------------------- @@ -305,6 +289,9 @@ auth_failed(Port *port, int status, const char *logdetail) case uaRADIUS: errstr = gettext_noop("RADIUS authentication failed for user \"%s\""); break; + case uaOAuth: + errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\""); + break; default: errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method"); break; @@ -340,7 +327,7 @@ auth_failed(Port *port, int status, const char *logdetail) * lifetime of MyClientConnectionInfo, so it is safe to pass a string that is * managed by an external library. */ -static void +void set_authn_id(Port *port, const char *id) { Assert(id); @@ -627,6 +614,9 @@ ClientAuthentication(Port *port) case uaTrust: status = STATUS_OK; break; + case uaOAuth: + status = CheckSASLAuth(&pg_be_oauth_mech, port, NULL, NULL); + break; } if ((status == STATUS_OK && port->hba->clientcert == clientCertFull) diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c index 18271def2e..aabe0b0e68 100644 --- a/src/backend/libpq/hba.c +++ b/src/backend/libpq/hba.c @@ -114,7 +114,8 @@ static const char *const UserAuthName[] = "ldap", "cert", "radius", - "peer" + "peer", + "oauth", }; /* @@ -1743,6 +1744,8 @@ parse_hba_line(TokenizedAuthLine *tok_line, int elevel) #endif else if (strcmp(token->string, "radius") == 0) parsedline->auth_method = uaRADIUS; + else if (strcmp(token->string, "oauth") == 0) + parsedline->auth_method = uaOAuth; else { ereport(elevel, @@ -2062,8 +2065,9 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline, hbaline->auth_method != uaPeer && hbaline->auth_method != uaGSS && hbaline->auth_method != uaSSPI && - hbaline->auth_method != uaCert) - INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, and cert")); + hbaline->auth_method != uaCert && + hbaline->auth_method != uaOAuth) + INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert, and oauth")); hbaline->usermap = pstrdup(val); } else if (strcmp(name, "clientcert") == 0) @@ -2446,6 +2450,27 @@ parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline, hbaline->radiusidentifiers = parsed_identifiers; hbaline->radiusidentifiers_s = pstrdup(val); } + else if (strcmp(name, "issuer") == 0) + { + if (hbaline->auth_method != uaOAuth) + INVALID_AUTH_OPTION("issuer", gettext_noop("oauth")); + hbaline->oauth_issuer = pstrdup(val); + } + else if (strcmp(name, "scope") == 0) + { + if (hbaline->auth_method != uaOAuth) + INVALID_AUTH_OPTION("scope", gettext_noop("oauth")); + hbaline->oauth_scope = pstrdup(val); + } + else if (strcmp(name, "trust_validator_authz") == 0) + { + if (hbaline->auth_method != uaOAuth) + INVALID_AUTH_OPTION("trust_validator_authz", gettext_noop("oauth")); + if (strcmp(val, "1") == 0) + hbaline->oauth_skip_usermap = true; + else + hbaline->oauth_skip_usermap = false; + } else { ereport(elevel, diff --git a/src/backend/libpq/meson.build b/src/backend/libpq/meson.build index 7c65314512..c85527fb01 100644 --- a/src/backend/libpq/meson.build +++ b/src/backend/libpq/meson.build @@ -1,6 +1,7 @@ # Copyright (c) 2022-2024, PostgreSQL Global Development Group backend_sources += files( + 'auth-oauth.c', 'auth-sasl.c', 'auth-scram.c', 'auth.c', diff --git a/src/backend/utils/misc/guc_tables.c b/src/backend/utils/misc/guc_tables.c index d28b0bcb40..461094f288 100644 --- a/src/backend/utils/misc/guc_tables.c +++ b/src/backend/utils/misc/guc_tables.c @@ -48,6 +48,7 @@ #include "jit/jit.h" #include "libpq/auth.h" #include "libpq/libpq.h" +#include "libpq/oauth.h" #include "libpq/scram.h" #include "nodes/queryjumble.h" #include "optimizer/cost.h" @@ -4707,6 +4708,17 @@ struct config_string ConfigureNamesString[] = check_synchronized_standby_slots, assign_synchronized_standby_slots, NULL }, + { + {"oauth_validator_library", PGC_SIGHUP, CONN_AUTH_AUTH, + gettext_noop("Sets the library that will be called to validate OAuth v2 bearer tokens."), + NULL, + GUC_SUPERUSER_ONLY | GUC_NOT_IN_SAMPLE + }, + &OAuthValidatorLibrary, + "", + NULL, NULL, NULL + }, + /* End-of-list marker */ { {NULL, 0, 0, NULL, NULL}, NULL, NULL, NULL, NULL, NULL diff --git a/src/common/Makefile b/src/common/Makefile index f1da2ed13d..beb9830030 100644 --- a/src/common/Makefile +++ b/src/common/Makefile @@ -41,7 +41,7 @@ override CPPFLAGS += -DVAL_LDFLAGS_SL="\"$(LDFLAGS_SL)\"" override CPPFLAGS += -DVAL_LIBS="\"$(LIBS)\"" override CPPFLAGS := -DFRONTEND -I. -I$(top_srcdir)/src/common -I$(libpq_srcdir) $(CPPFLAGS) -LIBS += $(PTHREAD_LIBS) +LIBS += $(PTHREAD_LIBS) $(libpq_pgport) OBJS_COMMON = \ archive.o \ diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h index 227b41daf6..22f6ab9f1d 100644 --- a/src/include/libpq/auth.h +++ b/src/include/libpq/auth.h @@ -16,6 +16,22 @@ #include "libpq/libpq-be.h" +/* + * Maximum accepted size of GSS and SSPI authentication tokens. + * We also use this as a limit on ordinary password packet lengths. + * + * Kerberos tickets are usually quite small, but the TGTs issued by Windows + * domain controllers include an authorization field known as the Privilege + * Attribute Certificate (PAC), which contains the user's Windows permissions + * (group memberships etc.). The PAC is copied into all tickets obtained on + * the basis of this TGT (even those issued by Unix realms which the Windows + * realm trusts), and can be several kB in size. The maximum token size + * accepted by Windows systems is determined by the MaxAuthToken Windows + * registry setting. Microsoft recommends that it is not set higher than + * 65535 bytes, so that seems like a reasonable limit for us as well. + */ +#define PG_MAX_AUTH_TOKEN_LENGTH 65535 + extern PGDLLIMPORT char *pg_krb_server_keyfile; extern PGDLLIMPORT bool pg_krb_caseins_users; extern PGDLLIMPORT bool pg_gss_accept_delegation; @@ -23,6 +39,7 @@ extern PGDLLIMPORT bool pg_gss_accept_delegation; extern void ClientAuthentication(Port *port); extern void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, int extralen); +extern void set_authn_id(Port *port, const char *id); /* Hook for plugins to get control in ClientAuthentication() */ typedef void (*ClientAuthentication_hook_type) (Port *, int); diff --git a/src/include/libpq/hba.h b/src/include/libpq/hba.h index 8ea837ae82..53c999a69f 100644 --- a/src/include/libpq/hba.h +++ b/src/include/libpq/hba.h @@ -39,7 +39,8 @@ typedef enum UserAuth uaCert, uaRADIUS, uaPeer, -#define USER_AUTH_LAST uaPeer /* Must be last value of this enum */ + uaOAuth, +#define USER_AUTH_LAST uaOAuth /* Must be last value of this enum */ } UserAuth; /* @@ -135,6 +136,9 @@ typedef struct HbaLine char *radiusidentifiers_s; List *radiusports; char *radiusports_s; + char *oauth_issuer; + char *oauth_scope; + bool oauth_skip_usermap; } HbaLine; typedef struct IdentLine diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h new file mode 100644 index 0000000000..6f98e84cc9 --- /dev/null +++ b/src/include/libpq/oauth.h @@ -0,0 +1,49 @@ +/*------------------------------------------------------------------------- + * + * oauth.h + * Interface to libpq/auth-oauth.c + * + * Portions Copyright (c) 1996-2024, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * src/include/libpq/oauth.h + * + *------------------------------------------------------------------------- + */ +#ifndef PG_OAUTH_H +#define PG_OAUTH_H + +#include "libpq/libpq-be.h" +#include "libpq/sasl.h" + +extern PGDLLIMPORT char *OAuthValidatorLibrary; + +typedef struct ValidatorModuleState +{ + void *private_data; +} ValidatorModuleState; + +typedef struct ValidatorModuleResult +{ + bool authorized; + char *authn_id; +} ValidatorModuleResult; + +typedef void (*ValidatorStartupCB) (ValidatorModuleState *state); +typedef void (*ValidatorShutdownCB) (ValidatorModuleState *state); +typedef ValidatorModuleResult *(*ValidatorValidateCB) (ValidatorModuleState *state, const char *token, const char *role); + +typedef struct OAuthValidatorCallbacks +{ + ValidatorStartupCB startup_cb; + ValidatorShutdownCB shutdown_cb; + ValidatorValidateCB validate_cb; +} OAuthValidatorCallbacks; + +typedef const OAuthValidatorCallbacks *(*OAuthValidatorModuleInit) (void); +extern PGDLLEXPORT const OAuthValidatorCallbacks *_PG_oauth_validator_module_init(void); + +/* Implementation */ +extern const pg_be_sasl_mech pg_be_oauth_mech; + +#endif /* PG_OAUTH_H */ diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h index 7a1f970cca..3f2c02b8f2 100644 --- a/src/include/libpq/sasl.h +++ b/src/include/libpq/sasl.h @@ -26,6 +26,14 @@ #define PG_SASL_EXCHANGE_SUCCESS 1 #define PG_SASL_EXCHANGE_FAILURE 2 +/* + * Maximum accepted size of SASL messages. + * + * The messages that the server or libpq generate are much smaller than this, + * but have some headroom. + */ +#define PG_MAX_SASL_MESSAGE_LENGTH 1024 + /* * Backend SASL mechanism callbacks. * @@ -127,6 +135,9 @@ typedef struct pg_be_sasl_mech const char *input, int inputlen, char **output, int *outputlen, const char **logdetail); + + /* The maximum size allowed for client SASLResponses. */ + int max_message_length; } pg_be_sasl_mech; /* Common implementation for auth.c */ diff --git a/src/test/modules/Makefile b/src/test/modules/Makefile index 256799f520..150dc1d908 100644 --- a/src/test/modules/Makefile +++ b/src/test/modules/Makefile @@ -11,6 +11,7 @@ SUBDIRS = \ dummy_index_am \ dummy_seclabel \ libpq_pipeline \ + oauth_validator \ plsample \ spgist_name_ops \ test_bloomfilter \ diff --git a/src/test/modules/meson.build b/src/test/modules/meson.build index d8fe059d23..60efa07b42 100644 --- a/src/test/modules/meson.build +++ b/src/test/modules/meson.build @@ -9,6 +9,7 @@ subdir('gin') subdir('injection_points') subdir('ldap_password_func') subdir('libpq_pipeline') +subdir('oauth_validator') subdir('plsample') subdir('spgist_name_ops') subdir('ssl_passphrase_callback') diff --git a/src/test/modules/oauth_validator/.gitignore b/src/test/modules/oauth_validator/.gitignore new file mode 100644 index 0000000000..5dcb3ff972 --- /dev/null +++ b/src/test/modules/oauth_validator/.gitignore @@ -0,0 +1,4 @@ +# Generated subdirectories +/log/ +/results/ +/tmp_check/ diff --git a/src/test/modules/oauth_validator/Makefile b/src/test/modules/oauth_validator/Makefile new file mode 100644 index 0000000000..655ce75796 --- /dev/null +++ b/src/test/modules/oauth_validator/Makefile @@ -0,0 +1,22 @@ +export PYTHON +export with_oauth + +MODULES = validator +PGFILEDESC = "validator - test OAuth validator module" + +NO_INSTALLCHECK = 1 + +TAP_TESTS = 1 + +REGRESS = validator + +ifdef USE_PGXS +PG_CONFIG = pg_config +PGXS := $(shell $(PG_CONFIG) --pgxs) +include $(PGXS) +else +subdir = src/test/modules/oauth_validator +top_builddir = ../../../.. +include $(top_builddir)/src/Makefile.global +include $(top_srcdir)/contrib/contrib-global.mk +endif diff --git a/src/test/modules/oauth_validator/expected/validator.out b/src/test/modules/oauth_validator/expected/validator.out new file mode 100644 index 0000000000..360caa2cb3 --- /dev/null +++ b/src/test/modules/oauth_validator/expected/validator.out @@ -0,0 +1,6 @@ +SELECT 1; + ?column? +---------- + 1 +(1 row) + diff --git a/src/test/modules/oauth_validator/meson.build b/src/test/modules/oauth_validator/meson.build new file mode 100644 index 0000000000..3db2ddea1c --- /dev/null +++ b/src/test/modules/oauth_validator/meson.build @@ -0,0 +1,37 @@ +# Copyright (c) 2024, PostgreSQL Global Development Group + +validator_sources = files( + 'validator.c', +) + +if host_system == 'windows' + validator_sources += rc_lib_gen.process(win32ver_rc, extra_args: [ + '--NAME', 'validator', + '--FILEDESC', 'validator - test OAuth validator module',]) +endif + +validator = shared_module('validator', + validator_sources, + kwargs: pg_test_mod_args, +) +test_install_libs += validator + +tests += { + 'name': 'oauth_validator', + 'sd': meson.current_source_dir(), + 'bd': meson.current_build_dir(), + 'regress': { + 'sql': [ + 'validator', + ], + }, + 'tap': { + 'tests': [ + 't/001_server.pl', + ], + 'env': { + 'PYTHON': python.path(), + 'with_oauth': oauth_library, + }, + }, +} diff --git a/src/test/modules/oauth_validator/sql/validator.sql b/src/test/modules/oauth_validator/sql/validator.sql new file mode 100644 index 0000000000..e0ac49d1ec --- /dev/null +++ b/src/test/modules/oauth_validator/sql/validator.sql @@ -0,0 +1 @@ +SELECT 1; diff --git a/src/test/modules/oauth_validator/t/001_server.pl b/src/test/modules/oauth_validator/t/001_server.pl new file mode 100644 index 0000000000..e3cf3ac7f2 --- /dev/null +++ b/src/test/modules/oauth_validator/t/001_server.pl @@ -0,0 +1,79 @@ + +# Copyright (c) 2021-2024, PostgreSQL Global Development Group + +use strict; +use warnings FATAL => 'all'; + +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use PostgreSQL::Test::OAuthServer; +use Test::More; + +if ($ENV{with_oauth} ne 'curl') +{ + plan skip_all => 'client-side OAuth not supported by this build'; +} + +my $node = PostgreSQL::Test::Cluster->new('primary'); +$node->init; +$node->append_conf('postgresql.conf', "log_connections = on\n"); +$node->append_conf('postgresql.conf', "shared_preload_libraries = 'validator'\n"); +$node->append_conf('postgresql.conf', "oauth_validator_library = 'validator'\n"); +$node->start; + +$node->safe_psql('postgres', 'CREATE USER test;'); +$node->safe_psql('postgres', 'CREATE USER testalt;'); + +my $webserver = PostgreSQL::Test::OAuthServer->new(); +$webserver->run(); + +my $port = $webserver->port(); +my $issuer = "127.0.0.1:$port"; + +unlink($node->data_dir . '/pg_hba.conf'); +$node->append_conf('pg_hba.conf', qq{ +local all test oauth issuer="$issuer" scope="openid postgres" +local all testalt oauth issuer="$issuer/alternate" scope="openid postgres alt" +}); +$node->reload; + +my ($log_start, $log_end); +$log_start = $node->wait_for_log(qr/reloading configuration files/); + +my $user = "test"; +$node->connect_ok("user=$user dbname=postgres oauth_client_id=f02c6361-0635", "connect", + expected_stderr => qr@Visit https://example\.com/ and enter the code: postgresuser@); + +$log_end = $node->wait_for_log(qr/connection authorized/, $log_start); +$node->log_check("user $user: validator receives correct parameters", $log_start, + log_like => [ + qr/oauth_validator: token="9243959234", role="$user"/, + qr/oauth_validator: issuer="\Q$issuer\E", scope="openid postgres"/, + ]); +$node->log_check("user $user: validator sets authenticated identity", $log_start, + log_like => [ + qr/connection authenticated: identity="test" method=oauth/, + ]); +$log_start = $log_end; + +# The /alternate issuer uses slightly different parameters. +$user = "testalt"; +$node->connect_ok("user=$user dbname=postgres oauth_client_id=f02c6361-0636", "connect", + expected_stderr => qr@Visit https://example\.org/ and enter the code: postgresuser@); + +$log_end = $node->wait_for_log(qr/connection authorized/, $log_start); +$node->log_check("user $user: validator receives correct parameters", $log_start, + log_like => [ + qr/oauth_validator: token="9243959234-alt", role="$user"/, + qr|oauth_validator: issuer="\Q$issuer/alternate\E", scope="openid postgres alt"|, + ]); +$node->log_check("user $user: validator sets authenticated identity", $log_start, + log_like => [ + qr/connection authenticated: identity="testalt" method=oauth/, + ]); +$log_start = $log_end; + +$webserver->stop(); +$node->stop; + +done_testing(); diff --git a/src/test/modules/oauth_validator/t/oauth_server.py b/src/test/modules/oauth_validator/t/oauth_server.py new file mode 100755 index 0000000000..77e3883a81 --- /dev/null +++ b/src/test/modules/oauth_validator/t/oauth_server.py @@ -0,0 +1,114 @@ +#! /usr/bin/env python3 + +import http.server +import json +import os +import sys + + +class OAuthHandler(http.server.BaseHTTPRequestHandler): + JsonObject = dict[str, object] # TypeAlias is not available until 3.10 + + def _check_issuer(self): + """ + Switches the behavior of the provider depending on the issuer URI. + """ + self._alt_issuer = self.path.startswith("/alternate/") + if self._alt_issuer: + self.path = self.path.removeprefix("/alternate") + + def do_GET(self): + self._check_issuer() + + if self.path == "/.well-known/openid-configuration": + resp = self.config() + else: + self.send_error(404, "Not Found") + return + + self._send_json(resp) + + def do_POST(self): + self._check_issuer() + + if self.path == "/authorize": + resp = self.authorization() + elif self.path == "/token": + resp = self.token() + else: + self.send_error(404, "Not Found") + return + + self._send_json(resp) + + def _send_json(self, js: JsonObject) -> None: + """ + Sends the provided JSON dict as an application/json response. + """ + + resp = json.dumps(js).encode("ascii") + + self.send_response(200, "OK") + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(resp))) + self.end_headers() + + self.wfile.write(resp) + + def config(self) -> JsonObject: + port = self.server.socket.getsockname()[1] + issuer = f"http://localhost:{port}" + if self._alt_issuer: + issuer += "/alternate" + + return { + "issuer": issuer, + "token_endpoint": issuer + "/token", + "device_authorization_endpoint": issuer + "/authorize", + "response_types_supported": ["token"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + "grant_types_supported": ["urn:ietf:params:oauth:grant-type:device_code"], + } + + def authorization(self) -> JsonObject: + uri = "https://example.com/" + if self._alt_issuer: + uri = "https://example.org/" + + return { + "device_code": "postgres", + "user_code": "postgresuser", + "interval": 0, + "verification_uri": uri, + "expires-in": 5, + } + + def token(self) -> JsonObject: + token = "9243959234" + if self._alt_issuer: + token += "-alt" + + return { + "access_token": token, + "token_type": "bearer", + } + + +def main(): + s = http.server.HTTPServer(("127.0.0.1", 0), OAuthHandler) + + # Give the parent the port number to contact (this is also the signal that + # we're ready to receive requests). + port = s.socket.getsockname()[1] + print(port) + + stdout = sys.stdout.fileno() + sys.stdout.close() + os.close(stdout) + + s.serve_forever() # we expect our parent to send a termination signal + + +if __name__ == "__main__": + main() diff --git a/src/test/modules/oauth_validator/validator.c b/src/test/modules/oauth_validator/validator.c new file mode 100644 index 0000000000..09a4bf61d2 --- /dev/null +++ b/src/test/modules/oauth_validator/validator.c @@ -0,0 +1,82 @@ +/*------------------------------------------------------------------------- + * + * validator.c + * Test module for serverside OAuth token validation callbacks + * + * Portions Copyright (c) 1996-2024, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * src/test/modules/oauth_validator/validator.c + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include "fmgr.h" +#include "libpq/oauth.h" +#include "miscadmin.h" +#include "utils/memutils.h" + +PG_MODULE_MAGIC; + +static void validator_startup(ValidatorModuleState *state); +static void validator_shutdown(ValidatorModuleState *state); +static ValidatorModuleResult * validate_token(ValidatorModuleState *state, + const char *token, + const char *role); + +static const OAuthValidatorCallbacks validator_callbacks = { + .startup_cb = validator_startup, + .shutdown_cb = validator_shutdown, + .validate_cb = validate_token +}; + +void +_PG_init(void) +{ + /* no-op */ +} + +const OAuthValidatorCallbacks * +_PG_oauth_validator_module_init(void) +{ + return &validator_callbacks; +} + +#define PRIVATE_COOKIE ((void *) 13579) + +static void +validator_startup(ValidatorModuleState *state) +{ + state->private_data = PRIVATE_COOKIE; +} + +static void +validator_shutdown(ValidatorModuleState *state) +{ + /* do nothing */ +} + +static ValidatorModuleResult * +validate_token(ValidatorModuleState *state, const char *token, const char *role) +{ + ValidatorModuleResult *res; + + /* Check to make sure our private state still exists. */ + if (state->private_data != PRIVATE_COOKIE) + elog(ERROR, "oauth_validator: private state cookie changed to %p", + state->private_data); + + res = palloc(sizeof(ValidatorModuleResult)); + + elog(LOG, "oauth_validator: token=\"%s\", role=\"%s\"", token, role); + elog(LOG, "oauth_validator: issuer=\"%s\", scope=\"%s\"", + MyProcPort->hba->oauth_issuer, + MyProcPort->hba->oauth_scope); + + res->authorized = true; + res->authn_id = pstrdup(role); + + return res; +} diff --git a/src/test/perl/PostgreSQL/Test/Cluster.pm b/src/test/perl/PostgreSQL/Test/Cluster.pm index 0135c5a795..f14839f4c5 100644 --- a/src/test/perl/PostgreSQL/Test/Cluster.pm +++ b/src/test/perl/PostgreSQL/Test/Cluster.pm @@ -2388,6 +2388,11 @@ instead of the default. If this regular expression is set, matches it with the output generated. +=item expected_stderr => B + +If this regular expression is set, matches it against the standard error +stream; otherwise the stderr must be empty. + =item log_like => [ qr/required message/ ] =item log_unlike => [ qr/prohibited message/ ] @@ -2431,7 +2436,14 @@ sub connect_ok like($stdout, $params{expected_stdout}, "$test_name: stdout matches"); } - is($stderr, "", "$test_name: no stderr"); + if (defined($params{expected_stderr})) + { + like($stderr, $params{expected_stderr}, "$test_name: stderr matches"); + } + else + { + is($stderr, "", "$test_name: no stderr"); + } $self->log_check($test_name, $log_location, %params); } diff --git a/src/test/perl/PostgreSQL/Test/OAuthServer.pm b/src/test/perl/PostgreSQL/Test/OAuthServer.pm new file mode 100644 index 0000000000..d96733f531 --- /dev/null +++ b/src/test/perl/PostgreSQL/Test/OAuthServer.pm @@ -0,0 +1,65 @@ +#!/usr/bin/perl + +package PostgreSQL::Test::OAuthServer; + +use warnings; +use strict; +use threads; +use Scalar::Util; +use Socket; +use IO::Select; + +local *server_socket; + +sub new +{ + my $class = shift; + + my $self = {}; + bless($self, $class); + + return $self; +} + +sub port +{ + my $self = shift; + + return $self->{'port'}; +} + +sub run +{ + my $self = shift; + my $port; + + my $pid = open(my $read_fh, "-|", $ENV{PYTHON}, "t/oauth_server.py") + // die "failed to start OAuth server: $!"; + + read($read_fh, $port, 7) // die "failed to read port number: $!"; + chomp $port; + die "server did not advertise a valid port" + unless Scalar::Util::looks_like_number($port); + + $self->{'pid'} = $pid; + $self->{'port'} = $port; + $self->{'child'} = $read_fh; + + print("# OAuth provider (PID $pid) is listening on port $port\n"); +} + +sub stop +{ + my $self = shift; + + print("# Sending SIGTERM to OAuth provider PID: $self->{'pid'}\n"); + + kill(15, $self->{'pid'}); + $self->{'pid'} = undef; + + # Closing the popen() handle waits for the process to exit. + close($self->{'child'}); + $self->{'child'} = undef; +} + +1; diff --git a/src/tools/pgindent/typedefs.list b/src/tools/pgindent/typedefs.list index 8f323da558..b02ee48898 100644 --- a/src/tools/pgindent/typedefs.list +++ b/src/tools/pgindent/typedefs.list @@ -1715,6 +1715,7 @@ NumericSortSupport NumericSumAccum NumericVar OAuthStep +OAuthValidatorCallbacks OM_uint32 OP OSAPerGroupState @@ -3055,6 +3056,7 @@ VacuumRelation VacuumStmt ValidIOData ValidateIndexState +ValidatorModuleState ValuesScan ValuesScanState Var @@ -3647,6 +3649,7 @@ normal_rand_fctx nsphash_hash ntile_context numeric +oauth_state object_access_hook_type object_access_hook_type_str off_t -- 2.34.1