REBIT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Kritagya Agarwal
Tester: Felipe Mota
Editorialist: Taranpreet Singh

DIFFICULTY:

Easy-Med

PREREQUISITES:

Probabilities and Implementation would do.

PROBLEM:

Given a logical expression with operands replaced by ‘#’, find the probability of the expression to evaluate to ‘0’, ‘1’, ‘a’ and ‘A’ if each ‘#’ can take any value among ‘0’, ‘1’, ‘a’ and ‘A’ randomly independently of each other.

QUICK EXPLANATION

  • We can convert the given logical expression into a full binary tree, where leaves represent operands and non-leaf nodes represent operator. A subtree of a node represents sub-expression of the given expression and root node represents the whole expression.
  • We can recursively compute for each node the probability to obtain ‘0’, ‘1’, ‘a’ and ‘A’ in a bottom-up manner, computing from leaves to parent and then by brute-forcing each pair of possible values for left and right child.

EXPLANATION

First of all, we need to note that the given expression can also be represented as a full binary tree, also known as Binary Expression Tree.

This tree has the following properties.

  • Each leaf node represents and operand while each non-leaf node represents an operator, having exactly two children representing operands for this operator.
  • Each subtree of this node corresponds to some sub-expression of the given expression.

The construction of the binary expression tree is easy using Stack, or just making some observations, as nicely explained here as well as here.

Now, we have constructed the binary tree. We need to compute the probabilities of each value for each node in the tree.

  • For a leaf, all four values are equiprobable, thus equal to 1/4.
  • For a non-leaf node, we can consider each pair of values for the left and right subtree. Since the operator is fixed, we can find the probability of each value appearing at the current node.

Applying this from a bottom-up manner, we can recursively compute the probability of each value for the root node, which represents the whole expression, hence computing the probabilities of each value in the whole expression, which is the required answer.

Exercise: Prove that a valid expression shall have the length as an integer of the form 4*p+1 for p \geq 0

TIME COMPLEXITY

The overall time complexity is O(\sum |L|*16) where 16 is the number of pairs considered at each node.

SOLUTIONS:

Setter's Solution
/*
Author : Kritagya Agarwal
April Long Challenge
REBIT
*/

#include<bits/stdc++.h>
#define ll long long int
#define MAX 5005
#define M 998244353
#define ld long long int
using namespace std;

struct node
{
	char number;
	ll sum;
	node * left;
	node * right;
	ll dp[4];
};

node * newNode(char u)
{
	node * temp = new node();
	temp->left = NULL;
	temp->right = NULL;
	temp->dp[0] = temp->dp[1] = temp->dp[2] = temp->dp[3] = 1;
	temp->sum = 4;
	temp->number = u;
	return temp;
}

ll poe(ll a, ll n)
{
	if(n == 0) return 1;

	ll ans = 1;
	ll vl = a;

	while(n)
	{
	    if(n%2)
	    {
	        ans *= vl;
	        ans %= M;
	    }

	    vl *= vl;
	    vl %= M;
	    n /= 2;
	}

	return ans;
}

node * constructTree(string s)
{
	stack<node *> st;
	node * t, *t1, *t2;

	for(int i = 0 ; i < s.size() ; i++)
	{
	    if(s[i] == '#')
	    {
	        t = newNode(s[i]);
	        st.push(t);
	    }  

	    else
	    {
	        t = newNode(s[i]);

	        t1 = st.top();
	        st.pop();
	        t2 = st.top();
	        st.pop();

	        t->right = t1;
	        t->left = t2;

	        st.push(t);
	    }
	}

	t = st.top();
	st.pop();

	return t;
}

 
int prec(char c) 
{ 
	if(c == '^') 
	return 1; 
	else if(c == '&') 
	return 3; 
	else if(c == '|') 
	return 2; 
	else
	return -1; 
} 
  
