Source code for planai.graph_task

# Copyright 2024 Niels Provos
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Type

from pydantic import Field, PrivateAttr

from .graph import Graph
from .provenance import ProvenanceChain
from .task import Task, TaskWorker

PRIVATE_STATE_KEY = "_graph_task_private_state"


class SubGraphWorkerInternal(TaskWorker):
    graph: Graph = Field(
        ..., description="The graph that will be run as part of this TaskWorker"
    )
    entry_worker: TaskWorker = Field(..., description="The entry point of the graph")
    exit_worker: TaskWorker = Field(..., description="The exit point of the graph")
    _state: Dict[str, Any] = PrivateAttr(default_factory=dict)

    def model_post_init(self, _: Any):
        if len(self.exit_worker.output_types) != 1:
            raise ValueError(
                f"Exit worker must have exactly one output type, got {self.exit_worker.output_types}"
            )

        output_type = self.exit_worker.output_types[0]
        graph_task = self
        self.output_types = [output_type]

        class AdapterSinkWorker(TaskWorker):
            output_types: List[Type[Task]] = [output_type]

            def consume_work(self, task: output_type):
                # we need to move this task from the sub-graph to the main graph
                # first, we need to fix up the provenance so that it shows only the GraphTask as the parent
                old_task = task.get_private_state(PRIVATE_STATE_KEY)
                if old_task is None:
                    raise ValueError(
                        f"No task provenance found for {PRIVATE_STATE_KEY}"
                    )
                assert isinstance(task, Task)

                # remove any metadata and callbacks before changing the provenance
                self.remove_state(task)

                # remember the provenance of the task
                provenance = task.prefix(1)

                task._add_input_provenance(old_task)
                task._add_worker_provenance(graph_task)

                # then we can add it to the main graph using the right consumer
                assert graph_task.exit_worker._graph is not None
                assert graph_task.exit_worker._graph._dispatcher is not None
                consumer = graph_task._get_consumer(task)
                graph_task.exit_worker._graph._dispatcher.add_work(consumer, task)

                # we need to make sure that we remove the extra provenance only once
                # so we associated a refcount with the new task that was injected into
                # the sub-graph and then get a reference back to it here
                with graph_task.lock:
                    if provenance not in graph_task._state:
                        raise ValueError(
                            f"Task {provenance} does not have any associated state."
                        )
                    logging.debug(
                        "Subgraph is removing provenance for %s in %s",
                        provenance,
                        self.name,
                    )
                    task, remove_provenance = graph_task._state.get(provenance)
                    if remove_provenance:
                        graph_task._graph._provenance_tracker._remove_provenance(
                            task, self
                        )
                        graph_task._state[provenance] = (task, False)

        instance = AdapterSinkWorker()
        self.graph.add_workers(instance)
        self.graph.set_dependency(self.exit_worker, instance)

    def get_task_class(self) -> Type[Task]:
        # usually the entry task gets dynamically determined from consume_work but we are overriding it here
        return self.entry_worker.get_task_class()

    def init(self):
        # we need to install the graph dispatcher into the sub-graph
        assert self._graph is not None
        self.graph._dispatcher = self._graph._dispatcher
        self.graph.init_workers()

    def consume_work(self, task: Task):
        new_task = task.copy_public()

        # save the task provenance
        # xxx - we really just need to remember the provenance of the task
        old_task = task.model_copy(deep=True)
        new_task.add_private_state(PRIVATE_STATE_KEY, old_task)

        # artificially increase the provenance
        logging.debug("Adding additional provenance for %s", task._provenance)
        self._graph._provenance_tracker._add_provenance(old_task)

        # get any associated state and re-inject it
        state = self.get_state(task)
        metadata = state["metadata"]
        callback = state["callback"]

        def inject_state(provenance: ProvenanceChain):
            self.graph.watch(provenance, self)
            # We inject True to indicate that extra provenance is still associated with the task
            logging.debug("Injecting state for %s in %s", provenance, self.name)
            with self.lock:
                self._state[provenance] = (old_task, True)

        # and dispatch it to the sub-graph. this also sets the task provenance to InitialTaskWorker
        self.graph._add_work(
            self.entry_worker,
            new_task,
            metadata=metadata,
            status_callback=callback,
            provenance_callback=inject_state,
        )

    def notify(self, prefix: str):
        # if tasks of the sub-graph fail, it's possible that the SinkWorker never gets called
        # so we need to remove additional provenance here and clean up state
        logging.debug("Removing state for %s in %s", prefix, self.name)
        self.graph.unwatch(prefix, self)
        with self.lock:
            if prefix not in self._state:
                raise ValueError(f"Task {prefix} does not have any associated state.")
            task, remove_provenance = self._state.pop(prefix)

        if remove_provenance:
            logging.info(
                "Caught Subgraph execution error. Removing provenance for %s in %s",
                prefix,
                self.name,
            )
            self._graph._provenance_tracker._remove_provenance(task, self)

    def abort_work(self, provenance: ProvenanceChain):
        # map the provenance to the sub-graph provenance
        need_to_abort = []
        with self.lock:
            for prefix, (task, _) in self._state.items():
                if task._provenance[: len(provenance)] == list(provenance):
                    need_to_abort.append((prefix, provenance))
        # abort the mapped provenance in our graph
        for prefix, provenance in need_to_abort:
            logging.info(
                "Aborting %s in %s (mapped from %s)", prefix, self.name, provenance
            )
            self.graph.abort_work(prefix)


