Skip to content

Commit

Permalink
Improve docstrings in "Task/braket.py"
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhusion committed May 31, 2024
1 parent aada63a commit d70d1a2
Showing 1 changed file with 118 additions and 12 deletions.
130 changes: 118 additions & 12 deletions src/bloqade/task/braket.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
"""
Module for managing Braket tasks in the bloqade framework.
This module defines the BraketTask class, which represents a task that can be submitted
to a Braket backend. It includes methods for task submission, validation, fetching results,
checking status, and cancellation. Additionally, serialization and deserialization
functions are provided for the BraketTask class.
"""

import warnings
from dataclasses import dataclass, field
from beartype.typing import Dict, Optional, Any

from bloqade.builder.base import ParamType
from bloqade.serialize import Serializer
from bloqade.submission.ir.parallel import ParallelDecoder
Expand All @@ -7,9 +20,6 @@

from bloqade.submission.base import ValidationError
from bloqade.submission.ir.task_results import QuEraTaskResults, QuEraTaskStatusCode
import warnings
from dataclasses import dataclass, field
from beartype.typing import Dict, Optional, Any


## keep the old conversion for now,
Expand All @@ -18,6 +28,18 @@
@dataclass
@Serializer.register
class BraketTask(RemoteTask):
"""
Represents a Braket Task which can be submitted to a Braket backend.
Attributes:
task_id (Optional[str]): The ID of the task.
backend (BraketBackend): The backend to which the task is submitted.
task_ir (QuEraTaskSpecification): The task specification.
metadata (Dict[str, ParamType]): Metadata associated with the task.
parallel_decoder (Optional[ParallelDecoder]): Parallel decoder for the task.
task_result_ir (QuEraTaskResults): The result of the task.
"""

task_id: Optional[str]
backend: BraketBackend
task_ir: QuEraTaskSpecification
Expand All @@ -30,18 +52,34 @@ class BraketTask(RemoteTask):
)

def submit(self, force: bool = False) -> "BraketTask":
"""
Submits the task to the backend.
Args:
force (bool): Whether to force submission even if the task is already submitted.
Returns:
BraketTask: The current task instance.
Raises:
ValueError: If the task is already submitted and force is False.
"""
if not force:
if self.task_id is not None:
raise ValueError(
"the task is already submitted with %s" % (self.task_id)
)
raise ValueError(f"the task is already submitted with {self.task_id}")
self.task_id = self.backend.submit_task(self.task_ir)

self.task_result_ir = QuEraTaskResults(task_status=QuEraTaskStatusCode.Enqueued)

return self

def validate(self) -> str:
"""
Validates the task specification.
Returns:
str: An empty string if validation is successful,otherwise the validation error message.
"""
try:
self.backend.validate_task(self.task_ir)
except ValidationError as e:
Expand All @@ -50,7 +88,15 @@ def validate(self) -> str:
return ""

def fetch(self) -> "BraketTask":
# non-blocking, pull only when its completed
"""
Fetches the task results if the task is completed.
Returns:
BraketTask: The current task instance.
Raises:
ValueError: If the task is not yet submitted.
"""
if self.task_result_ir.task_status is QuEraTaskStatusCode.Unsubmitted:
raise ValueError("Task ID not found.")

Expand All @@ -72,7 +118,15 @@ def fetch(self) -> "BraketTask":
return self

def pull(self) -> "BraketTask":
# blocking, force pulling, even its completed
"""
Forces pulling the task results.
Returns:
BraketTask: The current task instance.
Raises:
ValueError: If the task ID is not found.
"""
if self.task_id is None:
raise ValueError("Task ID not found.")

Expand All @@ -81,8 +135,12 @@ def pull(self) -> "BraketTask":
return self

def result(self) -> QuEraTaskResults:
# blocking, caching
"""
Gets the task results, blocking until results are available.
Returns:
QuEraTaskResults: The task results.
"""
if self.task_result_ir is None:
pass
else:
Expand All @@ -95,12 +153,27 @@ def result(self) -> QuEraTaskResults:
return self.task_result_ir

def status(self) -> QuEraTaskStatusCode:
"""
Gets the status of the task.
Returns:
QuEraTaskStatusCode: The status of the task.
"""
if self.task_id is None:
return QuEraTaskStatusCode.Unsubmitted

return self.backend.task_status(self.task_id)

def cancel(self) -> None:
"""
Cancels the task if it is currently submitted.
Returns:
None
Raises:
Warning: If the task ID is not found.
"""
if self.task_id is None:
warnings.warn("Cannot cancel task, missing task id.")
return
Expand All @@ -109,16 +182,34 @@ def cancel(self) -> None:

@property
def nshots(self):
"""
Gets the number of shots specified for the task.
Returns:
int: The number of shots.
"""
return self.task_ir.nshots

def _geometry(self) -> Geometry:
"""
Gets the geometry of the task lattice.
Returns:
Geometry: The geometry of the task lattice.
"""
return Geometry(
sites=self.task_ir.lattice.sites,
filling=self.task_ir.lattice.filling,
parallel_decoder=self.parallel_decoder,
)

def _result_exists(self) -> bool:
"""
Checks if the task results exist.
Returns:
bool: True if the task results exist and are completed, otherwise False.
"""
if self.task_result_ir is None:
return False
else:
Expand All @@ -127,12 +218,18 @@ def _result_exists(self) -> bool:
else:
return False

# def submit_no_task_id(self) -> "HardwareTaskShotResults":
# return HardwareTaskShotResults(hardware_task=self)


@BraketTask.set_serializer
def _serialize(obj: BraketTask) -> Dict[str, Any]:
"""
Serializes the BraketTask instance to a dictionary.
Args:
obj (BraketTask): The task instance to serialize.
Returns:
Dict[str, Any]: The serialized dictionary representation of the task.
"""
return {
"task_id": obj.task_id,
"backend": obj.backend.dict(),
Expand All @@ -147,6 +244,15 @@ def _serialize(obj: BraketTask) -> Dict[str, Any]:

@BraketTask.set_deserializer
def _deserialize(d: Dict[str, Any]) -> BraketTask:
"""
Deserializes a dictionary to a BraketTask instance.
Args:
d (Dict[str, Any]): The dictionary to deserialize.
Returns:
BraketTask: The deserialized task instance.
"""
d["backend"] = BraketBackend(**d["backend"])
d["task_ir"] = QuEraTaskSpecification(**d["task_ir"])
d["parallel_decoder"] = (
Expand Down

0 comments on commit d70d1a2

Please sign in to comment.