From 877e1f22943e4ecbe97ae71a359c142c3163ac86 Mon Sep 17 00:00:00 2001 From: houzj Date: Tue, 11 May 2021 08:54:00 +0800 Subject: [PATCH 1/2] check UDF parallel safety in fmgr_info --- src/backend/utils/fmgr/fmgr.c | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/backend/utils/fmgr/fmgr.c b/src/backend/utils/fmgr/fmgr.c index 3dfe6e5..a20faf3 100644 --- a/src/backend/utils/fmgr/fmgr.c +++ b/src/backend/utils/fmgr/fmgr.c @@ -16,6 +16,8 @@ #include "postgres.h" #include "access/detoast.h" +#include "access/parallel.h" +#include "access/xact.h" #include "catalog/pg_language.h" #include "catalog/pg_proc.h" #include "catalog/pg_type.h" @@ -56,6 +58,7 @@ static HTAB *CFuncHash = NULL; static void fmgr_info_cxt_security(Oid functionId, FmgrInfo *finfo, MemoryContext mcxt, bool ignore_security); +static void fmgr_check_parallel_safety(char parallel_safety, Oid functionId); static void fmgr_info_C_lang(Oid functionId, FmgrInfo *finfo, HeapTuple procedureTuple); static void fmgr_info_other_lang(Oid functionId, FmgrInfo *finfo, HeapTuple procedureTuple); static CFuncHashTabEntry *lookup_C_func(HeapTuple procedureTuple); @@ -183,6 +186,9 @@ fmgr_info_cxt_security(Oid functionId, FmgrInfo *finfo, MemoryContext mcxt, elog(ERROR, "cache lookup failed for function %u", functionId); procedureStruct = (Form_pg_proc) GETSTRUCT(procedureTuple); + /* Check parallel safety for other functions */ + fmgr_check_parallel_safety(procedureStruct->proparallel, functionId); + finfo->fn_nargs = procedureStruct->pronargs; finfo->fn_strict = procedureStruct->proisstrict; finfo->fn_retset = procedureStruct->proretset; @@ -264,6 +270,20 @@ fmgr_info_cxt_security(Oid functionId, FmgrInfo *finfo, MemoryContext mcxt, ReleaseSysCache(procedureTuple); } +static void +fmgr_check_parallel_safety(char parallel_safety, Oid functionId) +{ + if (IsInParallelMode() && + ((IsParallelWorker() && + parallel_safety == PROPARALLEL_RESTRICTED) || + parallel_safety == PROPARALLEL_UNSAFE)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_TRANSACTION_STATE), + errmsg("parallel-safety execution violation of function \"%s\" (%c)", + get_func_name(functionId), parallel_safety))); +} + + /* * Return module and C function name providing implementation of functionId. * -- 2.7.2.windows.1