diff options
-rw-r--r-- | main.go | 24 | ||||
-rw-r--r-- | main_test.go | 42 |
2 files changed, 50 insertions, 16 deletions
@@ -3,15 +3,18 @@ package main import ( "net/http" "net/url" + "path" + "strings" ) const ( - errSrcInvalid = "source is not a parsable URL" + errSrcInvalid = "source is not a parsable URL" + errTgtNotAccepted = "can not process webmentions for this target" ) // endpoint is a webmention receiver. type endpoint struct { - allowPrefix string + allowPrefix string // host (or host:port) and path prefix for the targets served by this endpoint } // ServeHTTP is http.Handler implementation. @@ -21,4 +24,21 @@ func (ep endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(errSrcInvalid)) } + target, err := url.Parse(r.PostFormValue("target")) + if err != nil || !ep.targetAllowed(target) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(errTgtNotAccepted)) + } +} + +// targetAllowed shows whether ep can accept a webmention for the target. +func (ep endpoint) targetAllowed(target *url.URL) bool { + if !strings.HasSuffix(ep.allowPrefix, "/") { + ep.allowPrefix = ep.allowPrefix + "/" + } + tgt := path.Join(target.Host, target.Path) + if !strings.HasSuffix(tgt, "/") { + tgt = tgt + "/" + } + return strings.HasPrefix(tgt, ep.allowPrefix) } diff --git a/main_test.go b/main_test.go index d1b864d..0510af8 100644 --- a/main_test.go +++ b/main_test.go @@ -9,22 +9,36 @@ import ( ) func TestSyncRejection(t *testing.T) { - server := httptest.NewServer(endpoint{"my.site"}) + server := httptest.NewServer(endpoint{"my.site/part"}) defer server.Close() - client := http.DefaultClient - r, err := client.PostForm(server.URL, url.Values{ - "source": []string{"https||:example.org/somewhere"}, - "target": []string{"my.site/target"}, - }) - if err != nil { - t.Fatal(err) + tests := []struct { + name string + source string + target string + expect string + }{ + {"invalid source", "https||:example.org/somewhere", "my.site/part/target", errSrcInvalid}, + {"target no accepted", "https://example.org/somewhere", "wrong.site/tgt", errTgtNotAccepted}, } - if r.StatusCode != 400 { - t.Fatalf("want 400, got %v", r.Status) - } - bb, _ := io.ReadAll(r.Body) - if string(bb) != errSrcInvalid { - t.Fatalf("want %s, got %s", errSrcInvalid, string(bb)) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := http.DefaultClient + r, err := client.PostForm(server.URL, url.Values{ + "source": []string{tc.source}, + "target": []string{tc.target}, + }) + if err != nil { + t.Fatal(err) + } + if r.StatusCode != 400 { + t.Fatalf("want 400, got %v", r.Status) + } + bb, _ := io.ReadAll(r.Body) + if string(bb) != tc.expect { + t.Fatalf("want %s, got %s", tc.expect, string(bb)) + } + }) } } |