From bf2cfaf0ede0e9744408f52538fb4bcd87a6d5b8 Mon Sep 17 00:00:00 2001
From: r <r@freesoftwareextremist.com>
Date: Sat, 25 Jan 2020 10:07:06 +0000
Subject: Add CSRF protection

---
 service/auth.go      | 43 ++++++++++++++++++++++++++++++++++-
 service/service.go   | 63 ++++++++++++++++++++++++++--------------------------
 service/transport.go | 25 +++++++++++++++++++--
 3 files changed, 96 insertions(+), 35 deletions(-)

(limited to 'service')

diff --git a/service/auth.go b/service/auth.go
index e517383..909a9a2 100644
--- a/service/auth.go
+++ b/service/auth.go
@@ -11,7 +11,8 @@ import (
 )
 
 var (
-	ErrInvalidSession = errors.New("invalid session")
+	ErrInvalidSession   = errors.New("invalid session")
+	ErrInvalidCSRFToken = errors.New("invalid csrf token")
 )
 
 type authService struct {
@@ -47,6 +48,14 @@ func (s *authService) getClient(ctx context.Context) (c *model.Client, err error
 	return c, nil
 }
 
+func checkCSRF(ctx context.Context, c *model.Client) (err error) {
+	csrfToken, ok := ctx.Value("csrf_token").(string)
+	if !ok || csrfToken != c.Session.CSRFToken {
+		return ErrInvalidCSRFToken
+	}
+	return nil
+}
+
 func (s *authService) GetAuthUrl(ctx context.Context, instance string) (
 	redirectUrl string, sessionID string, err error) {
 	return s.Service.GetAuthUrl(ctx, instance)
@@ -184,6 +193,10 @@ func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *mod
 	if err != nil {
 		return
 	}
+	err = checkCSRF(ctx, c)
+	if err != nil {
+		return
+	}
 	return s.Service.SaveSettings(ctx, client, c, settings)
 }
 
@@ -192,6 +205,10 @@ func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Clien
 	if err != nil {
 		return
 	}
+	err = checkCSRF(ctx, c)
+	if err != nil {
+		return
+	}
 	return s.Service.Like(ctx, client, c, id)
 }
 
@@ -200,6 +217,10 @@ func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Cli
 	if err != nil {
 		return
 	}
+	err = checkCSRF(ctx, c)
+	if err != nil {
+		return
+	}
 	return s.Service.UnLike(ctx, client, c, id)
 }
 
@@ -208,6 +229,10 @@ func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Cl
 	if err != nil {
 		return
 	}
+	err = checkCSRF(ctx, c)
+	if err != nil {
+		return
+	}
 	return s.Service.Retweet(ctx, client, c, id)
 }
 
@@ -216,6 +241,10 @@ func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model.
 	if err != nil {
 		return
 	}
+	err = checkCSRF(ctx, c)
+	if err != nil {
+		return
+	}
 	return s.Service.UnRetweet(ctx, client, c, id)
 }
 
@@ -224,6 +253,10 @@ func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model.
 	if err != nil {
 		return
 	}
+	err = checkCSRF(ctx, c)
+	if err != nil {
+		return
+	}
 	return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files)
 }
 
@@ -232,6 +265,10 @@ func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Cli
 	if err != nil {
 		return
 	}
+	err = checkCSRF(ctx, c)
+	if err != nil {
+		return
+	}
 	return s.Service.Follow(ctx, client, c, id)
 }
 
@@ -240,5 +277,9 @@ func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.C
 	if err != nil {
 		return
 	}
+	err = checkCSRF(ctx, c)
+	if err != nil {
+		return
+	}
 	return s.Service.UnFollow(ctx, client, c, id)
 }
diff --git a/service/service.go b/service/service.go
index bfacf80..db851f7 100644
--- a/service/service.go
+++ b/service/service.go
@@ -78,12 +78,21 @@ func NewService(clientName string, clientScope string, clientWebsite string,
 	}
 }
 
