Blending IPython's widgets and mpld3's plugins

This notebook performs a function quite similar to the 'sliderPlugin' example. Browser side visualisation is actionable and triggers recalculations in the ipython backend. While the sliderPlugin connects to the kernel, we use IPython's facilities : interact does the lifting for us.

Because you need an IPython instance running, you cannot use it directly on nbviewer for example. You have to download this notebook and run it in IPython yourself.

I used IPython 3.0.0-dev as of 2014/11/03. The widget interface does not seems so stable for now so you may have to tinker to get this working. If you experience problems I think that the examples we built on would be good material to get the whole thing working again.

Objective

We want to fit a curve in a cloud of points. The points are drag/drop-able by the user of the notebook and upon dropping the point, the fit is recalculated.

The model can be pretty much any $R \to R$ function, with any number of parameters.

In what follows you will see it :

  • (partially) applied to an "first order exponential response to an Heavyside function" (for lack of better wording on my side);
  • applied to an arc-tangente.

Architecture

Here is how things are organized :

  1. code copyied from the ClickInfo/DragPoints examples on the mpld3 side will generate updates when the user drag and drop the circles;
  2. these updates are the new coordinates of a given point of the cloud;
  3. the update trigger the 'change' event on a text widget from IPython (code taken from the custom widget example);
  4. IPython cogs and wheels transmit the update back to the IPython server side;
  5. where we recalculate parameters, and redraw everything.
In [1]:
# imports widget side
# see https://github.com/ipython/ipython/blob/2.x/examples/Interactive%20Widgets/Custom%20Widgets.ipynb
# and https://github.com/ipython/ipython/blob/master/examples/Interactive%20Widgets/Custom%20Widget%20-%20Hello%20World.ipynb

from __future__ import print_function # For py 2.7 compat

from IPython.html import widgets # Widget definitions
from IPython.display import display # Used to display widgets in the notebook
from IPython.utils.traitlets import Unicode # Used to declare attributes of our widget
from IPython.html.widgets import interact, interactive, fixed
In [2]:
# imports render side
# see http://mpld3.github.io/examples/drag_points.html

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import mpld3
from mpld3 import plugins, utils
In [3]:
# imports solve side
# see http://stackoverflow.com/questions/8739227/how-to-solve-a-pair-of-nonlinear-equations-using-python

from scipy.optimize import fsolve

#def expchelon(a, b, x):
#    return a * (1 - np.exp(-b * x))

#def fun(p1, p2):
#    x1, y1 = p1
#    x2, y2 = p2
#    def equations(p):
#        a, b = p
#        return (y1 - expchelon(a, b, x1), y2 - expchelon(a, b, x2))
#    return equations

#equations = fun((1,1), (2,4))
#a, b =  fsolve(equations, (1, 1))

#print((a, b), expchelon(a, b, 1), expchelon(a, b, 2))
In [4]:
# widget sync'd python side
class GraphWidget(widgets.DOMWidget):
    _view_name = Unicode('GraphView', sync=True)
    description = 'coord'    
    value = Unicode(sync=True)
