spot_img
HomeResearch & DevelopmentEnhancing Probabilistic Circuit Generalization with Sharpness-Aware Learning

Enhancing Probabilistic Circuit Generalization with Sharpness-Aware Learning

TLDR: A new method called “Tractable Sharpness-Aware Learning” improves Probabilistic Circuits (PCs) by guiding them to “flatter” solutions, which generalize better and prevent overfitting, especially with limited data. This is achieved by efficiently computing a measure of “sharpness” (Hessian trace) that is usually difficult for other deep learning models, and incorporating it as a regularizer during training.

In the evolving landscape of artificial intelligence, Probabilistic Circuits (PCs) stand out as a powerful class of generative models. They are highly valued for their ability to perform exact and tractable inference, meaning they can precisely calculate probabilities for a wide array of queries, a feat often challenging for other deep generative models like GANs or VAEs.

However, as PCs become deeper and more expressive, their increased capacity can lead to a common problem in machine learning: overfitting. This is particularly true when the amount of available data is limited. Overfitting occurs when a model learns the training data too well, including its noise, and consequently performs poorly on new, unseen data. Researchers have observed that this often happens when PCs converge to ‘sharp optima’ in their log-likelihood landscape—regions where the model’s performance is very sensitive to small changes in its parameters, leading to poor generalization.

Inspired by techniques used in neural networks, such as Sharpness-Aware Minimization (SAM), a new research paper introduces a novel approach to tackle this issue for Probabilistic Circuits. The core idea is to regularize the training process using a Hessian-based method. The Hessian matrix, which contains second-order partial derivatives of a loss function, is a natural way to quantify ‘flatness’ or ‘sharpness’ in the model’s performance landscape. A ‘flat minimum’ is generally associated with better generalization.

A significant contribution of this work is the discovery that the trace of the Hessian of the log-likelihood—a measure of overall curvature—can be computed efficiently for PCs. This is a crucial distinction from deep neural networks, where such exact Hessian computations are typically intractable. By minimizing this Hessian trace during training, the model is encouraged to converge to flatter minima.

The researchers show that minimizing the Hessian trace effectively induces a gradient-norm-based regularizer. For Expectation-Maximization (EM) based learning, a common method for training PCs, this leads to simple, closed-form parameter updates, making the approach scalable and easy to integrate into existing training pipelines. For gradient-based learning methods, it seamlessly integrates as an additional penalty term.

Experimental results on both synthetic and real-world datasets demonstrate the effectiveness of this new method. The models trained with this Hessian-based regularizer consistently converge to flatter minima, which in turn significantly improves their generalization performance. This is especially evident in scenarios with limited data, where the regularizer dramatically reduces overfitting. For instance, on synthetic datasets with only 1% of the data, the method cut overfitting by up to 65% and flattened the loss surface by 89%, boosting test log-likelihood by up to 49% on average. Similar improvements were observed on real-world binary datasets.

The paper also empirically validates the correctness and efficiency of their Hessian trace computation. While standard automatic differentiation tools suffer from exponential runtime increases with model depth, their closed-form formula scales linearly, proving to be a more practical and accurate way to analyze curvature in deep PCs.

Also Read:

This work, detailed in the paper Tractable Sharpness-Aware Learning of Probabilistic Circuits, opens up a promising new direction for understanding and improving the training of Probabilistic Circuits by focusing on the geometry of their log-likelihood surface. It suggests future research could explore asymmetric valleys in PC landscapes, develop theoretical frameworks for over-parameterized PCs, and design new optimization strategies leveraging this tractable second-order information.

Nikhil Patel
Nikhil Patelhttps://blogs.edgentiq.com
Nikhil Patel is a tech analyst and AI news reporter who brings a practitioner's perspective to every article. With prior experience working at an AI startup, he decodes the business mechanics behind product innovations, funding trends, and partnerships in the GenAI space. Nikhil's insights are sharp, forward-looking, and trusted by insiders and newcomers alike. You can reach him out at: [email protected]

- Advertisement -

spot_img

Gen AI News and Updates

spot_img

- Advertisement -