Random Weighted Choice in Java

I want to select a random element from a set, but the probability of choosing an element should be proportional to the corresponding weight

Input Example:

item weight ---- ------ sword of misery 10 shield of happy 5 potion of dying 6 triple-edged sword 1 

So, if I have 4 possible items, the probability of getting any item without weight will be 1 in 4.

In this case, the user should be 10 times more likely to receive a sword of suffering than a triple-pointed sword.

How to make weighted random choices in Java?

+56
java double random
Jun 20 2018-11-11T00:
source share
7 answers

I would use a navigation map

 public class RandomCollection<E> { private final NavigableMap<Double, E> map = new TreeMap<Double, E>(); private final Random random; private double total = 0; public RandomCollection() { this(new Random()); } public RandomCollection(Random random) { this.random = random; } public RandomCollection<E> add(double weight, E result) { if (weight <= 0) return this; total += weight; map.put(total, result); return this; } public E next() { double value = random.nextDouble() * total; return map.higherEntry(value).getValue(); } } 

Let's say I have a list of animals, dogs, cats, horses with a probability of 40%, 35%, 25%, respectively

 RandomCollection<String> rc = new RandomCollection<>() .add(40, "dog").add(35, "cat").add(25, "horse"); for (int i = 0; i < 10; i++) { System.out.println(rc.next()); } 
+97
Jun 20 '11 at 10:23
source share

You will not find a framework for this kind of problem, since the requested functionality is no more simple function. Do something like this:

 interface Item { double getWeight(); } class RandomItemChooser { public Item chooseOnWeight(List<Item> items) { double completeWeight = 0.0; for (Item item : items) completeWeight += item.getWeight(); double r = Math.random() * completeWeight; double countWeight = 0.0; for (Item item : items) { countWeight += item.getWeight(); if (countWeight >= r) return item; } throw new RuntimeException("Should never be shown."); } } 
+23
Jun 20 '11 at 10:21
source share

There is now a class in Apache Commons for this: EnumeratedDistribution

 Item selectedItem = new EnumeratedDistribution(itemWeights).sample(); 

where itemWeights is a List<Pair<Item,Double>> , for example (assuming the Item interface is in the Arne response):

 List<Pair<Item,Double>> itemWeights = Collections.newArrayList(); for (Item i : itemSet) { itemWeights.add(new Pair(i, i.getWeight())); } 

or in Java 8:

 itemSet.stream().map(i -> new Pair(i, i.getWeight())).collect(toList()); 

Note: The Pair here should be org.apache.commons.math3.util.Pair , not org.apache.commons.lang3.tuple.Pair .

+16
May 21 '15 at
source share

Use alias method

If you are going to ride many times (like in a game), you should use the alias method.

The code below is a fairly long implementation of this alias method. But this is because of the initialization part. Retrieving elements is very fast (see next and applyAsInt methods that they do not execute).

Using

 Set<Item> items = ... ; ToDoubleFunction<Item> weighter = ... ; Random random = new Random(); RandomSelector<T> selector = RandomSelector.weighted(items, weighter); Item drop = selector.next(random); 

Implementation

This implementation:

  • uses Java 8 ;
  • Designed for as quickly as possible (well, at least I tried to do this using micro-benchmarking);
  • completely thread-safe (keep one Random in each thread for maximum performance, use ThreadLocalRandom ?);
  • retrieves elements in O (1) , unlike what you usually find on the Internet or in StackOverflow, where naive implementations are executed in O (n) or O (log (n));
  • stores elements regardless of their weight , so different weights can be assigned to an element in different contexts.

