前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用栈的记忆化搜索来加速子集和算法

使用栈的记忆化搜索来加速子集和算法

作者头像
算法之名
发布2020-12-02 09:56:22
4500
发布2020-12-02 09:56:22
举报
文章被收录于专栏:算法之名算法之名

所谓子集和就是在一个数组中找出它的子集,使得该子集的和等于某个固定值。

一般我们都是使用递归加回溯的方式来处理的,代码如下(此处我们只找出一组满足的条件即可)

代码语言:javascript
复制
public class SubSet {

    private List<Integer> list = new ArrayList<>();   //用于存放求取子集中的元素
    @Getter
    private List<Integer> res = new ArrayList<>();

    //求取数组列表中元素和
    public int getSum(List<Integer> list) {
        int sum = 0;
        for(int i = 0;i < list.size();i++)
            sum += list.get(i);
        return sum;
    }

    public void getSubSet(int[] A, int m, int step) {
        if (res.size() > 0) {
            return;
        }
        while(step < A.length) {
            list.add(A[step]);
            if (getSum(list) == m) {
                if (getSum(res) == 0) {
                    res.addAll(list);
                }
            }
            step++;
            getSubSet(A, m, step);
            list.remove(list.size() - 1);   //回溯执行语句,删除列表最后一个元素
        }
    }

    public static void main(String[] args) {
        SubSet test = new SubSet();
        int[] A = new int[6];
        for(int i = 0;i < 6;i++) {
            A[i] = i + 1;
        }
        test.getSubSet(A, 8, 0);
        System.out.println(test.getRes());
    }
}

运行结果

代码语言:javascript
复制
[1, 2, 5]

但是这个算法的时间复杂度非常高,是NP级别的。如果数据量比较大的时候,将很难完成运算。

现在我们用栈和哈希缓存来加速这个算法。主要是缓存计算结果,不用每次都去getSum中把list的和算一遍。其思想主要是记忆化搜索,可以参考本人这篇博客动态规划、回溯、贪心,分治

代码语言:javascript
复制
public class SubSet {

    private List<Integer> list = new ArrayList<>();   //用于存放求取子集中的元素
    @Getter
    private List<Integer> res = new ArrayList<>();
    private Deque<Integer> deque = new ArrayDeque<>();
    private Map<String,Integer> map = new HashMap<>();

    //求取数组列表中元素和
    public int getSum(List<Integer> list) {
        int sum = 0;
        for(int i = 0;i < list.size();i++)
            sum += list.get(i);
        return sum;
    }

    public void getSubSet(int[] A, int m, int step) {
        if (res.size() > 0) {
            return;
        }
        while(step < A.length) {
            list.add(A[step]);
            if (!map.containsKey(deque.toString())) {
                int sum = getSum(list);
                deque.push(A[step]);
                map.put(deque.toString(),sum);
                if (sum == m) {
                    if (getSum(res) == 0) {
                        res.addAll(list);
                    }
                }
            }else {
                int sum = map.get(deque.toString()) + A[step];
                deque.push(A[step]);
                map.put(deque.toString(),sum);
                if (sum == m) {
                    if (getSum(res) == 0) {
                        res.addAll(list);
                    }
                }
            }
            step++;
            getSubSet(A, m, step);
            list.remove(list.size() - 1);   //回溯执行语句,删除列表最后一个元素
            deque.pop();
        }
    }

    public static void main(String[] args) {
        SubSet test = new SubSet();
        int[] A = new int[6];
        for(int i = 0;i < 6;i++) {
            A[i] = i + 1;
        }
        test.getSubSet(A, 8, 0);
        System.out.println(test.getRes());
    }
}

运算结果

代码语言:javascript
复制
[1, 2, 5]

但C#无法满足获取栈的值,只能获取栈的类型,如果我们用遍历的方式去获取栈的值又回到了以前NP级的时间复杂度,故直接使用数字来做哈希表的键。内容如下

代码语言:javascript
复制
using System;
using System.Collections.Generic;
using System.Collections;
using System.Text.RegularExpressions;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace ConsoleApplication1
{
    class Program
    {
        private class Oranize
        {
            public List<decimal> array = new List<decimal>();
            public List<decimal> res = new List<decimal>();
            public Stack<decimal> stack = new Stack<decimal>();
            public Hashtable table = new Hashtable();
            public decimal index = 0;

            public decimal getSum(List<decimal> list)
            {
                decimal sum = 0;
                for (int i = 0; i < list.Count; i++)
                {
                    sum += list[i];
                }
                return sum;
            }

            public String stackValue(Stack<decimal> stack)
            {
                StringBuilder sb = new StringBuilder();
                foreach (decimal s in stack)
                {
                    sb.Append(s.ToString());
                }
                return sb.ToString();
            }

            public void org(decimal[] arr,decimal all, int step)
            {
                if (res.Count > 0)
                {
                    return;
                }
                while (step < arr.Length)
                {
                    array.Add(arr[step]);                    
                    if (!table.ContainsKey(index.ToString()))
                    {
                        decimal sum = getSum(array);
                        stack.Push(index);
                        table.Add(stack.Peek().ToString(), sum);
                        if (sum == all)
                        {
                            if (getSum(res) == 0)
                            {
                                foreach (decimal a in array)
                                {
                                    res.Add(a);
                                }
                            }
                        }
                    }
                    else
                    {
                        decimal sum = 0;
                        if (stack.Count > 0)
                        {
                            sum = Convert.ToDecimal(table[stack.Peek().ToString()]) + arr[step];
                        }
                        else
                        {
                            sum = Convert.ToDecimal(table["0"]) + arr[step];
                        }
                        index++;
                        stack.Push(index);
                        if (table.ContainsKey(stack.Peek().ToString()))
                        {
                            table.Remove(stack.Peek().ToString());
                        }
                        table.Add(stack.Peek().ToString(), sum);
                        if (sum == all)
                        {
                            if (getSum(res) == 0)
                            {
                                foreach (decimal a in array)
                                {
                                    res.Add(a);
                                }
                            }
                        }
                    }
                    step++;
                    org(arr, all, step);
                    array.RemoveAt(array.Count - 1);
                    stack.Pop();
                }
            }
        }
        static void Main(string[] args)
        {
            decimal[] A = new decimal[6];
            for (int i = 0; i < 6; i++)
            {
                A[i] = i + 1;
            }
            Oranize oranize = new Oranize();
            oranize.org(A, 8, 0);

            foreach (decimal r in oranize.res)
            {
                Console.Write(r + ",");
            }
            Console.ReadLine();
        }
    }
}

这里我们可以看到如果使用stackValue来获取栈的各个值的字符串是不可取的,同样会非常慢。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档