#!/usr/bin/env bash set -euo pipefail GPU_COUNT=2 MODEL_NAME="${MODEL_NAME:-Qwen/Qwen3.5-2B}" OUTPUT_DIR="" RESUME_FROM_CHECKPOINT="" EXTRA_ARGS=() while [[ $# -gt 0 ]]; do case "$1" in --gpu-count) GPU_COUNT="$2" shift 2 ;; --model-name) MODEL_NAME="$2" shift 2 ;; --output-dir) OUTPUT_DIR="$2" shift 2 ;; --resume-from-checkpoint) RESUME_FROM_CHECKPOINT="$2" shift 2 ;; *) EXTRA_ARGS+=("$1") shift ;; esac done case "$GPU_COUNT" in 0) export CUDA_VISIBLE_DEVICES=0 NPROC=1 DEFAULT_GRAD_ACC=16 ;; 1) export CUDA_VISIBLE_DEVICES=1 NPROC=1 DEFAULT_GRAD_ACC=16 ;; 2) export CUDA_VISIBLE_DEVICES=0,1 NPROC=2 DEFAULT_GRAD_ACC=8 ;; *) echo "Unsupported --gpu-count: $GPU_COUNT (expected 0, 1, or 2)" >&2 exit 1 ;; esac TRAIN_FILE="data/train_sft.jsonl" DEV_FILE="data/dev_sft.jsonl" MAX_LENGTH="${MAX_LENGTH:-4096}" MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-3584}" MAX_COMPLETION_LENGTH="${MAX_COMPLETION_LENGTH:-512}" GRAD_ACC="${GRAD_ACC:-$DEFAULT_GRAD_ACC}" LEARNING_RATE="${LEARNING_RATE:-1e-5}" NUM_EPOCHS="${NUM_EPOCHS:-1.0}" TRAIN_BS="${TRAIN_BS:-1}" EVAL_BS="${EVAL_BS:-1}" LORA_R="${LORA_R:-16}" LORA_ALPHA="${LORA_ALPHA:-32}" LORA_DROPOUT="${LORA_DROPOUT:-0.05}" SAVE_STEPS="${SAVE_STEPS:-50}" EVAL_STEPS="${EVAL_STEPS:-50}" LOGGING_STEPS="${LOGGING_STEPS:-5}" WARMUP_RATIO="${WARMUP_RATIO:-0.03}" WEIGHT_DECAY="${WEIGHT_DECAY:-0.0}" ATTN_IMPL="${ATTN_IMPL:-sdpa}" OPTIM_NAME="${OPTIM_NAME:-paged_adamw_8bit}" SAFE_MODEL_TAG="$(echo "$MODEL_NAME" | tr '/:' '__')" OUTPUT_DIR="${OUTPUT_DIR:-outputs/${SAFE_MODEL_TAG}_full_log_sft}" export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" export TOKENIZERS_PARALLELISM="${TOKENIZERS_PARALLELISM:-false}" export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$CONDA_PREFIX/targets/x86_64-linux/lib:$CONDA_PREFIX/lib/python3.12/site-packages/nvidia/cu13/lib:${LD_LIBRARY_PATH:-}" RESUME_ARGS=() if [[ -n "$RESUME_FROM_CHECKPOINT" ]]; then RESUME_ARGS=(--resume-from-checkpoint "$RESUME_FROM_CHECKPOINT") fi echo "GPU_COUNT=$GPU_COUNT CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES NPROC=$NPROC" echo "MODEL_NAME=$MODEL_NAME" echo "MAX_LENGTH=$MAX_LENGTH MAX_PROMPT_LENGTH=$MAX_PROMPT_LENGTH MAX_COMPLETION_LENGTH=$MAX_COMPLETION_LENGTH" echo "OUTPUT_DIR=$OUTPUT_DIR" torchrun --nproc_per_node="$NPROC" scripts/train_sft.py \ --model-name "$MODEL_NAME" \ --train-file "$TRAIN_FILE" \ --dev-file "$DEV_FILE" \ --output-dir "$OUTPUT_DIR" \ --learning-rate "$LEARNING_RATE" \ --num-train-epochs "$NUM_EPOCHS" \ --per-device-train-batch-size "$TRAIN_BS" \ --per-device-eval-batch-size "$EVAL_BS" \ --gradient-accumulation-steps "$GRAD_ACC" \ --max-length "$MAX_LENGTH" \ --max-prompt-length "$MAX_PROMPT_LENGTH" \ --max-completion-length "$MAX_COMPLETION_LENGTH" \ --eval-steps "$EVAL_STEPS" \ --save-steps "$SAVE_STEPS" \ --logging-steps "$LOGGING_STEPS" \ --warmup-ratio "$WARMUP_RATIO" \ --weight-decay "$WEIGHT_DECAY" \ --lora-r "$LORA_R" \ --lora-alpha "$LORA_ALPHA" \ --lora-dropout "$LORA_DROPOUT" \ --optim "$OPTIM_NAME" \ --attn-implementation "$ATTN_IMPL" \ --bf16 \ "${RESUME_ARGS[@]}" \ "${EXTRA_ARGS[@]}"