Skip to content

Commit

Permalink
Merge pull request #67 from qchempku2017/main
Browse files Browse the repository at this point in the history
TST: fixes test solver issues.
  • Loading branch information
lbluque authored Jul 10, 2024
2 parents 3872e82 + bec0854 commit d6effd4
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 101 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ test = [
"polytope",
"cvxpy",
"gurobipy",
"pyscipopt",
"pyscipopt==4.3.0",
]
dev = [
"pre-commit >=2.12.1",
Expand Down
43 changes: 13 additions & 30 deletions tests/test_capp/test_solver/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,14 @@ def solver_test_prim():
lat,
[
{
"Li+": 1 / 6,
"Mn2+": 1 / 6,
"Mn3+": 1 / 6,
"Mn4+": 1 / 6,
"Ti4+": 1 / 6,
"Li+": 1 / 5,
"Mn2+": 1 / 5,
"Mn4+": 1 / 5,
"Ti4+": 1 / 5,
},
{
"O2-": 1 / 3,
"O-": 1 / 3,
"F-": 1 / 3,
"O2-": 1 / 2,
"F-": 1 / 2,
},
],
[[0, 0, 0], [0.5, 0.5, 0.5]],
Expand All @@ -42,7 +40,7 @@ def solver_test_prim():
def solver_test_subspace(solver_test_prim, request):
# Use sinusoid basis to test if useful.
space = ClusterSubspace.from_cutoffs(
solver_test_prim, {2: 3, 3: 2.1}, basis=request.param
solver_test_prim, {2: 2.1, 3: 2.1}, basis=request.param
)
space.add_external_term(EwaldTerm())
return space
Expand Down Expand Up @@ -80,19 +78,17 @@ def orig_ensemble(solver_test_expansion, request):
chemical_potentials = {
"Li+": np.random.normal(),
"Mn2+": np.random.normal(),
"Mn3+": np.random.normal(),
"Mn4+": np.random.normal(),
"Ti4+": np.random.normal(),
"Vacancy": np.random.normal(),
"O2-": np.random.normal(),
"O-": np.random.normal(),
"F-": np.random.normal(),
}
else:
chemical_potentials = None
return Ensemble.from_cluster_expansion(
solver_test_expansion,
np.diag([5, 2, 2]),
np.diag([3, 2, 2]),
request.param[1],
chemical_potentials=chemical_potentials,
)
Expand Down Expand Up @@ -122,41 +118,30 @@ def solver_test_ensemble(orig_ensemble, solver_test_initial_occupancy):
# Split the cation sublattice.
new_ensemble = Ensemble.from_dict(orig_ensemble.as_dict())

# Manually restrict 3 random li sites, 1 Vacancy site.
# Manually restrict 1 random li sites, 2 Vacancy site.
cation_sites = new_ensemble.sublattices[cation_id].sites
li_code = new_ensemble.sublattices[cation_id].encoding[
new_ensemble.sublattices[cation_id].species.index(Species("Li", 1))
]
li_sites = new_ensemble.sublattices[cation_id].sites[
np.where(solver_test_initial_occupancy[cation_sites] == li_code)[0]
]
li_restricts = np.random.choice(li_sites, size=3, replace=False)
li_restricts = np.random.choice(li_sites, size=1, replace=False)
new_ensemble.restrict_sites(li_restricts)
va_code = new_ensemble.sublattices[cation_id].encoding[
new_ensemble.sublattices[cation_id].species.index(Vacancy())
]
va_sites = new_ensemble.sublattices[cation_id].sites[
np.where(solver_test_initial_occupancy[cation_sites] == va_code)[0]
]
va_restricts = np.random.choice(va_sites, size=1, replace=False)
va_restricts = np.random.choice(va_sites, size=2, replace=False)
new_ensemble.restrict_sites(va_restricts)

# Manually restrict 2 random O2- sites.
anion_sites = new_ensemble.sublattices[anion_id].sites
o2_code = new_ensemble.sublattices[anion_id].encoding[
new_ensemble.sublattices[anion_id].species.index(Species("O", -2))
]
o2_sites = new_ensemble.sublattices[anion_id].sites[
np.where(solver_test_initial_occupancy[anion_sites] == o2_code)[0]
]
o2_restricts = np.random.choice(o2_sites, size=2, replace=False)
new_ensemble.restrict_sites(o2_restricts)

ca_partitions = [
[Species("Li", 1), Vacancy()],
[Species("Mn", 2), Species("Mn", 3), Species("Mn", 4), Species("Ti", 4)],
[Species("Mn", 2), Species("Mn", 4), Species("Ti", 4)],
]
an_partitions = [[Species("O", -2), Species("O", -1)], [Species("F", -1)]]
an_partitions = [[Species("O", -2)], [Species("F", -1)]]
new_ensemble.split_sublattice_by_species(
cation_id, solver_test_initial_occupancy, ca_partitions
)
Expand All @@ -172,7 +157,5 @@ def solver_test_ensemble(orig_ensemble, solver_test_initial_occupancy):
assert site in new_ensemble.restricted_sites
for site in va_restricts:
assert site in new_ensemble.restricted_sites
for site in o2_restricts:
assert site in new_ensemble.restricted_sites

return new_ensemble
22 changes: 11 additions & 11 deletions tests/test_capp/test_solver/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,16 @@ def test_comp_space_constraints(solver_test_ensemble, solver_test_initial_occupa
variable_indices,
solver_test_ensemble.processor.structure,
other_constraints=[
"Mn4+ == 1", # Broken when force_flip, kept when canonical. 2nd.
"Ti4+ = 2", # Broken when force_flip, kept when canonical. 3rd.
"Mn4+ + Mn3+ + Mn2+ >= 7", # Never true. 4th.
"Mn3+ + Mn2+ <= 3", # Always true. 5th.
"Mn4+ == 2", # Broken when force_flip, kept when canonical. 2nd.
"Mn4+ + Mn2+ >= 7", # Never true. 3rd.
"Mn4+ + Mn2+ <= 6", # Always true. 4th.
"0 >= -1", # Always True. Skipped.
"0 <= 1.5", # Always True. Skipped.
"0.0 = 0.0", # Always True. Skipped.
],
)

assert len(constraints) == 5
assert len(constraints) == 4
# Check with force_flip.
for _ in range(20):
rand_val = get_random_neutral_variable_values(
Expand All @@ -89,7 +88,7 @@ def test_comp_space_constraints(solver_test_ensemble, solver_test_initial_occupa
)
variables.value = rand_val
results = [c.value() for c in constraints]
assert results == [True, False, False, False, True]
assert results == [True, False, False, True]
# Check with canonical.
for _ in range(20):
rand_val = get_random_neutral_variable_values(
Expand All @@ -100,7 +99,7 @@ def test_comp_space_constraints(solver_test_ensemble, solver_test_initial_occupa
)
variables.value = rand_val
results = [c.value() for c in constraints]
assert results == [True, True, True, False, True]
assert results == [True, True, False, True]

# Bad test cases.
with pytest.raises(ValueError):
Expand Down Expand Up @@ -185,8 +184,8 @@ def test_fixed_composition_constraints(
fixed_composition=fixed_counts,
)

# F- is fixed and always satisfied, will not appear.
assert len(constraints) == 8
# F- and O2- are fixed and always satisfied, will not appear.
assert len(constraints) == 5
for _ in range(20):
rand_val = get_random_neutral_variable_values(
solver_test_ensemble.sublattices,
Expand All @@ -196,10 +195,11 @@ def test_fixed_composition_constraints(
) # Force canonical constraints, will always satisfy.
variables.value = rand_val
results = [c.value() for c in constraints]
assert results == [True for _ in range(8)]
assert results == [True for _ in range(5)]
flatten_bits = list(itertools.chain(*bits))
flatten_bits.remove(Species("F", -1))
assert len(flatten_bits) == 8
flatten_bits.remove(Species("O", -2))
assert len(flatten_bits) == 5
ti_id = flatten_bits.index(Species("Ti", 4))
mn4_id = flatten_bits.index(Species("Mn", 4))
for _ in range(20):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_capp/test_solver/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_expansion_upper(solver_test_ensemble, solver_test_initial_occupancy):
objective, aux, aux_indices, aux_cons = get_expression_and_auxiliary_from_terms(
terms, variables
)
for _ in range(50):
for _ in range(20):
rand_val = get_random_variable_values(solver_test_ensemble.sublattices)
aux_val = get_auxiliary_variable_values(rand_val, aux_indices)
rand_occu = get_occupancy_from_variables(
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_ewald_upper(solver_test_ensemble, solver_test_initial_occupancy):
for inds in aux_indices:
assert len(inds) == 2 # No more than pair terms.

for _ in range(50):
for _ in range(20):
# Should have the same ewald for either neutral or not neutral.
rand_val = get_random_variable_values(solver_test_ensemble.sublattices)
aux_val = get_auxiliary_variable_values(rand_val, aux_indices)
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_chemical_potentials_upper(solver_test_ensemble, solver_test_initial_occ
assert aux is None
assert len(aux_indices) == 0
assert len(aux_cons) == 0
for _ in range(50):
for _ in range(20):
rand_val = get_random_variable_values(solver_test_ensemble.sublattices)
rand_occu = get_occupancy_from_variables(
solver_test_ensemble.sublattices,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_capp/test_solver/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ def simple_ensemble(simple_expansion, request):
),
),
"GUROBI",
# pytest.param(
# "GUROBI",
# marks=pytest.mark.xfail(
# reason="Gurobi requires license to run.",
# raises=IndexError,
# ),
]
)
def simple_solver(simple_ensemble, request):
Expand Down
71 changes: 36 additions & 35 deletions tests/test_capp/test_solver/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_variable_indices_for_components(
variable_indices,
solver_test_ensemble.processor.structure,
)
# Total 9 species on all sub-lattices.
assert len(var_inds_for_components) == 9
# Total 7 species on all sub-lattices.
assert len(var_inds_for_components) == 7
dim_id = 0
for sublattice in solver_test_ensemble.sublattices:
sub_bits = sublattice.species
Expand All @@ -51,11 +51,11 @@ def test_variable_indices_for_components(
dtype=int,
)
if Species("Li", 1) in sub_bits:
# 2 li sites are restricted.
# 1 li sites are restricted.
li_code = sublattice.encoding[sub_bits.index(Species("Li", 1))]
va_code = sublattice.encoding[sub_bits.index(Vacancy())]
restricted_all_sites = sublattice.restricted_sites
assert len(restricted_all_sites) == 4
assert len(restricted_all_sites) == 3
restricted_li_sites = restricted_all_sites[
np.where(
solver_test_initial_occupancy[restricted_all_sites] == li_code
Expand All @@ -64,8 +64,8 @@ def test_variable_indices_for_components(
restricted_vac_sites = np.setdiff1d(
restricted_all_sites, restricted_li_sites
)
assert len(restricted_li_sites) == 3
assert len(restricted_vac_sites) == 1
assert len(restricted_li_sites) == 1
assert len(restricted_vac_sites) == 2
# Only li sites restricted.
npt.assert_array_equal(
solver_test_initial_occupancy[restricted_li_sites], li_code
Expand All @@ -76,16 +76,16 @@ def test_variable_indices_for_components(
for sp_id, species in enumerate(sub_bits):
var_ids, n_fix = var_inds_for_components[dim_id]
if species == Species("Li", 1):
assert n_fix == 3 # 3 li sites manually restricted
assert n_fix == 1 # 1 li sites manually restricted
elif species == Vacancy():
assert n_fix == 1
assert n_fix == 2
else:
raise ValueError(
"Li/Vac sub-lattice was not correctly partitioned!"
f" Extra species {species}."
)
# 6 unrestricted li sites + 5 unrestricted vac sites.
assert len(var_ids) == 10
# 2 unrestricted li sites + 1 unrestricted vac sites.
assert len(var_ids) == 3
# Check indices are correct.
npt.assert_array_equal(sl_active_variables[:, sp_id], var_ids)
dim_id += 1
Expand All @@ -97,22 +97,16 @@ def test_variable_indices_for_components(
assert len(var_ids) == 6
npt.assert_array_equal(sl_active_variables[:, sp_id], var_ids)
dim_id += 1
elif Species("O", -2) in sub_bits:
for sp_id, species in enumerate(sub_bits):
var_ids, n_fix = var_inds_for_components[dim_id]
if species == Species("O", -2):
# 2 restricted o2- sites.
assert n_fix == 2
else:
assert n_fix == 0
# 6 unrestricted o2- sites, 2 unrestricted o- sites.
assert len(var_ids) == 8
npt.assert_array_equal(sl_active_variables[:, sp_id], var_ids)
dim_id += 1
elif Species("O", -2) in sub_bits: # O2- sublattice totally inactive.
assert list(sub_bits) == [Species("O", -2)]
var_ids, n_fix = var_inds_for_components[dim_id]
assert n_fix == 9
assert len(var_ids) == 0
dim_id += 1
else: # F sub-lattice totally inactive.
assert list(sub_bits) == [Species("F", -1)]
var_ids, n_fix = var_inds_for_components[dim_id]
assert n_fix == 10
assert n_fix == 3
assert len(var_ids) == 0
dim_id += 1

Expand Down Expand Up @@ -142,7 +136,7 @@ def test_ewald_indices(solver_test_ensemble, solver_test_initial_occupancy):

restricted_vac_sites = solver_test_ensemble.restricted_sites[
np.where(
solver_test_initial_occupancy[solver_test_ensemble.restricted_sites] == 5
solver_test_initial_occupancy[solver_test_ensemble.restricted_sites] == 4
)[0]
]
# print("supercell:\n", ew_processor.structure)
Expand All @@ -169,23 +163,30 @@ def test_ewald_indices(solver_test_ensemble, solver_test_initial_occupancy):
var_id += 1
continue
if site_id in solver_test_ensemble.restricted_sites:
# Not the inactive F sub-lattice, just manually restricted.
if Species("F", -1) not in sublattice.species:
# Not the inactive F or O sub-lattice, just manually restricted.
if (Species("F", -1) not in sublattice.species) and (
Species("O", -2) not in sublattice.species
):
# Always occupied by one non-vacancy species.
if spec == Species("O", -2) or (
spec == Species("Li", 1) and site_id not in restricted_vac_sites
):
if spec == Species("Li", 1) and site_id not in restricted_vac_sites:
expected = -1
# Always occupied by other species than this one.
else:
expected = -2
# Inactive F sub-lattice.
# Inactive F or O sub-lattice.
else:
if spec == Species("F", -1):
expected = -1
if Species("F", -1) in sublattice.species:
if spec == Species("F", -1):
expected = -1
else:
expected = -2
elif Species("O", -2) in sublattice.species:
if spec == Species("O", -2):
expected = -1
else:
expected = -2
else:
assert spec.symbol == "O"
expected = -2
raise ValueError(f"Invalid sublattice: {sublattice}!")
# Active site.
else:
# In the new sub-lattice.
Expand All @@ -199,8 +200,8 @@ def test_ewald_indices(solver_test_ensemble, solver_test_initial_occupancy):
ew_id += 1
expects.append(expected)

assert ew_id == n_ew_rows
# print("expected:\n", expects)
assert ew_id == n_ew_rows
npt.assert_array_equal(expects, ew_to_var_id)


Expand Down
Loading

0 comments on commit d6effd4

Please sign in to comment.