# Weak Pair

Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 262144/262144 K (Java/Others)

Problem Description

You are given a  tree of  nodes, labeled from 1 to . To the th node a non-negative value  is assigned.An  pair of nodes  is said to be  if   (1)  is an ancestor of  (Note: In this problem a node  is not considered an ancestor of itself);   (2) . Can you find the number of weak pairs in the tree?

Input

There are multiple cases in the data set.   The first line of input contains an integer  denoting number of test cases.   For each case, the first line contains two space-separated integers,  and , respectively.   The second line contains  space-separated integers, denoting  to .   Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes  and  , where node  is the parent of node .   Constrains:

Output

For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.

Sample Input

```1
2 3
1 2
1 2```

Sample Output

```1

#include <string.h>
#include <stdlib.h>
#include <algorithm>
#include <math.h>
#include <stdio.h>
#include <map>

using namespace std;
const int maxn=1e5;
typedef long long int LL;
int n;
LL k;
LL a[maxn+5];
struct Node
{
int value;
int next;
}edge[maxn*2+5];
int vis[maxn+5];
int tot;
int c[maxn*2+5];
LL b[maxn+5];
LL e[maxn*2+5];
map<LL,int> m;
{
edge[tot].value=y;
}
int lowbit(int x)
{
return x&(-x);
}
void update(int x,int num)
{
while(x<=n*2)
{
c[x]+=num;
x+=lowbit(x);
}
}
int sum(int x)
{
int _sum=0;
while(x>0)
{
_sum+=c[x];
x-=lowbit(x);
}
return  _sum;
}
LL ans;
void dfs(int root)
{
vis[root]=1;
{
int v=edge[i].value;
if(!vis[v])
{
ans+=sum(m[b[v]]);
update(m[a[v]],1);
dfs(v);
update(m[a[v]],-1);
}
}
}

void init()
{
memset(c,0,sizeof(c));
memset(vis,0,sizeof(vis));
tot=0;
}
int tag[maxn+5];
int main()
{
int t;
scanf("%d",&t);
int x,y;
while(t--)
{
scanf("%d%lld",&n,&k);
init();
int cnt=n;
m.clear();
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
e[i]=a[i];
if(a[i]==0)
m[a[i]]=2*n;
else
{
b[i]=k/a[i];
e[++cnt]=b[i];
}
}
sort(e+1,e+cnt+1);
int tot=1;
for(int i=1;i<=cnt;i++)
{
if(!m.count(e[i]))
m[e[i]]=tot++;
}
memset(tag,0,sizeof(tag));
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
tag[y]++;
}
int root;
for(int i=1;i<=n;i++)
{
if(tag[i]==0)
root=i;
}
ans=0;
update(m[a[root]],1);
dfs(root);
printf("%lld\n",ans);
}
return 0;
}

#include <string.h>
#include <algorithm>
#include <stdlib.h>
#include <math.h>
#include <stdio.h>
#include <string>
#include <map>
#include <vector>

using namespace std;
typedef long long int LL;
const int maxn=1e5;
vector<int> v[maxn+5];
int sum[maxn*8+5];
int n;
LL k;
LL a[maxn+5];
LL b[maxn+5];
LL e[maxn*2+5];
int aa[maxn+5];
int bb[maxn+6];
map<LL,int> m;

void PushUp(int node)
{
sum[node]=sum[node<<1]+sum[node<<1|1];
}
void update(int node,int begin,int end,int ind,int num)
{
if(begin==end)
{
sum[node]+=num*(end-begin+1);
return;
}
int m=(begin+end)>>1;
if(ind<=m)
update(node<<1,begin,m,ind,num);
else
update(node<<1|1,m+1,end,ind,num);
PushUp(node);
}
LL Query(int node,int begin,int end,int left,int right)
{
if(left<=begin&&end<=right)
return sum[node];
int m=(begin+end)>>1;
LL ret=0;
if(left<=m)
ret+=Query(node<<1,begin,m,left,right);
if(right>m)
ret+=Query(node<<1|1,m+1,end,left,right);
PushUp(node);
return ret;
}
int tag[maxn+5];
LL ans;
void dfs(int root)
{
int len=v[root].size();
for(int i=0;i<len;i++)
{
int w=v[root][i];
ans+=Query(1,1,2*n,1,bb[w]);
update(1,1,2*n,aa[w],1);
dfs(v[root][i]);
update(1,1,2*n,aa[w],-1);
}
}
void init()
{

memset(sum,0,sizeof(sum));
memset(tag,0,sizeof(tag));
}
int main()
{
int t;
scanf("%d",&t);
int x,y;
while(t--)
{
scanf("%d%lld",&n,&k);
int cnt=0;
init();
m.clear();
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
e[++cnt]=a[i];
b[i]=k/a[i];
e[++cnt]=b[i];
v[i].clear();
}
sort(e+1,e+cnt+1);
int cot=1;
for(int i=1;i<=cnt;i++)
{
if(!m.count(e[i]))
m[e[i]]=cot++;
}
for(int i=1;i<=n;i++)
{
aa[i]=m[a[i]];
bb[i]=m[b[i]];
}
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
v[x].push_back(y);
tag[y]++;
}
int root;
for(int i=1;i<=n;i++)
{
if(tag[i]==0)
root=i;
}
ans=0;
update(1,1,2*n,m[a[root]],1);
dfs(root);
printf("%lld\n",ans);
}
return 0;
}

```</pre><pre code_snippet_id="1877993" snippet_file_name="blog_20160912_2_1715825" name="code" class="html"><pre name="code" class="html">#include <iostream>
#include <string.h>
#include <stdlib.h>
#include <algorithm>
#include <math.h>
#include <string>
#include <stdio.h>
#include <vector>

