From 80e19af369d730722e531f9ea3d6e32ec71558e5 Mon Sep 17 00:00:00 2001 From: Ashutosh Sharma Date: Wed, 24 Jul 2024 11:48:15 +0000 Subject: [PATCH] Introduce new control file parameter 'protected' to define implicit search_path for extension functions. When enabled, this parameter defines the implicit search_path for functions and procedures created by extensions if no explicit search_path is specified. It includes $extension_schema, pg_temp, and function_schema (if different from the extension's schema). Here $extension_schema is a special name that dynamically resolves to all schemas on which the extension depends. This resolution occurs at the time function or procedure execution. --- doc/src/sgml/extend.sgml | 23 +++++++-- src/backend/catalog/namespace.c | 32 +++++++++++++ src/backend/catalog/pg_depend.c | 43 +++++++++++++++++ src/backend/commands/extension.c | 65 ++++++++++++++++++++----- src/backend/commands/functioncmds.c | 73 ++++++++++++++++++++++++++++- src/backend/utils/fmgr/fmgr.c | 16 +++++++ src/include/catalog/dependency.h | 1 + src/include/commands/extension.h | 5 ++ 8 files changed, 241 insertions(+), 17 deletions(-) diff --git a/doc/src/sgml/extend.sgml b/doc/src/sgml/extend.sgml index 218940ee5c..a319119a02 100644 --- a/doc/src/sgml/extend.sgml +++ b/doc/src/sgml/extend.sgml @@ -822,6 +822,21 @@ RETURNS anycompatible AS ... + + + protected (boolean) + + + This parameter, if set to true (which is not the default), defines the + implicit search_path for functions and procedures created by the + extension. It sets the search_path to + $extension_schema, pg_temp, where + $extension_schema is a special name that dynamically + resolves to all schemas on which the extension depends. This resolution + occurs at the time of function or procedure execution. + + + @@ -1288,10 +1303,10 @@ SELECT * FROM pg_extension_update_paths('extension_name If you cannot set the search_path to contain only - secure schemas, assume that each unqualified name could resolve to an - object that a malicious user has defined. Beware of constructs that - depend on search_path implicitly; for - example, IN + secure schemas, or mark the extension as protected, then assume that each + unqualified name could resolve to an object that a malicious user has + defined. Beware of constructs that depend on + search_path implicitly; for example, IN and CASE expression WHEN always select an operator using the search path. In their place, use OPERATOR(schema.=) ANY diff --git a/src/backend/catalog/namespace.c b/src/backend/catalog/namespace.c index 43b707699d..05fca3354c 100644 --- a/src/backend/catalog/namespace.c +++ b/src/backend/catalog/namespace.c @@ -42,6 +42,7 @@ #include "catalog/pg_ts_template.h" #include "catalog/pg_type.h" #include "commands/dbcommands.h" +#include "commands/extension.h" #include "common/hashfn_unstable.h" #include "funcapi.h" #include "mb/pg_wchar.h" @@ -4152,6 +4153,37 @@ preprocessNamespacePath(const char *searchPath, Oid roleid, *temp_missing = true; } } + else if (strcmp(curname, "$extension_schema") == 0) + { + /* + * $extension_schema --- substitute namespace on which the extension + * depends, if executing functions or procedures related to an + * extension that has search_path set in its proconfig to + * $extension_schema; otherwise, skip. + */ + Oid extOid = GetCurrentExtensionId(); + List *extList; + ListCell *lc; + + if (!OidIsValid(extOid)) + continue; + + extList = getExtensionsOfExtension(extOid); + extList = lappend_oid(extList, extOid); + + foreach(lc, extList) + { + extOid = lfirst_oid(lc); + + namespaceId = get_extension_schema(extOid); + if (OidIsValid(namespaceId) && + object_aclcheck(NamespaceRelationId, namespaceId, roleid, + ACL_USAGE) == ACLCHECK_OK) + oidlist = lappend_oid(oidlist, namespaceId); + } + + list_free(extList); + } else { /* normal namespace reference */ diff --git a/src/backend/catalog/pg_depend.c b/src/backend/catalog/pg_depend.c index cfd7ef51df..8a7f071c00 100644 --- a/src/backend/catalog/pg_depend.c +++ b/src/backend/catalog/pg_depend.c @@ -814,6 +814,49 @@ getAutoExtensionsOfObject(Oid classId, Oid objectId) return result; } +/* + * Return (possibly NIL) list of extensions that the given extension depends on + * in DEPENDENCY_NORMAL mode. + */ +List * +getExtensionsOfExtension(Oid objectId) +{ + List *result = NIL; + Relation depRel; + ScanKeyData key[2]; + SysScanDesc scan; + HeapTuple tup; + + depRel = table_open(DependRelationId, AccessShareLock); + + ScanKeyInit(&key[0], + Anum_pg_depend_classid, + BTEqualStrategyNumber, F_OIDEQ, + ObjectIdGetDatum(ExtensionRelationId)); + ScanKeyInit(&key[1], + Anum_pg_depend_objid, + BTEqualStrategyNumber, F_OIDEQ, + ObjectIdGetDatum(objectId)); + + scan = systable_beginscan(depRel, DependDependerIndexId, true, + NULL, 2, key); + + while (HeapTupleIsValid((tup = systable_getnext(scan)))) + { + Form_pg_depend depform = (Form_pg_depend) GETSTRUCT(tup); + + if (depform->refclassid == ExtensionRelationId && + depform->deptype == DEPENDENCY_NORMAL) + result = lappend_oid(result, depform->refobjid); + } + + systable_endscan(scan); + + table_close(depRel, AccessShareLock); + + return result; +} + /* * Detect whether a sequence is marked as "owned" by a column * diff --git a/src/backend/commands/extension.c b/src/backend/commands/extension.c index 1643c8c69a..2b4f52d8be 100644 --- a/src/backend/commands/extension.c +++ b/src/backend/commands/extension.c @@ -70,6 +70,8 @@ /* Globally visible state variables */ bool creating_extension = false; Oid CurrentExtensionObject = InvalidOid; +bool create_extension_set_search_path = false; +Oid CurrentExtensionId = InvalidOid; /* * Internal data structure to hold the results of parsing a control file @@ -86,6 +88,8 @@ typedef struct ExtensionControlFile bool relocatable; /* is ALTER EXTENSION SET SCHEMA supported? */ bool superuser; /* must be superuser to install? */ bool trusted; /* allow becoming superuser on the fly? */ + bool protected; /* should we protect extension by setting implicit + * search_path for functions and procedures? */ int encoding; /* encoding of the script file, or -1 */ List *requires; /* names of prerequisite extensions */ List *no_relocate; /* names of prerequisite extensions that @@ -117,7 +121,8 @@ static Oid get_required_extension(char *reqExtensionName, char *origSchemaName, bool cascade, List *parents, - bool is_create); + bool is_create, + bool set_search_path); static void get_available_versions_for_extension(ExtensionControlFile *pcontrol, Tuplestorestate *tupstore, TupleDesc tupdesc); @@ -128,12 +133,30 @@ static void ApplyExtensionUpdates(Oid extensionOid, List *updateVersions, char *origSchemaName, bool cascade, - bool is_create); + bool is_create, + bool set_search_path); static void ExecAlterExtensionContentsRecurse(AlterExtensionContentsStmt *stmt, ObjectAddress extension, ObjectAddress object); static char *read_whole_file(const char *filename, int *length); +/* + * SetCurrentExtensionId - Set the current extension Oid. + */ +void +SetCurrentExtensionId(Oid extensionOid) +{ + CurrentExtensionId = extensionOid; +} + +/* + * GetCurrentExtensionId - Get the current extension Oid. + */ +Oid +GetCurrentExtensionId() +{ + return CurrentExtensionId; +} /* * get_extension_oid - given an extension name, look up the OID @@ -585,6 +608,14 @@ parse_extension_control_file(ExtensionControlFile *control, errmsg("parameter \"%s\" requires a Boolean value", item->name))); } + else if (strcmp(item->name, "protected") == 0) + { + if (!parse_bool(item->value, &control->protected)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("parameter \"%s\" requires a Boolean value", + item->name))); + } else if (strcmp(item->name, "encoding") == 0) { control->encoding = pg_valid_server_encoding(item->value); @@ -871,7 +902,8 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control, const char *from_version, const char *version, List *requiredSchemas, - const char *schemaName, Oid schemaOid) + const char *schemaName, Oid schemaOid, + bool set_search_path) { bool switch_to_superuser = false; char *filename; @@ -992,6 +1024,7 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control, */ creating_extension = true; CurrentExtensionObject = extensionOid; + create_extension_set_search_path = set_search_path; PG_TRY(); { char *c_sql = read_extension_script_file(control, filename); @@ -1116,6 +1149,7 @@ execute_extension_script(Oid extensionOid, ExtensionControlFile *control, { creating_extension = false; CurrentExtensionObject = InvalidOid; + create_extension_set_search_path = false; } PG_END_TRY(); @@ -1475,6 +1509,7 @@ CreateExtensionInternal(char *extensionName, Oid extensionOid; ObjectAddress address; ListCell *lc; + bool set_search_path = false; /* * Read the primary control file. Note we assume that it does not contain @@ -1542,6 +1577,10 @@ CreateExtensionInternal(char *extensionName, */ control = read_extension_aux_control_file(pcontrol, versionName); + /* Check if this extension requires protection */ + if (control->protected) + set_search_path = true; + /* * Determine the target schema to install the extension into */ @@ -1648,7 +1687,8 @@ CreateExtensionInternal(char *extensionName, origSchemaName, cascade, parents, - is_create); + is_create, + set_search_path); reqschema = get_extension_schema(reqext); requiredExtensions = lappend_oid(requiredExtensions, reqext); requiredSchemas = lappend_oid(requiredSchemas, reqschema); @@ -1677,7 +1717,7 @@ CreateExtensionInternal(char *extensionName, execute_extension_script(extensionOid, control, NULL, versionName, requiredSchemas, - schemaName, schemaOid); + schemaName, schemaOid, set_search_path); /* * If additional update scripts have to be executed, apply the updates as @@ -1685,7 +1725,7 @@ CreateExtensionInternal(char *extensionName, */ ApplyExtensionUpdates(extensionOid, pcontrol, versionName, updateVersions, - origSchemaName, cascade, is_create); + origSchemaName, cascade, is_create, set_search_path); return address; } @@ -1699,7 +1739,8 @@ get_required_extension(char *reqExtensionName, char *origSchemaName, bool cascade, List *parents, - bool is_create) + bool is_create, + bool set_search_path) { Oid reqExtensionOid; @@ -3115,7 +3156,7 @@ ExecAlterExtensionStmt(ParseState *pstate, AlterExtensionStmt *stmt) */ ApplyExtensionUpdates(extensionOid, control, oldVersionName, updateVersions, - NULL, false, false); + NULL, false, false, false); ObjectAddressSet(address, ExtensionRelationId, extensionOid); @@ -3137,7 +3178,8 @@ ApplyExtensionUpdates(Oid extensionOid, List *updateVersions, char *origSchemaName, bool cascade, - bool is_create) + bool is_create, + bool set_search_path) { const char *oldVersionName = initialVersion; ListCell *lcv; @@ -3232,7 +3274,8 @@ ApplyExtensionUpdates(Oid extensionOid, origSchemaName, cascade, NIL, - is_create); + is_create, + set_search_path); reqschema = get_extension_schema(reqext); requiredExtensions = lappend_oid(requiredExtensions, reqext); requiredSchemas = lappend_oid(requiredSchemas, reqschema); @@ -3269,7 +3312,7 @@ ApplyExtensionUpdates(Oid extensionOid, execute_extension_script(extensionOid, control, oldVersionName, versionName, requiredSchemas, - schemaName, schemaOid); + schemaName, schemaOid, set_search_path); /* * Update prior-version name and loop around. Since diff --git a/src/backend/commands/functioncmds.c b/src/backend/commands/functioncmds.c index 6593fd7d81..79764f2996 100644 --- a/src/backend/commands/functioncmds.c +++ b/src/backend/commands/functioncmds.c @@ -52,6 +52,7 @@ #include "executor/functions.h" #include "funcapi.h" #include "miscadmin.h" +#include "nodes/makefuncs.h" #include "nodes/nodeFuncs.h" #include "optimizer/optimizer.h" #include "parser/analyze.h" @@ -71,6 +72,7 @@ #include "utils/snapmgr.h" #include "utils/syscache.h" #include "utils/typcache.h" +#include "utils/varlena.h" /* * Examine the RETURNS clause of the CREATE FUNCTION statement @@ -705,6 +707,25 @@ interpret_func_support(DefElem *defel) return procOid; } +/* + * Returns true if search_path is set in set_items list. + */ +static bool +IsSearchPathSet(List *set_items) +{ + ListCell *l; + + foreach(l, set_items) + { + VariableSetStmt *sstmt = lfirst_node(VariableSetStmt, l); + + if (pg_strcasecmp(sstmt->name, "search_path") == 0 && + sstmt->kind == VAR_SET_VALUE) + return true; + } + + return false; +} /* * Dissect the list of options assembled in gram.y into function @@ -726,7 +747,8 @@ compute_function_attributes(ParseState *pstate, float4 *procost, float4 *prorows, Oid *prosupport, - char *parallel_p) + char *parallel_p, + Oid namespaceId) { ListCell *option; DefElem *as_item = NULL; @@ -813,6 +835,53 @@ compute_function_attributes(ParseState *pstate, *security_definer = boolVal(security_item->arg); if (leakproof_item) *leakproof_p = boolVal(leakproof_item->arg); + + /* + * If "create_extension_set_search_path" is enabled, it indicates that the + * user has set "protected" flag inside the extension control file. + * Therefore, we must ensure that the function(s) created by an extension + * have their search_path set to trusted schema(s), which includes the + * schema where the function is being created and the search_path set by the + * extension. See execute_extension_script() for details on search_path set + * by the extension. + */ + if (creating_extension && create_extension_set_search_path) + { + /* If the search_path is already set, there is nothing to do. */ + if (!set_items || !IsSearchPathSet(set_items)) + { + StringInfoData sp_string; + VariableSetStmt *sp_node = makeNode(VariableSetStmt); + List *schemaList; + ListCell *lc; + + sp_node->kind = VAR_SET_VALUE; + sp_node->name = "search_path"; + + initStringInfo(&sp_string); + + if (namespaceId != get_extension_schema(CurrentExtensionObject)) + { + appendStringInfoString(&sp_string, get_namespace_name(namespaceId)); + appendStringInfoString(&sp_string, ", "); + } + appendStringInfoString(&sp_string, "$extension_schema, pg_temp"); + + (void) SplitIdentifierString(sp_string.data, ',', &schemaList); + + foreach(lc, schemaList) + { + char *schema_name = lfirst(lc); + + sp_node->args = lappend(sp_node->args, + makeStringConst(pstrdup(schema_name), -1)); + } + + set_items = lappend(set_items, sp_node); + pfree(sp_string.data); + } + } + if (set_items) *proconfig = update_proconfig_value(NULL, set_items); if (cost_item) @@ -1079,7 +1148,7 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt) &isWindowFunc, &volatility, &isStrict, &security, &isLeakProof, &proconfig, &procost, &prorows, - &prosupport, ¶llel); + &prosupport, ¶llel, namespaceId); if (!language) { diff --git a/src/backend/utils/fmgr/fmgr.c b/src/backend/utils/fmgr/fmgr.c index e48a86be54..e2211c82f3 100644 --- a/src/backend/utils/fmgr/fmgr.c +++ b/src/backend/utils/fmgr/fmgr.c @@ -16,6 +16,8 @@ #include "postgres.h" #include "access/detoast.h" +#include "commands/extension.h" +#include "catalog/dependency.h" #include "catalog/pg_language.h" #include "catalog/pg_proc.h" #include "catalog/pg_type.h" @@ -641,6 +643,15 @@ fmgr_security_definer(PG_FUNCTION_ARGS) *lc3; volatile int save_nestlevel; PgStat_FunctionCallUsage fcusage; + Oid extensionOid = InvalidOid; + + /* + * Let's check if this is an extension created function. If it is, we'll set + * the CurrentExtensionId before calling it, so that preprocessNamespacePath + * can handle $extension_schema correctly. + */ + extensionOid = getExtensionOfObject(ProcedureRelationId, + fcinfo->flinfo->fn_oid); if (!fcinfo->flinfo->fn_extra) { @@ -737,6 +748,9 @@ fmgr_security_definer(PG_FUNCTION_ARGS) */ save_flinfo = fcinfo->flinfo; + if (OidIsValid(extensionOid)) + SetCurrentExtensionId(extensionOid); + PG_TRY(); { fcinfo->flinfo = &fcache->flinfo; @@ -758,6 +772,7 @@ fmgr_security_definer(PG_FUNCTION_ARGS) PG_CATCH(); { fcinfo->flinfo = save_flinfo; + SetCurrentExtensionId(InvalidOid); if (fmgr_hook) (*fmgr_hook) (FHET_ABORT, &fcache->flinfo, &fcache->arg); PG_RE_THROW(); @@ -765,6 +780,7 @@ fmgr_security_definer(PG_FUNCTION_ARGS) PG_END_TRY(); fcinfo->flinfo = save_flinfo; + SetCurrentExtensionId(InvalidOid); if (fcache->configNames != NIL) AtEOXact_GUC(true, save_nestlevel); diff --git a/src/include/catalog/dependency.h b/src/include/catalog/dependency.h index 6908ca7180..1055c2f784 100644 --- a/src/include/catalog/dependency.h +++ b/src/include/catalog/dependency.h @@ -174,6 +174,7 @@ extern long changeDependenciesOn(Oid refClassId, Oid oldRefObjectId, extern Oid getExtensionOfObject(Oid classId, Oid objectId); extern List *getAutoExtensionsOfObject(Oid classId, Oid objectId); +extern List *getExtensionsOfExtension(Oid objectId); extern bool sequenceIsOwned(Oid seqId, char deptype, Oid *tableId, int32 *colId); extern List *getOwnedSequences(Oid relid); diff --git a/src/include/commands/extension.h b/src/include/commands/extension.h index c6f3f867eb..9512e8109c 100644 --- a/src/include/commands/extension.h +++ b/src/include/commands/extension.h @@ -29,6 +29,8 @@ */ extern PGDLLIMPORT bool creating_extension; extern PGDLLIMPORT Oid CurrentExtensionObject; +extern PGDLLIMPORT bool create_extension_set_search_path; +extern PGDLLIMPORT Oid CurrentExtensionId; extern ObjectAddress CreateExtension(ParseState *pstate, CreateExtensionStmt *stmt); @@ -53,4 +55,7 @@ extern bool extension_file_exists(const char *extensionName); extern ObjectAddress AlterExtensionNamespace(const char *extensionName, const char *newschema, Oid *oldschema); +extern void SetCurrentExtensionId(Oid extensionOid); +extern Oid GetCurrentExtensionId(void); + #endif /* EXTENSION_H */ -- 2.17.1