diff --git a/instructor/dsl/multitask.py b/instructor/dsl/multitask.py index 727936d..e03da8c 100644 --- a/instructor/dsl/multitask.py +++ b/instructor/dsl/multitask.py @@ -56,19 +56,20 @@ class MultiTaskBase: def extract_json(completion, mode: Mode): for chunk in completion: try: - if mode == Mode.FUNCTIONS: - if json_chunk := chunk.choices[0].delta.function_call.arguments: - yield json_chunk - elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}: - if json_chunk := chunk.choices[0].delta.content: - yield json_chunk - elif mode == Mode.TOOLS: - if json_chunk := chunk.choices[0].delta.tool_calls: - yield json_chunk[0].function.arguments - else: - raise NotImplementedError( - f"Mode {mode} is not supported for MultiTask streaming" - ) + if chunk.choices: + if mode == Mode.FUNCTIONS: + if json_chunk := chunk.choices[0].delta.function_call.arguments: + yield json_chunk + elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}: + if json_chunk := chunk.choices[0].delta.content: + yield json_chunk + elif mode == Mode.TOOLS: + if json_chunk := chunk.choices[0].delta.tool_calls: + yield json_chunk[0].function.arguments + else: + raise NotImplementedError( + f"Mode {mode} is not supported for MultiTask streaming" + ) except AttributeError: pass @@ -76,19 +77,20 @@ class MultiTaskBase: async def extract_json_async(completion, mode: Mode): async for chunk in completion: try: - if mode == Mode.FUNCTIONS: - if json_chunk := chunk.choices[0].delta.function_call.arguments: - yield json_chunk - elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}: - if json_chunk := chunk.choices[0].delta.content: - yield json_chunk - elif mode == Mode.TOOLS: - if json_chunk := chunk.choices[0].delta.tool_calls: - yield json_chunk[0].function.arguments - else: - raise NotImplementedError( - f"Mode {mode} is not supported for MultiTask streaming" - ) + if chunk.choices: + if mode == Mode.FUNCTIONS: + if json_chunk := chunk.choices[0].delta.function_call.arguments: + yield json_chunk + elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}: + if json_chunk := chunk.choices[0].delta.content: + yield json_chunk + elif mode == Mode.TOOLS: + if json_chunk := chunk.choices[0].delta.tool_calls: + yield json_chunk[0].function.arguments + else: + raise NotImplementedError( + f"Mode {mode} is not supported for MultiTask streaming" + ) except AttributeError: pass