string infixToPostfix(string s) 
{ 
	std::stack<char> st; 
	st.push('N'); 
	int l = s.length(); 
	string ns; 
	for(int i = 0; i < l; i++) 
	{ 
	    // If the scanned character is an operand, add it to output string. 
	    if(s[i] == '#') 
	    ns += s[i]; 
  
	    // If the scanned character is an ‘(‘, push it to the stack. 
	    else if(s[i] == '(') 
	      
	    st.push('('); 
	      
	    // If the scanned character is an ‘)’, pop and to output string from the stack 
	    // until an ‘(‘ is encountered. 
	    else if(s[i] == ')') 
	    { 
	        while(st.top() != 'N' && st.top() != '(') 
	        { 
	            char c = st.top(); 
	            st.pop(); 
	           ns += c; 
	        } 
	        if(st.top() == '(') 
	        { 
	            char c = st.top(); 
	            st.pop(); 
	        } 
	    } 
	      
	    //If an operator is scanned 
	    else{ 
	        while(st.top() != 'N' && prec(s[i]) <= prec(st.top())) 
	        { 
	            char c = st.top(); 
	            st.pop(); 
	            ns += c; 
	        } 
	        st.push(s[i]); 
	    } 
	} 
	//Pop all the remaining elements from the stack 
	while(st.top() != 'N') 
	{ 
	    char c = st.top(); 
	    st.pop(); 
	    ns += c; 
	} 
	  
	return ns;
  
} 

void process(node * root)
{
	if(root->left == NULL && root->right == NULL)
	{
	    return;
	}
	
	process(root->left);
	process(root->right);
	char op = root->number;

	if(op == '&')
	{
	    root->dp[0] = ((root->left->dp[0])*(root->right->sum))%M
	    + ((root->left->sum)*(root->right->dp[0]))%M +
	    ((root->left->dp[2])*(root->right->dp[3]))%M +
	    ((root->left->dp[3])*(root->right->dp[2]))%M -
	    ((root->left->dp[0])*(root->right->dp[0]))%M + M ;
	    root->dp[0] %= M;

	    root->dp[1] = ((root->left->dp[1])*(root->right->dp[1]))%M;
	    root->dp[1] %= M;

	    root->dp[2] = ((root->left->dp[2])*(root->right->dp[2]))%M
	    + ((root->left->dp[2])*(root->right->dp[1]))%M +
	    ((root->left->dp[1])*(root->right->dp[2]))%M ;
	    root->dp[2] %= M;

	    root->dp[3] = ((root->left->dp[3])*(root->right->dp[3]))%M
	    + ((root->left->dp[3])*(root->right->dp[1]))%M +
	    ((root->left->dp[1])*(root->right->dp[3]))%M;
	    root->dp[3] %= M;
	}
	else if( op == '|')
	{
	    root->dp[0] = ((root->left->dp[0])*(root->right->dp[0]))%M;
	    root->dp[0] %= M;

	    root->dp[1] = ((root->left->sum)*(root->right->dp[1]))%M 
	    + ((root->left->dp[1])*(root->right->sum))%M + 
	    ((root->left->dp[2])*(root->right->dp[3]))%M +
	    ((root->left->dp[3])*(root->right->dp[2]))%M -
	    ((root->left->dp[1])*(root->right->dp[1]))%M + M;
	    root->dp[1] %= M;

	    root->dp[2] = ((root->left->dp[2])*(root->right->dp[2]))%M
	    + ((root->left->dp[2])*(root->right->dp[0]))%M +
	    ((root->left->dp[0])*(root->right->dp[2]))%M;
	    root->dp[2] %= M;

	    root->dp[3] = ((root->left->dp[3])*(root->right->dp[3]))%M
	    + ((root->left->dp[3])*(root->right->dp[0]))%M +
	    ((root->left->dp[0])*(root->right->dp[3]))%M;
	    root->dp[3] %= M;          
	}
	else if( op == '^')
	{
	    root->dp[0] = ((root->left->dp[0])*(root->right->dp[0]))%M
	    + ((root->left->dp[1])*(root->right->dp[1]))%M +
	    ((root->left->dp[2])*(root->right->dp[2]))%M +
	    ((root->left->dp[3])*(root->right->dp[3]))%M; 
	    root->dp[0] %= M;

	    root->dp[1] = ((root->left->dp[1])*(root->right->dp[0]))%M
	    + ((root->left->dp[0])*(root->right->dp[1]))%M +
	    ((root->left->dp[2])*(root->right->dp[3]))%M +
	    ((root->left->dp[3])*(root->right->dp[2]))%M ;
	    root->dp[1] %= M;

	    root->dp[2] = ((root->left->dp[1])*(root->right->dp[3]))%M
	    + ((root->left->dp[3])*(root->right->dp[1]))%M + 
	    ((root->left->dp[2])*(root->right->dp[0]))%M +
	    ((root->left->dp[0])*(root->right->dp[2]))%M ;
	    root->dp[2] %= M;

	    root->dp[3] = ((root->left->dp[1])*(root->right->dp[2]))%M
	    + ((root->left->dp[2])*(root->right->dp[1]))%M + 
	    ((root->left->dp[3])*(root->right->dp[0]))%M +
	    ((root->left->dp[0])*(root->right->dp[3]))%M ;
	    root->dp[3] %= M;
	}

	root->sum = root->dp[0] + root->dp[1] + root->dp[2] + root->dp[3];
	root->sum %= M;
}

