��վ�ܷ�������

点分治学习

蒟蒻又来学习神仙算法了!!

ps :此文章讲解有所借鉴 来源


点分治

· 1. 基本思想及用途

点分治,顾名思义就是在树上的每个节点进行分治。

点的步步拆开也是树的拆开

我们利用dfs将每个子树拆开进行各种操作

运用主要是处理树上距离的统计,当我们要大量处理树上路径时,就需要运用点分治进行大批量处理来完成统一,所以点分治也是离线操作。

· 分治点选择

选点是点分治的核心操作也是减小复杂度的核心步骤

当树退化成一条链之后,显然选择链的中心进行点分治,时间复杂度是最优秀的

为$O( \log n )$

可以看出选择的点左右子树越大,递归层数越多,复杂度越优秀

于是我们选择树的重心作为分治的点

void get_root(int x,int fa)//寻找重心
{
    siz[x]=1;maxp[x]=0;
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(v==fa||vis[v])continue;
        get_root(v,x);
        siz[x]+=siz[v];
        maxp[x]=max(maxp[x],siz[v]);//更新最大的子树
    }
    maxp[x]=max(maxp[x],sum-siz[x]);//求出父亲节点的子树大小,这千万不能漏,利用容斥原理
    if(maxp[root]>maxp[x])root=x;//要求最大子树最小,也就是变向要求平衡
}

· 点分治的实现与应用

模板 luogu 3806

题目简述

给定一棵有n个点的树

询问树上距离为k的点对是否存在。

实现方案

可以直接根据代码进行理解

主要的统计答案思想是

我们不仅仅只找恰好等于$ask[k$]的路径是否存在

若当前路径长$dis$存在,且$dis<ask[k]$ ,就看$ask[k]-dis$这样长度的路径是否也存在

相当于在一棵子树中我们把重心以下的两条路径拼接起来成为一条新的路径

这应该很好啊理解,这样就可以直接判断是否存在路径了。

注意打标直接 |= ,只要一种存在就是真

void cal(int x)
{
    int num=0;
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(vis[v])continue;
        rem[0]=0;dis[v]=edge[i].w;
        getdis(v,x);
         for(int j=rem[0];j;j--)
            for(int k=1;k<=m;k++)
        {
            if(ask[k]>=rem[j])
                test[k]|=judge[ask[k]-rem[j]];//路径分割
        }
        for(int j=rem[0];j;j--)
            q[++num]=rem[j],judge[rem[j]]=1;
    }
     for(int i=1;i<=num;i++)//节省清零时间
        judge[q[i]]=0;
}

核心请直接看代码吧

英文注释注意

#include <bits/stdc++.h>

using namespace std;
const int N=1e6+110;
int n,m,head[N],cnt,siz[N],maxp[N],sum,root,ask[N],vis[N],q[N],rem[N];
int test[N],judge[N],dis[N];
struct node
{
    int nt,to,w;
}edge[N*2];
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
void add(int x,int y,int z)
{
    edge[++cnt]=(node){head[x],y,z};head[x]=cnt;
    edge[++cnt]=(node){head[y],x,z};head[y]=cnt;
}

void get_root(int x,int fa)//find the center
{
    siz[x]=1;maxp[x]=0;
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(v==fa||vis[v])continue;
        get_root(v,x);
        siz[x]+=siz[v];
        maxp[x]=max(maxp[x],siz[v]);
    }
    maxp[x]=max(maxp[x],sum-siz[x]);
    if(maxp[root]>maxp[x])root=x;
}
int getdis(int x,int fa)
{
    rem[++rem[0]]=dis[x];//memorize the recent road
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(v==fa||vis[v])continue;
        dis[v]=dis[x]+edge[i].w;
        getdis(v,x);//update dis

    }

}

void cal(int x)
{
    int num=0;
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(vis[v])continue;
        rem[0]=0;dis[v]=edge[i].w;
        getdis(v,x);
         for(int j=rem[0];j;j--)
            for(int k=1;k<=m;k++)
        {
            if(ask[k]>=rem[j])
                test[k]|=judge[ask[k]-rem[j]];
        }
        for(int j=rem[0];j;j--)
            q[++num]=rem[j],judge[rem[j]]=1;
    }
     for(int i=1;i<=num;i++)//save the time instead of memset
        judge[q[i]]=0;
}
void solve(int x)//partition every tree
{
    vis[x]=judge[0]=1;
    cal(x);//update the tree
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(vis[v])continue;
        sum=siz[v]; root=0;maxp[root]=n;//be careful with the previous settings
        get_root(v,0);
        solve(root);
    }

}


