前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >挑战程序竞赛系列(35):3.3Binary Indexed Tree

挑战程序竞赛系列(35):3.3Binary Indexed Tree

作者头像
用户1147447
发布2019-05-26 09:29:24
5110
发布2019-05-26 09:29:24
举报
文章被收录于专栏:机器学习入门机器学习入门

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1434668

挑战程序竞赛系列(35):3.3Binary Indexed Tree

详细代码可以fork下Github上leetcode项目,不定期更新。

练习题如下:

Binary Indexed Tree简介

Binary Indexed Tree是线段树的升级版,主要用于求前缀和,简单说说思想:

线段树的产生是为了满足频繁更新和求区间和的需求,所以用数组表示成一棵树的形式,使得更新和区间求和都能控制在O(logn)O(\log n)内。

接着观察线段树求和的性质,会发现有趣的现象,具体参考《挑战程序设计竞赛》P175页,右孩子都可以由它的父结点和父结点的左孩子相减求出,于是在线段数的基础上优化了空间复杂度,使它控制在了O(n)O(n)内。

那么如何表达元素的更新呢?将结点编号按从左至右递增,从1开始,与其对应的二进制我们能发现规律,参考书P176页的内容:

以1结尾的1,3,5,7的长度是1,最后有一个0的2,6的长度为2,最后有2个0的4的长度是4…..

于是有了更新元素和求和的规则:

更新元素:

因为更新某个位置的元素,则首先找到该位置对元素进行删改,那么后续的前缀后都需要更新,所以逐层向上,更新编号是依次递增的,代码如下:

代码语言:javascript
复制
public void add(int i, int val){
    while (i <= N){
        BIT[i] += val;
        i += i & -i;
    }
}

求前缀和:

那么自然的,从某个位置开始求“之前”的所有区间之和,巧妙的是这种更新规则刚好满足需求,好用好用。具体为啥是这个表达式可以参考博文https://www.topcoder.com/community/data-science/data-science-tutorials/binary-indexed-trees/,代码如下:

代码语言:javascript
复制
public int sum(int i){
    int s = 0;
    while (i > 0){
        s += BIT[i];
        i -= i & -i;
    }
    return s;
}

好了,有了这些我们就可以快速求解区域和,并且较快的更新每个元素。

POJ 1990: MooFest

思路:这种多个元素绑定在一个对象上的题,好像都需要排序啊。嘿嘿,这里还需要构建出两个有用的思想,否则想到题解较困难。

第一点,先来简单模拟下暴力解题过程,很简单,握手的循环结构,总共需要遍历O(n⋅(n−1)/2)O(n \cdot (n - 1) / 2),但在比较时需要求出每个pair的最大V(j),以及它们之间的距离,最后相乘累加即可。

如果按照Volume从小到大排序,每次加入一头牛时,就可以直接得到当前每个Pair的最大Volume是谁,必然是当前这头牛的Volume,那么问题就转换成了,求解Volume * (已加入牛的距离之和)。

这样就可以思考如何快速的求解现在队伍中所有牛和当前牛的距离之和了。

技巧二:把位置看成树状数组的编号,位置只在(1 - 20,000)之间,于是就可以根据位置所以求解在当前牛的左边牛的个数和右边牛的个数以及对应的距离之和。

