From bf2cfaf0ede0e9744408f52538fb4bcd87a6d5b8 Mon Sep 17 00:00:00 2001 From: r Date: Sat, 25 Jan 2020 10:07:06 +0000 Subject: Add CSRF protection --- service/transport.go | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) (limited to 'service/transport.go') diff --git a/service/transport.go b/service/transport.go index 8cca4f5..e878f8d 100644 --- a/service/transport.go +++ b/service/transport.go @@ -160,6 +160,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/like/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] retweetedByID := req.FormValue("retweeted_by_id") @@ -179,6 +181,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/unlike/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] retweetedByID := req.FormValue("retweeted_by_id") @@ -198,6 +202,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/retweet/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] retweetedByID := req.FormValue("retweeted_by_id") @@ -217,6 +223,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] retweetedByID := req.FormValue("retweeted_by_id") @@ -236,6 +244,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/fluoride/like/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] count, err := s.Like(ctx, w, nil, id) if err != nil { @@ -252,6 +262,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/fluoride/unlike/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] count, err := s.UnLike(ctx, w, nil, id) if err != nil { @@ -268,6 +280,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/fluoride/retweet/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] count, err := s.Retweet(ctx, w, nil, id) if err != nil { @@ -284,6 +298,8 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/fluoride/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) + id, _ := mux.Vars(req)["id"] count, err := s.UnRetweet(ctx, w, nil, id) if err != nil { @@ -299,14 +315,16 @@ func NewHandler(s Service, staticDir string) http.Handler { }).Methods(http.MethodPost) r.HandleFunc("/post", func(w http.ResponseWriter, req *http.Request) { - ctx := getContextWithSession(context.Background(), req) - err := req.ParseMultipartForm(4 << 20) if err != nil { s.ServeErrorPage(ctx, w, nil, err) return } + ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", + getMultipartFormValue(req.MultipartForm, "csrf_token")) + content := getMultipartFormValue(req.MultipartForm, "content") replyToID := getMultipartFormValue(req.MultipartForm, "reply_to_id") format := getMultipartFormValue(req.MultipartForm, "format") @@ -358,6 +376,7 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/follow/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) id, _ := mux.Vars(req)["id"] @@ -373,6 +392,7 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/unfollow/{id}", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) id, _ := mux.Vars(req)["id"] @@ -442,6 +462,7 @@ func NewHandler(s Service, staticDir string) http.Handler { r.HandleFunc("/settings", func(w http.ResponseWriter, req *http.Request) { ctx := getContextWithSession(context.Background(), req) + ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) visibility := req.FormValue("visibility") copyScope := req.FormValue("copy_scope") == "true" -- cgit v1.2.3