Skip to main content

Model Training Guide

This guide explains how to retrieve datasets for training inside Michelangelo workflows and how to optionally scale training using RayTask and the Lightning Trainer SDK.

The focus is simplicity: you control your training logic, Michelangelo provides the dataset plumbing and optional distributed compute.

What You'll Learn

  • How datasets are passed to training tasks
  • How to load Ray, Pandas, or Spark datasets
  • How to scale training with Ray workers
  • How to use the Lightning Trainer SDK for deep learning

Understanding Training Inputs

Michelangelo workflows pass datasets using DatasetVariable.

A DatasetVariable may contain:

  • Ray Dataset (recommended for distributed training)

  • Pandas DataFrame (small/local datasets)

  • Spark DataFrame (large-scale preprocessing)

Access the dataset inside a training task using:

dataset = train_dv.value

Dataset Formats

FormatWhen It AppearsHow to Use It
Ray DatasetFrom data prep tasks using RayBest for distributed training
Pandas DataFrameLocal CSV or small dataConvert to tensors directly
Spark DataFrameSpark preprocessing stepConvert to Pandas or Ray before training

Simple Training Example

For basic (scikit-learn, lightweight PyTorch) training, load your dataset directly:

import michelangelo.uniflow.core as uniflow
from michelangelo.sdk.workflow.variables import DatasetVariable

@uniflow.task()
def train_model(train_dv: DatasetVariable, val_dv: DatasetVariable):
"""Simple training with scikit-learn"""

# Load datasets - works with Ray, Pandas, or Spark
train_df = train_dv.value.to_pandas()
val_df = val_dv.value.to_pandas()

feature_cols = [col for col in train_df.columns if col != 'target']
X_train, y_train = train_df[feature_cols], train_df['target']
X_val, y_val = val_df[feature_cols], val_df['target']

from sklearn.ensemble import RandomForestRegressor
model = RandomForestRegressor(n_estimators=100)
model.fit(X_train, y_train)

score = model.score(X_val, y_val)
print(f"Validation R² score: {score:.3f}")
return model

Distributed Training with Lightning Trainer SDK

To scale training across CPUs/GPUs, wrap your training task using RayTask.

Example: Distributed Deep Learning with Ray Workers

from michelangelo.sdk.trainer.torch.pytorch_lightning.lightning_trainer import (
LightningTrainer, LightningTrainerParam
)
from michelangelo.uniflow.plugins.ray import RayTask
from michelangelo.maf.ray.train import create_run_config, create_scaling_config
from ray.train import CheckpointConfig

@uniflow.task(
config=RayTask(
head_cpu=2,
head_memory="8Gi",
worker_cpu=4,
worker_memory="16Gi",
worker_instances=4,
)
)
def train_distributed_model(
train_dv: DatasetVariable,
validation_dv: DatasetVariable,
model_name: str = "gpt2",
num_epochs: int = 3,
batch_size: int = 32,
learning_rate: float = 5e-5,
num_workers: int = 4,
use_gpu: bool = True,
):
"""Distributed training using Ray Lightning"""

train_dv.load_ray_dataset()
train_data = train_dv.value

validation_dv.load_ray_dataset()
validation_data = validation_dv.value

# Scaling config
scaling_config = create_scaling_config(
trainer_cpu=2,
cpu_per_worker=4,
num_workers=num_workers,
use_gpu=use_gpu,
)

# Run config with checkpointing
run_config = create_run_config(
name=f"distributed-training-{model_name}",
checkpoint_config=CheckpointConfig(
num_to_keep=1,
checkpoint_score_attribute="val_loss",
checkpoint_score_order="min",
),
)

# Lightning trainer parameters
trainer_param = LightningTrainerParam(
create_model=create_model_function,
model_kwargs={
"model_name": model_name,
"learning_rate": learning_rate,
},
train_data=train_data,
validation_data=validation_data,
batch_size=batch_size,
num_epochs=num_epochs,
lightning_trainer_kwargs={
"precision": "16-mixed",
"log_every_n_steps": 10,
"val_check_interval": 0.25,
},
)

trainer = LightningTrainer(trainer_param)
return trainer.train(run_config, scaling_config)

What Ray Handles for You

  • Worker creation
  • Dataset sharding
  • Parallel batch execution
  • GPU scheduling
  • Automatic checkpointing
  • Fault recovery

Benefits of the Lightning Trainer SDK

BenefitDescription
Automatic dataset shardingNo manual sampler or dataloader
Automatic distributed setupMulti-node, multi-GPU ready
Automatic checkpointingLightning + model weights saved
Minimal boilerplateFocus on model logic, not infrastructure

You do not need to implement:

  • dataloaders
  • DDP or multiprocessing

The SDK automates all distributed concerns.

Best Practices

  • Use trainer SDK for distributed deep learning
  • Start small, then scale
  • Track experiments consistently
  • Tune compute resources for your model

Avoid

  • Manual distributed loops unless necessary
  • Training without validation datasets
  • Ignoring memory/CPU/GPU limits

Next Steps

Your models are now ready to move forward:

  • Continue to Model Registry to save and version
  • Continue to Model Deployment (Coming Soon) for inference

Troubleshooting

  • Out of memory? Lower batch size or increase memory
  • Slow training? Increase workers or enable GPU
  • Loss not converging? Verify preprocessing and learning rate