Scientific Machine Learning: Interpretable Neural Networks That Accurately Extrapolate From Small Data


The fundamental problems of classical machine learning are:

  1. Machine learning models require big data to train
  2. Machine learning models cannot extrapolate out of the their training data well
  3. Machine learning models are not interpretable

However, in our recent paper, we have shown that this does not have to be the case. In Universal Differential Equations for Scientific Machine Learning, we start by showing the following figure:

Indeed, it shows that by only seeing the tiny first part of the time series, we can automatically learn the equations in such a manner that it predicts the time series will be cyclic in the future, in a way that even gets the periodicity correct. Not only that, but our result is not a neural network, rather the program itself spits out the LaTeX for the differential equation:

$$x^\prime = \alpha x – \beta x y$$

$$y^\prime = -\delta y + \gamma x y$$

with the correct coefficients, which is exactly the how we generated the data.

Rather than just explaining the method, what I really want to convey in this blog post is why it intuitively makes sense that our methods work. Intuitively, we utilize all of the known scientific structure to embed as much prior knowledge as possible. This is the opposite of most modern machine learning which tries to use blackbox architectures to fits as wide of a range of behaviors as possible. Instead, what we do is we look at our problem and say, what do I know has to be true about the system, and how can I constrain the neural network to force the parameter search to only look at cases such that it is true. In the context of science, we do so by directly embedding neural networks into existing scientific simulations, essentially saying that the model is at least approximately accurate, so our neural network should only learn what the scientific model didn’t cover or simplified away. Our approach has many more applications than what we show in the paper, and if you know the underlying idea, it is quite straightforward to apply to your own work.

Quick Overview of the Universal Differential Equation Approach

The starting point for universal differential equations is the now classic work on neural ordinary differential equations. Neural ODEs are defined by the equation:

$$u^\prime = \text{NN}_\theta (u),$$

i.e. it’s an arbitrary function described as the solution to an ODE defined by a neural network. The reason why the authors went down this route was because it’s a continuous form of a recurrent neural network that then makes it natural for handling irregularly-spaced time series data.

However, this formulation can have another interesting purpose. ODEs, and differential equations in general, are well-studied because they are the language of science. Newtonian physics is described in terms of differential equations. So are Einstein’s equations, and quantum mechanics. Not only physics, but also biological models of chemical reactions in cells, population sizes in ecology, the motion of fluids, etc.: differential equations are really the language of science.

Thus it’s not a surprise that in many cases we already have and know some differential equations. They may be an approximate model, but we know this approximation is “true” in some sense. This is the jumping point for the universal differential equation. Instead of trying to make the entire differential equation be a neural network like object, since science is encoded in differential equations, it would be scientifically informative to actually learn the differential equation itself. But, in any scientific context we already know parts of the differential equation, so we might as well hard code that information as a form of prior structural information. This gives us the form of the universal differential equation:

$$u^\prime = f(u,U_\theta,t)$$

where $$U_\theta$$ is an arbitrary universal approximator, i.e. a finite parameter object that can represent “any possible function”. It just so happens that neural networks are universal approximators, but note that other forms, like Chebyshev polynomials, also have this property, but neural networks do well in high dimensions and on irregular grids (some properties we utilize in some of our other examples).

What happens when we describe a differential equation in this form is that the trained neural network becomes a tangible numerical approximation of the missing function. By doing this, we can train a program that has the same exact input/output behavior as the missing term of our model. And this is precisely what we do. We assume we only know part of the differential equation:

$$x^\prime = \alpha x – U^1_\theta (x,y)$$

$$y^\prime = -\delta y + U^2_\theta(x,y)$$

and train a neural network so that way embedded neural networks defined a universal ODE that fits our data. When trained, the neural network is a numerical approximation to the missing function. But since it’s just a simple function, it’s fairly straightforward to plot it and say “hey! We were missing a quadratic term”, and there you go: interpreted back to the original generating equations. In the paper we describe how to make use of the SInDy technique to make this more rigorous through a sparse regression to a basis of possible terms, but it’s the same story that in the end we learn exactly the differential equations that generated the data, and hence the extrapolation accuracy even beyond the original time series and the nice picture:

Understanding Universal Differential Equations As A Method to Simplify Learning

Trying to approximate data might be a much harder problem then trying to understand the processes that create the data. Indeed, the Lotka-Volterra equations are a simple set of equations that are defined by 4 interpretable terms. The first simply states that the number of rabbits would grow exponentially if there wasn’t a predator eating them. The second term just states that the number of rabbits goes down when they are eaten by predators (and more predators means more eating). The third term is that more prey means more food and growth of the wolf population. Finally, the wolves die off with an exponential decay due to old age.