-func getRendererContext(s model.Settings) *renderer.Context {
+func getRendererContext(c *model.Client) *renderer.Context {
+	var settings model.Settings
+	var session model.Session
+	if c != nil {
+		settings = c.Session.Settings
+		session = c.Session
+	} else {
+		settings = *model.NewSettings()
+	}
 	return &renderer.Context{
-		MaskNSFW:       s.MaskNSFW,
-		ThreadInNewTab: s.ThreadInNewTab,
-		FluorideMode:   s.FluorideMode,
-		DarkMode:       s.DarkMode,
+		MaskNSFW:       settings.MaskNSFW,
+		ThreadInNewTab: settings.ThreadInNewTab,
+		FluorideMode:   settings.FluorideMode,
+		DarkMode:       settings.DarkMode,
+		CSRFToken:      session.CSRFToken,
 	}
 }
 
@@ -98,9 +107,11 @@ func (svc *service) GetAuthUrl(ctx context.Context, instance string) (
 	}
 
 	sessionID = util.NewSessionId()
+	csrfToken := util.NewCSRFToken()
 	session := model.Session{
 		ID:             sessionID,
 		InstanceDomain: instance,
+		CSRFToken:      csrfToken,
 		Settings:       *model.NewSettings(),
 	}
 	err = svc.sessionRepo.Add(session)
@@ -199,13 +210,6 @@ func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *model
 	if err != nil {
 		return
 	}
-	/*
-		err = c.AuthenticateToken(ctx, code, svc.clientWebsite+"/oauth_callback")
-		if err != nil {
-			return
-		}
-		err = svc.sessionRepo.Update(sessionID, c.GetAccessToken(ctx))
-	*/
 
 	return res.AccessToken, nil
 }
@@ -226,13 +230,7 @@ func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, c *mod
 		Error:      errStr,
 	}
 
-	var s model.Settings
-	if c != nil {
-		s = c.Session.Settings
-	} else {
-		s = *model.NewSettings()
-	}
-	rCtx := getRendererContext(s)
+	rCtx := getRendererContext(c)
 
 	svc.renderer.RenderErrorPage(rCtx, client, data)
 }
@@ -247,7 +245,7 @@ func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err
 		CommonData: commonData,
 	}
 
-	rCtx := getRendererContext(*model.NewSettings())
+	rCtx := getRendererContext(nil)
 	return svc.renderer.RenderSigninPage(rCtx, client, data)
 }
 
@@ -334,7 +332,7 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer,
 		PostContext: postContext,
 		CommonData:  commonData,
 	}
-	rCtx := getRendererContext(c.Session.Settings)
+	rCtx := getRendererContext(c)
 
 	err = svc.renderer.RenderTimelinePage(rCtx, client, data)
 	if err != nil {
@@ -416,7 +414,7 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo
 		ReplyMap:    replyMap,
 		CommonData:  commonData,
 	}
