您现在的位置是:首页 > 文章详情

使用栈的记忆化搜索来加速子集和算法

日期:2020-12-01点击:418

所谓子集和就是在一个数组中找出它的子集,使得该子集的和等于某个固定值。

一般我们都是使用递归加回溯的方式来处理的,代码如下(此处我们只找出一组满足的条件即可)

 public class SubSet { private List<Integer> list = new ArrayList<>(); //用于存放求取子集中的元素  @Getter  private List<Integer> res = new ArrayList<>();   //求取数组列表中元素和  public int getSum(List<Integer> list) { int sum = 0;  for(int i = 0;i < list.size();i++) sum += list.get(i);  return sum;  } public void getSubSet(int[] A, int m, int step) { if (res.size() > 0) { return;  } while(step < A.length) { list.add(A[step]);  if (getSum(list) == m) { if (getSum(res) == 0) { res.addAll(list);  } } step++;  getSubSet(A, m, step);  list.remove(list.size() - 1); //回溯执行语句,删除列表最后一个元素  } } public static void main(String[] args) { SubSet test = new SubSet();  int[] A = new int[6];  for(int i = 0;i < 6;i++) { A[i] = i + 1;  } test.getSubSet(A, 8, 0);  System.out.println(test.getRes());  } }

运行结果

 [1, 2, 5]

但是这个算法的时间复杂度非常高,是NP级别的。如果数据量比较大的时候,将很难完成运算。

现在我们用栈和哈希缓存来加速这个算法。主要是缓存计算结果,不用每次都去getSum中把list的和算一遍。其思想主要是记忆化搜索,可以参考本人这篇博客动态规划、回溯、贪心,分治

 public class SubSet { private List<Integer> list = new ArrayList<>(); //用于存放求取子集中的元素  @Getter  private List<Integer> res = new ArrayList<>();  private Deque<Integer> deque = new ArrayDeque<>();  private Map<String,Integer> map = new HashMap<>();   //求取数组列表中元素和  public int getSum(List<Integer> list) { int sum = 0;  for(int i = 0;i < list.size();i++) sum += list.get(i);  return sum;  } public void getSubSet(int[] A, int m, int step) { if (res.size() > 0) { return;  } while(step < A.length) { list.add(A[step]);  if (!map.containsKey(deque.toString())) { int sum = getSum(list);  deque.push(A[step]);  map.put(deque.toString(),sum);  if (sum == m) { if (getSum(res) == 0) { res.addAll(list);  } } }else { int sum = map.get(deque.toString()) + A[step];  deque.push(A[step]);  map.put(deque.toString(),sum);  if (sum == m) { if (getSum(res) == 0) { res.addAll(list);  } } } step++;  getSubSet(A, m, step);  list.remove(list.size() - 1); //回溯执行语句,删除列表最后一个元素  deque.pop();  } } public static void main(String[] args) { SubSet test = new SubSet();  int[] A = new int[6];  for(int i = 0;i < 6;i++) { A[i] = i + 1;  } test.getSubSet(A, 8, 0);  System.out.println(test.getRes());  } }

运算结果

 [1, 2, 5]

但C#无法满足获取栈的值,只能获取栈的类型,如果我们用遍历的方式去获取栈的值又回到了以前NP级的时间复杂度,故直接使用数字来做哈希表的键。内容如下

 using System; using System.Collections.Generic; using System.Collections; using System.Text.RegularExpressions; using System.Linq; using System.Text; using System.Threading.Tasks; namespace ConsoleApplication1 { class Program { private class Oranize { public List<decimal> array = new List<decimal>(); public List<decimal> res = new List<decimal>(); public Stack<decimal> stack = new Stack<decimal>(); public Hashtable table = new Hashtable(); public decimal index = 0; public decimal getSum(List<decimal> list) { decimal sum = 0; for (int i = 0; i < list.Count; i++) { sum += list[i]; } return sum; } public String stackValue(Stack<decimal> stack) { StringBuilder sb = new StringBuilder(); foreach (decimal s in stack) { sb.Append(s.ToString()); } return sb.ToString(); } public void org(decimal[] arr,decimal all, int step) { if (res.Count > 0) { return; } while (step < arr.Length) { array.Add(arr[step]); if (!table.ContainsKey(index.ToString())) { decimal sum = getSum(array); stack.Push(index); table.Add(stack.Peek().ToString(), sum); if (sum == all) { if (getSum(res) == 0) { foreach (decimal a in array) { res.Add(a); } } } } else { decimal sum = 0; if (stack.Count > 0) { sum = Convert.ToDecimal(table[stack.Peek().ToString()]) + arr[step]; } else { sum = Convert.ToDecimal(table["0"]) + arr[step]; } index++; stack.Push(index); if (table.ContainsKey(stack.Peek().ToString())) { table.Remove(stack.Peek().ToString()); } table.Add(stack.Peek().ToString(), sum); if (sum == all) { if (getSum(res) == 0) { foreach (decimal a in array) { res.Add(a); } } } } step++; org(arr, all, step); array.RemoveAt(array.Count - 1); stack.Pop(); } } } static void Main(string[] args) { decimal[] A = new decimal[6]; for (int i = 0; i < 6; i++) { A[i] = i + 1; } Oranize oranize = new Oranize(); oranize.org(A, 8, 0); foreach (decimal r in oranize.res) { Console.Write(r + ","); } Console.ReadLine(); } } }

这里我们可以看到如果使用stackValue来获取栈的各个值的字符串是不可取的,同样会非常慢。

由于C#本身的Hashtable在数据量大的情况下存在溢出风险,所以我们要重写哈希表。重写的哈希表的每个节点由红黑树组成,由于我们并不需要删除哈希表内的元素,所以就不写红黑树和哈希表的删除方法。

  private class RedBlackTreeMap         {             private static bool RED = true;             private static bool BLACK = false;             private class Node             {                 public String key;                 public decimal value;                 public Node left;                 public Node right;                 public bool color;                 public Node(String key,decimal value,Node left,Node right,bool color)                 {                     this.key = key;                     this.value = value;                     this.left = left;                     this.right = right;                     this.color = color;                 }                 public Node(String key): this(key, 0, null, null, RED)                 { }                 public Node(String key,decimal value): this(key, value, null, null, RED)                 { }                                                 }             private Node root;             private int size;             public ISet<String> keySet = new HashSet<String>();             public RedBlackTreeMap()             {                 root = null;                 size = 0;             }             private bool isRed(Node node)             {                 if (node == null)                 {                     return BLACK;                 }                 return node.color;             }             private Node leftRotate(Node node)             {                 Node ret = node.right;                 Node retLeft = ret.left;                 node.right = retLeft;                 ret.left = node;                 ret.color = node.color;                 node.color = RED;                 return ret;             }             private Node rightRotate(Node node)             {                 Node ret = node.left;                 Node retRight = ret.right;                 node.left = retRight;                 ret.right = node;                 ret.color = node.color;                 node.color = RED;                 return ret;             }             private void flipColors(Node node)             {                 node.color = RED;                 node.left.color = BLACK;                 node.right.color = BLACK;             }             public void add(String key,decimal value)             {                 root = add(root, key, value);                 keySet.Add(key);             }             private Node add(Node node,String key,decimal value)             {                 if (node == null)                 {                     size++;                     return new Node(key, value);                 }                 if (key.CompareTo(node.key) < 0)                 {                     node.left = add(node.left, key, value);                 }else if (key.CompareTo(node.key) > 0)                 {                     node.right = add(node.right, key, value);                 }else                 {                     node.value = value;                 }                 if (isRed(node.right) && !isRed(node.left))                 {                     node = leftRotate(node);                 }                 if (isRed(node.left) && isRed(node.left.left))                 {                     node = rightRotate(node);                 }                 if (isRed(node.left) && isRed(node.right))                 {                     flipColors(node);                 }                 return node;             }             public bool contains(String key)             {                 return getNode(root, key) != null;             }             public decimal get(String key)             {                 Node node = getNode(root, key);                 return node == null ? 0 : node.value;             }             public void set(String key,decimal value)             {                 Node node = getNode(root, key);                 if (node == null)                 {                     throw new ArgumentException(key + "不存在");                 }                 node.value = value;             }             public int getSize()             {                 return size;             }             public bool isEmpty()             {                 return size == 0;             }             private Node getNode(Node node,String key)             {                 if (node == null)                 {                     return null;                 }                 if (key.CompareTo(node.key) == 0)                 {                     return node;                 }else if (key.CompareTo(node.key) < 0)                 {                     return getNode(node.left, key);                 }else                 {                     return getNode(node.right, key);                 }             }         }         private class HashFind         {             private int[] capacity = {53,97,193,389,769,1543,3079,6151,12289,24593,             49157,98317,196613,393241,786433,1572869,3145739,             6291469,12582917,25165843,50331653,100663319,             201326611,402653189,805306457,1610612741};             //容忍度上界             private static int upperTol = 10;             //容忍度下届             private static int lowerTol = 2;             private int capacityIndex = 0;             private RedBlackTreeMap[] tables;             private int M;             private int size;             public HashFind()             {                 this.M = capacity[capacityIndex];                 this.size = 0;                 tables = new RedBlackTreeMap[M];                 for (int i = 0; i < M; i++)                 {                     tables[i] = new RedBlackTreeMap();                 }             }             private int hash(String key)             {                 return (key.GetHashCode() & 0x7fffffff) % M;             }             public void add(String key,decimal value)             {                 RedBlackTreeMap map = tables[hash(key)];                 if (map.contains(key))                 {                     map.add(key, value);                 }else                 {                     map.add(key, value);                     size++;                     if (size >= upperTol * M && capacityIndex + 1 < capacity.Length)                     {                         capacityIndex++;                         resize(capacity[capacityIndex]);                     }                 }             }             public bool contains(String key)             {                 int index = hash(key);                 return tables[index].contains(key);             }             public decimal get(String key)             {                 int index = hash(key);                 return tables[index].get(key);             }             public void set(String key,decimal value)             {                 int index = hash(key);                 RedBlackTreeMap map = tables[index];                 if(!map.contains(key))                 {                     throw new ArgumentException(key + "不存在");                 }                 map.add(key, value);             }             public int getSize()             {                 return size;             }             public bool isEmpty()             {                 return size == 0;             }             private void resize(int newM)             {                 RedBlackTreeMap[] newTables = new RedBlackTreeMap[newM];                 for (int i = 0; i < newM; i++)                 {                     newTables[i] = new RedBlackTreeMap();                 }                 int oldM = this.M;                 this.M = newM;                 for (int i = 0; i < oldM; i++)                 {                     RedBlackTreeMap map = tables[i];                     foreach (String key in map.keySet)                     {                         int index = hash(key);                         newTables[index].add(key, map.get(key));                     }                 }                 this.tables = newTables;             }         }         private class Oranize         {             public List<decimal> array = new List<decimal>();             public List<decimal> res = new List<decimal>();             public Stack<decimal> stack = new Stack<decimal>();             public HashFind table = new HashFind();             public decimal index = 0;             public decimal getSum(List<decimal> list)             {                 decimal sum = 0;                 for (int i = 0; i < list.Count; i++)                 {                     sum += list[i];                 }                 return sum;             }             //public String stackValue(Stack<decimal> stack)             //{             //    StringBuilder sb = new StringBuilder();             //    foreach (decimal s in stack)             //    {             //        sb.Append(s.ToString());             //    }             //    return sb.ToString();             //}             public void org(decimal[] arr, decimal all, int step)             {                 if (res.Count > 0)                 {                     return;                 }                 while (step < arr.Length)                 {                     array.Add(arr[step]);                     if (!table.contains(index.ToString()))                     {                         decimal sum = getSum(array);                         stack.Push(index);                         table.add(stack.Peek().ToString(), sum);                         if (sum == all)                         {                             if (getSum(res) == 0)                             {                                 foreach (decimal a in array)                                 {                                     res.Add(a);                                 }                             }                         }                     }                     else                     {                         decimal sum = 0;                         if (stack.Count > 0)                         {                             sum = Convert.ToDecimal(table.get(stack.Peek().ToString())) + arr[step];                         }                         else                         {                             sum = Convert.ToDecimal(table.get("0")) + arr[step];                         }                         index++;                         stack.Push(index);                         if (!table.contains(stack.Peek().ToString()))                         {                             table.add(stack.Peek().ToString(), sum);                         }                         if (sum == all)                         {                             if (getSum(res) == 0)                             {                                 foreach (decimal a in array)                                 {                                     res.Add(a);                                 }                             }                         }                     }                     step++;                     org(arr, all, step);                     array.RemoveAt(array.Count - 1);                     stack.Pop();                 }             }         }

虽然该算法进行了加速,但是能否算出,依然在于数组元素的个数所组成的和的组合数,比如有1、2、3、4四个数,则这四个数的和的组合数为1、2、3、4、1+2、1+2+3、1+2+4、1+2+3+4、1+3、1+3+4、1+4、2+3、2+3+4、2+4、3+4总共15个。

我们可以用计算组合数算法来进行验证,该算法也是使用递归加记忆化搜索的方式

 public class Combine { private static Map<String,Long> map= new HashMap<>();   /**  * 计算从m个元素中拿出n个元素的组合数  * @param m  * @param n  * @return  */  private static long comb(int m,int n){ String key= m+","+n;  if(n == 0) return 1;  if (n == 1) return m;  if(n > m / 2) return comb(m,m-n);  if(n > 1){ if(!map.containsKey(key)) map.put(key, comb(m-1,n-1)+comb(m-1,n));  return map.get(key);  } return -1;  } public static void main(String[] args) { long total = 0;  for (int i = 1 ; i <= 4; i++) { total += comb(4,i);  } System.out.println(total);  } }

运行结果

15

我们现在的主要目的是寻找可计算的节点,我们可以先给出一个比较大的数,比如一个数组中有40个元素

 public static void main(String[] args) { long total = 0;  for (int i = 1 ; i <= 40; i++) { total += comb(40,i);  } System.out.println(total); }

运行结果

1099511627775

由结果可知,40个数的组合数达到了万亿级别,一般我们计算机的计算级数量在亿级别就差不多了,再多的话就比较难算的出来了。当然这里我的个人建议是数组元素数量在28个

 public static void main(String[] args) { long total = 0;  for (int i = 1 ; i <= 28; i++) { total += comb(28,i);  } System.out.println(total); }

运行结果

268435455

这里是2.6亿,最后我们来看一下30的组合数

 public static void main(String[] args) { long total = 0;  for (int i = 1 ; i <= 30; i++) { total += comb(30,i);  } System.out.println(total); }

运行结果

1073741823

运行结果为10亿,所以我们可以看出从28到30,增长的组合数绝对不是一点点。这是一个几何级数的增长。

 

原文链接:https://my.oschina.net/u/3768341/blog/4769129
关注公众号

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。

持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。

转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。

文章评论

共有0条评论来说两句吧...

文章二维码

扫描即可查看该文章

点击排行

推荐阅读

最新文章