diff --git a/fastdb_get_alerts/load_ppdb_from_alerts.py b/fastdb_get_alerts/load_ppdb_from_alerts.py index 4a83a39f..a11418e7 100644 --- a/fastdb_get_alerts/load_ppdb_from_alerts.py +++ b/fastdb_get_alerts/load_ppdb_from_alerts.py @@ -1,693 +1 @@ -import sys -import pathlib -import logging -import fastavro -import json -import multiprocessing -import signal -import time -import confluent_kafka -import io -import os -import re -import traceback -import datetime -import collections -import atexit -from optparse import OptionParser -from psycopg2.extras import execute_values -from psycopg2 import sql -import psycopg2 - -_rundir = pathlib.Path(__file__).parent -print(_rundir) - -_logger = logging.getLogger(__name__) -if not _logger.hasHandlers(): - _logout = logging.FileHandler( _rundir / f"logs/alerts.log" ) - _logger.addHandler( _logout ) - _formatter = logging.Formatter( f'[msgconsumer - %(asctime)s - %(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) - _logout.setFormatter( _formatter ) -_logger.setLevel( logging.DEBUG ) - -def _donothing( *args, **kwargs ): - pass - -def close_msg_consumer( obj ): - obj.close() - -class MsgConsumer(object): - def __init__( self, server, groupid, schema, topics=None, extraconsumerconfig=None,consume_nmsgs=10, consume_timeout=5, nomsg_sleeptime=1, logger=_logger ): - """Wraps a confluent_kafka.Consumer. - - server : the bootstrap.servers value - groupid : the group.id value - schema : filename where the schema of messages to be consumed can be found - topics : topic name, or list of topic names, to subscribe to - extraconsumerconfig : (optional) additional consumer config (dict) - consume_nmsgs : number of messages to pull from the server at once (default 10) - consume_timeout : timeout after waiting on the server for this many seconds - nomsg_sleeptime : sleep for this many seconds after a consume_timeout before trying again - logger : a logging object - - """ - - self.consumer = None - self.logger = logger - self.tot_handled = 0 - - self.schema = fastavro.schema.load_schema( schema ) - self.consume_nmsgs = consume_nmsgs - self.consume_timeout = consume_timeout - self.nomsg_sleeptime = nomsg_sleeptime - - consumerconfig = { "bootstrap.servers": server, - "auto.offset.reset": "earliest", - "group.id": groupid, - } - - if extraconsumerconfig is not None: - consumerconfig.update( extraconsumerconfig ) - - self.logger.debug( f'Initializing Kafka consumer with\n{json.dumps(consumerconfig, indent=4)}' ) - self.consumer = confluent_kafka.Consumer( consumerconfig ) - atexit.register( close_msg_consumer, self ) - - self.subscribed = False - self.subscribe( topics ) - - def close( self ): - if self.consumer is not None: - self.logger.info( "Closing MsgConsumer" ) - self.consumer.close() - self.consumer = None - - def subscribe( self, topics ): - if topics is None: - self.topics = [] - elif isinstance( topics, str ): - self.topics = [ topics ] - elif isinstance( topics, collections.abc.Sequence ): - self.topics = list( topics ) - else: - raise ValueError( f'topics must be either a string or a list' ) - - servertopics = self.get_topics() - subtopics = [] - for topic in self.topics: - if topic not in servertopics: - self.logger.warning( f'Topic {topic} not on server, not subscribing' ) - else: - subtopics.append( topic ) - self.topics = subtopics - - #for topic in self.topics: - # st = [i for i in servertopics if topic in i] - # if len(st) !=0: - # for t in st: - # subtopics.append(t) - # else: - # self.logger.warning( f'Topic {topic} not on server, not subscribing' ) - #self.topics = subtopics - - if self.topics is not None and len(self.topics) > 0: - self.logger.info( f'Subscribing to topics: {", ".join( self.topics )}' ) - self.consumer.subscribe( self.topics, on_assign=self._sub_callback ) - else: - self.logger.warning( f'No existing topics given, not subscribing.' ) - - def get_topics( self ): - cluster_meta = self.consumer.list_topics() - return [ n for n in cluster_meta.topics ] - - def print_topics( self, newlines=False ): - topics = self.get_topics() - if not newlines: - self.logger.info( f"\nTopics: {', '.join(topics)}" ) - else: - topicstr = '\n '.join( topics ) - self.logger.info( f"\nTopics:\n {topicstr}" ) - - def _get_positions( self, partitions ): - return self.consumer.position( partitions ) - - def _dump_assignments( self, ofp, partitions ): - ofp.write( f'{"Topic":<32s} {"partition":>9s} {"offset":>12s}\n' ) - for par in partitions: - ofp.write( f"{par.topic:32s} {par.partition:9d} {par.offset:12d}\n" ) - ofp.write( "\n" ) - - def print_assignments( self ): - asmgt = self._get_positions( self.consumer.assignment() ) - ofp = io.StringIO() - ofp.write( "Current partition assignments\n" ) - self._dump_assignments( ofp, asmgt ) - self.logger.info( ofp.getvalue() ) - ofp.close() - - def _sub_callback( self, consumer, partitions ): - self.subscribed = True - ofp = io.StringIO() - ofp.write( "Consumer subscribed. Assigned partitions:\n" ) - self._dump_assignments( ofp, self._get_positions( partitions ) ) - self.logger.info( ofp.getvalue() ) - ofp.close() - - def reset_to_start( self, topic ): - partitions = self.consumer.list_topics( topic ).topics[topic].partitions - self.logger.info( f'Resetting partitions for topic {topic}' ) - # partitions is a map - partlist = [] - # Must consume one message to really hook up to the topic - self.consume_one_message( handler=_donothing, timeout=10 ) - for i in range(len(partitions)): - self.logger.info( f'...resetting partition {i}' ) - curpart = confluent_kafka.TopicPartition( topic, i ) - lowmark, highmark = self.consumer.get_watermark_offsets( curpart ) - self.logger.debug( f'Partition {curpart.topic} has id {curpart.partition} ' - f'and current offset {curpart.offset}; lowmark={lowmark} ' - f'and highmark={highmark}' ) - curpart.offset = lowmark - if lowmark < highmark: - self.consumer.seek( curpart ) - partlist.append( curpart ) - self.logger.info( f'Committing partition offsets.' ) - self.consumer.commit( offsets=partlist ) - self.tot_handled = 0 - - def consume_one_message( self, timeout=None, handler=None ): - """Both calls handler and returns a batch of 1 message.""" - timeout = self.consume_timeout if timeout is None else timeout - self.logger.info( f"Trying to consume one message with timeout {timeout}...\n" ) - msgs = self.consumer.consume( 1, timeout=timeout ) - if len(msgs) == 0: - return None - else: - self.tot_handled += len(msgs) - if handler is not None: - handler( msgs ) - else: - self.default_handle_message_batch( msgs ) - - def default_handle_message_batch( self, msgs ): - self.logger.info( f'Got {len(msgs)}; have received {self._tot_handled} so far.' ) - - def echoing_handle_message_batch( self, msgs ): - self.logger.info( f'Handling {len(msgs)} messages' ) - for msg in msgs: - ofp = io.StringIO( f"Topic: {msg.topic()} ; Partition: {msg.partition()} ; " - f"Offset: {msg.offset()} ; Key: {msg.key()}\n" ) - alert = fastavro.schemaless_reader( io.BytesIO(msg.value()), self.schema ) - ofp.write( json.dumps( alert, indent=4, sort_keys=True ) ) - ofp.write( "\n" ) - self.logger.info( ofp.getvalue() ) - ofp.close() - self.logger.info( f'Have handled {self.tot_handled} messages so far' ) - - def poll_loop( self, handler=None, max_consumed=None, pipe=None, max_runtime=datetime.timedelta(hours=1) ): - """Calls handler with batches of messages. - - handler : a callback that's called with batches of messages (the list - returned by confluent_kafka.Consumer.consume(). - max_consumed : Quit polling after this many messages have been - consumed (default: no limit) - pipe : A pipe to send regular heartbeats to, and to listen for "die" messages from. - max_runtime : Quit polling after this much time has elapsed; - must be a datetime.timedelta object. (Default: 1h.) - - returns True if consumed ?max_consumed or timed out, False if died due to die command - """ - nconsumed = 0 - starttime = datetime.datetime.now() - keepgoing = True - retval = True - while keepgoing: - self.logger.debug( f"Trying to consume {self.consume_nmsgs} messages " - f"with timeout {self.consume_timeout}..." ) - msgs = self.consumer.consume( self.consume_nmsgs, timeout=self.consume_timeout ) - if len(msgs) == 0: - self.logger.debug( f"No messages, sleeping {self.nomsg_sleeptime} sec" ) - time.sleep( self.nomsg_sleeptime ) - else: - self.logger.debug( f"...got {len(msgs)} messages" ) - self.tot_handled += len(msgs) - if handler is not None: - handler( msgs ) - else: - self.default_handle_message_batch( msgs ) - nconsumed += len( msgs ) - runtime = datetime.datetime.now() - starttime - if ( ( ( max_consumed is not None ) and ( nconsumed >= max_consumed ) ) - or - ( ( max_runtime is not None ) and ( runtime > max_runtime ) ) ): - keepgoing = False - if pipe is not None: - pipe.send( { "message": "ok", "nconsumed": nconsumed, "runtime": runtime } ) - if pipe.poll(): - msg = pipe.recv() - if ( 'command' in msg ) and ( msg['command'] == 'die' ): - self.logger.info( "Exiting poll loop due to die command." ) - retval = False - keepgoing = False - else: - self.logger.error( f"Received unknown message from pipe, ignoring: {msg}" ) - - self.logger.info( f"Stopping poll loop after consuming {nconsumed} messages during {runtime}" ) - return retval - - -class BrokerConsumer: - def __init__( self, server, groupid, topics=None, updatetopics=False, - schemaless=True, reset=False, extraconfig={}, - schemafile=None, pipe=None, loggername="BROKER", **kwargs ): - - self.logger = logging.getLogger( loggername ) - self.logger.propagate = False - logout = logging.FileHandler( _rundir / f"logs/alerts.log" ) - self.logger.addHandler( logout ) - formatter = logging.Formatter( f'[%(asctime)s - {loggername} - %(levelname)s] - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' ) - logout.setFormatter( formatter ) - self.logger.setLevel( logging.DEBUG ) - - if schemafile is None: - schemafile = _rundir / "elasticc.v0_9_1.alert.avsc" - - self.countlogger = logging.getLogger( f"countlogger_{loggername}" ) - self.countlogger.propagate = False - _countlogout = logging.FileHandler( _rundir / f"logs/alertpoll_counts_{loggername}.log" ) - _countformatter = logging.Formatter( f'[%(asctime)s - %(levelname)s] - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' ) - _countlogout.setFormatter( _countformatter ) - self.countlogger.addHandler( _countlogout ) - self.countlogger.setLevel( logging.DEBUG ) - - self.countlogger.info( f"************ Starting Brokerconsumer for {loggername} ****************" ) - - self.pipe = pipe - self.server = server - self.groupid = groupid - self.topics = topics - self._updatetopics = updatetopics - self._reset = reset - self.extraconfig = extraconfig - - self.schemaless = schemaless - if not self.schemaless: - self.countlogger.error( "CRASHING. I only know how to handle schemaless streams." ) - raise RuntimeError( "I only know how to handle schemaless streams" ) - self.schemafile = schemafile - self.schema = fastavro.schema.load_schema( self.schemafile ) - - self.nmessagesconsumed = 0 - - - @property - def reset( self ): - return self._reset - - @reset.setter - def reset( self, val ): - self._reset = val - - def create_connection( self ): - countdown = 5 - while countdown >= 0: - try: - self.consumer = MsgConsumer( self.server, self.groupid, self.schemafile, self.topics, extraconsumerconfig=self.extraconfig, consume_nmsgs=1000, consume_timeout=1, nomsg_sleeptime=5, logger=self.logger ) - countdown = -1 - except Exception as e: - countdown -= 1 - strio = io.StringIO("") - strio.write( f"Exception connecting to broker: {str(e)}" ) - traceback.print_exc( file=strio ) - self.logger.warning( strio.getvalue() ) - if countdown >= 0: - self.logger.warning( "Sleeping 5s and trying again." ) - time.sleep(5) - else: - self.logger.error( "Repeated exceptions connecting to broker, punting." ) - self.countlogger.error( "Repeated exceptions connecting to broker, punting." ) - raise RuntimeError( "Failed to connect to broker" ) - - if self._reset and ( self.topics is not None ): - self.countlogger.info( f"*************** Resetting to start of broker kafka stream ***************" ) - self.reset_to_start() - # Only want to reset the first time the connection is opened! - self._reset = False - - self.countlogger.info( f"**************** Consumer connection opened *****************" ) - - def close_connection( self ): - self.countlogger.info( f"**************** Closing consumer connection ******************" ) - self.consumer.close() - self.consumer = None - - - def reset_to_start( self ): - self.logger.info( "Resetting all topics to start" ) - for topic in self.topics: - self.consumer.reset_to_start( topic ) - - def handle_message_batch( self, msgs ): - messagebatch = [] - self.countlogger.info( f"Handling {len(msgs)} messages; consumer has received " - f"{self.consumer.tot_handled} messages." ) - list_diaObject = [] - list_diaSource = [] - list_diaForcedSource = [] - - for msg in msgs: - timestamptype, timestamp = msg.timestamp() - - - if timestamptype == confluent_kafka.TIMESTAMP_NOT_AVAILABLE: - timestamp = None - else: - timestamp = datetime.datetime.fromtimestamp( timestamp / 1000, tz=datetime.timezone.utc ) - - payload = msg.value() - if not self.schemaless: - self.countlogger.error( "I only know how to handle schemaless streams" ) - raise RuntimeError( "I only know how to handle schemaless streams" ) - alert = fastavro.schemaless_reader( io.BytesIO( payload ), self.schema ) - - - diaObject = [] - diaSource = [] - diaForcedSource = [] - - diaObject.append(alert['diaObject']['diaObjectId']) - diaObject.append(timestamp) - diaObject.append(alert['diaObject']['ra']) - diaObject.append(alert['diaObject']['decl']) - - # Now store DiaSources and DiaForcedSources - - diaSource.append(alert['diaSource']['diaSourceId']) - diaSource.append(alert['diaSource']['diaObjectId']) - diaSource.append(alert['diaSource']['midPointTai']) - diaSource.append(alert['diaSource']['ra']) - diaSource.append(alert['diaSource']['decl']) - diaSource.append(alert['diaSource']['psFlux']) - diaSource.append(alert['diaSource']['psFluxErr']) - diaSource.append(alert['diaSource']['snr']) - diaSource.append(alert['diaSource']['filterName']) - diaSource.append(timestamp) - - list_diaObject.append(diaObject) - list_diaSource.append(diaSource) - - - for s in alert['prvDiaSources']: - - diaSource = [] - ccdVisit = [] - diaSource.append(s['diaSourceId']) - diaSource.append(s['diaObjectId']) - diaSource.append(s['midPointTai']) - diaSource.append(s['ra']) - diaSource.append(s['decl']) - diaSource.append(s['psFlux']) - diaSource.append(s['psFluxErr']) - diaSource.append(s['snr']) - diaSource.append(s['filterName']) - diaSource.append(timestamp) - - list_diaSource.append(diaSource) - - - for s in alert['prvDiaForcedSources']: - - diaForcedSource = [] - ccdVisit = [] - diaForcedSource.append(s['diaForcedSourceId']) - diaForcedSource.append(s['diaObjectId']) - diaForcedSource.append(s['psFlux']) - diaForcedSource.append(s['psFluxErr']) - diaForcedSource.append(s['filterName']) - diaForcedSource.append(timestamp) - - list_diaForcedSource.append(diaForcedSource) - - - # Connect to the fake PPDB - - # Get password - - secret = os.environ['PPDB_WRITER_PASSWORD'] - conn_string = "host='fastdb-ppdb-psql' dbname='ppdb' user='ppdb_writer' password='%s'" % secret.strip() - self.logger.info("Connecting to database %s" % conn_string) - conn = psycopg2.connect(conn_string) - - cursor = conn.cursor() - self.logger.info("Connected") - - names = ['diaObjectId','validityStart','ra','decl'] - - query = sql.SQL("INSERT INTO {} ({}) VALUES ({})").format(sql.Identifier('DiaObject'),sql.SQL(', ').join(map(sql.Identifier, names)),sql.SQL(', ').join(map(sql.Placeholder, names))) - self.logger.info(query.as_string(conn)) - q = sql.SQL("SELECT {} from {} where {} = %s").format(sql.Identifier('diaObjectId'),sql.Identifier('DiaObject'),sql.Identifier('diaObjectId')) - self.logger.info(q.as_string(conn)) - for r in list_diaObject: - # Make sure we haven't already added this object - - cursor.execute(q,(r[0],)) - if cursor.rowcount == 0: - cursor.execute(query, {'diaObjectId':r[0], 'validityStart':r[1], 'ra':r[2], 'decl':r[3]}) - conn.commit() - - # Now add data for each source and forced source - - names = ['diaSourceId','filterName','diaObjectId','midPointTai', 'ra','decl','psFlux', 'psFluxSigma', 'snr', 'observeDate'] - query = sql.SQL( "INSERT INTO {} ({}) VALUES ({})").format(sql.Identifier('DiaSource'),sql.SQL(',').join(map(sql.Identifier, names)),sql.SQL(',').join(map(sql.Placeholder, names))) - q = sql.SQL("SELECT {} from {} where {} = %s").format(sql.Identifier('diaSourceId'),sql.Identifier('DiaSource'),sql.Identifier('diaSourceId')) - self.logger.info(query.as_string(conn)) - for r in list_diaSource: - # Make sure we haven't already added this object - - cursor.execute(q,(r[0],)) - if cursor.rowcount == 0: - cursor.execute(query, {'diaSourceId':r[0],'filterName':r[8],'diaObjectId':r[1],'midPointTai':r[2], 'ra':r[3],'decl':r[4],'psFlux':r[5], 'psFluxSigma':r[6], 'snr':r[7], 'observeDate':r[9]}) - conn.commit() - - names = ['diaForcedSourceId','filterName','diaObjectId','psFlux', 'psFluxSigma', 'observeDate'] - query = sql.SQL( "INSERT INTO {} ({}) VALUES ({})").format(sql.Identifier('DiaForcedSource'),sql.SQL(',').join(map(sql.Identifier, names)),sql.SQL(',').join(map(sql.Placeholder, names))) - self.logger.info(query.as_string(conn)) - q = sql.SQL("SELECT {} from {} where {} = %s").format(sql.Identifier('diaForcedSourceId'),sql.Identifier('DiaForcedSource'),sql.Identifier('diaForcedSourceId')) - for r in list_diaForcedSource: - # Make sure we haven't already added this object - - cursor.execute(q,(r[0],)) - if cursor.rowcount == 0: - cursor.execute(query, {'diaForcedSourceId':r[0],'filterName':r[4],'diaObjectId':r[1],'psFlux':r[2], 'psFluxSigma':r[3], 'observeDate':r[5]}) - - conn.commit() - cursor.close() - conn.close() - - - - def poll( self, restart_time=datetime.timedelta(minutes=30) ): - self.create_connection() - while True: - if self._updatetopics: - self.update_topics() - strio = io.StringIO("") - if len(self.consumer.topics) == 0: - self.logger.info( "No topics, will wait 10s and reconnect." ) - time.sleep(10) - else: - self.logger.info( f"Subscribed to topics: {self.consumer.topics}; starting poll loop." ) - self.countlogger.info( f"Subscribed to topics: {self.consumer.topics}; starting poll loop." ) - try: - happy = self.consumer.poll_loop( handler=self.handle_message_batch, - max_consumed=None, max_runtime=restart_time, - pipe=self.pipe ) - if happy: - strio.write( f"Reached poll timeout for {self.server}; " - f"handled {self.consumer.tot_handled} messages. " ) - else: - strio.write( f"Poll loop received die command after handling " - f"{self.consumer.tot_handled} messages. Exiting." ) - self.logger.info( strio.getvalue() ) - self.countlogger.info( strio.getvalue() ) - self.close_connection() - return - except Exception as e: - otherstrio = io.StringIO("") - traceback.print_exc( file=otherstrio ) - self.logger.warning( otherstrio.getvalue() ) - strio.write( f"Exception polling: {str(e)}. " ) - - if self.pipe.poll(): - msg = self.pipe.recv() - if ( 'command' in msg ) and ( msg['command'] == 'die' ): - self.logger.info( "No topics, but also exiting broker poll due to die command." ) - self.countlogger.info( "No topics, but also existing broker poll due to die command." ) - self.close_connection() - return - strio.write( "Reconnecting.\n" ) - self.logger.info( strio.getvalue() ) - self.countlogger.info( strio.getvalue() ) - self.close_connection() - if self._updatetopics: - self.topics = None - self.create_connection() - - - def store(self, messages = None): - - messagebatch = messages - results = self.collection.insert_many(messagebatch) - count = len(results.inserted_ids) - self.logger.info(f"Inserted {count} messages") - - return - -class PublicZTFConsumer(BrokerConsumer): - - def __init__( self, loggername="PUBLIC_ZTF", **kwargs ): - - server = "public.alerts.ztf.uw.edu:9092" - #server = "127.0.0.0:9092" - groupid = "elasticc-lbnl" - topics = ['elasticc-1'] - #server = "kafka.antares.noirlab.edu:9092" - updatetopics = False - extraconfig = {} - - super().__init__( server, groupid, topics=topics, updatetopics=updatetopics, extraconfig=extraconfig, loggername=loggername, **kwargs ) - self.logger.info( f"Public ZTF group id is {groupid}" ) - - -class Broker(object): - - def __init__( self, reset=False, *args, **kwargs ): - - self.logger = logging.getLogger( "brokerpoll_baselogger" ) - self.logger.propagate = False - logout = logging.FileHandler( _rundir / f"logs/alertpoll.log" ) - self.logger.addHandler( logout ) - formatter = logging.Formatter( f'[%(asctime)s - alertpoll - %(levelname)s] - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' ) - logout.setFormatter( formatter ) - self.logger.setLevel( logging.DEBUG ) - self.reset = reset - - - def sigterm( self, sig="TERM" ): - self.logger.warning( f"Got a {sig} signal, trying to die." ) - self.mustdie = True - - def launch_broker( self, brokerclass, pipe, **options ): - signal.signal( signal.SIGINT, - lambda sig, stack: self.logger.warning( f"{brokerclass.__name__} ignoring SIGINT" ) ) - signal.signal( signal.SIGTERM, - lambda sig, stack: self.logger.warning( f"{brokerclass.__name__} ignoring SIGTERM" ) ) - consumer = brokerclass( pipe=pipe ) - consumer.poll() - - def broker_poll( self, *args, **options ): - self.logger.info( "******** brokerpoll starting ***********" ) - - self.mustdie = False - signal.signal( signal.SIGTERM, lambda sig, stack: self.sigterm( "TERM" ) ) - signal.signal( signal.SIGINT, lambda sig, stack: self.sigterm( "INT" ) ) - - #brokerstodo = { 'antares': AntaresConsumer, - # 'fink': FinkConsumer, - # 'alerce': AlerceConsumer, - # 'ztf': PublicZTFConsumer } - brokerstodo = {'ztf': PublicZTFConsumer } - - brokers = {} - - # Launch a process for each broker that will poll that broker indefinitely - - for name,brokerclass in brokerstodo.items(): - self.logger.info( f"Launching thread for {name}" ) - parentconn, childconn = multiprocessing.Pipe() - proc = multiprocessing.Process( target=self.launch_broker(brokerclass, childconn, **options) ) - proc.start() - brokers[name] = { "process": proc, - "pipe": parentconn, - "lastheartbeat": time.monotonic() } - - # Listen for a heartbeat from all processes. - # If we don't get a heartbeat for 5min, - # kill that process and restart it. - - heartbeatwait = 2 - toolongsilent = 300 - while not self.mustdie: - try: - pipelist = [ b['pipe'] for i,b in brokers.items() ] - whichpipe = multiprocessing.connection.wait( pipelist, timeout=heartbeatwait ) - - brokerstorestart = set() - for name, broker in brokers.items(): - try: - while broker['pipe'].poll(): - msg = broker['pipe'].recv() - if ( 'message' not in msg ) or ( msg['message'] != "ok" ): - self.logger.error( f"Got unexpected message from thread for {name}; " - f"will restart: {msg}" ) - brokerstorestart.add( name ) - else: - self.logger.debug( f"Got heartbeat from {name}" ) - broker['lastheartbeat'] = time.monotonic() - except Exception as ex: - self.logger.error( f"Got exception listening for heartbeat from {name}; will restart." ) - brokerstorestart.add( name ) - - for name, broker in brokers.items(): - dt = time.monotonic() - broker['lastheartbeat'] - if dt > toolongsilent: - self.logger.error( f"It's been {dt:.0f} seconds since last heartbeat from {name}; "f"will restart." ) - brokerstorestart.add( name ) - - for torestart in brokerstorestart: - self.logger.warning( f"Killing and restarting process for {torestart}" ) - brokers[torestart]['process'].kill() - brokers[torestart]['pipe'].close() - del brokers[torestart] - parentconn, childconn = multiprocessing.Pipe() - proc = multiprocessing.Process( target=lambda: self.launch_broker( brokerstodo[torestart], - childconn, **options ) ) - proc.start() - brokers[torestart] = { "process": proc, - "pipe": parentconn, - "lastheartbeat": time.monotonic() } - except Exception as ex: - self.logger.exception( "brokerpoll got an exception, going to shut down." ) - self.mustdie = True - - # I chose 20s since kubernetes sends a TERM and then waits 30s before shutting things down - self.logger.warning( "Shutting down. Sending die to all processes and waiting 20s" ) - for name, broker in brokers.items(): - broker['pipe'].send( { "command": "die" } ) - time.sleep( 20 ) - self.logger.warning( "Exiting." ) - return - - -if __name__ == '__main__': - - logger = logging.getLogger( "brokerpoll_baselogger" ) - logger.propagate = False - logout = logging.FileHandler( _rundir / f"logs/alertpoll.log" ) - logger.addHandler( logout ) - formatter = logging.Formatter( f'[%(asctime)s - alertrpoll - %(levelname)s] - %(message)s',datefmt='%Y-%m-%d %H:%M:%S' ) - logout.setFormatter( formatter ) - logger.setLevel( logging.DEBUG ) - - - parser = OptionParser() - parser.add_option('-r', '--reset', action='store_true', default=False, help='Reset all stream pointers') - - (options, args) = parser.parse_args() - - broker = Broker(reset=options.reset) - - poll = broker.broker_poll(reset=options.reset) +raise RuntimeError( 'Deprecated, use the fastdb_dev "fastdb_dev_brokerpoll" django management command.' ) diff --git a/tests/alertcycle_testbase.py b/tests/alertcycle_testbase.py new file mode 100644 index 00000000..8e138815 --- /dev/null +++ b/tests/alertcycle_testbase.py @@ -0,0 +1,536 @@ +# A base class for any class that tests the alert cycle. Right now there are two, in +# test_elasticc2_alertcycle.py and test_fastdb_dev_alertcycle.py. + +# Inefficiency note: the fixutres do everything needed for both the +# elasticc2 and fastdb_dev tests, but because those tests are separate, +# it all gets run twice. ¯\_(ツ)_/¯ + +import sys +import os +import pathlib +import datetime +import time +import pytz +import random +import subprocess +import multiprocessing +import pytest +import logging + +sys.path.insert( 0, "/tom_desc" ) +os.environ["DJANGO_SETTINGS_MODULE"] = "tom_desc.settings" +import django +django.setup() +import django.db + +import elasticc2.models +import fastdb_dev.models +import tom_targets.models + +from tom_client import TomClient +sys.stderr.write( "ALERTCYCLEFIXTURES IMPORTING testmsgconsumer\n" ) +from testmsgconsumer import MsgConsumer + +sys.path.insert( 0, pathlib.Path(__file__).parent ) +from fakebroker import FakeBroker + +class _AlertCounter: + def __init__( self ): + self._test_alerts_exist_count = 0 + + def handle_test_alerts_exist( self, msgs ): + self._test_alerts_exist_count += len(msgs) + + +# The numbers in these tests are based on the SNANA files in the +# directory elasticc2_alert_test_data under tests, which should +# be unpacked from elasticc2_alert_test_data.tar.bz2. + + +class AlertCycleTestBase: + + # ....oh, pytest. Too much magic. I had this logger + # in an __init__ function, but it turns out you can't + # use __init__ functions in classes that are + # pytest classes. + + logger = logging.getLogger( "alertcyclefixtures" ) + logger.propagate = False + _logout = logging.StreamHandler( sys.stderr ) + logger.addHandler( _logout ) + _formatter = logging.Formatter( f'[%(asctime)s - alertcyclefixtures - %(levelname)s] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' ) + _logout.setFormatter( _formatter ) + logger.setLevel( logging.INFO ) + + @pytest.fixture( scope="class", autouse=True ) + def alertcycletestbase_cleanup( self, elasticc2_ppdb_class ): + # The loaded database will have some of the alerts tagged as + # having been sent (as the database dump was done for all of + # these fixtures having been run!), but we want them all to be + # null for these tests, so fix that. + with django.db.connection.cursor() as cursor: + cursor.execute( "UPDATE elasticc2_ppdbalert SET alertsenttimestamp=NULL" ) + + yield True + + # Clean up the database. This assumes that there aren't any + # session fixtures that have left behind things in any of the + # tables in self._models_to_cleanup.... Perhaps we should be + # better about keeping track of what we've created. + + # (The PPDB and DiaObjectTruth tables are filled by a fixture + # in confetest.py that *is* a session-scope fixture.) + + self.logger.info( "Cleaning up alertcycle database entries" ) + for model in self._models_to_cleanup: + model.objects.all().delete() + + self._cleanup() + + + + # It's not really possible to clean up messages off of the kafka server. + # So, to have a "fresh" environment, we randomly generate topics each time + # we run this fixture, so those topics will begin empty. + @pytest.fixture( scope="class" ) + def topic_barf( self ): + return "".join( random.choices( "abcdefghijklmnopqrstuvwxyz", k=10 ) ) + + + @pytest.fixture( scope="class" ) + def fakebroker( self, topic_barf ): + broker = FakeBroker( "kafka-server:9092", + [ f"alerts-wfd-{topic_barf}", f"alerts-ddf-full-{topic_barf}" ], + "kafka-server:9092", + f"classifications-{topic_barf}" ) + proc = multiprocessing.Process( target=broker, args=[], daemon=True ) + proc.start() + + yield True + + proc.terminate() + proc.join() + + + @pytest.fixture( scope="class" ) + def brokerpoll_elasticc2( self, topic_barf ): + def run_brokerpoll( topic_barf ): + # Instead of using subprocess, use os.execvp so that signals we + # send to this process properly get to the command we're + # launching. (The brokerpoll2 management command loops + # forever, but captures SIGINT, SIGKILL, and SIGUSR1, + # and shuts itself down upon receiving any of those signals.) + # (That's the intention, anyway....) + sys.stdout.flush() + sys.stderr.flush() + os.chdir( "/tom_desc" ) + args = [ "python", "manage.py", "brokerpoll2", + "--do-test", + "--grouptag", "elasticc2", + "--test-topic", f"classifications-{topic_barf}" ] + os.execvp( args[0], args ) + + proc = multiprocessing.Process( target=run_brokerpoll, args=(topic_barf,), daemon=True ) + proc.start() + + yield True + + proc.terminate() + proc.join() + + + @pytest.fixture( scope="class" ) + def brokerpoll_fastdb_dev( self, topic_barf ): + def run_brokerpoll( topic_barf ): + # See comments in brokerpoll_elasticc2 + sys.stdout.flush() + sys.stderr.flush() + os.chdir( "/tom_desc" ) + args = [ "python", "manage.py", "fastdb_dev_brokerpoll", + "--do-test", + "--grouptag", "fastdb_dev", + "--test-topic", f"classifications-{topic_barf}" ] + os.execvp( args[0], args ) + + proc = multiprocessing.Process( target=run_brokerpoll, args=(topic_barf,), daemon=True ) + proc.start() + + yield True + + proc.terminate() + proc.join() + + + @pytest.fixture( scope="class" ) + def alerts_300days( self, topic_barf ): + result = subprocess.run( [ "python", "manage.py", "send_elasticc2_alerts", "-d", "60578", + "-k", "kafka-server:9092", + "--wfd-topic", f"alerts-wfd-{topic_barf}", + "--ddf-full-topic", f"alerts-ddf-full-{topic_barf}", + "--ddf-limited-topic", f"alerts-ddf-limited-{topic_barf}", + "-s", "/tests/schema/elasticc.v0_9_1.alert.avsc", + "-r", "sending_alerts_runningfile", + "--do" ], + cwd="/tom_desc", capture_output=True ) + sys.stderr.write( result.stderr.decode( 'utf-8' ) ) + assert result.returncode == 0 + + consumer = MsgConsumer( 'kafka-server:9092', + 'test_send_alerts', + [ f'alerts-wfd-{topic_barf}', f'alerts-ddf-full-{topic_barf}' ], + '/tests/schema/elasticc.v0_9_1.alert.avsc', + consume_nmsgs=100, + logger=self.logger ) + counter = _AlertCounter() + consumer.poll_loop( counter.handle_test_alerts_exist, timeout=10, stopafter=datetime.timedelta(seconds=10) ) + # I don't understand why this is 546. 545 were sent. + # The fake broker sees 545. + assert counter._test_alerts_exist_count == 546 + consumer.close() + + yield True + + + @pytest.fixture( scope="class" ) + def classifications_300days_exist( self, alerts_300days, topic_barf, fakebroker ): + counter = _AlertCounter() + consumer = MsgConsumer( 'kafka-server:9092', + 'test_classifications_exist', + f'classifications-{topic_barf}', + '/tests/schema/elasticc.v0_9_1.brokerClassification.avsc', + consume_nmsgs=100, + logger=self.logger ) + consumer.reset_to_start( f'classifications-{topic_barf}' ) + + # fake broker has a 10s sleep loop, so we can't + # assume things will be there instantly; thus, the 16s timeout. + + consumer.poll_loop( counter.handle_test_alerts_exist, timeout=5, + stopafter=datetime.timedelta(seconds=16) ) + + # This is 2x545 + assert counter._test_alerts_exist_count == 1090 + consumer.close() + + yield True + + + @pytest.fixture( scope="class" ) + def classifications_300days_elasticc2_ingested( self, classifications_300days_exist, mongoclient, + brokerpoll_elasticc2 ): + # Have to have an additional sleep after the classifications exist, + # because brokerpoll itself has a 10s sleep loop + time.sleep( 11 ) + + # Have to have these tests here rather than in the actual test_* + # file because I can't clean up, and there is hysteresis. Once + # later fixtures have run, the tests below would fail, and these + # fixtures may be used in more than one test. + + brkmsg = elasticc2.models.BrokerMessage + cfer = elasticc2.models.BrokerClassifier + bsid = elasticc2.models.BrokerSourceIds + + assert brkmsg.objects.count() == 1090 + assert cfer.objects.count() == 2 + assert bsid.objects.count() == 545 + + numprobs = 0 + for msg in brkmsg.objects.all(): + assert len(msg.classid) == len(msg.probability) + numprobs += len(msg.classid) + + # 545 from NugentClassifier plus 20*545 for RandomSNType + assert numprobs == 11445 + + assert ( set( [ i.classifiername for i in cfer.objects.all() ] ) + == set( [ "NugentClassifier", "RandomSNType" ] ) ) + + yield True + + + @pytest.fixture( scope="class" ) + def classifications_300days_fastdb_dev_ingested( self, classifications_300days_exist, mongoclient, + brokerpoll_fastdb_dev ): + # Have to have an additional sleep after the classifications exist, + # because brokerpoll itself has a 10s sleep loop + time.sleep( 11 ) + + # Have to have these tests here rather than in the actual test_* + # file because I can't clean up, and there is hysteresis. Once + # later fixtures have run, the tests below would fail, and these + # fixtures may be used in more than one test. + + db = mongoclient.alerts + + assert 'fakebroker' in db.list_collection_names() + + coll = db.fakebroker + assert coll.count_documents({}) == 1090 + + numprobs = 0 + for msg in coll.find(): + msg = msg['msg'] + assert msg['brokerName'] == 'FakeBroker' + assert msg['classifierName'] in [ 'RandomSNType', 'NugentClassifier' ] + if msg['classifierName'] == 'NugentClassifier': + assert len( msg['classifications'] ) == 1 + assert msg['classifications'][0]['classId'] == 2222 + assert msg['classifications'][0]['probability'] == 1.0 + numprobs += len( msg['classifications'] ) + assert numprobs == 11445 + + yield True + + + @pytest.fixture( scope="class" ) + def update_elasticc2_diasource_300days( self, classifications_300days_elasticc2_ingested ): + result = subprocess.run( [ "python", "manage.py", "update_elasticc2_sources" ], + cwd="/tom_desc", capture_output=True ) + assert result.returncode == 0 + + # Have to have tests here because of hysteresis (search for that word above) + obj = elasticc2.models.DiaObject + src = elasticc2.models.DiaSource + frced = elasticc2.models.DiaForcedSource + targ = tom_targets.models.Target + ooft = elasticc2.models.DiaObjectOfTarget + bsid = elasticc2.models.BrokerSourceIds + + assert bsid.objects.count() == 0 + assert obj.objects.count() == 102 + # TODO -- put these next two lines back in once we start doing this thing again + # assert ooft.objects.count() == obj.objects.count() + # assert targ.objects.count() == obj.objects.count() + assert src.objects.count() == 545 + assert frced.objects.count() == 4242 + + yield True + + + @pytest.fixture( scope="class" ) + def update_fastdb_dev_diasource_300days( self, classifications_300days_fastdb_dev_ingested ): + result = subprocess.run( [ "python", "manage.py", "load_fastdb", + "--pv", "test_pv", "--snapshot", "test_ss", + "--tag", "test_ss_tag", + "--brokers", "fakebroker" ], + cwd="/tom_desc", + capture_output=True ) + assert result.returncode == 0 + + lut = fastdb_dev.models.LastUpdateTime + obj = fastdb_dev.models.DiaObject + src = fastdb_dev.models.DiaSource + frced = fastdb_dev.models.DiaForcedSource + cfer = fastdb_dev.models.BrokerClassifier + cification = fastdb_dev.models.BrokerClassification + pver = fastdb_dev.models.ProcessingVersions + ss = fastdb_dev.models.Snapshots + dspvss = fastdb_dev.models.DStoPVtoSS + dfspvss = fastdb_dev.models.DFStoPVtoSS + + assert lut.objects.count() == 1 + assert lut.objects.first().last_update_time > datetime.datetime.fromtimestamp( 0, tz=datetime.timezone.utc ) + assert lut.objects.first().last_update_time < datetime.datetime.now( tz=datetime.timezone.utc ) + + # TODO : right now, load_fastdb.py imports the future -- that is, it imports + # the full ForcedSource lightcure for an object for which we got a source + # the first time that source is seen, and never looks at forcedsources + # again. Update the tests numbers if/when it simulates not knowing the + # future. + # (Really, we should probably creat a whole separate simulated PPDB server with + # an interface that will look something like the real PPDB interface... when + # we actually know what that is.) + + assert obj.objects.count() == 102 + assert src.objects.count() == 545 + assert frced.objects.count() == 15760 # 4242 + assert cfer.objects.count() == 2 + # assert cification.objects.count() == 831 # ???? WHy is this not 545 * 2 ? LOOK INTO THIS + # # ---> seems to be non-deterministic! + # TODO : pver, ss, dpvss, dfspvss + + yield True + + # WORRY. This will screw things up if this fixture is run + # before something the update_300days_*_ingested fixtures. + # But, we don't want to make those prerequisites for this one, + # because we don't want to have to run (in particular) the fastdb_dev + # ingestion fixtures when we don't need to, because it's slow. So, + # Just figure that right now, the classes that are using these fixtures + # won't do things wrong, and leave the dependency out. + @pytest.fixture( scope="class" ) + def alerts_100daysmore( self, alerts_300days, topic_barf, fakebroker ): + # This will send alerts up through mjd 60676. Why not 60678, since the previous + # sent through 60578? There were no alerts between 60675 and 60679, so the last + # alert sent will have been a source from mjd 60675. That's what the 100 days + # are added to. + # This is an additional 105 alerts, for a total of 650 (coming from 131 objects). + # (Note that we have to make sure the udpate_*_diasource_300days + # fixtures have run before this fixture runs, because this fixture + # is going to add more alerts, leading the running fakebroker to + # add more classifications, and then there will be more + # classifications sitting on the kafka server than those fixtures + # (and their prerequisites) are expecting.) + result = subprocess.run( [ "python", "manage.py", "send_elasticc2_alerts", + "-a", "100", + "-k", "kafka-server:9092", + "--wfd-topic", f"alerts-wfd-{topic_barf}", + "--ddf-full-topic", f"alerts-ddf-full-{topic_barf}", + "--ddf-limited-topic", f"alerts-ddf-limited-{topic_barf}", + "-s", "/tests/schema/elasticc.v0_9_1.alert.avsc", + "-r", "sending_alerts_runningfile", + "--do" ], + cwd="/tom_desc", capture_output=True ) + sys.stderr.write( result.stderr.decode( 'utf-8' ) ) + assert result.returncode == 0 + + yield True + + + @pytest.fixture( scope="class" ) + def classifications_100daysmore_elasticc2_ingested( self, alerts_100daysmore, mongoclient, + brokerpoll_elasticc2 ): + # This time we need to allow for both the 10s sleep cycle timeout of + # brokerpoll and fakebroker (since we're not checking + # classifications exist separately from ingested) + time.sleep( 22 ) + + brkmsg = elasticc2.models.BrokerMessage + cfer = elasticc2.models.BrokerClassifier + + # 650 total alerts times 2 classifiers = 1300 broker messages + assert len( brkmsg.objects.all() ) == 1300 + assert cfer.objects.count() == 2 + assert len( cfer.objects.all() ) == 2 + + numprobs = 0 + for msg in brkmsg.objects.all(): + assert len(msg.classid) == len(msg.probability) + numprobs += len(msg.classid) + # 650 from NugentClassifier plus 20*650 for RandomSNType + assert numprobs == 13650 + + assert ( set( [ i.classifiername for i in cfer.objects.all() ] ) + == set( [ "NugentClassifier", "RandomSNType" ] ) ) + + yield True + + + @pytest.fixture( scope="class" ) + def classifications_100daysmore_fastdb_dev_ingested( self, alerts_100daysmore, mongoclient, + brokerpoll_fastdb_dev ): + # This time we need to allow for both the 10s sleep cycle timeout of + # brokerpoll and fakebroker (since we're not checking + # classifications exist separately from ingested) + time.sleep( 22 ) + + db = mongoclient.alerts + + assert 'fakebroker' in db.list_collection_names() + + coll = db.fakebroker + assert coll.count_documents({}) == 1300 + + numprobs = 0 + for msg in coll.find(): + msg = msg['msg'] + assert msg['brokerName'] == 'FakeBroker' + assert msg['classifierName'] in [ 'RandomSNType', 'NugentClassifier' ] + if msg['classifierName'] == 'NugentClassifier': + assert len( msg['classifications'] ) == 1 + assert msg['classifications'][0]['classId'] == 2222 + assert msg['classifications'][0]['probability'] == 1.0 + numprobs += len( msg['classifications'] ) + assert numprobs == 13650 + + yield True + + + @pytest.fixture( scope="class" ) + def update_elasticc2_diasource_100daysmore( self, classifications_100daysmore_elasticc2_ingested ): + result = subprocess.run( [ "python", "manage.py", "update_elasticc2_sources" ], + cwd="/tom_desc", capture_output=True ) + assert result.returncode == 0 + + obj = elasticc2.models.DiaObject + src = elasticc2.models.DiaSource + frced = elasticc2.models.DiaForcedSource + targ = tom_targets.models.Target + ooft = elasticc2.models.DiaObjectOfTarget + bsid = elasticc2.models.BrokerSourceIds + + assert bsid.objects.count() == 0 + assert obj.objects.count() == 131 + # TODO: put these next two lines back in once we start doing this again + # assert ooft.objects.count() == obj.objects.count() + # assert targ.objects.count() == obj.objects.count() + assert src.objects.count() == 650 + assert frced.objects.count() == 5765 + + yield True + + + @pytest.fixture( scope="class" ) + def update_fastdb_dev_diasource_100daysmore( self, classifications_100daysmore_fastdb_dev_ingested ): + # SEE COMMENTS IN update_fastdb_dev_diasource_300days + + result = subprocess.run( [ "python", "manage.py", "load_fastdb", + "--pv", "test_pv", "--snapshot", "test_ss", + "--tag", "test_ss_tag", + "--brokers", "fakebroker" ], + cwd="/tom_desc", + capture_output=True ) + assert result.returncode == 0 + + lut = fastdb_dev.models.LastUpdateTime + obj = fastdb_dev.models.DiaObject + src = fastdb_dev.models.DiaSource + frced = fastdb_dev.models.DiaForcedSource + cfer = fastdb_dev.models.BrokerClassifier + cification = fastdb_dev.models.BrokerClassification + pver = fastdb_dev.models.ProcessingVersions + ss = fastdb_dev.models.Snapshots + dspvss = fastdb_dev.models.DStoPVtoSS + dfspvss = fastdb_dev.models.DFStoPVtoSS + + assert lut.objects.count() == 1 + assert lut.objects.first().last_update_time > datetime.datetime.fromtimestamp( 0, tz=datetime.timezone.utc ) + assert lut.objects.first().last_update_time < datetime.datetime.now( tz=datetime.timezone.utc ) + + # TODO : right now, load_fastdb.py imports the future -- that is, it imports + # the full ForcedSource lightcure for an object for which we got a source + # the first time that source is seen, and never looks at forcedsources + # again. Update the tests numbers if/when it simulates not knowing the + # future. + # (Really, we should probably creat a whole separate simulated PPDB server with + # an interface that will look something like the real PPDB interface... when + # we actually know what that is.) + + assert obj.objects.count() == 131 + assert src.objects.count() == 650 + assert frced.objects.count() == 20834 # 5765 + assert cfer.objects.count() == 2 + # assert cification.objects.count() == ... # ???? WHy is this not 650 * 2 ? LOOK INTO THIS + # TODO : pver, ss, dpvss, dfspvss + + yield True + + @pytest.fixture( scope="class" ) + def api_classify_existing_alerts( self, alerts_100daysmore, apibroker_client, topic_barf ): + result = subprocess.run( [ "python", "apiclassifier.py", + "--source", "kafka-server:9092", + "-t", f"alerts-wfd-{topic_barf}", f"alerts-ddf-full-{topic_barf}", + "-g", "apibroker", + "-u", "apibroker", + "-p", "testing", + "-s", "2", + "-a", "/tests/schema/elasticc.v0_9_1.alert.avsc", + "-b", "/tests/schema/elasticc.v0_9_1.brokerClassification.avsc" + ], + cwd="/tests", capture_output=True ) + sys.stderr.write( result.stderr.decode( 'utf-8' ) ) + assert result.returncode == 0 + + yield True diff --git a/tests/alertcyclefixtures.py b/tests/alertcyclefixtures.py index 8135a3a4..e69de29b 100644 --- a/tests/alertcyclefixtures.py +++ b/tests/alertcyclefixtures.py @@ -1,511 +0,0 @@ -# IMPORTANT -- running any tests that depend on fixtures in this file -# OTHER than alert_cycle_complete requires a completely fresh -# environment. After any run of "pytest ...", if you want to run tests -# (e.g. in test_alert_cycle.py) that use these fixtures, you have to -# completely tear down and rebuild the docker compose environment. This -# is because, as noted below, we can't easily clean up the kafka -# server's state, so on a rerun, the server state will be wrong. (We -# also use that as a reason to be lazy and not clean up the database; -# see the long comment below.) - -import sys -import os -import pathlib -import datetime -import time -import pytz -import random -import subprocess -import pytest - -from pymongo import MongoClient - -sys.path.insert( 0, "/tom_desc" ) -os.environ["DJANGO_SETTINGS_MODULE"] = "tom_desc.settings" -import django -django.setup() - -import elasticc2.models -import fastdb_dev.models -import tom_targets.models - -from tom_client import TomClient -from msgconsumer import MsgConsumer - -# NOTE -- in many of the fixtures below there are lots of tests that -# would normally be in the tests_* file that use the fixtures. The -# reason they're here is because it's hard (impossible without a bunch -# of ugly hacks) to really clean up after these fixtures -- in -# particular, cleaning up the kafka server topic is something I can't -# just do here, but would have to do *on* the kafka server. So, once -# some of the later fixtures have run, tests that depend only on earlier -# fixtures would start to fail. The solution is to make all the -# fixtures session scoped, and to put the tests that have this -# hysteresis problem inside the fixtures, so they'll only be run once, -# and we can control the order in which the fixtures are run. That will -# also then allow us to use these fixtures in more than one set of -# tests. - -# Because of this, lots of fixtures don't bother cleaning up, even if -# they could. In fact, they deliberately choose not to clean up, -# so that the database will be in the state it is at the end of the -# full alert cycle; the alert_cycle_complete fixture then detects that -# and runs the slow fixtures or not as necessary. - -# Any tests that use these fixtures and are going to test actual numbers -# in the database should only depend on alert_cycle_complete. Once all -# these fixtures have run (perhaps from an earlier test), the numbers -# that come out of earlier fixtures will no longer be right. If any -# fixture other than alert_cycle_complete is run when the other fixtures -# have already been run once in a given docker compose environment, the -# database will be changed, and the fixtures will fail. - -# The numbers in these tests are based on the SNANA files in the -# directory elasticc2_alert_test_data under tests, which should -# be unpacked from elasticc2_alert_test_data.tar.bz2. - -class AlertCounter: - def __init__( self ): - self._test_alerts_exist_count = 0 - - def handle_test_alerts_exist( self, msgs ): - self._test_alerts_exist_count += len(msgs) - - -@pytest.fixture( scope="session" ) -def alerts_300days( elasticc2_ppdb ): - result = subprocess.run( [ "python", "manage.py", "send_elasticc2_alerts", "-d", "60578", - "-k", "kafka-server:9092", - "--wfd-topic", "alerts-wfd", "--ddf-full-topic", "alerts-ddf-full", - "--ddf-limited-topic", "alerts-ddf-limited", - "-s", "/tests/schema/elasticc.v0_9_1.alert.avsc", - "-r", "sending_alerts_runningfile", "--do" ], - cwd="/tom_desc", capture_output=True ) - sys.stderr.write( result.stderr.decode( 'utf-8' ) ) - assert result.returncode == 0 - - consumer = MsgConsumer( 'kafka-server:9092', 'test_send_alerts', [ 'alerts-wfd', 'alerts-ddf-full' ], - '/tests/schema/elasticc.v0_9_1.alert.avsc', - consume_nmsgs=100 ) - counter = AlertCounter() - consumer.poll_loop( counter.handle_test_alerts_exist, timeout=10, stopafter=datetime.timedelta(seconds=10) ) - # I don't understand why this is 546. 545 were sent. - # The fake broker sees 545. - assert counter._test_alerts_exist_count == 546 - consumer.close() - - yield True - - -@pytest.fixture( scope="session" ) -def classifications_300days_exist( alerts_300days ): - - counter = AlertCounter() - consumer = MsgConsumer( 'kafka-server:9092', 'test_classifications_exist', 'classifications', - '/tests/schema/elasticc.v0_9_1.brokerClassification.avsc', - consume_nmsgs=100 ) - consumer.reset_to_start( 'classifications' ) - - # fake broker has a 10s sleep loop, so we can't - # assume things will be there instantly; thus, the 16s timeout. - - consumer.poll_loop( counter.handle_test_alerts_exist, timeout=5, - stopafter=datetime.timedelta(seconds=16) ) - - # This is 2x545 - assert counter._test_alerts_exist_count == 1090 - consumer.close() - - yield True - - -@pytest.fixture( scope="session" ) -def classifications_300days_elasticc2_ingested( classifications_300days_exist ): - # Have to have an additional sleep after the classifications exist, - # because brokerpoll itself has a 10s sleep loop - time.sleep( 11 ) - - # Have to have these tests here rather than in the actual test_* - # file because I can't clean up, and there is hysteresis. Once - # later fixtures have run, the tests below would fail, and these - # fixtures may be used in more than one test. - - brkmsg = elasticc2.models.BrokerMessage - cfer = elasticc2.models.BrokerClassifier - bsid = elasticc2.models.BrokerSourceIds - - assert brkmsg.objects.count() == 1090 - assert cfer.objects.count() == 2 - assert bsid.objects.count() == 545 - - numprobs = 0 - for msg in brkmsg.objects.all(): - assert len(msg.classid) == len(msg.probability) - numprobs += len(msg.classid) - - # 545 from NugentClassifier plus 20*545 for RandomSNType - assert numprobs == 11445 - - # TODO : check that the data is identical for - # corresponding entries in the two cassbroker - # tables - - assert ( set( [ i.classifiername for i in cfer.objects.all() ] ) - == set( [ "NugentClassifier", "RandomSNType" ] ) ) - - yield True - - -@pytest.fixture( scope="session" ) -def classifications_300days_fastdb_dev_ingested( classifications_300days_exist ): - # Have to have an additional sleep after the classifications exist, - # because brokerpoll itself has a 10s sleep loop - time.sleep( 11 ) - - # Have to have these tests here rather than in the actual test_* - # file because I can't clean up, and there is hysteresis. Once - # later fixtures have run, the tests below would fail, and these - # fixtures may be used in more than one test. - - host = os.getenv( 'MONGOHOST' ) - username = os.getenv( 'MONGODB_ALERT_READER' ) - password = os.getenv( 'MONGODB_ALERT_READER_PASSWORD' ) - client = MongoClient( f"mongodb://{username}:{password}@{host}:27017/?authSource=alerts" ) - db = client.alerts - - assert 'fakebroker' in db.list_collection_names() - - coll = db.fakebroker - assert coll.count_documents({}) == 1090 - - numprobs = 0 - for msg in coll.find(): - msg = msg['msg'] - assert msg['brokerName'] == 'FakeBroker' - assert msg['classifierName'] in [ 'RandomSNType', 'NugentClassifier' ] - if msg['classifierName'] == 'NugentClassifier': - assert len( msg['classifications'] ) == 1 - assert msg['classifications'][0]['classId'] == 2222 - assert msg['classifications'][0]['probability'] == 1.0 - numprobs += len( msg['classifications'] ) - assert numprobs == 11445 - - yield True - -@pytest.fixture( scope="session" ) -def update_elasticc2_diasource_300days( classifications_300days_elasticc2_ingested ): - result = subprocess.run( [ "python", "manage.py", "update_elasticc2_sources" ], - cwd="/tom_desc", capture_output=True ) - assert result.returncode == 0 - - # Have to have tests here because of hysteresis (search for that word above) - obj = elasticc2.models.DiaObject - src = elasticc2.models.DiaSource - frced = elasticc2.models.DiaForcedSource - targ = tom_targets.models.Target - ooft = elasticc2.models.DiaObjectOfTarget - bsid = elasticc2.models.BrokerSourceIds - - assert bsid.objects.count() == 0 - assert obj.objects.count() == 102 - # TODO -- put these next two lines back in once we start doing this thing again - # assert ooft.objects.count() == obj.objects.count() - # assert targ.objects.count() == obj.objects.count() - assert src.objects.count() == 545 - assert frced.objects.count() == 4242 - - yield True - - -@pytest.fixture( scope="session" ) -def update_fastdb_dev_diasource_300days( classifications_300days_fastdb_dev_ingested ): - result = subprocess.run( [ "python", "manage.py", "load_fastdb", - "--pv", "test_pv", "--snapshot", "test_ss", - "--tag", "test_ss_tag", - "--brokers", "fakebroker" ], - cwd="/tom_desc", - capture_output=True ) - assert result.returncode == 0 - - lut = fastdb_dev.models.LastUpdateTime - obj = fastdb_dev.models.DiaObject - src = fastdb_dev.models.DiaSource - frced = fastdb_dev.models.DiaForcedSource - cfer = fastdb_dev.models.BrokerClassifier - cification = fastdb_dev.models.BrokerClassification - pver = fastdb_dev.models.ProcessingVersions - ss = fastdb_dev.models.Snapshots - dspvss = fastdb_dev.models.DStoPVtoSS - dfspvss = fastdb_dev.models.DFStoPVtoSS - - assert lut.objects.count() == 1 - assert lut.objects.first().last_update_time > datetime.datetime.fromtimestamp( 0, tz=datetime.timezone.utc ) - assert lut.objects.first().last_update_time < datetime.datetime.now( tz=datetime.timezone.utc ) - - # TODO : right now, load_fastdb.py imports the future -- that is, it imports - # the full ForcedSource lightcure for an object for which we got a source - # the first time that source is seen, and never looks at forcedsources - # again. Update the tests numbers if/when it simulates not knowing the - # future. - # (Really, we should probably creat a whole separate simulated PPDB server with - # an interface that will look something like the real PPDB interface... when - # we actually know what that is.) - - assert obj.objects.count() == 102 - assert src.objects.count() == 545 - assert frced.objects.count() == 15760 # 4242 - assert cfer.objects.count() == 2 - # assert cification.objects.count() == 831 # ???? WHy is this not 545 * 2 ? LOOK INTO THIS - # # ---> seems to be non-deterministic! - # TODO : pver, ss, dpvss, dfspvss - - yield True - -@pytest.fixture( scope="session" ) -def alerts_100daysmore( alerts_300days ): - # This will send alerts up through mjd 60676. Why not 60678, since the previous - # sent through 60578? There were no alerts between 60675 and 60679, so the last - # alert sent will have been a source from mjd 60675. That's what the 100 days - # are added to. - # This is an additional 105 alerts, for a total of 650 (coming from 131 objects). - result = subprocess.run( [ "python", "manage.py", "send_elasticc2_alerts", "-a", "100", - "-k", "kafka-server:9092", - "--wfd-topic", "alerts-wfd", "--ddf-full-topic", "alerts-ddf-full", - "--ddf-limited-topic", "alerts-ddf-limited", - "-s", "/tests/schema/elasticc.v0_9_1.alert.avsc", - "-r", "sending_alerts_runningfile", "--do" ], - cwd="/tom_desc", capture_output=True ) - sys.stderr.write( result.stderr.decode( 'utf-8' ) ) - assert result.returncode == 0 - - yield True - - # Same issue as alerts_300days about not cleaning up - -@pytest.fixture( scope="session" ) -def classifications_100daysmore_elasticc2_ingested( alerts_100daysmore ): - # This time we need to allow for both the 10s sleep cycle timeout of - # brokerpoll and fakebroker (since we're not checking - # classifications exist separately from ingested) - time.sleep( 22 ) - - # Tests here because of hysteresis - - brkmsg = elasticc2.models.BrokerMessage - cfer = elasticc2.models.BrokerClassifier - - # 650 total alerts times 2 classifiers = 1300 broker messages - assert len( brkmsg.objects.all() ) == 1300 - assert cfer.objects.count() == 2 - assert len( cfer.objects.all() ) == 2 - - numprobs = 0 - for msg in brkmsg.objects.all(): - assert len(msg.classid) == len(msg.probability) - numprobs += len(msg.classid) - # 650 from NugentClassifier plus 20*650 for RandomSNType - assert numprobs == 13650 - - assert ( set( [ i.classifiername for i in cfer.objects.all() ] ) - == set( [ "NugentClassifier", "RandomSNType" ] ) ) - - yield True - - -@pytest.fixture( scope="session" ) -def classifications_100daysmore_fastdb_dev_ingested( alerts_100daysmore ): - # This time we need to allow for both the 10s sleep cycle timeout of - # brokerpoll and fakebroker (since we're not checking - # classifications exist separately from ingested) - time.sleep( 22 ) - - # Tests here because of hysteresis - - host = os.getenv( 'MONGOHOST' ) - username = os.getenv( 'MONGODB_ALERT_READER' ) - password = os.getenv( 'MONGODB_ALERT_READER_PASSWORD' ) - client = MongoClient( f"mongodb://{username}:{password}@{host}:27017/?authSource=alerts" ) - db = client.alerts - - assert 'fakebroker' in db.list_collection_names() - - coll = db.fakebroker - assert coll.count_documents({}) == 1300 - - numprobs = 0 - for msg in coll.find(): - msg = msg['msg'] - assert msg['brokerName'] == 'FakeBroker' - assert msg['classifierName'] in [ 'RandomSNType', 'NugentClassifier' ] - if msg['classifierName'] == 'NugentClassifier': - assert len( msg['classifications'] ) == 1 - assert msg['classifications'][0]['classId'] == 2222 - assert msg['classifications'][0]['probability'] == 1.0 - numprobs += len( msg['classifications'] ) - assert numprobs == 13650 - - yield True - - -@pytest.fixture( scope="session" ) -def update_elasticc2_diasource_100daysmore( classifications_100daysmore_elasticc2_ingested ): - result = subprocess.run( [ "python", "manage.py", "update_elasticc2_sources" ], - cwd="/tom_desc", capture_output=True ) - assert result.returncode == 0 - - obj = elasticc2.models.DiaObject - src = elasticc2.models.DiaSource - frced = elasticc2.models.DiaForcedSource - targ = tom_targets.models.Target - ooft = elasticc2.models.DiaObjectOfTarget - bsid = elasticc2.models.BrokerSourceIds - - assert bsid.objects.count() == 0 - assert obj.objects.count() == 131 - # TODO: put these next two lines back in once we start doing this again - # assert ooft.objects.count() == obj.objects.count() - # assert targ.objects.count() == obj.objects.count() - assert src.objects.count() == 650 - assert frced.objects.count() == 5765 - - yield True - - -@pytest.fixture( scope="session" ) -def update_fastdb_dev_diasource_100daysmore( classifications_100daysmore_fastdb_dev_ingested ): - # SEE COMMENTS IN update_fastdb_dev_diasource_300days - - result = subprocess.run( [ "python", "manage.py", "load_fastdb", - "--pv", "test_pv", "--snapshot", "test_ss", - "--tag", "test_ss_tag", - "--brokers", "fakebroker" ], - cwd="/tom_desc", - capture_output=True ) - assert result.returncode == 0 - - lut = fastdb_dev.models.LastUpdateTime - obj = fastdb_dev.models.DiaObject - src = fastdb_dev.models.DiaSource - frced = fastdb_dev.models.DiaForcedSource - cfer = fastdb_dev.models.BrokerClassifier - cification = fastdb_dev.models.BrokerClassification - pver = fastdb_dev.models.ProcessingVersions - ss = fastdb_dev.models.Snapshots - dspvss = fastdb_dev.models.DStoPVtoSS - dfspvss = fastdb_dev.models.DFStoPVtoSS - - assert lut.objects.count() == 1 - assert lut.objects.first().last_update_time > datetime.datetime.fromtimestamp( 0, tz=datetime.timezone.utc ) - assert lut.objects.first().last_update_time < datetime.datetime.now( tz=datetime.timezone.utc ) - - # TODO : right now, load_fastdb.py imports the future -- that is, it imports - # the full ForcedSource lightcure for an object for which we got a source - # the first time that source is seen, and never looks at forcedsources - # again. Update the tests numbers if/when it simulates not knowing the - # future. - # (Really, we should probably creat a whole separate simulated PPDB server with - # an interface that will look something like the real PPDB interface... when - # we actually know what that is.) - - assert obj.objects.count() == 131 - assert src.objects.count() == 650 - assert frced.objects.count() == 20834 # 5765 - assert cfer.objects.count() == 2 - # assert cification.objects.count() == ... # ???? WHy is this not 650 * 2 ? LOOK INTO THIS - # TODO : pver, ss, dpvss, dfspvss - - yield True - -@pytest.fixture( scope="session" ) -def api_classify_existing_alerts( alerts_100daysmore, apibroker_client ): - result = subprocess.run( [ "python", "apiclassifier.py", "--source", "kafka-server:9092", - "-t", "alerts-wfd", "alerts-ddf-full", - "-g", "apibroker", "-u", "apibroker", "-p", "testing", "-s", "2", - "-a", "/tests/schema/elasticc.v0_9_1.alert.avsc", - "-b", "/tests/schema/elasticc.v0_9_1.brokerClassification.avsc"], - cwd="/tests", capture_output=True ) - sys.stderr.write( result.stderr.decode( 'utf-8' ) ) - assert result.returncode == 0 - - yield True - -@pytest.fixture( scope="module" ) -def random_broker_classifications(): - brokers = { - 'rbc_test1': { - '1.0': { - 'classifiertest1': [ '1.0' ], - 'classifiertest2': [ '1.0' ] - } - }, - 'rbc_test2': { - '3.5': { - 'testing1': [ '42' ], - 'testing2': [ '23' ] - } - } - } - - minsrc = 10 - maxsrc = 20 - mincls = 1 - maxcls = 20 - - msgs = [] - for brokername, brokerspec in brokers.items(): - for brokerversion, versionspec in brokerspec.items(): - for classifiername, clsspec in versionspec.items(): - for classifierparams in clsspec: - nsrcs = random.randint( minsrc, maxsrc ) - for src in range(nsrcs): - ncls = random.randint( mincls, maxcls ) - probleft = 1.0 - classes = [] - probs = [] - for cls in range( ncls ): - classes.append( cls ) - prob = random.random() * probleft - probleft -= prob - probs.append( prob ) - classes.append( ncls ) - probs.append( probleft ) - - msgs.append( { 'sourceid': src, - 'brokername': brokername, - 'alertid': src, - 'elasticcpublishtimestamp': datetime.datetime.now( tz=pytz.utc ), - 'brokeringesttimestamp': datetime.datetime.now( tz=pytz.utc ), - 'brokerversion': brokerversion, - 'classifiername': classifiername, - 'classifierparams': classifierparams, - 'classid': classes, - 'probability': probs } ) - - yield msgs - - -@pytest.fixture( scope="session" ) -def alert_cycle_complete( request, tomclient ): - res = tomclient.post( 'db/runsqlquery/', - json={ 'query': 'SELECT COUNT(*) AS count FROM elasticc2_brokermessage' } ) - rows = res.json()[ 'rows' ] - if rows[0]['count'] == 0: - request.getfixturevalue( "update_elasticc2_diasource_100daysmore" ) - request.getfixturevalue( "update_fastdb_dev_diasource_100daysmore" ) - request.getfixturevalue( "api_classify_existing_alerts" ) - - yield True - - -__all__ = [ 'alerts_300days', - 'classifications_300days_exist', - 'classifications_300days_elasticc2_ingested', - 'classifications_300days_fastdb_dev_ingested', - 'update_elasticc2_diasource_300days', - 'update_fastdb_dev_diasource_300days', - 'alerts_100daysmore', - 'classifications_100daysmore_elasticc2_ingested', - 'classifications_100daysmore_fastdb_dev_ingested', - 'update_fastdb_dev_diasource_100daysmore', - 'api_classify_existing_alerts', - 'alert_cycle_complete' ] diff --git a/tests/apiclassifier.py b/tests/apiclassifier.py index ff2d189c..d5afaf32 100644 --- a/tests/apiclassifier.py +++ b/tests/apiclassifier.py @@ -7,12 +7,12 @@ import fastavro import datetime -from msgconsumer import MsgConsumer +from testmsgconsumer import MsgConsumer from tom_client import TomClient _rundir = pathlib.Path( __file__ ).parent -_logger = logging.getLogger( __name__ ) +_logger = logging.getLogger( "tests/apiclassifier" ) _logger.propagate = False if not _logger.hasHandlers(): _logout = logging.StreamHandler( sys.stderr ) diff --git a/tests/conftest.py b/tests/conftest.py index a03e2a4c..3ce40eb2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,9 @@ import subprocess import pytest +from pymongo import MongoClient + +# Make sure the django environment is fully set up sys.path.insert( 0, "/tom_desc" ) os.environ["DJANGO_SETTINGS_MODULE"] = "tom_desc.settings" import django @@ -16,30 +19,152 @@ import elasticc2.models from tom_client import TomClient +# Additional fixtures in other files +# sys.path.insert( 0, os.getenv("PWD") ) +# pytest_plugins = [ 'alertcyclefixtures' ] + + @pytest.fixture( scope="session" ) def tomclient(): return TomClient( "http://tom:8080", username="root", password="testing" ) +@pytest.fixture( scope="session" ) +def mongoclient(): + host = os.getenv( 'MONGOHOST' ) + username = os.getenv( 'MONGODB_ALERT_READER' ) + password = os.getenv( 'MONGODB_ALERT_READER_PASSWORD' ) + client = MongoClient( f"mongodb://{username}:{password}@{host}:27017/?authSource=alerts" ) + return client + @pytest.fixture( scope="session" ) def apibroker_client(): return TomClient( "http://tom:8080", username="apibroker", password="testing" ) -@pytest.fixture( scope="session" ) -def elasticc2_ppdb( tomclient ): - basedir = pathlib.Path( "/elasticc2data" ) - dirs = [] - for subdir in basedir.glob( '*' ): - if subdir.is_dir(): - result = subprocess.run( [ "python", "manage.py", "load_snana_fits", "-d", str(subdir), "--ppdb", "--do" ], - cwd="/tom_desc", capture_output=True ) - assert result.returncode == 0 +def load_elasticc2_database_snapshot( *args ): + models = args + for m in models: + assert m.objects.count() == 0 + + tables = [ m._meta.db_table for m in models ] + args = [ "pg_restore", + "--data-only", + "-h", "postgres", + "-U", "postgres", + "-d", "tom_desc" ] + for t in tables: + args.append( "-t" ) + args.append( t ) + args.append( "elasticc2_alertcycle_complete.psqlc" ) + res = subprocess.run( args, cwd="/tests", env={ "PGPASSWORD": "fragile" }, capture_output=True ) + assert res.returncode == 0 + + return models + + +@pytest.fixture +def elasticc2_ppdb(): + models = load_elasticc2_database_snapshot( elasticc2.models.PPDBAlert, + elasticc2.models.PPDBDiaForcedSource, + elasticc2.models.PPDBDiaObject, + elasticc2.models.PPDBDiaSource, + elasticc2.models.DiaObjectTruth ) yield True + for m in models: + m.objects.all().delete() + + +@pytest.fixture( scope="class" ) +def elasticc2_ppdb_class(): + models = load_elasticc2_database_snapshot( elasticc2.models.PPDBAlert, + elasticc2.models.PPDBDiaForcedSource, + elasticc2.models.PPDBDiaObject, + elasticc2.models.PPDBDiaSource, + elasticc2.models.DiaObjectTruth ) + yield True + for m in models: + m.objects.all().delete() + + +@pytest.fixture +def elasticc2_database_snapshot( elasticc2_ppdb ): + models = load_elasticc2_database_snapshot( elasticc2.models.BrokerClassifier, + elasticc2.models.BrokerMessage, + elasticc2.models.DiaForcedSource, + elasticc2.models.DiaObject, + elasticc2.models.DiaSource, + elasticc2.models.DiaObjectInfo, + elasticc2.models.BrokerSourceIds ) + yield True + for m in models: + m.objects.all().delete() + +@pytest.fixture( scope='class' ) +def elasticc2_database_snapshot_class( elasticc2_ppdb_class ): + models = load_elasticc2_database_snapshot( elasticc2.models.BrokerClassifier, + elasticc2.models.BrokerMessage, + elasticc2.models.DiaForcedSource, + elasticc2.models.DiaObject, + elasticc2.models.DiaSource, + elasticc2.models.DiaObjectInfo, + elasticc2.models.BrokerSourceIds ) + yield True + for m in models: + m.objects.all().delete() + + + +@pytest.fixture( scope="class" ) +def random_broker_classifications(): + brokers = { + 'rbc_test1': { + '1.0': { + 'classifiertest1': [ '1.0' ], + 'classifiertest2': [ '1.0' ] + } + }, + 'rbc_test2': { + '3.5': { + 'testing1': [ '42' ], + 'testing2': [ '23' ] + } + } + } + + minsrc = 10 + maxsrc = 20 + mincls = 1 + maxcls = 20 + + msgs = [] + for brokername, brokerspec in brokers.items(): + for brokerversion, versionspec in brokerspec.items(): + for classifiername, clsspec in versionspec.items(): + for classifierparams in clsspec: + nsrcs = random.randint( minsrc, maxsrc ) + for src in range(nsrcs): + ncls = random.randint( mincls, maxcls ) + probleft = 1.0 + classes = [] + probs = [] + for cls in range( ncls ): + classes.append( cls ) + prob = random.random() * probleft + probleft -= prob + probs.append( prob ) + classes.append( ncls ) + probs.append( probleft ) - elasticc2.models.DiaObjectTruth.objects.all().delete() - elasticc2.models.PPDBAlert.objects.all().delete() - elasticc2.models.PPDBDiaForcedSource.objects.all().delete() - elasticc2.models.PPDBDiaSource.objects.all().delete() - elasticc2.models.PPDBDiaObject.objects.all().delete() + msgs.append( { 'sourceid': src, + 'brokername': brokername, + 'alertid': src, + 'elasticcpublishtimestamp': datetime.datetime.now( tz=pytz.utc ), + 'brokeringesttimestamp': datetime.datetime.now( tz=pytz.utc ), + 'brokerversion': brokerversion, + 'classifiername': classifiername, + 'classifierparams': classifierparams, + 'classid': classes, + 'probability': probs } ) + yield msgs diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index ded0f2f8..2bbb865c 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -1,6 +1,6 @@ services: kafka-zookeeper: - image: registry.nersc.gov/m1727/raknop/kafka + image: ${TOM_DOCKER_ARCHIVE:-ghcr.io/lsstdesc/tom_desc}/kafka:${TOM_DOCKER_VERSION:-latest} build: context: ../docker_kafka healthcheck: @@ -14,7 +14,7 @@ services: depends_on: kafka-zookeeper: condition: service_healthy - image: registry.nersc.gov/m1727/raknop/kafka + image: ${TOM_DOCKER_ARCHIVE:-ghcr.io/lsstdesc/tom_desc}/kafka:${TOM_DOCKER_VERSION:-latest} build: context: ../docker_kafka healthcheck: @@ -24,17 +24,8 @@ services: retries: 5 entrypoint: [ "bin/kafka-server-start.sh", "config/server.properties" ] - fakebroker: - depends_on: - kafka-server: - condition: service_healthy - image: registry.nersc.gov/m1727/raknop/fakebroker - build: - context: . - dockerfile: Dockerfile.fakebroker - postgres: - image: registry.nersc.gov/m1727/raknop/tom-postgres + image: ${TOM_DOCKER_ARCHIVE:-ghcr.io/lsstdesc/tom_desc}/tom-postgres:${TOM_DOCKER_VERSION:-latest} build: context: ../docker_postgres target: tom-postgres @@ -45,7 +36,7 @@ services: retries: 5 mongodb: - image: registry.nersc.gov/m1727/rknop/tom-mongodb:latest + image: ${TOM_DOCKER_ARCHIVE:-ghcr.io/lsstdesc/tom_desc}/tom-mongodb:${TOM_DOCKER_VERSION:-latest} build: context: ../docker_mongodb environment: @@ -73,7 +64,7 @@ services: condition: service_healthy mongodb: condition: service_healthy - image: registry.nersc.gov/m1727/raknop/tom_desc_bindmount + image: ${TOM_DOCKER_ARCHIVE:-ghcr.io/lsstdesc/tom_desc}/tom_desc_bindmount:${TOM_DOCKER_VERSION:-latest} build: context: ../ dockerfile: docker_server/Dockerfile @@ -120,7 +111,7 @@ services: depends_on: createdb: condition: service_completed_successfully - image: registry.nersc.gov/m1727/raknop/tom_desc_bindmount + image: ${TOM_DOCKER_ARCHIVE:-ghcr.io/lsstdesc/tom_desc}/tom_desc_bindmount:${TOM_DOCKER_VERSION:-latest} build: context: ../ dockerfile: docker_server/Dockerfile @@ -154,79 +145,6 @@ services: MONGODB_ALERT_READER_PASSWORD: reader - brokerpoll: - depends_on: - createdb: - condition: service_completed_successfully - tom: - condition: service_started - fakebroker: - condition: service_started - image: registry.nersc.gov/m1727/raknop/tom_desc_bindmount - build: - context: ../ - dockerfile: docker_server/Dockerfile - target: tom-server-bindmount - volumes: - - type: bind - source: ../secrets - target: /secrets - - type: bind - source: ../tom_desc - target: /tom_desc - - type: volume - source: logs - target: /logs - environment: - LOGDIR: /logs - DB_NAME: tom_desc - DB_HOST: postgres - DB_USER: postgres - DB_PASS: fragile - DB_PORT: 5432 - entrypoint: [ "python", "manage.py", "brokerpoll2", "--do-test", "--grouptag", "elasticc2" ] - - - brokerpoll_fastdb_dev: - depends_on: - createdb: - condition: service_completed_successfully - tom: - condition: service_started - fakebroker: - condition: service_started - image: registry.nersc.gov/m1727/raknop/tom_desc_bindmount - build: - context: ../ - dockerfile: docker_server/Dockerfile - target: tom-server-bindmount - volumes: - - type: bind - source: ../secrets - target: /secrets - - type: bind - source: ../tom_desc - target: /tom_desc - - type: volume - source: logs - target: /logs - environment: - LOGDIR: /logs - MONGOHOST: mongodb - MONGODB_ADMIN: mongodb_admin - MONGODB_ADMIN_PASSWORD: fragile - MONGODB_ALERT_WRITER: mongodb_alert_writer - MONGODB_ALERT_WRITER_PASSWORD: writer - MONGODB_ALERT_READER: mongdb_alert_reader - MONGODB_ALERT_READER_PASSWORD: reader - DB_NAME: tom_desc - DB_HOST: postgres - DB_USER: postgres - DB_PASS: fragile - DB_PORT: 5432 - entrypoint: [ "python", "manage.py", "fastdb_dev_brokerpoll", "--do-test", "--grouptag", "fastdb_dev" ] - - # Thought required: want to make this dependent on # createdb completed successfully, or just on the # database servers being up? The advantage of the latter @@ -241,13 +159,15 @@ services: condition: service_healthy tom: condition: service_started - fakebroker: - condition: service_started - brokerpoll: - condition: service_started - brokerpoll_fastdb_dev: - condition: service_started - image: registry.nersc.gov/m1727/raknop/tom_server_bindmount_dev + kafka-server: + condition: service_healthy + # fakebroker: + # condition: service_started + # brokerpoll: + # condition: service_started + # brokerpoll_fastdb_dev: + # condition: service_started + image: ${TOM_DOCKER_ARCHIVE:-ghcr.io/lsstdesc/tom_desc}/tom_server_bindmount_dev:${TOM_DOCKER_VERSION:-latest} build: context: ../ dockerfile: docker_server/Dockerfile diff --git a/tests/elasticc2_alertcycle_complete.psqlc b/tests/elasticc2_alertcycle_complete.psqlc new file mode 100644 index 00000000..5c8b4632 Binary files /dev/null and b/tests/elasticc2_alertcycle_complete.psqlc differ diff --git a/tests/fakebroker.py b/tests/fakebroker.py index 39a53875..dafcc7ef 100644 --- a/tests/fakebroker.py +++ b/tests/fakebroker.py @@ -10,18 +10,18 @@ import confluent_kafka import fastavro -from msgconsumer import MsgConsumer +sys.stderr.write( "FAKEBROKER importing testmsgconsumer\n" ) +from testmsgconsumer import MsgConsumer _rundir = pathlib.Path( __file__ ).parent -_logger = logging.getLogger( __name__ ) +_logger = logging.getLogger( "fakebroker" ) _logger.propagate = False -if not _logger.hasHandlers(): - _logout = logging.StreamHandler( sys.stderr ) - _logger.addHandler( _logout ) - _formatter = logging.Formatter( f'[%(asctime)s - %(levelname)s] - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' ) - _logout.setFormatter( _formatter ) +_logout = logging.StreamHandler( sys.stderr ) +_logger.addHandler( _logout ) +_formatter = logging.Formatter( f'[%(asctime)s - fakebroker - %(levelname)s] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' ) +_logout.setFormatter( _formatter ) _logger.setLevel( logging.INFO ) # ====================================================================== @@ -44,7 +44,7 @@ def __init__( self, brokername, brokerversion, classifiername, classifierparams, self.nextlog = self.logevery def determine_types_and_probabilities( self, alert ): - """Given an alert (a dict in the format of the elasticc alert schema), return a list of + """Given an alert (a dict in the format of the elasticc alert schema), return a list of two-element tuples that is (classId, probability).""" raise RuntimeError( "Need to implement this function in a subclass!" ) @@ -111,7 +111,79 @@ def determine_types_and_probabilities( self, alert ): retval.append( ( 2242, 1-totprob ) ) return retval -# ====================================================================== +# ====================================================================== + +class FakeBroker: + def __init__( self, + source, + source_topics, + dest, + dest_topic, + group_id="rknop-test", + alert_schema=f"{_rundir}/schema/elasticc.v0_9_1.alert.avsc", + brokermessage_schema=f"{_rundir}/schema/elasticc.v0_9_1.brokerClassification.avsc", + reset=False ): + self.source = source + self.source_topics = source_topics + self.dest = dest + self.dest_topic = dest_topic + self.group_id = group_id + self.reset = reset + + self.alert_schema = alert_schema + alertschemaobj = fastavro.schema.load_schema( alert_schema ) + brokermsgschema = fastavro.schema.load_schema( brokermessage_schema ) + self.classifiers = [ NugentClassifier( kafkaserver=self.dest, topic=self.dest_topic, + alertschema=alertschemaobj, brokermessageschema=brokermsgschema ), + RandomSNType( kafkaserver=self.dest, topic=self.dest_topic, + alertschema=alertschemaobj, brokermessageschema=brokermsgschema ) + ] + + def handle_message_batch( self, msgs ): + for cfer in self.classifiers: + cfer.classify_alerts( msgs ) + + def __call__( self ): + if self.reset: + topicstoreset = set( self.source_topics ) + else: + topicstoreset = set() + consumer = None + while True: + subbed = [] + if consumer is not None: + consumer.close() + consumer = MsgConsumer( self.source, self.group_id, [], self.alert_schema, logger=_logger, + consume_nmsgs=100 ) + # Wait for the topic to exist, and only then subscribe + while len(subbed) == 0: + topics = consumer.topic_list() + _logger.debug( f"Topics seen on server: {topics}" ) + for topic in self.source_topics: + if topic in topics: + subbed.append( topic ) + if len(subbed) > 0: + _logger.debug( f"Subscribing to topics {subbed}" ) + if len(subbed) < len( self.source_topics ): + missing = [ i for i in self.source_topics if i not in subbed ] + _logger.debug( f"(Didn't see topics: {missing})" ) + consumer.subscribe( subbed ) + else: + _logger.warning( f"No topics in {self.source_topics} exists, sleeping 10s and trying again." ) + time.sleep( 10 ) + + if len(topicstoreset) > 0: + for topic in subbed: + if topic in topicstoreset: + consumer.reset_to_start( topic ) + topicstoreset.remove( topic ) + + consumer.poll_loop( handler=self.handle_message_batch, + stoponnomessages=(len(subbed) 0: - _logger.info( f"Subscribing to topics {subbed}" ) - if len(subbed) < len( args.source_topics ): - missing = [ i for i in args.source_topics if i not in subbed ] - _logger.info( f"(Didn't see topics: {missing})" ) - consumer.subscribe( subbed ) - else: - _logger.warning( f"No topics in {args.source_topics} exists, sleeping 10s and trying again." ) - time.sleep( 10 ) - - if len(topicstoreset) > 0: - for topic in subbed: - if topic in topicstoreset: - consumer.reset_to_start( topic ) - topicstoreset.remove( topic ) - - consumer.poll_loop( handler = handle_message_batch, stoponnomessages=(len(subbed)= nextprint: - t1 = time.perf_counter() - _logger.info( f"Cass: loaded {n} of {len(lots_of_alerts)} alerts " - f"in {t1-tinit:.2f} sec (current rate {printevery/(t1-t0):.0f} s⁻¹)" ) - nextprint += printevery - t0 = t1 - - t1 = time.perf_counter() - _logger.info( f"Cass: done loading {len(lots_of_alerts)} alerts in {t1-tinit:.2f} sec " - f"({len(lots_of_alerts)/(t1-tinit):.0f} s⁻¹)" ) - - yield lots_of_alerts - - cur = django.db.connection.cursor() - cur.execute( "TRUNCATE TABLE elasticc2_brokersourceids" ) - cur.execute( "TRUNCATE TABLE elasticc2_brokerclassifier" ) - casscur = django.db.connections['cassandra'].connection.cursor() - casscur.execute( "TRUNCATE TABLE tom_desc.cass_broker_message_by_time" ) - casscur.execute( "TRUNCATE TABLE tom_desc.cass_broker_message_by_source" ) - - def counter( self, rows ): - self.tot += len( rows ) - if self.future.has_more_pages: - self.future.start_fetching_next_page() - else: - self.done.set() - - def err( self, exc ): - self.error = exc - self.done.set() - - def test_cass( self, load_cass ): - # This doesn't work. I think we're getting paged or - # something. The django cassandra interface - # dissapoints me - # bysrc = CassBrokerMessageBySource.objects.all() - # assert len(bysrc) == len(load_cass) - # bytim = CassBrokerMessageByTime.objects.all() - # assert len(bytim) == len(load_cass ) - - casssession = django.db.connections['cassandra'].connection.session - - self.tot = 0 - self.error = None - self.done = threading.Event() - self.future = casssession.execute_async( "SELECT * FROM tom_desc.cass_broker_message_by_source" ) - self.future.add_callbacks( callback=self.counter, errback=self.err ) - self.done.wait() - if self.error: - raise self.error - assert self.tot == len(load_cass) - - self.tot = 0 - self.error = None - self.done = threading.Event() - self.future = casssession.execute_async( "SELECT * FROM tom_desc.cass_broker_message_by_time" ) - self.future.add_callbacks( callback=self.counter, errback=self.err ) - self.done.wait() - if self.error: - raise self.error - assert self.tot == len(load_cass) - - def test_gen_brokerdelay( self, load_cass ): - t0 = ( datetime.datetime.now( tz=pytz.utc ) + datetime.timedelta( days=-1 ) ).date().isoformat() - t1 = ( datetime.datetime.now( tz=pytz.utc ) + datetime.timedelta( days=2 ) ).date().isoformat() - dt = 1 - - tstart = time.perf_counter() - res = subprocess.run( [ "python", "manage.py", "gen_elasticc2_brokerdelaygraphs", - "--t0", t0, "--t1", t1, "--dt", str(dt) ], - cwd="/tom_desc", capture_output=True ) - assert res.returncode == 0 - tend = time.perf_counter() - _logger.info( f"gen_elasticc2_brokerdelaygraphs took {tend-tstart:.1f} sec" ) - - class TestPostgres: @pytest.fixture(scope='class') diff --git a/tests/test_db_createable.py b/tests/test_db_createable.py new file mode 100644 index 00000000..49c44c2b --- /dev/null +++ b/tests/test_db_createable.py @@ -0,0 +1,3 @@ +# TODO +# +# Write tests for db/models.py::Creatable diff --git a/tests/test_db_sqlquery.py b/tests/test_db_sqlquery.py index dce57f63..2f4343dc 100644 --- a/tests/test_db_sqlquery.py +++ b/tests/test_db_sqlquery.py @@ -18,12 +18,6 @@ import db.models -# Use the last alertcycle fixture api_classify_existing_alerts to get -# the database into a known state. This is kind of slow, due to all the -# delays in the various polling servers, so it means a ~1 min delay on -# the tests here actually running, but what can you do. -from alertcyclefixtures import * - class TestSQLWebInterface: def test_run_query( self, tomclient ): @@ -125,7 +119,7 @@ def test_run_query( self, tomclient ): # having been run, in order. @pytest.fixture( scope='class' ) - def submit_long_query( self, tomclient, alert_cycle_complete ): + def submit_long_query( self, tomclient, elasticc2_database_snapshot_class ): res = tomclient.post( 'db/submitsqlquery/', json= { 'query': 'SELECT * FROM elasticc2_brokermessage', 'format': 'pandas' } ); diff --git a/tests/test_elasticc2_alertcycle.py b/tests/test_elasticc2_alertcycle.py index 196d9453..44070960 100644 --- a/tests/test_elasticc2_alertcycle.py +++ b/tests/test_elasticc2_alertcycle.py @@ -1,14 +1,3 @@ -# WARNING -- if you run both this test and test_fastdb_dev_alertcycle -# within the same docker compose session, but different pytest -# sessions, one will fail. For the reason, see the comments in -# alertcyclefixtures.py. (Basically, the first one you run will load -# up both databases, so early tests that expect not-fully-loaded -# databases will fail.) -# -# Both should all pass if you run them both at once, i.e. -# -# pytest -v test_elasticc2_alertcycle.py test_fastdb_dev_alertcycle.py - import os import sys import datetime @@ -17,22 +6,28 @@ sys.path.insert( 0, "/tom_desc" ) import elasticc2.models +import tom_targets.models -from msgconsumer import MsgConsumer - -# pytest is mysterious. I tried importing just the fixtures I was using -# form alertcyclefixtures, but the a fixture there that used another -# fixture from alertcyclefixtures that I did *not* import here couldn't -# find that other fixture. So, I import *, and use an __all__ in -# alertcyclefixtures. -from alertcyclefixtures import * +from alertcycle_testbase import AlertCycleTestBase # NOTE -- many of the actual tests are run in the fixtures rather than -# the tests below. See comments in alercyclefixtures.py for the reason for +# the tests below. See comments in alertcycle_testbase.py for the reason for # this. -class TestElasticc2AlertCycle: - def test_ppdb_loaded( self, elasticc2_ppdb ): +class TestElasticc2AlertCycle( AlertCycleTestBase ): + _models_to_cleanup = [ elasticc2.models.BrokerMessage, + elasticc2.models.BrokerClassifier, + elasticc2.models.BrokerSourceIds, + elasticc2.models.DiaObjectOfTarget, + tom_targets.models.Target, + elasticc2.models.DiaForcedSource, + elasticc2.models.DiaSource, + elasticc2.models.DiaObject ] + + def _cleanup( self ): + pass + + def test_ppdb_loaded( self, elasticc2_ppdb_class ): # I should probably have some better tests than just object counts.... assert elasticc2.models.PPDBDiaObject.objects.count() == 346 assert elasticc2.models.PPDBDiaSource.objects.count() == 1862 @@ -41,10 +36,6 @@ def test_ppdb_loaded( self, elasticc2_ppdb ): assert elasticc2.models.DiaObjectTruth.objects.count() == elasticc2.models.PPDBDiaObject.objects.count() - def handle_test_send_alerts( self, msgs ): - self._test_send_alerts_count += len(msgs) - - def test_send_alerts( self, alerts_300days ): assert alerts_300days @@ -71,8 +62,6 @@ def test_100moredays_sources_updated( self, update_elasticc2_diasource_100daysmo def test_apibroker_existingsources( self, api_classify_existing_alerts ): cfer = elasticc2.models.BrokerClassifier - # brkmsgsrc = elasticc2.models.CassBrokerMessageBySource - # brkmsgtim = elasticc2.models.CassBrokerMessageByTime brkmsg = elasticc2.models.BrokerMessage assert cfer.objects.count() == 3 @@ -107,4 +96,3 @@ def test_apibroker_existingsources( self, api_classify_existing_alerts ): assert onemsg.msghdrtimestamp >= onemsg.brokeringesttimestamp assert onemsg.msghdrtimestamp - onemsg.brokeringesttimestamp < datetime.timedelta(seconds=5) assert onemsg.descingesttimestamp - onemsg.msghdrtimestamp < datetime.timedelta(seconds=5) - diff --git a/tests/test_elasticc2_api.py b/tests/test_elasticc2_api.py index 20580eea..a52cb1d7 100644 --- a/tests/test_elasticc2_api.py +++ b/tests/test_elasticc2_api.py @@ -3,7 +3,7 @@ class TestReconstructAlert: - def test_reconstruct_alert( self, elasticc2_ppdb ): + def test_reconstruct_alert( self, elasticc2_ppdb_class ): # Get an alert I know is from the first day, make sure that # it doesn't have any previous forced sources @@ -85,7 +85,7 @@ def test_reconstruct_alert( self, elasticc2_ppdb ): - def test_alert_api( self, elasticc2_ppdb, tomclient ): + def test_alert_api( self, elasticc2_ppdb_class, tomclient ): res = tomclient.post( "elasticc2/getalert", json={ 'alertid': 666 } ) assert res.status_code == 500 @@ -109,7 +109,7 @@ def test_alert_api( self, elasticc2_ppdb, tomclient ): class TestLtcv: - def test_ltcv_features( self, elasticc2_ppdb, tomclient ): + def test_ltcv_features( self, elasticc2_ppdb_class, tomclient ): # Default features for an object diff --git a/tests/test_elasticc2_cassandra.py b/tests/test_elasticc2_cassandra.py deleted file mode 100644 index c20569f1..00000000 --- a/tests/test_elasticc2_cassandra.py +++ /dev/null @@ -1,86 +0,0 @@ -import pytest -import os -import sys -import random -import datetime -import pytz - -# The next 4 lines are necessary to get access to -# django models without doing python manage.py shell -sys.path.insert( 0, "/tom_desc" ) -os.environ["DJANGO_SETTINGS_MODULE"] = "tom_desc.settings" -import django -django.setup() - -import django.db - -import elasticc2.models as m - -class TestElasticc2Cassandra: - - #TODO lots more - - @pytest.fixture(scope='class') - def load_some( self ): - brokers = { - 'test1': { - '1.0': { - 'classifiertest1': [ '1.0' ], - 'classifiertest2': [ '1.0' ] - } - }, - 'test2': { - '3.5': { - 'testing1': [ '42' ], - 'testing2': [ '23' ] - } - } - } - - minsrc = 10 - maxsrc = 20 - mincls = 1 - maxcls = 20 - - msgs = [] - classifier_id = 0 - for brokername, brokerspec in brokers.items(): - for brokerversion, versionspec in brokerspec.items(): - for classifiername, clsspec in versionspec.items(): - for classifierparams in clsspec: - nsrcs = random.randint( minsrc, maxsrc ) - for src in range(nsrcs): - ncls = random.randint( mincls, maxcls ) - probleft = 1.0 - classes = [] - probs = [] - for cls in range( ncls ): - classes.append( cls ) - prob = random.random() * probleft - probleft -= prob - probs.append( prob ) - classes.append( ncls ) - probs.append( probleft ) - - msg = m.CassBrokerMessage( sourceid=src, - classifier_id=classifier_id, - alertid=src, - elasticcpublishtimestmp=datetime.datetime.now( tz=pytz.utc ), - brokeringesttimestamp=datetime.datetime.now( tz=pytz.utc ), - classid=classes, - probability=probs ) - msgs.append( msg ) - - classifier_id += 1 - - m.CassBrokerMessage.objects.bulk_create( msgs ) - - yield true - - conn = django.db.connections['cassandra'].cursor().connection - conn.execute( "TRUNCATE TABLE tom_desc.cass_broker_message" ) - - @pytest.mark.xfail( reason="Fix fixture to not use no-longer-existing bulk_create" ) - def test_some_loaded( self, load_some ): - import pdb; pdb.set_trace() - assert m.CassBrokerMessage.objects.count() > 0 diff --git a/tests/test_elasticc2_dump_alert_tar.py b/tests/test_elasticc2_dump_alert_tar.py index 33f56730..8b361ea1 100644 --- a/tests/test_elasticc2_dump_alert_tar.py +++ b/tests/test_elasticc2_dump_alert_tar.py @@ -11,7 +11,7 @@ from alertcyclefixtures import * class TestDumpAlertTar: - def test_dump_tar( self, elasticc2_ppdb ): + def test_dump_tar( self, elasticc2_ppdb_class ): try: # Just make sure things are as expected assert elasticc2.models.PPDBDiaObject.objects.count() == 346 @@ -29,8 +29,6 @@ def test_dump_tar( self, elasticc2_ppdb ): cwd="/tom_desc", capture_output=True ) assert res.returncode == 0 - import pdb; pdb.set_trace() - # Check that the expected tar files exist, and spot-check one assert all( pathlib.Path( f"/tests/{i}.tar" ).is_file() for i in [ 60278, 60279, 60281, 60282, 60283, 60286 ] ) diff --git a/tests/test_elasticc2_load_snana_fits.py b/tests/test_elasticc2_load_snana_fits.py index 46a2e4fe..663d70f7 100644 --- a/tests/test_elasticc2_load_snana_fits.py +++ b/tests/test_elasticc2_load_snana_fits.py @@ -15,12 +15,31 @@ class TestLoadSnanaFits: - def test_ppdb_loaded( self, elasticc2_ppdb ): + @pytest.fixture( scope="class" ) + def snana_loaded_elasticc2_ppdb( tomclient ): + basedir = pathlib.Path( "/elasticc2data" ) + dirs = [] + for subdir in basedir.glob( '*' ): + if subdir.is_dir(): + result = subprocess.run( [ "python", "manage.py", "load_snana_fits", + "-d", str(subdir), "--ppdb", "--do" ], + cwd="/tom_desc", capture_output=True ) + assert result.returncode == 0 + + yield True + + m.DiaObjectTruth.objects.all().delete() + m.PPDBAlert.objects.all().delete() + m.PPDBDiaForcedSource.objects.all().delete() + m.PPDBDiaSource.objects.all().delete() + m.PPDBDiaObject.objects.all().delete() + + def test_ppdb_loaded( self, snana_loaded_elasticc2_ppdb ): # I should probably have some better tests than just object counts.... - assert m.PPDBDiaObject.objects.count() == 138 - assert m.PPDBDiaSource.objects.count() == 429 + assert m.PPDBDiaObject.objects.count() == 346 + assert m.PPDBDiaSource.objects.count() == 1862 assert m.PPDBAlert.objects.count() == m.PPDBDiaSource.objects.count() - assert m.PPDBDiaForcedSource.objects.count() == 34284 + assert m.PPDBDiaForcedSource.objects.count() == 52172 assert m.DiaObjectTruth.objects.count() == m.PPDBDiaObject.objects.count() @@ -37,6 +56,8 @@ def count_ppdb( self ): @pytest.fixture( scope="class" ) def elasticc2_training( self, count_ppdb ): + # Loding in exactly the same data for test purposes, + # just to differnt tables basedir = pathlib.Path( "/elasticc2data" ) dirs = [] for subdir in basedir.glob( '*' ): @@ -61,9 +82,9 @@ def test_training_tables_loaded( self, elasticc2_training ): assert m.PPDBDiaForcedSource.objects.count() == self.__class__._ppdbdiaforcedsources assert m.DiaObjectTruth.objects.count() == self.__class__._ppdbdiaobjecttruths - assert m.TrainingDiaObject.objects.count() == 138 - assert m.TrainingDiaSource.objects.count() == 429 - assert m.TrainingDiaForcedSource.objects.count() == 34284 + assert m.TrainingDiaObject.objects.count() == 346 + assert m.TrainingDiaSource.objects.count() == 1862 + assert m.TrainingDiaForcedSource.objects.count() == 52172 assert m.TrainingAlert.objects.count() == m.TrainingDiaSource.objects.count() assert ( m.TrainingDiaObjectTruth.objects.count() == m.TrainingDiaObject.objects.count() ) diff --git a/tests/test_elasticc2_models.py b/tests/test_elasticc2_models.py index 954ad6e6..22c32648 100644 --- a/tests/test_elasticc2_models.py +++ b/tests/test_elasticc2_models.py @@ -4,9 +4,7 @@ import dateutil.parser import pytz -from elasticc2.models import ( CassBrokerMessageBySource, - CassBrokerMessageByTime, - BrokerSourceIds, +from elasticc2.models import ( BrokerSourceIds, BrokerClassifier, BrokerMessage ) @@ -59,43 +57,10 @@ def loaded_broker_classifications( self, alerts ): BrokerMessage.objects.filter( classifier_id__in=[ i.classifier_id for i in cfers ] ).delete() - @pytest.fixture( scope='class' ) - def loaded_cass_broker_classifications( self, alerts ): - # Some of the hardcodes here depend on details of what's inside - # random_broker_classifications in conftest.py - yield CassBrokerMessageBySource.load_batch( alerts ) - - # Hardcoded from what I know is in random_broker_classifications - cfers = BrokerClassifier.objects.filter( brokername__in=[ 'rbc_test1', 'rbc_test2'] ) - - # This doesn't work; Cassandra is picky about deleting stuff in bulk - # CassBrokerMessage.objects.filter( brokername__in=[ 'rbc_test1', 'rbc_test2' ] ).delete() - # So we do the slow thing, which will be OK given the small number of messages - msgs = CassBrokerMessageByTime.objects.filter( classifier_id__in=[ i.classifier_id for i in cfers ] ) - for msg in msgs: - msg.delete() - msgs = CassBrokerMessageBySource.objects.filter( classifier_id__in= [ i.classifier_id for i in cfers ] ) - for msg in msgs: - msg.delete() - def test_hello_world( self ): # This is just here so I can get a timestamp to see how long the next test took assert True - # def test_alert_reconstruct( self, elasticc2_ppdb ): - # pass - - def test_cassbrokermessage_bulk( self, loaded_cass_broker_classifications ): - assert CassBrokerMessageBySource.objects.count() >= loaded_cass_broker_classifications[ 'addedmsgs' ] - assert CassBrokerMessageBySource.objects.count() == CassBrokerMessageByTime.objects.count() - cfers = BrokerClassifier.objects.filter( brokername__in=[ 'rbc_test1', 'rbc_test2' ] ) - msgs = CassBrokerMessageBySource.objects.filter( classifier_id__in=[ i.classifier_id for i in cfers ] ) - assert msgs.count() == loaded_cass_broker_classifications[ 'addedmsgs' ] - sources = set() - for msg in msgs.all(): - sources.add( msg.diasource_id ) - assert sources.issubset( set( [ b.diasource_id for b in BrokerSourceIds.objects.all() ] ) ) - def test_brokermessage_bulk( self, loaded_broker_classifications ): assert BrokerMessage.objects.count() >= loaded_broker_classifications[ 'addedmsgs' ] cfers = BrokerClassifier.objects.filter( brokername__in=[ 'rbc_test1', 'rbc_test2' ] ) diff --git a/tests/test_elasticc2_spectrumcycle.py b/tests/test_elasticc2_spectrumcycle.py index b0a87236..f73f18d9 100644 --- a/tests/test_elasticc2_spectrumcycle.py +++ b/tests/test_elasticc2_spectrumcycle.py @@ -19,16 +19,16 @@ import elasticc2.models from tom_client import TomClient -from alertcyclefixtures import * - # I'm abusing pytest here by having tests depend on previous # tests, rather than making all dependencies fixtures. # I may fix that at some point. +# (But, truthfully, pytest abuses python so badly that +# you might as well just embrace it.) class TestSpectrumCycle: @pytest.fixture( scope='class' ) - def ask_for_spectra( self, update_diasource_100daysmore, tomclient ): + def ask_for_spectra( self, elasticc2_database_snapshot_class, tomclient ): objs = elasticc2.models.DiaObject.objects.all().order_by("diaobject_id") objs = list( objs ) @@ -51,7 +51,7 @@ def ask_for_spectra( self, update_diasource_100daysmore, tomclient ): # TODO : test things other than detected_since_mjd sent to gethottransients - def test_hot_sne( self, update_diasource_100daysmore, tomclient ): + def test_hot_sne( self, elasticc2_database_snapshot_class, tomclient ): # Testing detected_in_last_days is fraught because # the mjds in elasticc2 are what they are, are # in the future (as of this comment writing). @@ -59,10 +59,10 @@ def test_hot_sne( self, update_diasource_100daysmore, tomclient ): res = tomclient.post( 'elasticc2/gethottransients', json={ 'detected_since_mjd': 60660 } ) sne = res.json()['diaobject'] - assert len(sne) == 5 + assert len(sne) == 8 snids = { s['objectid'] for s in sne } - assert snids == { 15232, 416626, 1263066, 1286131, 1913410 } + assert snids == { 15232, 1913410, 2110476, 416626, 1286131, 1684659, 1045654, 1263066 } # Should probably check more than this... assert set( sne[0].keys() ) == { 'objectid', 'ra', 'dec', 'photometry', 'zp', 'redshift', 'sncode' } @@ -78,6 +78,28 @@ def test_ask_for_spectra( self, ask_for_spectra, tomclient ): assert wnt.requester == "tests" assert wnt.priority == prio + # Verify that if we ask again for a spectrum, it overwrites the previous request + lstobj = objs[-1] + lstprio = prios[-1] + newprio = lstprio + 1 if lstprio < 5 else 1 + + oldn = elasticc2.models.WantedSpectra.objects.count() + old_dbobj = elasticc2.models.WantedSpectra.objects.filter( requester='tests', + diaobject_id=lstobj )[0] + assert old_dbobj.priority == lstprio + + res = tomclient.post( 'elasticc2/askforspectrum', json={ 'requester': 'tests', + 'objectids': [ lstobj ], + 'priorities': [ newprio ] } ) + assert res.status_code == 200 + + assert elasticc2.models.WantedSpectra.objects.count() == oldn + dbobjs = elasticc2.models.WantedSpectra.objects.filter( requester='tests', + diaobject_id=lstobj ) + assert len( dbobjs ) == 1 + assert dbobjs[0].wanttime > old_dbobj.wanttime + assert dbobjs[0].priority == newprio + def test_what_are_wanted_initial( self, ask_for_spectra, tomclient ): objs, prios = ask_for_spectra @@ -287,3 +309,10 @@ def test_get_spectrum_info( self, ask_for_spectra, tomclient ): assert data['spectra'][0]['mjd'] == pytest.approx( 65536., abs=0.01 ) assert data['spectra'][0]['z'] == pytest.approx( 0.25, abs=0.01 ) assert data['spectra'][0]['classid'] == 2222 + + + def test_cleanup( self ): + # Lots of previous tests left stuff in the database. Clean it out. + elasticc2.models.SpectrumInfo.objects.all().delete() + elasticc2.models.WantedSpectra.objects.all().delete() + elasticc2.models.PlannedSpectra.objects.all().delete() diff --git a/tests/test_fastdb_dev_alertcycle.py b/tests/test_fastdb_dev_alertcycle.py index ffc8ef85..1aa84244 100644 --- a/tests/test_fastdb_dev_alertcycle.py +++ b/tests/test_fastdb_dev_alertcycle.py @@ -1,38 +1,45 @@ -# WARNING -- if you run both this test and test_elasticc2_alertcycle -# within the same docker compose session, but different pytest -# sessions, one will fail. For the reason, see the comments in -# alertcyclefixtures.py. (Basically, the first one you run will load -# up both databases, so early tests that expect not-fully-loaded -# databases will fail.) -# -# Both should all pass if you run them both at once, i.e. -# -# pytest -v test_elasticc2_alertcycle.py test_fastdb_dev_alertcycle.py - import os import sys import datetime import time +from pymongo import MongoClient + sys.path.insert( 0, "/tom_desc" ) +import fastdb_dev.models import elasticc2.models -from msgconsumer import MsgConsumer - -# pytest is mysterious. I tried importing just the fixtures I was using -# form alertcyclefixtures, but the a fixture there that used another -# fixture from alertcyclefixtures that I did *not* import here couldn't -# find that other fixture. So, I import *, and use an __all__ in -# alertcyclefixtures. -from alertcyclefixtures import * +from alertcycle_testbase import AlertCycleTestBase # NOTE -- many of the actual tests are run in the fixtures rather than -# the tests below. See comments in alercyclefixtures.py for the reason for +# the tests below. See comments in alertcycle_testbase.py for the reason for # this. -class TestFastDBDevAlertCycle: - def test_ppdb_loaded( self, elasticc2_ppdb ): +class TestFastDBDevAlertCycle( AlertCycleTestBase ): + _models_to_cleanup = [ fastdb_dev.models.BrokerClassification, + fastdb_dev.models.BrokerClassifier, + fastdb_dev.models.DiaForcedSource, + fastdb_dev.models.DiaSource, + fastdb_dev.models.DiaObject, + fastdb_dev.models.DStoPVtoSS, + fastdb_dev.models.DFStoPVtoSS, + fastdb_dev.models.Snapshots, + fastdb_dev.models.ProcessingVersions ] + + def _cleanup( self ): + host = os.getenv( 'MONGOHOST' ) + username = os.getenv( 'MONGODB_ADMIN' ) + password = os.getenv( 'MONGODB_ADMIN_PASSWORD' ) + client = MongoClient( f"mongodb://{username}:{password}@{host}:27017/" ) + db = client.alerts + if 'fakebroker' in db.list_collection_names(): + coll = db.fakebroker + coll.drop() + assert 'fakebroker' not in db.list_collection_names() + + + def test_ppdb_loaded( self, elasticc2_ppdb_class ): # I should probably have some better tests than just object counts.... assert elasticc2.models.PPDBDiaObject.objects.count() == 346 assert elasticc2.models.PPDBDiaSource.objects.count() == 1862 @@ -41,10 +48,6 @@ def test_ppdb_loaded( self, elasticc2_ppdb ): assert elasticc2.models.DiaObjectTruth.objects.count() == elasticc2.models.PPDBDiaObject.objects.count() - def handle_test_send_alerts( self, msgs ): - self._test_send_alerts_count += len(msgs) - - def test_send_alerts( self, alerts_300days ): assert alerts_300days diff --git a/tests/test_load_e2_snapshot.py b/tests/test_load_e2_snapshot.py new file mode 100644 index 00000000..50ec3615 --- /dev/null +++ b/tests/test_load_e2_snapshot.py @@ -0,0 +1,16 @@ +import elasticc2.models + +def test_elasticc2_database_snapshot( elasticc2_database_snapshot ): + assert elasticc2.models.BrokerClassifier.objects.count() == 3 + assert elasticc2.models.BrokerMessage.objects.count() == 1950 + assert elasticc2.models.DiaForcedSource.objects.count() == 5765 + assert elasticc2.models.DiaObject.objects.count() == 131 + assert elasticc2.models.DiaObjectTruth.objects.count() == 346 + assert elasticc2.models.DiaSource.objects.count() == 650 + assert elasticc2.models.PPDBAlert.objects.count() == 1862 + assert elasticc2.models.PPDBDiaForcedSource.objects.count() == 52172 + assert elasticc2.models.PPDBDiaObject.objects.count() == 346 + assert elasticc2.models.PPDBDiaSource.objects.count() == 1862 + assert elasticc2.models.DiaObjectInfo.objects.count() == 268 + assert elasticc2.models.BrokerSourceIds.objects.count() == 650 + diff --git a/tests/msgconsumer.py b/tests/testmsgconsumer.py similarity index 74% rename from tests/msgconsumer.py rename to tests/testmsgconsumer.py index 9d9f71d3..ad37c393 100644 --- a/tests/msgconsumer.py +++ b/tests/testmsgconsumer.py @@ -9,13 +9,16 @@ import fastavro import confluent_kafka -_logger = logging.getLogger(__name__) +_logger = logging.getLogger( "testmsgconsumer" ) if not _logger.hasHandlers(): + sys.stderr.write( "ADDING HANDLER TO testmsgconsumer\n" ) _logout = logging.StreamHandler( sys.stderr ) _logger.addHandler( _logout ) - _formatter = logging.Formatter( f'[msgconsumer - %(asctime)s - %(levelname)s] - %(message)s', + _formatter = logging.Formatter( f'[%(asctime)s - msgconsumer - %(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) _logout.setFormatter( _formatter ) +else: + sys.stderr.write( "OMG I AM SURPRISED, testmsgconsumer already had handlers.\n" ) # _logger.setLevel( logging.INFO ) _logger.setLevel( logging.DEBUG ) @@ -75,7 +78,7 @@ def subscribe( self, topics ): if topics is not None and len(topics) > 0: self.consumer.subscribe( topics, on_assign=self._sub_callback ) else: - self.logger.warning( f'No topics given, not subscribing.' ) + self.logger.debug( f'No topics given, not subscribing.' ) def reset_to_start( self, topic ): self.logger.info( f'Resetting partitions for topic {topic}\n' ) @@ -84,26 +87,26 @@ def reset_to_start( self, topic ): self.logger.debug( "got throwaway message" if msg is not None else "did't get throwaway message" ) # Now do the reset partitions = self.consumer.list_topics( topic ).topics[topic].partitions + self.logger.debug( f"Found {len(partitions)} for topic {topic}" ) # partitions is a kmap - partlist = [] - # for partid, partinfo in partitions.items(): - # self.logger.info( f'...resetting {partid} ( {partinfo} )' ) - # # Is this next one redundant? partinfo should already have the right stuff! - # curpart = confluent_kafka.TopicPartition( topic, partinfo.id ) - for i in range(len(partitions)): - self.logger.info( f'...resetting partition {i}' ) - curpart = confluent_kafka.TopicPartition( topic, i ) - lowmark, highmark = self.consumer.get_watermark_offsets( curpart ) - self.logger.debug( f'Partition {curpart.topic} has id {curpart.partition} ' - f'and current offset {curpart.offset}; lowmark={lowmark} ' - f'and highmark={highmark}' ) - curpart.offset = lowmark - # curpart.offset = confluent_kafka.OFFSET_BEGINNING - if lowmark < highmark: - self.consumer.seek( curpart ) - partlist.append( curpart ) - self.logger.info( f'Committing partition offsets.' ) - self.consumer.commit( offsets=partlist, asynchronous=False ) + if len(partitions) > 0: + partlist = [] + for i in range(len(partitions)): + self.logger.info( f'...resetting partition {i}' ) + curpart = confluent_kafka.TopicPartition( topic, i ) + lowmark, highmark = self.consumer.get_watermark_offsets( curpart ) + self.logger.debug( f'Partition {curpart.topic} has id {curpart.partition} ' + f'and current offset {curpart.offset}; lowmark={lowmark} ' + f'and highmark={highmark}' ) + curpart.offset = lowmark + # curpart.offset = confluent_kafka.OFFSET_BEGINNING + if lowmark < highmark: + self.consumer.seek( curpart ) + partlist.append( curpart ) + self.logger.info( f'Committing partition offsets.' ) + self.consumer.commit( offsets=partlist, asynchronous=False ) + else: + self.logger.info( f"Resetting partitions: no partitions found, hope that means we're already reset...!" ) def topic_list( self ): cluster_meta = self.consumer.list_topics() @@ -114,7 +117,7 @@ def topic_list( self ): def print_topics( self ): topics = self.topic_list() topicstxt = '\n '.join(topics) - self.logger.info( f"\nTopics:\n {topicstxt}" ) + self.logger.debug( f"\nTopics:\n {topicstxt}" ) def _get_positions( self, partitions ): return self.consumer.position( partitions ) @@ -130,7 +133,7 @@ def print_assignments( self ): ofp = io.StringIO() ofp.write( "Current partition assignments\n" ) self._dump_assignments( ofp, asmgt ) - self.logger.info( ofp.getvalue() ) + self.logger.debug( ofp.getvalue() ) ofp.close() def _sub_callback( self, consumer, partitions ): @@ -138,7 +141,7 @@ def _sub_callback( self, consumer, partitions ): ofp = io.StringIO() ofp.write( "Consumer subscribed. Assigned partitions:\n" ) self._dump_assignments( ofp, self._get_positions( partitions ) ) - self.logger.info( ofp.getvalue() ) + self.logger.debug( ofp.getvalue() ) ofp.close() def poll_loop( self, handler=None, timeout=None, stopafter=datetime.timedelta(hours=1), @@ -150,35 +153,36 @@ def poll_loop( self, handler=None, timeout=None, stopafter=datetime.timedelta(ho done = False nsleeps = 0 while not done: - self.logger.info( f"Trying to consume {self.consume_nmsgs} messages " - f"with timeout {timeout} sec...\n" ) + self.logger.debug( f"Trying to consume {self.consume_nmsgs} messages " + f"with timeout {timeout} sec...\n" ) msgs = self.consumer.consume( self.consume_nmsgs, timeout=timeout ) if len(msgs) == 0: if ( stopafternsleeps is not None ) and ( nsleeps >= stopafternsleeps ): - self.logger.info( f"Stopping after {nsleeps} consecutive sleeps." ) + self.logger.debug( f"Stopping after {nsleeps} consecutive sleeps." ) done = True if stoponnomessages: - self.logger.info( f"No messages, ending poll_loop." ) + self.logger.debug( f"...no messages, ending poll_loop." ) done = True else: - self.logger.info( f"No messages, sleeping {self.nomsg_sleeptime} sec" ) + self.logger.debug( f"...no messages, sleeping {self.nomsg_sleeptime} sec" ) time.sleep( self.nomsg_sleeptime ) nsleeps += 1 else: + self.logger.debug( f"...got {len(msgs)} messages" ) nsleeps = 0 if handler is not None: handler( msgs ) else: self.default_handle_message_batch( msgs ) if (not done) and ( datetime.datetime.now() - t0 ) >= stopafter: - self.logger.info( f"Ending poll loop after {stopafter} seconds of polling." ) + self.logger.debug( f"Ending poll loop after {stopafter} seconds of polling." ) done = True def consume_one_message( self, timeout=None, handler=None ): """Both calls handler and returns a batch of 1 message.""" if timeout is None: timeout = self.consume_timeout - self.logger.info( f"Trying to consume one message with timeout {timeout} sec...\n" ) + self.logger.debug( f"Trying to consume one message with timeout {timeout} sec...\n" ) # msgs = self.consumer.consume( 1, timeout=self.consume_timeout ) msg = self.consumer.poll( timeout ) if msg is not None: @@ -189,7 +193,7 @@ def consume_one_message( self, timeout=None, handler=None ): return msg def default_handle_message_batch( self, msgs ): - self.logger.info( f'Handling {len(msgs)} messages' ) + self.logger.debug( f'Handling {len(msgs)} messages' ) timestamp_name = { confluent_kafka.TIMESTAMP_NOT_AVAILABLE: "TIMESTAMP_NOT_AVAILABLE", confluent_kafka.TIMESTAMP_CREATE_TIME: "TIMESTAMP_CREATE_TIME", confluent_kafka.TIMESTAMP_LOG_APPEND_TIME: "TIMESTAMP_LOG_APPEND_TIME" } @@ -209,9 +213,9 @@ def default_handle_message_batch( self, msgs ): # alert['brokerIngestTimestamp'] = alert['brokerIngestTimestamp'].timestamp() ofp.write( json.dumps( alert, indent=4, sort_keys=True, cls=DateTimeEncoder ) ) ofp.write( "\n" ) - self.logger.info( ofp.getvalue() ) + self.logger.debug( ofp.getvalue() ) ofp.close() self.tot_handled += len(msgs) - self.logger.info( f'Have handled {self.tot_handled} messages so far' ) + self.logger.debug( f'Have handled {self.tot_handled} messages so far' ) self.print_assignments() diff --git a/tom_desc/db/management/commands/_consumekafkamsgs.py b/tom_desc/db/management/commands/_consumekafkamsgs.py index b7795ae9..f2953236 100644 --- a/tom_desc/db/management/commands/_consumekafkamsgs.py +++ b/tom_desc/db/management/commands/_consumekafkamsgs.py @@ -9,13 +9,12 @@ import fastavro import confluent_kafka -_logger = logging.getLogger(__name__) -if not _logger.hasHandlers(): - _logout = logging.StreamHandler( sys.stderr ) - _logger.addHandler( _logout ) - _formatter = logging.Formatter( f'[msgconsumer - %(asctime)s - %(levelname)s] - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' ) - _logout.setFormatter( _formatter ) +_logger = logging.getLogger( "db/_consumekafkamsgs" ) +_logout = logging.StreamHandler( sys.stderr ) +_logger.addHandler( _logout ) +_formatter = logging.Formatter( f'[%(asctime)s - db/msgconsumer - %(levelname)s] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' ) +_logout.setFormatter( _formatter ) _logger.setLevel( logging.INFO ) # _logger.setLevel( logging.DEBUG ) diff --git a/tom_desc/db/models.py b/tom_desc/db/models.py index d2e37fa4..e73af859 100644 --- a/tom_desc/db/models.py +++ b/tom_desc/db/models.py @@ -131,16 +131,43 @@ def which_exist( cls, pks ): # This version uses postgres COPY and tries to be faster than mucking # about with ORM constructs. @classmethod - def bulk_insert_onlynew( cls, data, kwmap=None ): - """Insert a bunch of data into the database. Ignores records that conflict with things present. + def bulk_insert_or_upsert( cls, data, kwmap=None, upsert=False, assume_no_conflict=False ): + """Try to efficiently insert a bunch of data into the database. - data can be: - * a dict of { kw: iterable }. All of the iterables must have the same length, - and must be something that pandas.DataFrame could handle - * a list of dicts. The keys in all dicts must be the same - data and kwmap will be run through data_to_createdict + Parameters + ---------- + data: dict or list + Data can be: + * a dict of { kw: iterable }. All of the iterables must + have the same length, and must be something that + pandas.DataFrame could handle + * a list of dicts. The keys in all dicts must be the same - Returns the number of rows actually inserted (which may be less than len(data)). + kwmap: dict, default None + A map of { dict_keyword : model_keyword } of conversions + from data to the class model. (See data_to_createdict().) + Defaults to the _create_kws_map property of the object. + + upsert: bool, default False + If False, then objects whose primary key is already in the + database will be ignored. If True, then objects whose + primary key is already in the database will be updated with + the values in dict. (SQL will have ON CONFLICT DO NOTHING + if False, ON CONFLICT DO UPDATE if True.) + + assume_no_conflict: bool, default False + Usually you just want to leave this False. There are + obscure kludge cases (e.g. if you're playing games and have + removed primary key constraints and you know what you're + doing-- this happens in load_snana_fits.py, for instance) + where the conflict clauses cause the sql to fail. Set this + to True to avoid having those clauses. + + + Returns + ------- + inserted: int + The number of rows actually inserted (which may be less than len(data)). """ conn = None @@ -165,6 +192,9 @@ def bulk_insert_onlynew( cls, data, kwmap=None ): # float our double, thereby losing precision. I've checked it, and it seems # to be doing the right thing. But I have had issues in the past with # pandas silently converting data to doubles. + # **** + # sys.stderr.write( f"Calling data_to_createdict on {data} with kwmap {kwmap}\n" ) + # **** df = pandas.DataFrame( cls.data_to_createdict( data, kwmap=kwmap ) ) strio = io.StringIO() df.to_csv( strio, index=False, header=False, sep='\t', na_rep='\\N' ) @@ -174,8 +204,22 @@ def bulk_insert_onlynew( cls, data, kwmap=None ): # this will break if columns aren't all lower case) # columns = [ f'"{c}"' for c in df.columns.values ] columns = df.columns.values + # **** + # sys.stderr.write( f"Bulk uploading from:\n{strio.getvalue()}\ncolmns={columns}\n" ) + # **** cursor.copy_from( strio, "bulk_upsert", columns=columns, size=1048576 ) - q = f"INSERT INTO {cls._meta.db_table} SELECT * FROM bulk_upsert ON CONFLICT DO NOTHING" + if not assume_no_conflict: + if not upsert: + conflict = f"ON CONFLICT ({cls._pk}) DO NOTHING" + else: + conflict = ( f"ON CONFLICT ({cls._pk}) DO UPDATE SET " + + ",".join( f"{c}=EXCLUDED.{c}" for c in columns ) ) + else: + conflict = "" + q = f"INSERT INTO {cls._meta.db_table} SELECT * FROM bulk_upsert {conflict}" + # **** + # sys.stderr.write( f"Running query {q}\n" ) + # **** cursor.execute( q ) ninserted = cursor.rowcount # I don't think I should have to do this; shouldn't it happen automatically diff --git a/tom_desc/elasticc2/management/commands/brokerpoll2.py b/tom_desc/elasticc2/management/commands/brokerpoll2.py index b6994a0d..893f02de 100644 --- a/tom_desc/elasticc2/management/commands/brokerpoll2.py +++ b/tom_desc/elasticc2/management/commands/brokerpoll2.py @@ -69,7 +69,7 @@ def add_arguments( self, parser ): parser.add_argument( '--pitt-project', default=None, help="Project name for PITT-Google" ) parser.add_argument( '--do-test', action='store_true', default=False, help="Poll from kafka-server:9092 (for testing purposes)" ) - parser.add_argument( '---test-topic', default='classifications', + parser.add_argument( '--test-topic', default='classifications', help="Topic to poll from on kafka-server:9092" ) parser.add_argument( '-g', '--grouptag', default=None, help="Tag to add to end of kafka group ids" ) parser.add_argument( '-r', '--reset', default=False, action='store_true', @@ -119,12 +119,13 @@ def handle( self, *args, **options ): # Launch a process for each broker that will poll that broker indefinitely - # We want to make sure that django doesn't send copies of database sessions - # to the subprocesses; at least for Cassandra, that breaks things. So, - # before launching all the processes, close all the database django connections + # We want to make sure that django doesn't send copies of + # database sessions to the subprocesses. So, before launching + # all the processes, close all the database django connections # so that each process will open a new one as it needs it. - # (They already open mongo connections as necessary, and django doesn't muck - # about with mongo, so we don't have to do things for that.) + # (They already open mongo connections as necessary, and django + # doesn't muck about with mongo, so we don't have to do things + # for that.) django.db.connections.close_all() brokers = {} diff --git a/tom_desc/elasticc2/management/commands/gen_elasticc2_brokercompleteness.py b/tom_desc/elasticc2/management/commands/gen_elasticc2_brokercompleteness.py index ca5786fa..6cc9ea2d 100644 --- a/tom_desc/elasticc2/management/commands/gen_elasticc2_brokercompleteness.py +++ b/tom_desc/elasticc2/management/commands/gen_elasticc2_brokercompleteness.py @@ -20,7 +20,6 @@ import django.db from matplotlib import pyplot from django.core.management.base import BaseCommand, CommandError -import cassandra.query _rundir = pathlib.Path(__file__).parent diff --git a/tom_desc/elasticc2/management/commands/gen_elasticc2_brokerdelaygraphs.py b/tom_desc/elasticc2/management/commands/gen_elasticc2_brokerdelaygraphs.py deleted file mode 100644 index 0c6ecca6..00000000 --- a/tom_desc/elasticc2/management/commands/gen_elasticc2_brokerdelaygraphs.py +++ /dev/null @@ -1,415 +0,0 @@ -import sys -import io -import re -import math -import copy -import pathlib -import traceback -import time -import datetime -import dateutil.parser -import dateutil.tz -import pytz -import logging -import threading -import numpy -import pandas -import psycopg2 -import psycopg2.extras -import django.db -from matplotlib import pyplot -from django.core.management.base import BaseCommand, CommandError -import cassandra.query - -_rundir = pathlib.Path(__file__).parent - -_logger = logging.getLogger( __name__ ) -_logger.propagate = False -_logout = logging.StreamHandler( sys.stderr ) -_logger.addHandler( _logout ) -_formatter = logging.Formatter( f'[%(asctime)s - %(levelname)s] - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' ) -_logout.setFormatter( _formatter ) -_logger.setLevel( logging.DEBUG ) - -def makesubdf( bucketnums ): - rows = [] - for which in [ 'full', 'broker', 'tom' ]: - for buck in bucketnums: - rows.append( { 'which': which, 'buck': buck, 'count': 0 } ) - return pandas.DataFrame( rows ) - -def calcbucks( bucketleft, bucketright, dbucket ): - # Given postgres' width_bucket handling, everything *less than* bucketleft will be in bucket 0 - # Anything between bucketleft + (n-1)*dbucket and bucketleft + n*dbucket will be in bucket n - # Anything >= bucketleft + nbuckets*dbucket will be in bucket nbuckets + 1 - nbuckets = ( bucketright - bucketleft ) / dbucket - if ( float(int(nbuckets)) != nbuckets ): - raise ValueError( f"Can't divide {bucketleft} to {bucketright} evenly by {dbucket}" ) - nbuckets = int( nbuckets ) - bucketnums = numpy.array( range( nbuckets+2 ) ) - bucketleftedges = bucketleft + ( bucketnums - 1 ) * dbucket - return nbuckets, bucketnums, bucketleftedges - - -class CassBrokerMessagePageHandler: - def __init__( self, future, pgcursor, bucketleft, bucketright, dbucket, lowcutoff=1, highcutoff=9.99e5 ): - self.future = future - self.pgcursor = pgcursor - self.bucketleft = bucketleft - self.bucketright = bucketright - self.dbucket = dbucket - self.nbuckets, self.bucketnums, self.bucketleftedges = calcbucks( self.bucketleft, - self.bucketright, - self.dbucket ) - self.lowcutoff = lowcutoff - self.highcutoff = highcutoff - self.nhandled = 0 - self.printevery = 20000 - self.nextprint = 20000 - self.trunctime = 0 - self.striotime = 0 - self.copyfromtime = 0 - self.executetime = 0 - self.pandastime = 0 - self.futuretime = 0 - self.tottime = None - self.debugfirst = _logger.getEffectiveLevel() >= logging.DEBUG - - # Create a temporary postgres table for storing the alert ids we need - pgcursor.execute( "CREATE TEMPORARY TABLE temp_alertids( " - " alert_id bigint," - " classifier_id bigint," - " brokeringesttimestamp timestamptz," - " descingesttimestamp timestamptz," - " msghdrtimestamp timestamptz " - ") ON COMMIT DROP" ) - - # Precompile the postgres query we're going to use - pgcursor.execute( "DEALLOCATE ALL" ) - q = ( f"""SELECT DISTINCT ON(t.alert_id,t.classifier_id) - EXTRACT(EPOCH FROM t.descingesttimestamp - a.alertsenttimestamp)::float AS fulldelay, - EXTRACT(EPOCH FROM t.msghdrtimestamp - a.alertsenttimestamp)::float AS brokerdelay, - EXTRACT(EPOCH FROM t.descingesttimestamp - t.msghdrtimestamp)::float AS tomdelay - FROM temp_alertids t - INNER JOIN elasticc2_ppdbalert a ON t.alert_id=a.alert_id - """ ) - _logger.debug( f"(no-longer-so-) Ugly query: {q}" ) - pgcursor.execute( f"PREPARE bucket_join_alert_tempids AS {q}" ) - - self.df = makesubdf( self.bucketnums ).set_index( [ 'which', 'buck' ] ) - - self.finished_event = threading.Event() - self.error = None - - self.future.add_callbacks( callback=self.handle_page, errback=self.handle_error ) - - def finalize( self ): - self.pgcursor.connection.rollback() - resid = None - if self.tottime is not None: - self.tottime = time.perf_counter() - self.tottime - resid = ( self.tottime - self.trunctime - self.striotime - - self.copyfromtime - self.executetime - self.pandastime - self.futuretime ) - outstr = io.StringIO() - _logger.info( f"Overall: handled {self.nhandled} rows in {self.tottime} sec:\n" - f" trunctime : {self.trunctime}\n" - f" striotime : {self.striotime}\n" - f" copyfromtime : {self.copyfromtime}\n" - f" executetime : {self.executetime}\n" - f" pandastime : {self.pandastime}\n" - f" futuretime : {self.futuretime}\n" - f" (residual) : {resid}\n" ) - - def handle_page( self, rows ): - if self.tottime is None: - self.tottime = time.perf_counter() - t0 = time.perf_counter() - self.pgcursor.execute( "TRUNCATE TABLE temp_alertids" ) - - t1 = time.perf_counter() - strio = io.StringIO() - for row in rows: - strio.write( f"{row['alert_id']}\t" - f"{row['classifier_id']}\t" - f"{row['brokeringesttimestamp'].isoformat()}Z\t" - f"{row['descingesttimestamp'].isoformat()}Z\t" - f"{row['msghdrtimestamp'].isoformat()}Z\n" ) - strio.seek( 0 ) - t2 = time.perf_counter() - self.pgcursor.copy_from( strio, 'temp_alertids', size=262144 ) - - t3 = time.perf_counter() - if self.debugfirst: - self.debugfirst = False - self.pgcursor.execute( "EXPLAIN ANALYZE EXECUTE bucket_join_alert_tempids" ) - analrows = self.pgcursor.fetchall() - nl = '\n' - _logger.debug( f'Analyzed query:\n{nl.join( [ r["QUERY PLAN"] for r in analrows ] )}' ) - self.pgcursor.execute( "EXECUTE bucket_join_alert_tempids" ) - - t4 = time.perf_counter() - tmpdf = pandas.DataFrame( self.pgcursor.fetchall() ) - - if len(tmpdf) > 0: - tmpdf.clip( lower=10**self.bucketleft, upper=10**self.bucketright, inplace=True ) - tmpdf = tmpdf.apply( numpy.log10 ) - fullhist, binedges = numpy.histogram( tmpdf['fulldelay'], - range=(self.bucketleft,self.bucketright+self.dbucket), - bins=self.nbuckets+1 ) - if not ( binedges == numpy.arange( self.bucketleft, self.bucketright+2*self.dbucket, - self.dbucket ) ).all(): - raise ValueError( "Unexpected bins." ) - brokerhist, binedges = numpy.histogram( tmpdf['brokerdelay'], - range=(self.bucketleft,self.bucketright+self.dbucket), - bins=self.nbuckets+1 ) - tomhist, binedges = numpy.histogram( tmpdf['tomdelay'], - range=(self.bucketleft,self.bucketright+self.dbucket), - bins=self.nbuckets+1 ) - curdf = None - for which, hist in zip( [ 'full', 'broker', 'tom' ], [ fullhist, brokerhist, tomhist ] ): - whichdf = pandas.DataFrame( { 'which': which, - 'buck': numpy.array( ( binedges / self.dbucket + 1 )[:-1], dtype=int ), - 'count': hist } ) - if curdf is None: - curdf = whichdf - else: - curdf = pandas.concat( [ curdf, whichdf ] ) - - curdf.set_index( [ 'which', 'buck' ], inplace=True ) - # Sadly, this will convert ints to floats, but, oh well - self.df = self.df.add( curdf, fill_value=0 ) - - t5 = time.perf_counter() - if self.future.has_more_pages: - self.future.start_fetching_next_page() - else: - self.finished_event.set() - - t6 = time.perf_counter() - self.nhandled += len( rows ) - self.trunctime += t1 - t0 - self.striotime += t2 - t1 - self.copyfromtime += t3 - t2 - self.executetime += t4 - t3 - self.pandastime += t5 - t4 - self.futuretime += t6 - t5 - - if self.nhandled >= self.nextprint: - self.nextprint += self.printevery - _logger.info( f"Handled {self.nhandled} rows" ) - - def handle_error( self, exc ): - self.error = exc - self.finished_event.set() - - -class Command(BaseCommand): - help = 'Generate broker time delay graphs' - outdir = _rundir / "../../static/elasticc2/brokertiminggraphs" - - def add_arguments( self, parser) : - parser.add_argument( "--t0", default="2023-10-16", - help="First day to look at (YYYY-MM-DD) (default: 2023-10=16)" ) - parser.add_argument( "--t1", default="2023-10-19", - help="One past the last day to look at (YYYY-MM-DD) (default: 2023-10-19)" ) - parser.add_argument( "--dt", default=7, type=int, help="Step in days (default: 7)" ) - - self.bucketleft = 0 - self.bucketright = 6 - self.dbucket = 0.25 - self.lowcutoff = 1 - self.highcutoff = 9.99e5 - - def makeplots( self ): - brokers = set( self.df.index.get_level_values( 'broker' ) ) - weeks = set( self.df.index.get_level_values( 'week' ) ) - - whichtitle = { 'full': "Orig. Alert to Tom Ingestion", - 'broker': "Broker Delay", - - 'tom': "Tom Delay" } - - for broker in brokers: - for week in weeks: - fig = pyplot.figure( figsize=(18,4), tight_layout=True ) - for i, which in enumerate( [ 'full', 'broker', 'tom' ] ): - subdf = self.df.xs( ( broker, week, which ), level=( 'broker', 'week', 'which' ) ) - ax = fig.add_subplot( 1, 3, i+1 ) - ax.set_title( whichtitle[ which ], fontsize=18 ) - ax.set_xlim( self.bucketleft, self.bucketright + self.dbucket ) - ax.set_xlabel( r"$\log_{10}(\Delta t (\mathrm{s}))$", fontsize=14 ) - ax.set_ylabel( "N", fontsize=14 ) - tickvals = numpy.arange( 0, 7, 1 ) - # +1 since the lowest bucket in postgres is 1 (0 being < the lowest bucket) - xticks = tickvals / self.dbucket + self.bucketleft + 1 - xlabels = [ str(i) for i in tickvals ] - xlabels[0] = f'≤{xlabels[0]}' - xlabels[-1] = f'≥{xlabels[-1]}' - ax.set_xticks( xticks, labels=xlabels ) - ax.tick_params( 'both', labelsize=12 ) - ax.bar( subdf.index.values[1:], subdf['count'].values[1:], width=1, align='edge' ) - outfile = self.outdir / f'{broker}-{week}.svg' - _logger.info( f"Writing {outfile}" ) - if week == 'cumulative': - fig.suptitle( broker, fontsize=20 ) - else: - fig.suptitle( f"{broker}, {week} UTC", fontsize=20 ) - fig.savefig( outfile ) - pyplot.close( fig ) - - def makedf( self, weeklabs, brokernames ): - self.df = None - rows = [] - weeklabs = copy.deepcopy( weeklabs ) - weeklabs.insert( 0, 'cumulative' ) - for week in weeklabs: - for bname in brokernames: - df = makesubdf( self.bucketnums ) - df[ 'broker' ] = bname - df[ 'week' ] = week - if self.df is None: - self.df = df - else: - self.df = pandas.concat( [ self.df, df ] ) - self.df.set_index( [ 'broker', 'week', 'which', 'buck' ], inplace=True ) - - def handle( self, *args, **options ): - raise RuntimeError( "Deprecated. Run gen_elasticc2_brokerdelaygraphs_pg" ) - - _logger.info( "Starting genbrokerdelaygraphs" ) - - conn = None - # Jump through hoops to get access to the psycopg2 connection from django - conn = django.db.connection.cursor().connection - orig_autocommit = conn.autocommit - - try: - just_read_pickle = False - updatetime = None - - if not just_read_pickle: - - casssession = django.db.connections['cassandra'].connection.session - casssession.default_fetch_size = 10000 - # Perversely, it took longer per page using the PreparedStatement - # than it did using a simple statementbelow. ??? - # cassq = casssession.prepare( "SELECT * FROM tom_desc.cass_broker_message " - # "WHERE classifier_id IN ? " - # " AND descingesttimestamp >= ? " - # " AND descingesttimestamp < ? " - # "ALLOW FILTERING" ) - - conn.autocommit = False - - updatetime = datetime.datetime.utcnow().date().isoformat() - - self.outdir.mkdir( parents=True, exist_ok=True ) - - # Determine time buckets and weeks - - self.nbuckets, self.bucketnums, self.bucketleftedges = calcbucks( self.bucketleft, - self.bucketright, - self.dbucket ) - - t0 = pytz.utc.localize( datetime.datetime.fromisoformat( options['t0'] ) ) - t1 = pytz.utc.localize( datetime.datetime.fromisoformat( options['t1'] ) ) - dt = datetime.timedelta( days=options['dt'] ) - weeks = [] - week = t0 - while ( week < t1 ): - weeks.append( week ) - week += dt - weeklabs = [ f'[{w.year:04d}-{w.month:02d}-{w.day:02d} , ' - f'{(w+dt).year:04d}-{(w+dt).month:02d}-{(w+dt).day:02d})' for w in weeks ] - - with conn.cursor( cursor_factory=psycopg2.extras.RealDictCursor ) as cursor: - # Figure out which brokers we have - cursor.execute( 'SELECT * FROM elasticc2_brokerclassifier ' - 'ORDER BY "brokername","brokerversion","classifiername","classifierparams"' ) - brokers = { row["classifier_id"] : row for row in cursor.fetchall() } - conn.rollback() - - brokergroups = {} - for brokerid, row in brokers.items(): - if row['brokername'] not in brokergroups: - brokergroups[row['brokername']] = [] - brokergroups[row['brokername']].append( row['classifier_id'] ) - - # Choose the brokers to actually work on ( for debugging purposes ) - whichgroups = [ k for k in brokergroups.keys() ] - # whichgroups = [ 'Fink' ] - - # This is the master df that we'll append to as we - # iterate through brokers and weeks - self.makedf( weeklabs, whichgroups ) - - for broker in whichgroups: - for week, weeklab in zip( weeks, weeklabs ): - _logger.info( f"Doing broker {broker} week {weeklab}..." ) - - # Extract the data from the database for this broker - # and week (the CassBrokerMessagePageHandler will - # send a postgres query for each page returned from - # Cassandra) - - cassq = ( "SELECT * FROM tom_desc.cass_broker_message_by_time " - "WHERE classifier_id IN %(id)s " - " AND descingesttimestamp >= %(t0)s " - " AND descingesttimestamp < %(t1)s " - "ALLOW FILTERING" ) - future = casssession.execute_async( cassq, { 'id': tuple( brokergroups[broker] ), - 't0': week, - 't1': week+dt } ) - handler = CassBrokerMessagePageHandler( future, cursor, - self.bucketleft, self.bucketright, self.dbucket, - lowcutoff=self.lowcutoff, - highcutoff=self.highcutoff ) - _logger.info( "...waiting for finished..." ) - handler.finished_event.wait() - _logger.info( "...done." ) - handler.finalize() - if handler.error: - _logger.error( handler.error ) - raise handler.error - - if len( handler.df ) > 0: - df = handler.df.reset_index() - df[ 'broker' ] = broker - df[ 'week' ] = weeklab - df.set_index( [ 'broker', 'week', 'which', 'buck' ], inplace=True ) - self.df = self.df.add( df, fill_value=0 ) - - - # It seems like there should be a more elegant way to do this - summeddf = ( self.df.query( 'week!="cumlative"' ) - .groupby( [ 'broker', 'which', 'buck' ] ) - .sum().reset_index() ) - summeddf['week'] = 'cumulative' - summeddf.set_index( [ 'broker', 'week', 'which', 'buck' ], inplace=True ) - self.df.loc[ ( slice(None), 'cumulative', slice(None), slice(None) ), : ] = summeddf - - # self.df.set_index( [ 'broker', 'week', 'which', 'buck' ], inplace=True ) - _logger.info( "Writing gen_elasticc2_brokerdelaygraphs.pkl" ) - self.df.to_pickle( "gen_elasticc2_brokerdelaygraphs.pkl" ) - else: - _logger.info( "Reading gen_elasticc2_brokerdelaygraphs.pkl" ) - self.df = pandas.read_pickle( "gen_elasticc2_brokerdelaygraphs.pkl" ) - - _logger.info( "Saving plots." ) - self.makeplots() - if updatetime is not None: - with open( self.outdir / "updatetime.txt", 'w' ) as ofp: - ofp.write( updatetime ) - _logger.info( "All done." ) - except Exception as e: - _logger.exception( e ) - _logger.exception( traceback.format_exc() ) - # import pdb; pdb.set_trace() - raise e - finally: - if conn is not None: - conn.autocommit = orig_autocommit - conn.close() - conn = None - diff --git a/tom_desc/elasticc2/management/commands/gen_elasticc2_brokerdelaygraphs_pg.py b/tom_desc/elasticc2/management/commands/gen_elasticc2_brokerdelaygraphs_pg.py index ffe4da51..ebf79969 100644 --- a/tom_desc/elasticc2/management/commands/gen_elasticc2_brokerdelaygraphs_pg.py +++ b/tom_desc/elasticc2/management/commands/gen_elasticc2_brokerdelaygraphs_pg.py @@ -20,7 +20,6 @@ import django.db from matplotlib import pyplot from django.core.management.base import BaseCommand, CommandError -import cassandra.query _rundir = pathlib.Path(__file__).parent diff --git a/tom_desc/elasticc2/management/commands/load_snana_fits.py b/tom_desc/elasticc2/management/commands/load_snana_fits.py index c88b518c..4e07d0b4 100644 --- a/tom_desc/elasticc2/management/commands/load_snana_fits.py +++ b/tom_desc/elasticc2/management/commands/load_snana_fits.py @@ -203,7 +203,7 @@ def load_one_file( self, headfile, photfile ): # row['isddf'] = True if self.really_do: - nobj = DiaObject.bulk_insert_onlynew( dict( head ) ) + nobj = DiaObject.bulk_insert_or_upsert( dict( head ), assume_no_conflict=True ) self.logger.info( f"PID {os.getpid()} loaded {nobj} objects from {headfile.name}" ) else: nobj = len(head) @@ -242,7 +242,7 @@ def load_one_file( self, headfile, photfile ): phot = phot[ phot['diaobject_id'] >= 0 ] if self.really_do: - nfrc = DiaForcedSource.bulk_insert_onlynew( phot ) + nfrc = DiaForcedSource.bulk_insert_or_upsert( phot, assume_no_conflict=True ) self.logger.info( f"PID {os.getpid()} loaded {nfrc} forced photometry points from {photfile.name}" ) else: nfrc = len(phot) @@ -254,7 +254,7 @@ def load_one_file( self, headfile, photfile ): phot = phot[ ( phot['photflag'] & self.photflag_detect ) !=0 ] if self.really_do: - nsrc = DiaSource.bulk_insert_onlynew( phot ) + nsrc = DiaSource.bulk_insert_or_upsert( phot, assume_no_conflict=True ) self.logger.info( f"PID {os.getpid()} loaded {nsrc} sources from {photfile.name}" ) else: nsrc = len(phot) @@ -266,7 +266,7 @@ def load_one_file( self, headfile, photfile ): 'diasource_id': phot[ 'diasource_id' ], 'diaobject_id': phot[ 'diaobject_id' ] } if self.really_do: - nalrt = Alert.bulk_insert_onlynew( alerts ) + nalrt = Alert.bulk_insert_or_upsert( alerts, assume_no_conflict=True ) self.logger.info( f"PID {os.getpid()} loaded {nalrt} alerts" ) else: nalrt = len(alerts['alert_id']) diff --git a/tom_desc/elasticc2/management/commands/send_elasticc2_alerts.py b/tom_desc/elasticc2/management/commands/send_elasticc2_alerts.py index d0faacfe..ae8f36db 100644 --- a/tom_desc/elasticc2/management/commands/send_elasticc2_alerts.py +++ b/tom_desc/elasticc2/management/commands/send_elasticc2_alerts.py @@ -43,7 +43,7 @@ def __init__( self, *args, **kwargs ): if not self.logger.hasHandlers(): logout = logging.StreamHandler( sys.stderr ) self.logger.addHandler( logout ) - formatter = logging.Formatter( f'[%(asctime)s - %(levelname)s] - %(message)s', + formatter = logging.Formatter( f'[%(asctime)s - e2alert - %(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logout.setFormatter( formatter ) self.logger.setLevel( logging.INFO ) diff --git a/tom_desc/elasticc2/models.py b/tom_desc/elasticc2/models.py index 9c5fc229..ed18639a 100644 --- a/tom_desc/elasticc2/models.py +++ b/tom_desc/elasticc2/models.py @@ -15,16 +15,13 @@ import django.db from django.db import models from django.db.models import F +import django.db.models.functions from guardian.shortcuts import assign_perm from django.contrib.auth.models import Group from django.contrib.postgres.fields import ArrayField import django.conf import django.utils -from cassandra.cqlengine import columns -import cassandra.query -from django_cassandra_engine.models import DjangoCassandraModel - import astropy.time @@ -54,7 +51,7 @@ # Link to tom targets import tom_targets.models -_logger = logging.getLogger(__name__) +_logger = logging.getLogger( "elasticc2/models" ) _logout = logging.StreamHandler( sys.stderr ) _formatter = logging.Formatter( f'[%(asctime)s - %(levelname)s] - %(message)s' ) _logout.setFormatter( _formatter ) @@ -950,170 +947,6 @@ def add_batch( cls, sources ): -class CassBrokerMessageBySource(DjangoCassandraModel): - classifier_id = columns.BigInt( primary_key=True ) - diasource_id = columns.BigInt( primary_key=True ) - id = columns.UUID( primary_key=True, default=uuid.uuid4 ) - - topicname = columns.Text() - streammessage_id = columns.BigInt() - alert_id = columns.BigInt() - msghdrtimestamp = columns.DateTime() - elasticcpublishtimestamp = columns.DateTime() - brokeringesttimestamp = columns.DateTime() - descingesttimestamp = columns.DateTime( default=datetime.datetime.utcnow ) - classid = columns.List( columns.Integer() ) - probability = columns.List( columns.Float() ) - - class Meta: - get_pk_field = 'id' - - @staticmethod - def load_batch( messages, logger=_logger ): - """Load an array of broker classification messages. - - Loads things to *both* CassBrokerMessageBySource and - CassBrokerMessageByTime - - This doesn't actually do any batching operation, because there's - no bulk_create in the Django Cassandra interface, and because I - don't understand Cassandra well enough to know how to do this -- - I've read that batching can be a bad idea. I'm worried about - the repeated network overhead, but we'll see how it goes. - - """ - - cfers = {} - sourceids = [] - for i, msgmeta in enumerate(messages): - msg = msgmeta['msg'] - if len( msg['classifications'] ) == 0: - logger.debug( 'Message with no classifications' ) - continue - keycfer = f"{msg['brokerName']}_{msg['brokerVersion']}_{msg['classifierName']}_{msg['classifierParams']}" - if keycfer not in cfers.keys(): - cfers[ keycfer ] = { 'brokername': msg['brokerName'], - 'brokerversion': msg['brokerVersion'], - 'classifiername': msg['classifierName'], - 'classifierparams': msg['classifierParams'], - 'classifier_id': None } - sourceids.append( msg['diaSourceId'] ) - - # Create any classifiers that don't already exist; this - # is one place where we do get efficiency by calling - # this batch method. - cferconds = models.Q() - logger.debug( f"Looking for pre-existing classifiers" ) - cferconds = models.Q() - i = 0 - for cferkey, cfer in cfers.items(): - newcond = ( models.Q( brokername = cfer['brokername'] ) & - models.Q( brokerversion = cfer['brokerversion'] ) & - models.Q( classifiername = cfer['classifiername'] ) & - models.Q( classifierparams = cfer['classifierparams'] ) ) - cferconds |= newcond - curcfers = BrokerClassifier.objects.filter( cferconds ) - numknown = 0 - for cur in curcfers: - keycfer = f"{cur.brokername}_{cur.brokerversion}_{cur.classifiername}_{cur.classifierparams}" - cfers[ keycfer ][ 'classifier_id' ] = cur.classifier_id - numknown += 1 - logger.debug( f'Found {numknown} existing classifiers that match the ones in this batch.' ) - - # Create new classifiers as necessary - - kwargses = [] - ncferstoadd = 0 - for keycfer, cfer in cfers.items(): - if cfer[ 'classifier_id' ] is None: - kwargses.append( { 'brokername': cfer['brokername'], - 'brokerversion': cfer['brokerversion'], - 'classifiername': cfer['classifiername'], - 'classifierparams': cfer['classifierparams'] } ) - ncferstoadd += 1 - ncferstoadd = len(kwargses) - logger.debug( f'Adding {ncferstoadd} new classifiers.' ) - if ncferstoadd > 0: - objs = ( BrokerClassifier( **k ) for k in kwargses ) - batch = list( itertools.islice( objs, len(kwargses) ) ) - newcfers = BrokerClassifier.objects.bulk_create( batch, len(kwargses) ) - for newcfer in newcfers: - keycfer = ( f'{newcfer.brokername}_{newcfer.brokerversion}_' - f'{newcfer.classifiername}_{newcfer.classifierparams}' ) - cfers[keycfer]['classifier_id'] = newcfer.classifier_id - - # It's pretty clear that django really wants - # to mediate your database access... otherwise - # there wouldn't be so many periods here - casssession = django.db.connections['cassandra'].connection.session - qm = ( "INSERT INTO tom_desc.cass_broker_message_by_source(classifier_id,diasource_id,id," - "topicname,streammessage_id,alert_id,msghdrtimestamp,elasticcpublishtimestamp," - "brokeringesttimestamp,descingesttimestamp,classid,probability) " - "VALUES (?,?,?,?,?,?,?,?,?,?,?,?)" ) - qt = ( "INSERT INTO tom_desc.cass_broker_message_by_time(classifier_id,diasource_id,id," - "topicname,streammessage_id,alert_id,msghdrtimestamp,elasticcpublishtimestamp," - "brokeringesttimestamp,descingesttimestamp,classid,probability) " - "VALUES (?,?,?,?,?,?,?,?,?,?,?,?)" ) - cassqm = casssession.prepare( qm ) - cassqt = casssession.prepare( qt ) - - # TODO - MAKE THIS BETTER -- this parameter should be configurable somewhere - batchsize = 1000 - nbatch = 0 - batch = None - for i, msgmeta in enumerate( messages ): - msg = msgmeta['msg'] - if len( msg['classifications' ] ) == 0: - continue - keycfer = f"{msg['brokerName']}_{msg['brokerVersion']}_{msg['classifierName']}_{msg['classifierParams']}" - classes = [] - probs = [] - for cification in msg['classifications']: - classes.append( cification['classId'] ) - probs.append( cification['probability'] ) - - cassid = uuid.uuid4() - descingesttimestamp = datetime.datetime.now() - - args = ( cfers[ keycfer][ 'classifier_id' ], - msg[ 'diaSourceId' ], - cassid, - msgmeta[ 'topic' ], - msgmeta[ 'msgoffset' ], - msg[ 'alertId' ], - msgmeta[ 'timestamp' ], - msg[ 'elasticcPublishTimestamp' ], - msg[ 'brokerIngestTimestamp' ], - descingesttimestamp, - classes, - probs ) - if batch is None: - batch = cassandra.query.BatchStatement() - nbatch = 0 - batch.add( cassqm.bind( args ) ) - batch.add( cassqt.bind( args ) ) - nbatch += 1 - if nbatch >= batchsize: - casssession.execute( batch ) - batch = None - nbatch = 0 - - if ( batch is not None ) and ( nbatch > 0 ): - casssession.execute( batch ) - - # Update the log of new broker source ids - BrokerSourceIds.add_batch( sourceids ) - - logger.debug( f"Classifiers in the messages just loaded: {list(cfers.keys())}" ) - - - # return newcfications - return { "addedmsgs": len(messages), - "addedclassifiers": ncferstoadd, - "addedclassifications": None, - "firstbrokermessage_id": None } - - # ====================================================================== # ====================================================================== # ====================================================================== @@ -1172,40 +1005,6 @@ class Meta: name='diaobjectclassification_unique' ) ] - -# ====================================================================== -# ====================================================================== -# ====================================================================== -# Cassandra tables (not currently used) - -# class CassBrokerMessageByTime(DjangoCassandraModel): -# classifier_id = columns.BigInt( primary_key=True ) -# descingesttimestamp = columns.DateTime( default=datetime.datetime.utcnow, primary_key=True ) -# id = columns.UUID( primary_key=True, default=uuid.uuid4 ) - -# topicname = columns.Text() -# streammessage_id = columns.BigInt() -# diasource_id = columns.BigInt() -# alert_id = columns.BigInt() -# msghdrtimestamp = columns.DateTime() -# elasticcpublishtimestamp = columns.DateTime() -# brokeringesttimestamp = columns.DateTime() -# classid = columns.List( columns.Integer() ) -# probability = columns.List( columns.Float() ) - -# class Meta: -# get_pk_field = 'id' - -# @staticmethod -# def load_batch( messages, logger=_logger ): -# """Calls CassBrokerMessageBySource.load_batch""" - -# CassBrokerMessageBySource.load_batch( messages, logger ) - - - - - # ====================================================================== class SpectrumInfo(Createable): @@ -1232,7 +1031,7 @@ class WantedSpectra(Createable): priority = models.IntegerField() _pk = 'wantspec_id' - _create_kws = [ _pk, 'diaobject_id', 'user_id', 'requester', 'priority' ] + _create_kws = [ _pk, 'diaobject_id', 'user_id', 'requester', 'priority', 'wanttime' ] class PlannedSpectra(Createable): reqspec_id = models.AutoField( primary_key=True, unique=True, db_index=True ) diff --git a/tom_desc/elasticc2/views.py b/tom_desc/elasticc2/views.py index 0e417912..5c6bf291 100644 --- a/tom_desc/elasticc2/views.py +++ b/tom_desc/elasticc2/views.py @@ -1,5 +1,6 @@ import sys import re +import io import pathlib import datetime import dateutil.parser @@ -39,7 +40,7 @@ # doesn't seem to have the formatting built in; # I guess djano makes its own formatting instead # of using logging's. Sigh. -_logger = logging.getLogger(__name__) +_logger = logging.getLogger( "elasticc2/views" ) _logout = logging.StreamHandler( sys.stderr ) _formatter = logging.Formatter( f'[%(asctime)s - %(levelname)s] - %(message)s' ) _logout.setFormatter( _formatter ) @@ -1164,23 +1165,71 @@ def post( self, request ): return HttpResponse( "Mal-formed data for askforspectrum", status=500, content_type='text/plain; charset=utf-8' ) + now = datetime.datetime.now( tz=datetime.timezone.utc ) tocreate = [ { 'requester': data['requester'], 'diaobject_id': data['objectids'][i], 'wantspec_id': f"{data['objectids'][i]} ; {data['requester']}", 'user_id': request.user.id, 'priority': ( 0 if int(data['priorities'][i]) < 0 else 5 if int(data['priorities'][i]) > 5 - else int(data['priorities'][i] )) } + else int(data['priorities'][i] )), + 'wanttime': now } for i in range(len(data['objectids'])) ] try: - objs = WantedSpectra.bulk_load_or_create( tocreate ) + # OK. Rant warning. + # + # I hate ORMs. I hate them all. The ONLY nice thing I have + # to say about Django is that it's not SQLAlchemy. (OMG, + # SQLAlchemy, I have so much past trauma.) + # + # For efficiency, I was using using Postgres' ability to + # bulk create by uploading a lot of data in one big blob, + # instead of with a bunch of SQL insert commands. But, + # that meant that the default fields weren't getting + # filled... because django was managing the defaults, not + # the database. This is what happens when you use an + # intervening layer to define your database instead of the + # database itself. As is often the case when using ORMs, + # you lose control of your database, and regret it later. + # + # Now, it turns out that django has the keyword + # "db_default"... but that's new in django 5, and right + # now I can't do the dependency-hell switch from django 4 + # to 5, because we've built on top of the TOM Toolkit, and + # just bumping django is likely to break that if we don't + # bump the TOM toolkit at the same time. I'm already two + # interrupt levels deep in this PR, and I don't want its + # scope to creep any more. (Which it may not, but I've + # learned through years of pain to have a deep and + # visceral fear of even the specter of dependency hell.) + # + # As such, I'm putting in the kwmap parameter here in hopes + # of working around the way I tried to set things up to be + # able to use SQL directly with django models. The right + # solution, of course, is to eschew and avoid ORMs at all + # times instead of trying to write ugly hack code to be + # able to cope with the fact that you have an ORM but want + # to do actual PostgreSQL sometimes. But, oh well. We go + # insane with the world we live in, not the world we + # wished we lived in. + ninserted = WantedSpectra.bulk_insert_or_upsert( tocreate, upsert=True, + kwmap={ 'wantspec_id': 'wantspec_id', + 'diaobject_id': 'diaobject_id', + 'user_id': 'user_id', + 'requester': 'requester', + 'priority': 'priority', + 'wanttime': 'wanttime' } ) except Exception as ex: + strio = io.StringIO() + strio.write( "Exception in AskForSpectrumView:\n" ) + traceback.print_exc( file=strio ) + _logger.error( strio.getvalue() ) return HttpResponse( str(ex), status=500, content_type='text/plain; charset=utf-8' ) return JsonResponse( { 'status': 'ok', 'message': f'wanted spectra created', - 'num': len(objs) } ) + 'num': ninserted } ) # ====================================================================== diff --git a/tom_desc/fastdb_dev/management/commands/fastdb_dev_brokerpoll.py b/tom_desc/fastdb_dev/management/commands/fastdb_dev_brokerpoll.py index 729583e5..36ab73d3 100644 --- a/tom_desc/fastdb_dev/management/commands/fastdb_dev_brokerpoll.py +++ b/tom_desc/fastdb_dev/management/commands/fastdb_dev_brokerpoll.py @@ -69,7 +69,7 @@ def add_arguments( self, parser ): parser.add_argument( '--pitt-project', default=None, help="Project name for PITT-Google" ) parser.add_argument( '--do-test', action='store_true', default=False, help="Poll from kafka-server:9092 (for testing purposes)" ) - parser.add_argument( '---test-topic', default='classifications', + parser.add_argument( '--test-topic', default='classifications', help="Topic to poll from on kafka-server:9092" ) parser.add_argument( '-g', '--grouptag', default=None, help="Tag to add to end of kafka group ids" ) parser.add_argument( '-r', '--reset', default=False, action='store_true', @@ -121,12 +121,13 @@ def handle( self, *args, **options ): # Launch a process for each broker that will poll that broker indefinitely - # We want to make sure that django doesn't send copies of database sessions - # to the subprocesses; at least for Cassandra, that breaks things. So, - # before launching all the processes, close all the database django connections + # We want to make sure that django doesn't send copies of + # database sessions to the subprocesses. So, before launching + # all the processes, close all the database django connections # so that each process will open a new one as it needs it. - # (They already open mongo connections as necessary, and django doesn't muck - # about with mongo, so we don't have to do things for that.) + # (They already open mongo connections as necessary, and django + # doesn't muck about with mongo, so we don't have to do things + # for that.) django.db.connections.close_all() brokers = {} diff --git a/tom_desc/tests/README.md b/tom_desc/tests/README.md new file mode 100644 index 00000000..b4c3dc74 --- /dev/null +++ b/tom_desc/tests/README.md @@ -0,0 +1,3 @@ +The tests for the DESC TOM are found in ../../tests + +This directory holds a django application named "tests" that we don't actually use. diff --git a/tom_desc/tom_desc/settings.py b/tom_desc/tom_desc/settings.py index c0315e41..508d9cce 100644 --- a/tom_desc/tom_desc/settings.py +++ b/tom_desc/tom_desc/settings.py @@ -15,7 +15,6 @@ import socket import tempfile import psqlextra -# import cassandra # Build paths inside the project like this: os.path.join(BASE_DIR, ...) BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -37,8 +36,6 @@ DATA_UPLOAD_MAX_MEMORY_SIZE = 524288000 # Application definition -# NOTE -- django_cassandra_engine demands to be first -# Just gotta hope nobody else makes the same claim. INSTALLED_APPS = [ 'whitenoise.runserver_nostatic',