mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
Added support for model suffix and added migrations for new OpenAI SDK (#169)
This commit is contained in:
+32
-23
@@ -7,6 +7,8 @@ from rich.live import Live
|
||||
from rich.table import Table
|
||||
from rich.console import Console
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from openai.types.fine_tuning import FineTuningJob
|
||||
|
||||
client = OpenAI()
|
||||
app = typer.Typer()
|
||||
@@ -15,7 +17,8 @@ console = Console()
|
||||
|
||||
def generate_table(jobs):
|
||||
# Sorting the jobs by creation time
|
||||
jobs = sorted(jobs, key=lambda x: x["created_at"], reverse=True)
|
||||
jobs = sorted(jobs, key=lambda x: (cast(FineTuningJob, x)).created_at, reverse=True)
|
||||
jobs = cast(List[FineTuningJob], jobs)
|
||||
|
||||
table = Table(
|
||||
title="OpenAI Fine Tuning Job Monitoring",
|
||||
@@ -37,23 +40,21 @@ def generate_table(jobs):
|
||||
"succeeded": "✅",
|
||||
"failed": "❌",
|
||||
"cancelled": "🚫",
|
||||
}.get(job["status"], "❓")
|
||||
}.get(job.status, "❓")
|
||||
|
||||
finished_at = (
|
||||
str(datetime.fromtimestamp(job["finished_at"]))
|
||||
if job["finished_at"]
|
||||
else "N/A"
|
||||
str(datetime.fromtimestamp(job.finished_at)) if job.finished_at else "N/A"
|
||||
)
|
||||
|
||||
table.add_row(
|
||||
job["id"],
|
||||
f"{status_emoji} [{status_color(job['status'])}]{job['status']}[/]",
|
||||
str(datetime.fromtimestamp(job["created_at"])),
|
||||
job.id,
|
||||
f"{status_emoji} [{status_color(job.status)}]{job.status}[/]",
|
||||
str(datetime.fromtimestamp(job.created_at)),
|
||||
finished_at,
|
||||
job["fine_tuned_model"],
|
||||
job["training_file"],
|
||||
str(job["hyperparameters"]["n_epochs"]),
|
||||
job["model"],
|
||||
job.fine_tuned_model,
|
||||
job.training_file,
|
||||
str(job.hyperparameters.n_epochs),
|
||||
job.model,
|
||||
)
|
||||
|
||||
return table
|
||||
@@ -66,12 +67,12 @@ def status_color(status: str) -> str:
|
||||
|
||||
|
||||
def get_jobs(limit: int = 5) -> List:
|
||||
return client.fine_tuning.list(limit=limit)["data"]
|
||||
return client.fine_tuning.jobs.list(limit=limit).data
|
||||
|
||||
|
||||
def get_file_status(file_id: str) -> str:
|
||||
response = client.files.retrieve(file_id)
|
||||
return response["status"]
|
||||
return response.status
|
||||
|
||||
|
||||
@app.command(
|
||||
@@ -124,7 +125,7 @@ def create_from_id(
|
||||
with console.status(
|
||||
f"[bold green]Creating fine-tuning job from ID {id}...", spinner="dots"
|
||||
):
|
||||
job = client.fine_tuning.create(
|
||||
job = client.fine_tuning.jobs.create(
|
||||
training_file=id,
|
||||
model=model,
|
||||
hyperparameters=hyperparameters_dict if hyperparameters_dict else None,
|
||||
@@ -151,6 +152,7 @@ def create_from_file(
|
||||
None, help="Learning rate multiplier for fine-tuning", show_default=False
|
||||
),
|
||||
validation_file: str = typer.Option(None, help="Path to the validation file"),
|
||||
model_suffix: str = typer.Option(None, help="Suffix to identify the model"),
|
||||
):
|
||||
hyperparameters_dict = {}
|
||||
if n_epochs is not None:
|
||||
@@ -163,13 +165,13 @@ def create_from_file(
|
||||
with open(file, "rb") as file:
|
||||
response = client.files.create(file=file, purpose="fine-tune")
|
||||
|
||||
file_id = response["id"]
|
||||
file_id = response.id
|
||||
|
||||
validation_file_id = None
|
||||
if validation_file:
|
||||
with open(validation_file, "rb") as val_file:
|
||||
val_response = client.files.create(file=val_file, purpose="fine-tune")
|
||||
validation_file_id = val_response["id"]
|
||||
validation_file_id = val_response.id
|
||||
|
||||
with console.status(f"Monitoring upload: {file_id} before finetuning...") as status:
|
||||
status.spinner_style = "dots"
|
||||
@@ -190,19 +192,26 @@ def create_from_file(
|
||||
|
||||
time.sleep(poll)
|
||||
|
||||
job = client.fine_tuning.create(
|
||||
additional_params = {}
|
||||
if hyperparameters_dict:
|
||||
additional_params["hyperparameters"] = hyperparameters_dict
|
||||
if validation_file:
|
||||
additional_params["validation_file"] = validation_file
|
||||
if model_suffix:
|
||||
additional_params["suffix"] = model_suffix
|
||||
|
||||
job = client.fine_tuning.jobs.create(
|
||||
training_file=file_id,
|
||||
model=model,
|
||||
hyperparameters=hyperparameters_dict if hyperparameters_dict else None,
|
||||
validation_file=validation_file_id if validation_file else None,
|
||||
**additional_params,
|
||||
)
|
||||
if validation_file_id:
|
||||
console.log(
|
||||
f"[bold green]Fine-tuning job created with ID: {job['id']} from file ID: {file_id} and validation_file ID: {validation_file_id}"
|
||||
f"[bold green]Fine-tuning job created with ID: {job.id} from file ID: {file_id} and validation_file ID: {validation_file_id}"
|
||||
)
|
||||
else:
|
||||
console.log(
|
||||
f"[bold green]Fine-tuning job created with ID: {job['id']} from file ID: {file_id}"
|
||||
f"[bold green]Fine-tuning job created with ID: {job.id} from file ID: {file_id}"
|
||||
)
|
||||
watch(limit=5, poll=poll, screen=False)
|
||||
|
||||
@@ -213,7 +222,7 @@ def create_from_file(
|
||||
def cancel(id: str = typer.Argument(..., help="ID of the fine-tuning job to cancel")):
|
||||
with console.status(f"[bold red]Cancelling job {id}...", spinner="dots"):
|
||||
try:
|
||||
client.fine_tuning.cancel(id)
|
||||
client.fine_tuning.jobs.cancel(id)
|
||||
console.log(f"[bold red]Job {id} cancelled successfully!")
|
||||
except Exception as e:
|
||||
console.log(f"[bold red]Error cancelling job {id}: {e}")
|
||||
|
||||
Reference in New Issue
Block a user