int main()
{
    n=read();m=read();
    for(int i=1;i<n;i++)
    {
        int x,y,z;
        x=read();y=read();z=read();
        add(x,y,z);
    }
    for(int i=1;i<=m;i++)ask[i]=read();
    maxp[root]=sum=n;
    get_root(1,0);
    solve(root);
    for(int i=1;i<=m;i++)
    {
        if(test[i])printf("AYE\n");
        else printf("NAY\n");
    }
    return 0;
}

· 例题

1. Tree

题目简述

给你一棵TREE,以及这棵树上边的距离.问有多少对点它们两者间的距离小于等于K

题解

与上一道题有异曲同工之处,同样写法也很类似

这里就不用memset了非常开心

注意这一题统计的是所有满足题意得路径数量,所以需要前缀和的思想

$rem[x]$与上道题意义相同 ,增加了$sav[x]$数组,这个数组是找完一棵子树所有答案再更新,也就是x的所有子树

rem是当前找的v子树的当前答案,为了更快找到答案,用到了树状数组

如果求等于K的路径条数,非常简单。

本题求小于等于K的路径条数,可以考虑改变统计答案的方法。

原本统计答案是对于一个路径长度$len$,判断$K-len$在之前的子树中出现多少次。

现在统计答案是对于一个路径长度$len$,判断小于等于$K-len$的数在之前的子树中出现多少次。

在树状数组内存的是每个路径的出现次数,每一次统计就统计所有小于等于$len$的路径数,也就是需要求前缀和,利用树状数组就很方便,线段树re

至于为什么统计$K-len$,也和上一题是一个原理,路径合并

void cal(int x)
{
    sav[0]=0;
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(vis[v])continue;
        rem[0]=0;dis[v]=edge[i].w;
        getdis(v,x);
        for(int j=rem[0];j;j--)
        {
            if(rem[j]>m)continue;
             ans+=query(m-rem[j]);//注意这里求得是前缀和!!!
        }
        for(int j=rem[0];j;j--)
        {
            if(rem[j]>m)continue;
            update(rem[j],1);//每出现合法路径,就将次数加一
            ans++;
        }
    }
    for(int i=sav[0];i;i--)
    {
        if(sav[i]>m)continue;
         update(sav[i],-1);//重新找重心,次数全部减一归零
    }

}

完整代码

// luogu-judger-enable-o2
//myloglast luogu p3806
#include <bits/stdc++.h>
#define rson 2*p+1
#define lson 2*p
using namespace std;
const int N=4e4+110;
int n,m,head[N],cnt,siz[N],maxp[N],sum,root,vis[N],q[N];
int dis[N],sav[N],ans,rem[N],t[N];
struct node
{
    int nt,to,w;
}edge[N*2];
void update(int x,int k)
{
    while(x<=m)
    {
        t[x]+=k;
        x+=x&(-x);
    }
}
int query(int x)
{
    int sum=0;
    while(x!=0)
    {
        sum+=t[x];
        x-=x&(-x);

    }
    return sum;
}
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
void add(int x,int y,int z)
{
    edge[++cnt]=(node){head[x],y,z};head[x]=cnt;
    edge[++cnt]=(node){head[y],x,z};head[y]=cnt;
}
void get_root(int x,int fa)//find the center
{
    siz[x]=1;maxp[x]=0;
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(v==fa||vis[v])continue;
        get_root(v,x);
        siz[x]+=siz[v];
        maxp[x]=max(maxp[x],siz[v]);
    }
    maxp[x]=max(maxp[x],sum-siz[x]);
    if(maxp[root]>maxp[x])root=x;
}
void getdis(int x,int fa)
{
    if(dis[x]>m)return;
    rem[++rem[0]]=dis[x];sav[++sav[0]]=dis[x];
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(v==fa||vis[v])continue;
        dis[v]=dis[x]+edge[i].w;
        getdis(v,x);//update the dist of son tree

    }

}
void cal(int x)
{
    sav[0]=0;
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(vis[v])continue;
        rem[0]=0;dis[v]=edge[i].w;
        getdis(v,x);
        for(int j=rem[0];j;j--)
        {
            if(rem[j]>m)continue;
             ans+=query(m-rem[j]);
        }
        for(int j=rem[0];j;j--)
        {
            if(rem[j]>m)continue;
            update(rem[j],1);
            ans++;
        }
    }
    for(int i=sav[0];i;i--)
    {
        if(sav[i]>m)continue;
         update(sav[i],-1);
    }

}
void solve(int x)//partition every tree
{
    vis[x]=1;
    cal(x);//update the tree
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(vis[v])continue;root=0;
        maxp[root]=n;sum=siz[v];
        get_root(v,0);
        solve(root);
    }

}
int main()
{
    n=read();
    for(int i=1;i<n;i++)
    {
        int x,y,z;
        x=read();y=read();z=read();
        add(x,y,z);
    }

    m=read();
    maxp[root]=sum=n;
    get_root(1,0);
    solve(root);
    printf("%d",ans);
    return 0;
}

