e633e77d08657f5e71e662026129dca8b91e3a81
[Benchmarks_CSolver.git] / sudoku-csolver / Sudoku.py
1 import pycosat
2 import sys, getopt 
3 import time
4 import numpy as np
5 import re
6 import random
7 import csolversudoku as cs
8
9 def main(argv): 
10         argument = '' 
11         try:
12                 opts, args = getopt.getopt(argv,"emhvb",["easy","medium","hard","evil","blank","help", "file", "gen", "csolver"])
13         except getopt.GetoptError:
14                 print('Argument error, check -h | --help')
15                 sys.exit(2)
16         for indx,(opt, arg) in enumerate(opts): 
17                 if opt in ("--help"):
18                         help()
19                 elif opt in ("-e", "--easy"): 
20                         solve_problem(easy) 
21                 elif opt in ("-m", "--medium"): 
22                         solve_problem(medium) 
23                 elif opt in ("-h", "--hard"):
24                         solve_problem(hard) 
25                 elif opt in ("-v", "--evil"):
26                         solve_problem(evil) 
27                 elif opt in ("-b", "--blank"):
28                         solve_problem(blank) 
29                 elif opt in ( "--file"):
30                         print opts, args
31                         if "--csolver" in args:
32                                 print "Solving the problem using csolver ..."
33                                 read_problem_from_file(args[indx], 1)
34                         elif "--dump" in args:
35                                 print "Solving the problem using csolver ..."
36                                 read_problem_from_file(args[indx], 2)
37                         elif "--alloy" in args:
38                                 print "Solving the problem using alloy ..."
39                                 read_problem_from_file(args[indx], 3)
40                         else:
41                                 read_problem_from_file(args[indx], 0)
42                 elif opt in ( "--gen"):
43                         N, K = extractProblemSpecs(args)
44                         if "--csolver" in args:
45                                 print "Generating problem using csolver ..."
46                                 generate_problem_csolver(N,K)
47                         else:
48                                 generate_problem(N, K)
49                 else:
50                         help()
51
52             
53 def help():
54         print('Usage:')
55         print('Sudoku.py -e [or] --easy')
56         print('Sudoku.py -m [or] --medium')
57         print('Sudoku.py -h [or] --hard')
58         print('Sudoku.py -v [or] --evil')
59         print('Sudoku.py -b [or] --blank')
60         print('Sudoku.py --file file.problem [--csolver/--dump/--alloy]')
61         print('Sudoku.py --gen 9 20 [--csolver/--dump]')
62         print('All problems generated by websudoku.com')
63         sys.exit()
64
65
66 def removeKDigits(mat, N, K):
67         count = K;
68         while (count != 0):
69                 cellId = random.randint(1,N*N)
70                 i = cellId//N
71                 j = cellId%N
72                 if j != 0:
73                         j = j - 1
74                 if i == N:
75                         i = i-1
76                 #print 'cellId=' + str(cellId) + ' i='+ str(i) + ' j=' + str(j) + ' count='+ str(count)
77                 if mat[i][j] != 0:
78                         count = count -1
79                         mat[i][j] = 0
80
81 def extractProblemSpecs(args):
82         assert len(args) >= 2
83         global N
84         N = int(args[0])
85         K = int(args[1])
86         print N
87         return N, K
88
89 def printValidationStatus(problem):
90         if validateSolution(problem):
91                 print "***CORRECT***"
92         else:
93                 print "***WRONG*****"
94
95 def validateSolution(problem):
96         global N
97         for row in problem:
98                 for i in range(1, N):
99                         row2 = [r for r in row]
100                         if row2.count(i) !=1:
101                                 return False
102         for col in range(1,N):
103                 for i in range(1, N):
104                         if [problem[k][col] for k in range(N) ].count(i) >1:
105                                 return False
106         root = int(N**(0.5))
107         for i in range( root):
108                 for j in range(root):
109                         cube = [ problem[i*root + k % root][ j*root + k // root] for k in range(N)]
110                         for num in range( N):
111                                 if cube.count(num) >1:
112                                         return False
113         return True
114
115 def generate_problem_csolver(N,K):
116         problem = cs.generateProblem(N)
117         pprint(problem)
118         printValidationStatus(problem)
119         np.savetxt('solved/'+str(N) + 'x' + str(N) + '.sol',problem)
120         removeKDigits(problem, N, K)
121 #       np.savetxt('problems/'+str(N) + 'x' + str(N) + '-' + str(K) + '.problem',problem)
122
123 def generate_problem(N, K):
124         problem = [[0 for i in range(N)] for i in range(N)]
125         solve(problem)
126         np.savetxt('solved/'+str(N) + 'x' + str(N) + '.sol',problem)
127         printValidationStatus(problem)  
128         removeKDigits(problem, N, K)
129         pprint(problem)
130         np.savetxt('problems/'+str(N) + 'x' + str(N) + '-' + str(K) + '.problem',problem)
131
132 def read_problem_from_file(filename, solverID):
133         problem = np.loadtxt(filename)
134         global N
135         N=int(re.findall('\d+', filename)[0])
136         problem = problem.astype(int)
137         solve_problem(problem, solverID)
138
139 def solve_problem(problemset, solverID):
140         print('Problem:') 
141         pprint(problemset)
142         if solverID != 0:
143                 problemset=cs.solveProblem(N, problemset, solverID)
144                 np.savetxt('solved/'+str(N) + 'x' + str(N) + '.problem',problemset)
145         else: 
146                 solve(problemset) 
147         print('Answer:')
148         pprint(problemset)
149         printValidationStatus(problemset)  
150     
151 def v(i, j, d): 
152         return N**2 * (i - 1) + N * (j - 1) + d
153
154 #Reduces Sudoku problem to a SAT clauses 
155 def sudoku_clauses(): 
156         res = []
157         # for all cells, ensure that the each cell:
158         for i in range(1, N+1):
159                 for j in range(1, N+1):
160                         # does not denote two different digits at once (36 clauses)
161                         for d in range(1, N+1):
162                                 for dp in range(d + 1, N+1):
163                                         res.append([ -v(i, j, dp), -v(i, j, d)])
164                         # denotes (at least) one of the 9 digits (1 clause)
165                         res.append([v(i, j, d) for d in range(1, N+1)])
166                         
167         print "First one :" + str( len(res))
168         
169         def valid(cells): 
170                 for i, xi in enumerate(cells):
171                         for j, xj in enumerate(cells):
172                                 if i < j:
173                                         for d in range(1, N+1):
174                                                 res.append([-v(xi[0], xi[1], d), -v(xj[0], xj[1], d)])
175
176         # ensure rows and columns have distinct values
177         for i in range(1, N+1):
178                 valid([(i, j) for j in range(1, N+1)])
179                 valid([(j, i) for j in range(1, N+1)])
180         print "Second one :" + str(len(res))
181         # ensure rootxroot (e.g. 3*3) sub-grids "regions" have distinct values
182         root = int(N**(0.5))
183         collections = [ root*i+1 for i in range(root)]
184         for i in collections:
185                 for j in collections:
186                         valid([(i + k % root, j + k // root) for k in range(N)])
187         print "Third one :" + str( len(res))
188 #       assert len(res) == 81 * (1 + 36) + 27 * 324
189         return res
190
191 def solve(grid):
192         #solve a Sudoku problem
193         clauses = sudoku_clauses()
194         for i in range(1, N+1):
195                 for j in range(1, N+1):
196                         d = grid[i - 1][j - 1]
197                         # For each digit already known, a clause (with one literal). 
198                         if d:
199                                 clauses.append([v(i, j, d)])
200
201         # Print number SAT clause 
202         numclause = len(clauses)
203         print "P CNF " + str(numclause) +"(number of clauses)"
204 #       for c in clauses:
205 #               print c
206         # solve the SAT problem
207         start = time.time()
208         sol = set(pycosat.solve(clauses, N**3))
209         end = time.time()
210         print("SUDOKU SAT SOLVING TIME: "+str(end - start))
211     
212         def read_cell(i, j):
213                 # return the digit of cell i, j according to the solution
214                 for d in range(1, N+1):
215                         if v(i, j, d) in sol:
216                                 return d
217
218         for i in range(1, N+1):
219                 for j in range(1, N+1):
220                         grid[i - 1][j - 1] = read_cell(i, j)
221
222
223 if __name__ == '__main__':
224         from pprint import pprint
225         N = 9
226         # Sudoku problem generated by websudoku.com
227         easy = [[0, 0, 0, 1, 0, 9, 4, 2, 7],
228                 [1, 0, 9, 8, 0, 0, 0, 0, 6],
229                 [0, 0, 7, 0, 5, 0, 1, 0, 8],
230                 [0, 5, 6, 0, 0, 0, 0, 8, 2],
231                 [0, 0, 0, 0, 2, 0, 0, 0, 0],
232                 [9, 4, 0, 0, 0, 0, 6, 1, 0],
233                 [7, 0, 4, 0, 6, 0, 9, 0, 0],
234                 [6, 0, 0, 0, 0, 8, 2, 0, 5],
235                 [2, 9, 5, 3, 0, 1, 0, 0, 0]]
236
237         medium = [[5, 8, 0, 0, 0, 1, 0, 0, 0],
238                 [0, 3, 0, 0, 6, 0, 0, 7, 0],
239                 [9, 0, 0, 3, 2, 0, 1, 0, 6],
240                 [0, 0, 0, 0, 0, 0, 0, 5, 0],
241                 [3, 0, 9, 0, 0, 0, 2, 0, 1],
242                 [0, 5, 0, 0, 0, 0, 0, 0, 0],
243                 [6, 0, 2, 0, 5, 7, 0, 0, 8],
244                 [0, 4, 0, 0, 8, 0, 0, 1, 0],
245                 [0, 0, 0, 1, 0, 0, 0, 6, 5]]
246
247         evil = [[0, 2, 0, 0, 0, 0, 0, 0, 0],
248                 [0, 0, 0, 6, 0, 0, 0, 0, 3],
249                 [0, 7, 4, 0, 8, 0, 0, 0, 0],
250                 [0, 0, 0, 0, 0, 3, 0, 0, 2],
251                 [0, 8, 0, 0, 4, 0, 0, 1, 0],
252                 [6, 0, 0, 5, 0, 0, 0, 0, 0],
253                 [0, 0, 0, 0, 1, 0, 7, 8, 0],
254                 [5, 0, 0, 0, 0, 9, 0, 0, 0],
255                 [0, 0, 0, 0, 0, 0, 0, 4, 0]]
256
257         hard = [[0, 2, 0, 0, 0, 0, 0, 3, 0],
258                 [0, 0, 0, 6, 0, 1, 0, 0, 0],
259                 [0, 6, 8, 2, 0, 0, 0, 0, 5],
260                 [0, 0, 9, 0, 0, 8, 3, 0, 0],
261                 [0, 4, 6, 0, 0, 0, 7, 5, 0],
262                 [0, 0, 1, 3, 0, 0, 4, 0, 0],
263                 [9, 0, 0, 0, 0, 7, 5, 1, 0],
264                 [0, 0, 0, 1, 0, 4, 0, 0, 0],
265                 [0, 1, 0, 0, 0, 0, 0, 9, 0]]
266     
267         blank = [[0, 0, 0, 0, 0, 0, 0, 0, 0],
268                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
269                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
270                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
271                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
272                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
273                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
274                 [0, 0, 0, 0, 0, 0, 0, 0, 0],
275                 [0, 0, 0, 0, 0, 0, 0, 0, 0]]
276     
277         if(len(sys.argv[1:]) == 0):
278                 print('Argument error, check --help')
279         else:
280                 main(sys.argv[1:])