[groovy] branch master updated: Add built-in `stdev` function

Previous Topic Next Topic
 
classic Classic list List threaded Threaded
1 message Options
Reply | Threaded
Open this post in threaded view
|

[groovy] branch master updated: Add built-in `stdev` function

Daniel.Sun
This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new 1f27d29  Add built-in `stdev` function
1f27d29 is described below

commit 1f27d29e071e6de32b908acdb3223e552af44cdd
Author: Daniel Sun <[hidden email]>
AuthorDate: Sat Jan 9 17:44:39 2021 +0800

    Add built-in `stdev` function
---
 .../ginq/provider/collection/GinqAstWalker.groovy  |  5 +++--
 .../provider/collection/runtime/Queryable.java     |  9 ++++++++
 .../collection/runtime/QueryableCollection.java    | 24 ++++++++++++++++++--
 .../groovy-ginq/src/spec/doc/ginq-userguide.adoc   | 17 +++++++++++---
 .../test/org/apache/groovy/ginq/GinqTest.groovy    | 26 ++++++++++++++++++++++
 5 files changed, 74 insertions(+), 7 deletions(-)

diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
index 1549a3e..fd8ef9a 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
@@ -1510,8 +1510,9 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
     private static final String FUNCTION_SUM = 'sum'
     private static final String FUNCTION_AVG = 'avg'
     private static final String FUNCTION_MEDIAN = 'median'
+    private static final String FUNCTION_STDEV = 'stdev'
     private static final String FUNCTION_AGG = 'agg'
-    private static final List<String> AGG_FUNCTION_NAME_LIST = [FUNCTION_COUNT, FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN, FUNCTION_AGG]
+    private static final List<String> AGG_FUNCTION_NAME_LIST = [FUNCTION_COUNT, FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN, FUNCTION_STDEV, FUNCTION_AGG]
 
     private static final String FUNCTION_ROW_NUMBER = 'rowNumber'
     private static final String FUNCTION_LEAD = 'lead'
@@ -1520,7 +1521,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable
     private static final String FUNCTION_LAST_VALUE = 'lastValue'
     private static final String FUNCTION_RANK = 'rank'
     private static final String FUNCTION_DENSE_RANK = 'denseRank'
