《算法竞赛·快冲300题》将于2024年出版,是《算法竞赛》的辅助练习册。
所有题目放在自建的OJ New Online Judge。
用C/C++、Java、Python三种语言给出代码,以中低档题为主,适合入门、进阶。
文章目录
- 题目描述
- 题解
- C++代码
- Java代码
- Python代码
“ x1 == x2” ,链接: http://oj.ecustacm.cn/problem.php?id=1725
题目描述
【题目描述】 现在给定一些变量的等于或者不等于的约束,请你判断这些约束能否同时满足。
输入时变量为x1,x2,…,xn,均以x开始,x后面为该变量对应的编号。
约束条件只有"=“或者”!="。
【输入格式】 输入第一行为正整数T,表示存在T组测试数据。(T≤10)
每组测试数据第一行为正整数n,表示约束条件的数量。(n≤1000000)
接下来n行,每行以下列形式输出
xi = xj
xi != xj
其中i和j表示对应变量的编号。(1≤i,j≤10^9)
【输出格式】 对于每组测试数据,输出一行Yes表示可以满足,输出No表示不能满足。
【输入样例】
2
4
x1 = x7
x9 != x7
x13 = x9
x1 = x13
2
x1 = x2
x2 = x1
【输出样例】
No
Yes
题解
本题的解法显然是并查集。首先建立并查集,把相等的约束合并。然后逐一检查不等的约束,如果有一个与并查集产生了矛盾,输出“No”;如果完全没有产生矛盾,输出“Yes”。
但是,如果直接用x的编号建立并查集会超内存。因为编号最大为 1 0 9 10^9 109,直接建立并查集,需要 1 0 9 10^9 109 = 1G的空间。
如何优化空间?由于最多只有n= 1 0 6 10^6 106个约束,x的编号数量最多只有 2 × 1 0 6 2×10^6 2×106个。这是典型的离散化,把x的原来 1 0 9 10^9 109个编号,转换为 2 × 1 0 6 2×10^6 2×106个新编号。
【重点】 离散化 。
C++代码
离散化的编码(离散化见《算法竞赛》,清华大学出版社,罗勇军、郭卫斌著,75页,2.7 离散化),可以手工编码,也可以用STL的lower_bound()和unique()。下面的代码用STL实现。
注意输入的处理。
#include<bits/stdc++.h>
using namespace std;
struct node{int x, y, z;}a[1000010];
int tot, b[2000010];
int s[2000010]; //并查集
int get_id(int x){ //返回离散化后的新值return lower_bound(b + 1, b + 1 + tot, x) - b;
}
int find_set(int x){if(x != s[x]) s[x] = find_set(s[x]);return s[x];
}
int main(){int T; cin >> T;while(T--) {int n; cin >> n;tot = 0;for(int i = 1; i <= n; i++) {char ch1, ch2; string str;cin >> ch1 >> a[i].x >> str >> ch2 >> a[i].y;if(str[0] == '=') a[i].z = 1; //相等else a[i].z = 0; //不等b[++tot] = a[i].x; //把a的编号记录在b中b[++tot] = a[i].y;}sort(b + 1, b + 1 + tot);tot = unique(b + 1, b + 1 + tot) - (b + 1); //b去重,留下唯一的编号for(int i = 1; i <= tot; i++) s[i] = i; //建立并查集for(int i = 1; i <= n; i++)if(a[i].z) { //处理相等约束,对离散化后的新编号合并并查集int sx = find_set(get_id(a[i].x));int sy = find_set(get_id(a[i].y));s[sx] = sy;}bool ok = true;for(int i = 1; i <= n; i++)if(!a[i].z) { //检查不等约束是否造成矛盾int sx = find_set(get_id(a[i].x));int sy = find_set(get_id(a[i].y));if( sx == sy ){ok = false;break;}}if(ok) cout<<"Yes"<<endl;else cout<<"No"<<endl;}return 0;
}
Java代码
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
class Main {static class Node {int x, y, z;Node(int x, int y, int z) {this.x = x;this.y = y;this.z = z;}}static Node[] a = new Node[1000010];static int tot;static int[] b = new int[2000010];static int[] s = new int[2000010];static int get_id(int x) {return Arrays.binarySearch(b, 1, tot + 1, x);}static int find_set(int x) {if (x != s[x]) s[x] = find_set(s[x]);return s[x];}public static void main(String[] args) throws IOException {BufferedReader br = new BufferedReader(new InputStreamReader(System.in));int T = Integer.parseInt(br.readLine());while (T-- > 0) {int n = Integer.parseInt(br.readLine());tot = 0;for (int i = 1; i <= n; i++) {String str = br.readLine();int x = 0, y = 0;if (str.contains("!")) {String[] split = str.split(" != ");x = Integer.parseInt(split[0].substring(1));y = Integer.parseInt(split[1].substring(1));a[i] = new Node(x, y, 0);} else {String[] split = str.split(" = ");x = Integer.parseInt(split[0].substring(1));y = Integer.parseInt(split[1].substring(1));a[i] = new Node(x, y, 1);}b[++tot] = a[i].x;b[++tot] = a[i].y;}Arrays.sort(b, 1, tot + 1);tot = deduplicate(b, tot);for (int i = 1; i <= tot; i++) s[i] = i;for (int i = 1; i <= n; i++) {if (a[i].z == 1) {int sx = find_set(get_id(a[i].x));int sy = find_set(get_id(a[i].y));s[sx] = sy;}}boolean ok = true;for (int i = 1; i <= n; i++) {if (a[i].z == 0) {int sx = find_set(get_id(a[i].x));int sy = find_set(get_id(a[i].y));if (sx == sy) {ok = false;break;}}}if (ok) System.out.println("Yes");else System.out.println("No");}}static int deduplicate(int[] b, int n) { // 去重int p = 1;for (int i = 2; i <= n; i++)if (b[i] != b[p])b[++p] = b[i];return p;}
}
Python代码
用bisect.bisect_left()离散化,用set去重。
#pypy
import sys
sys.setrecursionlimit(10000)
import bisect
input = sys.stdin.readline
def get_id(x): return bisect.bisect_left(b, x)
def find_set(x):if x != s[x]: s[x] = find_set(s[x])return s[x]
T = int(input())
for _ in range(T):n = int(input())a = []b = []for i in range(n):str = input()x, y = 0, 0if "!" in str:split = str.split(" != ")x = int(split[0][1:])y = int(split[1][1:])a.append([x, y, 0])else:split = str.split(" = ")x = int(split[0][1:])y = int(split[1][1:])a.append([x, y, 1])b.append(a[i][0])b.append(a[i][1]) b = sorted(set(b))tot = len(b)s = [i for i in range(tot)]for i in range(n):if a[i][2] == 1:sx = find_set(get_id(a[i][0]))sy = find_set(get_id(a[i][1]))s[sx] = syok = Truefor i in range(n):if a[i][2] == 0:sx = find_set(get_id(a[i][0]))sy = find_set(get_id(a[i][1]))if sx == sy:ok = Falsebreakif ok: print("Yes")else: print("No")