2. [国家集训队]聪聪可可

题目简述

求树上是3的倍数的路径数与总路径数的比,分数形式需要约分

题解

用$t[x]$数组记录路径长度为$x$的路径条数,那么答案就是

t[1]*t[2]*2+t[0]*t[0]

很容易理解1+2=3 是3的倍数合法

3+3=6 也是3的倍数,合法

但因为(1,2) (2,1)是不同的方案,需要乘以2

乘法原理得出路径数的组合就是相乘

另外这道题就用到了点分治基本思想中的容斥原理

在统计答案时,先处理所有的A子树的路径

还要减去不合法的组合路径

如图

统计A子树时得到的所有路径是

A—>A
A—>B
A—>B—>C
A—>B—>D
A—>E
A—>E—>F (按照先序遍历顺序罗列)

我们要做的事上述所有路径进行两两合并

然而A—>B—>C 和A—>B—>D 的合并显然不合法

于是我们需要减掉三个不合法的组合答案

A—>B 与A—>B—>C , A—>B与A—>B—>D ,A—>B—>D 与A—>B—>C

因为这并不是一条树上(简单)路径,出现了重边,我们要想办法把这种情况处理掉。
处理方法很简单,减去每个子树的单独贡献。
例如对于以B为根的子树,就会减去:
B—>B
B—>C
B—>D
这三条路径组合的贡献

会发现恰好也是三条,其实也是根据排列组合的性质所有就很简单

统计答案是只需要先
Ans+=cal(x,0);

再减掉不合法的答案就行

Ans-=cal(v,edge[i].w);

然而由于上两题写法比较特殊就没有这个问题。。

void solve(int x)//partition every tree
{
    vis[x]=1;
    Ans+=cal(x,0);//calculate the answers of the root tree first
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(vis[v])continue;
        Ans-=cal(v,edge[i].w);//delete the duplicate one
        root=0;
        maxp[root]=n;sum=siz[v];
        get_root(v,0);
        solve(root);
    }

}

完整代码

// luogu-judger-enable-o2
//myloglast luogu p3806
#include <bits/stdc++.h>
#define rson 2*p+1
#define lson 2*p
#define ll long long
#define int long long
using namespace std;
const int N=4e4+110;
int n,m,head[N],cnt,siz[N],maxp[N],sum,root,vis[N],q[N],Ans;
int dis[N],ans,t[N];
struct node
{
    int nt,to,w;
}edge[N*2];
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
void add(int x,int y,int z)
{
    edge[++cnt]=(node){head[x],y,z};head[x]=cnt;
    edge[++cnt]=(node){head[y],x,z};head[y]=cnt;
}
void get_root(int x,int fa)//find the center
{
    siz[x]=1;maxp[x]=0;
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(v==fa||vis[v])continue;
        get_root(v,x);
        siz[x]+=siz[v];
        maxp[x]=max(maxp[x],siz[v]);
    }
    maxp[x]=max(maxp[x],sum-siz[x]);
    if(maxp[root]>maxp[x])root=x;
}
void getdis(int x,int fa)
{
    t[dis[x]]++;//memorize appearing times of every road
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(v==fa||vis[v])continue;
        dis[v]=(dis[x]+edge[i].w)%3;
        getdis(v,x);//update the dist of son tree

    }

}
int cal(int x,int v)
{
    t[0]=t[1]=t[2]=0;//clear
    dis[x]=v%3;//start point
    getdis(x,0);
    return t[1]*t[2]*2+t[0]*t[0];
    //1+2=3%3=0,3+3=6%3=0
    //but (1,2) (2,1) is different answers,so we should mult 2
}
void solve(int x)//partition every tree
{
    vis[x]=1;
    Ans+=cal(x,0);//calculate the answers of the root tree first
    for(int i=head[x];i;i=edge[i].nt)
    {
        int v=edge[i].to;
        if(vis[v])continue;
        Ans-=cal(v,edge[i].w);//delete the duplicate one
        root=0;
        maxp[root]=n;sum=siz[v];
        get_root(v,0);
        solve(root);
    }

}
ll gcd(ll a,ll b)
{
    return !b?a:gcd(b,a%b);
}
 main()
{
    n=read();
    for(int i=1;i<n;i++)
    {
        int x,y,z;
        x=read();y=read();z=read()%3;
        add(x,y,z);
    }
    maxp[root]=sum=n;
    get_root(1,0);
    solve(root);
    ans=n*n;
    ll GCD=gcd(ans,Ans);
    printf("%lld/%lld",Ans/GCD,ans/GCD);
    return 0;
}