Sum of Squares with Segment Tree
Given below c++ code is for segsqrss spoj or sum of squares with segment tree spoj.
Main logic of code is within Merge and Split function in class 'node'
/* =================================================== Name :- Nishant Raj Email :- raj.nishant360@gmail.com College :- Indian School of Mines Branch :- Computer Science and Engineering Time :- 16 October 2015 (Friday) 02:09 ===================================================*/ #include <bits/stdc++.h> using namespace std; #define ll long long #define pii pair < int , int > #define pb push_back #define mp make_pair #define mod 1000000009 template<class T> class segmentTree{ public: segmentTree(){ height = 1; left_most = 1<<height; right_most = (left_most<<1) - 1; tree = new T[right_most]; } segmentTree(int s){ size=s; height = ceil(log2(s)); left_most = 1<<height; right_most = (left_most<<1) - 1; tree = new T[right_most+9]; } void init(T * arr){ build(arr); } void fill_ans(){ initalize(1,left_most,right_most); for(int i = left_most ;i< left_most+size ; i++){ for(int j=1;j<=26;j++) if(tree[i].arr[j]){ cout<<char(j+96); break; } } } void Update(int pos , T val){ point_update(1 , left_most , right_most , left_most+pos , val); } void Update(int l , int r , T val){ range_update(1 , left_most , right_most , left_most+l , left_most+r , val); } T Query(int pos){ return point_query(1 , left_most , right_most , left_most+pos); } T Query(int l ,int r){ return range_query(1 , left_most , right_most , left_most+l , left_most+r); } private: T *tree; int size , left_most , right_most , height; void build(T * arr){ for(int i = 0 ; i < size ; i++) tree[left_most+i] = arr[i]; initalize(1 , left_most , right_most); } void initalize(int root , int left_most , int right_most){ if(left_most == right_most) return; int mid = (left_most + right_most)>>1 , l_child = (root<<1) , r_child = (root<<1)+1; tree[root].split(tree[l_child] , tree[r_child]); initalize(l_child , left_most , mid); initalize(r_child , mid+1 , right_most); tree[root].merge(tree[l_child] , tree[r_child]); } void point_update(int root , int left_most , int right_most , int pos , T val){ if(left_most == right_most && root == pos) { tree[root].update(val); return ;} int mid = (left_most + right_most)>>1 , l_child = root<<1 , r_child = (root<<1)+1; tree[root].split(tree[l_child] , tree[r_child]); if(pos <= mid) point_update(l_child , left_most , mid , pos , val); else point_update(r_child , mid+1 , right_most , pos , val); tree[root].merge(tree[l_child] , tree[r_child]); } void range_update(int root , int left_most , int right_most , int l , int r , T val){ if(l <= left_most && r >= right_most){ tree[root].update(val);return;} int mid = (left_most + right_most)>>1 , l_child = root<<1 , r_child = (root<<1)+1; tree[root].split(tree[l_child] , tree[r_child]); if(l <= mid) range_update(l_child , left_most , mid, l , r , val); if(r > mid) range_update(r_child , mid+1 , right_most , l , r , val); tree[root].merge(tree[l_child] , tree[r_child]); } T range_query(int root , int left_most ,int right_most ,int l , int r){ if( l <= left_most && r >= right_most ) return tree[root]; int mid = (left_most + right_most)>>1 , l_child = root<<1 , r_child = (root<<1)+1; tree[root].split(tree[l_child] , tree[r_child]); T l_node , r_node , temp; if(l <= mid) l_node = range_query(l_child , left_most , mid , l , r ); if(r > mid) r_node = range_query(r_child , mid+1 , right_most , l , r ); tree[root].merge(tree[l_child] , tree[r_child]); temp.merge(l_node , r_node); return temp; } T point_query(int root , int left_most , int right_most , int pos){ if(left_most == right_most && root == pos) return tree[root]; int mid = (left_most + right_most)>>1 , l_child = root<<1 , r_child = (root<<1)+1; T temp; tree[root].split(tree[l_child] , tree[r_child]); if(pos <= mid) temp = point_query(l_child , left_most , mid , pos); else temp = point_query(r_child , mid+1 , right_most , pos); tree[root].merge(tree[l_child] , tree[r_child]); return temp; } }; class node{ public: ll sum , sq_sum , lazy1 , lazy2; int child_count; void merge(node &a , node &b){ sum = a.sum + b.sum; sq_sum = a.sq_sum + b.sq_sum; child_count = a.child_count + b.child_count; lazy1 = lazy2 = 0; } void split(node &a , node &b){ if(lazy1){ a.sq_sum += lazy1 * lazy1 * (ll)a.child_count + 2LL * lazy1 * a.sum; b.sq_sum += lazy1 * lazy1 * (ll)b.child_count + 2LL * lazy1 * b.sum; a.sum += lazy1 * a.child_count; b.sum += lazy1 * b.child_count; a.lazy1 += lazy1; b.lazy1 += lazy1; lazy1 = 0; } if(lazy2){ a.sq_sum = a.child_count * lazy2 * lazy2; a.sum = a.child_count * lazy2; a.lazy2 += lazy2; b.sq_sum = b.child_count * lazy2 * lazy2; b.sum = b.child_count * lazy2; b.lazy2 += lazy2; } } void update(node &a){ if(a.lazy1){ sq_sum = sq_sum + a.lazy1 * a.lazy1 * child_count + 2LL * a.lazy1 * sum; sum += a.lazy1 * child_count; lazy1 += a.lazy1; } if(a.lazy2){ sq_sum = child_count * a.lazy2 * a.lazy2; sum = child_count * a.lazy2; lazy2 += a.lazy2; } } node(){ sum = sq_sum = lazy1 = lazy2 = 0; child_count = 0; } node(ll a , ll l1 , ll l2){ sum = a; sq_sum = a*a; child_count = 1; lazy1 = l1; lazy2 = l2; } }; node arr[100009]; int main(){ int t; scanf("%d",&t); for(int test = 1 ; test <= t ; test++){ int n , temp ,q; scanf("%d%d",&n,&q); segmentTree<node> s(n); for(int i =0;i<n;i++){ scanf("%d",&temp); arr[i]=node(temp , 0 , 0); } s.init(arr); printf("Case %d:\n",test); while(q--){ int l,r,k,val; scanf("%d%d%d",&k,&l,&r); l-- , r--; if(k == 2){ printf("%lld\n",s.Query(l,r).sq_sum); } else if(k==1){ scanf("%d",&val); s.Update(l , r , node(0 , val , 0)); } else{ scanf("%d",&val); s.Update(l , r , node(0 , 0 , val)); } } } return 0; }
i tried running your code in Coding Ninja it gives WA for 1 Test case
ReplyDeletehmm, your code goes wrong in this test case:
ReplyDelete1
3 10
1 6 2
1 1 2 5
0 3 3 5
0 2 2 10
1 1 1 2
2 3 3
0 1 3 6
0 1 3 7
2 2 2
2 3 3
1 1 1 2
I think something went wrong in your code :>.
Sr for my poor english