专栏首页机器学习入门挑战程序竞赛系列(36):3.3线段树和平方分割

挑战程序竞赛系列(36):3.3线段树和平方分割

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u014688145/article/details/77571884

挑战程序竞赛系列(36):3.3线段树和平方分割

详细代码可以fork下Github上leetcode项目,不定期更新。

练习题如下:

分桶法和平方分割

具体可以参考《挑战》P183页,这里简单说说思想。

我的理解:空间换时间,举个例子:

1 2 3 4 5 6 7 8 9 10

求指定区间内的最小值

区间 [1, 3]中的最小值为1
区间 [4, 8]中的最小值为4

传统做法,遍历指定区间需要O(n)次,能够求出答案,但由于频繁查询可能需要O(m)次,所以整体时间复杂度为O(nm)次,有没有办法把时间复杂度降低一些?平方分桶法可以降低到O(mn√)O(m\sqrt n)。

说白了,上述情况的每个结点维护自己的信息,分桶法的思想是:

组合几个个体成一个桶,由桶统一维护信息,所以对我们来说,它的呈现形式是多个个体和一个个桶,也就是所谓的空间换时间。

比如上述例子:

{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}

如果给定查询区间[2, 9]

那就意味着要遍历个体元素2和个体元素9,但查询区间中,有几个桶完全包含了该区间,如{3, 4}

那么,我们就可以直接拿出桶维护的最小元素信息

所以遍历时,我们只需要遍历 2个元素 + 3个桶
传统做法需要遍历 8个元素

谁快?

当然分桶法实现比线段树简单,但划分整体比线段树粗暴,所以时间复杂度略慢于线段树。

POJ 2104: K-th Number

思路很简单,根据分桶法,可以把它们放在一个个桶内单独维护,在区间内的桶,因为全部包含,所以排序后可以用二分快速找出答案,而桶不完全包含在区间内的,需要单独计算。整体再采用二分,在所有候选答案中猜出即可。

具体思路可以参考《挑战》P186页,代码如下:

    static final int MAX_N = 100000 + 16;
    static final int B = 1000;
    List<Integer>[] bucket = new ArrayList[MAX_N / B + 1];

    void solve() {
        int n = ni();
        int m = ni();
        int[] sort = new int[n];
        int[] arra = new int[n];

        for (int i = 0; i <= n / B; ++i) bucket[i] = new ArrayList<Integer>();

        for (int i = 0; i < n; ++i){
            arra[i] = ni();
            bucket[i / B].add(arra[i]);
            sort[i] = arra[i];
        }

        for (int i = 0; i < n / B; ++i){
            Collections.sort(bucket[i]);
        }

        Arrays.sort(sort);

        for (int t = 0; t < m; ++t){
            int i = ni();
            int j = ni();
            int k = ni();
            i--;
            j--;

            int lf = -1, rt = n - 1;
            while (rt - lf > 1){
                int l = i;
                int r = j;

                int s = l / B;
                int e = r / B;
                int mid = (lf + rt) / 2;

                int key = sort[mid];

                int x = 0;
                if (e - s <= 1){
                    for (int y = l; y <= r; ++y){
                        if (arra[y] <= key) x++;
                    }
                }
                else{

                    while (l < n && l / B == s){
                        if (arra[l] <= key) x++;
                        l++;
                    }

                    while (r >= 0 && r / B == e){
                        if (arra[r] <= key) x++;
                        r--;
                    }

                    for (int y = s + 1; y < e; ++y){
                        x += binarySearch(bucket[y], key) + 1;
                    }
                }
                if (x < k){
                    lf = mid;
                }
                else{
                    rt = mid;
                }
            }
            out.println(sort[rt]);
        }

    }

    public int binarySearch(List<Integer> aux, int key){
        int lf = 0, rt = aux.size() - 1;
        while (lf < rt){
            int mid = lf + (rt - lf + 1) / 2;
            if (aux.get(mid) > key){
                rt = mid - 1;
            }
            else lf = mid;
        }
        if (aux.get(lf) <= key) return lf;
        return -1;
    }

