|
3 | 3 | import typing
|
4 | 4 | from typing import Optional
|
5 | 5 | from typing import TextIO
|
| 6 | +from typing import Any |
| 7 | +from io import StringIO |
6 | 8 |
|
7 | 9 |
|
8 |
| -class ExtendedLogger(logging.Logger): |
9 |
| - def __init__(self, logger: logging.Logger): |
| 10 | +class ExtendedLogger: |
| 11 | + def __init__(self, logger: logging.Logger, byte_limit: int = 50 * 1024 * 1024): |
10 | 12 | self._wrapped_logger = logger
|
| 13 | + self._total_bytes = 0 |
| 14 | + self._byte_limit = byte_limit |
| 15 | + self._string_buffer = StringIO() |
11 | 16 |
|
12 |
| - def __getattribute__(self: 'ExtendedLogger', name: str) -> typing.Any: |
13 |
| - # ExtendedLogger is-a logging.Logger, but it delegates most calls to |
14 |
| - # the wrapped-logger (which is also a logging.Logger). |
15 |
| - if name == 'error_and_exit': |
| 17 | + def _process_logged_content(self) -> None: |
| 18 | + string_buffer = object.__getattribute__(self, '_string_buffer') |
| 19 | + logged_content = string_buffer.getvalue() |
| 20 | + if logged_content: |
| 21 | + total_bytes = object.__getattribute__(self, '_total_bytes') |
| 22 | + new_bytes = len(logged_content.encode('utf-8')) |
| 23 | + new_total = total_bytes + new_bytes |
| 24 | + object.__setattr__(self, '_total_bytes', new_total) |
| 25 | + |
| 26 | + byte_limit = object.__getattribute__(self, '_byte_limit') |
| 27 | + if new_total > byte_limit: |
| 28 | + wrapped_logger = object.__getattribute__(self, '_wrapped_logger') |
| 29 | + wrapped_logger.error(f"Log limit of {byte_limit} bytes exceeded") |
| 30 | + os._exit(-1) |
| 31 | + |
| 32 | + string_buffer.seek(0) |
| 33 | + string_buffer.truncate(0) |
| 34 | + |
| 35 | + def __getattribute__(self: 'ExtendedLogger', name: str) -> Any: |
| 36 | + if name in ('error_and_exit', '_wrapped_logger', '_total_bytes', '_byte_limit', '_string_buffer', '_process_logged_content'): |
16 | 37 | return object.__getattribute__(self, name)
|
| 38 | + if name in ('debug', 'info', 'warning', 'error', 'critical', 'exception'): |
| 39 | + logger = object.__getattribute__(self, '_wrapped_logger') |
| 40 | + original_method = logger.__getattribute__(name) |
| 41 | + |
| 42 | + def counting_wrapper(*args: Any, **kwargs: Any) -> Any: |
| 43 | + result = original_method(*args, **kwargs) |
| 44 | + self._process_logged_content() |
| 45 | + return result |
| 46 | + |
| 47 | + return counting_wrapper |
| 48 | + |
17 | 49 | logger = object.__getattribute__(self, '_wrapped_logger')
|
18 | 50 | return logger.__getattribute__(name)
|
19 | 51 |
|
@@ -49,7 +81,15 @@ def configure_logger(lvl: Optional[int] = None) -> ExtendedLogger:
|
49 | 81 | prev_handler = handler
|
50 | 82 | logger.addHandler(handler)
|
51 | 83 |
|
52 |
| - return ExtendedLogger(logger) |
| 84 | + extended_logger = ExtendedLogger(logger) |
| 85 | + |
| 86 | + # Add string handler to capture logs |
| 87 | + string_handler = logging.StreamHandler(extended_logger._string_buffer) |
| 88 | + string_handler.setLevel(lvl) |
| 89 | + string_handler.setFormatter(formatter) |
| 90 | + logger.addHandler(string_handler) |
| 91 | + |
| 92 | + return extended_logger |
53 | 93 |
|
54 | 94 |
|
55 | 95 | prev_handler: Optional['logging.StreamHandler[TextIO]'] = None
|
|
0 commit comments