This is an automated email from the ASF dual-hosted git repository.
sunlan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/groovy.git The following commit(s) were added to refs/heads/master by this push: new 2ec2b9a Minor refactoring: simplify parallel querying 2ec2b9a is described below commit 2ec2b9adc6fae786c3237d6dc8478e7a45be722d Author: Daniel Sun <[hidden email]> AuthorDate: Sun Jan 10 14:56:36 2021 +0800 Minor refactoring: simplify parallel querying --- .../ginq/provider/collection/GinqAstWalker.groovy | 69 +++++++++++++--------- .../collection/runtime/QueryableCollection.java | 57 +++--------------- .../collection/runtime/QueryableHelper.groovy | 14 +++++ 3 files changed, 64 insertions(+), 76 deletions(-) diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy index 1182761..a3b8b64 100644 --- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy +++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy @@ -89,13 +89,13 @@ import static org.codehaus.groovy.ast.tools.GeneralUtils.fieldX import static org.codehaus.groovy.ast.tools.GeneralUtils.lambdaX import static org.codehaus.groovy.ast.tools.GeneralUtils.listX import static org.codehaus.groovy.ast.tools.GeneralUtils.localVarX +import static org.codehaus.groovy.ast.tools.GeneralUtils.nullX import static org.codehaus.groovy.ast.tools.GeneralUtils.param import static org.codehaus.groovy.ast.tools.GeneralUtils.params import static org.codehaus.groovy.ast.tools.GeneralUtils.propX import static org.codehaus.groovy.ast.tools.GeneralUtils.returnS import static org.codehaus.groovy.ast.tools.GeneralUtils.stmt import static org.codehaus.groovy.ast.tools.GeneralUtils.varX - /** * Visit AST of GINQ to generate target method calls for GINQ * @@ -719,6 +719,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable lambdaCode = namedListCtorCallExpression } + final boolean parallel = isParallel() lambdaCode = ((ListExpression) new ListExpression(Collections.singletonList(lambdaCode)).transformExpression(new ExpressionTransformer() { @Override Expression transform(Expression expression) { @@ -785,13 +786,9 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable } } - def parallel = isParallel() - final supplyAsyncLambdaParamName = "__salp${System.nanoTime()}" - def windowDefinitionFactoryMethodCallExpression = constructWindowDefinitionFactoryMethodCallExpression(expression, dataSourceExpression) - def rowNumberGetMethodCall = callX(varX(rowNumberName), 'get') Expression newObjectExpression = callX(wqVar, 'over', args( - callX(TUPLE_TYPE, 'tuple', args(currentRecordVar, parallel ? varX(supplyAsyncLambdaParamName) : rowNumberGetMethodCall)), + callX(TUPLE_TYPE, 'tuple', args(currentRecordVar, parallel ? varX(supplyAsyncLambdaParamName) : getRowNumberMethodCall())), windowDefinitionFactoryMethodCallExpression )) result = callX( @@ -799,22 +796,6 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable windowFunctionMethodCallExpression.methodAsString, args(argumentExpressionList) ) - - if (parallel) { - result = callX( - new ClassExpression(QUERYABLE_HELPER_TYPE), - "supplyAsync", - args( - lambdaX( - params( - param(DYNAMIC_TYPE, supplyAsyncLambdaParamName) - ), - stmt(result) - ), - rowNumberGetMethodCall - ) - ) - } } else { GinqAstWalker.this.collectSyntaxError(new GinqSyntaxError( "Unsupported window function: `${windowFunctionMethodCallExpression.methodAsString}`", @@ -839,13 +820,18 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable extra << callX(varX(rowNumberName), 'getAndIncrement') } - def selectMethodCallExpression = callXWithLambda(selectMethodReceiver, "select", dataSourceExpression, lambdaCode, extra, param(DYNAMIC_TYPE, getWindowQueryableName())) + def selectMethodCallExpression = callXWithLambda(selectMethodReceiver, "select", dataSourceExpression, parallel, lambdaCode, extra, param(DYNAMIC_TYPE, getWindowQueryableName())) currentGinqExpression.putNodeMetaData(__VISITING_SELECT, false) return selectMethodCallExpression } + private MethodCallExpression getRowNumberMethodCall() { + final rowNumberGetMethodCall = callX(varX(rowNumberName), 'get') + return rowNumberGetMethodCall + } + private MethodCallExpression constructWindowDefinitionFactoryMethodCallExpression(MethodCallExpression methodCallExpression, DataSourceExpression dataSourceExpression) { Expression classifierExpr = null Expression orderExpr = null @@ -1118,6 +1104,18 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable return name } + private int supplyAsyncLambdaParamNameSeq = 0 + private String getSupplyAsyncLambdaParamName() { + String name = (String) currentGinqExpression.getNodeMetaData(SUPPLY_ASYNC_LAMBDA_PARAM_NAME_PREFIX) + + if (!name) { + name = "${SUPPLY_ASYNC_LAMBDA_PARAM_NAME_PREFIX}${supplyAsyncLambdaParamNameSeq++}" + currentGinqExpression.putNodeMetaData(SUPPLY_ASYNC_LAMBDA_PARAM_NAME_PREFIX, name) + } + + return name + } + private MethodCallExpression getMetaDataMethodCall(String key) { callX(varX(metaDataMapName), "get", new ConstantExpression(key)) } @@ -1367,8 +1365,8 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable this.callXWithLambda(receiver, methodName, dataSourceExpression, lambdaCode, Collections.emptyList(), extraParams) } - private MethodCallExpression callXWithLambda(Expression receiver, String methodName, DataSourceExpression dataSourceExpression, Expression lambdaCode, List<Expression> extraLambdaCode, Parameter... extraParams) { - LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, lambdaCode, extraParams) + private MethodCallExpression callXWithLambda(Expression receiver, String methodName, DataSourceExpression dataSourceExpression, boolean async = false, Expression lambdaCode, List<Expression> extraLambdaCode, Parameter... extraParams) { + LambdaExpression lambdaExpression = constructLambdaExpression(dataSourceExpression, async, lambdaCode, extraParams) if (extraLambdaCode) { ((BlockStatement) lambdaExpression.code).getStatements().addAll(0, extraLambdaCode.collect { stmt(it) }) @@ -1377,7 +1375,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable callXWithLambda(receiver, methodName, lambdaExpression) } - private LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, Expression lambdaCode, Parameter... extraParams) { + private LambdaExpression constructLambdaExpression(DataSourceExpression dataSourceExpression, boolean async = false, Expression lambdaCode, Parameter... extraParams) { Tuple3<String, List<DeclarationExpression>, Expression> paramNameAndLambdaCode = correctVariablesOfLambdaExpression(dataSourceExpression, lambdaCode) List<DeclarationExpression> declarationExpressionList = paramNameAndLambdaCode.v2 @@ -1385,7 +1383,23 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable if (!visitingWindowFunction) { statementList.addAll(declarationExpressionList.stream().map(e -> stmt(e)).collect(Collectors.toList())) } - statementList.add(stmt(paramNameAndLambdaCode.v3)) + + def transformedLambdCode = paramNameAndLambdaCode.v3 + if (async) { + ArgumentListExpression argumentListExpression + argumentListExpression = rowNumberUsed + ? args(lambdaX(params(param(DYNAMIC_TYPE, supplyAsyncLambdaParamName)), + stmt(transformedLambdCode)), + rowNumberUsed ? getRowNumberMethodCall() : nullX()) + : args(lambdaX(stmt(transformedLambdCode))) + + transformedLambdCode = callX( + new ClassExpression(QUERYABLE_HELPER_TYPE), + "supplyAsync", + argumentListExpression + ) + } + statementList.add(stmt(transformedLambdCode)) def paramList = [param(DYNAMIC_TYPE, paramNameAndLambdaCode.v1)] if (extraParams) { @@ -1564,6 +1578,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable private static final String __META_DATA_MAP_NAME_PREFIX = '__metaDataMap_' private static final String __WINDOW_QUERYABLE_NAME = '__wq_' private static final String __ROW_NUMBER_NAME_PREFIX = '__rowNumber_' + private static final String SUPPLY_ASYNC_LAMBDA_PARAM_NAME_PREFIX = "__salp_" private static final String __SOURCE_RECORD = "__sourceRecord" private static final String __GROUP = "__group" private static final String MD_GROUP_NAME_LIST = "groupNameList" diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java index 4db5412..6e3e621 100644 --- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java +++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableCollection.java @@ -30,8 +30,6 @@ import org.codehaus.groovy.runtime.dgmimpl.NumberNumberMinus; import org.codehaus.groovy.runtime.typehandling.NumberMath; import java.io.Serializable; -import java.lang.reflect.Constructor; -import java.lang.reflect.Method; import java.math.BigDecimal; import java.math.RoundingMode; import java.util.ArrayList; @@ -279,55 +277,31 @@ class QueryableCollection<T> implements Queryable<T>, Serializable { @Override public <U> Queryable<U> select(BiFunction<? super T, ? super Queryable<? extends T>, ? extends U> mapper) { - String originalParallel = null; + final String originalParallel = QueryableHelper.getVar(PARALLEL); + QueryableHelper.setVar(PARALLEL, FALSE_STR); // ensure the row number is generated sequentially boolean useWindowFunction = TRUE_STR.equals(QueryableHelper.getVar(USE_WINDOW_FUNCTION)); if (useWindowFunction) { - originalParallel = QueryableHelper.getVar(PARALLEL); - QueryableHelper.setVar(PARALLEL, FALSE_STR); // ensure the row number is generated sequentially this.makeReusable(); } Stream<U> stream = this.stream().map((T t) -> mapper.apply(t, this)); - if (useWindowFunction && TRUE_STR.equals(originalParallel)) { + if (TRUE_STR.equals(originalParallel)) { // invoke `collect` to trigger the intermediate operator, which will create `CompletableFuture` instances stream = stream.collect(Collectors.toList()).parallelStream().map((U u) -> { - Function<? super U, ?> transform = e -> { - try { - return e instanceof CompletableFuture ? ((CompletableFuture) e).get() : e; - } catch (InterruptedException | ExecutionException ex) { - throw new GroovyRuntimeException(ex); - } - }; - - if (instanceOfNamedRecord(u)) { - Tuple record = (Tuple) u; - List<?> transformed = (List<?>) record.stream().map(transform).collect(Collectors.toList()); - try { - List<String> nameList = (List<String>) GET_NAME_SET_METHOD.invoke(record); - List<String> aliasList = (List<String>) GET_ALIAS_LIST_METHOD.invoke(record); - return (U) NAMED_RECORD_CONSTRUCTOR.newInstance(transformed, nameList, aliasList); - } catch (ReflectiveOperationException ex) { - throw new GroovyRuntimeException(ex); - } + try { + return (U) ((CompletableFuture) u).get(); + } catch (InterruptedException | ExecutionException ex) { + throw new GroovyRuntimeException(ex); } - - return (U) transform.apply(u); }); } - if (useWindowFunction) { - QueryableHelper.setVar(PARALLEL, originalParallel); - } + QueryableHelper.setVar(PARALLEL, originalParallel); return from(stream); } - private static <U> boolean instanceOfNamedRecord(U u) { - // workaround joint compilation issue - return (u instanceof Tuple) && NAMED_RECORD_CLASS_NAME.equals(u.getClass().getName()); - } - @Override public Queryable<T> distinct() { Stream<T> stream = this.stream().distinct(); @@ -736,24 +710,9 @@ class QueryableCollection<T> implements Queryable<T>, Serializable { private final Lock readLock = rwl.readLock(); private final Lock writeLock = rwl.writeLock(); private static final BigDecimal BD_TWO = BigDecimal.valueOf(2); - private static final String NAMED_RECORD_CLASS_NAME = "org.apache.groovy.ginq.provider.collection.runtime.NamedRecord"; - private static final Constructor<?> NAMED_RECORD_CONSTRUCTOR; - private static final Method GET_ALIAS_LIST_METHOD; - private static final Method GET_NAME_SET_METHOD; private static final String USE_WINDOW_FUNCTION = "useWindowFunction"; private static final String PARALLEL = "parallel"; private static final String TRUE_STR = "true"; private static final String FALSE_STR = "false"; private static final long serialVersionUID = -5067092453136522893L; - - static { - try { - final Class<?> namedRecordClass = Class.forName(NAMED_RECORD_CLASS_NAME); - NAMED_RECORD_CONSTRUCTOR = namedRecordClass.getConstructor(List.class, List.class, List.class); - GET_ALIAS_LIST_METHOD = namedRecordClass.getDeclaredMethod("getAliasList"); - GET_NAME_SET_METHOD = namedRecordClass.getMethod("getNameList"); - } catch (NoSuchMethodException | ClassNotFoundException ex) { - throw new GroovyRuntimeException(ex); - } - } } diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy index cfc6c77..bb458d5 100644 --- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy +++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/runtime/QueryableHelper.groovy @@ -24,6 +24,7 @@ import java.util.concurrent.CompletableFuture import java.util.concurrent.ExecutorService import java.util.concurrent.Executors import java.util.function.Function +import java.util.function.Supplier import java.util.stream.Collectors import static org.apache.groovy.ginq.provider.collection.runtime.Queryable.from @@ -74,6 +75,10 @@ class QueryableHelper { throw new TooManyValuesException("subquery returns more than one value: $list") } + static <U> CompletableFuture<U> supplyAsync(Supplier<U> supplier) { + return CompletableFuture.supplyAsync(supplier, THREAD_POOL) + } + static <T, U> CompletableFuture<U> supplyAsync(Function<? super T, ? extends U> function, T param) { return CompletableFuture.supplyAsync(() -> { function.apply(param) }, THREAD_POOL) } @@ -99,5 +104,14 @@ class QueryableHelper { private static final String PARALLEL = "parallel" private static final String TRUE_STR = "true" + static { + Runtime.addShutdownHook { + try { + THREAD_POOL.shutdownNow() + } catch (ignored) { + } + } + } + private QueryableHelper() {} } |
Free forum by Nabble | Edit this page |