ai-research: name: ray-train
Installation
Details
Usage
After installing, this skill will be available to your AI coding assistant.
Verify installation:
skills listSkill Instructions
name: ray-train description: Distributed training orchestration across clusters. Scales PyTorch/TensorFlow/HuggingFace from laptop to 1000s of nodes. Built-in hyperparameter tuning with Ray Tune, fault tolerance, elastic scaling. Use when training massive models across multiple machines or running distributed hyperparameter sweeps. version: 1.0.0 author: Orchestra Research license: MIT tags: [Ray Train, Distributed Training, Orchestration, Ray, Hyperparameter Tuning, Fault Tolerance, Elastic Scaling, Multi-Node, PyTorch, TensorFlow] dependencies: [ray[train], torch, transformers]
Ray Train - Distributed Training Orchestration
Quick start
Ray Train scales machine learning training from single GPU to multi-node clusters with minimal code changes.
Installation:
pip install -U "ray[train]"
Basic PyTorch training (single node):
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
import torch
import torch.nn as nn
# Define training function
def train_func(config):
# Your normal PyTorch code
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Prepare for distributed (Ray handles device placement)
model = train.torch.prepare_model(model)
for epoch in range(10):
# Your training loop
output = model(torch.randn(32, 10))
loss = output.sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Report metrics (logged automatically)
train.report({"loss": loss.item(), "epoch": epoch})
# Run distributed training
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(
num_workers=4, # 4 GPUs/workers
use_gpu=True
)
)
result = trainer.fit()
print(f"Final loss: {result.metrics['loss']}")
That's it! Ray handles:
- Distributed coordination
- GPU allocation
- Fault tolerance
- Checkpointing
- Metric aggregation
Common workflows
Workflow 1: Scale existing PyTorch code
Original single-GPU code:
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
Ray Train version (scales to multi-GPU/multi-node):
from ray.train.torch import TorchTrainer
from ray import train
def train_func(config):
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
# Prepare for distributed (automatic device placement)
model = train.torch.prepare_model(model)
dataloader = train.torch.prepare_data_loader(dataloader)
for epoch in range(epochs):
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
# Report metrics
train.report({"loss": loss.item()})
# Scale to 8 GPUs
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=8, use_gpu=True)
)
trainer.fit()
Benefits: Same code runs on 1 GPU or 1000 GPUs
Workflow 2: HuggingFace Transformers integration
from ray.train.huggingface import TransformersTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
def train_func(config):
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Training arguments (HuggingFace API)
training_args = TrainingArguments(
output_dir="./output",
num_train_epochs=3,
per_device_train_batch_size=8,
learning_rate=2e-5,
)
# Ray automatically handles distributed training
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
trainer.train()
# Scale to multi-node (2 nodes × 8 GPUs = 16 workers)
trainer = TransformersTrainer(
train_func,
scaling_config=ScalingConfig(
num_workers=16,
use_gpu=True,
resources_per_worker={"GPU": 1}
)
)
result = trainer.fit()
Workflow 3: Hyperparameter tuning with Ray Tune
from ray import tune
from ray.train.torch import TorchTrainer
from ray.tune.schedulers import ASHAScheduler
def train_func(config):
# Use hyperparameters from config
lr = config["lr"]
batch_size = config["batch_size"]
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model = train.torch.prepare_model(model)
for epoch in range(10):
# Training loop
loss = train_epoch(model, optimizer, batch_size)
train.report({"loss": loss, "epoch": epoch})
# Define search space
param_space = {
"lr": tune.loguniform(1e-5, 1e-2),
"batch_size": tune.choice([16, 32, 64, 128])
}
# Run 20 trials with early stopping
tuner = tune.Tuner(
TorchTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
),
param_space=param_space,
tune_config=tune.TuneConfig(
num_samples=20,
scheduler=ASHAScheduler(metric="loss", mode="min")
)
)
results = tuner.fit()
best = results.get_best_result(metric="loss", mode="min")
print(f"Best hyperparameters: {best.config}")
Result: Distributed hyperparameter search across cluster
Workflow 4: Checkpointing and fault tolerance
from ray import train
from ray.train import Checkpoint
def train_func(config):
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
# Try to resume from checkpoint
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
state = torch.load(f"{checkpoint_dir}/model.pt")
model.load_state_dict(state["model"])
optimizer.load_state_dict(state["optimizer"])
start_epoch = state["epoch"]
else:
start_epoch = 0
model = train.torch.prepare_model(model)
for epoch in range(start_epoch, 100):
loss = train_epoch(model, optimizer)
# Save checkpoint every 10 epochs
if epoch % 10 == 0:
checkpoint = Checkpoint.from_directory(
train.get_context().get_trial_dir()
)
torch.save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch
}, checkpoint.path / "model.pt")
train.report({"loss": loss}, checkpoint=checkpoint)
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=8, use_gpu=True)
)
# Automatically resumes from checkpoint if training fails
result = trainer.fit()
Workflow 5: Multi-node training
from ray.train import ScalingConfig
# Connect to Ray cluster
ray.init(address="auto") # Or ray.init("ray://head-node:10001")
# Train across 4 nodes × 8 GPUs = 32 workers
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(
num_workers=32,
use_gpu=True,
resources_per_worker={"GPU": 1, "CPU": 4},
placement_strategy="SPREAD" # Spread across nodes
)
)
result = trainer.fit()
Launch Ray cluster:
# On head node
ray start --head --port=6379
# On worker nodes
ray start --address=<head-node-ip>:6379
When to use vs alternatives
Use Ray Train when:
- Training across multiple machines (multi-node)
- Need hyperparameter tuning at scale
- Want fault tolerance (auto-restart failed workers)
- Elastic scaling (add/remove nodes during training)
- Unified framework (same code for PyTorch/TF/HF)
Key advantages:
- Multi-node orchestration: Easiest multi-node setup
- Ray Tune integration: Best-in-class hyperparameter tuning
- Fault tolerance: Automatic recovery from failures
- Elastic: Add/remove nodes without restarting
- Framework agnostic: PyTorch, TensorFlow, HuggingFace, XGBoost
Use alternatives instead:
- Accelerate: Single-node multi-GPU, simpler
- PyTorch Lightning: High-level abstractions, callbacks
- DeepSpeed: Maximum performance, complex setup
- Raw DDP: Maximum control, minimal overhead
Common issues
Issue: Ray cluster not connecting
Check ray status:
ray status
# Should show:
# - Nodes: 4
# - GPUs: 32
# - Workers: Ready
If not connected:
# Restart head node
ray stop
ray start --head --port=6379 --dashboard-host=0.0.0.0
# Restart worker nodes
ray stop
ray start --address=<head-ip>:6379
Issue: Out of memory
Reduce workers or use gradient accumulation:
scaling_config=ScalingConfig(
num_workers=4, # Reduce from 8
use_gpu=True
)
# In train_func, accumulate gradients
for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
Issue: Slow training
Check if data loading is bottleneck:
import time
def train_func(config):
for epoch in range(epochs):
start = time.time()
for batch in dataloader:
data_time = time.time() - start
# Train...
start = time.time()
print(f"Data loading: {data_time:.3f}s")
If data loading is slow, increase workers:
dataloader = DataLoader(dataset, num_workers=8)
Advanced topics
Multi-node setup: See references/multi-node.md for Ray cluster deployment on AWS, GCP, Kubernetes, and SLURM.
Hyperparameter tuning: See references/hyperparameter-tuning.md for Ray Tune integration, search algorithms (Optuna, HyperOpt), and population-based training.
Custom training loops: See references/custom-loops.md for advanced Ray Train usage, custom backends, and integration with other frameworks.
Hardware requirements
- Single node: 1+ GPUs (or CPUs)
- Multi-node: 2+ machines with network connectivity
- Cloud: AWS, GCP, Azure (Ray autoscaling)
- On-prem: Kubernetes, SLURM clusters
Supported accelerators:
- NVIDIA GPUs (CUDA)
- AMD GPUs (ROCm)
- TPUs (Google Cloud)
- CPUs
Resources
- Docs: https://docs.ray.io/en/latest/train/train.html
- GitHub: https://github.com/ray-project/ray ⭐ 36,000+
- Version: 2.40.0+
- Examples: https://docs.ray.io/en/latest/train/examples.html
- Slack: https://forms.gle/9TSdDYUgxYs8SA9e8
- Used by: OpenAI, Uber, Spotify, Shopify, Instacart
More by davila7
View allAgile product ownership toolkit for Senior Product Owner including INVEST-compliant user story generation, sprint planning, backlog management, and velocity tracking. Use for story writing, sprint planning, stakeholder communication, and agile ceremonies.
Create SEO-optimized marketing content with consistent brand voice. Includes brand voice analyzer, SEO optimizer, content frameworks, and social media templates. Use when writing blog posts, creating social media content, analyzing brand voice, optimizing SEO, planning content calendars, or when user mentions content creation, brand voice, SEO optimization, social media marketing, or content strategy.
Build complex AI systems with declarative programming, optimize prompts automatically, create modular RAG systems and agents with DSPy - Stanford NLP's framework for systematic LM programming
Multi-channel demand generation, paid media optimization, SEO strategy, and partnership programs for Series A+ startups. Includes CAC calculator, channel playbooks, HubSpot integration, and international expansion tactics. Use when planning demand generation campaigns, optimizing paid media, building SEO strategies, establishing partnerships, or when user mentions demand gen, paid ads, LinkedIn ads, Google ads, CAC, acquisition, lead generation, or pipeline generation.
