#!/bin/bash
#SBATCH --job-name=vllm-serve
#SBATCH --nodes=2
#SBATCH --gpus=8
#SBATCH --time=4:00:00
#SBATCH --exclusive
#SBATCH --output=out/%x.%j.out

module reset
module load brics/nccl
module list

source .venv/bin/activate

echo SERVING ON $HOSTNAME

SERVER_ADDRESS=$(dig +short ${HOSTNAME}-hsn0)
HEAD_NODE=$(scontrol show hostnames $SLURM_NODELIST | head -n1)
WORKER_NODES=$(scontrol show hostnames $SLURM_NODELIST | tail -n+2)
HEAD_NODE_IP=$(dig +short ${HEAD_NODE})
RAY_PORT=6378
RAY_ADDRESS=$HEAD_NODE_IP:$RAY_PORT

export TIKTOKEN_ENCODINGS_BASE="/projects/public/brics/distributed_vllm/etc/encodings"
export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_ALLREDUCE_USE_SYMM_MEM=0
export VLLM_NCCL_USE_SYMM_MEM=0

# Start the vLLM server in the background
echo "Starting head node $HEAD_NODE..."
srun \
    --nodelist $HEAD_NODE \
    --nodes=1 \
    --gpus=4 \
    --cpus-per-task 72 \
    --ntasks-per-node 1 \
    bash -c "VLLM_HOST_IP=$HEAD_NODE_IP ray start --block --head --node-ip-address=$HEAD_NODE_IP --port=$RAY_PORT" &
sleep 20

echo "Starting worker nodes..."
for WORKER in $WORKER_NODES; do
    WORKER_IP=$(dig +short ${WORKER})
    echo "Starting worker node: $WORKER with IP $WORKER_IP"

    srun \
        --nodelist $WORKER \
        --nodes=1 \
        --gpus=4 \
        --cpus-per-task 72 \
        --ntasks-per-node 1 \
        bash -c "VLLM_HOST_IP=$WORKER_IP ray start --block --address=$HEAD_NODE_IP:$RAY_PORT --node-ip-address=$WORKER_IP" &
done
sleep 20

echo "Checking cluster status..."
srun \
    --overlap \
    --nodelist $HEAD_NODE \
    --nodes=1 \
    --gpus=4 \
    --ntasks-per-node 1 \
    ray status

wait
