Browse Source

much cleaner backend

automate2
Stephen Lorenz 4 years ago
parent
commit
b8f82db8ff
  1. BIN
      cjs/.data/my_database/database.db
  2. BIN
      cjs/.data/state.db
  3. 31
      cjs/core/database.py
  4. 28
      cjs/core/state.py
  5. 291
      cjs/debug.py
  6. 26
      cjs/models/database.py
  7. 14
      cjs/models/state.py

BIN
cjs/.data/my_database/database.db

BIN
cjs/.data/state.db

31
cjs/core/database.py

@ -40,17 +40,30 @@ class Database:
return res
return wrapper
def _get(self, table):
s = self.get_session()
res = s.query(table).all()
s.close()
return res
@pass_session
def add_table(self, sess, Table):
# TODO: make insert one or many
sess.add(Table)
sess.commit()
@pass_session
def remove_table(sess, Table, expr):
sess.query(Table).filter(expr).delete()
sess.commit()
@pass_session
def update_table(sess, Table, expr):
pass
def _clear(self, table):
s = self.get_session()
s.query(table).delete()
@pass_session
def select_table(self, sess, Table, expr):
q = s.query(Table).filter(expr).all()
return q
@pass_session
def clear_table(self, sess, Table):
s.query(Table).delete()
s.commit()
s.close()
def get_tests(self):
return self._get(VaultTest)

28
cjs/core/state.py

@ -23,37 +23,35 @@ def _get_session():
# decorator to connect to the state database
# will close session after function completes
def pass_session(func):
def _pass_session(fn):
def wrapper(*args, **kwargs):
s = _get_session()
res = func(s, *args, **kwargs)
res = fn(s, *args, **kwargs)
s.close()
return res
return wrapper
# using pass_session is more convenient
@pass_session
@_pass_session
def add_state(sess, State):
'''Add a new entry to the given Table in the State database.'''
sess.add(State)
sess.commit()
@pass_session
def remove_state(sess, State, filter_expr):
@_pass_session
def remove_state(sess, State, expr=True):
'''Delete an entry that matches the given filter expression
from the given Table in the State database.'''
sess.query(State).filter(filter_expr).delete()
sess.query(State).filter(expr).delete()
sess.commit()
@pass_session
def select_state(sess, State, filter_expr):
@_pass_session
def update_state(sess, State, expr):
pass
@_pass_session
def select_state(sess, State, expr=True):
'''Return entries that match the given filter expression
from the given Table in the State database.'''
q = sess.query(State).filter(filter_expr).all() # TODO: make return first, all, and reveresed
return q
@pass_session
def list_state(sess, State):
'''Return all entires in the given Table in the State database.'''
q = sess.query(State).all()
q = sess.query(State).filter(expr).all() # TODO: make return first, all, and reveresed
return q

291
cjs/debug.py

