Skip to content

Commit

Permalink
adding 'L2' to lattice constructors (#749)
Browse files Browse the repository at this point in the history
* adding 'L2' to lattice constrcutors

* updating JSON codegen.

* fixing typos.
  • Loading branch information
weinbe58 authored Oct 25, 2023
1 parent d786277 commit cdda67b
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "bloqade"
version = "0.9.0"
version = "0.10.0"
description = "Neutral atom software development kit"
authors = [
{name = "QuEra Computing Inc.", email = "[email protected]"},
Expand Down
16 changes: 6 additions & 10 deletions src/bloqade/codegen/common/assign_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,11 @@ def __init__(self, mapping: Dict[str, numbers.Real]):
self.scalar_visitor = AssignScalar(mapping)

def visit_chain(self, ast: location.Chain) -> location.Chain:
return location.Chain(
ast.shape[0], self.scalar_visitor.emit(ast.lattice_spacing)
)
return location.Chain(*ast.shape, self.scalar_visitor.emit(ast.lattice_spacing))

def visit_square(self, ast: location.Square) -> location.Square:
return location.Square(
ast.shape[0], self.scalar_visitor.emit(ast.lattice_spacing)
*ast.shape, self.scalar_visitor.emit(ast.lattice_spacing)
)

def visit_rectangular(self, ast: location.Rectangular) -> location.Rectangular:
Expand All @@ -183,23 +181,21 @@ def visit_rectangular(self, ast: location.Rectangular) -> location.Rectangular:

def visit_triangular(self, ast: location.Triangular) -> location.Triangular:
return location.Triangular(
ast.shape[0], self.scalar_visitor.emit(ast.lattice_spacing)
*ast.shape, self.scalar_visitor.emit(ast.lattice_spacing)
)

def visit_honeycomb(self, ast: location.Honeycomb) -> location.Honeycomb:
return location.Honeycomb(
ast.shape[0], self.scalar_visitor.emit(ast.lattice_spacing)
*ast.shape, self.scalar_visitor.emit(ast.lattice_spacing)
)

def visit_kagome(self, ast: location.Kagome) -> location.Kagome:
return location.Kagome(
ast.shape[0], self.scalar_visitor.emit(ast.lattice_spacing)
*ast.shape, self.scalar_visitor.emit(ast.lattice_spacing)
)

def visit_lieb(self, ast: location.Lieb) -> location.Lieb:
return location.Lieb(
ast.shape[0], self.scalar_visitor.emit(ast.lattice_spacing)
)
return location.Lieb(*ast.shape, self.scalar_visitor.emit(ast.lattice_spacing))

def visit_list_of_locations(
self, ast: location.ListOfLocations
Expand Down
15 changes: 10 additions & 5 deletions src/bloqade/codegen/common/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,23 +216,26 @@ def visit_honeycomb(self, ast: Honeycomb) -> Any:
return {
"honeycomb": {
"lattice_spacing": self.scalar_serializer.visit(ast.lattice_spacing),
"L": ast.shape[0],
"L1": ast.shape[0],
"L2": ast.shape[1],
}
}

def visit_kagome(self, ast: Kagome) -> Any:
return {
"kagome": {
"lattice_spacing": self.scalar_serializer.visit(ast.lattice_spacing),
"L": ast.shape[0],
"L1": ast.shape[0],
"L2": ast.shape[1],
}
}

def visit_lieb(self, ast: Lieb) -> Any:
return {
"lieb": {
"lattice_spacing": self.scalar_serializer.visit(ast.lattice_spacing),
"L": ast.shape[0],
"L1": ast.shape[0],
"L2": ast.shape[1],
}
}

Expand Down Expand Up @@ -263,15 +266,17 @@ def visit_square(self, ast: Square) -> Any:
"lattice_spacing": self.visit(
self.scalar_serializer.visit(ast.lattice_spacing)
),
"L": ast.shape[0],
"L1": ast.shape[0],
"L2": ast.shape[1],
}
}

def visit_triangular(self, ast: Triangular) -> Any:
return {
"triangular": {
"lattice_spacing": self.scalar_serializer.visit(ast.lattice_spacing),
"L": ast.shape[0],
"L1": ast.shape[0],
"L2": ast.shape[1],
}
}

Expand Down
59 changes: 45 additions & 14 deletions src/bloqade/ir/location/bravais.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ class Square(BoundedBravais):
- loc (0,0)
Args:
L (int): number of sites in linear direction. n_atoms = L * L.
L1 (int): number of sites in linear direction. n_atoms = L1 * L1.
L2 (Optional[int]): number of sites in direction a2.
n_atoms = L1 * L2, default is L1
lattice_spacing (Scalar, Real): lattice spacing. Defaults to 1.0.
Expand All @@ -233,8 +235,12 @@ class Square(BoundedBravais):
"""

@beartype
def __init__(self, L: int, lattice_spacing: ScalarType = 1.0):
super().__init__(L, L, lattice_spacing=lattice_spacing)
def __init__(
self, L1: int, L2: Optional[int] = None, lattice_spacing: ScalarType = 1.0
):
if L2 is None:
L2 = L1
super().__init__(L1, L2, lattice_spacing=lattice_spacing)

def __repr__(self):
return super().__repr__()
Expand Down Expand Up @@ -367,7 +373,10 @@ class Honeycomb(BoundedBravais):
Args:
L (int): number of sites in linear direction. n_atoms = L * L * 2.
L1 (int): number of unit cells in linear direction. n_atoms = L1 * L1 * 2.
L2 (Optional[int]): number of unit cells in direction a2.
n_atoms = L1 * L2 * 2, default is L1.
lattice_spacing (Scalar, Real):
lattice spacing. Defaults to 1.0.
Expand All @@ -379,8 +388,12 @@ class Honeycomb(BoundedBravais):
"""

@beartype
def __init__(self, L: int, lattice_spacing: ScalarType = 1.0):
super().__init__(L, L, lattice_spacing=lattice_spacing)
def __init__(
self, L1: int, L2: Optional[int] = None, lattice_spacing: ScalarType = 1.0
):
if L2 is None:
L2 = L1
super().__init__(L1, L2, lattice_spacing=lattice_spacing)

def __repr__(self):
return super().__repr__()
Expand All @@ -406,6 +419,8 @@ class Triangular(BoundedBravais):
Args:
L (int): number of sites in linear direction. n_atoms = L * L.
L2 (Optional[int]): number of sites along a2 direction,
n_atoms = L1 * L2, default is L1.
lattice_spacing (Scalar, Real):
lattice spacing. Defaults to 1.0.
Expand All @@ -417,8 +432,12 @@ class Triangular(BoundedBravais):
"""

@beartype
def __init__(self, L: int, lattice_spacing: ScalarType = 1.0):
super().__init__(L, L, lattice_spacing=lattice_spacing)
def __init__(
self, L1: int, L2: Optional[int] = None, lattice_spacing: ScalarType = 1.0
):
if L2 is None:
L2 = L1
super().__init__(L1, L2, lattice_spacing=lattice_spacing)

def __repr__(self):
return super().__repr__()
Expand All @@ -444,7 +463,9 @@ class Lieb(BoundedBravais):
- loc3 (0 ,0.5)
Args:
L (int): number of sites in linear direction. n_atoms = L * L.
L1 (int): number of unit cells in linear direction. n_atoms = 3* L1 * L1.
L2 (Optional[int]): number of unit cells along a2 direction,
n_atoms = 3 * L1 * L2, default is L1.
lattice_spacing (Scalar, Real):
lattice spacing. Defaults to 1.0.
Expand All @@ -456,8 +477,12 @@ class Lieb(BoundedBravais):
"""

@beartype
def __init__(self, L: int, lattice_spacing: ScalarType = 1.0):
super().__init__(L, L, lattice_spacing=lattice_spacing)
def __init__(
self, L1: int, L2: Optional[int] = None, lattice_spacing: ScalarType = 1.0
):
if L2 is None:
L2 = L1
super().__init__(L1, L2, lattice_spacing=lattice_spacing)

def __repr__(self):
return super().__repr__()
Expand All @@ -483,7 +508,9 @@ class Kagome(BoundedBravais):
- loc3 (0.25 ,0.25sqrt(3))
Args:
L (int): number of sites in linear direction. n_atoms = L * L.
L1 (int): number of sites in linear direction. n_atoms = 3 * L1 * L1.
L2 (Optional[int]): number of unit cells along a2 direction,
n_atoms = 3 * L1 * L2, default is L1.
lattice_spacing (Scalar, Real):
lattice spacing. Defaults to 1.0.
Expand All @@ -495,8 +522,12 @@ class Kagome(BoundedBravais):
"""

@beartype
def __init__(self, L: int, lattice_spacing: ScalarType = 1.0):
super().__init__(L, L, lattice_spacing=lattice_spacing)
def __init__(
self, L1: int, L2: Optional[int] = None, lattice_spacing: ScalarType = 1.0
):
if L2 is None:
L2 = L1
super().__init__(L1, L2, lattice_spacing=lattice_spacing)

def __repr__(self):
return super().__repr__()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lattice_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_square_pprint():
square_pprint_var_output = open(square_pprint_var_output_path, "r").read()

bl = cast("bl")
assert str(Square(7, bl)) == square_pprint_var_output
assert str(Square(7, lattice_spacing=bl)) == square_pprint_var_output


def test_rectangular_pprint():
Expand Down

0 comments on commit cdda67b

Please sign in to comment.