一. 背景:

在Join算子谓词下推主要作用时是将where条件的谓词尽可能地应用到Join的左右两张表的TableScan中。比如对于sql:


SELECT * FROM orders JOIN lineitem ON orders.orderkey = lineitem.orderkey where lineitem.orderkey = 2


对于右表的谓词条件lineitem.orderkey = 2 也可以推给左表,实现左右表都能大大减少数据量后再进行Join操作。因此上边的join操作将被优化成:

SELECT * FROM orders JOIN lineitem ON orders.orderkey = lineitem.orderkey where lineitem.orderkey = 2 and orders.orderkey = 2。

本文主要通过走读presto谓词下推中对于Join部分处理的代码来探究上边的优化过程是如何实现的。

二 . 代码走读

如下代码的实现是在PredicatePushDown::visitJoin中实现的。

public PlanNode visitJoin(JoinNode node, RewriteContext<Expression> context)
{
    Expression inheritedPredicate = context.get();

    // See if we can rewrite outer joins in terms of a plain inner join
    node = tryNormalizeToOuterToInnerJoin(node, inheritedPredicate);

    // leftEffectivePredicate和rightEffectivePredicate的意思是左右表有没有谓词条件固定为false的情况,比如2 < 1 之类的谓词条件
    
Expression leftEffectivePredicate = effectivePredicateExtractor.extract(session, node.getLeft(), types, typeAnalyzer);
    Expression rightEffectivePredicate = effectivePredicateExtractor.extract(session, node.getRight(), types, typeAnalyzer);
    // joinPredicate 是join的on条件的谓词
    Expression joinPredicate = extractJoinPredicate(node);

    Expression leftPredicate;
    Expression rightPredicate;
    Expression postJoinPredicate;
    Expression newJoinPredicate;

    switch (node.getType()) {
        case INNER:
            // processInnerJoin实现INNER Join类型中的谓词下推,processInnerJoin的实现见底下分析
            InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin(inheritedPredicate,
                    leftEffectivePredicate,
                    rightEffectivePredicate,
                    joinPredicate,
                    node.getLeft().getOutputSymbols());
            leftPredicate = innerJoinPushDownResult.getLeftPredicate();
            rightPredicate = innerJoinPushDownResult.getRightPredicate();
            postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate();
            newJoinPredicate = innerJoinPushDownResult.getJoinPredicate();
            break;
        case LEFT:
            // processLimitedOuterJoin功能和processInnerJoin类型,只是往inner表推送谓词的时候,如果谓词字段不在outer表,也不会推送而已
            OuterJoinPushDownResult leftOuterJoinPushDownResult = processLimitedOuterJoin(inheritedPredicate,
                    leftEffectivePredicate,
                    rightEffectivePredicate,
                    joinPredicate,
                    node.getLeft().getOutputSymbols());
            leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate();
            rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate();
            postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate();
            newJoinPredicate = leftOuterJoinPushDownResult.getJoinPredicate();
            break;
        case RIGHT:
            OuterJoinPushDownResult rightOuterJoinPushDownResult = processLimitedOuterJoin(inheritedPredicate,
                    rightEffectivePredicate,
                    leftEffectivePredicate,
                    joinPredicate,
                    node.getRight().getOutputSymbols());
            leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate();
            rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate();
            postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate();
            newJoinPredicate = rightOuterJoinPushDownResult.getJoinPredicate();
            break;
        case FULL:
            leftPredicate = TRUE_LITERAL;
            rightPredicate = TRUE_LITERAL;
            postJoinPredicate = inheritedPredicate;
            newJoinPredicate = joinPredicate;
            break;
        default:
            throw new UnsupportedOperationException("Unsupported join type: " + node.getType());
    }

    newJoinPredicate = simplifyExpression(newJoinPredicate);
 
    // 将结果固定为false的join谓词换成 where 0 = 1
    if (newJoinPredicate.equals(BooleanLiteral.FALSE_LITERAL)) {
        newJoinPredicate = new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new LongLiteral("0"), new LongLiteral("1"));
    }

    // Create identity projections for all existing symbols
    Assignments.Builder leftProjections = Assignments.builder();
    leftProjections.putAll(node.getLeft()
            .getOutputSymbols().stream()
            .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference)));

    Assignments.Builder rightProjections = Assignments.builder();
    rightProjections.putAll(node.getRight()
            .getOutputSymbols().stream()
            .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference)));

    // 如下的意思是如果左右表推了谓词,检查要不要增加project,如果新增的谓词没有对应的project,则新增回来。

    // Create new projections for the new join clauses
    List<JoinNode.EquiJoinClause> equiJoinClauses = new ArrayList<>();
    ImmutableList.Builder<Expression> joinFilterBuilder = ImmutableList.builder();
    for (Expression conjunct : extractConjuncts(newJoinPredicate)) {
        if (joinEqualityExpression(node.getLeft().getOutputSymbols()).test(conjunct)) {
            ComparisonExpression equality = (ComparisonExpression) conjunct;

            boolean alignedComparison = Iterables.all(SymbolsExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols()));
            Expression leftExpression = (alignedComparison) ? equality.getLeft() : equality.getRight();
            Expression rightExpression = (alignedComparison) ? equality.getRight() : equality.getLeft();

            Symbol leftSymbol = symbolForExpression(leftExpression);
            if (!node.getLeft().getOutputSymbols().contains(leftSymbol)) {
                leftProjections.put(leftSymbol, leftExpression);
            }

            Symbol rightSymbol = symbolForExpression(rightExpression);
            if (!node.getRight().getOutputSymbols().contains(rightSymbol)) {
                rightProjections.put(rightSymbol, rightExpression);
            }

            equiJoinClauses.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol));
        }
        else {
            joinFilterBuilder.add(conjunct);
        }
    }

    DynamicFiltersResult dynamicFiltersResult = createDynamicFilters(node, equiJoinClauses, session, idAllocator);
    Map<String, Symbol> dynamicFilters = dynamicFiltersResult.getDynamicFilters();
    leftPredicate = combineConjuncts(leftPredicate, combineConjuncts(dynamicFiltersResult.getPredicates()));

    // 根据新的谓词条件重新生成左右表的tablescan

    PlanNode leftSource;
    PlanNode rightSource;
    boolean equiJoinClausesUnmodified = ImmutableSet.copyOf(equiJoinClauses).equals(ImmutableSet.copyOf(node.getCriteria()));
    if (!equiJoinClausesUnmodified) {
        leftSource = context.rewrite(new ProjectNode(idAllocator.getNextId(), node.getLeft(), leftProjections.build()), leftPredicate);
        rightSource = context.rewrite(new ProjectNode(idAllocator.getNextId(), node.getRight(), rightProjections.build()), rightPredicate);
    }
    else {
        leftSource = context.rewrite(node.getLeft(), leftPredicate);
        rightSource = context.rewrite(node.getRight(), rightPredicate);
    }

    Optional<Expression> newJoinFilter = Optional.of(combineConjuncts(joinFilterBuilder.build()));
    if (newJoinFilter.get() == TRUE_LITERAL) {
        newJoinFilter = Optional.empty();
    }

    if (node.getType() == INNER && newJoinFilter.isPresent() && equiJoinClauses.isEmpty()) {
        // if we do not have any equi conjunct we do not pushdown non-equality condition into
        // inner join, so we plan execution as nested-loops-join followed by filter instead
        // hash join.
        // todo: remove the code when we have support for filter function in nested loop join
        postJoinPredicate = combineConjuncts(postJoinPredicate, newJoinFilter.get());
        newJoinFilter = Optional.empty();
    }

    boolean filtersEquivalent =
            newJoinFilter.isPresent() == node.getFilter().isPresent() &&
                    (!newJoinFilter.isPresent() || areExpressionsEquivalent(newJoinFilter.get(), node.getFilter().get()));

    // 重新生成Join Node,新的Join Node已经将所有可能得谓词都下推给TableScan了

    PlanNode output = node;
    if (leftSource != node.getLeft() ||
            rightSource != node.getRight() ||
            !filtersEquivalent ||
            !dynamicFilters.equals(node.getDynamicFilters()) ||
            !equiJoinClausesUnmodified) {
        leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build());
        rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build());

        output = new JoinNode(
                node.getId(),
                node.getType(),
                leftSource,
                rightSource,
                equiJoinClauses,
                ImmutableList.<Symbol>builder()
                        .addAll(leftSource.getOutputSymbols())
                        .addAll(rightSource.getOutputSymbols())
                        .build(),
                newJoinFilter,
                node.getLeftHashSymbol(),
                node.getRightHashSymbol(),
                node.getDistributionType(),
                node.isSpillable(),
                dynamicFilters);
    }

    if (!postJoinPredicate.equals(TRUE_LITERAL)) {
        output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate);
    }

    if (!node.getOutputSymbols().equals(output.getOutputSymbols())) {
        output = new ProjectNode(idAllocator.getNextId(), output, Assignments.identity(node.getOutputSymbols()));
    }

    return output;
}

