Skip to content

Commit

Permalink
Merge branch 'master' into monitor-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
naueramant committed Mar 24, 2024
2 parents cfb6e93 + 7c9a9f1 commit 237351b
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 32 deletions.
5 changes: 5 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ func initConfig() {
logrus.Info("Using config file: ", viper.ConfigFileUsed())
}

viper.SetConfigName("secrets")
if err := viper.MergeInConfig(); err == nil {
logrus.Info("Using secret file: ", viper.ConfigFileUsed())
}

viper.AutomaticEnv()
}

Expand Down
94 changes: 90 additions & 4 deletions internal/billing/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"github.com/stripe/stripe-go/v76"
portalsession "github.com/stripe/stripe-go/v76/billingportal/session"
"github.com/stripe/stripe-go/v76/checkout/session"
"github.com/stripe/stripe-go/v76/price"
"github.com/stripe/stripe-go/v76/subscription"
"github.com/stripe/stripe-go/v76/webhook"
)

Expand All @@ -22,8 +24,13 @@ type Config struct {
type Service interface {
PostConfig() StripeConfig
CreateCheckoutSession(team *entities.Team, priceLookupKey string) (*stripe.CheckoutSession, error)
UpdateSubscribtion(team *entities.Team, priceLookupKey string) (*stripe.Subscription, error)
CancelSubscribtion(team *entities.Team) (*stripe.Subscription, error)
GetCheckoutSession(sessionID string) (*stripe.CheckoutSession, error)
GetLineItems(sessionID string) *session.LineItemIter
GetPrice(priceLookupKey string) (*stripe.Price, error)
GetCustomerSubscribtion(customerID string) *subscription.Iter
GetSubscribtion(subID string) (*stripe.Subscription, error)
CreateCustomerPortal(sessionID string) (*stripe.BillingPortalSession, error)
ConstructEvent(payload []byte, header string) (stripe.Event, error)
}
Expand Down Expand Up @@ -54,28 +61,77 @@ func (s *ServiceImpl) PostConfig() StripeConfig {
}

func (s *ServiceImpl) CreateCheckoutSession(team *entities.Team, priceLookupKey string) (*stripe.CheckoutSession, error) {
priceID, err := s.GetPrice(priceLookupKey)
if err != nil {
return nil, errors.Wrap(err, "failed to get price")
}

params := &stripe.CheckoutSessionParams{
SuccessURL: stripe.String(s.Config.Domain + "/team/plan"),
CancelURL: stripe.String(s.Config.Domain + "/canceled.html"),
SuccessURL: stripe.String(s.Config.Domain + "/team/subscription"),
CancelURL: stripe.String(s.Config.Domain + "/team/subscription"),
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
ClientReferenceID: stripe.String(strconv.FormatUint(uint64(team.ID), 10)),

LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(priceLookupKey),

Price: stripe.String(priceID.ID),
Quantity: stripe.Int64(1),
},
},
}

// Set Customer on session if already a customer
if team.StripeCustomerID != nil {
params.Customer = stripe.String(*team.StripeCustomerID)
}

return session.New(params)
}

func (s *ServiceImpl) UpdateSubscribtion(team *entities.Team, priceLookupKey string) (*stripe.Subscription, error) {
// Set Customer on session if already a customer

priceID, err := s.GetPrice(priceLookupKey)
if err != nil {
return nil, errors.Wrap(err, "failed to get price")
}
sub := s.GetCustomerSubscribtion(*team.StripeCustomerID)
sub.Next()
teamSubscription := sub.Subscription()

params := &stripe.SubscriptionParams{
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String(teamSubscription.Items.Data[0].ID),
Price: stripe.String(priceID.ID),
},
},
}
result, err := subscription.Update(teamSubscription.ID, params)
if err != nil {
return nil, errors.Wrap(err, "failed to update subscription")
}

return result, nil
}

func (s *ServiceImpl) CancelSubscribtion(team *entities.Team) (*stripe.Subscription, error) {
// Set Customer on session if already a customer

sub := s.GetCustomerSubscribtion(*team.StripeCustomerID)
sub.Next()
teamSubscription := sub.Subscription()

params := &stripe.SubscriptionCancelParams{}
result, err := subscription.Cancel(teamSubscription.ID, params)

if err != nil {
return nil, errors.Wrap(err, "failed to update subscription")
}

return result, nil
}