-    private static final List<String> WINDOW_FUNCTION_LIST = [FUNCTION_COUNT, FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN,
+    private static final List<String> WINDOW_FUNCTION_LIST = [FUNCTION_COUNT, FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN, FUNCTION_STDEV,
                                                               FUNCTION_ROW_NUMBER, FUNCTION_LEAD, FUNCTION_LAG, FUNCTION_FIRST_VALUE, FUNCTION_LAST_VALUE, FUNCTION_RANK, FUNCTION_DENSE_RANK]
 
     private static final String NAMEDRECORD_CLASS_NAME = NamedRecord.class.name
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Queryable.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Queryable.java
index 19e40f8..36e0d6e 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Queryable.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/Queryable.java
@@ -418,6 +418,15 @@ public interface Queryable<T> {
     BigDecimal median(Function<? super T, ? extends Number> mapper);
 
     /**
+     * Aggregate function {@code stdev}, similar to SQL's {@code stdev}
+     *
+     * @param mapper choose the field to calculate the statistical standard deviation
+     * @return statistical standard deviation result
+     * @since 4.0.0
+     */
+    BigDecimal stdev(Function<? super T, ? extends Number> mapper);
+
+    /**
      * The most powerful aggregate function in GINQ, it will receive the grouped result({@link Queryable} instance) and apply any processing
      *
      * @param mapper map the grouped result({@link Queryable} instance) to aggregate result
diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
index faa3081..040ccd6 100644
--- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
+++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java
@@ -25,6 +25,7 @@ import groovy.transform.Internal;
 import org.apache.groovy.internal.util.Supplier;
 import org.apache.groovy.util.SystemUtil;
 import org.codehaus.groovy.runtime.DefaultGroovyMethods;
+import org.codehaus.groovy.runtime.dgmimpl.NumberNumberMinus;
 import org.codehaus.groovy.runtime.typehandling.NumberMath;
 
 import java.io.Serializable;
@@ -53,12 +54,15 @@ import java.util.stream.Stream;
 import java.util.stream.StreamSupport;
 
 import static groovy.lang.Tuple.tuple;
+import static java.lang.Math.pow;
+import static java.lang.Math.sqrt;
 import static java.util.Comparator.naturalOrder;
 import static java.util.Comparator.nullsFirst;
 import static java.util.Comparator.nullsLast;
 import static java.util.Comparator.reverseOrder;
 import static org.apache.groovy.ginq.provider.collection.runtime.Queryable.from;
 import static org.apache.groovy.ginq.provider.collection.runtime.WindowImpl.composeOrders;
+import static org.codehaus.groovy.runtime.typehandling.NumberMath.toBigDecimal;
 
 /**
  * Represents the queryable collections
@@ -345,7 +349,7 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
                     Number n = mapper.apply(e);
                     if (null == n) return BigDecimal.ZERO;
 
-                    return NumberMath.toBigDecimal(n);
+                    return toBigDecimal(n);
                 }).reduce(BigDecimal.ZERO, BigDecimal::add));
     }
 
@@ -362,7 +366,7 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
                 }, (o1, o2) -> o1)
         );
 
-        return ((BigDecimal) result[1]).divide(BigDecimal.valueOf((Long) result[0]), 16, RoundingMode.HALF_UP);
+        return ((BigDecimal) result[1]).divide(toBigDecimal((Long) result[0]), 16, RoundingMode.HALF_UP);
     }
 
     @Override
@@ -409,6 +413,22 @@ class QueryableCollection<T> implements Queryable<T>, Serializable {
     }
 
     @Override
+    public BigDecimal stdev(Function<? super T, ? extends Number> mapper) {
+        BigDecimal avg = this.avg(mapper);
+        Object[] result = agg(q -> q.stream()
+                .map(mapper)
+                .filter(Objects::nonNull)
+                .map(e -> toBigDecimal(pow(NumberNumberMinus.minus(e, avg).doubleValue(), 2)))
+                .reduce(new Object[]{0L, BigDecimal.ZERO}, (r, e) -> {
+                    r[0] = (Long) r[0] + 1;
+                    r[1] = ((BigDecimal) r[1]).add(e);
+                    return r;
+                }, (o1, o2) -> o1));
+
+        return toBigDecimal(sqrt(((BigDecimal) result[1]).divide(toBigDecimal((Long) result[0]), 16, RoundingMode.HALF_UP).doubleValue()));
+    }
+
+    @Override
     public <U> U agg(Function<? super Queryable<? extends T>, ? extends U> mapper) {
         return mapper.apply(this);
     }
diff --git a/subprojects/groovy-ginq/src/spec/doc/ginq-userguide.adoc b/subprojects/groovy-ginq/src/spec/doc/ginq-userguide.adoc
index 3eab36a..04bb22c 100644
--- a/subprojects/groovy-ginq/src/spec/doc/ginq-userguide.adoc
+++ b/subprojects/groovy-ginq/src/spec/doc/ginq-userguide.adoc
@@ -284,9 +284,9 @@ include::../test/org/apache/groovy/ginq/GinqTest.groovy[tags=ginq_grouping_09,in
 
 ===== Aggregate Functions
 GINQ provides some built-in aggregate functions, e.g.
-`count`, `min`, `max`, `sum`, `avg`, `median` and the most powerful function `agg`.
+`count`, `min`, `max`, `sum`, `avg`, `median`, `stdev` and the most powerful function `agg`.
 [NOTE]
-`count(...)`, `min(...)`, `max(...)`, `avg(...)` and `median(...)` just operate on non-`null` values,
+`count(...)`, `min(...)`, `max(...)`, `avg(...)`, `median(...)` and `stdev(...)` just operate on non-`null` values,
 and `count()` is similar to `count(*)` in SQL.
 [source, sql]
 ----
@@ -332,11 +332,17 @@ Also, we could apply the aggregate functions for the whole GINQ result, i.e. no
 ----
 include::../test/org/apache/groovy/ginq/GinqTest.groovy[tags=ginq_aggfunction_02,indent=0]
 ----
+
 [source, groovy]
 ----
 include::../test/org/apache/groovy/ginq/GinqTest.groovy[tags=ginq_aggfunction_01,indent=0]
 ----
 
+[source, groovy]
+----
+include::../test/org/apache/groovy/ginq/GinqTest.groovy[tags=ginq_aggfunction_03,indent=0]
+----
+
 ==== Sorting
 `orderby` is equivalent to SQL's `ORDER BY`
 
@@ -581,7 +587,7 @@ include::../test/org/apache/groovy/ginq/GinqTest.groovy[tags=ginq_winfunction_15
 include::../test/org/apache/groovy/ginq/GinqTest.groovy[tags=ginq_winfunction_11,indent=0]
 ----
 
-===== `min`, `max`, `count`, `sum`, `avg` and `median`
+===== `min`, `max`, `count`, `sum`, `avg`, `median` and `stdev`
 [source, groovy]
 ----
 include::../test/org/apache/groovy/ginq/GinqTest.groovy[tags=ginq_winfunction_22,indent=0]
@@ -632,6 +638,11 @@ include::../test/org/apache/groovy/ginq/GinqTest.groovy[tags=ginq_winfunction_29
 include::../test/org/apache/groovy/ginq/GinqTest.groovy[tags=ginq_winfunction_30,indent=0]
 ----
 
+[source, groovy]
+----
+include::../test/org/apache/groovy/ginq/GinqTest.groovy[tags=ginq_winfunction_35,indent=0]
+----
+
 === GINQ Tips
 ==== Row Number
 `_rn` is the implicit variable representing row number for each record in the result set. It starts with `0`
diff --git a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
index f0fc89c..4c2f405 100644
--- a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
+++ b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
@@ -4660,6 +4660,18 @@ class GinqTest {
     }
 
     @Test
+    void "testGinq - aggregate function - 13"() {
+        assertGinqScript '''
+// tag::ginq_aggfunction_03[]
+            assert [0.816496580927726] == GQ {
+                from n in [1, 2, 3]
+                select stdev(n)
+            }.toList()
+// end::ginq_aggfunction_03[]
+        '''
+    }
+
+    @Test
     void "testGinq - parallel - 1"() {
         assertGinqScript '''
 // tag::ginq_tips_08[]
@@ -5588,6 +5600,20 @@ class GinqTest {
         '''
     }
 
+    @Test
+    void "testGinq - window - 75"() {
+        assertGinqScript '''
+// tag::ginq_winfunction_35[]
+            assert [[1, 0.816496580927726],
+                    [2, 0.816496580927726],
+                    [3, 0.816496580927726]] == GQ {
+                from n in [1, 2, 3]
+                select n, (stdev(n) over())
+            }.toList()
+// end::ginq_winfunction_35[]
+        '''
+    }
+
     private static void assertGinqScript(String script) {
         String deoptimizedScript = script.replaceAll(/\bGQ\s*[{]/, 'GQ(optimize:false) {')
         List<String> scriptList = [deoptimizedScript, script]

Apache Groovy committer & PMC member

Blog: http://blog.sunlan.me
Twitter: @daniel_sun