题目依旧给了一种很强烈的暗示为
答案就是两个字串的两两后缀的最长公共前缀长度之和。
把两个字串合成成一个字串。
于是求出$\text{height}$数组,就可以使用暴力了。
不妨再结合一些性质,两个后缀之间的最长公共前缀长度就是$\min\limits_{rk[i]\le k\le rk[j]}\{\text{height}[k]\}$
做一些优化,找出$\min\limits_{rk[i]\le k\le rk[j]}\{\text{height}[k]\}$相等的一段。
这个可以用单调栈实现。
比如$i,[L_i,R_i]$表示$[L_i,R_i]$经过$i$的子区间的最长公共前缀长度均为$\text{height}[i]$。
那么就可以统计左端点在$[L_i,i-1]$,右端点在$[i,R_i]$的数量,再将他们相乘即可。
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#define gc getchar()
#define ll long long
#define fin(s) freopen(s".in","r",stdin)
#define I inline
using namespace std;
const int N=4e5+5,M=N<<1;
template<class o>I void qr(o &x)
{
char c=gc;int f=1;x=0;
while(c<'0'||c>'9'){if(c=='-')f=-1;c=gc;}
while(c>='0'&&c<='9'){x=x*10+(c^48);c=gc;}
x*=f;
}
template<class o>I void qw(o x)
{
if(x<0)x=-x,putchar('-');
if(x/10)qw(x/10);
putchar(x%10+48);
}
int c[N],sa[N],wa[N],wb[N],wv[N],n,m,rk[N],height[N];char s1[N],s2[N],s3[N];
I bool cmp(int *y,int a,int b,int l){return y[a]==y[b]&&y[a+l]==y[b+l];}
void SA(char *s)
{
int *x=wa,*y=wb,*t;m=300;
for(int i=1;i<=m;i++)c[i]=0;
for(int i=1;i<=n;i++)++c[x[i]=s[i]-'a'+1];
for(int i=2;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i>=1;i--)sa[c[x[i]]--]=i;
for(int j=1,p=0;p<n;j<<=1,m=p)
{
p=0;for(int i=n-j+1;i<=n;i++)y[++p]=i;
for(int i=1;i<=n;i++)if(sa[i]>j)y[++p]=sa[i]-j;
for(int i=1;i<=m;i++)c[i]=0;
for(int i=1;i<=n;i++)++c[wv[i]=x[y[i]]];
for(int i=2;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i>=1;i--)sa[c[wv[i]]--]=y[i];
t=x;x=y;y=t;p=1;x[sa[1]]=1;
for(int i=2;i<=n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p:++p;
}
}
void gh(char *s)
{
for(int i=1;i<=n;i++)rk[sa[i]]=i;
for(int i=1,j=0,k=0;i<=n;height[rk[i++]]=k)
for(k?k--:k,j=sa[rk[i]-1];s[i+k]==s[j+k];++k);
}
int L[N],R[N],S1[N],S2[N],sta[N],top;
inline ll calc()
{
ll ans=0;int p1=strlen(s1+1),p2=p1+2;
for(int i=1;i<=n;i++)
{
for(;top&&height[sta[top]]>=height[i];--top)
R[sta[top]]=i;
L[i]=sta[top];sta[++top]=i;
}
while(top)R[sta[top--]]=n+1;
for(int i=1;i<=n;i++)
{
S1[i]=S1[i-1]+(1<=sa[i]&&sa[i]<=p1);
S2[i]=S2[i-1]+(p2<=sa[i]&&sa[i]<=n);
}
for(int i=1;i<=n;i++)
{
int l=L[i]+1,r=R[i]-1;
ans+=1ll*(S2[r]-S2[i-1])*(S1[i-1]-S1[l-2])*height[i];
ans+=1ll*(S1[r]-S1[i-1])*(S2[i-1]-S2[l-2])*height[i];
}
return ans;
}
int main()
{
ll ans=0;
scanf("%s%s",s1+1,s2+1);
int p1=strlen(s1+1),p2=strlen(s2+1);
for(int i=1;i<=p1;i++)
s3[i]=s1[i];
n=p1+1;
s3[n]='z'+1;
for(int i=1;i<=p2;i++)
s3[i+n]=s2[i];
n+=p2;
SA(s3);
gh(s3);
ans+=calc();
qw(ans);puts("");
return 0;
}
最后一次更新于2020-05-13
0 条评论