Move all DB function to db.py + add testsuite for db.py
This commit is contained in:
parent
2d5eae04ea
commit
b8f170688e
4 changed files with 219 additions and 88 deletions
54
srv/db.py
54
srv/db.py
|
|
@ -1,5 +1,6 @@
|
|||
import datetime
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
from databases import Database
|
||||
from sqlalchemy import (Column, DateTime, Integer, Float, String,
|
||||
|
|
@ -36,8 +37,8 @@ class EnergyDB():
|
|||
def metadata(self):
|
||||
return self._metadata
|
||||
|
||||
def url(self):
|
||||
return self._engine.url
|
||||
def url(self) -> str:
|
||||
return str(self._engine.url)
|
||||
|
||||
def engine(self):
|
||||
return self._engine
|
||||
|
|
@ -73,7 +74,7 @@ class EnergyDB():
|
|||
def getChannels(self) -> dict:
|
||||
try:
|
||||
table_channels = self.table("channels")
|
||||
query = sqlalchemy.select([table_channels.c.name, table_channels.c.id]).select_from(table_channels)
|
||||
query = sqlalchemy.select([table_channels.c.name]).select_from(table_channels)
|
||||
channels = [dict(r.items()) for r in self.execute(query).fetchall()]
|
||||
except Exception as e:
|
||||
raise Exception(f"Database error in getChannels(): {type(e)} - {str(e)}")
|
||||
|
|
@ -91,3 +92,48 @@ class EnergyDB():
|
|||
if chId is None:
|
||||
raise Exception(f"Database error in getChannelId(): channel '{channelName}' not found")
|
||||
return chId
|
||||
|
||||
def getChannelData(self, channelIds : List[int], fromTime : datetime.datetime, tillTime : datetime.datetime) -> Dict[int, list]:
|
||||
try:
|
||||
chData = {}
|
||||
query = sqlalchemy.sql.text(
|
||||
"""SELECT timestamp, value FROM energy
|
||||
WHERE channel_id == :channel_id
|
||||
AND timestamp >= :fromTime
|
||||
AND timestamp <= :tillTime
|
||||
ORDER BY timestamp"""
|
||||
)
|
||||
for ch in channelIds:
|
||||
result = self.execute(query, channel_id = ch, fromTime = fromTime, tillTime = tillTime)
|
||||
data = [{"timestamp": datetime.datetime.fromisoformat(row[0]), "value": row[1]} for row in result.fetchall() ]
|
||||
chData[ch] = data
|
||||
return chData
|
||||
except Exception as e:
|
||||
raise Exception(f"Database error in getChannelData(): {type(e)} - {str(e)}")
|
||||
|
||||
def addChannelData(self, channel_id : int, data : List[Dict[datetime.datetime, float]]):
|
||||
try:
|
||||
# query = sqlalchemy.sql.text(
|
||||
# "INSERT INTO energy (channel_id, timestamp, value) VALUES "
|
||||
# )
|
||||
queryStr = "INSERT INTO energy (channel_id, timestamp, value) VALUES "
|
||||
valueStr = ""
|
||||
for d in data:
|
||||
timestamp = d["timestamp"]
|
||||
value = d["value"]
|
||||
# self.execute(
|
||||
# # query,
|
||||
# # channel_id=channel_id,
|
||||
# # timestamp = d["timestamp"],
|
||||
# # value=d["value"]
|
||||
# sqlalchemy.sql.text(
|
||||
# f"""INSERT INTO energy (channel_id, timestamp, value)
|
||||
# VALUES ('{channel_id}', '{timestamp}', '{value}')"""
|
||||
# )
|
||||
# )
|
||||
if valueStr != "":
|
||||
valueStr += ",\n"
|
||||
valueStr += f"({channel_id}, '{timestamp}', {value})"
|
||||
self.execute(sqlalchemy.sql.text(queryStr + valueStr + ";"))
|
||||
except Exception as e:
|
||||
raise Exception(f"Database error in addChannelData(): {type(e)} - {str(e)}")
|
||||
|
|
|
|||
|
|
@ -4,8 +4,6 @@ from pydantic import BaseModel
|
|||
import datetime
|
||||
import os
|
||||
|
||||
# import sqlalchemy
|
||||
|
||||
from .db import EnergyDB
|
||||
|
||||
import logging
|
||||
|
|
@ -14,9 +12,9 @@ API_VERSION_MAJOR = 1
|
|||
API_VERSION_MINOR = 0
|
||||
|
||||
DB_URL = os.getenv("DATABASE_URL") #, default="sqlite://")
|
||||
if len(DB_URL) == 0:
|
||||
if DB_URL is None or len(DB_URL) == 0:
|
||||
raise Exception("Environment variable DATABASE_URL missed!")
|
||||
print(f"DB URL: {DB_URL}")
|
||||
# print(f"DB URL: {DB_URL}")
|
||||
db = EnergyDB(DB_URL)
|
||||
|
||||
app = FastAPI(debug=True)
|
||||
|
|
@ -109,81 +107,41 @@ async def apiGetInfo():
|
|||
@app.get(_restApiPath("/bulkData"), response_model = BulkData)
|
||||
async def getBulkEnergyData(bulkDataRequest: BulkDataRequest):
|
||||
bulkData = []
|
||||
trace = []
|
||||
exception = None
|
||||
try:
|
||||
for ch in bulkDataRequest.channel_ids:
|
||||
data = []
|
||||
table_energy = db.table("energy")
|
||||
query = sqlalchemy.select([table_energy.c.timestamp, table_energy.c.value]) \
|
||||
.select_from(table_energy) \
|
||||
.where(sqlalchemy.sql.and_(
|
||||
table_energy.c.channel_id == ch,
|
||||
table_energy.c.timestamp >= bulkDataRequest.fromTime,
|
||||
table_energy.c.timestamp <= bulkDataRequest.tillTime
|
||||
)
|
||||
)
|
||||
for row in db.execute(query).fetchall():
|
||||
data.append(dict(row.items()))
|
||||
bulkData.append({"channel_id": ch, "data": data})
|
||||
chData = db.getChannelData(bulkDataRequest.channel_ids, bulkDataRequest.fromTime, bulkDataRequest.tillTime)
|
||||
for chId in chData:
|
||||
bulkData.append({
|
||||
"channel_id": chId,
|
||||
"channel": None,
|
||||
"data": chData[chId],
|
||||
"msg": str(chData),
|
||||
})
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Database error: {type(e)} - {str(e)}"
|
||||
# detail=f"Database error: {str(e)}\nQuery: {str(query)}"
|
||||
)
|
||||
return {
|
||||
"bulk": bulkData,
|
||||
"msg": None #__name__ + " - " + str(query)
|
||||
"msg": str(chData.keys())
|
||||
}
|
||||
|
||||
@app.put(_restApiPath("/bulkData"))
|
||||
async def putBulkEnergyData(bulkData: BulkData):
|
||||
valuesToInsert = []
|
||||
result = "ok"
|
||||
# rows_before = {}
|
||||
# rows_after = {}
|
||||
|
||||
try:
|
||||
# rowCounter = 0
|
||||
# dbResult = db.execute( db.tables["energy"].select() )
|
||||
# for row in dbResult.fetchall():
|
||||
# rows_before[f"row_{rowCounter}"] = str(row)
|
||||
# rowCounter += 1
|
||||
|
||||
valuesToInsert = []
|
||||
for channelData in bulkData.bulk:
|
||||
if channelData.channel_id is None:
|
||||
try:
|
||||
table_channels = db.table("channels")
|
||||
channel_id = db.execute(
|
||||
sqlalchemy.select([table_channels.c.id]) \
|
||||
.select_from(table_channels) \
|
||||
.where(table_channels.c.name == channelData.channel))
|
||||
except:
|
||||
raise HTTPException(
|
||||
status_code = 500,
|
||||
detail = f"Database error: {type(ex)} - \"{ex}\""
|
||||
)
|
||||
for measurement in channelData.data:
|
||||
valuesToInsert.append({
|
||||
"channel_id": channelData.channel_id,
|
||||
"timestamp": measurement.timestamp,
|
||||
"value": measurement.value
|
||||
})
|
||||
db.execute(db.table("energy").insert(), valuesToInsert)
|
||||
|
||||
# rowCounter = 0
|
||||
# dbResult = db.execute( db.tables["energy"].select() )
|
||||
# for row in dbResult.fetchall():
|
||||
# rows_after[f"row_{rowCounter}"] = str(row)
|
||||
# rowCounter += 1
|
||||
channel_id = db.getChannelId(channelData.channel)
|
||||
else:
|
||||
channel_id = channelData.channel_id
|
||||
values = []
|
||||
for v in channelData.data:
|
||||
values.append({"timestamp": v.timestamp, "value": v.value})
|
||||
db.addChannelData(channel_id, values)
|
||||
return
|
||||
except Exception as e:
|
||||
result = f"Exception \"{str(e)}\""
|
||||
return {
|
||||
"result": result,
|
||||
# "rows_before": rows_before,
|
||||
# "rows_after": rows_after,
|
||||
}
|
||||
raise HTTPException(status_code=500, detail=f"Internal error: \"{str(e)}\"")
|
||||
|
||||
@app.put(_restApiPath("/channels"))
|
||||
async def putChannels(channel_info: Channels):
|
||||
|
|
|
|||
86
test/test_db.py
Normal file
86
test/test_db.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
import pytest
|
||||
|
||||
from datetime import datetime
|
||||
import os
|
||||
|
||||
# os.environ["DATABASE_URL"] = "sqlite:///./testDB.sqlite"
|
||||
os.environ["DATABASE_URL"] = "sqlite://"
|
||||
|
||||
from srv import db
|
||||
|
||||
class Test_db:
|
||||
_DB_URL = "sqlite://"
|
||||
|
||||
def setup(self):
|
||||
self._db = db.EnergyDB(self._DB_URL)
|
||||
|
||||
def teardown(self):
|
||||
pass
|
||||
|
||||
# --- helper functions
|
||||
|
||||
def _clearTable(self, tableName : str):
|
||||
self._db.execute(f"DELETE FROM {tableName}")
|
||||
|
||||
# --- test functions
|
||||
|
||||
def test_url(self):
|
||||
assert self._db.url() == self._DB_URL
|
||||
|
||||
def test_table(self):
|
||||
assert self._db.table("energy") == self._db._tables["energy"]
|
||||
assert self._db.table("channels") == self._db._tables["channels"]
|
||||
|
||||
def _test_table_unknownTable(self):
|
||||
pass
|
||||
|
||||
def test_getChannels_emptyDatabase(self):
|
||||
channels = self._db.getChannels()
|
||||
assert channels["channels"] == []
|
||||
|
||||
def test_addChannels(self):
|
||||
self._db.addChannels(["abc", "def", "ghi"])
|
||||
result = self._db.getChannels()
|
||||
assert type(result) == dict
|
||||
channels = result["channels"]
|
||||
assert channels[0] == {"name": "abc"}
|
||||
assert channels[1] == {"name": "def"}
|
||||
assert channels[2] == {"name": "ghi"}
|
||||
|
||||
def test_getChannelId(self):
|
||||
self._db.addChannels(["abc", "def", "ghi"])
|
||||
assert self._db.getChannelId("abc") == 1
|
||||
assert self._db.getChannelId("def") == 2
|
||||
assert self._db.getChannelId("ghi") == 3
|
||||
|
||||
def test_getChannelId_ExceptionIfChannelIsUnknown(self):
|
||||
with pytest.raises(Exception):
|
||||
assert self._db.getChannelId("jkl") == 0
|
||||
|
||||
def test_getChannelData_EmptyDatabase(self):
|
||||
fromTime = datetime.now()
|
||||
tillTime = datetime.now()
|
||||
result = self._db.getChannelData([1], fromTime, tillTime)
|
||||
assert list(result.keys()) == [1]
|
||||
assert result[1] == []
|
||||
|
||||
def test_addChannelData(self):
|
||||
data = [
|
||||
{"timestamp": datetime.fromisoformat("2020-12-12T09:00:01"), "value": 900.01},
|
||||
{"timestamp": datetime.fromisoformat("2020-12-12T09:05:02"), "value": 905.02},
|
||||
{"timestamp": datetime.fromisoformat("2020-12-12T09:10:03"), "value": 910.03},
|
||||
]
|
||||
self._db.addChannelData(8, data)
|
||||
result = self._db.getChannelData(
|
||||
[8],
|
||||
datetime.fromisoformat("2020-12-12T09:00:00"),
|
||||
datetime.now()
|
||||
)
|
||||
assert isinstance(result, dict)
|
||||
assert len(result) == 1
|
||||
assert 8 in result
|
||||
channelData = result[8]
|
||||
assert len(channelData) == 3
|
||||
assert channelData[0] == data[0]
|
||||
assert channelData[1] == data[1]
|
||||
assert channelData[2] == data[2]
|
||||
|
|
@ -119,13 +119,13 @@ class Test_energyDB:
|
|||
# self._dumpRequestAndResponse("test_getInfo", response)
|
||||
assert response.status_code == 200
|
||||
|
||||
def _test_getChannelsOfEmptyTable(self):
|
||||
def test_getChannelsOfEmptyTable(self):
|
||||
response = self.client.get(self._apiUrl("/channels"))
|
||||
# self._dumpRequestAndResponse("test_getChannelsOfEmptyTable", response)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["channels"] == []
|
||||
|
||||
def _test_getBulkDataOfEmptyTable(self):
|
||||
def test_getBulkDataOfEmptyTable(self):
|
||||
response = self.client.get(
|
||||
self._apiUrl("/bulkData"),
|
||||
json = {
|
||||
|
|
@ -133,17 +133,16 @@ class Test_energyDB:
|
|||
"fromTime": "0001-01-01T00:00:00"
|
||||
}
|
||||
)
|
||||
self._dumpRequestAndResponse("test_getBulkDataOfEmptyTable", response)
|
||||
# self._dumpRequestAndResponse("test_getBulkDataOfEmptyTable", response)
|
||||
assert response.status_code == 200
|
||||
assert "bulk" in response.json()
|
||||
assert response.json()["bulk"] == [
|
||||
{
|
||||
"channel_id": 1,
|
||||
"channel": None,
|
||||
"data": [],
|
||||
"msg": None
|
||||
}
|
||||
]
|
||||
bulkData = response.json()["bulk"]
|
||||
assert len(bulkData) == 1
|
||||
assert "channel_id" in bulkData[0]
|
||||
assert "data" in bulkData[0]
|
||||
assert bulkData[0]["channel_id"] == 1
|
||||
channelData = bulkData[0]["data"]
|
||||
assert len(channelData) == 0
|
||||
|
||||
def test_fillDatabase(self):
|
||||
self._fillDatabase()
|
||||
|
|
@ -185,20 +184,62 @@ class Test_energyDB:
|
|||
# self._dumpRequestAndResponse("test_getChannelId", response)
|
||||
assert response.status_code == 200
|
||||
|
||||
def _test_putBulkData(self):
|
||||
response = self.client.put(
|
||||
def test_getBulkData(self):
|
||||
response = self.client.get(
|
||||
self._apiUrl("/bulkData"),
|
||||
json = {"bulk": {
|
||||
json = {
|
||||
"channel_ids": [1],
|
||||
"fromTime": "0001-01-01T00:00:00"
|
||||
}
|
||||
)
|
||||
# self._dumpRequestAndResponse("test_getBulkData", response)
|
||||
assert response.status_code == 200
|
||||
assert "bulk" in response.json()
|
||||
bulkData = response.json()["bulk"]
|
||||
assert len(bulkData) == 1
|
||||
channelData = bulkData[0]
|
||||
assert channelData["channel_id"] == 1
|
||||
referenceData = self.testData["bulkdata"][0]
|
||||
assert len(channelData["data"]) == len(referenceData["data"])
|
||||
assert channelData["data"][0] == referenceData["data"][0]
|
||||
assert channelData["data"][1] == referenceData["data"][1]
|
||||
assert channelData["data"][2] == referenceData["data"][2]
|
||||
assert channelData["data"][3] == referenceData["data"][3]
|
||||
assert channelData["data"][4] == referenceData["data"][4]
|
||||
assert channelData["data"][5] == referenceData["data"][5]
|
||||
assert channelData["data"][6] == referenceData["data"][6]
|
||||
assert channelData["data"][7] == referenceData["data"][7]
|
||||
assert channelData["data"][8] == referenceData["data"][8]
|
||||
assert channelData["data"][9] == referenceData["data"][9]
|
||||
|
||||
def test_putBulkData(self):
|
||||
newData = [{
|
||||
"channel_id": None,
|
||||
"channel": "total_yield",
|
||||
"data": [
|
||||
{ "timestamp": "2020-12-11T12:01:20", "value": 120120.1 },
|
||||
{ "timestamp": "2020-12-11T12:30:25", "value": 123025.2 },
|
||||
]
|
||||
}}
|
||||
}]
|
||||
response = self.client.put(
|
||||
self._apiUrl("/bulkData"),
|
||||
json = {"bulk": newData}
|
||||
)
|
||||
self._dumpRequestAndResponse("test_putBulkData", response)
|
||||
# self._dumpRequestAndResponse("test_putBulkData", response)
|
||||
assert response.status_code == 200
|
||||
response = self.client.get(
|
||||
self._apiUrl("/bulkData"),
|
||||
json = {
|
||||
"channel_ids": [3],
|
||||
"fromTime": "2020-12-11T12:00:00",
|
||||
"tillTime": "2020-12-11T12:59:59"
|
||||
}
|
||||
)
|
||||
channelData = response.json()["bulk"][0]
|
||||
assert channelData["channel_id"] == 3
|
||||
assert len(channelData["data"]) == 2
|
||||
assert channelData["data"][0] == newData[0]["data"][0]
|
||||
assert channelData["data"][1] == newData[0]["data"][1]
|
||||
|
||||
# def test_getRecordCount(self):
|
||||
# response = self.client.get(self._apiUrl("/1/count"))
|
||||
Loading…
Add table
Reference in a new issue