[docs] def SubGraphWorker( *, graph: Graph, entry_worker: TaskWorker, exit_worker: TaskWorker, name: str = "SubGraphWorker", ) -> SubGraphWorkerInternal: """ Factory function to create a SubGraphWorker that manages a subgraph within a larger PlanAI graph. Parameters ---------- name : str, optional Custom name for the SubGraphWorker class, defaults to "SubGraphWorker" graph : Graph The graph that will be run as part of this TaskWorker entry_worker : TaskWorker The entry point worker of the graph that receives initial tasks exit_worker : TaskWorker The exit point worker of the graph that produces final outputs Must have exactly one output type Returns ------- SubGraphWorkerInternal A new instance of SubGraphWorker with the specified configuration Raises ------ ValueError If the exit_worker has more than one output type """ # Create a new class with the custom name CustomClass = type(name, (SubGraphWorkerInternal,), {}) return CustomClass(graph=graph, entry_worker=entry_worker, exit_worker=exit_worker)
def main(): import argparse import random import time from typing import Type parser = argparse.ArgumentParser(description="Simple Graph Example") parser.add_argument( "--run-dashboard", action="store_true", help="Run the web dashboard" ) args = parser.parse_args() # Define custom Task classes class Task1WorkItem(Task): data: str class Task2WorkItem(Task): processed_data: str class Task3WorkItem(Task): final_result: str # Define custom TaskWorker classes class Task1Worker(TaskWorker): output_types: List[Type[Task]] = [Task2WorkItem] def consume_work(self, task: Task1WorkItem): self.print(f"Task1 consuming: {task.data}") time.sleep(random.uniform(0.2, 0.9)) for i in range(7): processed = f"Processed: {task.data.upper()} at iteration {i}" self.publish_work( Task2WorkItem(processed_data=processed), input_task=task ) class Task2Worker(TaskWorker): output_types: List[Type[Task]] = [Task3WorkItem] def consume_work(self, task: Task2WorkItem): self.print(f"Task2 consuming: {task.processed_data}") time.sleep(random.uniform(0.3, 2.5)) if args.run_dashboard: # demonstrate the ability to request user input if random.random() < 0.15: result, mime_type = self.request_user_input( task=task, instruction="Please provide a value", accepted_mime_types=["text/html", "application/pdf"], ) self.print( f"User input: {len(result) if result else None} ({mime_type})" ) for i in range(11): final = f"Final: {task.processed_data} at iteration {i}!" self.publish_work(Task3WorkItem(final_result=final), input_task=task) class Task3Worker(TaskWorker): output_types: List[Type[Task]] = [] def consume_work(self, task: Task3WorkItem): self.print(f"Task3 consuming: {task.final_result}") time.sleep(random.uniform(0.4, 1.2)) self.print("Workflow complete!") # Create Graph sub_graph = Graph(name="Simple SubGraph") # Create tasks task1 = Task1Worker() task2 = Task2Worker() # Add tasks to Graph sub_graph.add_workers(task1, task2) # Set dependencies sub_graph.set_dependency(task1, task2) # Create the graph task graph_task = SubGraphWorker( name="SubGraph", graph=sub_graph, entry_worker=task1, exit_worker=task2 ) # Create the final consumer task3 = Task3Worker() graph = Graph(name="Simple Graph") graph.add_workers(graph_task, task3) graph.set_dependency(graph_task, task3) # Prepare initial work item initial_work = [ (graph_task, Task1WorkItem(data="Hello, Graph v1!")), (graph_task, Task1WorkItem(data="Hello, Graph v2!")), ] # Run the Graph graph.run( initial_work, run_dashboard=args.run_dashboard, display_terminal=not args.run_dashboard, ) if __name__ == "__main__": main()