Lotka-Volterra with Emojis

That’s it: a simple quadratic equation that describes 4 mechanisms of interaction. Each mechanism is interpretable and can be independently verified. Meanwhile, that cyclic solution that is the data set? The time series itself is such a complicated function that you can prove that there is no way to even express its analytical solution!. This phenomena is known as emergence: simple mechanisms can give rise to complex behavior. From this it should be clear that a method which is trying to predict a time series has a much more difficult problem to solve than one that is trying to learn mechanisms!

One way to really solidify this idea is our next example. In our next example we showcase how reconstructing a partial differential equation can be straightforwardly done through universal approximators embedded within partial differential equations, what we call a universal PDE. If you want the full details behind what I will handwave here, take a look at the MIT 18.337 Scientific Machine Learning course notes or the MIT 18.S096 Applications of Scientific Machine Learning course notes. In it, there is a derivation that showcases how one can interpret partial differential equations as large systems of ODEs. Specifically, the Fisher-KPP equations that we look at in our paper:

$$\frac{\partial u}{\partial t} = D \frac{\partial^2 u}{\partial x^2} + u(1-u)$$

can be interpreted as a system of ODEs:

$$u_i^\prime = D(u_{i-1} – 2u_i + u_{i+1}) + u_i (1 – u_i)$$

However, the term in front, $$u_{i-1} – 2u_i + u_{i+1}$$, is known as a stencil. Basically, you go to each point and the solution of this operation is, sum the left and the right terms and subtract twice from the middle. Sounds familiar? It turns out that a convolutional layer from convolutional neural networks are actually just parameterized forms of stencils. For example, a picture of a two-dimensional stencil looks like:

stencil gif

where this stencil is applying the operation:

1 0 1
0 1 0
1 0 1

A convolutional layer is just a parameterized form of a stencil operation:

w1 w2 w3
w4 w5 w6
w7 w8 w9

Thus one way to approach learning spatiotemporal data which we think may come from such a generating process is:

$$u_i^\prime = D(w_1 u_{i-1} – w_2 u_i + w_3 u_{i+1}) + \text{NN}(u_i)$$

i.e., the discretization of the second spatial derivative, what’s known as diffusion, is the physically represented as the stencil of weights [1 -2 1]. Notice that in this form, the entire spatiotemporal data is described by a 1-input 1-output neural network + 3 parameters. Globally of the array of all $$u_i$$, this becomes:

$$u^\prime = D \text{CNN}(u) + \text{NN}.(u)$$

i.e. it’s a universal differential equation with a 3 parameter CNN and (the same) small $$R \rightarrow R$$ neural network applied at each spatial point.

Can this tiny neural network actually fit the data? It does. But not only does it fit the data, it also is interpretable:

Trained Fisher-KPP

You see, not only did it fit and accurately match the data, it also tells us exactly the PDE that generated the data. Notice that figure (C) says that $$w_1 = w_3$$ in the fitted equation, and that $$w_2 = – (w_1 + w_3)$$. This means that the convolutional neural network learned to be $$D[1,-2,1]$$, exactly as the theory would predict if the only spatial physical process was diffusion. But secondly, figure (D) shows that the neural network that represented the 1-dimensional behavior seems to be quadratic. Indeed, remember from the PDE discretization:

$$u_i^\prime = D(u_{i-1} – 2u_i + u_{i+1}) + u_i (1 – u_i)$$

it really is quadratic. This tells us that the only physical generating process that could have given us this data is a diffusion equation with a quadratic reaction, i.e.

$$\frac{\partial u}{\partial t} = D \frac{\partial^2 u}{\partial x^2} + u(1-u)$$

thus interpreting the neural networks to precisely receive a PDE governing the evolution of the spatiotemporal data. Not only that, this trained form can predict beyond its data set, since if we wished for it to predict the behavior of a fluid with a different diffusion constant $$D$$, we know exactly how to change that term without retraining the neural networks since our weights in the convolutional neural network is $$D[1,-2,1]$$, so we’d simply rescale those weights and suddenly have a neural network that predicts for a different underlying fluid.

Small neural network, small data, trained in an interpretable form that extrapolates, even on hard problems like spatiotemporal data.

Why It Works: Scientific Knowledge is Encoded in Structure, Not Data Points

From this explanation, it should be very clear that our approach is general, but in every application, it’s specific. We utilize prior knowledge of differential equations, like known physics or biological interactions, to try and hard code as much of the equation as possible. Then the neural networks are just stand-ins for the little pieces that are leftover. Thus the neural networks have a very easy job! They don’t have to learn very much! They just have to learn the leftovers! Thus the problem becomes easy since we imposed so much knowledge in how the neural infrastructure was made by utilizing the differential equation form to its fullest.

