[groovy] branch master updated: Tweak parallel querying in GINQ

Previous Topic Next Topic
 
classic Classic list List threaded Threaded
1 message Options
Reply | Threaded
Open this post in threaded view
|

[groovy] branch master updated: Tweak parallel querying in GINQ

Daniel.Sun
This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new 7dedf36  Tweak parallel querying in GINQ
7dedf36 is described below

commit 7dedf360505e8a2df574f4349c4ba803b13b7f84
Author: Daniel Sun <[hidden email]>
AuthorDate: Sun Jan 10 00:23:58 2021 +0800

    Tweak parallel querying in GINQ
---
 .../ginq/provider/collection/GinqAstWalker.groovy  | 22 ++++++---
 .../collection/runtime/AsciiTableMaker.groovy      |  2 +-
 .../provider/collection/runtime/NamedRecord.groovy |  4 ++
 .../provider/collection/runtime/NamedTuple.groovy  |  6 +--
 .../collection/runtime/QueryableCollection.java    | 52 +++++++++++++++++++++-
 .../collection/runtime/QueryableHelper.groovy      | 16 ++++++-
 .../test/org/apache/groovy/ginq/GinqTest.groovy    | 48 ++++++++++++++++++++
 7 files changed, 139 insertions(+), 11 deletions(-)

diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
index fd8ef9a..749b3bb 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
@@ -174,13 +174,13 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         MethodCallExpression selectMethodCallExpression = this.visitSelectExpression(selectExpression)
 
         List<Statement> statementList = []
-        boolean useWindowFunction = isUseWindowFunction(selectExpression)
+        boolean isRootGinqExpression = ginqExpression === ginqExpression.getNodeMetaData(GinqAstBuilder.ROOT_GINQ_EXPRESSION)
+        boolean useWindowFunction = isRootGinqExpression && isUseWindowFunction(ginqExpression)
         if (useWindowFunction) {
             statementList << stmt(callX(QUERYABLE_HELPER_TYPE, 'setVar', args(new ConstantExpression(USE_WINDOW_FUNCTION), new ConstantExpression(TRUE_STR))))
         }
 
-        boolean isRootGinqExpression = ginqExpression === ginqExpression.getNodeMetaData(GinqAstBuilder.ROOT_GINQ_EXPRESSION)
-        boolean parallelEnabled = isRootGinqExpression && TRUE_STR == configuration.get(GinqGroovyMethods.CONF_PARALLEL)
+        boolean parallelEnabled = isRootGinqExpression && isParallel()
         if (parallelEnabled) {
             statementList << stmt(callX(QUERYABLE_HELPER_TYPE, 'setVar', args(new ConstantExpression(PARALLEL), new ConstantExpression(TRUE_STR))))
         }
@@ -214,9 +214,13 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         return result
     }
 