核心的核心是把位置看成数组下标进行索引。

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P1990.txt";

    class Cow implements Comparable<Cow>{
        int pos;
        int volume;
        public Cow(int volume, int pos){
            this.pos = pos;
            this.volume = volume;
        }

        @Override
        public int compareTo(Cow that) {
            return this.volume - that.volume;
        }
    }

    static final int MAX_N = 20000 + 16;
    Cow[] cows;
    void solve() {
        int n = ni();
        init(MAX_N);
        cows = new Cow[n];
        for (int i = 0; i < n; ++i){
            cows[i] = new Cow(ni(), ni());
        }

        Arrays.sort(cows);
        long ans = 0;
        for (int i = 0; i < n; ++i){
            int v = cows[i].volume;
            int x = cows[i].pos;
            long left = sum(count, 1, x - 1);
            long right = sum(count, x + 1, MAX_N);
            ans += v * (left * x - sum (distance, 1, x - 1) + sum (distance, x + 1, MAX_N) - right * x);
            addCount(x, 1);
            addDistance(x, x);
        }
        out.println(ans);
    }

    /***********************Binary index tree********************************/
    long[] count;
    long[] distance;
    int N;
    public void init(int n){
        N = n;
        count = new long[n + 1];
        distance = new long[n + 1];
    }

    public void addCount(int i, int val){
        while (i <= N){
            count[i] += val;
            i += i & -i;
        }
    }

    public void addDistance(int i, int val){
        while (i <= N){
            distance[i] += val;
            i += i & -i;
        }
    }

    public long sum(long[] BIT, int i){ // [1, i]
        long sum = 0;
        while (i > 0){
            sum += BIT[i];
            i -= i & -i;
        }
        return sum;
    }

    public long sum(long[] BIT, int l, int r){ //[l, r]
        return sum(BIT, r) - sum(BIT, l - 1);
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 2155: Matrix

好吧,此题被坑了一天,一开始参考博文:http://www.hankcs.com/program/algorithm/poj-2155-matrix.html的方法,接着根据《挑战》P181的公式推导了一个2D的区域更新binary Index tree,但死活出不来,此题其实可以这样。。。

翻转的好处在于不断用1累加求和,最后结果&1,即为开关操作,所以我们要做的无非就是对指定区域进行累加操作。那么给定了(x1,y1)和(x2,y2),如何翻转?

参考博文:http://blog.csdn.net/zxy_snow/article/details/6264135

嗯,图话的很清晰,先(x2,y2)翻一次,接着(x1 - 1, y2)和(x2, y1 - 1)再翻回来,最后多翻的一次(x1 - 1, y1 - 1)再翻一次,用二维BIT完美解决。

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P2155.txt";

    int N;
    void solve() {
        int t = ni();
        while (t --> 0){
            N = ni();
            int C = ni();
            BIT2DTree bit = new BIT2DTree();
            for (int i = 0; i < C; ++i){
                char c = nc();
                if (c == 'C'){
                    int x = ni();
                    int y = ni();
                    int l = ni();
                    int r = ni();
                    x ++;
                    y ++;
                    l ++;
                    r ++;
                    bit.add(l , r, 1);
                    bit.add(x - 1, y - 1, 1);
                    bit.add(x - 1, r, 1);
                    bit.add(l, y - 1, 1);
                }
                else{
                    int x = ni();
                    int y = ni();
                    out.println(bit.sum(x, y) & 1);
                }
            }
            out.println();
        }
    }

    /***************Binary Index Tree*******************/

    static final int MAX_N = 1000 + 16;

    class BIT2DTree{
        int[][] bit;

        public BIT2DTree(){
            bit = new int[MAX_N][MAX_N];
        }

        public void add(int x, int y, int val){
            for (int i = x; i <= N; i += i & -i){
                for (int j = y; j <= N; j += j & -j){
                    bit[i][j] += val;
                }
            }
        }

        public int sum(int x, int y){
            int ans = 0;
            for (int i = x; i > 0; i -= i & -i){
                for (int j = y; j > 0; j -= j & -j){
                    ans += bit[i][j];
                }
            }
            return ans;
        }
    }


    //以下代码不能AC
    BIT2DTree bit_XY, bit_X, bit_Y, bit;

    public void init(){
        bit_XY = new BIT2DTree();
        bit_X = new BIT2DTree();
        bit_Y = new BIT2DTree();
        bit = new BIT2DTree();
    }

    public int sum (int x, int y){
        return bit.sum(x, y) + bit_XY.sum(x, y) * x * y + bit_X.sum(x, y) * x + bit_Y.sum(x, y) * y;
    }

    public int sumRange(int x, int y, int l, int r){ //l >= x && r >= y
        return sum(l, r) + sum(x - 1, y - 1) - (sum(l, y - 1) + sum(x - 1, r));
    }

    public void add(int x, int y, int l, int r, int val){
        bit_XY.add(x, y, val);
        bit_XY.add(x, r + 1, -val);
        bit_XY.add(l + 1, y, -val);
        bit_XY.add(l + 1, r + 1, val);

        bit_X.add(x, y, -val * (y - 1));
        bit_X.add(x, r + 1, val * r);
        bit_X.add(l + 1, y, (y - 1) * val);
        bit_X.add(l + 1, r + 1, - val * r);

        bit_Y.add(x, y, -val * (x - 1));
        bit_Y.add(x, r + 1, val * (x - 1));
        bit_Y.add(l + 1, y, val * l);
        bit_Y.add(l + 1, r + 1, - val * l);

        bit.add(x, y, (x - 1) * (y - 1) * val);
        bit.add(x, r + 1, -r * (x - 1) * val);
        bit.add(l + 1, y, -l * (y - 1) * val);
        bit.add(l + 1, r + 1, l * r * val);

    }


    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 2886: Who Gets the Most Candies?

此题的关键在于快速定位下标,因为每一轮都会有人离开,所以数组下标将不断发生变化,此题利用BIT的sum(i)去定位下标,比如,初始时,给每个位置add(i,1),得到了这样的sum:

代码语言:javascript
复制
1 2 3 4 5 6

如果位置2第一轮被选中,则有
new: 1 1 2 3 4 5
old: 1 2 3 4 5 6

所以BIT的优势在于更新坐标时,能够以log(n)的速度计算更新。

那么问题就转化为定位新坐标下的原始位置,用一个binary Search,去搜一次
比如binarySearch(3)对应的原始坐标为4.

接着就是根据题意更新坐标了,有几个坑点:

1. 负数和正数注意当前长度,因为循环,所以要取模
2. 在负数更新时,要时刻注意当前坐标变小了,所以求解step时,要减去1

当然关于F(P)的求解,可以用艾氏筛选法,边算素数边打表,这样就AC了。

约束个数详解可以参看http://www.hankcs.com/program/algorithm/poj-2886-who-gets-the-most-candies.html

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.HashMap;
import java.util.InputMismatchException;
import java.util.Map;
import java.util.Scanner;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P2886.txt";

    class Pair{
        String name;
        int step;
        public Pair(String name, int step){
            this.name = name;
            this.step = step;
        }
    }

    Scanner in;
    public Main(){
        in = new Scanner(System.in);
    }


    static final int MAX_N = 500000 + 16;
    void solve() {
        factors();
        while (in.hasNext()){
            int N = in.nextInt();
            int K = in.nextInt();

            /************** 计算 约数个数 ******************/
            int[] p = new int[MAX_N];
            int max = 0;
            int id = 0;

            for (int i = 1; i < N + 1; ++i){
                if (max < table[i]){
                    max = table[i];
                    id = i;
                }
                table[i] = max;
                p[i] = id;
            }

            init(N);
            for (int i = 1; i <= N; ++i) add(i, 1);
            Pair[] ps = new Pair[N];

            for (int i = 0; i < N; ++i){
                String name = in.next();
                int step = in.nextInt();
                ps[i] = new Pair(name, step);
            }

            int index = 0;
            int len = N;
            for (int i = 0; i < p[N]; ++i){
                index = binarySearch(K);
                add(index, -1);
                len --;
                if (len == 0) break;
                int step = ps[index - 1].step;
                if (step < 0){
                    step = -step;
                    K = (sum(index) - 1 + len - (step - 1) % len) % len + 1;
                }
                else{
                    K = (sum(index) - 1 + step) % len + 1;
                }
            }

            System.out.println(ps[index - 1].name + " " + table[N]);
        }

    }

    /******************* Binary Index Tree*******************/
    int[] bit;
    int N;

    public void init(int n){
        bit = new int[MAX_N];
        this.N = n;
    }

    public void add(int i, int val){
        while (i <= N){
            bit[i] += val;
            i += i & -i;
        }
    }

    public int sum(int i){
        int res = 0;
        while (i > 0){
            res += bit[i];
            i -= i & -i;
        }
        return res;
    }

    /******************* binary search *******************/
    public int binarySearch(int key){
        int lf = 1, rt = N;
        while (lf < rt){
            int mid = lf + (rt - lf) / 2;
            if (sum(mid) < key) lf = mid + 1;
            else rt = mid;
        }

        if (sum(lf) == key) return lf;
        return -1;
    }

    // 速度太慢,采用分解为素数的办法
    int[] table = new int[MAX_N];
    public void factors(){
        Arrays.fill(table, 1);
        for (int i = 2; i < MAX_N; ++i){
            if (table[i] == 1){
                for (int j = i; j < MAX_N; j += i){
                    int k = 0;
                    for (int m = j; m % i == 0; m /= i, k++);
                    table[j] *= (k + 1);
                }
            }
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3109: Inner Vertices

此题trick了很久,艰辛。个人认为关键还是在于如何把模拟的思路转成代码,首先需要明确一点,新生成的黑色点,在加入新的图中后,不会改变后续点的生成,并没有连锁反应。

所以,我们可以放心的采用任何一种模拟手段不断地标点,而不用担心标点之后先前遍历过的状态是否还要重新走一遍。

那么问题来了,如何模拟出有效的标点操作呢?采用扫面线算法,从下往上扫描。在扫描的过程中,我们必须得知道扫描线与垂直线相交的个数,这是一个trick,我们可以用BIT解决,它可以在给定扫面平行线的两个最远端点时,快速求出该区间范围内的相交垂直线个数。

BIT的用处:

简单来说,可以把BIT想象成一个动态变化的数组(做了那么多题的一个最大感受),从此题来看,我们最大的难点就是如何判断扫面线与垂直线相交,BIT可以记录横坐标的位置,我们假设扫描的该点上方还有黑点,则可以在此位置上加入:

代码语言:javascript
复制
 add(x , 1)

注意:这样做的前提条件是下方不存在黑点。同理,如果当前横坐标上方不存在黑点了呢?我们可以在该位置x上:

代码语言:javascript
复制
add(x, -1);

注意:这样做的前提条件是该横坐标下方是存在过黑点的。

嘿,这样一来,所有的问题都解决了,因为扫面线是从下往上的,而BIT也是从下往上慢慢构造的,所以在扫描的同时,我已经知道了扫面线扫过后的下方所有信息,而BIT能够记录这些垂直线的变化过程,完美解决。

最后就是预处理和离散化了,比较简单。代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.TreeSet;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P3109.txt";

    int[] x;
    int[] y;
    class Pair implements Comparable<Pair>{
        int x;
        int y;
        public Pair(int x, int y){
            this.x = x;
            this.y = y;
        }
        @Override
        public int compareTo(Pair that) {
            int thiz = this.y;
            int thaz = that.y;
            return thiz == thaz ? this.x - that.x : thiz - thaz;
        }
    }

    Pair[] ps;
    int N;
    void solve() {
        N = ni();
        x = new int[N];
        y = new int[N];

        for (int i = 0; i < N; ++i){
            x[i] = ni();
            y[i] = ni();
        }

        compress(x);
        compress(y);

        ps = new Pair[N];
        for (int i = 0; i < N; ++i){
            ps[i] = new Pair(x[i], y[i]);
        }

        Arrays.sort(ps);
        init();
        doSolve();
    }

    public int compress(int[] arra){
        TreeSet<Integer> set = new TreeSet<Integer>();
        for (int i : arra) set.add(i);
        Integer[] find = set.toArray(new Integer[0]);
        for (int i = 0; i < arra.length; ++i){
            arra[i] = Arrays.binarySearch(find, arra[i]) + 1;
        }
        return find.length;
    }

    int[] l;  //ps[i] 的上方有黑点则为0,否则为1
    int[] r;  //ps[i] 的下方有黑点则为0,否则为1
    static final int MAX_N = 200000 + 16;
    public void init(){
        l = new int[N];
        r = new int[N];
        boolean[] visited = new boolean[MAX_N];
        for (int i = 0; i < N; ++i){
            if (!visited[ps[i].x]){
                l[i] = 1;
                visited[ps[i].x] = true;
            }
            else{
                l[i] = 0;
            }
        }

        visited = new boolean[MAX_N];
        for (int i = N - 1; i >= 0; --i){
            if (!visited[ps[i].x]){
                r[i] = 1;
                visited[ps[i].x] = true;
            }
            else{
                r[i] = 0;
            }
        }
    }

    public void doSolve(){
        //扫面线算法
        long ans = 0;
        for (int i = 0; i < N; ++i){
            int j = i;
            while (j + 1 < N && ps[j].y == ps[j + 1].y) j++; // ps[j].y != ps[j + 1].y
            //j 指向的是同一行最后一个点
            for (int k = i; k <= j; ++k){ //每一个点的上方没有点,则从累积和中删除
                if (r[k] == 1 && l[k] == 0){ //下方有黑点,且上方没有黑点了,则此条线已经不存在了
                    add(ps[k].x, -1);
                }
            }

            if (i == j);
            else{
                ans += sum(ps[j].x - 1) - sum(ps[i].x);
                for (int k = i + 1; k < j; ++k){
                    if (l[k] == 0 && r[k] == 0){ //上方和下方均有黑点的
                        ans --;
                    }
//                  else{
//                      ans --;
//                  }
                }
            }

            //开始标记
            for (int k = i; k <= j; ++k){
                if (r[k] == 0 && l[k] == 1){  //上方有黑点,且下方没有黑点(避免重复计算)
                    add(ps[k].x, 1);
                }
            }
            i = j;
        }
        out.println(ans + N);
    }

    /************************binary indexed tree***************************/
    long[] BIT = new long[MAX_N];

    public void add(int i, int val){
        while (i < MAX_N){
            BIT[i] += val;
            i += i & -i;
        }
    }

    public long sum(int i){
        int res = 0;
        while (i > 0){
            res += BIT[i];
            i -= i & -i;
        }
        return res;
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017年08月19日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 挑战程序竞赛系列(35):3.3Binary Indexed Tree
    • Binary Indexed Tree简介
      • POJ 1990: MooFest
        • POJ 2155: Matrix
          • POJ 2886: Who Gets the Most Candies?
            • POJ 3109: Inner Vertices
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档