feat(api): add 'downloadForm' form field

This commit is contained in:
Julien Neuhart
2024-09-15 16:01:50 +02:00
parent d21f87b543
commit f2b6bd3d4b
10 changed files with 569 additions and 63 deletions
+9 -1
View File
@@ -37,6 +37,10 @@ API_TRACE_HEADER=Gotenberg-Trace
API_ENABLE_BASIC_AUTH=false
GOTENBERG_API_BASIC_AUTH_USERNAME=
GOTENBERG_API_BASIC_AUTH_PASSWORD=
API-DOWNLOAD-FROM-ALLOW-LIST=
API-DOWNLOAD-FROM-DENY-LIST=
API-DOWNLOAD-FROM-FROM-MAX-RETRY=4
API-DISABLE-DOWNLOAD-FROM=false
API_DISABLE_HEALTH_CHECK_LOGGING=false
CHROMIUM_RESTART_AFTER=0
CHROMIUM_MAX_QUEUE_SIZE=0
@@ -95,6 +99,10 @@ run: ## Start a Gotenberg container
--api-root-path=$(API_ROOT_PATH) \
--api-trace-header=$(API_TRACE_HEADER) \
--api-enable-basic-auth=$(API_ENABLE_BASIC_AUTH) \
--api-download-from-allow-list=$(API-DOWNLOAD-FROM-ALLOW-LIST) \
--api-download-from-deny-list=$(API-DOWNLOAD-FROM-DENY-LIST) \
--api-download-from-max-retry=$(API-DOWNLOAD-FROM-FROM-MAX-RETRY) \
--api-disable-download-from=$(API-DISABLE-DOWNLOAD-FROM) \
--api-disable-health-check-logging=$(API_DISABLE_HEALTH_CHECK_LOGGING) \
--chromium-restart-after=$(CHROMIUM_RESTART_AFTER) \
--chromium-auto-start=$(CHROMIUM_AUTO_START) \
@@ -151,7 +159,7 @@ build-tests: ## Build the tests' Docker image
tests: ## Start the testing environment
docker run --rm -it \
-v $(PWD):/tests \
$(DOCKER_REGISTRY)/$(DOCKER_REPOSITORY)g:$(GOTENBERG_VERSION)-tests \
$(DOCKER_REGISTRY)/$(DOCKER_REPOSITORY):$(GOTENBERG_VERSION)-tests \
bash
.PHONY: tests-once
+44 -1
View File
@@ -1,6 +1,11 @@
package gotenberg
import "go.uber.org/zap"
import (
"fmt"
"github.com/hashicorp/go-retryablehttp"
"go.uber.org/zap"
)
// LoggerProvider is an interface for a module that supplies a method for
// creating a [zap.Logger] instance for use by other modules.
@@ -12,3 +17,41 @@ import "go.uber.org/zap"
type LoggerProvider interface {
Logger(mod Module) (*zap.Logger, error)
}
// LeveledLogger is wrapper around a [zap.Logger] so that it may be used by a
// [retryablehttp.Client].
type LeveledLogger struct {
logger *zap.Logger
}
// NewLeveledLogger instantiates a [LeveledLogger].
func NewLeveledLogger(logger *zap.Logger) *LeveledLogger {
return &LeveledLogger{
logger: logger,
}
}
// Error logs a message at error level using the wrapped zap.Logger.
func (leveled LeveledLogger) Error(msg string, keysAndValues ...interface{}) {
leveled.logger.Error(fmt.Sprintf("%s: %+v", msg, keysAndValues))
}
// Warn logs a message at warning level using the wrapped zap.Logger.
func (leveled LeveledLogger) Warn(msg string, keysAndValues ...interface{}) {
leveled.logger.Warn(fmt.Sprintf("%s: %+v", msg, keysAndValues))
}
// Info logs a message at info level using the wrapped zap.Logger.
func (leveled LeveledLogger) Info(msg string, keysAndValues ...interface{}) {
leveled.logger.Info(fmt.Sprintf("%s: %+v", msg, keysAndValues))
}
// Debug logs a message at debug level using the wrapped zap.Logger.
func (leveled LeveledLogger) Debug(msg string, keysAndValues ...interface{}) {
leveled.logger.Debug(fmt.Sprintf("%s: %+v", msg, keysAndValues))
}
// Interface guards.
var (
_ retryablehttp.LeveledLogger = (*LeveledLogger)(nil)
)
@@ -1,4 +1,4 @@
package webhook
package gotenberg
import (
"testing"
@@ -7,17 +7,17 @@ import (
)
func TestLeveledLogger_Error(t *testing.T) {
leveledLogger{logger: zap.NewNop()}.Error("foo")
NewLeveledLogger(zap.NewNop()).Error("foo")
}
func TestLeveledLogger_Warn(t *testing.T) {
leveledLogger{logger: zap.NewNop()}.Warn("foo")
NewLeveledLogger(zap.NewNop()).Warn("foo")
}
func TestLeveledLogger_Info(t *testing.T) {
leveledLogger{logger: zap.NewNop()}.Info("foo")
NewLeveledLogger(zap.NewNop()).Info("foo")
}
func TestLeveledLogger_Debug(t *testing.T) {
leveledLogger{logger: zap.NewNop()}.Debug("foo")
NewLeveledLogger(zap.NewNop()).Debug("foo")
}
+20 -1
View File
@@ -10,6 +10,7 @@ import (
"time"
"github.com/alexliesenfeld/health"
"github.com/dlclark/regexp2"
"github.com/labstack/echo/v4"
flag "github.com/spf13/pflag"
"go.uber.org/multierr"
@@ -36,6 +37,7 @@ type Api struct {
traceHeader string
basicAuthUsername string
basicAuthPassword string
downloadFromCfg downloadFromConfig
disableHealthCheckLogging bool
routes []Route
@@ -47,6 +49,13 @@ type Api struct {
srv *echo.Echo
}
type downloadFromConfig struct {
allowList *regexp2.Regexp
denyList *regexp2.Regexp
maxRetry int
disable bool
}
// Router is a module interface which adds routes to the [Api].
type Router interface {
Routes() ([]Route, error)
@@ -168,6 +177,10 @@ func (a *Api) Descriptor() gotenberg.ModuleDescriptor {
fs.String("api-root-path", "/", "Set the root path of the API - for service discovery via URL paths")
fs.String("api-trace-header", "Gotenberg-Trace", "Set the header name to use for identifying requests")
fs.Bool("api-enable-basic-auth", false, "Enable basic authentication - will look for the GOTENBERG_API_BASIC_AUTH_USERNAME and GOTENBERG_API_BASIC_AUTH_PASSWORD environment variables")
fs.String("api-download-from-allow-list", "", "Set the allowed URLs for the download from feature using a regular expression")
fs.String("api-download-from-deny-list", "", "Set the denied URLs for the download from feature using a regular expression")
fs.Int("api-download-from-max-retry", 4, "Set the maximum number of retries for the download from feature")
fs.Bool("api-disable-download-from", false, "Disable the download from feature")
fs.Bool("api-disable-health-check-logging", false, "Disable health check logging")
return fs
}(),
@@ -185,6 +198,12 @@ func (a *Api) Provision(ctx *gotenberg.Context) error {
a.timeout = flags.MustDuration("api-timeout")
a.rootPath = flags.MustString("api-root-path")
a.traceHeader = flags.MustString("api-trace-header")
a.downloadFromCfg = downloadFromConfig{
allowList: flags.MustRegexp("api-download-from-allow-list"),
denyList: flags.MustRegexp("api-download-from-deny-list"),
maxRetry: flags.MustInt("api-download-from-max-retry"),
disable: flags.MustBool("api-disable-download-from"),
}
a.disableHealthCheckLogging = flags.MustBool("api-disable-health-check-logging")
// Port from env?
@@ -436,7 +455,7 @@ func (a *Api) Start() error {
}
if route.IsMultipart {
middlewares = append(middlewares, contextMiddleware(a.fs, a.timeout))
middlewares = append(middlewares, contextMiddleware(a.fs, a.timeout, a.downloadFromCfg))
for _, externalMultipartMiddleware := range externalMultipartMiddlewares {
middlewares = append(middlewares, externalMultipartMiddleware.Handler)
+152 -2
View File
@@ -3,9 +3,11 @@ package api
import (
"compress/flate"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"os"
@@ -14,9 +16,11 @@ import (
"time"
"github.com/google/uuid"
"github.com/hashicorp/go-retryablehttp"
"github.com/labstack/echo/v4"
"github.com/mholt/archiver/v3"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"golang.org/x/text/unicode/norm"
"github.com/gotenberg/gotenberg/v8/pkg/gotenberg"
@@ -46,6 +50,14 @@ type Context struct {
context.Context
}
type downloadFrom struct {
// Url is the URL to download a file from.
Url string `json:"url"`
// ExtraHttpHeaders are the HTTP headers to send alongside.
ExtraHttpHeaders map[string]string `json:"extraHttpHeaders"`
}
type osPathRename struct{}
func (o *osPathRename) Rename(oldpath, newpath string) error {
@@ -53,7 +65,7 @@ func (o *osPathRename) Rename(oldpath, newpath string) error {
}
// newContext returns a [Context] by parsing a "multipart/form-data" request.
func newContext(echoCtx echo.Context, logger *zap.Logger, fs *gotenberg.FileSystem, timeout time.Duration) (*Context, context.CancelFunc, error) {
func newContext(echoCtx echo.Context, logger *zap.Logger, fs *gotenberg.FileSystem, timeout time.Duration, downloadFromCfg downloadFromConfig, traceHeader, trace string) (*Context, context.CancelFunc, error) {
processCtx, processCancel := context.WithTimeout(context.Background(), timeout)
ctx := &Context{
@@ -126,6 +138,144 @@ func newContext(echoCtx echo.Context, logger *zap.Logger, fs *gotenberg.FileSyst
ctx.values = form.Value
ctx.files = make(map[string]string)
// First, try to download files listed in the "downloadFrom" form field, if
// any.
raw, ok := ctx.values["downloadFrom"]
if !downloadFromCfg.disable && ok {
var dls []downloadFrom
err = json.Unmarshal([]byte(raw[0]), &dls)
if err != nil {
return nil, cancel, WrapError(
fmt.Errorf("unmarshal json: %w", err),
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("Invalid 'downloadFrom' form field value: %s", err)),
)
}
eg, _ := errgroup.WithContext(ctx)
for i, dl := range dls {
eg.Go(func() error {
deadline, ok := ctx.Deadline()
if !ok {
// Should not happen, as context is created with a timeout.
return errors.New("context has no deadline")
}
if strings.TrimSpace(dl.Url) == "" {
return WrapError(
errors.New("empty download from URL"),
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("Invalid 'downloadFrom' form field entry %d: URL must be set", i)),
)
}
err := gotenberg.FilterDeadline(downloadFromCfg.allowList, downloadFromCfg.denyList, dl.Url, deadline)
if err != nil {
return fmt.Errorf("filter URL: %w", err)
}
logger.Debug(fmt.Sprintf("download file from '%s'", dl.Url))
req, err := retryablehttp.NewRequest(http.MethodGet, dl.Url, nil)
if err != nil {
return fmt.Errorf("create request to '%s': %w", dl.Url, err)
}
req.Header.Set("User-Agent", "Gotenberg")
for key, value := range dl.ExtraHttpHeaders {
req.Header.Set(key, value)
}
req.Header.Set(traceHeader, trace)
client := &retryablehttp.Client{
HTTPClient: &http.Client{
Timeout: time.Until(deadline),
},
RetryMax: downloadFromCfg.maxRetry,
RetryWaitMin: time.Duration(1) * time.Second,
RetryWaitMax: time.Until(deadline),
Logger: gotenberg.NewLeveledLogger(logger),
CheckRetry: retryablehttp.DefaultRetryPolicy,
Backoff: retryablehttp.DefaultBackoff,
}
resp, err := client.Do(req)
if err != nil {
return WrapError(
fmt.Errorf("download file from to '%s': %w", dl.Url, err),
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("Unable to download file from '%s': %s", dl.Url, err)),
)
}
defer func() {
err := resp.Body.Close()
if err != nil {
logger.Error(fmt.Sprintf("close response body from '%s': %s", dl.Url, err))
}
}()
if resp.StatusCode != http.StatusOK {
return WrapError(
fmt.Errorf("download file from to '%s': got status: '%s'", dl.Url, resp.Status),
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("Unable to download file from '%s': got status: '%s'", dl.Url, resp.Status)),
)
}
contentDisposition := resp.Header.Get("Content-Disposition")
if contentDisposition == "" {
return WrapError(
fmt.Errorf("no 'Content-Disposition' header from '%s'", dl.Url),
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("No 'Content-Disposition' header from '%s'", dl.Url)),
)
}
_, params, err := mime.ParseMediaType(contentDisposition)
if err != nil {
return WrapError(
fmt.Errorf("parse 'Content-Disposition' header '%s' from '%s': %w", contentDisposition, dl.Url, err),
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("Invalid 'Content-Disposition' header '%s' from '%s': %s", contentDisposition, dl.Url, err)),
)
}
filename, ok := params["filename"]
if !ok {
return WrapError(
fmt.Errorf("get filename from 'Content-Disposition' header '%s' from '%s'", contentDisposition, dl.Url),
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("Invalid 'Content-Disposition' header '%s' from '%s': no filename", contentDisposition, dl.Url)),
)
}
// Avoid directory traversal and make sure filename characters are
// normalized.
// See: https://github.com/gotenberg/gotenberg/issues/662.
filename = norm.NFC.String(filepath.Base(filename))
path := fmt.Sprintf("%s/%s", ctx.dirPath, filename)
out, err := os.Create(path)
if err != nil {
return fmt.Errorf("create local file: %w", err)
}
defer func() {
err := out.Close()
if err != nil {
logger.Error(fmt.Sprintf("close local file: %s", err))
}
}()
_, err = io.Copy(out, resp.Body)
if err != nil {
return fmt.Errorf("copy downloaded file from '%s' to local file: %v", dl.Url, err)
}
ctx.files[filename] = path
return nil
})
}
err = eg.Wait()
if err != nil {
return ctx, cancel, err
}
}
copyToDisk := func(fh *multipart.FileHeader) error {
in, err := fh.Open()
if err != nil {
@@ -149,7 +299,6 @@ func newContext(echoCtx echo.Context, logger *zap.Logger, fs *gotenberg.FileSyst
if err != nil {
return fmt.Errorf("create local file: %w", err)
}
defer func() {
err := out.Close()
if err != nil {
@@ -167,6 +316,7 @@ func newContext(echoCtx echo.Context, logger *zap.Logger, fs *gotenberg.FileSyst
return nil
}
// Then, copy the form files, if any.
for _, files := range form.File {
for _, fh := range files {
err = copyToDisk(fh)
+317 -1
View File
@@ -2,6 +2,7 @@ package api
import (
"bytes"
"context"
"errors"
"fmt"
"io"
@@ -15,6 +16,7 @@ import (
"testing"
"time"
"github.com/dlclark/regexp2"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"go.uber.org/zap"
@@ -75,9 +77,27 @@ func TestOsPathRename_Rename(t *testing.T) {
}
func TestNewContext(t *testing.T) {
defaultAllowList, err := regexp2.Compile("", 0)
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
defaultDenyList, err := regexp2.Compile("", 0)
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
defaultDownloadFromCfg := downloadFromConfig{
allowList: defaultAllowList,
denyList: defaultDenyList,
maxRetry: 1,
disable: false,
}
for _, tc := range []struct {
scenario string
request *http.Request
downloadFromCfg downloadFromConfig
downloadFromSrv *echo.Echo
expectContext *Context
expectError bool
expectHttpError bool
expectHttpStatus int
@@ -123,6 +143,236 @@ func TestNewContext(t *testing.T) {
expectHttpError: true,
expectHttpStatus: http.StatusBadRequest,
},
{
scenario: "invalid downloadFrom form field: cannot unmarshal",
request: func() *http.Request {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
defer func() {
err := writer.Close()
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
}()
err := writer.WriteField("downloadFrom", "foo")
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
return req
}(),
downloadFromCfg: defaultDownloadFromCfg,
expectError: true,
expectHttpError: true,
expectHttpStatus: http.StatusBadRequest,
},
{
scenario: "invalid downloadFrom form field: no URL",
request: func() *http.Request {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
defer func() {
err := writer.Close()
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
}()
err := writer.WriteField("downloadFrom", `[{}]`)
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
return req
}(),
downloadFromCfg: defaultDownloadFromCfg,
expectError: true,
expectHttpError: true,
expectHttpStatus: http.StatusBadRequest,
},
{
scenario: "invalid downloadFrom form field: filtered URL",
request: func() *http.Request {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
defer func() {
err := writer.Close()
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
}()
err := writer.WriteField("downloadFrom", `[{"url":"https://foo.bar"}]`)
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
return req
}(),
downloadFromCfg: func() downloadFromConfig {
denyList, err := regexp2.Compile("https://foo.bar", 0)
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
return downloadFromConfig{allowList: defaultAllowList, denyList: denyList, maxRetry: 1, disable: false}
}(),
expectError: true,
},
{
scenario: "invalid downloadFrom form field: unreachable URL",
request: func() *http.Request {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
defer func() {
err := writer.Close()
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
}()
err := writer.WriteField("downloadFrom", `[{"url":"http://localhost:80/"}]`)
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
return req
}(),
downloadFromCfg: defaultDownloadFromCfg,
expectError: true,
expectHttpError: true,
expectHttpStatus: http.StatusBadRequest,
},
{
scenario: "invalid downloadFrom form field: invalid status code",
request: func() *http.Request {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
defer func() {
err := writer.Close()
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
}()
err := writer.WriteField("downloadFrom", `[{"url":"http://localhost:80/"}]`)
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
return req
}(),
downloadFromSrv: func() *echo.Echo {
srv := echo.New()
srv.HideBanner = true
srv.GET("/", func(c echo.Context) error {
return c.String(http.StatusNotFound, http.StatusText(http.StatusNotFound))
})
return srv
}(),
downloadFromCfg: defaultDownloadFromCfg,
expectError: true,
expectHttpError: true,
expectHttpStatus: http.StatusBadRequest,
},
{
scenario: "invalid downloadFrom form field: no 'Content-Disposition' header",
request: func() *http.Request {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
defer func() {
err := writer.Close()
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
}()
err := writer.WriteField("downloadFrom", `[{"url":"http://localhost:80/"}]`)
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
return req
}(),
downloadFromSrv: func() *echo.Echo {
srv := echo.New()
srv.HideBanner = true
srv.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, http.StatusText(http.StatusOK))
})
return srv
}(),
downloadFromCfg: defaultDownloadFromCfg,
expectError: true,
expectHttpError: true,
expectHttpStatus: http.StatusBadRequest,
},
{
scenario: "invalid downloadFrom form field: malformed 'Content-Disposition' header",
request: func() *http.Request {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
defer func() {
err := writer.Close()
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
}()
err := writer.WriteField("downloadFrom", `[{"url":"http://localhost:80/"}]`)
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
return req
}(),
downloadFromSrv: func() *echo.Echo {
srv := echo.New()
srv.HideBanner = true
srv.GET("/", func(c echo.Context) error {
c.Response().Header().Set(echo.HeaderContentDisposition, ";;")
return c.String(http.StatusOK, http.StatusText(http.StatusOK))
})
return srv
}(),
downloadFromCfg: defaultDownloadFromCfg,
expectError: true,
expectHttpError: true,
expectHttpStatus: http.StatusBadRequest,
},
{
scenario: "invalid downloadFrom form field: no filename parameter in 'Content-Disposition' header",
request: func() *http.Request {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
defer func() {
err := writer.Close()
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
}()
err := writer.WriteField("downloadFrom", `[{"url":"http://localhost:80/"}]`)
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
return req
}(),
downloadFromSrv: func() *echo.Echo {
srv := echo.New()
srv.HideBanner = true
srv.GET("/", func(c echo.Context) error {
c.Response().Header().Set(echo.HeaderContentDisposition, "inline;")
return c.String(http.StatusOK, http.StatusText(http.StatusOK))
})
return srv
}(),
downloadFromCfg: defaultDownloadFromCfg,
expectError: true,
expectHttpError: true,
expectHttpStatus: http.StatusBadRequest,
},
{
scenario: "success",
request: func() *http.Request {
@@ -146,17 +396,69 @@ func TestNewContext(t *testing.T) {
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
err = writer.WriteField("downloadFrom", `[{"url":"http://localhost:80/","extraHttpHeaders":{"X-Foo":"Bar"}}]`)
if err != nil {
t.Fatalf("expected no error but got: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
return req
}(),
downloadFromSrv: func() *echo.Echo {
srv := echo.New()
srv.HideBanner = true
srv.GET("/", func(c echo.Context) error {
if c.Request().Header.Get("User-Agent") != "Gotenberg" {
t.Fatalf("expected 'Gotenberg' from header 'User-Agent', but got '%s'", c.Request().Header.Get("User-Agent"))
}
if c.Request().Header.Get("X-Foo") != "Bar" {
t.Fatalf("expected 'Bar' from header 'X-Foo', but got '%s'", c.Request().Header.Get("X-Foo"))
}
if c.Request().Header.Get("Gotenberg-Trace") != "123" {
t.Fatalf("expected '123' from header 'Gotenberg-Trace', but got '%s'", c.Request().Header.Get("Gotenberg-Trace"))
}
c.Response().Header().Set(echo.HeaderContentDisposition, `attachment; filename="bar.txt"`)
c.Response().Header().Set(echo.HeaderContentType, "text/plain")
return c.String(http.StatusOK, http.StatusText(http.StatusOK))
})
return srv
}(),
downloadFromCfg: defaultDownloadFromCfg,
expectContext: &Context{
values: map[string][]string{
"foo": {"foo"},
"downloadFrom": {
`[{"url":"http://localhost:80/","extraHttpHeaders":{"X-Foo":"Bar"}}]`,
},
},
files: map[string]string{
"foo.txt": "foo.txt",
"bar.txt": "bar.txt", // downloadFrom.
},
},
expectError: false,
expectHttpError: false,
},
} {
t.Run(tc.scenario, func(t *testing.T) {
if tc.downloadFromSrv != nil {
go func() {
err := tc.downloadFromSrv.Start(":80")
if !errors.Is(err, http.ErrServerClosed) {
t.Error(err)
return
}
}()
defer func() {
err := tc.downloadFromSrv.Shutdown(context.TODO())
if err != nil {
t.Error(err)
}
}()
}
handler := func(c echo.Context) error {
_, cancel, err := newContext(c, zap.NewNop(), gotenberg.NewFileSystem(), time.Duration(10)*time.Second)
ctx, cancel, err := newContext(c, zap.NewNop(), gotenberg.NewFileSystem(), time.Duration(10)*time.Second, tc.downloadFromCfg, "Gotenberg-Trace", "123")
defer cancel()
// Context already cancelled.
defer cancel()
@@ -165,6 +467,20 @@ func TestNewContext(t *testing.T) {
return err
}
if tc.expectContext != nil {
if !reflect.DeepEqual(tc.expectContext.values, ctx.values) {
t.Fatalf("expected context.values to be %v but got %v", tc.expectContext.values, ctx.values)
}
if len(tc.expectContext.files) != len(ctx.files) {
t.Fatalf("expected context.files to contain %d items but got %d", len(tc.expectContext.files), len(ctx.files))
}
for key, value := range tc.expectContext.files {
if !strings.HasSuffix(ctx.files[key], value) {
t.Fatalf("expected context.files to contain '%s' but got '%s'", value, ctx.files[key])
}
}
}
return nil
}
+4 -2
View File
@@ -236,14 +236,16 @@ func basicAuthMiddleware(username, password string) echo.MiddlewareFunc {
//
// ctx := c.Get("context").(*api.Context)
// cancel := c.Get("cancel").(context.CancelFunc)
func contextMiddleware(fs *gotenberg.FileSystem, timeout time.Duration) echo.MiddlewareFunc {
func contextMiddleware(fs *gotenberg.FileSystem, timeout time.Duration, downloadFromCfg downloadFromConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
logger := c.Get("logger").(*zap.Logger)
traceHeader := c.Get("traceHeader").(string)
trace := c.Get("trace").(string)
// We create a context with a timeout so that underlying processes are
// able to stop early and handle correctly a timeout scenario.
ctx, cancel, err := newContext(c, logger, fs, timeout)
ctx, cancel, err := newContext(c, logger, fs, timeout, downloadFromCfg, traceHeader, trace)
if err != nil {
cancel()
+2 -1
View File
@@ -458,10 +458,11 @@ func TestContextMiddleware(t *testing.T) {
c := srv.NewContext(tc.request, recorder)
c.Set("logger", zap.NewNop())
c.Set("traceHeader", "Gotenberg-Trace")
c.Set("trace", "foo")
c.Set("startTime", time.Now())
err := contextMiddleware(gotenberg.NewFileSystem(), time.Duration(10)*time.Second)(tc.next)(c)
err := contextMiddleware(gotenberg.NewFileSystem(), time.Duration(10)*time.Second, downloadFromConfig{})(tc.next)(c)
if tc.expectErr && err == nil {
t.Errorf("test %d: expected error but got: %v", i, err)
+8 -39
View File
@@ -27,9 +27,9 @@ type client struct {
// send call the webhook either to send the success response or the error response.
func (c client) send(body io.Reader, headers map[string]string, erroed bool) error {
URL := c.url
url := c.url
if erroed {
URL = c.errorUrl
url = c.errorUrl
}
method := c.method
@@ -37,9 +37,9 @@ func (c client) send(body io.Reader, headers map[string]string, erroed bool) err
method = c.errorMethod
}
req, err := retryablehttp.NewRequest(method, URL, body)
req, err := retryablehttp.NewRequest(method, url, body)
if err != nil {
return fmt.Errorf("create '%s' request to '%s': %w", method, URL, err)
return fmt.Errorf("create '%s' request to '%s': %w", method, url, err)
}
req.Header.Set("User-Agent", "Gotenberg")
@@ -75,17 +75,17 @@ func (c client) send(body io.Reader, headers map[string]string, erroed bool) err
resp, err := c.client.Do(req)
if err != nil {
return fmt.Errorf("send '%s' request to '%s': %w", method, URL, err)
return fmt.Errorf("send '%s' request to '%s': %w", method, url, err)
}
if resp.StatusCode >= http.StatusBadRequest {
return fmt.Errorf("send '%s' request to '%s': got status: '%s'", method, URL, resp.Status)
return fmt.Errorf("send '%s' request to '%s': got status: '%s'", method, url, resp.Status)
}
defer func() {
err := resp.Body.Close()
if err != nil {
c.logger.Error(fmt.Sprintf("close response body from '%s': %s", URL, err))
c.logger.Error(fmt.Sprintf("close response body from '%s': %s", url, err))
}
}()
@@ -94,7 +94,7 @@ func (c client) send(body io.Reader, headers map[string]string, erroed bool) err
// Now let's log!
fields := make([]zap.Field, 5)
fields[0] = zap.String("webhook_url", URL)
fields[0] = zap.String("webhook_url", url)
fields[1] = zap.String("method", method)
fields[2] = zap.Int64("latency", int64(finishTime.Sub(c.startTime)))
fields[3] = zap.String("latency_human", finishTime.Sub(c.startTime).String())
@@ -110,34 +110,3 @@ func (c client) send(body io.Reader, headers map[string]string, erroed bool) err
return nil
}
// leveledLogger is wrapper around a [zap.Logger] which is used by the
// [retryablehttp.Client].
type leveledLogger struct {
logger *zap.Logger
}
// Error logs a message at error level using the wrapped zap.Logger.
func (leveled leveledLogger) Error(msg string, keysAndValues ...interface{}) {
leveled.logger.Error(fmt.Sprintf("%s: %+v", msg, keysAndValues))
}
// Warn logs a message at warning level using the wrapped zap.Logger.
func (leveled leveledLogger) Warn(msg string, keysAndValues ...interface{}) {
leveled.logger.Warn(fmt.Sprintf("%s: %+v", msg, keysAndValues))
}
// Info logs a message at info level using the wrapped zap.Logger.
func (leveled leveledLogger) Info(msg string, keysAndValues ...interface{}) {
leveled.logger.Info(fmt.Sprintf("%s: %+v", msg, keysAndValues))
}
// Debug logs a message at debug level using the wrapped zap.Logger.
func (leveled leveledLogger) Debug(msg string, keysAndValues ...interface{}) {
leveled.logger.Debug(fmt.Sprintf("%s: %+v", msg, keysAndValues))
}
// Interface guards.
var (
_ retryablehttp.LeveledLogger = (*leveledLogger)(nil)
)
+8 -10
View File
@@ -100,11 +100,11 @@ func webhookMiddleware(w *Webhook) api.Middleware {
}
// What about extra HTTP headers?
var extraHTTPHeaders map[string]string
var extraHttpHeaders map[string]string
extraHTTPHeadersJSON := c.Request().Header.Get("Gotenberg-Webhook-Extra-Http-Headers")
if extraHTTPHeadersJSON != "" {
err = json.Unmarshal([]byte(extraHTTPHeadersJSON), &extraHTTPHeaders)
extraHttpHeadersJson := c.Request().Header.Get("Gotenberg-Webhook-Extra-Http-Headers")
if extraHttpHeadersJson != "" {
err = json.Unmarshal([]byte(extraHttpHeadersJson), &extraHttpHeaders)
if err != nil {
return api.WrapError(
fmt.Errorf("unmarshal webhook extra HTTP headers: %w", err),
@@ -118,7 +118,7 @@ func webhookMiddleware(w *Webhook) api.Middleware {
method: webhookMethod,
errorUrl: webhookErrorUrl,
errorMethod: webhookErrorMethod,
extraHttpHeaders: extraHTTPHeaders,
extraHttpHeaders: extraHttpHeaders,
startTime: c.Get("startTime").(time.Time),
client: &retryablehttp.Client{
@@ -128,11 +128,9 @@ func webhookMiddleware(w *Webhook) api.Middleware {
RetryMax: w.maxRetry,
RetryWaitMin: w.retryMinWait,
RetryWaitMax: w.retryMaxWait,
Logger: leveledLogger{
logger: ctx.Log(),
},
CheckRetry: retryablehttp.DefaultRetryPolicy,
Backoff: retryablehttp.DefaultBackoff,
Logger: gotenberg.NewLeveledLogger(ctx.Log()),
CheckRetry: retryablehttp.DefaultRetryPolicy,
Backoff: retryablehttp.DefaultBackoff,
},
logger: ctx.Log(),
}