For each loop as threads in Java8 - k-means

I have an implementation of the k-mean algorithm, and I would like to speed up the process using Java 8 threads and multi-core processing.

I have this code in Java 7:

//Step 2: For each point p: //find nearest clusters c //assign the point p to the closest cluster c for (Point p : points) { double minDst = Double.MAX_VALUE; int minClusterNr = 1; for (Cluster c : clusters) { double tmpDst = determineDistance(p, c); if (tmpDst < minDst) { minDst = tmpDst; minClusterNr = c.clusterNumber; } } clusters.get(minClusterNr - 1).points.add(p); } //Step 3: For each cluster c //find the central point of all points p in c //set c to the center point ArrayList<Cluster> newClusters = new ArrayList<Cluster>(); for (Cluster c : clusters) { double newX = 0; double newY = 0; for (Point p : c.points) { newX += px; newY += py; } newX = newX / c.points.size(); newY = newY / c.points.size(); newClusters.add(new Cluster(newX, newY, c.clusterNumber)); } 

And I would like to use Java 8 with parallel threads to speed up the process. I tried a little and came up with this solution:

 points.stream().forEach(p -> { minDst = Double.MAX_VALUE; //<- THESE ARE GLOBAL VARIABLES NOW minClusterNr = 1; //<- THESE ARE GLOBAL VARIABLES NOW clusters.stream().forEach(c -> { double tmpDst = determineDistance(p, c); if (tmpDst < minDst) { minDst = tmpDst; minClusterNr = c.clusterNumber; } }); clusters.get(minClusterNr - 1).points.add(p); }); ArrayList<Cluster> newClusters = new ArrayList<Cluster>(); clusters.stream().forEach(c -> { newX = 0; //<- THESE ARE GLOBAL VARIABLES NOW newY = 0; //<- THESE ARE GLOBAL VARIABLES NOW c.points.stream().forEach(p -> { newX += px; newY += py; }); newX = newX / c.points.size(); newY = newY / c.points.size(); newClusters.add(new Cluster(newX, newY, c.clusterNumber)); }); 

This solution with threads is much faster than without it. And I was wondering if this already uses multi-core processing? Why else would it be suddenly almost twice as fast?

no threads: Elapsed time: 202 ms & with threads: Elapsed time: 116 ms

Would it also be useful to use parallelStream in any of these methods to speed them even further? Everything that he does right now leads to ArrayOutOfBounce and NullPointer exceptions when changing a stream in stream (). Parallel (). ForEach (CODE)

---- EDIT (added source code as requested so you try to do it yourself) ----

--- Clustering.java ---

 package algo; import java.awt.Color; import java.awt.Graphics2D; import java.awt.image.BufferedImage; import java.util.ArrayList; import java.util.Random; import java.util.function.BiFunction; import graphics.SimpleColorFun; /** * An implementation of the k-means-algorithm. * <p> * Step 0: Determine the max size of the canvas * <p> * Step 1: Place clusters at random * <p> * Step 2: For each point p:<br> * find nearest clusters c<br> * assign the point p to the closest cluster c * <p> * Step 3: For each cluster c<br> * find the central point of all points p in c<br> * set c to the center point * <p> * Stop when none of the cluster x,y values change * @author makt * */ public class Clustering { private BiFunction<Integer, Integer, Color> colorFun = new SimpleColorFun(); // private BiFunction<Integer, Integer, Color> colorFun = new GrayScaleColorFun(); public Random rngGenerator = new Random(); public double max_x; public double max_y; public double max_xy; //--------------------------------- //TODO: IS IT GOOD TO HAVE THOUSE VALUES UP HERE? double minDst = Double.MAX_VALUE; int minClusterNr = 1; double newX = 0; double newY = 0; //---------------------------------- public boolean workWithStreams = false; public ArrayList<ArrayList<Cluster>> allGeneratedClusterLists = new ArrayList<ArrayList<Cluster>>(); public ArrayList<BufferedImage> allGeneratedImages = new ArrayList<BufferedImage>(); public Clustering(int seed) { rngGenerator.setSeed(seed); } public Clustering(Random rng) { rngGenerator = rng; } public void setup(int centroidCount, ArrayList<Point> points, int maxIterations) { //Step 0: Determine the max size of the canvas determineSize(points); ArrayList<Cluster> clusters = new ArrayList<Cluster>(); //Step 1: Place clusters at random for (int i = 0; i < centroidCount; i++) { clusters.add(new Cluster(rngGenerator.nextInt((int) max_x), rngGenerator.nextInt((int) max_y), i + 1)); } int iterations = 0; if (workWithStreams) { allGeneratedClusterLists.add(doClusteringWithStreams(points, clusters)); } else { allGeneratedClusterLists.add(doClustering(points, clusters)); } iterations += 1; //do until maxIterations is reached or until none of the cluster x and y values change anymore while (iterations < maxIterations) { //Step 2: happens inside doClustering if (workWithStreams) { allGeneratedClusterLists.add(doClusteringWithStreams(points, allGeneratedClusterLists.get(iterations - 1))); } else { allGeneratedClusterLists.add(doClustering(points, allGeneratedClusterLists.get(iterations - 1))); } if (!didPointsChangeClusters(allGeneratedClusterLists.get(iterations - 1), allGeneratedClusterLists.get(iterations))) { break; } iterations += 1; } System.out.println("Finished with " + iterations + " out of " + maxIterations + " max iterations"); } /** * checks if the cluster x and y values changed compared to the previous x and y values * @param previousCluster * @param currentCluster * @return true if any cluster x or y values changed, false if all of them they are the same */ private boolean didPointsChangeClusters(ArrayList<Cluster> previousCluster, ArrayList<Cluster> currentCluster) { for (int i = 0; i < previousCluster.size(); i++) { if (previousCluster.get(i).x != currentCluster.get(i).x || previousCluster.get(i).y != currentCluster.get(i).y) { return true; } } return false; } /** * * @param points - all given points * @param clusters - its point list gets filled in this method * @return a new Clusters Array which has an <b> empty </b> point list. */ private ArrayList<Cluster> doClustering(ArrayList<Point> points, ArrayList<Cluster> clusters) { //Step 2: For each point p: //find nearest clusters c //assign the point p to the closest cluster c for (Point p : points) { double minDst = Double.MAX_VALUE; int minClusterNr = 1; for (Cluster c : clusters) { double tmpDst = determineDistance(p, c); if (tmpDst < minDst) { minDst = tmpDst; minClusterNr = c.clusterNumber; } } clusters.get(minClusterNr - 1).points.add(p); } //Step 3: For each cluster c //find the central point of all points p in c //set c to the center point ArrayList<Cluster> newClusters = new ArrayList<Cluster>(); for (Cluster c : clusters) { double newX = 0; double newY = 0; for (Point p : c.points) { newX += px; newY += py; } newX = newX / c.points.size(); newY = newY / c.points.size(); newClusters.add(new Cluster(newX, newY, c.clusterNumber)); } allGeneratedImages.add(createImage(clusters)); return newClusters; } /** * Does the same as doClustering but about twice as fast!<br> * Uses Java8 streams to achieve this * @param points * @param clusters * @return */ private ArrayList<Cluster> doClusteringWithStreams(ArrayList<Point> points, ArrayList<Cluster> clusters) { points.stream().forEach(p -> { minDst = Double.MAX_VALUE; minClusterNr = 1; clusters.stream().forEach(c -> { double tmpDst = determineDistance(p, c); if (tmpDst < minDst) { minDst = tmpDst; minClusterNr = c.clusterNumber; } }); clusters.get(minClusterNr - 1).points.add(p); }); ArrayList<Cluster> newClusters = new ArrayList<Cluster>(); clusters.stream().forEach(c -> { newX = 0; newY = 0; c.points.stream().forEach(p -> { newX += px; newY += py; }); newX = newX / c.points.size(); newY = newY / c.points.size(); newClusters.add(new Cluster(newX, newY, c.clusterNumber)); }); allGeneratedImages.add(createImage(clusters)); return newClusters; } //draw all centers from clusters //draw all points //color points according to cluster value private BufferedImage createImage(ArrayList<Cluster> clusters) { //add 10% of the max size left and right to the image bounds //BufferedImage bi = new BufferedImage((int) (max_xy * 1.05), (int) (max_xy * 1.05), BufferedImage.TYPE_BYTE_INDEXED); BufferedImage bi = new BufferedImage((int) (max_xy * 1.05), (int) (max_xy * 1.05), BufferedImage.TYPE_INT_ARGB); // support 32-bit RGBA values Graphics2D g2d = bi.createGraphics(); int numClusters = clusters.size(); for (Cluster c : clusters) { //color points according to cluster value Color col = colorFun.apply(c.clusterNumber, numClusters); //draw all points g2d.setColor(col); for (Point p : c.points) { g2d.fillRect((int) px, (int) py, (int) (max_xy * 0.02), (int) (max_xy * 0.02)); } //draw all centers from clusters g2d.setColor(new Color(160, 80, 80, 200)); // use RGBA: transparency=200 g2d.fillOval((int) cx, (int) cy, (int) (max_xy * 0.03), (int) (max_xy * 0.03)); } return bi; } /** * Calculates the euclidean distance without square root * @param p * @param c * @return */ private double determineDistance(Point p, Cluster c) { //math.sqrt not needed because the relative distance does not change by applying the square root // return Math.sqrt(Math.pow((px - cx), 2)+Math.pow((py - cy),2)); return Math.pow((px - cx), 2) + Math.pow((py - cy), 2); } //TODO: What if coordinates can also be negative? private void determineSize(ArrayList<Point> points) { for (Point p : points) { if (px > max_x) { max_x = px; } if (py > max_y) { max_y = py; } } if (max_x > max_y) { max_xy = max_x; } else { max_xy = max_y; } } } 

--- Point.java ---

 package algo; public class Point { public double x; public double y; public Point(int x, int y) { this.x = x; this.y = y; } public Point(double x, double y) { this.x = x; this.y = y; } } 

--- Cluster.java ---

 package algo; import java.util.ArrayList; public class Cluster { public double x; public double y; public int clusterNumber; public ArrayList<Point> points = new ArrayList<Point>(); public Cluster(double x, double y, int clusterNumber) { this.x = x; this.y = y; this.clusterNumber = clusterNumber; } } 

--- SimpleColorFun.java ---

 package graphics; import java.awt.Color; import java.util.function.BiFunction; /** * Simple function for selection a color for a specific cluster identified with an integer-ID. * * @author makl, hese */ public class SimpleColorFun implements BiFunction<Integer, Integer, Color> { /** * Selects a color value. * @param n current index * @param numCol number of color-values possible */ @Override public Color apply(Integer n, Integer numCol) { Color col = Color.BLACK; //color points according to cluster value switch (n) { case 1: col = Color.RED; break; case 2: col = Color.GREEN; break; case 3: col = Color.BLUE; break; case 4: col = Color.ORANGE; break; case 5: col = Color.MAGENTA; break; case 6: col = Color.YELLOW; break; case 7: col = Color.CYAN; break; case 8: col = Color.PINK; break; case 9: col = Color.LIGHT_GRAY; break; default: break; } return col; } } 

--- Main.java --- (REPLACE Stopwatch with some time recording mechanism - I get this from our working environment)

 package main; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Random; import java.util.concurrent.TimeUnit; import javax.imageio.ImageIO; import algo.Clustering; import algo.Point; import eu.lbase.common.util.Stopwatch; // import persistence.DataHandler; public class Main { private static final String OUTPUT_DIR = (new File("./output/withoutStream")).getAbsolutePath() + File.separator; private static final String OUTPUT_DIR_2 = (new File("./output/withStream")).getAbsolutePath() + File.separator; public static void main(String[] args) { Random rng = new Random(); int numPoints = 300; int seed = 2; ArrayList<Point> points = new ArrayList<Point>(); rng.setSeed(rng.nextInt()); for (int i = 0; i < numPoints; i++) { points.add(new Point(rng.nextInt(1000), rng.nextInt(1000))); } Stopwatch stw = Stopwatch.create(TimeUnit.MILLISECONDS); { // Stopwatch start System.out.println("--- Started without streams ---"); stw.start(); Clustering algo = new Clustering(seed); algo.setup(8, points, 25); // Stopwatch stop stw.stop(); System.out.println("--- Finished without streams ---"); System.out.printf("Elapsed time: %d msec%n%n", stw.getElapsed()); System.out.printf("Writing images to '%s' ...%n", OUTPUT_DIR); deleteOldFiles(new File(OUTPUT_DIR)); makeImages(OUTPUT_DIR, algo); System.out.println("Finished writing.\n"); } { System.out.println("--- Started with streams ---"); stw.start(); Clustering algo = new Clustering(seed); algo.workWithStreams = true; algo.setup(8, points, 25); // Stopwatch stop stw.stop(); System.out.println("--- Finished with streams ---"); System.out.printf("Elapsed time: %d msec%n%n", stw.getElapsed()); System.out.printf("Writing images to '%s' ...%n", OUTPUT_DIR_2); deleteOldFiles(new File(OUTPUT_DIR_2)); makeImages(OUTPUT_DIR_2, algo); System.out.println("Finished writing.\n"); } } /** * creates one image for each iteration in the given directory * @param algo */ private static void makeImages(String dir, Clustering algo) { int i = 1; for (BufferedImage img : algo.allGeneratedImages) { try { String filename = String.format("%03d.png", i); ImageIO.write(img, "png", new File(dir + filename)); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } i++; } } /** * deletes old files from the target directory<br> * Does <b>not</b> delete directories! * @param dir - directory to delete files from * @return */ private static boolean deleteOldFiles(File file) { File[] allContents = file.listFiles(); if (allContents != null) { for (File f : allContents) { deleteOldFiles(f); } } if (!file.isDirectory()) { return file.delete(); } return false; } } 
+5
source share
2 answers

If you want to use threads efficiently, you should stop using forEach to basically write the same thing as a loop, and instead learn about aggregate operations . See Also Comprehensive Package Documentation .

A thread protected solution might look like

 points.stream().forEach(p -> { Cluster min = clusters.stream() .min(Comparator.comparingDouble(c -> determineDistance(p, c))).get(); // your original code used the custerNumber to lookup the Cluster in // the list, don't know whether this is this really necessary min = clusters.get(min.clusterNumber - 1); // didn't find a better way considering your current code structure synchronized(min) { min.points.add(p); } }); List<Cluster> newClusters = clusters.stream() .map(c -> new Cluster( c.points.stream().mapToDouble(p -> px).sum()/c.points.size(), c.points.stream().mapToDouble(p -> py).sum()/c.points.size(), c.clusterNumber)) .collect(Collectors.toList()); } 

but you did not provide enough context to verify this. There are some open questions, for example. You used the clusterNumber Cluster instance to look back at the list of clusters ; I don't know if clusterNumber represents the actual list index of the Cluster instance that we already have, i.e. if it is excess redundancy or has a different meaning.

I also don't know a better solution than syncing a specific Cluster to make list flow control safe (given your current code structure). This is only necessary if you decide to use a parallel stream, i.e. points.parallelStream().forEach(p -> …) , other operations are not affected.

Now you have several threads that you can try in parallel and sequentially to find out where you get the benefit or not. Usually, only other threads bring significant benefits, if any ...

+3
source

This solution with threads is much faster than without

Which is strange, since the streams created by Collection.stream () are sequential and not related to multicore processing. It should also not be necessary to move variables (e.g. minDst) externally for each loop. This will probably lead to the fact that parallel threads will calculate the wrong result, however, I do not see the reason for the occurrence of an AOOB or NPE exception in the provided code.

0
source

Source: https://habr.com/ru/post/1275259/


All Articles