Reflex Docs
Python SDK

Training

Create training jobs and drive low-level LoRA training from Python.

The SDK exposes two layers of training control:

  • High-level helpers that mirror reflex training start, return a job handle, and let you poll for completion.
  • LoraTrainingClient — a step-level interface for forward/backward passes, optimizer steps, and adapter checkpoints.

High-level: create / poll / cancel

import reflex
import time

job = reflex.create_training_job(
    dataset_id="ds_abc123",
    base_model="pi0.5",
    fine_tuning_type="lora",
    model_name="my-adapter",
    epochs=5,
)

while True:
    status = reflex.get_training_job(job["training_job_id"])
    if status["status"] in ("succeeded", "failed", "stopped"):
        break
    time.sleep(10)

Functions

FunctionWhat it does
create_training_job(*, dataset_id=None, hf_source_uri=None, base_model="pi0.5", fine_tuning_type="lora", ...)Create a training job. Pass exactly one of dataset_id or hf_source_uri.
lora_finetune(...)Convenience wrapper for LoRA training.
full_finetune(...)Convenience wrapper for full fine-tunes.
get_training_job(training_job_id)Fetch the current status.
list_training_jobs(status=None)List jobs, optionally filtered by status.
cancel_training_job(training_job_id)Stop a running job.

All of the above accept convex_url= and api_key= overrides.

Low-level: LoraTrainingClient

For experiment loops that need step-level control (e.g. custom curricula, hand-crafted gradient accumulation), use LoraTrainingClient.

import reflex

service = reflex.ServiceClient()

client = service.create_lora_training_client(
    base_model="pi0.5",
    name="my-adapter",
    rank=32,
    dataset_id="ds_abc123",
)

# Forward + backward on a batch
batch = [
    reflex.Datum(
        observation={"state": [0.1, 0.2, 0.3]},
        actions=[[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]],
    ),
]
fb = client.forward_backward(batch, loss_fn="behavior_cloning")
print(fb.loss, fb.metrics)

# Apply an optimizer step
step = client.optim_step(reflex.AdamParams(learning_rate=1e-3))
print("step:", step.step)

# Checkpoint the adapter
adapter = client.save_adapter(name="my-adapter", version="v1")
print(adapter.lora)

Methods

MethodDescription
forward_backward(data, loss_fn="behavior_cloning", microbatch_size=None, request_id="")Run forward + backward. Returns ForwardBackwardResult.
optim_step(params, request_id="")Apply an optimizer step. Returns OptimStepResult.
save_adapter(name="", version="")Persist a LoRA adapter. Returns AdapterHandle.
save_state(name)Checkpoint optimizer state.
status()Current run status.

Each method has an _async variant (e.g. forward_backward_async) for use with thread-pool executors.

Dataclasses

@dataclass(frozen=True)
class Datum:
    observation: dict
    actions: list[list[float]]
    loss_weights: list[float] | None = None
    metadata: dict | None = None

@dataclass(frozen=True)
class AdamParams:
    learning_rate: float
    beta1: float = 0.9
    beta2: float = 0.95
    eps: float = 1e-8
    weight_decay: float = 0.0
    max_grad_norm: float | None = None

@dataclass(frozen=True)
class ForwardBackwardResult:
    loss: float | None
    metrics: dict
    raw: dict

@dataclass(frozen=True)
class OptimStepResult:
    step: int | None
    metrics: dict
    raw: dict

@dataclass(frozen=True)
class AdapterHandle:
    lora: str
    name: str
    version: str
    adapter_id: str
    raw: dict