Python SDK
Training
Create training jobs and drive low-level LoRA training from Python.
The SDK exposes two layers of training control:
- High-level helpers that mirror
reflex training start, return a job handle, and let you poll for completion. LoraTrainingClient— a step-level interface for forward/backward passes, optimizer steps, and adapter checkpoints.
High-level: create / poll / cancel
import reflex
import time
job = reflex.create_training_job(
dataset_id="ds_abc123",
base_model="pi0.5",
fine_tuning_type="lora",
model_name="my-adapter",
epochs=5,
)
while True:
status = reflex.get_training_job(job["training_job_id"])
if status["status"] in ("succeeded", "failed", "stopped"):
break
time.sleep(10)Functions
| Function | What it does |
|---|---|
create_training_job(*, dataset_id=None, hf_source_uri=None, base_model="pi0.5", fine_tuning_type="lora", ...) | Create a training job. Pass exactly one of dataset_id or hf_source_uri. |
lora_finetune(...) | Convenience wrapper for LoRA training. |
full_finetune(...) | Convenience wrapper for full fine-tunes. |
get_training_job(training_job_id) | Fetch the current status. |
list_training_jobs(status=None) | List jobs, optionally filtered by status. |
cancel_training_job(training_job_id) | Stop a running job. |
All of the above accept convex_url= and api_key= overrides.
Low-level: LoraTrainingClient
For experiment loops that need step-level control (e.g. custom curricula,
hand-crafted gradient accumulation), use LoraTrainingClient.
import reflex
service = reflex.ServiceClient()
client = service.create_lora_training_client(
base_model="pi0.5",
name="my-adapter",
rank=32,
dataset_id="ds_abc123",
)
# Forward + backward on a batch
batch = [
reflex.Datum(
observation={"state": [0.1, 0.2, 0.3]},
actions=[[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]],
),
]
fb = client.forward_backward(batch, loss_fn="behavior_cloning")
print(fb.loss, fb.metrics)
# Apply an optimizer step
step = client.optim_step(reflex.AdamParams(learning_rate=1e-3))
print("step:", step.step)
# Checkpoint the adapter
adapter = client.save_adapter(name="my-adapter", version="v1")
print(adapter.lora)Methods
| Method | Description |
|---|---|
forward_backward(data, loss_fn="behavior_cloning", microbatch_size=None, request_id="") | Run forward + backward. Returns ForwardBackwardResult. |
optim_step(params, request_id="") | Apply an optimizer step. Returns OptimStepResult. |
save_adapter(name="", version="") | Persist a LoRA adapter. Returns AdapterHandle. |
save_state(name) | Checkpoint optimizer state. |
status() | Current run status. |
Each method has an _async variant (e.g. forward_backward_async) for use
with thread-pool executors.
Dataclasses
@dataclass(frozen=True)
class Datum:
observation: dict
actions: list[list[float]]
loss_weights: list[float] | None = None
metadata: dict | None = None
@dataclass(frozen=True)
class AdamParams:
learning_rate: float
beta1: float = 0.9
beta2: float = 0.95
eps: float = 1e-8
weight_decay: float = 0.0
max_grad_norm: float | None = None
@dataclass(frozen=True)
class ForwardBackwardResult:
loss: float | None
metrics: dict
raw: dict
@dataclass(frozen=True)
class OptimStepResult:
step: int | None
metrics: dict
raw: dict
@dataclass(frozen=True)
class AdapterHandle:
lora: str
name: str
version: str
adapter_id: str
raw: dict