Files
gotenberg/pkg/modules/api/middlewares.go
T

444 lines
14 KiB
Go

package api
import (
"context"
"crypto/subtle"
"errors"
"fmt"
"log/slog"
"net/http"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
"github.com/gotenberg/gotenberg/v8/pkg/gotenberg"
semconvutil "github.com/gotenberg/gotenberg/v8/pkg/gotenberg/semconv"
)
var (
// ErrAsyncProcess happens when a handler or middleware handles a request
// in an asynchronous fashion.
ErrAsyncProcess = errors.New("async process")
// ErrNoOutputFile happens when a handler or middleware handles a request
// without sending any output file.
ErrNoOutputFile = errors.New("no output file")
)
// ParseError parses an error and returns the corresponding HTTP status and
// HTTP message.
func ParseError(err error) (int, string) {
var echoErr *echo.HTTPError
ok := errors.As(err, &echoErr)
if ok {
return echoErr.Code, http.StatusText(echoErr.Code)
}
if errors.Is(err, context.DeadlineExceeded) {
return http.StatusServiceUnavailable, "The request exceeded the time limit. Increase it with --api-timeout, or reduce the workload."
}
if errors.Is(err, gotenberg.ErrFiltered) {
return http.StatusForbidden, http.StatusText(http.StatusForbidden)
}
if errors.Is(err, gotenberg.ErrMaximumQueueSizeExceeded) {
return http.StatusTooManyRequests, "The request queue is full. Retry shortly, or raise the limit with --chromium-max-queue-size or --libreoffice-max-queue-size."
}
if errors.Is(err, gotenberg.ErrPdfSplitModeNotSupported) {
return http.StatusBadRequest, "The requested split mode is not supported, or no PDF engine could process it. Valid modes: 'intervals', 'pages'."
}
if errors.Is(err, gotenberg.ErrPdfFormatNotSupported) {
return http.StatusBadRequest, "The requested PDF format is not supported, or no PDF engine could apply it. Valid formats include PDF/A-1b, PDF/A-2b, PDF/A-3b, and PDF/UA."
}
if errors.Is(err, gotenberg.ErrPdfEngineMetadataValueNotSupported) {
return http.StatusBadRequest, "The requested metadata could not be written; ensure values are valid and free of control characters."
}
if errors.Is(err, gotenberg.ErrPdfStampSourceNotSupported) {
return http.StatusBadRequest, "The requested stamp source is not supported, or no PDF engine could process it. Valid sources: 'text', 'image', 'pdf'."
}
if errors.Is(err, gotenberg.ErrPdfRotateAngleNotSupported) {
return http.StatusBadRequest, "The requested rotation angle is not supported. Valid angles: 90, 180, 270."
}
if invalidArgsError, ok := errors.AsType[*gotenberg.PdfEngineInvalidArgsError](err); ok {
return http.StatusBadRequest, invalidArgsError.Error()
}
var httpErr HttpError
if errors.As(err, &httpErr) {
return httpErr.HttpError()
}
// Default 500 status code.
return http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)
}
// httpErrorHandler is the centralized HTTP error handler. It parses the error,
// returns a response as "text/plain; charset=UTF-8".
func httpErrorHandler() echo.HTTPErrorHandler {
return func(err error, c echo.Context) {
logger := c.Get("logger").(*slog.Logger)
status, message := ParseError(err)
c.Response().Header().Add(echo.HeaderContentType, echo.MIMETextPlainCharsetUTF8)
err = c.String(status, message)
if err != nil {
logger.ErrorContext(c.Request().Context(), fmt.Sprintf("send error response: %s", err.Error()))
}
}
}
// latencyMiddleware sets the start time in the [echo.Context] under
// "startTime". Its value will be used later to calculate request latency.
//
// startTime := c.Get("startTime").(time.Time)
func latencyMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// First piece for calculating the latency.
startTime := time.Now()
c.Set("startTime", startTime)
// Call the next middleware in the chain.
return next(c)
}
}
}
// rootPathMiddleware sets the root path in the [echo.Context] under
// "rootPath". Its value may be used to skip a middleware execution based on a
// request URI.
//
// rootPath := c.Get("rootPath").(string)
// healthURI := fmt.Sprintf("%s/health", rootPath)
//
// // Skip the middleware if it's the health check URI.
// if c.Request().RequestURI == healthURI {
// // Call the next middleware in the chain.
// return next(c)
// }
func rootPathMiddleware(rootPath string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Set("rootPath", rootPath)
// Call the next middleware in the chain.
return next(c)
}
}
}
// outputFilenameMiddleware sets the output filename in the [echo.Context]
// under "outputFilename".
//
// outputFilename := c.Get("outputFilename").(string)
func outputFilenameMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
filename := c.Request().Header.Get("Gotenberg-Output-Filename")
// See https://github.com/gotenberg/gotenberg/issues/1227.
if filename != "" {
filename = filepath.Base(filename)
}
c.Set("outputFilename", filename)
// Call the next middleware in the chain.
return next(c)
}
}
}
// telemetryMiddleware manages telemetry. It sets the correlation ID in the
// [echo.Context] under "correlationId".
//
// correlationIdHeader := c.Get("correlationIdHeader").(string)
// correlationId := c.Get("correlationId").(string)
func telemetryMiddleware(logger *slog.Logger, serverName, correlationIdHeader string, disableTelemetryForPaths []string) echo.MiddlewareFunc {
meter := gotenberg.Meter()
semconvSrv := semconvutil.NewHTTPServer(meter)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
startTime := c.Get("startTime").(time.Time)
rootPath := c.Get("rootPath").(string)
request := c.Request()
savedCtx := request.Context()
defer func() {
request = request.WithContext(savedCtx)
c.SetRequest(request)
}()
routePath := func() string {
path := c.Request().URL.Path
if path == "" {
path = "/"
}
return path
}()
// Evaluate if we should skip telemetry for this path.
skipTelemetry := false
for _, path := range disableTelemetryForPaths {
URI := fmt.Sprintf("%s%s", rootPath, path)
if c.Request().RequestURI == URI {
skipTelemetry = true
break
}
}
if skipTelemetry {
c.Set("logger", slog.New(slog.DiscardHandler))
err := next(c)
if err != nil {
c.Error(err)
}
return nil
}
correlationId := request.Header.Get(correlationIdHeader)
if correlationId == "" {
correlationId = uuid.NewString()
}
c.Set("correlationIdHeader", correlationIdHeader)
c.Set("correlationId", correlationId)
ctx := otel.GetTextMapPropagator().Extract(savedCtx, propagation.HeaderCarrier(request.Header))
rAttr := semconvSrv.Route(routePath)
opts := []trace.SpanStartOption{
trace.WithAttributes(
semconvSrv.RequestTraceAttrs(serverName, request, semconvutil.RequestTraceAttrsOpts{})...,
),
trace.WithSpanKind(trace.SpanKindServer),
trace.WithAttributes(rAttr),
}
spanName := strings.ToUpper(c.Request().Method) + " " + routePath
tracer := gotenberg.Tracer()
ctx, span := tracer.Start(ctx, spanName, opts...)
defer span.End()
otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(c.Response().Header()))
c.Response().Header().Set(correlationIdHeader, correlationId)
c.SetRequest(c.Request().WithContext(ctx))
appLogger := logger.
With(slog.String("log_type", "application")).
With(slog.String("correlation_id", correlationId))
loggerName := strings.ReplaceAll(
strings.ReplaceAll(c.Request().URL.Path, rootPath, ""),
"/",
"",
)
c.Set("logger", appLogger.With(slog.String("logger", loggerName)))
// Call the next middleware in the chain.
err := next(c)
finishTime := time.Now()
status := c.Response().Status
if err != nil {
parsedStatus, _ := ParseError(err)
status = parsedStatus
span.SetAttributes(attribute.String("error", err.Error()))
c.Error(err)
}
span.SetStatus(semconvSrv.Status(status))
span.SetAttributes(semconvSrv.ResponseTraceAttrs(semconvutil.ResponseTelemetry{
StatusCode: status,
WriteBytes: c.Response().Size,
})...)
accessLogger := logger.
With(slog.String("log_type", "access")).
With(slog.String("correlation_id", correlationId)).
With(slog.String("remote_ip", c.RealIP())).
With(slog.String("host", c.Request().Host)).
With(slog.String("uri", c.Request().RequestURI)).
With(slog.String("method", c.Request().Method)).
With(slog.String("path", routePath)).
With(slog.String("referer", c.Request().Referer())).
With(slog.String("user_agent", c.Request().UserAgent())).
With(slog.Int("status", c.Response().Status)).
With(slog.Int64("latency", int64(finishTime.Sub(startTime)))).
With(slog.String("latency_human", finishTime.Sub(startTime).String())).
With(slog.Int64("bytes_in", c.Request().ContentLength)).
With(slog.Int64("bytes_out", c.Response().Size))
if err != nil {
accessLogger.ErrorContext(ctx, err.Error())
} else {
accessLogger.InfoContext(ctx, "request handled")
}
additionalAttributes := []attribute.KeyValue{
semconvSrv.Route(routePath),
}
semconvSrv.RecordMetrics(ctx, semconvutil.ServerMetricData{
ServerName: serverName,
ResponseSize: c.Response().Size,
MetricAttributes: semconvutil.MetricAttributes{
Req: request,
StatusCode: status,
AdditionalAttributes: additionalAttributes,
},
MetricData: semconvutil.MetricData{
RequestSize: request.ContentLength,
ElapsedTime: float64(time.Since(startTime)) / float64(time.Millisecond),
},
})
return nil
}
}
}
// basicAuthMiddleware manages basic authentication.
func basicAuthMiddleware(username, password string) echo.MiddlewareFunc {
return middleware.BasicAuth(func(u string, p string, e echo.Context) (bool, error) {
if subtle.ConstantTimeCompare([]byte(u), []byte(username)) == 1 &&
subtle.ConstantTimeCompare([]byte(p), []byte(password)) == 1 {
return true, nil
}
return false, nil
})
}
// contextMiddleware, middleware for "multipart/form-data" requests, sets the
// [Context] and related context.CancelFunc in the [echo.Context] under
// "context" and "cancel". If the process is synchronous, it also handles the
// result of a "multipart/form-data" request.
//
// ctx := c.Get("context").(*api.Context)
// cancel := c.Get("cancel").(context.CancelFunc)
func contextMiddleware(fs *gotenberg.FileSystem, timeout time.Duration, bodyLimit int64, downloadFromCfg downloadFromConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
logger, _ := c.Get("logger").(*slog.Logger)
if logger == nil {
return errors.New("no logger in context (possible pool reuse)")
}
// We create a context with a timeout so that underlying processes are
// able to stop early and correctly handle a timeout scenario.
ctx, cancel, err := newContext(c, logger, fs, timeout, bodyLimit, downloadFromCfg)
if err != nil {
cancel()
return fmt.Errorf("create request context: %w", err)
}
c.Set("context", ctx)
c.Set("cancel", cancel)
// Call the next middleware in the chain.
err = next(c)
if errors.Is(err, ErrAsyncProcess) {
// A middleware/handler tells us that it's handling the process
// in an asynchronous fashion. Therefore, we must not cancel
// the context nor send an output file.
return c.NoContent(http.StatusNoContent)
}
defer cancel()
if errors.Is(err, ErrNoOutputFile) {
// A middleware/handler tells us that it's handling the process
// in an asynchronous fashion. Therefore, we must not cancel
// the context nor send an output file.
return nil
}
if err != nil {
return err
}
// No error, let's build the output file.
outputPath, err := ctx.BuildOutputFile()
if err != nil {
return fmt.Errorf("build output file: %w", err)
}
// Send the output file.
err = c.Attachment(outputPath, ctx.OutputFilename(outputPath))
if err != nil {
return fmt.Errorf("send response: %w", err)
}
return nil
}
}
}
// hardTimeoutMiddleware manages hard timeout scenarios, i.e., when a route
// handler fails to timeout as expected.
func hardTimeoutMiddleware(hardTimeout time.Duration) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Guard the type assertion so a pooled [echo.Context] whose
// store has been recycled under us does not crash the process.
// See the webhook async handler for the race this protects
// against.
logger, _ := c.Get("logger").(*slog.Logger)
if logger == nil {
return errors.New("no logger in context (possible pool reuse)")
}
// Define a hard timeout if the route handler fails to timeout as
// expected.
hardTimeoutCtx, hardTimeoutCancel := context.WithTimeout(
context.Background(),
hardTimeout,
)
defer hardTimeoutCancel()
errChan := make(chan error, 1)
go func() {
// In case of hard timeout, a panic may occur.
// This deferred function allows us to recover from such scenarios.
defer func() {
if r := recover(); r != nil {
logger.DebugContext(hardTimeoutCtx, fmt.Sprintf("recovering from a panic (possible cause being a hard timeout): %s", r))
}
}()
// Call the next middleware in the chain.
errChan <- next(c)
}()
select {
case err := <-errChan:
return err
case <-hardTimeoutCtx.Done():
logger.DebugContext(hardTimeoutCtx, "hard timeout as the route handler did not timeout as expected")
return fmt.Errorf("hard timeout: %w", hardTimeoutCtx.Err())
}
}
}
}