-	rCtx := getRendererContext(c.Session.Settings)
+	rCtx := getRendererContext(c)
 
 	err = svc.renderer.RenderThreadPage(rCtx, client, data)
 	if err != nil {
@@ -478,7 +476,7 @@ func (svc *service) ServeNotificationPage(ctx context.Context, client io.Writer,
 		NextLink:      nextLink,
 		CommonData:    commonData,
 	}
-	rCtx := getRendererContext(c.Session.Settings)
+	rCtx := getRendererContext(c)
 
 	err = svc.renderer.RenderNotificationPage(rCtx, client, data)
 	if err != nil {
@@ -525,7 +523,7 @@ func (svc *service) ServeUserPage(ctx context.Context, client io.Writer, c *mode
 		NextLink:   nextLink,
 		CommonData: commonData,
 	}
-	rCtx := getRendererContext(c.Session.Settings)
+	rCtx := getRendererContext(c)
 
 	err = svc.renderer.RenderUserPage(rCtx, client, data)
 	if err != nil {
@@ -544,7 +542,7 @@ func (svc *service) ServeAboutPage(ctx context.Context, client io.Writer, c *mod
 	data := &renderer.AboutData{
 		CommonData: commonData,
 	}
-	rCtx := getRendererContext(c.Session.Settings)
+	rCtx := getRendererContext(c)
 
 	err = svc.renderer.RenderAboutPage(rCtx, client, data)
 	if err != nil {
@@ -569,7 +567,7 @@ func (svc *service) ServeEmojiPage(ctx context.Context, client io.Writer, c *mod
 		Emojis:     emojis,
 		CommonData: commonData,
 	}
-	rCtx := getRendererContext(c.Session.Settings)
+	rCtx := getRendererContext(c)
 
 	err = svc.renderer.RenderEmojiPage(rCtx, client, data)
 	if err != nil {
@@ -594,7 +592,7 @@ func (svc *service) ServeLikedByPage(ctx context.Context, client io.Writer, c *m
 		CommonData: commonData,
 		Users:      likers,
 	}
-	rCtx := getRendererContext(c.Session.Settings)
+	rCtx := getRendererContext(c)
 
 	err = svc.renderer.RenderLikedByPage(rCtx, client, data)
 	if err != nil {
@@ -619,7 +617,7 @@ func (svc *service) ServeRetweetedByPage(ctx context.Context, client io.Writer,
 		CommonData: commonData,
 		Users:      retweeters,
 	}
-	rCtx := getRendererContext(c.Session.Settings)
+	rCtx := getRendererContext(c)
 
 	err = svc.renderer.RenderRetweetedByPage(rCtx, client, data)
 	if err != nil {
@@ -660,7 +658,7 @@ func (svc *service) ServeFollowingPage(ctx context.Context, client io.Writer, c
 		HasNext:    hasNext,
 		NextLink:   nextLink,
 	}
-	rCtx := getRendererContext(c.Session.Settings)
+	rCtx := getRendererContext(c)
 
 	err = svc.renderer.RenderFollowingPage(rCtx, client, data)
 	if err != nil {
@@ -701,7 +699,7 @@ func (svc *service) ServeFollowersPage(ctx context.Context, client io.Writer, c
 		HasNext:    hasNext,
 		NextLink:   nextLink,
 	}
-	rCtx := getRendererContext(c.Session.Settings)
+	rCtx := getRendererContext(c)
 
 	err = svc.renderer.RenderFollowersPage(rCtx, client, data)
 	if err != nil {
@@ -750,7 +748,7 @@ func (svc *service) ServeSearchPage(ctx context.Context, client io.Writer, c *mo
 		HasNext:    hasNext,
 		NextLink:   nextLink,
 	}
-	rCtx := getRendererContext(c.Session.Settings)
+	rCtx := getRendererContext(c)
 
 	err = svc.renderer.RenderSearchPage(rCtx, client, data)
 	if err != nil {
@@ -770,7 +768,7 @@ func (svc *service) ServeSettingsPage(ctx context.Context, client io.Writer, c *
 		CommonData: commonData,
 		Settings:   &c.Session.Settings,
 	}
-	rCtx := getRendererContext(c.Session.Settings)
+	rCtx := getRendererContext(c)
 
 	err = svc.renderer.RenderSettingsPage(rCtx, client, data)
 	if err != nil {
@@ -828,6 +826,7 @@ func (svc *service) getCommonData(ctx context.Context, client io.Writer, c *mode
 		}
 
 		data.HeaderData.NotificationCount = notificationCount
+		data.HeaderData.CSRFToken = c.Session.CSRFToken
 	}
 
 	return
diff --git a/service/transport.go b/service/transport.go
index 8cca4f5..e878f8d 100644
--- a/service/transport.go
+++ b/service/transport.go
@@ -160,6 +160,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
 
 	r.HandleFunc("/like/{id}", func(w http.ResponseWriter, req *http.Request) {
 		ctx := getContextWithSession(context.Background(), req)
+		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
+
 		id, _ := mux.Vars(req)["id"]
 		retweetedByID := req.FormValue("retweeted_by_id")
 
@@ -179,6 +181,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
 
 	r.HandleFunc("/unlike/{id}", func(w http.ResponseWriter, req *http.Request) {
 		ctx := getContextWithSession(context.Background(), req)
+		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
+
 		id, _ := mux.Vars(req)["id"]
 		retweetedByID := req.FormValue("retweeted_by_id")
 
@@ -198,6 +202,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
 
 	r.HandleFunc("/retweet/{id}", func(w http.ResponseWriter, req *http.Request) {
 		ctx := getContextWithSession(context.Background(), req)
+		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
+
 		id, _ := mux.Vars(req)["id"]
 		retweetedByID := req.FormValue("retweeted_by_id")
 
@@ -217,6 +223,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
 
 	r.HandleFunc("/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) {
 		ctx := getContextWithSession(context.Background(), req)
+		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
+
 		id, _ := mux.Vars(req)["id"]
 		retweetedByID := req.FormValue("retweeted_by_id")
 
@@ -236,6 +244,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
 
 	r.HandleFunc("/fluoride/like/{id}", func(w http.ResponseWriter, req *http.Request) {
 		ctx := getContextWithSession(context.Background(), req)
+		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
+
 		id, _ := mux.Vars(req)["id"]
 		count, err := s.Like(ctx, w, nil, id)
 		if err != nil {
@@ -252,6 +262,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
 
 	r.HandleFunc("/fluoride/unlike/{id}", func(w http.ResponseWriter, req *http.Request) {
 		ctx := getContextWithSession(context.Background(), req)
+		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
+
 		id, _ := mux.Vars(req)["id"]
 		count, err := s.UnLike(ctx, w, nil, id)
 		if err != nil {
@@ -268,6 +280,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
 
 	r.HandleFunc("/fluoride/retweet/{id}", func(w http.ResponseWriter, req *http.Request) {
 		ctx := getContextWithSession(context.Background(), req)
+		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
+
 		id, _ := mux.Vars(req)["id"]
 		count, err := s.Retweet(ctx, w, nil, id)
 		if err != nil {
@@ -284,6 +298,8 @@ func NewHandler(s Service, staticDir string) http.Handler {
 
 	r.HandleFunc("/fluoride/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) {
 		ctx := getContextWithSession(context.Background(), req)
+		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
+
 		id, _ := mux.Vars(req)["id"]
 		count, err := s.UnRetweet(ctx, w, nil, id)
 		if err != nil {
@@ -299,14 +315,16 @@ func NewHandler(s Service, staticDir string) http.Handler {
 	}).Methods(http.MethodPost)
 
 	r.HandleFunc("/post", func(w http.ResponseWriter, req *http.Request) {
-		ctx := getContextWithSession(context.Background(), req)
-
 		err := req.ParseMultipartForm(4 << 20)
 		if err != nil {
 			s.ServeErrorPage(ctx, w, nil, err)
 			return
 		}
 
+		ctx := getContextWithSession(context.Background(), req)
+		ctx = context.WithValue(ctx, "csrf_token",
+			getMultipartFormValue(req.MultipartForm, "csrf_token"))
+
 		content := getMultipartFormValue(req.MultipartForm, "content")
 		replyToID := getMultipartFormValue(req.MultipartForm, "reply_to_id")
 		format := getMultipartFormValue(req.MultipartForm, "format")
@@ -358,6 +376,7 @@ func NewHandler(s Service, staticDir string) http.Handler {
 
 	r.HandleFunc("/follow/{id}", func(w http.ResponseWriter, req *http.Request) {
 		ctx := getContextWithSession(context.Background(), req)
+		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
 
 		id, _ := mux.Vars(req)["id"]
 
@@ -373,6 +392,7 @@ func NewHandler(s Service, staticDir string) http.Handler {
 
 	r.HandleFunc("/unfollow/{id}", func(w http.ResponseWriter, req *http.Request) {
 		ctx := getContextWithSession(context.Background(), req)
+		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
 
 		id, _ := mux.Vars(req)["id"]
 
@@ -442,6 +462,7 @@ func NewHandler(s Service, staticDir string) http.Handler {
 
 	r.HandleFunc("/settings", func(w http.ResponseWriter, req *http.Request) {
 		ctx := getContextWithSession(context.Background(), req)
+		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
 
 		visibility := req.FormValue("visibility")
 		copyScope := req.FormValue("copy_scope") == "true"
-- 
cgit v1.2.3