Coverage for /usr/local/lib/python3.11/site-packages/twinpad_backend/db.py: 81%

53 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-03 07:30 +0000

1import os 

2import logging 

3import time 

4 

5from pymongo import MongoClient 

6 

7MONGO_USERNAME = os.environ.get("MONGO_USERNAME", "mongo") 

8MONGO_PASSWORD = os.environ.get("MONGO_PASSWORD", "change_password") 

9MONGO_HOST = os.environ.get("MONGO_HOST", "localhost") 

10 

11MAX_RETRIES = 5 

12RETRY_DELAY = 10 

13 

14MONGO_CONN_STRING = f"mongodb://{MONGO_USERNAME}:{MONGO_PASSWORD}@{MONGO_HOST}/?retryWrites=true&w=majority" 

15 

16logger = logging.getLogger("uvicorn.error") 

17 

18logger.info(f"Connecting to mongo database: {MONGO_USERNAME} @ {MONGO_HOST}") 

19 

20 

21def connect_to_mongo(): 

22 attempt = 0 

23 while attempt < MAX_RETRIES: 

24 try: 

25 client = MongoClient(MONGO_CONN_STRING) 

26 client.admin.command("ping") 

27 print("Database connected") 

28 return client 

29 except Exception as e: 

30 attempt += 1 

31 print(f"Attempt {attempt}/{MAX_RETRIES} failed: {e}") 

32 if attempt < MAX_RETRIES: 

33 print(f"New attempt in {RETRY_DELAY} seconds...") 

34 time.sleep(RETRY_DELAY) 

35 else: 

36 print("Cannot connect to database after {MAX_RETRIES} retries") 

37 return None 

38 

39 

40client = connect_to_mongo() 

41# mongo_database = client.twinpad 

42 

43signals_database = client.signals 

44systems_database = client.systems 

45devices_states_database = client.devices_states 

46 

47_collection_cache = {} 

48 

49 

50def get_collection(database, collection_name: str, create: bool = False, time_series: bool = False): 

51 if (database, collection_name) in _collection_cache: 

52 return _collection_cache[database, collection_name] 

53 

54 if collection_name in database.list_collection_names(): 

55 collection = database[collection_name] 

56 _collection_cache[database, collection_name] = collection 

57 return collection 

58 

59 if create: 

60 if time_series: 

61 collection = database.create_collection(collection_name, timeseries={"timeField": "timestamp"}) 

62 else: 

63 collection = database.create_collection(collection_name) 

64 _collection_cache[database, collection_name] = collection 

65 return collection 

66 

67 return None 

68 

69 

70def get_signal_collection(signal_id: str, create=False): 

71 return get_collection(signals_database, signal_id, create=create) 

72 

73 

74def get_signals_ids_from_collection_names(): 

75 """ 

76 This is a low level function. 

77 """ 

78 return sorted([c for c in signals_database.list_collection_names() if not c.startswith("system.")]) 

79 

80 

81# def get_signals(): 

82# return [signal for signal in get_collection(systems_database, "signals", create=True).find({})] 

83 

84# def get_signals_ids(): 

85# return [signal['signal_id'] for signal in get_signals()] 

86 

87 

88def signal_datasize(): 

89 return signals_database.command("dbstats")["dataSize"]