diff --git a/src/main/java/io/appium/java_client/AppiumFluentWait.java b/src/main/java/io/appium/java_client/AppiumFluentWait.java index 9061600a0..6361a5652 100644 --- a/src/main/java/io/appium/java_client/AppiumFluentWait.java +++ b/src/main/java/io/appium/java_client/AppiumFluentWait.java @@ -29,12 +29,16 @@ import java.time.Duration; import java.time.Instant; import java.util.List; +import java.util.Optional; import java.util.function.Function; import java.util.function.Supplier; public class AppiumFluentWait extends FluentWait { private Function pollingStrategy = null; + private static final Duration DEFAULT_POLL_DELAY_DURATION = Duration.ZERO; + private Duration pollDelay = DEFAULT_POLL_DELAY_DURATION; + public static class IterationInfo { /** * The current iteration number. @@ -98,6 +102,18 @@ public AppiumFluentWait(T input, Clock clock, Sleeper sleeper) { super(input, clock, sleeper); } + /** + * Sets how long to wait before starting to evaluate condition to be true. + * The default pollDelay is {@link #DEFAULT_POLL_DELAY_DURATION}. + * + * @param pollDelay The pollDelay duration. + * @return A self reference. + */ + public AppiumFluentWait withPollDelay(Duration pollDelay) { + this.pollDelay = pollDelay; + return this; + } + private B getPrivateFieldValue(String fieldName, Class fieldType) { return ReflectionHelpers.getPrivateFieldValue(FluentWait.class, this, fieldName, fieldType); } @@ -200,10 +216,19 @@ public AppiumFluentWait withPollingStrategy(Function */ @Override public V until(Function isTrue) { - final Instant start = getClock().instant(); - final Instant end = getClock().instant().plus(getTimeout()); - long iterationNumber = 1; + final var start = getClock().instant(); + // Adding pollDelay to end instant will allow to verify the condition for the expected timeout duration. + final var end = start.plus(getTimeout()).plus(pollDelay); + + return performIteration(isTrue, start, end); + } + + private V performIteration(Function isTrue, Instant start, Instant end) { + var iterationNumber = 1; Throwable lastException; + + sleepInterruptibly(pollDelay); + while (true) { try { V value = isTrue.apply(getInput()); @@ -222,32 +247,51 @@ public V until(Function isTrue) { // Check the timeout after evaluating the function to ensure conditions // with a zero timeout can succeed. if (end.isBefore(getClock().instant())) { - String message = getMessageSupplier() != null ? getMessageSupplier().get() : null; - - String timeoutMessage = String.format( - "Expected condition failed: %s (tried for %d second(s) with %s interval)", - message == null ? "waiting for " + isTrue : message, - getTimeout().getSeconds(), getInterval()); - throw timeoutException(timeoutMessage, lastException); + handleTimeoutException(lastException, isTrue); } - try { - Duration interval = getInterval(); - if (pollingStrategy != null) { - final IterationInfo info = new IterationInfo(iterationNumber, - Duration.between(start, getClock().instant()), getTimeout(), - interval); - interval = pollingStrategy.apply(info); - } - getSleeper().sleep(interval); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new WebDriverException(e); - } + var interval = getIntervalWithPollingStrategy(start, iterationNumber); + sleepInterruptibly(interval); + ++iterationNumber; } } + private void handleTimeoutException(Throwable lastException, Function isTrue) { + var message = Optional.ofNullable(getMessageSupplier()) + .map(Supplier::get) + .orElseGet(() -> "waiting for " + isTrue); + + var timeoutMessage = String.format( + "Expected condition failed: %s (tried for %s ms with an interval of %s ms)", + message, + getTimeout().toMillis(), + getInterval().toMillis() + ); + + throw timeoutException(timeoutMessage, lastException); + } + + private Duration getIntervalWithPollingStrategy(Instant start, long iterationNumber) { + var interval = getInterval(); + return Optional.ofNullable(pollingStrategy) + .map(strategy -> strategy.apply(new IterationInfo( + iterationNumber, + Duration.between(start, getClock().instant()), getTimeout(), interval))) + .orElse(interval); + } + + private void sleepInterruptibly(Duration duration) { + try { + if (!duration.isZero() && !duration.isNegative()) { + getSleeper().sleep(duration); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new WebDriverException(e); + } + } + protected Throwable propagateIfNotIgnored(Throwable e) { for (Class ignoredException : getIgnoredExceptions()) { if (ignoredException.isInstance(e)) {