TLE了,显然此题用分桶法还不够快,因此我们采用线段树来解决,线段树维护的独立个体是自底向上慢慢长大的,所以空间复杂度更高,但速度会更快,代码如下:

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.InputMismatchException;
import java.util.List;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P2104.txt";

    static final int SIZE = (1 << 18) - 1;
    List<Integer>[] dat = new ArrayList[SIZE];
    int[] A;

    void solve(){
        int N = ni();
        int M = ni();
        A = new int[N];
        int[] sort = new int[N];
        for (int i = 0; i < N; ++i){
            A[i] = ni();
            sort[i] = A[i];
        }
        for (int i = 0; i < dat.length; ++i) dat[i] = new ArrayList<Integer>();

        Arrays.sort(sort);
        init(0, 0, N);

        for (int t = 0; t < M; ++t){
            int i = ni();
            int j = ni();
            int k = ni();
            i--;
            int lf = -1, rt = N - 1;
            while (rt - lf > 1){
                int mid = (lf + rt) / 2;
                int query = query(0, i, j, sort[mid], 0, N);
                if (query < k){
                    lf = mid;
                }
                else{
                    rt = mid;
                }
            }

            out.println(sort[rt]);
        }
    }


    /******************以下是线段树******************/

    /***
     * 区间 [l, r)
     * @param k
     * @param l
     * @param r
     */
    public void init(int k, int l, int r){
        if (r - l == 1){
            dat[k].add(A[l]);
        }
        else{
            int lch = 2 * k + 1;
            int rch = 2 * k + 2;
            init(lch, l, (l + r) / 2); //为了能够准确的划分区间
            init(rch, (l + r) / 2, r);

            merge (dat[lch], dat[rch], dat[k]);
        }
    }

    /**
     * 查询区间 [i, j)
     * 线段树区间 [l, r)
     * @param k
     * @param i
     * @param j
     * @param x
     * @param l
     * @param r
     * @return
     */
    public int query(int k, int i, int j, int x, int l, int r){
        if (j <= l || i >= r) return 0;
        else if (i <= l && j >= r){
            return binarySearch(dat[k], x) + 1;
        }else{
            int ans = 0;
            ans += query(2 * k + 1, i, j, x, l, (l + r) / 2);
            ans += query(2 * k + 2, i, j, x, (l + r) / 2, r);
            return ans;
        }
    }

    public void merge(List<Integer> lch, List<Integer> rch, List<Integer> k){
        int l = 0, r = 0;
        while (l < lch.size() && r < rch.size()){
            if (lch.get(l) <= rch.get(r)){
                k.add(lch.get(l));
                l++;
            }
            else{
                k.add(rch.get(r));
                r++;
            }
        }

        while (l < lch.size()) k.add(lch.get(l++));
        while (r < rch.size()) k.add(rch.get(r++));
    }


    public int binarySearch(List<Integer> aux, int key){
        int lf = 0, rt = aux.size() - 1;
        while (lf < rt){
            int mid = lf + (rt - lf + 1) / 2;
            if (aux.get(mid) > key){
                rt = mid - 1;
            }
            else lf = mid;
        }
        if (aux.get(lf) <= key) return lf;
        return -1;
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3264: Balanced Lineup

水题,思路很直接,关键怎么加快速度,采用分桶法,可以直接参考《挑战》P187页代码。

代码如下:

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P3264.txt";

    static final int B = 1000;
    static final int MAX_N = 50000 * 2;
    int[] max = new int[MAX_N / B + 1];
    int[] min = new int[MAX_N / B + 1];

    static final int INF = 1 << 29;
    void solve() {
        int N = ni();
        int Q = ni();

        int[] cows = new int[N];

        Arrays.fill(max, -INF);
        Arrays.fill(min, INF);

        for (int i = 0; i < N; ++i){
            cows[i] = ni();
            max[i / B] = Math.max(max[i / B], cows[i]);
            min[i / B] = Math.min(min[i / B], cows[i]);
        }

        for (int q = 0; q < Q; ++q){
            int i = ni();
            int j = ni();
            i--;
            // [i, j)
            int minHeight = INF;
            int maxHeight = -INF;

            int l = i, r = j;
            while (l < r && l % B != 0){
                minHeight = Math.min(minHeight, cows[l]);
                maxHeight = Math.max(maxHeight, cows[l++]);
            }

            while (l < r && r % B != 0){
                minHeight = Math.min(minHeight, cows[--r]);
                maxHeight = Math.max(maxHeight, cows[r]);
            }

            while (l < r){
                int b = l / B;
                minHeight = Math.min(minHeight, min[b]);
                maxHeight = Math.max(maxHeight, max[b]);
                l += B;
            }

            out.println(maxHeight - minHeight);
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3368: Frequent values

思路:一开始采用分桶法,用Map记录每个元素的个数,最后再拼接答案,但这种做法超时了,后来想了下,因为统计个数这操作太慢,且并没有充分利用原数组的非递减性质。

超时代码如下:

    static final int B = 1000;
    static final int MAX_N = 100000 + 2000;
    Map<Integer, Integer>[] bucket = new HashMap[MAX_N / B];

    void solve() {
        while (true){
            int N = ni();
            if (N == 0) break;

            int Q = ni();
            int[] A = new int[N];

            for (int i = 0; i <= N / B; ++i) bucket[i] = new HashMap<Integer, Integer>();
            for (int i = 0; i < N; ++i){
                A[i] = ni();
                int b = i / B;
                if (!bucket[b].containsKey(A[i])) bucket[b].put(A[i], 0);
                bucket[b].put(A[i], bucket[b].get(A[i]) + 1);
            }

            for (int q = 0; q < Q; ++q){
                int i = ni();
                int j = ni();
                i--;

                int l = i, r = j;
                Map<Integer, Integer> map = new HashMap<Integer, Integer>();
                int max = 0;
                while (l < r && l % B != 0){
                    int key = A[l++];
                    if (!map.containsKey(key)) map.put(key, 0);
                    map.put(key, map.get(key) + 1);
                }

                while (l < r && r % B != 0){
                    int key = A[--r];
                    if (!map.containsKey(key)) map.put(key, 0);
                    map.put(key, map.get(key) + 1);
                }

                while (l < r){
                    int b = l / B;
                    for (int key : bucket[b].keySet()){
                        if (!map.containsKey(key)) map.put(key, 0);
                        map.put(key, map.get(key) + bucket[b].get(key));
                    }
                    l += B;
                }

                for (int key : map.keySet()){
                    max = Math.max(map.get(key), max);
                }

                out.println(max);
            }
        }
    }

此题采用了线段树,我们维护三元组分别表示为{当前区间的最大频次,左边界元素的频次,右边界出现的频次},这样我们就可以从下往上构造每个区间的三元组了,且能够由左孩子和右孩子不断向上合并,用分治的手段解决了统计频次问题。

参考至:http://www.hankcs.com/program/algorithm/poj-3368-frequent-values-am.html

代码如下:

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P3368.txt";

    static final int SIZE = (1 << 18) - 1;
    class Pair{
        int max;
        int left;
        int right;


        public Pair(int max, int left, int right){
            this.max = max;
            this.left = left;
            this.right = right;
        }

        @Override
        public String toString() {
            return max + " " + left + " " + right;
        }
    }

    Pair[] dat = new Pair[SIZE];
    int[] A;
    void solve() {
        while (true){
            int N = ni();
            if (N == 0) break;

            int Q = ni();
             A = new int[N];

            for (int i = 0; i < N; ++i){
                A[i] = ni();
            }

            init(0, 0, N);
            for (int q = 0; q < Q; ++q){
                int i = ni();
                int j = ni();
                i--;
                out.println(query(0, i, j, 0, N).max);
            }

        }
    }

    // 区间 [l, r)
    public void init(int k, int l, int r){
        if (r - l == 1){
            dat[k] = new Pair(1, 1, 1);
        }
        else{
            int lch = 2 * k + 1;
            int rch = 2 * k + 2;
            init(lch, l, (l + r) / 2);
            init(rch, (l + r) / 2, r);

            dat[k] = new Pair(0, 0, 0);
            dat[k].max = Math.max(dat[lch].max, dat[rch].max);

            int mid = (l + r) / 2;
            if (A[mid - 1] == A[mid]){
                dat[k].max = Math.max(dat[k].max, dat[lch].right + dat[rch].left);
            }

            if (A[l] == A[mid]){
                dat[k].left = dat[lch].left + dat[rch].left;
            }
            else{
                dat[k].left = dat[lch].left;
            }

            if (A[r - 1] == A[mid - 1]){
                dat[k].right = dat[lch].right + dat[rch].right;
            }
            else{
                dat[k].right = dat[rch].right;
            }
        }
    }

    // 查询
    public Pair query(int k, int i, int j, int l, int r){
        if (j <= l || i >= r) return new Pair(0, 0, 0);
        else if (i <= l && j >= r){
            return dat[k];
        }
        else{
            int mid = (l + r) / 2;
            Pair lch = query(2 * k + 1, i, j, l, mid);
            Pair rch = query(2 * k + 2, i, j, mid, r);

            Pair ans = new Pair(Math.max(lch.max, rch.max), lch.left, rch.right);

            if (A[mid] == A[mid - 1]){
                ans.max = Math.max(ans.max, lch.right + rch.left);
            }

            if (A[l] == A[mid]) ans.left += rch.left;
            if (A[r - 1] == A[mid - 1]) ans.right += lch.right;

            return ans;
        }
    }

    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new Main().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

POJ 3470: Walls

累觉不爱,待解决……

题解参考博文:http://www.hankcs.com/program/algorithm/poj-3470-walls.html

POJ 1201: Intervals

思路:排序+贪心+归简

首先按照右区间进行从小到达排序,这样开始选第一个区间时,选择最大的几个数,可以证明这种与后续出现区间存在交集的“可能性”最大,接着再考虑第二个区间时,把选择过的元素排除,继续取剩余最大的几个数,这样一来问题规模逐步缩小,完美解决。

如果快速求解区间内有多少个元素被选?BIT或线段树,BIT更简洁易懂。

代码如下:

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class SolutionDay24_P1201 {
    InputStream is;
    PrintWriter out;
    String INPUT = "./data/judge/201708/P1201.txt";

    class Range implements Comparable<Range>{
        int l;
        int r;
        int c;

        public Range(int l, int r, int c){
            this.l = l;
            this.r = r;
            this.c = c;
        }

        @Override
        public int compareTo(Range that) {
            return this.r - that.r;
        }

        @Override
        public String toString() {
            return l + " " + r + " " + c;
        }
    }

    void solve() {
        int n = ni();
        init();
        Range[] intervals = new Range[n];
        for (int i = 0; i < n; ++i){
            intervals[i] = new Range(ni(), ni(), ni());
        }

        Arrays.sort(intervals);
        boolean[] visited = new boolean[MAX_N];
        int res = 0;
        for (int i = 0; i < n; ++i){
            Range now = intervals[i];
            int picked = sum(now.l, now.r);
            if (picked == 0){
                res += now.c;
                for (int j = 0; j < now.c; ++j){
                    add(now.r - j, 1);
                    visited[now.r - j] = true;
                }
            }
            else{
                int choose = now.c - picked;
                if (choose <= 0) continue;
                res += choose;
                int pos = now.r;
                while (choose > 0){
                    if (visited[pos]){
                        pos --;
                    }
                    else{
                        add(pos, 1);
                        visited[pos] = true;
                        pos --;
                        choose --;
                    }
                }
            }
        }

        out.println(res);

    }


    /*********************BIT************************/
    int MAX_N = 2 * (50000 + 16);
    int[] BIT;

    public void init(){
        BIT = new int[MAX_N];
    }

    public void add(int i, int val){
        while (i <= MAX_N){
            BIT[i] += val;
            i += i & -i;
        }
    }

    public int sum(int i){
        int res = 0;
        while (i > 0){
            res += BIT[i];
            i -= i & -i;
        }
        return res;
    }

    //区间 [l, r]
    public int sum(int l, int r){
        return sum(r) - sum(l - 1);
    }


    void run() throws Exception {
        is = oj ? System.in : new FileInputStream(new File(INPUT));
        out = new PrintWriter(System.out);

        long s = System.currentTimeMillis();
        solve();
        out.flush();
        tr(System.currentTimeMillis() - s + "ms");
    }

    public static void main(String[] args) throws Exception {
        new SolutionDay24_P1201().run();
    }

    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;

    private int readByte() {
        if (lenbuf == -1)
            throw new InputMismatchException();
        if (ptrbuf >= lenbuf) {
            ptrbuf = 0;
            try {
                lenbuf = is.read(inbuf);
            } catch (IOException e) {
                throw new InputMismatchException();
            }
            if (lenbuf <= 0)
                return -1;
        }
        return inbuf[ptrbuf++];
    }

    private boolean isSpaceChar(int c) {
        return !(c >= 33 && c <= 126);
    }

    private int skip() {
        int b;
        while ((b = readByte()) != -1 && isSpaceChar(b))
            ;
        return b;
    }

    private double nd() {
        return Double.parseDouble(ns());
    }

    private char nc() {
        return (char) skip();
    }

    private String ns() {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
                                    // ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }

    private char[] ns(int n) {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while (p < n && !(isSpaceChar(b))) {
            buf[p++] = (char) b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }

    private char[][] nm(int n, int m) {
        char[][] map = new char[n][];
        for (int i = 0; i < n; i++)
            map[i] = ns(m);
        return map;
    }

    private int[] na(int n) {
        int[] a = new int[n];
        for (int i = 0; i < n; i++)
            a[i] = ni();
        return a;
    }

    private int ni() {
        int num = 0, b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private long nl() {
        long num = 0;
        int b;
        boolean minus = false;
        while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
            ;
        if (b == '-') {
            minus = true;
            b = readByte();
        }

        while (true) {
            if (b >= '0' && b <= '9') {
                num = num * 10 + (b - '0');
            } else {
                return minus ? -num : num;
            }
            b = readByte();
        }
    }

    private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

    private void tr(Object... o) {
        if (!oj)
            System.out.println(Arrays.deepToString(o));
    }
}

UVA 11990: Inversion

求逆序对个数的做法有很多种,可以采用分治合并,BST,线段树等等均可。

可以参考http://blog.csdn.net/u014688145/article/details/72864156

但此题除了求逆序对之外,还会动态删减,这样就需要使用一种好的数据结构去维护一些内部信息,并且在删除时,能够准确的表达出来。

此处我们采用分桶法,或者是一种分治的手段。我们把每个点映射到二维的平面上去(i, A[i]),这样一来,逆序对的个数为其左上点的个数和右下点的个数总和。如何表示动态删除?

因为在二维平面上,可以方便的表达某个点的左上和右下区域,所以删除也是很自然的事情,至于点没有了(即不参与计算),可以用(-1,-1)表示。

之所以可以这么做,因为题目给了一下额外性质:

  • permutation,范围是在1 ~ N,且小标在0 ~ N - 1,所以这些值可以方便的映射到二维坐标平面(都不需要离散化处理)
  • 当然在做题时,很重要的一点在于下标自然的排序了,所以这给我们累加逆序对的个数也带来了极大好处,时间复杂度只需要O(n)

JAVA 代码如下:

import java.util.Arrays;
import java.util.Scanner;

public class Main{

    static final int MAX_N = 200000 + 16;
    static final int MAX_M = 200000 + 16;
    static final int BUCKET_SIZE = 450;

    static int[] A;
    static int[] POS;
    static int N, M;

    static class Bucket{
        int count;
        int prefix_sum;
    }
    static Bucket[][] buckets;

    static class Space{
        int[] X;
        int[] Y;

        public Space(){
            X = new int[MAX_N];
            Y = new int[MAX_N];

            Arrays.fill(X, -1);
            Arrays.fill(Y, -1);
        }


        public void add(int x, int y){
            X[y] = x;
            Y[x] = y;
        }

        public void remove(int x, int y){
            X[y] = -1;
            Y[x] = -1;
        }

        public int getX(int y){
            return X[y];
        }

        public int getY(int x){
            return Y[x];
        }
    }
    static Space space;

    public static void update_prefix_sum(int bx, int by){
        int len = buckets[0].length;
        int sum = (bx > 0 ? buckets[bx - 1][by].prefix_sum : 0);
        for (int i = bx; i < len; ++i){
            sum += buckets[i][by].count;
            buckets[i][by].prefix_sum = sum;
        }
    }

    public static void add(int x, int y){
        space.add(x, y);
        int bx = x / BUCKET_SIZE;
        int by = y / BUCKET_SIZE;
        ++buckets[bx][by].count;
        update_prefix_sum(bx, by);
    }

    public static void remove(int x, int y){
        space.remove(x, y);
        int bx = x / BUCKET_SIZE;
        int by = y / BUCKET_SIZE;
        --buckets[bx][by].count;
        update_prefix_sum(bx, by);
    }

    // 统计区间 [0,0] 到 [x, y] 的点的个数
    public static int sum(int x, int y){
        int bx = x / BUCKET_SIZE;
        int by = y / BUCKET_SIZE;

        int count = 0;
        for (int i = 0; i < by; ++i){
            if (bx > 0)
                count += buckets[bx - 1][i].prefix_sum;
        }

        for (int py = by * BUCKET_SIZE; py < y; ++py){
            if (space.getX(py) != -1 && space.getX(py) < x) count++;
        }

        for (int px = bx * BUCKET_SIZE; px < x; ++px){
            if (space.getY(px) != -1 && space.getY(px) < by * BUCKET_SIZE) count++;
        }
        return count;
    }

    public static int sum_inversion(int x, int y){
        int res = 0;
        int intersection = sum(x, y);
        res += sum(x, N) - intersection;
        res += sum(N, y) - intersection;
        return res;
    }

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        while (in.hasNext()){
            N = in.nextInt();
            M = in.nextInt();
            A = new int[N];
            POS = new int[N];
            for (int i = 0; i < N; ++i){
                A[i] = in.nextInt();
                A[i]--;
                POS[A[i]] = i;
            }

            long res = 0;
            space = new Space();
            buckets = new Bucket[MAX_N / BUCKET_SIZE + 1][MAX_N / BUCKET_SIZE + 1];
            int len = buckets.length;
            for (int i = 0; i < len; ++i){
                for (int j = 0; j < len; ++j){
                    buckets[i][j] = new Bucket();
                }
            }

            for (int i = 0; i < N; ++i){
                add(i, A[i]);
                res += sum_inversion(i, A[i]);
            }

            for (int i = 0; i < M; ++i){
                int m = in.nextInt();
                m--;
                System.out.println(res);
                res -= sum_inversion(POS[m], m);
                remove(POS[m], m);
            }
        }
        in.close();
    }

}

中间利用了一些加速的手段,如前缀和,但总体就是分治的一种迭代版本。。。可惜还是TLE了,改成C++版本,能过,蛋疼。

代码如下:

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define MAX_N 200000 + 16
#define MAX_M 100000 + 16
#define BUCKET_SIZE 450 // sqrt(MAX_N) = 447

int A[MAX_N], N, M;
struct Bucket
{
    int count;      // 内部数字的个数
    int prefix_sum; // 前缀和
}bucket[BUCKET_SIZE][BUCKET_SIZE];

// 平面坐标快速查询
struct Space
{
    int X[MAX_N], Y[MAX_N];

    void insert(const int& x, const int& y)
    {
        X[y] = x;
        Y[x] = y;
    }

    void remove(const int& x, const int& y)
    {
        X[y] = -1;
        Y[x] = -1;
    }

    int getX(const int& y)
    {
        return X[y];
    }

    int getY(const int& x)
    {
        return Y[x];
    }
    void init()
    {
        memset(X, -1, sizeof(X)); memset(Y, -1, sizeof(Y));
    }
} space;

void update_prefix_sum(int bx, int by) 
{
    int sum = (bx > 0 ? bucket[bx - 1][by].prefix_sum : 0);
    for (int i = bx; i < BUCKET_SIZE; ++i)
    {
        sum += bucket[i][by].count;
        bucket[i][by].prefix_sum = sum;
    }
}

// 加入一个点
void add(int x, int y) 
{
    space.insert(x, y);
    int bx = x / BUCKET_SIZE;
    int by = y / BUCKET_SIZE;

    ++bucket[bx][by].count;
    update_prefix_sum(bx, by);
}

// 删除一个点
void remove(int x, int y) 
{
    space.remove(x, y);
    int bx = x / BUCKET_SIZE;
    int by = y / BUCKET_SIZE;

    --bucket[bx][by].count;
    update_prefix_sum(bx, by);
}

// (0,0)与(x,y)围起来的矩形区域的点的个数
int count_sum(int x, int y) 
{
    int block_w = x / BUCKET_SIZE;
    int block_h = y / BUCKET_SIZE;

    int count = 0;
    // 完全在内部的桶
    for (int i = 0; i < block_h; ++i) 
    {
        if (block_w > 0)
        {
            count += bucket[block_w - 1][i].prefix_sum;
        }
    }
    // 其他
    for (int i = block_w * BUCKET_SIZE; i < x; ++i) 
    {
        if (space.getY(i) != -1 && space.getY(i) < block_h * BUCKET_SIZE) count++;
    }
    for (int i = block_h * BUCKET_SIZE; i < y; ++i) 
    {
        if (space.getX(i) != -1 && space.getX(i) < x) count++;
    }
    return count;
}

// (x,y)的左上和右下方块内部点的个数就是逆序数对的个数
int count_inversion(int x, int y) 
{
    int count = 0;
    int intersection = count_sum(x, y);
    count += count_sum(x, N) - intersection;    // 左上
    count += count_sum(N, y) - intersection;    // 右下
    return count;
}

///////////////////////////SubMain//////////////////////////////////
int main(int argc, char *argv[])
{
#ifndef ONLINE_JUDGE
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
#endif
    while (scanf("%d %d", &N, &M) != EOF)
    {
        space.init();
        memset(bucket, 0, sizeof(bucket));
        for (int i = 0; i < N; ++i) 
        {
            scanf("%d", &A[i]);
            --A[i];
        }
        long long inversion = 0;
        for (int i = 0; i < N; ++i) 
        {
            add(i, A[i]);
            inversion += count_inversion(i, A[i]);
        }
        for (int i = 0; i < M; ++i) 
        {
            int q;
            scanf("%d", &q);
            --q;
            printf("%lld\n", inversion);
            inversion -= count_inversion(space.getX(q), q);
            remove(space.getX(q), q);
        }
    }
#ifndef ONLINE_JUDGE
    fclose(stdin);
    fclose(stdout);
    system("out.txt");
#endif
    return 0;
}

参考了:http://www.hankcs.com/program/algorithm/uva-11990-inversion.html

但count_sum的函数做了一些改动,以自己的方式计算了(0,0)到(x,y)的个数,大同小异,注意一些边界即可。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 挑战程序竞赛系列(21):3.2反转

    版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.n...

    用户1147447
  • 挑战程序竞赛系列(26):3.5二分图匹配(1)

    版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.n...

    用户1147447
  • 挑战程序竞赛系列(28):3.5最小费用流

    版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.n...

    用户1147447
  • 一遍记住Java常用的八种排序算法与代码实现

    (如果每次比较都交换,那么就是交换排序;如果每次比较完一个循环再交换,就是简单选择排序。)

    田维常
  • 你必须知道的指针基础-7.void指针与函数指针

      void *表示一个“不知道类型”的指针,也就不知道从这个指针地址开始多少字节为一个数据。和用int表示指针异曲同工,只是更明确是“指针”。

    Edison Zhou
  • ICPC Asia Shenyang 2019 Dudu's maze

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

    用户2965768
  • LeetCode 第 210 场周赛 解题报告

    那么在遍历过程中,栈中元素数量的最大值即为答案。栈中的(可以理解为还没遍历到匹配的),即那些嵌套的(。

    ACM算法日常
  • LeetCode 164. Maximum Gap (排序)

    题解:首先,当然我们可以用快排,排完序之后,遍历一遍数组,就能得到答案了。但是快速排序的效率是O(n* logn),不是题目要求的线性效率,也就是O(n)的效率...

    ShenduCC
  • 图论--拓扑排序--判断一个图能否被拓扑排序

    拓扑排序的实现条件,以及结合应用场景,我们都能得到拓扑排序适用于DAG图(Directed Acyclic Graph简称DAG)有向无环图, 根据关系我们能得...

    风骨散人Chiam
  • Educational Codeforces Round 67 (Rated for Div. 2) A~E 贪心,构造,线段树,树的子树

    Educational Codeforces Round 67 (Rated for Div. 2)

    用户2965768

扫码关注云+社区

领取腾讯云代金券