��վ�ܷ�������

二维莫队

今日雅礼模拟赛day2 T2就只做了二维莫队,就来写博客更一下


二维莫队

· 前言

大家都知道普通莫队这个很暴力的算法。

利用玄学排序和分块来通过L,R区间移动改变次数尽量小来减少复杂度

当然也可以带修,也可在树上搞,主要是用于处理出现次数之类的问题

那么问题来了,二维甚至高维的莫队是否也可以实现?

· 例题

详见bzoj2639

题目描述

输入一个$n*m$的矩阵,矩阵的每一个元素都是一个整数,然后有$q$个询问,每次询问一个子矩阵的权值。矩阵的权值是这样定义的,对于一个整数$x$,如果它在该矩阵中出现了$p$次,那么它给该矩阵的权值就贡献$p^2$。

输入格式

第一行两个整数$n$,$m$表示矩阵的规模。
接下来$n$行每行$m$个整数,表示这个矩阵的每个元素。
再下来一行一个整数$q$,表示询问个数。
接下来$q$行每行四个正整数$x_1$,$y_1$,$x_2$,$y_2$,询问以第$x_1$行第$y_1$列和第$x_2$行第$y_2$列的连线为对角线的子矩阵的权值。

输出格式

输出$q$行每行一个整数回答对应询问。

样例输入

3 4 
1 3 2 1
1 3 2 4
1 2 3 4
1 2 2 1
8
1 1 2 1
1 1 3 4
1 1 1 1
2 2 3 3
3 4 2 2
1 3 3 1
2 4 3 4

样例输出

8
4
38
1
8
12
27
4

显然这是一个二维莫队板子

二维大家都很好想,主要是如何分块以及如何排序。

· 题解

1、排序与分块

众所周知莫队一维排序

inline bool cmp(node x,node y)
{
    if(bl[x.l]==bl[y.l])
    {
        return bl[x.r]<bl[y.r];
    }
    return bl[x.l]<bl[y.l];
}
    blo=sqrt(n);
    for(int i=1;i<=n;i++)
    {
        a[i]=read();
        bl[i]=(i-1)/blo+1;
    }

所以模仿一维,二维的排序就是定义两个$pn,pm$

分别表示横向的和纵向的块的大小

$bl[i]$就变成了$bl[i][j]$

将每个点的横纵坐标所属的块分出来

排序就是先按照$L$ 的$x,y$,看他们是否在一个块内,若在,就按$x,y$的值排序

代替繁琐的判断可以直接用坐标转换$Hash$法,直接判断出$x,y$的大小关系,详见$c(x,y)$

若$L$不在同一个块内,就按$L$的块排序

struct ask
{
    int x0,y0,x1,y1,id;
}q[M];
int c(int x,int y){return (x-1)*m+y;}
bool mmp (ask a,ask b)//核心排序
{
    if(bl[a.x0][a.y0]==bl[b.x0][b.y0])
    {
        return c(a.x1,a.y1)<c(b.x1,b.y1);
    }
    else return bl[a.x0][a.y0]<bl[b.x0][b.y0];

}
//分块代码
 pn=sqrt(n);pm=sqrt(m);
for(int i=1;i<=n;i++)
     for(int j=1;j<=m;j++)
    {
        mp[i][j]=read();
        lower[++tot]=mp[i][j];
        bl[i][j]=(i-1)/pn*pm+(j-1)/pm;
    }

块的大小我用的$\sqrt{n}$

但据说有更加玄学优秀的块大小 like this

  B=pow(r*c,0.5)/pow(m,0.25)+1.0;

以及玄学排序 代码来自ouuan dalao

 bool operator<(Query& b)
  {
     return x/B==b.x/B?(y/B==b.y/B?(xx/B==b.xx/B?yy<b.yy:xx<b.xx):y<b.y):x<b.x;
  }

证明我不清楚想了解去ouuan dalao 原博

2、答案更新

模仿一维莫队的4个while,由于变为了二维就有了8个while

定义四个变量 $L0,L1,R0,R1$

分别表示横纵向的矩阵坐标范围

转移直接先加再减,先横再纵按这个顺序写8个while就不容易写混写错

当然在一列/一排 加/减的时候直接一条来操作,一删删一条,一加加一列这样的也很好理解

