Business Context
You are joining an OpenAI research team that has a promising single-GPU training prototype for a transformer-based text classifier used in internal model quality workflows. The prototype works on one NVIDIA A100, but the team now needs to train on a much larger corpus and reduce wall-clock training time by scaling to a multi-node cluster.
Dataset
The training job uses tokenized text examples stored in sharded Parquet files on object storage. Each row contains token IDs, attention masks, metadata, and a binary label indicating whether the sample belongs to a target quality bucket.
| Feature Group | Count | Examples |
|---|
| Token features | 2 | input_ids, attention_mask |
| Metadata | 5 | language, source_surface, prompt_length, response_length, model_family |
| Target | 1 | label |
| Split keys | 2 | shard_id, created_at |
- Size: 42M examples, sequence length up to 2048, 7 input fields
- Target: Binary — positive quality bucket (1) vs other (0)
- Class balance: 18% positive, 82% negative
- Missing data: ~6% missing in some metadata columns; token fields are complete
Success Criteria
A good solution should:
- Achieve at least 3.5x throughput scaling when moving from 1 GPU to 8 GPUs, and at least 10x on 32 GPUs
- Keep final validation AUC within 1 percentage point of the single-GPU baseline
- Train reproducibly with restart-safe checkpointing and no duplicated or skipped samples across workers
Constraints
- Training runs on a multi-node GPU cluster with preemptions possible
- GPU memory is limited; full-batch scaling is not possible
- The solution must support experiment tracking and checkpoint artifacts through OpenAI-compatible workflows
- Inference latency is not critical, but training efficiency and correctness are
Deliverables
- Design a training plan to move from single-GPU to distributed multi-node training.
- Explain data sharding, gradient synchronization, checkpointing, and failure recovery.
- Implement a distributed PyTorch training loop with
DistributedDataParallel.
- Measure throughput, convergence, and validation quality versus the single-GPU baseline.
- Propose tuning and debugging steps for poor scaling efficiency or unstable optimization.