[groovy] branch master updated: Tweak parallel querying in GINQ further

Previous Topic Next Topic
 
classic Classic list List threaded Threaded
1 message Options
Reply | Threaded
Open this post in threaded view
|

[groovy] branch master updated: Tweak parallel querying in GINQ further

Daniel.Sun
This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new 4781630  Tweak parallel querying in GINQ further
4781630 is described below

commit 478163025d5747d9fe3199bd6ad9deda20e4b110
Author: Daniel Sun <[hidden email]>
AuthorDate: Sun Jan 10 01:54:16 2021 +0800

    Tweak parallel querying in GINQ further
---
 .../ginq/provider/collection/GinqAstWalker.groovy  | 35 +++++++++++++++-------
 .../collection/runtime/QueryableCollection.java    |  3 +-
 .../collection/runtime/QueryableHelper.groovy      |  7 +++--
 3 files changed, 30 insertions(+), 15 deletions(-)

diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
index 749b3bb..1182761 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
@@ -45,7 +45,6 @@ import org.apache.groovy.ginq.provider.collection.runtime.ValueBound
 import org.apache.groovy.ginq.provider.collection.runtime.WindowDefinition
 import org.apache.groovy.util.Maps
 import org.codehaus.groovy.GroovyBugError
-import org.codehaus.groovy.ast.ClassHelper
 import org.codehaus.groovy.ast.ClassNode
 import org.codehaus.groovy.ast.CodeVisitorSupport
 import org.codehaus.groovy.ast.Parameter
@@ -77,6 +76,7 @@ import java.util.function.Consumer
 import java.util.stream.Collectors
 
 import static groovy.lang.Tuple.tuple
+import static org.codehaus.groovy.ast.ClassHelper.DYNAMIC_TYPE
 import static org.codehaus.groovy.ast.ClassHelper.makeCached
 import static org.codehaus.groovy.ast.ClassHelper.makeWithoutCaching
 import static org.codehaus.groovy.ast.tools.GeneralUtils.args
@@ -95,6 +95,7 @@ import static org.codehaus.groovy.ast.tools.GeneralUtils.propX
 import static org.codehaus.groovy.ast.tools.GeneralUtils.returnS
 import static org.codehaus.groovy.ast.tools.GeneralUtils.stmt
 import static org.codehaus.groovy.ast.tools.GeneralUtils.varX
