From fac2171162f141826c2ca6964cb0ca6ab10649a5 Mon Sep 17 00:00:00 2001 From: Ivan Leo Date: Sun, 12 Nov 2023 23:01:46 +0800 Subject: [PATCH] Added support for model suffix and added migrations for new OpenAI SDK (#169) --- instructor/cli/jobs.py | 55 ++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/instructor/cli/jobs.py b/instructor/cli/jobs.py index d378cd5..a289173 100644 --- a/instructor/cli/jobs.py +++ b/instructor/cli/jobs.py @@ -7,6 +7,8 @@ from rich.live import Live from rich.table import Table from rich.console import Console from datetime import datetime +from typing import cast +from openai.types.fine_tuning import FineTuningJob client = OpenAI() app = typer.Typer() @@ -15,7 +17,8 @@ console = Console() def generate_table(jobs): # Sorting the jobs by creation time - jobs = sorted(jobs, key=lambda x: x["created_at"], reverse=True) + jobs = sorted(jobs, key=lambda x: (cast(FineTuningJob, x)).created_at, reverse=True) + jobs = cast(List[FineTuningJob], jobs) table = Table( title="OpenAI Fine Tuning Job Monitoring", @@ -37,23 +40,21 @@ def generate_table(jobs): "succeeded": "✅", "failed": "❌", "cancelled": "🚫", - }.get(job["status"], "❓") + }.get(job.status, "❓") finished_at = ( - str(datetime.fromtimestamp(job["finished_at"])) - if job["finished_at"] - else "N/A" + str(datetime.fromtimestamp(job.finished_at)) if job.finished_at else "N/A" ) table.add_row( - job["id"], - f"{status_emoji} [{status_color(job['status'])}]{job['status']}[/]", - str(datetime.fromtimestamp(job["created_at"])), + job.id, + f"{status_emoji} [{status_color(job.status)}]{job.status}[/]", + str(datetime.fromtimestamp(job.created_at)), finished_at, - job["fine_tuned_model"], - job["training_file"], - str(job["hyperparameters"]["n_epochs"]), - job["model"], + job.fine_tuned_model, + job.training_file, + str(job.hyperparameters.n_epochs), + job.model, ) return table @@ -66,12 +67,12 @@ def status_color(status: str) -> str: def get_jobs(limit: int = 5) -> List: - return client.fine_tuning.list(limit=limit)["data"] + return client.fine_tuning.jobs.list(limit=limit).data def get_file_status(file_id: str) -> str: response = client.files.retrieve(file_id) - return response["status"] + return response.status @app.command( @@ -124,7 +125,7 @@ def create_from_id( with console.status( f"[bold green]Creating fine-tuning job from ID {id}...", spinner="dots" ): - job = client.fine_tuning.create( + job = client.fine_tuning.jobs.create( training_file=id, model=model, hyperparameters=hyperparameters_dict if hyperparameters_dict else None, @@ -151,6 +152,7 @@ def create_from_file( None, help="Learning rate multiplier for fine-tuning", show_default=False ), validation_file: str = typer.Option(None, help="Path to the validation file"), + model_suffix: str = typer.Option(None, help="Suffix to identify the model"), ): hyperparameters_dict = {} if n_epochs is not None: @@ -163,13 +165,13 @@ def create_from_file( with open(file, "rb") as file: response = client.files.create(file=file, purpose="fine-tune") - file_id = response["id"] + file_id = response.id validation_file_id = None if validation_file: with open(validation_file, "rb") as val_file: val_response = client.files.create(file=val_file, purpose="fine-tune") - validation_file_id = val_response["id"] + validation_file_id = val_response.id with console.status(f"Monitoring upload: {file_id} before finetuning...") as status: status.spinner_style = "dots" @@ -190,19 +192,26 @@ def create_from_file( time.sleep(poll) - job = client.fine_tuning.create( + additional_params = {} + if hyperparameters_dict: + additional_params["hyperparameters"] = hyperparameters_dict + if validation_file: + additional_params["validation_file"] = validation_file + if model_suffix: + additional_params["suffix"] = model_suffix + + job = client.fine_tuning.jobs.create( training_file=file_id, model=model, - hyperparameters=hyperparameters_dict if hyperparameters_dict else None, - validation_file=validation_file_id if validation_file else None, + **additional_params, ) if validation_file_id: console.log( - f"[bold green]Fine-tuning job created with ID: {job['id']} from file ID: {file_id} and validation_file ID: {validation_file_id}" + f"[bold green]Fine-tuning job created with ID: {job.id} from file ID: {file_id} and validation_file ID: {validation_file_id}" ) else: console.log( - f"[bold green]Fine-tuning job created with ID: {job['id']} from file ID: {file_id}" + f"[bold green]Fine-tuning job created with ID: {job.id} from file ID: {file_id}" ) watch(limit=5, poll=poll, screen=False) @@ -213,7 +222,7 @@ def create_from_file( def cancel(id: str = typer.Argument(..., help="ID of the fine-tuning job to cancel")): with console.status(f"[bold red]Cancelling job {id}...", spinner="dots"): try: - client.fine_tuning.cancel(id) + client.fine_tuning.jobs.cancel(id) console.log(f"[bold red]Job {id} cancelled successfully!") except Exception as e: console.log(f"[bold red]Error cancelling job {id}: {e}")