BINARY JUNIPR : An Interpretable Probabilistic Model for Discrimination

JUNIPR is an approach to unsupervised learning in particle physics that scaffolds a probabilistic model for jets around their representation as binary trees. Separate JUNIPR models can be learned for different event or jet types, then compared and explored for physical insight. The relative probabilities can also be used for discrimination. In this Letter, we show how the training of the separate models can be refined in the context of classification to optimize discrimination power. We refer to this refined approach as BINARY JUNIPR . BINARY JUNIPR achieves state-of-the-art performance for quark-gluon discrimination and top tagging. The trained models can then be analyzed to provide physical insight into how the classification is achieved. As examples, we explore differences between quark and gluon jets and between gluon jets generated with two different simulations.

Modern machine learning has already made impressive contributions to particle physics. Convolutional [1][2][3][4][5][6], recurrent and recursive networks [7][8][9][10][11], autoencoders [12][13][14][15], adversarial networks [16][17][18], and more have been shown effective in applications including quark-gluon jet discrimination, top tagging, and pileup removal. A key question that is beginning to be addressed is this: what is the optimal representation of the information in an event? Is it through analogy with images [1,2], natural-language processing [8,11], or set theory [19,20]? In many of these approaches, there is a competition between effectiveness in some task (e.g., pileup removal, jet classification) and interpretability of the neural network. An approach to machine learning for particle physics called JUNIPR [21] builds a separate network for each jet type using a physical representation of the information in the jet: the jet clustering tree. In [21] a method for construction and training of such a network was introduced. In this Letter, we show how the JUNIPR framework can be used in discrimination tasks, achieving state-of-the art classification power while maintaining physical interpretability. JUNIPR begins by taking each jet in some sample and clustering it into a binary tree according to some deterministic algorithm. See Fig. 2 below for an example of such a tree. The algorithm can be physically motivated (like the k T [22] or Cambridge-Aachen [23] algorithms) but does not have to be. In such a tree, the momenta of each mother branch is the sum of the momenta of her daughters. We denote the momenta of the particles in the jet by fp 1 ; …; p n g and the momenta in the clustering tree by fk ðtÞ 1 ; …; k ðtÞ t g at branching step t. To be concrete, at t ¼ 1 we have k ; …; k ðtþ1Þ tþ1 g involves a single 1 → 2 momentum splitting. JUNIPR learns to compute the probability P J ðjetÞ of the jet, meaning the probability that the corresponding set of final state momenta fp 1 ; …; p n g would be found in the given sample. This probability can be factorized as a product over branching steps in the clustering tree: To learn these probability distributions, JUNIPR introduces a quantity h ðtÞ as a representation of fk ðtÞ 1 ; …; k ðtÞ t g, i.e., the "state" of the jet at branching step t. JUNIPR learns to compute h ðtÞ in training. In machine-learning language, h ðtÞ is the autoregressive latent variable, which in our implementations is taken to be the latent state of a recurrent neural network. Then we can write, e.g., Here, P end ðfalsejh ðtÞ Þ is the binary probability that the clustering tree does not end at branching step t, P mother ðm ðtÞ jh ðtÞ Þ is the discrete probability that treemomentum k ðtÞ m will participate in the 1 → 2 branching at step t, and P branch ðk Even if there were no evidence for this factorization in the training data (as was explored with "printer jets" in [21]), JUNIPR would still learn the probability distributions, but physical interpretability would be lost.
In [21], JUNIPR was trained to model jet dynamics via unsupervised learning. In that approach, the probabilistic model is learned by maximizing the log likelihood of P J over the training data: where the sum is over jets fp 1 ; …; p n g in the training set. We call this the "unary objective function." Despite being unsupervised, this approach can be used to discriminate between two jet types, say a and b. To accomplish this, one trains two separate JUNIPR models: P J ðjetjaÞ on a data set containing predominantly type-a jets and P J ðjetjbÞ on predominantly type-b jets. Discrimination between a and b is then achieved by thresholding the likelihood ratio P J ðjetjaÞ=P J ðjetjbÞ. While discrimination by likelihood ratio is theoretically optimal in the perfect-model limit, it has been shown that deep neural networks classify out-of-distribution data poorly [24,25]. That is, e.g., the P J ðjetjaÞ model is not expected to behave well on type-b jets. It is thus advantageous in practice to refine the training for discrimination. By training directly for discrimination, JUNIPR can also focus model capacity on learning the often-subtle differences between type-a and type-b jets. In fact, JUNIPR's probabilistic nature makes supervised discrimination learning very straightforward. Assuming a mixed sample of both jet types, the probability that a given jet drawn at random belongs to class a is, through Bayes's theorem, given by For binary discrimination, PðajjetÞ þ PðbjjetÞ ¼ 1, so Here PðaÞ and PðbÞ are simply the composition fractions f a and f b of the mixed sample, while PðjetjaÞ and PðjetjbÞ can be computed using two separate JUNIPR networks as laid out in the paragraphs above. This leads directly to the binary cross-entropy objective function one should use to train JUNIPR for discrimination: where the sums extend over type-a and type-b jets in the training data, respectively. We call training with this objective function "BINARY JUNIPR." Note that BINARY JUNIPR still learns the probabilities for type-a and type-b jets and still trains the same neural-network functions; however, it uses a more effective objective function for discrimination applications. We also note that training can easily be generalized to multiclass classification. As a test of the advantage that the binary objective function provides over its unary counterpart, we applied BINARY JUNIPR to the discrimination of quark and gluon jets. We used a mixed sample of 10 6 PYTHIA quark jets and 10 6 PYTHIA gluon jets from the data set at [26], see also [19,27]. We set aside 10 5 jets of each type into a test set, 10 5 for validation, and used the remaining 80% of the jets for training. For the JUNIPR models, P J ðjetjquarkÞ and P J ðjetjgluonÞ, we used an LSTM of dimension 30 to model h ðtÞ and separate feed-forward networks, each with a single hidden layer of dimension 10, to model P end , P mother , and P branch . (The BINARY JUNIPR architecture is available at [28] with example code.) We began by pretraining the two JUNIPR models using the original unary objective function of Eq. (4). We followed the same training schedule as in [21], but scaled down the number of epochs by a factor of 5 because this data set is larger than the one used there. Pretraining took PHYSICAL REVIEW LETTERS 123, 182001 (2019) 182001-2 about five hours on a 16-core CPU server for each model. After pretraining, we optimized the binary objective function of Eq. (7) using Adam with standard settings [29] and the following batch-size schedule: This segment of training took 12 hours on a 16-core CPU server. BINARY JUNIPR parameters were decided upon by evaluating the AUC (area under the ROC curve) on the validation set 10 times per epoch and choosing the model that achieved the maximal AUC during the final 10 training epochs. Note that different hyperparameters might be appropriate for different applications.
In Fig. 1 we show the quark-versus-gluon Significance Improvement Curve [30] (SIC), (ε S = ffiffiffiffiffi ε B p ), achieved by BINARY JUNIPR and compare it to recent results with previous state-of-the-art discriminants: a CNN approach based on jet images [3] (with architecture from [19]) and particle flow networks [19]. One can see that BINARY JUNIPR offers a small-but-significant advantage. Quantitatively, BINARY JUNIPR achieves an AUC of 0.8986 AE 0.0004, as compared to 0.8911 AE 0.0008 for particle flow networks, and 0.8799 AE 0.0008 for the CNN. (Each reported number is the mean and semiinterquartile range over 10 trainings.) Unary JUNIPR, trained with Eq. (4), performs significantly worse than the other methods, achieving an AUC of 0.6968 AE 0.0008. This demonstrates the importance of training JUNIPR with the binary objective function of Eq. (7) for classification.
As a second experiment, we trained and tested BINARY JUNIPR for boosted top-jet identification. We used the same architecture and training schedule that were optimized for quark-versus-gluon discrimination. In doing so, we obtain a sense of the performance one might expect from BINARY JUNIPR without specialized hyperparameter tuning. The training, validation, and test data for this experiment are taken from [7]. We found that untuned BINARY JUNIPR comes close to state-of-the-art top discrimination. Specifically, JUNIPR achieves an AUC of 0.9810 AE 0.0002 as compared to 0.9819 AE 0.0001 attained using particle flow networks [19], and 0.9848 reported for ParticleNet [20]; all significantly outperform traditional boosted top-tagging methods [31]. For a recent overview of machine learning in top tagging, see [32].
Next we discuss the interpretability of JUNIPR models. As discussed below Eq. (3), each component of JUNIPR's output has a well-defined physical meaning. Moreover, the output is structured along a physically motivated binary tree, defined by clustering the momenta in a jet. One can thus decompose JUNIPR's prediction, say P J ðtopjjetÞ as in Eq. (6), visually along the clustering tree. In Fig. 2, we show the clustering tree for an easily classifiable top jet drawn from the mixed top-QCD test set. In the figure, we label the tth node with P ðtÞ ðtopjjetÞ, i.e., the probability that the jet is top type, given only the information present at branching step t; this is computed with BINARY JUNIPR by substituting Eq. (3) into Eq. (6). One can see, for example, that the three-prong structure characteristic of t → W þ b → udb contributes to large P J ðtopjjetÞ. Quantitatively, this results in the two hard branchings, with P ðtÞ ðtopjjetÞ ¼ 0.72 and 0.71, dominating the prediction. By analyzing such trees, one can develop intuition for which branchings are most decisive in classifying different types of jets.
To be concrete, let us return to the BINARY JUNIPR model used to create Fig. 1, which learned to discriminate quark and gluon jets from PYTHIA. Much is already known about the difference between quark and gluon jets: gluon jets are known to be bigger, with larger multiplicity and larger shape parameters such as mass and width [33,34]. Although many methods exist for quark-gluon discrimination, including other machine-learning approaches [3,11], it is not clear how well these methods will work on actual data. In particular, it is known that real gluon jets are more similar to real quark jets than PYTHIA leads us to believe [35]. In particular, it is the modeling of gluon jets that seems most inaccurate. An alternative generator, HERWIG, produces and gluon jets that are more similar to its quark jets [36]. Thus, we also considered a secondary challenge: determine how PYTHIA and HERWIG gluon jets differ. To explore their differences, we trained a second BINARY JUNIPR model to discriminate PYTHIA8.226 and HERWIG7.1.4 gluon jets using 10 6 samples of each from [27,37]. Figure 3 shows another visualization, complementary to Fig. 2, of exploring how JUNIPR discriminates. The top row of Fig. 3 shows how JUNIPR separates PYTHIA quarks from PYTHIA gluons, and the bottom row shows how JUNIPR separates PYTHIA gluons from HERWIG gluons. In the middle column, the overall probability that JUNIPR uses for discrimination is decomposed into branching steps t, averaged over all jets of the given class. From this, we see ffiffiffiffiffi ε G p ) as a function of ε Q for quark-gluon discrimination. BINARY JUNIPR is compared to a particle flow network [19], a CNN using jet images [19], constituent multiplicity, and unary JUNIPR. that near t ¼ 20-50 there is roughly three times the quarkgluon discrimination power per branching step as for t ¼ 1-10. This echoes the well-known fact that multiplicity allows one to separate quark and gluon jets better than perturbatively calculable observables sensitive to only the first few splittings [33]. The lower-middle plot shows that differences in PYTHIA and HERWIG gluon jets are more uniformly spread over branching steps.
Not only can JUNIPR break discrimination power down into branching steps; JUNIPR can further decompose classification probability into components at each branching.
These components are displayed in the right column of Fig. 3; there are discrete components, such as whether branchings should end, as well as the energy z and angles θ, ϕ, δ of the branching itself. While multiplicity (P end ) is the main driver of performance for quark-gluon discrimination, the angle θ also contributes significantly over a wide range of branchings, echoing the importance of jet width in this context. For the PYTHIA-HERWIG task, both the angle θ and energy fraction z play a significant role in discrimination on early branchings, and multiplicity becomes important on later branchings.
It is interesting that a significant fraction of the difference between PYTHIA and HERWIG results from the way energy and angles are distributed early on in the clustering trees. Early branchings are controlled primarily by perturbative elements of the simulated parton showers. This suggests that a substantial portion of the difference between PYTHIA and HERWIG gluon jets may be driven by the parton-shower implementations, rather than exclusively by the modeling of nonperturbative effects. To gain further insight into the importance of nonperturbative effects like hadronization in discrimination, JUNIPR could be upgraded to include quantum numbers of final state particles-a straightforward next step.
In [21], JUNIPR was introduced as a new framework for unsupervised machine learning in particle physics that prioritizes interpretability. Given a jet, i.e., a set of momenta, JUNIPR learns to compute the probability of that jet, i.e., how consistent the distribution of momenta is with the training data. In this Letter, we used the same probabilistic framework as in [21], but we augmented the training to learn subtle differences between two FIG. 2. BINARY JUNIPR tree for jet drawn from mixed top-QCD test set. BINARY JUNIPR predicts "top" with high probability: P J ðtopjjetÞ ¼ 0.99. Each node is labeled with the probability that the jet is top type, given only the information at that branching. Planar angles correspond to 3D opening angles between clustered momenta, and color corresponds to energy. The final factor corresponding to the tree's true end is not shown: P ðnÞ ¼ 0.52; see Eq. (2).
FIG. 3. Quark-gluon (top row) and PYTHIA-HERWIG discrimination (bottom row) with BINARY JUNIPR. Here we will refer to quark jets and PYTHIA jets as "signal," and to gluon jets and HERWIG jets as "background." The left column shows the binary probability with which JUNIPR predicts each jet is a signal jet. The middle column breaks these probabilities down by branching step in the clustering tree. Specifically, the plots show the ratio of P t ðsignaljjetÞ, averaged over signal jets in the numerator and background jets in the denominator. The right column breaks these ratios down further by branching component. samples, an enhancement we call BINARY JUNIPR. We demonstrated both its effectiveness and interpretability, using quark-gluon jets, boosted top jets, and Monte Carlo generator dependence as examples. It is satisfying that demanding interpretability does not lead to a loss in effectiveness: BINARY JUNIPR discriminates at levels competitive with the best machine-learning methods available.
While these case studies were all simulation based, there is a straightforward path to repeating these exercises on collider data. Although real data do not come with truth labels, there are established methods for working with mixed samples [38,39] that can be adapted to JUNIPR without much modification. Then one could use a data or simulation BINARY JUNIPR model to understand deficiencies in simulations. One could also use insights derived from BINARY JUNIPR trees to judge whether predictions should be trusted experimentally (was information below experimental resolution deemed important?) or to design new calculable observables (sensitive to previously overlooked decisive branchings). Having interpretable methods opens the door to whole new approaches to understanding data from particle colliders.
We thank P. Komiske and A. Parthak for assistance with the samples. This work is supported in part by the U.S. Department of Energy under Contract No. DE-SC0013607. Computations were performed on the CORI supercomputing resources at NERSC, Harvard's Odyssey cluster, and the Faculty Platform for machine learning.