Make it easy to profile workers with nsight (#3162)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Philipp Moritz 2024-03-03 16:19:13 -08:00 committed by GitHub
parent 996d095c54
commit 17c3103c56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 34 additions and 2 deletions

View File

@ -26,6 +26,7 @@ def main(args: argparse.Namespace):
enforce_eager=args.enforce_eager, enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype, kv_cache_dtype=args.kv_cache_dtype,
device=args.device, device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
) )
sampling_params = SamplingParams( sampling_params = SamplingParams(
@ -145,5 +146,10 @@ if __name__ == '__main__':
default="cuda", default="cuda",
choices=["cuda"], choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.') help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',
help="If specified, use nsight to profile ray workers",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -382,6 +382,8 @@ class ParallelConfig:
parallel and large models. parallel and large models.
disable_custom_all_reduce: Disable the custom all-reduce kernel and disable_custom_all_reduce: Disable the custom all-reduce kernel and
fall back to NCCL. fall back to NCCL.
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
""" """
def __init__( def __init__(
@ -391,6 +393,7 @@ class ParallelConfig:
worker_use_ray: bool, worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None, max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
ray_workers_use_nsight: bool = False,
) -> None: ) -> None:
self.pipeline_parallel_size = pipeline_parallel_size self.pipeline_parallel_size = pipeline_parallel_size
if is_neuron(): if is_neuron():
@ -404,6 +407,7 @@ class ParallelConfig:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce self.disable_custom_all_reduce = disable_custom_all_reduce
self.ray_workers_use_nsight = ray_workers_use_nsight
self.world_size = pipeline_parallel_size * self.tensor_parallel_size self.world_size = pipeline_parallel_size * self.tensor_parallel_size
# Ray worker is not supported for Neuron backend. # Ray worker is not supported for Neuron backend.
@ -426,6 +430,9 @@ class ParallelConfig:
logger.info( logger.info(
"Disabled the custom all-reduce kernel because it is not " "Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.") "supported with pipeline parallelism.")
if self.ray_workers_use_nsight and not self.worker_use_ray:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")
# FIXME(woosuk): Fix the stability issues and re-enable the custom # FIXME(woosuk): Fix the stability issues and re-enable the custom
# all-reduce kernel. # all-reduce kernel.

View File

@ -46,6 +46,7 @@ class EngineArgs:
lora_dtype = 'auto' lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
device: str = 'auto' device: str = 'auto'
ray_workers_use_nsight: bool = False
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
@ -168,6 +169,10 @@ class EngineArgs:
help='load model sequentially in multiple batches, ' help='load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor ' 'to avoid RAM OOM when using tensor '
'parallel and large models') 'parallel and large models')
parser.add_argument(
'--ray-workers-use-nsight',
action='store_true',
help='If specified, use nsight to profile ray workers')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', parser.add_argument('--block-size',
type=int, type=int,
@ -305,7 +310,8 @@ class EngineArgs:
self.tensor_parallel_size, self.tensor_parallel_size,
self.worker_use_ray, self.worker_use_ray,
self.max_parallel_loading_workers, self.max_parallel_loading_workers,
self.disable_custom_all_reduce) self.disable_custom_all_reduce,
self.ray_workers_use_nsight)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, self.max_num_seqs,
model_config.max_model_len, model_config.max_model_len,

View File

@ -124,7 +124,20 @@ class LLMEngine:
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1": if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0" os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
self._init_workers_ray(placement_group) # Pass additional arguments to initialize the worker
additional_ray_args = {}
if self.parallel_config.ray_workers_use_nsight:
logger.info("Configuring Ray workers to use nsight.")
additional_ray_args = {
"runtime_env": {
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
}
}
self._init_workers_ray(placement_group, **additional_ray_args)
else: else:
self._init_workers() self._init_workers()