前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >挑战程序竞赛系列(36):3.3线段树和平方分割

挑战程序竞赛系列(36):3.3线段树和平方分割

作者头像
用户1147447
发布2019-05-26 09:29:59
6450
发布2019-05-26 09:29:59
举报
文章被收录于专栏:机器学习入门

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://cloud.tencent.com/developer/article/1434672

挑战程序竞赛系列(36):3.3线段树和平方分割

详细代码可以fork下Github上leetcode项目,不定期更新。

练习题如下:

分桶法和平方分割

具体可以参考《挑战》P183页,这里简单说说思想。

我的理解:空间换时间,举个例子:

代码语言:javascript
复制
1 2 3 4 5 6 7 8 9 10

求指定区间内的最小值

区间 [1, 3]中的最小值为1
区间 [4, 8]中的最小值为4

传统做法,遍历指定区间需要O(n)次,能够求出答案,但由于频繁查询可能需要O(m)次,所以整体时间复杂度为O(nm)次,有没有办法把时间复杂度降低一些?平方分桶法可以降低到O(mn√)O(m\sqrt n)。

说白了,上述情况的每个结点维护自己的信息,分桶法的思想是:

组合几个个体成一个桶,由桶统一维护信息,所以对我们来说,它的呈现形式是多个个体和一个个桶,也就是所谓的空间换时间。

比如上述例子:

代码语言:javascript
复制
{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}

如果给定查询区间[2, 9]

那就意味着要遍历个体元素2和个体元素9,但查询区间中,有几个桶完全包含了该区间,如{3, 4}

那么,我们就可以直接拿出桶维护的最小元素信息

所以遍历时,我们只需要遍历 2个元素 + 3个桶
传统做法需要遍历 8个元素

谁快?

当然分桶法实现比线段树简单,但划分整体比线段树粗暴,所以时间复杂度略慢于线段树。

POJ 2104: K-th Number

思路很简单,根据分桶法,可以把它们放在一个个桶内单独维护,在区间内的桶,因为全部包含,所以排序后可以用二分快速找出答案,而桶不完全包含在区间内的,需要单独计算。整体再采用二分,在所有候选答案中猜出即可。

具体思路可以参考《挑战》P186页,代码如下:

代码语言:javascript
复制
    static final int MAX_N = 100000 + 16;
    static final int B = 1000;
    List<Integer>[] bucket = new ArrayList[MAX_N / B + 1];

    void solve() {
        int n = ni();
        int m = ni();
        int[] sort = new int[n];
        int[] arra = new int[n];

        for (int i = 0; i <= n / B; ++i) bucket[i] = new ArrayList<Integer>();

        for (int i = 0; i < n; ++i){
            arra[i] = ni();
            bucket[i / B].add(arra[i]);
            sort[i] = arra[i];
        }

        for (int i = 0; i < n / B; ++i){
            Collections.sort(bucket[i]);
        }

        Arrays.sort(sort);

        for (int t = 0; t < m; ++t){
            int i = ni();
            int j = ni();
            int k = ni();
            i--;
            j--;

            int lf = -1, rt = n - 1;
            while (rt - lf > 1){
                int l = i;
                int r = j;

                int s = l / B;
                int e = r / B;
                int mid = (lf + rt) / 2;

                int key = sort[mid];

                int x = 0;
                if (e - s <= 1){
                    for (int y = l; y <= r; ++y){
                        if (arra[y] <= key) x++;
                    }
                }
                else{

                    while (l < n && l / B == s){
                        if (arra[l] <= key) x++;
                        l++;
                    }

                    while (r >= 0 && r / B == e){
                        if (arra[r] <= key) x++;
                        r--;
                    }

                    for (int y = s + 1; y < e; ++y){
                        x += binarySearch(bucket[y], key) + 1;
                    }
                }
                if (x < k){
                    lf = mid;
                }
                else{
                    rt = mid;
                }
            }
            out.println(sort[rt]);
        }

    }

    public int binarySearch(List<Integer> aux, int key){
        int lf = 0, rt = aux.size() - 1;
        while (lf < rt){
            int mid = lf + (rt - lf + 1) / 2;
            if (aux.get(mid) > key){
                rt = mid - 1;
            }
            else lf = mid;
        }
        if (aux.get(lf) <= key) return lf;
        return -1;
    }

