diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java index dc4df0cc9..8e311fa89 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java @@ -27,11 +27,14 @@ * split into 3 parts and cursors for each of those parts are processing the segment simultaneously in the same thread. * Results are accumulated into {@link Result} objects and a tree map is used to sequentially accumulate the results in * the end. - * Runs in 0.39s on an Intel i9-13900K. + * Runs in 0.31 on an Intel i9-13900K while the reference implementation takes 120.37s. * Credit: * Quan Anh Mai for branchless number parsing code * Alfonso² Peterssen for suggesting memory mapping with unsafe and the subprocess idea * Artsiom Korzun for showing the benefits of work stealing at 2MB segments instead of equal split between workers + * Jaromir Hamala for showing that avoiding the branch misprediction between <8 and 8-16 cases is a big win even if + * more work is performed + * Van Phu DO for demonstrating the lookup tables based on masks instead of bit shifting */ public class CalculateAverage_thomaswue { private static final String FILE = "./measurements.txt"; @@ -141,9 +144,15 @@ private static void parseLoop(AtomicLong counter, long fileEnd, long fileStart, long delimiterMask1 = findDelimiter(word1); long delimiterMask2 = findDelimiter(word2); long delimiterMask3 = findDelimiter(word3); - Result existingResult1 = findResult(word1, delimiterMask1, scanner1, results, collectedResults); - Result existingResult2 = findResult(word2, delimiterMask2, scanner2, results, collectedResults); - Result existingResult3 = findResult(word3, delimiterMask3, scanner3, results, collectedResults); + long word1b = scanner1.getLongAt(scanner1.pos() + 8); + long word2b = scanner2.getLongAt(scanner2.pos() + 8); + long word3b = scanner3.getLongAt(scanner3.pos() + 8); + long delimiterMask1b = findDelimiter(word1b); + long delimiterMask2b = findDelimiter(word2b); + long delimiterMask3b = findDelimiter(word3b); + Result existingResult1 = findResult(word1, delimiterMask1, word1b, delimiterMask1b, scanner1, results, collectedResults); + Result existingResult2 = findResult(word2, delimiterMask2, word2b, delimiterMask2b, scanner2, results, collectedResults); + Result existingResult3 = findResult(word3, delimiterMask3, word3b, delimiterMask3b, scanner3, results, collectedResults); long number1 = scanNumber(scanner1); long number2 = scanNumber(scanner2); long number3 = scanNumber(scanner3); @@ -155,76 +164,70 @@ private static void parseLoop(AtomicLong counter, long fileEnd, long fileStart, while (scanner1.hasNext()) { long word = scanner1.getLong(); long pos = findDelimiter(word); - record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1)); + long wordB = scanner1.getLongAt(scanner1.pos() + 8); + long posB = findDelimiter(wordB); + record(findResult(word, pos, wordB, posB, scanner1, results, collectedResults), scanNumber(scanner1)); } while (scanner2.hasNext()) { long word = scanner2.getLong(); long pos = findDelimiter(word); - record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2)); + long wordB = scanner2.getLongAt(scanner2.pos() + 8); + long posB = findDelimiter(wordB); + record(findResult(word, pos, wordB, posB, scanner2, results, collectedResults), scanNumber(scanner2)); } while (scanner3.hasNext()) { long word = scanner3.getLong(); long pos = findDelimiter(word); - record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3)); + long wordB = scanner3.getLongAt(scanner3.pos() + 8); + long posB = findDelimiter(wordB); + record(findResult(word, pos, wordB, posB, scanner3, results, collectedResults), scanNumber(scanner3)); } } } - private static Result findResult(long initialWord, long initialDelimiterMask, Scanner scanner, Result[] results, List collectedResults) { + private static final long[] MASK1 = new long[]{ 0xFFL, 0xFFFFL, 0xFFFFFFL, 0xFFFFFFFFL, 0xFFFFFFFFFFL, 0xFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL, + 0xFFFFFFFFFFFFFFFFL }; + private static final long[] MASK2 = new long[]{ 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0xFFFFFFFFFFFFFFFFL }; + + private static Result findResult(long initialWord, long initialDelimiterMask, long wordB, long delimiterMaskB, Scanner scanner, Result[] results, + List collectedResults) { Result existingResult; long word = initialWord; long delimiterMask = initialDelimiterMask; long hash; long nameAddress = scanner.pos(); - - // Search for ';', one long at a time. There are two common cases that a specially treated: - // (b) the ';' is found in the first 16 bytes - if (delimiterMask != 0) { - // Special case for when the ';' is found in the first 8 bytes. - int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); - word = (word << (63 - trailingZeros)); - scanner.add(trailingZeros >>> 3); - hash = word; + long word2 = wordB; + long delimiterMask2 = delimiterMaskB; + if ((delimiterMask | delimiterMask2) != 0) { + int letterCount1 = Long.numberOfTrailingZeros(delimiterMask) >>> 3; // value between 1 and 8 + int letterCount2 = Long.numberOfTrailingZeros(delimiterMask2) >>> 3; // value between 0 and 8 + long mask = MASK2[letterCount1]; + word = word & MASK1[letterCount1]; + word2 = mask & word2 & MASK1[letterCount2]; + hash = word ^ word2; existingResult = results[hashToIndex(hash, results)]; - if (existingResult != null && existingResult.lastNameLong == word) { + scanner.add(letterCount1 + (letterCount2 & mask)); + if (existingResult != null && existingResult.firstNameWord == word && existingResult.secondNameWord == word2) { return existingResult; } } else { - // Special case for when the ';' is found in bytes 9-16. - hash = word; - long prevWord = word; - scanner.add(8); - word = scanner.getLong(); - delimiterMask = findDelimiter(word); - if (delimiterMask != 0) { - int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); - word = (word << (63 - trailingZeros)); - scanner.add(trailingZeros >>> 3); - hash ^= word; - existingResult = results[hashToIndex(hash, results)]; - if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { - return existingResult; + // Slow-path for when the ';' could not be found in the first 16 bytes. + hash = word ^ word2; + scanner.add(16); + while (true) { + word = scanner.getLong(); + delimiterMask = findDelimiter(word); + if (delimiterMask != 0) { + int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); + word = (word << (63 - trailingZeros)); + scanner.add(trailingZeros >>> 3); + hash ^= word; + break; } - } - else { - // Slow-path for when the ';' could not be found in the first 16 bytes. - scanner.add(8); - hash ^= word; - while (true) { - word = scanner.getLong(); - delimiterMask = findDelimiter(word); - if (delimiterMask != 0) { - int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); - word = (word << (63 - trailingZeros)); - scanner.add(trailingZeros >>> 3); - hash ^= word; - break; - } - else { - scanner.add(8); - hash ^= word; - } + else { + scanner.add(8); + hash ^= word; } } } @@ -249,8 +252,8 @@ private static Result findResult(long initialWord, long initialDelimiterMask, Sc } } - int remainingShift = (64 - (nameLength + 1 - i) << 3); - if (existingResult.lastNameLong == (scanner.getLongAt(nameAddress + i) << remainingShift)) { + int remainingShift = (64 - ((nameLength + 1 - i) << 3)); + if (((scanner.getLongAt(existingResult.nameAddress + i) ^ (scanner.getLongAt(nameAddress + i))) << remainingShift) == 0) { break; } else { @@ -297,7 +300,7 @@ private static void record(Result existingResult, long number) { } private static int hashToIndex(long hash, Result[] results) { - long hashAsInt = hash ^ (hash >>> 37) ^ (hash >>> 17); + long hashAsInt = hash ^ (hash >>> 33) ^ (hash >>> 15); return (int) (hashAsInt & (results.length - 1)); } @@ -324,21 +327,23 @@ private static long findDelimiter(long word) { private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner, List collectedResults) { Result r = new Result(); results[hash] = r; - int i = 0; - for (; i < nameLength + 1 - Long.BYTES; i += Long.BYTES) { + int totalLength = nameLength + 1; + r.firstNameWord = scanner.getLongAt(nameAddress); + r.secondNameWord = scanner.getLongAt(nameAddress + 8); + if (totalLength <= 8) { + r.firstNameWord = r.firstNameWord & MASK1[totalLength - 1]; + r.secondNameWord = 0; } - if (nameLength + 1 > 8) { - r.secondLastNameLong = scanner.getLongAt(nameAddress + i - 8); + else if (totalLength < 16) { + r.secondNameWord = r.secondNameWord & MASK1[totalLength - 9]; } - int remainingShift = (64 - (nameLength + 1 - i) << 3); - r.lastNameLong = (scanner.getLongAt(nameAddress + i) << remainingShift); r.nameAddress = nameAddress; collectedResults.add(r); return r; } private static final class Result { - long lastNameLong, secondLastNameLong; + long firstNameWord, secondNameWord; short min, max; int count; long sum;