From 7eeb715d7a3dfcd17c0b7a0a75a3ac7b9b2427df Mon Sep 17 00:00:00 2001 From: Antoine Balliet Date: Fri, 29 Sep 2023 10:30:52 +0200 Subject: [PATCH] feat(tasks): decouple job status check from trigger run (#65) closes #62 --- .../kestra/plugin/dbt/cloud/CheckStatus.java | 251 ++++++++++++++++++ .../kestra/plugin/dbt/cloud/TriggerRun.java | 166 +----------- .../plugin/dbt/cloud/CheckStatusTest.java | 63 +++++ 3 files changed, 325 insertions(+), 155 deletions(-) create mode 100644 src/main/java/io/kestra/plugin/dbt/cloud/CheckStatus.java create mode 100644 src/test/java/io/kestra/plugin/dbt/cloud/CheckStatusTest.java diff --git a/src/main/java/io/kestra/plugin/dbt/cloud/CheckStatus.java b/src/main/java/io/kestra/plugin/dbt/cloud/CheckStatus.java new file mode 100644 index 0000000..382ff72 --- /dev/null +++ b/src/main/java/io/kestra/plugin/dbt/cloud/CheckStatus.java @@ -0,0 +1,251 @@ +package io.kestra.plugin.dbt.cloud; + +import io.kestra.core.exceptions.IllegalVariableEvaluationException; +import io.kestra.core.models.annotations.Example; +import io.kestra.core.models.annotations.Plugin; +import io.kestra.core.models.annotations.PluginProperty; +import io.kestra.core.models.tasks.RunnableTask; +import io.kestra.core.runners.RunContext; +import io.kestra.core.utils.Await; +import io.kestra.plugin.dbt.ResultParser; +import io.kestra.plugin.dbt.cloud.models.JobStatusHumanizedEnum; +import io.kestra.plugin.dbt.cloud.models.RunResponse; +import io.kestra.plugin.dbt.cloud.models.Step; +import io.micronaut.core.type.Argument; +import io.micronaut.http.HttpMethod; +import io.micronaut.http.HttpRequest; +import io.micronaut.http.uri.UriTemplate; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.*; +import lombok.experimental.SuperBuilder; +import org.slf4j.Logger; + +import java.io.IOException; +import java.net.URI; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.time.Duration; +import java.util.*; + +import static io.kestra.core.utils.Rethrow.throwSupplier; +import static java.lang.Math.max; + +@SuperBuilder +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +@Schema( + title = "Check the status of a dbt Cloud job" +) +@Plugin( + examples = { + @Example( + code = { + "accountId: \"\"", + "token: \"\"", + "runId: \"\"", + } + ) + } +) +public class CheckStatus extends AbstractDbtCloud implements RunnableTask { + private static final List ENDED_STATUS = List.of( + JobStatusHumanizedEnum.ERROR, + JobStatusHumanizedEnum.CANCELLED, + JobStatusHumanizedEnum.SUCCESS + ); + + @Schema( + title = "The job run id to check the status for" + ) + @PluginProperty(dynamic = true) + Integer runId; + + + @Schema( + title = "Specify how often the task should poll for the job status" + ) + @PluginProperty(dynamic = false) + @Builder.Default + Duration pollFrequency = Duration.ofSeconds(5); + + @Schema( + title = "The maximum duration the task should poll for the job completion" + ) + @PluginProperty(dynamic = false) + @Builder.Default + Duration maxDuration = Duration.ofMinutes(60); + + @Builder.Default + @Schema( + title = "Parse run result", + description = "Whether to parse the run result to display the duration of each dbt node in the Gantt view" + ) + @PluginProperty + protected Boolean parseRunResults = true; + + @Builder.Default + @Getter(AccessLevel.NONE) + private transient List loggedStatus = new ArrayList<>(); + + @Builder.Default + @Getter(AccessLevel.NONE) + private transient Map loggedSteps = new HashMap<>(); + + @Override + public CheckStatus.Output run(RunContext runContext) throws Exception { + Logger logger = runContext.logger(); + + // wait for end + RunResponse finalRunResponse = Await.until( + throwSupplier(() -> { + Optional fetchRunResponse = fetchRunResponse( + runContext, + runId, + false + ); + + if (fetchRunResponse.isPresent()) { + logSteps(logger, fetchRunResponse.get()); + + // we rely on truncated logs to be sure + boolean allLogs = fetchRunResponse.get() + .getData() + .getRunSteps() + .stream() + .filter(step -> step.getTruncatedDebugLogs() != null) + .count() == + fetchRunResponse.get() + .getData() + .getRunSteps().size(); + + // ended + if (ENDED_STATUS.contains(fetchRunResponse.get().getData().getStatusHumanized()) && allLogs) { + return fetchRunResponse.get(); + } + } + + return null; + }), + this.pollFrequency, + this.maxDuration + ); + + // final response + logSteps(logger, finalRunResponse); + + if (!finalRunResponse.getData().getStatusHumanized().equals(JobStatusHumanizedEnum.SUCCESS)) { + throw new Exception("Failed run with status '" + finalRunResponse.getData().getStatusHumanized() + + "' after " + finalRunResponse.getData().getDurationHumanized() + ": " + finalRunResponse + ); + } + + Path runResultsArtifact = downloadArtifacts(runContext, runId, "run_results.json"); + Path manifestArtifact = downloadArtifacts(runContext, runId, "manifest.json"); + + if (this.parseRunResults) { + ResultParser.parseRunResult(runContext, runResultsArtifact.toFile()); + } + + return Output.builder() + .runResults(runResultsArtifact.toFile().exists() ? runContext.putTempFile(runResultsArtifact.toFile()) : null) + .manifest(manifestArtifact.toFile().exists() ? runContext.putTempFile(manifestArtifact.toFile()) : null) + .build(); + } + + private void logSteps(Logger logger, RunResponse runResponse) { + // status changed + if (!loggedStatus.contains(runResponse.getData().getStatusHumanized())) { + logger.debug("Status changed to '{}' after {}", + runResponse.getData().getStatusHumanized(), + runResponse.getData().getDurationHumanized() + ); + loggedStatus.add(runResponse.getData().getStatusHumanized()); + } + + // log steps + for (Step step : runResponse.getData().getRunSteps()) { + if (!step.getLogs().isEmpty()){ + if (!loggedSteps.containsKey(step.getId())){ + loggedSteps.put(step.getId(), 0); + } + + if (step.getLogs().length() > loggedSteps.get(step.getId())) { + for (String s : step.getLogs().substring(max(loggedSteps.get(step.getId()) -1, 0)).split("\n")) { + logger.info("[Step {}]: {}", step.getName(), s); + } + loggedSteps.put(step.getId(), step.getLogs().length()); + } + } + } + } + + private Optional fetchRunResponse(RunContext runContext, Integer id, Boolean debug) throws IllegalVariableEvaluationException { + return this + .request( + runContext, + HttpRequest + .create( + HttpMethod.GET, + UriTemplate + .of("/api/v2/accounts/{accountId}/runs/{runId}/" + + "?include_related=" + URLEncoder.encode( + "[\"trigger\",\"job\"," + (debug ? "\"debug_logs\"" : "") + ",\"run_steps\", \"environment\"]", + StandardCharsets.UTF_8 + ) + ) + .expand(Map.of( + "accountId", runContext.render(this.accountId), + "runId", id + )) + ), + Argument.of(RunResponse.class) + ) + .getBody(); + } + + private Path downloadArtifacts(RunContext runContext, Integer runId, String path) throws IllegalVariableEvaluationException, IOException { + String artifact = this + .request( + runContext, + HttpRequest + .create( + HttpMethod.GET, + UriTemplate + .of("/api/v2/accounts/{accountId}/runs/{runId}/artifacts/{path}") + .expand(Map.of( + "accountId", runContext.render(this.accountId), + "runId", runId, + "path", path + )) + ), + Argument.of(String.class) + ) + .getBody() + .orElseThrow(); + + Path tempFile = runContext.tempFile(".json"); + + Files.writeString(tempFile, artifact, StandardOpenOption.TRUNCATE_EXISTING); + + return tempFile; + } + + @Builder + @Getter + public static class Output implements io.kestra.core.models.tasks.Output { + @Schema( + title = "URI of the run result" + ) + private URI runResults; + + @Schema( + title = "URI of a manifest" + ) + private URI manifest; + } +} diff --git a/src/main/java/io/kestra/plugin/dbt/cloud/TriggerRun.java b/src/main/java/io/kestra/plugin/dbt/cloud/TriggerRun.java index 3a1399f..3822714 100644 --- a/src/main/java/io/kestra/plugin/dbt/cloud/TriggerRun.java +++ b/src/main/java/io/kestra/plugin/dbt/cloud/TriggerRun.java @@ -6,11 +6,7 @@ import io.kestra.core.models.annotations.PluginProperty; import io.kestra.core.models.tasks.RunnableTask; import io.kestra.core.runners.RunContext; -import io.kestra.core.utils.Await; -import io.kestra.plugin.dbt.ResultParser; -import io.kestra.plugin.dbt.cloud.models.JobStatusHumanizedEnum; import io.kestra.plugin.dbt.cloud.models.RunResponse; -import io.kestra.plugin.dbt.cloud.models.Step; import io.micronaut.core.type.Argument; import io.micronaut.http.HttpMethod; import io.micronaut.http.HttpRequest; @@ -21,20 +17,11 @@ import lombok.experimental.SuperBuilder; import org.slf4j.Logger; -import java.io.IOException; import java.net.URI; -import java.net.URLEncoder; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.StandardOpenOption; import java.time.Duration; import java.util.*; import javax.validation.constraints.NotNull; -import static io.kestra.core.utils.Rethrow.throwSupplier; -import static java.lang.Math.max; - @SuperBuilder @ToString @EqualsAndHashCode @@ -58,11 +45,6 @@ } ) public class TriggerRun extends AbstractDbtCloud implements RunnableTask { - private static final List ENDED_STATUS = List.of( - JobStatusHumanizedEnum.ERROR, - JobStatusHumanizedEnum.CANCELLED, - JobStatusHumanizedEnum.SUCCESS - ); @Schema( title = "Numeric ID of the job" @@ -163,14 +145,6 @@ public class TriggerRun extends AbstractDbtCloud implements RunnableTask loggedStatus = new ArrayList<>(); - - @Builder.Default - @Getter(AccessLevel.NONE) - private transient Map loggedSteps = new HashMap<>(); - @Override public TriggerRun.Output run(RunContext runContext) throws Exception { Logger logger = runContext.logger(); @@ -245,142 +219,24 @@ public TriggerRun.Output run(RunContext runContext) throws Exception { .build(); } - // wait for end - RunResponse finalRunResponse = Await.until( - throwSupplier(() -> { - Optional fetchRunResponse = fetchRunResponse( - runContext, - runId, - false - ); - - if (fetchRunResponse.isPresent()) { - logSteps(logger, fetchRunResponse.get()); - - // we rely on truncated logs to be sure - boolean allLogs = fetchRunResponse.get() - .getData() - .getRunSteps() - .stream() - .filter(step -> step.getTruncatedDebugLogs() != null) - .count() == - fetchRunResponse.get() - .getData() - .getRunSteps().size(); - - // ended - if (ENDED_STATUS.contains(fetchRunResponse.get().getData().getStatusHumanized()) && allLogs) { - return fetchRunResponse.get(); - } - } - - return null; - }), - this.pollFrequency, - this.maxDuration - ); + CheckStatus checkStatusJob = CheckStatus.builder() + .runId(runId) + .token(getToken()) + .accountId(getAccountId()) + .pollFrequency(getPollFrequency()) + .maxDuration(getMaxDuration()) + .parseRunResults(getParseRunResults()) + .build(); - // final response - logSteps(logger, finalRunResponse); - - if (!finalRunResponse.getData().getStatusHumanized().equals(JobStatusHumanizedEnum.SUCCESS)) { - throw new Exception("Failed run with status '" + finalRunResponse.getData().getStatusHumanized() + - "' after " + finalRunResponse.getData().getDurationHumanized() + ": " + finalRunResponse - ); - } - - Path runResultsArtifact = downloadArtifacts(runContext, runId, "run_results.json"); - Path manifestArtifact = downloadArtifacts(runContext, runId, "manifest.json"); - - if (this.parseRunResults) { - ResultParser.parseRunResult(runContext, runResultsArtifact.toFile()); - } + CheckStatus.Output runOutput = checkStatusJob.run(runContext); return Output.builder() .runId(runId) - .runResults(runResultsArtifact.toFile().exists() ? runContext.putTempFile(runResultsArtifact.toFile()) : null) - .manifest(manifestArtifact.toFile().exists() ? runContext.putTempFile(manifestArtifact.toFile()) : null) + .runResults(runOutput.getRunResults()) + .manifest(runOutput.getManifest()) .build(); } - private void logSteps(Logger logger, RunResponse runResponse) { - // status changed - if (!loggedStatus.contains(runResponse.getData().getStatusHumanized())) { - logger.debug("Status changed to '{}' after {}", - runResponse.getData().getStatusHumanized(), - runResponse.getData().getDurationHumanized() - ); - loggedStatus.add(runResponse.getData().getStatusHumanized()); - } - - // log steps - for (Step step : runResponse.getData().getRunSteps()) { - if (!step.getLogs().isEmpty()){ - if (!loggedSteps.containsKey(step.getId())){ - loggedSteps.put(step.getId(), 0); - } - - if (step.getLogs().length() > loggedSteps.get(step.getId())) { - for (String s : step.getLogs().substring(max(loggedSteps.get(step.getId()) -1, 0)).split("\n")) { - logger.info("[Step {}]: {}", step.getName(), s); - } - loggedSteps.put(step.getId(), step.getLogs().length()); - } - } - } - } - - private Optional fetchRunResponse(RunContext runContext, Integer id, Boolean debug) throws IllegalVariableEvaluationException { - return this - .request( - runContext, - HttpRequest - .create( - HttpMethod.GET, - UriTemplate - .of("/api/v2/accounts/{accountId}/runs/{runId}/" + - "?include_related=" + URLEncoder.encode( - "[\"trigger\",\"job\"," + (debug ? "\"debug_logs\"" : "") + ",\"run_steps\", \"environment\"]", - StandardCharsets.UTF_8 - ) - ) - .expand(Map.of( - "accountId", runContext.render(this.accountId), - "runId", id - )) - ), - Argument.of(RunResponse.class) - ) - .getBody(); - } - - private Path downloadArtifacts(RunContext runContext, Integer runId, String path) throws IllegalVariableEvaluationException, IOException { - String artifact = this - .request( - runContext, - HttpRequest - .create( - HttpMethod.GET, - UriTemplate - .of("/api/v2/accounts/{accountId}/runs/{runId}/artifacts/{path}") - .expand(Map.of( - "accountId", runContext.render(this.accountId), - "runId", runId, - "path", path - )) - ), - Argument.of(String.class) - ) - .getBody() - .orElseThrow(); - - Path tempFile = runContext.tempFile(".json"); - - Files.writeString(tempFile, artifact, StandardOpenOption.TRUNCATE_EXISTING); - - return tempFile; - } - @Builder @Getter public static class Output implements io.kestra.core.models.tasks.Output { diff --git a/src/test/java/io/kestra/plugin/dbt/cloud/CheckStatusTest.java b/src/test/java/io/kestra/plugin/dbt/cloud/CheckStatusTest.java new file mode 100644 index 0000000..65a736b --- /dev/null +++ b/src/test/java/io/kestra/plugin/dbt/cloud/CheckStatusTest.java @@ -0,0 +1,63 @@ +package io.kestra.plugin.dbt.cloud; + +import com.google.common.collect.ImmutableMap; +import io.kestra.core.runners.RunContext; +import io.kestra.core.runners.RunContextFactory; +import io.kestra.core.utils.IdUtils; +import io.micronaut.context.annotation.Value; +import io.micronaut.test.extensions.junit5.annotation.MicronautTest; +import jakarta.inject.Inject; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.time.Duration; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; + +@MicronautTest +class CheckStatusTest { + @Inject + private RunContextFactory runContextFactory; + + @Value("${dbt.cloud.account-id}") + private String accountId; + + @Value("${dbt.cloud.token}") + private String token; + + @Value("${dbt.cloud.job-id}") + private String jobId; + + @Test + @Disabled("Trial account can't trigger run through api") + void run() throws Exception { + + RunContext runContext = runContextFactory.of(ImmutableMap.of()); + + TriggerRun task = TriggerRun.builder() + .id(IdUtils.create()) + .type(TriggerRun.class.getName()) + .accountId(this.accountId) + .wait(false) + .token(this.token) + .jobId(this.jobId) + .build(); + + TriggerRun.Output runOutput = task.run(runContext); + + CheckStatus checkStatus = CheckStatus.builder() + .runId(runOutput.getRunId()) + .token(this.token) + .accountId(this.accountId) + .maxDuration(Duration.ofMinutes(60)) + .parseRunResults(false) + .build(); + + CheckStatus.Output checkStatusOutput = checkStatus.run(runContext); + + assertThat(checkStatusOutput, is(notNullValue())); + assertThat(checkStatusOutput.getManifest(), is(notNullValue())); + } +}