TLE了,显然此题用分桶法还不够快,因此我们采用线段树来解决,线段树维护的独立个体是自底向上慢慢长大的,所以空间复杂度更高,但速度会更快,代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.InputMismatchException;
import java.util.List;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P2104.txt";

    static final int SIZE = (1 << 18) - 1;
    List<Integer>[] dat = new ArrayList[SIZE];
    int[] A;

    void solve(){
        int N = ni();
        int M = ni();
        A = new int[N];
        int[] sort = new int[N];
        for (int i = 0; i < N; ++i){
            A[i] = ni();
            sort[i] = A[i];
        }
        for (int i = 0; i < dat.length; ++i) dat[i] = new ArrayList<Integer>();

        Arrays.sort(sort);
        init(0, 0, N);

        for (int t = 0; t < M; ++t){
            int i = ni();
            int j = ni();
            int k = ni();
            i--;
            int lf = -1, rt = N - 1;
            while (rt - lf > 1){
                int mid = (lf + rt) / 2;
                int query = query(0, i, j, sort[mid], 0, N);
                if (query < k){
                    lf = mid;
                }
                else{
                    rt = mid;
                }
            }

            out.println(sort[rt]);
        }
    }


    /******************以下是线段树******************/

    /***
     * 区间 [l, r)
     * @param k
     * @param l
     * @param r
     */
    public void init(int k, int l, int r){
        if (r - l == 1){
            dat[k].add(A[l]);
        }
        else{
            int lch = 2 * k + 1;
            int rch = 2 * k + 2;
            init(lch, l, (l + r) / 2); //为了能够准确的划分区间
            init(rch, (l + r) / 2, r);

            merge (dat[lch], dat[rch], dat[k]);
        }
    }

    /**
     * 查询区间 [i, j)
     * 线段树区间 [l, r)
     * @param k
     * @param i
     * @param j
     * @param x
     * @param l
     * @param r
     * @return
     */
    public int query(int k, int i, int j, int x, int l, int r){
        if (j <= l || i >= r) return 0;
        else if (i <= l && j >= r){
            return binarySearch(dat[k], x) + 1;
        }else{
            int ans = 0;
            ans += query(2 * k + 1, i, j, x, l, (l + r) / 2);
            ans += query(2 * k + 2, i, j, x, (l + r) / 2, r);
            return ans;
        }
    }

    public void merge(List<Integer> lch, List<Integer> rch, List<Integer> k){
        int l = 0, r = 0;
        while (l < lch.size() && r < rch.size()){
            if (lch.get(l) <= rch.get(r)){
                k.add(lch.get(l));
                l++;
            }
            else{
                k.add(rch.get(r));
                r++;
            }
        }

        while (l < lch.size()) k.add(lch.get(l++));
        while (r < rch.size()) k.add(rch.get(r++));
    }


    public int binarySearch(List<Integer> aux, int key){
        int lf = 0, rt = aux.size() - 1;
        while (lf < rt){
            int mid = lf + (rt - lf + 1) / 2;
            if (aux.get(mid) > key){
                rt = mid - 1;
            }
            else lf = mid;
        }
        if (aux.get(lf) <= key) return lf;
        return -1;
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3264: Balanced Lineup

水题,思路很直接,关键怎么加快速度,采用分桶法,可以直接参考《挑战》P187页代码。

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P3264.txt";

    static final int B = 1000;
    static final int MAX_N = 50000 * 2;
    int[] max = new int[MAX_N / B + 1];
    int[] min = new int[MAX_N / B + 1];

    static final int INF = 1 << 29;
    void solve() {
        int N = ni();
        int Q = ni();

        int[] cows = new int[N];

        Arrays.fill(max, -INF);
        Arrays.fill(min, INF);

        for (int i = 0; i < N; ++i){
            cows[i] = ni();
            max[i / B] = Math.max(max[i / B], cows[i]);
            min[i / B] = Math.min(min[i / B], cows[i]);
        }

        for (int q = 0; q < Q; ++q){
            int i = ni();
            int j = ni();
            i--;
            // [i, j)
            int minHeight = INF;
            int maxHeight = -INF;

            int l = i, r = j;
            while (l < r && l % B != 0){
                minHeight = Math.min(minHeight, cows[l]);
                maxHeight = Math.max(maxHeight, cows[l++]);
            }

            while (l < r && r % B != 0){
                minHeight = Math.min(minHeight, cows[--r]);
                maxHeight = Math.max(maxHeight, cows[r]);
            }

            while (l < r){
                int b = l / B;
                minHeight = Math.min(minHeight, min[b]);
                maxHeight = Math.max(maxHeight, max[b]);
                l += B;
            }

            out.println(maxHeight - minHeight);
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3368: Frequent values

思路:一开始采用分桶法,用Map记录每个元素的个数,最后再拼接答案,但这种做法超时了,后来想了下,因为统计个数这操作太慢,且并没有充分利用原数组的非递减性质。

超时代码如下:

代码语言:javascript
复制
    static final int B = 1000;
    static final int MAX_N = 100000 + 2000;
    Map<Integer, Integer>[] bucket = new HashMap[MAX_N / B];

    void solve() {
        while (true){
            int N = ni();
            if (N == 0) break;

            int Q = ni();
            int[] A = new int[N];

            for (int i = 0; i <= N / B; ++i) bucket[i] = new HashMap<Integer, Integer>();
            for (int i = 0; i < N; ++i){
                A[i] = ni();
                int b = i / B;
                if (!bucket[b].containsKey(A[i])) bucket[b].put(A[i], 0);
                bucket[b].put(A[i], bucket[b].get(A[i]) + 1);
            }

            for (int q = 0; q < Q; ++q){
                int i = ni();
                int j = ni();
                i--;

                int l = i, r = j;
                Map<Integer, Integer> map = new HashMap<Integer, Integer>();
                int max = 0;
                while (l < r && l % B != 0){
                    int key = A[l++];
                    if (!map.containsKey(key)) map.put(key, 0);
                    map.put(key, map.get(key) + 1);
                }

                while (l < r && r % B != 0){
                    int key = A[--r];
                    if (!map.containsKey(key)) map.put(key, 0);
                    map.put(key, map.get(key) + 1);
                }

                while (l < r){
                    int b = l / B;
                    for (int key : bucket[b].keySet()){
                        if (!map.containsKey(key)) map.put(key, 0);
                        map.put(key, map.get(key) + bucket[b].get(key));
                    }
                    l += B;
                }

                for (int key : map.keySet()){
                    max = Math.max(map.get(key), max);
                }

                out.println(max);
            }
        }
    }

此题采用了线段树,我们维护三元组分别表示为{当前区间的最大频次,左边界元素的频次,右边界出现的频次},这样我们就可以从下往上构造每个区间的三元组了,且能够由左孩子和右孩子不断向上合并,用分治的手段解决了统计频次问题。

参考至:http://www.hankcs.com/program/algorithm/poj-3368-frequent-values-am.html

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P3368.txt";

    static final int SIZE = (1 << 18) - 1;
    class Pair{
        int max;
        int left;
        int right;


        public Pair(int max, int left, int right){
            this.max = max;
            this.left = left;
            this.right = right;
        }

        @Override
        public String toString() {
            return max + " " + left + " " + right;
        }
    }

    Pair[] dat = new Pair[SIZE];
    int[] A;
    void solve() {
        while (true){
            int N = ni();
            if (N == 0) break;

            int Q = ni();
             A = new int[N];

            for (int i = 0; i < N; ++i){
                A[i] = ni();
            }

            init(0, 0, N);
            for (int q = 0; q < Q; ++q){
                int i = ni();
                int j = ni();
                i--;
                out.println(query(0, i, j, 0, N).max);
            }

        }
    }

    // 区间 [l, r)
    public void init(int k, int l, int r){
        if (r - l == 1){
            dat[k] = new Pair(1, 1, 1);
        }
        else{
            int lch = 2 * k + 1;
            int rch = 2 * k + 2;
            init(lch, l, (l + r) / 2);
            init(rch, (l + r) / 2, r);

            dat[k] = new Pair(0, 0, 0);
            dat[k].max = Math.max(dat[lch].max, dat[rch].max);

            int mid = (l + r) / 2;
            if (A[mid - 1] == A[mid]){
                dat[k].max = Math.max(dat[k].max, dat[lch].right + dat[rch].left);
            }

            if (A[l] == A[mid]){
                dat[k].left = dat[lch].left + dat[rch].left;
            }
            else{
                dat[k].left = dat[lch].left;
            }

            if (A[r - 1] == A[mid - 1]){
                dat[k].right = dat[lch].right + dat[rch].right;
            }
            else{
                dat[k].right = dat[rch].right;
            }
        }
    }

    // 查询
    public Pair query(int k, int i, int j, int l, int r){
        if (j <= l || i >= r) return new Pair(0, 0, 0);
        else if (i <= l && j >= r){
            return dat[k];
        }
        else{
            int mid = (l + r) / 2;
            Pair lch = query(2 * k + 1, i, j, l, mid);
            Pair rch = query(2 * k + 2, i, j, mid, r);

            Pair ans = new Pair(Math.max(lch.max, rch.max), lch.left, rch.right);

            if (A[mid] == A[mid - 1]){
                ans.max = Math.max(ans.max, lch.right + rch.left);
            }

            if (A[l] == A[mid]) ans.left += rch.left;
            if (A[r - 1] == A[mid - 1]) ans.right += lch.right;

            return ans;
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3470: Walls

累觉不爱,待解决……

题解参考博文:http://www.hankcs.com/program/algorithm/poj-3470-walls.html

POJ 1201: Intervals

思路:排序+贪心+归简

首先按照右区间进行从小到达排序,这样开始选第一个区间时,选择最大的几个数,可以证明这种与后续出现区间存在交集的“可能性”最大,接着再考虑第二个区间时,把选择过的元素排除,继续取剩余最大的几个数,这样一来问题规模逐步缩小,完美解决。

如果快速求解区间内有多少个元素被选?BIT或线段树,BIT更简洁易懂。

代码如下:

代码语言:javascript
复制
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class SolutionDay24_P1201 {
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P1201.txt";

    class Range implements Comparable<Range>{
        int l;
        int r;
        int c;

        public Range(int l, int r, int c){
            this.l = l;
            this.r = r;
            this.c = c;
        }

        @Override
        public int compareTo(Range that) {
            return this.r - that.r;
        }

        @Override
        public String toString() {
            return l + " " + r + " " + c;
        }
    }

    void solve() {
        int n = ni();
        init();
        Range[] intervals = new Range[n];
        for (int i = 0; i < n; ++i){
            intervals[i] = new Range(ni(), ni(), ni());
        }

        Arrays.sort(intervals);
        boolean[] visited = new boolean[MAX_N];
        int res = 0;
        for (int i = 0; i < n; ++i){
            Range now = intervals[i];
            int picked = sum(now.l, now.r);
            if (picked == 0){
                res += now.c;
                for (int j = 0; j < now.c; ++j){
                    add(now.r - j, 1);
                    visited[now.r - j] = true;
                }
            }
            else{
                int choose = now.c - picked;
                if (choose <= 0) continue;
                res += choose;
                int pos = now.r;
                while (choose > 0){
                    if (visited[pos]){
                        pos --;
                    }
                    else{
                        add(pos, 1);
                        visited[pos] = true;
                        pos --;
                        choose --;
                    }
                }
            }
        }

        out.println(res);

    }


    /*********************BIT************************/
    int MAX_N = 2 * (50000 + 16);
    int[] BIT;

    public void init(){
        BIT = new int[MAX_N];
    }

    public void add(int i, int val){
        while (i <= MAX_N){
            BIT[i] += val;
            i += i & -i;
        }
    }

    public int sum(int i){
        int res = 0;
        while (i > 0){
            res += BIT[i];
            i -= i & -i;
        }
        return res;
    }

    //区间 [l, r]
    public int sum(int l, int r){
        return sum(r) - sum(l - 1);
    }


    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new SolutionDay24_P1201().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

UVA 11990: Inversion

求逆序对个数的做法有很多种,可以采用分治合并,BST,线段树等等均可。

可以参考http://blog.csdn.net/u014688145/article/details/72864156

但此题除了求逆序对之外,还会动态删减,这样就需要使用一种好的数据结构去维护一些内部信息,并且在删除时,能够准确的表达出来。

此处我们采用分桶法,或者是一种分治的手段。我们把每个点映射到二维的平面上去(i, Ai),这样一来,逆序对的个数为其左上点的个数和右下点的个数总和。如何表示动态删除?

因为在二维平面上,可以方便的表达某个点的左上和右下区域,所以删除也是很自然的事情,至于点没有了(即不参与计算),可以用(-1,-1)表示。

之所以可以这么做,因为题目给了一下额外性质:

  • permutation,范围是在1 ~ N,且小标在0 ~ N - 1,所以这些值可以方便的映射到二维坐标平面(都不需要离散化处理)
  • 当然在做题时,很重要的一点在于下标自然的排序了,所以这给我们累加逆序对的个数也带来了极大好处,时间复杂度只需要O(n)

JAVA 代码如下:

代码语言:javascript
复制
import java.util.Arrays;
import java.util.Scanner;

public class Main{

    static final int MAX_N = 200000 + 16;
    static final int MAX_M = 200000 + 16;
    static final int BUCKET_SIZE = 450;

    static int[] A;
    static int[] POS;
    static int N, M;

    static class Bucket{
        int count;
        int prefix_sum;
    }
    static Bucket[][] buckets;

    static class Space{
        int[] X;
        int[] Y;

        public Space(){
            X = new int[MAX_N];
            Y = new int[MAX_N];

            Arrays.fill(X, -1);
            Arrays.fill(Y, -1);
        }


        public void add(int x, int y){
            X[y] = x;
            Y[x] = y;
        }

        public void remove(int x, int y){
            X[y] = -1;
            Y[x] = -1;
        }

        public int getX(int y){
            return X[y];
        }

        public int getY(int x){
            return Y[x];
        }
    }
    static Space space;

    public static void update_prefix_sum(int bx, int by){
        int len = buckets[0].length;
        int sum = (bx > 0 ? buckets[bx - 1][by].prefix_sum : 0);
        for (int i = bx; i < len; ++i){
            sum += buckets[i][by].count;
            buckets[i][by].prefix_sum = sum;
        }
    }

    public static void add(int x, int y){
        space.add(x, y);
        int bx = x / BUCKET_SIZE;
        int by = y / BUCKET_SIZE;
        ++buckets[bx][by].count;
        update_prefix_sum(bx, by);
    }

    public static void remove(int x, int y){
        space.remove(x, y);
        int bx = x / BUCKET_SIZE;
        int by = y / BUCKET_SIZE;
        --buckets[bx][by].count;
        update_prefix_sum(bx, by);
    }

    // 统计区间 [0,0] 到 [x, y] 的点的个数
    public static int sum(int x, int y){
        int bx = x / BUCKET_SIZE;
        int by = y / BUCKET_SIZE;

        int count = 0;
        for (int i = 0; i < by; ++i){
            if (bx > 0)
                count += buckets[bx - 1][i].prefix_sum;
        }

        for (int py = by * BUCKET_SIZE; py < y; ++py){
            if (space.getX(py) != -1 && space.getX(py) < x) count++;
        }

        for (int px = bx * BUCKET_SIZE; px < x; ++px){
            if (space.getY(px) != -1 && space.getY(px) < by * BUCKET_SIZE) count++;
        }
        return count;
    }

    public static int sum_inversion(int x, int y){
        int res = 0;
        int intersection = sum(x, y);
        res += sum(x, N) - intersection;
        res += sum(N, y) - intersection;
        return res;
    }

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        while (in.hasNext()){
            N = in.nextInt();
            M = in.nextInt();
            A = new int[N];
            POS = new int[N];
            for (int i = 0; i < N; ++i){
                A[i] = in.nextInt();
                A[i]--;
                POS[A[i]] = i;
            }

            long res = 0;
            space = new Space();
            buckets = new Bucket[MAX_N / BUCKET_SIZE + 1][MAX_N / BUCKET_SIZE + 1];
            int len = buckets.length;
            for (int i = 0; i < len; ++i){
                for (int j = 0; j < len; ++j){
                    buckets[i][j] = new Bucket();
                }
            }

            for (int i = 0; i < N; ++i){
                add(i, A[i]);
                res += sum_inversion(i, A[i]);
            }

            for (int i = 0; i < M; ++i){
                int m = in.nextInt();
                m--;
                System.out.println(res);
                res -= sum_inversion(POS[m], m);
                remove(POS[m], m);
            }
        }
        in.close();
    }

}

中间利用了一些加速的手段,如前缀和,但总体就是分治的一种迭代版本。。。可惜还是TLE了,改成C++版本,能过,蛋疼。

代码如下:

代码语言:javascript
复制
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define MAX_N 200000 + 16
#define MAX_M 100000 + 16
#define BUCKET_SIZE 450 // sqrt(MAX_N) = 447

int A[MAX_N], N, M;
struct Bucket
{
    int count;      // 内部数字的个数
    int prefix_sum; // 前缀和
}bucket[BUCKET_SIZE][BUCKET_SIZE];

// 平面坐标快速查询
struct Space
{
    int X[MAX_N], Y[MAX_N];

    void insert(const int& x, const int& y)
    {
        X[y] = x;
        Y[x] = y;
    }

    void remove(const int& x, const int& y)
    {
        X[y] = -1;
        Y[x] = -1;
    }

    int getX(const int& y)
    {
        return X[y];
    }

    int getY(const int& x)
    {
        return Y[x];
    }
    void init()
    {
        memset(X, -1, sizeof(X)); memset(Y, -1, sizeof(Y));
    }
} space;

void update_prefix_sum(int bx, int by) 
{
    int sum = (bx > 0 ? bucket[bx - 1][by].prefix_sum : 0);
    for (int i = bx; i < BUCKET_SIZE; ++i)
    {
        sum += bucket[i][by].count;
        bucket[i][by].prefix_sum = sum;
    }
}

// 加入一个点
void add(int x, int y) 
{
    space.insert(x, y);
    int bx = x / BUCKET_SIZE;
    int by = y / BUCKET_SIZE;

    ++bucket[bx][by].count;
    update_prefix_sum(bx, by);
}

// 删除一个点
void remove(int x, int y) 
{
    space.remove(x, y);
    int bx = x / BUCKET_SIZE;
    int by = y / BUCKET_SIZE;

    --bucket[bx][by].count;
    update_prefix_sum(bx, by);
}

// (0,0)与(x,y)围起来的矩形区域的点的个数
int count_sum(int x, int y) 
{
    int block_w = x / BUCKET_SIZE;
    int block_h = y / BUCKET_SIZE;

    int count = 0;
    // 完全在内部的桶
    for (int i = 0; i < block_h; ++i) 
    {
        if (block_w > 0)
        {
            count += bucket[block_w - 1][i].prefix_sum;
        }
    }
    // 其他
    for (int i = block_w * BUCKET_SIZE; i < x; ++i) 
    {
        if (space.getY(i) != -1 && space.getY(i) < block_h * BUCKET_SIZE) count++;
    }
    for (int i = block_h * BUCKET_SIZE; i < y; ++i) 
    {
        if (space.getX(i) != -1 && space.getX(i) < x) count++;
    }
    return count;
}

// (x,y)的左上和右下方块内部点的个数就是逆序数对的个数
int count_inversion(int x, int y) 
{
    int count = 0;
    int intersection = count_sum(x, y);
    count += count_sum(x, N) - intersection;    // 左上
    count += count_sum(N, y) - intersection;    // 右下
    return count;
}

///////////////////////////SubMain//////////////////////////////////
int main(int argc, char *argv[])
{
#ifndef ONLINE_JUDGE
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
#endif
    while (scanf("%d %d", &N, &M) != EOF)
    {
        space.init();
        memset(bucket, 0, sizeof(bucket));
        for (int i = 0; i < N; ++i) 
        {
            scanf("%d", &A[i]);
            --A[i];
        }
        long long inversion = 0;
        for (int i = 0; i < N; ++i) 
        {
            add(i, A[i]);
            inversion += count_inversion(i, A[i]);
        }
        for (int i = 0; i < M; ++i) 
        {
            int q;
            scanf("%d", &q);
            --q;
            printf("%lld\n", inversion);
            inversion -= count_inversion(space.getX(q), q);
            remove(space.getX(q), q);
        }
    }
#ifndef ONLINE_JUDGE
    fclose(stdin);
    fclose(stdout);
    system("out.txt");
#endif
    return 0;
}

参考了:http://www.hankcs.com/program/algorithm/uva-11990-inversion.html

但count_sum的函数做了一些改动,以自己的方式计算了(0,0)到(x,y)的个数,大同小异,注意一些边界即可。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2017年08月25日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 挑战程序竞赛系列(36):3.3线段树和平方分割
    • 分桶法和平方分割
      • POJ 2104: K-th Number
        • POJ 3264: Balanced Lineup
          • POJ 3368: Frequent values
            • POJ 3470: Walls
              • POJ 1201: Intervals
                • UVA 11990: Inversion
                领券
                问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档