In [5]:
%%javascript
//widget javascript side
require(["widgets/js/widget", "widgets/js/manager"], function(widget, manager){
    // is based on the DatePickerView
    var GraphView = widget.DOMWidgetView.extend({
        render: function() {
            //@ attr id : this is the id we reach to in the dragended function in the DragPlugin
            this.$text = $('<input />')
                .attr('type', 'text')
                .attr('id', 'feedback_widget')                
                .appendTo(this.$el);
        },
        
        update: function() {
            this.$text.val(this.model.get('value'));
            return GraphView.__super__.update.apply(this);
        },
        
        events: {"change": "handle_change"},
        
        handle_change: function(event) {
            this.model.set('value', this.$text.val());
            this.touch();
        },
    });
    
    manager.WidgetManager.register_widget_view('GraphView', GraphView);
});
In [6]:
# visu plugin
# based on DragPlugin
class DragPlugin(plugins.PluginBase):
    JAVASCRIPT = r"""
$('#feedback_widget').hide();
mpld3.register_plugin("drag", DragPlugin);
DragPlugin.prototype = Object.create(mpld3.Plugin.prototype);
DragPlugin.prototype.constructor = DragPlugin;
DragPlugin.prototype.requiredProps = ["id"];
DragPlugin.prototype.defaultProps = {}
function DragPlugin(fig, props){
    mpld3.Plugin.call(this, fig, props);
    mpld3.insert_css("#" + fig.figid + " path.dragging",
                     {"fill-opacity": "1.0 !important",
                      "stroke-opacity": "1.0 !important"});
};$

DragPlugin.prototype.draw = function(){
    var obj = mpld3.get_element(this.props.id);

    var drag = d3.behavior.drag()
        .origin(function(d) { return {x:obj.ax.x(d[0]),
                                      y:obj.ax.y(d[1])}; })
        .on("dragstart", dragstarted)
        .on("drag", dragged)
        .on("dragend", dragended);

    obj.elements()
       .data(obj.offsets)
       .style("cursor", "default")
       .call(drag);

    function dragstarted(d) {
      d3.event.sourceEvent.stopPropagation();
      d3.select(this).classed("dragging", true);
    }

    function dragged(d, i) {
      d[0] = obj.ax.x.invert(d3.event.x);
      d[1] = obj.ax.y.invert(d3.event.y);
      d3.select(this)
        .attr("transform", "translate(" + [d3.event.x,d3.event.y] + ")");
    }

    function dragended(d,i) {
      d3.select(this).classed("dragging", false);
      // feed back the new position to python, triggering 'change' on the widget
      $('#feedback_widget').val("" + i + "," + d[0] + "," + d[1]).trigger("change");
    }
}"""

    def __init__(self, points):
        if isinstance(points, mpl.lines.Line2D):
            suffix = "pts"
        else:
            suffix = None

        self.dict_ = {"type": "drag",
                      "id": utils.get_id(points, suffix)}
In [7]:
# fit and draw
class Fit(object):
    def __init__(self, simulate, double_seeding=False):
        self.simulate = simulate
         
        # i will draw initial points at random
        # the number of points will increase until we match arity with the function to be fit(ted?)
        pseudo_fit = []
        while len(pseudo_fit) < 100:
            # just in case, I want to avoid inifite loops...
            try:
                simulate(0, pseudo_fit)
                print("we have %d parameters"%len(pseudo_fit))
                break
            except IndexError:
                pseudo_fit.append(1)
                
        # we generate a random cloud 
        # the dots are distributed in (>0, >0) quadrant    
        self.p = np.random.standard_exponential((len(pseudo_fit), 2))
        
        # first guess ! all ones.
        self.fit = np.array(pseudo_fit)
                
    def make_equations(self):
        def equations(params):
            return self.p[:,1] - self.simulate(self.p[:,0], params)
        self.equations = equations
    
    def recalc_param(self):
        self.make_equations()
        self.fit = fsolve(self.equations, np.ones(self.fit.shape), xtol=0.01)
        
    def redraw(self, coord):
        # we have an update !
        
        # record the new position for given point 
        if coord != "":
            i, x, y = coord.split(",")
            i = int(i)
            self.p[i][0] = float(x)
            self.p[i][1] = float(y)
            
        # recalculate best fit
        self.recalc_param()
        
        # draw things
        x = np.linspace(0, 10, 50) # 50 x points from 0 to 10
        y = self.simulate(x, self.fit)
    
        fig, ax = plt.subplots()

        points = ax.plot(self.p[:,0], self.p[:,1],'or', alpha=0.5, markersize=10, markeredgewidth=1)
        
        ax.plot(x,y,'r-')
        ax.set_title("Click and Drag\n, we match on : %s"%np.array_str(self.fit, precision=2), fontsize=12)

        plugins.connect(fig, DragPlugin(points[0]))

        fig_h = mpld3.display()
        display(fig_h)
In [8]:
# click and drag not active here, we just show how we fit

def exp_ech(x, params):
    return params[0] * (1 - np.exp(-params[1] * x))

# we ensure we will fit nicely by setting p[0] at [0,0]
# in effect adding one degree of liberty
Fit(exp_ech).redraw("0,0,0")
we have 2 parameters
In [9]:
def arctan(x, params):
    return params[0] * np.arctan(params[1] * x + params[2])

my_fit = Fit(arctan)

# not sure why, but you can't do
# interact(my_fit.redraw, coord=GraphWidget())
# so we need :
def f(coord):
    return my_fit.redraw(coord)
    
interact(f, coord=GraphWidget())