using namespace std;
const int maxn=1e5;
const long long int len=1e18;
typedef long long int LL;
LL a[maxn+5];
LL b[maxn+5];
int n;
LL k;
vector<int> v[maxn+5];
struct Node
{
int lch,rch;
LL sum;
Node(){};
Node(int lch,int rch,LL sum)
{
this->lch=lch;
this->rch=rch;
this->sum=sum;
}
}tr[maxn*100+5];
int p;
void PushUp(int node)
{
tr[node].sum=tr[tr[node].lch].sum+tr[tr[node].rch].sum;
}

int newnode()
{
tr[++p]=Node(-1,-1,0);
return p;
}
void update(int node,LL begin,LL end,LL ind,int num)
{
if(begin==end)
{
tr[node].sum+=num;
return;
}
LL m=(begin+end)>>1;
if(tr[node].lch==-1) tr[node].lch=newnode();
if(tr[node].rch==-1) tr[node].rch=newnode();
if(ind<=m)
update(tr[node].lch,begin,m,ind,num);
else
update(tr[node].rch,m+1,end,ind,num);
PushUp(node);
}
LL query(int node,LL begin,LL end,LL left,LL right)
{
if(node==-1)
return 0;
if(left<=begin&&end<=right)
return tr[node].sum;
LL m=(begin+end)>>1;
LL ret=0;
if(left<=m)
ret+=query(tr[node].lch,begin,m,left,right);
if(right>m)
ret+=query(tr[node].rch,m+1,end,left,right);
PushUp(node);
return ret;

}
int tag[maxn+5];
LL ans;
void dfs(int root)
{
int len1=v[root].size();
for(int i=0;i<len1;i++)
{
int w=v[root][i];
ans+=query(1,0,len,0,b[w]);
update(1,0,len,a[w],1);
dfs(w);
update(1,0,len,a[w],-1);
}
}
void init()
{
memset(tag,0,sizeof(tag));
p=0;
newnode();
}
int main()
{
int t;
scanf("%d",&t);
int x,y;
while(t--)
{
scanf("%d%lld",&n,&k);
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
b[i]=k/a[i];
v[i].clear();
}
init();
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
v[x].push_back(y);
tag[y]++;
}
int root;
for(int i=1;i<=n;i++)
{
if(!tag[i])
root=i;
}
ans=0;
update(1,0,len,a[root],1);
dfs(root);
printf("%lld\n",ans);
}
return 0;
}```

```#include <iostream>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <algorithm>
#include <math.h>

using namespace std;
const int maxn=1e5;
typedef long long int LL;
int rt[maxn*100+5];
int ls[maxn*100+5];
int rs[maxn*100+5];
LL sum[maxn*100+5];
int a[maxn+5];
LL k;
int n;
int p;
int l,r;
int newnode()
{
sum[p]=ls[p]=rs[p]=0;
return p++;
}
void build(int &node,int begin,int end,LL val)
{
if(!node) node=newnode();
sum[node]=1;
if(begin==end) return;
LL mid=(begin+end)>>1;
if(val<=mid) build(ls[node],begin,mid,val);
else build(rs[node],mid+1,end,val);
}
LL Query(int node,int begin,int end,LL val)
{
if(!node||val<begin) return 0;
if(begin==end) return sum[node];
LL mid=(begin+end)>>1;
if(val<=mid) return Query(ls[node],begin,mid,val);
else return sum[ls[node]]+Query(rs[node],mid+1,end,val);
}
void mergge(int &x,int y, int begin,int end)
{
if(!x||!y) {x=x^y;return;}
sum[x]+=sum[y];
if(begin==end) return;
LL mid=(begin+end)>>1;
mergge(ls[x],ls[y],begin,mid);
mergge(rs[x],rs[y],mid+1,end);
}
struct Node
{
int value;
int next;
}edge[maxn*2+5];
int tot;
{
edge[tot].value=y;
}
LL ans;
void dfs(int root)
{
{
int w=edge[i].value;
dfs(w);
mergge(rt[root],rt[w],l,r);
}
ans+=Query(rt[root],l,r,k/a[root]);
if(k>=1ll*a[root]*a[root])
ans--;
}
int tag[maxn+5];
int main()
{
int t;
scanf("%d",&t);
int x,y;
while(t--)
{
scanf("%d%lld",&n,&k);
p=1;
memset(tag,0,sizeof(tag));

tot=0;
l=1e9;r=0;
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
l=min(l,a[i]);r=max(r,a[i]);
}
for(int i=1;i<=n;i++)
build(rt[i]=0,l,r,a[i]);

for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
tag[y]++;
}
int root;
for(int i=1;i<=n;i++)
if(tag[i]==0) root=i;
ans=0;
dfs(root);
printf("%lld\n",ans);
}
return 0;
}```

```#include <iostream>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <algorithm>
#include <math.h>
#include <queue>