-    private boolean isUseWindowFunction(SelectExpression selectExpression) {
+    private boolean isParallel() {
+        return TRUE_STR == configuration.get(GinqGroovyMethods.CONF_PARALLEL)
+    }
+
+    private boolean isUseWindowFunction(GinqExpression ginqExpression) {
         boolean useWindowFunction = false
-        selectExpression.projectionExpr.visit(new GinqAstBaseVisitor() {
+        ginqExpression.visit(new GinqAstBaseVisitor() {
             @Override
             void visitMethodCallExpression(MethodCallExpression call) {
                 if (isOverMethodCall(call)) {
@@ -790,6 +794,14 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
                                         windowFunctionMethodCallExpression.methodAsString,
                                         args(argumentExpressionList)
                                 )
+
+                                if (isParallel()) {
+                                    result = callX(
+                                            new ClassExpression(QUERYABLE_HELPER_TYPE),
+                                            "supplyAsync",
+                                            lambdaX(stmt(result))
+                                    )
+                                }
                             } else {
                                 GinqAstWalker.this.collectSyntaxError(new GinqSyntaxError(
                                         "Unsupported window function: `${windowFunctionMethodCallExpression.methodAsString}`",
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/AsciiTableMaker.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/AsciiTableMaker.groovy
index dd379f6..0c0fafc 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/AsciiTableMaker.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/AsciiTableMaker.groovy
@@ -47,7 +47,7 @@ class AsciiTableMaker {
             List<String[]> list = new ArrayList<>(tableData.size() + 1)
             def firstRecord = tableData.get(0)
             if (firstRecord instanceof NamedRecord) {
-                list.add(((NamedRecord) firstRecord).nameSet as String[])
+                list.add(((NamedRecord) firstRecord).nameList.toArray(String[]::new))
                 tableData.stream().forEach(e -> {
                     if (e instanceof NamedRecord) {
                         String[] record = ((List) e).stream().map(c -> c?.toString()).toArray(String[]::new)
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/NamedRecord.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/NamedRecord.groovy
index f99551f..a442eb8 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/NamedRecord.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/NamedRecord.groovy
@@ -47,6 +47,10 @@ class NamedRecord<E, T> extends NamedTuple<E> {
         return sourceRecord.get(name)
     }
 
+    List<String> getAliasList() {
+        return Collections.unmodifiableList(aliasList)
+    }
+
     NamedRecord<E, T> sourceRecord(T sr) {
         this.sourceRecord = new SourceRecord<>(sr, aliasList)
         return this
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/NamedTuple.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/NamedTuple.groovy
index d7dbe03..4afcadb 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/NamedTuple.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/NamedTuple.groovy
@@ -59,13 +59,13 @@ class NamedTuple<E> extends Tuple<E> {
         return data.containsKey(name)
     }
 
-    Set<String> getNameSet() {
-        return Collections.unmodifiableSet(data.keySet())
+    List<String> getNameList() {
+        return Collections.unmodifiableList(data.keySet().toList())
     }
 
     @Override
     String toString() {
-        '(' + nameSet.withIndex()
+        '(' + nameList.withIndex()
                 .collect((String n, int i) -> { "${n}:${this[i]}" })
                 .join(', ') + ')'
     }
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
index 534b714..2447fb1 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
@@ -18,6 +18,7 @@
  */
 package org.apache.groovy.ginq.provider.collection.runtime;
 
+import groovy.lang.GroovyRuntimeException;
 import groovy.lang.Tuple;
 import groovy.lang.Tuple2;
 import groovy.lang.Tuple3;
@@ -29,6 +30,8 @@ import org.codehaus.groovy.runtime.dgmimpl.NumberNumberMinus;
 import org.codehaus.groovy.runtime.typehandling.NumberMath;
 
 import java.io.Serializable;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Method;
 import java.math.BigDecimal;
 import java.math.RoundingMode;
 import java.util.ArrayList;
@@ -41,7 +44,9 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.ReadWriteLock;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
@@ -284,6 +289,31 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
         }
 
         Stream<U> stream = this.stream().map((T t) -> mapper.apply(t, this));
+        if (useWindowFunction && TRUE_STR.equals(originalParallel)) {
+            stream = stream.map((U u) -> {
+                Function<? super U, ?> transform = e -> {
+                    try {
+                        return e instanceof CompletableFuture ? ((CompletableFuture) e).get() : e;
+                    } catch (InterruptedException | ExecutionException ex) {
+                        throw new GroovyRuntimeException(ex);
+                    }
+                };
+
+                if (instanceOfNamedRecord(u)) {
+                    Tuple record = (Tuple) u;
+                    List<?> transformed = (List<?>) record.stream().map(transform).collect(Collectors.toList());
+                    try {
+                        List<String> nameList = (List<String>) GET_NAME_SET_METHOD.invoke(record);
+                        List<String> aliasList = (List<String>) GET_ALIAS_LIST_METHOD.invoke(record);
+                        return (U) NAMED_RECORD_CONSTRUCTOR.newInstance(transformed, nameList, aliasList);
+                    } catch (ReflectiveOperationException ex) {
+                        throw new GroovyRuntimeException(ex);
+                    }
+                }
+
+                return (U) transform.apply(u);
+            });
+        }
 
         if (useWindowFunction) {
             QueryableHelper.setVar(PARALLEL, originalParallel);
@@ -292,6 +322,11 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
         return from(stream);
     }
 
+    private static <U> boolean instanceOfNamedRecord(U u) {
+        // workaround joint compilation issue
+        return (u instanceof Tuple) && NAMED_RECORD_CLASS_NAME.equals(u.getClass().getName());
+    }
+
     @Override
     public Queryable<T> distinct() {
         Stream<T> stream = this.stream().distinct();
@@ -624,7 +659,7 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
     }
 
     private static boolean isParallel() {
-        return TRUE_STR.equals(QueryableHelper.getVar(PARALLEL));
+        return QueryableHelper.isParallel();
     }
 
     private boolean isReusable() {
@@ -700,9 +735,24 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
     private final Lock readLock = rwl.readLock();
     private final Lock writeLock = rwl.writeLock();
     private static final BigDecimal BD_TWO = BigDecimal.valueOf(2);
+    private static final String NAMED_RECORD_CLASS_NAME = "org.apache.groovy.ginq.provider.collection.runtime.NamedRecord";
+    private static final Constructor<?> NAMED_RECORD_CONSTRUCTOR;
+    private static final Method GET_ALIAS_LIST_METHOD;
+    private static final Method GET_NAME_SET_METHOD;
     private static final String USE_WINDOW_FUNCTION = "useWindowFunction";
     private static final String PARALLEL = "parallel";
     private static final String TRUE_STR = "true";
     private static final String FALSE_STR = "false";
     private static final long serialVersionUID = -5067092453136522893L;
+
+    static {
+        try {
+            final Class<?> namedRecordClass = Class.forName(NAMED_RECORD_CLASS_NAME);
+            NAMED_RECORD_CONSTRUCTOR = namedRecordClass.getConstructor(List.class, List.class, List.class);
+            GET_ALIAS_LIST_METHOD = namedRecordClass.getDeclaredMethod("getAliasList");
+            GET_NAME_SET_METHOD = namedRecordClass.getMethod("getNameList");
+        } catch (NoSuchMethodException | ClassNotFoundException ex) {
+            throw new GroovyRuntimeException(ex);
+        }
+    }
 }
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy
index 179d026..2543cf9 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy
@@ -20,10 +20,13 @@ package org.apache.groovy.ginq.provider.collection.runtime
 
 import groovy.transform.CompileStatic
 
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.ExecutorService
+import java.util.concurrent.Executors
+import java.util.function.Supplier
 import java.util.stream.Collectors
 
 import static org.apache.groovy.ginq.provider.collection.runtime.Queryable.from
-
 /**
  * Helper for {@link Queryable}
  *
@@ -70,6 +73,14 @@ class QueryableHelper {
         throw new TooManyValuesException("subquery returns more than one value: $list")
     }
 
+    static <U> CompletableFuture<U> supplyAsync(Supplier<U> supplier) {
+        return CompletableFuture.supplyAsync(supplier, THREAD_POOL)
+    }
+
+    static boolean isParallel() {
+        return TRUE_STR == getVar(PARALLEL)
+    }
+
     static <T> void setVar(String name, T value) {
         VAR_HOLDER.get().put(name, value)
     }
@@ -83,6 +94,9 @@ class QueryableHelper {
     }
 
     private static final ThreadLocal<Map<String, Object>> VAR_HOLDER = ThreadLocal.<Map<String, Object>> withInitial(() -> new LinkedHashMap<>())
+    private static final ExecutorService THREAD_POOL = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors())
+    private static final String PARALLEL = "parallel"
+    private static final String TRUE_STR = "true"
 
     private QueryableHelper() {}
 }
diff --git a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
index 4c2f405..b02beba 100644
--- a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
+++ b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
@@ -5603,6 +5603,54 @@ class GinqTest {
     @Test
     void "testGinq - window - 75"() {
         assertGinqScript '''
+            assert [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
+                    [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]] == GQ {
+                from n in (
+                    from m in 0..<10
+                    select m
+                )
+                select n, (rowNumber() over(orderby n))
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - window - 76"() {
+        assertGinqScript '''
+            assert [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
+                    [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]] == GQ(parallel:true) {
+                from v in (
+                    from n in (
+                        from m in 0..<10
+                        select m
+                    )
+                    select n, (rowNumber() over(orderby n)) as rn
+                )
+                select v.n, v.rn
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - window - 77"() {
+        assertGinqScript '''
+            assert [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
+                    [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]] == GQ {
+                from v in (
+                    from n in (
+                        from m in 0..<10
+                        select m
+                    )
+                    select n, (rowNumber() over(orderby n)) as rn
+                )
+                select v.n, v.rn
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - window - 78"() {
+        assertGinqScript '''
 // tag::ginq_winfunction_35[]
             assert [[1, 0.816496580927726],
                     [2, 0.816496580927726],

Apache Groovy committer & PMC member

Blog: http://blog.sunlan.me
Twitter: @daniel_sun