+
 /**
  * Visit AST of GINQ to generate target method calls for GINQ
  *
@@ -410,14 +411,14 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
             statementList.add(stmt(listX(leftExpressionList)))
             argumentExpressionList << lambdaX(
                     params(
-                            param(ClassHelper.DYNAMIC_TYPE, otherParamName)
+                            param(DYNAMIC_TYPE, otherParamName)
                     ),
                     block(statementList as Statement[])
             )
 
             argumentExpressionList << lambdaX(
                     params(
-                            param(ClassHelper.DYNAMIC_TYPE, joinExpression.aliasExpr.text)
+                            param(DYNAMIC_TYPE, joinExpression.aliasExpr.text)
                     ),
                     block(stmt(listX(rightExpressionList)))
             )
@@ -425,8 +426,8 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
             statementList.add(stmt(filterExpr))
             argumentExpressionList << (null == onExpression ? EmptyExpression.INSTANCE : lambdaX(
                     params(
-                            param(ClassHelper.DYNAMIC_TYPE, otherParamName),
-                            param(ClassHelper.DYNAMIC_TYPE, joinExpression.aliasExpr.text)
+                            param(DYNAMIC_TYPE, otherParamName),
+                            param(DYNAMIC_TYPE, joinExpression.aliasExpr.text)
                     ),
                     block(statementList as Statement[])))
         }
@@ -772,7 +773,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
                                     }))).getExpression(0)
 
                                     argumentExpressionList << lambdaX(
-                                            params(param(ClassHelper.DYNAMIC_TYPE, windowFunctionLambdaName)),
+                                            params(param(DYNAMIC_TYPE, windowFunctionLambdaName)),
                                             block(stmt(windowFunctionLambdaCode))
                                     )
 
@@ -784,9 +785,13 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
                                     }
                                 }
 
+                                def parallel = isParallel()
+                                final supplyAsyncLambdaParamName = "__salp${System.nanoTime()}"
+
                                 def windowDefinitionFactoryMethodCallExpression = constructWindowDefinitionFactoryMethodCallExpression(expression, dataSourceExpression)
+                                def rowNumberGetMethodCall = callX(varX(rowNumberName), 'get')
                                 Expression newObjectExpression = callX(wqVar, 'over', args(
-                                        callX(TUPLE_TYPE, 'tuple', args(currentRecordVar, callX(varX(rowNumberName), 'get'))),
+                                        callX(TUPLE_TYPE, 'tuple', args(currentRecordVar, parallel ? varX(supplyAsyncLambdaParamName) : rowNumberGetMethodCall)),
                                         windowDefinitionFactoryMethodCallExpression
                                 ))
                                 result = callX(
@@ -795,11 +800,19 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
                                         args(argumentExpressionList)
                                 )
 
-                                if (isParallel()) {
+                                if (parallel) {
                                     result = callX(
                                             new ClassExpression(QUERYABLE_HELPER_TYPE),
                                             "supplyAsync",
-                                            lambdaX(stmt(result))
+                                            args(
+                                                    lambdaX(
+                                                            params(
+                                                                    param(DYNAMIC_TYPE, supplyAsyncLambdaParamName)
+                                                            ),
+                                                            stmt(result)
+                                                    ),
+                                                    rowNumberGetMethodCall
+                                            )
                                     )
                                 }
                             } else {
@@ -826,7 +839,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
             extra << callX(varX(rowNumberName), 'getAndIncrement')
         }
 
-        def selectMethodCallExpression = callXWithLambda(selectMethodReceiver, "select", dataSourceExpression, lambdaCode, extra, param(ClassHelper.DYNAMIC_TYPE, getWindowQueryableName()))
+        def selectMethodCallExpression = callXWithLambda(selectMethodReceiver, "select", dataSourceExpression, lambdaCode, extra, param(DYNAMIC_TYPE, getWindowQueryableName()))
 
         currentGinqExpression.putNodeMetaData(__VISITING_SELECT, false)
 
@@ -1374,7 +1387,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         }
         statementList.add(stmt(paramNameAndLambdaCode.v3))
 
-        def paramList = [param(ClassHelper.DYNAMIC_TYPE, paramNameAndLambdaCode.v1)]
+        def paramList = [param(DYNAMIC_TYPE, paramNameAndLambdaCode.v1)]
         if (extraParams) {
             paramList.addAll(Arrays.asList(extraParams))
         }
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
index 2447fb1..5a279da 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
@@ -290,7 +290,8 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
 
         Stream<U> stream = this.stream().map((T t) -> mapper.apply(t, this));
         if (useWindowFunction && TRUE_STR.equals(originalParallel)) {
-            stream = stream.map((U u) -> {
+            // invoke `collect` to trigger the intermediate operator, which will create `CompletableFuture` instances
+            stream = stream.collect(Collectors.toList()).stream().map((U u) -> {
                 Function<? super U, ?> transform = e -> {
                     try {
                         return e instanceof CompletableFuture ? ((CompletableFuture) e).get() : e;
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy
index 2543cf9..cfc6c77 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy
@@ -23,10 +23,11 @@ import groovy.transform.CompileStatic
 import java.util.concurrent.CompletableFuture
 import java.util.concurrent.ExecutorService
 import java.util.concurrent.Executors
-import java.util.function.Supplier
+import java.util.function.Function
 import java.util.stream.Collectors
 
 import static org.apache.groovy.ginq.provider.collection.runtime.Queryable.from
+
 /**
  * Helper for {@link Queryable}
  *
@@ -73,8 +74,8 @@ class QueryableHelper {
         throw new TooManyValuesException("subquery returns more than one value: $list")
     }
 
-    static <U> CompletableFuture<U> supplyAsync(Supplier<U> supplier) {
-        return CompletableFuture.supplyAsync(supplier, THREAD_POOL)
+    static <T, U> CompletableFuture<U> supplyAsync(Function<? super T, ? extends U> function, T param) {
+        return CompletableFuture.supplyAsync(() -> { function.apply(param) }, THREAD_POOL)
     }
 
     static boolean isParallel() {

Apache Groovy committer & PMC member

Blog: http://blog.sunlan.me
Twitter: @daniel_sun