[groovy] branch master updated: Minor refactoring: simplify parallel querying

Previous Topic Next Topic
 
classic Classic list List threaded Threaded
1 message Options
Reply | Threaded
Open this post in threaded view
|

[groovy] branch master updated: Minor refactoring: simplify parallel querying

Daniel.Sun
This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new 2ec2b9a  Minor refactoring: simplify parallel querying
2ec2b9a is described below

commit 2ec2b9adc6fae786c3237d6dc8478e7a45be722d
Author: Daniel Sun <[hidden email]>
AuthorDate: Sun Jan 10 14:56:36 2021 +0800

    Minor refactoring: simplify parallel querying
---
 .../ginq/provider/collection/GinqAstWalker.groovy  | 69 +++++++++++++---------
 .../collection/runtime/QueryableCollection.java    | 57 +++---------------
 .../collection/runtime/QueryableHelper.groovy      | 14 +++++
 3 files changed, 64 insertions(+), 76 deletions(-)

diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
index 1182761..a3b8b64 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
@@ -89,13 +89,13 @@ import static org.codehaus.groovy.ast.tools.GeneralUtils.fieldX
 import static org.codehaus.groovy.ast.tools.GeneralUtils.lambdaX
 import static org.codehaus.groovy.ast.tools.GeneralUtils.listX
 import static org.codehaus.groovy.ast.tools.GeneralUtils.localVarX
+import static org.codehaus.groovy.ast.tools.GeneralUtils.nullX
 import static org.codehaus.groovy.ast.tools.GeneralUtils.param
 import static org.codehaus.groovy.ast.tools.GeneralUtils.params
 import static org.codehaus.groovy.ast.tools.GeneralUtils.propX
 import static org.codehaus.groovy.ast.tools.GeneralUtils.returnS
 import static org.codehaus.groovy.ast.tools.GeneralUtils.stmt
 import static org.codehaus.groovy.ast.tools.GeneralUtils.varX
