aboutsummaryrefslogtreecommitdiff
path: root/service/service.go
diff options
context:
space:
mode:
authorr <r@freesoftwareextremist.com>2020-01-25 10:07:06 +0000
committerr <r@freesoftwareextremist.com>2020-01-26 06:49:29 +0000
commitbf2cfaf0ede0e9744408f52538fb4bcd87a6d5b8 (patch)
tree5d3be1dfa65395bddedd2fb6f06a990c23274f00 /service/service.go
parent5fdc7a59b2efc60e35f5421e28986c356810456e (diff)
downloadbloat-bf2cfaf0ede0e9744408f52538fb4bcd87a6d5b8.tar.gz
bloat-bf2cfaf0ede0e9744408f52538fb4bcd87a6d5b8.zip
Add CSRF protection
Diffstat (limited to 'service/service.go')
-rw-r--r--service/service.go63
1 files changed, 31 insertions, 32 deletions
diff --git a/service/service.go b/service/service.go
index bfacf80..db851f7 100644
--- a/service/service.go
+++ b/service/service.go
@@ -78,12 +78,21 @@ func NewService(clientName string, clientScope string, clientWebsite string,
}
}
-func getRendererContext(s model.Settings) *renderer.Context {
+func getRendererContext(c *model.Client) *renderer.Context {
+ var settings model.Settings
+ var session model.Session
+ if c != nil {
+ settings = c.Session.Settings
+ session = c.Session
+ } else {
+ settings = *model.NewSettings()
+ }
return &renderer.Context{
- MaskNSFW: s.MaskNSFW,
- ThreadInNewTab: s.ThreadInNewTab,
- FluorideMode: s.FluorideMode,
- DarkMode: s.DarkMode,
+ MaskNSFW: settings.MaskNSFW,
+ ThreadInNewTab: settings.ThreadInNewTab,
+ FluorideMode: settings.FluorideMode,
+ DarkMode: settings.DarkMode,
+ CSRFToken: session.CSRFToken,
}
}
@@ -98,9 +107,11 @@ func (svc *service) GetAuthUrl(ctx context.Context, instance string) (
}
sessionID = util.NewSessionId()
+ csrfToken := util.NewCSRFToken()
session := model.Session{
ID: sessionID,
InstanceDomain: instance,
+ CSRFToken: csrfToken,
Settings: *model.NewSettings(),
}
err = svc.sessionRepo.Add(session)
@@ -199,13 +210,6 @@ func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *model
if err != nil {
return
}
- /*
- err = c.AuthenticateToken(ctx, code, svc.clientWebsite+"/oauth_callback")
- if err != nil {
- return
- }
- err = svc.sessionRepo.Update(sessionID, c.GetAccessToken(ctx))
- */
return res.AccessToken, nil
}
@@ -226,13 +230,7 @@ func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, c *mod
Error: errStr,
}
- var s model.Settings
- if c != nil {
- s = c.Session.Settings
- } else {
- s = *model.NewSettings()
- }
- rCtx := getRendererContext(s)
+ rCtx := getRendererContext(c)
svc.renderer.RenderErrorPage(rCtx, client, data)
}
@@ -247,7 +245,7 @@ func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err
CommonData: commonData,
}
- rCtx := getRendererContext(*model.NewSettings())
+ rCtx := getRendererContext(nil)
return svc.renderer.RenderSigninPage(rCtx, client, data)
}
@@ -334,7 +332,7 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer,
PostContext: postContext,
CommonData: commonData,
}
- rCtx := getRendererContext(c.Session.Settings)
+ rCtx := getRendererContext(c)
err = svc.renderer.RenderTimelinePage(rCtx, client, data)
if err != nil {
@@ -416,7 +414,7 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo
ReplyMap: replyMap,
CommonData: commonData,
}
- rCtx := getRendererContext(c.Session.Settings)
+ rCtx := getRendererContext(c)
err = svc.renderer.RenderThreadPage(rCtx, client, data)
if err != nil {
@@ -478,7 +476,7 @@ func (svc *service) ServeNotificationPage(ctx context.Context, client io.Writer,
NextLink: nextLink,
CommonData: commonData,
}
- rCtx := getRendererContext(c.Session.Settings)
+ rCtx := getRendererContext(c)
err = svc.renderer.RenderNotificationPage(rCtx, client, data)
if err != nil {
@@ -525,7 +523,7 @@ func (svc *service) ServeUserPage(ctx context.Context, client io.Writer, c *mode
NextLink: nextLink,
CommonData: commonData,
}
- rCtx := getRendererContext(c.Session.Settings)
+ rCtx := getRendererContext(c)
err = svc.renderer.RenderUserPage(rCtx, client, data)
if err != nil {
@@ -544,7 +542,7 @@ func (svc *service) ServeAboutPage(ctx context.Context, client io.Writer, c *mod
data := &renderer.AboutData{
CommonData: commonData,
}
- rCtx := getRendererContext(c.Session.Settings)
+ rCtx := getRendererContext(c)
err = svc.renderer.RenderAboutPage(rCtx, client, data)
if err != nil {
@@ -569,7 +567,7 @@ func (svc *service) ServeEmojiPage(ctx context.Context, client io.Writer, c *mod
Emojis: emojis,
CommonData: commonData,
}
- rCtx := getRendererContext(c.Session.Settings)
+ rCtx := getRendererContext(c)
err = svc.renderer.RenderEmojiPage(rCtx, client, data)
if err != nil {
@@ -594,7 +592,7 @@ func (svc *service) ServeLikedByPage(ctx context.Context, client io.Writer, c *m
CommonData: commonData,
Users: likers,
}
- rCtx := getRendererContext(c.Session.Settings)
+ rCtx := getRendererContext(c)
err = svc.renderer.RenderLikedByPage(rCtx, client, data)
if err != nil {
@@ -619,7 +617,7 @@ func (svc *service) ServeRetweetedByPage(ctx context.Context, client io.Writer,
CommonData: commonData,
Users: retweeters,
}
- rCtx := getRendererContext(c.Session.Settings)
+ rCtx := getRendererContext(c)
err = svc.renderer.RenderRetweetedByPage(rCtx, client, data)
if err != nil {
@@ -660,7 +658,7 @@ func (svc *service) ServeFollowingPage(ctx context.Context, client io.Writer, c
HasNext: hasNext,
NextLink: nextLink,
}
- rCtx := getRendererContext(c.Session.Settings)
+ rCtx := getRendererContext(c)
err = svc.renderer.RenderFollowingPage(rCtx, client, data)
if err != nil {
@@ -701,7 +699,7 @@ func (svc *service) ServeFollowersPage(ctx context.Context, client io.Writer, c
HasNext: hasNext,
NextLink: nextLink,
}
- rCtx := getRendererContext(c.Session.Settings)
+ rCtx := getRendererContext(c)
err = svc.renderer.RenderFollowersPage(rCtx, client, data)
if err != nil {
@@ -750,7 +748,7 @@ func (svc *service) ServeSearchPage(ctx context.Context, client io.Writer, c *mo
HasNext: hasNext,
NextLink: nextLink,
}
- rCtx := getRendererContext(c.Session.Settings)
+ rCtx := getRendererContext(c)
err = svc.renderer.RenderSearchPage(rCtx, client, data)
if err != nil {
@@ -770,7 +768,7 @@ func (svc *service) ServeSettingsPage(ctx context.Context, client io.Writer, c *
CommonData: commonData,
Settings: &c.Session.Settings,
}
- rCtx := getRendererContext(c.Session.Settings)
+ rCtx := getRendererContext(c)
err = svc.renderer.RenderSettingsPage(rCtx, client, data)
if err != nil {
@@ -828,6 +826,7 @@ func (svc *service) getCommonData(ctx context.Context, client io.Writer, c *mode
}
data.HeaderData.NotificationCount = notificationCount
+ data.HeaderData.CSRFToken = c.Session.CSRFToken
}
return