#include<bits/stdc++.h>
#define endl ‘\n’
#define ll long long
#define debug(x) #x<<": "<<x
#define armax 100001
ll a[armax];
ll b[armax];
ll c[armax];
ll btmp[armax];
using namespace std;
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
ll n;
cin>>n;
ll bsize[armax+1];
memset(bsize,0,armax);
for(ll i = 0;i<n;i++)
cin>>a[i];
for(ll i = 0;i<n;i++)
{
cin>>b[i];
btmp[i] = b[i];
}
for(ll i = 0;i<n;i++)
cin>>c[i];
sort(btmp,btmp+n);
sort(c,c+n);
for(ll i = 0;i<=n;i++)
{
if(binary_search(c,c+n,i+1))
{
ll nooftimes = upper_bound(c,c+n,i+1)-lower_bound(c,c+n,i+1);
bsize[b[i]]+=nooftimes;
}
}
vector<ll> bpositions;
for(ll i = 0;i<n;i++)
if(binary_search(btmp,btmp+n,a[i]))
bpositions.push_back(a[i]);
ll ans = 0;
for(ll pos:bpositions)
ans+=bsize[pos];
cout<<ans;
}
here’s my code. I have looked at the editorial and felt that my approach is very similar to that mentioned in it. But, I can’t seem to find what’s wrong. This gives WA on bound_00 and bound_01. Any help would be appreciated.
1 Like
In this problem you find How many pairs (i,j) where Ai = B(ci)
so firslty you count in map : map[B[C[i]]]++; for all i
and than for every i in A to add your answer : ans+=map[A[i]]
and print the ans.
my Java Code:
import java.util.*;
import java.io.*;
import java.math.*;
/**
*
* @Har_Har_Mahadev
*/
/**
* Main , Solution , Remove Public
*/
public class Main {
public static void process() throws IOException {
int n = sc.nextInt();
int count[] = new int[n+1];
int freq[] = new int[n+1];
int a[] = sc.readArray(n);
int b[] = sc.readArray(n);
int c[] = sc.readArray(n);
for(int i = 0; i<n; i++) {
count[b[c[i]-1]]++;
freq[b[i]]++;
}
long ans = 0;
for(int i =0; i<n; i++) {
int val = a[i];
int ff = count[val];
ans+=ff;
}
System.out.println(ans);
}
//=============================================================================
//--------------------------The End---------------------------------
//=============================================================================
private static long INF = 2000000000000000000L, M = 1000000007, MM = 998244353;
private static int N = 0;
private static void google(int tt) {
System.out.print("Case #" + (tt) + ": ");
}
static FastScanner sc;
static PrintWriter out;
public static void main(String[] args) throws IOException {
boolean oj = true;
if (oj) {
sc = new FastScanner();
out = new PrintWriter(System.out);
} else {
sc = new FastScanner(100);
out = new PrintWriter("output.txt");
}
int t = 1;
// t = sc.nextInt();
int TTT = 1;
while (t-- > 0) {
// google(TTT++);
process();
}
out.flush();
out.close();
}
static class Pair implements Comparable<Pair> {
int x, y;
Pair(int x, int y) {
this.x = x;
this.y = y;
}
@Override
public int compareTo(Pair o) {
return Integer.compare(this.x, o.x);
}
// @Override
// public boolean equals(Object o) {
// if (this == o) return true;
// if (!(o instanceof Pair)) return false;
// Pair key = (Pair) o;
// return x == key.x && y == key.y;
// }
//
// @Override
// public int hashCode() {
// int result = x;
// result = 31 * result + y;
// return result;
// }
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
static void println(Object o) {
out.println(o);
}
static void println() {
out.println();
}
static void print(Object o) {
out.print(o);
}
static void pflush(Object o) {
out.println(o);
out.flush();
}
static int ceil(int x, int y) {
return (x % y == 0 ? x / y : (x / y + 1));
}
static long ceil(long x, long y) {
return (x % y == 0 ? x / y : (x / y + 1));
}
static int max(int x, int y) {
return Math.max(x, y);
}
static int min(int x, int y) {
return Math.min(x, y);
}
static int abs(int x) {
return Math.abs(x);
}
static long abs(long x) {
return Math.abs(x);
}
static long sqrt(long z) {
long sqz = (long) Math.sqrt(z);
while (sqz * 1L * sqz < z) {
sqz++;
}
while (sqz * 1L * sqz > z) {
sqz--;
}
return sqz;
}
static int log2(int N) {
int result = (int) (Math.log(N) / Math.log(2));
return result;
}
static long max(long x, long y) {
return Math.max(x, y);
}
static long min(long x, long y) {
return Math.min(x, y);
}
public static int gcd(int a, int b) {
BigInteger b1 = BigInteger.valueOf(a);
BigInteger b2 = BigInteger.valueOf(b);
BigInteger gcd = b1.gcd(b2);
return gcd.intValue();
}
public static long gcd(long a, long b) {
BigInteger b1 = BigInteger.valueOf(a);
BigInteger b2 = BigInteger.valueOf(b);
BigInteger gcd = b1.gcd(b2);
return gcd.longValue();
}
public static long lcm(long a, long b) {
return (a * b) / gcd(a, b);
}
public static int lcm(int a, int b) {
return (a * b) / gcd(a, b);
}
public static int lower_bound(int[] arr, int x) {
int low = 0, high = arr.length, mid = -1;
while (low < high) {
mid = (low + high) / 2;
if (arr[mid] >= x)
high = mid;
else
low = mid + 1;
}
return low;
}
public static int upper_bound(int[] arr, int x) {
int low = 0, high = arr.length, mid = -1;
while (low < high) {
mid = (low + high) / 2;
if (arr[mid] > x)
high = mid;
else
low = mid + 1;
}
return low;
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
static class FastScanner {
BufferedReader br;
StringTokenizer st;
FastScanner() throws FileNotFoundException {
br = new BufferedReader(new InputStreamReader(System.in));
}
FastScanner(int a) throws FileNotFoundException {
br = new BufferedReader(new FileReader("input.txt"));
}
String next() throws IOException {
while (st == null || !st.hasMoreElements()) {
try {
st = new StringTokenizer(br.readLine());
} catch (IOException e) {
e.printStackTrace();
}
}
return st.nextToken();
}
int nextInt() throws IOException {
return Integer.parseInt(next());
}
long nextLong() throws IOException {
return Long.parseLong(next());
}
double nextDouble() throws IOException {
return Double.parseDouble(next());
}
String nextLine() throws IOException {
String str = "";
try {
str = br.readLine();
} catch (IOException e) {
e.printStackTrace();
}
return str;
}
int[] readArray(int n) throws IOException {
int[] A = new int[n];
for (int i = 0; i != n; i++) {
A[i] = sc.nextInt();
}
return A;
}
long[] readArrayLong(int n) throws IOException {
long[] A = new long[n];
for (int i = 0; i != n; i++) {
A[i] = sc.nextLong();
}
return A;
}
}
static void ruffleSort(int[] a) {
Random get = new Random();
for (int i = 0; i < a.length; i++) {
int r = get.nextInt(a.length);
int temp = a[i];
a[i] = a[r];
a[r] = temp;
}
Arrays.sort(a);
}
static void ruffleSort(long[] a) {
Random get = new Random();
for (int i = 0; i < a.length; i++) {
int r = get.nextInt(a.length);
long temp = a[i];
a[i] = a[r];
a[r] = temp;
}
Arrays.sort(a);
}
}
1 Like
Actually during the contest I didn’t see the constraint that A, B, C have elements in the range [1, n]. I just solved generally with maps 
Anyways you can construct array d as d[i] = b[c[i]] and problem reduces to counting pairs of equal elements between a, d which can be done in \mathcal O(n) time (I implemented with maps, so mine was n \log n
)
3 Likes
I understood the solution and I have been able to code the solution from scratch after the contest using the same idea as yours but I can’t seem to find what’s wrong in my solution that I posted above. I coded that during the contest and it gave me a WA. Wish I could see where its wrong.
@leo_valdez I implemented the same approach in Python but got AC in 10 test cases and TLE in rest. Any idea if AtCoder has a language multiplier system?
1 Like
I think usually Python has a larger multiplier than C++. I don’t think using a dict is as fast as using map in C++. But of course, I could be wrong. Maybe AtCoder does have a multiplier system, but I’m not sure…