Having fun with Scikit-Learn

Scikit-Learn is a great library to start machine learning with, because it combines a powerful API, solid documentation, and a large variety of methods with lots of different options and sensible defaults. For example, if we have a classification problem of predicting whether a sentence is about New-York, London or both, we can create a pipeline including tokenization with case folding and stop-word removal, bigram extraction, tf-idf weighting and support for multiple labels, train and apply it in merely 8 lines of code (well, excluding imports and input specification).

vec = TfidfVectorizer(ngram_range=(1, 2), stop_words='english', max_features=15)
clf = OneVsRestClassifier(LogisticRegressionCV())
pipeline = make_pipeline(vec, clf)
mlb = MultiLabelBinarizer()
y_train = mlb.fit_transform(y_train)
pipeline.fit(X_train, y_train)
predicted = pipeline.predict(X_test)
print(zip(X_test, mlb.inverse_transform(predicted)))

Furthermore, we can easily add cross-validation and parameter tuning, etc., in just a few more lines. Check out this great introductory video series for all the cool features you can use right from the beginning. But lets go back a bit! Assume we have the following training data:

[("new york is a hell of a town", ["New York"]),
 ("new york was originally dutch", ["New York"]),
 ("the big apple is great", ["New York"]),
 ("new york is also called the big apple", ["New York"]),
 ("nyc is nice", ["New York"]),
 ("people abbreviate new york city as nyc", ["New York"]),
 ("the capital of great britain is london", ["London"]),
 ("london is in the uk", ["London"]),
 ("london is in england", ["London"]),
 ("london is in great britain", ["London"]),
 ("it rains a lot in london", ["London"]),
 ("london hosts the british museum", ["London"]),
 ("new york is great and so is london", ["London", "New York"]),
 ("i like london better than new york", ["London", "New York"])]

Our model can easily predict the following:

[("nice day in nyc", ["New York"]),
 ("welcome to london", ["London']),
 ("hello simon welcome to new york. enjoy it here and london", ["London", "New York"])]

But how? Well, we know that the model’s decision function computes a dot product between the input features and the trained weights and we can actually see the computed numbers:

[[ -7.96273351   1.04803743]
 [ 22.19686347  -1.39109585]
 [  5.48828931   1.28660432]]

So, a positive number indicates positive classification ([London, New-York]), and we can convert this to some sort of “probability” estimation:

print(np.vectorize(lambda x: 1 / (1 + exp(-x)))(pipeline.decision_function(X_test)))
[[  3.48078806e-04   7.40397853e-01]
 [  1.00000000e+00   1.99232868e-01]
 [  9.95882115e-01   7.83571880e-01]]

However, it would be nice if we could know exactly which words have contributed to the final score. And guess what, there is an awesome library to explain you this like are five years old:

import eli5
eli5.show_prediction(clf, X_test[0], vec=vec, target_names=mlb.classes_)

y=London (probability 0.996, score 5.488)top features

Weight Feature
+8.631 Highlighted in text (sum)

hello simon welcome to new york. enjoy it here and london too

y=New York (probability 0.784, score 1.287)top features

Weight Feature
+1.084 Highlighted in text (sum)

hello simon welcome to new york. enjoy it here and london too

Note that the numbers here are exactly what we have seen above, but in addition we get a visual explanation of which words have trigged positive and negative signals. Moreover, with a bit of trickery to work around this issue, we can also dump the features computed for our model as a neatly looking heat map:

setattr(clf.estimator, 'classes_', clf.classes_)
setattr(clf.estimator, 'coef_', clf.coef_)
setattr(clf.estimator, 'intercept_', clf.intercept_)
eli5.show_weights(clf.estimator, vec=vec, target_names=mlb.classes_)
y=London top features y=New York top features
Weight Feature
+25.340 london
+4.706 great
+2.079 lot london
+2.079 lot
+1.266 great britain
+1.266 britain
+0.896 museum
-1.682 new
-1.682 new york
-1.682 york
-2.723 nice
-3.143 <BIAS>
-3.875 apple
-3.875 big
-3.875 big apple
-4.219 nyc
Weight Feature
+1.154 york
+1.154 new york
+1.154 new
+0.661 nyc
+0.546 nice
+0.526 big apple
+0.526 big
+0.526 apple
+0.202 <BIAS>
+0.048 great
-0.500 lot
-0.500 lot london
-0.632 museum
-0.755 britain
-0.755 great britain
-1.594 london

So far, working with scikit-learn and eli5 is as fun as Transformers were when I was five! And working with a relatively large dataset is just as easy as the toy example shown here in. Man, this is great, but what about using it in a production system you say? Well, lets talk about that next time ;)

Read more

Recently I have been working on a text-classification task. Along the way I have tested out three interesting machine learning frameworks which I would like to address in the next few posts. This time I start with the Apache Spark’s MLlib.

Spark got my attention quite a long time ago and it is extremely useful for data exploration tasks where you can simply put lots of data on HDFS and then use a Jupyter notebook to transform the data interactively. However, although training is preferably done offline using a large number of examples (where Spark becomes handy), the classification part is often desired to be a short-latency/high-throughput task. As the framework itself brings quite a lot of overhead, it could be nice if the API methods could be executed without the Spark cluster when necessary. In that case you could use the cluster to build a model, serialize and ship it to a worker, which will then use the model on the incoming instances.

My earlier implementation of text-classification for Reuters 21578 can be executed as a simple JAR, but as I wrote earlier, it was quite a dance to do this correctly. Moreover, in that example I have used the RDD part of the MLlib API and ended up with a very verbose Java code.

Recently, the API has been extended with pipelines and many interesting features (most likely inspired by Scikit-Learn) making it really easy to implement a classifier in just a few lines of code, for example:

labelIndexer = StringIndexer(inputCol="label_text", outputCol="label")
tokenizer = Tokenizer(inputCol="text", outputCol="tokens")
remover = StopWordsRemover(inputCol=tokenizer.getOutputCol(), outputCol="filtered")
hashingTF = HashingTF(inputCol=remover.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=20, regParam=0.01)
ovr = OneVsRest(classifier=lr)
pipeline = Pipeline(stages=[labelIndexer, tokenizer, remover, hashingTF, ovr])
model = pipeline.fit(train_paired)
result = model.transform(test_paired)
predictionAndLabels = result.select("prediction", "label")
evaluator = MulticlassClassificationEvaluator(metricName="accuracy")
print("Test set accuracy = " + str(evaluator.evaluate(predictionAndLabels)))

Nevertheless, this new DataFrame-based part of the API has not yet reached parity with the old RDD-based part of the API. The latter is planned to be deprecated when this happens, but currently some of the methods are available only through the old part of the API. In other words, a strong dependency between the algorithms and the underlying data structure is a real problem here. I only hope that the same will not happen again if the DataFrame concept gets replaced by a better idea in a year or two.

Finally, although there are quite many resources available online (books, courses, talks, slides, etc.), the documentation of MLlib is far from good (especially the API docs) and the customization part beyond simple examples is a nightmare (an exercise for the reader: add bigrams to the pipeline above), if possible at all. On the positive side, Spark is great for certain use cases and is being actively developed with lots of interesting features and ideas coming up next.

Read more

A while ago, I was looking at cardinality estimators for use in a distributed setting – given a data set spread over a set of nodes, we want to compute the total number of unique keys without having to transfer all keys or a global bit signature. Counting sketches such as HyperLogLog (see here, here and here for an introduction) have superior memory usage and cpu performance when cardinality can be estimated with a small error margin. In the following, I summarize a comparison between the two Java libraries, StreamLib and Java-HLL, I did back in February 2014.


StreamLib implements several methods:

  • Linear counting (lincnt) - hashes values into positions in a bit vector and then estimates the number of items based on the number of unset bits.

  • LogLog (ll) - uses hashing to add an element to one of the m different estimators, and updates the maximum observed rank updateRegister(h >>> (Integer.SIZE - k), Integer.numberOfLeadingZeros((h << k) | (1 << (k - 1))) + 1)), where k = log2(m). The cardinality is estimated as Math.pow(2, Ravg) * a, where Ravg is the average maximum observed rank across the m registers and a is the a correction function for the given m (see the paper for details).

  • HyperLogLog (hll) - improves the LogLog algorithm by several aspects, for example by using harmonic mean.

  • HyperLogLog++ (hlp) - Google’s take on HLL that improves memory usage and accuracy for small cardinalities

