Browse Source

Refactor to using Init funcs and add gob serializer

Signed-off-by: jolheiser <john.olheiser@gmail.com>
main v0.0.4
jolheiser 1 year ago
parent
commit
de4e822f7d
No known key found for this signature in database GPG Key ID: 83E486E71AFEB820
  1. 5
      bucket.go
  2. 3
      example_key_test.go
  3. 3
      example_seq_test.go
  4. 21
      serial.go
  5. 52
      serial_test.go
  6. 12
      shock.go
  7. 16
      shock_test.go
  8. 22
      store.go
  9. 41
      store_test.go

5
bucket.go

@ -35,3 +35,8 @@ func (b *Bucket) ViewEach(fn EachFunc) error {
func (b *Bucket) UpdateEach(fn EachFunc) error {
return b.DB.UpdateEach(b.Name, fn)
}
// Init initializes the bucket if it doesn't exist
func (b *Bucket) Init() error {
return b.DB.Init(b.Name)
}

3
example_key_test.go

@ -27,6 +27,9 @@ func ExampleTestUserSettings() {
// Get a bucket to work with instead of specifying each time
bucket := db.Bucket("test-user-settings")
if err := bucket.Init(); err != nil {
panic(err)
}
// Add a new TestUserSettings
t := &TestUserSettings{

3
example_seq_test.go

@ -34,6 +34,9 @@ func ExampleTestUser() {
// Get a bucket to work with instead of specifying each time
bucket := db.Bucket("test-user")
if err := bucket.Init(); err != nil {
panic(err)
}
// Add a new TestUser
t := &TestUser{

21
serial.go

@ -1,6 +1,10 @@
package shock
import "encoding/json"
import (
"bytes"
"encoding/gob"
"encoding/json"
)
// Serializer defines a way to encode/decode objects in the database
type Serializer interface {
@ -20,3 +24,18 @@ func (j JSONSerializer) Marshal(v interface{}) ([]byte, error) {
func (j JSONSerializer) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, &v)
}
// GobSerializer is a Serializer that uses gob
type GobSerializer struct{}
// Marshal encodes an object to gob
func (g GobSerializer) Marshal(v interface{}) ([]byte, error) {
var buf bytes.Buffer
err := gob.NewEncoder(&buf).Encode(v)
return buf.Bytes(), err
}
// Unmarshal decodes an object from gob
func (g GobSerializer) Unmarshal(data []byte, v interface{}) error {
return gob.NewDecoder(bytes.NewBuffer(data)).Decode(v)
}

52
serial_test.go

@ -0,0 +1,52 @@
package shock
import (
"fmt"
"os"
"path"
"testing"
"go.etcd.io/bbolt"
)
func TestGob(t *testing.T) {
dbPath := path.Join(tmpDir, "gob.db")
db, err := Open(dbPath, os.ModePerm, &Options{
Bolt: bbolt.DefaultOptions,
Serializer: GobSerializer{},
})
if err != nil {
panic(err)
}
bucket := db.Bucket("test")
if err := bucket.Init(); err != nil {
t.Log(err)
t.FailNow()
}
for _, tc := range tt {
t.Run(tc.Name, func(t *testing.T) {
if err := bucket.Put(tc); err != nil {
t.Log(err)
t.FailNow()
}
var tcc TestUser
err := bucket.Get(tc.ID, &tcc)
if err != nil {
t.Log(err)
t.FailNow()
}
if !tcc.Equal(*tc) {
t.Log("Serialized struct is not the same")
t.FailNow()
}
})
}
if err := db.Bolt.Close(); err != nil {
fmt.Printf("Could not close DB %s: %v\n", dbPath, err)
}
}

12
shock.go

@ -34,3 +34,15 @@ func (d *DB) Bucket(name string) *Bucket {
DB: d,
}
}
// Init initializes buckets if they don't exist
func (d *DB) Init(buckets ...string) error {
return d.Bolt.Update(func(tx *bbolt.Tx) error {
for _, bucket := range buckets {
if _, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil {
return err
}
}
return nil
})
}

16
shock_test.go

@ -8,16 +8,20 @@ import (
"testing"
)
var db *DB
var (
db *DB
tmpDir string
)
func TestMain(m *testing.M) {
dir, err := ioutil.TempDir(os.TempDir(), "shock")
var err error
tmpDir, err = ioutil.TempDir(os.TempDir(), "shock")
if err != nil {
panic(err)
}
dbPath := path.Join(dir, "shock.db")
db, err = Open(dbPath, os.ModePerm, DefaultOptions)
dbPath := path.Join(tmpDir, "shock.db")
db, err = Open(dbPath, os.ModePerm, nil)
if err != nil {
panic(err)
}
@ -27,8 +31,8 @@ func TestMain(m *testing.M) {
if err := db.Bolt.Close(); err != nil {
fmt.Printf("Could not close DB %s: %v\n", dbPath, err)
}
if err := os.RemoveAll(dir); err != nil {
fmt.Printf("Could not delete temp dir %s: %v\n", dir, err)
if err := os.RemoveAll(tmpDir); err != nil {
fmt.Printf("Could not delete temp dir %s: %v\n", tmpDir, err)
}
os.Exit(exit)

22
store.go

@ -13,9 +13,6 @@ type Sequencer interface {
// Put adds a new value to a bucket
func (d *DB) Put(bucket string, val Sequencer) error {
if err := d.initBucket(bucket); err != nil {
return err
}
return d.Bolt.Update(func(tx *bbolt.Tx) error {
b := tx.Bucket([]byte(bucket))
@ -35,9 +32,6 @@ func (d *DB) Put(bucket string, val Sequencer) error {
// PutWithKey adds a new value to the bucket with a defined key
func (d *DB) PutWithKey(bucket string, key, val interface{}) error {
if err := d.initBucket(bucket); err != nil {
return err
}
return d.Bolt.Update(func(tx *bbolt.Tx) error {
b := tx.Bucket([]byte(bucket))
@ -51,9 +45,6 @@ func (d *DB) PutWithKey(bucket string, key, val interface{}) error {
// Get returns a value from a bucket with the specified sequence ID
func (d *DB) Get(bucket string, id, val interface{}) error {
if err := d.initBucket(bucket); err != nil {
return err
}
if err := d.Bolt.View(func(tx *bbolt.Tx) error {
serial := tx.Bucket([]byte(bucket)).Get([]byte(fmt.Sprintf("%v", id)))
return d.Serializer.Unmarshal(serial, val)
@ -66,9 +57,6 @@ func (d *DB) Get(bucket string, id, val interface{}) error {
// Count returns the number of objects in a bucket
func (d *DB) Count(bucket string) (int, error) {
count := 0
if err := d.initBucket(bucket); err != nil {
return count, err
}
if err := d.Bolt.View(func(tx *bbolt.Tx) error {
return tx.Bucket([]byte(bucket)).ForEach(func(_, _ []byte) error {
count++
@ -94,9 +82,6 @@ func (d *DB) UpdateEach(bucket string, fn EachFunc) error {
}
func (d *DB) forEach(bucket string, writable bool, fn EachFunc) error {
if err := d.initBucket(bucket); err != nil {
return err
}
tx, err := d.Bolt.Begin(writable)
if err != nil {
return err
@ -110,10 +95,3 @@ func (d *DB) forEach(bucket string, writable bool, fn EachFunc) error {
}
return tx.Rollback()
}
func (d *DB) initBucket(bucket string) error {
return d.Bolt.Update(func(tx *bbolt.Tx) error {
_, err := tx.CreateBucketIfNotExists([]byte(bucket))
return err
})
}

41
store_test.go

@ -4,25 +4,25 @@ import (
"testing"
)
func TestStore(t *testing.T) {
var tt = []*TestUser{
{
Name: "user1",
Age: 25,
Admin: true,
},
{
Name: "user2",
Age: 30,
Admin: false,
},
{
Name: "user3",
Age: 40,
Admin: false,
},
}
tt := []*TestUser{
{
Name: "user1",
Age: 25,
Admin: true,
},
{
Name: "user2",
Age: 30,
Admin: false,
},
{
Name: "user3",
Age: 40,
Admin: false,
},
}
func TestStore(t *testing.T) {
ttt := []*TestUserSettings{
{
@ -36,6 +36,11 @@ func TestStore(t *testing.T) {
}
bucket := db.Bucket("test")
if err := bucket.Init(); err != nil {
t.Log(err)
t.FailNow()
}
for _, tc := range tt {
t.Run(tc.Name, func(t *testing.T) {
if err := bucket.Put(tc); err != nil {

Loading…
Cancel
Save