From 22a7ce068dd53c54e037c07488a14716ef979b34 Mon Sep 17 00:00:00 2001 From: "okbob@github.com" Date: Wed, 12 Jun 2024 21:34:05 +0200 Subject: [PATCH 1/3] use strict rules for parsing PL/pgSQL expressions Originally the rule PLpgSQL_Expr allows almost all SQL clauses. It was designed to allow old undocumented syntax var := col FROM tab; The reason for support of this "strange" syntax was technical. The PLpgSQL parser cannot use SQL parser accurately (it was really primitive), and people found this undocumented syntax. Lattery, when it was possible to do exact parsing, from compatibility reasons, the parsing of PL/pgSQL expressions allows described syntax. Unfortunately, with support almost all SQL clauses, the PLpgSQL can accept really broken code like DO $$ DECLARE l_cnt int; BEGIN l_cnt := 1 DELETE FROM foo3 WHERE id=1; END; $$; proposed patch introduce new extra error check strict_expr_check, that solve this issue. --- doc/src/sgml/plpgsql.sgml | 19 ++++ src/pl/plpgsql/src/pl_comp.c | 7 ++ src/pl/plpgsql/src/pl_gram.y | 138 ++++++++++++++++++++++---- src/pl/plpgsql/src/pl_handler.c | 2 + src/pl/plpgsql/src/plpgsql.h | 1 + src/test/regress/expected/plpgsql.out | 14 +++ src/test/regress/sql/plpgsql.sql | 14 +++ 7 files changed, 177 insertions(+), 18 deletions(-) diff --git a/doc/src/sgml/plpgsql.sgml b/doc/src/sgml/plpgsql.sgml index e937491e6b8..cbe5c2be28b 100644 --- a/doc/src/sgml/plpgsql.sgml +++ b/doc/src/sgml/plpgsql.sgml @@ -5388,6 +5388,25 @@ a_output := a_output || $$ if v_$$ || referrer_keys.kind || $$ like '$$ + + + strict_expr_check + + + Enabling this check will cause PL/pgSQL to + check if a PL/pgSQL expression is just an + expression without any SQL clauses like FROM, + ORDER BY. This undocumented form of expressions + is allowed for compatibility reasons, but in some special cases + it doesn't allow to detect broken code. + + + + This check is allowed only when plpgsql.extra_errors + is set to "strict_expr_check". + + + The following example shows the effect of plpgsql.extra_warnings diff --git a/src/pl/plpgsql/src/pl_comp.c b/src/pl/plpgsql/src/pl_comp.c index b80c59447fb..840c6ee0c06 100644 --- a/src/pl/plpgsql/src/pl_comp.c +++ b/src/pl/plpgsql/src/pl_comp.c @@ -796,6 +796,13 @@ plpgsql_compile_inline(char *proc_source) function->extra_warnings = 0; function->extra_errors = 0; + /* + * Although function->extra_errors is disabled, we want to + * do strict_expr_check inside annoymous block too. + */ + if (plpgsql_extra_errors & PLPGSQL_XCHECK_STRICTEXPRCHECK) + function->extra_errors = PLPGSQL_XCHECK_STRICTEXPRCHECK; + function->nstatements = 0; function->requires_procedure_resowner = false; function->has_exception_block = false; diff --git a/src/pl/plpgsql/src/pl_gram.y b/src/pl/plpgsql/src/pl_gram.y index 5612e66d023..dcc581afdbf 100644 --- a/src/pl/plpgsql/src/pl_gram.y +++ b/src/pl/plpgsql/src/pl_gram.y @@ -18,6 +18,7 @@ #include "catalog/namespace.h" #include "catalog/pg_proc.h" #include "catalog/pg_type.h" +#include "nodes/nodeFuncs.h" #include "parser/parser.h" #include "parser/parse_type.h" #include "parser/scanner.h" @@ -71,6 +72,7 @@ static PLpgSQL_expr *read_sql_construct(int until, const char *expected, RawParseMode parsemode, bool isexpression, + bool allowlist, bool valid_sql, int *startloc, int *endtoken, @@ -106,7 +108,7 @@ static PLpgSQL_row *make_scalar_list1(char *initial_name, PLpgSQL_datum *initial_datum, int lineno, int location, yyscan_t yyscanner); static void check_sql_expr(const char *stmt, - RawParseMode parseMode, int location, yyscan_t yyscanner); + RawParseMode parseMode, bool allowlist, int location, yyscan_t yyscanner); static void plpgsql_sql_error_callback(void *arg); static PLpgSQL_type *parse_datatype(const char *string, int location, yyscan_t yyscanner); static void check_labels(const char *start_label, @@ -117,6 +119,7 @@ static PLpgSQL_expr *read_cursor_args(PLpgSQL_var *cursor, int until, YYSTYPE *yylvalp, YYLTYPE *yyllocp, yyscan_t yyscanner); static List *read_raise_options(YYSTYPE *yylvalp, YYLTYPE *yyllocp, yyscan_t yyscanner); static void check_raise_parameters(PLpgSQL_stmt_raise *stmt); +static bool is_strict_expr(List *parsetree, int *errpos, bool allowlist); %} @@ -193,6 +196,7 @@ static void check_raise_parameters(PLpgSQL_stmt_raise *stmt); %type expr_until_semi %type expr_until_then expr_until_loop opt_expr_until_when %type opt_exitcond +%type expressions_until_then %type cursor_variable %type decl_cursor_arg @@ -914,7 +918,7 @@ stmt_perform : K_PERFORM */ new->expr = read_sql_construct(';', 0, 0, ";", RAW_PARSE_DEFAULT, - false, false, + false, false, false, &startloc, NULL, &yylval, &yylloc, yyscanner); /* overwrite "perform" ... */ @@ -924,7 +928,7 @@ stmt_perform : K_PERFORM strlen(new->expr->query)); /* offset syntax error position to account for that */ check_sql_expr(new->expr->query, new->expr->parseMode, - startloc + 1, yyscanner); + false, startloc + 1, yyscanner); $$ = (PLpgSQL_stmt *) new; } @@ -1001,7 +1005,7 @@ stmt_assign : T_DATUM plpgsql_push_back_token(T_DATUM, &yylval, &yylloc, yyscanner); new->expr = read_sql_construct(';', 0, 0, ";", pmode, - false, true, + false, false, true, NULL, NULL, &yylval, &yylloc, yyscanner); mark_expr_as_assignment_source(new->expr, $1.datum); @@ -1262,7 +1266,7 @@ case_when_list : case_when_list case_when } ; -case_when : K_WHEN expr_until_then proc_sect +case_when : K_WHEN expressions_until_then proc_sect { PLpgSQL_case_when *new = palloc(sizeof(PLpgSQL_case_when)); @@ -1292,6 +1296,15 @@ opt_case_else : } ; +expressions_until_then : + { + $$ = read_sql_construct(K_THEN, 0, 0, "THEN", + RAW_PARSE_PLPGSQL_EXPR, /* expr_list */ + true, true, true, NULL, NULL, + &yylval, &yylloc, yyscanner); + } + ; + stmt_loop : opt_loop_label K_LOOP loop_body { PLpgSQL_stmt_loop *new; @@ -1495,6 +1508,7 @@ for_control : for_variable K_IN RAW_PARSE_DEFAULT, true, false, + false, &expr1loc, &tok, &yylval, &yylloc, yyscanner); @@ -1513,7 +1527,7 @@ for_control : for_variable K_IN */ expr1->parseMode = RAW_PARSE_PLPGSQL_EXPR; check_sql_expr(expr1->query, expr1->parseMode, - expr1loc, yyscanner); + false, expr1loc, yyscanner); /* Read and check the second one */ expr2 = read_sql_expression2(K_LOOP, K_BY, @@ -1570,7 +1584,7 @@ for_control : for_variable K_IN /* Check syntax as a regular query */ check_sql_expr(expr1->query, expr1->parseMode, - expr1loc, yyscanner); + false, expr1loc, yyscanner); new = palloc0(sizeof(PLpgSQL_stmt_fors)); new->cmd_type = PLPGSQL_STMT_FORS; @@ -1902,7 +1916,7 @@ stmt_raise : K_RAISE expr = read_sql_construct(',', ';', K_USING, ", or ; or USING", RAW_PARSE_PLPGSQL_EXPR, - true, true, + true, false, true, NULL, &tok, &yylval, &yylloc, yyscanner); new->params = lappend(new->params, expr); @@ -2040,7 +2054,7 @@ stmt_dynexecute : K_EXECUTE expr = read_sql_construct(K_INTO, K_USING, ';', "INTO or USING or ;", RAW_PARSE_PLPGSQL_EXPR, - true, true, + true, false, true, NULL, &endtoken, &yylval, &yylloc, yyscanner); @@ -2080,7 +2094,7 @@ stmt_dynexecute : K_EXECUTE expr = read_sql_construct(',', ';', K_INTO, ", or ; or INTO", RAW_PARSE_PLPGSQL_EXPR, - true, true, + true, false, true, NULL, &endtoken, &yylval, &yylloc, yyscanner); new->params = lappend(new->params, expr); @@ -2713,7 +2727,7 @@ read_sql_expression(int until, const char *expected, YYSTYPE *yylvalp, YYLTYPE * { return read_sql_construct(until, 0, 0, expected, RAW_PARSE_PLPGSQL_EXPR, - true, true, NULL, NULL, + true, false, true, NULL, NULL, yylvalp, yyllocp, yyscanner); } @@ -2724,7 +2738,7 @@ read_sql_expression2(int until, int until2, const char *expected, { return read_sql_construct(until, until2, 0, expected, RAW_PARSE_PLPGSQL_EXPR, - true, true, NULL, endtoken, + true, false, true, NULL, endtoken, yylvalp, yyllocp, yyscanner); } @@ -2734,7 +2748,7 @@ read_sql_stmt(YYSTYPE *yylvalp, YYLTYPE *yyllocp, yyscan_t yyscanner) { return read_sql_construct(';', 0, 0, ";", RAW_PARSE_DEFAULT, - false, true, NULL, NULL, + false, false, true, NULL, NULL, yylvalp, yyllocp, yyscanner); } @@ -2747,6 +2761,7 @@ read_sql_stmt(YYSTYPE *yylvalp, YYLTYPE *yyllocp, yyscan_t yyscanner) * expected: text to use in complaining that terminator was not found * parsemode: raw_parser() mode to use * isexpression: whether to say we're reading an "expression" or a "statement" + * allowlist: the result can be list of expressions * valid_sql: whether to check the syntax of the expr * startloc: if not NULL, location of first token is stored at *startloc * endtoken: if not NULL, ending token is stored at *endtoken @@ -2759,6 +2774,7 @@ read_sql_construct(int until, const char *expected, RawParseMode parsemode, bool isexpression, + bool allowlist, bool valid_sql, int *startloc, int *endtoken, @@ -2854,7 +2870,7 @@ read_sql_construct(int until, pfree(ds.data); if (valid_sql) - check_sql_expr(expr->query, expr->parseMode, startlocation, yyscanner); + check_sql_expr(expr->query, expr->parseMode, allowlist, startlocation, yyscanner); return expr; } @@ -3175,7 +3191,7 @@ make_execsql_stmt(int firsttoken, int location, PLword *word, YYSTYPE *yylvalp, expr = make_plpgsql_expr(ds.data, RAW_PARSE_DEFAULT); pfree(ds.data); - check_sql_expr(expr->query, expr->parseMode, location, yyscanner); + check_sql_expr(expr->query, expr->parseMode, false, location, yyscanner); execsql = palloc0(sizeof(PLpgSQL_stmt_execsql)); execsql->cmd_type = PLPGSQL_STMT_EXECSQL; @@ -3775,11 +3791,15 @@ make_scalar_list1(char *initial_name, * If no error cursor is provided, we'll just point at "location". */ static void -check_sql_expr(const char *stmt, RawParseMode parseMode, int location, yyscan_t yyscanner) +check_sql_expr(const char *stmt, + RawParseMode parseMode, bool allowlist, + int location, yyscan_t yyscanner) { sql_error_callback_arg cbarg; ErrorContextCallback syntax_errcontext; MemoryContext oldCxt; + List *parsetree; + int errpos; if (!plpgsql_check_syntax) return; @@ -3793,11 +3813,25 @@ check_sql_expr(const char *stmt, RawParseMode parseMode, int location, yyscan_t error_context_stack = &syntax_errcontext; oldCxt = MemoryContextSwitchTo(plpgsql_compile_tmp_cxt); - (void) raw_parser(stmt, parseMode); + parsetree = raw_parser(stmt, parseMode); MemoryContextSwitchTo(oldCxt); /* Restore former ereport callback */ error_context_stack = syntax_errcontext.previous; + + if (plpgsql_curr_compile->extra_warnings & PLPGSQL_XCHECK_STRICTEXPRCHECK || + plpgsql_curr_compile->extra_errors & PLPGSQL_XCHECK_STRICTEXPRCHECK) + { + /* do this check only for expressions */ + if (parseMode == RAW_PARSE_DEFAULT) + return; + + if (!is_strict_expr(parsetree, &errpos, allowlist)) + ereport(plpgsql_curr_compile->extra_errors & PLPGSQL_XCHECK_STRICTEXPRCHECK ? ERROR : WARNING, + (errcode(ERRCODE_SYNTAX_ERROR), + errmsg("syntax of expression is not strict"), + parser_errposition(errpos != -1 ? location + errpos : location))); + } } static void @@ -3831,6 +3865,74 @@ plpgsql_sql_error_callback(void *arg) errposition(0); } +/* + * Returns true, when the only targetList is in parsetree. Cursors + * can require list of expressions or list of named expressions. + */ +static bool +is_strict_expr(List *parsetree, int *errpos, bool allowlist) +{ + RawStmt *rawstmt; + SelectStmt *select; + int targets = 0; + ListCell *lc; + + /* Top should be RawStmt */ + rawstmt = castNode(RawStmt, linitial(parsetree)); + + if (IsA(rawstmt->stmt, SelectStmt)) + { + select = (SelectStmt *) rawstmt->stmt; + } + else if (IsA(rawstmt->stmt, PLAssignStmt)) + { + select = castNode(SelectStmt, ((PLAssignStmt *) rawstmt->stmt)->val); + } + else + elog(ERROR, "unexpected node type"); + + if (!select->targetList) + { + *errpos = -1; + return false; + } + else + *errpos = exprLocation((Node *) select->targetList); + + if (select->distinctClause || + select->fromClause || + select->whereClause || + select->groupClause || + select->groupDistinct || + select->havingClause || + select->windowClause || + select->sortClause || + select->limitOffset || + select->limitCount || + select->limitOption || + select->lockingClause) + return false; + + foreach(lc, select->targetList) + { + ResTarget *rt = castNode(ResTarget, lfirst(lc)); + + if (targets++ >= 1 && !allowlist) + { + *errpos = exprLocation((Node *) rt); + return false; + } + + if (rt->name) + { + *errpos = exprLocation((Node *) rt); + return false; + } + } + + return true; +} + /* * Parse a SQL datatype name and produce a PLpgSQL_type structure. * @@ -4014,7 +4116,7 @@ read_cursor_args(PLpgSQL_var *cursor, int until, YYSTYPE *yylvalp, YYLTYPE *yyll item = read_sql_construct(',', ')', 0, ",\" or \")", RAW_PARSE_PLPGSQL_EXPR, - true, true, + true, false, true, NULL, &endtoken, yylvalp, yyllocp, yyscanner); diff --git a/src/pl/plpgsql/src/pl_handler.c b/src/pl/plpgsql/src/pl_handler.c index e9a72929947..b3ba3163e9a 100644 --- a/src/pl/plpgsql/src/pl_handler.c +++ b/src/pl/plpgsql/src/pl_handler.c @@ -97,6 +97,8 @@ plpgsql_extra_checks_check_hook(char **newvalue, void **extra, GucSource source) extrachecks |= PLPGSQL_XCHECK_TOOMANYROWS; else if (pg_strcasecmp(tok, "strict_multi_assignment") == 0) extrachecks |= PLPGSQL_XCHECK_STRICTMULTIASSIGNMENT; + else if (pg_strcasecmp(tok, "strict_expr_check") == 0) + extrachecks |= PLPGSQL_XCHECK_STRICTEXPRCHECK; else if (pg_strcasecmp(tok, "all") == 0 || pg_strcasecmp(tok, "none") == 0) { GUC_check_errdetail("Key word \"%s\" cannot be combined with other key words.", tok); diff --git a/src/pl/plpgsql/src/plpgsql.h b/src/pl/plpgsql/src/plpgsql.h index 41e52b8ce71..459f5f2e223 100644 --- a/src/pl/plpgsql/src/plpgsql.h +++ b/src/pl/plpgsql/src/plpgsql.h @@ -1195,6 +1195,7 @@ extern bool plpgsql_check_asserts; #define PLPGSQL_XCHECK_SHADOWVAR (1 << 1) #define PLPGSQL_XCHECK_TOOMANYROWS (1 << 2) #define PLPGSQL_XCHECK_STRICTMULTIASSIGNMENT (1 << 3) +#define PLPGSQL_XCHECK_STRICTEXPRCHECK (1 << 4) #define PLPGSQL_XCHECK_ALL ((int) ~0) extern int plpgsql_extra_warnings; diff --git a/src/test/regress/expected/plpgsql.out b/src/test/regress/expected/plpgsql.out index d8ce39dba3c..8f4f5cb1183 100644 --- a/src/test/regress/expected/plpgsql.out +++ b/src/test/regress/expected/plpgsql.out @@ -3084,6 +3084,20 @@ select shadowtest(1); t (1 row) +-- test of strict expression check +set plpgsql.extra_errors to 'strict_expr_check'; +create or replace function strict_expr_check_func() +returns void as $$ +declare var int; +begin + var = 1 + delete from pg_class where false; +end; +$$ language plpgsql; +ERROR: syntax of expression is not strict +LINE 5: var = 1 + ^ +reset plpgsql.extra_errors; -- runtime extra checks set plpgsql.extra_warnings to 'too_many_rows'; do $$ diff --git a/src/test/regress/sql/plpgsql.sql b/src/test/regress/sql/plpgsql.sql index d413d995d17..dd0d908d422 100644 --- a/src/test/regress/sql/plpgsql.sql +++ b/src/test/regress/sql/plpgsql.sql @@ -2618,6 +2618,20 @@ declare f1 int; begin return 1; end $$ language plpgsql; select shadowtest(1); +-- test of strict expression check +set plpgsql.extra_errors to 'strict_expr_check'; + +create or replace function strict_expr_check_func() +returns void as $$ +declare var int; +begin + var = 1 + delete from pg_class where false; +end; +$$ language plpgsql; + +reset plpgsql.extra_errors; + -- runtime extra checks set plpgsql.extra_warnings to 'too_many_rows'; -- 2.49.0