Coverage for /usr/local/lib/python3.11/site-packages/twinpad_backend/db.py: 92%

63 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-25 07:17 +0000

1import os 

2import logging 

3import time 

4 

5from pymongo import MongoClient, AsyncMongoClient, errors 

6 

7MONGO_USERNAME = os.environ.get("MONGO_USERNAME", "mongo") 

8MONGO_PASSWORD = os.environ.get("MONGO_PASSWORD", "change_password") 

9MONGO_HOST = os.environ.get("MONGO_HOST", "localhost") 

10 

11MAX_RETRIES = 5 

12RETRY_DELAY = 10 

13 

14MONGO_CONN_STRING = f"mongodb://{MONGO_USERNAME}:{MONGO_PASSWORD}@{MONGO_HOST}/?retryWrites=true&w=majority" 

15 

16logger = logging.getLogger("uvicorn.error") 

17 

18logger.info("Connecting to mongo database: %s @ %s", MONGO_USERNAME, MONGO_HOST) 

19 

20 

21def connect_to_mongo(): 

22 attempt = 0 

23 while attempt < MAX_RETRIES: 

24 try: 

25 client = MongoClient(MONGO_CONN_STRING) 

26 client.admin.command("ping") 

27 logger.info("Database connected") 

28 return client 

29 except errors.ConnectionFailure as e: # pragma: no cover 

30 attempt += 1 

31 logger.info("Attempt %d/%d failed: %s", attempt, MAX_RETRIES, str(e)) 

32 if attempt < MAX_RETRIES: 

33 logger.info("New attempt in %d seconds...", RETRY_DELAY) 

34 time.sleep(RETRY_DELAY) 

35 logger.info("Cannot connect to database after %d retries", MAX_RETRIES) # pragma: no cover 

36 return None # pragma: no cover 

37 

38 

39client = connect_to_mongo() 

40async_client = AsyncMongoClient(MONGO_CONN_STRING) 

41# mongo_database = client.twinpad 

42 

43signals_database = client.signals 

44signals_async_database = async_client.signals 

45systems_database = client.systems 

46systems_async_database = async_client.systems 

47devices_states_database = client.devices_states 

48devices_states_async_database = async_client.devices_states 

49 

50_collection_cache = {} 

51 

52 

53def get_collection(database, collection_name: str, create: bool = False, time_series: bool = False): 

54 if (database, collection_name) in _collection_cache: 

55 return _collection_cache[database, collection_name] 

56 

57 if collection_name in database.list_collection_names(): 

58 collection = database[collection_name] 

59 _collection_cache[database, collection_name] = collection 

60 return collection 

61 

62 if create: 

63 if time_series: 

64 collection = database.create_collection(collection_name, timeseries={"timeField": "timestamp"}) 

65 else: 

66 collection = database.create_collection(collection_name) 

67 _collection_cache[database, collection_name] = collection 

68 return collection 

69 

70 return None 

71 

72 

73async def get_async_collection(database, collection_name: str, create: bool = False, time_series: bool = False): 

74 if (database, collection_name) in _collection_cache: 

75 return _collection_cache[database, collection_name] 

76 

77 if collection_name in (await database.list_collection_names()): 

78 collection = database[collection_name] 

79 _collection_cache[database, collection_name] = collection 

80 return collection 

81 

82 if create: 

83 if time_series: 

84 collection = await database.create_collection(collection_name, timeseries={"timeField": "timestamp"}) 

85 else: 

86 collection = await database.create_collection(collection_name) 

87 _collection_cache[database, collection_name] = collection 

88 return collection 

89 

90 return None 

91 

92 

93def get_signal_collection(signal_id: str, create=False): 

94 return get_collection(signals_database, signal_id, create=create) 

95 

96 

97def get_signals_ids_from_collection_names(): 

98 """ 

99 This is a low level function. 

100 """ 

101 return sorted([c for c in signals_database.list_collection_names() if not c.startswith("system.")]) 

102 

103 

104# def get_signals(): 

105# return [signal for signal in get_collection(systems_database, "signals", create=True).find({})] 

106 

107# def get_signals_ids(): 

108# return [signal['signal_id'] for signal in get_signals()] 

109 

110 

111def signal_datasize(): 

112 return signals_database.command("dbstats")["dataSize"]