题意:给一个N×MN \times MN×M的矩阵,可以进行任意多次操作将一列轮换,求每一行的最大值之和的最大值。多组数据。
Easy VersionN≤4N \leq 4N≤4,M≤100M \leq100M≤100
Hard VersionN≤12N \leq 12N≤12,M≤2000M \leq2000M≤2000
看这数据范围显然是个状压
相当于是每一行只能选一个
设f(i,S)f(i,S)f(i,S)表示当前到iii,已经选了SSS的最大值
记忆化搜索一波,暴力转一下
复杂度O(4nn2m)O(4^nn^2m)O(4nn2m) 可以通过Easy Version
#include <iostream>#include <cstdio>#include <cstring>#include <cctype>using namespace std;int n,m;int a[20][105],dp[105][1<<4];int Move(int s){return (s>>1)|((s&1)<<(n-1));}int dfs(int pos,int s){if (dp[pos][s]) return dp[pos][s];if (s==(1<<n)-1) return 0;if (pos==m) return 0;int &ans=dp[pos][s];for (int t=0;t<(1<<n);t++){if (s&t) continue;int res=dfs(pos+1,s|t);int mx=0;for (int k=0,cur=t;k<n;k++){int sum=0;cur=Move(cur);for (int i=0;i<n;i++)if ((1<<i)&cur)sum+=a[i][pos];mx=max(mx,sum);}ans=max(ans,res+mx);}return ans;}void solve(){memset(dp,0,sizeof(dp));scanf("%d%d",&n,&m);for (int i=0;i<n;i++)for (int j=0;j<m;j++)scanf("%d",&a[i][j]);printf("%d\n",dfs(0,0));}int main(){int T;scanf("%d",&T);while (T--) solve();return 0;}
复杂度里有MMM,过不了HV
考虑如何消掉M
我们发现因为只有NNN行,所以最多只有NNN列选了数
能否快速确定这NNN列?
贪心!
我们把列按最大值从大到小排序,如果前面的一列没选数而后面的选了
那我们完全可以改成前面没有选的,一定更优
因为只能选NNN个数,所以只用考虑最大的NNN列
复杂度O(4nn3)O(4^nn^3)O(4nn3)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
int n,m;
int a[20][2005],dp[2005][1<<12],ord[2005];
int Move(int s){return (s>>1)|((s&1)<<(n-1));}
int dfs(int pos,int s)
{if (dp[pos][s]) return dp[pos][s];if (s==(1<<n)-1) return 0;if (pos==m) return 0;int &ans=dp[pos][s];for (int t=0;t<(1<<n);t++){if (s&t) continue;int res=dfs(pos+1,s|t);int mx=0;
// printf("%d",t);for (int k=0,cur=t;k<n;k++){int sum=0;cur=Move(cur);
// printf("->%d",cur);for (int i=0;i<n;i++)if ((1<<i)&cur)sum+=a[i][ord[pos]];mx=max(mx,sum);}
// puts("");ans=max(ans,res+mx);}return ans;
}
int mx[2005];
inline bool cmp(const int& a,const int& b){return mx[a]>mx[b];}
void solve()
{memset(dp,0,sizeof(dp));memset(mx,0,sizeof(mx));scanf("%d%d",&n,&m);for (int i=0;i<n;i++)for (int j=0;j<m;j++)scanf("%d",&a[i][j]),mx[j]=max(mx[j],a[i][j]);for (int i=0;i<m;i++) ord[i]=i;sort(ord,ord+m,cmp);m=min(n,m);printf("%d\n",dfs(0,0));
}
int main()
{int T;scanf("%d",&T);while (T--) solve();return 0;
}
然而发现慢得飞起
我们发现:在dp的时候,会花费O(n)O(n)O(n)找出最佳的旋转方案,而这个O(n)O(n)O(n)是在O(4n)O(4^n)O(4n)的基础上的
为什么不预处理呢
然后得到了O(2nn3+4nn)O(2^nn^3+4^nn)O(2nn3+4nn)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
int n,m;
int a[20][2005],mem[2005][1<<12],dp[2005][1<<12],ord[2005];
int dfs(int pos,int s)
{if (dp[pos][s]) return dp[pos][s];if (s==(1<<n)-1) return 0;if (pos==m) return 0;int &ans=dp[pos][s];for (int t=0;t<(1<<n);t++)if (!(s&t))ans=max(ans,dfs(pos+1,s|t)+mem[ord[pos]][t]);return ans;
}
int mx[2005];
inline bool cmp(const int& a,const int& b){return mx[a]>mx[b];}
void solve()
{memset(dp,0,sizeof(dp));memset(mx,0,sizeof(mx));memset(mem,0,sizeof(mem));scanf("%d%d",&n,&m);for (int i=0;i<n;i++)for (int j=0;j<m;j++)scanf("%d",&a[i][j]),mx[j]=max(mx[j],a[i][j]);for (int i=0;i<m;i++) ord[i]=i;sort(ord,ord+m,cmp);m=min(n,m);for (int pos=0;pos<m;pos++)for (int s=0;s<(1<<n);s++)for (int k=0;k<n;k++){int sum=0;for (int i=0;i<n;i++)if (s&(1<<i))sum+=a[(i+k)%n][ord[pos]];mem[ord[pos]][s]=max(mem[ord[pos]][s],sum); } printf("%d\n",dfs(0,0));
}
int main()
{int T;scanf("%d",&T);while (T--) solve();return 0;
}
然而还是慢得飞起
这玩意还有多组数据
我们发现这一句
if (!(s&t))
是求和sss没有交集的ttt
能否快速得到这玩意呢?
还真不行
但我们可以换一种实现方式
直接递推实现dp
这样从上一维转移过来只用枚举子集
可以用这句
for (int t=s;t;t=((t-1)&s))
这样就只枚举了sss的子集
复杂度?
加上枚举sss
一共是
∑i=0nCni2i\sum_{i=0}^{n}C_n^i2^ii=0∑nCni2i
=∑i=0nCnn−i2i=\sum_{i=0}^{n}C_n^{n-i}2^i=i=0∑nCnn−i2i
=(1+2)n=3n=(1+2)^n=3^n=(1+2)n=3n
所以复杂度是O(2nn3+3nn)O(2^nn^3+3^nn)O(2nn3+3nn)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
int n,m;
int a[12][2000],mx[2000],pos[2000],mem[2000][1<<12],dp[2000][1<<12];
inline bool cmp(const int& a,const int& b){return mx[a]>mx[b];}
void solve()
{scanf("%d%d",&n,&m);memset(mx,0,sizeof(mx));for (int i=0;i<n;i++)for (int j=0;j<m;j++)scanf("%d",&a[i][j]),mx[j]=max(mx[j],a[i][j]);for (int i=0;i<m;i++) pos[i]=i;sort(pos,pos+m,cmp);m=min(n,m);for (int p=0;p<m;p++)for (int s=0;s<(1<<n);s++){mem[p][s]=0;for (int k=0;k<n;k++){int sum=0;for (int i=0;i<n;i++)if ((1<<i)&s)sum+=a[(i+k)%n][pos[p]];mem[p][s]=max(mem[p][s],sum);}} memset(dp,0,sizeof(dp));for (int s=0;s<(1<<n);s++) dp[0][s]=mem[0][s];for (int p=1;p<m;p++)for (int s=0;s<(1<<n);s++){for (int t=s;t;t=((t-1)&s))dp[p][s]=max(dp[p][s],dp[p-1][s^t]+mem[p][t]);dp[p][s]=max(dp[p][s],dp[p-1][s]); }printf("%d\n",dp[m-1][(1<<n)-1]);
}
int main()
{int T;scanf("%d",&T);while (T--) solve();return 0;
}
过于毒瘤