// C++ program for the above approach
#include <bits/stdc++.h>
using namespace std;
// Function to implement DFS traversal
void Solution_dfs(int v, int color[], int red,
int blue, int* sub_red,
int* sub_blue, int* vis,
map<int, vector<int> >& adj,
int* ans)
{
// Mark node v as visited
vis[v] = 1;
// Traverse Adj_List of node v
for (int i = 0; i < adj[v].size();
i++) {
// If current node is not visited
if (vis[adj[v][i]] == 0) {
// DFS call for current node
Solution_dfs(adj[v][i], color,
red, blue,
sub_red, sub_blue,
vis, adj, ans);
// Count the total red and blue
// nodes of children of its subtree
sub_red[v] += sub_red[adj[v][i]];
sub_blue[v] += sub_blue[adj[v][i]];
}
}
if (color[v] == 1) {
sub_red[v]++;
}
// Count the no. of red and blue
// nodes in the subtree
if (color[v] == 2) {
sub_blue[v]++;
}
// If subtree contains all
// red node & no blue node
if (sub_red[v] == red
&& sub_blue[v] == 0) {
(*ans)++;
}
// If subtree contains all
// blue node & no red node
if (sub_red[v] == 0
&& sub_blue[v] == blue) {
(*ans)++;
}
}
// Function to count the number of
// nodes with red color
int countRed(int color[], int n)
{
int red = 0;
for (int i = 0; i < n; i++) {
if (color[i] == 1)
red++;
}
return red;
}
// Function to count the number of
// nodes with blue color
int countBlue(int color[], int n)
{
int blue = 0;
for (int i = 0; i < n; i++) {
if (color[i] == 2)
blue++;
}
return blue;
}
// Function to create a Tree with
// given vertices
void buildTree(int edge[][2],
map<int, vector<int> >& m,
int n)
{
int u, v, i;
// Traverse the edge[] array
for (i = 0; i < n - 1; i++) {
u = edge[i][0] - 1;
v = edge[i][1] - 1;
// Create adjacency list
m[u].push_back(v);
m[v].push_back(u);
}
}
// Function to count the number of
// subtree with the given condition
void countSubtree(int color[], int n,
int edge[][2])
{
// For creating adjacency list
map<int, vector<int> > adj;
int ans = 0;
// To store the count of subtree
// with only blue and red color
int sub_red[n + 3] = { 0 };
int sub_blue[n + 3] = { 0 };
// visited array for DFS Traversal
int vis[n + 3] = { 0 };
// Count the number of red
// node in the tree
int red = countRed(color, n);
// Count the number of blue
// node in the tree
int blue = countBlue(color, n);
// Function Call to build tree
buildTree(edge, adj, n);
// DFS Traversal
Solution_dfs(0, color, red, blue,
sub_red, sub_blue,
vis, adj, &ans);
// Print the final count
cout << ans;
}
// Driver Code
int main()
{
int N = 5;
int color[] = { 1, 0, 0, 0, 2 };
int edge[][2] = { { 1, 2 },
{ 2, 3 },
{ 3, 4 },
{ 4, 5 } };
countSubtree(color, N, edge);
return 0;
}