PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Setter: Kritagya Agarwal
Tester: Felipe Mota
Editorialist: Taranpreet Singh
DIFFICULTY:
Easy-Med
PREREQUISITES:
Probabilities and Implementation would do.
PROBLEM:
Given a logical expression with operands replaced by ‘#’, find the probability of the expression to evaluate to ‘0’, ‘1’, ‘a’ and ‘A’ if each ‘#’ can take any value among ‘0’, ‘1’, ‘a’ and ‘A’ randomly independently of each other.
QUICK EXPLANATION
- We can convert the given logical expression into a full binary tree, where leaves represent operands and non-leaf nodes represent operator. A subtree of a node represents sub-expression of the given expression and root node represents the whole expression.
- We can recursively compute for each node the probability to obtain ‘0’, ‘1’, ‘a’ and ‘A’ in a bottom-up manner, computing from leaves to parent and then by brute-forcing each pair of possible values for left and right child.
EXPLANATION
First of all, we need to note that the given expression can also be represented as a full binary tree, also known as Binary Expression Tree.
This tree has the following properties.
- Each leaf node represents and operand while each non-leaf node represents an operator, having exactly two children representing operands for this operator.
- Each subtree of this node corresponds to some sub-expression of the given expression.
The construction of the binary expression tree is easy using Stack, or just making some observations, as nicely explained here as well as here.
Now, we have constructed the binary tree. We need to compute the probabilities of each value for each node in the tree.
- For a leaf, all four values are equiprobable, thus equal to 1/4.
- For a non-leaf node, we can consider each pair of values for the left and right subtree. Since the operator is fixed, we can find the probability of each value appearing at the current node.
Applying this from a bottom-up manner, we can recursively compute the probability of each value for the root node, which represents the whole expression, hence computing the probabilities of each value in the whole expression, which is the required answer.
Exercise: Prove that a valid expression shall have the length as an integer of the form 4*p+1 for p \geq 0
TIME COMPLEXITY
The overall time complexity is O(\sum |L|*16) where 16 is the number of pairs considered at each node.
SOLUTIONS:
Setter's Solution
/*
Author : Kritagya Agarwal
April Long Challenge
REBIT
*/
#include<bits/stdc++.h>
#define ll long long int
#define MAX 5005
#define M 998244353
#define ld long long int
using namespace std;
struct node
{
char number;
ll sum;
node * left;
node * right;
ll dp[4];
};
node * newNode(char u)
{
node * temp = new node();
temp->left = NULL;
temp->right = NULL;
temp->dp[0] = temp->dp[1] = temp->dp[2] = temp->dp[3] = 1;
temp->sum = 4;
temp->number = u;
return temp;
}
ll poe(ll a, ll n)
{
if(n == 0) return 1;
ll ans = 1;
ll vl = a;
while(n)
{
if(n%2)
{
ans *= vl;
ans %= M;
}
vl *= vl;
vl %= M;
n /= 2;
}
return ans;
}
node * constructTree(string s)
{
stack<node *> st;
node * t, *t1, *t2;
for(int i = 0 ; i < s.size() ; i++)
{
if(s[i] == '#')
{
t = newNode(s[i]);
st.push(t);
}
else
{
t = newNode(s[i]);
t1 = st.top();
st.pop();
t2 = st.top();
st.pop();
t->right = t1;
t->left = t2;
st.push(t);
}
}
t = st.top();
st.pop();
return t;
}
int prec(char c)
{
if(c == '^')
return 1;
else if(c == '&')
return 3;
else if(c == '|')
return 2;
else
return -1;
}
string infixToPostfix(string s)
{
std::stack<char> st;
st.push('N');
int l = s.length();
string ns;
for(int i = 0; i < l; i++)
{
// If the scanned character is an operand, add it to output string.
if(s[i] == '#')
ns += s[i];
// If the scanned character is an ‘(‘, push it to the stack.
else if(s[i] == '(')
st.push('(');
// If the scanned character is an ‘)’, pop and to output string from the stack
// until an ‘(‘ is encountered.
else if(s[i] == ')')
{
while(st.top() != 'N' && st.top() != '(')
{
char c = st.top();
st.pop();
ns += c;
}
if(st.top() == '(')
{
char c = st.top();
st.pop();
}
}
//If an operator is scanned
else{
while(st.top() != 'N' && prec(s[i]) <= prec(st.top()))
{
char c = st.top();
st.pop();
ns += c;
}
st.push(s[i]);
}
}
//Pop all the remaining elements from the stack
while(st.top() != 'N')
{
char c = st.top();
st.pop();
ns += c;
}
return ns;
}
void process(node * root)
{
if(root->left == NULL && root->right == NULL)
{
return;
}
process(root->left);
process(root->right);
char op = root->number;
if(op == '&')
{
root->dp[0] = ((root->left->dp[0])*(root->right->sum))%M
+ ((root->left->sum)*(root->right->dp[0]))%M +
((root->left->dp[2])*(root->right->dp[3]))%M +
((root->left->dp[3])*(root->right->dp[2]))%M -
((root->left->dp[0])*(root->right->dp[0]))%M + M ;
root->dp[0] %= M;
root->dp[1] = ((root->left->dp[1])*(root->right->dp[1]))%M;
root->dp[1] %= M;
root->dp[2] = ((root->left->dp[2])*(root->right->dp[2]))%M
+ ((root->left->dp[2])*(root->right->dp[1]))%M +
((root->left->dp[1])*(root->right->dp[2]))%M ;
root->dp[2] %= M;
root->dp[3] = ((root->left->dp[3])*(root->right->dp[3]))%M
+ ((root->left->dp[3])*(root->right->dp[1]))%M +
((root->left->dp[1])*(root->right->dp[3]))%M;
root->dp[3] %= M;
}
else if( op == '|')
{
root->dp[0] = ((root->left->dp[0])*(root->right->dp[0]))%M;
root->dp[0] %= M;
root->dp[1] = ((root->left->sum)*(root->right->dp[1]))%M
+ ((root->left->dp[1])*(root->right->sum))%M +
((root->left->dp[2])*(root->right->dp[3]))%M +
((root->left->dp[3])*(root->right->dp[2]))%M -
((root->left->dp[1])*(root->right->dp[1]))%M + M;
root->dp[1] %= M;
root->dp[2] = ((root->left->dp[2])*(root->right->dp[2]))%M
+ ((root->left->dp[2])*(root->right->dp[0]))%M +
((root->left->dp[0])*(root->right->dp[2]))%M;
root->dp[2] %= M;
root->dp[3] = ((root->left->dp[3])*(root->right->dp[3]))%M
+ ((root->left->dp[3])*(root->right->dp[0]))%M +
((root->left->dp[0])*(root->right->dp[3]))%M;
root->dp[3] %= M;
}
else if( op == '^')
{
root->dp[0] = ((root->left->dp[0])*(root->right->dp[0]))%M
+ ((root->left->dp[1])*(root->right->dp[1]))%M +
((root->left->dp[2])*(root->right->dp[2]))%M +
((root->left->dp[3])*(root->right->dp[3]))%M;
root->dp[0] %= M;
root->dp[1] = ((root->left->dp[1])*(root->right->dp[0]))%M
+ ((root->left->dp[0])*(root->right->dp[1]))%M +
((root->left->dp[2])*(root->right->dp[3]))%M +
((root->left->dp[3])*(root->right->dp[2]))%M ;
root->dp[1] %= M;
root->dp[2] = ((root->left->dp[1])*(root->right->dp[3]))%M
+ ((root->left->dp[3])*(root->right->dp[1]))%M +
((root->left->dp[2])*(root->right->dp[0]))%M +
((root->left->dp[0])*(root->right->dp[2]))%M ;
root->dp[2] %= M;
root->dp[3] = ((root->left->dp[1])*(root->right->dp[2]))%M
+ ((root->left->dp[2])*(root->right->dp[1]))%M +
((root->left->dp[3])*(root->right->dp[0]))%M +
((root->left->dp[0])*(root->right->dp[3]))%M ;
root->dp[3] %= M;
}
root->sum = root->dp[0] + root->dp[1] + root->dp[2] + root->dp[3];
root->sum %= M;
}
int main()
{
int t;
cin>>t;
for(int k = 1 ; k <= t ; k++)
{
string s;
cin>>s;
string post = infixToPostfix(s);
int length = s.size();
int ans = 0;
char an;
ll sum = 0;
if(length == 1)
{
ld f0 = 1;
ld f1 = 1;
ld f2 = 1;
ld f3 = 1;
ld fn = f0 + f1 + f2 + f3;
fn %= M;
f0 *= poe(fn,M-2);
f0 %= M;
f1 *= poe(fn,M-2);
f1 %= M;
f2 *= poe(fn,M-2);
f2 %= M;
f3 *= poe(fn,M-2);
f3 %= M;
cout<<f0<<" "<<f1<<" "<<f2<<" "<<f3<<endl;
}
else
{
node * root = constructTree(post);
process(root);
ld f0 = root->dp[0];
ld f1 = root->dp[1];
ld f2 = root->dp[2];
ld f3 = root->dp[3];
ld fn = f0 + f1 + f2 + f3;
fn %= M;
f0 *= poe(fn,M-2);
f0 %= M;
f1 *= poe(fn,M-2);
f1 %= M;
f2 *= poe(fn,M-2);
f2 %= M;
f3 *= poe(fn,M-2);
f3 %= M;
cout<<f0<<" "<<f1<<" "<<f2<<" "<<f3<<endl;
}
}
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
template<typename T = int> vector<T> create(size_t n){ return vector<T>(n); }
template<typename T, typename... Args> auto create(size_t n, Args... args){ return vector<decltype(create<T>(args...))>(n, create<T>(args...)); }
template<typename T = int, T mod = 1'000'000'007, typename U = long long>
struct umod{
T val;
umod(): val(0){}
umod(U x){ x %= mod; if(x < 0) x += mod; val = x;}
umod& operator += (umod oth){ val += oth.val; if(val >= mod) val -= mod; return *this; }
umod& operator -= (umod oth){ val -= oth.val; if(val < 0) val += mod; return *this; }
umod& operator *= (umod oth){ val = ((U)val) * oth.val % mod; return *this; }
umod& operator /= (umod oth){ return *this *= oth.inverse(); }
umod& operator ^= (U oth){ return *this = pwr(*this, oth); }
umod operator + (umod oth) const { return umod(*this) += oth; }
umod operator - (umod oth) const { return umod(*this) -= oth; }
umod operator * (umod oth) const { return umod(*this) *= oth; }
umod operator / (umod oth) const { return umod(*this) /= oth; }
umod operator ^ (long long oth) const { return umod(*this) ^= oth; }
bool operator < (umod oth) const { return val < oth.val; }
bool operator > (umod oth) const { return val > oth.val; }
bool operator <= (umod oth) const { return val <= oth.val; }
bool operator >= (umod oth) const { return val >= oth.val; }
bool operator == (umod oth) const { return val == oth.val; }
bool operator != (umod oth) const { return val != oth.val; }
umod pwr(umod a, U b) const { umod r = 1; for(; b; a *= a, b >>= 1) if(b&1) r *= a; return r; }
umod inverse() const {
U a = val, b = mod, u = 1, v = 0;
while(b){
U t = a/b;
a -= t * b; swap(a, b);
u -= t * v; swap(u, v);
}
if(u < 0)
u += mod;
return u;
}
};
using U = umod<int, 998244353>;
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
int t;
cin >> t;
while(t--){
string s;
cin >> s;
/**
Run through evalution tree and calculate probabilities
Use acumulated sum for finding the correct point k to divide interval (i, j)
Probabilites for 0, 1, a, A
* */
int n = s.size();
vector<int> sum(n + 1, 0);
for(int i = 0; i < n; i++){
sum[i + 1] = sum[i];
if(s[i] == '(') sum[i + 1]++;
if(s[i] == ')') sum[i + 1]--;
}
vector<vector<int>> at(n);
for(int i = 0; i <= n; i++){
assert(sum[i] >= 0);
at[sum[i]].push_back(i);
}
U i4 = U(1)/4;
vector<U> op = {i4, i4, i4, i4};
function<vector<U>(int,int)> solve = [&](int i, int j){
if(i == j) return op;
assert(s[i] == '(');
assert(s[j] == ')');
int k = *lower_bound(at[sum[i + 1]].begin(), at[sum[i + 1]].end(), i + 2);
auto l = solve(i + 1, k - 1);
auto r = solve(k + 1, j - 1);
vector<U> res(4, 0);
if(s[k] == '&'){
for(int x = 0; x < 4; x++){
for(int y = 0; y < 4; y++){
if(x == 0 || y == 0) res[0] += l[x] * r[y];
else if(x == 1) res[y] += l[x] * r[y];
else if(y == 1) res[x] += l[x] * r[y];
else if(x == y) res[x] += l[x] * r[y];
else res[0] += l[x] * r[y];
}
}
} else if(s[k] == '|'){
for(int x = 0; x < 4; x++){
for(int y = 0; y < 4; y++){
if(x == 1 || y == 1) res[1] += l[x] * r[y];
else if(x == 0) res[y] += l[x] * r[y];
else if(y == 0) res[x] += l[x] * r[y];
else if(x == y) res[x] += l[x] * r[y];
else res[1] += l[x] * r[y];
}
}
} else if(s[k] == '^'){
for(int x = 0; x < 4; x++){
for(int y = 0; y < 4; y++){
if(x == 0 || y == 0) res[x ^ y] += l[x] * r[y];
else if(x == y) res[0] += l[x] * r[y];
else if(x == 1) res[y ^ 1] += l[x] * r[y];
else if(y == 1) res[x ^ 1] += l[x] * r[y];
else res[1] += l[x] * r[y];
}
}
}
return res;
};
auto ans = solve(0, n - 1);
for(int i = 0; i < 4; i++){
if(i) cout << ' ';
cout << ans[i].val;
}
cout << '\n';
}
return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class REBIT{
//SOLUTION BEGIN
long MOD = 998244353;
void pre() throws Exception{}
void solve(int TC) throws Exception{
s = n();pos = 0;
hold(s.length()%4 == 1);
Node root = new Node();
for(long l:root.f)p(l+" ");pn("");
}
String s;int pos = 0;//pos stores the next character to be considered
class Node{
long[] f;
Node le, ri;
public Node() throws Exception{
if(s.charAt(pos) == '#'){
//Leaf node
f = new long[]{748683265,748683265 ,748683265 ,748683265};//1/4 for each value
le = ri = null;
pos++;
}else{
hold(s.charAt(pos++) == '(');
le = new Node();
char op = s.charAt(pos++);
ri = new Node();
hold(s.charAt(pos++) == ')');
//Computing probabilities from children
f = new long[4];
for(int i = 0; i< 4; i++){
for(int j = 0; j< 4; j++){
int res = op(i, j, op);
f[res] = (f[res]+le.f[i]*ri.f[j])%MOD;
}
}
}
}
}
//return the result of a (op) b
int op(int a, int b, char op){
switch(op){
case '&':
if(a == b)return a;
if(a == 0 || b == 0)return 0;
if(a == 1 || b == 1)return a+b-1;
return 0;
case '|':
if(a == b)return a;
if(a == 1 || b == 1)return 1;
if(a == 0 || b == 0)return a+b;
return 1;
case '^':
return a^b;
default: return -1;
}
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
DecimalFormat df = new DecimalFormat("0.00000000000");
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
// new REBIT().run();
new Thread(null, new Runnable() {public void run(){try{new REBIT().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.