I like to think of it as follows. There is a certain amount of knowledge $$K$$ that is required to effectively learn the problem. Knowledge can come from prior information $$P$$ or it can come from data $$D$$. Either way, you need enough knowledge $$P + D \geq K$$ to effectively learn the model and do accurate predicting. Machine learning has gone the route of effectively relying entirely on data, but that doesn’t need to be the case. We know how physics works, and how time series relate to derivatives, so there’s no reason to force a neural network to have to learn these parts. Instead, by writing small neural networks inside of differential equations, we can embed everything that we know about the physics as true structurally-imposed prior knowledge, and then what’s left is a simple training problem. That way a big $$P$$ and a small $$D$$ still gives you $$K$$, and that’s how you make a “neural network” accurately extrapolate from only a small amount of training data. And now that it’s only learning a simple function, what it learned is easily interpretable through sparse regression techniques.

Software and Performance Issues

Once we had this idea of wanting to embed structure, then all of the hard work came. Essentially, in big neural network machine learning, you can get away with a lot of performance issues if 99.9% of your time is spent in the neural network’s calculations. But once we got into the regime of small data small neural network structured machine learning, our neural networks were not the big time consumer, which meant every little detail mattered. Thus we needed to hyper-optimize the solution of small ODE solves to make this a reality. As a result of our optimizations, we have easily reproducible benchmarks which showcase a 50,000x acceleration over the torchdiffeq neural ODE library. In fact, benchmarks show across the board orders of magnitude performance advantages over SciPy, MATLAB, and R’s deSolve as well. This is not a small detail, as this problem of training neural networks within scientific simulations is a costly project which takes many millions of ODE solves, and therefore these performance optimizations changed the problem from “impractical” to “reality”. Again, when very large neural networks are involved this may be masked by the cost of neural network passes itself, but in the context of small network scientific machine learning, this change was a godsend.

But the even larger difficulty that we noticed was that traditional numerical analysis ideas like stability really came into play once real physical models got involved. There is this property of ODEs called stiffness, and when it comes into play, the simple Runge-Kutta method or Adams-Bashforth-Moulton methods are no longer stable enough to accurately solve the equations. Thus when looking at the universal partial differential equations, we had to make use of a set of ODE solvers which have package implementations in Julia and Fortran. Additionally, any form of backwards solving is unconditionally unstable on the diffusion-advection equation, meaning that it ended up being a practical case where the use of simple adjoint methods like the backsolve approach of the original neural ODEs paper and torchdiffeq actually ends up diverging to infinity in finite time for any tolerance on the ODE solver. Thus we had to implement a bunch of different versions of (checkpointed) adjoint implementations in order to be accurately and efficiently train neural networks within these equations. Then, once we had a partial differential equation form, we had to build tools that would integrate with automatic differentiation to automatically specialize on sparsity. The result was a a full set of advanced methods for efficiently handling stiffness that was fully compatible with neural network backpropagation. It was only when all of this came together that the most difficult examples of what we showed actually worked. Now, our software DifferentialEquations.jl with DiffEqFlux.jl is able to handle:

  • Stiffness and ill-conditioned problems
  • Universal ordinary differential equations
  • Universal stochastic differential equations
  • Universal delay differential equations
  • Universal differential algebraic equations
  • Universal partial differential equations
  • Universal (event-driven) hybrid differential equations

all with GPUs, sparse and structured Jacobians, preconditioned Newton-Krylov, and the list of features just keeps going. This is the limitation: when real scientific models get involved, the numerical complexity drastically increases. But now this is something that at least Julia libraries have solved.

Final Thoughts

The takeaway there is that not only do you need to use all of the scientific knowledge available, but you also need to make use of all of the numerical analysis knowledge. When you combine all of this knowledge with the most recent advances of machine learning, then you get small neural networks that train on small data in a way that is interpretable and accurately extrapolates. So yes, it’s not magic: we just replaced the big data requirement with the requirement of having some prior scientific knowledge, and if you go talk to any scientist you’ll know this data source exists. I think it’s time we use it in machine learning.

Code and Reproducibility

The code to reproduce our results is in this Github repository. However, I would like to see people try other examples. All of the tooling is now in the open source DifferentialEquations.jl and DiffEqFlux.jl packages. The implementation of SInDy is in the DataDrivenDiffEq package (this one isn’t quite released yet, but the tooled used for these examples are released and it does spit out ModelingToolkit which you can call LaTeXify on. Documentation on this will come soon!). The final example in the paper is a library call in NeuralNetDiffEq.jl with the algorithm choice of LambaEM().

