-
Notifications
You must be signed in to change notification settings - Fork 662
[rl] Using JobConfig as the centralized config system for inference and simple GRPO #2191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/wwwjn/2/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
allenwang28
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this direction, thanks! Mostly nits here
| ``` | ||
| Right now we only support VLLM_COMPAT mode, which could achieve trainer and generator bitwise identical. We are working on support UNIFIED mode, | ||
| which uses a unified model definition for trainer and generator. | ||
| We uses a unified model definition for trainer and generator, which could achieve trainer and generator bitwise identical. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| We uses a unified model definition for trainer and generator, which could achieve trainer and generator bitwise identical. | |
| We use a unified model definition for the trainer and generator, ensuring bitwise-identical models to address a class of subtle correctness bugs in RL for LLMs. |
nit
| math_reward_function if use_real_dataset else trivial_reward_function | ||
| ) | ||
| # Reward function. TODO: Use a real reward function | ||
| self.reward_fn = trivial_reward_function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the RL job definition would need to define a callable here, is this the idea for later?
| sampling_params = SamplingParams( | ||
| temperature=temperature, | ||
| max_tokens=max_new_tokens, | ||
| n=n_samples_per_prompt, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n_samples_per_prompt fulfills the same purpose as group_size I assume. I see below that we're preferring to submit a prompt multiple times instead of relying on vLLM. Is this due to batch invariance or something else? I'd assume that letting vLLM do it is better performance wise
| type=int, | ||
| default=1, | ||
| help="Number of GPUs for tensor parallelism (default: 1 for single GPU)", | ||
| def infer(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def infer(): | |
| def generate(): |
nit, but infer() makes me think of like getting the logits
| # Create process meshes | ||
| trainer_mesh = this_host().spawn_procs(per_host={"gpus": 2}) | ||
| trainer_mesh = this_host().spawn_procs( | ||
| per_host={"gpus": trainer_ddp_size * trainer_tp_size} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the future we should deduce the total number of GPUs needed for a given trainer parallelism
Stack from ghstack (oldest at bottom):
trainer's config andgenerator's config are not symmetric, egParallelismandGeneration.parallelismrun_configs/qwen3_0.6b.tomlfile.Test: (trainer ddp = 2, n_generator =1)

Following-up refactors: