Documentation ¶
Overview ¶
Package replicate is a client for Replicate's API.
See https://replicate.com/docs for more information.
Index ¶
- Constants
- Variables
- func Paginate[T any](ctx context.Context, client *Client, initialPage *Page[T]) (<-chan []T, <-chan error)
- func ValidateWebhookRequest(req *http.Request, secret WebhookSigningSecret) (bool, error)
- type APIError
- type Account
- type Backoff
- type Client
- func (r *Client) CancelPrediction(ctx context.Context, id string) (*Prediction, error)
- func (r *Client) CancelTraining(ctx context.Context, trainingID string) (*Training, error)
- func (r *Client) CreateFileFromBuffer(ctx context.Context, buf *bytes.Buffer, options *CreateFileOptions) (*File, error)
- func (r *Client) CreateFileFromBytes(ctx context.Context, data []byte, options *CreateFileOptions) (*File, error)
- func (r *Client) CreateFileFromPath(ctx context.Context, filePath string, options *CreateFileOptions) (*File, error)
- func (r *Client) CreateModel(ctx context.Context, modelOwner string, modelName string, ...) (*Model, error)
- func (r *Client) CreatePrediction(ctx context.Context, version string, input PredictionInput, webhook *Webhook, ...) (*Prediction, error)
- func (r *Client) CreatePredictionWithDeployment(ctx context.Context, deploymentOwner string, deploymentName string, ...) (*Prediction, error)
- func (r *Client) CreatePredictionWithModel(ctx context.Context, modelOwner string, modelName string, ...) (*Prediction, error)
- func (r *Client) CreateTraining(ctx context.Context, modelOwner string, modelName string, version string, ...) (*Training, error)
- func (r *Client) DeleteFile(ctx context.Context, fileID string) error
- func (r *Client) DeleteModelVersion(ctx context.Context, modelOwner string, modelName string, versionID string) error
- func (r *Client) GetCollection(ctx context.Context, slug string) (*Collection, error)
- func (r *Client) GetCurrentAccount(ctx context.Context) (*Account, error)
- func (r *Client) GetDefaultWebhookSecret(ctx context.Context) (*WebhookSigningSecret, error)
- func (r *Client) GetDeployment(ctx context.Context, deploymentOwner string, deploymentName string) (*Deployment, error)
- func (r *Client) GetFile(ctx context.Context, fileID string) (*File, error)
- func (r *Client) GetModel(ctx context.Context, modelOwner string, modelName string) (*Model, error)
- func (r *Client) GetModelVersion(ctx context.Context, modelOwner string, modelName string, versionID string) (*ModelVersion, error)
- func (r *Client) GetPrediction(ctx context.Context, id string) (*Prediction, error)
- func (r *Client) GetTraining(ctx context.Context, trainingID string) (*Training, error)
- func (r *Client) ListCollections(ctx context.Context) (*Page[Collection], error)
- func (r *Client) ListFiles(ctx context.Context) (*Page[File], error)
- func (r *Client) ListHardware(ctx context.Context) (*[]Hardware, error)
- func (r *Client) ListModelVersions(ctx context.Context, modelOwner string, modelName string) (*Page[ModelVersion], error)
- func (r *Client) ListModels(ctx context.Context) (*Page[Model], error)
- func (r *Client) ListPredictions(ctx context.Context) (*Page[Prediction], error)
- func (r *Client) ListTrainings(ctx context.Context) (*Page[Training], error)
- func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, ...) (PredictionOutput, error)
- func (r *Client) Stream(ctx context.Context, identifier string, input PredictionInput, ...) (<-chan SSEEvent, <-chan error)
- func (r *Client) StreamPrediction(ctx context.Context, prediction *Prediction) (<-chan SSEEvent, <-chan error)
- func (r *Client) Wait(ctx context.Context, prediction *Prediction, opts ...WaitOption) error
- func (r *Client) WaitAsync(ctx context.Context, prediction *Prediction, opts ...WaitOption) (<-chan *Prediction, <-chan error)
- type ClientOption
- type Collection
- type ConstantBackoff
- type CreateFileOptions
- type CreateModelOptions
- type Deployment
- type DeploymentConfiguration
- type DeploymentRelease
- type ExponentialBackoff
- type File
- type Hardware
- type Identifier
- type Model
- type ModelVersion
- type Page
- type Prediction
- type PredictionInput
- type PredictionMetrics
- type PredictionOutput
- type PredictionProgress
- type SSEEvent
- type Source
- type Status
- type Training
- type TrainingInput
- type WaitOption
- type Webhook
- type WebhookEventType
- type WebhookSigningSecret
Examples ¶
Constants ¶
const ( // SSETypeDone is the type of SSEEvent that indicates the prediction is done. The Data field will contain an empty JSON object. SSETypeDone = "done" // SSETypeError is the type of SSEEvent that indicates an error occurred during the prediction. The Data field will contain JSON with the error. SSETypeError = "error" // SSETypeLogs is the type of SSEEvent that contains logs from the prediction. SSETypeLogs = "logs" // SSETypeOutput is the type of SSEEvent that contains output from the prediction. SSETypeOutput = "output" )
Variables ¶
var (
ErrInvalidIdentifier = errors.New("invalid identifier, it must be in the format \"owner/name\" or \"owner/name:version\"")
)
var (
ErrInvalidUTF8Data = errors.New("invalid UTF-8 data")
)
var (
ErrNoAuth = errors.New(`no auth token or token source provided -- perhaps you forgot to pass replicate.WithToken("...")`)
)
var WebhookEventAll = []WebhookEventType{ WebhookEventStart, WebhookEventOutput, WebhookEventLogs, WebhookEventCompleted, }
Functions ¶
func Paginate ¶
func Paginate[T any](ctx context.Context, client *Client, initialPage *Page[T]) (<-chan []T, <-chan error)
Paginate takes a Page and the Client request method, and iterates through pages of results.
func ValidateWebhookRequest ¶ added in v0.16.0
func ValidateWebhookRequest(req *http.Request, secret WebhookSigningSecret) (bool, error)
ValidateWebhookRequest validates the signature from an incoming webhook request using the provided secret
Types ¶
type APIError ¶
type APIError struct { // Type is a URI that identifies the error type. Type string `json:"type,omitempty"` // Title is a short human-readable summary of the error. Title string `json:"title,omitempty"` // Status is the HTTP status code. Status int `json:"status,omitempty"` // Detail is a human-readable explanation of the error. Detail string `json:"detail,omitempty"` // Instance is a URI that identifies the specific occurrence of the error. Instance string `json:"instance,omitempty"` }
APIError represents an error returned by the Replicate API
func (*APIError) WriteHTTPResponse ¶ added in v0.14.0
func (e *APIError) WriteHTTPResponse(w http.ResponseWriter)
type Account ¶ added in v0.15.0
type Account struct { Type string `json:"type"` Username string `json:"username"` Name string `json:"name"` GithubURL string `json:"github_url"` // contains filtered or unexported fields }
func (Account) MarshalJSON ¶ added in v0.15.0
func (*Account) UnmarshalJSON ¶ added in v0.15.0
type Client ¶
type Client struct {
// contains filtered or unexported fields
}
Client is a client for the Replicate API.
func NewClient ¶
func NewClient(opts ...ClientOption) (*Client, error)
NewClient creates a new Replicate API client.
func (*Client) CancelPrediction ¶ added in v0.17.0
CancelPrediction cancels a running prediction by its ID.
func (*Client) CancelTraining ¶
CancelTraining sends a request to the Replicate API to cancel a training.
func (*Client) CreateFileFromBuffer ¶ added in v0.15.0
func (r *Client) CreateFileFromBuffer(ctx context.Context, buf *bytes.Buffer, options *CreateFileOptions) (*File, error)
CreateFileFromBuffer creates a new file from a buffer.
func (*Client) CreateFileFromBytes ¶ added in v0.15.0
func (r *Client) CreateFileFromBytes(ctx context.Context, data []byte, options *CreateFileOptions) (*File, error)
CreateFileFromBytes creates a new file from bytes.
func (*Client) CreateFileFromPath ¶ added in v0.15.0
func (r *Client) CreateFileFromPath(ctx context.Context, filePath string, options *CreateFileOptions) (*File, error)
CreateFileFromPath creates a new file from a file path.
func (*Client) CreateModel ¶ added in v0.11.0
func (r *Client) CreateModel(ctx context.Context, modelOwner string, modelName string, options CreateModelOptions) (*Model, error)
CreateModel creates a new model.
func (*Client) CreatePrediction ¶
func (r *Client) CreatePrediction(ctx context.Context, version string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error)
CreatePrediction sends a request to the Replicate API to create a prediction.
Example ¶
ctx := context.Background() // You can also provide a token directly with `replicate.NewClient(replicate.WithToken("r8_..."))` r8, err := replicate.NewClient(replicate.WithTokenFromEnv()) if err != nil { return } // https://replicate.com/stability-ai/stable-diffusion version := "ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4" input := replicate.PredictionInput{ "prompt": "an astronaut riding a horse on mars, hd, dramatic lighting", } webhook := replicate.Webhook{ URL: "https://example.com/webhook", Events: []replicate.WebhookEventType{"start", "completed"}, } // The `Run` method is a convenience method that // creates a prediction, waits for it to finish, and returns the output. // If you want a reference to the prediction, you can call `CreatePrediction`, // call `Wait` on the prediction, and access its `Output` field. prediction, err := r8.CreatePrediction(ctx, version, input, &webhook, false) if err != nil { return } // Wait for the prediction to finish err = r8.Wait(ctx, prediction) if err != nil { return } fmt.Println("output: ", prediction.Output)
Output:
func (*Client) CreatePredictionWithDeployment ¶ added in v0.9.0
func (r *Client) CreatePredictionWithDeployment(ctx context.Context, deploymentOwner string, deploymentName string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error)
CreateDeploymentPrediction sends a request to the Replicate API to create a prediction using the specified deployment.
func (*Client) CreatePredictionWithModel ¶ added in v0.13.0
func (r *Client) CreatePredictionWithModel(ctx context.Context, modelOwner string, modelName string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error)
CreatePredictionWithModel sends a request to the Replicate API to create a prediction for a model.
func (*Client) CreateTraining ¶
func (r *Client) CreateTraining(ctx context.Context, modelOwner string, modelName string, version string, destination string, input TrainingInput, webhook *Webhook) (*Training, error)
CreateTraining sends a request to the Replicate API to create a new training.
func (*Client) DeleteFile ¶ added in v0.15.0
DeleteFile deletes a file.
func (*Client) DeleteModelVersion ¶ added in v0.18.0
func (r *Client) DeleteModelVersion(ctx context.Context, modelOwner string, modelName string, versionID string) error
DeleteModelVersion deletes a model version and all associated predictions, including all output files.
func (*Client) GetCollection ¶
GetCollection returns a collection by slug.
func (*Client) GetCurrentAccount ¶ added in v0.15.0
GetCurrentAccount returns the authenticated user or organization.
func (*Client) GetDefaultWebhookSecret ¶ added in v0.16.0
func (r *Client) GetDefaultWebhookSecret(ctx context.Context) (*WebhookSigningSecret, error)
GetDefaultWebhookSecret gets the default webhook signing secret
func (*Client) GetDeployment ¶ added in v0.16.0
func (r *Client) GetDeployment(ctx context.Context, deploymentOwner string, deploymentName string) (*Deployment, error)
GetDeployment retrieves the details of a specific deployment.
func (*Client) GetModelVersion ¶
func (r *Client) GetModelVersion(ctx context.Context, modelOwner string, modelName string, versionID string) (*ModelVersion, error)
GetModelVersion retrieves a specific version of a model.
func (*Client) GetPrediction ¶
GetPrediction retrieves a prediction from the Replicate API by its ID.
func (*Client) GetTraining ¶
GetTraining sends a request to the Replicate API to get a training.
func (*Client) ListCollections ¶
ListCollections returns a list of all collections.
func (*Client) ListHardware ¶ added in v0.11.0
ListHardware returns a list of available hardware.
func (*Client) ListModelVersions ¶
func (r *Client) ListModelVersions(ctx context.Context, modelOwner string, modelName string) (*Page[ModelVersion], error)
ListModelVersions lists the versions of a model.
func (*Client) ListModels ¶ added in v0.10.0
ListModels lists public models.
func (*Client) ListPredictions ¶
ListPredictions returns a paginated list of predictions.
func (*Client) ListTrainings ¶
ListTrainings returns a list of trainings.
func (*Client) Run ¶
func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (PredictionOutput, error)
Example ¶
ctx := context.Background() // You can also provide a token directly with `replicate.NewClient(replicate.WithToken("r8_..."))` r8, err := replicate.NewClient(replicate.WithTokenFromEnv()) if err != nil { return } // https://replicate.com/stability-ai/stable-diffusion version := "ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4" input := replicate.PredictionInput{ "prompt": "an astronaut riding a horse on mars, hd, dramatic lighting", } webhook := replicate.Webhook{ URL: "https://example.com/webhook", Events: []replicate.WebhookEventType{"start", "completed"}, } // Run a model and wait for its output output, err := r8.Run(ctx, version, input, &webhook) if err != nil { return } fmt.Println("output: ", output)
Output:
func (*Client) StreamPrediction ¶ added in v0.13.1
func (*Client) Wait ¶
func (r *Client) Wait(ctx context.Context, prediction *Prediction, opts ...WaitOption) error
Wait for a prediction to finish.
This function blocks until the prediction has finished, or the context is canceled. If the prediction has already finished, the function returns immediately. If polling interval is less than or equal to zero, an error is returned.
func (*Client) WaitAsync ¶ added in v0.7.0
func (r *Client) WaitAsync(ctx context.Context, prediction *Prediction, opts ...WaitOption) (<-chan *Prediction, <-chan error)
WaitAsync returns a channel that receives the prediction as it progresses.
The channel is closed when the prediction has finished, or the context is canceled. If the prediction has already finished, the channel is closed immediately. If polling interval is less than or equal to zero, an error is sent to the error channel.
type ClientOption ¶
type ClientOption func(*clientOptions) error
ClientOption is a function that modifies an options struct.
func WithBaseURL ¶
func WithBaseURL(baseURL string) ClientOption
WithBaseURL sets the base URL for the client.
func WithHTTPClient ¶
func WithHTTPClient(httpClient *http.Client) ClientOption
WithHTTPClient sets the HTTP client used by the client.
func WithRetryPolicy ¶ added in v0.7.0
func WithRetryPolicy(maxRetries int, backoff Backoff) ClientOption
WithRetryPolicy sets the retry policy used by the client.
func WithToken ¶ added in v0.6.0
func WithToken(token string) ClientOption
WithToken sets the auth token used by the client.
func WithTokenFromEnv ¶ added in v0.6.0
func WithTokenFromEnv() ClientOption
WithTokenFromEnv configures the client to use the auth token provided in the REPLICATE_API_TOKEN environment variable.
func WithUserAgent ¶
func WithUserAgent(userAgent string) ClientOption
WithUserAgent sets the User-Agent header on requests made by the client.
type Collection ¶
type Collection struct { Name string `json:"name"` Slug string `json:"slug"` Description string `json:"description"` Models *[]Model `json:"models,omitempty"` // contains filtered or unexported fields }
func (Collection) MarshalJSON ¶ added in v0.8.1
func (c Collection) MarshalJSON() ([]byte, error)
func (*Collection) UnmarshalJSON ¶ added in v0.8.1
func (c *Collection) UnmarshalJSON(data []byte) error
type ConstantBackoff ¶ added in v0.7.0
ConstantBackoff is a backoff strategy that returns a constant delay with some jitter.
type CreateFileOptions ¶ added in v0.15.0
type CreateModelOptions ¶ added in v0.11.0
type CreateModelOptions struct { Visibility string `json:"visibility"` Hardware string `json:"hardware"` Description *string `json:"description,omitempty"` GithubURL *string `json:"github_url,omitempty"` PaperURL *string `json:"paper_url,omitempty"` LicenseURL *string `json:"license_url,omitempty"` CoverImageURL *string `json:"cover_image_url,omitempty"` }
type Deployment ¶ added in v0.16.0
type Deployment struct { Owner string `json:"owner"` Name string `json:"name"` CurrentRelease DeploymentRelease `json:"current_release"` // contains filtered or unexported fields }
func (Deployment) MarshalJSON ¶ added in v0.16.0
func (d Deployment) MarshalJSON() ([]byte, error)
func (*Deployment) UnmarshalJSON ¶ added in v0.16.0
func (d *Deployment) UnmarshalJSON(data []byte) error
type DeploymentConfiguration ¶ added in v0.16.0
type DeploymentRelease ¶ added in v0.16.0
type ExponentialBackoff ¶ added in v0.7.0
ExponentialBackoff is a backoff strategy that returns an exponentially increasing delay with some jitter.
type File ¶ added in v0.15.0
type File struct { ID string `json:"id"` Name string `json:"name"` ContentType string `json:"content_type"` Size int `json:"size"` Etag string `json:"etag"` Checksums map[string]string `json:"checksums"` Metadata map[string]string `json:"metadata"` CreatedAt string `json:"created_at"` ExpiresAt string `json:"expires_at"` URLs map[string]string `json:"urls"` }
type Hardware ¶ added in v0.11.0
type Hardware struct { SKU string `json:"sku"` Name string `json:"name"` // contains filtered or unexported fields }
func (Hardware) MarshalJSON ¶ added in v0.11.0
func (*Hardware) UnmarshalJSON ¶ added in v0.11.0
type Identifier ¶ added in v0.13.0
type Identifier struct { // Owner is the username of the model owner. Owner string // Name is the name of the model. Name string // Version is the version of the model. Version *string }
Identifier represents a reference to a Replicate model with an optional version.
func ParseIdentifier ¶ added in v0.13.0
func ParseIdentifier(identifier string) (*Identifier, error)
func (*Identifier) String ¶ added in v0.13.1
func (i *Identifier) String() string
type Model ¶
type Model struct { URL string `json:"url"` Owner string `json:"owner"` Name string `json:"name"` Description string `json:"description"` Visibility string `json:"visibility"` GithubURL string `json:"github_url"` PaperURL string `json:"paper_url"` LicenseURL string `json:"license_url"` RunCount int `json:"run_count"` CoverImageURL string `json:"cover_image_url"` DefaultExample *Prediction `json:"default_example"` LatestVersion *ModelVersion `json:"latest_version"` // contains filtered or unexported fields }
func (Model) MarshalJSON ¶ added in v0.8.1
func (*Model) UnmarshalJSON ¶ added in v0.8.1
type ModelVersion ¶
type ModelVersion struct { ID string `json:"id"` CreatedAt string `json:"created_at"` CogVersion string `json:"cog_version"` OpenAPISchema interface{} `json:"openapi_schema"` // contains filtered or unexported fields }
func (ModelVersion) MarshalJSON ¶ added in v0.8.1
func (m ModelVersion) MarshalJSON() ([]byte, error)
func (*ModelVersion) UnmarshalJSON ¶ added in v0.8.1
func (m *ModelVersion) UnmarshalJSON(data []byte) error
type Page ¶
type Page[T any] struct { Previous *string `json:"previous,omitempty"` Next *string `json:"next,omitempty"` Results []T `json:"results"` // contains filtered or unexported fields }
Page represents a paginated response from Replicate's API.
func (Page[T]) MarshalJSON ¶ added in v0.8.1
func (*Page[T]) UnmarshalJSON ¶ added in v0.8.1
type Prediction ¶
type Prediction struct { ID string `json:"id"` Status Status `json:"status"` Model string `json:"model"` Version string `json:"version"` Input PredictionInput `json:"input"` Output PredictionOutput `json:"output,omitempty"` Source Source `json:"source"` Error interface{} `json:"error,omitempty"` Logs *string `json:"logs,omitempty"` Metrics *PredictionMetrics `json:"metrics,omitempty"` Webhook *string `json:"webhook,omitempty"` WebhookEventsFilter []WebhookEventType `json:"webhook_events_filter,omitempty"` URLs map[string]string `json:"urls,omitempty"` CreatedAt string `json:"created_at"` StartedAt *string `json:"started_at,omitempty"` CompletedAt *string `json:"completed_at,omitempty"` // contains filtered or unexported fields }
func (Prediction) MarshalJSON ¶ added in v0.8.1
func (p Prediction) MarshalJSON() ([]byte, error)
func (Prediction) Progress ¶ added in v0.8.0
func (p Prediction) Progress() *PredictionProgress
func (*Prediction) UnmarshalJSON ¶ added in v0.8.1
func (p *Prediction) UnmarshalJSON(data []byte) error
type PredictionInput ¶
type PredictionInput map[string]interface{}
type PredictionMetrics ¶ added in v0.18.0
type PredictionMetrics struct { PredictTime *float64 `json:"predict_time,omitempty"` TotalTime *float64 `json:"total_time,omitempty"` InputTokenCount *int `json:"input_token_count,omitempty"` OutputTokenCount *int `json:"output_token_count,omitempty"` TimeToFirstToken *float64 `json:"time_to_first_token,omitempty"` TokensPerSecond *float64 `json:"tokens_per_second,omitempty"` }
type PredictionOutput ¶
type PredictionOutput interface{}
type PredictionProgress ¶ added in v0.8.0
type Training ¶
type Training Prediction
type TrainingInput ¶
type TrainingInput PredictionInput
type WaitOption ¶ added in v0.7.0
type WaitOption func(*waitOptions) error
WaitOption is a function that modifies an options struct.
func WithPollingInterval ¶ added in v0.7.0
func WithPollingInterval(interval time.Duration) WaitOption
WithPollingInterval sets the interval between attempts.
type Webhook ¶
type Webhook struct { URL string Events []WebhookEventType }
type WebhookEventType ¶
type WebhookEventType string
const ( WebhookEventStart WebhookEventType = "start" WebhookEventOutput WebhookEventType = "output" WebhookEventLogs WebhookEventType = "logs" WebhookEventCompleted WebhookEventType = "completed" )
func (WebhookEventType) String ¶
func (w WebhookEventType) String() string
type WebhookSigningSecret ¶ added in v0.16.0
type WebhookSigningSecret struct {
Key string `json:"key"`
}