Java-HLL (hlx) on the other hand provides a set of tweaks to HyperLogLog, mainly exploring the idea that a chunk of data, say 1280 bytes, can be used to fully represent a short sorted list, a sparse/lazy map of non-empty register, or a full register set (see the project page for details).

Performance comparison

I used two relatively small real-world data sets, similar to what was intended to be used in production. For hashing I used StreamLib’s MurmurHash.hash64, which for some reason did it better than Guava’s on the test data (I haven’t investigated the reason though). The latency times given below are cold-start numbers, measured with no respect to JIT and other issues. In other words, these are not scientific results.

Dataset A

The first data set has the following characteristics:

  • 3765844 tokens
  • 587913 unique keys (inserting into a Sets.newHashSet(): 977ms)
  • 587913 unique hashed keys (Sets.newHashSet(): 2520ms)

First lets compare the StreamLib methods tuned for 1% error with 10 mil keys. The collected data includes the name of the method, relative error, total estimator size, total elapsed time. The number behind ll, hll, hlp denotes the log2(m) parameter:

name error size time
lincnt 0.0017 137073B 1217ms
ll__14 0.0135 16384B 963ms
hll_13 0.0181 5472B 1000ms
hlp13 -0.0081 5473B 863ms

Here HLP performs best, with only 0.81% error and using only 5KB memory.

Now, lets compare StreamLib and Java-HLL. The parameter behind hlp is log2(m), while the parameters behind hlx are log2(m), register width (5 seems like the only one that works), promotion threshold (-1 denotes the auto mode) and the initial representation type.

name error size time
hlp10 0.0323 693B 818ms
hlp11 0.0153 1377B 967ms
hlp12 0.0132 2741B 790ms
hlp13 -0.0081 5473B 731ms
hlp14 -0.0081 10933B 697ms
hlx_105-1_FULL -0.0212 643B 723ms
hlx_105-1_SPARSE -0.0212 643B 680ms
hlx_115-1_FULL -0.0202 1283B 670ms
hlx_115-1_SPARSE -0.0202 1283B 710ms
hlx_125-1_FULL -0.0069 2563B 673ms
hlx_125-1_SPARSE -0.0069 2563B 699ms
hlx_135-1_FULL 0.0046 5123B 702ms
hlx_135-1_SPARSE 0.0046 5123B 672ms
hlx_145-1_FULL 0.0013 10243B 693ms
hlx_145-1_SPARSE 0.0013 10243B 678ms

Here Java-HLL is both more accurate and faster.

Dataset B

The second data set has the following characteristics:

  • 3765844 tokens
  • 2074012 unque keys (Sets.newHashSet(): 1195ms)
  • 2074012 unique hashed keys (Sets.newHashSet(): 2885ms)

StreamLib methods tuned for 1% error with 10 mil keys:

name error size time
lincnt 0.0005 137073B 663ms
ll__14 -0.0080 16384B 578ms
hll_13 0.0131 5472B 515ms
hlp13 -0.0118 5473B 566ms

And StreamLib vs Java-HLL:

name error size time
hlp10 0.0483 693B 560ms
hlp11 0.0336 1377B 489ms
hlp12 -0.0059 2741B 560ms
hlp13 -0.0118 5473B 567ms
hlp14 -0.0025 10933B 495ms
hlx_105-1_FULL -0.0227 643B 575ms
hlx_105-1_SPARSE -0.0227 643B 570ms
hlx_115-1_FULL -0.0194 1283B 505ms
hlx_115-1_SPARSE -0.0194 1283B 573ms
hlx_125-1_FULL -0.0076 2563B 500ms
hlx_125-1_SPARSE -0.0076 2563B 570ms
hlx_135-1_FULL -0.0099 5123B 576ms
hlx_135-1_SPARSE -0.0099 5123B 501ms
hlx_145-1_FULL 0.0015 10243B 572ms
hlx_145-1_SPARSE 0.0015 10243B 500ms

