From b8f170688e9c26384ed73498047814902afeb4d7 Mon Sep 17 00:00:00 2001 From: Me Date: Tue, 22 Dec 2020 22:40:55 +0000 Subject: [PATCH] Move all DB function to db.py + add testsuite for db.py --- srv/db.py | 54 ++++++++++++- srv/energyDB.py | 84 +++++--------------- test/test_db.py | 86 +++++++++++++++++++++ test/{test_EnergyDB.py => test_energyDB.py} | 83 +++++++++++++++----- 4 files changed, 219 insertions(+), 88 deletions(-) create mode 100644 test/test_db.py rename test/{test_EnergyDB.py => test_energyDB.py} (83%) diff --git a/srv/db.py b/srv/db.py index 1066f5a..ef002a3 100644 --- a/srv/db.py +++ b/srv/db.py @@ -1,5 +1,6 @@ +import datetime import os -from typing import List +from typing import List, Dict from databases import Database from sqlalchemy import (Column, DateTime, Integer, Float, String, @@ -36,8 +37,8 @@ class EnergyDB(): def metadata(self): return self._metadata - def url(self): - return self._engine.url + def url(self) -> str: + return str(self._engine.url) def engine(self): return self._engine @@ -73,7 +74,7 @@ class EnergyDB(): def getChannels(self) -> dict: try: table_channels = self.table("channels") - query = sqlalchemy.select([table_channels.c.name, table_channels.c.id]).select_from(table_channels) + query = sqlalchemy.select([table_channels.c.name]).select_from(table_channels) channels = [dict(r.items()) for r in self.execute(query).fetchall()] except Exception as e: raise Exception(f"Database error in getChannels(): {type(e)} - {str(e)}") @@ -91,3 +92,48 @@ class EnergyDB(): if chId is None: raise Exception(f"Database error in getChannelId(): channel '{channelName}' not found") return chId + + def getChannelData(self, channelIds : List[int], fromTime : datetime.datetime, tillTime : datetime.datetime) -> Dict[int, list]: + try: + chData = {} + query = sqlalchemy.sql.text( + """SELECT timestamp, value FROM energy + WHERE channel_id == :channel_id + AND timestamp >= :fromTime + AND timestamp <= :tillTime + ORDER BY timestamp""" + ) + for ch in channelIds: + result = self.execute(query, channel_id = ch, fromTime = fromTime, tillTime = tillTime) + data = [{"timestamp": datetime.datetime.fromisoformat(row[0]), "value": row[1]} for row in result.fetchall() ] + chData[ch] = data + return chData + except Exception as e: + raise Exception(f"Database error in getChannelData(): {type(e)} - {str(e)}") + + def addChannelData(self, channel_id : int, data : List[Dict[datetime.datetime, float]]): + try: + # query = sqlalchemy.sql.text( + # "INSERT INTO energy (channel_id, timestamp, value) VALUES " + # ) + queryStr = "INSERT INTO energy (channel_id, timestamp, value) VALUES " + valueStr = "" + for d in data: + timestamp = d["timestamp"] + value = d["value"] + # self.execute( + # # query, + # # channel_id=channel_id, + # # timestamp = d["timestamp"], + # # value=d["value"] + # sqlalchemy.sql.text( + # f"""INSERT INTO energy (channel_id, timestamp, value) + # VALUES ('{channel_id}', '{timestamp}', '{value}')""" + # ) + # ) + if valueStr != "": + valueStr += ",\n" + valueStr += f"({channel_id}, '{timestamp}', {value})" + self.execute(sqlalchemy.sql.text(queryStr + valueStr + ";")) + except Exception as e: + raise Exception(f"Database error in addChannelData(): {type(e)} - {str(e)}") diff --git a/srv/energyDB.py b/srv/energyDB.py index 2ece97a..e265247 100644 --- a/srv/energyDB.py +++ b/srv/energyDB.py @@ -4,8 +4,6 @@ from pydantic import BaseModel import datetime import os -# import sqlalchemy - from .db import EnergyDB import logging @@ -14,9 +12,9 @@ API_VERSION_MAJOR = 1 API_VERSION_MINOR = 0 DB_URL = os.getenv("DATABASE_URL") #, default="sqlite://") -if len(DB_URL) == 0: +if DB_URL is None or len(DB_URL) == 0: raise Exception("Environment variable DATABASE_URL missed!") -print(f"DB URL: {DB_URL}") +# print(f"DB URL: {DB_URL}") db = EnergyDB(DB_URL) app = FastAPI(debug=True) @@ -109,81 +107,41 @@ async def apiGetInfo(): @app.get(_restApiPath("/bulkData"), response_model = BulkData) async def getBulkEnergyData(bulkDataRequest: BulkDataRequest): bulkData = [] - trace = [] - exception = None try: - for ch in bulkDataRequest.channel_ids: - data = [] - table_energy = db.table("energy") - query = sqlalchemy.select([table_energy.c.timestamp, table_energy.c.value]) \ - .select_from(table_energy) \ - .where(sqlalchemy.sql.and_( - table_energy.c.channel_id == ch, - table_energy.c.timestamp >= bulkDataRequest.fromTime, - table_energy.c.timestamp <= bulkDataRequest.tillTime - ) - ) - for row in db.execute(query).fetchall(): - data.append(dict(row.items())) - bulkData.append({"channel_id": ch, "data": data}) + chData = db.getChannelData(bulkDataRequest.channel_ids, bulkDataRequest.fromTime, bulkDataRequest.tillTime) + for chId in chData: + bulkData.append({ + "channel_id": chId, + "channel": None, + "data": chData[chId], + "msg": str(chData), + }) except Exception as e: raise HTTPException( status_code=404, detail=f"Database error: {type(e)} - {str(e)}" - # detail=f"Database error: {str(e)}\nQuery: {str(query)}" ) return { "bulk": bulkData, - "msg": None #__name__ + " - " + str(query) + "msg": str(chData.keys()) } @app.put(_restApiPath("/bulkData")) async def putBulkEnergyData(bulkData: BulkData): - valuesToInsert = [] - result = "ok" - # rows_before = {} - # rows_after = {} - try: - # rowCounter = 0 - # dbResult = db.execute( db.tables["energy"].select() ) - # for row in dbResult.fetchall(): - # rows_before[f"row_{rowCounter}"] = str(row) - # rowCounter += 1 - + valuesToInsert = [] for channelData in bulkData.bulk: if channelData.channel_id is None: - try: - table_channels = db.table("channels") - channel_id = db.execute( - sqlalchemy.select([table_channels.c.id]) \ - .select_from(table_channels) \ - .where(table_channels.c.name == channelData.channel)) - except: - raise HTTPException( - status_code = 500, - detail = f"Database error: {type(ex)} - \"{ex}\"" - ) - for measurement in channelData.data: - valuesToInsert.append({ - "channel_id": channelData.channel_id, - "timestamp": measurement.timestamp, - "value": measurement.value - }) - db.execute(db.table("energy").insert(), valuesToInsert) - - # rowCounter = 0 - # dbResult = db.execute( db.tables["energy"].select() ) - # for row in dbResult.fetchall(): - # rows_after[f"row_{rowCounter}"] = str(row) - # rowCounter += 1 + channel_id = db.getChannelId(channelData.channel) + else: + channel_id = channelData.channel_id + values = [] + for v in channelData.data: + values.append({"timestamp": v.timestamp, "value": v.value}) + db.addChannelData(channel_id, values) + return except Exception as e: - result = f"Exception \"{str(e)}\"" - return { - "result": result, - # "rows_before": rows_before, - # "rows_after": rows_after, - } + raise HTTPException(status_code=500, detail=f"Internal error: \"{str(e)}\"") @app.put(_restApiPath("/channels")) async def putChannels(channel_info: Channels): diff --git a/test/test_db.py b/test/test_db.py new file mode 100644 index 0000000..10a762e --- /dev/null +++ b/test/test_db.py @@ -0,0 +1,86 @@ +import pytest + +from datetime import datetime +import os + +# os.environ["DATABASE_URL"] = "sqlite:///./testDB.sqlite" +os.environ["DATABASE_URL"] = "sqlite://" + +from srv import db + +class Test_db: + _DB_URL = "sqlite://" + + def setup(self): + self._db = db.EnergyDB(self._DB_URL) + + def teardown(self): + pass + + # --- helper functions + + def _clearTable(self, tableName : str): + self._db.execute(f"DELETE FROM {tableName}") + + # --- test functions + + def test_url(self): + assert self._db.url() == self._DB_URL + + def test_table(self): + assert self._db.table("energy") == self._db._tables["energy"] + assert self._db.table("channels") == self._db._tables["channels"] + + def _test_table_unknownTable(self): + pass + + def test_getChannels_emptyDatabase(self): + channels = self._db.getChannels() + assert channels["channels"] == [] + + def test_addChannels(self): + self._db.addChannels(["abc", "def", "ghi"]) + result = self._db.getChannels() + assert type(result) == dict + channels = result["channels"] + assert channels[0] == {"name": "abc"} + assert channels[1] == {"name": "def"} + assert channels[2] == {"name": "ghi"} + + def test_getChannelId(self): + self._db.addChannels(["abc", "def", "ghi"]) + assert self._db.getChannelId("abc") == 1 + assert self._db.getChannelId("def") == 2 + assert self._db.getChannelId("ghi") == 3 + + def test_getChannelId_ExceptionIfChannelIsUnknown(self): + with pytest.raises(Exception): + assert self._db.getChannelId("jkl") == 0 + + def test_getChannelData_EmptyDatabase(self): + fromTime = datetime.now() + tillTime = datetime.now() + result = self._db.getChannelData([1], fromTime, tillTime) + assert list(result.keys()) == [1] + assert result[1] == [] + + def test_addChannelData(self): + data = [ + {"timestamp": datetime.fromisoformat("2020-12-12T09:00:01"), "value": 900.01}, + {"timestamp": datetime.fromisoformat("2020-12-12T09:05:02"), "value": 905.02}, + {"timestamp": datetime.fromisoformat("2020-12-12T09:10:03"), "value": 910.03}, + ] + self._db.addChannelData(8, data) + result = self._db.getChannelData( + [8], + datetime.fromisoformat("2020-12-12T09:00:00"), + datetime.now() + ) + assert isinstance(result, dict) + assert len(result) == 1 + assert 8 in result + channelData = result[8] + assert len(channelData) == 3 + assert channelData[0] == data[0] + assert channelData[1] == data[1] + assert channelData[2] == data[2] diff --git a/test/test_EnergyDB.py b/test/test_energyDB.py similarity index 83% rename from test/test_EnergyDB.py rename to test/test_energyDB.py index f5ef774..ee02efd 100644 --- a/test/test_EnergyDB.py +++ b/test/test_energyDB.py @@ -119,13 +119,13 @@ class Test_energyDB: # self._dumpRequestAndResponse("test_getInfo", response) assert response.status_code == 200 - def _test_getChannelsOfEmptyTable(self): + def test_getChannelsOfEmptyTable(self): response = self.client.get(self._apiUrl("/channels")) # self._dumpRequestAndResponse("test_getChannelsOfEmptyTable", response) assert response.status_code == 200 assert response.json()["channels"] == [] - def _test_getBulkDataOfEmptyTable(self): + def test_getBulkDataOfEmptyTable(self): response = self.client.get( self._apiUrl("/bulkData"), json = { @@ -133,17 +133,16 @@ class Test_energyDB: "fromTime": "0001-01-01T00:00:00" } ) - self._dumpRequestAndResponse("test_getBulkDataOfEmptyTable", response) + # self._dumpRequestAndResponse("test_getBulkDataOfEmptyTable", response) assert response.status_code == 200 assert "bulk" in response.json() - assert response.json()["bulk"] == [ - { - "channel_id": 1, - "channel": None, - "data": [], - "msg": None - } - ] + bulkData = response.json()["bulk"] + assert len(bulkData) == 1 + assert "channel_id" in bulkData[0] + assert "data" in bulkData[0] + assert bulkData[0]["channel_id"] == 1 + channelData = bulkData[0]["data"] + assert len(channelData) == 0 def test_fillDatabase(self): self._fillDatabase() @@ -185,20 +184,62 @@ class Test_energyDB: # self._dumpRequestAndResponse("test_getChannelId", response) assert response.status_code == 200 - def _test_putBulkData(self): + def test_getBulkData(self): + response = self.client.get( + self._apiUrl("/bulkData"), + json = { + "channel_ids": [1], + "fromTime": "0001-01-01T00:00:00" + } + ) + # self._dumpRequestAndResponse("test_getBulkData", response) + assert response.status_code == 200 + assert "bulk" in response.json() + bulkData = response.json()["bulk"] + assert len(bulkData) == 1 + channelData = bulkData[0] + assert channelData["channel_id"] == 1 + referenceData = self.testData["bulkdata"][0] + assert len(channelData["data"]) == len(referenceData["data"]) + assert channelData["data"][0] == referenceData["data"][0] + assert channelData["data"][1] == referenceData["data"][1] + assert channelData["data"][2] == referenceData["data"][2] + assert channelData["data"][3] == referenceData["data"][3] + assert channelData["data"][4] == referenceData["data"][4] + assert channelData["data"][5] == referenceData["data"][5] + assert channelData["data"][6] == referenceData["data"][6] + assert channelData["data"][7] == referenceData["data"][7] + assert channelData["data"][8] == referenceData["data"][8] + assert channelData["data"][9] == referenceData["data"][9] + + def test_putBulkData(self): + newData = [{ + "channel_id": None, + "channel": "total_yield", + "data": [ + { "timestamp": "2020-12-11T12:01:20", "value": 120120.1 }, + { "timestamp": "2020-12-11T12:30:25", "value": 123025.2 }, + ] + }] response = self.client.put( self._apiUrl("/bulkData"), - json = {"bulk": { - "channel_id": None, - "channel": "total_yield", - "data": [ - { "timestamp": "2020-12-11T12:01:20", "value": 120120.1 }, - { "timestamp": "2020-12-11T12:30:25", "value": 123025.2 }, - ] - }} + json = {"bulk": newData} ) - self._dumpRequestAndResponse("test_putBulkData", response) + # self._dumpRequestAndResponse("test_putBulkData", response) assert response.status_code == 200 + response = self.client.get( + self._apiUrl("/bulkData"), + json = { + "channel_ids": [3], + "fromTime": "2020-12-11T12:00:00", + "tillTime": "2020-12-11T12:59:59" + } + ) + channelData = response.json()["bulk"][0] + assert channelData["channel_id"] == 3 + assert len(channelData["data"]) == 2 + assert channelData["data"][0] == newData[0]["data"][0] + assert channelData["data"][1] == newData[0]["data"][1] # def test_getRecordCount(self): # response = self.client.get(self._apiUrl("/1/count"))