From 0f5f23370fcbfa25b6705bcf4a4667e51c2dbaf9 Mon Sep 17 00:00:00 2001 From: Richard Guo Date: Thu, 16 May 2024 06:17:37 +0000 Subject: [PATCH v4] Introduce a RTE for the grouping step --- .../postgres_fdw/expected/postgres_fdw.out | 2 +- src/backend/commands/explain.c | 21 +- src/backend/nodes/nodeFuncs.c | 14 ++ src/backend/nodes/outfuncs.c | 3 + src/backend/nodes/print.c | 4 + src/backend/nodes/readfuncs.c | 3 + src/backend/optimizer/path/allpaths.c | 4 + src/backend/optimizer/path/equivclass.c | 12 + src/backend/optimizer/plan/initsplan.c | 4 + src/backend/optimizer/plan/planner.c | 31 ++- src/backend/optimizer/plan/setrefs.c | 1 + src/backend/optimizer/prep/prepjointree.c | 9 +- src/backend/optimizer/util/var.c | 125 ++++++++++ src/backend/parser/parse_agg.c | 214 +++++++++++++++++- src/backend/parser/parse_clause.c | 4 +- src/backend/parser/parse_relation.c | 79 ++++++- src/backend/parser/parse_target.c | 2 + src/backend/utils/adt/ruleutils.c | 19 +- src/include/commands/explain.h | 1 + src/include/nodes/nodeFuncs.h | 2 + src/include/nodes/parsenodes.h | 7 + src/include/nodes/pathnodes.h | 5 + src/include/optimizer/optimizer.h | 1 + src/include/parser/parse_clause.h | 2 + src/include/parser/parse_node.h | 2 + src/include/parser/parse_relation.h | 2 + src/test/regress/expected/groupingsets.out | 49 ++++ src/test/regress/sql/groupingsets.sql | 23 ++ 28 files changed, 624 insertions(+), 21 deletions(-) diff --git a/contrib/postgres_fdw/expected/postgres_fdw.out b/contrib/postgres_fdw/expected/postgres_fdw.out index 078b8a966f..edc8f1d51b 100644 --- a/contrib/postgres_fdw/expected/postgres_fdw.out +++ b/contrib/postgres_fdw/expected/postgres_fdw.out @@ -3669,7 +3669,7 @@ select count(*), sum(t1.c1), avg(t2.c1) from (select c1 from ft4 where c1 betwee Foreign Scan Output: (count(*)), (sum(ft4.c1)), (avg(ft5.c1)) Relations: Aggregate on ((public.ft4) FULL JOIN (public.ft5)) - Remote SQL: SELECT count(*), sum(s4.c1), avg(s5.c1) FROM ((SELECT c1 FROM "S 1"."T 3" WHERE ((c1 >= 50)) AND ((c1 <= 60))) s4(c1) FULL JOIN (SELECT c1 FROM "S 1"."T 4" WHERE ((c1 >= 50)) AND ((c1 <= 60))) s5(c1) ON (((s4.c1 = s5.c1)))) + Remote SQL: SELECT count(*), sum(s5.c1), avg(s6.c1) FROM ((SELECT c1 FROM "S 1"."T 3" WHERE ((c1 >= 50)) AND ((c1 <= 60))) s5(c1) FULL JOIN (SELECT c1 FROM "S 1"."T 4" WHERE ((c1 >= 50)) AND ((c1 <= 60))) s6(c1) ON (((s5.c1 = s6.c1)))) (4 rows) select count(*), sum(t1.c1), avg(t2.c1) from (select c1 from ft4 where c1 between 50 and 60) t1 full join (select c1 from ft5 where c1 between 50 and 60) t2 on (t1.c1 = t2.c1); diff --git a/src/backend/commands/explain.c b/src/backend/commands/explain.c index 94511a5a02..6840c3d596 100644 --- a/src/backend/commands/explain.c +++ b/src/backend/commands/explain.c @@ -877,6 +877,7 @@ ExplainPrintPlan(ExplainState *es, QueryDesc *queryDesc) { Bitmapset *rels_used = NULL; PlanState *ps; + ListCell *lc; /* Set up ExplainState fields associated with this plan tree */ Assert(queryDesc->plannedstmt != NULL); @@ -887,6 +888,14 @@ ExplainPrintPlan(ExplainState *es, QueryDesc *queryDesc) es->deparse_cxt = deparse_context_for_plan_tree(queryDesc->plannedstmt, es->rtable_names); es->printed_subplans = NULL; + es->rtable_size = list_length(es->rtable); + foreach (lc, es->rtable) + { + RangeTblEntry *rte = lfirst_node(RangeTblEntry, lc); + + if (rte->rtekind == RTE_GROUP) + es->rtable_size--; + } /* * Sometimes we mark a Gather node as "invisible", which means that it's @@ -2463,7 +2472,7 @@ show_plan_tlist(PlanState *planstate, List *ancestors, ExplainState *es) context = set_deparse_context_plan(es->deparse_cxt, plan, ancestors); - useprefix = list_length(es->rtable) > 1; + useprefix = es->rtable_size > 1; /* Deparse each result column (we now include resjunk ones) */ foreach(lc, plan->targetlist) @@ -2547,7 +2556,7 @@ show_upper_qual(List *qual, const char *qlabel, { bool useprefix; - useprefix = (list_length(es->rtable) > 1 || es->verbose); + useprefix = (es->rtable_size > 1 || es->verbose); show_qual(qual, qlabel, planstate, ancestors, useprefix, es); } @@ -2637,7 +2646,7 @@ show_grouping_sets(PlanState *planstate, Agg *agg, context = set_deparse_context_plan(es->deparse_cxt, planstate->plan, ancestors); - useprefix = (list_length(es->rtable) > 1 || es->verbose); + useprefix = (es->rtable_size > 1 || es->verbose); ExplainOpenGroup("Grouping Sets", "Grouping Sets", false, es); @@ -2777,7 +2786,7 @@ show_sort_group_keys(PlanState *planstate, const char *qlabel, context = set_deparse_context_plan(es->deparse_cxt, plan, ancestors); - useprefix = (list_length(es->rtable) > 1 || es->verbose); + useprefix = (es->rtable_size > 1 || es->verbose); for (keyno = 0; keyno < nkeys; keyno++) { @@ -2889,7 +2898,7 @@ show_tablesample(TableSampleClause *tsc, PlanState *planstate, context = set_deparse_context_plan(es->deparse_cxt, planstate->plan, ancestors); - useprefix = list_length(es->rtable) > 1; + useprefix = es->rtable_size > 1; /* Get the tablesample method name */ method_name = get_func_name(tsc->tsmhandler); @@ -3339,7 +3348,7 @@ show_memoize_info(MemoizeState *mstate, List *ancestors, ExplainState *es) * It's hard to imagine having a memoize node with fewer than 2 RTEs, but * let's just keep the same useprefix logic as elsewhere in this file. */ - useprefix = list_length(es->rtable) > 1 || es->verbose; + useprefix = es->rtable_size > 1 || es->verbose; /* Set up deparsing context */ context = set_deparse_context_plan(es->deparse_cxt, diff --git a/src/backend/nodes/nodeFuncs.c b/src/backend/nodes/nodeFuncs.c index 89ee4b61f2..6f0f8e8c54 100644 --- a/src/backend/nodes/nodeFuncs.c +++ b/src/backend/nodes/nodeFuncs.c @@ -2862,6 +2862,11 @@ range_table_entry_walker_impl(RangeTblEntry *rte, case RTE_RESULT: /* nothing to do */ break; + case RTE_GROUP: + if (!(flags & QTW_IGNORE_GROUPEXPRS)) + if (WALK(rte->groupexprs)) + return true; + break; } if (WALK(rte->securityQuals)) @@ -3900,6 +3905,15 @@ range_table_mutator_impl(List *rtable, case RTE_RESULT: /* nothing to do */ break; + case RTE_GROUP: + if (!(flags & QTW_IGNORE_GROUPEXPRS)) + MUTATE(newrte->groupexprs, rte->groupexprs, List *); + else + { + /* else, copy group exprs as-is */ + newrte->groupexprs = copyObject(rte->groupexprs); + } + break; } MUTATE(newrte->securityQuals, rte->securityQuals, List *); newrt = lappend(newrt, newrte); diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c index 3337b77ae6..9827cf16be 100644 --- a/src/backend/nodes/outfuncs.c +++ b/src/backend/nodes/outfuncs.c @@ -562,6 +562,9 @@ _outRangeTblEntry(StringInfo str, const RangeTblEntry *node) case RTE_RESULT: /* no extra fields */ break; + case RTE_GROUP: + WRITE_NODE_FIELD(groupexprs); + break; default: elog(ERROR, "unrecognized RTE kind: %d", (int) node->rtekind); break; diff --git a/src/backend/nodes/print.c b/src/backend/nodes/print.c index 02798f4482..03416e8f4a 100644 --- a/src/backend/nodes/print.c +++ b/src/backend/nodes/print.c @@ -300,6 +300,10 @@ print_rt(const List *rtable) printf("%d\t%s\t[result]", i, rte->eref->aliasname); break; + case RTE_GROUP: + printf("%d\t%s\t[group]", + i, rte->eref->aliasname); + break; default: printf("%d\t%s\t[unknown rtekind]", i, rte->eref->aliasname); diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c index c4d01a441a..818e472a3b 100644 --- a/src/backend/nodes/readfuncs.c +++ b/src/backend/nodes/readfuncs.c @@ -422,6 +422,9 @@ _readRangeTblEntry(void) case RTE_RESULT: /* no extra fields */ break; + case RTE_GROUP: + READ_NODE_FIELD(groupexprs); + break; default: elog(ERROR, "unrecognized RTE kind: %d", (int) local_node->rtekind); diff --git a/src/backend/optimizer/path/allpaths.c b/src/backend/optimizer/path/allpaths.c index 4895cee994..2ee478195f 100644 --- a/src/backend/optimizer/path/allpaths.c +++ b/src/backend/optimizer/path/allpaths.c @@ -731,6 +731,10 @@ set_rel_consider_parallel(PlannerInfo *root, RelOptInfo *rel, case RTE_RESULT: /* RESULT RTEs, in themselves, are no problem. */ break; + case RTE_GROUP: + /* Shouldn't happen; we're only considering baserels here. */ + Assert(false); + return; } /* diff --git a/src/backend/optimizer/path/equivclass.c b/src/backend/optimizer/path/equivclass.c index 21ce1ae2e1..61c450bb99 100644 --- a/src/backend/optimizer/path/equivclass.c +++ b/src/backend/optimizer/path/equivclass.c @@ -737,6 +737,10 @@ get_eclass_for_sort_expr(PlannerInfo *root, { RelOptInfo *rel = root->simple_rel_array[i]; + /* ignore GROUP RTE */ + if (i == root->group_rtindex) + continue; + if (rel == NULL) /* must be an outer join */ { Assert(bms_is_member(i, root->outer_join_rels)); @@ -1098,6 +1102,10 @@ generate_base_implied_equalities(PlannerInfo *root) { RelOptInfo *rel = root->simple_rel_array[i]; + /* ignore GROUP RTE */ + if (i == root->group_rtindex) + continue; + if (rel == NULL) /* must be an outer join */ { Assert(bms_is_member(i, root->outer_join_rels)); @@ -3353,6 +3361,10 @@ get_eclass_indexes_for_relids(PlannerInfo *root, Relids relids) { RelOptInfo *rel = root->simple_rel_array[i]; + /* ignore GROUP RTE */ + if (i == root->group_rtindex) + continue; + if (rel == NULL) /* must be an outer join */ { Assert(bms_is_member(i, root->outer_join_rels)); diff --git a/src/backend/optimizer/plan/initsplan.c b/src/backend/optimizer/plan/initsplan.c index e2c68fe6f9..48fad35051 100644 --- a/src/backend/optimizer/plan/initsplan.c +++ b/src/backend/optimizer/plan/initsplan.c @@ -1328,6 +1328,10 @@ mark_rels_nulled_by_join(PlannerInfo *root, Index ojrelid, { RelOptInfo *rel = root->simple_rel_array[relid]; + /* ignore GROUP RTE */ + if (relid == root->group_rtindex) + continue; + if (rel == NULL) /* must be an outer join */ { Assert(bms_is_member(relid, root->outer_join_rels)); diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c index 032818423f..b969aa3bcf 100644 --- a/src/backend/optimizer/plan/planner.c +++ b/src/backend/optimizer/plan/planner.c @@ -748,6 +748,7 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root, */ root->hasJoinRTEs = false; root->hasLateralRTEs = false; + root->group_rtindex = 0; hasOuterJoins = false; hasResultRTEs = false; foreach(l, parse->rtable) @@ -781,6 +782,9 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root, case RTE_RESULT: hasResultRTEs = true; break; + case RTE_GROUP: + root->group_rtindex = list_cell_number(parse->rtable, l) + 1; + break; default: /* No work here for other RTE types */ break; @@ -836,10 +840,6 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root, preprocess_expression(root, (Node *) parse->targetList, EXPRKIND_TARGET); - /* Constant-folding might have removed all set-returning functions */ - if (parse->hasTargetSRFs) - parse->hasTargetSRFs = expression_returns_set((Node *) parse->targetList); - newWithCheckOptions = NIL; foreach(l, parse->withCheckOptions) { @@ -969,6 +969,13 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root, rte->values_lists = (List *) preprocess_expression(root, (Node *) rte->values_lists, kind); } + else if (rte->rtekind == RTE_GROUP) + { + /* Preprocess the groupexprs lists fully */ + rte->groupexprs = (List *) + preprocess_expression(root, (Node *) rte->groupexprs, + EXPRKIND_TARGET); + } /* * Process each element of the securityQuals list as if it were a @@ -984,6 +991,22 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root, } } + /* + * Replace any Vars that reference GROUP outputs in the subquery's + * targetlist and havingQual with the underlying grouping expressions. + */ + if (root->group_rtindex > 0) + { + parse->targetList = (List *) + flatten_group_exprs(root, root->parse, (Node *) parse->targetList); + parse->havingQual = + flatten_group_exprs(root, root->parse, parse->havingQual); + } + + /* Constant-folding might have removed all set-returning functions */ + if (parse->hasTargetSRFs) + parse->hasTargetSRFs = expression_returns_set((Node *) parse->targetList); + /* * Now that we are done preprocessing expressions, and in particular done * flattening join alias variables, get rid of the joinaliasvars lists. diff --git a/src/backend/optimizer/plan/setrefs.c b/src/backend/optimizer/plan/setrefs.c index 37abcb4701..631d4d2c70 100644 --- a/src/backend/optimizer/plan/setrefs.c +++ b/src/backend/optimizer/plan/setrefs.c @@ -557,6 +557,7 @@ add_rte_to_flat_rtable(PlannerGlobal *glob, List *rteperminfos, newrte->coltypes = NIL; newrte->coltypmods = NIL; newrte->colcollations = NIL; + newrte->groupexprs = NIL; newrte->securityQuals = NIL; glob->finalrtable = lappend(glob->finalrtable, newrte); diff --git a/src/backend/optimizer/prep/prepjointree.c b/src/backend/optimizer/prep/prepjointree.c index 5482ab85a7..728c07f464 100644 --- a/src/backend/optimizer/prep/prepjointree.c +++ b/src/backend/optimizer/prep/prepjointree.c @@ -1235,6 +1235,7 @@ pull_up_simple_subquery(PlannerInfo *root, Node *jtnode, RangeTblEntry *rte, case RTE_CTE: case RTE_NAMEDTUPLESTORE: case RTE_RESULT: + case RTE_GROUP: /* these can't contain any lateral references */ break; } @@ -2218,7 +2219,8 @@ perform_pullup_replace_vars(PlannerInfo *root, } /* - * Replace references in the joinaliasvars lists of join RTEs. + * Replace references in the joinaliasvars lists of join RTEs and the + * groupexprs list of group RTE. */ foreach(lc, parse->rtable) { @@ -2228,6 +2230,10 @@ perform_pullup_replace_vars(PlannerInfo *root, otherrte->joinaliasvars = (List *) pullup_replace_vars((Node *) otherrte->joinaliasvars, rvcontext); + else if (otherrte->rtekind == RTE_GROUP) + otherrte->groupexprs = (List *) + pullup_replace_vars((Node *) otherrte->groupexprs, + rvcontext); } } @@ -2293,6 +2299,7 @@ replace_vars_in_jointree(Node *jtnode, case RTE_CTE: case RTE_NAMEDTUPLESTORE: case RTE_RESULT: + case RTE_GROUP: /* these shouldn't be marked LATERAL */ Assert(false); break; diff --git a/src/backend/optimizer/util/var.c b/src/backend/optimizer/util/var.c index 844fc30978..fa7860bec7 100644 --- a/src/backend/optimizer/util/var.c +++ b/src/backend/optimizer/util/var.c @@ -81,6 +81,8 @@ static bool pull_var_clause_walker(Node *node, pull_var_clause_context *context); static Node *flatten_join_alias_vars_mutator(Node *node, flatten_join_alias_vars_context *context); +static Node *flatten_group_exprs_mutator(Node *node, + flatten_join_alias_vars_context *context); static Node *add_nullingrels_if_needed(PlannerInfo *root, Node *newnode, Var *oldvar); static bool is_standard_join_alias_expression(Node *newnode, Var *oldvar); @@ -902,6 +904,129 @@ flatten_join_alias_vars_mutator(Node *node, (void *) context); } +/* + * flatten_group_exprs + * Replace Vars that reference GROUP outputs with references to the original + * relation variables instead. + */ +Node * +flatten_group_exprs(PlannerInfo *root, Query *query, Node *node) +{ + flatten_join_alias_vars_context context; + + /* + * We do not expect this to be applied to the whole Query, only to + * expressions or LATERAL subqueries. Hence, if the top node is a Query, + * it's okay to immediately increment sublevels_up. + */ + Assert(node != (Node *) query); + + context.root = root; + context.query = query; + context.sublevels_up = 0; + /* flag whether join aliases could possibly contain SubLinks */ + context.possible_sublink = query->hasSubLinks; + /* if hasSubLinks is already true, no need to work hard */ + context.inserted_sublink = query->hasSubLinks; + + return flatten_group_exprs_mutator(node, &context); +} + +static Node * +flatten_group_exprs_mutator(Node *node, + flatten_join_alias_vars_context *context) +{ + if (node == NULL) + return NULL; + if (IsA(node, Var)) + { + Var *var = (Var *) node; + RangeTblEntry *rte; + Node *newvar; + + /* No change unless Var belongs to the GROUP of the target level */ + if (var->varlevelsup != context->sublevels_up) + return node; /* no need to copy, really */ + rte = rt_fetch(var->varno, context->query->rtable); + if (rte->rtekind != RTE_GROUP) + return node; + + /* Expand group exprs reference */ + Assert(var->varattno > 0); + newvar = (Node *) list_nth(rte->groupexprs, var->varattno - 1); + Assert(newvar != NULL); + newvar = copyObject(newvar); + + /* + * If we are expanding an expr carried down from an upper query, must + * adjust its varlevelsup fields. + */ + if (context->sublevels_up != 0) + IncrementVarSublevelsUp(newvar, context->sublevels_up, 0); + + /* Preserve original Var's location, if possible */ + if (IsA(newvar, Var)) + ((Var *) newvar)->location = var->location; + + /* Detect if we are adding a sublink to query */ + if (context->possible_sublink && !context->inserted_sublink) + context->inserted_sublink = checkExprHasSubLink(newvar); + + /* + * TODO var->varnullingrels might have the nullingrel bit that + * references RTE_GROUP. We're supposed to add it to the replacement + * expression. + * + * Maybe we can do something like add_nullingrels_if_needed(). + */ + return newvar; + } + + if (IsA(node, Aggref)) + { + Aggref *agg = (Aggref *) node; + + if ((int) agg->agglevelsup > context->sublevels_up) + return node; + + agg = copyObject(agg); + agg->aggdirectargs = (List *) + flatten_group_exprs_mutator((Node *) agg->aggdirectargs, context); + + return (Node *) agg; + } + + if (IsA(node, GroupingFunc)) + { + GroupingFunc *grp = (GroupingFunc *) node; + + if ((int) grp->agglevelsup >= context->sublevels_up) + return node; + } + + if (IsA(node, Query)) + { + /* Recurse into RTE subquery or not-yet-planned sublink subquery */ + Query *newnode; + bool save_inserted_sublink; + + context->sublevels_up++; + save_inserted_sublink = context->inserted_sublink; + context->inserted_sublink = ((Query *) node)->hasSubLinks; + newnode = query_tree_mutator((Query *) node, + flatten_group_exprs_mutator, + (void *) context, + QTW_IGNORE_GROUPEXPRS); + newnode->hasSubLinks |= context->inserted_sublink; + context->inserted_sublink = save_inserted_sublink; + context->sublevels_up--; + return (Node *) newnode; + } + + return expression_tree_mutator(node, flatten_group_exprs_mutator, + (void *) context); +} + /* * Add oldvar's varnullingrels, if any, to a flattened join alias expression. * The newnode has been copied, so we can modify it freely. diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c index bee7d8346a..7e2ec2ef4a 100644 --- a/src/backend/parser/parse_agg.c +++ b/src/backend/parser/parse_agg.c @@ -26,6 +26,7 @@ #include "parser/parse_clause.h" #include "parser/parse_coerce.h" #include "parser/parse_expr.h" +#include "parser/parse_relation.h" #include "parser/parsetree.h" #include "rewrite/rewriteManip.h" #include "utils/builtins.h" @@ -53,6 +54,15 @@ typedef struct bool in_agg_direct_args; } check_ungrouped_columns_context; +typedef struct +{ + ParseState *pstate; + List *groupClauses; + List *groupClauseCommonExprs; + bool have_non_var_grouping; + int sublevels_up; +} substitute_group_exprs_context; + static int check_agg_arguments(ParseState *pstate, List *directargs, List *args, @@ -65,6 +75,11 @@ static void check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry, List **func_grouped_rels); static bool check_ungrouped_columns_walker(Node *node, check_ungrouped_columns_context *context); +static Node *substitute_group_exprs(Node *node, ParseState *pstate, + List *groupClauses, List *groupClauseCommonExprs, + bool have_non_var_grouping); +static Node *substitute_group_exprs_mutator(Node *node, + substitute_group_exprs_context *context); static void finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry, List *groupClauses, bool hasJoinRTEs, bool have_non_var_grouping); @@ -1082,6 +1097,7 @@ parseCheckAggregates(ParseState *pstate, Query *qry) List *gset_common = NIL; List *groupClauses = NIL; List *groupClauseCommonVars = NIL; + List *groupClauseCommonExprs = NIL; bool have_non_var_grouping; List *func_grouped_rels = NIL; ListCell *l; @@ -1201,13 +1217,26 @@ parseCheckAggregates(ParseState *pstate, Query *qry) { have_non_var_grouping = true; } - else if (!qry->groupingSets || - list_member_int(gset_common, tle->ressortgroupref)) + + if (!qry->groupingSets || + list_member_int(gset_common, tle->ressortgroupref)) { - groupClauseCommonVars = lappend(groupClauseCommonVars, tle->expr); + groupClauseCommonExprs = lappend(groupClauseCommonExprs, tle->expr); + + if (IsA(tle->expr, Var)) + groupClauseCommonVars = lappend(groupClauseCommonVars, tle->expr); } + } + /* + * Now build an RTE and nsitem for the result of the grouping step. + */ + pstate->p_grouping_nsitem = + addRangeTableEntryForGroup(pstate, groupClauses); + + qry->rtable = pstate->p_rtable; + /* * Check the targetlist and HAVING clause for ungrouped variables. * @@ -1241,6 +1270,15 @@ parseCheckAggregates(ParseState *pstate, Query *qry) have_non_var_grouping, &func_grouped_rels); + qry->targetList = (List *) + substitute_group_exprs((Node *) qry->targetList, pstate, + groupClauses, groupClauseCommonExprs, + have_non_var_grouping); + qry->havingQual = + substitute_group_exprs(qry->havingQual, pstate, + groupClauses, groupClauseCommonExprs, + have_non_var_grouping); + /* * Per spec, aggregates can't appear in a recursive term. */ @@ -1470,6 +1508,176 @@ check_ungrouped_columns_walker(Node *node, (void *) context); } +static Node * +substitute_group_exprs(Node *node, ParseState *pstate, + List *groupClauses, List *groupClauseCommonExprs, + bool have_non_var_grouping) +{ + substitute_group_exprs_context context; + + context.pstate = pstate; + context.groupClauses = groupClauses; + context.groupClauseCommonExprs = groupClauseCommonExprs; + context.have_non_var_grouping = have_non_var_grouping; + context.sublevels_up = 0; + return substitute_group_exprs_mutator(node, &context); +} + +static Node * +substitute_group_exprs_mutator(Node *node, + substitute_group_exprs_context *context) +{ + ListCell *gl; + + if (node == NULL) + return NULL; + + if (IsA(node, Aggref)) + { + Aggref *agg = (Aggref *) node; + + if ((int) agg->agglevelsup == context->sublevels_up) + { + /* + * If we find an aggregate call of the original level, do not + * recurse into its normal arguments, ORDER BY arguments, or + * filter; grouped vars there do not need to be replaced. But we + * should modify direct arguments as though they weren't in an + * aggregate. + */ + agg = copyObject(agg); + agg->aggdirectargs = (List *) + substitute_group_exprs_mutator((Node *) agg->aggdirectargs, + context); + return (Node *) agg; + } + + /* + * We can skip recursing into aggregates of higher levels altogether, + * since they could not possibly contain Vars of concern to us (see + * transformAggregateCall). We do need to look at aggregates of lower + * levels, however. + */ + if ((int) agg->agglevelsup > context->sublevels_up) + return node; + } + + if (IsA(node, GroupingFunc)) + { + GroupingFunc *grp = (GroupingFunc *) node; + + if ((int) grp->agglevelsup >= context->sublevels_up) + return node; + } + + /* + * If we have any GROUP BY items that are not simple Vars, check to see if + * subexpression as a whole matches any GROUP BY item. We need to do this + * at every recursion level so that we recognize GROUPed-BY expressions + * before reaching variables within them. But this only works at the outer + * query level, as noted above. + */ + if (context->have_non_var_grouping && context->sublevels_up == 0) + { + int attnum = 0; + foreach(gl, context->groupClauses) + { + TargetEntry *tle = lfirst(gl); + + attnum++; + if (equal(node, tle->expr)) + { + Var *newvar; + int group_rtindex; + ParseNamespaceColumn *group_nscolumns; + + group_rtindex = context->pstate->p_grouping_nsitem->p_rtindex; + group_nscolumns = context->pstate->p_grouping_nsitem->p_nscolumns; + + newvar = buildVarFromNSColumn(context->pstate, + group_nscolumns + attnum - 1); + + if (!list_member(context->groupClauseCommonExprs, node)) + newvar->varnullingrels = + bms_add_member(newvar->varnullingrels, group_rtindex); + + return (Node *) newvar; + } + } + } + + if (IsA(node, Const) || + IsA(node, Param)) + return node; + + /* + * We are only interested in Vars of the original query level. + */ + if (IsA(node, Var)) + { + Var *var = (Var *) node; + + if (var->varlevelsup != context->sublevels_up) + return node; /* it's not local to my query, ignore */ + + /* + * Check for a match, if we didn't do it above. + */ + if (!context->have_non_var_grouping || context->sublevels_up != 0) + { + int attnum = 0; + foreach(gl, context->groupClauses) + { + Var *gvar = (Var *) ((TargetEntry *) lfirst(gl))->expr; + + attnum++; + if (IsA(gvar, Var) && + gvar->varno == var->varno && + gvar->varattno == var->varattno && + gvar->varlevelsup == 0) + { + Var *newvar; + int group_rtindex; + ParseNamespaceColumn *group_nscolumns; + + group_rtindex = + context->pstate->p_grouping_nsitem->p_rtindex; + group_nscolumns = + context->pstate->p_grouping_nsitem->p_nscolumns; + + newvar = buildVarFromNSColumn(context->pstate, + group_nscolumns + attnum - 1); + newvar->varlevelsup = context->sublevels_up; + + if (!list_member(context->groupClauseCommonExprs, node)) + newvar->varnullingrels = + bms_add_member(newvar->varnullingrels, group_rtindex); + + return (Node *) newvar; + } + } + } + + return node; + } + + if (IsA(node, Query)) + { + /* Recurse into subselects */ + Query *newnode; + + context->sublevels_up++; + newnode = query_tree_mutator((Query *) node, + substitute_group_exprs_mutator, + (void *) context, + 0); + context->sublevels_up--; + return (Node *) newnode; + } + return expression_tree_mutator(node, substitute_group_exprs_mutator, + (void *) context); +} + /* * finalize_grouping_exprs - * Scan the given expression tree for GROUPING() and related calls, diff --git a/src/backend/parser/parse_clause.c b/src/backend/parser/parse_clause.c index 8118036495..350ca1d515 100644 --- a/src/backend/parser/parse_clause.c +++ b/src/backend/parser/parse_clause.c @@ -74,8 +74,6 @@ static ParseNamespaceItem *getNSItemForSpecialRelationTypes(ParseState *pstate, static Node *transformFromClauseItem(ParseState *pstate, Node *n, ParseNamespaceItem **top_nsitem, List **namespace); -static Var *buildVarFromNSColumn(ParseState *pstate, - ParseNamespaceColumn *nscol); static Node *buildMergedJoinVar(ParseState *pstate, JoinType jointype, Var *l_colvar, Var *r_colvar); static void markRelsAsNulledBy(ParseState *pstate, Node *n, int jindex); @@ -1636,7 +1634,7 @@ transformFromClauseItem(ParseState *pstate, Node *n, * Note also that no column SELECT privilege is requested here; that would * happen only if the column is actually referenced in the query. */ -static Var * +Var * buildVarFromNSColumn(ParseState *pstate, ParseNamespaceColumn *nscol) { Var *var; diff --git a/src/backend/parser/parse_relation.c b/src/backend/parser/parse_relation.c index 2f64eaf0e3..6947638425 100644 --- a/src/backend/parser/parse_relation.c +++ b/src/backend/parser/parse_relation.c @@ -2557,6 +2557,79 @@ addRangeTableEntryForENR(ParseState *pstate, tupdesc); } +/* + * Add an entry for grouping step to the pstate's range table (p_rtable). + * Then, construct and return a ParseNamespaceItem for the new RTE. + */ +ParseNamespaceItem * +addRangeTableEntryForGroup(ParseState *pstate, + List *groupClauses) +{ + RangeTblEntry *rte = makeNode(RangeTblEntry); + Alias *eref; + List *groupexprs; + List *coltypes, + *coltypmods, + *colcollations; + ListCell *lc; + ParseNamespaceItem *nsitem; + + Assert(pstate != NULL); + + rte->rtekind = RTE_GROUP; + rte->alias = NULL; + + eref = makeAlias("*GROUP*", NIL); + + /* fill in any unspecified alias columns, and extract column type info */ + groupexprs = NIL; + coltypes = coltypmods = colcollations = NIL; + foreach(lc, groupClauses) + { + TargetEntry *te = (TargetEntry *) lfirst(lc); + char *colname = te->resname ? pstrdup(te->resname) : "unamed_col"; + + eref->colnames = lappend(eref->colnames, makeString(colname)); + + groupexprs = lappend(groupexprs, copyObject(te->expr)); + + coltypes = lappend_oid(coltypes, + exprType((Node *) te->expr)); + coltypmods = lappend_int(coltypmods, + exprTypmod((Node *) te->expr)); + colcollations = lappend_oid(colcollations, + exprCollation((Node *) te->expr)); + } + + rte->eref = eref; + rte->groupexprs = groupexprs; + + /* + * Set flags. + * + * The grouping step is never checked for access rights, so no need to + * perform addRTEPermissionInfo(). + */ + rte->lateral = false; + rte->inFromCl = false; + + /* + * Add completed RTE to pstate's range table list, so that we know its + * index. But we don't add it to the join list --- caller must do that if + * appropriate. + */ + pstate->p_rtable = lappend(pstate->p_rtable, rte); + + /* + * Build a ParseNamespaceItem, but don't add it to the pstate's namespace + * list --- caller must do that if appropriate. + */ + nsitem = buildNSItemFromLists(rte, list_length(pstate->p_rtable), + coltypes, coltypmods, colcollations); + + return nsitem; +} + /* * Has the specified refname been selected FOR UPDATE/FOR SHARE? @@ -3003,6 +3076,7 @@ expandRTE(RangeTblEntry *rte, int rtindex, int sublevels_up, } break; case RTE_RESULT: + case RTE_GROUP: /* These expose no columns, so nothing to do */ break; default: @@ -3317,10 +3391,11 @@ get_rte_attribute_is_dropped(RangeTblEntry *rte, AttrNumber attnum) case RTE_TABLEFUNC: case RTE_VALUES: case RTE_CTE: + case RTE_GROUP: /* - * Subselect, Table Functions, Values, CTE RTEs never have dropped - * columns + * Subselect, Table Functions, Values, CTE, GROUP RTEs never have + * dropped columns */ result = false; break; diff --git a/src/backend/parser/parse_target.c b/src/backend/parser/parse_target.c index ee6fcd0503..1f8edc05c9 100644 --- a/src/backend/parser/parse_target.c +++ b/src/backend/parser/parse_target.c @@ -380,6 +380,7 @@ markTargetListOrigin(ParseState *pstate, TargetEntry *tle, case RTE_TABLEFUNC: case RTE_NAMEDTUPLESTORE: case RTE_RESULT: + case RTE_GROUP: /* not a simple relation, leave it unmarked */ break; case RTE_CTE: @@ -1579,6 +1580,7 @@ expandRecordVariable(ParseState *pstate, Var *var, int levelsup) case RTE_VALUES: case RTE_NAMEDTUPLESTORE: case RTE_RESULT: + case RTE_GROUP: /* * This case should not occur: a column of a table, values list, diff --git a/src/backend/utils/adt/ruleutils.c b/src/backend/utils/adt/ruleutils.c index 9618619762..f539693bfe 100644 --- a/src/backend/utils/adt/ruleutils.c +++ b/src/backend/utils/adt/ruleutils.c @@ -5433,11 +5433,27 @@ get_query_def(Query *query, StringInfo buf, List *parentnamespace, { deparse_context context; deparse_namespace dpns; + int rtable_size; + ListCell *lc; /* Guard against excessively long or deeply-nested queries */ CHECK_FOR_INTERRUPTS(); check_stack_depth(); + rtable_size = list_length(query->rtable); + foreach (lc, query->rtable) + { + RangeTblEntry *rte = lfirst_node(RangeTblEntry, lc); + + if (rte->rtekind == RTE_GROUP) + rtable_size--; + } + + query->targetList = (List *) + flatten_group_exprs(NULL, query, (Node *) query->targetList); + query->havingQual = + flatten_group_exprs(NULL, query, query->havingQual); + /* * Before we begin to examine the query, acquire locks on referenced * relations, and fix up deleted columns in JOIN RTEs. This ensures @@ -5454,7 +5470,7 @@ get_query_def(Query *query, StringInfo buf, List *parentnamespace, context.windowClause = NIL; context.windowTList = NIL; context.varprefix = (parentnamespace != NIL || - list_length(query->rtable) != 1); + rtable_size != 1); context.prettyFlags = prettyFlags; context.wrapColumn = wrapColumn; context.indentLevel = startIndent; @@ -7838,6 +7854,7 @@ get_name_for_var_field(Var *var, int fieldno, case RTE_VALUES: case RTE_NAMEDTUPLESTORE: case RTE_RESULT: + case RTE_GROUP: /* * This case should not occur: a column of a table, values list, diff --git a/src/include/commands/explain.h b/src/include/commands/explain.h index 9b8b351d9a..35be084869 100644 --- a/src/include/commands/explain.h +++ b/src/include/commands/explain.h @@ -67,6 +67,7 @@ typedef struct ExplainState List *deparse_cxt; /* context list for deparsing expressions */ Bitmapset *printed_subplans; /* ids of SubPlans we've printed */ bool hide_workers; /* set if we find an invisible Gather */ + int rtable_size; /* length of rtable excluding GROUP entries */ /* state related to the current plan node */ ExplainWorkersState *workers_state; /* needed if parallel plan */ } ExplainState; diff --git a/src/include/nodes/nodeFuncs.h b/src/include/nodes/nodeFuncs.h index eaba59bed8..1f0de5b3d8 100644 --- a/src/include/nodes/nodeFuncs.h +++ b/src/include/nodes/nodeFuncs.h @@ -31,6 +31,8 @@ struct PlanState; /* avoid including execnodes.h too */ #define QTW_DONT_COPY_QUERY 0x40 /* do not copy top Query */ #define QTW_EXAMINE_SORTGROUP 0x80 /* include SortGroupClause lists */ +#define QTW_IGNORE_GROUPEXPRS 0x100 /* GROUP expressions lists */ + /* callback function for check_functions_in_node */ typedef bool (*check_function_callback) (Oid func_id, void *context); diff --git a/src/include/nodes/parsenodes.h b/src/include/nodes/parsenodes.h index ddfed02db2..a7b6fd3976 100644 --- a/src/include/nodes/parsenodes.h +++ b/src/include/nodes/parsenodes.h @@ -1036,6 +1036,7 @@ typedef enum RTEKind RTE_RESULT, /* RTE represents an empty FROM clause; such * RTEs are added by the planner, they're not * present during parsing or rewriting */ + RTE_GROUP, /* the grouping step */ } RTEKind; typedef struct RangeTblEntry @@ -1242,6 +1243,12 @@ typedef struct RangeTblEntry /* estimated or actual from caller */ Cardinality enrtuples pg_node_attr(query_jumble_ignore); + /* + * Fields valid for GROUP RTEs (else NULL/zero): + */ + /* list of expressions grouped on */ + List *groupexprs pg_node_attr(query_jumble_ignore); + /* * Fields valid in all RTEs: */ diff --git a/src/include/nodes/pathnodes.h b/src/include/nodes/pathnodes.h index 14ef296ab7..c082693e7c 100644 --- a/src/include/nodes/pathnodes.h +++ b/src/include/nodes/pathnodes.h @@ -505,6 +505,11 @@ struct PlannerInfo /* true if planning a recursive WITH item */ bool hasRecursion; + /* + * The rangetable index for the GROUP RTE, or 0 if there is no GROUP RTE. + */ + int group_rtindex; + /* * Information about aggregates. Filled by preprocess_aggrefs(). */ diff --git a/src/include/optimizer/optimizer.h b/src/include/optimizer/optimizer.h index 7b63c5cf71..93e3dc719d 100644 --- a/src/include/optimizer/optimizer.h +++ b/src/include/optimizer/optimizer.h @@ -201,5 +201,6 @@ extern bool contain_vars_of_level(Node *node, int levelsup); extern int locate_var_of_level(Node *node, int levelsup); extern List *pull_var_clause(Node *node, int flags); extern Node *flatten_join_alias_vars(PlannerInfo *root, Query *query, Node *node); +extern Node *flatten_group_exprs(PlannerInfo *root, Query *query, Node *node); #endif /* OPTIMIZER_H */ diff --git a/src/include/parser/parse_clause.h b/src/include/parser/parse_clause.h index e71762b10c..1a1cf3570e 100644 --- a/src/include/parser/parse_clause.h +++ b/src/include/parser/parse_clause.h @@ -17,6 +17,8 @@ #include "parser/parse_node.h" extern void transformFromClause(ParseState *pstate, List *frmList); +extern Var *buildVarFromNSColumn(ParseState *pstate, + ParseNamespaceColumn *nscol); extern int setTargetTable(ParseState *pstate, RangeVar *relation, bool inh, bool alsoSource, AclMode requiredPerms); diff --git a/src/include/parser/parse_node.h b/src/include/parser/parse_node.h index 5b781d87a9..ef78fd8224 100644 --- a/src/include/parser/parse_node.h +++ b/src/include/parser/parse_node.h @@ -237,6 +237,8 @@ struct ParseState ParseParamRefHook p_paramref_hook; CoerceParamHook p_coerce_param_hook; void *p_ref_hook_state; /* common passthrough link for above */ + + ParseNamespaceItem *p_grouping_nsitem; /* NSItem for grouping, or NULL */ }; /* diff --git a/src/include/parser/parse_relation.h b/src/include/parser/parse_relation.h index bea2da5496..91fd8e243b 100644 --- a/src/include/parser/parse_relation.h +++ b/src/include/parser/parse_relation.h @@ -100,6 +100,8 @@ extern ParseNamespaceItem *addRangeTableEntryForCTE(ParseState *pstate, extern ParseNamespaceItem *addRangeTableEntryForENR(ParseState *pstate, RangeVar *rv, bool inFromCl); +extern ParseNamespaceItem *addRangeTableEntryForGroup(ParseState *pstate, + List *groupClauses); extern RTEPermissionInfo *addRTEPermissionInfo(List **rteperminfos, RangeTblEntry *rte); extern RTEPermissionInfo *getRTEPermissionInfo(List *rteperminfos, diff --git a/src/test/regress/expected/groupingsets.out b/src/test/regress/expected/groupingsets.out index e1f0660810..9c7590e7ba 100644 --- a/src/test/regress/expected/groupingsets.out +++ b/src/test/regress/expected/groupingsets.out @@ -2150,4 +2150,53 @@ select (select grouping(v1)) from (values ((select 1))) v(v1) group by v1; 0 (1 row) +-- test handling of subqueries in grouping sets +create temp table gstest5(id integer primary key, v integer); +insert into gstest5 select i, i from generate_series(1,5)i; +explain (costs off) +select grouping((select t1.v from gstest5 t2 where id = t1.id)), + (select t1.v from gstest5 t2 where id = t1.id) as s +from gstest5 t1 +group by grouping sets(v, s) +order by case when grouping((select t1.v from gstest5 t2 where id = t1.id)) = 0 + then (select t1.v from gstest5 t2 where id = t1.id) + else null end + nulls first; + QUERY PLAN +----------------------------------------------------------------------------------------------------------- + Sort + Sort Key: (CASE WHEN (GROUPING((SubPlan 2)) = 0) THEN ((SubPlan 3)) ELSE NULL::integer END) NULLS FIRST + -> HashAggregate + Hash Key: t1.v + Hash Key: (SubPlan 3) + -> Seq Scan on gstest5 t1 + SubPlan 3 + -> Bitmap Heap Scan on gstest5 t2 + Recheck Cond: (id = t1.id) + -> Bitmap Index Scan on gstest5_pkey + Index Cond: (id = t1.id) +(11 rows) + +select grouping((select t1.v from gstest5 t2 where id = t1.id)), + (select t1.v from gstest5 t2 where id = t1.id) as s +from gstest5 t1 +group by grouping sets(v, s) +order by case when grouping((select t1.v from gstest5 t2 where id = t1.id)) = 0 + then (select t1.v from gstest5 t2 where id = t1.id) + else null end + nulls first; + grouping | s +----------+--- + 1 | + 1 | + 1 | + 1 | + 1 | + 0 | 1 + 0 | 2 + 0 | 3 + 0 | 4 + 0 | 5 +(10 rows) + -- end diff --git a/src/test/regress/sql/groupingsets.sql b/src/test/regress/sql/groupingsets.sql index 90ba27257a..0520e44aeb 100644 --- a/src/test/regress/sql/groupingsets.sql +++ b/src/test/regress/sql/groupingsets.sql @@ -589,4 +589,27 @@ explain (costs off) select (select grouping(v1)) from (values ((select 1))) v(v1) group by v1; select (select grouping(v1)) from (values ((select 1))) v(v1) group by v1; +-- test handling of subqueries in grouping sets +create temp table gstest5(id integer primary key, v integer); +insert into gstest5 select i, i from generate_series(1,5)i; + +explain (costs off) +select grouping((select t1.v from gstest5 t2 where id = t1.id)), + (select t1.v from gstest5 t2 where id = t1.id) as s +from gstest5 t1 +group by grouping sets(v, s) +order by case when grouping((select t1.v from gstest5 t2 where id = t1.id)) = 0 + then (select t1.v from gstest5 t2 where id = t1.id) + else null end + nulls first; + +select grouping((select t1.v from gstest5 t2 where id = t1.id)), + (select t1.v from gstest5 t2 where id = t1.id) as s +from gstest5 t1 +group by grouping sets(v, s) +order by case when grouping((select t1.v from gstest5 t2 where id = t1.id)) = 0 + then (select t1.v from gstest5 t2 where id = t1.id) + else null end + nulls first; + -- end -- 2.34.1