[V1] Avoid sending text prompt to core engine (#11963)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang 2025-01-11 22:36:38 -08:00 committed by GitHub
parent 4b657d3292
commit b25cfab9a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 2 deletions

View File

@ -19,8 +19,8 @@ class EngineCoreRequest:
# due to circular imports and typing we have in data.py # due to circular imports and typing we have in data.py
request_id: str request_id: str
#NOTE(Nick): I don't think we need to pass prompt here since it should # NOTE(ywang96): original text prompt is needed when a request is added to
# always be tokenized? # Detokenizer, but set to None when it is added to EngineCoreClient.
prompt: Optional[str] prompt: Optional[str]
prompt_token_ids: List[int] prompt_token_ids: List[int]
mm_inputs: Optional[List[Optional["MultiModalKwargs"]]] mm_inputs: Optional[List[Optional["MultiModalKwargs"]]]

View File

@ -219,6 +219,9 @@ class SyncMPClient(MPClient):
self.input_socket.send_multipart(msg, copy=False) self.input_socket.send_multipart(msg, copy=False)
def add_request(self, request: EngineCoreRequest) -> None: def add_request(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request.prompt = None
self._send_input(EngineCoreRequestType.ADD, request) self._send_input(EngineCoreRequestType.ADD, request)
def abort_requests(self, request_ids: List[str]) -> None: def abort_requests(self, request_ids: List[str]) -> None:
@ -257,6 +260,9 @@ class AsyncMPClient(MPClient):
await self.input_socket.send_multipart(msg, copy=False) await self.input_socket.send_multipart(msg, copy=False)
async def add_request_async(self, request: EngineCoreRequest) -> None: async def add_request_async(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request.prompt = None
await self._send_input(EngineCoreRequestType.ADD, request) await self._send_input(EngineCoreRequestType.ADD, request)
async def abort_requests_async(self, request_ids: List[str]) -> None: async def abort_requests_async(self, request_ids: List[str]) -> None: