diff --git a/anyscale_provider/operators/anyscale.py b/anyscale_provider/operators/anyscale.py index 98f8923..d8126cb 100644 --- a/anyscale_provider/operators/anyscale.py +++ b/anyscale_provider/operators/anyscale.py @@ -123,7 +123,7 @@ def hook(self) -> AnyscaleHook: """Return an instance of the AnyscaleHook.""" return AnyscaleHook(conn_id=self.conn_id) - def execute(self, context: Context) -> None: + def execute(self, context: Context) -> str: """ Execute the job submission to Anyscale. @@ -180,8 +180,9 @@ def execute(self, context: Context) -> None: ) else: raise Exception(f"Unexpected state `{current_state}` for job_id `{self.job_id}`.") + return self.job_id - def execute_complete(self, context: Context, event: Any) -> None: + def execute_complete(self, context: Context, event: Any) -> str: """ Complete the execution of the job based on the trigger event. @@ -192,7 +193,7 @@ def execute_complete(self, context: Context, event: Any) -> None: :param event: The event data from the trigger. :return: None """ - current_job_id = event["job_id"] + current_job_id: str = event["job_id"] if event["state"] == JobState.FAILED: self.log.info(f"Anyscale job {current_job_id} ended with state: {event['state']}") @@ -200,6 +201,8 @@ def execute_complete(self, context: Context, event: Any) -> None: else: self.log.info(f"Anyscale job {current_job_id} completed with state: {event['state']}") + return current_job_id + class RolloutAnyscaleService(BaseOperator): """ diff --git a/tests/operators/test_anyscale_operators.py b/tests/operators/test_anyscale_operators.py index 59cfcc8..62f5eaf 100644 --- a/tests/operators/test_anyscale_operators.py +++ b/tests/operators/test_anyscale_operators.py @@ -51,7 +51,7 @@ def test_on_kill(self, mock_hook): @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook") def test_execute_complete(self, mock_hook): event = {"state": JobState.SUCCEEDED, "job_id": "123", "message": "Job completed successfully"} - self.assertEqual(self.operator.execute_complete(Context(), event), None) + self.assertEqual(self.operator.execute_complete(Context(), event), "123") @patch("anyscale_provider.operators.anyscale.SubmitAnyscaleJob.hook") def test_execute_complete_failure(self, mock_hook):