Skip to content
/ gmda Public

This is the official implementation for the Generative Modeling Density Alignment (GMDA). This work was presented in the paper "Frugal Generative Modeling for Tabular Data" at ECML 2024.

License

Notifications You must be signed in to change notification settings

ablacan/gmda

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GMDA: Generative Modeling Density Alignment

Python 3.9 License: MIT Version

GMDA is a Python package for generative modeling with density alignment. This README provides instructions for installation, usage, and key features of the package.

Table of Contents

  1. Installation
  2. Data Processing
  3. Model Training
  4. Generating Synthetic Data
  5. Metrics
  6. Command-Line Usage

Installation

Disclaimer: Installing this package will result in the installation of a specific version of PyTorch, which may not be compatible with every user's GPU driver. Before installation, please check the compatibility of the included PyTorch version with your GPU driver. If incompatible, you should create your Python environment with the PyTorch version best suited for your system. Visit the official PyTorch website to find the appropriate installation command for your setup.

Using a Virtual Environment

python -m venv env_gmda
source env_gmda/bin/activate
pip install .

Using conda

conda create -n env_gmda python=3.9
conda activate env_gmda
pip install .[conda]

For develpoment mode, use:

pip install -e .[conda]

Data Processing

GMDA provides flexible data processing capabilities through the DataProcessor class.

from gmda.data_utils import DataProcessor

# Define custom data loading and processing functions
def custom_data_loader(train: bool = True, **kwargs):
    # Your custom data loading logic here. Should retun a tuple of tabular data (X, y).
    pass

def custom_data_processor(data, train: bool = True, **kwargs):
    # Your custom data processing logic here. Should retun a tuple of tuples of processed train and test data ((X_train, y_train), (X_test, y_test)).
    pass

# Instantiate the DataProcessor
data_processor = DataProcessor(custom_data_loader, custom_data_processor)

# Create dataloaders
train_loader, val_loader, X, y = data_processor.create_dataloaders(batch_size=64, density=0.1)

Model Training

To train a GMDA model:

from gmda.models import GMDARunner
from gmda.models.gmda.tools import get_config

# Load configuration
config = get_config('path/to/config.json')

# Initialize and train the model
model = GMDARunner(config)
model.train(train_loader, val_loader, X, config['training'])

Generating Synthetic Data

From a Trained Model:

X_synthetic, y_synthetic = model.generate(y)
X_synthetic, y_synthetic = X_synthetic.numpy(), y_synthetic.numpy()

From a Pretrained Model:

from gmda.models import generate_from_pretrained

X_synthetic, y_synthetic = generate_from_pretrained(
    y, 
    config['model'], 
    path_pretrained=model.checkpoint_dir,
    device=config['model']['device'], 
    return_as_array=True
)

Metrics

GMDA provides metrics to evaluate the quality of generated data:

from gmda.metrics import get_corr_error, get_precision_recall
import numpy as np

# Correlation Error
idx = np.random.choice(np.arange(len(X)), size=min(len(X), 1500), replace=False)
corr_error, corr_error_matrix = get_corr_error(X[idx], X_synthetic[idx])

# Precision/Recall
precision, recall = get_precision_recall(X, X_synthetic, nb_nn=config['training']['nb_nn_for_prec_recall'])

Command-Line Usage

GMDA can be run from the command line:

python main.py --dataset '<DATASET>' \
               --path_train '<PATH/TO/TRAIN/CSV>' \
               --path_test '<PATH/TO/TEST/CSV>' \
               --device 'cuda:0' \
               --config '<PATH/TO/CONFIG/JSON>' \
               --output_dir '<PATH/TO/OUTPUT/RESULTS>' \
               --compute_metrics \
               --save_generated

For more details on command-line options, run:

python main.py --help

Contributing

We welcome contributions! Please contact me for more details.

License

This project is licensed under the MIT License.

About

This is the official implementation for the Generative Modeling Density Alignment (GMDA). This work was presented in the paper "Frugal Generative Modeling for Tabular Data" at ECML 2024.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  NODES
COMMUNITY 1
Project 4
USERS 1