[groovy] branch master updated: Tweak row number in GINQ

Previous Topic Next Topic
 
classic Classic list List threaded Threaded
1 message Options
Reply | Threaded
Open this post in threaded view
|

[groovy] branch master updated: Tweak row number in GINQ

Daniel.Sun
This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new 74c4585  Tweak row number in GINQ
74c4585 is described below

commit 74c45859917a215e61a0629a902d1e5ca3cd00e7
Author: Daniel Sun <[hidden email]>
AuthorDate: Sun Jan 10 17:16:05 2021 +0800

    Tweak row number in GINQ
---
 .../ginq/provider/collection/GinqAstWalker.groovy  |  9 ++---
 .../collection/runtime/QueryableCollection.java    |  2 +-
 .../test/org/apache/groovy/ginq/GinqTest.groovy    | 40 +++++++++++++++++++---
 3 files changed, 42 insertions(+), 9 deletions(-)

diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
index a3b8b64..09968f7 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
@@ -195,7 +195,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
                 ))
         )
         if (rowNumberUsed) {
-            statementList << declS(localVarX(rowNumberName), ctorX(ATOMIC_LONG_TYPE, new ConstantExpression(0L)))
+            statementList << declS(localVarX(rowNumberName), ctorX(ATOMIC_LONG_TYPE, new ConstantExpression(-1L)))
         }
 
         final resultName = "__r${System.nanoTime()}"
@@ -720,13 +720,14 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         }
 
         final boolean parallel = isParallel()
+        final VariableExpression supplyAsyncLambdaParam = varX(supplyAsyncLambdaParamName)
         lambdaCode = ((ListExpression) new ListExpression(Collections.singletonList(lambdaCode)).transformExpression(new ExpressionTransformer() {
             @Override
             Expression transform(Expression expression) {
                 if (expression instanceof VariableExpression) {
                     if (_RN == expression.text) {
                         currentGinqExpression.putNodeMetaData(__RN_USED, true)
-                        return callX(varX(rowNumberName), 'getAndIncrement')
+                        return parallel ? supplyAsyncLambdaParam : callX(varX(rowNumberName), 'get')
                     }
                 }
 
@@ -788,7 +789,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
 
                                 def windowDefinitionFactoryMethodCallExpression = constructWindowDefinitionFactoryMethodCallExpression(expression, dataSourceExpression)
                                 Expression newObjectExpression = callX(wqVar, 'over', args(
-                                        callX(TUPLE_TYPE, 'tuple', args(currentRecordVar, parallel ? varX(supplyAsyncLambdaParamName) : getRowNumberMethodCall())),
+                                        callX(TUPLE_TYPE, 'tuple', args(currentRecordVar, parallel ? supplyAsyncLambdaParam : getRowNumberMethodCall())),
                                         windowDefinitionFactoryMethodCallExpression
                                 ))
                                 result = callX(
@@ -815,7 +816,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
         })).getExpression(0)
 
         def extra = []
-        if (enableCount) {
+        if (enableCount || rowNumberUsed) {
             currentGinqExpression.putNodeMetaData(__RN_USED, true)
             extra << callX(varX(rowNumberName), 'getAndIncrement')
         }
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
index 6e3e621..55aca1f 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
@@ -552,7 +552,7 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
                 new PartitionCacheKey(windowDefinition.partitionBy().apply(currentRecord.getV1()), partitionId),
                 partitionCacheKey -> from(Collections.singletonList(currentRecord)).innerHashJoin(
                         allPartitionCache.computeIfAbsent(partitionId, pid -> {
-                            long[] rn = new long[]{1L};
+                            long[] rn = new long[]{0L};
                             List<Tuple2<T, Long>> listWithIndex =
                                     this.toList().stream()
                                             .map(e -> Tuple.tuple(e, rn[0]++))
diff --git a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
index b02beba..881205c 100644
--- a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
+++ b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
@@ -4713,6 +4713,39 @@ class GinqTest {
     }
 
     @Test
+    void "testGinq - parallel - 4"() {
+        assertGinqScript '''
+            assert [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5],
+                    [5, 6], [6, 7], [7, 8], [8, 9], [9, 10]] == GQ(parallel:true) {
+                from n in 1..10
+                select _rn, n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - parallel - 5"() {
+        assertGinqScript '''
+            assert [[0, 0, 1], [1, 1, 2], [2, 2, 3], [3, 3, 4], [4, 4, 5],
+                    [5, 5, 6], [6, 6, 7], [7, 7, 8], [8, 8, 9], [9, 9, 10]] == GQ(parallel:true) {
+                from n in 1..10
+                select _rn, (rowNumber() over(orderby n)), n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - parallel - 6"() {
+        assertGinqScript '''
+            assert [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
+                    [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]] == GQ(parallel:true) {
+                from n in 0..<10
+                select n, (rowNumber() over(orderby n))
+            }.toList()
+        '''
+    }
+
+    @Test
     void "testGinq - window - 0"() {
         assertGinqScript '''
 // tag::ginq_winfunction_01[]
@@ -5292,10 +5325,9 @@ class GinqTest {
     @Test
     void "testGinq - window - 49"() {
         assertGinqScript '''
-            assert [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
-                    [5, 5], [6, 6], [7, 7], [8, 8], [9, 9]] == GQ(parallel:true) {
-                from n in 0..<10
-                select n, (rowNumber() over(orderby n))
+            assert [['a', 0], ['b', 1], ['aa', 0], ['bb', 1]] == GQ {
+                from s in ['a', 'b', 'aa', 'bb']
+                select s, (rowNumber() over(partitionby s.length() orderby s))
             }.toList()
         '''
     }

Apache Groovy committer & PMC member

Blog: http://blog.sunlan.me
Twitter: @daniel_sun