int main()
{
	int t;
	cin>>t;

	for(int k = 1 ; k <= t ; k++)
	{
	    string s;
	    cin>>s;
	    string post = infixToPostfix(s);
	    int length = s.size();
	    
	 
	    int ans = 0;
	    char an;
	    ll sum = 0;
	    if(length == 1)
	    {
	        ld f0 = 1;
	        ld f1 = 1;
	        ld f2 = 1;
	        ld f3 = 1;
	        
	        ld fn = f0 + f1 + f2 + f3;
	        fn %= M;

	        f0 *= poe(fn,M-2);
	        f0 %= M;

	        f1 *= poe(fn,M-2);
	        f1 %= M;

	        f2 *= poe(fn,M-2);
	        f2 %= M;

	        f3 *= poe(fn,M-2);
	        f3 %= M;

	    
	        cout<<f0<<" "<<f1<<" "<<f2<<" "<<f3<<endl;

	    }
	    else
	    {
	        node * root = constructTree(post);
	        process(root);
	        
	        ld f0 = root->dp[0];
	        ld f1 = root->dp[1];
	        ld f2 = root->dp[2];
	        ld f3 = root->dp[3];
	        
	        ld fn = f0 + f1 + f2 + f3;
	        fn %= M;

	        f0 *= poe(fn,M-2);
	        f0 %= M;

	        f1 *= poe(fn,M-2);
	        f1 %= M;

	        f2 *= poe(fn,M-2);
	        f2 %= M;

	        f3 *= poe(fn,M-2);
	        f3 %= M;

	        cout<<f0<<" "<<f1<<" "<<f2<<" "<<f3<<endl;
	    }
	}
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
template<typename T = int> vector<T> create(size_t n){ return vector<T>(n); }
template<typename T, typename... Args> auto create(size_t n, Args... args){ return vector<decltype(create<T>(args...))>(n, create<T>(args...)); }
template<typename T = int, T mod = 1'000'000'007, typename U = long long>
struct umod{
	T val;
	umod(): val(0){}
	umod(U x){ x %= mod; if(x < 0) x += mod; val = x;}
	umod& operator += (umod oth){ val += oth.val; if(val >= mod) val -= mod; return *this; }
	umod& operator -= (umod oth){ val -= oth.val; if(val < 0) val += mod; return *this; }
	umod& operator *= (umod oth){ val = ((U)val) * oth.val % mod; return *this; }
	umod& operator /= (umod oth){ return *this *= oth.inverse(); }
	umod& operator ^= (U oth){ return *this = pwr(*this, oth); }
	umod operator + (umod oth) const { return umod(*this) += oth; }
	umod operator - (umod oth) const { return umod(*this) -= oth; }
	umod operator * (umod oth) const { return umod(*this) *= oth; }
	umod operator / (umod oth) const { return umod(*this) /= oth; }
	umod operator ^ (long long oth) const { return umod(*this) ^= oth; }
	bool operator < (umod oth) const { return val < oth.val; }
	bool operator > (umod oth) const { return val > oth.val; }
	bool operator <= (umod oth) const { return val <= oth.val; }
	bool operator >= (umod oth) const { return val >= oth.val; }
	bool operator == (umod oth) const { return val == oth.val; }
	bool operator != (umod oth) const { return val != oth.val; }
	umod pwr(umod a, U b) const { umod r = 1; for(; b; a *= a, b >>= 1) if(b&1) r *= a; return r; }
	umod inverse() const {
	    U a = val, b = mod, u = 1, v = 0;
	    while(b){
	        U t = a/b;
	        a -= t * b; swap(a, b);
	        u -= t * v; swap(u, v);
	    }
	    if(u < 0)
	        u += mod;
	    return u;
	}
};
using U = umod<int, 998244353>;
int main(){
	ios::sync_with_stdio(false);
	cin.tie(0);
	int t;
	cin >> t;
	while(t--){
		string s;
		cin >> s;
		/**
		Run through evalution tree and calculate probabilities 
		Use acumulated sum for finding the correct point k to divide interval (i, j)
 
		Probabilites for 0, 1, a, A
		 * */
		int n = s.size();
		vector<int> sum(n + 1, 0);
		for(int i = 0; i < n; i++){
			sum[i + 1] = sum[i];
			if(s[i] == '(') sum[i + 1]++;
			if(s[i] == ')') sum[i + 1]--;
		}
		vector<vector<int>> at(n);
		for(int i = 0; i <= n; i++){
			assert(sum[i] >= 0);
			at[sum[i]].push_back(i);
		}
		U i4 = U(1)/4;
		vector<U> op = {i4, i4, i4, i4};
		function<vector<U>(int,int)> solve = [&](int i, int j){
			if(i == j) return op;
			assert(s[i] == '(');
			assert(s[j] == ')');
			int k = *lower_bound(at[sum[i + 1]].begin(), at[sum[i + 1]].end(), i + 2);
			auto l = solve(i + 1, k - 1);
			auto r = solve(k + 1, j - 1);
			vector<U> res(4, 0);
			if(s[k] == '&'){
				for(int x = 0; x < 4; x++){
					for(int y = 0; y < 4; y++){
						if(x == 0 || y == 0) res[0] += l[x] * r[y];
						else if(x == 1) res[y] += l[x] * r[y];
						else if(y == 1) res[x] += l[x] * r[y];
						else if(x == y) res[x] += l[x] * r[y];
						else res[0] += l[x] * r[y];
					}
				}
			} else if(s[k] == '|'){
				for(int x = 0; x < 4; x++){
					for(int y = 0; y < 4; y++){
						if(x == 1 || y == 1) res[1] += l[x] * r[y];
						else if(x == 0) res[y] += l[x] * r[y];
						else if(y == 0) res[x] += l[x] * r[y];
						else if(x == y) res[x] += l[x] * r[y];
						else res[1] += l[x] * r[y];
					}
				}
			} else if(s[k] == '^'){
				for(int x = 0; x < 4; x++){
					for(int y = 0; y < 4; y++){
						if(x == 0 || y == 0) res[x ^ y] += l[x] * r[y];
						else if(x == y) res[0] += l[x] * r[y];
						else if(x == 1) res[y ^ 1] += l[x] * r[y];
						else if(y == 1) res[x ^ 1] += l[x] * r[y];
						else res[1] += l[x] * r[y];
					}
				}
			}
			return res;
		};
		auto ans = solve(0, n - 1);
		for(int i = 0; i < 4; i++){
			if(i) cout << ' ';
			cout << ans[i].val;
		}
		cout << '\n';
	}
	return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class REBIT{
	//SOLUTION BEGIN
	long MOD = 998244353;
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    s = n();pos = 0;
	    hold(s.length()%4 == 1);
	    Node root = new Node();
	    for(long l:root.f)p(l+" ");pn("");
	}
	String s;int pos = 0;//pos stores the next character to be considered
	class Node{
	    long[] f;
	    Node le, ri;
	    public Node() throws Exception{
	        if(s.charAt(pos) == '#'){
	            //Leaf node
	            f = new long[]{748683265,748683265 ,748683265 ,748683265};//1/4 for each value
	            le = ri = null;
	            pos++;
	        }else{
	            hold(s.charAt(pos++) == '(');
	            le = new Node();
	            char op = s.charAt(pos++);
	            ri = new Node();
	            hold(s.charAt(pos++) == ')');
	            
	            //Computing probabilities from children
	            f = new long[4];
	            for(int i = 0; i< 4; i++){
	                for(int j = 0; j< 4; j++){
	                    int res = op(i, j, op);
	                    f[res] = (f[res]+le.f[i]*ri.f[j])%MOD;
	                }
	            }
	        }
	    }
	}
	//return the result of a (op) b
	int op(int a, int b, char op){
	    switch(op){
	        case '&':
	            if(a == b)return a;
	            if(a == 0 || b == 0)return 0;
	            if(a == 1 || b == 1)return a+b-1;
	            return 0;
	        case '|':
	            if(a == b)return a;
	            if(a == 1 || b == 1)return 1;
	            if(a == 0 || b == 0)return a+b;
	            return 1;
	        case '^':
	            return a^b;
	        default: return -1;
	    }
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
//        new REBIT().run();
	    new Thread(null, new Runnable() {public void run(){try{new REBIT().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

8 Likes

if someone needs video

27 Likes

Detailed solution in Layman Language.

https://www.codechef.com/viewsolution/31836777
Here is my solution, can anyone plz provide some test cases where it fails?

https://www.codechef.com/viewsolution/31838592
can you please provide test cases where program fails

1 Like

If anyone wants a video explanation of expression trees : Expression Tree - SPOJ - Complicated Expressions - YouTube

I have used infix to postfix conversion to deal with the expressions.
Here is my python code: CodeChef: Practical coding for everyone

1 Like

Correct me if i am wrong…shouldnt the value of fn in main function would always be pow(4,c) , where c is the total number of ‘#’ present in the expression.

Yes. It should be pow(4, number of #). Also to the setter, this was my favorite question of the contest
because I think it requires a hell of creative thinking coming up with such a question. (idk this was the first time I was dealing with such problem). my solution.

Let E be a valid expression .
E= (E1 op E2 )
then PE can be written in terms of combination P[E1] (0,1,a,A) and P[E2] (0,1,a,A) under different operators
The recursion can be carried out iteratively using a stack… whenever we find ‘#’ we push (1/4,1/4,1/4,1/4) to stack . whenever we find a closing bracket pop two values from the stack and push their result . separate stacks can be maintained for different probabilities . final result will be the value in stack .
base case when E is ‘#’ p[E] (0,1,a,A)={1/4,1/4,1/4,1/4} as all are equally probable.
Link to solution -
https://www.codechef.com/viewsolution/31584883

Wow I was wildly off on this one. I tried to use infix evaluation algorithm and hard-coding all probabilities. Couldn’t think of test cases that’d fail, but still didn’t get AC.

1 Like

Try this:

2
(((#^#)|(#&#))&((#&#)^#))
((#|#)|#)

Your output
526417921 900759553 783777793 783777793
857866241 233963521 889061377 889061377

Correct output
526417921 900759553 783777793 783777793
982646785 233963521 889061377 889061377

I basically considered all possibilities like 0 xor {0,1,a,A}, 0 & {0,1,a,A},0 | {0,1,a,A} similarly for all 1,a,A. And with little observation it can be seen that the possibilities of {1,0,a,A} is getting multiplied with each expression. Code

I find something fishy in your expression in line 49 in operation == ‘|’ if you are using v[0] for zero probability then it should be v[0]=v1[0]*V2[0] as 0=0|0 but you wrote something else .Please verify your findval function properly.If you would comment your code it would be easy for me to help

I tried the above test case mine is also printing the same value.
Lets take the equation ((#|#)|#)
The probability of 0,1,a,A are 1/64, 49/64, 7/64,7/64 respectively.
in the result 857866241 is 1inv(64)%998244353 so the 49inv(64)%998244353 should be 109182983 but it is 233963521
but in my ans 982646785 is 1inv(64)%998244353 so the 49inv(64)%998244353 is 233963521. What am I missing? could you please help me why I am getting this wrong ans.

Can Anyone help me in finding modulo of a fraction…??

Thank you :")
brother

Thank you Thank you thank you ,
idk how that v2[1] came :sob: . now it passed.
Thanks

Sad it happens at times :pensive: Anyways welcome :sweat_smile:. Did you solve ppdiv I need some help

Thanks a lot it is really helpful.

https://www.codechef.com/viewsolution/31545809
Can i get any test cases where my program gives WA and NZEC ? i tried few test cases later with my AC code but all of them were giving the same answers