func (s *ServiceImpl) GetCheckoutSession(sessionID string) (*stripe.CheckoutSession, error) {
return session.Get(sessionID, &stripe.CheckoutSessionParams{})
}
Expand All @@ -87,6 +143,36 @@ func (s *ServiceImpl) GetLineItems(sessionID string) *session.LineItemIter {
return session.ListLineItems(params)
}

func (s *ServiceImpl) GetPrice(priceLookupKey string) (*stripe.Price, error) {
params := &stripe.PriceListParams{
LookupKeys: stripe.StringSlice([]string{
priceLookupKey,
}),
}
i := price.List(params)

var price *stripe.Price
for i.Next() {
p := i.Price()
price = p
}
if price == nil {
return nil, errors.New("Price not found for lookup key" + priceLookupKey)
}
return price, nil
}

func (s *ServiceImpl) GetCustomerSubscribtion(customerID string) *subscription.Iter {
params := &stripe.SubscriptionListParams{Customer: stripe.String(customerID)}
return subscription.List(params)
}

func (s *ServiceImpl) GetSubscribtion(subID string) (*stripe.Subscription, error) {
params := &stripe.SubscriptionParams{}
return subscription.Get(subID, params)

}

func (s *ServiceImpl) CreateCustomerPortal(sessionID string) (*stripe.BillingPortalSession, error) {
// For demonstration purposes, we're using the Checkout session to retrieve the customer ID.
// Typically this is stored alongside the authenticated user in your database.
Expand Down
45 changes: 40 additions & 5 deletions internal/rest/controllers/teams/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ func (h *Handlers) PostCreateCheckoutSession(c hs.AuthenticatedContext) error {

return echo.ErrBadRequest
}
c.Log.Error(req.PriceLookupKey)
c.Log.Info(req.TeamID)

team, err := h.TeamService.GetByID(c.Request().Context(), req.TeamID)
if err != nil {
Expand All @@ -34,14 +32,51 @@ func (h *Handlers) PostCreateCheckoutSession(c hs.AuthenticatedContext) error {
return echo.ErrInternalServerError
}

s, err := h.BillingService.CreateCheckoutSession(team, req.PriceLookupKey)
if team.PaymentPlan == req.PriceLookupKey {
return c.JSON(http.StatusOK, "")
}

if team.StripeCustomerID == nil {
if req.PriceLookupKey == "FREE" {
return c.JSON(http.StatusOK, "")
}
s, err := h.BillingService.CreateCheckoutSession(team, req.PriceLookupKey)
if err != nil {
c.Log.WithError(err).Debug("create stripe checkout session")

return echo.ErrInternalServerError
}
return c.JSON(http.StatusOK, s.URL)
}

if req.PriceLookupKey == "FREE" {
_, err := h.BillingService.CancelSubscribtion(team)
if err != nil {
c.Log.WithError(err).Debug("cancel subscription")

return echo.ErrInternalServerError
}
return c.JSON(http.StatusOK, "")
}

if team.PaymentPlan == "FREE" {
s, err := h.BillingService.CreateCheckoutSession(team, req.PriceLookupKey)
if err != nil {
c.Log.WithError(err).Debug("create stripe checkout session")

return echo.ErrInternalServerError
}
return c.JSON(http.StatusOK, s.URL)
}

_, err = h.BillingService.UpdateSubscribtion(team, req.PriceLookupKey)
if err != nil {
c.Log.WithError(err).Debug("create stripe checkout session")
c.Log.WithError(err).Debug("update subscription")

return echo.ErrInternalServerError
}
return c.JSON(http.StatusOK, "")

return c.JSON(http.StatusOK, s.URL)
}

type GetCheckoutSession struct {
Expand Down
1 change: 1 addition & 0 deletions internal/rest/controllers/webhooks/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func Register(
) {
h := &Handlers{
BillingService: billingService,
TeamService: teamService,
}

root := e.Group(
Expand Down
111 changes: 88 additions & 23 deletions internal/rest/controllers/webhooks/stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,112 @@ func (h *Handlers) handleWebhook(c hs.StripeContext) error {
return echo.ErrBadRequest
}

params := &stripe.CheckoutSessionParams{}
params.AddExpand("line_items")

// Retrieve the session. If you require line items in the response, you may include them by expanding line_items.
// sessionWithLineItems, err := h.BillingService.GetCheckoutSession(session.ID)
// if err != nil {
// c.Log.WithError(err).Debug("Error getting checkout session")
// return echo.ErrBadRequest
// }

c.Log.Error(session.ClientReferenceID)
c.Log.Error(session)
c.Log.Error(session.ID)
c.Log.Error(event.Data.Object["customer"].(string))
lineItems := h.BillingService.GetLineItems(session.ID).List()
c.Log.Error(lineItems)
c.Log.Error(event.Data.Object["subscription"].(string))
subscription, err := h.BillingService.GetSubscribtion(event.Data.Object["subscription"].(string))
if err != nil {
c.Log.WithError(err).Debug("Error getting subscription")
}
c.Log.Error(subscription.Items.Data[0])
c.Log.Error(subscription.Items.Data[0].Price.LookupKey)

teamID, err := strconv.ParseUint(session.ClientReferenceID, 10, 32)
if err != nil {
c.Log.WithError(err).Debug("Error parsing team id", session.ClientReferenceID)
return c.NoContent(http.StatusInternalServerError)
}

c.Log.Error(c.Request().Context())
c.Log.Error(uint(teamID))
customerTeam, err := h.TeamService.GetByID(c.Request().Context(), uint(teamID))
if err != nil {
c.Log.WithError(err).Debug("Error getting team by id", teamID)
return c.NoContent(http.StatusInternalServerError)
}

// TODO if not same plan remove old plan

if customerTeam.PaymentPlan == "TEST" { //lineItems.Price.Product.Name {
return c.NoContent(http.StatusOK)
}

customerID := event.Data.Object["customer"].(string)
if customerTeam.StripeCustomerID == nil {
customerTeam.StripeCustomerID = &customerID
}

h.TeamService.UpdateTeam(c.Request().Context(), customerTeam)
customerTeam.PaymentPlan = subscription.Items.Data[0].Price.LookupKey

err = h.TeamService.UpdateTeam(c.Request().Context(), customerTeam)
if err != nil {
c.Log.WithError(err).Debug("Error updating team")
return c.NoContent(http.StatusInternalServerError)
}
case "customer.subscription.updated":
var subscription stripe.Subscription
err := json.Unmarshal(event.Data.Raw, &subscription)
if err != nil {
c.Log.WithError(err).Debug("Error parsing webhook JSON")
return echo.ErrBadRequest
}

team, err := h.TeamService.GetByStripeID(c.Request().Context(), subscription.Customer.ID)
if err != nil {
c.Log.WithError(err).Debug("Error getting team by stripe id")
return c.NoContent(http.StatusInternalServerError)
}
team.PaymentPlan = subscription.Items.Data[0].Price.LookupKey

if subscription.Status == "canceled" {
team.PaymentPlan = "FREE"
}

err = h.TeamService.UpdateTeam(c.Request().Context(), team)
if err != nil {
c.Log.WithError(err).Debug("Error updating team")
return c.NoContent(http.StatusInternalServerError)
}

case "customer.subscription.deleted":
var subscription stripe.Subscription
err := json.Unmarshal(event.Data.Raw, &subscription)
if err != nil {
c.Log.WithError(err).Debug("Error parsing webhook JSON")
return echo.ErrBadRequest
}

team, err := h.TeamService.GetByStripeID(c.Request().Context(), subscription.Customer.ID)
if err != nil {
c.Log.WithError(err).Debug("Error getting team by stripe id")
return c.NoContent(http.StatusInternalServerError)
}
team.PaymentPlan = subscription.Items.Data[0].Price.LookupKey

if subscription.Status == "canceled" {
team.PaymentPlan = "FREE"
}

err = h.TeamService.UpdateTeam(c.Request().Context(), team)
if err != nil {
c.Log.WithError(err).Debug("Error updating team")
return c.NoContent(http.StatusInternalServerError)
}

case "customer.deleted":
var customer stripe.Customer
err := json.Unmarshal(event.Data.Raw, &customer)
if err != nil {
c.Log.WithError(err).Debug("Error parsing webhook JSON")
return echo.ErrBadRequest
}

team, err := h.TeamService.GetByStripeID(c.Request().Context(), customer.ID)
if err != nil {
c.Log.WithError(err).Debug("Error getting team by stripe id")
return c.NoContent(http.StatusInternalServerError)
}

team.PaymentPlan = "FREE"
team.StripeCustomerID = nil

err = h.TeamService.UpdateTeam(c.Request().Context(), team)
if err != nil {
c.Log.WithError(err).Debug("Error updating team")
return c.NoContent(http.StatusInternalServerError)
}

default:
c.Log.WithField("event", event.Type).Debug("Unhandled event type")
Expand Down

0 comments on commit 237351b

Please sign in to comment.