From cdda67b2aceaaa314854f26f395dc6378063addd Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Wed, 25 Oct 2023 13:00:34 -0400 Subject: [PATCH] adding 'L2' to lattice constructors (#749) * adding 'L2' to lattice constrcutors * updating JSON codegen. * fixing typos. --- pyproject.toml | 2 +- .../codegen/common/assign_variables.py | 16 ++--- src/bloqade/codegen/common/json.py | 15 +++-- src/bloqade/ir/location/bravais.py | 59 ++++++++++++++----- tests/test_lattice_pprint.py | 2 +- 5 files changed, 63 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6e68755c0..508589f45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "bloqade" -version = "0.9.0" +version = "0.10.0" description = "Neutral atom software development kit" authors = [ {name = "QuEra Computing Inc.", email = "info@quera.com"}, diff --git a/src/bloqade/codegen/common/assign_variables.py b/src/bloqade/codegen/common/assign_variables.py index 83c1b5a16..402e2b0ee 100644 --- a/src/bloqade/codegen/common/assign_variables.py +++ b/src/bloqade/codegen/common/assign_variables.py @@ -164,13 +164,11 @@ def __init__(self, mapping: Dict[str, numbers.Real]): self.scalar_visitor = AssignScalar(mapping) def visit_chain(self, ast: location.Chain) -> location.Chain: - return location.Chain( - ast.shape[0], self.scalar_visitor.emit(ast.lattice_spacing) - ) + return location.Chain(*ast.shape, self.scalar_visitor.emit(ast.lattice_spacing)) def visit_square(self, ast: location.Square) -> location.Square: return location.Square( - ast.shape[0], self.scalar_visitor.emit(ast.lattice_spacing) + *ast.shape, self.scalar_visitor.emit(ast.lattice_spacing) ) def visit_rectangular(self, ast: location.Rectangular) -> location.Rectangular: @@ -183,23 +181,21 @@ def visit_rectangular(self, ast: location.Rectangular) -> location.Rectangular: def visit_triangular(self, ast: location.Triangular) -> location.Triangular: return location.Triangular( - ast.shape[0], self.scalar_visitor.emit(ast.lattice_spacing) + *ast.shape, self.scalar_visitor.emit(ast.lattice_spacing) ) def visit_honeycomb(self, ast: location.Honeycomb) -> location.Honeycomb: return location.Honeycomb( - ast.shape[0], self.scalar_visitor.emit(ast.lattice_spacing) + *ast.shape, self.scalar_visitor.emit(ast.lattice_spacing) ) def visit_kagome(self, ast: location.Kagome) -> location.Kagome: return location.Kagome( - ast.shape[0], self.scalar_visitor.emit(ast.lattice_spacing) + *ast.shape, self.scalar_visitor.emit(ast.lattice_spacing) ) def visit_lieb(self, ast: location.Lieb) -> location.Lieb: - return location.Lieb( - ast.shape[0], self.scalar_visitor.emit(ast.lattice_spacing) - ) + return location.Lieb(*ast.shape, self.scalar_visitor.emit(ast.lattice_spacing)) def visit_list_of_locations( self, ast: location.ListOfLocations diff --git a/src/bloqade/codegen/common/json.py b/src/bloqade/codegen/common/json.py index 4778e8a66..1057064a7 100644 --- a/src/bloqade/codegen/common/json.py +++ b/src/bloqade/codegen/common/json.py @@ -216,7 +216,8 @@ def visit_honeycomb(self, ast: Honeycomb) -> Any: return { "honeycomb": { "lattice_spacing": self.scalar_serializer.visit(ast.lattice_spacing), - "L": ast.shape[0], + "L1": ast.shape[0], + "L2": ast.shape[1], } } @@ -224,7 +225,8 @@ def visit_kagome(self, ast: Kagome) -> Any: return { "kagome": { "lattice_spacing": self.scalar_serializer.visit(ast.lattice_spacing), - "L": ast.shape[0], + "L1": ast.shape[0], + "L2": ast.shape[1], } } @@ -232,7 +234,8 @@ def visit_lieb(self, ast: Lieb) -> Any: return { "lieb": { "lattice_spacing": self.scalar_serializer.visit(ast.lattice_spacing), - "L": ast.shape[0], + "L1": ast.shape[0], + "L2": ast.shape[1], } } @@ -263,7 +266,8 @@ def visit_square(self, ast: Square) -> Any: "lattice_spacing": self.visit( self.scalar_serializer.visit(ast.lattice_spacing) ), - "L": ast.shape[0], + "L1": ast.shape[0], + "L2": ast.shape[1], } } @@ -271,7 +275,8 @@ def visit_triangular(self, ast: Triangular) -> Any: return { "triangular": { "lattice_spacing": self.scalar_serializer.visit(ast.lattice_spacing), - "L": ast.shape[0], + "L1": ast.shape[0], + "L2": ast.shape[1], } } diff --git a/src/bloqade/ir/location/bravais.py b/src/bloqade/ir/location/bravais.py index 2a36ab649..886014431 100644 --- a/src/bloqade/ir/location/bravais.py +++ b/src/bloqade/ir/location/bravais.py @@ -222,7 +222,9 @@ class Square(BoundedBravais): - loc (0,0) Args: - L (int): number of sites in linear direction. n_atoms = L * L. + L1 (int): number of sites in linear direction. n_atoms = L1 * L1. + L2 (Optional[int]): number of sites in direction a2. + n_atoms = L1 * L2, default is L1 lattice_spacing (Scalar, Real): lattice spacing. Defaults to 1.0. @@ -233,8 +235,12 @@ class Square(BoundedBravais): """ @beartype - def __init__(self, L: int, lattice_spacing: ScalarType = 1.0): - super().__init__(L, L, lattice_spacing=lattice_spacing) + def __init__( + self, L1: int, L2: Optional[int] = None, lattice_spacing: ScalarType = 1.0 + ): + if L2 is None: + L2 = L1 + super().__init__(L1, L2, lattice_spacing=lattice_spacing) def __repr__(self): return super().__repr__() @@ -367,7 +373,10 @@ class Honeycomb(BoundedBravais): Args: - L (int): number of sites in linear direction. n_atoms = L * L * 2. + L1 (int): number of unit cells in linear direction. n_atoms = L1 * L1 * 2. + L2 (Optional[int]): number of unit cells in direction a2. + n_atoms = L1 * L2 * 2, default is L1. + lattice_spacing (Scalar, Real): lattice spacing. Defaults to 1.0. @@ -379,8 +388,12 @@ class Honeycomb(BoundedBravais): """ @beartype - def __init__(self, L: int, lattice_spacing: ScalarType = 1.0): - super().__init__(L, L, lattice_spacing=lattice_spacing) + def __init__( + self, L1: int, L2: Optional[int] = None, lattice_spacing: ScalarType = 1.0 + ): + if L2 is None: + L2 = L1 + super().__init__(L1, L2, lattice_spacing=lattice_spacing) def __repr__(self): return super().__repr__() @@ -406,6 +419,8 @@ class Triangular(BoundedBravais): Args: L (int): number of sites in linear direction. n_atoms = L * L. + L2 (Optional[int]): number of sites along a2 direction, + n_atoms = L1 * L2, default is L1. lattice_spacing (Scalar, Real): lattice spacing. Defaults to 1.0. @@ -417,8 +432,12 @@ class Triangular(BoundedBravais): """ @beartype - def __init__(self, L: int, lattice_spacing: ScalarType = 1.0): - super().__init__(L, L, lattice_spacing=lattice_spacing) + def __init__( + self, L1: int, L2: Optional[int] = None, lattice_spacing: ScalarType = 1.0 + ): + if L2 is None: + L2 = L1 + super().__init__(L1, L2, lattice_spacing=lattice_spacing) def __repr__(self): return super().__repr__() @@ -444,7 +463,9 @@ class Lieb(BoundedBravais): - loc3 (0 ,0.5) Args: - L (int): number of sites in linear direction. n_atoms = L * L. + L1 (int): number of unit cells in linear direction. n_atoms = 3* L1 * L1. + L2 (Optional[int]): number of unit cells along a2 direction, + n_atoms = 3 * L1 * L2, default is L1. lattice_spacing (Scalar, Real): lattice spacing. Defaults to 1.0. @@ -456,8 +477,12 @@ class Lieb(BoundedBravais): """ @beartype - def __init__(self, L: int, lattice_spacing: ScalarType = 1.0): - super().__init__(L, L, lattice_spacing=lattice_spacing) + def __init__( + self, L1: int, L2: Optional[int] = None, lattice_spacing: ScalarType = 1.0 + ): + if L2 is None: + L2 = L1 + super().__init__(L1, L2, lattice_spacing=lattice_spacing) def __repr__(self): return super().__repr__() @@ -483,7 +508,9 @@ class Kagome(BoundedBravais): - loc3 (0.25 ,0.25sqrt(3)) Args: - L (int): number of sites in linear direction. n_atoms = L * L. + L1 (int): number of sites in linear direction. n_atoms = 3 * L1 * L1. + L2 (Optional[int]): number of unit cells along a2 direction, + n_atoms = 3 * L1 * L2, default is L1. lattice_spacing (Scalar, Real): lattice spacing. Defaults to 1.0. @@ -495,8 +522,12 @@ class Kagome(BoundedBravais): """ @beartype - def __init__(self, L: int, lattice_spacing: ScalarType = 1.0): - super().__init__(L, L, lattice_spacing=lattice_spacing) + def __init__( + self, L1: int, L2: Optional[int] = None, lattice_spacing: ScalarType = 1.0 + ): + if L2 is None: + L2 = L1 + super().__init__(L1, L2, lattice_spacing=lattice_spacing) def __repr__(self): return super().__repr__() diff --git a/tests/test_lattice_pprint.py b/tests/test_lattice_pprint.py index 359911175..0a3ca1aa6 100644 --- a/tests/test_lattice_pprint.py +++ b/tests/test_lattice_pprint.py @@ -101,7 +101,7 @@ def test_square_pprint(): square_pprint_var_output = open(square_pprint_var_output_path, "r").read() bl = cast("bl") - assert str(Square(7, bl)) == square_pprint_var_output + assert str(Square(7, lattice_spacing=bl)) == square_pprint_var_output def test_rectangular_pprint():