-
Notifications
You must be signed in to change notification settings - Fork 0
/
dlx.py
188 lines (156 loc) · 4.95 KB
/
dlx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Data object, x.
class X:
def __init__(self, column=None, row=None):
self.up = self.down = self.right = self.left = self
# Points to the column object at the head of the relevant column.
self.column = column
self.row = row
def __str__(self):
return f'{self.column.name}:{self.row}'
# Column object, y.
class Y(X):
def __init__(self, row=None, name='', size=0):
self.size = size # The number of 1s in the column.
self.name = name # Symbolic identifier for printing the answers.
X.__init__(self, self, row)
class DLX:
@staticmethod
def solve(A):
torodial = DLX.init_toroidal(A)
return DLX.search(torodial)
@staticmethod
def init_column_labels(A):
rows, cols = A.shape
# Use a human-readable label.
labels = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ')
# In case we have a lot of columns...
if cols >= len(labels):
labels = list(map(str, range(cols)))
return labels
@staticmethod
def init_column_header(A):
_, cols = A.shape
labels = DLX.init_column_labels(A)
# The root header links all the column headers.
root = Y(name='root',
size=float('inf'))
# The column headers is a circular linkedlist.
curr = root
for col in range(cols):
curr.right = Y(name=labels[col])
curr.right.left = curr
curr = curr.right
# Linking the rightmost column label to the root
# to make it circular.
curr.right = root
curr.right.left = curr
return root
@staticmethod
def header_pointer(root):
# Simple pointer to help us find the column at constant time.
header = {}
curr = root.right
while curr != root:
header[curr.name] = curr
curr = curr.right
return header
@staticmethod
def smallest_column_object(root):
'''
Attempt to select the column object with the smallest size.
'''
curr = root.column.right
c = root
while curr != root:
if curr.size < c.size:
c = curr
curr = curr.right
return c
@staticmethod
def init_toroidal(A):
labels = DLX.init_column_labels(A)
root = DLX.init_column_header(A)
header = DLX.header_pointer(root)
for i, row in enumerate(A):
prev = None
left = None
for j, col in enumerate(row):
if col != 1: continue
head = header[labels[j]]
head.size += 1
curr = head.up
curr.down = X(column=head,
row=i)
curr.down.up = curr
curr = curr.down
curr.down = curr.column
curr.down.up = curr
if prev is None:
prev = curr
prev.right = curr
prev.right.left = prev
left = curr
else:
prev.right = curr
prev.right.left = prev
prev = curr
# Happens when the column does not have any 1s.
if prev is not None:
prev.right = left
prev.right.left = prev
return root
@staticmethod
def cover(col):
col = col.column
col.right.left = col.left
col.left.right = col.right
i = col.down
while i != col:
j = i.right
while j != i:
j.up.down = j.down
j.down.up = j.up
j.column.size -= 1
j = j.right
i = i.down
@staticmethod
def uncover(col):
col = col.column
i = col.up
while i != col:
j = i.left
while j != i:
j.column.size += 1
j.up.down = j
j.down.up = j
j = j.left
i = i.up
col.right.left = col
col.left.right = col
@staticmethod
def search(root, k=0, solution=None):
if solution is None:
solution = []
if root.right == root:
return solution[:]
col = DLX.smallest_column_object(root)
DLX.cover(col)
r = col.down
while r != col:
o_k = r
solution.append(o_k.row)
j = r.right
while j != r:
DLX.cover(j)
j = j.right
result = DLX.search(root, k+1, solution)
if result: return result
solution.remove(o_k.row)
r = o_k
col = r.column
j = r.left
while j != r:
DLX.uncover(j)
j = j.left
r = r.down
DLX.uncover(col)