int L0=1,L1=1,R0=0,R1=0;
    for(int i=1;i<=Q;i++)
    {
        while(L0>q[i].x0)
        {
            L0--;
            for(int j=L1;j<=R1;j++)
              add(mp[L0][j]);

        }
        while(R0<q[i].x1)
        {
            R0++;
            for(int j=L1;j<=R1;j++)
                add(mp[R0][j]);

        }
        while(L1>q[i].y0)
        {
            L1--;
            for(int j=L0;j<=R0;j++)
              add(mp[j][L1]);

        }
        while(R1<q[i].y1)
        {
            R1++;
            for(int j=L0;j<=R0;j++)
              add(mp[j][R1]);
        }
        while(L0<q[i].x0)
        {
            for(int j=L1;j<=R1;j++)
                del(mp[L0][j]);
            L0++;
        }
        while(R0>q[i].x1)
        {
            for(int j=L1;j<=R1;j++)
                del(mp[R0][j]);
            R0--;
        }
        while(L1<q[i].y0)
        {
            for(int j=L0;j<=R0;j++)
                del(mp[j][L1]);
            L1++;
        }
        while(R1>q[i].y1)
        {
            for(int j=L0;j<=R0;j++)
                del(mp[j][R1]);
            R1--;
        }
        f[q[i].id]=ans;
    }

3、总代码

很好理解就没写什么注释了。。。。

总之模拟赛手推搞了一个多小时还算是搞出来了。。

#include <bits/stdc++.h>
#define int long long
#define ll long long
using namespace std;
const int N=200+10;
const int M=1e5+110;
int n,m,p,pn,pm,bl[N][N],Q,mp[N][N],tot,lower[N*N*2];
int cnt[N*N*2],ans,f[M];
struct ask
{
    int x0,y0,x1,y1,id;
}q[M];
int c(int x,int y){return (x-1)*m+y;}
bool mmp (ask a,ask b)
{
    if(bl[a.x0][a.y0]==bl[b.x0][b.y0])
    {
        return c(a.x1,a.y1)<c(b.x1,b.y1);
    }
    else return bl[a.x0][a.y0]<bl[b.x0][b.y0];

}
ll read()
{
    ll x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
void add(int x)
{
    cnt[x]++;
    ans+=2*cnt[x]-1;
}
void del(int x)
{
    cnt[x]--;
    ans-=2*cnt[x]+1;
}
main()
{

    n=read();m=read();
    pn=sqrt(n);pm=sqrt(m);
    for(int i=1;i<=n;i++)
     for(int j=1;j<=m;j++)
    {
        mp[i][j]=read();
        lower[++tot]=mp[i][j];
        bl[i][j]=(i-1)/pn*pm+(j-1)/pm;
    }
    sort(lower+1,lower+1+tot);
    int xx=unique(lower+1,lower+1+tot)-lower-1;
    for(int i=1;i<=n;i++)
        for(int j=1;j<=m;j++)
    {
        mp[i][j]=lower_bound(lower+1,lower+1+xx,mp[i][j])-lower;
    }
    Q=read();
    for(int i=1;i<=Q;i++)
    {
        q[i].x0=read();q[i].y0=read();q[i].x1=read();q[i].y1=read();
        if(q[i].x0>q[i].x1)swap(q[i].x0,q[i].x1);//注意输入。。。有毒
        if(q[i].y0>q[i].y1)swap(q[i].y0,q[i].y1);
        q[i].id=i;
    }
    sort(q+1,q+1+Q,mmp);
    int L0=1,L1=1,R0=0,R1=0;
    for(int i=1;i<=Q;i++)
    {
        while(L0>q[i].x0)
        {
            L0--;
            for(int j=L1;j<=R1;j++)
              add(mp[L0][j]);

        }
        while(R0<q[i].x1)
        {
            R0++;
            for(int j=L1;j<=R1;j++)
                add(mp[R0][j]);

        }
        while(L1>q[i].y0)
        {
            L1--;
            for(int j=L0;j<=R0;j++)
              add(mp[j][L1]);

        }
        while(R1<q[i].y1)
        {
            R1++;
            for(int j=L0;j<=R0;j++)
              add(mp[j][R1]);
        }
        while(L0<q[i].x0)
        {
            for(int j=L1;j<=R1;j++)
                del(mp[L0][j]);
            L0++;
        }
        while(R0>q[i].x1)
        {
            for(int j=L1;j<=R1;j++)
                del(mp[R0][j]);
            R0--;
        }
        while(L1<q[i].y0)
        {
            for(int j=L0;j<=R0;j++)
                del(mp[j][L1]);
            L1++;
        }
        while(R1>q[i].y1)
        {
            for(int j=L0;j<=R0;j++)
                del(mp[j][R1]);
            R1--;
        }
        f[q[i].id]=ans;
    }
    for(int i=1;i<=Q;i++)
        printf("%lld\n",f[i]);
    return 0;
}