diff options
| author | r <r@freesoftwareextremist.com> | 2020-01-28 17:51:00 +0000 | 
|---|---|---|
| committer | r <r@freesoftwareextremist.com> | 2020-01-28 17:58:29 +0000 | 
| commit | 2af37d47783aac8c650ffd1578e2297b5784c73d (patch) | |
| tree | 7b5c7a4b2fa530285bfaa16324e818d97dd00408 /service | |
| parent | 57d2a4288b02fd1245ee85ae629649798578cf6c (diff) | |
| download | bloat-2af37d47783aac8c650ffd1578e2297b5784c73d.tar.gz bloat-2af37d47783aac8c650ffd1578e2297b5784c73d.zip  | |
Refactor everything
Diffstat (limited to 'service')
| -rw-r--r-- | service/auth.go | 208 | ||||
| -rw-r--r-- | service/logging.go | 190 | ||||
| -rw-r--r-- | service/service.go | 811 | ||||
| -rw-r--r-- | service/transport.go | 650 | 
4 files changed, 916 insertions, 943 deletions
diff --git a/service/auth.go b/service/auth.go index 909a9a2..78934fd 100644 --- a/service/auth.go +++ b/service/auth.go @@ -3,7 +3,6 @@ package service  import (  	"context"  	"errors" -	"io"  	"mime/multipart"  	"bloat/model" @@ -11,28 +10,28 @@ import (  )  var ( -	ErrInvalidSession   = errors.New("invalid session") -	ErrInvalidCSRFToken = errors.New("invalid csrf token") +	errInvalidSession   = errors.New("invalid session") +	errInvalidCSRFToken = errors.New("invalid csrf token")  ) -type authService struct { -	sessionRepo model.SessionRepository -	appRepo     model.AppRepository +type as struct { +	sessionRepo model.SessionRepo +	appRepo     model.AppRepo  	Service  } -func NewAuthService(sessionRepo model.SessionRepository, appRepo model.AppRepository, s Service) Service { -	return &authService{sessionRepo, appRepo, s} +func NewAuthService(sessionRepo model.SessionRepo, appRepo model.AppRepo, s Service) Service { +	return &as{sessionRepo, appRepo, s}  } -func (s *authService) getClient(ctx context.Context) (c *model.Client, err error) { +func (s *as) authenticateClient(ctx context.Context, c *model.Client) (err error) {  	sessionID, ok := ctx.Value("session_id").(string)  	if !ok || len(sessionID) < 1 { -		return nil, ErrInvalidSession +		return errInvalidSession  	}  	session, err := s.sessionRepo.Get(sessionID)  	if err != nil { -		return nil, ErrInvalidSession +		return errInvalidSession  	}  	client, err := s.appRepo.Get(session.InstanceDomain)  	if err != nil { @@ -44,152 +43,163 @@ func (s *authService) getClient(ctx context.Context) (c *model.Client, err error  		ClientSecret: client.ClientSecret,  		AccessToken:  session.AccessToken,  	}) -	c = &model.Client{Client: mc, Session: session} -	return c, nil +	if c == nil { +		c = &model.Client{} +	} +	c.Client = mc +	c.Session = session +	return nil  }  func checkCSRF(ctx context.Context, c *model.Client) (err error) { -	csrfToken, ok := ctx.Value("csrf_token").(string) -	if !ok || csrfToken != c.Session.CSRFToken { -		return ErrInvalidCSRFToken +	token, ok := ctx.Value("csrf_token").(string) +	if !ok || token != c.Session.CSRFToken { +		return errInvalidCSRFToken  	}  	return nil  } -func (s *authService) GetAuthUrl(ctx context.Context, instance string) ( -	redirectUrl string, sessionID string, err error) { -	return s.Service.GetAuthUrl(ctx, instance) +func (s *as) ServeErrorPage(ctx context.Context, c *model.Client, err error) { +	s.authenticateClient(ctx, c) +	s.Service.ServeErrorPage(ctx, c, err)  } -func (s *authService) GetUserToken(ctx context.Context, sessionID string, c *model.Client, -	code string) (token string, err error) { -	c, err = s.getClient(ctx) +func (s *as) ServeSigninPage(ctx context.Context, c *model.Client) (err error) { +	return s.Service.ServeSigninPage(ctx, c) +} + +func (s *as) ServeTimelinePage(ctx context.Context, c *model.Client, tType string, +	maxID string, minID string) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} +	return s.Service.ServeTimelinePage(ctx, c, tType, maxID, minID) +} -	token, err = s.Service.GetUserToken(ctx, c.Session.ID, c, code) +func (s *as) ServeThreadPage(ctx context.Context, c *model.Client, id string, reply bool) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} +	return s.Service.ServeThreadPage(ctx, c, id, reply) +} -	c.Session.AccessToken = token -	err = s.sessionRepo.Add(c.Session) +func (s *as) ServeLikedByPage(ctx context.Context, c *model.Client, id string) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} - -	return -} - -func (s *authService) ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) { -	c, _ = s.getClient(ctx) -	s.Service.ServeErrorPage(ctx, client, c, err) -} - -func (s *authService) ServeSigninPage(ctx context.Context, client io.Writer) (err error) { -	return s.Service.ServeSigninPage(ctx, client) +	return s.Service.ServeLikedByPage(ctx, c, id)  } -func (s *authService) ServeTimelinePage(ctx context.Context, client io.Writer, -	c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) { -	c, err = s.getClient(ctx) +func (s *as) ServeRetweetedByPage(ctx context.Context, c *model.Client, id string) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} -	return s.Service.ServeTimelinePage(ctx, client, c, timelineType, maxID, sinceID, minID) +	return s.Service.ServeRetweetedByPage(ctx, c, id)  } -func (s *authService) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) { -	c, err = s.getClient(ctx) +func (s *as) ServeFollowingPage(ctx context.Context, c *model.Client, id string, +	maxID string, minID string) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} -	return s.Service.ServeThreadPage(ctx, client, c, id, reply) +	return s.Service.ServeFollowingPage(ctx, c, id, maxID, minID)  } -func (s *authService) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) { -	c, err = s.getClient(ctx) +func (s *as) ServeFollowersPage(ctx context.Context, c *model.Client, id string, +	maxID string, minID string) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} -	return s.Service.ServeNotificationPage(ctx, client, c, maxID, minID) +	return s.Service.ServeFollowersPage(ctx, c, id, maxID, minID)  } -func (s *authService) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { -	c, err = s.getClient(ctx) +func (s *as) ServeNotificationPage(ctx context.Context, c *model.Client, +	maxID string, minID string) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} -	return s.Service.ServeUserPage(ctx, client, c, id, maxID, minID) +	return s.Service.ServeNotificationPage(ctx, c, maxID, minID)  } -func (s *authService) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { -	c, err = s.getClient(ctx) +func (s *as) ServeUserPage(ctx context.Context, c *model.Client, id string, +	maxID string, minID string) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} -	return s.Service.ServeAboutPage(ctx, client, c) +	return s.Service.ServeUserPage(ctx, c, id, maxID, minID)  } -func (s *authService) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { -	c, err = s.getClient(ctx) +func (s *as) ServeAboutPage(ctx context.Context, c *model.Client) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} -	return s.Service.ServeEmojiPage(ctx, client, c) +	return s.Service.ServeAboutPage(ctx, c)  } -func (s *authService) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { -	c, err = s.getClient(ctx) +func (s *as) ServeEmojiPage(ctx context.Context, c *model.Client) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} -	return s.Service.ServeLikedByPage(ctx, client, c, id) +	return s.Service.ServeEmojiPage(ctx, c)  } -func (s *authService) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { -	c, err = s.getClient(ctx) +func (s *as) ServeSearchPage(ctx context.Context, c *model.Client, q string, +	qType string, offset int) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} -	return s.Service.ServeRetweetedByPage(ctx, client, c, id) +	return s.Service.ServeSearchPage(ctx, c, q, qType, offset)  } -func (s *authService) ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { -	c, err = s.getClient(ctx) +func (s *as) ServeSettingsPage(ctx context.Context, c *model.Client) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} -	return s.Service.ServeFollowingPage(ctx, client, c, id, maxID, minID) +	return s.Service.ServeSettingsPage(ctx, c)  } -func (s *authService) ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { -	c, err = s.getClient(ctx) +func (s *as) NewSession(ctx context.Context, instance string) (redirectUrl string, +	sessionID string, err error) { +	return s.Service.NewSession(ctx, instance) +} + +func (s *as) Signin(ctx context.Context, c *model.Client, sessionID string, +	code string) (token string, err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} -	return s.Service.ServeFollowersPage(ctx, client, c, id, maxID, minID) -} -func (s *authService) ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) { -	c, err = s.getClient(ctx) +	token, err = s.Service.Signin(ctx, c, c.Session.ID, code)  	if err != nil {  		return  	} -	return s.Service.ServeSearchPage(ctx, client, c, q, qType, offset) -} -func (s *authService) ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { -	c, err = s.getClient(ctx) +	c.Session.AccessToken = token +	err = s.sessionRepo.Add(c.Session)  	if err != nil {  		return  	} -	return s.Service.ServeSettingsPage(ctx, client, c) + +	return  } -func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) { -	c, err = s.getClient(ctx) +func (s *as) Post(ctx context.Context, c *model.Client, content string, +	replyToID string, format string, visibility string, isNSFW bool, +	files []*multipart.FileHeader) (id string, err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} @@ -197,11 +207,11 @@ func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *mod  	if err != nil {  		return  	} -	return s.Service.SaveSettings(ctx, client, c, settings) +	return s.Service.Post(ctx, c, content, replyToID, format, visibility, isNSFW, files)  } -func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { -	c, err = s.getClient(ctx) +func (s *as) Like(ctx context.Context, c *model.Client, id string) (count int64, err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} @@ -209,11 +219,11 @@ func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Clien  	if err != nil {  		return  	} -	return s.Service.Like(ctx, client, c, id) +	return s.Service.Like(ctx, c, id)  } -func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { -	c, err = s.getClient(ctx) +func (s *as) UnLike(ctx context.Context, c *model.Client, id string) (count int64, err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} @@ -221,11 +231,11 @@ func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Cli  	if err != nil {  		return  	} -	return s.Service.UnLike(ctx, client, c, id) +	return s.Service.UnLike(ctx, c, id)  } -func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { -	c, err = s.getClient(ctx) +func (s *as) Retweet(ctx context.Context, c *model.Client, id string) (count int64, err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} @@ -233,11 +243,11 @@ func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Cl  	if err != nil {  		return  	} -	return s.Service.Retweet(ctx, client, c, id) +	return s.Service.Retweet(ctx, c, id)  } -func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { -	c, err = s.getClient(ctx) +func (s *as) UnRetweet(ctx context.Context, c *model.Client, id string) (count int64, err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} @@ -245,11 +255,11 @@ func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model.  	if err != nil {  		return  	} -	return s.Service.UnRetweet(ctx, client, c, id) +	return s.Service.UnRetweet(ctx, c, id)  } -func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) { -	c, err = s.getClient(ctx) +func (s *as) Follow(ctx context.Context, c *model.Client, id string) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} @@ -257,11 +267,11 @@ func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model.  	if err != nil {  		return  	} -	return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files) +	return s.Service.Follow(ctx, c, id)  } -func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { -	c, err = s.getClient(ctx) +func (s *as) UnFollow(ctx context.Context, c *model.Client, id string) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} @@ -269,11 +279,11 @@ func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Cli  	if err != nil {  		return  	} -	return s.Service.Follow(ctx, client, c, id) +	return s.Service.UnFollow(ctx, c, id)  } -func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { -	c, err = s.getClient(ctx) +func (s *as) SaveSettings(ctx context.Context, c *model.Client, settings *model.Settings) (err error) { +	err = s.authenticateClient(ctx, c)  	if err != nil {  		return  	} @@ -281,5 +291,5 @@ func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.C  	if err != nil {  		return  	} -	return s.Service.UnFollow(ctx, client, c, id) +	return s.Service.SaveSettings(ctx, c, settings)  } diff --git a/service/logging.go b/service/logging.go index cafd815..e4f8985 100644 --- a/service/logging.go +++ b/service/logging.go @@ -2,7 +2,6 @@ package service  import (  	"context" -	"io"  	"log"  	"mime/multipart"  	"time" @@ -10,206 +9,215 @@ import (  	"bloat/model"  ) -type loggingService struct { +type ls struct {  	logger *log.Logger  	Service  }  func NewLoggingService(logger *log.Logger, s Service) Service { -	return &loggingService{logger, s} +	return &ls{logger, s}  } -func (s *loggingService) GetAuthUrl(ctx context.Context, instance string) ( -	redirectUrl string, sessionID string, err error) { +func (s *ls) ServeErrorPage(ctx context.Context, c *model.Client, err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, instance=%v, took=%v, err=%v\n", -			"GetAuthUrl", instance, time.Since(begin), err) +		s.logger.Printf("method=%v, err=%v, took=%v\n", +			"ServeErrorPage", err, time.Since(begin))  	}(time.Now()) -	return s.Service.GetAuthUrl(ctx, instance) +	s.Service.ServeErrorPage(ctx, c, err)  } -func (s *loggingService) GetUserToken(ctx context.Context, sessionID string, c *model.Client, -	code string) (token string, err error) { +func (s *ls) ServeSigninPage(ctx context.Context, c *model.Client) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, session_id=%v, code=%v, took=%v, err=%v\n", -			"GetUserToken", sessionID, code, time.Since(begin), err) +		s.logger.Printf("method=%v, took=%v, err=%v\n", +			"ServeSigninPage", time.Since(begin), err)  	}(time.Now()) -	return s.Service.GetUserToken(ctx, sessionID, c, code) +	return s.Service.ServeSigninPage(ctx, c)  } -func (s *loggingService) ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) { +func (s *ls) ServeTimelinePage(ctx context.Context, c *model.Client, tType string, +	maxID string, minID string) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, err=%v, took=%v\n", -			"ServeErrorPage", err, time.Since(begin)) +		s.logger.Printf("method=%v, type=%v, took=%v, err=%v\n", +			"ServeTimelinePage", tType, time.Since(begin), err)  	}(time.Now()) -	s.Service.ServeErrorPage(ctx, client, c, err) +	return s.Service.ServeTimelinePage(ctx, c, tType, maxID, minID)  } -func (s *loggingService) ServeSigninPage(ctx context.Context, client io.Writer) (err error) { +func (s *ls) ServeThreadPage(ctx context.Context, c *model.Client, id string, +	reply bool) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, took=%v, err=%v\n", -			"ServeSigninPage", time.Since(begin), err) +		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", +			"ServeThreadPage", id, time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeSigninPage(ctx, client) +	return s.Service.ServeThreadPage(ctx, c, id, reply)  } -func (s *loggingService) ServeTimelinePage(ctx context.Context, client io.Writer, -	c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) { +func (s *ls) ServeLikedByPage(ctx context.Context, c *model.Client, id string) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, timeline_type=%v, max_id=%v, since_id=%v, min_id=%v, took=%v, err=%v\n", -			"ServeTimelinePage", timelineType, maxID, sinceID, minID, time.Since(begin), err) +		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", +			"ServeLikedByPage", id, time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeTimelinePage(ctx, client, c, timelineType, maxID, sinceID, minID) +	return s.Service.ServeLikedByPage(ctx, c, id)  } -func (s *loggingService) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) { +func (s *ls) ServeRetweetedByPage(ctx context.Context, c *model.Client, id string) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, id=%v, reply=%v, took=%v, err=%v\n", -			"ServeThreadPage", id, reply, time.Since(begin), err) +		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", +			"ServeRetweetedByPage", id, time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeThreadPage(ctx, client, c, id, reply) +	return s.Service.ServeRetweetedByPage(ctx, c, id)  } -func (s *loggingService) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) { +func (s *ls) ServeFollowingPage(ctx context.Context, c *model.Client, id string, +	maxID string, minID string) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, max_id=%v, min_id=%v, took=%v, err=%v\n", -			"ServeNotificationPage", maxID, minID, time.Since(begin), err) +		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", +			"ServeFollowingPage", id, time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeNotificationPage(ctx, client, c, maxID, minID) +	return s.Service.ServeFollowingPage(ctx, c, id, maxID, minID)  } -func (s *loggingService) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { +func (s *ls) ServeFollowersPage(ctx context.Context, c *model.Client, id string, +	maxID string, minID string) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, id=%v, max_id=%v, min_id=%v, took=%v, err=%v\n", -			"ServeUserPage", id, maxID, minID, time.Since(begin), err) +		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", +			"ServeFollowersPage", id, time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeUserPage(ctx, client, c, id, maxID, minID) +	return s.Service.ServeFollowersPage(ctx, c, id, maxID, minID)  } -func (s *loggingService) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { +func (s *ls) ServeNotificationPage(ctx context.Context, c *model.Client, +	maxID string, minID string) (err error) {  	defer func(begin time.Time) {  		s.logger.Printf("method=%v, took=%v, err=%v\n", -			"ServeAboutPage", time.Since(begin), err) +			"ServeNotificationPage", time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeAboutPage(ctx, client, c) +	return s.Service.ServeNotificationPage(ctx, c, maxID, minID)  } -func (s *loggingService) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { +func (s *ls) ServeUserPage(ctx context.Context, c *model.Client, id string, +	maxID string, minID string) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, took=%v, err=%v\n", -			"ServeEmojiPage", time.Since(begin), err) +		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", +			"ServeUserPage", id, time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeEmojiPage(ctx, client, c) +	return s.Service.ServeUserPage(ctx, c, id, maxID, minID)  } -func (s *loggingService) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (s *ls) ServeAboutPage(ctx context.Context, c *model.Client) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", -			"ServeLikedByPage", id, time.Since(begin), err) +		s.logger.Printf("method=%v, took=%v, err=%v\n", +			"ServeAboutPage", time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeLikedByPage(ctx, client, c, id) +	return s.Service.ServeAboutPage(ctx, c)  } -func (s *loggingService) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (s *ls) ServeEmojiPage(ctx context.Context, c *model.Client) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", -			"ServeRetweetedByPage", id, time.Since(begin), err) +		s.logger.Printf("method=%v, took=%v, err=%v\n", +			"ServeEmojiPage", time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeRetweetedByPage(ctx, client, c, id) +	return s.Service.ServeEmojiPage(ctx, c)  } -func (s *loggingService) ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { +func (s *ls) ServeSearchPage(ctx context.Context, c *model.Client, q string, +	qType string, offset int) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, id=%v, max_id=%v, min_id=%v, took=%v, err=%v\n", -			"ServeFollowingPage", id, maxID, minID, time.Since(begin), err) +		s.logger.Printf("method=%v, took=%v, err=%v\n", +			"ServeSearchPage", time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeFollowingPage(ctx, client, c, id, maxID, minID) +	return s.Service.ServeSearchPage(ctx, c, q, qType, offset)  } -func (s *loggingService) ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { +func (s *ls) ServeSettingsPage(ctx context.Context, c *model.Client) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, id=%v, max_id=%v, min_id=%v, took=%v, err=%v\n", -			"ServeFollowersPage", id, maxID, minID, time.Since(begin), err) +		s.logger.Printf("method=%v, took=%v, err=%v\n", +			"ServeSettingsPage", time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeFollowersPage(ctx, client, c, id, maxID, minID) +	return s.Service.ServeSettingsPage(ctx, c)  } -func (s *loggingService) ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) { +func (s *ls) NewSession(ctx context.Context, instance string) (redirectUrl string, +	sessionID string, err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, q=%v, type=%v, offset=%v, took=%v, err=%v\n", -			"ServeSearchPage", q, qType, offset, time.Since(begin), err) +		s.logger.Printf("method=%v, instance=%v, took=%v, err=%v\n", +			"NewSession", instance, time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeSearchPage(ctx, client, c, q, qType, offset) +	return s.Service.NewSession(ctx, instance)  } -func (s *loggingService) ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { +func (s *ls) Signin(ctx context.Context, c *model.Client, sessionID string, +	code string) (token string, err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, took=%v, err=%v\n", -			"ServeSettingsPage", time.Since(begin), err) +		s.logger.Printf("method=%v, session_id=%v, took=%v, err=%v\n", +			"Signin", sessionID, time.Since(begin), err)  	}(time.Now()) -	return s.Service.ServeSettingsPage(ctx, client, c) +	return s.Service.Signin(ctx, c, sessionID, code)  } -func (s *loggingService) SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) { +func (s *ls) Post(ctx context.Context, c *model.Client, content string, +	replyToID string, format string, visibility string, isNSFW bool, +	files []*multipart.FileHeader) (id string, err error) {  	defer func(begin time.Time) {  		s.logger.Printf("method=%v, took=%v, err=%v\n", -			"SaveSettings", time.Since(begin), err) +			"Post", time.Since(begin), err)  	}(time.Now()) -	return s.Service.SaveSettings(ctx, client, c, settings) +	return s.Service.Post(ctx, c, content, replyToID, format, +		visibility, isNSFW, files)  } -func (s *loggingService) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { +func (s *ls) Like(ctx context.Context, c *model.Client, id string) (count int64, err error) {  	defer func(begin time.Time) {  		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n",  			"Like", id, time.Since(begin), err)  	}(time.Now()) -	return s.Service.Like(ctx, client, c, id) +	return s.Service.Like(ctx, c, id)  } -func (s *loggingService) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { +func (s *ls) UnLike(ctx context.Context, c *model.Client, id string) (count int64, err error) {  	defer func(begin time.Time) {  		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n",  			"UnLike", id, time.Since(begin), err)  	}(time.Now()) -	return s.Service.UnLike(ctx, client, c, id) +	return s.Service.UnLike(ctx, c, id)  } -func (s *loggingService) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { +func (s *ls) Retweet(ctx context.Context, c *model.Client, id string) (count int64, err error) {  	defer func(begin time.Time) {  		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n",  			"Retweet", id, time.Since(begin), err)  	}(time.Now()) -	return s.Service.Retweet(ctx, client, c, id) +	return s.Service.Retweet(ctx, c, id)  } -func (s *loggingService) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { +func (s *ls) UnRetweet(ctx context.Context, c *model.Client, id string) (count int64, err error) {  	defer func(begin time.Time) {  		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n",  			"UnRetweet", id, time.Since(begin), err)  	}(time.Now()) -	return s.Service.UnRetweet(ctx, client, c, id) +	return s.Service.UnRetweet(ctx, c, id)  } -func (s *loggingService) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) { +func (s *ls) Follow(ctx context.Context, c *model.Client, id string) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, content=%v, reply_to_id=%v, format=%v, visibility=%v, is_nsfw=%v, took=%v, err=%v\n", -			"PostTweet", content, replyToID, format, visibility, isNSFW, time.Since(begin), err) +		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", +			"Follow", id, time.Since(begin), err)  	}(time.Now()) -	return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files) +	return s.Service.Follow(ctx, c, id)  } -func (s *loggingService) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (s *ls) UnFollow(ctx context.Context, c *model.Client, id string) (err error) {  	defer func(begin time.Time) {  		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", -			"Follow", id, time.Since(begin), err) +			"UnFollow", id, time.Since(begin), err)  	}(time.Now()) -	return s.Service.Follow(ctx, client, c, id) +	return s.Service.UnFollow(ctx, c, id)  } -func (s *loggingService) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { +func (s *ls) SaveSettings(ctx context.Context, c *model.Client, settings *model.Settings) (err error) {  	defer func(begin time.Time) { -		s.logger.Printf("method=%v, id=%v, took=%v, err=%v\n", -			"UnFollow", id, time.Since(begin), err) +		s.logger.Printf("method=%v, took=%v, err=%v\n", +			"SaveSettings", time.Since(begin), err)  	}(time.Now()) -	return s.Service.UnFollow(ctx, client, c, id) +	return s.Service.SaveSettings(ctx, c, settings)  } diff --git a/service/service.go b/service/service.go index c9fccb4..7ad860f 100644 --- a/service/service.go +++ b/service/service.go @@ -1,14 +1,10 @@  package service  import ( -	"bytes"  	"context" -	"encoding/json"  	"errors"  	"fmt" -	"io"  	"mime/multipart" -	"net/http"  	"net/url"  	"strings" @@ -19,37 +15,35 @@ import (  )  var ( -	ErrInvalidArgument = errors.New("invalid argument") -	ErrInvalidToken    = errors.New("invalid token") -	ErrInvalidClient   = errors.New("invalid client") -	ErrInvalidTimeline = errors.New("invalid timeline") +	errInvalidArgument = errors.New("invalid argument")  )  type Service interface { -	GetAuthUrl(ctx context.Context, instance string) (url string, sessionID string, err error) -	GetUserToken(ctx context.Context, sessionID string, c *model.Client, token string) (accessToken string, err error) -	ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) -	ServeSigninPage(ctx context.Context, client io.Writer) (err error) -	ServeTimelinePage(ctx context.Context, client io.Writer, c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) -	ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) -	ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) -	ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) -	ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) -	ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) -	ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) -	ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) -	ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) -	ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) -	ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) -	ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) -	SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) -	Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) -	UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) -	Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) -	UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) -	PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) -	Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) -	UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) +	ServeErrorPage(ctx context.Context, c *model.Client, err error) +	ServeSigninPage(ctx context.Context, c *model.Client) (err error) +	ServeTimelinePage(ctx context.Context, c *model.Client, tType string, maxID string, minID string) (err error) +	ServeThreadPage(ctx context.Context, c *model.Client, id string, reply bool) (err error) +	ServeLikedByPage(ctx context.Context, c *model.Client, id string) (err error) +	ServeRetweetedByPage(ctx context.Context, c *model.Client, id string) (err error) +	ServeFollowingPage(ctx context.Context, c *model.Client, id string, maxID string, minID string) (err error) +	ServeFollowersPage(ctx context.Context, c *model.Client, id string, maxID string, minID string) (err error) +	ServeNotificationPage(ctx context.Context, c *model.Client, maxID string, minID string) (err error) +	ServeUserPage(ctx context.Context, c *model.Client, id string, maxID string, minID string) (err error) +	ServeAboutPage(ctx context.Context, c *model.Client) (err error) +	ServeEmojiPage(ctx context.Context, c *model.Client) (err error) +	ServeSearchPage(ctx context.Context, c *model.Client, q string, qType string, offset int) (err error) +	ServeSettingsPage(ctx context.Context, c *model.Client) (err error) +	NewSession(ctx context.Context, instance string) (redirectUrl string, sessionID string, err error) +	Signin(ctx context.Context, c *model.Client, sessionID string, code string) (token string, err error) +	Post(ctx context.Context, c *model.Client, content string, replyToID string, format string, +		visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) +	Like(ctx context.Context, c *model.Client, id string) (count int64, err error) +	UnLike(ctx context.Context, c *model.Client, id string) (count int64, err error) +	Retweet(ctx context.Context, c *model.Client, id string) (count int64, err error) +	UnRetweet(ctx context.Context, c *model.Client, id string) (count int64, err error) +	Follow(ctx context.Context, c *model.Client, id string) (err error) +	UnFollow(ctx context.Context, c *model.Client, id string) (err error) +	SaveSettings(ctx context.Context, c *model.Client, settings *model.Settings) (err error)  }  type service struct { @@ -59,13 +53,19 @@ type service struct {  	customCSS     string  	postFormats   []model.PostFormat  	renderer      renderer.Renderer -	sessionRepo   model.SessionRepository -	appRepo       model.AppRepository +	sessionRepo   model.SessionRepo +	appRepo       model.AppRepo  } -func NewService(clientName string, clientScope string, clientWebsite string, -	customCSS string, postFormats []model.PostFormat, renderer renderer.Renderer, -	sessionRepo model.SessionRepository, appRepo model.AppRepository) Service { +func NewService(clientName string, +	clientScope string, +	clientWebsite string, +	customCSS string, +	postFormats []model.PostFormat, +	renderer renderer.Renderer, +	sessionRepo model.SessionRepo, +	appRepo model.AppRepo, +) Service {  	return &service{  		clientName:    clientName,  		clientScope:   clientScope, @@ -96,137 +96,75 @@ func getRendererContext(c *model.Client) *renderer.Context {  	}  } -func (svc *service) GetAuthUrl(ctx context.Context, instance string) ( -	redirectUrl string, sessionID string, err error) { -	var instanceURL string -	if strings.HasPrefix(instance, "https://") { -		instanceURL = instance -		instance = strings.TrimPrefix(instance, "https://") -	} else { -		instanceURL = "https://" + instance -	} - -	sessionID, err = util.NewSessionId() -	if err != nil { -		return -	} -	csrfToken, err := util.NewCSRFToken() -	if err != nil { -		return -	} -	session := model.Session{ -		ID:             sessionID, -		InstanceDomain: instance, -		CSRFToken:      csrfToken, -		Settings:       *model.NewSettings(), -	} -	err = svc.sessionRepo.Add(session) -	if err != nil { +func addToReplyMap(m map[string][]mastodon.ReplyInfo, key interface{}, +	val string, number int) { +	if key == nil {  		return  	} -	app, err := svc.appRepo.Get(instance) -	if err != nil { -		if err != model.ErrAppNotFound { -			return -		} - -		var mastoApp *mastodon.Application -		mastoApp, err = mastodon.RegisterApp(ctx, &mastodon.AppConfig{ -			Server:       instanceURL, -			ClientName:   svc.clientName, -			Scopes:       svc.clientScope, -			Website:      svc.clientWebsite, -			RedirectURIs: svc.clientWebsite + "/oauth_callback", -		}) -		if err != nil { -			return -		} - -		app = model.App{ -			InstanceDomain: instance, -			InstanceURL:    instanceURL, -			ClientID:       mastoApp.ClientID, -			ClientSecret:   mastoApp.ClientSecret, -		} - -		err = svc.appRepo.Add(app) -		if err != nil { -			return -		} -	} - -	u, err := url.Parse("/oauth/authorize") -	if err != nil { +	keyStr, ok := key.(string) +	if !ok {  		return  	} -	q := make(url.Values) -	q.Set("scope", "read write follow") -	q.Set("client_id", app.ClientID) -	q.Set("response_type", "code") -	q.Set("redirect_uri", svc.clientWebsite+"/oauth_callback") -	u.RawQuery = q.Encode() - -	redirectUrl = instanceURL + u.String() +	_, ok = m[keyStr] +	if !ok { +		m[keyStr] = []mastodon.ReplyInfo{} +	} -	return +	m[keyStr] = append(m[keyStr], mastodon.ReplyInfo{val, number})  } -func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *model.Client, -	code string) (token string, err error) { -	if len(code) < 1 { -		err = ErrInvalidArgument -		return +func (svc *service) getCommonData(ctx context.Context, c *model.Client, +	title string) (data *renderer.CommonData, err error) { + +	data = new(renderer.CommonData) +	data.HeaderData = &renderer.HeaderData{ +		Title:             title + " - " + svc.clientName, +		NotificationCount: 0, +		CustomCSS:         svc.customCSS,  	} -	session, err := svc.sessionRepo.Get(sessionID) -	if err != nil { +	if c == nil || !c.Session.IsLoggedIn() {  		return  	} -	app, err := svc.appRepo.Get(session.InstanceDomain) +	notifications, err := c.GetNotifications(ctx, nil)  	if err != nil { -		return +		return nil, err  	} -	data := &bytes.Buffer{} -	err = json.NewEncoder(data).Encode(map[string]string{ -		"client_id":     app.ClientID, -		"client_secret": app.ClientSecret, -		"grant_type":    "authorization_code", -		"code":          code, -		"redirect_uri":  svc.clientWebsite + "/oauth_callback", -	}) -	if err != nil { -		return +	var notificationCount int +	for i := range notifications { +		if notifications[i].Pleroma != nil && +			!notifications[i].Pleroma.IsSeen { +			notificationCount++ +		}  	} -	resp, err := http.Post(app.InstanceURL+"/oauth/token", "application/json", data) +	u, err := c.GetAccountCurrentUser(ctx)  	if err != nil { -		return +		return nil, err  	} -	defer resp.Body.Close() -	var res struct { -		AccessToken string `json:"access_token"` +	data.NavbarData = &renderer.NavbarData{ +		User:              u, +		NotificationCount: notificationCount,  	} -	err = json.NewDecoder(resp.Body).Decode(&res) -	if err != nil { -		return -	} +	data.HeaderData.NotificationCount = notificationCount +	data.HeaderData.CSRFToken = c.Session.CSRFToken -	return res.AccessToken, nil +	return  } -func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) { +func (svc *service) ServeErrorPage(ctx context.Context, c *model.Client, err error) {  	var errStr string  	if err != nil {  		errStr = err.Error()  	} -	commonData, err := svc.getCommonData(ctx, client, nil, "error") +	commonData, err := svc.getCommonData(ctx, nil, "error")  	if err != nil {  		return  	} @@ -237,12 +175,13 @@ func (svc *service) ServeErrorPage(ctx context.Context, client io.Writer, c *mod  	}  	rCtx := getRendererContext(c) - -	svc.renderer.RenderErrorPage(rCtx, client, data) +	svc.renderer.RenderErrorPage(rCtx, c.Writer, data)  } -func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err error) { -	commonData, err := svc.getCommonData(ctx, client, nil, "signin") +func (svc *service) ServeSigninPage(ctx context.Context, c *model.Client) ( +	err error) { + +	commonData, err := svc.getCommonData(ctx, nil, "signin")  	if err != nil {  		return  	} @@ -252,26 +191,23 @@ func (svc *service) ServeSigninPage(ctx context.Context, client io.Writer) (err  	}  	rCtx := getRendererContext(nil) -	return svc.renderer.RenderSigninPage(rCtx, client, data) +	return svc.renderer.RenderSigninPage(rCtx, c.Writer, data)  } -func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer, -	c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) { - -	var hasNext, hasPrev bool -	var nextLink, prevLink string +func (svc *service) ServeTimelinePage(ctx context.Context, c *model.Client, +	tType string, maxID string, minID string) (err error) { +	var nextLink, prevLink, title string +	var statuses []*mastodon.Status  	var pg = mastodon.Pagination{  		MaxID: maxID,  		MinID: minID,  		Limit: 20,  	} -	var statuses []*mastodon.Status -	var title string -	switch timelineType { +	switch tType {  	default: -		return ErrInvalidTimeline +		return errInvalidArgument  	case "home":  		statuses, err = c.GetTimelineHome(ctx, &pg)  		title = "Timeline" @@ -293,29 +229,31 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer,  	}  	if len(maxID) > 0 && len(statuses) > 0 { -		hasPrev = true -		prevLink = fmt.Sprintf("/timeline/$s?min_id=%s", timelineType, statuses[0].ID) +		prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", tType, +			statuses[0].ID)  	} +  	if len(minID) > 0 && len(pg.MinID) > 0 { -		newStatuses, err := c.GetTimelineHome(ctx, &mastodon.Pagination{MinID: pg.MinID, Limit: 20}) +		newPg := &mastodon.Pagination{MinID: pg.MinID, Limit: 20} +		newStatuses, err := c.GetTimelineHome(ctx, newPg)  		if err != nil {  			return err  		} -		newStatusesLen := len(newStatuses) -		if newStatusesLen == 20 { -			hasPrev = true -			prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", timelineType, pg.MinID) +		newLen := len(newStatuses) +		if newLen == 20 { +			prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", +				tType, pg.MinID)  		} else { -			i := 20 - newStatusesLen - 1 +			i := 20 - newLen - 1  			if len(statuses) > i { -				hasPrev = true -				prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", timelineType, statuses[i].ID) +				prevLink = fmt.Sprintf("/timeline/%s?min_id=%s", +					tType, statuses[i].ID)  			}  		}  	} +  	if len(pg.MaxID) > 0 { -		hasNext = true -		nextLink = fmt.Sprintf("/timeline/%s?max_id=%s", timelineType, pg.MaxID) +		nextLink = fmt.Sprintf("/timeline/%s?max_id=%s", tType, pg.MaxID)  	}  	postContext := model.PostContext{ @@ -323,7 +261,7 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer,  		Formats:           svc.postFormats,  	} -	commonData, err := svc.getCommonData(ctx, client, c, timelineType+" timeline ") +	commonData, err := svc.getCommonData(ctx, c, tType+" timeline ")  	if err != nil {  		return  	} @@ -331,24 +269,21 @@ func (svc *service) ServeTimelinePage(ctx context.Context, client io.Writer,  	data := &renderer.TimelineData{  		Title:       title,  		Statuses:    statuses, -		HasNext:     hasNext,  		NextLink:    nextLink, -		HasPrev:     hasPrev,  		PrevLink:    prevLink,  		PostContext: postContext,  		CommonData:  commonData,  	} +  	rCtx := getRendererContext(c) +	return svc.renderer.RenderTimelinePage(rCtx, c.Writer, data) +} -	err = svc.renderer.RenderTimelinePage(rCtx, client, data) -	if err != nil { -		return -	} +func (svc *service) ServeThreadPage(ctx context.Context, c *model.Client, +	id string, reply bool) (err error) { -	return -} +	var postContext model.PostContext -func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) {  	status, err := c.GetStatus(ctx, id)  	if err != nil {  		return @@ -359,19 +294,19 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo  		return  	} -	var postContext model.PostContext  	if reply {  		var content string +		var visibility string  		if u.ID != status.Account.ID {  			content += "@" + status.Account.Acct + " "  		}  		for i := range status.Mentions { -			if status.Mentions[i].ID != u.ID && status.Mentions[i].ID != status.Account.ID { +			if status.Mentions[i].ID != u.ID && +				status.Mentions[i].ID != status.Account.ID {  				content += "@" + status.Mentions[i].Acct + " "  			}  		} -		var visibility string  		if c.Session.Settings.CopyScope {  			s, err := c.GetStatus(ctx, id)  			if err != nil { @@ -400,16 +335,15 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo  	}  	statuses := append(append(context.Ancestors, status), context.Descendants...) - -	replyMap := make(map[string][]mastodon.ReplyInfo) +	replies := make(map[string][]mastodon.ReplyInfo)  	for i := range statuses {  		statuses[i].ShowReplies = true -		statuses[i].ReplyMap = replyMap -		addToReplyMap(replyMap, statuses[i].InReplyToID, statuses[i].ID, i+1) +		statuses[i].ReplyMap = replies +		addToReplyMap(replies, statuses[i].InReplyToID, statuses[i].ID, i+1)  	} -	commonData, err := svc.getCommonData(ctx, client, c, "post by "+status.Account.DisplayName) +	commonData, err := svc.getCommonData(ctx, c, "post by "+status.Account.DisplayName)  	if err != nil {  		return  	} @@ -417,224 +351,182 @@ func (svc *service) ServeThreadPage(ctx context.Context, client io.Writer, c *mo  	data := &renderer.ThreadData{  		Statuses:    statuses,  		PostContext: postContext, -		ReplyMap:    replyMap, +		ReplyMap:    replies,  		CommonData:  commonData,  	} -	rCtx := getRendererContext(c) -	err = svc.renderer.RenderThreadPage(rCtx, client, data) -	if err != nil { -		return -	} - -	return +	rCtx := getRendererContext(c) +	return svc.renderer.RenderThreadPage(rCtx, c.Writer, data)  } -func (svc *service) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) { -	var hasNext bool -	var nextLink string +func (svc *service) ServeLikedByPage(ctx context.Context, c *model.Client, +	id string) (err error) { -	var pg = mastodon.Pagination{ -		MaxID: maxID, -		MinID: minID, -		Limit: 20, +	likers, err := c.GetFavouritedBy(ctx, id, nil) +	if err != nil { +		return  	} -	notifications, err := c.GetNotifications(ctx, &pg) +	commonData, err := svc.getCommonData(ctx, c, "likes")  	if err != nil {  		return  	} -	var unreadCount int -	for i := range notifications { -		if notifications[i].Status != nil { -			notifications[i].Status.CreatedAt = notifications[i].CreatedAt -			switch notifications[i].Type { -			case "reblog", "favourite": -				notifications[i].Status.HideAccountInfo = true -			} -		} -		if notifications[i].Pleroma != nil && !notifications[i].Pleroma.IsSeen { -			unreadCount++ -		} +	data := &renderer.LikedByData{ +		CommonData: commonData, +		Users:      likers,  	} -	if unreadCount > 0 { -		err := c.ReadNotifications(ctx, notifications[0].ID) -		if err != nil { -			return err -		} -	} +	rCtx := getRendererContext(c) +	return svc.renderer.RenderLikedByPage(rCtx, c.Writer, data) +} -	if len(pg.MaxID) > 0 { -		hasNext = true -		nextLink = "/notifications?max_id=" + pg.MaxID -	} +func (svc *service) ServeRetweetedByPage(ctx context.Context, c *model.Client, +	id string) (err error) { -	commonData, err := svc.getCommonData(ctx, client, c, "notifications") +	retweeters, err := c.GetRebloggedBy(ctx, id, nil)  	if err != nil {  		return  	} -	data := &renderer.NotificationData{ -		Notifications: notifications, -		HasNext:       hasNext, -		NextLink:      nextLink, -		CommonData:    commonData, -	} -	rCtx := getRendererContext(c) - -	err = svc.renderer.RenderNotificationPage(rCtx, client, data) +	commonData, err := svc.getCommonData(ctx, c, "retweets")  	if err != nil {  		return  	} -	return +	data := &renderer.RetweetedByData{ +		CommonData: commonData, +		Users:      retweeters, +	} + +	rCtx := getRendererContext(c) +	return svc.renderer.RenderRetweetedByPage(rCtx, c.Writer, data)  } -func (svc *service) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { -	user, err := c.GetAccount(ctx, id) -	if err != nil { -		return -	} +func (svc *service) ServeFollowingPage(ctx context.Context, c *model.Client, +	id string, maxID string, minID string) (err error) { -	var hasNext bool  	var nextLink string -  	var pg = mastodon.Pagination{  		MaxID: maxID,  		MinID: minID,  		Limit: 20,  	} -	statuses, err := c.GetAccountStatuses(ctx, id, &pg) +	followings, err := c.GetAccountFollowing(ctx, id, &pg)  	if err != nil {  		return  	} -	if len(pg.MaxID) > 0 { -		hasNext = true -		nextLink = "/user/" + id + "?max_id=" + pg.MaxID +	if len(followings) == 20 && len(pg.MaxID) > 0 { +		nextLink = "/following/" + id + "?max_id=" + pg.MaxID  	} -	commonData, err := svc.getCommonData(ctx, client, c, user.DisplayName) +	commonData, err := svc.getCommonData(ctx, c, "following")  	if err != nil {  		return  	} -	data := &renderer.UserData{ -		User:       user, -		Statuses:   statuses, -		HasNext:    hasNext, -		NextLink:   nextLink, +	data := &renderer.FollowingData{  		CommonData: commonData, -	} -	rCtx := getRendererContext(c) - -	err = svc.renderer.RenderUserPage(rCtx, client, data) -	if err != nil { -		return +		Users:      followings, +		NextLink:   nextLink,  	} -	return +	rCtx := getRendererContext(c) +	return svc.renderer.RenderFollowingPage(rCtx, c.Writer, data)  } -func (svc *service) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { -	commonData, err := svc.getCommonData(ctx, client, c, "about") -	if err != nil { -		return -	} +func (svc *service) ServeFollowersPage(ctx context.Context, c *model.Client, +	id string, maxID string, minID string) (err error) { -	data := &renderer.AboutData{ -		CommonData: commonData, +	var nextLink string +	var pg = mastodon.Pagination{ +		MaxID: maxID, +		MinID: minID, +		Limit: 20,  	} -	rCtx := getRendererContext(c) -	err = svc.renderer.RenderAboutPage(rCtx, client, data) +	followers, err := c.GetAccountFollowers(ctx, id, &pg)  	if err != nil {  		return  	} -	return -} - -func (svc *service) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { -	commonData, err := svc.getCommonData(ctx, client, c, "emojis") -	if err != nil { -		return +	if len(followers) == 20 && len(pg.MaxID) > 0 { +		nextLink = "/followers/" + id + "?max_id=" + pg.MaxID  	} -	emojis, err := c.GetInstanceEmojis(ctx) +	commonData, err := svc.getCommonData(ctx, c, "followers")  	if err != nil {  		return  	} -	data := &renderer.EmojiData{ -		Emojis:     emojis, +	data := &renderer.FollowersData{  		CommonData: commonData, +		Users:      followers, +		NextLink:   nextLink,  	}  	rCtx := getRendererContext(c) - -	err = svc.renderer.RenderEmojiPage(rCtx, client, data) -	if err != nil { -		return -	} - -	return +	return svc.renderer.RenderFollowersPage(rCtx, c.Writer, data)  } -func (svc *service) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { -	likers, err := c.GetFavouritedBy(ctx, id, nil) -	if err != nil { -		return +func (svc *service) ServeNotificationPage(ctx context.Context, c *model.Client, +	maxID string, minID string) (err error) { + +	var nextLink string +	var unreadCount int +	var pg = mastodon.Pagination{ +		MaxID: maxID, +		MinID: minID, +		Limit: 20,  	} -	commonData, err := svc.getCommonData(ctx, client, c, "likes") +	notifications, err := c.GetNotifications(ctx, &pg)  	if err != nil {  		return  	} -	data := &renderer.LikedByData{ -		CommonData: commonData, -		Users:      likers, +	for i := range notifications { +		if notifications[i].Status != nil { +			notifications[i].Status.CreatedAt = notifications[i].CreatedAt +			switch notifications[i].Type { +			case "reblog", "favourite": +				notifications[i].Status.HideAccountInfo = true +			} +		} +		if notifications[i].Pleroma != nil && !notifications[i].Pleroma.IsSeen { +			unreadCount++ +		}  	} -	rCtx := getRendererContext(c) -	err = svc.renderer.RenderLikedByPage(rCtx, client, data) -	if err != nil { -		return +	if unreadCount > 0 { +		err := c.ReadNotifications(ctx, notifications[0].ID) +		if err != nil { +			return err +		}  	} -	return -} - -func (svc *service) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { -	retweeters, err := c.GetRebloggedBy(ctx, id, nil) -	if err != nil { -		return +	if len(pg.MaxID) > 0 { +		nextLink = "/notifications?max_id=" + pg.MaxID  	} -	commonData, err := svc.getCommonData(ctx, client, c, "retweets") +	commonData, err := svc.getCommonData(ctx, c, "notifications")  	if err != nil {  		return  	} -	data := &renderer.RetweetedByData{ -		CommonData: commonData, -		Users:      retweeters, +	data := &renderer.NotificationData{ +		Notifications: notifications, +		NextLink:      nextLink, +		CommonData:    commonData,  	}  	rCtx := getRendererContext(c) - -	err = svc.renderer.RenderRetweetedByPage(rCtx, client, data) -	if err != nil { -		return -	} - -	return +	return svc.renderer.RenderNotificationPage(rCtx, c.Writer, data)  } -func (svc *service) ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { -	var hasNext bool +func (svc *service) ServeUserPage(ctx context.Context, c *model.Client, +	id string, maxID string, minID string) (err error) { +  	var nextLink string  	var pg = mastodon.Pagination{ @@ -643,104 +535,91 @@ func (svc *service) ServeFollowingPage(ctx context.Context, client io.Writer, c  		Limit: 20,  	} -	followings, err := c.GetAccountFollowing(ctx, id, &pg) +	user, err := c.GetAccount(ctx, id)  	if err != nil {  		return  	} -	if len(followings) == 20 && len(pg.MaxID) > 0 { -		hasNext = true -		nextLink = "/following/" + id + "?max_id=" + pg.MaxID +	statuses, err := c.GetAccountStatuses(ctx, id, &pg) +	if err != nil { +		return +	} + +	if len(pg.MaxID) > 0 { +		nextLink = "/user/" + id + "?max_id=" + pg.MaxID  	} -	commonData, err := svc.getCommonData(ctx, client, c, "following") +	commonData, err := svc.getCommonData(ctx, c, user.DisplayName)  	if err != nil {  		return  	} -	data := &renderer.FollowingData{ -		CommonData: commonData, -		Users:      followings, -		HasNext:    hasNext, +	data := &renderer.UserData{ +		User:       user, +		Statuses:   statuses,  		NextLink:   nextLink, +		CommonData: commonData,  	}  	rCtx := getRendererContext(c) +	return svc.renderer.RenderUserPage(rCtx, c.Writer, data) +} -	err = svc.renderer.RenderFollowingPage(rCtx, client, data) +func (svc *service) ServeAboutPage(ctx context.Context, c *model.Client) (err error) { +	commonData, err := svc.getCommonData(ctx, c, "about")  	if err != nil {  		return  	} -	return -} - -func (svc *service) ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) { -	var hasNext bool -	var nextLink string - -	var pg = mastodon.Pagination{ -		MaxID: maxID, -		MinID: minID, -		Limit: 20, +	data := &renderer.AboutData{ +		CommonData: commonData,  	} -	followers, err := c.GetAccountFollowers(ctx, id, &pg) +	rCtx := getRendererContext(c) +	return svc.renderer.RenderAboutPage(rCtx, c.Writer, data) +} + +func (svc *service) ServeEmojiPage(ctx context.Context, c *model.Client) (err error) { +	commonData, err := svc.getCommonData(ctx, c, "emojis")  	if err != nil {  		return  	} -	if len(followers) == 20 && len(pg.MaxID) > 0 { -		hasNext = true -		nextLink = "/followers/" + id + "?max_id=" + pg.MaxID -	} - -	commonData, err := svc.getCommonData(ctx, client, c, "followers") +	emojis, err := c.GetInstanceEmojis(ctx)  	if err != nil {  		return  	} -	data := &renderer.FollowersData{ +	data := &renderer.EmojiData{ +		Emojis:     emojis,  		CommonData: commonData, -		Users:      followers, -		HasNext:    hasNext, -		NextLink:   nextLink,  	} -	rCtx := getRendererContext(c) -	err = svc.renderer.RenderFollowersPage(rCtx, client, data) -	if err != nil { -		return -	} - -	return +	rCtx := getRendererContext(c) +	return svc.renderer.RenderEmojiPage(rCtx, c.Writer, data)  } -func (svc *service) ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) { -	var hasNext bool +func (svc *service) ServeSearchPage(ctx context.Context, c *model.Client, +	q string, qType string, offset int) (err error) { +  	var nextLink string +	var title = "search"  	results, err := c.Search(ctx, q, qType, 20, true, offset)  	if err != nil {  		return  	} -	switch qType { -	case "accounts": -		hasNext = len(results.Accounts) == 20 -	case "statuses": -		hasNext = len(results.Statuses) == 20 -	} - -	if hasNext { +	if (qType == "accounts" && len(results.Accounts) == 20) || +		(qType == "statuses" && len(results.Statuses) == 20) {  		offset += 20  		nextLink = fmt.Sprintf("/search?q=%s&type=%s&offset=%d", q, qType, offset)  	} -	var title = "search"  	if len(q) > 0 {  		title += " \"" + q + "\""  	} -	commonData, err := svc.getCommonData(ctx, client, c, title) + +	commonData, err := svc.getCommonData(ctx, c, title)  	if err != nil {  		return  	} @@ -751,21 +630,15 @@ func (svc *service) ServeSearchPage(ctx context.Context, client io.Writer, c *mo  		Type:       qType,  		Users:      results.Accounts,  		Statuses:   results.Statuses, -		HasNext:    hasNext,  		NextLink:   nextLink,  	} -	rCtx := getRendererContext(c) -	err = svc.renderer.RenderSearchPage(rCtx, client, data) -	if err != nil { -		return -	} - -	return +	rCtx := getRendererContext(c) +	return svc.renderer.RenderSearchPage(rCtx, c.Writer, data)  } -func (svc *service) ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) { -	commonData, err := svc.getCommonData(ctx, client, c, "settings") +func (svc *service) ServeSettingsPage(ctx context.Context, c *model.Client) (err error) { +	commonData, err := svc.getCommonData(ctx, c, "settings")  	if err != nil {  		return  	} @@ -774,122 +647,125 @@ func (svc *service) ServeSettingsPage(ctx context.Context, client io.Writer, c *  		CommonData: commonData,  		Settings:   &c.Session.Settings,  	} +  	rCtx := getRendererContext(c) +	return svc.renderer.RenderSettingsPage(rCtx, c.Writer, data) +} -	err = svc.renderer.RenderSettingsPage(rCtx, client, data) +func (svc *service) NewSession(ctx context.Context, instance string) ( +	redirectUrl string, sessionID string, err error) { + +	var instanceURL string +	if strings.HasPrefix(instance, "https://") { +		instanceURL = instance +		instance = strings.TrimPrefix(instance, "https://") +	} else { +		instanceURL = "https://" + instance +	} + +	sessionID, err = util.NewSessionID()  	if err != nil {  		return  	} -	return -} - -func (svc *service) SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) { -	session, err := svc.sessionRepo.Get(c.Session.ID) +	csrfToken, err := util.NewCSRFToken()  	if err != nil {  		return  	} -	session.Settings = *settings +	session := model.Session{ +		ID:             sessionID, +		InstanceDomain: instance, +		CSRFToken:      csrfToken, +		Settings:       *model.NewSettings(), +	} +  	err = svc.sessionRepo.Add(session)  	if err != nil {  		return  	} -	return -} - -func (svc *service) getCommonData(ctx context.Context, client io.Writer, c *model.Client, title string) (data *renderer.CommonData, err error) { -	data = new(renderer.CommonData) - -	data.HeaderData = &renderer.HeaderData{ -		Title:             title + " - " + svc.clientName, -		NotificationCount: 0, -		CustomCSS:         svc.customCSS, -	} +	app, err := svc.appRepo.Get(instance) +	if err != nil { +		if err != model.ErrAppNotFound { +			return +		} -	if c != nil && c.Session.IsLoggedIn() { -		notifications, err := c.GetNotifications(ctx, nil) +		mastoApp, err := mastodon.RegisterApp(ctx, &mastodon.AppConfig{ +			Server:       instanceURL, +			ClientName:   svc.clientName, +			Scopes:       svc.clientScope, +			Website:      svc.clientWebsite, +			RedirectURIs: svc.clientWebsite + "/oauth_callback", +		})  		if err != nil { -			return nil, err +			return "", "", err  		} -		var notificationCount int -		for i := range notifications { -			if notifications[i].Pleroma != nil && !notifications[i].Pleroma.IsSeen { -				notificationCount++ -			} +		app = model.App{ +			InstanceDomain: instance, +			InstanceURL:    instanceURL, +			ClientID:       mastoApp.ClientID, +			ClientSecret:   mastoApp.ClientSecret,  		} -		u, err := c.GetAccountCurrentUser(ctx) +		err = svc.appRepo.Add(app)  		if err != nil { -			return nil, err -		} - -		data.NavbarData = &renderer.NavbarData{ -			User:              u, -			NotificationCount: notificationCount, +			return "", "", err  		} - -		data.HeaderData.NotificationCount = notificationCount -		data.HeaderData.CSRFToken = c.Session.CSRFToken  	} -	return -} - -func (svc *service) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { -	s, err := c.Favourite(ctx, id) +	u, err := url.Parse("/oauth/authorize")  	if err != nil {  		return  	} -	count = s.FavouritesCount -	return -} -func (svc *service) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { -	s, err := c.Unfavourite(ctx, id) -	if err != nil { -		return -	} -	count = s.FavouritesCount +	q := make(url.Values) +	q.Set("scope", "read write follow") +	q.Set("client_id", app.ClientID) +	q.Set("response_type", "code") +	q.Set("redirect_uri", svc.clientWebsite+"/oauth_callback") +	u.RawQuery = q.Encode() + +	redirectUrl = instanceURL + u.String() +  	return  } -func (svc *service) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { -	s, err := c.Reblog(ctx, id) -	if err != nil { +func (svc *service) Signin(ctx context.Context, c *model.Client, +	sessionID string, code string) (token string, err error) { + +	if len(code) < 1 { +		err = errInvalidArgument  		return  	} -	if s.Reblog != nil { -		count = s.Reblog.ReblogsCount -	} -	return -} -func (svc *service) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) { -	s, err := c.Unreblog(ctx, id) +	err = c.AuthenticateToken(ctx, code, svc.clientWebsite+"/oauth_callback")  	if err != nil {  		return  	} -	count = s.ReblogsCount +	token = c.GetAccessToken(ctx) +  	return  } -func (svc *service) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) { -	var mediaIds []string +func (svc *service) Post(ctx context.Context, c *model.Client, content string, +	replyToID string, format string, visibility string, isNSFW bool, +	files []*multipart.FileHeader) (id string, err error) { + +	var mediaIDs []string  	for _, f := range files {  		a, err := c.UploadMediaFromMultipartFileHeader(ctx, f)  		if err != nil {  			return "", err  		} -		mediaIds = append(mediaIds, a.ID) +		mediaIDs = append(mediaIDs, a.ID)  	}  	tweet := &mastodon.Toot{  		Status:      content,  		InReplyToID: replyToID, -		MediaIDs:    mediaIds, +		MediaIDs:    mediaIDs,  		ContentType: format,  		Visibility:  visibility,  		Sensitive:   isNSFW, @@ -903,29 +779,66 @@ func (svc *service) PostTweet(ctx context.Context, client io.Writer, c *model.Cl  	return s.ID, nil  } -func (svc *service) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { -	_, err = c.AccountFollow(ctx, id) +func (svc *service) Like(ctx context.Context, c *model.Client, id string) ( +	count int64, err error) { +	s, err := c.Favourite(ctx, id) +	if err != nil { +		return +	} +	count = s.FavouritesCount  	return  } -func (svc *service) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) { -	_, err = c.AccountUnfollow(ctx, id) +func (svc *service) UnLike(ctx context.Context, c *model.Client, id string) ( +	count int64, err error) { +	s, err := c.Unfavourite(ctx, id) +	if err != nil { +		return +	} +	count = s.FavouritesCount  	return  } -func addToReplyMap(m map[string][]mastodon.ReplyInfo, key interface{}, val string, number int) { -	if key == nil { +func (svc *service) Retweet(ctx context.Context, c *model.Client, id string) ( +	count int64, err error) { +	s, err := c.Reblog(ctx, id) +	if err != nil {  		return  	} +	if s.Reblog != nil { +		count = s.Reblog.ReblogsCount +	} +	return +} -	keyStr, ok := key.(string) -	if !ok { +func (svc *service) UnRetweet(ctx context.Context, c *model.Client, id string) ( +	count int64, err error) { +	s, err := c.Unreblog(ctx, id) +	if err != nil {  		return  	} -	_, ok = m[keyStr] -	if !ok { -		m[keyStr] = []mastodon.ReplyInfo{} +	count = s.ReblogsCount +	return +} + +func (svc *service) Follow(ctx context.Context, c *model.Client, id string) (err error) { +	_, err = c.AccountFollow(ctx, id) +	return +} + +func (svc *service) UnFollow(ctx context.Context, c *model.Client, id string) (err error) { +	_, err = c.AccountUnfollow(ctx, id) +	return +} + +func (svc *service) SaveSettings(ctx context.Context, c *model.Client, +	settings *model.Settings) (err error) { + +	session, err := svc.sessionRepo.Get(c.Session.ID) +	if err != nil { +		return  	} -	m[keyStr] = append(m[keyStr], mastodon.ReplyInfo{val, number}) +	session.Settings = *settings +	return svc.sessionRepo.Add(session)  } diff --git a/service/transport.go b/service/transport.go index e878f8d..fbab2e5 100644 --- a/service/transport.go +++ b/service/transport.go @@ -15,327 +15,292 @@ import (  	"github.com/gorilla/mux"  ) -var ( -	ctx       = context.Background() -	cookieAge = "31536000" -) +func newClient(w io.Writer) *model.Client { +	return &model.Client{ +		Writer: w, +	} +} + +func newCtxWithSesion(req *http.Request) context.Context { +	ctx := context.Background() +	sessionID, err := req.Cookie("session_id") +	if err != nil { +		return ctx +	} +	return context.WithValue(ctx, "session_id", sessionID.Value) +} + +func newCtxWithSesionCSRF(req *http.Request, csrfToken string) context.Context { +	ctx := newCtxWithSesion(req) +	return context.WithValue(ctx, "csrf_token", csrfToken) +} + +func getMultipartFormValue(mf *multipart.Form, key string) (val string) { +	vals, ok := mf.Value[key] +	if !ok { +		return "" +	} +	if len(vals) < 1 { +		return "" +	} +	return vals[0] +} + +func serveJson(w io.Writer, data interface{}) (err error) { +	var d = make(map[string]interface{}) +	d["data"] = data +	return json.NewEncoder(w).Encode(d) +}  func NewHandler(s Service, staticDir string) http.Handler {  	r := mux.NewRouter() -	r.PathPrefix("/static").Handler(http.StripPrefix("/static", -		http.FileServer(http.Dir(path.Join(".", staticDir))))) +	rootPage := func(w http.ResponseWriter, req *http.Request) { +		sessionID, _ := req.Cookie("session_id") -	r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {  		location := "/signin" - -		sessionID, _ := req.Cookie("session_id")  		if sessionID != nil && len(sessionID.Value) > 0 {  			location = "/timeline/home"  		}  		w.Header().Add("Location", location)  		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodGet) - -	r.HandleFunc("/signin", func(w http.ResponseWriter, req *http.Request) { -		err := s.ServeSigninPage(ctx, w) -		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) -			return -		} -	}).Methods(http.MethodGet) +	} -	r.HandleFunc("/signin", func(w http.ResponseWriter, req *http.Request) { -		instance := req.FormValue("instance") -		url, sessionID, err := s.GetAuthUrl(ctx, instance) +	signinPage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := context.Background() +		err := s.ServeSigninPage(ctx, c)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} +	} -		http.SetCookie(w, &http.Cookie{ -			Name:    "session_id", -			Value:   sessionID, -			Expires: time.Now().Add(365 * 24 * time.Hour), -		}) - -		w.Header().Add("Location", url) -		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodPost) +	timelinePage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req) +		tType, _ := mux.Vars(req)["type"] +		maxID := req.URL.Query().Get("max_id") +		minID := req.URL.Query().Get("min_id") -	r.HandleFunc("/oauth_callback", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		token := req.URL.Query().Get("code") -		_, err := s.GetUserToken(ctx, "", nil, token) +		err := s.ServeTimelinePage(ctx, c, tType, maxID, minID)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} +	} +	timelineOldPage := func(w http.ResponseWriter, req *http.Request) {  		w.Header().Add("Location", "/timeline/home")  		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodGet) - -	r.HandleFunc("/timeline", func(w http.ResponseWriter, req *http.Request) { -		w.Header().Add("Location", "/timeline/home") -		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodGet) - -	r.HandleFunc("/timeline/{type}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) - -		timelineType, _ := mux.Vars(req)["type"] -		maxID := req.URL.Query().Get("max_id") -		sinceID := req.URL.Query().Get("since_id") -		minID := req.URL.Query().Get("min_id") - -		err := s.ServeTimelinePage(ctx, w, nil, timelineType, maxID, sinceID, minID) -		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) -			return -		} -	}).Methods(http.MethodGet) +	} -	r.HandleFunc("/thread/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) +	threadPage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req)  		id, _ := mux.Vars(req)["id"]  		reply := req.URL.Query().Get("reply") -		err := s.ServeThreadPage(ctx, w, nil, id, len(reply) > 1) + +		err := s.ServeThreadPage(ctx, c, id, len(reply) > 1)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -	}).Methods(http.MethodGet) +	} -	r.HandleFunc("/likedby/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) +	likedByPage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req)  		id, _ := mux.Vars(req)["id"] -		err := s.ServeLikedByPage(ctx, w, nil, id) +		err := s.ServeLikedByPage(ctx, c, id)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -	}).Methods(http.MethodGet) +	} -	r.HandleFunc("/retweetedby/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) +	retweetedByPage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req)  		id, _ := mux.Vars(req)["id"] -		err := s.ServeRetweetedByPage(ctx, w, nil, id) +		err := s.ServeRetweetedByPage(ctx, c, id)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -	}).Methods(http.MethodGet) - -	r.HandleFunc("/following/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) +	} +	followingPage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req)  		id, _ := mux.Vars(req)["id"]  		maxID := req.URL.Query().Get("max_id")  		minID := req.URL.Query().Get("min_id") -		err := s.ServeFollowingPage(ctx, w, nil, id, maxID, minID) +		err := s.ServeFollowingPage(ctx, c, id, maxID, minID)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -	}).Methods(http.MethodGet) - -	r.HandleFunc("/followers/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) +	} +	followersPage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req)  		id, _ := mux.Vars(req)["id"]  		maxID := req.URL.Query().Get("max_id")  		minID := req.URL.Query().Get("min_id") -		err := s.ServeFollowersPage(ctx, w, nil, id, maxID, minID) +		err := s.ServeFollowersPage(ctx, c, id, maxID, minID)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -	}).Methods(http.MethodGet) - -	r.HandleFunc("/like/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) +	} -		id, _ := mux.Vars(req)["id"] -		retweetedByID := req.FormValue("retweeted_by_id") +	notificationsPage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req) +		maxID := req.URL.Query().Get("max_id") +		minID := req.URL.Query().Get("min_id") -		_, err := s.Like(ctx, w, nil, id) +		err := s.ServeNotificationPage(ctx, c, maxID, minID)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} +	} -		rID := id -		if len(retweetedByID) > 0 { -			rID = retweetedByID -		} -		w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) -		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodPost) - -	r.HandleFunc("/unlike/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - +	userPage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req)  		id, _ := mux.Vars(req)["id"] -		retweetedByID := req.FormValue("retweeted_by_id") +		maxID := req.URL.Query().Get("max_id") +		minID := req.URL.Query().Get("min_id") -		_, err := s.UnLike(ctx, w, nil, id) +		err := s.ServeUserPage(ctx, c, id, maxID, minID)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} +	} -		rID := id -		if len(retweetedByID) > 0 { -			rID = retweetedByID -		} -		w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) -		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodPost) - -	r.HandleFunc("/retweet/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) - -		id, _ := mux.Vars(req)["id"] -		retweetedByID := req.FormValue("retweeted_by_id") +	aboutPage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req) -		_, err := s.Retweet(ctx, w, nil, id) +		err := s.ServeAboutPage(ctx, c)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} +	} -		rID := id -		if len(retweetedByID) > 0 { -			rID = retweetedByID -		} -		w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) -		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodPost) - -	r.HandleFunc("/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) +	emojisPage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req) -		id, _ := mux.Vars(req)["id"] -		retweetedByID := req.FormValue("retweeted_by_id") - -		_, err := s.UnRetweet(ctx, w, nil, id) +		err := s.ServeEmojiPage(ctx, c)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} +	} -		rID := id -		if len(retweetedByID) > 0 { -			rID = retweetedByID -		} -		w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) -		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodPost) - -	r.HandleFunc("/fluoride/like/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) +	searchPage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req) +		q := req.URL.Query().Get("q") +		qType := req.URL.Query().Get("type") +		offsetStr := req.URL.Query().Get("offset") -		id, _ := mux.Vars(req)["id"] -		count, err := s.Like(ctx, w, nil, id) -		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) -			return +		var offset int +		var err error +		if len(offsetStr) > 1 { +			offset, err = strconv.Atoi(offsetStr) +			if err != nil { +				s.ServeErrorPage(ctx, c, err) +				return +			}  		} -		err = serveJson(w, count) +		err = s.ServeSearchPage(ctx, c, q, qType, offset)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -	}).Methods(http.MethodPost) - -	r.HandleFunc("/fluoride/unlike/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) +	} -		id, _ := mux.Vars(req)["id"] -		count, err := s.UnLike(ctx, w, nil, id) -		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) -			return -		} +	settingsPage := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req) -		err = serveJson(w, count) +		err := s.ServeSettingsPage(ctx, c)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -	}).Methods(http.MethodPost) +	} -	r.HandleFunc("/fluoride/retweet/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) +	signin := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := context.Background() +		instance := req.FormValue("instance") -		id, _ := mux.Vars(req)["id"] -		count, err := s.Retweet(ctx, w, nil, id) +		url, sessionID, err := s.NewSession(ctx, instance)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -		err = serveJson(w, count) -		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) -			return -		} -	}).Methods(http.MethodPost) +		http.SetCookie(w, &http.Cookie{ +			Name:    "session_id", +			Value:   sessionID, +			Expires: time.Now().Add(365 * 24 * time.Hour), +		}) -	r.HandleFunc("/fluoride/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) +		w.Header().Add("Location", url) +		w.WriteHeader(http.StatusFound) +	} -		id, _ := mux.Vars(req)["id"] -		count, err := s.UnRetweet(ctx, w, nil, id) -		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) -			return -		} +	oauthCallback := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesion(req) +		token := req.URL.Query().Get("code") -		err = serveJson(w, count) +		_, err := s.Signin(ctx, c, "", token)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -	}).Methods(http.MethodPost) -	r.HandleFunc("/post", func(w http.ResponseWriter, req *http.Request) { +		w.Header().Add("Location", "/timeline/home") +		w.WriteHeader(http.StatusFound) +	} + +	post := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w)  		err := req.ParseMultipartForm(4 << 20)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(context.Background(), c, err)  			return  		} -		ctx := getContextWithSession(context.Background(), req) -		ctx = context.WithValue(ctx, "csrf_token", +		ctx := newCtxWithSesionCSRF(req,  			getMultipartFormValue(req.MultipartForm, "csrf_token")) -  		content := getMultipartFormValue(req.MultipartForm, "content")  		replyToID := getMultipartFormValue(req.MultipartForm, "reply_to_id")  		format := getMultipartFormValue(req.MultipartForm, "format")  		visibility := getMultipartFormValue(req.MultipartForm, "visibility")  		isNSFW := "on" == getMultipartFormValue(req.MultipartForm, "is_nsfw") -  		files := req.MultipartForm.File["attachments"] -		id, err := s.PostTweet(ctx, w, nil, content, replyToID, format, visibility, isNSFW, files) +		id, err := s.Post(ctx, c, content, replyToID, format, visibility, isNSFW, files)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} @@ -345,131 +310,129 @@ func NewHandler(s Service, staticDir string) http.Handler {  		}  		w.Header().Add("Location", location)  		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodPost) - -	r.HandleFunc("/notifications", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) - -		maxID := req.URL.Query().Get("max_id") -		minID := req.URL.Query().Get("min_id") - -		err := s.ServeNotificationPage(ctx, w, nil, maxID, minID) -		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) -			return -		} -	}).Methods(http.MethodGet) - -	r.HandleFunc("/user/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) +	} +	like := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token"))  		id, _ := mux.Vars(req)["id"] -		maxID := req.URL.Query().Get("max_id") -		minID := req.URL.Query().Get("min_id") +		retweetedByID := req.FormValue("retweeted_by_id") -		err := s.ServeUserPage(ctx, w, nil, id, maxID, minID) +		_, err := s.Like(ctx, c, id)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -	}).Methods(http.MethodGet) -	r.HandleFunc("/follow/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) +		rID := id +		if len(retweetedByID) > 0 { +			rID = retweetedByID +		} +		w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) +		w.WriteHeader(http.StatusFound) +	} +	unlike := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token"))  		id, _ := mux.Vars(req)["id"] +		retweetedByID := req.FormValue("retweeted_by_id") -		err := s.Follow(ctx, w, nil, id) +		_, err := s.UnLike(ctx, c, id)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -		w.Header().Add("Location", req.Header.Get("Referer")) +		rID := id +		if len(retweetedByID) > 0 { +			rID = retweetedByID +		} +		w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID)  		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodPost) - -	r.HandleFunc("/unfollow/{id}", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) +	} +	retweet := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token"))  		id, _ := mux.Vars(req)["id"] +		retweetedByID := req.FormValue("retweeted_by_id") -		err := s.UnFollow(ctx, w, nil, id) +		_, err := s.Retweet(ctx, c, id)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -		w.Header().Add("Location", req.Header.Get("Referer")) +		rID := id +		if len(retweetedByID) > 0 { +			rID = retweetedByID +		} +		w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID)  		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodPost) +	} -	r.HandleFunc("/about", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) +	unretweet := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) +		id, _ := mux.Vars(req)["id"] +		retweetedByID := req.FormValue("retweeted_by_id") -		err := s.ServeAboutPage(ctx, w, nil) +		_, err := s.UnRetweet(ctx, c, id)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -	}).Methods(http.MethodGet) -	r.HandleFunc("/emojis", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) - -		err := s.ServeEmojiPage(ctx, w, nil) -		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) -			return +		rID := id +		if len(retweetedByID) > 0 { +			rID = retweetedByID  		} -	}).Methods(http.MethodGet) - -	r.HandleFunc("/search", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		q := req.URL.Query().Get("q") -		qType := req.URL.Query().Get("type") -		offsetStr := req.URL.Query().Get("offset") +		w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID) +		w.WriteHeader(http.StatusFound) +	} -		var offset int -		var err error -		if len(offsetStr) > 1 { -			offset, err = strconv.Atoi(offsetStr) -			if err != nil { -				s.ServeErrorPage(ctx, w, nil, err) -				return -			} -		} +	follow := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) +		id, _ := mux.Vars(req)["id"] -		err = s.ServeSearchPage(ctx, w, nil, q, qType, offset) +		err := s.Follow(ctx, c, id)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -	}).Methods(http.MethodGet) -	r.HandleFunc("/settings", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) +		w.Header().Add("Location", req.Header.Get("Referer")) +		w.WriteHeader(http.StatusFound) +	} + +	unfollow := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) +		id, _ := mux.Vars(req)["id"] -		err := s.ServeSettingsPage(ctx, w, nil) +		err := s.UnFollow(ctx, c, id)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		} -	}).Methods(http.MethodGet) -	r.HandleFunc("/settings", func(w http.ResponseWriter, req *http.Request) { -		ctx := getContextWithSession(context.Background(), req) -		ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token")) +		w.Header().Add("Location", req.Header.Get("Referer")) +		w.WriteHeader(http.StatusFound) +	} +	settings := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token"))  		visibility := req.FormValue("visibility")  		copyScope := req.FormValue("copy_scope") == "true"  		threadInNewTab := req.FormValue("thread_in_new_tab") == "true"  		maskNSFW := req.FormValue("mask_nsfw") == "true"  		fluorideMode := req.FormValue("fluoride_mode") == "true"  		darkMode := req.FormValue("dark_mode") == "true" +  		settings := &model.Settings{  			DefaultVisibility: visibility,  			CopyScope:         copyScope, @@ -479,17 +442,17 @@ func NewHandler(s Service, staticDir string) http.Handler {  			DarkMode:          darkMode,  		} -		err := s.SaveSettings(ctx, w, nil, settings) +		err := s.SaveSettings(ctx, c, settings)  		if err != nil { -			s.ServeErrorPage(ctx, w, nil, err) +			s.ServeErrorPage(ctx, c, err)  			return  		}  		w.Header().Add("Location", req.Header.Get("Referer"))  		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodPost) +	} -	r.HandleFunc("/signout", func(w http.ResponseWriter, req *http.Request) { +	signout := func(w http.ResponseWriter, req *http.Request) {  		// TODO remove session from database  		http.SetCookie(w, &http.Cookie{  			Name:    "session_id", @@ -498,32 +461,111 @@ func NewHandler(s Service, staticDir string) http.Handler {  		})  		w.Header().Add("Location", "/")  		w.WriteHeader(http.StatusFound) -	}).Methods(http.MethodGet) +	} -	return r -} +	fLike := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) +		id, _ := mux.Vars(req)["id"] -func getContextWithSession(ctx context.Context, req *http.Request) context.Context { -	sessionID, err := req.Cookie("session_id") -	if err != nil { -		return ctx +		count, err := s.Like(ctx, c, id) +		if err != nil { +			s.ServeErrorPage(ctx, c, err) +			return +		} + +		err = serveJson(w, count) +		if err != nil { +			s.ServeErrorPage(ctx, c, err) +			return +		}  	} -	return context.WithValue(ctx, "session_id", sessionID.Value) -} -func getMultipartFormValue(mf *multipart.Form, key string) (val string) { -	vals, ok := mf.Value[key] -	if !ok { -		return "" +	fUnlike := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) +		id, _ := mux.Vars(req)["id"] +		count, err := s.UnLike(ctx, c, id) +		if err != nil { +			s.ServeErrorPage(ctx, c, err) +			return +		} + +		err = serveJson(w, count) +		if err != nil { +			s.ServeErrorPage(ctx, c, err) +			return +		}  	} -	if len(vals) < 1 { -		return "" + +	fRetweet := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) +		id, _ := mux.Vars(req)["id"] + +		count, err := s.Retweet(ctx, c, id) +		if err != nil { +			s.ServeErrorPage(ctx, c, err) +			return +		} + +		err = serveJson(w, count) +		if err != nil { +			s.ServeErrorPage(ctx, c, err) +			return +		}  	} -	return vals[0] -} -func serveJson(w io.Writer, data interface{}) (err error) { -	var d = make(map[string]interface{}) -	d["data"] = data -	return json.NewEncoder(w).Encode(d) +	fUnretweet := func(w http.ResponseWriter, req *http.Request) { +		c := newClient(w) +		ctx := newCtxWithSesionCSRF(req, req.FormValue("csrf_token")) +		id, _ := mux.Vars(req)["id"] + +		count, err := s.UnRetweet(ctx, c, id) +		if err != nil { +			s.ServeErrorPage(ctx, c, err) +			return +		} + +		err = serveJson(w, count) +		if err != nil { +			s.ServeErrorPage(ctx, c, err) +			return +		} +	} + +	r.HandleFunc("/", rootPage).Methods(http.MethodGet) +	r.HandleFunc("/signin", signinPage).Methods(http.MethodGet) +	r.HandleFunc("/timeline/{type}", timelinePage).Methods(http.MethodGet) +	r.HandleFunc("/timeline", timelineOldPage).Methods(http.MethodGet) +	r.HandleFunc("/thread/{id}", threadPage).Methods(http.MethodGet) +	r.HandleFunc("/likedby/{id}", likedByPage).Methods(http.MethodGet) +	r.HandleFunc("/retweetedby/{id}", retweetedByPage).Methods(http.MethodGet) +	r.HandleFunc("/following/{id}", followingPage).Methods(http.MethodGet) +	r.HandleFunc("/followers/{id}", followersPage).Methods(http.MethodGet) +	r.HandleFunc("/notifications", notificationsPage).Methods(http.MethodGet) +	r.HandleFunc("/user/{id}", userPage).Methods(http.MethodGet) +	r.HandleFunc("/about", aboutPage).Methods(http.MethodGet) +	r.HandleFunc("/emojis", emojisPage).Methods(http.MethodGet) +	r.HandleFunc("/search", searchPage).Methods(http.MethodGet) +	r.HandleFunc("/settings", settingsPage).Methods(http.MethodGet) +	r.HandleFunc("/signin", signin).Methods(http.MethodPost) +	r.HandleFunc("/oauth_callback", oauthCallback).Methods(http.MethodGet) +	r.HandleFunc("/post", post).Methods(http.MethodPost) +	r.HandleFunc("/like/{id}", like).Methods(http.MethodPost) +	r.HandleFunc("/unlike/{id}", unlike).Methods(http.MethodPost) +	r.HandleFunc("/retweet/{id}", retweet).Methods(http.MethodPost) +	r.HandleFunc("/unretweet/{id}", unretweet).Methods(http.MethodPost) +	r.HandleFunc("/follow/{id}", follow).Methods(http.MethodPost) +	r.HandleFunc("/unfollow/{id}", unfollow).Methods(http.MethodPost) +	r.HandleFunc("/settings", settings).Methods(http.MethodPost) +	r.HandleFunc("/signout", signout).Methods(http.MethodGet) +	r.HandleFunc("/fluoride/like/{id}", fLike).Methods(http.MethodPost) +	r.HandleFunc("/fluoride/unlike/{id}", fUnlike).Methods(http.MethodPost) +	r.HandleFunc("/fluoride/retweet/{id}", fRetweet).Methods(http.MethodPost) +	r.HandleFunc("/fluoride/unretweet/{id}", fUnretweet).Methods(http.MethodPost) +	r.PathPrefix("/static").Handler(http.StripPrefix("/static", +		http.FileServer(http.Dir(path.Join(".", staticDir))))) + +	return r  }  | 