using namespace std;
const int maxn=1e5;
typedef long long int LL;
int rt[maxn*100+5];
int ls[maxn*100+5];
int rs[maxn*100+5];
LL sum[maxn*100+5];
int a[maxn+5];
int f[maxn+5];
LL k;
int n;
int p;
int l,r;
queue<int> q;
int newnode()
{
sum[p]=ls[p]=rs[p]=0;
return p++;
}
void build(int &node,int begin,int end,LL val)
{
if(!node) node=newnode();
sum[node]=1;
if(begin==end) return;
LL mid=(begin+end)>>1;
if(val<=mid) build(ls[node],begin,mid,val);
else build(rs[node],mid+1,end,val);
}
LL Query(int node,int begin,int end,LL val)
{
if(!node||val<begin) return 0;
if(begin==end) return sum[node];
LL mid=(begin+end)>>1;
if(val<=mid) return Query(ls[node],begin,mid,val);
else return sum[ls[node]]+Query(rs[node],mid+1,end,val);
}
void mergge(int &x,int y, int begin,int end)
{
if(!x||!y) {x=x^y;return;}
sum[x]+=sum[y];
if(begin==end) return;
LL mid=(begin+end)>>1;
mergge(ls[x],ls[y],begin,mid);
mergge(rs[x],rs[y],mid+1,end);
}
LL ans;
int tag[maxn+5];
int main()
{
int t;
scanf("%d",&t);
int x,y;
while(t--)
{
scanf("%d%lld",&n,&k);
p=1;
memset(tag,0,sizeof(tag));
l=1e9;r=0;
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
l=min(l,a[i]);r=max(r,a[i]);
}
for(int i=1;i<=n;i++)
build(rt[i]=0,l,r,a[i]);

for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
tag[x]++;
f[y]=x;
}
for(int i=1;i<=n;i++)
{
if(tag[i]==0)
q.push(i);
}
ans=0;
while(!q.empty())
{
int x=q.front();q.pop();
if(1LL*a[x]*a[x]<=k) ans--;
ans+=Query(rt[x],l,r,k/a[x]);
mergge(rt[f[x]],rt[x],l,r);
if(!--tag[f[x]]) q.push(f[x]);

}
printf("%lld\n",ans);
}
return 0;
}```

```#include <iostream>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <algorithm>
#include <math.h>
#include <stack>

using namespace std;
const int maxn=1e5;
typedef long long int LL;
int rt[maxn*100+5];
int ls[maxn*100+5];
int rs[maxn*100+5];
LL sum[maxn*100+5];
int p;
int n;
LL k;
int l,r;
void update(int &node,int l,int r,int val)
{

ls[p]=ls[node];rs[p]=rs[node];
sum[p]=sum[node];node=p;
p++;

if(l==r)
{
sum[node]++;
return;
}
sum[node]++;
int mid=(l+r)>>1;
if(val<=mid) update(ls[node],l,mid,val);
else update(rs[node],mid+1,r,val);
}
LL query(int node,int l,int r,LL val)
{
if(val<l) return 0;
if(!node) return 0;
if(l==r) return sum[node];
LL mid=(l+r)>>1;
if(val<=mid) return query(ls[node],l,mid,val);
else return sum[ls[node]]+query(rs[node],mid+1,r,val);
}
struct Node
{
int value;
int next;
}edge[maxn*2+5];
int tot;
{
edge[tot].value=y;
}
int res[maxn*2];
int a[maxn+5];
int cot;
void dfs(int root)
{
res[cot++]=root;
{
int w=edge[i].value;
dfs(w);
}
res[cot++]=root;
}
int tag[maxn+5];
int flag[maxn+5];
int main()
{
int t;
scanf("%d",&t);
int x,y;
while(t--)
{
scanf("%d%lld",&n,&k);
l=1e9;r=0;
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
l=min(l,a[i]);r=max(r,a[i]);
}

memset(tag,0,sizeof(tag));
tot=0;
p=1;
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
tag[y]++;
}
int root;
for(int i=1;i<=n;i++)
{
if(!tag[i])
root=i;
}
cot=0;
dfs(root);
memset(flag,0,sizeof(flag));
update(rt[res[0]],l,r,a[res[0]]);
flag[res[0]]=1;
LL ans=0;
int now=0;
for(int i=1;i<cot;i++)
{
if(flag[res[i]]==1)
{
LL ans1=query(rt[res[now]],l,r,k/a[res[i]]);
LL ans2=query(rt[res[i]],l,r,k/a[res[i]]);
//cout<<ans1<<" "<<ans2<<endl;
ans+=ans1-ans2;
continue;
}
flag[res[i]]=1;
update(rt[res[i]]=rt[res[now]],l,r,a[res[i]]);
now=i;
}
printf("%lld\n",ans);
}
return 0;
}```

