package api import ( "context" "crypto/subtle" "errors" "fmt" "log/slog" "net/http" "path/filepath" "strings" "time" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" "github.com/gotenberg/gotenberg/v8/pkg/gotenberg" semconvutil "github.com/gotenberg/gotenberg/v8/pkg/gotenberg/semconv" ) var ( // ErrAsyncProcess happens when a handler or middleware handles a request // in an asynchronous fashion. ErrAsyncProcess = errors.New("async process") // ErrNoOutputFile happens when a handler or middleware handles a request // without sending any output file. ErrNoOutputFile = errors.New("no output file") ) // ParseError parses an error and returns the corresponding HTTP status and // HTTP message. func ParseError(err error) (int, string) { var echoErr *echo.HTTPError ok := errors.As(err, &echoErr) if ok { return echoErr.Code, http.StatusText(echoErr.Code) } if errors.Is(err, context.DeadlineExceeded) { return http.StatusServiceUnavailable, "The request exceeded the time limit. Increase it with --api-timeout, or reduce the workload." } if errors.Is(err, gotenberg.ErrFiltered) { return http.StatusForbidden, http.StatusText(http.StatusForbidden) } if errors.Is(err, gotenberg.ErrMaximumQueueSizeExceeded) { return http.StatusTooManyRequests, "The request queue is full. Retry shortly, or raise the limit with --chromium-max-queue-size or --libreoffice-max-queue-size." } if errors.Is(err, gotenberg.ErrPdfSplitModeNotSupported) { return http.StatusBadRequest, "The requested split mode is not supported, or no PDF engine could process it. Valid modes: 'intervals', 'pages'." } if errors.Is(err, gotenberg.ErrPdfFormatNotSupported) { return http.StatusBadRequest, "The requested PDF format is not supported, or no PDF engine could apply it. Valid formats include PDF/A-1b, PDF/A-2b, PDF/A-3b, and PDF/UA." } if errors.Is(err, gotenberg.ErrPdfEngineMetadataValueNotSupported) { return http.StatusBadRequest, "The requested metadata could not be written; ensure values are valid and free of control characters." } if errors.Is(err, gotenberg.ErrPdfStampSourceNotSupported) { return http.StatusBadRequest, "The requested stamp source is not supported, or no PDF engine could process it. Valid sources: 'text', 'image', 'pdf'." } if errors.Is(err, gotenberg.ErrPdfRotateAngleNotSupported) { return http.StatusBadRequest, "The requested rotation angle is not supported. Valid angles: 90, 180, 270." } if invalidArgsError, ok := errors.AsType[*gotenberg.PdfEngineInvalidArgsError](err); ok { return http.StatusBadRequest, invalidArgsError.Error() } var httpErr HttpError if errors.As(err, &httpErr) { return httpErr.HttpError() } // Default 500 status code. return http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError) } // httpErrorHandler is the centralized HTTP error handler. It parses the error, // returns a response as "text/plain; charset=UTF-8". func httpErrorHandler() echo.HTTPErrorHandler { return func(err error, c echo.Context) { logger := c.Get("logger").(*slog.Logger) status, message := ParseError(err) c.Response().Header().Add(echo.HeaderContentType, echo.MIMETextPlainCharsetUTF8) err = c.String(status, message) if err != nil { logger.ErrorContext(c.Request().Context(), fmt.Sprintf("send error response: %s", err.Error())) } } } // latencyMiddleware sets the start time in the [echo.Context] under // "startTime". Its value will be used later to calculate request latency. // // startTime := c.Get("startTime").(time.Time) func latencyMiddleware() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { // First piece for calculating the latency. startTime := time.Now() c.Set("startTime", startTime) // Call the next middleware in the chain. return next(c) } } } // rootPathMiddleware sets the root path in the [echo.Context] under // "rootPath". Its value may be used to skip a middleware execution based on a // request URI. // // rootPath := c.Get("rootPath").(string) // healthURI := fmt.Sprintf("%s/health", rootPath) // // // Skip the middleware if it's the health check URI. // if c.Request().RequestURI == healthURI { // // Call the next middleware in the chain. // return next(c) // } func rootPathMiddleware(rootPath string) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { c.Set("rootPath", rootPath) // Call the next middleware in the chain. return next(c) } } } // outputFilenameMiddleware sets the output filename in the [echo.Context] // under "outputFilename". // // outputFilename := c.Get("outputFilename").(string) func outputFilenameMiddleware() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { filename := c.Request().Header.Get("Gotenberg-Output-Filename") // See https://github.com/gotenberg/gotenberg/issues/1227. if filename != "" { filename = filepath.Base(filename) } c.Set("outputFilename", filename) // Call the next middleware in the chain. return next(c) } } } // telemetryMiddleware manages telemetry. It sets the correlation ID in the // [echo.Context] under "correlationId". // // correlationIdHeader := c.Get("correlationIdHeader").(string) // correlationId := c.Get("correlationId").(string) func telemetryMiddleware(logger *slog.Logger, serverName, correlationIdHeader string, disableTelemetryForPaths []string) echo.MiddlewareFunc { meter := gotenberg.Meter() semconvSrv := semconvutil.NewHTTPServer(meter) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { startTime := c.Get("startTime").(time.Time) rootPath := c.Get("rootPath").(string) request := c.Request() savedCtx := request.Context() defer func() { request = request.WithContext(savedCtx) c.SetRequest(request) }() routePath := func() string { path := c.Request().URL.Path if path == "" { path = "/" } return path }() // Evaluate if we should skip telemetry for this path. skipTelemetry := false for _, path := range disableTelemetryForPaths { URI := fmt.Sprintf("%s%s", rootPath, path) if c.Request().RequestURI == URI { skipTelemetry = true break } } if skipTelemetry { c.Set("logger", slog.New(slog.DiscardHandler)) err := next(c) if err != nil { c.Error(err) } return nil } correlationId := request.Header.Get(correlationIdHeader) if correlationId == "" { correlationId = uuid.NewString() } c.Set("correlationIdHeader", correlationIdHeader) c.Set("correlationId", correlationId) ctx := otel.GetTextMapPropagator().Extract(savedCtx, propagation.HeaderCarrier(request.Header)) rAttr := semconvSrv.Route(routePath) opts := []trace.SpanStartOption{ trace.WithAttributes( semconvSrv.RequestTraceAttrs(serverName, request, semconvutil.RequestTraceAttrsOpts{})..., ), trace.WithSpanKind(trace.SpanKindServer), trace.WithAttributes(rAttr), } spanName := strings.ToUpper(c.Request().Method) + " " + routePath tracer := gotenberg.Tracer() ctx, span := tracer.Start(ctx, spanName, opts...) defer span.End() otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(c.Response().Header())) c.Response().Header().Set(correlationIdHeader, correlationId) c.SetRequest(c.Request().WithContext(ctx)) appLogger := logger. With(slog.String("log_type", "application")). With(slog.String("correlation_id", correlationId)) loggerName := strings.ReplaceAll( strings.ReplaceAll(c.Request().URL.Path, rootPath, ""), "/", "", ) c.Set("logger", appLogger.With(slog.String("logger", loggerName))) // Call the next middleware in the chain. err := next(c) finishTime := time.Now() status := c.Response().Status if err != nil { parsedStatus, _ := ParseError(err) status = parsedStatus span.SetAttributes(attribute.String("error", err.Error())) c.Error(err) } span.SetStatus(semconvSrv.Status(status)) span.SetAttributes(semconvSrv.ResponseTraceAttrs(semconvutil.ResponseTelemetry{ StatusCode: status, WriteBytes: c.Response().Size, })...) accessLogger := logger. With(slog.String("log_type", "access")). With(slog.String("correlation_id", correlationId)). With(slog.String("remote_ip", c.RealIP())). With(slog.String("host", c.Request().Host)). With(slog.String("uri", c.Request().RequestURI)). With(slog.String("method", c.Request().Method)). With(slog.String("path", routePath)). With(slog.String("referer", c.Request().Referer())). With(slog.String("user_agent", c.Request().UserAgent())). With(slog.Int("status", c.Response().Status)). With(slog.Int64("latency", int64(finishTime.Sub(startTime)))). With(slog.String("latency_human", finishTime.Sub(startTime).String())). With(slog.Int64("bytes_in", c.Request().ContentLength)). With(slog.Int64("bytes_out", c.Response().Size)) if err != nil { accessLogger.ErrorContext(ctx, err.Error()) } else { accessLogger.InfoContext(ctx, "request handled") } additionalAttributes := []attribute.KeyValue{ semconvSrv.Route(routePath), } semconvSrv.RecordMetrics(ctx, semconvutil.ServerMetricData{ ServerName: serverName, ResponseSize: c.Response().Size, MetricAttributes: semconvutil.MetricAttributes{ Req: request, StatusCode: status, AdditionalAttributes: additionalAttributes, }, MetricData: semconvutil.MetricData{ RequestSize: request.ContentLength, ElapsedTime: float64(time.Since(startTime)) / float64(time.Millisecond), }, }) return nil } } } // basicAuthMiddleware manages basic authentication. func basicAuthMiddleware(username, password string) echo.MiddlewareFunc { return middleware.BasicAuth(func(u string, p string, e echo.Context) (bool, error) { if subtle.ConstantTimeCompare([]byte(u), []byte(username)) == 1 && subtle.ConstantTimeCompare([]byte(p), []byte(password)) == 1 { return true, nil } return false, nil }) } // contextMiddleware, middleware for "multipart/form-data" requests, sets the // [Context] and related context.CancelFunc in the [echo.Context] under // "context" and "cancel". If the process is synchronous, it also handles the // result of a "multipart/form-data" request. // // ctx := c.Get("context").(*api.Context) // cancel := c.Get("cancel").(context.CancelFunc) func contextMiddleware(fs *gotenberg.FileSystem, timeout time.Duration, bodyLimit int64, downloadFromCfg downloadFromConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { logger, _ := c.Get("logger").(*slog.Logger) if logger == nil { return errors.New("no logger in context (possible pool reuse)") } // We create a context with a timeout so that underlying processes are // able to stop early and correctly handle a timeout scenario. ctx, cancel, err := newContext(c, logger, fs, timeout, bodyLimit, downloadFromCfg) if err != nil { cancel() return fmt.Errorf("create request context: %w", err) } c.Set("context", ctx) c.Set("cancel", cancel) // Call the next middleware in the chain. err = next(c) if errors.Is(err, ErrAsyncProcess) { // A middleware/handler tells us that it's handling the process // in an asynchronous fashion. Therefore, we must not cancel // the context nor send an output file. return c.NoContent(http.StatusNoContent) } defer cancel() if errors.Is(err, ErrNoOutputFile) { // A middleware/handler tells us that it's handling the process // in an asynchronous fashion. Therefore, we must not cancel // the context nor send an output file. return nil } if err != nil { return err } // No error, let's build the output file. outputPath, err := ctx.BuildOutputFile() if err != nil { return fmt.Errorf("build output file: %w", err) } // Send the output file. err = c.Attachment(outputPath, ctx.OutputFilename(outputPath)) if err != nil { return fmt.Errorf("send response: %w", err) } return nil } } } // hardTimeoutMiddleware manages hard timeout scenarios, i.e., when a route // handler fails to timeout as expected. func hardTimeoutMiddleware(hardTimeout time.Duration) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { // Guard the type assertion so a pooled [echo.Context] whose // store has been recycled under us does not crash the process. // See the webhook async handler for the race this protects // against. logger, _ := c.Get("logger").(*slog.Logger) if logger == nil { return errors.New("no logger in context (possible pool reuse)") } // Define a hard timeout if the route handler fails to timeout as // expected. hardTimeoutCtx, hardTimeoutCancel := context.WithTimeout( context.Background(), hardTimeout, ) defer hardTimeoutCancel() errChan := make(chan error, 1) go func() { // In case of hard timeout, a panic may occur. // This deferred function allows us to recover from such scenarios. defer func() { if r := recover(); r != nil { logger.DebugContext(hardTimeoutCtx, fmt.Sprintf("recovering from a panic (possible cause being a hard timeout): %s", r)) } }() // Call the next middleware in the chain. errChan <- next(c) }() select { case err := <-errChan: return err case <-hardTimeoutCtx.Done(): logger.DebugContext(hardTimeoutCtx, "hard timeout as the route handler did not timeout as expected") return fmt.Errorf("hard timeout: %w", hardTimeoutCtx.Err()) } } } }