diff --git a/colly.go b/colly.go index 451b6766..d0cf4ee7 100644 --- a/colly.go +++ b/colly.go @@ -1343,12 +1343,26 @@ func (c *Collector) checkRedirectFunc() func(req *http.Request, via []*http.Requ return fmt.Errorf("Not following redirect to %q: %w", req.URL, err) } - // allow redirects to the original destination - // to support websites redirecting to the same page while setting - // session cookies - samePageRedirect := normalizeURL(req.URL.String()) == normalizeURL(via[0].URL.String()) + // Page may set cookies and respond with a redirect to itself. + // Some example of such redirect "cycles": + // + // example.com -(set cookie)-> example.com + // example.com -> auth.example.com -(set cookie)-> example.com + // www.example.com -> example.com -(set cookie)-> example.com + // + // We must not return "already visited" error in such cases. + // So ignore redirect cycles when checking for URL revisit. + redirectCycle := false + normalizedURL := normalizeURL(req.URL.String()) + for _, viaReq := range via { + viaURL := normalizeURL(viaReq.URL.String()) + if viaURL == normalizedURL { + redirectCycle = true + break + } + } - if !c.AllowURLRevisit && !samePageRedirect { + if !c.AllowURLRevisit && !redirectCycle { var body io.ReadCloser if req.GetBody != nil { var err error diff --git a/colly_test.go b/colly_test.go index 2382ecb1..9b3a650b 100644 --- a/colly_test.go +++ b/colly_test.go @@ -783,6 +783,30 @@ func TestSetCookieRedirect(t *testing.T) { } } +func TestSetCookieComplexRedirectCycle(t *testing.T) { + // server2 -> server1 -(set cookie)-> server1 + ts1 := newUnstartedTestServer() + ts1.Config.Handler = requireSessionCookieSimple(ts1.Config.Handler) + ts1.Start() + defer ts1.Close() + + ts2 := httptest.NewServer(http.RedirectHandler(ts1.URL, http.StatusMovedPermanently)) + defer ts2.Close() + + c := NewCollector() + c.OnResponse(func(r *Response) { + if got, want := r.Body, serverIndexResponse; !bytes.Equal(got, want) { + t.Errorf("bad response body got=%q want=%q", got, want) + } + if got, want := r.StatusCode, http.StatusOK; got != want { + t.Errorf("bad response code got=%d want=%d", got, want) + } + }) + if err := c.Visit(ts2.URL); err != nil { + t.Fatal(err) + } +} + func TestCollectorPostURLRevisitCheck(t *testing.T) { ts := newTestServer() defer ts.Close()