一. 背景:
在Join算子谓词下推主要作用时是将where条件的谓词尽可能地应用到Join的左右两张表的TableScan中。比如对于sql:
SELECT * FROM orders JOIN lineitem ON orders.orderkey = lineitem.orderkey where lineitem.orderkey = 2
对于右表的谓词条件lineitem.orderkey = 2 也可以推给左表,实现左右表都能大大减少数据量后再进行Join操作。因此上边的join操作将被优化成:
SELECT * FROM orders JOIN lineitem ON orders.orderkey = lineitem.orderkey where lineitem.orderkey = 2 and orders.orderkey = 2。
本文主要通过走读presto谓词下推中对于Join部分处理的代码来探究上边的优化过程是如何实现的。
二 . 代码走读
如下代码的实现是在PredicatePushDown::visitJoin中实现的。
public PlanNode visitJoin(JoinNode node, RewriteContext<Expression> context)
{
Expression inheritedPredicate = context.get();
// See if we can rewrite outer joins in terms of a plain inner join
node = tryNormalizeToOuterToInnerJoin(node, inheritedPredicate);
// leftEffectivePredicate和rightEffectivePredicate的意思是左右表有没有谓词条件固定为false的情况,比如2 < 1 之类的谓词条件
Expression leftEffectivePredicate = effectivePredicateExtractor.extract(session, node.getLeft(), types, typeAnalyzer);
Expression rightEffectivePredicate = effectivePredicateExtractor.extract(session, node.getRight(), types, typeAnalyzer);
// joinPredicate 是join的on条件的谓词
Expression joinPredicate = extractJoinPredicate(node);
Expression leftPredicate;
Expression rightPredicate;
Expression postJoinPredicate;
Expression newJoinPredicate;
switch (node.getType()) {
case INNER:
// processInnerJoin实现INNER Join类型中的谓词下推,processInnerJoin的实现见底下分析
InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin(inheritedPredicate,
leftEffectivePredicate,
rightEffectivePredicate,
joinPredicate,
node.getLeft().getOutputSymbols());
leftPredicate = innerJoinPushDownResult.getLeftPredicate();
rightPredicate = innerJoinPushDownResult.getRightPredicate();
postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate();
newJoinPredicate = innerJoinPushDownResult.getJoinPredicate();
break;
case LEFT:
// processLimitedOuterJoin功能和processInnerJoin类型,只是往inner表推送谓词的时候,如果谓词字段不在outer表,也不会推送而已
OuterJoinPushDownResult leftOuterJoinPushDownResult = processLimitedOuterJoin(inheritedPredicate,
leftEffectivePredicate,
rightEffectivePredicate,
joinPredicate,
node.getLeft().getOutputSymbols());
leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate();
rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate();
postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate();
newJoinPredicate = leftOuterJoinPushDownResult.getJoinPredicate();
break;
case RIGHT:
OuterJoinPushDownResult rightOuterJoinPushDownResult = processLimitedOuterJoin(inheritedPredicate,
rightEffectivePredicate,
leftEffectivePredicate,
joinPredicate,
node.getRight().getOutputSymbols());
leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate();
rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate();
postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate();
newJoinPredicate = rightOuterJoinPushDownResult.getJoinPredicate();
break;
case FULL:
leftPredicate = TRUE_LITERAL;
rightPredicate = TRUE_LITERAL;
postJoinPredicate = inheritedPredicate;
newJoinPredicate = joinPredicate;
break;
default:
throw new UnsupportedOperationException("Unsupported join type: " + node.getType());
}
newJoinPredicate = simplifyExpression(newJoinPredicate);
// 将结果固定为false的join谓词换成 where 0 = 1
if (newJoinPredicate.equals(BooleanLiteral.FALSE_LITERAL)) {
newJoinPredicate = new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new LongLiteral("0"), new LongLiteral("1"));
}
// Create identity projections for all existing symbols
Assignments.Builder leftProjections = Assignments.builder();
leftProjections.putAll(node.getLeft()
.getOutputSymbols().stream()
.collect(Collectors.toMap(key -> key, Symbol::toSymbolReference)));
Assignments.Builder rightProjections = Assignments.builder();
rightProjections.putAll(node.getRight()
.getOutputSymbols().stream()
.collect(Collectors.toMap(key -> key, Symbol::toSymbolReference)));
// 如下的意思是如果左右表推了谓词,检查要不要增加project,如果新增的谓词没有对应的project,则新增回来。
// Create new projections for the new join clauses
List<JoinNode.EquiJoinClause> equiJoinClauses = new ArrayList<>();
ImmutableList.Builder<Expression> joinFilterBuilder = ImmutableList.builder();
for (Expression conjunct : extractConjuncts(newJoinPredicate)) {
if (joinEqualityExpression(node.getLeft().getOutputSymbols()).test(conjunct)) {
ComparisonExpression equality = (ComparisonExpression) conjunct;
boolean alignedComparison = Iterables.all(SymbolsExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols()));
Expression leftExpression = (alignedComparison) ? equality.getLeft() : equality.getRight();
Expression rightExpression = (alignedComparison) ? equality.getRight() : equality.getLeft();
Symbol leftSymbol = symbolForExpression(leftExpression);
if (!node.getLeft().getOutputSymbols().contains(leftSymbol)) {
leftProjections.put(leftSymbol, leftExpression);
}
Symbol rightSymbol = symbolForExpression(rightExpression);
if (!node.getRight().getOutputSymbols().contains(rightSymbol)) {
rightProjections.put(rightSymbol, rightExpression);
}
equiJoinClauses.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol));
}
else {
joinFilterBuilder.add(conjunct);
}
}
DynamicFiltersResult dynamicFiltersResult = createDynamicFilters(node, equiJoinClauses, session, idAllocator);
Map<String, Symbol> dynamicFilters = dynamicFiltersResult.getDynamicFilters();
leftPredicate = combineConjuncts(leftPredicate, combineConjuncts(dynamicFiltersResult.getPredicates()));
// 根据新的谓词条件重新生成左右表的tablescan
PlanNode leftSource;
PlanNode rightSource;
boolean equiJoinClausesUnmodified = ImmutableSet.copyOf(equiJoinClauses).equals(ImmutableSet.copyOf(node.getCriteria()));
if (!equiJoinClausesUnmodified) {
leftSource = context.rewrite(new ProjectNode(idAllocator.getNextId(), node.getLeft(), leftProjections.build()), leftPredicate);
rightSource = context.rewrite(new ProjectNode(idAllocator.getNextId(), node.getRight(), rightProjections.build()), rightPredicate);
}
else {
leftSource = context.rewrite(node.getLeft(), leftPredicate);
rightSource = context.rewrite(node.getRight(), rightPredicate);
}
Optional<Expression> newJoinFilter = Optional.of(combineConjuncts(joinFilterBuilder.build()));
if (newJoinFilter.get() == TRUE_LITERAL) {
newJoinFilter = Optional.empty();
}
if (node.getType() == INNER && newJoinFilter.isPresent() && equiJoinClauses.isEmpty()) {
// if we do not have any equi conjunct we do not pushdown non-equality condition into
// inner join, so we plan execution as nested-loops-join followed by filter instead
// hash join.
// todo: remove the code when we have support for filter function in nested loop join
postJoinPredicate = combineConjuncts(postJoinPredicate, newJoinFilter.get());
newJoinFilter = Optional.empty();
}
boolean filtersEquivalent =
newJoinFilter.isPresent() == node.getFilter().isPresent() &&
(!newJoinFilter.isPresent() || areExpressionsEquivalent(newJoinFilter.get(), node.getFilter().get()));
// 重新生成Join Node,新的Join Node已经将所有可能得谓词都下推给TableScan了
PlanNode output = node;
if (leftSource != node.getLeft() ||
rightSource != node.getRight() ||
!filtersEquivalent ||
!dynamicFilters.equals(node.getDynamicFilters()) ||
!equiJoinClausesUnmodified) {
leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build());
rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build());
output = new JoinNode(
node.getId(),
node.getType(),
leftSource,
rightSource,
equiJoinClauses,
ImmutableList.<Symbol>builder()
.addAll(leftSource.getOutputSymbols())
.addAll(rightSource.getOutputSymbols())
.build(),
newJoinFilter,
node.getLeftHashSymbol(),
node.getRightHashSymbol(),
node.getDistributionType(),
node.isSpillable(),
dynamicFilters);
}
if (!postJoinPredicate.equals(TRUE_LITERAL)) {
output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate);
}
if (!node.getOutputSymbols().equals(output.getOutputSymbols())) {
output = new ProjectNode(idAllocator.getNextId(), output, Assignments.identity(node.getOutputSymbols()));
}
return output;
}
processInnerJoin的代码走读:
private static InnerJoinPushDownResult processInnerJoin(Expression inheritedPredicate, Expression leftEffectivePredicate, Expression rightEffectivePredicate, Expression joinPredicate, Collection<Symbol> leftSymbols)
{
checkArgument(Iterables.all(SymbolsExtractor.extractUnique(leftEffectivePredicate), in(leftSymbols)), "leftEffectivePredicate must only contain symbols from leftSymbols");
checkArgument(Iterables.all(SymbolsExtractor.extractUnique(rightEffectivePredicate), not(in(leftSymbols))), "rightEffectivePredicate must not contain symbols from leftSymbols");
ImmutableList.Builder<Expression> leftPushDownConjuncts = ImmutableList.builder();
ImmutableList.Builder<Expression> rightPushDownConjuncts = ImmutableList.builder();
ImmutableList.Builder<Expression> joinConjuncts = ImmutableList.builder();
// 首先过滤掉哪些不确定的谓词,
joinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(DeterminismEvaluator::isDeterministic)));
inheritedPredicate = filterDeterministicConjuncts(inheritedPredicate);
joinConjuncts.addAll(filter(extractConjuncts(joinPredicate), not(DeterminismEvaluator::isDeterministic)));
joinPredicate = filterDeterministicConjuncts(joinPredicate);
leftEffectivePredicate = filterDeterministicConjuncts(leftEffectivePredicate);
rightEffectivePredicate = filterDeterministicConjuncts(rightEffectivePredicate);
// 根据左右包的谓词和join的谓词推断所有的等式,比如a = b, b = c 可以推断出 a = c
// Generate equality inferences
EqualityInference allInference = createEqualityInference(inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate);
EqualityInference allInferenceWithoutLeftInferred = createEqualityInference(inheritedPredicate, rightEffectivePredicate, joinPredicate);
EqualityInference allInferenceWithoutRightInferred = createEqualityInference(inheritedPredicate, leftEffectivePredicate, joinPredicate);
// 下边的代码的主要意思是如果发现可以通过推断增加更多的谓词,则将推断新增的谓词加上去
// 不如上述推断出新的谓词a = c,则将a = c 加上去到谓词列表中
// Sort through conjuncts in inheritedPredicate that were not used for inference
for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) {
Expression leftRewrittenConjunct = allInference.rewriteExpression(conjunct, in(leftSymbols));
if (leftRewrittenConjunct != null) {
leftPushDownConjuncts.add(leftRewrittenConjunct);
}
Expression rightRewrittenConjunct = allInference.rewriteExpression(conjunct, not(in(leftSymbols)));
if (rightRewrittenConjunct != null) {
rightPushDownConjuncts.add(rightRewrittenConjunct);
}
// Drop predicate after join only if unable to push down to either side
if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) {
joinConjuncts.add(conjunct);
}
}
// 如下主要是实现尝试将右边的谓词推到左边TableScan或者将左边的谓词推到右边TableScan
// See if we can push the right effective predicate to the left side
for (Expression conjunct : EqualityInference.nonInferrableConjuncts(rightEffectivePredicate)) {
Expression rewritten = allInference.rewriteExpression(conjunct, in(leftSymbols));
if (rewritten != null) {
leftPushDownConjuncts.add(rewritten);
}
}
// See if we can push the left effective predicate to the right side
for (Expression conjunct : EqualityInference.nonInferrableConjuncts(leftEffectivePredicate)) {
Expression rewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols)));
if (rewritten != null) {
rightPushDownConjuncts.add(rewritten);
}
}
// 如下主要将Join谓词拆开,看下Join谓词能不能推给左右的TableScan
for (Expression conjunct : EqualityInference.nonInferrableConjuncts(joinPredicate)) {
Expression leftRewritten = allInference.rewriteExpression(conjunct, in(leftSymbols));
if (leftRewritten != null) {
leftPushDownConjuncts.add(leftRewritten);
}
Expression rightRewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols)));
if (rightRewritten != null) {
rightPushDownConjuncts.add(rightRewritten);
}
if (leftRewritten == null && rightRewritten == null) {
joinConjuncts.add(conjunct);
}
}
// 如下是将推断出来的新谓词也增加到左右表的过滤条件中
// 比如左表x = 1 如果可以推断出右表可以增加y=1,则在右表也增加y=1的过滤条件
leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeEqualities());
rightPushDownConjuncts.addAll(allInferenceWithoutRightInferred.generateEqualitiesPartitionedBy(not(in(leftSymbols))).getScopeEqualities());
joinConjuncts.addAll(allInference.generateEqualitiesPartitionedBy(in(leftSymbols)::apply).getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as part of the join predicate
return new InnerJoinPushDownResult(combineConjuncts(leftPushDownConjuncts.build()), combineConjuncts(rightPushDownConjuncts.build()), combineConjuncts(joinConjuncts.build()), TRUE_LITERAL);
}