@ -0,0 +1,291 @@
#!/usr/bin/env python3
# standard library modules
import functools
from functools import wraps
# pip modules
import sqlalchemy
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import scoped_session
from sqlalchemy.exc import IntegrityError
# local modules
import core.state
from core.state import add_state
from core.state import remove_state
from core.state import select_state
import models.state
from models.state import DatabaseState
import models.database
from models.database import *
# BEGIN: Generic Database API
# modified from https://stackoverflow.com/a/36944992
# NOTE: ideally this would be inside the database class
def _pass_session(fn):
@wraps(fn) # get around the fact a decorator cannot have self
def wrapper(self, *fn_args, **fn_kwargs):
s = self._get_session()
res = fn(self, s, *fn_args, **fn_kwargs)
s.close()
return res
return wrapper
class Database:
def __init__(self, location):
self._engine = create_engine('sqlite:///%s/database.db' % location)
self._SessionFactory = sessionmaker(bind=self._engine)
def _get_session(self):
DatabaseBase.metadata.create_all(self._engine)
return self._SessionFactory()
# TODO: replace this function with a better method of doing this
@_pass_session
def initialize(self, sess):
# NOTE: this method was created because I was unable to get sqlalchemy's
# initial values to work
names = ['Pending', 'Processing', 'Complete', 'Failed']
events = [EventTable(n) for n in names]
self.add_tables(events)
@_pass_session
def add_tables(self, sess, Tables):
try:
for T in Tables:
sess.add(T)
sess.commit()
except:
# TODO: implement fine-grained handling
raise
@_pass_session
def remove_table(sess, Table, expr=True):
try:
sess.query(Table).filter(expr).delete()
sess.commit()
except:
# TODO: implement fine-grained handling
raise
@_pass_session
def update_table(sess, Table, expr):
pass
@_pass_session
def select_tables(self, sess, Table, expr=True):
try:
q = sess.query(Table).filter(expr).all()
return q
except:
# TODO: implement fine-grained handling
raise
# END: Generic Database API
# BEGIN: Internal functions
# BEGIN: Database object functions
def _get_database(name):
try:
ds = select_state(DatabaseState,
DatabaseState.name == name)
db = Database(ds[0].location) # select_state returns a list
return db
except:
# TODO: implement fine-grained handling
raise
# TODO: name is easily confused with add_database, consider refactoring
def _database_add(name, Tables):
try:
db = _get_database(name)
db.add_tables(Tables)
except:
# TODO: implement fine-grained handling
raise
# TODO: name is easily confused with remove_database, consider refactoring
def _database_remove(name, Table):
try:
db = _get_database(name)
# remove_table's expr parameter is true by default
# and will remove everything in the given table if
# not passed another value
db.remove_table(Table)
except:
# TODO: implement fine-grained handling
raise
# TODO: name is easily confused with add_database, consider refactoring
def _database_update(name, Table):
try:
db = _get_database(name)
pass
except:
# TODO: implement fine-grained handling
raise
# TODO: name is easily confused with add_database, consider refactoring
def _database_select(name, Table, expr):
try:
db = _get_database(name)
Ts = db.select_tables(Table, expr) # returns list of Tables
return Ts
except:
# TODO: implement fine-grained handling
raise
# END: Database object functions
# END: Internal functions
# BEGIN: External functions
# BEGIN: Database State functions
def add_database(name, location):
try:
ds = DatabaseState(name, location)
add_state(ds)
# TODO: replace with a better method of doing this
db = _get_database(name)
db.initialize()
except:
# TODO: implement fine-grained handling
raise
def remove_database(name):
try:
remove_state(DatabaseState,
DatabaseState.name == name)
except:
# TODO: implement fine-grained handling
raise
def update_database(name):
try:
pass
except:
# TODO: implement fine-grained handling
raise
def list_database():
try:
# select_state's expr parameter is true by default
# and will return everything in the state table if
# not passed another value
ds = select_state(DatabaseState)
return ds
except:
# TODO: implement fine-grained handling
raise
# END: Database State functions
# END: External functions
def tester():
name = 'debug'
#location = '.data/my_database'
#add_database(name, location)
#print(list_database())
#database_add(name, [EvidenceTable('my_file', '/path/to/file')])
print(_database_select(name, EventTable, EventTable.name == 'Pending')[0].name)
tester()
'''
def stage(name, batch):
vault_state = _state_get_by_name(name)
vault = Vault(vault_state.path)
batch_id = vault.insert(VaultBatch(batch))
tests = vault.get_tests()
l = [VaultJob(t.id, batch_id) for t in tests]
vault.insert(l)
def upload(name, evidence, comparisons):
vault_state = _state_get_by_name(name)
vault_root = vault_state.root
vault = Vault(vault_state.path)
evidence_name = evidence['name']
evidence_file = '%s/input/%s.json' % (vault_root, evidence_name)
evidence_file = str(abs_path(evidence_file))
write_json(evidence_file, evidence)
evidence_id = vault.insert(VaultEvidence(evidence_name, evidence_file))
for c in comparisons:
comparison_name = c['name']
comparison_file = '%s/input/%s.json' % (vault_root, comparison_name)
comparison_file = str(abs_path(comparison_file))
write_json(comparison_file, c)
comparison_id = vault.insert(VaultComparison(comparison_name, comparison_file))
vault.insert(VaultTest(evidence_id, comparison_id))
def next_job(name):
vault_state = _state_get_by_name(name)
vault = Vault(vault_state.path)
s = vault.get_session()
job_row = s.query(VaultJob).filter_by(event_id=1).first()
test_row = s.query(VaultTest).filter_by(id=job_row.test_id).first()
evidence = s.query(VaultEvidence).filter_by(id=test_row.evidence_id).first()
comparison = s.query(VaultComparison).filter_by(id=test_row.comparison_id).first()
s.close()
return [evidence, comparison]
def all_jobs(name):
vault_state = _state_get_by_name(name)
vault = Vault(vault_state.path)
s = vault.get_session()
jobs = []
for job_row in s.query(VaultJob).filter_by(event_id=1).all():
test_row = s.query(VaultTest).filter_by(id=job_row.test_id).first()
evidence = s.query(VaultEvidence).filter_by(id=test_row.evidence_id).first()
comparison = s.query(VaultComparison).filter_by(id=test_row.comparison_id).first()
jobs.append([evidence, comparison])
s.close()
return jobs
def insert_result(vault, name, path):
# get the current batch's id
batch_id = vault.select_last(VaultBatch).id
r = VaultResult(name, path, bath_id)
vault.insert(r)
def copy(name, destination):
vault = _state_get_by_name(name)
copy_file(vault.path, destination)
def status(name):
status = {}
vault_state = _state_get_by_name(name)
vault = Vault(vault_state.path)
s = vault.get_session()
status['num_jobs'] = vault.sizeof(VaultJob)
if status['num_jobs'] == 0:
return status
status['num_pending'] = s.query(VaultJob).filter_by(event_id=1).count()
status['num_processing'] = s.query(VaultJob).filter_by(event_id=2).count()
status['num_complete'] = s.query(VaultJob).filter_by(event_id=3).count()
job_row = s.query(VaultJob).filter_by(event_id=1).first()
test_row = s.query(VaultTest).filter_by(id=job_row.test_id).first()
evidence_name = s.query(VaultEvidence).filter_by(id=test_row.evidence_id).first().name
comparison_name = s.query(VaultComparison).filter_by(id=test_row.comparison_id).first().name
status['next_job'] = [evidence_name, comparison_name]
s.close()
return status
'''

