# Copyright 2024 Niels Provos
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import inspect
import logging
import threading
import uuid
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
get_type_hints,
)
from pydantic import BaseModel, Field, PrivateAttr
from planai.utils import dict_dump_xml
if TYPE_CHECKING:
from .graph import Graph
from .provenance import ProvenanceChain
TaskStatusCallback = Callable[
[Dict, "ProvenanceChain", "TaskWorker", "Task", Optional[str], Optional[BaseModel]],
None,
]
[docs]
class Task(BaseModel):
"""Base class for all tasks in the system.
A Task represents a unit of work that can be processed by TaskWorkers. Tasks maintain
their execution provenance and can carry both public and private state.
Attributes:
_provenance (List[Tuple[str, int]]): List of worker name and ID tuples tracking task history
_input_provenance (List[Task]): List of input tasks that led to this task
_private_state (Dict[str, Any]): Private state storage
_retry_count (int): Number of times this task has been retried
_start_time (Optional[float]): When task processing started
_end_time (Optional[float]): When task processing completed
"""
_provenance: List[Tuple[str, int]] = PrivateAttr(default_factory=list)
_input_provenance: List[Task] = PrivateAttr(default_factory=list)
_private_state: Dict[str, Any] = PrivateAttr(default_factory=dict)
_retry_count: int = PrivateAttr(default=0)
_start_time: Optional[float] = PrivateAttr(default=None)
_end_time: Optional[float] = PrivateAttr(default=None)
@property
def name(self) -> str:
return self.__class__.__name__
@property
def retry_count(self) -> int:
"""
Read-only property to access the current retry count.
"""
return self._retry_count
[docs]
def copy_public(self, deep: bool = False) -> Task:
"""
Creates a copy of the Task instance, excluding private attributes. This is a safer way than model_copy()
of creating a new task from an existing one. Can be used in conjunction with enabling strict on a graph.
Args:
deep: Whether to perform a deep copy of the public fields.
Returns:
A new Task instance without the private attributes.
"""
return self.model_validate(
self.model_dump(
exclude_unset=True, exclude_defaults=True, exclude_none=True
)
)
[docs]
def increment_retry_count(self) -> None:
"""
Increments the retry count by 1.
"""
self._retry_count += 1
[docs]
def copy_provenance(self) -> List[Tuple[str, int]]:
return self._provenance.copy()
[docs]
def prefix(self, length: int) -> "ProvenanceChain":
"""
Get a prefix of specified length from task's provenance chain.
Args:
task (Task): The task object containing provenance information.
length (int): The desired length of the prefix to extract.
Returns:
ProvenanceChain: A tuple containing the first 'length' elements of the task's provenance chain.
"""
return tuple(self._provenance[:length])
def _add_worker_provenance(self, worker: "TaskWorker") -> Task:
provenance = worker.get_next_provenance()
self._provenance.append(provenance)
return self
def _add_input_provenance(self, input_task: Optional[Task]) -> Task:
# Copy provenance from input task if provided
if input_task is not None:
self._provenance = input_task.copy_provenance()
self._input_provenance = input_task.copy_input_provenance() + [
input_task.copy_public()
]
# merge private state
self._private_state.update(input_task._private_state)
else:
self._provenance = []
self._input_provenance = []
return self
[docs]
def add_private_state(self, key: str, value: Any) -> None:
self._private_state[key] = value
[docs]
def get_private_state(self, key: str) -> Any:
return self._private_state.pop(key, None)
[docs]
def model_dump_xml(self) -> str:
"""Formats the task as XML."""
return dict_dump_xml(self.model_dump(), root=self.name)
[docs]
def is_type(self, task_class: Type[Task]) -> bool:
"""
Check if this task is of the specified task class type.
Args:
task_class (Type[Task]): The task class type to check against.
Returns:
bool: True if the task is of the specified type, False otherwise.
"""
return type(self).__name__ == task_class.__name__
class WorkBufferContext:
def __init__(self, worker: TaskWorker, input_task=None):
self.worker: TaskWorker = worker
self.input_task: Task = input_task
self.work_buffer: List[Tuple[TaskWorker, Task]] = []
def __enter__(self):
self.worker._local.ctx = self
return self
def __exit__(self, exc_type, exc_value, traceback):
self._flush_work_buffer()
self.worker._local.ctx = None
def get_input_and_outputs(self):
return self.input_task, self.work_buffer
def _flush_work_buffer(self):
self.work_buffer.clear()
def add_to_buffer(self, consumer: "TaskWorker", task: Task):
self.work_buffer.append((consumer, task))
[docs]
class TaskWorker(BaseModel, ABC):
"""Base class for all task workers.
TaskWorker implements the core task processing functionality. Workers consume tasks,
process them, and can produce new tasks for downstream workers. The system ensures
type safety between workers and maintains execution provenance.
Attributes:
output_types (List[Type[Task]]): Types of tasks this worker can produce
num_retries (int): Number of times to retry failed tasks
_id (int): Internal worker ID counter
_consumers (Dict[Type[Task], List[TaskWorker]]): Registered downstream consumers
_graph (Optional[Graph]): Reference to containing workflow graph
_instance_id (UUID): Unique worker instance identifier
_local (threading.local): Thread-local storage
"""
output_types: List[Type[Task]] = Field(default_factory=list)
num_retries: int = Field(default=0)
_state_lock: threading.RLock = PrivateAttr(default_factory=threading.RLock)
_id: int = PrivateAttr(default=0)
_consumers: Dict[Type[Task], List[TaskWorker]] = PrivateAttr(default_factory=dict)
_graph: Optional["Graph"] = PrivateAttr(default=None)
_instance_id: uuid.UUID = PrivateAttr(default_factory=uuid.uuid4)
_local: threading.local = PrivateAttr(default_factory=threading.local)
_strict_checking: bool = PrivateAttr(default=False)
# allows an implementation of taskworker to associate state with a task
# this is useful for tracking the state of a task across multiple calls to consume_work
# for example when there are circular dependencies
_user_state: Dict["ProvenanceChain", Dict[str, Any]] = PrivateAttr(
default_factory=dict
)
[docs]
def __init__(self, **data):
super().__init__(**data)
def __hash__(self):
return hash(self._instance_id)
def __eq__(self, other):
if isinstance(other, TaskWorker):
return self._instance_id == other._instance_id
return False
[docs]
def work_buffer_context(self, input_task):
if input_task is None:
raise ValueError("Input task cannot be None")
return WorkBufferContext(self, input_task)
@property
def name(self) -> str:
"""
Returns the name of this worker class.
"""
return self.__class__.__name__
@property
def lock(self) -> threading.RLock:
"""
Returns the lock object for this worker.
:return: The lock object.
:rtype: threading.Lock
"""
return self._state_lock
[docs]
def set_graph(self, graph: "Graph"):
self._graph = graph
[docs]
def next(self, downstream: "TaskWorker"):
"""
Sets the dependency between the current task and the downstream task.
Parameters:
downstream (TaskWorker): The downstream task to set as a dependency.
Returns:
TaskWorker: The downstream task.
Raises:
ValueError: If the task has not been added to a Graph before setting dependencies.
"""
if self._graph is None:
raise ValueError(
"Task must be added to a Graph before setting dependencies"
)
self._graph.set_dependency(self, downstream)
return downstream
[docs]
def sink(
self,
output_type: Type[Task],
notify: Optional[Callable[[Dict[str, Any], Task], None]] = None,
):
"""
Designates the current task worker as a sink in the associated graph.
This method marks the current task worker as a sink, which means its output
will be collected and can be retrieved after the graph execution.
Parameters:
output_type (Task): The output type of the task to send to the sink.
notify: Optional callback function to be called when the sink is executed. It will receive
any metadata associated with the task and the task itself.
Raises:
ValueError: If the task worker is not associated with a graph.
Note:
- Only one sink can be set per graph. Attempting to set multiple sinks
will raise a RuntimeError from the graph's set_sink method.
- The task worker must have exactly one output type to be eligible as a sink.
- Results from the sink can be retrieved using the graph's get_output_tasks() method
after the graph has been executed.
See Also:
Graph.set_sink(): The underlying method called to set the sink.
Graph.get_output_tasks(): Method to retrieve results from the sink after graph execution.
"""
if self._graph is None:
raise ValueError(
"Task must be added to a Graph before setting a sink dependency"
)
self._graph.set_sink(self, output_type, notify)
[docs]
def trace(self, prefix: "ProvenanceChain"):
"""
Traces the provenance chain for a given prefix in the graph.
This method sets up a trace on a given prefix in the provenance chain. It will be visible
in the dispatcher dashboard.
Parameters:
-----------
prefix : ProvenanceChain
The prefix to trace. Must be a tuple representing a part of a task's provenance chain.
This is the sequence of task identifiers leading up to (but not including) the current task.
"""
assert self._graph is not None
self._graph.trace(prefix)
[docs]
def watch(self, prefix: "ProvenanceChain") -> bool:
"""
Watches for the completion of a specific provenance chain prefix in the task graph.
This method sets up a watch on a given prefix in the provenance chain. It will be notified
in its notify method when this prefix is no longer part of any active task's provenance, indicating
that all tasks with this prefix have been completed.
Parameters:
-----------
prefix : ProvenanceChain
The prefix to watch. Must be a tuple representing a part of a task's provenance chain.
This is the sequence of task identifiers leading up to (but not including) the current task.
Returns:
--------
bool
True if the watch was successfully added for the given prefix.
False if a watch for this prefix was already present.
Raises:
-------
ValueError
If the provided prefix is not a tuple.
"""
if not isinstance(prefix, tuple):
raise ValueError("Prefix must be a tuple")
assert self._graph is not None
return self._graph.watch(prefix, self)
[docs]
def unwatch(self, prefix: "ProvenanceChain") -> bool:
"""
Removes the watch for this task provenance to be completed in the graph.
Parameters:
worker (Type[Task]): The worker to unwatch.
Returns:
True if the watch was removed, False if the watch was not present.
"""
if not isinstance(prefix, tuple):
raise ValueError("Prefix must be a tuple")
assert self._graph is not None
return self._graph.unwatch(prefix, self)
[docs]
def get_worker_state(self, provenance: "ProvenanceChain") -> Dict[str, Any]:
"""
Allows a worker to store state for a specific provenance chain.
This is helpful when we expect the worker to be called multiple times with the same provenance chain.
For example, this can happen when there are circular dependencies in the graph. The most common case
is when a worker needs to ask for more data from upstream workers and sends a task back to them.
The state will be cleaned up automatically when the provenance chain is no longer active in the graph.
Returns:
Dict[str, Any]: The state of the task.
"""
with self.lock:
if provenance in self._user_state:
return self._user_state[provenance]
self.watch(provenance)
self._user_state[provenance] = {}
return self._user_state[provenance]
[docs]
def print(self, *args):
"""
Prints a message to the console.
Parameters:
*args: The message to print.
"""
assert self._graph is not None
self._graph.print(*args)
[docs]
def get_next_provenance(self) -> Tuple[str, int]:
"""
Gets the next provenance tuple for this worker.
Returns:
Tuple[str, int]: The next provenance tuple.
"""
with self.lock:
self._id += 1
return tuple((self.name, self._id))
[docs]
def remove_state(self, task: Task):
"""
Remove the state for a task.
Args:
task (Task): The task to remove the state for.
"""
provenance = task.prefix(1)
assert self._graph is not None
self._graph._provenance_tracker.remove_state(provenance)
[docs]
def get_state(self, task: Task) -> Dict[str, Any]:
"""
Get the state of a task.
Parameters:
task (Task): The task to get the state for.
Returns:
Dict[str, Any]: The state of the task.
"""
provenance = task.prefix(1)
assert self._graph is not None
return self._graph._provenance_tracker.get_state(provenance)
[docs]
def add_work(
self,
task: Task,
metadata: Optional[Dict] = None,
status_callback: Optional[TaskStatusCallback] = None,
) -> "ProvenanceChain":
if self._graph is None:
raise RuntimeError("Graph is not initialized.")
return self._graph.add_work(self, task, metadata, status_callback)
[docs]
def notify_status(
self,
task: Task,
message: Optional[str] = None,
object: Optional[BaseModel] = None,
):
"""Notify registered callback about task status updates."""
assert self._graph is not None
self._graph._provenance_tracker.notify_status(self, task, message, object)
def _pre_consume_work(self, task: Task):
with self.work_buffer_context(task):
self.consume_work(task)
[docs]
def init(self):
"""
Called when the graph is fully constructed and starts work.
"""
pass
[docs]
@abstractmethod
def consume_work(self, task: Task):
"""
Abstract method to consume a work item.
This method must be implemented by subclasses to define specific work consumption logic. It needs to be thread-safe
as it may be called concurrently by multiple threads.
Args:
task (Task): The work item to be consumed.
"""
pass
[docs]
def publish_work(
self,
task: Task,
input_task: Optional[Task],
consumer: Optional[TaskWorker] = None,
):
"""
Publish a work item.
This method handles the publishing of work items, including provenance tracking and consumer routing.
It is important that task is a newly created object and not a reference to an existing task. You can
use the model_copy method to create a new object with the same data.
Args:
task (Task): The work item to be published.
input_task (Task): The input task that led to this work item.
consumer (TaskWorker): The TaskWorker to publish to if there are multiple consumers for the task type.
Raises:
ValueError: If the task type is not in the output_types or if no consumer is registered for the task type.
"""
if type(task) not in self.output_types:
raise ValueError(
f"Task {self.name} cannot publish work of type {type(task).__name__}"
)
if self._strict_checking and (task._provenance or task._input_provenance):
raise ValueError(
"Cannot publish a task that has already been published. Use copy_public() to create a new task."
)
# the order of these operations is important as the first call erases the provenance
task._add_input_provenance(input_task)
task._add_worker_provenance(self)
# find the consumer for this task to publish to
consumer = self._get_consumer(task, worker=consumer)
logging.info(
"Worker %s publishing work to consumer %s with task type %s and provenance %s",
self.name,
consumer.name,
task.__class__.__name__,
task._provenance,
)
if self._graph and self._graph._dispatcher:
logging.info(
"Worker %s publishing work to buffer with consumer %s",
self.name,
consumer.name,
)
self._graph._dispatcher.add_work(consumer, task)
else:
self._dispatch_work(task)
# this requires that anything that might call publish_work is wrapped in a work_buffer_context
self._local.ctx.add_to_buffer(consumer, task)
def _get_consumer_by_name(self, task: Task, worker_name: str) -> TaskWorker:
# Verify if there is a consumer for the given task class
consumers = self._consumers.get(task.__class__)
if not consumers:
logging.error(
"%s: No consumer registered for %s, available consumers: %s",
self.name,
task.__class__.__name__,
[c.name for consumers in self._consumers.values() for c in consumers],
)
raise ValueError(f"No consumer registered for {task.__class__.__name__}")
for consumer in consumers:
if consumer.name == worker_name:
return consumer
raise ValueError(
f"No consumer registered for {task.__class__.__name__} with name {worker_name}"
)
def _get_consumer(
self, task: Task, worker: Optional[TaskWorker] = None
) -> TaskWorker:
# Verify if there is a consumer for the given task class
consumers = self._consumers.get(task.__class__)
if not consumers:
logging.error(
"%s: No consumer registered for %s, available consumers: %s",
self.name,
task.__class__.__name__,
[c.name for consumers in self._consumers.values() for c in consumers],
)
raise ValueError(f"No consumer registered for {task.__class__.__name__}")
if len(consumers) == 1:
if worker and consumers[0] != worker:
raise ValueError(
f"Worker {worker.name} is not a registered consumer for {task.__class__.__name__}"
)
return consumers[0]
if worker is None:
raise ValueError(
f"Multiple consumers registered for {task.__class__.__name__}, specify worker_name"
)
if worker not in consumers:
raise ValueError(
f"Worker {worker.name} is not a registered consumer for {task.__class__.__name__}"
)
return worker
[docs]
def completed(self):
"""Called to let the worker know that it has finished processing all work."""
pass
[docs]
def notify(self, prefix: "ProvenanceChain"):
"""Called to notify the worker that no tasks with this provenance prefix are remaining.
Children implementing this method need to call the base class method to ensure that the
state is fully removed.
"""
with self.lock:
if prefix in self._user_state:
del self._user_state[prefix]
logging.info("Removing watch for %s in %s", prefix, self.name)
# remove the watch - which may already have been removed by the child class
self.unwatch(prefix)
def _dispatch_work(self, task: Task):
consumer: Optional[TaskWorker] = self._get_consumer(task)
assert consumer is not None
consumer.consume_work(task)
[docs]
def validate_task(
self, task_cls: Type[Task], consumer: TaskWorker
) -> Tuple[bool, Optional[BaseException]]:
"""
Validate that a consumer can handle a specific Task type.
This method checks if the consumer has a properly typed consume_work method for the given task class.
Args:
task_cls (Type[Task]): The Task subclass to validate.
consumer (TaskWorker): The consumer to validate against.
Returns:
Tuple[bool, Optional[BaseException]]: A tuple containing a boolean indicating success and an exception if validation failed.
"""
# Ensure consumer has a consume_work method taking task_cls as parameter
first_param_type = consumer.get_task_class()
if first_param_type is not task_cls:
return False, TypeError(
f"TaskWorker {consumer.__class__.__name__} cannot consume tasks of type {task_cls.__name__}. It can only consume tasks of type {first_param_type.__name__}"
)
return True, None
[docs]
def get_task_class(self) -> Type[Task]:
"""
Get the Task subclass that this worker can consume.
This method checks for the task type provided in consume_work.
Returns:
Type[Task]: The Task subclass this worker can consume.
Raises:
AttributeError: If the consume method is not defined.
TypeError: If the consume method is not properly typed.
"""
consume_method = getattr(self, "consume_work", None)
if not consume_method:
raise AttributeError(
f"{self.__class__.__name__} has no method consume_work"
)
signature = inspect.signature(consume_method)
parameters = signature.parameters
if len(parameters) != 1:
raise TypeError(
f"Method consume_work in {self.__class__.__name__} must accept one parameter"
)
type_hints = get_type_hints(consume_method)
first_param_type = type_hints.get("task", None)
if not first_param_type:
raise TypeError(
f"consume_work method in {self.__class__.__name__} must have type hints"
)
return first_param_type
[docs]
def register_consumer(self, task_cls: Type[Task], consumer: "TaskWorker"):
"""
Register a consumer for a specific Task type.
This method performs type checking to ensure that the consumer can handle the specified Task type.
Args:
task_cls (Type[Task]): The Task subclass to register a consumer for.
consumer (TaskWorker): The consumer to register.
Raises:
TypeError: If task_cls is not a subclass of Task or if the consumer cannot handle the task type.
ValueError: If the task type is not in the output_types or if a consumer is already registered for the task type.
"""
# Ensure task_cls is a subclass of Task
if not issubclass(task_cls, Task):
raise TypeError(f"{task_cls.__name__} is not a subclass of Task")
success, error = self.validate_task(task_cls, consumer)
if not success:
assert error is not None
raise error
if task_cls not in self.output_types:
raise ValueError(
f"Downstream consumer {consumer.name} only accepts work of type {task_cls.__name__} but Worker {self.name} does not produce it"
)
if task_cls in self._consumers and consumer in self._consumers[task_cls]:
raise ValueError(f"Consumer for {task_cls.__name__} already registered")
if task_cls not in self._consumers:
self._consumers[task_cls] = []
self._consumers[task_cls].append(consumer)
# special cases like JoinedTask need to validate that their join_type is upstream
consumer._validate_connection()
def _validate_connection(self):
pass