Added support for model suffix and added migrations for new OpenAI SDK (#169)

This commit is contained in:
Ivan Leo
2023-11-12 23:01:46 +08:00
committed by GitHub
parent 3d2035bb30
commit fac2171162
+32 -23
View File
@@ -7,6 +7,8 @@ from rich.live import Live
from rich.table import Table
from rich.console import Console
from datetime import datetime
from typing import cast
from openai.types.fine_tuning import FineTuningJob
client = OpenAI()
app = typer.Typer()
@@ -15,7 +17,8 @@ console = Console()
def generate_table(jobs):
# Sorting the jobs by creation time
jobs = sorted(jobs, key=lambda x: x["created_at"], reverse=True)
jobs = sorted(jobs, key=lambda x: (cast(FineTuningJob, x)).created_at, reverse=True)
jobs = cast(List[FineTuningJob], jobs)
table = Table(
title="OpenAI Fine Tuning Job Monitoring",
@@ -37,23 +40,21 @@ def generate_table(jobs):
"succeeded": "",
"failed": "",
"cancelled": "🚫",
}.get(job["status"], "")
}.get(job.status, "")
finished_at = (
str(datetime.fromtimestamp(job["finished_at"]))
if job["finished_at"]
else "N/A"
str(datetime.fromtimestamp(job.finished_at)) if job.finished_at else "N/A"
)
table.add_row(
job["id"],
f"{status_emoji} [{status_color(job['status'])}]{job['status']}[/]",
str(datetime.fromtimestamp(job["created_at"])),
job.id,
f"{status_emoji} [{status_color(job.status)}]{job.status}[/]",
str(datetime.fromtimestamp(job.created_at)),
finished_at,
job["fine_tuned_model"],
job["training_file"],
str(job["hyperparameters"]["n_epochs"]),
job["model"],
job.fine_tuned_model,
job.training_file,
str(job.hyperparameters.n_epochs),
job.model,
)
return table
@@ -66,12 +67,12 @@ def status_color(status: str) -> str:
def get_jobs(limit: int = 5) -> List:
return client.fine_tuning.list(limit=limit)["data"]
return client.fine_tuning.jobs.list(limit=limit).data
def get_file_status(file_id: str) -> str:
response = client.files.retrieve(file_id)
return response["status"]
return response.status
@app.command(
@@ -124,7 +125,7 @@ def create_from_id(
with console.status(
f"[bold green]Creating fine-tuning job from ID {id}...", spinner="dots"
):
job = client.fine_tuning.create(
job = client.fine_tuning.jobs.create(
training_file=id,
model=model,
hyperparameters=hyperparameters_dict if hyperparameters_dict else None,
@@ -151,6 +152,7 @@ def create_from_file(
None, help="Learning rate multiplier for fine-tuning", show_default=False
),
validation_file: str = typer.Option(None, help="Path to the validation file"),
model_suffix: str = typer.Option(None, help="Suffix to identify the model"),
):
hyperparameters_dict = {}
if n_epochs is not None:
@@ -163,13 +165,13 @@ def create_from_file(
with open(file, "rb") as file:
response = client.files.create(file=file, purpose="fine-tune")
file_id = response["id"]
file_id = response.id
validation_file_id = None
if validation_file:
with open(validation_file, "rb") as val_file:
val_response = client.files.create(file=val_file, purpose="fine-tune")
validation_file_id = val_response["id"]
validation_file_id = val_response.id
with console.status(f"Monitoring upload: {file_id} before finetuning...") as status:
status.spinner_style = "dots"
@@ -190,19 +192,26 @@ def create_from_file(
time.sleep(poll)
job = client.fine_tuning.create(
additional_params = {}
if hyperparameters_dict:
additional_params["hyperparameters"] = hyperparameters_dict
if validation_file:
additional_params["validation_file"] = validation_file
if model_suffix:
additional_params["suffix"] = model_suffix
job = client.fine_tuning.jobs.create(
training_file=file_id,
model=model,
hyperparameters=hyperparameters_dict if hyperparameters_dict else None,
validation_file=validation_file_id if validation_file else None,
**additional_params,
)
if validation_file_id:
console.log(
f"[bold green]Fine-tuning job created with ID: {job['id']} from file ID: {file_id} and validation_file ID: {validation_file_id}"
f"[bold green]Fine-tuning job created with ID: {job.id} from file ID: {file_id} and validation_file ID: {validation_file_id}"
)
else:
console.log(
f"[bold green]Fine-tuning job created with ID: {job['id']} from file ID: {file_id}"
f"[bold green]Fine-tuning job created with ID: {job.id} from file ID: {file_id}"
)
watch(limit=5, poll=poll, screen=False)
@@ -213,7 +222,7 @@ def create_from_file(
def cancel(id: str = typer.Argument(..., help="ID of the fine-tuning job to cancel")):
with console.status(f"[bold red]Cancelling job {id}...", spinner="dots"):
try:
client.fine_tuning.cancel(id)
client.fine_tuning.jobs.cancel(id)
console.log(f"[bold red]Job {id} cancelled successfully!")
except Exception as e:
console.log(f"[bold red]Error cancelling job {id}: {e}")