26
cjs/models/database.py

@ -7,12 +7,12 @@ from sqlalchemy import Column, Integer, String, ForeignKey, DateTime
DatabaseBase = declarative_base()
class VaultEvidence(DatabaseBase):
class EvidenceTable(DatabaseBase):
'''Used to store evidence files.'''
__tablename__ = 'Evidence'
id = Column('id', Integer, primary_key=True)
id_ = Column('id', Integer, primary_key=True)
name = Column('name', String(32), unique=True)
path = Column('path', String(260), unique=True)
@ -28,7 +28,7 @@ class VaultComparison(DatabaseBase):
__tablename__ = 'Comparison'
id = Column('id', Integer, primary_key=True)
id_ = Column('id', Integer, primary_key=True)
name = Column('name', String(32), unique=True)
path = Column('path', String(260), unique=True)
@ -44,7 +44,7 @@ class VaultTest(DatabaseBase):
__tablename__ = 'Test'
id = Column('id', Integer, primary_key=True)
id_ = Column('id', Integer, primary_key=True)
evidence_id = Column('evidence_id', ForeignKey('Evidence.id'))
comparison_id = Column('comparison_id', ForeignKey('Comparison.id'))
@ -55,12 +55,12 @@ class VaultTest(DatabaseBase):
def __str__(self):
return '%s, %s, %s' % (self.id, self.evidence_id, self.comparison_id)
class VaultEvent(DatabaseBase):
class EventTable(DatabaseBase):
'''Used to identify which set that a result belongs to.'''
__tablename__ = 'Event'
id = Column('id', Integer, primary_key=True)
id_ = Column('id', Integer, primary_key=True)
name = Column('name', String(32), unique=True)
def __init__(self, name):
@ -69,12 +69,12 @@ class VaultEvent(DatabaseBase):
def __str__(self):
return '%s, %s' % (self.id, self.name)
class VaultJob(DatabaseBase):
class JobTable(DatabaseBase):
'''Used to store comparison files.'''
__tablename__ = 'Job'
id = Column('id', Integer, primary_key=True)
id_ = Column('id', Integer, primary_key=True)
test_id = Column('test_id', ForeignKey('Test.id'))
event_id = Column('event_id', ForeignKey('Event.id'), default=1)
batch_id = Column('batch_id', ForeignKey('Batch.id'))
@ -84,12 +84,12 @@ class VaultJob(DatabaseBase):
self.test_id = test_id
self.batch_id = batch_id
class VaultBatch(DatabaseBase):
class BatchTable(DatabaseBase):
'''Used to identify which set that a result belongs to.'''
__tablename__ = 'Batch'
id = Column('id', Integer, primary_key=True)
id_ = Column('id', Integer, primary_key=True)
name = Column('name', String(32), unique=True)
def __init__(self, name):
@ -98,16 +98,16 @@ class VaultBatch(DatabaseBase):
def __str__(self):
return '%s, %s' % (self.id, self.name)
class VaultResult(DatabaseBase):
class ResultTable(DatabaseBase):
'''Used to store output files from completed jobs.'''
__tablename__ = 'Result'
id = Column('id', Integer, primary_key=True)
id_ = Column('id', Integer, primary_key=True)
batch_id = Column('batch_id', ForeignKey('Batch.id'))
name = Column('name', String(32))
path = Column('path', String(260))
processing_time = Column('processing_time', Integer())
batch_id = Column('batch_id', ForeignKey('Batch.id'))
def __init__(self, name, path, batch_id, time):
self.name = name

14
cjs/models/state.py

@ -5,15 +5,13 @@ from sqlalchemy import Column, Integer, String
StateBase = declarative_base()
class StateVault(StateBase):
__tablename__ = 'Vault'
class DatabaseState(StateBase):
__tablename__ = 'Database'
id = Column('id', Integer, primary_key=True)
id_ = Column('id', Integer, primary_key=True)
name = Column('name', String(32), unique=True)
path = Column('path', String(260), unique=True)
root = Column('root', String(260), unique=True)
location = Column('location', String(260), unique=True)
def __init__(self, name, path, root):
def __init__(self, name, location):
self.name = name
self.root = root
self.path = path
self.location = location
Loading…
Cancel
Save