-
 /**
  * Visit AST of GINQ to generate target method calls for GINQ
  *
@@ -719,6 +719,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
             lambdaCode = namedListCtorCallExpression
         }
 
+        final boolean parallel = isParallel()
         lambdaCode = ((ListExpression) new ListExpression(Collections.singletonList(lambdaCode)).transformExpression(new ExpressionTransformer() {
             @Override
             Expression transform(Expression expression) {
@@ -785,13 +786,9 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
                                     }
                                 }
 
-                                def parallel = isParallel()
-                                final supplyAsyncLambdaParamName = "__salp${System.nanoTime()}"
-
                                 def windowDefinitionFactoryMethodCallExpression = constructWindowDefinitionFactoryMethodCallExpression(expression, dataSourceExpression)
-                                def rowNumberGetMethodCall = callX(varX(rowNumberName), 'get')
                                 Expression newObjectExpression = callX(wqVar, 'over', args(
-                                        callX(TUPLE_TYPE, 'tuple', args(currentRecordVar, parallel ? varX(supplyAsyncLambdaParamName) : rowNumberGetMethodCall)),
+                                        callX(TUPLE_TYPE, 'tuple', args(currentRecordVar, parallel ? varX(supplyAsyncLambdaParamName) : getRowNumberMethodCall())),
                                         windowDefinitionFactoryMethodCallExpression
                                 ))
                                 result = callX(
@@ -799,22 +796,6 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
                                         windowFunctionMethodCallExpression.methodAsString,
                                         args(argumentExpressionList)
                                 )
-
-                                if (parallel) {
-                                    result = callX(
-                                            new ClassExpression(QUERYABLE_HELPER_TYPE),
-                                            "supplyAsync",
-                                            args(
-                                                    lambdaX(
-                                                            params(
-                                                                    param(DYNAMIC_TYPE, supplyAsyncLambdaParamName)
-                                                            ),
-                                                            stmt(result)
-                                                    ),
-                                                    rowNumberGetMethodCall
-                                            )
-                                    )
-                                }
                             } else {
                                 GinqAstWalker.this.collectSyntaxError(new GinqSyntaxError(
                                         "Unsupported window function: `${windowFunctionMethodCallExpression.methodAsString}`",
@@ -839,13 +820,18 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
             extra << callX(varX(rowNumberName), 'getAndIncrement')
         }
 
-        def selectMethodCallExpression = callXWithLambda(selectMethodReceiver, "select", dataSourceExpression, lambdaCode, extra, param(DYNAMIC_TYPE, getWindowQueryableName()))
+        def selectMethodCallExpression = callXWithLambda(selectMethodReceiver, "select", dataSourceExpression, parallel, lambdaCode, extra, param(DYNAMIC_TYPE, getWindowQueryableName()))
 
         currentGinqExpression.putNodeMetaData(__VISITING_SELECT, false)
 
         return selectMethodCallExpression
     }
 
+    private MethodCallExpression getRowNumberMethodCall() {
+        final rowNumberGetMethodCall = callX(varX(rowNumberName), 'get')
+        return rowNumberGetMethodCall
+    }
+
     private MethodCallExpression constructWindowDefinitionFactoryMethodCallExpression(MethodCallExpression methodCallExpression, DataSourceExpression dataSourceExpression) {
         Expression classifierExpr = null
         Expression orderExpr = null
@@ -1118,6 +1104,18 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         return name
     }
 
+    private int supplyAsyncLambdaParamNameSeq = 0
+    private String getSupplyAsyncLambdaParamName() {
+        String name = (String) currentGinqExpression.getNodeMetaData(SUPPLY_ASYNC_LAMBDA_PARAM_NAME_PREFIX)
+
+        if (!name) {
+            name = "${SUPPLY_ASYNC_LAMBDA_PARAM_NAME_PREFIX}${supplyAsyncLambdaParamNameSeq++}"
+            currentGinqExpression.putNodeMetaData(SUPPLY_ASYNC_LAMBDA_PARAM_NAME_PREFIX, name)
+        }
+
+        return name
+    }
+
     private MethodCallExpression getMetaDataMethodCall(String key) {
         callX(varX(metaDataMapName), "get", new ConstantExpression(key))
     }
@@ -1367,8 +1365,8 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         this.callXWithLambda(receiver, methodName, dataSourceExpression, lambdaCode, Collections.emptyList(), extraParams)
     }
 
-    private MethodCallExpression callXWithLambda(Expression receiver, String methodName, DataSourceExpression dataSourceExpression, Expression lambdaCode, List<Expression> extraLambdaCode, Parameter... extraParams) {
-        LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, lambdaCode, extraParams)
+    private MethodCallExpression callXWithLambda(Expression receiver, String methodName, DataSourceExpression dataSourceExpression, boolean async = false, Expression lambdaCode, List<Expression> extraLambdaCode, Parameter... extraParams) {
+        LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, async, lambdaCode, extraParams)
 
         if (extraLambdaCode) {
             ((BlockStatement) lambdaExpression.code).getStatements().addAll(0, extraLambdaCode.collect { stmt(it) })
@@ -1377,7 +1375,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         callXWithLambda(receiver, methodName, lambdaExpression)
     }
 
-    private LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode, Parameter... extraParams) {
+    private LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, boolean async = false, Expression lambdaCode, Parameter... extraParams) {
         Tuple3<String, List<DeclarationExpression>, Expression> paramNameAndLambdaCode = correctVariablesOfLambdaExpression(dataSourceExpression, lambdaCode)
 
         List<DeclarationExpression> declarationExpressionList = paramNameAndLambdaCode.v2
@@ -1385,7 +1383,23 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         if (!visitingWindowFunction) {
             statementList.addAll(declarationExpressionList.stream().map(e -> stmt(e)).collect(Collectors.toList()))
         }
-        statementList.add(stmt(paramNameAndLambdaCode.v3))
+
+        def transformedLambdCode = paramNameAndLambdaCode.v3
+        if (async) {
+            ArgumentListExpression  argumentListExpression
+            argumentListExpression = rowNumberUsed
+                                        ? args(lambdaX(params(param(DYNAMIC_TYPE, supplyAsyncLambdaParamName)),
+                                                        stmt(transformedLambdCode)),
+                                                rowNumberUsed ? getRowNumberMethodCall() : nullX())
+                                        : args(lambdaX(stmt(transformedLambdCode)))
+
+            transformedLambdCode = callX(
+                    new ClassExpression(QUERYABLE_HELPER_TYPE),
+                    "supplyAsync",
+                    argumentListExpression
+            )
+        }
+        statementList.add(stmt(transformedLambdCode))
 
         def paramList = [param(DYNAMIC_TYPE, paramNameAndLambdaCode.v1)]
         if (extraParams) {
@@ -1564,6 +1578,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
     private static final String __META_DATA_MAP_NAME_PREFIX = '__metaDataMap_'
     private static final String __WINDOW_QUERYABLE_NAME = '__wq_'
     private static final String __ROW_NUMBER_NAME_PREFIX = '__rowNumber_'
+    private static final String SUPPLY_ASYNC_LAMBDA_PARAM_NAME_PREFIX = "__salp_"
     private static final String __SOURCE_RECORD = "__sourceRecord"
     private static final String __GROUP = "__group"
     private static final String MD_GROUP_NAME_LIST = "groupNameList"
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
index 4db5412..6e3e621 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
@@ -30,8 +30,6 @@ import org.codehaus.groovy.runtime.dgmimpl.NumberNumberMinus;
 import org.codehaus.groovy.runtime.typehandling.NumberMath;
 
 import java.io.Serializable;
-import java.lang.reflect.Constructor;
-import java.lang.reflect.Method;
 import java.math.BigDecimal;
 import java.math.RoundingMode;
 import java.util.ArrayList;
@@ -279,55 +277,31 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
 
     @Override
     public <U> Queryable<U> select(BiFunction<? super T, ? super Queryable<? extends T>, ? extends U> mapper) {
-        String originalParallel = null;
+        final String originalParallel = QueryableHelper.getVar(PARALLEL);
+        QueryableHelper.setVar(PARALLEL, FALSE_STR); // ensure the row number is generated sequentially
         boolean useWindowFunction = TRUE_STR.equals(QueryableHelper.getVar(USE_WINDOW_FUNCTION));
 
         if (useWindowFunction) {
-            originalParallel = QueryableHelper.getVar(PARALLEL);
-            QueryableHelper.setVar(PARALLEL, FALSE_STR); // ensure the row number is generated sequentially
             this.makeReusable();
         }
 
         Stream<U> stream = this.stream().map((T t) -> mapper.apply(t, this));
-        if (useWindowFunction && TRUE_STR.equals(originalParallel)) {
+        if (TRUE_STR.equals(originalParallel)) {
             // invoke `collect` to trigger the intermediate operator, which will create `CompletableFuture` instances
             stream = stream.collect(Collectors.toList()).parallelStream().map((U u) -> {
-                Function<? super U, ?> transform = e -> {
-                    try {
-                        return e instanceof CompletableFuture ? ((CompletableFuture) e).get() : e;
-                    } catch (InterruptedException | ExecutionException ex) {
-                        throw new GroovyRuntimeException(ex);
-                    }
-                };
-
-                if (instanceOfNamedRecord(u)) {
-                    Tuple record = (Tuple) u;
-                    List<?> transformed = (List<?>) record.stream().map(transform).collect(Collectors.toList());
-                    try {
-                        List<String> nameList = (List<String>) GET_NAME_SET_METHOD.invoke(record);
-                        List<String> aliasList = (List<String>) GET_ALIAS_LIST_METHOD.invoke(record);
-                        return (U) NAMED_RECORD_CONSTRUCTOR.newInstance(transformed, nameList, aliasList);
-                    } catch (ReflectiveOperationException ex) {
-                        throw new GroovyRuntimeException(ex);
-                    }
+                try {
+                    return (U) ((CompletableFuture) u).get();
+                } catch (InterruptedException | ExecutionException ex) {
+                    throw new GroovyRuntimeException(ex);
                 }
-
-                return (U) transform.apply(u);
             });
         }
 
-        if (useWindowFunction) {
-            QueryableHelper.setVar(PARALLEL, originalParallel);
-        }
+        QueryableHelper.setVar(PARALLEL, originalParallel);
 
         return from(stream);
     }
 
-    private static <U> boolean instanceOfNamedRecord(U u) {
-        // workaround joint compilation issue
-        return (u instanceof Tuple) && NAMED_RECORD_CLASS_NAME.equals(u.getClass().getName());
-    }
-
     @Override
     public Queryable<T> distinct() {
         Stream<T> stream = this.stream().distinct();
@@ -736,24 +710,9 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
     private final Lock readLock = rwl.readLock();
     private final Lock writeLock = rwl.writeLock();
     private static final BigDecimal BD_TWO = BigDecimal.valueOf(2);
-    private static final String NAMED_RECORD_CLASS_NAME = "org.apache.groovy.ginq.provider.collection.runtime.NamedRecord";
-    private static final Constructor<?> NAMED_RECORD_CONSTRUCTOR;
-    private static final Method GET_ALIAS_LIST_METHOD;
-    private static final Method GET_NAME_SET_METHOD;
     private static final String USE_WINDOW_FUNCTION = "useWindowFunction";
     private static final String PARALLEL = "parallel";
     private static final String TRUE_STR = "true";
     private static final String FALSE_STR = "false";
     private static final long serialVersionUID = -5067092453136522893L;
-
-    static {
-        try {
-            final Class<?> namedRecordClass = Class.forName(NAMED_RECORD_CLASS_NAME);
-            NAMED_RECORD_CONSTRUCTOR = namedRecordClass.getConstructor(List.class, List.class, List.class);
-            GET_ALIAS_LIST_METHOD = namedRecordClass.getDeclaredMethod("getAliasList");
-            GET_NAME_SET_METHOD = namedRecordClass.getMethod("getNameList");
-        } catch (NoSuchMethodException | ClassNotFoundException ex) {
-            throw new GroovyRuntimeException(ex);
-        }
-    }
 }
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy
index cfc6c77..bb458d5 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy
@@ -24,6 +24,7 @@ import java.util.concurrent.CompletableFuture
 import java.util.concurrent.ExecutorService
 import java.util.concurrent.Executors
 import java.util.function.Function
+import java.util.function.Supplier
 import java.util.stream.Collectors
 
 import static org.apache.groovy.ginq.provider.collection.runtime.Queryable.from
@@ -74,6 +75,10 @@ class QueryableHelper {
         throw new TooManyValuesException("subquery returns more than one value: $list")
     }
 
+    static <U> CompletableFuture<U> supplyAsync(Supplier<U> supplier) {
+        return CompletableFuture.supplyAsync(supplier, THREAD_POOL)
+    }
+
     static <T, U> CompletableFuture<U> supplyAsync(Function<? super T, ? extends U> function, T param) {
         return CompletableFuture.supplyAsync(() -> { function.apply(param) }, THREAD_POOL)
     }
@@ -99,5 +104,14 @@ class QueryableHelper {
     private static final String PARALLEL = "parallel"
     private static final String TRUE_STR = "true"
 
+    static {
+        Runtime.addShutdownHook {
+            try {
+                THREAD_POOL.shutdownNow()
+            } catch (ignored) {
+            }
+        }
+    }
+
     private QueryableHelper() {}
 }

Apache Groovy committer & PMC member

Blog: http://blog.sunlan.me
Twitter: @daniel_sun