mirror of
https://github.com/gotenberg/gotenberg.git
synced 2026-07-02 08:27:41 +08:00
feat(api): add 'downloadForm' form field
This commit is contained in:
@@ -37,6 +37,10 @@ API_TRACE_HEADER=Gotenberg-Trace
|
||||
API_ENABLE_BASIC_AUTH=false
|
||||
GOTENBERG_API_BASIC_AUTH_USERNAME=
|
||||
GOTENBERG_API_BASIC_AUTH_PASSWORD=
|
||||
API-DOWNLOAD-FROM-ALLOW-LIST=
|
||||
API-DOWNLOAD-FROM-DENY-LIST=
|
||||
API-DOWNLOAD-FROM-FROM-MAX-RETRY=4
|
||||
API-DISABLE-DOWNLOAD-FROM=false
|
||||
API_DISABLE_HEALTH_CHECK_LOGGING=false
|
||||
CHROMIUM_RESTART_AFTER=0
|
||||
CHROMIUM_MAX_QUEUE_SIZE=0
|
||||
@@ -95,6 +99,10 @@ run: ## Start a Gotenberg container
|
||||
--api-root-path=$(API_ROOT_PATH) \
|
||||
--api-trace-header=$(API_TRACE_HEADER) \
|
||||
--api-enable-basic-auth=$(API_ENABLE_BASIC_AUTH) \
|
||||
--api-download-from-allow-list=$(API-DOWNLOAD-FROM-ALLOW-LIST) \
|
||||
--api-download-from-deny-list=$(API-DOWNLOAD-FROM-DENY-LIST) \
|
||||
--api-download-from-max-retry=$(API-DOWNLOAD-FROM-FROM-MAX-RETRY) \
|
||||
--api-disable-download-from=$(API-DISABLE-DOWNLOAD-FROM) \
|
||||
--api-disable-health-check-logging=$(API_DISABLE_HEALTH_CHECK_LOGGING) \
|
||||
--chromium-restart-after=$(CHROMIUM_RESTART_AFTER) \
|
||||
--chromium-auto-start=$(CHROMIUM_AUTO_START) \
|
||||
@@ -151,7 +159,7 @@ build-tests: ## Build the tests' Docker image
|
||||
tests: ## Start the testing environment
|
||||
docker run --rm -it \
|
||||
-v $(PWD):/tests \
|
||||
$(DOCKER_REGISTRY)/$(DOCKER_REPOSITORY)g:$(GOTENBERG_VERSION)-tests \
|
||||
$(DOCKER_REGISTRY)/$(DOCKER_REPOSITORY):$(GOTENBERG_VERSION)-tests \
|
||||
bash
|
||||
|
||||
.PHONY: tests-once
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package gotenberg
|
||||
|
||||
import "go.uber.org/zap"
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-retryablehttp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// LoggerProvider is an interface for a module that supplies a method for
|
||||
// creating a [zap.Logger] instance for use by other modules.
|
||||
@@ -12,3 +17,41 @@ import "go.uber.org/zap"
|
||||
type LoggerProvider interface {
|
||||
Logger(mod Module) (*zap.Logger, error)
|
||||
}
|
||||
|
||||
// LeveledLogger is wrapper around a [zap.Logger] so that it may be used by a
|
||||
// [retryablehttp.Client].
|
||||
type LeveledLogger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewLeveledLogger instantiates a [LeveledLogger].
|
||||
func NewLeveledLogger(logger *zap.Logger) *LeveledLogger {
|
||||
return &LeveledLogger{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Error logs a message at error level using the wrapped zap.Logger.
|
||||
func (leveled LeveledLogger) Error(msg string, keysAndValues ...interface{}) {
|
||||
leveled.logger.Error(fmt.Sprintf("%s: %+v", msg, keysAndValues))
|
||||
}
|
||||
|
||||
// Warn logs a message at warning level using the wrapped zap.Logger.
|
||||
func (leveled LeveledLogger) Warn(msg string, keysAndValues ...interface{}) {
|
||||
leveled.logger.Warn(fmt.Sprintf("%s: %+v", msg, keysAndValues))
|
||||
}
|
||||
|
||||
// Info logs a message at info level using the wrapped zap.Logger.
|
||||
func (leveled LeveledLogger) Info(msg string, keysAndValues ...interface{}) {
|
||||
leveled.logger.Info(fmt.Sprintf("%s: %+v", msg, keysAndValues))
|
||||
}
|
||||
|
||||
// Debug logs a message at debug level using the wrapped zap.Logger.
|
||||
func (leveled LeveledLogger) Debug(msg string, keysAndValues ...interface{}) {
|
||||
leveled.logger.Debug(fmt.Sprintf("%s: %+v", msg, keysAndValues))
|
||||
}
|
||||
|
||||
// Interface guards.
|
||||
var (
|
||||
_ retryablehttp.LeveledLogger = (*LeveledLogger)(nil)
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package webhook
|
||||
package gotenberg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -7,17 +7,17 @@ import (
|
||||
)
|
||||
|
||||
func TestLeveledLogger_Error(t *testing.T) {
|
||||
leveledLogger{logger: zap.NewNop()}.Error("foo")
|
||||
NewLeveledLogger(zap.NewNop()).Error("foo")
|
||||
}
|
||||
|
||||
func TestLeveledLogger_Warn(t *testing.T) {
|
||||
leveledLogger{logger: zap.NewNop()}.Warn("foo")
|
||||
NewLeveledLogger(zap.NewNop()).Warn("foo")
|
||||
}
|
||||
|
||||
func TestLeveledLogger_Info(t *testing.T) {
|
||||
leveledLogger{logger: zap.NewNop()}.Info("foo")
|
||||
NewLeveledLogger(zap.NewNop()).Info("foo")
|
||||
}
|
||||
|
||||
func TestLeveledLogger_Debug(t *testing.T) {
|
||||
leveledLogger{logger: zap.NewNop()}.Debug("foo")
|
||||
NewLeveledLogger(zap.NewNop()).Debug("foo")
|
||||
}
|
||||
+20
-1
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alexliesenfeld/health"
|
||||
"github.com/dlclark/regexp2"
|
||||
"github.com/labstack/echo/v4"
|
||||
flag "github.com/spf13/pflag"
|
||||
"go.uber.org/multierr"
|
||||
@@ -36,6 +37,7 @@ type Api struct {
|
||||
traceHeader string
|
||||
basicAuthUsername string
|
||||
basicAuthPassword string
|
||||
downloadFromCfg downloadFromConfig
|
||||
disableHealthCheckLogging bool
|
||||
|
||||
routes []Route
|
||||
@@ -47,6 +49,13 @@ type Api struct {
|
||||
srv *echo.Echo
|
||||
}
|
||||
|
||||
type downloadFromConfig struct {
|
||||
allowList *regexp2.Regexp
|
||||
denyList *regexp2.Regexp
|
||||
maxRetry int
|
||||
disable bool
|
||||
}
|
||||
|
||||
// Router is a module interface which adds routes to the [Api].
|
||||
type Router interface {
|
||||
Routes() ([]Route, error)
|
||||
@@ -168,6 +177,10 @@ func (a *Api) Descriptor() gotenberg.ModuleDescriptor {
|
||||
fs.String("api-root-path", "/", "Set the root path of the API - for service discovery via URL paths")
|
||||
fs.String("api-trace-header", "Gotenberg-Trace", "Set the header name to use for identifying requests")
|
||||
fs.Bool("api-enable-basic-auth", false, "Enable basic authentication - will look for the GOTENBERG_API_BASIC_AUTH_USERNAME and GOTENBERG_API_BASIC_AUTH_PASSWORD environment variables")
|
||||
fs.String("api-download-from-allow-list", "", "Set the allowed URLs for the download from feature using a regular expression")
|
||||
fs.String("api-download-from-deny-list", "", "Set the denied URLs for the download from feature using a regular expression")
|
||||
fs.Int("api-download-from-max-retry", 4, "Set the maximum number of retries for the download from feature")
|
||||
fs.Bool("api-disable-download-from", false, "Disable the download from feature")
|
||||
fs.Bool("api-disable-health-check-logging", false, "Disable health check logging")
|
||||
return fs
|
||||
}(),
|
||||
@@ -185,6 +198,12 @@ func (a *Api) Provision(ctx *gotenberg.Context) error {
|
||||
a.timeout = flags.MustDuration("api-timeout")
|
||||
a.rootPath = flags.MustString("api-root-path")
|
||||
a.traceHeader = flags.MustString("api-trace-header")
|
||||
a.downloadFromCfg = downloadFromConfig{
|
||||
allowList: flags.MustRegexp("api-download-from-allow-list"),
|
||||
denyList: flags.MustRegexp("api-download-from-deny-list"),
|
||||
maxRetry: flags.MustInt("api-download-from-max-retry"),
|
||||
disable: flags.MustBool("api-disable-download-from"),
|
||||
}
|
||||
a.disableHealthCheckLogging = flags.MustBool("api-disable-health-check-logging")
|
||||
|
||||
// Port from env?
|
||||
@@ -436,7 +455,7 @@ func (a *Api) Start() error {
|
||||
}
|
||||
|
||||
if route.IsMultipart {
|
||||
middlewares = append(middlewares, contextMiddleware(a.fs, a.timeout))
|
||||
middlewares = append(middlewares, contextMiddleware(a.fs, a.timeout, a.downloadFromCfg))
|
||||
|
||||
for _, externalMultipartMiddleware := range externalMultipartMiddlewares {
|
||||
middlewares = append(middlewares, externalMultipartMiddleware.Handler)
|
||||
|
||||
+152
-2
@@ -3,9 +3,11 @@ package api
|
||||
import (
|
||||
"compress/flate"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -14,9 +16,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/go-retryablehttp"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mholt/archiver/v3"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/text/unicode/norm"
|
||||
|
||||
"github.com/gotenberg/gotenberg/v8/pkg/gotenberg"
|
||||
@@ -46,6 +50,14 @@ type Context struct {
|
||||
context.Context
|
||||
}
|
||||
|
||||
type downloadFrom struct {
|
||||
// Url is the URL to download a file from.
|
||||
Url string `json:"url"`
|
||||
|
||||
// ExtraHttpHeaders are the HTTP headers to send alongside.
|
||||
ExtraHttpHeaders map[string]string `json:"extraHttpHeaders"`
|
||||
}
|
||||
|
||||
type osPathRename struct{}
|
||||
|
||||
func (o *osPathRename) Rename(oldpath, newpath string) error {
|
||||
@@ -53,7 +65,7 @@ func (o *osPathRename) Rename(oldpath, newpath string) error {
|
||||
}
|
||||
|
||||
// newContext returns a [Context] by parsing a "multipart/form-data" request.
|
||||
func newContext(echoCtx echo.Context, logger *zap.Logger, fs *gotenberg.FileSystem, timeout time.Duration) (*Context, context.CancelFunc, error) {
|
||||
func newContext(echoCtx echo.Context, logger *zap.Logger, fs *gotenberg.FileSystem, timeout time.Duration, downloadFromCfg downloadFromConfig, traceHeader, trace string) (*Context, context.CancelFunc, error) {
|
||||
processCtx, processCancel := context.WithTimeout(context.Background(), timeout)
|
||||
|
||||
ctx := &Context{
|
||||
@@ -126,6 +138,144 @@ func newContext(echoCtx echo.Context, logger *zap.Logger, fs *gotenberg.FileSyst
|
||||
ctx.values = form.Value
|
||||
ctx.files = make(map[string]string)
|
||||
|
||||
// First, try to download files listed in the "downloadFrom" form field, if
|
||||
// any.
|
||||
raw, ok := ctx.values["downloadFrom"]
|
||||
if !downloadFromCfg.disable && ok {
|
||||
var dls []downloadFrom
|
||||
err = json.Unmarshal([]byte(raw[0]), &dls)
|
||||
if err != nil {
|
||||
return nil, cancel, WrapError(
|
||||
fmt.Errorf("unmarshal json: %w", err),
|
||||
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("Invalid 'downloadFrom' form field value: %s", err)),
|
||||
)
|
||||
}
|
||||
|
||||
eg, _ := errgroup.WithContext(ctx)
|
||||
for i, dl := range dls {
|
||||
eg.Go(func() error {
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
// Should not happen, as context is created with a timeout.
|
||||
return errors.New("context has no deadline")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(dl.Url) == "" {
|
||||
return WrapError(
|
||||
errors.New("empty download from URL"),
|
||||
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("Invalid 'downloadFrom' form field entry %d: URL must be set", i)),
|
||||
)
|
||||
}
|
||||
|
||||
err := gotenberg.FilterDeadline(downloadFromCfg.allowList, downloadFromCfg.denyList, dl.Url, deadline)
|
||||
if err != nil {
|
||||
return fmt.Errorf("filter URL: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug(fmt.Sprintf("download file from '%s'", dl.Url))
|
||||
|
||||
req, err := retryablehttp.NewRequest(http.MethodGet, dl.Url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request to '%s': %w", dl.Url, err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "Gotenberg")
|
||||
for key, value := range dl.ExtraHttpHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
req.Header.Set(traceHeader, trace)
|
||||
|
||||
client := &retryablehttp.Client{
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: time.Until(deadline),
|
||||
},
|
||||
RetryMax: downloadFromCfg.maxRetry,
|
||||
RetryWaitMin: time.Duration(1) * time.Second,
|
||||
RetryWaitMax: time.Until(deadline),
|
||||
Logger: gotenberg.NewLeveledLogger(logger),
|
||||
CheckRetry: retryablehttp.DefaultRetryPolicy,
|
||||
Backoff: retryablehttp.DefaultBackoff,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return WrapError(
|
||||
fmt.Errorf("download file from to '%s': %w", dl.Url, err),
|
||||
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("Unable to download file from '%s': %s", dl.Url, err)),
|
||||
)
|
||||
}
|
||||
defer func() {
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("close response body from '%s': %s", dl.Url, err))
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return WrapError(
|
||||
fmt.Errorf("download file from to '%s': got status: '%s'", dl.Url, resp.Status),
|
||||
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("Unable to download file from '%s': got status: '%s'", dl.Url, resp.Status)),
|
||||
)
|
||||
}
|
||||
|
||||
contentDisposition := resp.Header.Get("Content-Disposition")
|
||||
if contentDisposition == "" {
|
||||
return WrapError(
|
||||
fmt.Errorf("no 'Content-Disposition' header from '%s'", dl.Url),
|
||||
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("No 'Content-Disposition' header from '%s'", dl.Url)),
|
||||
)
|
||||
}
|
||||
|
||||
_, params, err := mime.ParseMediaType(contentDisposition)
|
||||
if err != nil {
|
||||
return WrapError(
|
||||
fmt.Errorf("parse 'Content-Disposition' header '%s' from '%s': %w", contentDisposition, dl.Url, err),
|
||||
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("Invalid 'Content-Disposition' header '%s' from '%s': %s", contentDisposition, dl.Url, err)),
|
||||
)
|
||||
}
|
||||
|
||||
filename, ok := params["filename"]
|
||||
if !ok {
|
||||
return WrapError(
|
||||
fmt.Errorf("get filename from 'Content-Disposition' header '%s' from '%s'", contentDisposition, dl.Url),
|
||||
NewSentinelHttpError(http.StatusBadRequest, fmt.Sprintf("Invalid 'Content-Disposition' header '%s' from '%s': no filename", contentDisposition, dl.Url)),
|
||||
)
|
||||
}
|
||||
|
||||
// Avoid directory traversal and make sure filename characters are
|
||||
// normalized.
|
||||
// See: https://github.com/gotenberg/gotenberg/issues/662.
|
||||
filename = norm.NFC.String(filepath.Base(filename))
|
||||
path := fmt.Sprintf("%s/%s", ctx.dirPath, filename)
|
||||
|
||||
out, err := os.Create(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create local file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
err := out.Close()
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("close local file: %s", err))
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("copy downloaded file from '%s' to local file: %v", dl.Url, err)
|
||||
}
|
||||
|
||||
ctx.files[filename] = path
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
err = eg.Wait()
|
||||
if err != nil {
|
||||
return ctx, cancel, err
|
||||
}
|
||||
}
|
||||
|
||||
copyToDisk := func(fh *multipart.FileHeader) error {
|
||||
in, err := fh.Open()
|
||||
if err != nil {
|
||||
@@ -149,7 +299,6 @@ func newContext(echoCtx echo.Context, logger *zap.Logger, fs *gotenberg.FileSyst
|
||||
if err != nil {
|
||||
return fmt.Errorf("create local file: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := out.Close()
|
||||
if err != nil {
|
||||
@@ -167,6 +316,7 @@ func newContext(echoCtx echo.Context, logger *zap.Logger, fs *gotenberg.FileSyst
|
||||
return nil
|
||||
}
|
||||
|
||||
// Then, copy the form files, if any.
|
||||
for _, files := range form.File {
|
||||
for _, fh := range files {
|
||||
err = copyToDisk(fh)
|
||||
|
||||
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"go.uber.org/zap"
|
||||
@@ -75,9 +77,27 @@ func TestOsPathRename_Rename(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewContext(t *testing.T) {
|
||||
defaultAllowList, err := regexp2.Compile("", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
defaultDenyList, err := regexp2.Compile("", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
defaultDownloadFromCfg := downloadFromConfig{
|
||||
allowList: defaultAllowList,
|
||||
denyList: defaultDenyList,
|
||||
maxRetry: 1,
|
||||
disable: false,
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
scenario string
|
||||
request *http.Request
|
||||
downloadFromCfg downloadFromConfig
|
||||
downloadFromSrv *echo.Echo
|
||||
expectContext *Context
|
||||
expectError bool
|
||||
expectHttpError bool
|
||||
expectHttpStatus int
|
||||
@@ -123,6 +143,236 @@ func TestNewContext(t *testing.T) {
|
||||
expectHttpError: true,
|
||||
expectHttpStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
scenario: "invalid downloadFrom form field: cannot unmarshal",
|
||||
request: func() *http.Request {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
defer func() {
|
||||
err := writer.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
}()
|
||||
err := writer.WriteField("downloadFrom", "foo")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
|
||||
return req
|
||||
}(),
|
||||
downloadFromCfg: defaultDownloadFromCfg,
|
||||
expectError: true,
|
||||
expectHttpError: true,
|
||||
expectHttpStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
scenario: "invalid downloadFrom form field: no URL",
|
||||
request: func() *http.Request {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
defer func() {
|
||||
err := writer.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
}()
|
||||
err := writer.WriteField("downloadFrom", `[{}]`)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
|
||||
return req
|
||||
}(),
|
||||
downloadFromCfg: defaultDownloadFromCfg,
|
||||
expectError: true,
|
||||
expectHttpError: true,
|
||||
expectHttpStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
scenario: "invalid downloadFrom form field: filtered URL",
|
||||
request: func() *http.Request {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
defer func() {
|
||||
err := writer.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
}()
|
||||
err := writer.WriteField("downloadFrom", `[{"url":"https://foo.bar"}]`)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
|
||||
return req
|
||||
}(),
|
||||
downloadFromCfg: func() downloadFromConfig {
|
||||
denyList, err := regexp2.Compile("https://foo.bar", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
return downloadFromConfig{allowList: defaultAllowList, denyList: denyList, maxRetry: 1, disable: false}
|
||||
}(),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
scenario: "invalid downloadFrom form field: unreachable URL",
|
||||
request: func() *http.Request {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
defer func() {
|
||||
err := writer.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
}()
|
||||
err := writer.WriteField("downloadFrom", `[{"url":"http://localhost:80/"}]`)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
|
||||
return req
|
||||
}(),
|
||||
downloadFromCfg: defaultDownloadFromCfg,
|
||||
expectError: true,
|
||||
expectHttpError: true,
|
||||
expectHttpStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
scenario: "invalid downloadFrom form field: invalid status code",
|
||||
request: func() *http.Request {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
defer func() {
|
||||
err := writer.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
}()
|
||||
err := writer.WriteField("downloadFrom", `[{"url":"http://localhost:80/"}]`)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
|
||||
return req
|
||||
}(),
|
||||
downloadFromSrv: func() *echo.Echo {
|
||||
srv := echo.New()
|
||||
srv.HideBanner = true
|
||||
srv.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusNotFound, http.StatusText(http.StatusNotFound))
|
||||
})
|
||||
return srv
|
||||
}(),
|
||||
downloadFromCfg: defaultDownloadFromCfg,
|
||||
expectError: true,
|
||||
expectHttpError: true,
|
||||
expectHttpStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
scenario: "invalid downloadFrom form field: no 'Content-Disposition' header",
|
||||
request: func() *http.Request {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
defer func() {
|
||||
err := writer.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
}()
|
||||
err := writer.WriteField("downloadFrom", `[{"url":"http://localhost:80/"}]`)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
|
||||
return req
|
||||
}(),
|
||||
downloadFromSrv: func() *echo.Echo {
|
||||
srv := echo.New()
|
||||
srv.HideBanner = true
|
||||
srv.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, http.StatusText(http.StatusOK))
|
||||
})
|
||||
return srv
|
||||
}(),
|
||||
downloadFromCfg: defaultDownloadFromCfg,
|
||||
expectError: true,
|
||||
expectHttpError: true,
|
||||
expectHttpStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
scenario: "invalid downloadFrom form field: malformed 'Content-Disposition' header",
|
||||
request: func() *http.Request {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
defer func() {
|
||||
err := writer.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
}()
|
||||
err := writer.WriteField("downloadFrom", `[{"url":"http://localhost:80/"}]`)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
|
||||
return req
|
||||
}(),
|
||||
downloadFromSrv: func() *echo.Echo {
|
||||
srv := echo.New()
|
||||
srv.HideBanner = true
|
||||
srv.GET("/", func(c echo.Context) error {
|
||||
c.Response().Header().Set(echo.HeaderContentDisposition, ";;")
|
||||
return c.String(http.StatusOK, http.StatusText(http.StatusOK))
|
||||
})
|
||||
return srv
|
||||
}(),
|
||||
downloadFromCfg: defaultDownloadFromCfg,
|
||||
expectError: true,
|
||||
expectHttpError: true,
|
||||
expectHttpStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
scenario: "invalid downloadFrom form field: no filename parameter in 'Content-Disposition' header",
|
||||
request: func() *http.Request {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
defer func() {
|
||||
err := writer.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
}()
|
||||
err := writer.WriteField("downloadFrom", `[{"url":"http://localhost:80/"}]`)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
|
||||
return req
|
||||
}(),
|
||||
downloadFromSrv: func() *echo.Echo {
|
||||
srv := echo.New()
|
||||
srv.HideBanner = true
|
||||
srv.GET("/", func(c echo.Context) error {
|
||||
c.Response().Header().Set(echo.HeaderContentDisposition, "inline;")
|
||||
return c.String(http.StatusOK, http.StatusText(http.StatusOK))
|
||||
})
|
||||
return srv
|
||||
}(),
|
||||
downloadFromCfg: defaultDownloadFromCfg,
|
||||
expectError: true,
|
||||
expectHttpError: true,
|
||||
expectHttpStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
scenario: "success",
|
||||
request: func() *http.Request {
|
||||
@@ -146,17 +396,69 @@ func TestNewContext(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
err = writer.WriteField("downloadFrom", `[{"url":"http://localhost:80/","extraHttpHeaders":{"X-Foo":"Bar"}}]`)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Set(echo.HeaderContentType, writer.FormDataContentType())
|
||||
return req
|
||||
}(),
|
||||
downloadFromSrv: func() *echo.Echo {
|
||||
srv := echo.New()
|
||||
srv.HideBanner = true
|
||||
srv.GET("/", func(c echo.Context) error {
|
||||
if c.Request().Header.Get("User-Agent") != "Gotenberg" {
|
||||
t.Fatalf("expected 'Gotenberg' from header 'User-Agent', but got '%s'", c.Request().Header.Get("User-Agent"))
|
||||
}
|
||||
if c.Request().Header.Get("X-Foo") != "Bar" {
|
||||
t.Fatalf("expected 'Bar' from header 'X-Foo', but got '%s'", c.Request().Header.Get("X-Foo"))
|
||||
}
|
||||
if c.Request().Header.Get("Gotenberg-Trace") != "123" {
|
||||
t.Fatalf("expected '123' from header 'Gotenberg-Trace', but got '%s'", c.Request().Header.Get("Gotenberg-Trace"))
|
||||
}
|
||||
c.Response().Header().Set(echo.HeaderContentDisposition, `attachment; filename="bar.txt"`)
|
||||
c.Response().Header().Set(echo.HeaderContentType, "text/plain")
|
||||
return c.String(http.StatusOK, http.StatusText(http.StatusOK))
|
||||
})
|
||||
return srv
|
||||
}(),
|
||||
downloadFromCfg: defaultDownloadFromCfg,
|
||||
expectContext: &Context{
|
||||
values: map[string][]string{
|
||||
"foo": {"foo"},
|
||||
"downloadFrom": {
|
||||
`[{"url":"http://localhost:80/","extraHttpHeaders":{"X-Foo":"Bar"}}]`,
|
||||
},
|
||||
},
|
||||
files: map[string]string{
|
||||
"foo.txt": "foo.txt",
|
||||
"bar.txt": "bar.txt", // downloadFrom.
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
expectHttpError: false,
|
||||
},
|
||||
} {
|
||||
t.Run(tc.scenario, func(t *testing.T) {
|
||||
if tc.downloadFromSrv != nil {
|
||||
go func() {
|
||||
err := tc.downloadFromSrv.Start(":80")
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
err := tc.downloadFromSrv.Shutdown(context.TODO())
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
_, cancel, err := newContext(c, zap.NewNop(), gotenberg.NewFileSystem(), time.Duration(10)*time.Second)
|
||||
ctx, cancel, err := newContext(c, zap.NewNop(), gotenberg.NewFileSystem(), time.Duration(10)*time.Second, tc.downloadFromCfg, "Gotenberg-Trace", "123")
|
||||
defer cancel()
|
||||
// Context already cancelled.
|
||||
defer cancel()
|
||||
@@ -165,6 +467,20 @@ func TestNewContext(t *testing.T) {
|
||||
return err
|
||||
}
|
||||
|
||||
if tc.expectContext != nil {
|
||||
if !reflect.DeepEqual(tc.expectContext.values, ctx.values) {
|
||||
t.Fatalf("expected context.values to be %v but got %v", tc.expectContext.values, ctx.values)
|
||||
}
|
||||
if len(tc.expectContext.files) != len(ctx.files) {
|
||||
t.Fatalf("expected context.files to contain %d items but got %d", len(tc.expectContext.files), len(ctx.files))
|
||||
}
|
||||
for key, value := range tc.expectContext.files {
|
||||
if !strings.HasSuffix(ctx.files[key], value) {
|
||||
t.Fatalf("expected context.files to contain '%s' but got '%s'", value, ctx.files[key])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -236,14 +236,16 @@ func basicAuthMiddleware(username, password string) echo.MiddlewareFunc {
|
||||
//
|
||||
// ctx := c.Get("context").(*api.Context)
|
||||
// cancel := c.Get("cancel").(context.CancelFunc)
|
||||
func contextMiddleware(fs *gotenberg.FileSystem, timeout time.Duration) echo.MiddlewareFunc {
|
||||
func contextMiddleware(fs *gotenberg.FileSystem, timeout time.Duration, downloadFromCfg downloadFromConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
logger := c.Get("logger").(*zap.Logger)
|
||||
traceHeader := c.Get("traceHeader").(string)
|
||||
trace := c.Get("trace").(string)
|
||||
|
||||
// We create a context with a timeout so that underlying processes are
|
||||
// able to stop early and handle correctly a timeout scenario.
|
||||
ctx, cancel, err := newContext(c, logger, fs, timeout)
|
||||
ctx, cancel, err := newContext(c, logger, fs, timeout, downloadFromCfg, traceHeader, trace)
|
||||
if err != nil {
|
||||
cancel()
|
||||
|
||||
|
||||
@@ -458,10 +458,11 @@ func TestContextMiddleware(t *testing.T) {
|
||||
|
||||
c := srv.NewContext(tc.request, recorder)
|
||||
c.Set("logger", zap.NewNop())
|
||||
c.Set("traceHeader", "Gotenberg-Trace")
|
||||
c.Set("trace", "foo")
|
||||
c.Set("startTime", time.Now())
|
||||
|
||||
err := contextMiddleware(gotenberg.NewFileSystem(), time.Duration(10)*time.Second)(tc.next)(c)
|
||||
err := contextMiddleware(gotenberg.NewFileSystem(), time.Duration(10)*time.Second, downloadFromConfig{})(tc.next)(c)
|
||||
|
||||
if tc.expectErr && err == nil {
|
||||
t.Errorf("test %d: expected error but got: %v", i, err)
|
||||
|
||||
@@ -27,9 +27,9 @@ type client struct {
|
||||
|
||||
// send call the webhook either to send the success response or the error response.
|
||||
func (c client) send(body io.Reader, headers map[string]string, erroed bool) error {
|
||||
URL := c.url
|
||||
url := c.url
|
||||
if erroed {
|
||||
URL = c.errorUrl
|
||||
url = c.errorUrl
|
||||
}
|
||||
|
||||
method := c.method
|
||||
@@ -37,9 +37,9 @@ func (c client) send(body io.Reader, headers map[string]string, erroed bool) err
|
||||
method = c.errorMethod
|
||||
}
|
||||
|
||||
req, err := retryablehttp.NewRequest(method, URL, body)
|
||||
req, err := retryablehttp.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create '%s' request to '%s': %w", method, URL, err)
|
||||
return fmt.Errorf("create '%s' request to '%s': %w", method, url, err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "Gotenberg")
|
||||
@@ -75,17 +75,17 @@ func (c client) send(body io.Reader, headers map[string]string, erroed bool) err
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send '%s' request to '%s': %w", method, URL, err)
|
||||
return fmt.Errorf("send '%s' request to '%s': %w", method, url, err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
return fmt.Errorf("send '%s' request to '%s': got status: '%s'", method, URL, resp.Status)
|
||||
return fmt.Errorf("send '%s' request to '%s': got status: '%s'", method, url, resp.Status)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
c.logger.Error(fmt.Sprintf("close response body from '%s': %s", URL, err))
|
||||
c.logger.Error(fmt.Sprintf("close response body from '%s': %s", url, err))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -94,7 +94,7 @@ func (c client) send(body io.Reader, headers map[string]string, erroed bool) err
|
||||
|
||||
// Now let's log!
|
||||
fields := make([]zap.Field, 5)
|
||||
fields[0] = zap.String("webhook_url", URL)
|
||||
fields[0] = zap.String("webhook_url", url)
|
||||
fields[1] = zap.String("method", method)
|
||||
fields[2] = zap.Int64("latency", int64(finishTime.Sub(c.startTime)))
|
||||
fields[3] = zap.String("latency_human", finishTime.Sub(c.startTime).String())
|
||||
@@ -110,34 +110,3 @@ func (c client) send(body io.Reader, headers map[string]string, erroed bool) err
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// leveledLogger is wrapper around a [zap.Logger] which is used by the
|
||||
// [retryablehttp.Client].
|
||||
type leveledLogger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// Error logs a message at error level using the wrapped zap.Logger.
|
||||
func (leveled leveledLogger) Error(msg string, keysAndValues ...interface{}) {
|
||||
leveled.logger.Error(fmt.Sprintf("%s: %+v", msg, keysAndValues))
|
||||
}
|
||||
|
||||
// Warn logs a message at warning level using the wrapped zap.Logger.
|
||||
func (leveled leveledLogger) Warn(msg string, keysAndValues ...interface{}) {
|
||||
leveled.logger.Warn(fmt.Sprintf("%s: %+v", msg, keysAndValues))
|
||||
}
|
||||
|
||||
// Info logs a message at info level using the wrapped zap.Logger.
|
||||
func (leveled leveledLogger) Info(msg string, keysAndValues ...interface{}) {
|
||||
leveled.logger.Info(fmt.Sprintf("%s: %+v", msg, keysAndValues))
|
||||
}
|
||||
|
||||
// Debug logs a message at debug level using the wrapped zap.Logger.
|
||||
func (leveled leveledLogger) Debug(msg string, keysAndValues ...interface{}) {
|
||||
leveled.logger.Debug(fmt.Sprintf("%s: %+v", msg, keysAndValues))
|
||||
}
|
||||
|
||||
// Interface guards.
|
||||
var (
|
||||
_ retryablehttp.LeveledLogger = (*leveledLogger)(nil)
|
||||
)
|
||||
|
||||
@@ -100,11 +100,11 @@ func webhookMiddleware(w *Webhook) api.Middleware {
|
||||
}
|
||||
|
||||
// What about extra HTTP headers?
|
||||
var extraHTTPHeaders map[string]string
|
||||
var extraHttpHeaders map[string]string
|
||||
|
||||
extraHTTPHeadersJSON := c.Request().Header.Get("Gotenberg-Webhook-Extra-Http-Headers")
|
||||
if extraHTTPHeadersJSON != "" {
|
||||
err = json.Unmarshal([]byte(extraHTTPHeadersJSON), &extraHTTPHeaders)
|
||||
extraHttpHeadersJson := c.Request().Header.Get("Gotenberg-Webhook-Extra-Http-Headers")
|
||||
if extraHttpHeadersJson != "" {
|
||||
err = json.Unmarshal([]byte(extraHttpHeadersJson), &extraHttpHeaders)
|
||||
if err != nil {
|
||||
return api.WrapError(
|
||||
fmt.Errorf("unmarshal webhook extra HTTP headers: %w", err),
|
||||
@@ -118,7 +118,7 @@ func webhookMiddleware(w *Webhook) api.Middleware {
|
||||
method: webhookMethod,
|
||||
errorUrl: webhookErrorUrl,
|
||||
errorMethod: webhookErrorMethod,
|
||||
extraHttpHeaders: extraHTTPHeaders,
|
||||
extraHttpHeaders: extraHttpHeaders,
|
||||
startTime: c.Get("startTime").(time.Time),
|
||||
|
||||
client: &retryablehttp.Client{
|
||||
@@ -128,11 +128,9 @@ func webhookMiddleware(w *Webhook) api.Middleware {
|
||||
RetryMax: w.maxRetry,
|
||||
RetryWaitMin: w.retryMinWait,
|
||||
RetryWaitMax: w.retryMaxWait,
|
||||
Logger: leveledLogger{
|
||||
logger: ctx.Log(),
|
||||
},
|
||||
CheckRetry: retryablehttp.DefaultRetryPolicy,
|
||||
Backoff: retryablehttp.DefaultBackoff,
|
||||
Logger: gotenberg.NewLeveledLogger(ctx.Log()),
|
||||
CheckRetry: retryablehttp.DefaultRetryPolicy,
|
||||
Backoff: retryablehttp.DefaultBackoff,
|
||||
},
|
||||
logger: ctx.Log(),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user