起因
都说懒惰是第一生产力,最近在玩数独游戏的时候,总会遇到拆解数独比较复杂的情况,就想着自己写个代码解题,解放双手。所以很快就写了一个简单的代码求解经典数独。拿来跑了几个最高难度的数独发现确实很爽!虽说是比较暴力的 DFS,但是由于数独中约束较多的性质,实际上要找出唯一解并不复杂,即使是最高难度的数独也可以在 0.04s 内解完,可以说是非常的方便。
思路
经典数独游戏由 9*9 的方格组成,每个方格可填 1~9 的数字,一般都有三种约束:同行,同列,同宫不可出现相同的数字。只要暴力时利用这些约束,就可以快速剪枝。
考虑最简单的情况:我们对于任何一个空位,可以尝试去填 1~9 的数字,并且检查三种约束是否满足。若满足,就继续填下一个空位。
处理约束
实际上,并不需要每个格子都去把 1~9 全部尝试。因为填的数字越多,约束就越强,我们就越容易发现之前填数时的错误。所以我们可以预先处理三种约束影响的格子范围:
void initializeRelation() {memset(digitsUsed, 0, sizeof digitsUsed);// sub-gridsfor (int i = 0; i < 3; i++) {for (int j = 0; j < 3; j++) {int num = i * 3 + j;for (int k = 0; k < 3; k++) {for (int l = 0; l < 3; l++) {int idx = calcIdx(i * 3 + k, j * 3 + l);group[2][idx] = num;r[num].push_back(idx);}}}}// rowsfor (int i = 0; i < N; i++) {for (int j = 0; j < N; j++) {int idx = calcIdx(i, j);group[0][idx] = i + N;r[i + N].push_back(idx);}}// columnsfor (int i = 0; i < N; i++) {for (int j = 0; j < N; j++) {int idx = calcIdx(j, i);group[1][idx] = i + N * 2;r[i + N * 2].push_back(idx);}}
}
预先处理完约束后,下次要找一个格子到底应该对应哪些约束时,就可以直接找到对应的 idx
序号了。
状态压缩
一个格子可以填 1~9 共九种数字,那么到底哪些是可以填的呢?就如同我们实际解数独时一样,我们可以在格子上标记一下有哪些数字是符合约束的。一个简单的方法是把这个状态压缩成二进制数,每个可用数字代表一个二进制位的 1,若不可用,则该位为 0。那么一个格子上的可用数字可用一个 9 位二进制数表示,范围是0~2^9
,也即一个格子至多只有 512 种状态。
接下来 gcc
有一些方便的内建函数可以帮到我们,它们都是以 __builtin
开头:
__builtin_popcount(unsigned int x)
返回无符号整型x
的二进制中 1 的个数__builtin_ctz(unsigned int x)
返回无符号整型x
的二进制的末尾有多少个 0
上述函数也可使用 std::bitset<N>::count
等实现,作用类似。
现在计算某个格子还有多少可用数字就可以这样:
inline int calcUsable(int idx) {return 9 - __builtin_popcount(digitsUsed[idx]);
}
DFS
当我们枚举数字时,其实就是从当前状态中找到下一个可用数字,并根据约束关系删除与其相关的格子中的可用数字。
那么搜索时如何快速判断当前填的数字否可行呢?一个简单的思路是每次找到可用数字最少的格子,这样的格子可以确定更多的约束,搜索空间也更少,一旦失败了,我们可以迅速回滚。
那么把所有的空格子按照他们的[可用数字个数,可用数字状态]
作为一个数对,我们就可以利用std::set
构造出一个暴力 DFS 方案:
bool dfs() {if (grid.empty()) {return true;}pair<int, int> p = *grid.begin();grid.erase(p);int idx = p.second;int digitBit = MASK & ~digitsUsed[idx];for (int nextDigitBit = digitBit; nextDigitBit; nextDigitBit ^= lowbit(nextDigitBit)) {int digit = lowbit0Count(nextDigitBit);int currentDigitBit = 1 << digit;g[idx] = digit + 1;vector<int> last;for (int j = 0; j < 3; j++) {for (auto & x: r[group[j][idx]]) {auto it = grid.find(make_pair(calcUsable(x), x));if (it != grid.end() && (digitsUsed[x] | currentDigitBit) != digitsUsed[x]) {grid.erase(it);digitsUsed[x] = digitsUsed[x] | currentDigitBit;grid.insert(make_pair(calcUsable(x), x));last.push_back(x);}}}if (dfs()) {return true;}for (auto &x: last) {grid.erase(make_pair(calcUsable(x), x));digitsUsed[x] = digitsUsed[x] & ~currentDigitBit;grid.insert(make_pair(calcUsable(x), x));}}grid.insert(p);return false;
}
结语
由于只考虑经典数独,代码还是非常简洁而且高效的。而对于各种各样的变形数独,也可以考虑根据这种简化约束的方式去暴力求解。如果想要模仿人类解法,对强弱链等逻辑进行推演而非简单暴力的话,还需要更多的工作。
当然,数独如果由机器暴力计算就会缺失很多乐趣,但去寻找现有问题的一种代码实现也同样是另一种乐趣。我觉得能在数学游戏中找到自己喜欢的部分,并发掘出其中的趣味,其本身也是一种快乐的事情。
附录
最终代码如下,输入重定向于sudoku.in
,输入格式中星号*
代表空位,可在代码最后注释中看到样例。
输出格式为先输出整体的解,再输出只包含原数独中空位的解。
#include <bits/stdc++.h>using namespace std;const int N = 9;
const int R_NUM = 27;
const int GRID_NUM = 81;
const int MASK = (1 << N) - 1;char str[10][100];
int s[9][9];
int g[GRID_NUM];
int group[3][GRID_NUM]; // groups
vector<int> r[R_NUM]; // relations
set<pair<int, int>> grid;
int digitsUsed[GRID_NUM];/**
group 0:
000000000
111111111
222222222
333333333
444444444
555555555
666666666
777777777
888888888group 1:
012345678
012345678
012345678
012345678
012345678
012345678
012345678
012345678
012345678group 2:
000111222
000111222
000111222
333444555
333444555
333444555
666777888
666777888
666777888
**/inline int calcX(int idx) {return group[0][idx];
}inline int calcIdx(int x, int y) {return x * N + y;
}inline int lowbit(int x) {return x & (-x);
}inline int lowbit0Count(int x) {return __builtin_ctz(x);
}inline int calcUsable(int idx) {return 9 - __builtin_popcount(digitsUsed[idx]);
}void initializeRelation() {memset(digitsUsed, 0, sizeof digitsUsed);// sub-gridsfor (int i = 0; i < 3; i++) {for (int j = 0; j < 3; j++) {int num = i * 3 + j;for (int k = 0; k < 3; k++) {for (int l = 0; l < 3; l++) {int idx = calcIdx(i * 3 + k, j * 3 + l);group[2][idx] = num;r[num].push_back(idx);}}}}// rowsfor (int i = 0; i < N; i++) {for (int j = 0; j < N; j++) {int idx = calcIdx(i, j);group[0][idx] = i + N;r[i + N].push_back(idx);}}// columnsfor (int i = 0; i < N; i++) {for (int j = 0; j < N; j++) {int idx = calcIdx(j, i);group[1][idx] = i + N * 2;r[i + N * 2].push_back(idx);}}
}void fail() {printf("IMPOSSIBLE\n");exit(0);
}void printResult() {printf("Result:\n");for (int i = 0; i < N; i++) {for (int j = 0; j < N; j++) {printf("%d", g[calcIdx(i, j)]);}printf("\n");}
}void printFillableResult() {printf("\nFillable Result:\n");for (int i = 0; i < N; i++) {for (int j = 0; j < N; j++) {printf("%c", (s[i][j] == 0) ? g[calcIdx(i, j)] + '0' : '*');}printf("\n");}
}bool dfs() {if (grid.empty()) {return true;}pair<int, int> p = *grid.begin();grid.erase(p);int idx = p.second;int digitBit = MASK & ~digitsUsed[idx];for (int nextDigitBit = digitBit; nextDigitBit; nextDigitBit ^= lowbit(nextDigitBit)) {int digit = lowbit0Count(nextDigitBit);int currentDigitBit = 1 << digit;g[idx] = digit + 1;vector<int> last;for (int j = 0; j < 3; j++) {for (auto & x: r[group[j][idx]]) {auto it = grid.find(make_pair(calcUsable(x), x));if (it != grid.end() && (digitsUsed[x] | currentDigitBit) != digitsUsed[x]) {grid.erase(it);digitsUsed[x] = digitsUsed[x] | currentDigitBit;grid.insert(make_pair(calcUsable(x), x));last.push_back(x);}}}if (dfs()) {return true;}for (auto &x: last) {grid.erase(make_pair(calcUsable(x), x));digitsUsed[x] = digitsUsed[x] & ~currentDigitBit;grid.insert(make_pair(calcUsable(x), x));}}grid.insert(p);return false;
}int main() {freopen("sudoku.in", "r", stdin);initializeRelation();// Enter a sudoku puzzle: (9 lines with 9 characters on each line, use * for blank)for (int i = 0; i < N; i++) {scanf("%s", str[i]);}for (int i = 0; i < N; i++) {if (strlen(str[i]) != N) {exit(0);}for (int j = 0; j < N; j++) {int idx = calcIdx(i, j);if (str[i][j] == '*') {g[idx] = s[i][j] = 0;digitsUsed[idx] = 0;} else if (str[i][j] >= '1' && str[i][j] <= '9') {g[idx] = s[i][j] = str[i][j] - '0';} else {exit(0);}}}for (int idx = 0; idx < GRID_NUM; idx++) {if (g[idx] == 0) {for (int j = 0; j < 3; j++) {for (auto & cur: r[group[j][idx]]) {if (g[cur] != 0) {digitsUsed[idx] |= 1 << (g[cur] - 1);}}}// pair is (digitsLeftCount, idx)grid.insert(make_pair(calcUsable(idx), idx));}}if (dfs()) {printResult();printFillableResult();} else {printResult();fail();}return 0;
}/*
<Sample Input>*23456789
456789123
789123456
312645978
645978312
978312645
231564897
564897231
897231564**95*8*7*
23769***4
5**32**1*
8*1935**7
49*8*2*51
**3**6*2*
*1*2*4**6
6*8*****2
*7*1*38*******2***
2*4****7*
****5**49
**6**85**
******83*
57**4****
*3*7****6
*65*3**9*
7***9*1***/