Tuesday, May 22, 2012

Clojure Lessons

Recently I've been working with Java code in a Spring framework. I'm not a big fan of Spring, since the bean approach means that everything has a similar public face, which means that the data types don't document the system very well. The bean approach also means that most types can be plugged into most places (kind of like Lego), but just because something can be connected doesn't mean it will do anything meaningful. It can make for a confusing system. As a result, I'm not really have much fun at work.

To get myself motivated again, I thought I'd try something fun and render a Mandelbrot set. I know these are easy, but it's something I've never done for myself. I also thought it might be fun to do something with graphics on the JVM, since I'm always working on server-side code. Turned out that it was fun, keeping me up much later than I ought to have been. Being tired tonight I may end up rambling a bit. It may also be why I've decided to spell "colour" the way I grew up with, rather than the US way (except in code. After all, I have to use the Color class, and it's just too obtuse to have two different spellings in the same program).

To get my feet wet, I started with a simple Java application, with a plan to move it into Clojure. My approach gave me a class called Complex that can do the basic arithmetic (trivial to write, but surprising that it's not already there), and an abstract class called Drawing that does all of the Window management and just expects the implementing class to implement paint(Graphics). With that done it was easy to write a pair of functions:

  • coord2Math to convert a canvas coordinate into a complex number.
  • mandelbrotColor to calculate a colour for a given complex number (using a logarithmic scale, since linear shows too many discontinuities in colour).
Drawing this onto a graphical context is easy in Java:
for (int x = 0; x < gWidth; x++) {
  for (int y = 0; y < gHeight; y++) {
    g.setColor(mandelbrotColor(coord2Math(x, y)));
    plot(g, x, y);
  }
}

(plot(Graphics,int,int) is a simple function that draws one pixel at the given location).

A small image (300x200 pixels) on this MacBookPro takes ~360ms. A big one (1397x856) took ~11500ms. Room for improvement, but it'll do. So with a working Java implementation in hand, I turned to writing the same thing in Clojure.

Clojure Graphics

Initially I tried extending my Drawing class using proxy, with a plan of moving to an implementation completely in Clojure. However, after getting it working that way I realized that doing the entire thing in Clojure wasn't going to take much at all, so I did that straight away. The resulting code is reasonably simple and boilerplate:

(def window-name "Mandelbrot")
(def draw-fn)

(defn new-drawing-obj []
  (proxy [JPanel] []
    (paint [^Graphics graphics-context]
      (let [width (proxy-super getWidth)
            height (proxy-super getHeight)]
        (draw-fn graphics-context width height)))))

(defn show-window []
  (let [^JPanel drawing-obj (new-drawing-obj)
        frame (JFrame. window-name)]
    (.setPreferredSize drawing-obj (Dimension. default-width default-height))
    (.add (.getContentPane frame) drawing-obj)
    (doto frame
      (.setDefaultCloseOperation JFrame/EXIT_ON_CLOSE)
      (.pack)
      (.setBackground Color/WHITE)
      (.setVisible true))))

(defn start-window []
  (SwingUtilities/invokeLater #(show-window)))

Calling start-window sets off a thread that will run the event loop and then call the show-window function. That function uses new-drawing-obj to create a proxy object that handles the paint event. Then it sets the size of panel, puts it into a frame (the main window), and sets up the frame for display.

The only thing that seems worth noting from a Clojure perspective is the proxy object returned by new-drawing-obj. This is simple extension of java.swing.JPanel that implements the paint(Graphics) method of that class. Almost every part of the drawing can be done in an external function (draw-fn here), but the width and height are obtained by calling getWidth() and getHeight() on the JPanel object. That object isn't directly available to the draw-fn function, nor is it available through a name like "this". The object is returned from the proxy function, but that's out of scope for the paint method to access it. The only reasonable way to access methods that are inherited in the proxy is with the proxy-super function (I can think of some unreasonable ways as well, like setting a reference to the proxy, and using this reference in paint. But we won't talk about that kind of abuse).

While I haven't shown it here, I also wanted to close my window by pressing the "q" key. This takes just a couple of lines of code, whereby a proxy for KeyListener is created, and then added to the frame via (.addKeyListener the-key-listener-proxy). Compared to the equivalent code in Java, it's strikingly terse.

Rendering

The Java code for rendering used a pair of nested loops to generate coordinates, and then calculated the colour for each coordinate as it went. However, this imperative style of coding is something to explicitly avoid in any kind of functional programming. So the question for me at this point, was how should I think about the problem?

Each time the mandelbrotColor was to be called, it is mapping a coordinate to a colour. This gave me my first hint. I needed to map coordinates to colours. This implies calling map on a seq of coordinates, and ending up with a seq of colours. (Actually, not a seq, but rather a reducible collection). However, what order are the colours in? Row-by-row? That would work, but it would involve keeping a count of the offset while working over the seq, which seems onerous, particular when the required coordinates were available when the colour was calculated in the first place. So why not include the coordinates in the seq with the colour? Not only does that simplify processing, it makes the rendering of this map stateless, since any element of the seq could be rendered independently of any other.

Coordinates can be created as pairs of integers using a comprehension:

  (for [x (range width) y (range height)] [x y])

and the calculation can be done by mapping on a function that unpacks x and y and returns a triple of these two coordinates along with the calculated colour. I'll rename x and y to "a" and "b" in the mapping function to avoid ambiguity:

  (map (fn [[a b] [a b (mandelbrot-color (coord-2-math a b))])
       (for [x (range width) y (range height)] [x y]))

So now we have a sequence of coordinates and colours, but how do these get turned into an image? Again, the form of the problem provides the solution. We have a sequence (of tuples), and we want to reduce it into a single value (an image). Reductions like this are done using reduce. The first parameter for the reduction function will be the image, the second will be the next tuple to draw, and the result will be a new image with the tuple drawn in it. The reduce function isn't really supposed to mutate its first parameter, but we don't want to keep the original image without the pixel to be drawn, so it works for us here. The result is the following reduction function (type hint provided to avoid reflection on the graphical context):

  (defn plot [^Graphics g [x y c]]
    (.setColor g c)
    (.fillRect g x y 1 1)
    g)

Note that the original graphics context is returned, since this is the "new" value that plot has created (i.e. the image with the pixel added to it). Also, note that the second parameter is a 3 element tuple, which is just unpacked into x y and c.

So now the entire render process can be given as:

(reduce plot g
  (map (fn [[a b] [a b (mandelbrot-color (coord-2-math a b))])
       (for [x (range width) y (range height)] [x y])))

This works just fine, but there were performance issues, which was the part of this process that was most interesting. The full screen render (1397x856) took ~682 seconds (up from the 11.5 seconds it took Java). Obviously there were a few things to be fixed. There is still more to do, but I'll share what I came across so far.

Reflection

The first thing that @objcmdo suggested was to look for reflection. I planned on doing that, but thought I'd continue cleaning the program up first. The Complex class was still written in Java, so I embarked on rewriting that in Clojure.

The easiest way to do this was to implement a protocol that describes the actions (plus, minus, times, divide, absolute value), and to then define a record (of real/imaginary) that extends the protocol. It would have been nicer than the equivalent Java, but for one thing. Java allows method overloading based on parameter types, which means that a method like plus can be defined differently depending on whether it receives a double value, or another Complex number. My understanding is that Clojure only overloads functions based on the parameter count, meaning that different function names are required to redefine the same operation for different types. So for instance, the plus functions were written in Java as:

  public final Complex plus(Complex that) {
    return new Complex(real + that.real, imaginary + that.imaginary);
  }

  public final Complex plus(double that) {
    return new Complex(real + that, imaginary);
  }
But in Clojure I had to give them different names:
  (plus [this {that-real :real, that-imaginary :imaginary}]
        (Complex. (+ real that-real) (+ imaginary that-imaginary)))
  (plus-dbl [this that] (Complex. (+ real that) imaginary))

Not a big deal, but code like math manipulation looks prettier when function overloading is available.

It may be worth pointing out that I used the names of the operations (like "plus") instead of the symbolic operators ("+"). While the issue of function overloading would have made this awkward (+dbl is no clearer than plus-dbl) it has the bigger problem of clashing with functions of the same name in clojure.core. Some namespaces do this (the * character is a popular one to reuse), but I don't like it. You have to explicitly reject it from your current namespace, and then you need to refer to it by its full name if you do happen to need it. Given that Complex needs to manipulate internal numbers, then these original operators are needed.

So I created my protocol containing all the operators, defined a Complex record to implement it, and then I replaced all use of the original Java Complex class. Once I was finished I ran it again just to make sure that I hadn't broken anything.

To my great surprise, the full screen render went from 682 seconds down to 112 seconds. Protocols are an efficient mechanism, but they shouldn't be that good. At that point I realised that I hadn't used type hints around the Complex class, and that as a consequence the Clojure code had to perform reflection on the complex numbers. Just as @objcmdo had suggested.

Wondering what other reflection I may have missed, I tried enabling the *warn-on-reflection* flag in the repl, but no warnings were forthcoming. I suspect that this was being subverted by the fact that the code is all being run by a thread that belongs to the Swing runtime. I tried adding some other type hints, but nothing I added had any effect, meaning that the Clojure compiler was already able to figure out the types involved (or else it just wasn't in a critical piece of code).

Composable Abstractions

The next thing I wondered about was the map/reduce part of the algorithm. While it made for elegant programming, it was creating unnecessary tuples at every step of the way. Could these be having an impact?

Once you have a nice list comprehension, it's tough to break it out into an imperative-style loop. Aside from ruining the elegance of the original construct, once you've seen your way through to viewing a problem in such clear terms, it's difficult to reconceptualize it as a series of steps. Even when you do, how do you make Clojure work against itself?

Creating a loop without burning through resources can be done easily with tail recursion. Clojure doesn't do this automatically (since the JVM does not provide for it), but it can be emulated well with loop/recur. Since I want to loop between 0 (inclusive) and the width/height (exclusive), I decremented the upper limits for convenience. Also, the plot function is no longer constraint to just 2 arguments, so I changed the definition to accept all 4 arguments directly, thereby eliminating the need to construct that 3-tuple:

(let [dwidth (dec width)
                 dheight (dec height)]
  (loop [x 0 y 0]
    (let [[next-x next-y] (if (= x dwidth)
                              (if (= y dheight)
                                  [-1 -1]      ;; signal to terminate
                                  [0 (inc y)])
                              [(inc x) y])]
      (plot g x y (mandelbrotColor (coord-2-math x y)))
      (if (= -1 next-x)
        :end    ;; anything can be returned here
        (recur next-x next-y)))))

My word, that's ugly. The let that assigns next-x and next-y has a nasty nested if construct that increments x and resets it at the end of each row. It also returns a flag (could be any invalid number, such as the keyword :end) to indicate that the loop should be terminated. The loop itself terminates by testing for the termination value and returning a value that will be ignored.

But it all works as intended. Now instead of creating a tuple for every coordinate, it simply iterates through each coordinate and plots the point directly, just as the Java code did. So what's the performance difference here?

So far, the numbers I've provided are rounded to the nearest second. Repeated runs have usually taken a similar amount of time to the ones that I've reported here. However, there is always some jitter, sometimes by several seconds. Because of this, I was unable to see any difference whatsoever between using map/reduce on a for comprehension, versus using loop/recur.

That's an interesting result, since it shows that the Clojure compiler and JVM are indeed as clever as we're told, when we see that better abstractions are just as efficient as the direct approach. It's all well and good for a language to make it easy to write powerful constructs, but being able to perform more elegant code just as efficiently as more direct, imperative code that a language is really offering useful power.

Aside from the obvious clarity issues, the composability of the for/map/reduce makes an enormous difference. Because each element in the range being mapped is completely independent, we are free to use the pmap function instead of map. The documentation claims that this function is,

"Only useful for computationally intensive functions where the time of f dominates the coordination overhead."

Yup. That's us.

So how much does this change make for us? Using map on the current code, a full screen render takes 112 seconds. Changing map to pmap improves it to 75 seconds. That's a 33% improvement with no work, simply because the correct abstraction was applied. That's a very powerful abstraction.

Future Work

(Hmmm, that makes this sound like an academic paper. Should I be drawing charts?)

The final result is still a long way short of the 11.5 seconds the naïve Java code renders at. The single threaded version is particularly bad, taking about 10 times as long. I don't expect Clojure to be as fast as Java, but a factor of 10 suggests that there are some obvious things that I've missed, most likely related to reflection. If I can get it down to the same order of magnitude as the Java code, then using pmap could make the Clojure version faster due to being multi-threaded. Of course, Java can be multi-threaded as well, but the effort and infrastructure for doing this would be significant.