Anyway, here is the code. (Please note that I support the latest version of this class .

 import static java.util.Objects.requireNonNull; import java.util.*; import java.util.function.*; public final class RandomSelector<T> { public static <T> RandomSelector<T> weighted(Set<T> elements, ToDoubleFunction<? super T> weighter) throws IllegalArgumentException { requireNonNull(elements, "elements must not be null"); requireNonNull(weighter, "weighter must not be null"); if (elements.isEmpty()) { throw new IllegalArgumentException("elements must not be empty"); } // Array is faster than anything. Use that. int size = elements.size(); T[] elementArray = elements.toArray((T[]) new Object[size]); double totalWeight = 0d; double[] discreteProbabilities = new double[size]; // Retrieve the probabilities for (int i = 0; i < size; i++) { double weight = weighter.applyAsDouble(elementArray[i]); if (weight < 0.0d) { throw new IllegalArgumentException("weighter may not return a negative number"); } discreteProbabilities[i] = weight; totalWeight += weight; } if (totalWeight == 0.0d) { throw new IllegalArgumentException("the total weight of elements must be greater than 0"); } // Normalize the probabilities for (int i = 0; i < size; i++) { discreteProbabilities[i] /= totalWeight; } return new RandomSelector<>(elementArray, new RandomWeightedSelection(discreteProbabilities)); } private final T[] elements; private final ToIntFunction<Random> selection; private RandomSelector(T[] elements, ToIntFunction<Random> selection) { this.elements = elements; this.selection = selection; } public T next(Random random) { return elements[selection.applyAsInt(random)]; } private static class RandomWeightedSelection implements ToIntFunction<Random> { // Alias method implementation O(1) // using Vose algorithm to initialize O(n) private final double[] probabilities; private final int[] alias; RandomWeightedSelection(double[] probabilities) { int size = probabilities.length; double average = 1.0d / size; int[] small = new int[size]; int smallSize = 0; int[] large = new int[size]; int largeSize = 0; // Describe a column as either small (below average) or large (above average). for (int i = 0; i < size; i++) { if (probabilities[i] < average) { small[smallSize++] = i; } else { large[largeSize++] = i; } } // For each column, saturate a small probability to average with a large probability. while (largeSize != 0 && smallSize != 0) { int less = small[--smallSize]; int more = large[--largeSize]; probabilities[less] = probabilities[less] * size; alias[less] = more; probabilities[more] += probabilities[less] - average; if (probabilities[more] < average) { small[smallSize++] = more; } else { large[largeSize++] = more; } } // Flush unused columns. while (smallSize != 0) { probabilities[small[--smallSize]] = 1.0d; } while (largeSize != 0) { probabilities[large[--largeSize]] = 1.0d; } } @Override public int applyAsInt(Random random) { // Call random once to decide which column will be used. int column = random.nextInt(probabilities.length); // Call random a second time to decide which will be used: the column or the alias. if (random.nextDouble() < probabilities[column]) { return column; } else { return alias[column]; } } } } 
+4
Aug 01 '15 at 2:36
source share
 public class RandomCollection<E> { private final NavigableMap<Double, E> map = new TreeMap<Double, E>(); private double total = 0; public void add(double weight, E result) { if (weight <= 0 || map.containsValue(result)) return; total += weight; map.put(total, result); } public E next() { double value = ThreadLocalRandom.current().nextDouble() * total; return map.ceilingEntry(value).getValue(); } } 
+1
Nov 24 '16 at 22:58
source share

If you need to remove items after selection, you can use a different solution. Add all the elements to the "LinkedList", each element should be added as many times as the weight, and then use Collections.shuffle() , which, according to JavaDoc

Randomly rearrange the specified list using the default randomness source. All permutations occur with approximately equal probability.

Finally, get and remove items using pop() or removeFirst()

 Map<String, Integer> map = new HashMap<String, Integer>() {{ put("Five", 5); put("Four", 4); put("Three", 3); put("Two", 2); put("One", 1); }}; LinkedList<String> list = new LinkedList<>(); for (Map.Entry<String, Integer> entry : map.entrySet()) { for (int i = 0; i < entry.getValue(); i++) { list.add(entry.getKey()); } } Collections.shuffle(list); int size = list.size(); for (int i = 0; i < size; i++) { System.out.println(list.pop()); } 
+1
Jun 08 '17 at 5:14
source share

139

There is a simple algorithm for choosing an item in random order, where the items have an individual weight:

  1. calculate the sum of all weights

  2. select a random number that is 0 or greater and less than the sum of the weights

  3. Browse items one at a time, subtracting their weight from a random number until you get an item in which a random number is less than the weight of that item.

0
May 04 '19 at 8:17
source share



All Articles