From bf2cfaf0ede0e9744408f52538fb4bcd87a6d5b8 Mon Sep 17 00:00:00 2001 From: r Date: Sat, 25 Jan 2020 10:07:06 +0000 Subject: Add CSRF protection --- service/service.go | 63 +++++++++++++++++++++++++++--------------------------- 1 file changed, 31 insertions(+), 32 deletions(-) (limited to 'service/service.go') diff --git a/service/service.go b/service/service.go index bfacf80..db851f7 100644 --- a/service/service.go +++ b/service/service.go @@ -78,12 +78,21 @@ func NewService(clientName string, clientScope string, clientWebsite string, } } -func getRendererContext(s model.Settings) *renderer.Context { +func getRendererContext(c *model.Client) *renderer.Context { + var settings model.Settings + var session model.Session + if c != nil { + settings = c.Session.Settings + session = c.Session + } else { + settings = *model.NewSettings() + } return &renderer.Context{ - MaskNSFW: s.MaskNSFW, - ThreadInNewTab: s.ThreadInNewTab, - FluorideMode: s.FluorideMode, - DarkMode: s.DarkMode, + MaskNSFW: settings.MaskNSFW, + ThreadInNewTab: settings.ThreadInNewTab, + FluorideMode: settings.FluorideMode, + DarkMode: settings.DarkMode, + CSRFToken: session.CSRFToken, } } @@ -98,9 +107,11 @@ func (svc *service) GetAuthUrl(ctx context.Context, instance string) ( } sessionID = util.NewSessionId() + csrfToken := util.NewCSRFToken() session := model.Session{ ID: sessionID, InstanceDomain: instance, + CSRFToken: csrfToken, Settings: *model.NewSettings(), } err = svc.sessionRepo.Add(session) @@ -199,13 +210,6 @@ func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *model if err != nil { return } - /* - err = c.AuthenticateToken(ctx, code, svc.clientWebsite+"/oauth_callback") - if err != nil { - return - } - err = svc.sessionRepo.Update(sessionID, c.GetAccessToken(ctx)) - */ return res.AccessToken, nil } @@ -226,13 +230,7 @@ func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, c *mod Error: errStr, } - var s model.Settings - if c != nil { - s = c.Session.Settings - } else { - s = *model.NewSettings() - } - rCtx := getRendererContext(s) + rCtx := getRendererContext(c) svc.renderer.RenderErrorPage(rCtx, client, data) } @@ -247,7 +245,7 @@ func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err CommonData: commonData, } - rCtx := getRendererContext(*model.NewSettings()) + rCtx := getRendererContext(nil) return svc.renderer.RenderSigninPage(rCtx, client, data) } @@ -334,7 +332,7 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer, PostContext: postContext, CommonData: commonData, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderTimelinePage(rCtx, client, data) if err != nil { @@ -416,7 +414,7 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo ReplyMap: replyMap, CommonData: commonData, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderThreadPage(rCtx, client, data) if err != nil { @@ -478,7 +476,7 @@ func (svc *service) ServeNotificationPage(ctx context.Context, client io.Writer, NextLink: nextLink, CommonData: commonData, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderNotificationPage(rCtx, client, data) if err != nil { @@ -525,7 +523,7 @@ func (svc *service) ServeUserPage(ctx context.Context, client io.Writer, c *mode NextLink: nextLink, CommonData: commonData, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderUserPage(rCtx, client, data) if err != nil { @@ -544,7 +542,7 @@ func (svc *service) ServeAboutPage(ctx context.Context, client io.Writer, c *mod data := &renderer.AboutData{ CommonData: commonData, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderAboutPage(rCtx, client, data) if err != nil { @@ -569,7 +567,7 @@ func (svc *service) ServeEmojiPage(ctx context.Context, client io.Writer, c *mod Emojis: emojis, CommonData: commonData, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderEmojiPage(rCtx, client, data) if err != nil { @@ -594,7 +592,7 @@ func (svc *service) ServeLikedByPage(ctx context.Context, client io.Writer, c *m CommonData: commonData, Users: likers, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderLikedByPage(rCtx, client, data) if err != nil { @@ -619,7 +617,7 @@ func (svc *service) ServeRetweetedByPage(ctx context.Context, client io.Writer, CommonData: commonData, Users: retweeters, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderRetweetedByPage(rCtx, client, data) if err != nil { @@ -660,7 +658,7 @@ func (svc *service) ServeFollowingPage(ctx context.Context, client io.Writer, c HasNext: hasNext, NextLink: nextLink, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderFollowingPage(rCtx, client, data) if err != nil { @@ -701,7 +699,7 @@ func (svc *service) ServeFollowersPage(ctx context.Context, client io.Writer, c HasNext: hasNext, NextLink: nextLink, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderFollowersPage(rCtx, client, data) if err != nil { @@ -750,7 +748,7 @@ func (svc *service) ServeSearchPage(ctx context.Context, client io.Writer, c *mo HasNext: hasNext, NextLink: nextLink, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderSearchPage(rCtx, client, data) if err != nil { @@ -770,7 +768,7 @@ func (svc *service) ServeSettingsPage(ctx context.Context, client io.Writer, c * CommonData: commonData, Settings: &c.Session.Settings, } - rCtx := getRendererContext(c.Session.Settings) + rCtx := getRendererContext(c) err = svc.renderer.RenderSettingsPage(rCtx, client, data) if err != nil { @@ -828,6 +826,7 @@ func (svc *service) getCommonData(ctx context.Context, client io.Writer, c *mode } data.HeaderData.NotificationCount = notificationCount + data.HeaderData.CSRFToken = c.Session.CSRFToken } return -- cgit v1.2.3