So the results are similar to those with Dataset A.


This comparison was done more than two years ago and I was quite skeptical to both frameworks. I found many strange thins in the StreamLib (both the reported issues and more), while Java-HLL did not work with other regsizes either. I settled for Java-HLL since it had a better implementation and gave better results. However, things change fast and StreamLib might have been improved a lot since then. I still want to look more at the code in both frameworks, and perhaps the frameworks that were published since then.

Nevertheless, HLL is clearly a method to use. A really nice feature of HLL is that you can have multiple counters and you can add (union) them together without loss. Intersection, however, can be tricky.

Open question

The register width in LogLog methods is the number of bits needed to represent the position maximum position of the first 1 bit. There are m = (beta / se)^2 such registers, where beta is a method-related constant and se is desired standard error, say 0.01. I guess this comes from StdErr = StdDev / sqrt(N) for a sample mean of a population (ref. wikipedia), but my knowledge of statistics is a bit too rusty to really understand this. Consequently, my understanding of the papers is that LogLog has beta = 1.30, HLL has beta = 1.106 and HLL++ has beta = 1.04, but I might be wrong. After all StreamLib code used these three numbers completely randomly in methods and tests. When I asked what was correct, they asked me back. Honestly, I don’t know :)

The Code

Read more

JMH caught my attention several times before, but I never had time to try it out. Although I have written several micro-benchmarks in the past, for example for my master thesis, I doubt that any of them were as technically correct as what JMH delivers. For a brief introduction I would highly recommend to look at this presentation, these blogposts – one, two, three, and the official code examples. More advanced examples can be found in Aleksey Shipilёv’s and Nitsan Wakart’s blog posts. In the following, I write a simple benchmark to test a hypothesis that bothered me for a while.

First, I generate the benchmark project with Maven and import it into Eclipse.

mvn archetype:generate -DinteractiveMode=false -DarchetypeGroupId=org.openjdk.jmh -DarchetypeArtifactId=jmh-java-benchmark-archetype -DgroupId=com.simonj -DartifactId=first-benchmark -Dversion=1.0

Then, I would like to test the following. Given an array of sequentially increasing integers and that we would like to count the number of distinct numbers, the simples solution is to use a for loop and a condition on the equality of the consecutive elements, in other words:

int uniques = numbers.length > 0 ? 1 : 0;
for (int i = 0; i < numbers.length - 1; i++) {
    if (numbers[i] != numbers[i + 1]) {
        uniques += 1;

However, we can utilize the fact that the difference between the two consecutive and non-equal elements will be negative, and thus we can just shift the sign bit in the rightmost position and increment the counter by it, or in other words:

int uniques = numbers.length > 0 ? 1 : 0;
for (int i = 0; i < numbers.length - 1; i++) {
    uniques += (numbers[i] - numbers[i + 1]) >>> 31;

The latter eliminates the inner branch and therefore a potential branch penalty. Hence, it should be faster, but is it really so? That is exactly what we can test with the following benchmark:

Here, I try both variants on an array with 1 000 000 random numbers in range of 0 to bound. I try the following bounds, 1 000, 10 000, 100 000, 1 000 000, 10 000 000, 100 000 000, to simulate the actual cardinality of the generated data set. On my macbook, it gives the following results:

Benchmark                   (bound)   (size)  Mode  Cnt     Score     Error  Units
MyFirstBenchmark.testBranched       1000  1000000  avgt   20   764.576 ± 150.049  us/op
MyFirstBenchmark.testBranched      10000  1000000  avgt   20   837.783 ± 143.009  us/op
MyFirstBenchmark.testBranched     100000  1000000  avgt   20  1598.128 ±  17.773  us/op
MyFirstBenchmark.testBranched    1000000  1000000  avgt   20  1209.819 ±  39.535  us/op
MyFirstBenchmark.testBranched   10000000  1000000  avgt   20  1068.606 ±  42.052  us/op
MyFirstBenchmark.testBranched  100000000  1000000  avgt   20   625.952 ±  24.321  us/op
MyFirstBenchmark.testShifted        1000  1000000  avgt   20   973.910 ±  41.843  us/op
MyFirstBenchmark.testShifted       10000  1000000  avgt   20   966.573 ±  30.002  us/op
MyFirstBenchmark.testShifted      100000  1000000  avgt   20   966.102 ±  17.895  us/op
MyFirstBenchmark.testShifted     1000000  1000000  avgt   20   973.528 ±  24.396  us/op
MyFirstBenchmark.testShifted    10000000  1000000  avgt   20   957.287 ±  31.399  us/op
MyFirstBenchmark.testShifted   100000000  1000000  avgt   20  1049.479 ± 108.593  us/op

This shows that for both very large and very small cardinalities the branched code is significantly faster than the shifted one, although somewhere in the middle it is indeed significantly slower. The reason to this is of course the branch prediction (read the answers to this question on StackOverflow for details), and it illustrates exactly why we cannot assume that branch elimination by bit-twiddling will always improve the performance.

So much for the first try! This was easy and fun, and I will definitely look more into JMH in future. And by the way, JMH can also do code profiling via the -prof <profiler> and -lprof options (the latter lists the available profilers).

Read more

Apache Spark is an open source cluster computing framework, which is becoming extremely popular these days. By now it has taken over the role of many previously used MapReduce and Machine Learning frameworks. So far there exists plenty of recepies on how to launch a cluster and get the examples and shell running from there. Nevertheless, assume that for an educational purpose or any other odd reason we would like to build a single JAR, with all dependencies included, which then runs some Spark related code on its own. In that case, here is a simple four-step recipe to get started from scratch.

Create a new Maven Java project

The easiest way to do this is from the command line (look here for an explanation):

mvn archetype:generate -DgroupId=com.simonj.demo -DartifactId=spark-fun -DarchetypeArtifactId=maven-archetype-quickstart -DinteractiveMode=false

Edit the POM file

In my example, I first explicitly state the Java version, 1.8. Then, I remove the junit dependency and add dependencies to spark-core_2.10, testng and guava (note the version 16.0 to avoid conflicts with the current version of spark-core). Finally, I use the Maven shade plugin to include the dependencies, with additional filters and transformers to get this stuff working.

Import the project into an IDE and edit the files

In the next step, I import the project into Eclipse and edit App.java and AppTest.java. The code illustrates a simple word counting in Spark, but the important part here is using something like the following (where I launch a new Spark context with a local master):

try (JavaSparkContext context = new JavaSparkContext("local[2]", "Spark fun!")) {

Build the project and run

In the final step, I first build the project:

mvn clean package

Then create a test file, and run the App.java from the command line (note that here I use the allinone.jar, which is the one with all dependencies included):

java -cp target/spark-fun-1.0-SNAPSHOT-allinone.jar com.simonj.demo.App test.txt

Finally, after a short time the example program spits out something like this:

{took=1, lorem=4, but=1, text=2, is=1, standard=1, been=1, sheets=1, including=1, electronic=1, of=4, not=1, software=1, type=2, survived=1, book=1, only=1, s=1, desktop=1, to=1, passages=1, containing=1, and=3, versions=1, more=1, typesetting=2, essentially=1, recently=1, ipsum=4, a=2, galley=1, aldus=1, 1960s=1, simply=1, when=1, ever=1, dummy=2, with=2, 1500s=1, in=1, publishing=1, like=1, printing=1, five=1, industry=2, letraset=1, pagemaker=1, since=1, was=1, an=1, into=1, the=6, make=1, has=2, it=3, remaining=1, unknown=1, popularised=1, leap=1, unchanged=1, centuries=1, specimen=1, also=1, printer=1, release=1, scrambled=1}

So it works – what a lovely evening and good night folks!

PS. here is the complete project created through these steps.

Read more