processInnerJoin的代码走读:

private static InnerJoinPushDownResult processInnerJoin(Expression inheritedPredicate, Expression leftEffectivePredicate, Expression rightEffectivePredicate, Expression joinPredicate, Collection<Symbol> leftSymbols)
{
    checkArgument(Iterables.all(SymbolsExtractor.extractUnique(leftEffectivePredicate), in(leftSymbols)), "leftEffectivePredicate must only contain symbols from leftSymbols");
    checkArgument(Iterables.all(SymbolsExtractor.extractUnique(rightEffectivePredicate), not(in(leftSymbols))), "rightEffectivePredicate must not contain symbols from leftSymbols");

    ImmutableList.Builder<Expression> leftPushDownConjuncts = ImmutableList.builder();
    ImmutableList.Builder<Expression> rightPushDownConjuncts = ImmutableList.builder();
    ImmutableList.Builder<Expression> joinConjuncts = ImmutableList.builder();

    // 首先过滤掉哪些不确定的谓词,  

    joinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(DeterminismEvaluator::isDeterministic)));
    inheritedPredicate = filterDeterministicConjuncts(inheritedPredicate);

    joinConjuncts.addAll(filter(extractConjuncts(joinPredicate), not(DeterminismEvaluator::isDeterministic)));
    joinPredicate = filterDeterministicConjuncts(joinPredicate);

    leftEffectivePredicate = filterDeterministicConjuncts(leftEffectivePredicate);
    rightEffectivePredicate = filterDeterministicConjuncts(rightEffectivePredicate);

    // 根据左右包的谓词和join的谓词推断所有的等式,比如a = b, b = c 可以推断出 a = c
    // Generate equality inferences
    EqualityInference allInference = createEqualityInference(inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate);
    EqualityInference allInferenceWithoutLeftInferred = createEqualityInference(inheritedPredicate, rightEffectivePredicate, joinPredicate);
    EqualityInference allInferenceWithoutRightInferred = createEqualityInference(inheritedPredicate, leftEffectivePredicate, joinPredicate);

    // 下边的代码的主要意思是如果发现可以通过推断增加更多的谓词,则将推断新增的谓词加上去
    // 不如上述推断出新的谓词a = c,则将a = c 加上去到谓词列表中

    // Sort through conjuncts in inheritedPredicate that were not used for inference
    for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) {
        Expression leftRewrittenConjunct = allInference.rewriteExpression(conjunct, in(leftSymbols));
        if (leftRewrittenConjunct != null) {
            leftPushDownConjuncts.add(leftRewrittenConjunct);
        }

        Expression rightRewrittenConjunct = allInference.rewriteExpression(conjunct, not(in(leftSymbols)));
        if (rightRewrittenConjunct != null) {
            rightPushDownConjuncts.add(rightRewrittenConjunct);
        }

        // Drop predicate after join only if unable to push down to either side
        if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) {
            joinConjuncts.add(conjunct);
        }
    }

    // 如下主要是实现尝试将右边的谓词推到左边TableScan或者将左边的谓词推到右边TableScan
    // See if we can push the right effective predicate to the left side
    for (Expression conjunct : EqualityInference.nonInferrableConjuncts(rightEffectivePredicate)) {
        Expression rewritten = allInference.rewriteExpression(conjunct, in(leftSymbols));
        if (rewritten != null) {
            leftPushDownConjuncts.add(rewritten);
        }
    }

    // See if we can push the left effective predicate to the right side
    for (Expression conjunct : EqualityInference.nonInferrableConjuncts(leftEffectivePredicate)) {
        Expression rewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols)));
        if (rewritten != null) {
            rightPushDownConjuncts.add(rewritten);
        }
    }

    // 如下主要将Join谓词拆开,看下Join谓词能不能推给左右的TableScan

    for (Expression conjunct : EqualityInference.nonInferrableConjuncts(joinPredicate)) {
        Expression leftRewritten = allInference.rewriteExpression(conjunct, in(leftSymbols));
        if (leftRewritten != null) {
            leftPushDownConjuncts.add(leftRewritten);
        }

        Expression rightRewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols)));
        if (rightRewritten != null) {
            rightPushDownConjuncts.add(rightRewritten);
        }

        if (leftRewritten == null && rightRewritten == null) {
            joinConjuncts.add(conjunct);
        }
    }

 
    // 如下是将推断出来的新谓词也增加到左右表的过滤条件中
    // 比如左表x = 1 如果可以推断出右表可以增加y=1,则在右表也增加y=1的过滤条件  
  leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeEqualities());
    rightPushDownConjuncts.addAll(allInferenceWithoutRightInferred.generateEqualitiesPartitionedBy(not(in(leftSymbols))).getScopeEqualities());
    joinConjuncts.addAll(allInference.generateEqualitiesPartitionedBy(in(leftSymbols)::apply).getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as part of the join predicate

    return new InnerJoinPushDownResult(combineConjuncts(leftPushDownConjuncts.build()), combineConjuncts(rightPushDownConjuncts.build()), combineConjuncts(joinConjuncts.build()), TRUE_LITERAL);
}