今日雅礼模拟赛day2 T2就只做了二维莫队,就来写博客更一下
· 前言
大家都知道普通莫队这个很暴力的算法。
利用玄学排序和分块来通过L,R区间移动改变次数尽量小来减少复杂度
当然也可以带修,也可在树上搞,主要是用于处理出现次数之类的问题
那么问题来了,二维甚至高维的莫队是否也可以实现?
· 例题
题目描述
输入一个$n*m$的矩阵,矩阵的每一个元素都是一个整数,然后有$q$个询问,每次询问一个子矩阵的权值。矩阵的权值是这样定义的,对于一个整数$x$,如果它在该矩阵中出现了$p$次,那么它给该矩阵的权值就贡献$p^2$。
输入格式
第一行两个整数$n$,$m$表示矩阵的规模。
接下来$n$行每行$m$个整数,表示这个矩阵的每个元素。
再下来一行一个整数$q$,表示询问个数。
接下来$q$行每行四个正整数$x_1$,$y_1$,$x_2$,$y_2$,询问以第$x_1$行第$y_1$列和第$x_2$行第$y_2$列的连线为对角线的子矩阵的权值。
输出格式
输出$q$行每行一个整数回答对应询问。
样例输入
3 4
1 3 2 1
1 3 2 4
1 2 3 4
1 2 2 1
8
1 1 2 1
1 1 3 4
1 1 1 1
2 2 3 3
3 4 2 2
1 3 3 1
2 4 3 4
样例输出
8
4
38
1
8
12
27
4
显然这是一个二维莫队板子
二维大家都很好想,主要是如何分块以及如何排序。
· 题解
1、排序与分块
众所周知莫队一维排序
inline bool cmp(node x,node y)
{
if(bl[x.l]==bl[y.l])
{
return bl[x.r]<bl[y.r];
}
return bl[x.l]<bl[y.l];
}
blo=sqrt(n);
for(int i=1;i<=n;i++)
{
a[i]=read();
bl[i]=(i-1)/blo+1;
}
所以模仿一维,二维的排序就是定义两个$pn,pm$
分别表示横向的和纵向的块的大小
$bl[i]$就变成了$bl[i][j]$
将每个点的横纵坐标所属的块分出来
排序就是先按照$L$ 的$x,y$,看他们是否在一个块内,若在,就按$x,y$的值排序
代替繁琐的判断可以直接用坐标转换$Hash$法,直接判断出$x,y$的大小关系,详见$c(x,y)$
若$L$不在同一个块内,就按$L$的块排序
struct ask
{
int x0,y0,x1,y1,id;
}q[M];
int c(int x,int y){return (x-1)*m+y;}
bool mmp (ask a,ask b)//核心排序
{
if(bl[a.x0][a.y0]==bl[b.x0][b.y0])
{
return c(a.x1,a.y1)<c(b.x1,b.y1);
}
else return bl[a.x0][a.y0]<bl[b.x0][b.y0];
}
//分块代码
pn=sqrt(n);pm=sqrt(m);
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
{
mp[i][j]=read();
lower[++tot]=mp[i][j];
bl[i][j]=(i-1)/pn*pm+(j-1)/pm;
}
块的大小我用的$\sqrt{n}$
但据说有更加玄学优秀的块大小 like this
B=pow(r*c,0.5)/pow(m,0.25)+1.0;
以及玄学排序 代码来自ouuan dalao
bool operator<(Query& b)
{
return x/B==b.x/B?(y/B==b.y/B?(xx/B==b.xx/B?yy<b.yy:xx<b.xx):y<b.y):x<b.x;
}
证明我不清楚想了解去ouuan dalao 原博
2、答案更新
模仿一维莫队的4个while,由于变为了二维就有了8个while
定义四个变量 $L0,L1,R0,R1$
分别表示横纵向的矩阵坐标范围
转移直接先加再减,先横再纵按这个顺序写8个while就不容易写混写错
当然在一列/一排 加/减的时候直接一条来操作,一删删一条,一加加一列这样的也很好理解
int L0=1,L1=1,R0=0,R1=0;
for(int i=1;i<=Q;i++)
{
while(L0>q[i].x0)
{
L0--;
for(int j=L1;j<=R1;j++)
add(mp[L0][j]);
}
while(R0<q[i].x1)
{
R0++;
for(int j=L1;j<=R1;j++)
add(mp[R0][j]);
}
while(L1>q[i].y0)
{
L1--;
for(int j=L0;j<=R0;j++)
add(mp[j][L1]);
}
while(R1<q[i].y1)
{
R1++;
for(int j=L0;j<=R0;j++)
add(mp[j][R1]);
}
while(L0<q[i].x0)
{
for(int j=L1;j<=R1;j++)
del(mp[L0][j]);
L0++;
}
while(R0>q[i].x1)
{
for(int j=L1;j<=R1;j++)
del(mp[R0][j]);
R0--;
}
while(L1<q[i].y0)
{
for(int j=L0;j<=R0;j++)
del(mp[j][L1]);
L1++;
}
while(R1>q[i].y1)
{
for(int j=L0;j<=R0;j++)
del(mp[j][R1]);
R1--;
}
f[q[i].id]=ans;
}
3、总代码
很好理解就没写什么注释了。。。。
总之模拟赛手推搞了一个多小时还算是搞出来了。。
#include <bits/stdc++.h>
#define int long long
#define ll long long
using namespace std;
const int N=200+10;
const int M=1e5+110;
int n,m,p,pn,pm,bl[N][N],Q,mp[N][N],tot,lower[N*N*2];
int cnt[N*N*2],ans,f[M];
struct ask
{
int x0,y0,x1,y1,id;
}q[M];
int c(int x,int y){return (x-1)*m+y;}
bool mmp (ask a,ask b)
{
if(bl[a.x0][a.y0]==bl[b.x0][b.y0])
{
return c(a.x1,a.y1)<c(b.x1,b.y1);
}
else return bl[a.x0][a.y0]<bl[b.x0][b.y0];
}
ll read()
{
ll x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void add(int x)
{
cnt[x]++;
ans+=2*cnt[x]-1;
}
void del(int x)
{
cnt[x]--;
ans-=2*cnt[x]+1;
}
main()
{
n=read();m=read();
pn=sqrt(n);pm=sqrt(m);
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
{
mp[i][j]=read();
lower[++tot]=mp[i][j];
bl[i][j]=(i-1)/pn*pm+(j-1)/pm;
}
sort(lower+1,lower+1+tot);
int xx=unique(lower+1,lower+1+tot)-lower-1;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
{
mp[i][j]=lower_bound(lower+1,lower+1+xx,mp[i][j])-lower;
}
Q=read();
for(int i=1;i<=Q;i++)
{
q[i].x0=read();q[i].y0=read();q[i].x1=read();q[i].y1=read();
if(q[i].x0>q[i].x1)swap(q[i].x0,q[i].x1);//注意输入。。。有毒
if(q[i].y0>q[i].y1)swap(q[i].y0,q[i].y1);
q[i].id=i;
}
sort(q+1,q+1+Q,mmp);
int L0=1,L1=1,R0=0,R1=0;
for(int i=1;i<=Q;i++)
{
while(L0>q[i].x0)
{
L0--;
for(int j=L1;j<=R1;j++)
add(mp[L0][j]);
}
while(R0<q[i].x1)
{
R0++;
for(int j=L1;j<=R1;j++)
add(mp[R0][j]);
}
while(L1>q[i].y0)
{
L1--;
for(int j=L0;j<=R0;j++)
add(mp[j][L1]);
}
while(R1<q[i].y1)
{
R1++;
for(int j=L0;j<=R0;j++)
add(mp[j][R1]);
}
while(L0<q[i].x0)
{
for(int j=L1;j<=R1;j++)
del(mp[L0][j]);
L0++;
}
while(R0>q[i].x1)
{
for(int j=L1;j<=R1;j++)
del(mp[R0][j]);
R0--;
}
while(L1<q[i].y0)
{
for(int j=L0;j<=R0;j++)
del(mp[j][L1]);
L1++;
}
while(R1>q[i].y1)
{
for(int j=L0;j<=R0;j++)
del(mp[j][R1]);
R1--;
}
f[q[i].id]=ans;
}
for(int i=1;i<=Q;i++)
printf("%lld\n",f[i]);
return 0;
}