For examples on how to use these tools in your own packages, I would consult this part of the DiffEqFlux.jl README.

15 thoughts on “Scientific Machine Learning: Interpretable Neural Networks That Accurately Extrapolate From Small Data

  1. Henri Laurie

    says:

    Hi Chris — Lovely paper, sparkling blog. Thanks!

    Query abt VolterraExp.jl. I am teaching a class and referring to it extensively, I want them to modify it in as many ways as they can. I find that simply changing the initial u0 has impact on whether the final SInDy (“test on uode derivative data”) actually works i.e. gives the desired model.

    Specifically, u0 = Float32[2.24, 0.1]
    works (even though initial NN training produces lots of warnings) while u0 = Float32[2.24, 10.0]
    fails (even though initial NN training seems fine).

    Seems that when it fails the SInDy doesn’t converge, or converges to another model. Any ideas why?


    • The SInDy step is definitely the most finicky. It’s quite robust, but there are ways to break it. Using a different Lasso optimizer, like SR3, helps a lot. Also, using a different SInDy formulation like iSInDy can also help a ton. This is why we’re building out DataDrivenDiffEq.jl with a ton of methods, since the robustness of this portion is definitely something we can keep working on and explore.


  2. Carl

    says:

    1) According to your presentation in the video titled: “Universal Differential Equations for Scientific Machine Learning” you stated that you can “wrap” a neural network around a simulator, to allow it to run in real time. My understanding is that once you know what model and equations you are using, you can just run it as many times as you want to collect a nice distribution of data, and then train a neural network on it in order to represent that data in a compressed form, which can then be run in real time because all that is being computed is the forward pass of that NN, is this correct?

    Regarding this NN as a surrogate, I have been playing with diffeqflux.jl using the Lotka Volterra toy model as a proof of concept. I am having difficulty training a surrogate that predicts new parameters of the model well. I have also looked at surrogate.jl but am not able to follow the code due to limited documentation.
    Do you have examples of a surrogate using tools in diffeqflux.jl?
    Lastly, do you think that it is worth trying to employ neural odes towards training a surrogate?


    • (1) Yup that’s precisely it.

      (2) “I am having difficulty training a surrogate that predicts new parameters of the model well” What do you mean by that?

      (3) “Do you have examples of a surrogate using tools in diffeqflux.jl?” Generally I keep the two separate: DiffEqFlux is more for neural networks inside of simulators, surrogates are usually neural networks around simulators. They do different things and have different purposes. Using a straight neural ODE as a surrogate doesn’t make too much sense because even after you train it, you still have to solve an ODE to get a prediction, so why not compress it a bit more? You’ve got to do something more with it to make it useful, like pick problems where a surrogate around a simulator has difficulty to train but a neural ODE is much easier to train, or do something to the neural network in the ODE. That said, there are ways I am investigating using DiffEqFlux to build forms of surrogates, but that’s still at a very early phase of the research and I’ll share when I know what works.


  3. Theo

    says:

    This is most fascinating. To clarify, when choosing the structural equation, I could supply even only one term, with a guessed parameter, and if I know that in an abstract way, there are 2, or maybe 3 other terms, that either add or subtract to the first, I would just represent those as NN —> Train the parameters until we get a good fit —> turn back into real equations using SInDy —> inverse train the coefficients to better fit further and viola? In theory – shouldn’t this actually be able to acquire waaayy better simulation models than whats used in out of the box software that don’t use data driven techniques – since not only do we discover more complex function that eluded the simplified expressions created by us humans, but we learn the kinetic coefficients etcc.. that we usually just approximate? It sounds so magical….


  4. Theo

    says:

    Hey Christopher!

    Just wanted to say this was a REALLY well written article, and goes perfectly both with your paper and the presentation you have on Youtube, which I have watched at least 3 times to really let every thing you said sink in and shared with all my colleagues. I am working on a rather big personal project of my own, and this framework seems to be a PERFECT fit to what it is that I am trying to do, however I do want to clarify somethings and seek your advice on the direction I am taking. Right now I am using Aspen HYSYS to create process models of these chemical plants to optimize them in real-time. However, the quality of the optimal setpoints don’t reflect the optimal reality of the plant that I have access to in real life, and this is due to the accuracy of the process models, despite the fact that the models I am using are a mechanistic representation of the problem at hand and using “state of the art simulators”. Another problem, as you pointed out in your video, is that these simulators can’t run in real time due to their computational complexity. My questions and point of clarifications are as follows:

    1) According to your presentation in the video titled: “Universal Differential Equations for Scientific Machine Learning” you stated that you can “wrap” a neural network around a simulator, to allow it to run in real time. My understanding is that once you know what model and equations you are using, you can just run it as many times as you want to collect a nice distribution of data, and then train a neural network on it in order to represent that data in a compressed form, which can then be run in real time because all that is being computed is the forward pass of that NN, is this correct?

    2) Since chemical plants are simply individual unit processes like distillation columns, reactors etc strung together in a process flow diagram, the focus is to first model and optimize each unit individually first, then work on the whole plant after. Given that I have access to sensor data for say a reactor, how can I use both the inverse method to find the coefficients of the ODE and the forward method to find the differential equations that properly model the process simultaneously – is this even possible? How much of the ODE of the original equation do I have to supply in order to learn the rest? Is it possible to write a super general form (something like this: flow in – flow out + production rate) and then find the ODEs that fit this? My guess is that with your brilliant framework, a far more precise model could be obtained then what is usually offered in commercial simulators because the reaction kinetic coefficients etc.. are experimentally determined for other contexts but are used as rough approximations for the users application, and the diff equations used for the dynamic simulator were on some level hand crafted by an expert, but also simplified based on whatever limited understanding or complexity the originator could capture. Is my line of thought correct here – that you could learn these models specific to the reality of the situation per process unit per site? I would be deeply grateful for your insight into this in terms of how one were to start, given access to data to one of these units. Really looking forward to hearing from you


    • 1) Yup, that’s it.

      2) The more flexibility you give, the less constraint it has. It’s always a trade-off. In some cases the neural network can compensate for existing terms. For example, if you +NN, then the neural network could extrapolate the known terms, so a possible fit would be to have the other parameters as zero. So if you want interpretability of the parameter values, then you don’t want to estimate them at the same time as the neural network. That said, if you estimate the neural network at the same time as the parameters and then post SInDy to get back to structured equations like we do in Lotka-Volterra, that is fine and what it will do is be easier to fit than just a neural network since you’re essentially giving it a head start in the training process. What you’d to there is (pick some structural equations and guess some parameters) -> (train all parameters) -> (transform the NN to equations via SInDy) -> (do parameter estimation on the found structure to fix parameter values) and that should work out pretty well from what I’ve seen.


  5. Tom Dupree

    says:

    Hi Chris, thanks for this post! I found the paper very interesting and really relevant to some work I’m doing at the moment (identifying the dynamics of SCF CO2 extraction). But I had no idea how I was going to implement these ideas, this post has given me starting point.


  6. Mahdi

    says:

    Dear Christopher,

    Thanks for the very interesting post. I am using deep neural networks to learn the solution of partial differential equations in a framework that I call theory-training (my paper can be found here: https://arxiv.org/abs/1912.09800). Theory-trained networks (TTNs) do not nee ANY external data for training. They can self-train! I am now trying to develop TTNs that can extrapolate.

    I have a question and I would appreciate if you comment: do you (somehow) solve your differential equations in the regions that you extrapolate?

    Best,
    Mahdi


    • Interesting work. I do not see how TTNs (or PINNs) can have the same property though, since once you represent the solution of a PDE by a neural network and utilize theory only in the cost function, the network itself doesn’t have the structure and so it’s only guaranteed to satisfy the structure where you trained it. This is why I started moving in this direction, essentially using classical numerics to enforce structure and having as little as possible enforced by a neural network.


  7. arita37

    says:

    Wonderful post.
    Does it called structural learning ?


  8. Srikanth Sivaramakrishnan

    says:

    Thank you for this article. I’ve always wondered how one can incorporate some structural knowledge into machine learning instead of going full-on black box.

    Is it possible to learn the system without a direct measurement of the states of the ODE but using its outputs (y) which are a function of the states (xdot = f(x,u), y = g(x,u)) ?

    Is there some criteria like observability that indicates what are the minimum measurements needed to learn a given ODE system ?


    • >Is it possible to learn the system without a direct measurement of the states of the ODE but using its outputs (y) which are a function of the states (xdot = f(x,u), y = g(x,u)) ?

      Do you already know the function `g`? If not, I’m not sure how you’ll decouple that.

      >Is there some criteria like observability that indicates what are the minimum measurements needed to learn a given ODE system ?

      We need to do a lot more numerical analysis to fully quantify uncertainties and really know bounds on how well we can learn things. This is just the start of what will likely be a long research journey.


Write a Reply or Comment

Your email address will not be published. Required fields are marked *


*

This site uses Akismet to reduce spam. Learn how your comment data is processed.